diff --git a/comfy_cli/tracking.py b/comfy_cli/tracking.py index a42d3b9e..8c631515 100644 --- a/comfy_cli/tracking.py +++ b/comfy_cli/tracking.py @@ -2,6 +2,7 @@ import atexit import functools +import json import logging as logginglib import os import sys @@ -38,11 +39,47 @@ # historical streams stay continuous. POSTHOG_EVENT_PREFIX = "cli:" -# Kwargs whose values must never reach tracking system. -# The key is kept (with a redacted marker) so we can still see whether the option was supplied. -# `token` is the registry publisher PAT; `changelog` is bulky free text (up to a whole -# GitHub release body) with no analytics value beyond its presence. -SENSITIVE_TRACKING_KEYS = frozenset({"api_key", "token", "changelog"}) +# Sanitize command kwargs before sending them as telemetry: _is_sensitive() +# masks credential-bearing names, _is_trackable() drops ctx/private/unserializable +# values, and _scrub_value() strips query strings off URL values. + +_SENSITIVE_SUFFIXES = ("_token", "_api_key", "_secret", "_password") +# `token` is the publish PAT; `changelog` is bulky free text with no analytics +# value beyond its presence. Sensitive values become "" (the key is +# kept so we can still tell the option was supplied). +_SENSITIVE_EXACT = frozenset({"api_key", "token", "password", "secret", "changelog"}) + + +def _is_sensitive(name: str) -> bool: + """True if *name* looks like a credential. Case-insensitive; matches the + snake_case suffixes only (Typer kwargs are always snake_case).""" + lower = name.lower() + return lower in _SENSITIVE_EXACT or lower.endswith(_SENSITIVE_SUFFIXES) + + +def _is_trackable(name: str, value: object) -> bool: + """True if the (name, value) kwarg is safe to send. Drops ctx/context, + underscore-prefixed names, and values json can't serialize -- posthog-python + coerces unserializable values and ships them (e.g. a Click Context) rather + than raising the way Mixpanel does, so we must reject them ourselves.""" + if name in ("ctx", "context"): + return False + if name.startswith("_"): + return False + try: + json.dumps(value) + except (TypeError, ValueError, OverflowError, RecursionError): + return False + return True + + +def _scrub_value(value: object) -> object: + """Strip the query string and fragment from URL values; CivitAI download + links carry the token as ?token=. Only top-level http(s) strings are touched.""" + if isinstance(value, str) and value.startswith(("http://", "https://")): + return value.partition("?")[0].partition("#")[0] + return value + # Generate a unique tracing ID per command. config_manager = ConfigManager() @@ -183,11 +220,13 @@ def track_event(event_name: str, properties: Any = None, *, mixpanel_name: str | def filter_command_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: - """Drop ``ctx``/``context`` and redact ``SENSITIVE_TRACKING_KEYS`` values.""" + """Drop untrackable kwargs (see ``_is_trackable``), redact sensitive values + (see ``_is_sensitive``), and strip credentials embedded in URL values + (see ``_scrub_value``).""" return { - k: ("" if v is not None else None) if k in SENSITIVE_TRACKING_KEYS else v + k: ("" if v is not None else None) if _is_sensitive(k) else _scrub_value(v) for k, v in kwargs.items() - if k != "ctx" and k != "context" + if _is_trackable(k, v) } diff --git a/tests/comfy_cli/test_tracking.py b/tests/comfy_cli/test_tracking.py index 557da54e..f642ee32 100644 --- a/tests/comfy_cli/test_tracking.py +++ b/tests/comfy_cli/test_tracking.py @@ -138,6 +138,213 @@ def publish(token=None, changelog=None, changelog_file=None): assert "pat-supersecret" not in str(properties) assert "fix things" not in str(properties) + def test_set_civitai_api_token_is_redacted(self, tracking_module): + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command("model") + def download(url, set_civitai_api_token=None, set_hf_api_token=None): + return None + + download(url="https://example.com", set_civitai_api_token="civ-real-token") + + tracking_module.provider.track.assert_called_once() + _, _, properties = _last_track_call(tracking_module.provider) + assert properties["set_civitai_api_token"] == "" + assert "civ-real-token" not in str(properties) + + def test_set_hf_api_token_is_redacted(self, tracking_module): + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command("model") + def download(url, set_civitai_api_token=None, set_hf_api_token=None): + return None + + download(url="https://example.com", set_hf_api_token="hf_real-token") + + tracking_module.provider.track.assert_called_once() + _, _, properties = _last_track_call(tracking_module.provider) + assert properties["set_hf_api_token"] == "" + assert "hf_real-token" not in str(properties) + + def test_bare_token_kwarg_is_redacted(self, tracking_module): + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command() + def some_cmd(workflow, token=None): + return None + + some_cmd(workflow="wf.json", token="my-secret-token") + + tracking_module.provider.track.assert_called_once() + _, _, properties = _last_track_call(tracking_module.provider) + assert properties["token"] == "" + assert "my-secret-token" not in str(properties) + + def test_underscore_ctx_is_excluded(self, tracking_module): + import click + + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command("model") + def download(_ctx, url, set_civitai_api_token=None): + return None + + ctx = click.Context(click.Command("download")) + download(_ctx=ctx, url="https://example.com") + + tracking_module.provider.track.assert_called_once() + _, _, properties = _last_track_call(tracking_module.provider) + assert "_ctx" not in properties + assert properties["url"] == "https://example.com" + + def test_non_serializable_value_is_excluded(self, tracking_module): + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command() + def some_cmd(workflow, callback=None): + return None + + some_cmd(workflow="wf.json", callback=lambda x: x) + + tracking_module.provider.track.assert_called_once() + _, _, properties = _last_track_call(tracking_module.provider) + assert "callback" not in properties + assert properties["workflow"] == "wf.json" + + def test_url_query_string_is_scrubbed(self, tracking_module): + # CivitAI download links carry the API key as `?token=`. + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command("model") + def download(url=None, relative_path=None): + return None + + download(url="https://civitai.com/api/download/models/12345?token=civ-url-secret") + + _, _, properties = _last_track_call(tracking_module.provider) + assert properties["url"] == "https://civitai.com/api/download/models/12345" + assert "civ-url-secret" not in str(properties) + + def test_url_without_query_is_unchanged(self, tracking_module): + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + @tracking_module.track_command("model") + def download(url=None): + return None + + download(url="https://huggingface.co/org/repo/resolve/main/m.safetensors") + + _, _, properties = _last_track_call(tracking_module.provider) + assert properties["url"] == "https://huggingface.co/org/repo/resolve/main/m.safetensors" + + +class TestSensitiveNameMatcher: + @pytest.mark.parametrize( + "name", + [ + "api_key", + "token", + "password", + "secret", + "changelog", + "set_civitai_api_token", + "set_hf_api_token", + "access_token", + "client_secret", + "admin_password", + "API_KEY", + "Set_HF_Api_Token", + ], + ) + def test_matches(self, name): + import comfy_cli.tracking as tm + + assert tm._is_sensitive(name) is True + + @pytest.mark.parametrize("name", ["url", "workflow", "changelog_file", "max_tokens", "tokenizer", "relative_path"]) + def test_does_not_match(self, name): + import comfy_cli.tracking as tm + + assert tm._is_sensitive(name) is False + + +class TestCliParamNameDriftGate: + """BE-992 happened because credential flags were added after the redaction + set was written. Walk the real CLI tree so the next one cannot land + unredacted.""" + + # Params whose names merely contain a credential-ish substring but are + # reviewed as safe to track verbatim go here. + ALLOWLIST = frozenset() + + def test_credentialish_cli_params_are_redacted(self): + import click + from typer.main import get_command + + import comfy_cli.tracking as tm + from comfy_cli.cmdline import app + + suspicious = ("token", "secret", "password", "api_key", "apikey", "credential") + + def walk(cmd, path): + if isinstance(cmd, click.Group): + for name, sub in cmd.commands.items(): + yield from walk(sub, [*path, name]) + return + for param in cmd.params: + if param.name: + yield " ".join(path), param.name + + offenders = sorted( + { + (path, pname) + for path, pname in walk(get_command(app), ["comfy"]) + if any(s in pname.lower() for s in suspicious) + and pname not in self.ALLOWLIST + and not tm._is_sensitive(pname) + } + ) + assert offenders == [], f"credential-looking CLI params not redacted by _is_sensitive: {offenders}" + + +class TestTrackCommandRealTyperWiring: + def test_model_download_kwargs_are_filtered_and_redacted(self, tracking_module): + # `model download` is the command whose `_ctx` + credential kwarg + # combination motivated BE-992; invoke it through Typer for real so + # the Click context actually lands in the tracked kwargs. + from typer.testing import CliRunner + + import comfy_cli.command.models.models as models + + tracking_module.config_manager.set(constants.CONFIG_KEY_ENABLE_TRACKING, "True") + + with ( + patch.object(models, "config_manager", MagicMock()), + patch.object(models, "check_civitai_url", side_effect=RuntimeError("halt after tracking")), + ): + result = CliRunner().invoke( + models.app, + [ + "download", + "--url", + "https://example.com/model.safetensors?token=url-secret", + "--set-civitai-api-token", + "civ-secret", + ], + ) + + # The command body aborted at the patched helper, after tracking fired. + assert isinstance(result.exception, RuntimeError) + + tracking_module.provider.track.assert_called_once() + event_name, _, properties = _last_track_call(tracking_module.provider) + assert event_name == "model:download" + assert "_ctx" not in properties + assert properties["set_civitai_api_token"] == "" + assert "civ-secret" not in str(properties) + assert properties["url"] == "https://example.com/model.safetensors" + assert "url-secret" not in str(properties) + class TestInitTrackingRoundTrip: """End-to-end: init_tracking() writes the string "False"/"True", and track_event honors it. diff --git a/tests/comfy_cli/test_tracking_providers.py b/tests/comfy_cli/test_tracking_providers.py index f3a9846e..591cc754 100644 --- a/tests/comfy_cli/test_tracking_providers.py +++ b/tests/comfy_cli/test_tracking_providers.py @@ -241,6 +241,35 @@ def fake_cmd(workflow, api_key=None): assert "sk-supersecret" not in str(mp_kwargs["properties"]) assert "sk-supersecret" not in str(ph_kwargs["properties"]) + def test_download_credentials_never_reach_either_provider(self, tracking_with_two_providers): + # BE-992 kwarg shape: before the suffix matcher and the underscore + # filter, the un-redacted token still shipped to PostHog because its + # client coerces the unserializable _ctx instead of raising the way + # Mixpanel's does. + import click + + tracking_mod, mp_provider, ph_provider = tracking_with_two_providers + + @tracking_mod.track_command("model") + def download(_ctx=None, url=None, set_civitai_api_token=None, set_hf_api_token=None): + return None + + download( + _ctx=click.Context(click.Command("download")), + url="https://example.com/model.safetensors", + set_civitai_api_token="civ-secret", + set_hf_api_token="hf-secret", + ) + + _, mp_kwargs = mp_provider.client.track.call_args + ph_kwargs = _posthog_capture_kwargs(ph_provider.client) + for properties in (mp_kwargs["properties"], ph_kwargs["properties"]): + assert "_ctx" not in properties + assert properties["set_civitai_api_token"] == "" + assert properties["set_hf_api_token"] == "" + assert "civ-secret" not in str(properties) + assert "hf-secret" not in str(properties) + class TestAtexitFlush: def test_flush_all_providers_calls_each_flush(self):