Skip to content

Commit d306bc2

Browse files
committed
Added support in KF for estimating state and observation offset parameters
1 parent e326a85 commit d306bc2

15 files changed

Lines changed: 741 additions & 606 deletions

src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ namespace Bonsai.ML.Lds.Torch;
1717
[TypeConverter(typeof(TensorOperatorConverter))]
1818
public class CreateKalmanFilter : IScalarTypeProvider
1919
{
20+
private Tensor _transitionMatrix;
21+
private Tensor _measurementFunction;
22+
private Tensor _processNoiseVariance;
23+
private Tensor _measurementNoiseVariance;
24+
private Tensor _initialMean;
25+
private Tensor _initialCovariance;
26+
private Tensor _stateOffset;
27+
private Tensor _observationOffset;
28+
2029
/// <inheritdoc/>
2130
[Description("The data type of the tensor elements.")]
2231
[TypeConverter(typeof(ScalarTypeConverter))]
@@ -38,7 +47,7 @@ public class CreateKalmanFilter : IScalarTypeProvider
3847
/// The number of observations in the Kalman filter model.
3948
/// </summary>
4049
public int? NumObservations { get; set; } = null;
41-
50+
4251
/// <summary>
4352
/// The state transition matrix.
4453
/// </summary>
@@ -177,12 +186,52 @@ public string InitialCovarianceXml
177186
set => _initialCovariance = TensorConverter.ConvertFromString(value, Type);
178187
}
179188

180-
private Tensor _transitionMatrix;
181-
private Tensor _measurementFunction;
182-
private Tensor _processNoiseVariance;
183-
private Tensor _measurementNoiseVariance;
184-
private Tensor _initialMean;
185-
private Tensor _initialCovariance;
189+
/// <summary>
190+
/// The state offset.
191+
/// </summary>
192+
[XmlIgnore]
193+
[TypeConverter(typeof(TensorConverter))]
194+
public Tensor StateOffset
195+
{
196+
get => _stateOffset;
197+
set => _stateOffset = value;
198+
}
199+
200+
/// <summary>
201+
/// The XML string representation of the state offset for serialization.
202+
/// </summary>
203+
[Browsable(false)]
204+
[XmlElement(nameof(StateOffset))]
205+
[EditorBrowsable(EditorBrowsableState.Never)]
206+
public string StateOffsetXml
207+
{
208+
get => TensorConverter.ConvertToString(_stateOffset, Type);
209+
set => _stateOffset = TensorConverter.ConvertFromString(value, Type);
210+
}
211+
212+
/// <summary>
213+
/// The observation offset.
214+
/// </summary>
215+
[XmlIgnore]
216+
[TypeConverter(typeof(TensorConverter))]
217+
public Tensor ObservationOffset
218+
{
219+
get => _observationOffset;
220+
set => _observationOffset = value;
221+
}
222+
223+
/// <summary>
224+
/// The XML string representation of the observation offset for serialization.
225+
/// </summary>
226+
[Browsable(false)]
227+
[XmlElement(nameof(ObservationOffset))]
228+
[EditorBrowsable(EditorBrowsableState.Never)]
229+
public string ObservationOffsetXml
230+
{
231+
get => TensorConverter.ConvertToString(_observationOffset, Type);
232+
set => _observationOffset = TensorConverter.ConvertFromString(value, Type);
233+
}
234+
186235

187236
/// <summary>
188237
/// Creates a Kalman filter model using the properties of this class.
@@ -198,6 +247,8 @@ public IObservable<KalmanFilter> Process()
198247
measurementNoiseVariance: MeasurementNoiseVariance,
199248
initialMean: InitialMean,
200249
initialCovariance: InitialCovariance,
250+
stateOffset: StateOffset,
251+
observationOffset: ObservationOffset,
201252
device: Device,
202253
scalarType: Type
203254
));
@@ -211,9 +262,7 @@ public IObservable<KalmanFilter> Process(IObservable<KalmanFilterParameters> sou
211262
return source.Select(parameters =>
212263
{
213264
return new KalmanFilter(
214-
parameters: parameters,
215-
device: Device,
216-
scalarType: Type
265+
parameters: parameters
217266
);
218267
});
219268
}

