Skip to content

Commit faca4ca

Browse files
committed
Updated to use TensorOperatorConverter class
1 parent 6260dc1 commit faca4ca

5 files changed

Lines changed: 81 additions & 126 deletions

File tree

src/Bonsai.ML.Lds.Torch/CreateKalmanFilter.cs

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,13 @@ namespace Bonsai.ML.Lds.Torch;
1414
[ResetCombinator]
1515
[Description("Creates a Kalman filter model.")]
1616
[WorkflowElementCategory(ElementCategory.Source)]
17+
[TypeConverter(typeof(TensorOperatorConverter))]
1718
public class CreateKalmanFilter : IScalarTypeProvider
1819
{
19-
private ScalarType _scalarType = ScalarType.Float32;
2020
/// <inheritdoc/>
2121
[Description("The data type of the tensor elements.")]
2222
[TypeConverter(typeof(ScalarTypeConverter))]
23-
public ScalarType Type
24-
{
25-
get => _scalarType;
26-
set
27-
{
28-
_scalarType = value;
29-
ConvertTensorsScalarType(value);
30-
}
31-
}
23+
public ScalarType Type { get; set; } = ScalarType.Float32;
3224

3325
/// <summary>
3426
/// The device on which to create the tensor.
@@ -47,8 +39,6 @@ public ScalarType Type
4739
/// </summary>
4840
public int? NumObservations { get; set; } = null;
4941

50-
// Tensor properties with XML serialization support
51-
private Tensor _transitionMatrix;
5242
/// <summary>
5343
/// The state transition matrix.
5444
/// </summary>
@@ -57,7 +47,7 @@ public ScalarType Type
5747
public Tensor TransitionMatrix
5848
{
5949
get => _transitionMatrix;
60-
set => _transitionMatrix = value?.to_type(Type);
50+
set => _transitionMatrix = value;
6151
}
6252

6353
/// <summary>
@@ -68,11 +58,10 @@ public Tensor TransitionMatrix
6858
[EditorBrowsable(EditorBrowsableState.Never)]
6959
public string TransitionMatrixXml
7060
{
71-
get => TensorConverter.ConvertToString(TransitionMatrix, _scalarType);
72-
set => TransitionMatrix = TensorConverter.ConvertFromString(value, _scalarType);
61+
get => TensorConverter.ConvertToString(_transitionMatrix, Type);
62+
set => _transitionMatrix = TensorConverter.ConvertFromString(value, Type);
7363
}
7464

75-
private Tensor _measurementFunction;
7665
/// <summary>
7766
/// The measurement function.
7867
/// </summary>
@@ -81,7 +70,7 @@ public string TransitionMatrixXml
8170
public Tensor MeasurementFunction
8271
{
8372
get => _measurementFunction;
84-
set => _measurementFunction = value?.to_type(Type);
73+
set => _measurementFunction = value;
8574
}
8675

8776
/// <summary>
@@ -92,11 +81,10 @@ public Tensor MeasurementFunction
9281
[EditorBrowsable(EditorBrowsableState.Never)]
9382
public string MeasurementFunctionXml
9483
{
95-
get => TensorConverter.ConvertToString(MeasurementFunction, _scalarType);
96-
set => MeasurementFunction = TensorConverter.ConvertFromString(value, _scalarType);
84+
get => TensorConverter.ConvertToString(_measurementFunction, Type);
85+
set => _measurementFunction = TensorConverter.ConvertFromString(value, Type);
9786
}
9887

99-
private Tensor _processNoiseVariance;
10088
/// <summary>
10189
/// The process noise variance.
10290
/// </summary>
@@ -105,7 +93,7 @@ public string MeasurementFunctionXml
10593
public Tensor ProcessNoiseVariance
10694
{
10795
get => _processNoiseVariance;
108-
set => _processNoiseVariance = value?.to_type(Type);
96+
set => _processNoiseVariance = value;
10997
}
11098

11199
/// <summary>
@@ -116,11 +104,10 @@ public Tensor ProcessNoiseVariance
116104
[EditorBrowsable(EditorBrowsableState.Never)]
117105
public string ProcessNoiseVarianceXml
118106
{
119-
get => TensorConverter.ConvertToString(ProcessNoiseVariance, _scalarType);
120-
set => ProcessNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType);
107+
get => TensorConverter.ConvertToString(_processNoiseVariance, Type);
108+
set => _processNoiseVariance = TensorConverter.ConvertFromString(value, Type);
121109
}
122110

