Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions airflow-core/src/airflow/api_fastapi/common/cursors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Cursor-based (keyset) pagination helpers.

:meta private:
"""

from __future__ import annotations

import base64
import uuid as uuid_mod
from typing import Any

import msgspec
from fastapi import HTTPException, status
from sqlalchemy import and_, or_
from sqlalchemy.sql import Select
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.sqltypes import Uuid

from airflow.api_fastapi.common.parameters import SortParam


def _b64url_decode_padded(token: str) -> bytes:
padding = 4 - (len(token) % 4)
if padding != 4:
token = token + ("=" * padding)
return base64.urlsafe_b64decode(token.encode("ascii"))


def _nonstrict_bound(col: ColumnElement, value: Any, is_desc: bool) -> ColumnElement[bool]:
"""Inclusive range edge on the leading column at each nesting level (``>=`` / ``<=``)."""
return col <= value if is_desc else col >= value


def _strict_bound(col: ColumnElement, value: Any, is_desc: bool) -> ColumnElement[bool]:
"""Strict inequality for ``or_`` branches (``<`` / ``>``)."""
return col < value if is_desc else col > value


def _nested_keyset_predicate(
resolved: list[tuple[str, ColumnElement, bool]], values: list[Any]
) -> ColumnElement[bool]:
"""
Keyset predicate for rows strictly after the cursor in ``ORDER BY`` order.

Uses nested ``and_(non-strict, or_(strict, ...))`` so leading sort keys use
inclusive range bounds and inner branches use strict inequalities—friendly
for composite index range scans. Logically equivalent to an OR-of-prefix-
equalities formulation.
"""
n = len(resolved)
_, col, is_desc = resolved[n - 1]
inner: ColumnElement[bool] = _strict_bound(col, values[n - 1], is_desc)
for i in range(n - 2, -1, -1):
_, col_i, is_desc_i = resolved[i]
inner = and_(
_nonstrict_bound(col_i, values[i], is_desc_i),
or_(_strict_bound(col_i, values[i], is_desc_i), inner),
)
return inner


def _coerce_value(column: ColumnElement, value: Any) -> Any:
"""Normalize decoded values for SQL bind parameters (e.g. UUID columns)."""
if value is None or not isinstance(value, str):
return value
ctype = getattr(column, "type", None)
if isinstance(ctype, Uuid):
try:
return uuid_mod.UUID(value)
except ValueError:
return value
return value


def encode_cursor(row: Any, sort_param: SortParam) -> str:
"""
Encode cursor token from the boundary row of a result set.

The token is a url-safe base64 encoding of a MessagePack list of sort-key
values (no padding ``=``), so the cursor is compact and safe in query strings.
Binary msgpack is not URL-safe by itself, so base64 is still required.
"""
resolved = sort_param.get_resolved_columns()
if not resolved:
raise ValueError("SortParam has no resolved columns.")

parts = [getattr(row, attr_name, None) for attr_name, _col, _desc in resolved]
payload = msgspec.msgpack.encode(parts)
return base64.urlsafe_b64encode(payload).decode("ascii").rstrip("=")


def decode_cursor(token: str) -> list[Any]:
"""Decode a cursor token to the list of sort-key values."""
try:
raw = _b64url_decode_padded(token)
except Exception:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token")

try:
data: Any = msgspec.msgpack.decode(raw)
except Exception:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token")

if not isinstance(data, list):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token structure")

return data


def apply_cursor_filter(statement: Select, cursor: str, sort_param: SortParam) -> Select:
"""
Apply a keyset pagination WHERE clause from a cursor token.

Uses nested ``and_(col <=/>= v, or_(col </> v, ...))`` so each leading sort
key carries a range-friendly non-strict bound, with strict inequalities on
the ``or_`` branches—aligned with common composite index range scans.
"""
raw_values = decode_cursor(cursor)

resolved = sort_param.get_resolved_columns()
if len(raw_values) != len(resolved):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Cursor token does not match current query shape")

parsed_values = [_coerce_value(col, val) for (_, col, _), val in zip(resolved, raw_values, strict=True)]

return statement.where(_nested_keyset_predicate(resolved, parsed_values))
52 changes: 39 additions & 13 deletions airflow-core/src/airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def to_orm(self, select: Select) -> Select:
return select.offset(self.value)

@classmethod
def depends(cls, offset: NonNegativeInt = 0) -> OffsetFilter:
def depends(
cls,
offset: NonNegativeInt = 0,
) -> OffsetFilter:
return cls().set_value(offset)


Expand Down Expand Up @@ -281,10 +284,18 @@ def __init__(
self.allowed_attrs = allowed_attrs
self.model = model
self.to_replace = to_replace
self._cached_resolution: tuple[list[ColumnElement], list[tuple[str, ColumnElement, bool]]] | None = (
None
)

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")
def set_value(self, value: list[str] | None) -> Self:
self._cached_resolution = None
return super().set_value(value)

def _resolve(self) -> tuple[list[ColumnElement], list[tuple[str, ColumnElement, bool]]]:
"""Resolve sort columns and ORDER BY expressions. Cached after first call."""
if self._cached_resolution is not None:
return self._cached_resolution

if self.value is None:
self.value = [self.get_primary_key_string()]
Expand All @@ -297,8 +308,10 @@ def to_orm(self, select: Select) -> Select:
)

columns: list[ColumnElement] = []
resolved: list[tuple[str, ColumnElement, bool]] = []
for order_by_value in order_by_values:
lstriped_orderby = order_by_value.lstrip("-")
attr_name = lstriped_orderby
column: Column | None = None
if self.to_replace:
replacement = self.to_replace.get(lstriped_orderby, lstriped_orderby)
Expand All @@ -316,22 +329,35 @@ def to_orm(self, select: Select) -> Select:
if column is None:
column = getattr(self.model, lstriped_orderby)

if order_by_value.startswith("-"):
is_desc = order_by_value.startswith("-")
if is_desc:
columns.append(column.desc())
else:
columns.append(column.asc())

# Reset default sorting
select = select.order_by(None)
resolved.append((attr_name, column, is_desc))

primary_key_column = self.get_primary_key_column()
# Always add a final discriminator to enforce deterministic ordering.
if order_by_values and order_by_values[0].startswith("-"):
columns.append(primary_key_column.desc())
else:
pk_name = self.get_primary_key_string()
# Always use ascending PK as the final tie-breaker so keyset pagination is stable when
# sort columns contain duplicates.
if not any(name == pk_name for name, _, _ in resolved):
columns.append(primary_key_column.asc())
resolved.append((pk_name, primary_key_column, False))

self._cached_resolution = (columns, resolved)
return self._cached_resolution

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

columns, _ = self._resolve()
return select.order_by(None).order_by(*columns)

return select.order_by(*columns)
def get_resolved_columns(self) -> list[tuple[str, ColumnElement, bool]]:
"""Return resolved sort columns as (attr_name, column_element, is_descending) tuples."""
_, resolved = self._resolve()
return resolved

def get_primary_key_column(self) -> Column:
"""Get the primary key column of the model of SortParam object."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,33 @@ class TaskInstanceResponse(BaseModel):