src/Bonsai.ML.Lds.Torch/CreateKalmanFilterParameters.cs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ namespace Bonsai.ML.Lds.Torch;
1818
[TypeConverter(typeof(TensorOperatorConverter))]
1919
public class CreateKalmanFilterParameters : IScalarTypeProvider
2020
{
21+
private Tensor _transitionMatrix = null;
22+
private Tensor _measurementFunction = null;
23+
private Tensor _processNoiseCovariance = null;
24+
private Tensor _measurementNoiseCovariance = null;
25+
private Tensor _initialMean = null;
26+
private Tensor _initialCovariance = null;
27+
private Tensor _stateOffset = null;
28+
private Tensor _observationOffset = null;
29+
2130
/// <inheritdoc/>
2231
[Description("The data type of the tensor elements.")]
2332
[TypeConverter(typeof(ScalarTypeConverter))]
@@ -178,20 +187,58 @@ public string InitialCovarianceXml
178187
set => _initialCovariance = TensorConverter.ConvertFromString(value, Type);
179188
}
180189

181-
private Tensor _transitionMatrix = null;
182-
private Tensor _measurementFunction = null;
183-
private Tensor _processNoiseCovariance = null;
184-
private Tensor _measurementNoiseCovariance = null;
185-
private Tensor _initialMean = null;
186-
private Tensor _initialCovariance = null;
190+
/// <summary>
191+
/// The state offset.
192+
/// </summary>
193+
[XmlIgnore]
194+
[TypeConverter(typeof(TensorConverter))]
195+
public Tensor StateOffset
196+
{
197+
get => _stateOffset;
198+
set => _stateOffset = value;
199+
}
187200

201+
/// <summary>
202+
/// The XML string representation of the state offset for serialization.
203+
/// </summary>
204+
[Browsable(false)]
205+
[XmlElement(nameof(StateOffset))]
206+
[EditorBrowsable(EditorBrowsableState.Never)]
207+
public string StateOffsetXml
208+
{
209+
get => TensorConverter.ConvertToString(_stateOffset, Type);
210+
set => _stateOffset = TensorConverter.ConvertFromString(value, Type);
211+
}
212+
213+
/// <summary>
214+
/// The observation offset.
215+
/// </summary>
216+
[XmlIgnore]
217+
[TypeConverter(typeof(TensorConverter))]
218+
public Tensor ObservationOffset
219+
{
220+
get => _observationOffset;
221+
set => _observationOffset = value;
222+
}
223+
224+
/// <summary>
225+
/// The XML string representation of the observation offset for serialization.
226+
/// </summary>
227+
[Browsable(false)]
228+
[XmlElement(nameof(ObservationOffset))]
229+
[EditorBrowsable(EditorBrowsableState.Never)]
230+
public string ObservationOffsetXml
231+
{
232+
get => TensorConverter.ConvertToString(_observationOffset, Type);
233+
set => _observationOffset = TensorConverter.ConvertFromString(value, Type);
234+
}
188235