123-
private Tensor _measurementNoiseVariance;
124111
/// <summary>
125112
/// The measurement noise variance.
126113
/// </summary>
@@ -129,7 +116,7 @@ public string ProcessNoiseVarianceXml
129116
public Tensor MeasurementNoiseVariance
130117
{
131118
get => _measurementNoiseVariance;
132-
set => _measurementNoiseVariance = value?.to_type(Type);
119+
set => _measurementNoiseVariance = value;
133120
}
134121

135122
/// <summary>
@@ -140,11 +127,10 @@ public Tensor MeasurementNoiseVariance
140127
[EditorBrowsable(EditorBrowsableState.Never)]
141128
public string MeasurementNoiseVarianceXml
142129
{
143-
get => TensorConverter.ConvertToString(MeasurementNoiseVariance, _scalarType);
144-
set => MeasurementNoiseVariance = TensorConverter.ConvertFromString(value, _scalarType);
130+
get => TensorConverter.ConvertToString(_measurementNoiseVariance, Type);
131+
set => _measurementNoiseVariance = TensorConverter.ConvertFromString(value, Type);
145132
}
146133

147-
private Tensor _initialMean;
148134
/// <summary>
149135
/// The initial mean.
150136
/// </summary>
@@ -153,7 +139,7 @@ public string MeasurementNoiseVarianceXml
153139
public Tensor InitialMean
154140
{
155141
get => _initialMean;
156-
set => _initialMean = value?.to_type(Type);
142+
set => _initialMean = value;
157143
}
158144

159145
/// <summary>
@@ -164,11 +150,10 @@ public Tensor InitialMean
164150
[EditorBrowsable(EditorBrowsableState.Never)]
165151
public string InitialMeanXml
166152
{
167-
get => TensorConverter.ConvertToString(InitialMean, _scalarType);
168-
set => InitialMean = TensorConverter.ConvertFromString(value, _scalarType);
153+
get => TensorConverter.ConvertToString(_initialMean, Type);
154+
set => _initialMean = TensorConverter.ConvertFromString(value, Type);
169155
}
170156

171-
private Tensor _initialCovariance;
172157
/// <summary>
173158
/// The initial covariance.
174159
/// </summary>
@@ -177,7 +162,7 @@ public string InitialMeanXml
177162
public Tensor InitialCovariance
178163
{
179164
get => _initialCovariance;
180-
set => _initialCovariance = value?.to_type(Type);
165+
set => _initialCovariance = value;
181166
}
182167

183168
/// <summary>
@@ -188,19 +173,16 @@ public Tensor InitialCovariance
188173
[EditorBrowsable(EditorBrowsableState.Never)]
189174
public string InitialCovarianceXml
190175
{
191-
get => TensorConverter.ConvertToString(InitialCovariance, _scalarType);
192-
set => InitialCovariance = TensorConverter.ConvertFromString(value, _scalarType);
176+
get => TensorConverter.ConvertToString(_initialCovariance, Type);
177+
set => _initialCovariance = TensorConverter.ConvertFromString(value, Type);
193178
}
194179

195-
private void ConvertTensorsScalarType(ScalarType scalarType)
196-
{
197-
_transitionMatrix = _transitionMatrix?.to_type(scalarType);
198-
_measurementFunction = _measurementFunction?.to_type(scalarType);
199-
_processNoiseVariance = _processNoiseVariance?.to_type(scalarType);
200-
_measurementNoiseVariance = _measurementNoiseVariance?.to_type(scalarType);
201-
_initialMean = _initialMean?.to_type(scalarType);
202-
_initialCovariance = _initialCovariance?.to_type(scalarType);
203-
}
180+
private Tensor _transitionMatrix;
181+
private Tensor _measurementFunction;
182+
private Tensor _processNoiseVariance;
183+
private Tensor _measurementNoiseVariance;
184+
private Tensor _initialMean;
185+
private Tensor _initialCovariance;
204186

205187
/// <summary>
206188
/// Creates a Kalman filter model using the properties of this class.
@@ -226,13 +208,13 @@ public IObservable<KalmanFilter> Process()
226208
/// </summary>
227209
public IObservable<KalmanFilter> Process(IObservable<KalmanFilterParameters> source)
228210
{
229-
return source.SelectMany(parameters =>
211+
return source.Select(parameters =>
230212
{
231-
return Observable.Return(new KalmanFilter(
213+
return new KalmanFilter(
232214
parameters: parameters,
233215
device: Device,
234216
scalarType: Type
235-
));
217+
);
236218
});
237219
}
238220
}

0 commit comments

Comments
 (0)