Skip to content

Commit d2d0bdc

Browse files
refactor dojo async task base task
The custom decorators that we have on Celery tasks interfere with some (advanced) celery functionality like signatures. This PR refactors this to have a clean base task that passes on context, but does not interfere with celery mechanisms. The logic to decide whether or not the task is to be called asynchronously is now in a dispatch method.
1 parent f12f27e commit d2d0bdc

File tree

29 files changed

+806
-293
lines changed

29 files changed

+806
-293
lines changed

dojo/api_v2/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from dojo.api_v2.prefetch.prefetcher import _Prefetcher
4848
from dojo.authorization.roles_permissions import Permissions
49+
from dojo.celery_dispatch import dojo_dispatch_task
4950
from dojo.cred.queries import get_authorized_cred_mappings
5051
from dojo.endpoint.queries import (
5152
get_authorized_endpoint_status,
@@ -679,13 +680,13 @@ def update_jira_epic(self, request, pk=None):
679680
try:
680681

681682
if engagement.has_jira_issue:
682-
jira_helper.update_epic(engagement.id, **request.data)
683+
dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data)
683684
response = Response(
684685
{"info": "Jira Epic update query sent"},
685686
status=status.HTTP_200_OK,
686687
)
687688
else:
688-
jira_helper.add_epic(engagement.id, **request.data)
689+
dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data)
689690
response = Response(
690691
{"info": "Jira Epic create query sent"},
691692
status=status.HTTP_200_OK,

dojo/celery.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,56 @@
1212
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings")
1313

1414

15-
class PgHistoryTask(Task):
15+
class DojoAsyncTask(Task):
16+
17+
"""
18+
Base task class that provides dojo_async_task functionality without using a decorator.
19+
20+
This class:
21+
- Injects user context into task kwargs
22+
- Tracks task calls for performance testing
23+
- Supports all Celery features (signatures, chords, groups, chains)
24+
"""
25+
26+
def apply_async(self, args=None, kwargs=None, **options):
27+
"""Override apply_async to inject user context and track tasks."""
28+
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
29+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
30+
31+
if kwargs is None:
32+
kwargs = {}
33+
34+
# Inject user context if not already present
35+
if "async_user" not in kwargs:
36+
kwargs["async_user"] = get_current_user()
37+
38+
# Control flag used for sync/async decision; never pass into the task itself
39+
kwargs.pop("sync", None)
40+
41+
# Track dispatch
42+
dojo_async_task_counter.incr(
43+
self.name,
44+
args=args,
45+
kwargs=kwargs,
46+
)
47+
48+
# Call parent to execute async
49+
return super().apply_async(args=args, kwargs=kwargs, **options)
50+
51+
52+
class PgHistoryTask(DojoAsyncTask):
1653

1754
"""
1855
Custom Celery base task that automatically applies pghistory context.
1956
20-
When a task is dispatched via dojo_async_task, the current pghistory
21-
context is captured and passed in kwargs as "_pgh_context". This base
22-
class extracts that context and applies it before running the task,
23-
ensuring all database events share the same context as the original
24-
request.
57+
This class inherits from DojoAsyncTask to provide:
58+
- User context injection and task tracking (from DojoAsyncTask)
59+
- Automatic pghistory context application (from this class)
60+
61+
When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current
62+
pghistory context is captured and passed in kwargs as "_pgh_context". This base
63+
class extracts that context and applies it before running the task, ensuring all
64+
database events share the same context as the original request.
2565
"""
2666

2767
def __call__(self, *args, **kwargs):

dojo/celery_dispatch.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol, cast
4+
5+
from celery.canvas import Signature
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Mapping
9+
10+
11+
class _SupportsSi(Protocol):
12+
def si(self, *args: Any, **kwargs: Any) -> Signature: ...
13+
14+
15+
class _SupportsApplyAsync(Protocol):
16+
def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ...
17+
18+
19+
def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
20+
result: dict[str, Any] = dict(kwargs or {})
21+
if "async_user" not in result:
22+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
23+
24+
result["async_user"] = get_current_user()
25+
return result
26+
27+
28+
def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
29+
"""Capture and inject pghistory context if available."""
30+
result: dict[str, Any] = dict(kwargs or {})
31+
if "_pgh_context" not in result:
32+
from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import
33+
34+
if pgh_context := get_serializable_pghistory_context():
35+
result["_pgh_context"] = pgh_context
36+
return result
37+
38+
39+
def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature:
40+
"""
41+
Build a Celery signature with DefectDojo user context and pghistory context injected.
42+
43+
- If passed a task, returns `task_or_sig.si(*args, **kwargs)`.
44+
- If passed an existing signature, returns a cloned signature with merged kwargs.
45+
"""
46+
injected = _inject_async_user(kwargs)
47+
injected = _inject_pghistory_context(injected)
48+
injected.pop("countdown", None)
49+
50+
if isinstance(task_or_sig, Signature):
51+
merged_kwargs = {**(task_or_sig.kwargs or {}), **injected}
52+
return task_or_sig.clone(kwargs=merged_kwargs)
53+
54+
return task_or_sig.si(*args, **injected)
55+
56+
57+
def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any:
58+
"""
59+
Dispatch a task/signature using DefectDojo semantics.
60+
61+
- Inject `async_user` if missing.
62+
- Capture and inject pghistory context if available.
63+
- Respect `sync=True` (foreground execution) and user `block_execution`.
64+
- Support `countdown=<seconds>` for async dispatch.
65+
66+
Returns:
67+
- async: AsyncResult-like return from Celery
68+
- sync: underlying return value of the task
69+
70+
"""
71+
from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import
72+
73+
countdown = cast("int", kwargs.pop("countdown", 0))
74+
injected = _inject_async_user(kwargs)
75+
injected = _inject_pghistory_context(injected)
76+
77+
sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected)
78+
sig_kwargs = dict(sig.kwargs or {})
79+
80+
if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs):
81+
# DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here.
82+
return sig.apply_async(countdown=countdown)
83+
84+
# Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior)
85+
dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs)
86+
87+
sig_kwargs.pop("sync", None)
88+
sig = sig.clone(kwargs=sig_kwargs)
89+
eager = sig.apply()
90+
try:
91+
return eager.get(propagate=True)
92+
except RuntimeError:
93+
# Since we are intentionally running synchronously, we can propagate exceptions directly, and enable sync subtasks
94+
# If the requests desires this. Celery docs explain that this is a rare use case, but we support it _just in case_
95+
return eager.get(propagate=True, disable_sync_subtasks=False)