189236
/// <summary>
190237
/// Creates parameters for a Kalman filter model using the properties of this class.
191238
/// </summary>
192239
public IObservable<KalmanFilterParameters> Process()
193240
{
194-
return Observable.Return(KalmanFilterParameters.Initialize(
241+
return Observable.Return(new KalmanFilterParameters(
195242
numStates: NumStates,
196243
numObservations: NumObservations,
197244
transitionMatrix: _transitionMatrix,
@@ -200,6 +247,8 @@ public IObservable<KalmanFilterParameters> Process()
200247
measurementNoiseCovariance: _measurementNoiseCovariance,
201248
initialMean: _initialMean,
202249
initialCovariance: _initialCovariance,
250+
stateOffset: _stateOffset,
251+
observationOffset: _observationOffset,
203252
scalarType: Type,
204253
device: Device
205254
));
@@ -212,7 +261,7 @@ public IObservable<KalmanFilterParameters> Process<T>(IObservable<T> source)
212261
{
213262
return source.Select(_ =>
214263
{
215-
return KalmanFilterParameters.Initialize(
264+
return new KalmanFilterParameters(
216265
numStates: NumStates,
217266
numObservations: NumObservations,
218267
transitionMatrix: _transitionMatrix,
@@ -221,9 +270,11 @@ public IObservable<KalmanFilterParameters> Process<T>(IObservable<T> source)
221270
measurementNoiseCovariance: _measurementNoiseCovariance,
222271
initialMean: _initialMean,
223272
initialCovariance: _initialCovariance,
273+
stateOffset: _stateOffset,
274+
observationOffset: _observationOffset,
224275
scalarType: Type,
225276
device: Device
226277
);
227278
});
228279
}
229-
}
280+
}

src/Bonsai.ML.Lds.Torch/CreateLinearDynamicalSystemState.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ namespace Bonsai.ML.Lds.Torch;
1717
[TypeConverter(typeof(TensorOperatorConverter))]
1818
public class CreateLinearDynamicalSystemState : IScalarTypeProvider
1919
{
20+
private Tensor _mean = null;
21+
private Tensor _covariance = null;
22+
2023
/// <inheritdoc/>
2124
[Description("The data type of the tensor elements.")]
2225
[TypeConverter(typeof(ScalarTypeConverter))]
@@ -75,9 +78,6 @@ public string CovarianceXml
7578
set => _covariance = TensorConverter.ConvertFromString(value, Type);
7679
}
7780

78-
private Tensor _mean = null;
79-
private Tensor _covariance = null;
80-
8181
/// <summary>
8282
/// Creates an observable sequence and emits the state for a linear gaussian dynamical system.
8383
/// </summary>
@@ -123,4 +123,4 @@ public IObservable<LinearDynamicalSystemState> Process(IObservable<Tuple<Tensor,
123123
return new LinearDynamicalSystemState(input.Item1, input.Item2);
124124
});
125125
}
126-
}
126+
}

src/Bonsai.ML.Lds.Torch/ExpectationMaximization.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ namespace Bonsai.ML.Lds.Torch;
1919
[WorkflowElementCategory(ElementCategory.Combinator)]
2020
public class ExpectationMaximization
2121
{
22+
private int _maxIterations = 10;
23+
private double _tolerance = 1e-4;
24+
private bool _verbose = true;
25+
2226
/// <summary>
2327
/// The number of states in the Kalman filter model.
2428
/// </summary>
@@ -98,9 +102,17 @@ public bool Verbose
98102
[Description("If true, the initial covariance will be estimated during the EM algorithm.")]
99103
public bool EstimateInitialCovariance { get; set; } = true;
100104

101-
private int _maxIterations = 10;
102-
private double _tolerance = 1e-4;
103-
private bool _verbose = true;
105+
/// <summary>
106+
/// If true, the state offset will be estimated during the EM algorithm.
107+
/// </summary>
108+
[Description("If true, the state offset will be estimated during the EM algorithm.")]
109+
public bool EstimateStateOffset { get; set; } = false;
110+
111+
/// <summary>
112+
/// If true, the observation offset will be estimated during the EM algorithm.
113+
/// </summary>
114+
[Description("If true, the observation offset will be estimated during the EM algorithm.")]
115+
public bool EstimateObservationOffset { get; set; } = false;
104116

105117
/// <summary>
106118
/// Processes an observable sequence of input tensors, applying the Expectation-Maximization algorithm to learn the parameters of a Kalman filter model.
@@ -124,16 +136,16 @@ public IObservable<ExpectationMaximizationResult> Process(IObservable<Tensor> so
124136
processNoiseCovariance: EstimateProcessNoiseCovariance,
125137
measurementNoiseCovariance: EstimateMeasurementNoiseCovariance,
126138
initialMean: EstimateInitialMean,
127-
initialCovariance: EstimateInitialCovariance);
139+
initialCovariance: EstimateInitialCovariance,
140+
stateOffset: EstimateStateOffset,
141+
observationOffset: EstimateObservationOffset);
128142