class TaskInstanceCollectionResponse(BaseModel):
"""Task Instance Collection serializer for responses."""
"""
Task instance collection response supporting both offset and cursor pagination.

A single flat model is used instead of a discriminated union
(``Annotated[Offset | Cursor, Field(discriminator=...)]``) because
the OpenAPI ``oneOf`` + ``discriminator`` construct is not handled
correctly by ``@hey-api/openapi-ts`` / ``@7nohe/openapi-react-query-codegen``:
return types degrade to ``unknown`` in JSDoc and can produce
incorrect TypeScript types (see hey-api/openapi-ts#1613, #3270).
"""

task_instances: Iterable[TaskInstanceResponse]
total_entries: int
total_entries: int | None = Field(
default=None,
description="Total number of matching items. Populated for offset pagination, "
"``null`` when using cursor pagination.",
)
next_cursor: str | None = Field(
default=None,
description="Token pointing to the next page. Populated for cursor pagination, "
"``null`` when using offset pagination or when there is no next page.",
)
previous_cursor: str | None = Field(
default=None,
description="Token pointing to the previous page. Populated for cursor pagination, "
"``null`` when using offset pagination or when on the first page.",
)


class TaskDependencyResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6524,10 +6524,22 @@ paths:
description: 'Get list of task instances.


This endpoint allows specifying `~` as the dag_id, dag_run_id to retrieve
Task Instances for all DAGs
This endpoint allows specifying `~` as the dag_id, dag_run_id

and DAG runs.'
to retrieve task instances for all DAGs and DAG runs.


Supports two pagination modes:


**Offset (default):** use `limit` and `offset` query parameters. Returns `total_entries`.


**Cursor:** pass `cursor` (empty string for the first page, then `next_cursor`
from the response).

When `cursor` is provided, `offset` is ignored and `total_entries` is not
returned.'
operationId: get_task_instances
security:
- OAuth2PasswordBearer: []
Expand All @@ -6545,6 +6557,20 @@ paths:
schema:
type: string
title: Dag Run Id
- name: cursor
in: query
required: false
schema:
anyOf:
- type: string
- type: 'null'
description: Cursor for keyset-based pagination (mutually exclusive with
offset). Pass an empty string for the first page, then use ``next_cursor``
from the response.
title: Cursor
Comment on lines +6567 to +6570
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice. I was about to say "so does this break compat" but you've covered it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, backward comp is handled.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well with flattening the models 'total_entires' field can be "None | int" now, not only 'int'. Technically that's breaking, do you think it's fine, should we add a newsfragment ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me have a play and see if we can get discrimination working again

description: Cursor for keyset-based pagination (mutually exclusive with offset).
Pass an empty string for the first page, then use ``next_cursor`` from the
response.
- name: task_id
in: query
required: false
Expand Down Expand Up @@ -12576,14 +12602,45 @@ components:
type: array
title: Task Instances
total_entries:
type: integer
anyOf:
- type: integer
- type: 'null'
title: Total Entries
description: Total number of matching items. Populated for offset pagination,
``null`` when using cursor pagination.
next_cursor:
anyOf:
- type: string
- type: 'null'
title: Next Cursor
description: Token pointing to the next page. Populated for cursor pagination,
``null`` when using offset pagination or when there is no next page.
previous_cursor:
anyOf:
- type: string
- type: 'null'
title: Previous Cursor
description: Token pointing to the previous page. Populated for cursor pagination,
``null`` when using offset pagination or when on the first page.
type: object
required:
- task_instances
- total_entries
title: TaskInstanceCollectionResponse
description: Task Instance Collection serializer for responses.
description: 'Task instance collection response supporting both offset and cursor
pagination.


A single flat model is used instead of a discriminated union

(``Annotated[Offset | Cursor, Field(discriminator=...)]``) because

the OpenAPI ``oneOf`` + ``discriminator`` construct is not handled

correctly by ``@hey-api/openapi-ts`` / ``@7nohe/openapi-react-query-codegen``:

return types degrade to ``unknown`` in JSDoc and can produce

incorrect TypeScript types (see hey-api/openapi-ts#1613, #3270).'
TaskInstanceHistoryCollectionResponse:
properties:
task_instances:
Expand Down
Loading
Loading