dojo/endpoint/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dojo.authorization.authorization import user_has_permission_or_403
1919
from dojo.authorization.authorization_decorators import user_is_authorized
2020
from dojo.authorization.roles_permissions import Permissions
21+
from dojo.celery_dispatch import dojo_dispatch_task
2122
from dojo.endpoint.queries import get_authorized_endpoints_for_queryset
2223
from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import
2324
from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups
@@ -345,7 +346,7 @@ def endpoint_bulk_update_all(request, pid=None):
345346
product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct())
346347
endpoints.delete()
347348
for prod in product_calc:
348-
calculate_grade(prod.id)
349+
dojo_dispatch_task(calculate_grade, prod.id)
349350

350351
if skipped_endpoint_count > 0:
351352
add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.")

dojo/engagement/services.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.dispatch import receiver
66

77
import dojo.jira_link.helper as jira_helper
8+
from dojo.celery_dispatch import dojo_dispatch_task
89
from dojo.models import Engagement
910

1011
logger = logging.getLogger(__name__)
@@ -16,7 +17,7 @@ def close_engagement(eng):
1617
eng.save()
1718

1819
if jira_helper.get_jira_project(eng):
19-
jira_helper.close_epic(eng.id, push_to_jira=True)
20+
dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True)
2021

2122

2223
def reopen_engagement(eng):

dojo/engagement/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from dojo.authorization.authorization import user_has_permission_or_403
3838
from dojo.authorization.authorization_decorators import user_is_authorized
3939
from dojo.authorization.roles_permissions import Permissions
40+
from dojo.celery_dispatch import dojo_dispatch_task
4041
from dojo.endpoint.utils import save_endpoints_to_add
4142
from dojo.engagement.queries import get_authorized_engagements
4243
from dojo.engagement.services import close_engagement, reopen_engagement
@@ -392,7 +393,7 @@ def copy_engagement(request, eid):
392393
form = DoneForm(request.POST)
393394
if form.is_valid():
394395
engagement_copy = engagement.copy()
395-
calculate_grade(product.id)
396+
dojo_dispatch_task(calculate_grade, product.id)
396397
messages.add_message(
397398
request,
398399
messages.SUCCESS,

dojo/finding/deduplication.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from django.db.models.query_utils import Q
99

1010
from dojo.celery import app
11-
from dojo.decorators import dojo_async_task
1211
from dojo.models import Finding, System_Settings
1312

1413
logger = logging.getLogger(__name__)
@@ -45,13 +44,11 @@ def get_finding_models_for_deduplication(finding_ids):
4544
)
4645

