Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions finn/track_data_views/views/layers/track_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from finn.track_data_views.views_coordinator.tracks_viewer import TracksViewer
from finn.utils.events import Event

from finn.track_data_views.graph_attributes import NodeAttr


def new_label(layer: TrackLabels):
"""A function to override the default finn labels new_label function.
Expand Down Expand Up @@ -47,7 +45,7 @@ def _new_label(layer: TrackLabels, new_track_id=True):
"to be able to set the current select label to the next one available"
)
else:
if new_track_id:
if new_track_id or layer.selected_track is None:
new_selected_track = layer.tracks_viewer.tracks.get_next_track_id()
layer.selected_track = new_selected_track
layer.selected_label = new_selected_label
Expand Down Expand Up @@ -345,7 +343,9 @@ def update_selected_label(self):

self.events.selected_label.disconnect(self._ensure_valid_label)
if len(self.tracks_viewer.selected_nodes) > 0:
self.selected_label = int(self.tracks_viewer.selected_nodes[0])
node = int(self.tracks_viewer.selected_nodes[0])
self.selected_label = node
self.selected_track = int(self.tracks_viewer.tracks.get_track_id(node))
self.events.selected_label.connect(self._ensure_valid_label)

def _ensure_valid_label(self, event: Event | None = None):
Expand Down Expand Up @@ -384,12 +384,10 @@ def _ensure_valid_label(self, event: Event | None = None):
# if a node with the given label is already in the graph
if self.tracks_viewer.tracks.graph.has_node(self.selected_label):
# Update the track id
self.selected_track = self.tracks_viewer.tracks._get_node_attr(
self.selected_label, NodeAttr.TRACK_ID.value
)
existing_time = self.tracks_viewer.tracks._get_node_attr(
self.selected_label, NodeAttr.TIME.value
self.selected_track = self.tracks_viewer.tracks.get_track_id(
self.selected_label
)
existing_time = self.tracks_viewer.tracks.get_time(self.selected_label)
if existing_time == current_timepoint:
# we are changing the existing node. This is fine
pass
Expand All @@ -402,9 +400,7 @@ def _ensure_valid_label(self, event: Event | None = None):
self.selected_track
]:
if (
self.tracks_viewer.tracks._get_node_attr(
node, NodeAttr.TIME.value
)
self.tracks_viewer.tracks.get_time(node)
== current_timepoint
):
self.selected_label = int(node)
Expand All @@ -430,12 +426,7 @@ def _ensure_valid_label(self, event: Event | None = None):
for node in self.tracks_viewer.tracks.track_id_to_node[
self.selected_track
]:
if (
self.tracks_viewer.tracks._get_node_attr(
node, NodeAttr.TIME.value
)
== current_timepoint
):
if self.tracks_viewer.tracks.get_time(node) == current_timepoint:
self.selected_label = int(node)
edit = True
break
Expand Down
Loading