@@ -14,21 +14,13 @@ namespace Bonsai.ML.Lds.Torch;
1414[ ResetCombinator ]
1515[ Description ( "Creates a Kalman filter model." ) ]
1616[ WorkflowElementCategory ( ElementCategory . Source ) ]
17+ [ TypeConverter ( typeof ( TensorOperatorConverter ) ) ]
1718public class CreateKalmanFilter : IScalarTypeProvider
1819{
19- private ScalarType _scalarType = ScalarType . Float32 ;
2020 /// <inheritdoc/>
2121 [ Description ( "The data type of the tensor elements." ) ]
2222 [ TypeConverter ( typeof ( ScalarTypeConverter ) ) ]
23- public ScalarType Type
24- {
25- get => _scalarType ;
26- set
27- {
28- _scalarType = value ;
29- ConvertTensorsScalarType ( value ) ;
30- }
31- }
23+ public ScalarType Type { get ; set ; } = ScalarType . Float32 ;
3224
3325 /// <summary>
3426 /// The device on which to create the tensor.
@@ -47,8 +39,6 @@ public ScalarType Type
4739 /// </summary>
4840 public int ? NumObservations { get ; set ; } = null ;
4941
50- // Tensor properties with XML serialization support
51- private Tensor _transitionMatrix ;
5242 /// <summary>
5343 /// The state transition matrix.
5444 /// </summary>
@@ -57,7 +47,7 @@ public ScalarType Type
5747 public Tensor TransitionMatrix
5848 {
5949 get => _transitionMatrix ;
60- set => _transitionMatrix = value ? . to_type ( Type ) ;
50+ set => _transitionMatrix = value ;
6151 }
6252
6353 /// <summary>
@@ -68,11 +58,10 @@ public Tensor TransitionMatrix
6858 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
6959 public string TransitionMatrixXml
7060 {
71- get => TensorConverter . ConvertToString ( TransitionMatrix , _scalarType ) ;
72- set => TransitionMatrix = TensorConverter . ConvertFromString ( value , _scalarType ) ;
61+ get => TensorConverter . ConvertToString ( _transitionMatrix , Type ) ;
62+ set => _transitionMatrix = TensorConverter . ConvertFromString ( value , Type ) ;
7363 }
7464
75- private Tensor _measurementFunction ;
7665 /// <summary>
7766 /// The measurement function.
7867 /// </summary>
@@ -81,7 +70,7 @@ public string TransitionMatrixXml
8170 public Tensor MeasurementFunction
8271 {
8372 get => _measurementFunction ;
84- set => _measurementFunction = value ? . to_type ( Type ) ;
73+ set => _measurementFunction = value ;
8574 }
8675
8776 /// <summary>
@@ -92,11 +81,10 @@ public Tensor MeasurementFunction
9281 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
9382 public string MeasurementFunctionXml
9483 {
95- get => TensorConverter . ConvertToString ( MeasurementFunction , _scalarType ) ;
96- set => MeasurementFunction = TensorConverter . ConvertFromString ( value , _scalarType ) ;
84+ get => TensorConverter . ConvertToString ( _measurementFunction , Type ) ;
85+ set => _measurementFunction = TensorConverter . ConvertFromString ( value , Type ) ;
9786 }
9887
99- private Tensor _processNoiseVariance ;
10088 /// <summary>
10189 /// The process noise variance.
10290 /// </summary>
@@ -105,7 +93,7 @@ public string MeasurementFunctionXml
10593 public Tensor ProcessNoiseVariance
10694 {
10795 get => _processNoiseVariance ;
108- set => _processNoiseVariance = value ? . to_type ( Type ) ;
96+ set => _processNoiseVariance = value ;
10997 }
11098
11199 /// <summary>
@@ -116,11 +104,10 @@ public Tensor ProcessNoiseVariance
116104 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
117105 public string ProcessNoiseVarianceXml
118106 {
119- get => TensorConverter . ConvertToString ( ProcessNoiseVariance , _scalarType ) ;
120- set => ProcessNoiseVariance = TensorConverter . ConvertFromString ( value , _scalarType ) ;
107+ get => TensorConverter . ConvertToString ( _processNoiseVariance , Type ) ;
108+ set => _processNoiseVariance = TensorConverter . ConvertFromString ( value , Type ) ;
121109 }
122110
123- private Tensor _measurementNoiseVariance ;
124111 /// <summary>
125112 /// The measurement noise variance.
126113 /// </summary>
@@ -129,7 +116,7 @@ public string ProcessNoiseVarianceXml
129116 public Tensor MeasurementNoiseVariance
130117 {
131118 get => _measurementNoiseVariance ;
132- set => _measurementNoiseVariance = value ? . to_type ( Type ) ;
119+ set => _measurementNoiseVariance = value ;
133120 }
134121
135122 /// <summary>
@@ -140,11 +127,10 @@ public Tensor MeasurementNoiseVariance
140127 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
141128 public string MeasurementNoiseVarianceXml
142129 {
143- get => TensorConverter . ConvertToString ( MeasurementNoiseVariance , _scalarType ) ;
144- set => MeasurementNoiseVariance = TensorConverter . ConvertFromString ( value , _scalarType ) ;
130+ get => TensorConverter . ConvertToString ( _measurementNoiseVariance , Type ) ;
131+ set => _measurementNoiseVariance = TensorConverter . ConvertFromString ( value , Type ) ;
145132 }
146133
147- private Tensor _initialMean ;
148134 /// <summary>
149135 /// The initial mean.
150136 /// </summary>
@@ -153,7 +139,7 @@ public string MeasurementNoiseVarianceXml
153139 public Tensor InitialMean
154140 {
155141 get => _initialMean ;
156- set => _initialMean = value ? . to_type ( Type ) ;
142+ set => _initialMean = value ;
157143 }
158144
159145 /// <summary>
@@ -164,11 +150,10 @@ public Tensor InitialMean
164150 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
165151 public string InitialMeanXml
166152 {
167- get => TensorConverter . ConvertToString ( InitialMean , _scalarType ) ;
168- set => InitialMean = TensorConverter . ConvertFromString ( value , _scalarType ) ;
153+ get => TensorConverter . ConvertToString ( _initialMean , Type ) ;
154+ set => _initialMean = TensorConverter . ConvertFromString ( value , Type ) ;
169155 }
170156
171- private Tensor _initialCovariance ;
172157 /// <summary>
173158 /// The initial covariance.
174159 /// </summary>
@@ -177,7 +162,7 @@ public string InitialMeanXml
177162 public Tensor InitialCovariance
178163 {
179164 get => _initialCovariance ;
180- set => _initialCovariance = value ? . to_type ( Type ) ;
165+ set => _initialCovariance = value ;
181166 }
182167
183168 /// <summary>
@@ -188,19 +173,16 @@ public Tensor InitialCovariance
188173 [ EditorBrowsable ( EditorBrowsableState . Never ) ]
189174 public string InitialCovarianceXml
190175 {
191- get => TensorConverter . ConvertToString ( InitialCovariance , _scalarType ) ;
192- set => InitialCovariance = TensorConverter . ConvertFromString ( value , _scalarType ) ;
176+ get => TensorConverter . ConvertToString ( _initialCovariance , Type ) ;
177+ set => _initialCovariance = TensorConverter . ConvertFromString ( value , Type ) ;
193178 }
194179
195- private void ConvertTensorsScalarType ( ScalarType scalarType )
196- {
197- _transitionMatrix = _transitionMatrix ? . to_type ( scalarType ) ;
198- _measurementFunction = _measurementFunction ? . to_type ( scalarType ) ;
199- _processNoiseVariance = _processNoiseVariance ? . to_type ( scalarType ) ;
200- _measurementNoiseVariance = _measurementNoiseVariance ? . to_type ( scalarType ) ;
201- _initialMean = _initialMean ? . to_type ( scalarType ) ;
202- _initialCovariance = _initialCovariance ? . to_type ( scalarType ) ;
203- }
180+ private Tensor _transitionMatrix ;
181+ private Tensor _measurementFunction ;
182+ private Tensor _processNoiseVariance ;
183+ private Tensor _measurementNoiseVariance ;
184+ private Tensor _initialMean ;
185+ private Tensor _initialCovariance ;
204186
205187 /// <summary>
206188 /// Creates a Kalman filter model using the properties of this class.
@@ -226,13 +208,13 @@ public IObservable<KalmanFilter> Process()
226208 /// </summary>
227209 public IObservable < KalmanFilter > Process ( IObservable < KalmanFilterParameters > source )
228210 {
229- return source . SelectMany ( parameters =>
211+ return source . Select ( parameters =>
230212 {
231- return Observable . Return ( new KalmanFilter (
213+ return new KalmanFilter (
232214 parameters : parameters ,
233215 device : Device ,
234216 scalarType : Type
235- ) ) ;
217+ ) ;
236218 } ) ;
237219 }
238220}
0 commit comments