diff --git a/dojo/apps.py b/dojo/apps.py index 4b1af1ef192..d23a3c6f853 100644 --- a/dojo/apps.py +++ b/dojo/apps.py @@ -94,7 +94,7 @@ def ready(self): import dojo.product_type.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady import dojo.risk_acceptance.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady import dojo.sla_config.helpers # noqa: PLC0415, F401 raised: AppRegistryNotReady - import dojo.tags_signals # noqa: PLC0415, F401 raised: AppRegistryNotReady + import dojo.tags.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady import dojo.test.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady import dojo.tool_product.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady import dojo.url.signals # noqa: PLC0415, F401 raised: AppRegistryNotReady diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 0a7eb0bcdfe..79b54c20ab7 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -742,7 +742,7 @@ def bulk_clear_finding_m2m(finding_qs): FileUpload.delete() fires and removes files from disk storage. Tags are handled via bulk_remove_all_tags to maintain tag counts. """ - from dojo.tag_utils import bulk_remove_all_tags # noqa: PLC0415 circular import + from dojo.tags.utils import bulk_remove_all_tags # noqa: PLC0415 circular import finding_ids = finding_qs.values_list("id", flat=True) diff --git a/dojo/finding/views.py b/dojo/finding/views.py index 0d32235bf1f..d3cc025be26 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -94,7 +94,7 @@ User, ) from dojo.notifications.helper import create_notification -from dojo.tag_utils import bulk_add_tags_to_instances +from dojo.tags.utils import bulk_add_tags_to_instances from dojo.test.queries import get_authorized_tests from dojo.tools import tool_issue_updater from dojo.utils import ( diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 052b006356c..d87524185fe 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -33,7 +33,7 @@ Test_Type, ) from dojo.notifications.helper import create_notification -from dojo.tag_utils import bulk_add_tags_to_instances +from dojo.tags.utils import bulk_add_tags_to_instances from dojo.tools.factory import get_parser from dojo.tools.parser_test import ParserTest from dojo.utils import max_safe diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 428b8ad01a7..bc78a6b9592 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -18,7 +18,9 @@ Test_Import, ) from dojo.notifications.helper import async_create_notification -from dojo.tag_utils import bulk_apply_parser_tags +from dojo.tags import inheritance as tag_inheritance +from dojo.tags.inheritance import apply_inherited_tags_for_findings +from dojo.tags.utils import bulk_apply_parser_tags from dojo.utils import get_full_url, perform_product_grading from dojo.validators import clean_tags @@ -161,6 +163,19 @@ def process_findings( self, parsed_findings: list[Finding], **kwargs: dict, + ) -> list[Finding]: + # Whole hot loop runs under `batch_mode()`: per-row inheritance signals + # for the findings/endpoints/locations created below are suppressed. + # Inheritance is then applied in bulk per-batch (right before + # `post_process_findings_batch` dispatch) so rules/dedup see inherited + # tags on `finding.tags`. + with tag_inheritance.suppress_tag_inheritance(): + return self._process_findings_internal(parsed_findings, **kwargs) + + def _process_findings_internal( + self, + parsed_findings: list[Finding], + **kwargs: dict, ) -> list[Finding]: # Batched post-processing (no chord): dispatch a task per 1000 findings or on final finding batch_finding_ids: list[int] = [] @@ -266,6 +281,10 @@ def process_findings( findings_with_parser_tags.clear() # Apply import-time tags before post-processing so rules/deduplication see them. self.apply_import_tags_for_batch(batch_findings) + # Apply inherited Product tags to this batch's findings (and + # their endpoints/locations) BEFORE post_process_findings_batch + # dispatches, so rules/dedup see inherited tags on .tags. + apply_inherited_tags_for_findings(batch_findings) batch_findings.clear() finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 3df3c0bd4cc..9940cb337b8 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -24,7 +24,9 @@ Test, Test_Import, ) -from dojo.tag_utils import bulk_apply_parser_tags +from dojo.tags import inheritance as tag_inheritance +from dojo.tags.inheritance import apply_inherited_tags_for_findings +from dojo.tags.utils import bulk_apply_parser_tags from dojo.utils import perform_product_grading from dojo.validators import clean_tags @@ -263,6 +265,19 @@ def process_findings( the finding may be appended to a new or existing group based upon user selection at import time """ + # Whole hot loop runs under `batch_mode()`: per-row inheritance signals + # for the findings/endpoints/locations created below are suppressed. + # Inheritance is then applied in bulk per-batch (right before + # `post_process_findings_batch` dispatch) so rules/dedup see inherited + # tags on `finding.tags`. + with tag_inheritance.suppress_tag_inheritance(): + return self._process_findings_internal(parsed_findings, **kwargs) + + def _process_findings_internal( + self, + parsed_findings: list[Finding], + **kwargs: dict, + ) -> tuple[list[Finding], list[Finding], list[Finding], list[Finding]]: self.deduplication_algorithm = self.determine_deduplication_algorithm() # Only process findings with the same service value (or None) # Even though the service values is used in the hash_code calculation, @@ -302,6 +317,11 @@ def process_findings( batch_finding_ids: list[int] = [] batch_findings: list[Finding] = [] + # Findings that were newly created (else branch below) — pass these to + # `apply_inherited_tags_for_findings` instead of `batch_findings` so + # matched/existing findings (which already have correct inherited tags) + # don't trigger a redundant through-table read on no-change reimports. + new_findings_in_batch: list[Finding] = [] findings_with_parser_tags: list[tuple] = [] # Batch size for deduplication/post-processing (only new findings) dedupe_batch_max_size = getattr(settings, "IMPORT_REIMPORT_DEDUPE_BATCH_SIZE", 1000) @@ -384,6 +404,8 @@ def process_findings( candidates_by_uid, candidates_by_key, ) + if finding: + new_findings_in_batch.append(finding) # This condition __appears__ to always be true, but am afraid to remove it if finding: @@ -422,6 +444,14 @@ def process_findings( findings_with_parser_tags.clear() # Apply import-time tags before post-processing so rules/deduplication see them. self.apply_import_tags_for_batch(batch_findings) + # Apply inherited Product tags to NEWLY CREATED findings only + # (and their endpoints/locations) BEFORE post_process_findings_batch + # dispatches, so rules/dedup see inherited tags on .tags. + # Matched/existing findings already have inheritance applied from + # their original creation; re-running it on no-change reimports + # would be ~8 wasted queries per batch. + apply_inherited_tags_for_findings(new_findings_in_batch) + new_findings_in_batch.clear() batch_findings.clear() finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() @@ -949,7 +979,7 @@ def finding_post_processing( finding_from_report: Finding, *, is_matched_finding: bool = False, - tag_accumulator: list | None = None, + tag_accumulator: list, ) -> Finding: """ Save all associated objects to the finding after it has been saved @@ -971,15 +1001,10 @@ def finding_post_processing( finding_from_report.unsaved_tags = merged_tags if finding_from_report.unsaved_tags: cleaned_tags = clean_tags(finding_from_report.unsaved_tags) - if tag_accumulator is not None: - if isinstance(cleaned_tags, list): - tag_accumulator.append((finding, cleaned_tags)) - elif isinstance(cleaned_tags, str): - tag_accumulator.append((finding, [cleaned_tags])) - elif isinstance(cleaned_tags, list): - finding.tags.add(*cleaned_tags) + if isinstance(cleaned_tags, list): + tag_accumulator.append((finding, cleaned_tags)) elif isinstance(cleaned_tags, str): - finding.tags.add(cleaned_tags) + tag_accumulator.append((finding, [cleaned_tags])) # Process any files if finding_from_report.unsaved_files: finding.unsaved_files = finding_from_report.unsaved_files diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 21e9673ed8c..f3af161f159 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -14,7 +14,7 @@ Finding, Product, ) -from dojo.tags_signals import inherit_instance_tags +from dojo.tags.inheritance import apply_inherited_tags_for_endpoints logger = logging.getLogger(__name__) @@ -231,10 +231,16 @@ def get_or_create_endpoints(self) -> tuple[dict[EndpointUniqueKey, Endpoint], li if to_create: created = Endpoint.objects.bulk_create(to_create, batch_size=1000) endpoints_by_key.update(zip(to_create_keys, created, strict=True)) - # bulk_create bypasses post_save signals, so manually trigger tag inheritance - # this is not ideal, but we need to take a separate look at the tag inheritance feature itself later - for ep in created: - inherit_instance_tags(ep) + # bulk_create bypasses post_save so per-row inheritance signals never + # fire here. The importer hot path already covers these endpoints via + # the per-batch `apply_inherited_tags_for_findings` sweep (it picks + # them up through `Endpoint.status_finding.finding`), so this call is + # redundant for the importer. We keep a bulk call anyway as a defensive + # measure: if anything outside the importer ever bulk-creates endpoints + # through this manager, they still receive their inherited Product tags + # instead of silently missing them. The bulk helper costs ~2 queries + # when there's nothing to apply, vs N per-row signal fires. + apply_inherited_tags_for_endpoints(created) self._endpoints_to_create.clear() return endpoints_by_key, created diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index dc0171db731..a2e08e5272e 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -9,19 +9,15 @@ from django.db import transaction from django.utils import timezone -from dojo import tag_inheritance from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.models import Product, _manage_inherited_tags +from dojo.tags import inheritance as tag_inheritance from dojo.tools.locations import LocationData from dojo.url.models import URL -from dojo.utils import get_system_setting if TYPE_CHECKING: - from tagulous.models import TagField - - from dojo.models import Dojo_User, Finding + from dojo.models import Dojo_User, Finding, Product logger = logging.getLogger(__name__) @@ -214,8 +210,18 @@ def _persist_locations(self) -> None: all_product_refs, batch_size=1000, ignore_conflicts=True, ) - # Trigger bulk tag inheritance - self._bulk_inherit_tags(loc.location for loc in saved) + # Trigger bulk tag inheritance only when the Location's product + # membership actually changed. New product refs are the only thing + # that can add a Product to a Location's inherited-tags target set + # (new finding refs are always to findings in `self._product`, so + # they don't introduce a new Product); skipping when `all_product_refs` + # is empty avoids the through-table read on no-change reimports. + if all_product_refs: + new_ref_location_ids = {ref.location_id for ref in all_product_refs} + tag_inheritance.apply_inherited_tags_for_locations( + [loc.location for loc in saved if loc.location_id in new_ref_location_ids], + product=self._product, + ) # Clear accumulators self._locations_by_finding.clear() @@ -477,105 +483,3 @@ def type_id(x: tuple[int, AbstractLocation]) -> int: # Restore the original input ordering saved.sort(key=itemgetter(0)) return [loc for _, loc in saved] - - # ------------------------------------------------------------------ - # Tag inheritance - # ------------------------------------------------------------------ - - def _bulk_inherit_tags(self, locations): - """ - Bulk equivalent of calling inherit_instance_tags(loc) for many Locations. Actually persisting updates is handled - by a per-location call to _manage_inherited_tags(), but at least determining what the tags are is more efficient - (plus we can skip locations that don't need an update at all). - - When tag inheritance is enabled, computes the target inherited tags for each location from all related products - and updates only locations that are out of sync. - """ - locations = list(locations) - if not locations: - return - - # Check whether tag inheritance is enabled at either the product level or system-wide; quit early if neither - product_inherit = getattr(self._product, "enable_product_tag_inheritance", False) - system_wide_inherit = bool(get_system_setting("enable_product_tag_inheritance")) - if not system_wide_inherit and not product_inherit: - return - - # A location can be shared across multiple products. Its inherited tags should be the union of - # tags from ALL contributing products, not just the one running this import. - location_ids = [loc.id for loc in locations] - product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} - - # Find associations through LocationProductReference entries - for loc_id, prod_id in LocationProductReference.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", "product_id"): - product_ids_by_location[loc_id].add(prod_id) - - # Find associations through LocationFindingReference entries and the finding.test.engagement.product chain. - # This shouldn't add anything new, but just in case. - for loc_id, prod_id in ( - LocationFindingReference.objects - .filter(location_id__in=location_ids) - .values_list("location_id", "finding__test__engagement__product_id") - ): - product_ids_by_location[loc_id].add(prod_id) - - # Fetch all products that will contribute to tag inheritance, and their tags - all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} - product_qs = Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") - if not system_wide_inherit: - # Product-level inheritance only - product_qs = product_qs.filter(enable_product_tag_inheritance=True) - # Materialize into a dict for ease of use - products: dict[int, Product] = {p.id: p for p in product_qs} - # Get distinct tags, per-product - tags_by_product: dict[int, set[str]] = { - pid: {t.name for t in p.tags.all()} - for pid, p in products.items() - } - - # Helper method for getting all tags from the given TagField - def _get_tags(tags_field: TagField) -> dict[int, set[str]]: - through_model = tags_field.through - fk_name = tags_field.field.m2m_reverse_field_name() - tags_by_location: dict[int, set[str]] = {loc.id: set() for loc in locations} - for l_id, t_name in through_model.objects.filter( - location_id__in=location_ids, - ).values_list("location_id", f"{fk_name}__name"): - tags_by_location[l_id].add(t_name) - return tags_by_location - - # Gather inherited and 'regular' tags per location - existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags) - existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags) - - # Perform the bulk updates inside a `tag_inheritance.batch()` context. - # While the batch is active, signal handlers in `dojo/tags_signals.py` - # short-circuit per-row inheritance work that would otherwise fire on - # every `(inherited_)tags.set()` and defeat the bulk update. - # - # This replaces a previous `signals.m2m_changed.disconnect(...)` / - # `connect(...)` dance which was process-global and therefore unsafe - # under threaded gunicorn / Celery thread pools / ASGI threadpools: - # while disconnected, every thread in the process lost sticky - # enforcement. Thread-local batch state avoids that hazard. - with tag_inheritance.batch_mode(): - for location in locations: - target_tag_names: set[str] = set() - for pid in product_ids_by_location[location.id]: - # product_ids_by_location may contain products that shouldn't to contribute to tag inheritance (we - # didn't filter either location ref lookups to check), so do a last-minute check here - if pid in products: - target_tag_names |= tags_by_product[pid] - - if target_tag_names == existing_inherited_by_location[location.id]: - # The existing set matches the expected set, so nothing more to do for this location - continue - - # Update tags for this location - _manage_inherited_tags( - location, - list(target_tag_names), - potentially_existing_tags=existing_tags_by_location[location.id], - ) diff --git a/dojo/location/models.py b/dojo/location/models.py index 48ffeb6878a..6c6b738655d 100644 --- a/dojo/location/models.py +++ b/dojo/location/models.py @@ -34,7 +34,7 @@ LocationQueryset, ) from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.models import Dojo_User, Finding, Product, _manage_inherited_tags, copy_model_util +from dojo.models import Dojo_User, Finding, Product, copy_model_util from dojo.tools.locations import LocationAssociationData if TYPE_CHECKING: @@ -231,19 +231,48 @@ def all_related_products(self) -> QuerySet[Product]: | Q(engagement__test__finding__locations__location=self), ).distinct() + def iter_related_products(self) -> list[Product]: + """ + Prefetch-friendly equivalent of `all_related_products()`. + + Walks `self.products.all()` (LocationProductReference) and + `self.findings.all()` (LocationFindingReference -> Finding -> Test -> + Engagement -> Product) via Django related managers, so a caller that + already issued + + Location.objects.filter(...).prefetch_related( + "products__product__tags", + "findings__finding__test__engagement__product__tags", + ) + + gets every Product (and its tags) in 0 extra queries per Location. + + Use this method from bulk paths where many Locations are processed at + once. The original `all_related_products()` still issues a single + DISTINCT JOIN query and is kept for per-instance signal paths where + prefetching is not possible. + """ + seen: set[int] = set() + result: list[Product] = [] + for ref in self.products.all(): + if ref.product_id not in seen: + seen.add(ref.product_id) + result.append(ref.product) + for ref in self.findings.all(): + product = ref.finding.test.engagement.product + if product.id not in seen: + seen.add(product.id) + result.append(product) + return result + def products_to_inherit_tags_from(self) -> list[Product]: from dojo.utils import get_system_setting # noqa: PLC0415 - system_wide_inherit = get_system_setting("enable_product_tag_inheritance") - return [ - product for product - in self.all_related_products() - if product.enable_product_tag_inheritance or system_wide_inherit - ] - - def inherit_tags(self, potentially_existing_tags): - # get a copy of the tags to be inherited - incoming_inherited_tags = [tag.name for product in self.products_to_inherit_tags_from() for tag in product.tags.all()] - _manage_inherited_tags(self, incoming_inherited_tags, potentially_existing_tags=potentially_existing_tags) + # System-wide setting is cached — short-circuit before reading the + # per-product flag on every related product. + products = self.all_related_products() + if get_system_setting("enable_product_tag_inheritance"): + return products + return [product for product in products if product.enable_product_tag_inheritance] class Meta: verbose_name = "Locations - Location" diff --git a/dojo/models.py b/dojo/models.py index 420d85909eb..a41f5640889 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -41,7 +41,7 @@ from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicModel from tagulous.models import TagField -from tagulous.models.managers import FakeTagRelatedManager +from tagulous.models.managers import FakeTagRelatedManager # noqa: F401 -- backward compat re-export from titlecase import titlecase from dojo.base_models.base import BaseModel @@ -110,28 +110,12 @@ def _get_statistics_for_queryset(qs, annotation_factory): return stats -def _manage_inherited_tags(obj, incoming_inherited_tags, potentially_existing_tags=None): - # get copies of the current tag lists - if potentially_existing_tags is None: - potentially_existing_tags = [] - current_inherited_tags = [] if isinstance(obj.inherited_tags, FakeTagRelatedManager) else [tag.name for tag in obj.inherited_tags.all()] - tag_list = potentially_existing_tags if isinstance(obj.tags, FakeTagRelatedManager) or len(potentially_existing_tags) > 0 else [tag.name for tag in obj.tags.all()] - # Clean existing tag list from the old inherited tags. This represents the tags on the object and not the product - cleaned_tag_list = [tag for tag in tag_list if tag not in current_inherited_tags] - # Add the incoming inherited tag list - if incoming_inherited_tags: - for tag in incoming_inherited_tags: - if tag not in cleaned_tag_list: - cleaned_tag_list.append(tag) - # Update the current list of inherited tags. iteratively do this because of tagulous object restraints - if isinstance(obj.inherited_tags, FakeTagRelatedManager): - obj.inherited_tags.set_tag_list(incoming_inherited_tags) - if incoming_inherited_tags: - obj.tags.set_tag_list(cleaned_tag_list) - else: - obj.inherited_tags.set(incoming_inherited_tags) - if incoming_inherited_tags: - obj.tags.set(cleaned_tag_list) +def _sync_inherited_tags(obj, incoming_inherited_tags): + # Backward-compat shim. Implementation lives in dojo.tags.inheritance; lazy + # import keeps dojo.models loadable before dojo.tags.inheritance (which + # transitively imports dojo.utils -> dojo.models) is ready. + from dojo.tags.inheritance import _sync_inherited_tags as _impl # noqa: PLC0415 + return _impl(obj, incoming_inherited_tags) def copy_model_util(model_in_database, exclude_fields: list[str] | None = None): @@ -1585,11 +1569,6 @@ def delete(self, *args, **kwargs): from dojo.utils import perform_product_grading # noqa: PLC0415 circular import perform_product_grading(self.product) - def inherit_tags(self, potentially_existing_tags): - # get a copy of the tags to be inherited - incoming_inherited_tags = [tag.name for tag in self.product.tags.all()] - _manage_inherited_tags(self, incoming_inherited_tags, potentially_existing_tags=potentially_existing_tags) - class CWE(models.Model): url = models.CharField(max_length=1000) @@ -2035,11 +2014,6 @@ def from_uri(uri): fragment=fragment, ) - def inherit_tags(self, potentially_existing_tags): - # get a copy of the tags to be inherited - incoming_inherited_tags = [tag.name for tag in self.product.tags.all()] - _manage_inherited_tags(self, incoming_inherited_tags, potentially_existing_tags=potentially_existing_tags) - class Development_Environment(models.Model): name = models.CharField(max_length=200) @@ -2240,11 +2214,6 @@ def statistics(self): """Queries the database, no prefetching, so could be slow for lists of model instances""" return _get_statistics_for_queryset(Finding.objects.filter(test=self), _get_annotations_for_statistics) - def inherit_tags(self, potentially_existing_tags): - # get a copy of the tags to be inherited - incoming_inherited_tags = [tag.name for tag in self.engagement.product.tags.all()] - _manage_inherited_tags(self, incoming_inherited_tags, potentially_existing_tags=potentially_existing_tags) - class Test_Import(TimeStampedModel): @@ -3544,11 +3513,6 @@ def vulnerability_ids(self): # Remove duplicates return list(dict.fromkeys(vulnerability_ids)) - def inherit_tags(self, potentially_existing_tags): - # get a copy of the tags to be inherited - incoming_inherited_tags = [tag.name for tag in self.test.engagement.product.tags.all()] - _manage_inherited_tags(self, incoming_inherited_tags, potentially_existing_tags=potentially_existing_tags) - @property def violates_sla(self): return (self.sla_expiration_date and self.sla_expiration_date < timezone.now().date()) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py deleted file mode 100644 index 7bbb2937103..00000000000 --- a/dojo/product/helpers.py +++ /dev/null @@ -1,135 +0,0 @@ -import contextlib -import logging -from collections import defaultdict - -from django.conf import settings -from django.db.models import Q - -from dojo.celery import app -from dojo.location.models import Location -from dojo.models import Endpoint, Engagement, Finding, Product, Test -from dojo.tag_utils import bulk_add_tag_mapping, bulk_remove_tags_from_instances - -logger = logging.getLogger(__name__) - - -@app.task -def propagate_tags_on_product(product_id, *args, **kwargs): - with contextlib.suppress(Product.DoesNotExist): - product = Product.objects.get(id=product_id) - propagate_tags_on_product_sync(product) - - -def propagate_tags_on_product_sync(product): - """ - Bulk-apply Product tag changes to all children using through-table SQL. - - Replaces the previous per-row `.save()` loop. For every child model owned - by the product (Engagement, Test, Finding, plus Endpoint or Location - depending on the V3_FEATURE_LOCATIONS flag), reads the existing - `inherited_tags` per child in one query, computes the diff against the - Product's current tags, and applies adds/removes via the bulk tag - helpers. Both `tags` and `inherited_tags` fields are kept in sync. - """ - target_names = {tag.name for tag in product.tags.all()} - - logger.debug("Propagating tags from %s to all engagements", product) - _sync_inheritance_for_qs( - Engagement.objects.filter(product=product), - target_names_per_child=lambda _child: target_names, - ) - logger.debug("Propagating tags from %s to all tests", product) - _sync_inheritance_for_qs( - Test.objects.filter(engagement__product=product), - target_names_per_child=lambda _child: target_names, - ) - logger.debug("Propagating tags from %s to all findings", product) - _sync_inheritance_for_qs( - Finding.objects.filter(test__engagement__product=product), - target_names_per_child=lambda _child: target_names, - ) - if settings.V3_FEATURE_LOCATIONS: - logger.debug("Propagating tags from %s to all locations", product) - location_qs = Location.objects.filter( - Q(products__product=product) - | Q(findings__finding__test__engagement__product=product), - ).distinct() - # Locations can be linked to multiple products, so the inherited target - # is the union of every related product's tags. Compute per-location. - _sync_inheritance_for_qs( - location_qs, - target_names_per_child=_location_target_names, - ) - else: - logger.debug("Propagating tags from %s to all endpoints", product) - _sync_inheritance_for_qs( - Endpoint.objects.filter(product=product), - target_names_per_child=lambda _child: target_names, - ) - - -def _location_target_names(location): - names: set[str] = set() - for related_product in location.all_related_products(): - if related_product is None: - continue - names.update(tag.name for tag in related_product.tags.all()) - return names - - -def _sync_inheritance_for_qs(queryset, *, target_names_per_child): - """ - Sync inherited_tags + tags for every child in `queryset` to its target tag set. - - target_names_per_child: callable(child) -> set[str]. - - Issues bulk SQL: one through-table read for current inherited_tags, then - bulk add/remove on `tags` and `inherited_tags` fields. - """ - children = list(queryset) - if not children: - return - - model_class = type(children[0]) - inherited_field = model_class._meta.get_field("inherited_tags") - inherited_through = inherited_field.remote_field.through - inherited_tag_model = inherited_field.related_model - - # Resolve through-table FK column for the source side. - source_field_name = None - for field in inherited_through._meta.fields: - if hasattr(field, "remote_field") and field.remote_field and field.remote_field.model == model_class: - source_field_name = field.name - break - - child_ids = [c.pk for c in children] - # One query: pull every (child_id, tag_name) pair from the inherited_tags through table. - existing_pairs = inherited_through.objects.filter( - **{f"{source_field_name}__in": child_ids}, - ).values_list(source_field_name, f"{inherited_tag_model._meta.model_name}__name") - - old_inherited_by_child: dict[int, set[str]] = defaultdict(set) - for child_id, tag_name in existing_pairs: - old_inherited_by_child[child_id].add(tag_name) - - # Compute per-child diff and bucket by tag name. - add_map: dict[str, list] = defaultdict(list) - remove_map: dict[str, list] = defaultdict(list) - for child in children: - target = target_names_per_child(child) - old = old_inherited_by_child.get(child.pk, set()) - for name in target - old: - add_map[name].append(child) - for name in old - target: - remove_map[name].append(child) - - # Apply adds. Both `tags` and `inherited_tags` get the same set of new - # inherited names — `_manage_inherited_tags` did the same. - if add_map: - bulk_add_tag_mapping(add_map, tag_field_name="inherited_tags") - bulk_add_tag_mapping(add_map, tag_field_name="tags") - - # Apply removes. - for name, instances in remove_map.items(): - bulk_remove_tags_from_instances(name, instances, tag_field_name="inherited_tags") - bulk_remove_tags_from_instances(name, instances, tag_field_name="tags") diff --git a/dojo/settings/settings.dist.py b/dojo/settings/settings.dist.py index 522ad54e7c0..70c58bddbee 100644 --- a/dojo/settings/settings.dist.py +++ b/dojo/settings/settings.dist.py @@ -329,7 +329,7 @@ def generate_url(scheme, double_slashes, user, password, host, port, path, param _populate_notifications_settings(env, globals()) TAG_PREFETCHING = env("DD_TAG_PREFETCHING") -# Tag bulk add batch size (used by dojo.tag_utils.bulk_add_tag_to_instances) +# Tag bulk add batch size (used by dojo.tags.utils.bulk_add_tag_to_instances) TAG_BULK_ADD_BATCH_SIZE = env("DD_TAG_BULK_ADD_BATCH_SIZE") diff --git a/dojo/tag_inheritance.py b/dojo/tag_inheritance.py deleted file mode 100644 index e9d0e98a5fc..00000000000 --- a/dojo/tag_inheritance.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Tag inheritance — central coordination module. - -Provides a thread-local ``batch()`` context manager that suppresses -per-instance inheritance work driven by ``m2m_changed`` and ``post_save`` -signals. While inside a batch, the signal handlers in -``dojo/tags_signals.py`` early-return; the calling code is responsible for -applying inheritance in bulk (e.g. via the importer's existing -``_bulk_inherit_tags`` path or ``propagate_tags_on_product_sync``). - -This replaces the previous pattern of ``signals.m2m_changed.disconnect(...)`` -in importer hot loops, which was process-global and unsafe under threaded -gunicorn / Celery thread pools / ASGI threadpools (see PR description for -the full rationale). -""" -from __future__ import annotations - -import contextlib -import threading -from contextlib import contextmanager - -_state = threading.local() - - -def is_in_batch_mode() -> bool: - """Return True when the current thread is inside an active ``batch()``.""" - return bool(getattr(_state, "depth", 0)) - - -@contextmanager -def batch_mode(): - """ - Suppress per-instance inheritance signals for the calling thread. - - Usage: - with tag_inheritance.batch(): - # Bulk operations that would otherwise fire `make_inherited_tags_sticky` - # or `inherit_tags_on_instance` per row. - ... - - The context is reentrant; nested ``with`` blocks share the suppression - until the outermost block exits. State lives in ``threading.local()``, - so concurrent threads (and Celery workers in non-prefork pools) are - unaffected by other threads' batches. - """ - _state.depth = getattr(_state, "depth", 0) + 1 - try: - yield - finally: - _state.depth -= 1 - if _state.depth <= 0: - # Clean up the attribute so leak-free thread reuse stays simple. - with contextlib.suppress(AttributeError): - del _state.depth diff --git a/dojo/tags/__init__.py b/dojo/tags/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dojo/tags/inheritance.py b/dojo/tags/inheritance.py new file mode 100644 index 00000000000..206298e6282 --- /dev/null +++ b/dojo/tags/inheritance.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +import logging +import threading +from collections import defaultdict +from contextlib import contextmanager, suppress +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from django.db.models import Model + +from django.conf import settings +from django.db.models import Q +from tagulous.models.managers import FakeTagRelatedManager + +# Top-level imports of dojo internals are safe here because +# ``dojo.tags.inheritance`` is loaded lazily — never during the initial +# evaluation of ``dojo.models``. By the time anything imports this module +# (signals registration, importers, the per-model ``inherit_tags()`` shim +# in ``dojo.models``), the full model layer is initialised. +from dojo.celery import app +from dojo.location.models import Location +from dojo.models import Endpoint, Engagement, Finding, Product, Test +from dojo.tags.utils import bulk_add_tag_mapping, bulk_remove_tags_from_instances +from dojo.utils import get_system_setting + +logger = logging.getLogger(__name__) + +_state = threading.local() + + +def is_suppressed() -> bool: + """Return True when the current thread is inside an active ``suppress_tag_inheritance()``.""" + return bool(getattr(_state, "depth", 0)) + + +@contextmanager +def suppress_tag_inheritance(): + """ + Suppress per-instance inheritance signals for the calling thread. + + Usage: + with tag_inheritance.suppress_tag_inheritance(): + # Bulk operations that would otherwise fire `make_inherited_tags_sticky` + # or `inherit_tags_on_instance` per row. + ... + + The context is reentrant; nested ``with`` blocks share the suppression + until the outermost block exits. State lives in ``threading.local()``, + so concurrent threads (and Celery workers in non-prefork pools) are + unaffected by other threads' batches. + """ + _state.depth = getattr(_state, "depth", 0) + 1 + try: + yield + finally: + _state.depth -= 1 + if _state.depth <= 0: + # Clean up the attribute so leak-free thread reuse stays simple. + with suppress(AttributeError): + del _state.depth + + +def _sync_inherited_tags(obj, incoming_inherited_tags): + """ + Sync ``obj.inherited_tags`` and ``obj.tags`` to match ``incoming_inherited_tags``. + + Diff-based: only the inherited names that changed are added/removed. Also + re-adds any inherited name that has been stripped from ``obj.tags`` directly + (sticky enforcement). + + Writes are wrapped in ``suppress_tag_inheritance()`` so the m2m_changed + signal fired by each ``.add()``/``.remove()`` does not dispatch + ``make_inherited_tags_sticky`` back into this function. The context + manager is reentrant so callers that already opened a batch (e.g. + ``auto_inherit_product_tags`` in ``dojo.tags.signals``, or the importer's + bulk path) + nest harmlessly. + """ + target = set(incoming_inherited_tags or []) + + # Unsaved instance: FakeTagRelatedManager has no .all()/.add()/.remove(). + # Set in-memory tag lists directly, merging incoming into any preset tags. + # set_tag_list() is purely in-memory — no DB write, no m2m_changed — so it + # doesn't need the suppress wrap. The `obj.tags.add(*target)` fallback + # below covers a theoretical mixed-state case (saved tags manager next to + # an unsaved inherited_tags manager) and DOES fire m2m_changed, so it + # gets wrapped. + if isinstance(obj.inherited_tags, FakeTagRelatedManager): + obj.inherited_tags.set_tag_list(list(target)) + if target: + if isinstance(obj.tags, FakeTagRelatedManager): + existing = obj.tags.get_tag_list() + obj.tags.set_tag_list(list(dict.fromkeys([*existing, *target]))) + else: + # avoid reentrancy: the `add(*target)` write fires m2m_changed + with suppress_tag_inheritance(): + obj.tags.add(*target) + return + + current_inherited = {tag.name for tag in obj.inherited_tags.all()} + current_tags = {tag.name for tag in obj.tags.all()} + to_remove = current_inherited - target + to_add = target - current_inherited + # Sticky: any target name already absent from obj.tags AND not covered by + # to_add (user-driven m2m_changed stripped it). Re-add separately. + sticky_missing = (target - current_tags) - to_add + + # avoid reentrancy: the `remove(*to_remove)` / `add(*to_add)` / `add(*sticky_missing)` writes fire m2m_changed + with suppress_tag_inheritance(): + if to_remove: + obj.inherited_tags.remove(*to_remove) + obj.tags.remove(*to_remove) + if to_add: + obj.inherited_tags.add(*to_add) + obj.tags.add(*to_add) + if sticky_missing: + obj.tags.add(*sticky_missing) + + +def get_products(instance): + if isinstance(instance, Product): + return [instance] + if isinstance(instance, Endpoint): + return [instance.product] + if isinstance(instance, Engagement): + return [instance.product] + if isinstance(instance, Test): + return [instance.engagement.product] + if isinstance(instance, Finding): + return [instance.test.engagement.product] + if isinstance(instance, Location): + return list(instance.all_related_products()) + return [] + + +def get_products_to_inherit_tags_from(instance): + products = [p for p in get_products(instance) if p] + # System-wide setting is cached — short-circuit before reading the + # per-product flag on every related product. + if get_system_setting("enable_product_tag_inheritance"): + return products + return [product for product in products if product.enable_product_tag_inheritance] + + +def is_tag_inheritance_enabled(instance) -> bool: + # delegate so we have logic centralized. no products -> no inheritance enabled. + return bool(get_products_to_inherit_tags_from(instance)) + + +# --------------------------------------------------------------------------- +# Bulk product-wide inheritance +# --------------------------------------------------------------------------- + + +def propagate_tags_on_product_sync(product): + """ + Bulk-apply Product tag changes to all children using through-table SQL. + + Replaces the previous per-row `.save()` loop. For every child model owned + by the product (Engagement, Test, Finding, plus Endpoint or Location + depending on the V3_FEATURE_LOCATIONS flag), reads the existing + `inherited_tags` per child in one query, computes the diff against the + Product's current tags, and applies adds/removes via the bulk tag + helpers. Both `tags` and `inherited_tags` fields are kept in sync. + """ + if not (get_system_setting("enable_product_tag_inheritance") or product.enable_product_tag_inheritance): + return + + inherited_tag_names = {tag.name for tag in product.tags.all()} + + logger.debug("Propagating tags from %s to all engagements", product) + _sync_inheritance_for_ids( + Engagement, + Engagement.objects.filter(product=product).values_list("pk", flat=True), + target_tag_names=inherited_tag_names, + ) + logger.debug("Propagating tags from %s to all tests", product) + _sync_inheritance_for_ids( + Test, + Test.objects.filter(engagement__product=product).values_list("pk", flat=True), + target_tag_names=inherited_tag_names, + ) + logger.debug("Propagating tags from %s to all findings", product) + _sync_inheritance_for_ids( + Finding, + Finding.objects.filter(test__engagement__product=product).values_list("pk", flat=True), + target_tag_names=inherited_tag_names, + ) + if settings.V3_FEATURE_LOCATIONS: + logger.debug("Propagating tags from %s to all locations", product) + # Locations can be linked to multiple products, so the inherited target + # is the union of every related product's tags. Materialize the full + # Locations (with the related-product prefetch chain) into a pk-keyed + # dict so the per-pk callback can look up each Location's instance. + locations_by_pk = { + loc.pk: loc + for loc in Location.objects.filter( + Q(products__product=product) + | Q(findings__finding__test__engagement__product=product), + ).distinct().prefetch_related(*_LOCATION_PREFETCH_FOR_INHERITANCE) + } + _sync_inheritance_for_ids( + Location, + locations_by_pk.keys(), + target_tag_names=lambda pk: _inherited_tag_names_for_location(locations_by_pk[pk]), + ) + else: + logger.debug("Propagating tags from %s to all endpoints", product) + _sync_inheritance_for_ids( + Endpoint, + Endpoint.objects.filter(product=product).values_list("pk", flat=True), + target_tag_names=inherited_tag_names, + ) + + +@app.task(name="dojo.product.helpers.propagate_tags_on_product") +def propagate_tags_on_product_deprecated(product_id, *args, **kwargs): + # kept to make sure tasks are still processed if someone didn't do a clean shutdown before upgrading + logger.warning("propagate_tags_on_product_deprecated is deprecated and will be removed in a future version. Use propagate_tags_on_product instead.") + propagate_tags_on_product(product_id, *args, **kwargs) + + +@app.task +def propagate_tags_on_product(product_id, *args, **kwargs): + """Load Product by id and run ``propagate_tags_on_product_sync`` (Celery worker).""" + with suppress(Product.DoesNotExist): + product = Product.objects.get(id=product_id) + propagate_tags_on_product_sync(product) + + +def apply_inherited_tags_for_endpoints(endpoints): + """ + Bulk inheritance for a list of Endpoints, e.g. those just created via + `Endpoint.objects.bulk_create` (which bypasses post_save signals). + + All endpoints are assumed to share a single Product — true for the + importer's `EndpointManager`, which is per-product. If callers ever + mix products, split the list before calling. + """ + if not endpoints: + return + product = endpoints[0].product + if not (get_system_setting("enable_product_tag_inheritance") or product.enable_product_tag_inheritance): + return + inherited_tag_names = {tag.name for tag in product.tags.all()} + _sync_inheritance_for_ids( + Endpoint, + [e.pk for e in endpoints], + target_tag_names=inherited_tag_names, + ) + + +def apply_inherited_tags_for_findings(findings): + """ + Per-batch bulk inheritance for findings created during an import. + + Apply the owning Product's inherited tags to the given findings plus the + Endpoints (V2) / Locations (V3) reachable from them. Called from the + importer hot path right before each batch dispatches to + `post_process_findings_batch` so rules / deduplication see inherited tags + on `finding.tags`. + + Test and Engagement inheritance is handled by their own post_save handlers + (those run outside the importer's `batch_mode()`, so per-instance signal + work fires normally and applies inheritance on create). + """ + if not findings: + return + # Single-product invariant inside one importer call. Smart upload calls + # this per-product so the assumption holds there too. + product = findings[0].test.engagement.product + if not (get_system_setting("enable_product_tag_inheritance") or product.enable_product_tag_inheritance): + return + inherited_tag_names = {tag.name for tag in product.tags.all()} + finding_ids = [f.pk for f in findings] + + _sync_inheritance_for_ids( + Finding, + finding_ids, + target_tag_names=inherited_tag_names, + ) + if settings.V3_FEATURE_LOCATIONS: + locations_by_pk = { + loc.pk: loc + for loc in Location.objects.filter( + findings__finding_id__in=finding_ids, + ).distinct().prefetch_related(*_LOCATION_PREFETCH_FOR_INHERITANCE) + } + _sync_inheritance_for_ids( + Location, + locations_by_pk.keys(), + target_tag_names=lambda pk: _inherited_tag_names_for_location(locations_by_pk[pk]), + ) + else: + _sync_inheritance_for_ids( + Endpoint, + Endpoint.objects.filter(status_endpoint__finding_id__in=finding_ids).distinct().values_list("pk", flat=True), + target_tag_names=inherited_tag_names, + ) + + +def _inherited_tag_names_for_location(location): + """ + Compute the tag-name set this Location should have as `inherited_tags`. + + Unlike Finding / Test / Engagement / Endpoint (each owned by exactly one + Product), a Location can be attached to multiple Products — directly via + `LocationProductReference` or indirectly via `LocationFindingReference` + -> Finding -> Test -> Engagement -> Product. The target inherited set is + therefore the UNION of every related Product's tags, restricted to + Products whose own `enable_product_tag_inheritance` flag is on (or where + the system-wide setting is on). + + Used as the `target_tag_names` callback for `_sync_inheritance_for_ids` + on Location querysets; it must be called per Location because each Location + has its own set of related Products. Uses `iter_related_products()` so + that an upstream `prefetch_related(...)` reduces per-call cost to 0 + queries. + """ + system_wide = bool(get_system_setting("enable_product_tag_inheritance")) + names: set[str] = set() + for related_product in location.iter_related_products(): + if related_product is None: + continue + if not system_wide and not related_product.enable_product_tag_inheritance: + continue + names.update(tag.name for tag in related_product.tags.all()) + return names + + +def apply_inherited_tags_for_locations(locations, *, product): + """ + Per-batch bulk inheritance for Locations touched during an import. + + A Location can be linked to multiple Products via `LocationProductReference` + (direct) or `LocationFindingReference` -> Finding -> Test -> Engagement -> + Product (indirect). Target inherited set is the union of every contributing + Product's tags, filtered by each Product's `enable_product_tag_inheritance` + flag (skipped entirely when the system-wide setting is on). + + Gated on the importing `product`: when neither the system setting nor the + importing product's flag is on, this is a no-op. Tags from other products + propagate via their own `Product.tags.through` m2m_changed handler when + they change, so skipping here is safe. + + Uses values_list-based ref-table lookups (4 small queries) rather than + `prefetch_related(_LOCATION_PREFETCH_FOR_INHERITANCE)` to keep the + importer hot path lean. + """ + locations = list(locations) + if not locations: + return + system_wide = bool(get_system_setting("enable_product_tag_inheritance")) + if not system_wide and not getattr(product, "enable_product_tag_inheritance", False): + return + + from dojo.location.models import ( # noqa: PLC0415 + LocationFindingReference, + LocationProductReference, + ) + + location_ids = [loc.id for loc in locations] + product_ids_by_location: dict[int, set[int]] = {loc.id: set() for loc in locations} + + for loc_id, prod_id in LocationProductReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "product_id"): + product_ids_by_location[loc_id].add(prod_id) + + # LocationFindingReference -> Finding -> Test -> Engagement -> Product. + # Shouldn't add anything new (LocationProductReference is created alongside), + # but covers edge cases where only the finding ref exists. + for loc_id, prod_id in ( + LocationFindingReference.objects + .filter(location_id__in=location_ids) + .values_list("location_id", "finding__test__engagement__product_id") + ): + product_ids_by_location[loc_id].add(prod_id) + + # prefetch all tags for all linked products up front so we can pass a callable to the sync function + all_product_ids = {pid for pids in product_ids_by_location.values() for pid in pids} + product_qs = Product.objects.filter(id__in=all_product_ids).prefetch_related("tags") + if not system_wide: + product_qs = product_qs.filter(enable_product_tag_inheritance=True) + tags_by_product: dict[int, set[str]] = { + p.id: {t.name for t in p.tags.all()} for p in product_qs + } + + def _target_tag_names_for_location_pk(pk): + names: set[str] = set() + for pid in product_ids_by_location[pk]: + # product_ids_by_location may contain products that shouldn't contribute + # (ref lookups weren't flag-filtered); check membership in tags_by_product. + tags = tags_by_product.get(pid) + if tags: + names |= tags + return names + + _sync_inheritance_for_ids( + Location, + [loc.id for loc in locations], + target_tag_names=_target_tag_names_for_location_pk, + ) + + +_LOCATION_PREFETCH_FOR_INHERITANCE = ( + "products__product__tags", + "findings__finding__test__engagement__product__tags", +) + + +def _sync_inheritance_for_ids( + model_class: type[Model], + child_ids: Iterable[int], + *, + target_tag_names: set[str] | Callable[[int], set[str]], +) -> None: + """ + Sync ``inherited_tags`` and ``tags`` for every child pk to its target tag set. + + Parameters + ---------- + model_class + The child model class (``Finding``, ``Engagement``, ``Endpoint``, + ``Location``, …). Used to resolve the ``inherited_tags`` field's + through-table and to build minimal pk-only stubs for the bulk helpers. + child_ids + Iterable of primary keys for ``model_class``. Pass a ``values_list("pk", + flat=True)`` queryset directly to avoid materializing model instances + — fetching full rows for 14000 findings was the bottleneck (~22s per + product-tag toggle) that motivated the pk-based design. + target_tag_names + The desired inherited-tag-name set for each child, in one of two forms: + + - ``set[str]`` — **constant target**. All children share the same + inherited set (Product → Engagement/Test/Finding/Endpoint + propagation, where every child has the same one parent product). + The value is hoisted out of the per-pk loop so there is no per-row + function-call overhead. + - ``Callable[[int], set[str]]`` — **per-pk target**. Looks up the + target set for each pk. Used for ``Location``, which can be linked + to multiple Products via ``LocationProductReference`` / + ``LocationFindingReference``, so the inherited set is the per-row + union of every linked Product's tags. Callers typically build a + ``{pk: location}`` dict with the relevant ``prefetch_related`` chain + and close over it inside the callback. + + Implementation notes + -------------------- + + + Avoids materializing children as full model instances. The previous + ``list(queryset)`` path fetched all 70+ columns per Finding row, which + dominated wall-clock time on large products. ``bulk_add_tag_mapping`` / + ``bulk_remove_tags_from_instances`` only ever read ``instance.pk`` and + ``instance.__class__``, so a bare ``model_class(pk=pid)`` stub is enough. + + Issues bulk SQL: one through-table read for the current ``inherited_tags`` + rows, then one INSERT per added tag-name and one DELETE per removed + tag-name (each batched if needed by the helpers). + + """ + child_ids = list(child_ids) + if not child_ids: + return + + inherited_field = model_class._meta.get_field("inherited_tags") + inherited_through = inherited_field.remote_field.through + inherited_tag_model = inherited_field.related_model + + # Resolve through-table FK column for the source side. + source_field_name = None + for field in inherited_through._meta.fields: + if hasattr(field, "remote_field") and field.remote_field and field.remote_field.model == model_class: + source_field_name = field.name + break + + # One query: pull every (child_id, tag_name) pair from the inherited_tags through table. + existing_pairs = inherited_through.objects.filter( + **{f"{source_field_name}__in": child_ids}, + ).values_list(source_field_name, f"{inherited_tag_model._meta.model_name}__name") + + old_inherited_by_child: dict[int, set[str]] = defaultdict(set) + for child_id, tag_name in existing_pairs: + old_inherited_by_child[child_id].add(tag_name) + + # Per-pk stub instances reused across tag buckets (bulk helpers only read + # .pk and .__class__). + stubs: dict[int, object] = {} + + def _stub(pk): + s = stubs.get(pk) + if s is None: + s = model_class(pk=pk) + stubs[pk] = s + return s + + constant_target: set[str] | None = None if callable(target_tag_names) else target_tag_names + + # Compute per-pk diff and bucket by tag name. + add_map: dict[str, list] = defaultdict(list) + remove_map: dict[str, list] = defaultdict(list) + for pk in child_ids: + target = constant_target if constant_target is not None else target_tag_names(pk) + old = old_inherited_by_child.get(pk, set()) + if target == old: + continue + for name in target - old: + add_map[name].append(_stub(pk)) + for name in old - target: + remove_map[name].append(_stub(pk)) + + # Apply adds. Both `tags` and `inherited_tags` get the same set of new + # inherited names — `_sync_inherited_tags` did the same. + if add_map: + bulk_add_tag_mapping(add_map, tag_field_name="inherited_tags") + bulk_add_tag_mapping(add_map, tag_field_name="tags") + + # Apply removes. + for name, instances in remove_map.items(): + bulk_remove_tags_from_instances(name, instances, tag_field_name="inherited_tags") + bulk_remove_tags_from_instances(name, instances, tag_field_name="tags") diff --git a/dojo/tags/signals.py b/dojo/tags/signals.py new file mode 100644 index 00000000000..6b8acaec59b --- /dev/null +++ b/dojo/tags/signals.py @@ -0,0 +1,91 @@ +import contextlib +import logging + +from django.db.models import signals +from django.dispatch import receiver + +from dojo.celery_dispatch import dojo_dispatch_task +from dojo.location.models import Location, LocationFindingReference, LocationProductReference +from dojo.models import Endpoint, Engagement, Finding, Product, Test +from dojo.tags import inheritance as tag_inheritance +from dojo.tags.inheritance import ( + _sync_inherited_tags, + get_products_to_inherit_tags_from, + is_suppressed, +) + +logger = logging.getLogger(__name__) + + +def auto_inherit_product_tags(instance): + """ + Apply product-inherited tags to ``instance`` from the auto-inheritance + signal path. + + Skipped while a ``suppress_tag_inheritance()`` context is active so bulk + callers (e.g. the importer hot loop) can defer per-instance work and run + inheritance once at batch time. The underlying ``_sync_inherited_tags`` + diffs the current vs target inherited set and only writes the delta. + """ + if is_suppressed(): + return + products = get_products_to_inherit_tags_from(instance) + if not products: + return + incoming_inherited_tags = [tag.name for product in products for tag in product.tags.all()] + _sync_inherited_tags(instance, incoming_inherited_tags) + + +@receiver(signals.m2m_changed, sender=Product.tags.through) +def product_tags_post_add_remove(sender, instance, action, **kwargs): + if action in {"post_add", "post_remove"}: + # `running_async_process` is an in-memory dedup flag on the Product + # instance. `tags.set([...])` fires m2m_changed twice on the SAME + # instance — once `post_remove` for dropped tags, once `post_add` for + # added tags — and we only want one `propagate_tags_on_product` task + # per Python-level operation. Not persisted: scope is exactly the + # lifetime of this in-memory instance. Two separate `Product.objects + # .get(id=X).tags.add(...)` calls still dispatch twice; the + # downstream task is idempotent (diff-based sync, no-op when nothing + # changed) so duplicates waste a slot but don't corrupt state. + running_async_process = False + with contextlib.suppress(AttributeError): + running_async_process = instance.running_async_process + if not running_async_process and tag_inheritance.is_tag_inheritance_enabled(instance): + dojo_dispatch_task(tag_inheritance.propagate_tags_on_product, instance.id, countdown=5) + instance.running_async_process = True + + +@receiver(signals.m2m_changed, sender=Endpoint.tags.through) +@receiver(signals.m2m_changed, sender=Engagement.tags.through) +@receiver(signals.m2m_changed, sender=Test.tags.through) +@receiver(signals.m2m_changed, sender=Finding.tags.through) +@receiver(signals.m2m_changed, sender=Location.tags.through) +def make_inherited_tags_sticky(sender, instance, action, **kwargs): + """Make sure inherited tags are added back in if they are removed.""" + if action in {"post_add", "post_remove"}: + auto_inherit_product_tags(instance) + + +@receiver(signals.post_save, sender=Endpoint) +@receiver(signals.post_save, sender=Engagement) +@receiver(signals.post_save, sender=Test) +@receiver(signals.post_save, sender=Finding) +@receiver(signals.post_save, sender=Location) +@receiver(signals.post_save, sender=LocationFindingReference) +@receiver(signals.post_save, sender=LocationProductReference) +def inherit_tags_on_instance(sender, instance, created, **kwargs): + # Only inherit on creation. Previously fired on every save (create OR + # update), repeatedly re-applying inherited tags to children whose tag + # state had not changed. Sticky enforcement on user-driven tag edits is + # handled by `make_inherited_tags_sticky` (m2m_changed). + # `auto_inherit_product_tags` itself early-returns when suppressed. + # + # For LocationFindingReference / LocationProductReference, the new link + # means the referenced Location may have a different set of related + # Products, so re-sync the Location's inherited tags. Ref status updates + # via `set_status` don't change the related-product set and are skipped. + if not created: + return + target = instance.location if isinstance(instance, LocationFindingReference | LocationProductReference) else instance + auto_inherit_product_tags(target) diff --git a/dojo/tag_utils.py b/dojo/tags/utils.py similarity index 100% rename from dojo/tag_utils.py rename to dojo/tags/utils.py diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py deleted file mode 100644 index 93ac1886930..00000000000 --- a/dojo/tags_signals.py +++ /dev/null @@ -1,132 +0,0 @@ -import contextlib -import logging - -from django.db.models import signals -from django.dispatch import receiver - -from dojo import tag_inheritance -from dojo.celery_dispatch import dojo_dispatch_task -from dojo.location.models import Location, LocationFindingReference, LocationProductReference -from dojo.models import Endpoint, Engagement, Finding, Product, Test -from dojo.product import helpers as async_product_funcs -from dojo.utils import get_system_setting - -logger = logging.getLogger(__name__) - - -@receiver(signals.m2m_changed, sender=Product.tags.through) -def product_tags_post_add_remove(sender, instance, action, **kwargs): - if action in {"post_add", "post_remove"}: - running_async_process = False - with contextlib.suppress(AttributeError): - running_async_process = instance.running_async_process - # Check if the async process is already running to avoid calling it a second time - if not running_async_process and inherit_product_tags(instance): - dojo_dispatch_task(async_product_funcs.propagate_tags_on_product, instance.id, countdown=5) - instance.running_async_process = True - - -@receiver(signals.m2m_changed, sender=Endpoint.tags.through) -@receiver(signals.m2m_changed, sender=Engagement.tags.through) -@receiver(signals.m2m_changed, sender=Test.tags.through) -@receiver(signals.m2m_changed, sender=Finding.tags.through) -@receiver(signals.m2m_changed, sender=Location.tags.through) -def make_inherited_tags_sticky(sender, instance, action, **kwargs): - """Make sure inherited tags are added back in if they are removed""" - # Inside a `tag_inheritance.batch()` block the caller takes responsibility - # for applying inheritance in bulk; per-row signal work would defeat the - # purpose. This replaces the old `signals.m2m_changed.disconnect(...)` - # pattern, which was process-global and unsafe under threaded workers. - if tag_inheritance.is_in_batch_mode(): - return - if action in {"post_add", "post_remove"}: - if inherit_product_tags(instance): - tag_list = [tag.name for tag in instance.tags.all()] - if propagate_inheritance(instance, tag_list=tag_list): - instance.inherit_tags(tag_list) - - -def inherit_instance_tags(instance): - """Usually nothing to do when saving a model, except for new models?""" - if inherit_product_tags(instance): - # TODO: Is this change OK to make? - # tag_list = instance._tags_tagulous.get_tag_list() - tag_list = instance.tags.get_tag_list() - if propagate_inheritance(instance, tag_list=tag_list): - instance.inherit_tags(tag_list) - - -def inherit_linked_instance_tags(instance: LocationFindingReference | LocationProductReference): - inherit_instance_tags(instance.location) - - -@receiver(signals.post_save, sender=Endpoint) -@receiver(signals.post_save, sender=Engagement) -@receiver(signals.post_save, sender=Test) -@receiver(signals.post_save, sender=Finding) -@receiver(signals.post_save, sender=Location) -def inherit_tags_on_instance(sender, instance, created, **kwargs): - # Only inherit on creation. The previous behavior fired on every save - # (create OR update), repeatedly re-applying inherited tags to children - # whose tag state had not changed. Sticky enforcement on user-driven - # tag edits is handled by `make_inherited_tags_sticky` (m2m_changed). - if not created: - return - inherit_instance_tags(instance) - - -@receiver(signals.post_save, sender=LocationFindingReference) -@receiver(signals.post_save, sender=LocationProductReference) -def inherit_tags_on_linked_instance(sender, instance, created, **kwargs): - inherit_linked_instance_tags(instance) - - -def propagate_inheritance(instance, tag_list=None): - # Get the expected product tags - if tag_list is None: - tag_list = [] - product_inherited_tags = [ - tag.name - for product in get_products_to_inherit_tags_from(instance) - for tag in product.tags.all() - ] - existing_inherited_tags = [tag.name for tag in instance.inherited_tags.all()] - # Check if product tags already matches inherited tags - product_tags_equals_inherited_tags = product_inherited_tags == existing_inherited_tags - # Check if product tags have already been inherited - tags_have_already_been_inherited = set(product_inherited_tags) <= set(tag_list) - return not (product_tags_equals_inherited_tags and tags_have_already_been_inherited) - - -def inherit_product_tags(instance) -> bool: - products = get_products(instance) - # Save a read in the db - if any(product.enable_product_tag_inheritance for product in products if product): - return True - - return get_system_setting("enable_product_tag_inheritance") - - -def get_products_to_inherit_tags_from(instance) -> list[Product]: - products = get_products(instance) - system_wide_inherit = get_system_setting("enable_product_tag_inheritance") - - return [ - product for product in products if product.enable_product_tag_inheritance or system_wide_inherit - ] - - -def get_products(instance) -> list[Product]: - if isinstance(instance, Product): - return [instance] - if isinstance(instance, Endpoint): - return [instance.product] - if isinstance(instance, Engagement): - return [instance.product] - if isinstance(instance, Test): - return [instance.engagement.product] - if isinstance(instance, Finding): - return [instance.test.engagement.product] - if isinstance(instance, Location): - return list(instance.all_related_products()) - return [] diff --git a/dojo/utils_cascade_delete.py b/dojo/utils_cascade_delete.py index 992e072411c..1fb54113aa7 100644 --- a/dojo/utils_cascade_delete.py +++ b/dojo/utils_cascade_delete.py @@ -190,7 +190,7 @@ def cascade_delete_related_objects(from_model, instance_pk_query, skip_relations # Clear M2M through tables before deleting (not discovered by _meta.related_objects). # Skip if the caller already handled M2M cleanup for this model (e.g. bulk_clear_finding_m2m). if from_model not in skip_m2m_for: - from dojo.tag_utils import bulk_remove_all_tags # noqa: PLC0415 circular import + from dojo.tags.utils import bulk_remove_all_tags # noqa: PLC0415 circular import bulk_remove_all_tags(from_model, instance_pk_query) diff --git a/unittests/test_tag_inheritance.py b/unittests/test_tag_inheritance.py index 1d3bb7a8b18..964094fa8ed 100644 --- a/unittests/test_tag_inheritance.py +++ b/unittests/test_tag_inheritance.py @@ -26,8 +26,13 @@ from dojo.location.models import Location, LocationProductReference from dojo.location.status import ProductLocationStatus from dojo.models import Endpoint, Engagement, Finding, Product, Product_Type, Test, Test_Type -from dojo.product.helpers import propagate_tags_on_product_sync -from dojo.tags_signals import get_products, inherit_product_tags, propagate_inheritance +from dojo.tags.inheritance import ( + _sync_inherited_tags, # noqa: PLC2701 -- private API tested directly + get_products, + is_tag_inheritance_enabled, + propagate_tags_on_product_sync, +) +from dojo.tags.signals import auto_inherit_product_tags from dojo.tools.locations import LocationData from unittests.dojo_test_case import ( DojoAPITestCase, @@ -118,48 +123,47 @@ def _make_product(self, *, per_product_flag): p.enable_product_tag_inheritance = per_product_flag return p - @patch("dojo.tags_signals.get_system_setting", return_value=True) - @patch("dojo.tags_signals.get_products") + @patch("dojo.tags.inheritance.get_system_setting", return_value=True) + @patch("dojo.tags.inheritance.get_products") def test_system_setting_on_returns_true(self, mock_get_products, mock_setting): mock_get_products.return_value = [self._make_product(per_product_flag=False)] - self.assertTrue(inherit_product_tags(MagicMock())) + self.assertTrue(is_tag_inheritance_enabled(MagicMock())) - @patch("dojo.tags_signals.get_system_setting", return_value=False) - @patch("dojo.tags_signals.get_products") + @patch("dojo.tags.inheritance.get_system_setting", return_value=False) + @patch("dojo.tags.inheritance.get_products") def test_per_product_flag_on_system_off_returns_true(self, mock_get_products, mock_setting): mock_get_products.return_value = [self._make_product(per_product_flag=True)] - self.assertTrue(inherit_product_tags(MagicMock())) + self.assertTrue(is_tag_inheritance_enabled(MagicMock())) - @patch("dojo.tags_signals.get_system_setting", return_value=False) - @patch("dojo.tags_signals.get_products") + @patch("dojo.tags.inheritance.get_system_setting", return_value=False) + @patch("dojo.tags.inheritance.get_products") def test_both_off_returns_false(self, mock_get_products, mock_setting): mock_get_products.return_value = [self._make_product(per_product_flag=False)] - self.assertFalse(inherit_product_tags(MagicMock())) + self.assertFalse(is_tag_inheritance_enabled(MagicMock())) - @patch("dojo.tags_signals.get_system_setting", return_value=False) - @patch("dojo.tags_signals.get_products") + @patch("dojo.tags.inheritance.get_system_setting", return_value=False) + @patch("dojo.tags.inheritance.get_products") def test_no_products_returns_false(self, mock_get_products, mock_setting): mock_get_products.return_value = [] - self.assertFalse(inherit_product_tags(MagicMock())) + self.assertFalse(is_tag_inheritance_enabled(MagicMock())) - @patch("dojo.tags_signals.get_system_setting", return_value=False) - @patch("dojo.tags_signals.get_products") + @patch("dojo.tags.inheritance.get_system_setting", return_value=False) + @patch("dojo.tags.inheritance.get_products") def test_none_entries_in_product_list_are_skipped(self, mock_get_products, mock_setting): mock_get_products.return_value = [None, self._make_product(per_product_flag=False)] - self.assertFalse(inherit_product_tags(MagicMock())) + self.assertFalse(is_tag_inheritance_enabled(MagicMock())) -class TestPropagateInheritanceEarlyExit(unittest.TestCase): +class TestManageInheritedTagsDiff(unittest.TestCase): """ - Unit tests for propagate_inheritance() — the optimization guard that skips redundant DB writes. - - Returns False ("nothing to do") only when BOTH conditions hold: - 1. product tags match what is stored in instance.inherited_tags (already recorded) - 2. those tags are already present in the instance's full tag_list (already applied) - If either condition is false, returns True and the caller proceeds to write tags. - get_products_to_inherit_tags_from and instance.inherited_tags.all() are mocked - to isolate the boolean logic from DB access. + Unit tests for _sync_inherited_tags() — the diff primitive. + + Verifies that the function: + - Adds inherited tags that aren't yet recorded. + - Removes inherited tags that no longer belong. + - Re-adds inherited tags missing from instance.tags (sticky enforcement). + - Does no work when target matches current state. """ def _tag(self, name): @@ -167,43 +171,62 @@ def _tag(self, name): t.name = name return t - def _make_instance(self, inherited_names): + def _make_instance(self, inherited_names, tag_names): instance = MagicMock() + # Skip the FakeTagRelatedManager branch so we exercise the diff path. + instance.inherited_tags.__class__ = MagicMock + instance.tags.__class__ = MagicMock instance.inherited_tags.all.return_value = [self._tag(n) for n in inherited_names] + instance.tags.all.return_value = [self._tag(n) for n in tag_names] return instance - def _make_product(self, tag_names): - product = MagicMock() - product.tags.all.return_value = [self._tag(n) for n in tag_names] - return product - - @patch("dojo.tags_signals.get_products_to_inherit_tags_from") - def test_already_in_sync_returns_false(self, mock_get): - """inherited_tags matches product tags and all present in tag_list → skip.""" - instance = self._make_instance(["alpha", "beta"]) - mock_get.return_value = [self._make_product(["alpha", "beta"])] - self.assertFalse(propagate_inheritance(instance, tag_list=["alpha", "beta"])) - - @patch("dojo.tags_signals.get_products_to_inherit_tags_from") - def test_product_tags_changed_returns_true(self, mock_get): - """Stored inherited_tags differ from current product tags → must propagate.""" - instance = self._make_instance(["old"]) - mock_get.return_value = [self._make_product(["new"])] - self.assertTrue(propagate_inheritance(instance, tag_list=["old", "new"])) - - @patch("dojo.tags_signals.get_products_to_inherit_tags_from") - def test_tags_not_yet_applied_to_instance_returns_true(self, mock_get): - """inherited_tags already correct but not yet reflected in tag_list → must propagate.""" - instance = self._make_instance(["alpha"]) - mock_get.return_value = [self._make_product(["alpha"])] - self.assertTrue(propagate_inheritance(instance, tag_list=[])) - - @patch("dojo.tags_signals.get_products_to_inherit_tags_from") - def test_no_products_no_inherited_tags_returns_false(self, mock_get): - """No products, no inherited tags, empty tag_list → already in sync, skip.""" - instance = self._make_instance([]) + def test_already_in_sync_no_writes(self): + instance = self._make_instance(["alpha", "beta"], tag_names=["alpha", "beta"]) + _sync_inherited_tags(instance, ["alpha", "beta"]) + instance.inherited_tags.add.assert_not_called() + instance.inherited_tags.remove.assert_not_called() + instance.tags.add.assert_not_called() + instance.tags.remove.assert_not_called() + + def test_target_adds_new_inherited(self): + instance = self._make_instance(["old"], tag_names=["old", "user"]) + _sync_inherited_tags(instance, ["old", "new"]) + instance.inherited_tags.add.assert_called_once_with("new") + instance.tags.add.assert_called_once_with("new") + instance.inherited_tags.remove.assert_not_called() + instance.tags.remove.assert_not_called() + + def test_target_removes_dropped_inherited(self): + instance = self._make_instance(["alpha", "beta"], tag_names=["alpha", "beta", "user"]) + _sync_inherited_tags(instance, ["alpha"]) + instance.inherited_tags.remove.assert_called_once_with("beta") + instance.tags.remove.assert_called_once_with("beta") + instance.inherited_tags.add.assert_not_called() + instance.tags.add.assert_not_called() + + def test_sticky_readds_missing_inherited(self): + # inherited_tags already records "alpha", target is "alpha", but user + # stripped it from tags via m2m_changed. Sticky enforcement re-adds it. + instance = self._make_instance(["alpha"], tag_names=["user"]) + _sync_inherited_tags(instance, ["alpha"]) + instance.inherited_tags.add.assert_not_called() + instance.inherited_tags.remove.assert_not_called() + instance.tags.remove.assert_not_called() + instance.tags.add.assert_called_once_with("alpha") + + +class TestInheritInstanceTagsEarlyExit(unittest.TestCase): + + """No-products case: auto_inherit_product_tags must short-circuit before touching the instance.""" + + @patch("dojo.tags.signals.get_products_to_inherit_tags_from") + def test_no_products_skips_write(self, mock_get): + instance = MagicMock() mock_get.return_value = [] - self.assertFalse(propagate_inheritance(instance, tag_list=[])) + auto_inherit_product_tags(instance) + instance.inherit_tags.assert_not_called() + instance.inherited_tags.add.assert_not_called() + instance.tags.add.assert_not_called() # --------------------------------------------------------------------------- @@ -417,7 +440,7 @@ class TestLocationMultipleProductInheritance(DojoTestCase): linked to many products via LocationProductReference. These tests verify that all_related_products() is used correctly and tags are merged from every linked product, and that the per-product flag filters correctly when the system setting is off. - inherit_instance_tags() is called directly rather than relying on signal chaining. + auto_inherit_product_tags() is called directly rather than relying on signal chaining. Skipped when V3_FEATURE_LOCATIONS is disabled. """ @@ -427,7 +450,7 @@ def setUp(self): self.system_settings(enable_product_tag_inheritance=True) def test_location_inherits_from_multiple_products(self): - from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 + from dojo.tags.signals import auto_inherit_product_tags # noqa: PLC0415 p1 = self.create_product("Product A", tags=["p1-tag"]) p2 = self.create_product("Product B", tags=["p2-tag"]) @@ -441,7 +464,7 @@ def test_location_inherits_from_multiple_products(self): location=location, product=p2, status=ProductLocationStatus.Active, ) - inherit_instance_tags(location) + auto_inherit_product_tags(location) location.refresh_from_db() tag_names = sorted(t.name for t in location.tags.all()) @@ -449,7 +472,7 @@ def test_location_inherits_from_multiple_products(self): self.assertIn("p2-tag", tag_names) def test_location_inherits_only_from_flagged_product_when_system_off(self): - from dojo.tags_signals import inherit_instance_tags # noqa: PLC0415 + from dojo.tags.signals import auto_inherit_product_tags # noqa: PLC0415 self.system_settings(enable_product_tag_inheritance=False) @@ -467,7 +490,7 @@ def test_location_inherits_only_from_flagged_product_when_system_off(self): location=location, product=p_no, status=ProductLocationStatus.Active, ) - inherit_instance_tags(location) + auto_inherit_product_tags(location) location.refresh_from_db() tag_names = sorted(t.name for t in location.tags.all()) diff --git a/unittests/test_tag_inheritance_perf.py b/unittests/test_tag_inheritance_perf.py index a0cfa3958ab..9cbf221c038 100644 --- a/unittests/test_tag_inheritance_perf.py +++ b/unittests/test_tag_inheritance_perf.py @@ -23,7 +23,7 @@ from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Product_Type, Test, Test_Type -from dojo.product.helpers import propagate_tags_on_product_sync +from dojo.tags.inheritance import propagate_tags_on_product_sync from unittests.dojo_test_case import ( DojoAPITestCase, DojoTestCase, @@ -214,6 +214,24 @@ def _do_finding_add_user_tag(self, name: str, expected: int) -> None: self.assertIn("user-only", finding_tag_names) self.assertIn("inherited", finding_tag_names) # still sticky + def _do_propagate_sync_only(self, name: str, expected: int, *, with_endpoints: bool, with_locations: bool) -> None: + """ + Measure `propagate_tags_on_product_sync(product)` in isolation — no tag change. + + Captures the raw sweep cost for a product with a realistic mix of children: + N findings + (V2) N endpoints or (V3) N locations. Should be roughly idempotent + (no add/remove to apply) so the number reflects diff-detection overhead. + """ + product = _make_product_with_findings(name, n_findings=100, tags=["t1", "t2"]) + if with_endpoints: + _make_endpoints(product, n=100) + if with_locations: + _make_locations(product, n=100) + with self.assertNumQueries(expected): + propagate_tags_on_product_sync(product) + finding = Finding.objects.filter(test__engagement__product=product).first() + self.assertEqual({"t1", "t2"}, {t.name for t in finding.tags.all()}) + def _do_finding_remove_inherited(self, name: str, expected: int) -> None: product = _make_product_with_findings(name, n_findings=1, tags=["inherited"]) finding = Finding.objects.filter(test__engagement__product=product).first() @@ -282,6 +300,29 @@ def test_baseline_finding_remove_inherited_tag_sticky_re_adds_v2(self): def test_baseline_finding_remove_inherited_tag_sticky_re_adds_v3(self): self._do_finding_remove_inherited("perf-sticky-rm-v3", self.EXPECTED_FINDING_REMOVE_INHERITED_V3) + # ------------------------------------------------------------------ + # propagate_tags_on_product_sync direct invocation (no tag change). + # Measures the raw sweep cost over a product's children. + # ------------------------------------------------------------------ + + @override_settings(V3_FEATURE_LOCATIONS=False) + def test_baseline_propagate_tags_on_product_sync_v2(self): + self._do_propagate_sync_only( + "perf-sync-v2", + self.EXPECTED_PROPAGATE_SYNC_V2, + with_endpoints=True, + with_locations=False, + ) + + @override_settings(V3_FEATURE_LOCATIONS=True) + def test_baseline_propagate_tags_on_product_sync_v3(self): + self._do_propagate_sync_only( + "perf-sync-v3", + self.EXPECTED_PROPAGATE_SYNC_V3, + with_endpoints=False, + with_locations=True, + ) + # ------------------------------------------------------------------ # V2: propagation to Endpoints (skipped under V3_FEATURE_LOCATIONS) # ------------------------------------------------------------------ @@ -364,28 +405,34 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self): EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 53 EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 53 - EXPECTED_CREATE_ONE_FINDING_V2 = 64 - EXPECTED_CREATE_ONE_FINDING_V3 = 64 - EXPECTED_CREATE_100_FINDINGS_V2 = 4024 - EXPECTED_CREATE_100_FINDINGS_V3 = 4024 + EXPECTED_CREATE_ONE_FINDING_V2 = 55 + EXPECTED_CREATE_ONE_FINDING_V3 = 55 + EXPECTED_CREATE_100_FINDINGS_V2 = 3124 + EXPECTED_CREATE_100_FINDINGS_V3 = 3124 EXPECTED_FINDING_ADD_USER_TAG_V2 = 17 EXPECTED_FINDING_ADD_USER_TAG_V3 = 17 - EXPECTED_FINDING_REMOVE_INHERITED_V2 = 44 - EXPECTED_FINDING_REMOVE_INHERITED_V3 = 44 + EXPECTED_FINDING_REMOVE_INHERITED_V2 = 18 + EXPECTED_FINDING_REMOVE_INHERITED_V3 = 18 # V2 endpoint paths. Pre-Phase-A: 3958 add, 3740 remove. EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 91 EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 53 # V3 location paths. Pre-Phase-A: 4532 add, 4307 remove. - EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 316 - EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 266 + EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 125 + EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 75 + + # propagate_tags_on_product_sync direct invocation (no tag change). + # Product with 100 findings + 100 endpoints (V2) or + 100 locations (V3). + EXPECTED_PROPAGATE_SYNC_V2 = 9 + EXPECTED_PROPAGATE_SYNC_V3 = 18 @override_settings( CELERY_TASK_ALWAYS_EAGER=True, CELERY_TASK_EAGER_PROPAGATES=True, + SECURE_SSL_REDIRECT=False, ) @versioned_fixtures class TagInheritanceImportPerfBaselines(DojoAPITestCase): @@ -394,10 +441,10 @@ class TagInheritanceImportPerfBaselines(DojoAPITestCase): Pinned query-count baselines for the importer hot path. Real production tag-inheritance cost lives in scan import / reimport: the - importer creates findings + endpoints/locations, then `_manage_inherited_tags` + importer creates findings + endpoints/locations, then `_sync_inherited_tags` runs per row. Phase A (bulk product-side propagation + post_save gated on create) doesn't touch this loop because the importer's hot path is - creation-driven. Phase B's `tag_inheritance.batch()` context manager + creation-driven. Phase B's `tag_inheritance.suppress_tag_inheritance()` context manager targets it. Two scenarios: @@ -421,6 +468,11 @@ def setUp(self): self.product = self.create_product("Tag Perf Import Product", tags=["inherit", "these"]) self.engagement = self.create_engagement("Tag Perf Import Engagement", self.product) self.scan_path = get_unit_tests_scans_path("zap") / "dvwa_baseline_dojo.xml" + # Subset of the full report (10 findings vs 19) used to exercise the + # reimport-with-new-findings code path: initial import uses the subset, + # then reimport uses the full report so 9 findings get created during + # reimport while 10 match existing ones. + self.scan_path_subset = get_unit_tests_scans_path("zap") / "dvwa_baseline_dojo_subset.xml" @override_settings(V3_FEATURE_LOCATIONS=False) def test_baseline_zap_scan_import_v2(self): @@ -429,7 +481,7 @@ def test_baseline_zap_scan_import_v2(self): Captures total query count for: scan parse + finding creation + endpoint attachment + per-row inherit_tags signal chain. Production hot path. - Phase A leaves this number ~unchanged; Phase B's `tag_inheritance.batch()` + Phase A leaves this number ~unchanged; Phase B's `tag_inheritance.suppress_tag_inheritance()` targets it. """ with self.assertNumQueries(self.EXPECTED_ZAP_IMPORT_V2): @@ -487,6 +539,42 @@ def test_baseline_zap_scan_reimport_no_change_v3(self): finding = Finding.objects.filter(test_id=test_id).first() self.assertEqual({"inherit", "these"}, {t.name for t in finding.tags.all()}) + @override_settings(V3_FEATURE_LOCATIONS=False) + def test_baseline_zap_scan_reimport_with_new_findings_v2(self): + """ + V2: import 10-finding subset, then reimport 19-finding full report. + + 9 findings are NEW (must run inheritance), 10 are matched (skip). + Exercises the realistic "scheduled rescan with drift" path where a + reimport actually creates findings. + """ + response = self.import_scan_with_params( + self.scan_path_subset, + engagement=self.engagement.id, + ) + test_id = response["test"] + + with self.assertNumQueries(self.EXPECTED_ZAP_REIMPORT_WITH_NEW_V2): + self.reimport_scan_with_params(test_id, str(self.scan_path)) + + finding = Finding.objects.filter(test_id=test_id).first() + self.assertEqual({"inherit", "these"}, {t.name for t in finding.tags.all()}) + + @override_settings(V3_FEATURE_LOCATIONS=True) + def test_baseline_zap_scan_reimport_with_new_findings_v3(self): + """V3: same as V2 but Location-backed.""" + response = self.import_scan_with_params( + self.scan_path_subset, + engagement=self.engagement.id, + ) + test_id = response["test"] + + with self.assertNumQueries(self.EXPECTED_ZAP_REIMPORT_WITH_NEW_V3): + self.reimport_scan_with_params(test_id, str(self.scan_path)) + + finding = Finding.objects.filter(test_id=test_id).first() + self.assertEqual({"inherit", "these"}, {t.name for t in finding.tags.all()}) + # Pinned baselines per mode. Each test forces its own V3_FEATURE_LOCATIONS # via @override_settings so all four import paths run in a single suite # invocation regardless of the ambient `DD_V3_FEATURE_LOCATIONS` env var. @@ -497,11 +585,9 @@ def test_baseline_zap_scan_reimport_no_change_v3(self): # import path because the previous process-global signal-disconnect was # narrower in scope (Location.tags.through only). Net-positive trade for # eliminating the threading bug; full Phase B reductions land in Stage 2. - # Track B legacy auth + Alert FK-validation skip: notifications dispatch - # to the full active-user set (legacy doesn't filter by RBAC role) but - # the per-Alert ForeignKey.validate EXISTS probe is gone, netting -7 - # against the pre-Track-B numbers. - EXPECTED_ZAP_IMPORT_V2 = 1378 - EXPECTED_ZAP_IMPORT_V3 = 1256 + EXPECTED_ZAP_IMPORT_V2 = 420 + EXPECTED_ZAP_IMPORT_V3 = 444 EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69 - EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87 + EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 81 + EXPECTED_ZAP_REIMPORT_WITH_NEW_V2 = 169 + EXPECTED_ZAP_REIMPORT_WITH_NEW_V3 = 198 diff --git a/unittests/test_tag_utils_bulk.py b/unittests/test_tag_utils_bulk.py index 3a815041fb4..5e22515dd50 100644 --- a/unittests/test_tag_utils_bulk.py +++ b/unittests/test_tag_utils_bulk.py @@ -5,7 +5,7 @@ from dojo.location.models import Location from dojo.models import Endpoint, Engagement, Finding, Product, Product_Type, Test, Test_Type -from dojo.tag_utils import bulk_add_tag_mapping, bulk_add_tags_to_instances, bulk_apply_parser_tags +from dojo.tags.utils import bulk_add_tag_mapping, bulk_add_tags_to_instances, bulk_apply_parser_tags from dojo.url.models import URL from unittests.dojo_test_case import DojoAPITestCase, versioned_fixtures