@@ -55,7 +55,7 @@ def pi_norm_halfpi(x):
5555 return ((x + np .pi / 2 ) % (2 * np .pi / 2 )) - np .pi / 2
5656
5757
58- @torch .jit .script
58+ # @torch.jit.script
5959def torch_circular_diff_to_mean (angles : torch .Tensor , means : torch .Tensor ):
6060 assert means .ndim == 1
6161 a = torch .abs (means [:, None ] - angles ) % (2 * torch .pi )
@@ -68,12 +68,12 @@ def torch_circular_diff_to_mean(angles: torch.Tensor, means: torch.Tensor):
6868# return ((x + max_angle) % (2 * max_angle)) - max_angle
6969
7070
71- @torch .jit .script
71+ # @torch.jit.script
7272def torch_pi_norm_pi (x ):
7373 return ((x + torch .pi ) % (2 * torch .pi )) - torch .pi
7474
7575
76- @torch .jit .script
76+ # @torch.jit.script
7777def torch_pi_norm (x : torch .Tensor , max_angle : float = torch .pi ):
7878 return ((x + max_angle ) % (2 * max_angle )) - max_angle
7979
@@ -102,7 +102,7 @@ def circular_stddev(v, u, trim=50.0):
102102
103103
104104# returns circular_stddev and trimmed cricular stddev
105- @torch .jit .script
105+ # @torch.jit.script
106106def torch_circular_stddev (v : torch .Tensor , u : torch .Tensor , trim : float ): # =50.0):
107107 diff_from_mean = torch_circular_diff_to_mean (angles = v , means = u .reshape (- 1 ))
108108
@@ -126,7 +126,7 @@ def torch_circular_stddev(v: torch.Tensor, u: torch.Tensor, trim: float): # =50
126126 return stddev , trimmed_stddev
127127
128128
129- @torch .jit .script
129+ # @torch.jit.script
130130def torch_reduce_theta_to_positive_y (ground_truth_thetas ):
131131 reduced_thetas = ground_truth_thetas .clone ()
132132 # |theta|>np.pi/2 means its on the y<0
@@ -236,7 +236,7 @@ def circular_mean_single(angles, trim, weights=None):
236236 return pi_norm (cm ), pi_norm (_cm )
237237
238238
239- @torch .jit .script
239+ # @torch.jit.script
240240def torch_circular_mean_notrim (angles : torch .Tensor ):
241241 assert angles .ndim == 2
242242 _sin_angles = torch .sin (angles )
@@ -247,7 +247,7 @@ def torch_circular_mean_notrim(angles: torch.Tensor):
247247 return r , r
248248
249249
250- @torch .jit .script
250+ # @torch.jit.script
251251def torch_circular_mean_noweight (angles : torch .Tensor , trim : float ):
252252 assert angles .ndim == 2
253253 _sin_angles = torch .sin (angles )
@@ -315,7 +315,7 @@ def torch_circular_mean(angles: torch.Tensor, trim: float, weights=None):
315315 return torch_pi_norm_pi (cm ), torch_pi_norm_pi (_cm )
316316
317317
318- @torch .jit .script
318+ # @torch.jit.script
319319def torch_get_stats_for_signal (v : torch .Tensor , pd : torch .Tensor , trim : float ):
320320 trimmed_cm = torch_circular_mean_noweight (pd .reshape (1 , - 1 ), trim = trim )[1 ][
321321 0
@@ -335,7 +335,7 @@ def get_stats_for_signal(v, pd, trim):
335335 return trimmed_cm , trimmed_stddev , abs_signal_median
336336
337337
338- @torch .jit .script
338+ # @torch.jit.script
339339def torch_windowed_trimmed_circular_mean_and_stddev (
340340 v : torch .Tensor , pd : torch .Tensor , window_size : int , stride : int , trim : float
341341):
@@ -415,7 +415,7 @@ def get_phase_diff(signal_matrix):
415415 return pi_norm (np .angle (signal_matrix [0 ]) - np .angle (signal_matrix [1 ]))
416416
417417
418- @torch .jit .script
418+ # @torch.jit.script
419419def torch_get_phase_diff (signal_matrix : torch .Tensor ):
420420 return torch_pi_norm_pi (signal_matrix [:, 0 ].angle () - signal_matrix [:, 1 ].angle ())
421421
@@ -427,7 +427,7 @@ def get_avg_phase(signal_matrix, trim=0.0):
427427 ).reshape (- 1 )
428428
429429
430- @torch .jit .script
430+ # @torch.jit.script
431431def torch_get_avg_phase_notrim (signal_matrix : torch .Tensor ):
432432 return torch .hstack (
433433 torch_circular_mean_notrim (
0 commit comments