Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import packaging.version
from connexion import FlaskApi
from flask import Blueprint, url_for
from flask import Blueprint, g, url_for
from packaging.version import Version
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
Expand Down Expand Up @@ -183,9 +183,20 @@ def get_user_display_name(self) -> str:
return f"{first_name} {last_name}".strip()

def get_user(self) -> User:
"""Return the user associated to the user in session."""
"""
Return the user associated to the user in session.

Attempt to find the current user in g.user, as defined by the kerberos authentication backend.
If no such user is found, return the `current_user` local proxy object, linked to the user session.

"""
from flask_login import current_user

# If a user has gone through the Kerberos dance, the kerberos authentication manager
# has linked it with a User model, stored in g.user, and not the session.
if current_user.is_anonymous and getattr(g, "user", None) is not None and not g.user.is_anonymous:
return g.user

return current_user

def init(self) -> None:
Expand Down
26 changes: 23 additions & 3 deletions providers/tests/fab/auth_manager/test_fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations

from contextlib import contextmanager
from itertools import chain
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import Mock

import pytest
from flask import Flask
from flask import Flask, g

from airflow.exceptions import AirflowConfigException, AirflowException

Expand Down Expand Up @@ -72,6 +73,13 @@
}


@contextmanager
def user_set(app, user):
g.user = user
yield
g.user = None


@pytest.fixture
def auth_manager():
return FabAuthManager(None)
Expand Down Expand Up @@ -114,12 +122,24 @@ def test_get_user_display_name(
assert auth_manager.get_user_display_name() == expected

@mock.patch("flask_login.utils._get_user")
def test_get_user(self, mock_current_user, auth_manager):
def test_get_user(self, mock_current_user, minimal_app_for_auth_api, auth_manager):
user = Mock()
user.is_anonymous.return_value = True
mock_current_user.return_value = user
with minimal_app_for_auth_api.app_context():
assert auth_manager.get_user() == user

assert auth_manager.get_user() == user
@mock.patch("flask_login.utils._get_user")
def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api, auth_manager):
session_user = Mock()
session_user.is_anonymous = True
mock_current_user.return_value = session_user

flask_g_user = Mock()
flask_g_user.is_anonymous = False
with minimal_app_for_auth_api.app_context():
with user_set(minimal_app_for_auth_api, flask_g_user):
assert auth_manager.get_user() == flask_g_user

@pytest.mark.db_test
@mock.patch.object(FabAuthManager, "get_user")
Expand Down