129-
var parameters = ModelParameters?.Copy() ?? KalmanFilterParameters.Initialize(
143+
var parameters = ModelParameters?.Copy() ?? new KalmanFilterParameters(
130144
numStates: NumStates,
131145
numObservations: numObservations,
132146
scalarType: input.dtype,
133147
device: input.device);
134148

135-
parameters.Validate();
136-
137149
for (int i = 0; i < MaxIterations; i++)
138150
{
139151
// Check for cancellation before each iteration
@@ -200,4 +212,4 @@ public IObservable<ExpectationMaximizationResult> Process(IObservable<Tensor> so
200212
cancellationToken);
201213
})).Concat();
202214
}
203-
}
215+
}

src/Bonsai.ML.Lds.Torch/FilteredState.cs

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,53 @@ namespace Bonsai.ML.Lds.Torch;
55
/// <summary>
66
/// Represents the state of a Kalman filter.
77
/// </summary>
8-
/// <param name="predictedMean"></param>
9-
/// <param name="predictedCovariance"></param>
10-
/// <param name="updatedMean"></param>
11-
/// <param name="updatedCovariance"></param>
12-
public struct FilteredState(
13-
Tensor predictedMean,
14-
Tensor predictedCovariance,
15-
Tensor updatedMean,
16-
Tensor updatedCovariance) : ILinearDynamicalSystemState
8+
/// <param name="predictedState"></param>
9+
/// <param name="updatedState"></param>
10+
/// <param name="innovation"></param>
11+
/// <param name="innovationCovariance"></param>
12+
/// <param name="kalmanGain"></param>
13+
/// <param name="logLikelihood"></param>
14+
public readonly struct FilteredState(
15+
LinearDynamicalSystemState predictedState,
16+
LinearDynamicalSystemState updatedState,
17+
Tensor innovation = null,
18+
Tensor innovationCovariance = null,
19+
Tensor kalmanGain = null,
20+
Tensor logLikelihood = null) : ILinearDynamicalSystemState
1721
{
1822
/// <summary>
19-
/// The predicted mean after the prediction step.
23+
/// The predicted state following the prediction step.
2024
/// </summary>
21-
public Tensor PredictedMean = predictedMean;
25+
public readonly LinearDynamicalSystemState PredictedState => predictedState;
2226

2327
/// <summary>
24-
/// The predicted covariance after the prediction step.
28+
/// The updated state following the update step.
2529
/// </summary>
26-
public Tensor PredictedCovariance = predictedCovariance;
30+
public readonly LinearDynamicalSystemState UpdatedState => updatedState;
2731

2832
/// <summary>
29-
/// The updated mean after the update step.
33+
/// The innovation (residual) between the observation and the prediction.
3034
/// </summary>
31-
public Tensor UpdatedMean = updatedMean;
35+
public readonly Tensor Innovation => innovation;
3236

3337
/// <summary>
34-
/// The updated covariance after the update step.
38+
/// The innovation (residual) covariance.
3539
/// </summary>
36-
public Tensor UpdatedCovariance = updatedCovariance;
40+
public readonly Tensor InnovationCovariance => innovationCovariance;
41+
42+
/// <summary>
43+
/// The Kalman gain.
44+
/// </summary>
45+
public readonly Tensor KalmanGain => kalmanGain;
46+
47+
/// <summary>
48+
/// The log likelihood of the observation given the updated state.
49+
/// </summary>
50+
public readonly Tensor LogLikelihood => logLikelihood;
3751

3852
/// <inheritdoc/>
39-
public readonly Tensor Mean => UpdatedMean.isnan().any().item<bool>() ? PredictedMean : UpdatedMean;
53+
public readonly Tensor Mean => updatedState.Mean;
4054

4155
/// <inheritdoc/>
42-
public readonly Tensor Covariance => UpdatedCovariance.isnan().any().item<bool>() ? PredictedCovariance : UpdatedCovariance;
43-
}
56+
public readonly Tensor Covariance => updatedState.Covariance;
57+
}

0 commit comments

Comments
 (0)