4746

48-
@dojo_async_task
4947
@app.task
5048
def do_dedupe_finding_task(new_finding_id, *args, **kwargs):
5149
return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs)
5250

5351

54-
@dojo_async_task
5552
@app.task
5653
def do_dedupe_batch_task(finding_ids, *args, **kwargs):
5754
"""

dojo/finding/helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import dojo.jira_link.helper as jira_helper
1717
import dojo.risk_acceptance.helper as ra_helper
1818
from dojo.celery import app
19-
from dojo.decorators import dojo_async_task
2019
from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add
2120
from dojo.file_uploads.helper import delete_related_files
2221
from dojo.finding.deduplication import (
@@ -395,7 +394,6 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group
395394
finding_group.findings.add(*findings)
396395

397396

398-
@dojo_async_task
399397
@app.task
400398
def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002
401399
issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed
@@ -440,7 +438,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
440438

441439
if product_grading_option:
442440
if system_settings.enable_product_grade:
443-
calculate_grade(finding.test.engagement.product.id)
441+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
442+
443+
dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id)
444444
else:
445445
deduplicationLogger.debug("skipping product grading because it's disabled in system settings")
446446

@@ -457,7 +457,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
457457
jira_helper.push_to_jira(finding.finding_group)
458458

459459

460-
@dojo_async_task
461460
@app.task
462461
def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True,
463462
issue_updater_option=True, push_to_jira=False, user=None, **kwargs):
@@ -500,7 +499,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op
500499
tool_issue_updater.async_tool_issue_update(finding)
501500

502501
if product_grading_option and system_settings.enable_product_grade:
503-
calculate_grade(findings[0].test.engagement.product.id)
502+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
503+
504+
dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id)
504505

505506
if push_to_jira:
506507
for finding in findings:

dojo/finding/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
user_is_authorized,
3939
)
4040
from dojo.authorization.roles_permissions import Permissions
41+
from dojo.celery_dispatch import dojo_dispatch_task
4142
from dojo.filters import (
4243
AcceptedFindingFilter,
4344
AcceptedFindingFilterWithoutObjectLookups,
@@ -1099,7 +1100,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict):
10991100
product = finding.test.engagement.product
11001101
finding.delete()
11011102
# Update the grade of the product async
1102-
calculate_grade(product.id)
1103+
dojo_dispatch_task(calculate_grade, product.id)
11031104
# Add a message to the request that the finding was successfully deleted
11041105
messages.add_message(
11051106
request,
@@ -1374,7 +1375,7 @@ def copy_finding(request, fid):
13741375
test = form.cleaned_data.get("test")
13751376
product = finding.test.engagement.product
13761377
finding_copy = finding.copy(test=test)
1377-
calculate_grade(product.id)
1378+
dojo_dispatch_task(calculate_grade, product.id)
13781379
messages.add_message(
13791380
request,
13801381
messages.SUCCESS,

dojo/finding_group/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dojo.authorization.authorization import user_has_permission_or_403
1717
from dojo.authorization.authorization_decorators import user_is_authorized
1818
from dojo.authorization.roles_permissions import Permissions
19+
from dojo.celery_dispatch import dojo_dispatch_task
1920
from dojo.filters import (
2021
FindingFilter,
2122
FindingFilterWithoutObjectLookups,
@@ -100,7 +101,7 @@ def view_finding_group(request, fgid):
100101
elif not finding_group.has_jira_issue:
101102
jira_helper.finding_group_link_jira(request, finding_group, jira_issue)
102103
elif push_to_jira:
103-
jira_helper.push_to_jira(finding_group, sync=True)
104+
dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True)
104105

105106
finding_group.save()
106107
return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,)))
@@ -200,7 +201,7 @@ def push_to_jira(request, fgid):
200201

201202
# it may look like success here, but the push_to_jira are swallowing exceptions
202203
# but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors.
203-
if jira_helper.push_to_jira(group, sync=True):
204+
if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True):
204205
messages.add_message(
205206
request,
206207
messages.SUCCESS,

0 commit comments

Comments
 (0)