Skip to content

Commit cc6de15

Browse files
committed
missed files on previous commit
1 parent 02862b4 commit cc6de15

5 files changed

Lines changed: 79 additions & 31 deletions

File tree

spf/dataset/segmentation.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"window_size": 2048,
3636
"stride": 2048,
3737
"trim": 20.0,
38-
"mean_diff_threshold": 0.2,
38+
# "mean_diff_threshold": 0.2,
3939
"max_stddev_threshold": 0.5,
4040
"drop_less_than_size": 3000,
4141
"min_abs_signal": 40,
@@ -539,7 +539,17 @@ def segment_session(
539539
# 4. Identifies windows as "signal" vs "noise" based on phase stability and amplitude
540540
# 5. Combines adjacent windows with similar phase characteristics
541541
# 6. Returns a list of segment information and signal statistics
542-
segmentation_results.update(simple_segment(v, **kwrgs))
542+
segmentation_results.update(
543+
simple_segment(
544+
v=v,
545+
window_size=kwrgs["window_size"],
546+
stride=kwrgs["stride"],
547+
trim=kwrgs["trim"],
548+
max_stddev_threshold=kwrgs["max_stddev_threshold"],
549+
drop_less_than_size=kwrgs["drop_less_than_size"],
550+
min_abs_signal=kwrgs["min_abs_signal"],
551+
)
552+
)
543553

544554
# Transpose the window statistics for easier processing
545555
# all_windows_stats shape is (3, N_windows) where:
@@ -592,7 +602,12 @@ def segment_session(
592602
# If no signal windows were identified, use placeholder values
593603
segmentation_results["weighted_windows_stats"] = np.array([-1, -1, -1])
594604
else:
595-
segmentation_results['all_windows_stats']=get_all_windows_stats(v=v,window_size=kwrgs['window_size'],stride=kwrgs['stride'],trim=kwrgs['trim'])[1]
605+
segmentation_results["all_windows_stats"] = get_all_windows_stats(
606+
v=v,
607+
window_size=kwrgs["window_size"],
608+
stride=kwrgs["stride"],
609+
trim=kwrgs["trim"],
610+
)[1]
596611

597612
# Transpose the window statistics for easier processing
598613
# all_windows_stats shape is (3, N_windows) where:
@@ -611,6 +626,7 @@ def segment_session(
611626

612627
return segmentation_results
613628

629+
614630
def get_all_windows_stats(
615631
v,
616632
window_size,
@@ -633,6 +649,7 @@ def get_all_windows_stats(
633649
)
634650
return step_idxs, step_stats
635651

652+
636653
def simple_segment(
637654
v,
638655
window_size,
@@ -691,7 +708,9 @@ def simple_segment(
691708
# window_idxs_and_stats = windowed_trimmed_circular_mean_and_stddev(
692709
# v, pd, window_size=window_size, stride=stride, trim=trim
693710
# )
694-
window_idxs_and_stats = get_all_windows_stats(v=v,window_size=window_size,stride=stride,trim=trim)
711+
window_idxs_and_stats = get_all_windows_stats(
712+
v=v, window_size=window_size, stride=stride, trim=trim
713+
)
695714
# window_idxs_and_stats has two components:
696715
# [0] = list of window indices (start_idx, end_idx)
697716
# [1] = array of statistics (trimmed_mean, trimmed_stddev, abs_signal_median)

spf/dataset/spf_dataset.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import threading
21
import logging
32
import multiprocessing
43
import os
54
import pickle
5+
import queue
6+
import threading
67
import time
78
from contextlib import contextmanager
89
from enum import Enum
910
from functools import cache
1011
from typing import Dict, List
11-
import queue
1212

1313
import numpy as np
1414
import torch
@@ -118,6 +118,9 @@ def v5_collate_keys_fast(keys: List[str], batch: List[List[Dict[str, torch.Tenso
118118
)
119119
if key == "windowed_beamformer" or key == "all_windows_stats":
120120
d[key] = d[key].to(torch.float32)
121+
# d["dropout_mask_rand_values"] = torch.rand(
122+
# (*d["y_rad"].shape,8), device=d["y_rad"].device
123+
# )
121124
# d["random_rotations"] = torch_pi_norm(
122125
# torch.rand(d["y_rad"].shape[0], 1) * 2 * torch.pi
123126
# )
@@ -478,7 +481,7 @@ def __init__(
478481
self.target_ntheta = self.nthetas if target_ntheta is None else target_ntheta
479482

480483
self.realtime = realtime
481-
self.max_store_size=max_store_size
484+
self.max_store_size = max_store_size
482485

483486
self.max_in_memory = max_in_memory
484487
self.min_idx = 0
@@ -597,7 +600,7 @@ def __init__(
597600
self.empirical_data = None
598601
self.serving_idx = -1
599602

600-
#setup the reader thread
603+
# setup the reader thread
601604
self.stop_event = threading.Event()
602605
self.reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
603606
self.reader_thread.start()
@@ -614,16 +617,25 @@ def _reader_loop(self):
614617
}
615618
self.store[idx]["data"][ridx] = rendered_data
616619
self.store[idx]["count"] += 1
617-
print("LENGTH OF STORE IS",len(self.store))
618-
if self.max_store_size is not None and len(self.store)>self.max_store_size:
619-
for idx in sorted(self.store.keys())[:-self.max_store_size]:
620+
print("LENGTH OF STORE IS", len(self.store))
621+
if (
622+
self.max_store_size is not None
623+
and len(self.store) > self.max_store_size
624+
):
625+
for idx in sorted(self.store.keys())[: -self.max_store_size]:
620626
self.store.pop(idx)
621-
print("POPPING",idx)
627+
print("POPPING", idx)
622628

623629
except queue.Empty:
624630
pass # No new item, just keep looping
631+
except (OSError, EOFError):
632+
# Queue closed because process exiting. Exit silently.
633+
break
625634
except Exception as e:
626-
logging.exception("Error in v5inferencedataset queue reader thread")
635+
if not self.stop_event.is_set():
636+
logging.exception("Error in v5inferencedataset queue reader thread")
637+
# If stop_event is set, don't log, just exit.
638+
break
627639

628640
def __len__(self):
629641
return self.serving_idx
@@ -724,7 +736,7 @@ def render_session(self, ridx, data):
724736
# ]
725737
# ).T # torch.Size([1, 2])
726738

727-
data["rx_pos_mm"] = data["tx_pos_mm"] = torch.ones(1, 2) * torch.nan
739+
data["rx_pos_mm"] = data["tx_pos_mm"] = torch.zeros(1, 2) # * torch.nan
728740

729741
data["rx_pos_xy"] = (
730742
data["rx_pos_mm"][snapshot_idxs].unsqueeze(0) / self.distance_normalization
@@ -789,6 +801,15 @@ def render_session(self, ridx, data):
789801
data[key] = data[key].to(self.target_dtype)
790802
return data
791803

804+
def close(self):
805+
if hasattr(self, "stop_event"):
806+
self.stop_event.set()
807+
if hasattr(self, "reader_thread") and self.reader_thread.is_alive():
808+
self.reader_thread.join(timeout=1.0) # avoid blocking forever
809+
810+
def __del__(self):
811+
self.close()
812+
792813

793814
class v5spfdataset(Dataset):
794815
def __init__(

spf/model_training_and_inference/models/single_point_networks.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
TEMP = 10
1313

1414

15-
@torch.jit.script
15+
# @torch.jit.script
1616
def cdf(mean: torch.Tensor, sigma: torch.Tensor, value: float):
1717
return 0.5 * (1 + torch.erf((value - mean) * sigma.reciprocal() / sqrt(2)))
1818

1919

20-
@torch.jit.script
20+
# @torch.jit.script
2121
def normal_correction_for_bounded_range(
2222
mean: torch.Tensor, sigma: torch.Tensor, max_y: float
2323
):
@@ -468,10 +468,16 @@ def __init__(self, model_config, global_config):
468468
)
469469

470470
def prepare_input(self, batch, additional_inputs=[]):
471-
dropout_mask = (
472-
torch.rand((8, *batch["y_rad"].shape), device=batch["y_rad"].device)
473-
< self.input_dropout
474-
)
471+
if self.training:
472+
# dropout_mask = (
473+
# torch.rand((8, *batch["y_rad"].shape), device=batch["y_rad"].device)
474+
# < self.input_dropout
475+
# )
476+
shape = batch["y_rad"].shape
477+
rand_shape = torch.Size((8, shape[0], shape[1]))
478+
dropout_mask = torch.rand(rand_shape, device=batch["y_rad"].device).to(
479+
torch.bool
480+
)
475481
# 1 , 65
476482
# if mask out 65, then scale up 1 by 65?
477483
# [ batch, samples, dim of input ]
@@ -976,6 +982,7 @@ def forward(self, batch):
976982
additional_inputs.append(
977983
self.signal_matrix_net(input_with_spacing.select(1, 0))[:, None]
978984
)
985+
# print(batch["all_windows_stats"].abs().mean())
979986
return_dict["single"] = torch.nn.functional.normalize(
980987
self.single_point_with_beamformer_ffnn(
981988
self.prepare_input.prepare_input(batch, additional_inputs),

spf/model_training_and_inference/models/single_point_networks_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def single_example_realtime_inference(model, global_config, optim_config, realti
5454

5555

5656
def single_example_inference(model, global_config, datasets_config, optim_config):
57+
model.eval()
5758

5859
ds = v5spfdataset(
5960
datasets_config["train_paths"][0],

spf/rf.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def pi_norm_halfpi(x):
5555
return ((x + np.pi / 2) % (2 * np.pi / 2)) - np.pi / 2
5656

5757

58-
@torch.jit.script
58+
# @torch.jit.script
5959
def torch_circular_diff_to_mean(angles: torch.Tensor, means: torch.Tensor):
6060
assert means.ndim == 1
6161
a = torch.abs(means[:, None] - angles) % (2 * torch.pi)
@@ -68,12 +68,12 @@ def torch_circular_diff_to_mean(angles: torch.Tensor, means: torch.Tensor):
6868
# return ((x + max_angle) % (2 * max_angle)) - max_angle
6969

7070

71-
@torch.jit.script
71+
# @torch.jit.script
7272
def torch_pi_norm_pi(x):
7373
return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
7474

7575

76-
@torch.jit.script
76+
# @torch.jit.script
7777
def torch_pi_norm(x: torch.Tensor, max_angle: float = torch.pi):
7878
return ((x + max_angle) % (2 * max_angle)) - max_angle
7979

@@ -102,7 +102,7 @@ def circular_stddev(v, u, trim=50.0):
102102

103103

104104
# returns circular_stddev and trimmed cricular stddev
105-
@torch.jit.script
105+
# @torch.jit.script
106106
def torch_circular_stddev(v: torch.Tensor, u: torch.Tensor, trim: float): # =50.0):
107107
diff_from_mean = torch_circular_diff_to_mean(angles=v, means=u.reshape(-1))
108108

@@ -126,7 +126,7 @@ def torch_circular_stddev(v: torch.Tensor, u: torch.Tensor, trim: float): # =50
126126
return stddev, trimmed_stddev
127127

128128

129-
@torch.jit.script
129+
# @torch.jit.script
130130
def torch_reduce_theta_to_positive_y(ground_truth_thetas):
131131
reduced_thetas = ground_truth_thetas.clone()
132132
# |theta|>np.pi/2 means its on the y<0
@@ -236,7 +236,7 @@ def circular_mean_single(angles, trim, weights=None):
236236
return pi_norm(cm), pi_norm(_cm)
237237

238238

239-
@torch.jit.script
239+
# @torch.jit.script
240240
def torch_circular_mean_notrim(angles: torch.Tensor):
241241
assert angles.ndim == 2
242242
_sin_angles = torch.sin(angles)
@@ -247,7 +247,7 @@ def torch_circular_mean_notrim(angles: torch.Tensor):
247247
return r, r
248248

249249

250-
@torch.jit.script
250+
# @torch.jit.script
251251
def torch_circular_mean_noweight(angles: torch.Tensor, trim: float):
252252
assert angles.ndim == 2
253253
_sin_angles = torch.sin(angles)
@@ -315,7 +315,7 @@ def torch_circular_mean(angles: torch.Tensor, trim: float, weights=None):
315315
return torch_pi_norm_pi(cm), torch_pi_norm_pi(_cm)
316316

317317

318-
@torch.jit.script
318+
# @torch.jit.script
319319
def torch_get_stats_for_signal(v: torch.Tensor, pd: torch.Tensor, trim: float):
320320
trimmed_cm = torch_circular_mean_noweight(pd.reshape(1, -1), trim=trim)[1][
321321
0
@@ -335,7 +335,7 @@ def get_stats_for_signal(v, pd, trim):
335335
return trimmed_cm, trimmed_stddev, abs_signal_median
336336

337337

338-
@torch.jit.script
338+
# @torch.jit.script
339339
def torch_windowed_trimmed_circular_mean_and_stddev(
340340
v: torch.Tensor, pd: torch.Tensor, window_size: int, stride: int, trim: float
341341
):
@@ -415,7 +415,7 @@ def get_phase_diff(signal_matrix):
415415
return pi_norm(np.angle(signal_matrix[0]) - np.angle(signal_matrix[1]))
416416

417417

418-
@torch.jit.script
418+
# @torch.jit.script
419419
def torch_get_phase_diff(signal_matrix: torch.Tensor):
420420
return torch_pi_norm_pi(signal_matrix[:, 0].angle() - signal_matrix[:, 1].angle())
421421

@@ -427,7 +427,7 @@ def get_avg_phase(signal_matrix, trim=0.0):
427427
).reshape(-1)
428428

429429

430-
@torch.jit.script
430+
# @torch.jit.script
431431
def torch_get_avg_phase_notrim(signal_matrix: torch.Tensor):
432432
return torch.hstack(
433433
torch_circular_mean_notrim(

0 commit comments

Comments
 (0)