Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8614.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix FK constraint violation in purge_scaling_group by deleting kernels before sessions. Without this, deleting a scaling group with active sessions fails due to fk_kernels_session_id_sessions constraint
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,7 @@ async def purge_scaling_group(
self,
purger: Purger[ScalingGroupRow],
) -> ScalingGroupData:
"""Purges a scaling group and all related sessions and routes using a purger.

Cascade delete order:
1. RoutingRow (session FK with RESTRICT)
2. EndpointRow (resource_group FK with RESTRICT, has CASCADE to routing)
3. SessionRow (scaling_group FK)
4. ScalingGroupRow
"""Purges a scaling group and all related sessions, routes, endpoints, and kernels.

Raises ScalingGroupNotFound if scaling group doesn't exist.
"""
Expand All @@ -162,13 +156,20 @@ async def purge_scaling_group(
)
await session.execute(delete_endpoints_stmt)

# Step 4: Delete all sessions belonging to this scaling group
# Step 4: Delete all kernels belonging to these sessions
if session_ids:
delete_kernels_stmt = sa.delete(KernelRow).where(
KernelRow.session_id.in_(session_ids)
)
await session.execute(delete_kernels_stmt)

# Step 5: Delete all sessions belonging to this scaling group
delete_sessions_stmt = sa.delete(SessionRow).where(
SessionRow.scaling_group_name == scaling_group_name
)
await session.execute(delete_sessions_stmt)

# Step 5: Delete the scaling group itself using purger
# Step 6: Delete the scaling group itself using purger
result = await execute_purger(session, purger)

if result is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def purge_scaling_group(
self,
purger: Purger[ScalingGroupRow],
) -> ScalingGroupData:
"""Purges a scaling group and all related sessions and routes using a purger.
"""Purges a scaling group and all related sessions, routes, endpoints, and kernels.

Raises ScalingGroupNotFound if scaling group doesn't exist.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ai.backend.common.types import AccessKey, DefaultForUnspecified, ResourceSlot, SessionTypes
from ai.backend.manager.data.auth.hash import PasswordHashAlgorithm
from ai.backend.manager.data.user.types import UserStatus
from ai.backend.manager.defs import DEFAULT_ROLE
from ai.backend.manager.errors.resource import ScalingGroupNotFound
from ai.backend.manager.models.agent import AgentRow
from ai.backend.manager.models.deployment_auto_scaling_policy import DeploymentAutoScalingPolicyRow
Expand Down Expand Up @@ -1186,3 +1187,194 @@ async def test_disassociate_nonexistent_scaling_group_with_user_groups(
)
# Then: Should not raise any error (BatchPurger deletes 0 rows silently)
await scaling_group_repository.disassociate_scaling_group_with_user_groups(purger)

@pytest.fixture
async def sample_scaling_group_for_hierarchy(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
) -> AsyncGenerator[str, None]:
sgroup_name = f"test-{uuid.uuid4().hex[:8]}"
async with db_with_cleanup.begin_session() as db_sess:
sgroup = ScalingGroupRow(
name=sgroup_name,
description="Test scaling group for full hierarchy cascade delete",
is_active=True,
is_public=True,
created_at=datetime.now(tz=UTC),
wsproxy_addr=None,
wsproxy_api_token=None,
driver="static",
driver_opts={},
scheduler="fifo",
scheduler_opts=ScalingGroupOpts(),
use_host_network=False,
)
db_sess.add(sgroup)
await db_sess.flush()
yield sgroup_name

@pytest.fixture
async def sample_session(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_scaling_group_for_hierarchy: str,
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
) -> AsyncGenerator[SessionId, None]:
"""Create a session referencing the scaling group."""
test_user_uuid, test_domain, test_group_id = test_user_domain_group
session_id = SessionId(uuid.uuid4())
async with db_with_cleanup.begin_session() as db_sess:
db_sess.add(
SessionRow(
id=session_id,
domain_name=test_domain,
group_id=test_group_id,
user_uuid=test_user_uuid,
scaling_group_name=sample_scaling_group_for_hierarchy,
cluster_size=1,
vfolder_mounts={},
)
)
await db_sess.flush()
yield session_id

@pytest.fixture
async def sample_kernel(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_session: SessionId,
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
) -> AsyncGenerator[uuid.UUID, None]:
"""Create a kernel for the session."""
test_user_uuid, test_domain, test_group_id = test_user_domain_group
kernel_id = uuid.uuid4()
async with db_with_cleanup.begin_session() as db_sess:
db_sess.add(
KernelRow(
id=kernel_id,
session_id=sample_session,
domain_name=test_domain,
group_id=test_group_id,
user_uuid=test_user_uuid,
cluster_role=DEFAULT_ROLE,
occupied_slots=ResourceSlot(),
repl_in_port=0,
repl_out_port=0,
stdin_port=0,
stdout_port=0,
vfolder_mounts=None,
)
)
await db_sess.flush()
yield kernel_id

@pytest.fixture
async def sample_endpoint(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_scaling_group_for_hierarchy: str,
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
) -> AsyncGenerator[uuid.UUID, None]:
"""Create an endpoint referencing the scaling group."""
test_user_uuid, test_domain, test_group_id = test_user_domain_group
endpoint_id = uuid.uuid4()
async with db_with_cleanup.begin_session() as db_sess:
db_sess.add(
EndpointRow(
id=endpoint_id,
name="test-endpoint-hierarchy",
domain=test_domain,
project=test_group_id,
resource_group=sample_scaling_group_for_hierarchy,
image=None,
lifecycle_stage=EndpointLifecycle.DESTROYED,
session_owner=test_user_uuid,
created_user=test_user_uuid,
)
)
await db_sess.flush()
yield endpoint_id

@pytest.fixture
async def sample_route(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_session: SessionId,
sample_endpoint: uuid.UUID,
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
) -> AsyncGenerator[uuid.UUID, None]:
"""Create a route connecting the session to the endpoint."""
test_user_uuid, test_domain, test_group_id = test_user_domain_group
route_id = uuid.uuid4()
async with db_with_cleanup.begin_session() as db_sess:
db_sess.add(
RoutingRow(
id=route_id,
endpoint=sample_endpoint,
session=sample_session,
session_owner=test_user_uuid,
domain=test_domain,
project=test_group_id,
traffic_ratio=1.0,
)
)
await db_sess.flush()
yield route_id

async def test_purge_scaling_group_with_full_hierarchy(
self,
scaling_group_repository: ScalingGroupRepository,
sample_scaling_group_for_hierarchy: str,
sample_session: SessionId,
sample_kernel: uuid.UUID,
sample_endpoint: uuid.UUID,
sample_route: uuid.UUID,
db_with_cleanup: ExtendedAsyncSAEngine,
) -> None:
"""Test purging a scaling group with the full FK hierarchy.

