Skip to content

Commit 16bedcb

Browse files
committed
[script] predator_following
1 parent 169ebe3 commit 16bedcb

2 files changed

Lines changed: 202 additions & 2 deletions

File tree

scripts/predator_following.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import dataclasses
2+
from pathlib import Path
3+
from typing import NamedTuple
4+
5+
import numpy as np
6+
import polars as pl
7+
import typer
8+
from numpy.lib.npyio import NpzFile
9+
from numpy.typing import NDArray
10+
from scipy.spatial.distance import cdist
11+
12+
from emevo.analysis.log_plotting import load_log
13+
14+
15+
class AgentState(NamedTuple):
16+
angle: NDArray
17+
xy: NDArray
18+
is_active: NDArray
19+
20+
21+
@dataclasses.dataclass
22+
class AgentStateLoader:
23+
size: int
24+
files: list[NpzFile]
25+
cache: dict[int, AgentState] = dataclasses.field(default_factory=dict)
26+
27+
def get(self, time: int) -> AgentState:
28+
index = time // self.size
29+
if index not in self.cache:
30+
self._load_cache(index)
31+
cached = self.cache[index]
32+
offset = time % self.size
33+
angle = cached.angle[offset]
34+
xy = cached.xy[offset]
35+
is_active = cached.is_active[offset]
36+
return AgentState(angle=angle, xy=xy, is_active=is_active)
37+
38+
def _load_cache(self, index: int) -> None:
39+
angle = self.files[index]["circle_axy"].astype(np.float32)[:, :, 0]
40+
xy = self.files[index]["circle_axy"].astype(np.float32)[:, :, 1:]
41+
is_active = self.files[index]["circle_is_active"].astype(bool)
42+
self.cache[index] = AgentState(angle=angle, xy=xy, is_active=is_active)
43+
44+
45+
def get_state_loader(
46+
dirpath: Path,
47+
n_states: int,
48+
size: int = 1024000,
49+
) -> AgentStateLoader:
50+
files = []
51+
for i in range(n_states):
52+
npzfile = np.load(dirpath / f"state-{i + 1}.npz")
53+
files.append(npzfile)
54+
return AgentStateLoader(size=size, files=files)
55+
56+
57+
def load(
58+
logd: Path,
59+
n_states: int = 10,
60+
state_size: int = 1024000,
61+
) -> tuple[AgentStateLoader, pl.DataFrame]:
62+
state_loader = get_state_loader(logd, n_states, state_size)
63+
64+
eaten_path = logd / "eaten.parquet"
65+
if eaten_path.exists():
66+
stepdf = pl.read_parquet(eaten_path).select(
67+
"unique_id",
68+
"slots",
69+
"start",
70+
"end",
71+
)
72+
else:
73+
ldf = load_log(logd, last_idx=n_states).with_columns(
74+
pl.col("step").alias("Step")
75+
)
76+
stepdf = (
77+
ldf.group_by("unique_id")
78+
.agg(
79+
pl.col("slots").first(),
80+
pl.col("step").min().alias("start"),
81+
pl.col("step").max().alias("end"),
82+
)
83+
.collect()
84+
)
85+
return state_loader, stepdf
86+
87+
88+
def find_following_prey(
89+
*,
90+
state_loader: AgentStateLoader,
91+
stepdf: pl.DataFrame,
92+
start: int,
93+
interval: int,
94+
n_max_preys: int,
95+
end: int,
96+
neighbor: float = 50.0,
97+
angle_threshold: float = np.pi / 4, # 45 degrees
98+
) -> pl.DataFrame:
99+
step_list = []
100+
prey_uid_list = []
101+
predator_slot_list = []
102+
103+
for i in range(start, end, interval):
104+
dfi = stepdf.filter((pl.col("start") < i) & (i < pl.col("end")))
105+
angle, xy, is_active = state_loader.get(i)
106+
107+
# Split prey and predators
108+
all_prey_xy = xy[:n_max_preys]
109+
all_prey_active = is_active[:n_max_preys]
110+
all_prey_angles = angle[:n_max_preys]
111+
112+
all_predator_xy = xy[n_max_preys:]
113+
all_predator_active = is_active[n_max_preys:]
114+
115+
active_slots_in_df = set(dfi["slots"].to_list())
116+
117+
# 1. Identify valid prey indices
118+
valid_prey_indices = [
119+
idx
120+
for idx in range(n_max_preys)
121+
if all_prey_active[idx] and idx in active_slots_in_df
122+
]
123+
124+
# 2. Identify valid predator indices (using offset for global indexing)
125+
valid_pred_indices = [
126+
idx for idx, active in enumerate(all_predator_active) if active
127+
]
128+
129+
if not valid_prey_indices or not valid_pred_indices:
130+
continue
131+
132+
# 3. Compute Distance Matrix (Prey x Predators)
133+
prey_coords = all_prey_xy[valid_prey_indices]
134+
pred_coords = all_predator_xy[valid_pred_indices]
135+
dist_mat = cdist(prey_coords, pred_coords)
136+
137+
# 4. Check Following Condition
138+
for p_idx, prey_slot in enumerate(valid_prey_indices):
139+
prey_pos = prey_coords[p_idx]
140+
prey_angle = all_prey_angles[prey_slot]
141+
142+
# Unit vector of prey heading
143+
prey_dir = np.array([np.cos(prey_angle), np.sin(prey_angle)])
144+
145+
for target_idx, pred_slot_local in enumerate(valid_pred_indices):
146+
dist = dist_mat[p_idx, target_idx]
147+
148+
if dist < neighbor:
149+
# Vector from prey to predator
150+
vec_to_pred = pred_coords[target_idx] - prey_pos
151+
vec_to_pred_unit = vec_to_pred / (
152+
np.linalg.norm(vec_to_pred) + 1e-6
153+
)
154+
155+
# Dot product to find cosine of angle between heading and predator
156+
cos_sim = np.dot(prey_dir, vec_to_pred_unit)
157+
158+
# If angle difference is within threshold (e.g., cos(45°))
159+
if cos_sim > np.cos(angle_threshold):
160+
u_id = dfi.filter(pl.col("slots") == prey_slot)["unique_id"][0]
161+
162+
step_list.append(i)
163+
prey_uid_list.append(u_id)
164+
# predator index relative to the start of predator block
165+
predator_slot_list.append(n_max_preys + pred_slot_local)
166+
167+
return pl.DataFrame(
168+
{
169+
"Step": step_list,
170+
"prey_unique_id": prey_uid_list,
171+
"followed_predator_slot": predator_slot_list,
172+
}
173+
)
174+
175+
176+
def main(
177+
logd: Path,
178+
n_states: int = 10,
179+
n_max_preys: int = 450,
180+
start: int = 9216000,
181+
interval: int = 1000,
182+
end: int = 10240000,
183+
neighbor: int = 25,
184+
state_size: int = 1024000,
185+
) -> None:
186+
state_loader, stepdf = load(logd, n_states, state_size)
187+
group_df = find_following_prey(
188+
state_loader=state_loader,
189+
stepdf=stepdf,
190+
start=start,
191+
interval=interval,
192+
n_max_preys=n_max_preys,
193+
neighbor=neighbor,
194+
end=end,
195+
)
196+
group_df.write_parquet(logd / f"group-{start}-{interval}-{neighbor}.parquet")
197+
198+
199+
if __name__ == "__main__":
200+
typer.run(main)

scripts/predator_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def find_groups(
9090
start: int,
9191
interval: int,
9292
n_max_preys: int,
93-
neighbor: int,
93+
neighbor: float,
9494
end: int,
9595
) -> pl.DataFrame:
9696
step_list = []
@@ -177,7 +177,7 @@ def main(
177177
start: int = 9216000,
178178
interval: int = 1000,
179179
end: int = 10240000,
180-
neighbor: int = 25,
180+
neighbor: float = 25.0,
181181
state_size: int = 1024000,
182182
) -> None:
183183
state_loader, stepdf = load(logd, n_states, state_size)

0 commit comments

Comments
 (0)