diff --git a/pyproject.toml b/pyproject.toml
index 25dd12d..27846ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,7 @@ readme = {file = "README.md", content-type = "text/markdown"}
dependencies = [
"aind_behavior_services>=0.13.5",
"pydantic-settings",
+ "scikit-learn>=1.8.0",
]
[tool.uv.workspace]
diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json
index 681846e..38cf055 100644
--- a/schema/aind_behavior_dynamic_foraging.json
+++ b/schema/aind_behavior_dynamic_foraging.json
@@ -313,6 +313,56 @@
"type": "object",
"x-sgen-typename": "AllenNeuralDynamics.AindManipulator.AindManipulatorCalibration"
},
+ "AntiBiasParameters": {
+ "properties": {
+ "threshold": {
+ "$ref": "#/$defs/BiasThreshold",
+ "default": {
+ "upper": 0.7,
+ "lower": 0.3
+ },
+ "description": "Thresholds for bias correction intervention."
+ },
+ "intervention_interval": {
+ "default": 10,
+ "description": "Trials between bias intervention.",
+ "minimum": 0,
+ "title": "Intervention Interval",
+ "type": "integer"
+ },
+ "maximum_water_corrections": {
+ "default": 5,
+ "description": "Number of water correction to attempt.",
+ "minimum": 0,
+ "title": "Maximum Water Corrections",
+ "type": "integer"
+ },
+ "reward_fraction": {
+ "default": 0.8,
+ "description": "Fraction of full reward volume delivered during auto water (0=none, 1=full).",
+ "maximum": 1,
+ "minimum": 0,
+ "title": "Reward Fraction",
+ "type": "number"
+ },
+ "bias_window_length": {
+ "default": 200,
+ "description": "Trials to calculate bias over.",
+ "minimum": 0,
+ "title": "Bias Window Length",
+ "type": "integer"
+ },
+ "lickspout_offset_delta": {
+ "default": 0.05,
+ "description": "Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.",
+ "minimum": 0,
+ "title": "Lickspout Offset Delta",
+ "type": "number"
+ }
+ },
+ "title": "AntiBiasParameters",
+ "type": "object"
+ },
"AuditorySecondaryReinforcer": {
"description": "Represents an auditory secondary reinforcer.",
"properties": {
@@ -511,7 +561,7 @@
"min_unrewarded_trials": 3,
"reward_fraction": 0.8
},
- "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
"oneOf": [
{
"$ref": "#/$defs/AutoWaterParameters"
@@ -521,6 +571,28 @@
}
]
},
+ "antibias_parameters": {
+ "default": {
+ "threshold": {
+ "lower": 0.3,
+ "upper": 0.7
+ },
+ "intervention_interval": 10,
+ "maximum_water_corrections": 5,
+ "reward_fraction": 0.8,
+ "bias_window_length": 200,
+ "lickspout_offset_delta": 0.05
+ },
+ "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
+ "oneOf": [
+ {
+ "$ref": "#/$defs/AntiBiasParameters"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
"is_baiting": {
"default": false,
"description": "Whether uncollected rewards carry over to the next trial.",
@@ -656,6 +728,28 @@
"type": "object",
"x-sgen-typename": "AllenNeuralDynamics.AindBehaviorServices.Distributions.BetaDistributionParameters"
},
+ "BiasThreshold": {
+ "properties": {
+ "upper": {
+ "default": 0.7,
+ "description": "Absolute value of the upper bias threshold.",
+ "maximum": 1,
+ "minimum": 0,
+ "title": "Upper",
+ "type": "number"
+ },
+ "lower": {
+ "default": 0.3,
+ "description": "Absolute value of the lower bias threshold.",
+ "maximum": 1,
+ "minimum": 0,
+ "title": "Lower",
+ "type": "number"
+ }
+ },
+ "title": "BiasThreshold",
+ "type": "object"
+ },
"BinomialDistribution": {
"description": "A binomial probability distribution.\n\nModels the number of successes in a fixed number of independent\nBernoulli trials with constant success probability.",
"properties": {
@@ -811,7 +905,7 @@
"min_unrewarded_trials": 3,
"reward_fraction": 0.8
},
- "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
"oneOf": [
{
"$ref": "#/$defs/AutoWaterParameters"
@@ -821,6 +915,28 @@
}
]
},
+ "antibias_parameters": {
+ "default": {
+ "threshold": {
+ "lower": 0.3,
+ "upper": 0.7
+ },
+ "intervention_interval": 10,
+ "maximum_water_corrections": 5,
+ "reward_fraction": 0.8,
+ "bias_window_length": 200,
+ "lickspout_offset_delta": 0.05
+ },
+ "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
+ "oneOf": [
+ {
+ "$ref": "#/$defs/AntiBiasParameters"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
"is_baiting": {
"default": false,
"description": "Whether uncollected rewards carry over to the next trial.",
@@ -1077,7 +1193,7 @@
"min_unrewarded_trials": 3,
"reward_fraction": 0.8
},
- "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
"oneOf": [
{
"$ref": "#/$defs/AutoWaterParameters"
@@ -1087,6 +1203,28 @@
}
]
},
+ "antibias_parameters": {
+ "default": {
+ "threshold": {
+ "lower": 0.3,
+ "upper": 0.7
+ },
+ "intervention_interval": 10,
+ "maximum_water_corrections": 5,
+ "reward_fraction": 0.8,
+ "bias_window_length": 200,
+ "lickspout_offset_delta": 0.05
+ },
+ "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
+ "oneOf": [
+ {
+ "$ref": "#/$defs/AntiBiasParameters"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
"is_baiting": {
"default": false,
"description": "Whether uncollected rewards carry over to the next trial.",
@@ -1266,7 +1404,7 @@
"min_unrewarded_trials": 3,
"reward_fraction": 0.8
},
- "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
"oneOf": [
{
"$ref": "#/$defs/AutoWaterParameters"
@@ -1276,6 +1414,28 @@
}
]
},
+ "antibias_parameters": {
+ "default": {
+ "threshold": {
+ "lower": 0.3,
+ "upper": 0.7
+ },
+ "intervention_interval": 10,
+ "maximum_water_corrections": 5,
+ "reward_fraction": 0.8,
+ "bias_window_length": 200,
+ "lickspout_offset_delta": 0.05
+ },
+ "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
+ "oneOf": [
+ {
+ "$ref": "#/$defs/AntiBiasParameters"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
"is_baiting": {
"const": true,
"default": true,
@@ -3403,6 +3563,26 @@
}
]
},
+ "TrialMetrics": {
+ "description": "Represents metrics of session",
+ "properties": {
+ "bias": {
+ "default": null,
+ "description": "Bias of session. Negative values correspond to left bias, positive right.",
+ "oneOf": [
+ {
+ "type": "number"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "title": "Bias"
+ }
+ },
+ "title": "TrialMetrics",
+ "type": "object"
+ },
"TrialOutcome": {
"description": "Represents the outcome of a single trial.",
"properties": {
@@ -3586,7 +3766,7 @@
"min_unrewarded_trials": 3,
"reward_fraction": 0.8
},
- "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
"oneOf": [
{
"$ref": "#/$defs/AutoWaterParameters"
@@ -3596,6 +3776,28 @@
}
]
},
+ "antibias_parameters": {
+ "default": {
+ "threshold": {
+ "lower": 0.3,
+ "upper": 0.7
+ },
+ "intervention_interval": 10,
+ "maximum_water_corrections": 5,
+ "reward_fraction": 0.8,
+ "bias_window_length": 200,
+ "lickspout_offset_delta": 0.05
+ },
+ "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
+ "oneOf": [
+ {
+ "$ref": "#/$defs/AntiBiasParameters"
+ },
+ {
+ "type": "null"
+ }
+ ]
+ },
"is_baiting": {
"default": false,
"description": "Whether uncollected rewards carry over to the next trial.",
diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
index fd4cb0f..021a37d 100644
--- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
+++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs
@@ -703,6 +703,184 @@ public override string ToString()
}
+ [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
+ [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
+ [Bonsai.CombinatorAttribute(MethodName="Generate")]
+ public partial class AntiBiasParameters
+ {
+
+ private BiasThreshold _threshold;
+
+ private int _interventionInterval;
+
+ private int _maximumWaterCorrections;
+
+ private double _rewardFraction;
+
+ private int _biasWindowLength;
+
+ private double _lickspoutOffsetDelta;
+
+ public AntiBiasParameters()
+ {
+ _threshold = new BiasThreshold();
+ _interventionInterval = 10;
+ _maximumWaterCorrections = 5;
+ _rewardFraction = 0.8D;
+ _biasWindowLength = 200;
+ _lickspoutOffsetDelta = 0.05D;
+ }
+
+ protected AntiBiasParameters(AntiBiasParameters other)
+ {
+ _threshold = other._threshold;
+ _interventionInterval = other._interventionInterval;
+ _maximumWaterCorrections = other._maximumWaterCorrections;
+ _rewardFraction = other._rewardFraction;
+ _biasWindowLength = other._biasWindowLength;
+ _lickspoutOffsetDelta = other._lickspoutOffsetDelta;
+ }
+
+ ///
+ /// Thresholds for bias correction intervention.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("threshold")]
+ [System.ComponentModel.DescriptionAttribute("Thresholds for bias correction intervention.")]
+ public BiasThreshold Threshold
+ {
+ get
+ {
+ return _threshold;
+ }
+ set
+ {
+ _threshold = value;
+ }
+ }
+
+ ///
+ /// Trials between bias intervention.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("intervention_interval")]
+ [System.ComponentModel.DescriptionAttribute("Trials between bias intervention.")]
+ public int InterventionInterval
+ {
+ get
+ {
+ return _interventionInterval;
+ }
+ set
+ {
+ _interventionInterval = value;
+ }
+ }
+
+ ///
+ /// Number of water correction to attempt.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("maximum_water_corrections")]
+ [System.ComponentModel.DescriptionAttribute("Number of water correction to attempt.")]
+ public int MaximumWaterCorrections
+ {
+ get
+ {
+ return _maximumWaterCorrections;
+ }
+ set
+ {
+ _maximumWaterCorrections = value;
+ }
+ }
+
+ ///
+ /// Fraction of full reward volume delivered during auto water (0=none, 1=full).
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("reward_fraction")]
+ [System.ComponentModel.DescriptionAttribute("Fraction of full reward volume delivered during auto water (0=none, 1=full).")]
+ public double RewardFraction
+ {
+ get
+ {
+ return _rewardFraction;
+ }
+ set
+ {
+ _rewardFraction = value;
+ }
+ }
+
+ ///
+ /// Trials to calculate bias over.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("bias_window_length")]
+ [System.ComponentModel.DescriptionAttribute("Trials to calculate bias over.")]
+ public int BiasWindowLength
+ {
+ get
+ {
+ return _biasWindowLength;
+ }
+ set
+ {
+ _biasWindowLength = value;
+ }
+ }
+
+ ///
+ /// Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("lickspout_offset_delta")]
+ [System.ComponentModel.DescriptionAttribute("Distance (mm) to move the stage spouts by. This is a relative distance to the cur" +
+ "rent value, not absolute.")]
+ public double LickspoutOffsetDelta
+ {
+ get
+ {
+ return _lickspoutOffsetDelta;
+ }
+ set
+ {
+ _lickspoutOffsetDelta = value;
+ }
+ }
+
+ public System.IObservable Generate()
+ {
+ return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AntiBiasParameters(this)));
+ }
+
+ public System.IObservable Generate(System.IObservable source)
+ {
+ return System.Reactive.Linq.Observable.Select(source, _ => new AntiBiasParameters(this));
+ }
+
+ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
+ {
+ stringBuilder.Append("Threshold = " + _threshold + ", ");
+ stringBuilder.Append("InterventionInterval = " + _interventionInterval + ", ");
+ stringBuilder.Append("MaximumWaterCorrections = " + _maximumWaterCorrections + ", ");
+ stringBuilder.Append("RewardFraction = " + _rewardFraction + ", ");
+ stringBuilder.Append("BiasWindowLength = " + _biasWindowLength + ", ");
+ stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta);
+ return true;
+ }
+
+ public override string ToString()
+ {
+ System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder();
+ stringBuilder.Append(GetType().Name);
+ stringBuilder.Append(" { ");
+ if (PrintMembers(stringBuilder))
+ {
+ stringBuilder.Append(" ");
+ }
+ stringBuilder.Append("}");
+ return stringBuilder.ToString();
+ }
+ }
+
+
///
/// Represents an auditory secondary reinforcer.
///
@@ -897,6 +1075,8 @@ public partial class BaseCoupledTrialGeneratorSpec : TrialGeneratorSpec
private AutoWaterParameters _autowaterParameters;
+ private AntiBiasParameters _antibiasParameters;
+
private bool _isBaiting;
private RewardProbabilityParameters _rewardProbabilityParameters;
@@ -909,6 +1089,7 @@ public BaseCoupledTrialGeneratorSpec()
_interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_autowaterParameters = new AutoWaterParameters();
+ _antibiasParameters = new AntiBiasParameters();
_isBaiting = false;
_rewardProbabilityParameters = new RewardProbabilityParameters();
}
@@ -922,6 +1103,7 @@ protected BaseCoupledTrialGeneratorSpec(BaseCoupledTrialGeneratorSpec other) :
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_blockLength = other._blockLength;
_autowaterParameters = other._autowaterParameters;
+ _antibiasParameters = other._antibiasParameters;
_isBaiting = other._isBaiting;
_rewardProbabilityParameters = other._rewardProbabilityParameters;
}
@@ -1016,12 +1198,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block
}
///
- /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
+ /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")]
- [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " +
- "ignored or unrewarded trial thresholds.")]
+ [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" +
+ "gnored or unrewarded trial thresholds.")]
public AutoWaterParameters AutowaterParameters
{
get
@@ -1034,6 +1216,25 @@ public AutoWaterParameters AutowaterParameters
}
}
+ ///
+ /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")]
+ [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" +
+ " combat bias.")]
+ public AntiBiasParameters AntibiasParameters
+ {
+ get
+ {
+ return _antibiasParameters;
+ }
+ set
+ {
+ _antibiasParameters = value;
+ }
+ }
+
///
/// Whether uncollected rewards carry over to the next trial.
///
@@ -1091,6 +1292,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("BlockLength = " + _blockLength + ", ");
stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", ");
+ stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting + ", ");
stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters);
return true;
@@ -1259,6 +1461,94 @@ public override string ToString()
}
+ [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
+ [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
+ [Bonsai.CombinatorAttribute(MethodName="Generate")]
+ public partial class BiasThreshold
+ {
+
+ private double _upper;
+
+ private double _lower;
+
+ public BiasThreshold()
+ {
+ _upper = 0.7D;
+ _lower = 0.3D;
+ }
+
+ protected BiasThreshold(BiasThreshold other)
+ {
+ _upper = other._upper;
+ _lower = other._lower;
+ }
+
+ ///
+ /// Absolute value of the upper bias threshold.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("upper")]
+ [System.ComponentModel.DescriptionAttribute("Absolute value of the upper bias threshold.")]
+ public double Upper
+ {
+ get
+ {
+ return _upper;
+ }
+ set
+ {
+ _upper = value;
+ }
+ }
+
+ ///
+ /// Absolute value of the lower bias threshold.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("lower")]
+ [System.ComponentModel.DescriptionAttribute("Absolute value of the lower bias threshold.")]
+ public double Lower
+ {
+ get
+ {
+ return _lower;
+ }
+ set
+ {
+ _lower = value;
+ }
+ }
+
+ public System.IObservable Generate()
+ {
+ return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new BiasThreshold(this)));
+ }
+
+ public System.IObservable Generate(System.IObservable source)
+ {
+ return System.Reactive.Linq.Observable.Select(source, _ => new BiasThreshold(this));
+ }
+
+ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
+ {
+ stringBuilder.Append("Upper = " + _upper + ", ");
+ stringBuilder.Append("Lower = " + _lower);
+ return true;
+ }
+
+ public override string ToString()
+ {
+ System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder();
+ stringBuilder.Append(GetType().Name);
+ stringBuilder.Append(" { ");
+ if (PrintMembers(stringBuilder))
+ {
+ stringBuilder.Append(" ");
+ }
+ stringBuilder.Append("}");
+ return stringBuilder.ToString();
+ }
+ }
+
+
[System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
[Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
[Bonsai.CombinatorAttribute(MethodName="Generate")]
@@ -1277,6 +1567,8 @@ public partial class BlockBasedTrialGeneratorSpec : TrialGeneratorSpec
private AutoWaterParameters _autowaterParameters;
+ private AntiBiasParameters _antibiasParameters;
+
private bool _isBaiting;
public BlockBasedTrialGeneratorSpec()
@@ -1287,6 +1579,7 @@ public BlockBasedTrialGeneratorSpec()
_interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_autowaterParameters = new AutoWaterParameters();
+ _antibiasParameters = new AntiBiasParameters();
_isBaiting = false;
}
@@ -1299,6 +1592,7 @@ protected BlockBasedTrialGeneratorSpec(BlockBasedTrialGeneratorSpec other) :
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_blockLength = other._blockLength;
_autowaterParameters = other._autowaterParameters;
+ _antibiasParameters = other._antibiasParameters;
_isBaiting = other._isBaiting;
}
@@ -1392,12 +1686,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block
}
///
- /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
+ /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")]
- [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " +
- "ignored or unrewarded trial thresholds.")]
+ [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" +
+ "gnored or unrewarded trial thresholds.")]
public AutoWaterParameters AutowaterParameters
{
get
@@ -1410,6 +1704,25 @@ public AutoWaterParameters AutowaterParameters
}
}
+ ///
+ /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")]
+ [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" +
+ " combat bias.")]
+ public AntiBiasParameters AntibiasParameters
+ {
+ get
+ {
+ return _antibiasParameters;
+ }
+ set
+ {
+ _antibiasParameters = value;
+ }
+ }
+
///
/// Whether uncollected rewards carry over to the next trial.
///
@@ -1449,6 +1762,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("BlockLength = " + _blockLength + ", ");
stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", ");
+ stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting);
return true;
}
@@ -1976,6 +2290,8 @@ public partial class CoupledTrialGeneratorSpec : TrialGeneratorSpec
private AutoWaterParameters _autowaterParameters;
+ private AntiBiasParameters _antibiasParameters;
+
private bool _isBaiting;
private RewardProbabilityParameters _rewardProbabilityParameters;
@@ -1998,6 +2314,7 @@ public CoupledTrialGeneratorSpec()
_interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_autowaterParameters = new AutoWaterParameters();
+ _antibiasParameters = new AntiBiasParameters();
_isBaiting = false;
_rewardProbabilityParameters = new RewardProbabilityParameters();
_trialGenerationEndParameters = new CoupledTrialGenerationEndConditions();
@@ -2016,6 +2333,7 @@ protected CoupledTrialGeneratorSpec(CoupledTrialGeneratorSpec other) :
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_blockLength = other._blockLength;
_autowaterParameters = other._autowaterParameters;
+ _antibiasParameters = other._antibiasParameters;
_isBaiting = other._isBaiting;
_rewardProbabilityParameters = other._rewardProbabilityParameters;
_trialGenerationEndParameters = other._trialGenerationEndParameters;
@@ -2115,12 +2433,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block
}
///
- /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
+ /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")]
- [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " +
- "ignored or unrewarded trial thresholds.")]
+ [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" +
+ "gnored or unrewarded trial thresholds.")]
public AutoWaterParameters AutowaterParameters
{
get
@@ -2133,6 +2451,25 @@ public AutoWaterParameters AutowaterParameters
}
}
+ ///
+ /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")]
+ [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" +
+ " combat bias.")]
+ public AntiBiasParameters AntibiasParameters
+ {
+ get
+ {
+ return _antibiasParameters;
+ }
+ set
+ {
+ _antibiasParameters = value;
+ }
+ }
+
///
/// Whether uncollected rewards carry over to the next trial.
///
@@ -2275,6 +2612,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("BlockLength = " + _blockLength + ", ");
stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", ");
+ stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting + ", ");
stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", ");
stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", ");
@@ -2437,6 +2775,8 @@ public partial class CoupledWarmupTrialGeneratorSpec : TrialGeneratorSpec
private AutoWaterParameters _autowaterParameters;
+ private AntiBiasParameters _antibiasParameters;
+
private bool _isBaiting;
private RewardProbabilityParameters _rewardProbabilityParameters;
@@ -2451,6 +2791,7 @@ public CoupledWarmupTrialGeneratorSpec()
_interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_autowaterParameters = new AutoWaterParameters();
+ _antibiasParameters = new AntiBiasParameters();
_isBaiting = true;
_rewardProbabilityParameters = new RewardProbabilityParameters();
_trialGenerationEndParameters = new CoupledWarmupTrialGenerationEndConditions();
@@ -2465,6 +2806,7 @@ protected CoupledWarmupTrialGeneratorSpec(CoupledWarmupTrialGeneratorSpec other)
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_blockLength = other._blockLength;
_autowaterParameters = other._autowaterParameters;
+ _antibiasParameters = other._antibiasParameters;
_isBaiting = other._isBaiting;
_rewardProbabilityParameters = other._rewardProbabilityParameters;
_trialGenerationEndParameters = other._trialGenerationEndParameters;
@@ -2560,12 +2902,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block
}
///
- /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
+ /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")]
- [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " +
- "ignored or unrewarded trial thresholds.")]
+ [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" +
+ "gnored or unrewarded trial thresholds.")]
public AutoWaterParameters AutowaterParameters
{
get
@@ -2578,6 +2920,25 @@ public AutoWaterParameters AutowaterParameters
}
}
+ ///
+ /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")]
+ [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" +
+ " combat bias.")]
+ public AntiBiasParameters AntibiasParameters
+ {
+ get
+ {
+ return _antibiasParameters;
+ }
+ set
+ {
+ _antibiasParameters = value;
+ }
+ }
+
///
/// Whether uncollected rewards carry over to the next trial.
///
@@ -2653,6 +3014,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("BlockLength = " + _blockLength + ", ");
stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", ");
+ stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting + ", ");
stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", ");
stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters);
@@ -6147,6 +6509,75 @@ public override string ToString()
}
+ ///
+ /// Represents metrics of session
+ ///
+ [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")]
+ [System.ComponentModel.DescriptionAttribute("Represents metrics of session")]
+ [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)]
+ [Bonsai.CombinatorAttribute(MethodName="Generate")]
+ public partial class TrialMetrics
+ {
+
+ private double? _bias;
+
+ public TrialMetrics()
+ {
+ }
+
+ protected TrialMetrics(TrialMetrics other)
+ {
+ _bias = other._bias;
+ }
+
+ ///
+ /// Bias of session. Negative values correspond to left bias, positive right.
+ ///
+ [Newtonsoft.Json.JsonPropertyAttribute("bias")]
+ [System.ComponentModel.DescriptionAttribute("Bias of session. Negative values correspond to left bias, positive right.")]
+ public double? Bias
+ {
+ get
+ {
+ return _bias;
+ }
+ set
+ {
+ _bias = value;
+ }
+ }
+
+ public System.IObservable Generate()
+ {
+ return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new TrialMetrics(this)));
+ }
+
+ public System.IObservable Generate(System.IObservable source)
+ {
+ return System.Reactive.Linq.Observable.Select(source, _ => new TrialMetrics(this));
+ }
+
+ protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder)
+ {
+ stringBuilder.Append("Bias = " + _bias);
+ return true;
+ }
+
+ public override string ToString()
+ {
+ System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder();
+ stringBuilder.Append(GetType().Name);
+ stringBuilder.Append(" { ");
+ if (PrintMembers(stringBuilder))
+ {
+ stringBuilder.Append(" ");
+ }
+ stringBuilder.Append("}");
+ return stringBuilder.ToString();
+ }
+ }
+
+
///
/// Represents the outcome of a single trial.
///
@@ -6478,6 +6909,8 @@ public partial class UncoupledTrialGeneratorSpec : TrialGeneratorSpec
private AutoWaterParameters _autowaterParameters;
+ private AntiBiasParameters _antibiasParameters;
+
private bool _isBaiting;
private UncoupledTrialGenerationEndConditions _trialGenerationEndParameters;
@@ -6494,6 +6927,7 @@ public UncoupledTrialGeneratorSpec()
_interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution();
_blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistribution();
_autowaterParameters = new AutoWaterParameters();
+ _antibiasParameters = new AntiBiasParameters();
_isBaiting = false;
_trialGenerationEndParameters = new UncoupledTrialGenerationEndConditions();
_rewardProbabilities = new System.Collections.Generic.List();
@@ -6509,6 +6943,7 @@ protected UncoupledTrialGeneratorSpec(UncoupledTrialGeneratorSpec other) :
_interTrialIntervalDuration = other._interTrialIntervalDuration;
_blockLength = other._blockLength;
_autowaterParameters = other._autowaterParameters;
+ _antibiasParameters = other._antibiasParameters;
_isBaiting = other._isBaiting;
_trialGenerationEndParameters = other._trialGenerationEndParameters;
_rewardProbabilities = other._rewardProbabilities;
@@ -6605,12 +7040,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistributio
}
///
- /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
+ /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.
///
[System.Xml.Serialization.XmlIgnoreAttribute()]
[Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")]
- [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " +
- "ignored or unrewarded trial thresholds.")]
+ [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" +
+ "gnored or unrewarded trial thresholds.")]
public AutoWaterParameters AutowaterParameters
{
get
@@ -6623,6 +7058,25 @@ public AutoWaterParameters AutowaterParameters
}
}
+ ///
+ /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.
+ ///
+ [System.Xml.Serialization.XmlIgnoreAttribute()]
+ [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")]
+ [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" +
+ " combat bias.")]
+ public AntiBiasParameters AntibiasParameters
+ {
+ get
+ {
+ return _antibiasParameters;
+ }
+ set
+ {
+ _antibiasParameters = value;
+ }
+ }
+
///
/// Whether uncollected rewards carry over to the next trial.
///
@@ -6715,6 +7169,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder)
stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", ");
stringBuilder.Append("BlockLength = " + _blockLength + ", ");
stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", ");
+ stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", ");
stringBuilder.Append("IsBaiting = " + _isBaiting + ", ");
stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", ");
stringBuilder.Append("RewardProbabilities = " + _rewardProbabilities + ", ");
@@ -7855,6 +8310,11 @@ public System.IObservable Process(System.IObservable(source);
}
+ public System.IObservable Process(System.IObservable source)
+ {
+ return Process(source);
+ }
+
public System.IObservable Process(System.IObservable source)
{
return Process(source);
@@ -7880,6 +8340,11 @@ public System.IObservable Process(System.IObservable(source);
}
+ public System.IObservable Process(System.IObservable source)
+ {
+ return Process(source);
+ }
+
public System.IObservable Process(System.IObservable source)
{
return Process(source);
@@ -8025,6 +8490,11 @@ public System.IObservable Process(System.IObservable
return Process(source);
}
+ public System.IObservable Process(System.IObservable source)
+ {
+ return Process(source);
+ }
+
public System.IObservable Process(System.IObservable source)
{
return Process(source);
@@ -8087,11 +8557,13 @@ public System.IObservable Process(System.IObservable source)
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
+ [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
+ [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
@@ -8121,6 +8593,7 @@ public System.IObservable Process(System.IObservable source)
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
+ [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))]
diff --git a/src/Extensions/DeserializeFromPyObject.cs b/src/Extensions/DeserializeFromPyObject.cs
index 6b19749..558e0a0 100644
--- a/src/Extensions/DeserializeFromPyObject.cs
+++ b/src/Extensions/DeserializeFromPyObject.cs
@@ -15,6 +15,8 @@
[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))]
[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))]
+[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))]
+
public class DeserializeFromPyObject : SingleArgumentExpressionBuilder
{
public DeserializeFromPyObject()
diff --git a/src/Extensions/TaskEngine.bonsai b/src/Extensions/TaskEngine.bonsai
index 46fea6c..4367cfd 100644
--- a/src/Extensions/TaskEngine.bonsai
+++ b/src/Extensions/TaskEngine.bonsai
@@ -1217,6 +1217,47 @@
true
+
+ Metrics
+
+
+
+
+
+
+
+ trial_generator
+
+
+
+
+ metrics
+ true
+
+
+
+
+
+
+
+
+ GlobalTrialMetrics
+
+
+ TrialMetrics
+
+
+
+
+
+
+
+
+
+
+
+
+
SampleNextTrial
@@ -1262,6 +1303,7 @@
+
diff --git a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py
index 3f2e238..36b5435 100644
--- a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py
+++ b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py
@@ -194,6 +194,13 @@ def make_dataset(
name="SoftwareEvents",
description="Software events generated by the workflow. The timestamps of these events are low precision and should not be used to align to physiology data.",
data_streams=[
+ SoftwareEvents(
+ name="TrialMetrics",
+ description="An event emitted with the metrics of a trial.",
+ reader_params=SoftwareEvents.make_params(
+ root_path / "behavior/SoftwareEvents/TrialMetrics.json"
+ ),
+ ),
SoftwareEvents(
name="TrialGeneratorSpec",
description="An event emitted with the specification for the trial generator.",
diff --git a/src/aind_behavior_dynamic_foraging/regenerate.py b/src/aind_behavior_dynamic_foraging/regenerate.py
index 4dd1717..53a8389 100644
--- a/src/aind_behavior_dynamic_foraging/regenerate.py
+++ b/src/aind_behavior_dynamic_foraging/regenerate.py
@@ -20,6 +20,7 @@ def main():
Session,
aind_behavior_dynamic_foraging.task_logic.trial_models.Trial,
aind_behavior_dynamic_foraging.task_logic.trial_models.TrialOutcome,
+ aind_behavior_dynamic_foraging.task_logic.trial_models.TrialMetrics,
]
model = pydantic.RootModel[Union[tuple(models)]]
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py
index 2405daa..285adaf 100644
--- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py
+++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py
@@ -3,7 +3,7 @@
from pydantic import BaseModel
-from ..trial_models import Trial, TrialOutcome
+from ..trial_models import Trial, TrialMetrics, TrialOutcome
class BaseTrialGeneratorSpecModel(BaseModel, abc.ABC):
@@ -24,3 +24,6 @@ def next(self) -> Trial | None:
def update(self, outcome: TrialOutcome | str) -> None:
"""Update the trial generator with the outcome of the previous trial."""
+
+ def metrics(self) -> TrialMetrics:
+ """Return metrics of session at current state of the trial generator."""
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
index 39ddc92..84b2b72 100644
--- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
+++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py
@@ -12,7 +12,9 @@
from aind_behavior_services.task.distributions_utils import draw_sample
from pydantic import BaseModel, Field
-from ..trial_models import Metadata, Trial
+from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias
+
+from ..trial_models import Metadata, Trial, TrialMetrics
from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome
logger = logging.getLogger(__name__)
@@ -36,6 +38,31 @@ class AutoWaterParameters(BaseModel):
ge=0,
le=1,
description="Fraction of full reward volume delivered during auto water (0=none, 1=full).",
+ ) # TODO: Not implemented yet
+
+
+class BiasThreshold(BaseModel):
+ upper: float = Field(default=0.7, ge=0, le=1, description="Absolute value of the upper bias threshold.")
+ lower: float = Field(default=0.3, ge=0, le=1, description="Absolute value of the lower bias threshold.")
+
+
+class AntiBiasParameters(BaseModel):
+ threshold: BiasThreshold = Field(
+ default=BiasThreshold(), validate_default=True, description="Thresholds for bias correction intervention."
+ )
+ intervention_interval: int = Field(default=10, ge=0, description="Trials between bias intervention.")
+ maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.")
+ reward_fraction: float = Field(
+ default=0.8,
+ ge=0,
+ le=1,
+ description="Fraction of full reward volume delivered during auto water (0=none, 1=full).",
+ ) # TODO: Not implemented yet
+ bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.")
+ lickspout_offset_delta: float = Field(
+ default=0.05,
+ ge=0,
+ description="Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.",
)
@@ -84,7 +111,13 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel):
autowater_parameters: Optional[AutoWaterParameters] = Field(
default=AutoWaterParameters(),
validate_default=True,
- description="Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ description="Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.",
+ )
+
+ antibias_parameters: Optional[AntiBiasParameters] = Field(
+ default=AntiBiasParameters(),
+ validate_default=True,
+ description="Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.",
)
is_baiting: bool = Field(default=False, description="Whether uncollected rewards carry over to the next trial.")
@@ -106,7 +139,9 @@ class BlockBasedTrialGenerator(ITrialGenerator, ABC):
reward_history: Record of whether each trial resulted in a reward.
is_left_baited: Whether the left port currently has a baited reward.
is_right_baited: Whether the right port currently has a baited reward.
-
+ trials_in_bias_intervention: trials elapsed since last bias intervention
+ water_corrections: number of water corrections applied to combat bias
+ bias: bias of session. Negative values correspond to left bias, positive right.
"""
def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None:
@@ -117,12 +152,19 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None:
"""
self.spec = spec
+ self.outcome_history: list[TrialOutcome] = []
self.is_right_choice_history: list[bool | None] = []
self.reward_history: list[bool] = []
self.is_left_baited: bool = False
self.is_right_baited: bool = False
self.block: Block
+ # antibias parameters
+ self.trials_in_bias_intervention = 0
+ self.water_corrections = 0
+ self.bias: Optional[float] = None
+ self.total_lickspout_offset = 0
+
def update(self, outcome: TrialOutcome | str):
"""Updates generator state from the previous trial outcome. Records choice and reward history and manages baiting state.
Args:
@@ -132,6 +174,7 @@ def update(self, outcome: TrialOutcome | str):
if isinstance(outcome, str):
outcome = TrialOutcome.model_validate_json(outcome)
+ self.outcome_history.append(outcome)
self.is_right_choice_history.append(outcome.is_right_choice)
self.reward_history.append(outcome.is_rewarded)
@@ -146,6 +189,8 @@ def update(self, outcome: TrialOutcome | str):
# trial ignored so current baiting state retained
pass
+ self.bias = calculate_bias(outcomes=self.outcome_history)
+
def next(self) -> Trial | None:
"""Generates the next trial in the session.
@@ -176,10 +221,21 @@ def next(self) -> Trial | None:
self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited
logger.debug("Right baited: %s" % self.is_right_baited)
- # determine autowater
is_auto_response_right = None
+
+ # determine autowater
if is_autowater := self._are_autowater_conditions_met():
is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False
+ logger.debug("Delivering autowater: is_auto_response_right = %s" % is_auto_response_right)
+
+ # determine bias correction. Overrides autowater
+ lickspout_offset_delta = 0
+ if self._are_antibias_conditions_met():
+ is_auto_response_right, lickspout_offset_delta = self._determine_antibias_intervention()
+ logger.debug(
+ "Performing bias intervention: is_auto_response_right = %s, lickspout_offset_delta = %s."
+ % (is_auto_response_right, lickspout_offset_delta)
+ )
return Trial(
p_reward_left=1 if (self.is_left_baited or is_auto_response_right is False) else self.block.p_left_reward,
@@ -188,6 +244,7 @@ def next(self) -> Trial | None:
response_deadline_duration=self.spec.response_duration,
quiescence_period_duration=quiescent,
inter_trial_interval_duration=iti,
+ lickspout_offset_delta=lickspout_offset_delta,
is_auto_response_right=is_auto_response_right,
metadata=Metadata(
p_reward_left=self.block.p_left_reward,
@@ -196,6 +253,11 @@ def next(self) -> Trial | None:
),
)
+ def metrics(self) -> TrialMetrics:
+ """Return metrics of session at current state of the trial generator."""
+
+ return TrialMetrics(bias=self.bias)
+
def _are_autowater_conditions_met(self) -> bool:
"""Checks whether autowater should be given.
@@ -203,7 +265,8 @@ def _are_autowater_conditions_met(self) -> bool:
True if autowater conditions are met, False otherwise.
"""
- if self.spec.autowater_parameters is None: # autowater disabled
+ if self.spec.autowater_parameters is None:
+ logger.debug("Autowater not configured.")
return False
min_ignore = self.spec.autowater_parameters.min_ignored_trials
@@ -211,14 +274,71 @@ def _are_autowater_conditions_met(self) -> bool:
is_ignored = [choice is None for choice in self.is_right_choice_history]
if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]):
+ logger.debug("Past %s trials ignored." % min_ignore)
return True
is_unrewarded = [not reward for reward in self.reward_history]
if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]):
+ logger.debug("Past %s trials unrewarded." % min_unreward)
return True
return False
+ def _are_antibias_conditions_met(self) -> bool:
+ """Checks whether antibias conditions are met.
+
+ Returns:
+ True if antibias conditions are met, False otherwise.
+ """
+
+ if self.spec.antibias_parameters is None:
+ logger.debug("Anitbias not configured.")
+ return False
+
+ if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval:
+ if self.bias <= self.spec.antibias_parameters.threshold.lower:
+ logger.debug("Bias calculated below threshold: %s." % self.bias)
+ return True
+
+ if self.bias >= self.spec.antibias_parameters.threshold.upper:
+ logger.debug("Bias calculated above threshold: %s." % self.bias)
+ return True
+
+ return False
+
+ def _determine_antibias_intervention(self) -> tuple[bool | None, float]:
+ """Determine anitbias interventions to perform: give water or move lickspouts
+
+ Returns:
+ Tuple dictating is_auto_response_right and lickspout_offset_delta of trial
+ """
+
+ is_right_autowater = None
+ lickspout_offset_delta = 0
+ ab_delta = self.spec.antibias_parameters.lickspout_offset_delta
+ if abs(self.bias) >= self.spec.antibias_parameters.threshold.upper:
+ if self.water_corrections < self.spec.antibias_parameters.maximum_water_corrections:
+ logger.debug("Correcting bias with water.")
+ is_right_autowater = (
+ True if self.bias < 0 else False
+ ) # - bias values corresponds to left, so give right and vice versa
+ self.water_corrections += 1
+ else:
+ logger.debug("Correcting bias with lickspout offset.")
+ lickspout_offset_delta = ab_delta if self.bias < 0 else -ab_delta # + values move lickspout right
+ self.water_corrections = 0
+
+ elif (
+ abs(self.bias) <= self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0
+ ): # bias below lower threshold, move back towards center
+ logger.debug("Moving lickspout back toward center.")
+ delta = min(ab_delta, abs(self.total_lickspout_offset))
+ lickspout_offset_delta = -delta if self.total_lickspout_offset > 0 else delta
+
+ self.total_lickspout_offset += lickspout_offset_delta
+
+ return is_right_autowater, lickspout_offset_delta
+
@abstractmethod
def _are_end_conditions_met(self) -> bool:
"""Checks whether the session should end.
@@ -229,6 +349,7 @@ def _are_end_conditions_met(self) -> bool:
"""
pass
+ @abstractmethod
def _generate_next_block(*args, **kwargs) -> Block:
"""Abstract method. Subclasses must implement their own block switching logic.
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
index 7f28f4d..ee0a234 100644
--- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
+++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py
@@ -120,3 +120,11 @@ class TrialOutcome(BaseModel):
description="Reports the choice made by the subject. True for right, False for left, None for no choice."
)
is_rewarded: bool = Field(description="Indicates whether the subject received a reward on this trial.")
+
+
+class TrialMetrics(BaseModel):
+ """Represents metrics of session"""
+
+ bias: Optional[float] = Field(
+ default=None, description="Bias of session. Negative values correspond to left bias, positive right."
+ )
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py
new file mode 100644
index 0000000..ed9492e
--- /dev/null
+++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py
@@ -0,0 +1,3 @@
+from .calculate_bias import calculate_bias
+
+__all__ = ["calculate_bias"]
diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py
new file mode 100644
index 0000000..e2767f8
--- /dev/null
+++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py
@@ -0,0 +1,86 @@
+import logging
+from typing import List
+
+import numpy as np
+from sklearn.linear_model import LogisticRegressionCV
+
+from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome
+
+logger = logging.getLogger(__name__)
+
+
+def calculate_bias(outcomes: List[TrialOutcome]) -> float:
+ """Estimate the side bias of an animal using logistic regression on recent trial history.
+
+ Fits a Su2022-style logistic regression model using rewarded and unrewarded choice
+ history as predictors. The intercept of the fitted model is returned as the bias,
+ representing the animal's baseline tendency to choose right independent of reward history.
+
+ Parameters
+ ----------
+ outcomes : List[TrialOutcome]
+ List of trial outcomes. Auto-response and ignored trials are excluded.
+ Only the most recent 200 trials are used.
+
+ Returns
+ -------
+ float
+ The logistic regression intercept, representing side bias.
+ Positive values indicate a bias toward right, negative toward left.
+ """
+
+ solver = "liblinear"
+ l1_ratios = (0,)
+ lag = 5
+ cv = 10
+ cs = 10
+
+ # reduce outcomes to last 200
+ outcomes = outcomes[-200:]
+
+ # exclude auto response and ignored trials
+ filtered = [t for t in outcomes if t.is_right_choice is not None and t.trial.is_auto_response_right is None]
+
+ if len(filtered) <= lag:
+ logger.warning("Not enough choices to calculate bias.")
+ return np.nan
+
+ is_right_choice_history = np.array([t.is_right_choice for t in filtered], dtype=float)
+ is_rewarded_history = np.array([t.is_rewarded for t in filtered], dtype=float)
+
+ # transform to +-1 space for zero centered logistic regression
+ choice_signed = 2 * is_right_choice_history - 1 # left=0 → -1, right=1 → +1
+ reward_signed = 2 * is_rewarded_history - 1 # unrewarded=0 → -1, rewarded=1 → +1
+
+ rewarded_choice = choice_signed * (reward_signed == 1)
+ unrewarded_choice = choice_signed * (reward_signed == -1)
+
+ trial_length = len(is_rewarded_history)
+ x = np.zeros((trial_length - lag, 2 * lag))
+ for i in range(lag, trial_length):
+ x[i - lag] = np.hstack([choice[i - lag : i] for choice in [rewarded_choice, unrewarded_choice]])
+
+ y = choice_signed[lag:]
+
+ n_right_choice = np.sum(y == 1)
+ n_left_choice = np.sum(y == -1)
+ if min(n_right_choice, n_left_choice) < cv:
+ logger.warning(
+ "Not enough trials per class to fit logistic regression (need %d per class, got %d right, %d left).",
+ cv,
+ n_right_choice,
+ n_left_choice,
+ )
+ return np.nan
+
+ logistic_reg = LogisticRegressionCV(
+ solver=solver,
+ l1_ratios=l1_ratios,
+ Cs=cs,
+ cv=cv,
+ use_legacy_attributes=True,
+ )
+ logistic_reg.fit(x, y)
+
+ bias = logistic_reg.intercept_[0]
+ return bias
diff --git a/src/main.bonsai b/src/main.bonsai
index f64dd49..a5e1dbe 100644
--- a/src/main.bonsai
+++ b/src/main.bonsai
@@ -150,6 +150,9 @@
GlobalTrial
+
+ GlobalTrialMetrics
+
TaskLogicParameters
@@ -230,21 +233,21 @@
-
-
-
-
+
+
+
-
-
+
+
-
-
-
-
-
+
+
+
+
+
+
diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py
index 661439d..29c966e 100644
--- a/tests/trial_generators/test_block_based_trial_generator.py
+++ b/tests/trial_generators/test_block_based_trial_generator.py
@@ -5,11 +5,14 @@
import numpy as np
from aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator import (
+ AntiBiasParameters,
+ AutoWaterParameters,
+ BiasThreshold,
Block,
BlockBasedTrialGenerator,
BlockBasedTrialGeneratorSpec,
)
-from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial
+from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome
logging.basicConfig(level=logging.DEBUG)
@@ -49,6 +52,252 @@ def test_next_returns_correct_reward_probs(self):
self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward)
+class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase):
+ def _patch_bias(self, bias_value: float) -> dict:
+
+ return patch(
+ "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.calculate_bias",
+ return_value=bias_value,
+ )
+
+ def _make_generator(
+ self,
+ bias: float,
+ trials_in_bias_intervention: int = 15,
+ water_corrections: int = 0,
+ maximum_water_corrections: int = 5,
+ bias_window_length: int = 5,
+ intervention_interval: int = 10,
+ total_offset: float = 0.0,
+ threshold: BiasThreshold = BiasThreshold(upper=0.7, lower=0.3),
+ ) -> ConcreteBlockBasedTrialGenerator:
+ ab = AntiBiasParameters(
+ maximum_water_corrections=maximum_water_corrections,
+ bias_window_length=bias_window_length,
+ intervention_interval=intervention_interval,
+ threshold=threshold,
+ )
+ spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab)
+ gen = spec.create_generator()
+ gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10)
+ gen.total_lickspout_offset = total_offset
+ gen.bias = bias
+ gen.trials_in_bias_intervention = trials_in_bias_intervention
+ gen.water_corrections = water_corrections
+ gen.is_right_choice_history = [True] * 100
+ gen.reward_history = [True] * 100
+ return gen
+
+ def test_returns_false_when_antibias_disabled(self):
+ """Antibias should never trigger when antibias_parameters is None."""
+ spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=None)
+ gen = spec.create_generator()
+ self.assertFalse(gen._are_antibias_conditions_met())
+
+ def test_returns_false_before_intervention_interval(self):
+ """Condition should not trigger before the intervention interval is exceeded."""
+
+ gen = self._make_generator(bias=0.5, intervention_interval=10)
+ gen.trials_in_bias_intervention = 5
+ self.assertFalse(gen._are_antibias_conditions_met())
+
+ def test_returns_false_when_bias_within_thresholds(self):
+ """No intervention when bias sits between lower and upper thresholds."""
+ gen = self._make_generator(
+ bias=0.5,
+ intervention_interval=10,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ bias_window_length=5,
+ )
+ gen.trials_in_bias_intervention = 15
+ gen.is_right_choice_history = [True, False] * 50
+ gen.reward_history = [True] * 100
+ result = gen._are_antibias_conditions_met()
+
+ self.assertFalse(result)
+
+ def test_returns_true_when_bias_above_upper_threshold(self):
+ """Intervention when bias is above threshold"""
+ gen = self._make_generator(
+ bias=0.9,
+ intervention_interval=10,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ bias_window_length=5,
+ )
+ gen.trials_in_bias_intervention = 15
+ gen.is_right_choice_history = [True] * 100
+ gen.reward_history = [True] * 100
+ result = gen._are_antibias_conditions_met()
+
+ self.assertTrue(result)
+
+ def test_returns_true_when_bias_below_lower_threshold(self):
+ """Intervention when bias is below threshold"""
+ gen = self._make_generator(
+ bias=0.2,
+ intervention_interval=10,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ bias_window_length=5,
+ )
+ gen.trials_in_bias_intervention = 15
+ gen.is_right_choice_history = [False] * 100
+ gen.reward_history = [False] * 100
+ result = gen._are_antibias_conditions_met()
+
+ self.assertTrue(result)
+
+ def test_bias_stored_on_generator_after_check(self):
+ """The computed bias value should be saved on the generator."""
+ gen = self._make_generator(
+ bias=0,
+ intervention_interval=10,
+ bias_window_length=5,
+ )
+
+ with self._patch_bias(0.42):
+ gen.update(TrialOutcome(is_rewarded=True, is_right_choice=True, trial=Trial()))
+
+ self.assertAlmostEqual(gen.bias, 0.42)
+
+ def test_gives_right_water_on_left_bias(self):
+ """Negative bias (left bias) → give right water."""
+ gen = self._make_generator(bias=-0.9, maximum_water_corrections=5)
+ is_right, delta = gen._determine_antibias_intervention()
+ self.assertTrue(is_right)
+ self.assertEqual(delta, 0.0)
+
+ def test_gives_left_water_on_right_bias(self):
+ """Positive bias (right bias) → give left water."""
+ gen = self._make_generator(bias=0.9, maximum_water_corrections=5)
+ is_right, delta = gen._determine_antibias_intervention()
+ self.assertFalse(is_right)
+ self.assertEqual(delta, 0.0)
+
+ def test_water_corrections_counter_increments(self):
+ gen = self._make_generator(bias=-0.9, water_corrections=2, maximum_water_corrections=5)
+ gen._determine_antibias_intervention()
+ self.assertEqual(gen.water_corrections, 3)
+
+ def test_switches_to_lickspout_after_max_corrections_left_bias(self):
+ """After exhausting water corrections, move lickspout right (combat left bias)."""
+ gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5)
+ is_right, delta = gen._determine_antibias_intervention()
+ self.assertIsNone(is_right)
+ self.assertGreater(delta, 0)
+
+ def test_switches_to_lickspout_after_max_corrections_right_bias(self):
+ """After exhausting water corrections, move lickspout left (combat right bias)."""
+ gen = self._make_generator(bias=0.9, water_corrections=5, maximum_water_corrections=5)
+ is_right, delta = gen._determine_antibias_intervention()
+ self.assertIsNone(is_right)
+ self.assertLess(delta, 0)
+
+ def test_water_corrections_reset_after_lickspout_move(self):
+ gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5)
+ gen._determine_antibias_intervention()
+ self.assertEqual(gen.water_corrections, 0)
+
+ # #### Test lickspout centering ####
+
+ def test_no_centering_when_offset_is_zero(self):
+ """No correction when already centered, even if bias drops below lower threshold."""
+ gen = self._make_generator(
+ bias=0.1,
+ total_offset=0.0,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ )
+ _, delta = gen._determine_antibias_intervention()
+ self.assertEqual(delta, 0.0)
+
+ def test_centering_moves_toward_zero_from_positive_offset(self):
+ """Positive offset + low bias → negative delta (move back left)."""
+ gen = self._make_generator(
+ bias=0.1,
+ total_offset=1.0,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ )
+ _, delta = gen._determine_antibias_intervention()
+ self.assertLess(delta, 0)
+
+ def test_centering_moves_toward_zero_from_negative_offset(self):
+ """Negative offset + low bias → positive delta (move back right)."""
+ gen = self._make_generator(
+ bias=0.1,
+ total_offset=-1.0,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ )
+ _, delta = gen._determine_antibias_intervention()
+ self.assertGreater(delta, 0)
+
+ def test_centering_step_capped_at_offset_magnitude(self):
+ """Centering delta should not overshoot: capped at min(0.5, |offset|)."""
+ gen = self._make_generator(
+ bias=0.1,
+ total_offset=0.2,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ )
+ _, delta = gen._determine_antibias_intervention()
+ self.assertLessEqual(abs(delta), 0.2)
+
+ def test_total_lickspout_offset_updated_after_move(self):
+ """total_lickspout_offset should accumulate the delta applied."""
+ gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5, total_offset=0.0)
+ _, delta = gen._determine_antibias_intervention()
+ self.assertAlmostEqual(gen.total_lickspout_offset, delta)
+
+ #### Test next ####
+
+ def test_next_gives_right_autowater_on_left_bias(self):
+ gen = self._make_generator(bias=-0.9)
+ trial = gen.next()
+ self.assertIsNotNone(trial)
+ self.assertTrue(trial.is_auto_response_right)
+
+ def test_next_gives_left_autowater_on_right_bias(self):
+ gen = self._make_generator(bias=0.9)
+ trial = gen.next()
+ self.assertIsNotNone(trial)
+ self.assertFalse(trial.is_auto_response_right)
+
+ def test_next_no_antibias_when_below_interval(self):
+ """No antibias effect when trials_in_bias_intervention has not exceeded interval."""
+ gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5)
+ trial = gen.next()
+ self.assertIsNone(trial.is_auto_response_right)
+
+ def test_next_antibias_overrides_autowater(self):
+ """When both autowater and antibias conditions are met, antibias takes precedence."""
+ ab = AntiBiasParameters(
+ intervention_interval=10,
+ threshold=BiasThreshold(upper=0.7, lower=0.3),
+ maximum_water_corrections=5,
+ bias_window_length=5,
+ )
+ aw = AutoWaterParameters(min_ignored_trials=1, min_unrewarded_trials=1, reward_fraction=0.8)
+ spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab, autowater_parameters=aw)
+ gen = spec.create_generator()
+ gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10)
+ gen.bias = -0.9
+ gen.trials_in_bias_intervention = 15
+ gen.is_right_choice_history = [None] # ignored trial → autowater would also fire
+ gen.reward_history = [False]
+ trial = gen.next()
+
+ # Antibias (left bias → give right water) should win
+ self.assertTrue(trial.is_auto_response_right)
+
+ def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self):
+ """After max water corrections, next() should produce a nonzero lickspout delta."""
+ gen = self._make_generator(bias=-0.9, water_corrections=5)
+ trial = gen.next()
+ self.assertEqual(trial.lickspout_offset_delta, 0.05)
+
+ def test_next_no_lickspout_delta_when_antibias_not_triggered(self):
+ gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5)
+ trial = gen.next()
+ self.assertEqual(trial.lickspout_offset_delta, 0)
+
+
class TestBlockBaseBaitingTrialGenerator(unittest.TestCase):
def setUp(self):
self.spec = ConcreteBlockBasedTrialGeneratorSpec(is_baiting=True)
diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py
new file mode 100644
index 0000000..c3d23a4
--- /dev/null
+++ b/tests/trial_generators/test_calculate_bias.py
@@ -0,0 +1,107 @@
+import math
+import time
+import unittest
+
+import numpy as np
+
+from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome
+from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias
+
+
+def make_outcomes(n, right_prob=0.5, reward_prob=0.5, **trial_kwargs) -> list[TrialOutcome]:
+ outcomes = []
+ for _ in range(n):
+ is_right = bool(np.random.rand() < right_prob)
+ is_rewarded = bool(np.random.rand() < reward_prob)
+ outcomes.append(TrialOutcome(is_right_choice=is_right, is_rewarded=is_rewarded, trial=Trial(**trial_kwargs)))
+ return outcomes
+
+
+class TestCalculateBias(unittest.TestCase):
+ def setUp(self):
+ np.random.seed(42)
+
+ def test_right_bias_is_positive(self):
+ """Always choosing right with alternating rewards should produce positive bias."""
+ outcomes = make_outcomes(100, 0.9, 0.5)
+ bias = calculate_bias(outcomes)
+ self.assertGreater(bias, 0)
+
+ def test_left_bias_is_negative(self):
+ """Always choosing left with alternating rewards should produce negative bias."""
+ outcomes = make_outcomes(100, 0.1, 0.5)
+ bias = calculate_bias(outcomes)
+ self.assertLess(bias, 0)
+
+ def test_unbiased_near_zero(self):
+ """Alternating choices should produce bias near zero."""
+ outcomes = make_outcomes(100)
+ bias = calculate_bias(outcomes)
+ self.assertAlmostEqual(bias, 0, delta=1.0)
+
+ def test_uniform_choices_returns_nan(self):
+ """All same choice and reward should return nan since bias is undefined."""
+ outcomes = make_outcomes(100, 1, 1)
+ bias = calculate_bias(outcomes)
+ self.assertTrue(math.isnan(bias))
+
+ def test_ignored_trials_excluded(self):
+ """Adding ignored trials should not change the result."""
+ valid = make_outcomes(80)
+ ignored = [TrialOutcome(is_right_choice=None, is_rewarded=False, trial=Trial()) for _ in range(20)]
+ self.assertEqual(calculate_bias(valid + ignored), calculate_bias(valid))
+
+ def test_auto_response_trials_excluded(self):
+ """Adding auto response trials should not change the result."""
+ valid = make_outcomes(80)
+ auto = make_outcomes(20, 1, 1, is_auto_response_right=True)
+ self.assertEqual(calculate_bias(valid + auto), calculate_bias(valid))
+
+ def test_only_uses_last_200_trials(self):
+ """Trials beyond the last 200 should not affect the result."""
+ old = make_outcomes(100, 0)
+ recent = make_outcomes(200, 0.9)
+ self.assertEqual(calculate_bias(old + recent), calculate_bias(recent))
+
+ def test_too_few_trials_returns_nan(self):
+ """Fewer trials than lag should return nan since regression cannot be fit."""
+ outcomes = make_outcomes(5)
+ bias = calculate_bias(outcomes)
+ self.assertTrue(math.isnan(bias))
+
+
+TIME_LIMIT_MS = 170
+N_REPEATS = 20
+
+
+class TestCalculateBiasTiming(unittest.TestCase):
+ def _assert_within_time_limit(self, n_trials: int):
+ times = []
+ for _ in range(N_REPEATS):
+ outcomes = make_outcomes(n_trials)
+ t0 = time.perf_counter()
+ calculate_bias(outcomes)
+ times.append((time.perf_counter() - t0) * 1000)
+ mean_ms = float(np.mean(times))
+ self.assertLess(
+ mean_ms,
+ TIME_LIMIT_MS,
+ f"calculate_bias with {n_trials} trials too slow: {mean_ms:.1f}ms (limit {TIME_LIMIT_MS}ms)",
+ )
+
+ def test_timing_50_trials(self):
+ self._assert_within_time_limit(50)
+
+ def test_timing_100_trials(self):
+ self._assert_within_time_limit(100)
+
+ def test_timing_200_trials(self):
+ self._assert_within_time_limit(200)
+
+ def test_timing_1000_trials(self):
+ # should be similar to 200 since only last 200 trials evaluated
+ self._assert_within_time_limit(1000)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/trial_generators/test_coupled_trial_generator.py b/tests/trial_generators/test_coupled_trial_generator.py
index e4cc59e..cae55c8 100644
--- a/tests/trial_generators/test_coupled_trial_generator.py
+++ b/tests/trial_generators/test_coupled_trial_generator.py
@@ -6,8 +6,7 @@
from aind_behavior_dynamic_foraging.task_logic.trial_generators import CoupledTrialGeneratorSpec
from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome
-
-from .util import simulate_response
+from tests.trial_generators.util import simulate_response
logging.basicConfig(level=logging.DEBUG)
diff --git a/tests/trial_generators/test_uncoupled_trial_generator.py b/tests/trial_generators/test_uncoupled_trial_generator.py
index 2b21422..d952172 100644
--- a/tests/trial_generators/test_uncoupled_trial_generator.py
+++ b/tests/trial_generators/test_uncoupled_trial_generator.py
@@ -10,8 +10,7 @@
UncoupledTrialGeneratorSpec,
)
from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial
-
-from .util import simulate_response
+from tests.trial_generators.util import simulate_response
logging.basicConfig(level=logging.DEBUG)
diff --git a/tests/trial_generators/test_warmup_trial_generator.py b/tests/trial_generators/test_warmup_trial_generator.py
index e478685..92a89eb 100644
--- a/tests/trial_generators/test_warmup_trial_generator.py
+++ b/tests/trial_generators/test_warmup_trial_generator.py
@@ -6,8 +6,7 @@
CoupledWarmupTrialGeneratorSpec,
)
from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome
-
-from .util import simulate_response
+from tests.trial_generators.util import simulate_response
def make_outcome(is_right_choice: bool | None, is_rewarded: bool) -> TrialOutcome:
diff --git a/uv.lock b/uv.lock
index 423b991..fcc655d 100644
--- a/uv.lock
+++ b/uv.lock
@@ -52,6 +52,7 @@ source = { editable = "." }
dependencies = [
{ name = "aind-behavior-services" },
{ name = "pydantic-settings" },
+ { name = "scikit-learn" },
]
[package.optional-dependencies]
@@ -81,6 +82,7 @@ requires-dist = [
{ name = "aind-behavior-services", specifier = ">=0.13.5" },
{ name = "contraqctor", marker = "extra == 'data'", specifier = ">=0.5.3" },
{ name = "pydantic-settings" },
+ { name = "scikit-learn", specifier = ">=1.8.0" },
]
provides-extras = ["data"]
@@ -1086,6 +1088,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" },
]
+[[package]]
+name = "joblib"
+version = "1.5.3"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" },
+]
+
[[package]]
name = "jsonpointer"
version = "3.1.1"
@@ -2506,6 +2517,56 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" },
]
+[[package]]
+name = "scikit-learn"
+version = "1.8.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "joblib" },
+ { name = "numpy" },
+ { name = "scipy" },
+ { name = "threadpoolctl" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd", size = 7335585, upload-time = "2025-12-10T07:08:53.618Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c9/92/53ea2181da8ac6bf27170191028aee7251f8f841f8d3edbfdcaf2008fde9/scikit_learn-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:146b4d36f800c013d267b29168813f7a03a43ecd2895d04861f1240b564421da", size = 8595835, upload-time = "2025-12-10T07:07:39.385Z" },
+ { url = "https://files.pythonhosted.org/packages/01/18/d154dc1638803adf987910cdd07097d9c526663a55666a97c124d09fb96a/scikit_learn-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f984ca4b14914e6b4094c5d52a32ea16b49832c03bd17a110f004db3c223e8e1", size = 8080381, upload-time = "2025-12-10T07:07:41.93Z" },
+ { url = "https://files.pythonhosted.org/packages/8a/44/226142fcb7b7101e64fdee5f49dbe6288d4c7af8abf593237b70fca080a4/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5e30adb87f0cc81c7690a84f7932dd66be5bac57cfe16b91cb9151683a4a2d3b", size = 8799632, upload-time = "2025-12-10T07:07:43.899Z" },
+ { url = "https://files.pythonhosted.org/packages/36/4d/4a67f30778a45d542bbea5db2dbfa1e9e100bf9ba64aefe34215ba9f11f6/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ada8121bcb4dac28d930febc791a69f7cb1673c8495e5eee274190b73a4559c1", size = 9103788, upload-time = "2025-12-10T07:07:45.982Z" },
+ { url = "https://files.pythonhosted.org/packages/89/3c/45c352094cfa60050bcbb967b1faf246b22e93cb459f2f907b600f2ceda5/scikit_learn-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:c57b1b610bd1f40ba43970e11ce62821c2e6569e4d74023db19c6b26f246cb3b", size = 8081706, upload-time = "2025-12-10T07:07:48.111Z" },
+ { url = "https://files.pythonhosted.org/packages/3d/46/5416595bb395757f754feb20c3d776553a386b661658fb21b7c814e89efe/scikit_learn-1.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:2838551e011a64e3053ad7618dda9310175f7515f1742fa2d756f7c874c05961", size = 7688451, upload-time = "2025-12-10T07:07:49.873Z" },
+ { url = "https://files.pythonhosted.org/packages/90/74/e6a7cc4b820e95cc38cf36cd74d5aa2b42e8ffc2d21fe5a9a9c45c1c7630/scikit_learn-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fb63362b5a7ddab88e52b6dbb47dac3fd7dafeee740dc6c8d8a446ddedade8e", size = 8548242, upload-time = "2025-12-10T07:07:51.568Z" },
+ { url = "https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5025ce924beccb28298246e589c691fe1b8c1c96507e6d27d12c5fadd85bfd76", size = 8079075, upload-time = "2025-12-10T07:07:53.697Z" },
+ { url = "https://files.pythonhosted.org/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4", size = 8660492, upload-time = "2025-12-10T07:07:55.574Z" },
+ { url = "https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a", size = 8931904, upload-time = "2025-12-10T07:07:57.666Z" },
+ { url = "https://files.pythonhosted.org/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:35c007dedb2ffe38fe3ee7d201ebac4a2deccd2408e8621d53067733e3c74809", size = 8019359, upload-time = "2025-12-10T07:07:59.838Z" },
+ { url = "https://files.pythonhosted.org/packages/24/90/344a67811cfd561d7335c1b96ca21455e7e472d281c3c279c4d3f2300236/scikit_learn-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:8c497fff237d7b4e07e9ef1a640887fa4fb765647f86fbe00f969ff6280ce2bb", size = 7641898, upload-time = "2025-12-10T07:08:01.36Z" },
+ { url = "https://files.pythonhosted.org/packages/03/aa/e22e0768512ce9255eba34775be2e85c2048da73da1193e841707f8f039c/scikit_learn-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0d6ae97234d5d7079dc0040990a6f7aeb97cb7fa7e8945f1999a429b23569e0a", size = 8513770, upload-time = "2025-12-10T07:08:03.251Z" },
+ { url = "https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:edec98c5e7c128328124a029bceb09eda2d526997780fef8d65e9a69eead963e", size = 8044458, upload-time = "2025-12-10T07:08:05.336Z" },
+ { url = "https://files.pythonhosted.org/packages/2d/5a/3f1caed8765f33eabb723596666da4ebbf43d11e96550fb18bdec42b467b/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74b66d8689d52ed04c271e1329f0c61635bcaf5b926db9b12d58914cdc01fe57", size = 8610341, upload-time = "2025-12-10T07:08:07.732Z" },
+ { url = "https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8fdf95767f989b0cfedb85f7ed8ca215d4be728031f56ff5a519ee1e3276dc2e", size = 8900022, upload-time = "2025-12-10T07:08:09.862Z" },
+ { url = "https://files.pythonhosted.org/packages/1c/f9/9b7563caf3ec8873e17a31401858efab6b39a882daf6c1bfa88879c0aa11/scikit_learn-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:2de443b9373b3b615aec1bb57f9baa6bb3a9bd093f1269ba95c17d870422b271", size = 7989409, upload-time = "2025-12-10T07:08:12.028Z" },
+ { url = "https://files.pythonhosted.org/packages/49/bd/1f4001503650e72c4f6009ac0c4413cb17d2d601cef6f71c0453da2732fc/scikit_learn-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:eddde82a035681427cbedded4e6eff5e57fa59216c2e3e90b10b19ab1d0a65c3", size = 7619760, upload-time = "2025-12-10T07:08:13.688Z" },
+ { url = "https://files.pythonhosted.org/packages/d2/7d/a630359fc9dcc95496588c8d8e3245cc8fd81980251079bc09c70d41d951/scikit_learn-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7cc267b6108f0a1499a734167282c00c4ebf61328566b55ef262d48e9849c735", size = 8826045, upload-time = "2025-12-10T07:08:15.215Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/56/a0c86f6930cfcd1c7054a2bc417e26960bb88d32444fe7f71d5c2cfae891/scikit_learn-1.8.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:fe1c011a640a9f0791146011dfd3c7d9669785f9fed2b2a5f9e207536cf5c2fd", size = 8420324, upload-time = "2025-12-10T07:08:17.561Z" },
+ { url = "https://files.pythonhosted.org/packages/46/1e/05962ea1cebc1cf3876667ecb14c283ef755bf409993c5946ade3b77e303/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72358cce49465d140cc4e7792015bb1f0296a9742d5622c67e31399b75468b9e", size = 8680651, upload-time = "2025-12-10T07:08:19.952Z" },
+ { url = "https://files.pythonhosted.org/packages/fe/56/a85473cd75f200c9759e3a5f0bcab2d116c92a8a02ee08ccd73b870f8bb4/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:80832434a6cc114f5219211eec13dcbc16c2bac0e31ef64c6d346cde3cf054cb", size = 8925045, upload-time = "2025-12-10T07:08:22.11Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/b7/64d8cfa896c64435ae57f4917a548d7ac7a44762ff9802f75a79b77cb633/scikit_learn-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ee787491dbfe082d9c3013f01f5991658b0f38aa8177e4cd4bf434c58f551702", size = 8507994, upload-time = "2025-12-10T07:08:23.943Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/37/e192ea709551799379958b4c4771ec507347027bb7c942662c7fbeba31cb/scikit_learn-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf97c10a3f5a7543f9b88cbf488d33d175e9146115a451ae34568597ba33dcde", size = 7869518, upload-time = "2025-12-10T07:08:25.71Z" },
+ { url = "https://files.pythonhosted.org/packages/24/05/1af2c186174cc92dcab2233f327336058c077d38f6fe2aceb08e6ab4d509/scikit_learn-1.8.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c22a2da7a198c28dd1a6e1136f19c830beab7fdca5b3e5c8bba8394f8a5c45b3", size = 8528667, upload-time = "2025-12-10T07:08:27.541Z" },
+ { url = "https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:6b595b07a03069a2b1740dc08c2299993850ea81cce4fe19b2421e0c970de6b7", size = 8066524, upload-time = "2025-12-10T07:08:29.822Z" },
+ { url = "https://files.pythonhosted.org/packages/be/ce/a0623350aa0b68647333940ee46fe45086c6060ec604874e38e9ab7d8e6c/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:29ffc74089f3d5e87dfca4c2c8450f88bdc61b0fc6ed5d267f3988f19a1309f6", size = 8657133, upload-time = "2025-12-10T07:08:31.865Z" },
+ { url = "https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb65db5d7531bccf3a4f6bec3462223bea71384e2cda41da0f10b7c292b9e7c4", size = 8923223, upload-time = "2025-12-10T07:08:34.166Z" },
+ { url = "https://files.pythonhosted.org/packages/76/18/a8def8f91b18cd1ba6e05dbe02540168cb24d47e8dcf69e8d00b7da42a08/scikit_learn-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:56079a99c20d230e873ea40753102102734c5953366972a71d5cb39a32bc40c6", size = 8096518, upload-time = "2025-12-10T07:08:36.339Z" },
+ { url = "https://files.pythonhosted.org/packages/d1/77/482076a678458307f0deb44e29891d6022617b2a64c840c725495bee343f/scikit_learn-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3bad7565bc9cf37ce19a7c0d107742b320c1285df7aab1a6e2d28780df167242", size = 7754546, upload-time = "2025-12-10T07:08:38.128Z" },
+ { url = "https://files.pythonhosted.org/packages/2d/d1/ef294ca754826daa043b2a104e59960abfab4cf653891037d19dd5b6f3cf/scikit_learn-1.8.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:4511be56637e46c25721e83d1a9cea9614e7badc7040c4d573d75fbe257d6fd7", size = 8848305, upload-time = "2025-12-10T07:08:41.013Z" },
+ { url = "https://files.pythonhosted.org/packages/5b/e2/b1f8b05138ee813b8e1a4149f2f0d289547e60851fd1bb268886915adbda/scikit_learn-1.8.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:a69525355a641bf8ef136a7fa447672fb54fe8d60cab5538d9eb7c6438543fb9", size = 8432257, upload-time = "2025-12-10T07:08:42.873Z" },
+ { url = "https://files.pythonhosted.org/packages/26/11/c32b2138a85dcb0c99f6afd13a70a951bfdff8a6ab42d8160522542fb647/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c2656924ec73e5939c76ac4c8b026fc203b83d8900362eb2599d8aee80e4880f", size = 8678673, upload-time = "2025-12-10T07:08:45.362Z" },
+ { url = "https://files.pythonhosted.org/packages/c7/57/51f2384575bdec454f4fe4e7a919d696c9ebce914590abf3e52d47607ab8/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15fc3b5d19cc2be65404786857f2e13c70c83dd4782676dd6814e3b89dc8f5b9", size = 8922467, upload-time = "2025-12-10T07:08:47.408Z" },
+ { url = "https://files.pythonhosted.org/packages/35/4d/748c9e2872637a57981a04adc038dacaa16ba8ca887b23e34953f0b3f742/scikit_learn-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:00d6f1d66fbcf4eba6e356e1420d33cc06c70a45bb1363cd6f6a8e4ebbbdece2", size = 8774395, upload-time = "2025-12-10T07:08:49.337Z" },
+ { url = "https://files.pythonhosted.org/packages/60/22/d7b2ebe4704a5e50790ba089d5c2ae308ab6bb852719e6c3bd4f04c3a363/scikit_learn-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f28dd15c6bb0b66ba09728cf09fd8736c304be29409bd8445a080c1280619e8c", size = 8002647, upload-time = "2025-12-10T07:08:51.601Z" },
+]
+
[[package]]
name = "scipy"
version = "1.17.1"
@@ -2817,6 +2878,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" },
]
+[[package]]
+name = "threadpoolctl"
+version = "3.6.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" },
+]
+
[[package]]
name = "tomli"
version = "2.4.0"