diff --git a/pyproject.toml b/pyproject.toml index 25dd12d..27846ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ readme = {file = "README.md", content-type = "text/markdown"} dependencies = [ "aind_behavior_services>=0.13.5", "pydantic-settings", + "scikit-learn>=1.8.0", ] [tool.uv.workspace] diff --git a/schema/aind_behavior_dynamic_foraging.json b/schema/aind_behavior_dynamic_foraging.json index 681846e..38cf055 100644 --- a/schema/aind_behavior_dynamic_foraging.json +++ b/schema/aind_behavior_dynamic_foraging.json @@ -313,6 +313,56 @@ "type": "object", "x-sgen-typename": "AllenNeuralDynamics.AindManipulator.AindManipulatorCalibration" }, + "AntiBiasParameters": { + "properties": { + "threshold": { + "$ref": "#/$defs/BiasThreshold", + "default": { + "upper": 0.7, + "lower": 0.3 + }, + "description": "Thresholds for bias correction intervention." + }, + "intervention_interval": { + "default": 10, + "description": "Trials between bias intervention.", + "minimum": 0, + "title": "Intervention Interval", + "type": "integer" + }, + "maximum_water_corrections": { + "default": 5, + "description": "Number of water correction to attempt.", + "minimum": 0, + "title": "Maximum Water Corrections", + "type": "integer" + }, + "reward_fraction": { + "default": 0.8, + "description": "Fraction of full reward volume delivered during auto water (0=none, 1=full).", + "maximum": 1, + "minimum": 0, + "title": "Reward Fraction", + "type": "number" + }, + "bias_window_length": { + "default": 200, + "description": "Trials to calculate bias over.", + "minimum": 0, + "title": "Bias Window Length", + "type": "integer" + }, + "lickspout_offset_delta": { + "default": 0.05, + "description": "Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.", + "minimum": 0, + "title": "Lickspout Offset Delta", + "type": "number" + } + }, + "title": "AntiBiasParameters", + "type": "object" + }, "AuditorySecondaryReinforcer": { "description": "Represents an auditory secondary reinforcer.", "properties": { @@ -511,7 +561,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -521,6 +571,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "reward_fraction": 0.8, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -656,6 +728,28 @@ "type": "object", "x-sgen-typename": "AllenNeuralDynamics.AindBehaviorServices.Distributions.BetaDistributionParameters" }, + "BiasThreshold": { + "properties": { + "upper": { + "default": 0.7, + "description": "Absolute value of the upper bias threshold.", + "maximum": 1, + "minimum": 0, + "title": "Upper", + "type": "number" + }, + "lower": { + "default": 0.3, + "description": "Absolute value of the lower bias threshold.", + "maximum": 1, + "minimum": 0, + "title": "Lower", + "type": "number" + } + }, + "title": "BiasThreshold", + "type": "object" + }, "BinomialDistribution": { "description": "A binomial probability distribution.\n\nModels the number of successes in a fixed number of independent\nBernoulli trials with constant success probability.", "properties": { @@ -811,7 +905,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -821,6 +915,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "reward_fraction": 0.8, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -1077,7 +1193,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -1087,6 +1203,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "reward_fraction": 0.8, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", @@ -1266,7 +1404,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -1276,6 +1414,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "reward_fraction": 0.8, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "const": true, "default": true, @@ -3403,6 +3563,26 @@ } ] }, + "TrialMetrics": { + "description": "Represents metrics of session", + "properties": { + "bias": { + "default": null, + "description": "Bias of session. Negative values correspond to left bias, positive right.", + "oneOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "title": "Bias" + } + }, + "title": "TrialMetrics", + "type": "object" + }, "TrialOutcome": { "description": "Represents the outcome of a single trial.", "properties": { @@ -3586,7 +3766,7 @@ "min_unrewarded_trials": 3, "reward_fraction": 0.8 }, - "description": "Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + "description": "Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", "oneOf": [ { "$ref": "#/$defs/AutoWaterParameters" @@ -3596,6 +3776,28 @@ } ] }, + "antibias_parameters": { + "default": { + "threshold": { + "lower": 0.3, + "upper": 0.7 + }, + "intervention_interval": 10, + "maximum_water_corrections": 5, + "reward_fraction": 0.8, + "bias_window_length": 200, + "lickspout_offset_delta": 0.05 + }, + "description": "Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", + "oneOf": [ + { + "$ref": "#/$defs/AntiBiasParameters" + }, + { + "type": "null" + } + ] + }, "is_baiting": { "default": false, "description": "Whether uncollected rewards carry over to the next trial.", diff --git a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs index fd4cb0f..021a37d 100644 --- a/src/Extensions/AindBehaviorDynamicForaging.Generated.cs +++ b/src/Extensions/AindBehaviorDynamicForaging.Generated.cs @@ -703,6 +703,184 @@ public override string ToString() } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class AntiBiasParameters + { + + private BiasThreshold _threshold; + + private int _interventionInterval; + + private int _maximumWaterCorrections; + + private double _rewardFraction; + + private int _biasWindowLength; + + private double _lickspoutOffsetDelta; + + public AntiBiasParameters() + { + _threshold = new BiasThreshold(); + _interventionInterval = 10; + _maximumWaterCorrections = 5; + _rewardFraction = 0.8D; + _biasWindowLength = 200; + _lickspoutOffsetDelta = 0.05D; + } + + protected AntiBiasParameters(AntiBiasParameters other) + { + _threshold = other._threshold; + _interventionInterval = other._interventionInterval; + _maximumWaterCorrections = other._maximumWaterCorrections; + _rewardFraction = other._rewardFraction; + _biasWindowLength = other._biasWindowLength; + _lickspoutOffsetDelta = other._lickspoutOffsetDelta; + } + + /// + /// Thresholds for bias correction intervention. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("threshold")] + [System.ComponentModel.DescriptionAttribute("Thresholds for bias correction intervention.")] + public BiasThreshold Threshold + { + get + { + return _threshold; + } + set + { + _threshold = value; + } + } + + /// + /// Trials between bias intervention. + /// + [Newtonsoft.Json.JsonPropertyAttribute("intervention_interval")] + [System.ComponentModel.DescriptionAttribute("Trials between bias intervention.")] + public int InterventionInterval + { + get + { + return _interventionInterval; + } + set + { + _interventionInterval = value; + } + } + + /// + /// Number of water correction to attempt. + /// + [Newtonsoft.Json.JsonPropertyAttribute("maximum_water_corrections")] + [System.ComponentModel.DescriptionAttribute("Number of water correction to attempt.")] + public int MaximumWaterCorrections + { + get + { + return _maximumWaterCorrections; + } + set + { + _maximumWaterCorrections = value; + } + } + + /// + /// Fraction of full reward volume delivered during auto water (0=none, 1=full). + /// + [Newtonsoft.Json.JsonPropertyAttribute("reward_fraction")] + [System.ComponentModel.DescriptionAttribute("Fraction of full reward volume delivered during auto water (0=none, 1=full).")] + public double RewardFraction + { + get + { + return _rewardFraction; + } + set + { + _rewardFraction = value; + } + } + + /// + /// Trials to calculate bias over. + /// + [Newtonsoft.Json.JsonPropertyAttribute("bias_window_length")] + [System.ComponentModel.DescriptionAttribute("Trials to calculate bias over.")] + public int BiasWindowLength + { + get + { + return _biasWindowLength; + } + set + { + _biasWindowLength = value; + } + } + + /// + /// Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute. + /// + [Newtonsoft.Json.JsonPropertyAttribute("lickspout_offset_delta")] + [System.ComponentModel.DescriptionAttribute("Distance (mm) to move the stage spouts by. This is a relative distance to the cur" + + "rent value, not absolute.")] + public double LickspoutOffsetDelta + { + get + { + return _lickspoutOffsetDelta; + } + set + { + _lickspoutOffsetDelta = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new AntiBiasParameters(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new AntiBiasParameters(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Threshold = " + _threshold + ", "); + stringBuilder.Append("InterventionInterval = " + _interventionInterval + ", "); + stringBuilder.Append("MaximumWaterCorrections = " + _maximumWaterCorrections + ", "); + stringBuilder.Append("RewardFraction = " + _rewardFraction + ", "); + stringBuilder.Append("BiasWindowLength = " + _biasWindowLength + ", "); + stringBuilder.Append("LickspoutOffsetDelta = " + _lickspoutOffsetDelta); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + /// /// Represents an auditory secondary reinforcer. /// @@ -897,6 +1075,8 @@ public partial class BaseCoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -909,6 +1089,7 @@ public BaseCoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); } @@ -922,6 +1103,7 @@ protected BaseCoupledTrialGeneratorSpec(BaseCoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; } @@ -1016,12 +1198,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -1034,6 +1216,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -1091,6 +1292,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters); return true; @@ -1259,6 +1461,94 @@ public override string ToString() } + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class BiasThreshold + { + + private double _upper; + + private double _lower; + + public BiasThreshold() + { + _upper = 0.7D; + _lower = 0.3D; + } + + protected BiasThreshold(BiasThreshold other) + { + _upper = other._upper; + _lower = other._lower; + } + + /// + /// Absolute value of the upper bias threshold. + /// + [Newtonsoft.Json.JsonPropertyAttribute("upper")] + [System.ComponentModel.DescriptionAttribute("Absolute value of the upper bias threshold.")] + public double Upper + { + get + { + return _upper; + } + set + { + _upper = value; + } + } + + /// + /// Absolute value of the lower bias threshold. + /// + [Newtonsoft.Json.JsonPropertyAttribute("lower")] + [System.ComponentModel.DescriptionAttribute("Absolute value of the lower bias threshold.")] + public double Lower + { + get + { + return _lower; + } + set + { + _lower = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new BiasThreshold(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new BiasThreshold(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Upper = " + _upper + ", "); + stringBuilder.Append("Lower = " + _lower); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] [Bonsai.CombinatorAttribute(MethodName="Generate")] @@ -1277,6 +1567,8 @@ public partial class BlockBasedTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; public BlockBasedTrialGeneratorSpec() @@ -1287,6 +1579,7 @@ public BlockBasedTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; } @@ -1299,6 +1592,7 @@ protected BlockBasedTrialGeneratorSpec(BlockBasedTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; } @@ -1392,12 +1686,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -1410,6 +1704,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -1449,6 +1762,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting); return true; } @@ -1976,6 +2290,8 @@ public partial class CoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -1998,6 +2314,7 @@ public CoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _rewardProbabilityParameters = new RewardProbabilityParameters(); _trialGenerationEndParameters = new CoupledTrialGenerationEndConditions(); @@ -2016,6 +2333,7 @@ protected CoupledTrialGeneratorSpec(CoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; _trialGenerationEndParameters = other._trialGenerationEndParameters; @@ -2115,12 +2433,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -2133,6 +2451,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -2275,6 +2612,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); @@ -2437,6 +2775,8 @@ public partial class CoupledWarmupTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private RewardProbabilityParameters _rewardProbabilityParameters; @@ -2451,6 +2791,7 @@ public CoupledWarmupTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = true; _rewardProbabilityParameters = new RewardProbabilityParameters(); _trialGenerationEndParameters = new CoupledWarmupTrialGenerationEndConditions(); @@ -2465,6 +2806,7 @@ protected CoupledWarmupTrialGeneratorSpec(CoupledWarmupTrialGeneratorSpec other) _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _rewardProbabilityParameters = other._rewardProbabilityParameters; _trialGenerationEndParameters = other._trialGenerationEndParameters; @@ -2560,12 +2902,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution Block } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -2578,6 +2920,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -2653,6 +3014,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("RewardProbabilityParameters = " + _rewardProbabilityParameters + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters); @@ -6147,6 +6509,75 @@ public override string ToString() } + /// + /// Represents metrics of session + /// + [System.CodeDom.Compiler.GeneratedCodeAttribute("Bonsai.Sgen", "0.9.0.0 (Newtonsoft.Json v13.0.0.0)")] + [System.ComponentModel.DescriptionAttribute("Represents metrics of session")] + [Bonsai.WorkflowElementCategoryAttribute(Bonsai.ElementCategory.Source)] + [Bonsai.CombinatorAttribute(MethodName="Generate")] + public partial class TrialMetrics + { + + private double? _bias; + + public TrialMetrics() + { + } + + protected TrialMetrics(TrialMetrics other) + { + _bias = other._bias; + } + + /// + /// Bias of session. Negative values correspond to left bias, positive right. + /// + [Newtonsoft.Json.JsonPropertyAttribute("bias")] + [System.ComponentModel.DescriptionAttribute("Bias of session. Negative values correspond to left bias, positive right.")] + public double? Bias + { + get + { + return _bias; + } + set + { + _bias = value; + } + } + + public System.IObservable Generate() + { + return System.Reactive.Linq.Observable.Defer(() => System.Reactive.Linq.Observable.Return(new TrialMetrics(this))); + } + + public System.IObservable Generate(System.IObservable source) + { + return System.Reactive.Linq.Observable.Select(source, _ => new TrialMetrics(this)); + } + + protected virtual bool PrintMembers(System.Text.StringBuilder stringBuilder) + { + stringBuilder.Append("Bias = " + _bias); + return true; + } + + public override string ToString() + { + System.Text.StringBuilder stringBuilder = new System.Text.StringBuilder(); + stringBuilder.Append(GetType().Name); + stringBuilder.Append(" { "); + if (PrintMembers(stringBuilder)) + { + stringBuilder.Append(" "); + } + stringBuilder.Append("}"); + return stringBuilder.ToString(); + } + } + + /// /// Represents the outcome of a single trial. /// @@ -6478,6 +6909,8 @@ public partial class UncoupledTrialGeneratorSpec : TrialGeneratorSpec private AutoWaterParameters _autowaterParameters; + private AntiBiasParameters _antibiasParameters; + private bool _isBaiting; private UncoupledTrialGenerationEndConditions _trialGenerationEndParameters; @@ -6494,6 +6927,7 @@ public UncoupledTrialGeneratorSpec() _interTrialIntervalDuration = new AllenNeuralDynamics.AindBehaviorServices.Distributions.Distribution(); _blockLength = new AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistribution(); _autowaterParameters = new AutoWaterParameters(); + _antibiasParameters = new AntiBiasParameters(); _isBaiting = false; _trialGenerationEndParameters = new UncoupledTrialGenerationEndConditions(); _rewardProbabilities = new System.Collections.Generic.List(); @@ -6509,6 +6943,7 @@ protected UncoupledTrialGeneratorSpec(UncoupledTrialGeneratorSpec other) : _interTrialIntervalDuration = other._interTrialIntervalDuration; _blockLength = other._blockLength; _autowaterParameters = other._autowaterParameters; + _antibiasParameters = other._antibiasParameters; _isBaiting = other._isBaiting; _trialGenerationEndParameters = other._trialGenerationEndParameters; _rewardProbabilities = other._rewardProbabilities; @@ -6605,12 +7040,12 @@ public AllenNeuralDynamics.AindBehaviorServices.Distributions.UniformDistributio } /// - /// Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. + /// Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds. /// [System.Xml.Serialization.XmlIgnoreAttribute()] [Newtonsoft.Json.JsonPropertyAttribute("autowater_parameters")] - [System.ComponentModel.DescriptionAttribute("Auto water settings. If set, free water is delivered when the animal exceeds the " + - "ignored or unrewarded trial thresholds.")] + [System.ComponentModel.DescriptionAttribute("Autowater settings. If set, free water is delivered when the animal exceeds the i" + + "gnored or unrewarded trial thresholds.")] public AutoWaterParameters AutowaterParameters { get @@ -6623,6 +7058,25 @@ public AutoWaterParameters AutowaterParameters } } + /// + /// Antibias settings. If set, trial generator will give water and move lickspouts to combat bias. + /// + [System.Xml.Serialization.XmlIgnoreAttribute()] + [Newtonsoft.Json.JsonPropertyAttribute("antibias_parameters")] + [System.ComponentModel.DescriptionAttribute("Antibias settings. If set, trial generator will give water and move lickspouts to" + + " combat bias.")] + public AntiBiasParameters AntibiasParameters + { + get + { + return _antibiasParameters; + } + set + { + _antibiasParameters = value; + } + } + /// /// Whether uncollected rewards carry over to the next trial. /// @@ -6715,6 +7169,7 @@ protected override bool PrintMembers(System.Text.StringBuilder stringBuilder) stringBuilder.Append("InterTrialIntervalDuration = " + _interTrialIntervalDuration + ", "); stringBuilder.Append("BlockLength = " + _blockLength + ", "); stringBuilder.Append("AutowaterParameters = " + _autowaterParameters + ", "); + stringBuilder.Append("AntibiasParameters = " + _antibiasParameters + ", "); stringBuilder.Append("IsBaiting = " + _isBaiting + ", "); stringBuilder.Append("TrialGenerationEndParameters = " + _trialGenerationEndParameters + ", "); stringBuilder.Append("RewardProbabilities = " + _rewardProbabilities + ", "); @@ -7855,6 +8310,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -7880,6 +8340,11 @@ public System.IObservable Process(System.IObservable(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -8025,6 +8490,11 @@ public System.IObservable Process(System.IObservable return Process(source); } + public System.IObservable Process(System.IObservable source) + { + return Process(source); + } + public System.IObservable Process(System.IObservable source) { return Process(source); @@ -8087,11 +8557,13 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] @@ -8121,6 +8593,7 @@ public System.IObservable Process(System.IObservable source) [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] + [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(Bonsai.Expressions.TypeMapping))] diff --git a/src/Extensions/DeserializeFromPyObject.cs b/src/Extensions/DeserializeFromPyObject.cs index 6b19749..558e0a0 100644 --- a/src/Extensions/DeserializeFromPyObject.cs +++ b/src/Extensions/DeserializeFromPyObject.cs @@ -15,6 +15,8 @@ [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] [System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] +[System.Xml.Serialization.XmlIncludeAttribute(typeof(TypeMapping))] + public class DeserializeFromPyObject : SingleArgumentExpressionBuilder { public DeserializeFromPyObject() diff --git a/src/Extensions/TaskEngine.bonsai b/src/Extensions/TaskEngine.bonsai index 46fea6c..4367cfd 100644 --- a/src/Extensions/TaskEngine.bonsai +++ b/src/Extensions/TaskEngine.bonsai @@ -1217,6 +1217,47 @@ true + + Metrics + + + + + + + + trial_generator + + + + + metrics + true + + + + + + + + + GlobalTrialMetrics + + + TrialMetrics + + + + + + + + + + + + + SampleNextTrial @@ -1262,6 +1303,7 @@ + diff --git a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py index 3f2e238..36b5435 100644 --- a/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py +++ b/src/aind_behavior_dynamic_foraging/data_contract/_dataset.py @@ -194,6 +194,13 @@ def make_dataset( name="SoftwareEvents", description="Software events generated by the workflow. The timestamps of these events are low precision and should not be used to align to physiology data.", data_streams=[ + SoftwareEvents( + name="TrialMetrics", + description="An event emitted with the metrics of a trial.", + reader_params=SoftwareEvents.make_params( + root_path / "behavior/SoftwareEvents/TrialMetrics.json" + ), + ), SoftwareEvents( name="TrialGeneratorSpec", description="An event emitted with the specification for the trial generator.", diff --git a/src/aind_behavior_dynamic_foraging/regenerate.py b/src/aind_behavior_dynamic_foraging/regenerate.py index 4dd1717..53a8389 100644 --- a/src/aind_behavior_dynamic_foraging/regenerate.py +++ b/src/aind_behavior_dynamic_foraging/regenerate.py @@ -20,6 +20,7 @@ def main(): Session, aind_behavior_dynamic_foraging.task_logic.trial_models.Trial, aind_behavior_dynamic_foraging.task_logic.trial_models.TrialOutcome, + aind_behavior_dynamic_foraging.task_logic.trial_models.TrialMetrics, ] model = pydantic.RootModel[Union[tuple(models)]] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py index 2405daa..285adaf 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/_base.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from ..trial_models import Trial, TrialOutcome +from ..trial_models import Trial, TrialMetrics, TrialOutcome class BaseTrialGeneratorSpecModel(BaseModel, abc.ABC): @@ -24,3 +24,6 @@ def next(self) -> Trial | None: def update(self, outcome: TrialOutcome | str) -> None: """Update the trial generator with the outcome of the previous trial.""" + + def metrics(self) -> TrialMetrics: + """Return metrics of session at current state of the trial generator.""" diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py index 39ddc92..84b2b72 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_generators/block_based_trial_generator.py @@ -12,7 +12,9 @@ from aind_behavior_services.task.distributions_utils import draw_sample from pydantic import BaseModel, Field -from ..trial_models import Metadata, Trial +from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias + +from ..trial_models import Metadata, Trial, TrialMetrics from ._base import BaseTrialGeneratorSpecModel, ITrialGenerator, TrialOutcome logger = logging.getLogger(__name__) @@ -36,6 +38,31 @@ class AutoWaterParameters(BaseModel): ge=0, le=1, description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", + ) # TODO: Not implemented yet + + +class BiasThreshold(BaseModel): + upper: float = Field(default=0.7, ge=0, le=1, description="Absolute value of the upper bias threshold.") + lower: float = Field(default=0.3, ge=0, le=1, description="Absolute value of the lower bias threshold.") + + +class AntiBiasParameters(BaseModel): + threshold: BiasThreshold = Field( + default=BiasThreshold(), validate_default=True, description="Thresholds for bias correction intervention." + ) + intervention_interval: int = Field(default=10, ge=0, description="Trials between bias intervention.") + maximum_water_corrections: int = Field(default=5, ge=0, description="Number of water correction to attempt.") + reward_fraction: float = Field( + default=0.8, + ge=0, + le=1, + description="Fraction of full reward volume delivered during auto water (0=none, 1=full).", + ) # TODO: Not implemented yet + bias_window_length: int = Field(default=200, ge=0, description="Trials to calculate bias over.") + lickspout_offset_delta: float = Field( + default=0.05, + ge=0, + description="Distance (mm) to move the stage spouts by. This is a relative distance to the current value, not absolute.", ) @@ -84,7 +111,13 @@ class BlockBasedTrialGeneratorSpec(BaseTrialGeneratorSpecModel): autowater_parameters: Optional[AutoWaterParameters] = Field( default=AutoWaterParameters(), validate_default=True, - description="Auto water settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + description="Autowater settings. If set, free water is delivered when the animal exceeds the ignored or unrewarded trial thresholds.", + ) + + antibias_parameters: Optional[AntiBiasParameters] = Field( + default=AntiBiasParameters(), + validate_default=True, + description="Antibias settings. If set, trial generator will give water and move lickspouts to combat bias.", ) is_baiting: bool = Field(default=False, description="Whether uncollected rewards carry over to the next trial.") @@ -106,7 +139,9 @@ class BlockBasedTrialGenerator(ITrialGenerator, ABC): reward_history: Record of whether each trial resulted in a reward. is_left_baited: Whether the left port currently has a baited reward. is_right_baited: Whether the right port currently has a baited reward. - + trials_in_bias_intervention: trials elapsed since last bias intervention + water_corrections: number of water corrections applied to combat bias + bias: bias of session. Negative values correspond to left bias, positive right. """ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: @@ -117,12 +152,19 @@ def __init__(self, spec: BlockBasedTrialGeneratorSpec) -> None: """ self.spec = spec + self.outcome_history: list[TrialOutcome] = [] self.is_right_choice_history: list[bool | None] = [] self.reward_history: list[bool] = [] self.is_left_baited: bool = False self.is_right_baited: bool = False self.block: Block + # antibias parameters + self.trials_in_bias_intervention = 0 + self.water_corrections = 0 + self.bias: Optional[float] = None + self.total_lickspout_offset = 0 + def update(self, outcome: TrialOutcome | str): """Updates generator state from the previous trial outcome. Records choice and reward history and manages baiting state. Args: @@ -132,6 +174,7 @@ def update(self, outcome: TrialOutcome | str): if isinstance(outcome, str): outcome = TrialOutcome.model_validate_json(outcome) + self.outcome_history.append(outcome) self.is_right_choice_history.append(outcome.is_right_choice) self.reward_history.append(outcome.is_rewarded) @@ -146,6 +189,8 @@ def update(self, outcome: TrialOutcome | str): # trial ignored so current baiting state retained pass + self.bias = calculate_bias(outcomes=self.outcome_history) + def next(self) -> Trial | None: """Generates the next trial in the session. @@ -176,10 +221,21 @@ def next(self) -> Trial | None: self.is_right_baited = self.block.p_right_reward > random_numbers[1] or self.is_right_baited logger.debug("Right baited: %s" % self.is_right_baited) - # determine autowater is_auto_response_right = None + + # determine autowater if is_autowater := self._are_autowater_conditions_met(): is_auto_response_right = True if self.block.p_right_reward > self.block.p_left_reward else False + logger.debug("Delivering autowater: is_auto_response_right = %s" % is_auto_response_right) + + # determine bias correction. Overrides autowater + lickspout_offset_delta = 0 + if self._are_antibias_conditions_met(): + is_auto_response_right, lickspout_offset_delta = self._determine_antibias_intervention() + logger.debug( + "Performing bias intervention: is_auto_response_right = %s, lickspout_offset_delta = %s." + % (is_auto_response_right, lickspout_offset_delta) + ) return Trial( p_reward_left=1 if (self.is_left_baited or is_auto_response_right is False) else self.block.p_left_reward, @@ -188,6 +244,7 @@ def next(self) -> Trial | None: response_deadline_duration=self.spec.response_duration, quiescence_period_duration=quiescent, inter_trial_interval_duration=iti, + lickspout_offset_delta=lickspout_offset_delta, is_auto_response_right=is_auto_response_right, metadata=Metadata( p_reward_left=self.block.p_left_reward, @@ -196,6 +253,11 @@ def next(self) -> Trial | None: ), ) + def metrics(self) -> TrialMetrics: + """Return metrics of session at current state of the trial generator.""" + + return TrialMetrics(bias=self.bias) + def _are_autowater_conditions_met(self) -> bool: """Checks whether autowater should be given. @@ -203,7 +265,8 @@ def _are_autowater_conditions_met(self) -> bool: True if autowater conditions are met, False otherwise. """ - if self.spec.autowater_parameters is None: # autowater disabled + if self.spec.autowater_parameters is None: + logger.debug("Autowater not configured.") return False min_ignore = self.spec.autowater_parameters.min_ignored_trials @@ -211,14 +274,71 @@ def _are_autowater_conditions_met(self) -> bool: is_ignored = [choice is None for choice in self.is_right_choice_history] if len(is_ignored) > min_ignore and all(is_ignored[-min_ignore:]): + logger.debug("Past %s trials ignored." % min_ignore) return True is_unrewarded = [not reward for reward in self.reward_history] if len(is_unrewarded) > min_unreward and all(is_unrewarded[-min_unreward:]): + logger.debug("Past %s trials unrewarded." % min_unreward) return True return False + def _are_antibias_conditions_met(self) -> bool: + """Checks whether antibias conditions are met. + + Returns: + True if antibias conditions are met, False otherwise. + """ + + if self.spec.antibias_parameters is None: + logger.debug("Anitbias not configured.") + return False + + if self.trials_in_bias_intervention > self.spec.antibias_parameters.intervention_interval: + if self.bias <= self.spec.antibias_parameters.threshold.lower: + logger.debug("Bias calculated below threshold: %s." % self.bias) + return True + + if self.bias >= self.spec.antibias_parameters.threshold.upper: + logger.debug("Bias calculated above threshold: %s." % self.bias) + return True + + return False + + def _determine_antibias_intervention(self) -> tuple[bool | None, float]: + """Determine anitbias interventions to perform: give water or move lickspouts + + Returns: + Tuple dictating is_auto_response_right and lickspout_offset_delta of trial + """ + + is_right_autowater = None + lickspout_offset_delta = 0 + ab_delta = self.spec.antibias_parameters.lickspout_offset_delta + if abs(self.bias) >= self.spec.antibias_parameters.threshold.upper: + if self.water_corrections < self.spec.antibias_parameters.maximum_water_corrections: + logger.debug("Correcting bias with water.") + is_right_autowater = ( + True if self.bias < 0 else False + ) # - bias values corresponds to left, so give right and vice versa + self.water_corrections += 1 + else: + logger.debug("Correcting bias with lickspout offset.") + lickspout_offset_delta = ab_delta if self.bias < 0 else -ab_delta # + values move lickspout right + self.water_corrections = 0 + + elif ( + abs(self.bias) <= self.spec.antibias_parameters.threshold.lower and self.total_lickspout_offset != 0 + ): # bias below lower threshold, move back towards center + logger.debug("Moving lickspout back toward center.") + delta = min(ab_delta, abs(self.total_lickspout_offset)) + lickspout_offset_delta = -delta if self.total_lickspout_offset > 0 else delta + + self.total_lickspout_offset += lickspout_offset_delta + + return is_right_autowater, lickspout_offset_delta + @abstractmethod def _are_end_conditions_met(self) -> bool: """Checks whether the session should end. @@ -229,6 +349,7 @@ def _are_end_conditions_met(self) -> bool: """ pass + @abstractmethod def _generate_next_block(*args, **kwargs) -> Block: """Abstract method. Subclasses must implement their own block switching logic. diff --git a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py index 7f28f4d..ee0a234 100644 --- a/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py +++ b/src/aind_behavior_dynamic_foraging/task_logic/trial_models.py @@ -120,3 +120,11 @@ class TrialOutcome(BaseModel): description="Reports the choice made by the subject. True for right, False for left, None for no choice." ) is_rewarded: bool = Field(description="Indicates whether the subject received a reward on this trial.") + + +class TrialMetrics(BaseModel): + """Represents metrics of session""" + + bias: Optional[float] = Field( + default=None, description="Bias of session. Negative values correspond to left bias, positive right." + ) diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py new file mode 100644 index 0000000..ed9492e --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/__init__.py @@ -0,0 +1,3 @@ +from .calculate_bias import calculate_bias + +__all__ = ["calculate_bias"] diff --git a/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py new file mode 100644 index 0000000..e2767f8 --- /dev/null +++ b/src/aind_behavior_dynamic_foraging/task_logic/utils/calculate_bias.py @@ -0,0 +1,86 @@ +import logging +from typing import List + +import numpy as np +from sklearn.linear_model import LogisticRegressionCV + +from aind_behavior_dynamic_foraging.task_logic.trial_models import TrialOutcome + +logger = logging.getLogger(__name__) + + +def calculate_bias(outcomes: List[TrialOutcome]) -> float: + """Estimate the side bias of an animal using logistic regression on recent trial history. + + Fits a Su2022-style logistic regression model using rewarded and unrewarded choice + history as predictors. The intercept of the fitted model is returned as the bias, + representing the animal's baseline tendency to choose right independent of reward history. + + Parameters + ---------- + outcomes : List[TrialOutcome] + List of trial outcomes. Auto-response and ignored trials are excluded. + Only the most recent 200 trials are used. + + Returns + ------- + float + The logistic regression intercept, representing side bias. + Positive values indicate a bias toward right, negative toward left. + """ + + solver = "liblinear" + l1_ratios = (0,) + lag = 5 + cv = 10 + cs = 10 + + # reduce outcomes to last 200 + outcomes = outcomes[-200:] + + # exclude auto response and ignored trials + filtered = [t for t in outcomes if t.is_right_choice is not None and t.trial.is_auto_response_right is None] + + if len(filtered) <= lag: + logger.warning("Not enough choices to calculate bias.") + return np.nan + + is_right_choice_history = np.array([t.is_right_choice for t in filtered], dtype=float) + is_rewarded_history = np.array([t.is_rewarded for t in filtered], dtype=float) + + # transform to +-1 space for zero centered logistic regression + choice_signed = 2 * is_right_choice_history - 1 # left=0 → -1, right=1 → +1 + reward_signed = 2 * is_rewarded_history - 1 # unrewarded=0 → -1, rewarded=1 → +1 + + rewarded_choice = choice_signed * (reward_signed == 1) + unrewarded_choice = choice_signed * (reward_signed == -1) + + trial_length = len(is_rewarded_history) + x = np.zeros((trial_length - lag, 2 * lag)) + for i in range(lag, trial_length): + x[i - lag] = np.hstack([choice[i - lag : i] for choice in [rewarded_choice, unrewarded_choice]]) + + y = choice_signed[lag:] + + n_right_choice = np.sum(y == 1) + n_left_choice = np.sum(y == -1) + if min(n_right_choice, n_left_choice) < cv: + logger.warning( + "Not enough trials per class to fit logistic regression (need %d per class, got %d right, %d left).", + cv, + n_right_choice, + n_left_choice, + ) + return np.nan + + logistic_reg = LogisticRegressionCV( + solver=solver, + l1_ratios=l1_ratios, + Cs=cs, + cv=cv, + use_legacy_attributes=True, + ) + logistic_reg.fit(x, y) + + bias = logistic_reg.intercept_[0] + return bias diff --git a/src/main.bonsai b/src/main.bonsai index f64dd49..a5e1dbe 100644 --- a/src/main.bonsai +++ b/src/main.bonsai @@ -150,6 +150,9 @@ GlobalTrial + + GlobalTrialMetrics + TaskLogicParameters @@ -230,21 +233,21 @@ - - - - + + + - - + + - - - - - + + + + + + diff --git a/tests/trial_generators/test_block_based_trial_generator.py b/tests/trial_generators/test_block_based_trial_generator.py index 661439d..29c966e 100644 --- a/tests/trial_generators/test_block_based_trial_generator.py +++ b/tests/trial_generators/test_block_based_trial_generator.py @@ -5,11 +5,14 @@ import numpy as np from aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator import ( + AntiBiasParameters, + AutoWaterParameters, + BiasThreshold, Block, BlockBasedTrialGenerator, BlockBasedTrialGeneratorSpec, ) -from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial +from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome logging.basicConfig(level=logging.DEBUG) @@ -49,6 +52,252 @@ def test_next_returns_correct_reward_probs(self): self.assertEqual(trial.p_reward_right, self.generator.block.p_right_reward) +class TestAntiBiasBlockBasedTrialGenerator(unittest.TestCase): + def _patch_bias(self, bias_value: float) -> dict: + + return patch( + "aind_behavior_dynamic_foraging.task_logic.trial_generators.block_based_trial_generator.calculate_bias", + return_value=bias_value, + ) + + def _make_generator( + self, + bias: float, + trials_in_bias_intervention: int = 15, + water_corrections: int = 0, + maximum_water_corrections: int = 5, + bias_window_length: int = 5, + intervention_interval: int = 10, + total_offset: float = 0.0, + threshold: BiasThreshold = BiasThreshold(upper=0.7, lower=0.3), + ) -> ConcreteBlockBasedTrialGenerator: + ab = AntiBiasParameters( + maximum_water_corrections=maximum_water_corrections, + bias_window_length=bias_window_length, + intervention_interval=intervention_interval, + threshold=threshold, + ) + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab) + gen = spec.create_generator() + gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10) + gen.total_lickspout_offset = total_offset + gen.bias = bias + gen.trials_in_bias_intervention = trials_in_bias_intervention + gen.water_corrections = water_corrections + gen.is_right_choice_history = [True] * 100 + gen.reward_history = [True] * 100 + return gen + + def test_returns_false_when_antibias_disabled(self): + """Antibias should never trigger when antibias_parameters is None.""" + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=None) + gen = spec.create_generator() + self.assertFalse(gen._are_antibias_conditions_met()) + + def test_returns_false_before_intervention_interval(self): + """Condition should not trigger before the intervention interval is exceeded.""" + + gen = self._make_generator(bias=0.5, intervention_interval=10) + gen.trials_in_bias_intervention = 5 + self.assertFalse(gen._are_antibias_conditions_met()) + + def test_returns_false_when_bias_within_thresholds(self): + """No intervention when bias sits between lower and upper thresholds.""" + gen = self._make_generator( + bias=0.5, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [True, False] * 50 + gen.reward_history = [True] * 100 + result = gen._are_antibias_conditions_met() + + self.assertFalse(result) + + def test_returns_true_when_bias_above_upper_threshold(self): + """Intervention when bias is above threshold""" + gen = self._make_generator( + bias=0.9, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [True] * 100 + gen.reward_history = [True] * 100 + result = gen._are_antibias_conditions_met() + + self.assertTrue(result) + + def test_returns_true_when_bias_below_lower_threshold(self): + """Intervention when bias is below threshold""" + gen = self._make_generator( + bias=0.2, + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + bias_window_length=5, + ) + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [False] * 100 + gen.reward_history = [False] * 100 + result = gen._are_antibias_conditions_met() + + self.assertTrue(result) + + def test_bias_stored_on_generator_after_check(self): + """The computed bias value should be saved on the generator.""" + gen = self._make_generator( + bias=0, + intervention_interval=10, + bias_window_length=5, + ) + + with self._patch_bias(0.42): + gen.update(TrialOutcome(is_rewarded=True, is_right_choice=True, trial=Trial())) + + self.assertAlmostEqual(gen.bias, 0.42) + + def test_gives_right_water_on_left_bias(self): + """Negative bias (left bias) → give right water.""" + gen = self._make_generator(bias=-0.9, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertTrue(is_right) + self.assertEqual(delta, 0.0) + + def test_gives_left_water_on_right_bias(self): + """Positive bias (right bias) → give left water.""" + gen = self._make_generator(bias=0.9, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertFalse(is_right) + self.assertEqual(delta, 0.0) + + def test_water_corrections_counter_increments(self): + gen = self._make_generator(bias=-0.9, water_corrections=2, maximum_water_corrections=5) + gen._determine_antibias_intervention() + self.assertEqual(gen.water_corrections, 3) + + def test_switches_to_lickspout_after_max_corrections_left_bias(self): + """After exhausting water corrections, move lickspout right (combat left bias).""" + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertIsNone(is_right) + self.assertGreater(delta, 0) + + def test_switches_to_lickspout_after_max_corrections_right_bias(self): + """After exhausting water corrections, move lickspout left (combat right bias).""" + gen = self._make_generator(bias=0.9, water_corrections=5, maximum_water_corrections=5) + is_right, delta = gen._determine_antibias_intervention() + self.assertIsNone(is_right) + self.assertLess(delta, 0) + + def test_water_corrections_reset_after_lickspout_move(self): + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5) + gen._determine_antibias_intervention() + self.assertEqual(gen.water_corrections, 0) + + # #### Test lickspout centering #### + + def test_no_centering_when_offset_is_zero(self): + """No correction when already centered, even if bias drops below lower threshold.""" + gen = self._make_generator( + bias=0.1, + total_offset=0.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertEqual(delta, 0.0) + + def test_centering_moves_toward_zero_from_positive_offset(self): + """Positive offset + low bias → negative delta (move back left).""" + gen = self._make_generator( + bias=0.1, + total_offset=1.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertLess(delta, 0) + + def test_centering_moves_toward_zero_from_negative_offset(self): + """Negative offset + low bias → positive delta (move back right).""" + gen = self._make_generator( + bias=0.1, + total_offset=-1.0, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertGreater(delta, 0) + + def test_centering_step_capped_at_offset_magnitude(self): + """Centering delta should not overshoot: capped at min(0.5, |offset|).""" + gen = self._make_generator( + bias=0.1, + total_offset=0.2, + threshold=BiasThreshold(upper=0.7, lower=0.3), + ) + _, delta = gen._determine_antibias_intervention() + self.assertLessEqual(abs(delta), 0.2) + + def test_total_lickspout_offset_updated_after_move(self): + """total_lickspout_offset should accumulate the delta applied.""" + gen = self._make_generator(bias=-0.9, water_corrections=5, maximum_water_corrections=5, total_offset=0.0) + _, delta = gen._determine_antibias_intervention() + self.assertAlmostEqual(gen.total_lickspout_offset, delta) + + #### Test next #### + + def test_next_gives_right_autowater_on_left_bias(self): + gen = self._make_generator(bias=-0.9) + trial = gen.next() + self.assertIsNotNone(trial) + self.assertTrue(trial.is_auto_response_right) + + def test_next_gives_left_autowater_on_right_bias(self): + gen = self._make_generator(bias=0.9) + trial = gen.next() + self.assertIsNotNone(trial) + self.assertFalse(trial.is_auto_response_right) + + def test_next_no_antibias_when_below_interval(self): + """No antibias effect when trials_in_bias_intervention has not exceeded interval.""" + gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) + trial = gen.next() + self.assertIsNone(trial.is_auto_response_right) + + def test_next_antibias_overrides_autowater(self): + """When both autowater and antibias conditions are met, antibias takes precedence.""" + ab = AntiBiasParameters( + intervention_interval=10, + threshold=BiasThreshold(upper=0.7, lower=0.3), + maximum_water_corrections=5, + bias_window_length=5, + ) + aw = AutoWaterParameters(min_ignored_trials=1, min_unrewarded_trials=1, reward_fraction=0.8) + spec = ConcreteBlockBasedTrialGeneratorSpec(antibias_parameters=ab, autowater_parameters=aw) + gen = spec.create_generator() + gen.block = Block(p_left_reward=0.2, p_right_reward=0.8, left_length=10, right_length=10) + gen.bias = -0.9 + gen.trials_in_bias_intervention = 15 + gen.is_right_choice_history = [None] # ignored trial → autowater would also fire + gen.reward_history = [False] + trial = gen.next() + + # Antibias (left bias → give right water) should win + self.assertTrue(trial.is_auto_response_right) + + def test_next_lickspout_delta_nonzero_after_corrections_exhausted(self): + """After max water corrections, next() should produce a nonzero lickspout delta.""" + gen = self._make_generator(bias=-0.9, water_corrections=5) + trial = gen.next() + self.assertEqual(trial.lickspout_offset_delta, 0.05) + + def test_next_no_lickspout_delta_when_antibias_not_triggered(self): + gen = self._make_generator(bias=-0.9, trials_in_bias_intervention=5) + trial = gen.next() + self.assertEqual(trial.lickspout_offset_delta, 0) + + class TestBlockBaseBaitingTrialGenerator(unittest.TestCase): def setUp(self): self.spec = ConcreteBlockBasedTrialGeneratorSpec(is_baiting=True) diff --git a/tests/trial_generators/test_calculate_bias.py b/tests/trial_generators/test_calculate_bias.py new file mode 100644 index 0000000..c3d23a4 --- /dev/null +++ b/tests/trial_generators/test_calculate_bias.py @@ -0,0 +1,107 @@ +import math +import time +import unittest + +import numpy as np + +from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome +from aind_behavior_dynamic_foraging.task_logic.utils.calculate_bias import calculate_bias + + +def make_outcomes(n, right_prob=0.5, reward_prob=0.5, **trial_kwargs) -> list[TrialOutcome]: + outcomes = [] + for _ in range(n): + is_right = bool(np.random.rand() < right_prob) + is_rewarded = bool(np.random.rand() < reward_prob) + outcomes.append(TrialOutcome(is_right_choice=is_right, is_rewarded=is_rewarded, trial=Trial(**trial_kwargs))) + return outcomes + + +class TestCalculateBias(unittest.TestCase): + def setUp(self): + np.random.seed(42) + + def test_right_bias_is_positive(self): + """Always choosing right with alternating rewards should produce positive bias.""" + outcomes = make_outcomes(100, 0.9, 0.5) + bias = calculate_bias(outcomes) + self.assertGreater(bias, 0) + + def test_left_bias_is_negative(self): + """Always choosing left with alternating rewards should produce negative bias.""" + outcomes = make_outcomes(100, 0.1, 0.5) + bias = calculate_bias(outcomes) + self.assertLess(bias, 0) + + def test_unbiased_near_zero(self): + """Alternating choices should produce bias near zero.""" + outcomes = make_outcomes(100) + bias = calculate_bias(outcomes) + self.assertAlmostEqual(bias, 0, delta=1.0) + + def test_uniform_choices_returns_nan(self): + """All same choice and reward should return nan since bias is undefined.""" + outcomes = make_outcomes(100, 1, 1) + bias = calculate_bias(outcomes) + self.assertTrue(math.isnan(bias)) + + def test_ignored_trials_excluded(self): + """Adding ignored trials should not change the result.""" + valid = make_outcomes(80) + ignored = [TrialOutcome(is_right_choice=None, is_rewarded=False, trial=Trial()) for _ in range(20)] + self.assertEqual(calculate_bias(valid + ignored), calculate_bias(valid)) + + def test_auto_response_trials_excluded(self): + """Adding auto response trials should not change the result.""" + valid = make_outcomes(80) + auto = make_outcomes(20, 1, 1, is_auto_response_right=True) + self.assertEqual(calculate_bias(valid + auto), calculate_bias(valid)) + + def test_only_uses_last_200_trials(self): + """Trials beyond the last 200 should not affect the result.""" + old = make_outcomes(100, 0) + recent = make_outcomes(200, 0.9) + self.assertEqual(calculate_bias(old + recent), calculate_bias(recent)) + + def test_too_few_trials_returns_nan(self): + """Fewer trials than lag should return nan since regression cannot be fit.""" + outcomes = make_outcomes(5) + bias = calculate_bias(outcomes) + self.assertTrue(math.isnan(bias)) + + +TIME_LIMIT_MS = 170 +N_REPEATS = 20 + + +class TestCalculateBiasTiming(unittest.TestCase): + def _assert_within_time_limit(self, n_trials: int): + times = [] + for _ in range(N_REPEATS): + outcomes = make_outcomes(n_trials) + t0 = time.perf_counter() + calculate_bias(outcomes) + times.append((time.perf_counter() - t0) * 1000) + mean_ms = float(np.mean(times)) + self.assertLess( + mean_ms, + TIME_LIMIT_MS, + f"calculate_bias with {n_trials} trials too slow: {mean_ms:.1f}ms (limit {TIME_LIMIT_MS}ms)", + ) + + def test_timing_50_trials(self): + self._assert_within_time_limit(50) + + def test_timing_100_trials(self): + self._assert_within_time_limit(100) + + def test_timing_200_trials(self): + self._assert_within_time_limit(200) + + def test_timing_1000_trials(self): + # should be similar to 200 since only last 200 trials evaluated + self._assert_within_time_limit(1000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/trial_generators/test_coupled_trial_generator.py b/tests/trial_generators/test_coupled_trial_generator.py index e4cc59e..cae55c8 100644 --- a/tests/trial_generators/test_coupled_trial_generator.py +++ b/tests/trial_generators/test_coupled_trial_generator.py @@ -6,8 +6,7 @@ from aind_behavior_dynamic_foraging.task_logic.trial_generators import CoupledTrialGeneratorSpec from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome - -from .util import simulate_response +from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_uncoupled_trial_generator.py b/tests/trial_generators/test_uncoupled_trial_generator.py index 2b21422..d952172 100644 --- a/tests/trial_generators/test_uncoupled_trial_generator.py +++ b/tests/trial_generators/test_uncoupled_trial_generator.py @@ -10,8 +10,7 @@ UncoupledTrialGeneratorSpec, ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial - -from .util import simulate_response +from tests.trial_generators.util import simulate_response logging.basicConfig(level=logging.DEBUG) diff --git a/tests/trial_generators/test_warmup_trial_generator.py b/tests/trial_generators/test_warmup_trial_generator.py index e478685..92a89eb 100644 --- a/tests/trial_generators/test_warmup_trial_generator.py +++ b/tests/trial_generators/test_warmup_trial_generator.py @@ -6,8 +6,7 @@ CoupledWarmupTrialGeneratorSpec, ) from aind_behavior_dynamic_foraging.task_logic.trial_models import Trial, TrialOutcome - -from .util import simulate_response +from tests.trial_generators.util import simulate_response def make_outcome(is_right_choice: bool | None, is_rewarded: bool) -> TrialOutcome: diff --git a/uv.lock b/uv.lock index 423b991..fcc655d 100644 --- a/uv.lock +++ b/uv.lock @@ -52,6 +52,7 @@ source = { editable = "." } dependencies = [ { name = "aind-behavior-services" }, { name = "pydantic-settings" }, + { name = "scikit-learn" }, ] [package.optional-dependencies] @@ -81,6 +82,7 @@ requires-dist = [ { name = "aind-behavior-services", specifier = ">=0.13.5" }, { name = "contraqctor", marker = "extra == 'data'", specifier = ">=0.5.3" }, { name = "pydantic-settings" }, + { name = "scikit-learn", specifier = ">=1.8.0" }, ] provides-extras = ["data"] @@ -1086,6 +1088,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, +] + [[package]] name = "jsonpointer" version = "3.1.1" @@ -2506,6 +2517,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd", size = 7335585, upload-time = "2025-12-10T07:08:53.618Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/92/53ea2181da8ac6bf27170191028aee7251f8f841f8d3edbfdcaf2008fde9/scikit_learn-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:146b4d36f800c013d267b29168813f7a03a43ecd2895d04861f1240b564421da", size = 8595835, upload-time = "2025-12-10T07:07:39.385Z" }, + { url = "https://files.pythonhosted.org/packages/01/18/d154dc1638803adf987910cdd07097d9c526663a55666a97c124d09fb96a/scikit_learn-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f984ca4b14914e6b4094c5d52a32ea16b49832c03bd17a110f004db3c223e8e1", size = 8080381, upload-time = "2025-12-10T07:07:41.93Z" }, + { url = "https://files.pythonhosted.org/packages/8a/44/226142fcb7b7101e64fdee5f49dbe6288d4c7af8abf593237b70fca080a4/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5e30adb87f0cc81c7690a84f7932dd66be5bac57cfe16b91cb9151683a4a2d3b", size = 8799632, upload-time = "2025-12-10T07:07:43.899Z" }, + { url = "https://files.pythonhosted.org/packages/36/4d/4a67f30778a45d542bbea5db2dbfa1e9e100bf9ba64aefe34215ba9f11f6/scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ada8121bcb4dac28d930febc791a69f7cb1673c8495e5eee274190b73a4559c1", size = 9103788, upload-time = "2025-12-10T07:07:45.982Z" }, + { url = "https://files.pythonhosted.org/packages/89/3c/45c352094cfa60050bcbb967b1faf246b22e93cb459f2f907b600f2ceda5/scikit_learn-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:c57b1b610bd1f40ba43970e11ce62821c2e6569e4d74023db19c6b26f246cb3b", size = 8081706, upload-time = "2025-12-10T07:07:48.111Z" }, + { url = "https://files.pythonhosted.org/packages/3d/46/5416595bb395757f754feb20c3d776553a386b661658fb21b7c814e89efe/scikit_learn-1.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:2838551e011a64e3053ad7618dda9310175f7515f1742fa2d756f7c874c05961", size = 7688451, upload-time = "2025-12-10T07:07:49.873Z" }, + { url = "https://files.pythonhosted.org/packages/90/74/e6a7cc4b820e95cc38cf36cd74d5aa2b42e8ffc2d21fe5a9a9c45c1c7630/scikit_learn-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fb63362b5a7ddab88e52b6dbb47dac3fd7dafeee740dc6c8d8a446ddedade8e", size = 8548242, upload-time = "2025-12-10T07:07:51.568Z" }, + { url = "https://files.pythonhosted.org/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5025ce924beccb28298246e589c691fe1b8c1c96507e6d27d12c5fadd85bfd76", size = 8079075, upload-time = "2025-12-10T07:07:53.697Z" }, + { url = "https://files.pythonhosted.org/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4", size = 8660492, upload-time = "2025-12-10T07:07:55.574Z" }, + { url = "https://files.pythonhosted.org/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a", size = 8931904, upload-time = "2025-12-10T07:07:57.666Z" }, + { url = "https://files.pythonhosted.org/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:35c007dedb2ffe38fe3ee7d201ebac4a2deccd2408e8621d53067733e3c74809", size = 8019359, upload-time = "2025-12-10T07:07:59.838Z" }, + { url = "https://files.pythonhosted.org/packages/24/90/344a67811cfd561d7335c1b96ca21455e7e472d281c3c279c4d3f2300236/scikit_learn-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:8c497fff237d7b4e07e9ef1a640887fa4fb765647f86fbe00f969ff6280ce2bb", size = 7641898, upload-time = "2025-12-10T07:08:01.36Z" }, + { url = "https://files.pythonhosted.org/packages/03/aa/e22e0768512ce9255eba34775be2e85c2048da73da1193e841707f8f039c/scikit_learn-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0d6ae97234d5d7079dc0040990a6f7aeb97cb7fa7e8945f1999a429b23569e0a", size = 8513770, upload-time = "2025-12-10T07:08:03.251Z" }, + { url = "https://files.pythonhosted.org/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:edec98c5e7c128328124a029bceb09eda2d526997780fef8d65e9a69eead963e", size = 8044458, upload-time = "2025-12-10T07:08:05.336Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5a/3f1caed8765f33eabb723596666da4ebbf43d11e96550fb18bdec42b467b/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74b66d8689d52ed04c271e1329f0c61635bcaf5b926db9b12d58914cdc01fe57", size = 8610341, upload-time = "2025-12-10T07:08:07.732Z" }, + { url = "https://files.pythonhosted.org/packages/38/cf/06896db3f71c75902a8e9943b444a56e727418f6b4b4a90c98c934f51ed4/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8fdf95767f989b0cfedb85f7ed8ca215d4be728031f56ff5a519ee1e3276dc2e", size = 8900022, upload-time = "2025-12-10T07:08:09.862Z" }, + { url = "https://files.pythonhosted.org/packages/1c/f9/9b7563caf3ec8873e17a31401858efab6b39a882daf6c1bfa88879c0aa11/scikit_learn-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:2de443b9373b3b615aec1bb57f9baa6bb3a9bd093f1269ba95c17d870422b271", size = 7989409, upload-time = "2025-12-10T07:08:12.028Z" }, + { url = "https://files.pythonhosted.org/packages/49/bd/1f4001503650e72c4f6009ac0c4413cb17d2d601cef6f71c0453da2732fc/scikit_learn-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:eddde82a035681427cbedded4e6eff5e57fa59216c2e3e90b10b19ab1d0a65c3", size = 7619760, upload-time = "2025-12-10T07:08:13.688Z" }, + { url = "https://files.pythonhosted.org/packages/d2/7d/a630359fc9dcc95496588c8d8e3245cc8fd81980251079bc09c70d41d951/scikit_learn-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7cc267b6108f0a1499a734167282c00c4ebf61328566b55ef262d48e9849c735", size = 8826045, upload-time = "2025-12-10T07:08:15.215Z" }, + { url = "https://files.pythonhosted.org/packages/cc/56/a0c86f6930cfcd1c7054a2bc417e26960bb88d32444fe7f71d5c2cfae891/scikit_learn-1.8.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:fe1c011a640a9f0791146011dfd3c7d9669785f9fed2b2a5f9e207536cf5c2fd", size = 8420324, upload-time = "2025-12-10T07:08:17.561Z" }, + { url = "https://files.pythonhosted.org/packages/46/1e/05962ea1cebc1cf3876667ecb14c283ef755bf409993c5946ade3b77e303/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:72358cce49465d140cc4e7792015bb1f0296a9742d5622c67e31399b75468b9e", size = 8680651, upload-time = "2025-12-10T07:08:19.952Z" }, + { url = "https://files.pythonhosted.org/packages/fe/56/a85473cd75f200c9759e3a5f0bcab2d116c92a8a02ee08ccd73b870f8bb4/scikit_learn-1.8.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:80832434a6cc114f5219211eec13dcbc16c2bac0e31ef64c6d346cde3cf054cb", size = 8925045, upload-time = "2025-12-10T07:08:22.11Z" }, + { url = "https://files.pythonhosted.org/packages/cc/b7/64d8cfa896c64435ae57f4917a548d7ac7a44762ff9802f75a79b77cb633/scikit_learn-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ee787491dbfe082d9c3013f01f5991658b0f38aa8177e4cd4bf434c58f551702", size = 8507994, upload-time = "2025-12-10T07:08:23.943Z" }, + { url = "https://files.pythonhosted.org/packages/5e/37/e192ea709551799379958b4c4771ec507347027bb7c942662c7fbeba31cb/scikit_learn-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf97c10a3f5a7543f9b88cbf488d33d175e9146115a451ae34568597ba33dcde", size = 7869518, upload-time = "2025-12-10T07:08:25.71Z" }, + { url = "https://files.pythonhosted.org/packages/24/05/1af2c186174cc92dcab2233f327336058c077d38f6fe2aceb08e6ab4d509/scikit_learn-1.8.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c22a2da7a198c28dd1a6e1136f19c830beab7fdca5b3e5c8bba8394f8a5c45b3", size = 8528667, upload-time = "2025-12-10T07:08:27.541Z" }, + { url = "https://files.pythonhosted.org/packages/a8/25/01c0af38fe969473fb292bba9dc2b8f9b451f3112ff242c647fee3d0dfe7/scikit_learn-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:6b595b07a03069a2b1740dc08c2299993850ea81cce4fe19b2421e0c970de6b7", size = 8066524, upload-time = "2025-12-10T07:08:29.822Z" }, + { url = "https://files.pythonhosted.org/packages/be/ce/a0623350aa0b68647333940ee46fe45086c6060ec604874e38e9ab7d8e6c/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:29ffc74089f3d5e87dfca4c2c8450f88bdc61b0fc6ed5d267f3988f19a1309f6", size = 8657133, upload-time = "2025-12-10T07:08:31.865Z" }, + { url = "https://files.pythonhosted.org/packages/b8/cb/861b41341d6f1245e6ca80b1c1a8c4dfce43255b03df034429089ca2a2c5/scikit_learn-1.8.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb65db5d7531bccf3a4f6bec3462223bea71384e2cda41da0f10b7c292b9e7c4", size = 8923223, upload-time = "2025-12-10T07:08:34.166Z" }, + { url = "https://files.pythonhosted.org/packages/76/18/a8def8f91b18cd1ba6e05dbe02540168cb24d47e8dcf69e8d00b7da42a08/scikit_learn-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:56079a99c20d230e873ea40753102102734c5953366972a71d5cb39a32bc40c6", size = 8096518, upload-time = "2025-12-10T07:08:36.339Z" }, + { url = "https://files.pythonhosted.org/packages/d1/77/482076a678458307f0deb44e29891d6022617b2a64c840c725495bee343f/scikit_learn-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:3bad7565bc9cf37ce19a7c0d107742b320c1285df7aab1a6e2d28780df167242", size = 7754546, upload-time = "2025-12-10T07:08:38.128Z" }, + { url = "https://files.pythonhosted.org/packages/2d/d1/ef294ca754826daa043b2a104e59960abfab4cf653891037d19dd5b6f3cf/scikit_learn-1.8.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:4511be56637e46c25721e83d1a9cea9614e7badc7040c4d573d75fbe257d6fd7", size = 8848305, upload-time = "2025-12-10T07:08:41.013Z" }, + { url = "https://files.pythonhosted.org/packages/5b/e2/b1f8b05138ee813b8e1a4149f2f0d289547e60851fd1bb268886915adbda/scikit_learn-1.8.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:a69525355a641bf8ef136a7fa447672fb54fe8d60cab5538d9eb7c6438543fb9", size = 8432257, upload-time = "2025-12-10T07:08:42.873Z" }, + { url = "https://files.pythonhosted.org/packages/26/11/c32b2138a85dcb0c99f6afd13a70a951bfdff8a6ab42d8160522542fb647/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c2656924ec73e5939c76ac4c8b026fc203b83d8900362eb2599d8aee80e4880f", size = 8678673, upload-time = "2025-12-10T07:08:45.362Z" }, + { url = "https://files.pythonhosted.org/packages/c7/57/51f2384575bdec454f4fe4e7a919d696c9ebce914590abf3e52d47607ab8/scikit_learn-1.8.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15fc3b5d19cc2be65404786857f2e13c70c83dd4782676dd6814e3b89dc8f5b9", size = 8922467, upload-time = "2025-12-10T07:08:47.408Z" }, + { url = "https://files.pythonhosted.org/packages/35/4d/748c9e2872637a57981a04adc038dacaa16ba8ca887b23e34953f0b3f742/scikit_learn-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:00d6f1d66fbcf4eba6e356e1420d33cc06c70a45bb1363cd6f6a8e4ebbbdece2", size = 8774395, upload-time = "2025-12-10T07:08:49.337Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/d7b2ebe4704a5e50790ba089d5c2ae308ab6bb852719e6c3bd4f04c3a363/scikit_learn-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f28dd15c6bb0b66ba09728cf09fd8736c304be29409bd8445a080c1280619e8c", size = 8002647, upload-time = "2025-12-10T07:08:51.601Z" }, +] + [[package]] name = "scipy" version = "1.17.1" @@ -2817,6 +2878,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072, upload-time = "2024-07-29T01:10:08.203Z" }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + [[package]] name = "tomli" version = "2.4.0"