Hierarchy: ScalingGroup → Session → Kernel + Endpoint → Route
"""
sgroup_name = sample_scaling_group_for_hierarchy
session_id = sample_session
kernel_id = sample_kernel
endpoint_id = sample_endpoint
route_id = sample_route

purger = Purger(row_class=ScalingGroupRow, pk_value=sgroup_name)
# FK Error should not occur, and all related records should be deleted
result = await scaling_group_repository.purge_scaling_group(purger)

assert result.name == sgroup_name

# Verify all records in the hierarchy are deleted
async with db_with_cleanup.begin_readonly_session() as db_sess:
# Verify scaling group is deleted
sg_result = await db_sess.execute(
sa.select(ScalingGroupRow).where(ScalingGroupRow.name == sgroup_name)
)
assert sg_result.scalar_one_or_none() is None

# Verify session is deleted
session_result = await db_sess.execute(
sa.select(SessionRow).where(SessionRow.id == session_id)
)
assert session_result.scalar_one_or_none() is None

# Verify kernel is deleted
kernel_result = await db_sess.execute(
sa.select(KernelRow).where(KernelRow.id == kernel_id)
)
assert kernel_result.scalar_one_or_none() is None

# Verify endpoint is deleted
endpoint_result = await db_sess.execute(
sa.select(EndpointRow).where(EndpointRow.id == endpoint_id)
)
assert endpoint_result.scalar_one_or_none() is None

# Verify route is deleted
route_result = await db_sess.execute(
sa.select(RoutingRow).where(RoutingRow.id == route_id)
)
assert route_result.scalar_one_or_none() is None
Loading