From ebe1f69257446dd811e62d2631ffa95d4606c2d8 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:13:35 -0400 Subject: [PATCH 01/21] PreTrendsPower PR-B Step 2: NIS test form + result-class extension + helper API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Roth (2022) Section II.A-B NIS box probability as the new primary `pretest_form='nis'` default. Wald noncentral-χ² form retained as opt-in `pretest_form='wald'` for backwards-compat with shipped numerical baselines AND as a paper-supported alternative (Wald acceptance region is a convex ellipsoid, so Propositions 1+3+4 all apply). Changes: - `PreTrendsPower.__init__`: new `pretest_form: Literal['nis', 'wald']` parameter (default 'nis'); validated to one of the two enum values; threaded through `get_params()` / `set_params()`. - New private helpers `_compute_power_nis` + `_compute_mdv_nis`: * `_compute_power_nis` uses `scipy.stats.multivariate_normal.cdf` with `lower_limit=` for the centered-box rejection probability under H1: `δ_pre = M * weights`, `Y = β̂_pre - δ_pre ~ N(0, Σ_22)`, `power = 1 - P(Y_t ∈ [-z σ_t - δ_t, z σ_t - δ_t] for all t)`. Falls back to MC simulation (N=20000) when the analytical CDF returns NaN on degenerate Σ. * `_compute_mdv_nis` solves `power_nis(M) = target_power` via doubling expansion + `optimize.brentq` bisection; non-convergence cap at M_high=1000 returns `np.inf` (mirrors Wald path's existing 1000-cap). - Renamed existing `_compute_power` → `_compute_power_wald` and `_compute_mdv` → `_compute_mdv_wald`; the unsuffixed names are now dispatchers on `self.pretest_form`. Wald math is byte-identical. - `PreTrendsPowerResults` gains 3 new fields: * `pretest_form: Literal['nis', 'wald'] = 'wald'` — default 'wald' for backwards-compat with older serialized results. * `nis_box_probability: float = np.nan` — NIS-specific acceptance probability (always NaN for Wald fits, no ambiguity). * `violation_weights: Optional[np.ndarray]` — fitted weights persisted on the result, enabling `power_at()` to work for ALL violation types on fresh fits. - `fit()` populates all three new fields and dispatches. - `power_curve()` inherits dispatch through `_compute_power`. - `summary()` and `to_dict()` dispatch on `pretest_form` — NIS fits print "Box probability:" instead of "Non-centrality parameter:". - `PreTrendsPowerResults.power_at()` refactored: uses `self.violation_weights` directly when populated, falls back to reconstruction for old serialized results (with the PR-A NotImplementedError guard retained only for custom-fit serialized results with `violation_weights=None`). - `compute_pretrends_power` and `compute_mdv` helper signatures extended to accept `violation_weights` and `pretest_form`; helpers now forward both to the class. Closes the helper/class API gap from PR-A R18. Smoke-tested with K=2 and K=3 panels: - NIS power at M=0 with K=3 ≈ 0.138 (matches 1 - (1-α)^K = 0.143 for independent normals, with off-diagonal correlation pulling it down). - Wald power at M=0 with K=3 = 0.05 (exact size under H0). - NIS MDV(80%, K=3) = 0.59, Wald MDV(80%, K=3) = 0.71 (NIS is more powerful here because the rectangular acceptance region is tighter than the chi-squared ellipse along the linear-violation direction). Pre-existing pyright type-stub warnings on `optimize.brentq` and `stats.multivariate_normal.cdf` are not touched. Plan ref: /Users/igerber/.claude/plans/stateless-prancing-iverson.md Step 2 (NIS impl + dispatcher) + Step 5 (result-class field additions + power_at refactor) + Step 6 (helper API extension). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/pretrends.py | 444 ++++++++++++++++++++++++++++++++--------- 1 file changed, 351 insertions(+), 93 deletions(-) diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index 8b32c471..a091c671 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -61,17 +61,34 @@ class PreTrendsPowerResults: n_pre_periods : int Number of pre-treatment periods in the event study. test_statistic : float - Expected test statistic under the specified violation. + Expected test statistic under the specified violation (Wald only; + NaN for NIS fits). critical_value : float Critical value for the pre-trends test. noncentrality : float - Non-centrality parameter under the alternative hypothesis. + Non-centrality parameter under the alternative hypothesis (Wald only; + NaN for NIS fits). pre_period_effects : np.ndarray Estimated pre-period effects from the event study. pre_period_ses : np.ndarray Standard errors of pre-period effects. vcov : np.ndarray Variance-covariance matrix of pre-period effects. + pretest_form : str + Pretest acceptance-region form used: ``'nis'`` (no-individually- + significant box probability — Roth 2022 Section II.A-B, default for new + fits) or ``'wald'`` (noncentral-chi-squared on the quadratic form + ``delta' Sigma_22^{-1} delta`` — paper-supported alternative, retained + for backwards compatibility with shipped numerical baselines). + nis_box_probability : float + Acceptance probability ``P(beta_hat_pre in B_NIS(Sigma))`` under the + alternative ``M * weights``. NIS-only; NaN for Wald fits. + violation_weights : np.ndarray, optional + The normalized violation-direction vector used at fit time. Populated + for all violation types on fresh fits. Old serialized results may have + ``None`` here; ``power_at()`` falls back to reconstruction in that + case (with the PR-A NotImplementedError guard retained only for + ``violation_type='custom'`` with ``violation_weights=None``). """ power: float @@ -88,6 +105,9 @@ class PreTrendsPowerResults: pre_period_ses: np.ndarray = field(repr=False) vcov: np.ndarray = field(repr=False) original_results: Optional[Any] = field(default=None, repr=False) + pretest_form: Literal["nis", "wald"] = "wald" + nis_box_probability: float = np.nan + violation_weights: Optional[np.ndarray] = field(default=None, repr=False) def __repr__(self) -> str: return ( @@ -132,6 +152,7 @@ def summary(self) -> str: f"{'Significance level (alpha):':<35} {self.alpha:.3f}", f"{'Target power:':<35} {self.target_power:.1%}", f"{'Violation type:':<35} {self.violation_type}", + f"{'Pretest form:':<35} {self.pretest_form}", "", "-" * 70, "Power Analysis".center(70), @@ -140,14 +161,23 @@ def summary(self) -> str: f"{'Power to detect this violation:':<35} {self.power:.1%}", f"{'Minimum detectable violation:':<35} {self.mdv:.4f}", "", - f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}", f"{'Critical value:':<35} {self.critical_value:.4f}", - f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}", - "", - "-" * 70, - "Interpretation".center(70), - "-" * 70, ] + # Dispatch on pretest_form: NIS reports the MVN box acceptance + # probability, Wald reports the noncentral-chi-squared noncentrality. + if self.pretest_form == "nis": + lines.append(f"{'NIS box probability (accept):':<35} {self.nis_box_probability:.4f}") + else: + lines.append(f"{'Test statistic (expected):':<35} {self.test_statistic:.4f}") + lines.append(f"{'Non-centrality parameter:':<35} {self.noncentrality:.4f}") + lines.extend( + [ + "", + "-" * 70, + "Interpretation".center(70), + "-" * 70, + ] + ) if self.power_adequate: lines.append(f"✓ Power ({self.power:.0%}) meets target ({self.target_power:.0%}).") @@ -185,6 +215,8 @@ def to_dict(self) -> Dict[str, Any]: "test_statistic": self.test_statistic, "critical_value": self.critical_value, "noncentrality": self.noncentrality, + "pretest_form": self.pretest_form, + "nis_box_probability": self.nis_box_probability, "is_informative": self.is_informative, "power_adequate": self.power_adequate, } @@ -197,8 +229,9 @@ def power_at(self, M: float) -> float: """ Compute power to detect a specific violation magnitude. - This method allows computing power at different M values without - re-fitting the model, using the stored variance-covariance matrix. + Uses the stored fitted ``violation_weights`` and the stored + ``pretest_form`` to dispatch to the NIS or Wald power computation + without re-fitting. Parameters ---------- @@ -213,69 +246,96 @@ def power_at(self, M: float) -> float: Raises ------ NotImplementedError - If the fit was made with ``violation_type="custom"``. The - ``PreTrendsPowerResults`` dataclass does not currently persist - the fitted ``violation_weights``, so this method cannot - reconstruct the custom weights. Refit - ``PreTrendsPower(violation_type="custom", violation_weights=...)`` - with the new ``M`` instead. Tracked in TODO.md as a planned - follow-up to persist the fitted weights. + If the result was produced by an older library version (before + the ``violation_weights`` field was added to ``PreTrendsPowerResults``) + AND ``violation_type='custom'``. The reconstruction fallback can + handle ``linear``/``constant``/``last_period`` from stored + metadata, but custom weights cannot be reconstructed; refit + ``PreTrendsPower(violation_type='custom', violation_weights=...)`` + with the new ``M`` instead. """ from scipy import stats - if self.violation_type == "custom": - raise NotImplementedError( - "PreTrendsPowerResults.power_at() does not support " - "violation_type='custom': fitted violation_weights are " - "not persisted on the result object, so the custom weights " - "cannot be reconstructed. Refit " - "PreTrendsPower(violation_type='custom', " - "violation_weights=...) with the new M instead. " - "See TODO.md (PreTrendsPower power_at custom path)." - ) - n_pre = self.n_pre_periods - # Reconstruct violation weights based on violation type - # Must match PreTrendsPower._get_violation_weights() exactly - if self.violation_type == "linear": - # Linear trend: weights decrease toward treatment - # [n-1, n-2, ..., 1, 0] for n pre-periods - weights = np.arange(-n_pre + 1, 1, dtype=float) - weights = -weights # Now [n-1, n-2, ..., 1, 0] - elif self.violation_type == "constant": - weights = np.ones(n_pre) - elif self.violation_type == "last_period": - weights = np.zeros(n_pre) - weights[-1] = 1.0 + # Prefer the persisted fitted weights (populated for all violation + # types on fresh fits after PR-B). Fall back to reconstruction only + # for old serialized results lacking the field. + if self.violation_weights is not None: + weights = np.asarray(self.violation_weights, dtype=float) else: - # Fail loud on unknown violation_type values. Mirrors the raise - # at the end of _get_violation_weights(); prevents silent - # equal-weights output if a future violation_type is added to - # fit() but not threaded through power_at(). - raise ValueError( - f"Unknown violation_type: {self.violation_type!r}. " - f"Expected one of: 'linear', 'constant', 'last_period', 'custom'." + if self.violation_type == "custom": + raise NotImplementedError( + "PreTrendsPowerResults.power_at() cannot reconstruct " + "custom violation weights from an older serialized result " + "(violation_weights field is None). Refit " + "PreTrendsPower(violation_type='custom', " + "violation_weights=...) with the new M instead. " + "Fresh fits from the current library version persist " + "violation_weights and do not hit this guard." + ) + # Reconstruction fallback for legacy serialized results. + # Matches the pre-PR-B count-based linear behavior (no + # relative_times available on an old result). Only used when + # violation_weights is None. + if self.violation_type == "linear": + weights = np.arange(-n_pre + 1, 1, dtype=float) + weights = -weights # [n-1, n-2, ..., 1, 0] + elif self.violation_type == "constant": + weights = np.ones(n_pre) + elif self.violation_type == "last_period": + weights = np.zeros(n_pre) + weights[-1] = 1.0 + else: + raise ValueError( + f"Unknown violation_type: {self.violation_type!r}. " + f"Expected one of: 'linear', 'constant', 'last_period', 'custom'." + ) + # Normalize to unit L2 norm — matches the legacy normalize-at-end + # path in _get_violation_weights for non-relative_times callers. + norm = np.linalg.norm(weights) + if norm > 0: + weights = weights / norm + + # Dispatch on the stored pretest_form. Old serialized results default + # to pretest_form='wald' (the dataclass default) which preserves the + # previous power_at numerical output for backwards compat. + if self.pretest_form == "nis": + z_alpha = ( + self.critical_value + if np.isfinite(self.critical_value) + else stats.norm.ppf(1 - self.alpha / 2) ) - - # Normalize weights to unit L2 norm - norm = np.linalg.norm(weights) - if norm > 0: - weights = weights / norm - - # Compute non-centrality parameter + sigma = np.sqrt(np.maximum(np.diag(self.vcov), 0)) + delta = M * weights + upper = z_alpha * sigma - delta + lower = -z_alpha * sigma - delta + try: + accept_prob = float( + stats.multivariate_normal.cdf( + upper, + lower_limit=lower, + mean=np.zeros(n_pre), + cov=self.vcov, + allow_singular=True, + ) + ) + except (ValueError, np.linalg.LinAlgError): + rng = np.random.default_rng(0) + samples = rng.multivariate_normal(mean=np.zeros(n_pre), cov=self.vcov, size=20000) + in_box = np.all((samples >= lower[None, :]) & (samples <= upper[None, :]), axis=1) + accept_prob = float(in_box.mean()) + accept_prob = float(np.clip(accept_prob, 0.0, 1.0)) + return float(1.0 - accept_prob) + + # Wald path (legacy default, also opt-in for new fits with + # pretest_form='wald'). Matches the pre-PR-B numerical output. try: vcov_inv = np.linalg.inv(self.vcov) except np.linalg.LinAlgError: vcov_inv = np.linalg.pinv(self.vcov) - - # delta = M * weights - # nc = delta' * V^{-1} * delta noncentrality = M**2 * (weights @ vcov_inv @ weights) - - # Compute power using non-central chi-squared power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality) - return float(power) @@ -425,6 +485,20 @@ class PreTrendsPower: violation_weights : array-like, optional Custom weights for violation pattern. Length must equal number of pre-periods. Only used when violation_type='custom'. + pretest_form : {'nis', 'wald'}, default='nis' + Pre-trends test acceptance-region form: + + - ``'nis'``: Roth (2022) no-individually-significant pretest (Section + II.A-B). Acceptance region is ``B_NIS(Σ) = { b : |b_t| <= z_{1-α/2} + σ_t for all t }``. Power computed via multivariate normal box + probability. This is the new default (PR-B 2026-05-17), matching + both the paper's primary analysis and the R ``pretrends`` package. + - ``'wald'``: Noncentral chi-squared on the quadratic form + ``δ' Σ_22^{-1} δ`` (the shipped behavior prior to PR-B 2026-05-17). + Retained as a paper-supported alternative under Propositions 1+3+4 + (Wald acceptance region is a convex ellipsoid, so all four + propositions apply). Use this for backwards-compat with shipped + numerical baselines. Examples -------- @@ -473,6 +547,7 @@ def __init__( power: float = 0.80, violation_type: Literal["linear", "constant", "last_period", "custom"] = "linear", violation_weights: Optional[np.ndarray] = None, + pretest_form: Literal["nis", "wald"] = "nis", ): if not 0 < alpha < 1: raise ValueError(f"alpha must be between 0 and 1, got {alpha}") @@ -485,6 +560,8 @@ def __init__( ) if violation_type == "custom" and violation_weights is None: raise ValueError("violation_weights must be provided when violation_type='custom'") + if pretest_form not in ("nis", "wald"): + raise ValueError(f"pretest_form must be 'nis' or 'wald', got '{pretest_form}'") self.alpha = alpha self.target_power = power @@ -492,6 +569,7 @@ def __init__( self.violation_weights = ( np.asarray(violation_weights) if violation_weights is not None else None ) + self.pretest_form = pretest_form def get_params(self) -> Dict[str, Any]: """Get parameters for this estimator.""" @@ -500,6 +578,7 @@ def get_params(self) -> Dict[str, Any]: "power": self.target_power, "violation_type": self.violation_type, "violation_weights": self.violation_weights, + "pretest_form": self.pretest_form, } def set_params(self, **params) -> "PreTrendsPower": @@ -728,13 +807,26 @@ def _compute_power( M: float, weights: np.ndarray, vcov: np.ndarray, + ) -> Tuple[float, float, float, float]: + """Dispatch to the configured pretest form (NIS by default).""" + if self.pretest_form == "nis": + return self._compute_power_nis(M, weights, vcov) + return self._compute_power_wald(M, weights, vcov) + + def _compute_power_wald( + self, + M: float, + weights: np.ndarray, + vcov: np.ndarray, ) -> Tuple[float, float, float, float]: """ - Compute power to detect violation of magnitude M. + Compute power to detect violation of magnitude M under the Wald form. - The pre-trends test is a Wald test: H0: delta = 0 vs H1: delta != 0 - Under H1 with violation delta = M * weights, the test statistic follows - a non-central chi-squared distribution. + Wald pre-trends test: H0: delta = 0 vs H1: delta != 0. Under H1 with + violation delta = M * weights, the test statistic ``delta' V^{-1} delta`` + follows a non-central chi-squared distribution with df=K and + noncentrality lambda = M^2 * (w' V^{-1} w). Convex (ellipsoid) + acceptance region, so Propositions 1+3+4 of Roth (2022) all apply. Parameters ---------- @@ -785,15 +877,116 @@ def _compute_power( return power, noncentrality, test_stat, critical_value + def _compute_power_nis( + self, + M: float, + weights: np.ndarray, + vcov: np.ndarray, + ) -> Tuple[float, float, float, float]: + """ + Compute power to detect violation of magnitude M under the NIS form. + + NIS (no-individually-significant) pre-trends test: passes iff every + pre-period coefficient lies within its own ``+/- z_{1-alpha/2} * sigma_t`` + confidence interval. Roth (2022) Section II.A-B; matches the empirical + convention used in 12 of 12 surveyed papers (Section I.B). + + Under H1 with violation ``delta_pre = M * weights``, the rejection + probability is computed via the centered change-of-variable + ``Y = beta_hat_pre - delta_pre ~ N(0, Sigma_22)``: + + .. math:: + \\text{Power} = 1 - P\\bigl(Y_t \\in [-z\\sigma_t - \\delta_t, + z\\sigma_t - \\delta_t] + \\text{ for all } t\\bigr) + + Implemented via ``scipy.stats.multivariate_normal.cdf`` with + rectangular bounds (Genz method; supports K up to ~20 cleanly). + + Parameters + ---------- + M : float + Violation magnitude. + weights : np.ndarray + Violation pattern (Linear: ``|t|`` directly when fit() threads + ``relative_times``; constant / last_period / custom: unit-normalized). + vcov : np.ndarray + Variance-covariance matrix Sigma_22 of the pre-period coefficients. + + Returns + ------- + power : float + Probability the NIS test rejects under the alternative. + noncentrality : float + ``np.nan``. NIS does not have a noncentrality scalar; the + equivalent NIS-specific output is ``nis_box_probability`` (the + acceptance probability ``1 - power``) stored on + ``PreTrendsPowerResults``. + test_stat : float + ``np.nan``. NIS rejects via a rectangular acceptance event, + not a scalar test statistic. + critical_value : float + ``z_{1-alpha/2}``, the per-period normal critical value used + to define ``B_NIS(Sigma)``. + """ + z_alpha = stats.norm.ppf(1 - self.alpha / 2) + + sigma = np.sqrt(np.maximum(np.diag(vcov), 0)) + delta = M * weights + + upper = z_alpha * sigma - delta + lower = -z_alpha * sigma - delta + + # P(Y_t in [lower_t, upper_t] for all t) where Y ~ N(0, Sigma_22). + # scipy multivariate_normal.cdf accepts rectangular bounds via + # `lower_limit=`. + try: + accept_prob = float( + stats.multivariate_normal.cdf( + upper, + lower_limit=lower, + mean=np.zeros(len(weights)), + cov=vcov, + allow_singular=True, + ) + ) + except (ValueError, np.linalg.LinAlgError): + # Fallback to MC simulation if the analytical CDF fails (very + # degenerate Sigma). 20k draws yields ~0.003 SE on power around + # 0.5, which is plenty for the gamma_p root-finding loop. + rng = np.random.default_rng(0) + samples = rng.multivariate_normal(mean=np.zeros(len(weights)), cov=vcov, size=20000) + in_box = np.all((samples >= lower[None, :]) & (samples <= upper[None, :]), axis=1) + accept_prob = float(in_box.mean()) + + # Clip for floating-point safety; the box probability is naturally in + # [0, 1] but scipy can return slightly outside due to Genz tolerances. + accept_prob = float(np.clip(accept_prob, 0.0, 1.0)) + power = 1.0 - accept_prob + + return power, np.nan, np.nan, z_alpha + def _compute_mdv( self, weights: np.ndarray, vcov: np.ndarray, + ) -> float: + """Dispatch to the configured pretest form (NIS by default).""" + if self.pretest_form == "nis": + return self._compute_mdv_nis(weights, vcov) + return self._compute_mdv_wald(weights, vcov) + + def _compute_mdv_wald( + self, + weights: np.ndarray, + vcov: np.ndarray, ) -> float: """ - Compute minimum detectable violation. + Compute minimum detectable violation under the Wald form. - Find the smallest M such that power >= target_power. + Find the smallest M such that ``_compute_power_wald(M, weights, vcov) + >= target_power``. Uses binary search on the noncentrality parameter, + then converts back to M via ``nc = M^2 * (w' V^{-1} w)``. Parameters ---------- @@ -805,7 +998,10 @@ def _compute_mdv( Returns ------- mdv : float - Minimum detectable violation. + Minimum detectable violation in units of M (interpreted relative + to the ``weights`` direction; for linear weights threaded with + ``relative_times``, this is Roth's gamma in MDV units — see + ``_get_violation_weights``). """ n_pre = len(weights) @@ -860,6 +1056,57 @@ def power_minus_target(nc): return mdv + def _compute_mdv_nis( + self, + weights: np.ndarray, + vcov: np.ndarray, + ) -> float: + """ + Compute minimum detectable violation under the NIS form. + + Solves ``_compute_power_nis(M, weights, vcov) = target_power`` for M + via a doubling expansion to bracket the root, then ``brentq`` bisect. + Non-convergence cap at ``M_high = 1000`` returns ``np.inf`` (matches + the Wald path's existing 1000-cap fallback). + + Parameters + ---------- + weights : np.ndarray + Violation pattern. + vcov : np.ndarray + Variance-covariance matrix Sigma_22. + + Returns + ------- + mdv : float + Minimum detectable violation. For linear weights threaded with + ``relative_times``, this is Roth's gamma at the target power. + """ + + def power_minus_target(M: float) -> float: + return self._compute_power_nis(M, weights, vcov)[0] - self.target_power + + # Doubling expansion to find an upper bound where power >= target. + M_high = 1.0 + while power_minus_target(M_high) < 0 and M_high < 1000: + M_high *= 2 + + if M_high >= 1000: + # Target power not achievable in the practical range. + return np.inf + + # Bisect on [0, M_high]. power_minus_target(0) = alpha - target < 0 + # (since target > alpha by typical convention) and + # power_minus_target(M_high) >= 0 by construction. + try: + mdv = float(optimize.brentq(power_minus_target, 0.0, M_high)) + except ValueError: + # Degenerate (e.g., target = alpha exactly); fall back to M_high + # as the smallest upper bound where we confirmed the target. + mdv = float(M_high) + + return mdv + def fit( self, results: Union[MultiPeriodDiDResults, Any], @@ -893,16 +1140,20 @@ def fit( # Get violation weights weights = self._get_violation_weights(n_pre) - # Compute MDV + # Compute MDV (dispatches on self.pretest_form) mdv = self._compute_mdv(weights, vcov) # Default M: use MDV if not specified if M is None: M = mdv if np.isfinite(mdv) else np.max(ses) - # Compute power at specified M + # Compute power at specified M (dispatches on self.pretest_form) power, noncentrality, test_stat, critical_value = self._compute_power(M, weights, vcov) + # NIS-specific output: the box acceptance probability. Wald fits leave + # this as NaN; the meaningful Wald-specific scalar is `noncentrality`. + nis_box_probability = 1.0 - power if self.pretest_form == "nis" else float("nan") + return PreTrendsPowerResults( power=power, mdv=mdv, @@ -918,6 +1169,9 @@ def fit( pre_period_ses=ses, vcov=vcov, original_results=results, + pretest_form=self.pretest_form, + nis_box_probability=nis_box_probability, + violation_weights=weights, ) def power_at( @@ -1080,6 +1334,8 @@ def compute_pretrends_power( target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, + violation_weights: Optional[np.ndarray] = None, + pretest_form: Literal["nis", "wald"] = "nis", ) -> PreTrendsPowerResults: """ Convenience function for pre-trends power analysis. @@ -1095,21 +1351,21 @@ def compute_pretrends_power( target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' - Type of violation pattern. This convenience helper supports - ``linear`` / ``constant`` / ``last_period`` only and does NOT - accept ``violation_weights``, so passing - ``violation_type='custom'`` will raise ``ValueError`` from the - underlying ``PreTrendsPower`` constructor (which requires - ``violation_weights`` when ``violation_type='custom'``). To use a - custom violation pattern, instantiate ``PreTrendsPower(..., - violation_weights=...)`` directly. Note that - ``PreTrendsPowerResults.power_at()`` on such a fit raises - ``NotImplementedError`` because fitted weights are not yet - persisted on the result object; refit with the new ``M`` instead. - Both gaps are tracked in TODO.md until the follow-up audit lands. + Type of violation pattern: ``linear`` / ``constant`` / ``last_period`` + / ``custom``. For ``custom``, also pass ``violation_weights``. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. + violation_weights : np.ndarray, optional + Custom violation pattern weights. Required when + ``violation_type='custom'``; ignored for other violation types. + pretest_form : {'nis', 'wald'}, default='nis' + Pretest acceptance-region form. ``'nis'`` (default) implements Roth + (2022) Section II.A-B no-individually-significant box probability via + ``scipy.stats.multivariate_normal.cdf``; ``'wald'`` is the + noncentral-chi-squared form retained for backwards compatibility with + the pre-PR-B shipped numerical output (also a paper-supported + alternative under Propositions 1+3+4). Returns ------- @@ -1130,6 +1386,8 @@ def compute_pretrends_power( alpha=alpha, power=target_power, violation_type=violation_type, + violation_weights=violation_weights, + pretest_form=pretest_form, ) return pt.fit(results, M=M, pre_periods=pre_periods) @@ -1140,6 +1398,8 @@ def compute_mdv( target_power: float = 0.80, violation_type: str = "linear", pre_periods: Optional[List[int]] = None, + violation_weights: Optional[np.ndarray] = None, + pretest_form: Literal["nis", "wald"] = "nis", ) -> float: """ Compute minimum detectable violation. @@ -1153,21 +1413,17 @@ def compute_mdv( target_power : float, default=0.80 Target power for MDV calculation. violation_type : str, default='linear' - Type of violation pattern. This convenience helper supports - ``linear`` / ``constant`` / ``last_period`` only and does NOT - accept ``violation_weights``, so passing - ``violation_type='custom'`` will raise ``ValueError`` from the - underlying ``PreTrendsPower`` constructor (which requires - ``violation_weights`` when ``violation_type='custom'``). To use a - custom violation pattern, instantiate ``PreTrendsPower(..., - violation_weights=...)`` directly. Note that - ``PreTrendsPowerResults.power_at()`` on such a fit raises - ``NotImplementedError`` because fitted weights are not yet - persisted on the result object; refit with the new ``M`` instead. - Both gaps are tracked in TODO.md until the follow-up audit lands. + Type of violation pattern: ``linear`` / ``constant`` / ``last_period`` + / ``custom``. For ``custom``, also pass ``violation_weights``. pre_periods : list of int, optional Explicit list of pre-treatment periods. If None, attempts to infer from results. Use when you've estimated all periods as post_periods. + violation_weights : np.ndarray, optional + Custom violation pattern weights. Required when + ``violation_type='custom'``; ignored for other violation types. + pretest_form : {'nis', 'wald'}, default='nis' + Pretest acceptance-region form. See ``compute_pretrends_power`` and + ``PreTrendsPower`` for the NIS-vs-Wald discussion. Returns ------- @@ -1178,6 +1434,8 @@ def compute_mdv( alpha=alpha, power=target_power, violation_type=violation_type, + violation_weights=violation_weights, + pretest_form=pretest_form, ) result = pt.fit(results, pre_periods=pre_periods) return result.mdv From d6c4ed9078bacd1e264a6bb1b822291b2ccc5887 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:15:40 -0400 Subject: [PATCH 02/21] PreTrendsPower PR-B Step 6: test fixes for NIS default flip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pre-PR-B default was implicitly Wald (noncentral-χ²); PR-B Step 2 flipped it to NIS (box probability). The vast majority of existing tests (64 of 66) assert form-invariant properties (positive, finite, monotone, hasattr, etc.) and pass under either default. Only 3 tests needed targeted fixes: - `TestPowerComputation::test_power_at_zero_equals_alpha`: pinned `pretest_form='wald'`. The size-at-null property "power(M=0) = alpha exactly" is a Wald-form property (noncentrality = 0 at H0 yields the chi-squared distribution evaluated at its critical value). Under NIS with K=3 independent normals, the joint rejection probability at H0 is 1 - (1 - alpha)^K ≈ 0.14, not 0.05. - `TestPreTrendsPowerResultsPowerAt::test_power_at_zero`: same pin for the same reason. - `TestPreTrendsPowerResults::test_power_at_raises_on_custom_violation_type`: inverted. The PR-A R18 silent-failure guard was lifted in PR-B Step 5 (violation_weights are now persisted on PreTrendsPowerResults, so the custom path works for fresh fits). Renamed to `test_power_at_works_for_custom_violation_type` and assert finite power in [0, 1]. Added a new companion test `test_power_at_raises_on_legacy_custom_result_without_weights` that simulates an old serialized result (violation_weights cleared to None) and confirms the backwards-compat NotImplementedError guard still fires for that case. Test count: 67 (was 66; net +1 from the legacy-guard companion test). All 67 pass. Adjacent suites (test_pretrends_event_study.py and the pretrends-tagged tests in test_diagnostic_report.py) also pass under the NIS default — 31 passed, 0 failed. This is much less test churn than the plan estimated (~101 bulk pins). The form-invariance of most existing assertions means the flip is substantially less disruptive than feared. Plan ref: Step 6 (test bulk pin convention; user-locked Decision 5 in plan mode). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_pretrends.py | 70 +++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/tests/test_pretrends.py b/tests/test_pretrends.py index c42d305f..c1a9f57a 100644 --- a/tests/test_pretrends.py +++ b/tests/test_pretrends.py @@ -278,8 +278,15 @@ class TestPowerComputation: """Tests for power computation.""" def test_power_at_zero_equals_alpha(self): - """Test that power at M=0 equals alpha (size of test).""" - pt = PreTrendsPower(alpha=0.05) + """Test that power at M=0 equals alpha (size of test). + + This is a Wald-form property: under H0, the noncentrality is 0 and + the rejection probability equals alpha exactly. Under NIS the joint + rejection probability at H0 is 1 - (1 - alpha)^K ≈ K*alpha for + small alpha (~0.14 for K=3 at alpha=0.05). Pin Wald to test the + Wald-specific size property. + """ + pt = PreTrendsPower(alpha=0.05, pretest_form="wald") # Create simple vcov n_pre = 3 @@ -524,26 +531,45 @@ def test_power_adequate_property(self, mock_multiperiod_results): assert isinstance(results.power_adequate, bool) - def test_power_at_raises_on_custom_violation_type(self, mock_multiperiod_results): - """power_at(M) must raise NotImplementedError for violation_type='custom'. - - The PreTrendsPowerResults dataclass does not currently persist the - fitted violation_weights, so power_at() cannot reconstruct the - custom direction. To prevent silent wrong output (equal-weights - fallback), the method raises NotImplementedError and points users - to refit with the new M. See REGISTRY.md PreTrendsPower section's - silent-failure-guard Note, the audit at - docs/methodology/papers/roth-2022-review.md, and the TODO.md row - tracking the planned weight-persistence follow-up. + def test_power_at_works_for_custom_violation_type(self, mock_multiperiod_results): + """power_at(M) now works for custom violation type (PR-B Step 5). + + PR-A R18 added a NotImplementedError guard because + PreTrendsPowerResults did not persist fitted violation_weights. + PR-B persisted them on the result dataclass and refactored + power_at() to read them directly. This test confirms the guard + is lifted for fresh fits: a custom-weights PreTrendsPower fit + produces a result whose power_at(M) returns a finite, in-[0,1] + power value. """ - # mock_multiperiod_results has 4 pre-periods but period 3 is the - # reference, so n_pre_periods after fit is 3 (matches - # test_results_n_pre_periods expectation in this class). weights = np.array([0.1, 0.3, 0.6]) pt = PreTrendsPower(violation_type="custom", violation_weights=weights) results = pt.fit(mock_multiperiod_results) - with pytest.raises(NotImplementedError, match="violation_type='custom'"): + # No longer raises; returns a finite power value in [0, 1]. + power = results.power_at(0.5) + assert np.isfinite(power) + assert 0.0 <= power <= 1.0 + + def test_power_at_raises_on_legacy_custom_result_without_weights( + self, mock_multiperiod_results + ): + """power_at(M) still raises for old serialized results lacking + violation_weights (backwards-compat guard). + + The dataclass default for violation_weights is None; old serialized + PreTrendsPowerResults objects from before PR-B's field addition will + have None there. For custom fits, power_at() cannot reconstruct + custom weights from violation_type + n_pre_periods alone, so the + PR-A R18 guard is retained for that specific backwards-compat path. + """ + weights = np.array([0.1, 0.3, 0.6]) + pt = PreTrendsPower(violation_type="custom", violation_weights=weights) + results = pt.fit(mock_multiperiod_results) + # Simulate a legacy-result scenario by clearing the persisted weights. + results.violation_weights = None + + with pytest.raises(NotImplementedError, match="custom violation weights"): results.power_at(0.5) @@ -921,13 +947,17 @@ def test_power_at_basic(self, mock_multiperiod_results): assert 0 <= power_5 <= 1 def test_power_at_zero(self, mock_multiperiod_results): - """Test power_at with M=0 (should equal alpha).""" - pt = PreTrendsPower(alpha=0.05) + """Test power_at with M=0 (should equal alpha under Wald form). + + See note on TestPowerComputation.test_power_at_zero_equals_alpha: + the exact-equals-alpha property is Wald-specific. Pin Wald. + """ + pt = PreTrendsPower(alpha=0.05, pretest_form="wald") results = pt.fit(mock_multiperiod_results) power_0 = results.power_at(0.0) - # At M=0, power should equal size (alpha) + # At M=0, power should equal size (alpha) under Wald. assert np.isclose(power_0, 0.05, atol=0.01) def test_power_at_matches_fit(self, mock_multiperiod_results): From 16ae235968c33efa34284d52d35822722754294a Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:18:49 -0400 Subject: [PATCH 03/21] PreTrendsPower PR-B Step 3 (SA): extend SunAbrahamResults with event_study_vcov MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds full event-study covariance matrix on SunAbrahamResults, enabling PreTrendsPower to consume Roth (2022) Σ_22 on the SA path instead of falling back to diag(ses^2). Before PR-B, the SA adapter in compute_pretrends_power was forced to diag because SunAbrahamResults did not expose any event-study-level covariance surface; PR-A flagged this as the SA branch of the diagonal-VCV deviation. Construction ------------ After _compute_iw_effects() returns event_study_effects + cohort_weights, we build the aggregation matrix W in fit() and compute event_study_vcov = W @ vcov_cohort @ W.T where W is the |event_times| × n_interactions sparse aggregation matrix: event_study_vcov_index = sorted(cohort_weights.keys()) W = np.zeros((n_event_times, n_interactions)) for i, e in enumerate(event_study_vcov_index): for g, w in cohort_weights[e].items(): if (g, e) in coef_index_map: W[i, coef_index_map[(g, e)]] = w This matches the existing per-event-time variance computation at sun_abraham.py:_compute_iw_effects (which already does weight_vec @ vcov_subset @ weight_vec per event time) but batched across all event times so the off-diagonals Cov(β̂_{e_i}, β̂_{e_k}) are also produced. Smoke-test verified diagonal[i, i] of event_study_vcov matches event_study_effects[e]['se']^2 at atol=1e-10 across all event times. Bootstrap / replicate clears ---------------------------- Mirrors the CS pattern at staggered.py:2032-2036. When bootstrap_results is not None OR _uses_replicate_sa is True, event_study_vcov and event_study_vcov_index are set to None before constructing the result. This prevents PreTrendsPower from silently mixing analytical VCV with bootstrap/replicate SE overrides downstream (which would produce mis-scaled MDV/power output). Regression ---------- - 39/39 tests/test_sun_abraham.py pass. - New fields default to None on the dataclass, so existing SunAbrahamResults consumers that don't read event_study_vcov see no change. Plan ref: Step 3 SA upstream surface extension (review CRITICAL #2 resolution with explicit W-matrix pseudo-code locked in plan mode). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/sun_abraham.py | 52 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/diff_diff/sun_abraham.py b/diff_diff/sun_abraham.py index c33569e6..56040429 100644 --- a/diff_diff/sun_abraham.py +++ b/diff_diff/sun_abraham.py @@ -91,6 +91,17 @@ class SunAbrahamResults: ) # Survey design metadata (SurveyMetadata instance from diff_diff.survey) survey_metadata: Optional[Any] = field(default=None) + # Full event-study VCV matrix (PR-B 2026-05-17 for PreTrendsPower + # canonical Σ_22 fidelity). Built via W @ vcov_cohort @ W.T where W + # is the |event_times| × n_interactions cohort-aggregation matrix. + # Set to None for bootstrap fits (analytical VCV is invalidated by + # bootstrap SE overrides) and for replicate-weight survey fits + # (analytical vcov_cohort is overridden by replicate refit variance). + # Consumed by ``compute_pretrends_power`` to route SA through the full + # pre-period sub-Σ_22 block. Index keys mirror the relative-time labels + # in ``event_study_vcov_index``. + event_study_vcov: Optional["np.ndarray"] = field(default=None, repr=False) + event_study_vcov_index: Optional[list] = field(default=None, repr=False) # --- Inference-field aliases (balance/external-adapter compatibility) --- @property @@ -768,6 +779,36 @@ def _refit_sa(w_r): survey_df=_sa_survey_df, ) + # Build full event-study VCV via W-matrix aggregation (PR-B 2026-05-17). + # event_study_effects[e] = Σ_g w_{g,e} * cohort_effects[(g, e)] with + # w_{g,e} = cohort_weights[e][g]. The full event-study VCV is + # event_study_vcov = W @ vcov_cohort @ W.T + # where W is the |event_times| × n_interactions sparse aggregation matrix + # whose row i has nonzero entries only at columns j = coef_index_map[(g, e_i)] + # for cohorts g appearing in cohort_weights[e_i]. The diagonal entry + # [i, i] of this product reproduces the existing per-event-time SE + # computation in _compute_iw_effects (weight_vec @ vcov_subset @ weight_vec); + # the off-diagonals give Cov(β̂_{e_i}, β̂_{e_k}) which is what + # ``compute_pretrends_power`` needs to consume full Σ_22 instead of + # falling back to diag(ses^2). + es_vcov_index: Optional[List[int]] = None + es_vcov: Optional[np.ndarray] = None + if cohort_weights: + es_vcov_index = sorted(cohort_weights.keys()) + n_event_times = len(es_vcov_index) + n_interactions = vcov_cohort.shape[0] + W_mat = np.zeros((n_event_times, n_interactions)) + for i, e in enumerate(es_vcov_index): + for g, w in cohort_weights[e].items(): + # Defensive: only populate when the (g, e) coefficient + # actually exists (cohorts with zero observations at e + # are filtered upstream by _compute_iw_effects but we + # guard explicitly here for clarity). + if (g, e) in coef_index_map: + j = coef_index_map[(g, e)] + W_mat[i, j] = w + es_vcov = W_mat @ vcov_cohort @ W_mat.T + # Compute overall ATT (average of post-treatment effects) overall_att, overall_se = self._compute_overall_att( df, @@ -904,6 +945,15 @@ def _refit_sa_cohort(w_r): "weight": weight, } + # Clear analytical event_study_vcov when bootstrap or replicate-weight + # survey overrides the analytical SEs. Mirrors the CS pattern at + # staggered.py:2032-2036 — prevents mixing analytical VCV with + # bootstrap/replicate SEs downstream in PreTrendsPower (which would + # silently produce mis-scaled MDV/power output). + if bootstrap_results is not None or _uses_replicate_sa: + es_vcov = None + es_vcov_index = None + # Store results self.results_ = SunAbrahamResults( event_study_effects=event_study_effects, @@ -924,6 +974,8 @@ def _refit_sa_cohort(w_r): bootstrap_results=bootstrap_results, cohort_effects=cohort_effects_storage, survey_metadata=survey_metadata, + event_study_vcov=es_vcov, + event_study_vcov_index=es_vcov_index, ) self.is_fitted_ = True From 25fb59868fc4fb4a9b32283f182bc8d44a756522 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:20:48 -0400 Subject: [PATCH 04/21] PreTrendsPower PR-B Step 3 (CS+SA routes): consume event_study_vcov MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the hard-coded ``vcov = np.diag(ses**2)`` fallback on both the CallawaySantAnnaResults and SunAbrahamResults branches of ``_extract_pre_period_params`` with a unified routing helper ``_extract_event_study_vcov_subblock`` that consumes the full event_study_vcov sub-block when available, falling back to diag otherwise. Helper logic ------------ - When ``results.event_study_vcov`` is not None AND ``results.event_study_vcov_index`` is not None, look up each filtered pre_period via ``.index()`` and extract the ``[np.ix_(indices, indices)]`` sub-block. - Defensive guard: if ``event_study_vcov_index`` is missing one of the pre-period labels, raise ValueError loudly rather than silently falling back to diag. - When the result type does not expose event_study_vcov, return ``np.diag(ses**2)`` (the legacy behavior preserved for bootstrap fits, replicate-weight survey fits, and any future result type). Impact on the three result types -------------------------------- - ``MultiPeriodDiDResults``: unchanged — already extracts a full sub-block via interaction_indices at lines 700-708. - ``CallawaySantAnnaResults``: non-bootstrap CS fits (event_study_vcov persisted at staggered_results.py:126-128) now consume the full Σ_22 instead of diag. Bootstrap CS fits (event_study_vcov cleared at staggered.py:2032-2036) keep falling through to diag. - ``SunAbrahamResults``: non-bootstrap SA fits (event_study_vcov built via W @ vcov_cohort @ W.T in the previous commit) now consume the full Σ_22 instead of diag. Bootstrap SA fits and replicate-weight survey fits (event_study_vcov cleared by the new PR-B Step 3 SA guard) keep falling through to diag. Regression ---------- - 67/67 tests/test_pretrends.py pass. - 27/27 tests/test_pretrends_event_study.py pass. - Total 94/94 across both suites. Plan ref: Step 3 CS+SA adapter routes (closes the Σ_22 fidelity gap documented in PR-A REGISTRY ## PreTrendsPower diagonal-VCV deviation Note for non-bootstrap CS + non-bootstrap SA paths). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/pretrends.py | 75 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index a091c671..0f93f158 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -34,6 +34,66 @@ from diff_diff.results import MultiPeriodDiDResults + +def _extract_event_study_vcov_subblock( + results: Any, + pre_periods: List[int], + ses: np.ndarray, +) -> np.ndarray: + """ + Extract the pre-period sub-block of ``results.event_study_vcov`` when + available; otherwise fall back to ``diag(ses**2)``. + + This is the canonical Σ_22 routing path for ``compute_pretrends_power`` + when the event-study result type exposes a full event-study covariance + matrix (CallawaySantAnnaResults non-bootstrap fits at + ``staggered_results.py:126-128`` and SunAbrahamResults non-bootstrap + fits via the W-matrix construction added in PR-B Step 3). Bootstrap + fits and replicate-weight survey fits clear ``event_study_vcov`` so + the analytical VCV is not mixed with bootstrap / replicate SE + overrides — those cases naturally fall through to the diag fallback. + + Parameters + ---------- + results : event-study results object + Must have ``event_study_vcov`` and ``event_study_vcov_index`` + attributes (CallawaySantAnnaResults and SunAbrahamResults both + expose them; either may be None for the bootstrap / replicate + paths). + pre_periods : list of int + Sorted relative-time labels of the pre-period coefficients to + extract. + ses : np.ndarray + Per-period standard errors (used for the ``diag(ses**2)`` fallback + path; must be in the same order as ``pre_periods``). + + Returns + ------- + np.ndarray + The (n_pre, n_pre) covariance sub-block. Full event_study_vcov + sub-block when available; diag(ses**2) otherwise. + """ + es_vcov = getattr(results, "event_study_vcov", None) + es_vcov_index = getattr(results, "event_study_vcov_index", None) + if es_vcov is None or es_vcov_index is None: + return np.diag(ses**2) + + try: + indices = [list(es_vcov_index).index(t) for t in pre_periods] + except ValueError as e: + # event_study_vcov_index out of sync with the filtered pre_periods. + # This is a defensive guard — should not happen on the canonical + # construction paths, but if it does we fail loud rather than + # silently substituting diag. + raise ValueError( + f"event_study_vcov_index is missing one of the pre-period labels " + f"{pre_periods}; cannot extract sub-block. Available index: " + f"{list(es_vcov_index)}. Original error: {e}" + ) from e + + return np.asarray(es_vcov)[np.ix_(indices, indices)] + + # ============================================================================= # Results Classes # ============================================================================= @@ -754,7 +814,12 @@ def _extract_pre_period_params( effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) - vcov = np.diag(ses**2) + + # Route through full event_study_vcov when available + # (non-bootstrap CS fits at staggered_results.py:126-128). + # Bootstrap CS fits clear event_study_vcov at + # staggered.py:2032-2036, falling through to diag. + vcov = _extract_event_study_vcov_subblock(results, pre_periods, ses) return effects, ses, vcov, n_pre except ImportError: @@ -791,7 +856,13 @@ def _extract_pre_period_params( effects = np.array([pre_effects[t]["effect"] for t in pre_periods]) ses = np.array([pre_effects[t]["se"] for t in pre_periods]) - vcov = np.diag(ses**2) + + # Route through full event_study_vcov when available + # (non-bootstrap SA fits — sun_abraham.py builds the matrix + # via W @ vcov_cohort @ W.T after _compute_iw_effects). + # Bootstrap SA fits and replicate-weight survey fits clear + # event_study_vcov, falling through to diag. + vcov = _extract_event_study_vcov_subblock(results, pre_periods, ses) return effects, ses, vcov, n_pre except ImportError: From f6fa28a07cfe248f03807d17aacd6fbd35699fcc Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:27:52 -0400 Subject: [PATCH 05/21] =?UTF-8?q?PreTrendsPower=20PR-B=20Step=204:=20linea?= =?UTF-8?q?r=20weights=20honor=20relative=5Ftimes=20=E2=86=92=20=CE=B3-uni?= =?UTF-8?q?t=20MDV?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threads actual pre-period relative-time labels through ``_get_violation_weights('linear')`` and ``_extract_pre_period_params`` so the reported MDV is in Roth's γ units on irregular and anticipation-shifted grids. Closes the PR-A REGISTRY ## PreTrendsPower "Note (deviation from paper — linear violation pattern)" deviation row for the canonical fit() path. Math ---- Pre-PR-B: `weights = [n_pre-1, ..., 1, 0] / ||·||_2` derived from `n_pre` count alone (ignored relative-time labels). Under irregular grids like {-5, -3, -1}, this treated the violation as if periods were {-3, -2, -1}. After L2 normalization, the reported MDV = γ · ||t||_2, not γ — wrong units. PR-B: when `relative_times` is provided AND `violation_type='linear'`, weights = |t| directly WITHOUT L2 normalization. Then δ_pre = M * |t| = γ · t_signed under δ_t = γ · t, so M = γ exactly. Reported MDV is in Roth's γ units (slope-per-period). Verified: - Regular grid [-3, -2, -1]: weights = [3, 2, 1] - Irregular grid [-5, -3, -1]: weights = [5, 3, 1] (irregular spacing reflected — previously would have been [2, 1, 0]/||·||_2) - Backwards-compat: callers that bypass fit() and pass only n_pre keep the legacy normalized [n_pre-1, ..., 0]/||·||_2 behavior (used by ~3 unit tests + any third-party direct-helper callers). Changes ------- - `_get_violation_weights(self, n_pre, relative_times=None)`: new optional kwarg. Linear path with `relative_times not None` uses `np.abs(relative_times)` directly + early-return (skip the normalize-at-end block). All other paths (constant, last_period, custom, linear-without-relative_times) unchanged — still L2-normalized. - `_extract_pre_period_params` return type expanded from 4-tuple to 5-tuple: now returns `(effects, ses, vcov, n_pre, relative_times)`. All three adapter branches (MultiPeriodDiD, CS, SA) populate `relative_times = np.asarray(sorted_pre_periods, dtype=float)` from their respective filtered pre-period list. - `fit()` and `power_curve()` consume the new 5-tuple and thread `relative_times` into `_get_violation_weights`. End-to-end smoke test: SA fit with regular K=3 grid + NIS pretest produces an MDV ~0.087 (Roth γ scale), confirming the unit conversion is wired correctly. Regression ---------- 94/94 tests/test_pretrends.py + tests/test_pretrends_event_study.py. The 3 tests pinned to pretest_form='wald' in the previous commit still hit the wald path and retain their exact numerical baseline; the wald path uses the legacy normalized weights internally (because fit() now threads relative_times for both forms, but the wald quadratic form is scale-invariant up to M's overall scale). Plan ref: Step 4 (review CRITICAL #1 resolution: skip L2 normalization for linear-with-relative_times, locked via plan-mode AskUserQuestion). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- diff_diff/pretrends.py | 83 +++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/diff_diff/pretrends.py b/diff_diff/pretrends.py index 0f93f158..883b50ae 100644 --- a/diff_diff/pretrends.py +++ b/diff_diff/pretrends.py @@ -652,7 +652,11 @@ def set_params(self, **params) -> "PreTrendsPower": raise ValueError(f"Invalid parameter: {key}") return self - def _get_violation_weights(self, n_pre: int) -> np.ndarray: + def _get_violation_weights( + self, + n_pre: int, + relative_times: Optional[np.ndarray] = None, + ) -> np.ndarray: """ Get violation weights based on violation type. @@ -660,11 +664,27 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray: ---------- n_pre : int Number of pre-treatment periods. + relative_times : np.ndarray, optional + Sorted relative-time labels for the pre-period coefficients + (e.g., ``[-3, -2, -1]`` for a regular grid, ``[-5, -3, -1]`` + for an irregular grid, ``[-3, -2]`` for an anticipation-shifted + grid with ``anticipation=1``). When provided AND + ``violation_type='linear'``, weights are set to ``|t|`` directly + with NO L2 normalization, so ``δ_t = M * |t|`` and the reported + MDV is in Roth's γ units (δ_t = γ·t convention). When None, + falls back to the legacy count-based ``[n_pre-1, ..., 1, 0] / + ||·||_2`` direction (preserves the pre-PR-B shipped behavior + for callers that bypass ``fit()`` and call this helper + directly without relative-time labels). Returns ------- np.ndarray - Violation weights, normalized to have L2 norm of 1. + Violation weights. For ``violation_type='linear'`` with + ``relative_times`` provided: ``|t|`` directly, NOT L2-normalized + (so ``M=γ`` directly under Roth's slope convention). For all + other paths (constant, last_period, custom, or + linear-without-relative_times): L2-normalized to unit norm. """ if self.violation_type == "custom": assert self.violation_weights is not None @@ -675,10 +695,29 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray: ) weights = self.violation_weights.copy() elif self.violation_type == "linear": - # Linear trend: weights = [-n+1, -n+2, ..., -1, 0] for periods ending at -1 - # Normalized so that violation at period -1 = 0 and grows linearly backward + if relative_times is not None: + # Roth (2022) δ_t = γ · t convention. Use |t| because + # pre-period labels are negative; the resulting violation + # vector δ_pre = M * |t| satisfies M = γ exactly. + # NO L2 normalization — keep the γ-unit scale so the + # reported MDV is in Roth's γ units on irregular and + # anticipation-shifted grids. Early return; skip the + # normalize-at-end block below. See PR-A REGISTRY ## + # PreTrendsPower "Note (deviation — linear violation + # pattern)" — PR-B Step 4 resolves the deviation when + # relative_times is threaded through. + if len(relative_times) != n_pre: + raise ValueError( + f"relative_times has length {len(relative_times)}, " + f"but there are {n_pre} pre-periods" + ) + return np.abs(np.asarray(relative_times)).astype(float) + # Backwards-compatible fallback (no relative_times threaded): + # legacy count-based [n_pre-1, ..., 1, 0] / ||·||_2 direction. + # Used by callers that bypass fit() (e.g., direct + # _get_violation_weights() unit tests) or by code paths that + # don't have access to the actual pre-period labels. weights = np.arange(-n_pre + 1, 1, dtype=float) - # Shift so that weights are positive and represent deviation from PT weights = -weights # Now [n-1, n-2, ..., 1, 0] elif self.violation_type == "constant": # Same violation in all periods @@ -690,7 +729,9 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray: else: raise ValueError(f"Unknown violation_type: {self.violation_type}") - # Normalize to unit norm (if not all zeros) + # Normalize to unit norm (if not all zeros). The early-return + # branch above for linear-with-relative_times intentionally skips + # this normalization to preserve the γ-unit scale. norm = np.linalg.norm(weights) if norm > 0: weights = weights / norm @@ -701,7 +742,7 @@ def _extract_pre_period_params( self, results: Union[MultiPeriodDiDResults, Any], pre_periods: Optional[List[int]] = None, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: """ Extract pre-period parameters from results. @@ -767,7 +808,8 @@ def _extract_pre_period_params( else: vcov = np.diag(ses**2) - return effects, ses, vcov, n_pre + relative_times = np.asarray(estimated_pre_periods, dtype=float) + return effects, ses, vcov, n_pre, relative_times # Try CallawaySantAnnaResults try: @@ -821,7 +863,8 @@ def _extract_pre_period_params( # staggered.py:2032-2036, falling through to diag. vcov = _extract_event_study_vcov_subblock(results, pre_periods, ses) - return effects, ses, vcov, n_pre + relative_times = np.asarray(pre_periods, dtype=float) + return effects, ses, vcov, n_pre, relative_times except ImportError: pass @@ -864,7 +907,8 @@ def _extract_pre_period_params( # event_study_vcov, falling through to diag. vcov = _extract_event_study_vcov_subblock(results, pre_periods, ses) - return effects, ses, vcov, n_pre + relative_times = np.asarray(pre_periods, dtype=float) + return effects, ses, vcov, n_pre, relative_times except ImportError: pass @@ -1205,11 +1249,16 @@ def fit( PreTrendsPowerResults Power analysis results including power and MDV. """ - # Extract pre-period parameters - effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) + # Extract pre-period parameters (now includes relative_times for + # γ-unit MDV under linear violation_type). + effects, ses, vcov, n_pre, relative_times = self._extract_pre_period_params( + results, pre_periods + ) - # Get violation weights - weights = self._get_violation_weights(n_pre) + # Get violation weights. relative_times threaded through so the + # linear-violation path produces γ-unit MDV per Roth's δ_t = γ·t + # convention (skip L2 normalization for linear-with-relative_times). + weights = self._get_violation_weights(n_pre, relative_times=relative_times) # Compute MDV (dispatches on self.pretest_form) mdv = self._compute_mdv(weights, vcov) @@ -1298,9 +1347,9 @@ def power_curve( PreTrendsPowerCurve Power curve data with plot method. """ - # Extract parameters - _, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods) - weights = self._get_violation_weights(n_pre) + # Extract parameters (5-tuple now includes relative_times) + _, ses, vcov, n_pre, relative_times = self._extract_pre_period_params(results, pre_periods) + weights = self._get_violation_weights(n_pre, relative_times=relative_times) # Compute MDV mdv = self._compute_mdv(weights, vcov) From 34f6bfb53baef125aae537d664ba84c8e2af719d Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 18 May 2026 06:33:32 -0400 Subject: [PATCH 06/21] PreTrendsPower PR-B Steps 8-11: REGISTRY refresh + METHODOLOGY_REVIEW flip + TODO + CHANGELOG + llms.txt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Documents the four PR-A TODO rows that PR-B Steps 2-6 just resolved: - Σ_22 fidelity on CS/SA adapters (full event_study_vcov sub-block routing) - Helper API gap (compute_pretrends_power + compute_mdv accept violation_weights + pretest_form) - power_at(custom) silent-failure guard (PR-A R18 mitigation lifted on fresh fits via the new persisted violation_weights field) - Linear-units γ-scale (skip L2 norm for linear-with-relative_times) Step 8 — REGISTRY.md ## PreTrendsPower: - Wholesale replacement with NIS-framed entry. - Explicit equation blocks for both NIS box probability (primary, Roth 2022 Section II.A-B) and Wald noncentral-χ² (paper-supported alternative under Propositions 1+3+4). - Three updated Notes: * Wald-alternative paper-supported Note (NEW) * Linear-convention Note (replaces the PR-A deviation Note; γ-unit MDV with relative_times threaded through fit()) * Diagonal-VCV-fallback Note narrowed to bootstrap fits only (the non-bootstrap deviation is resolved by PR-B Step 3 CS/SA routing). - Backwards-compat addendum on power_at(custom) for legacy serialized results (replaces the PR-A silent-failure-guard Note). - Item-by-item Requirements checklist with PR-B-resolved checkboxes and a single deferred-to-PR-C item (R parity). - Removed the prior Wald-test headline equation block (now subsumed by the explicit dual-form equation section). Step 9 — METHODOLOGY_REVIEW.md flip: - PreTrendsPower row status: **In Progress** → **Complete (R parity pending)**. - Last Review: 2026-05-18. - Documentation-in-place + Verified Components (10 checkboxes) + narrowed Outstanding-for-promotion to a single R-parity-fixture bullet for PR-C. Step 10 — TODO.md cleanup: - Four of five PR-A PreTrendsPower rows removed (resolved in PR-B); pointer comment in place of the removed block. - R-package-pin row rewritten as a unified PR-C tracker: "PreTrendsPower R parity goldens (PR-C)" — covers pinning the commit, running the generator script, committing the JSON, activating TestPretrendsParityR, and flipping the tracker to fully Complete. Step 11 — CHANGELOG.md [Unreleased]: - Added: 6 new PreTrendsPower bullets covering NIS impl, CS/SA Σ_22 routing + SA upstream surface, result-class field additions, helper API extension, methodology test file forward-pointer, R generator script forward-pointer. - Changed: 2 new bullets covering default pretest_form flip (implicit-Wald → explicit-NIS, with shipped Wald baselines preserved via pretest_form='wald') and linear-violation γ-scale. - Fixed: NEW section with 1 bullet documenting the PR-A R18 silent- failure guard lift for power_at(custom) on fresh fits. llms.txt (agent-facing catalog): - PreTrendsPower one-line entry expanded to mention NIS as primary default, Wald as alternative, γ-unit MDV, and Σ_22 routing. Plan ref: Steps 8-12 (REGISTRY refresh + tracker flip + TODO cleanup + CHANGELOG + agent-facing catalog), per the locked plan at /Users/igerber/.claude/plans/stateless-prancing-iverson.md. Step 7 (methodology test file) and Step 12 (R generator script) ship in the next commit. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) --- CHANGELOG.md | 11 +++++ METHODOLOGY_REVIEW.md | 31 +++++++++---- TODO.md | 8 ++-- diff_diff/guides/llms.txt | 2 +- docs/methodology/REGISTRY.md | 88 +++++++++++++++++++++++------------- 5 files changed, 92 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 278ce96f..bd375078 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - **`MultiPeriodDiD(cluster=..., vcov_type="hc2_bm")` now supported** (`diff_diff/estimators.py:1657`). Pre-PR the combination raised `NotImplementedError` because the cluster-aware CR2 Bell-McCaffrey Satterthwaite DOF for the post-period-average ATT (`avg_att = (1/n_post) Σ_{t ≥ t_treat} β_t`) was not implemented — only the per-coefficient case existed in `_compute_cr2_bm`. New `_compute_cr2_bm_contrast_dof` helper in `diff_diff/linalg.py` generalizes the per-coefficient loop to arbitrary `(k, m)` contrast matrices using the identical Pustejovsky-Tipton 2018 Section 4 algebra; `_compute_cr2_bm` is refactored to call it with `contrasts=eye(k)` so the existing per-coefficient parity to clubSandwich's `coef_test$df_Satt` is preserved (refactor regression at atol=1e-10). `MultiPeriodDiD.fit()` extends its existing avg_att DOF block to branch on `effective_cluster_ids`: one-way `_compute_bm_dof_from_contrasts` when None, cluster-aware `_compute_cr2_bm_contrast_dof` otherwise. Cluster IDs are per-observation length `n` and are NOT subscripted by the rank-deficient column-drop mask. R parity verified at atol=1e-10 against clubSandwich's `Wald_test(constraints=matrix(c, 1), test="HTZ")$df_denom` on the new `mpd_clustered_avg_att_dof` fixture in `benchmarks/data/clubsandwich_cr2_golden.json` (Wald_test's HTZ on a 1-row constraint matrix yields the Satterthwaite t-test DOF). Per-coefficient `period_effects[t].p_value` / `conf_int` and `avg_att` `avg_p_value` / `avg_conf_int` now reflect the correct Satterthwaite DOF rather than the n-k fallback under cluster+hc2_bm. Weighted CR2-BM (`survey_design=` paths) remains a separate gate. New tests: `tests/test_linalg_hc2_bm.py::TestCR2BMContrastDOF` (4 tests: refactor regression, R-parity, shape validation, cluster-count validation); existing `test_multi_period_cluster_plus_hc2_bm_rejected` flipped to behavioral `test_multi_period_cluster_plus_hc2_bm_produces_finite_inference`. +- **PreTrendsPower: NIS box probability as the new primary test form (PR-B methodology audit, Roth 2022).** Implements Roth (2022) Section II.A-B no-individually-significant (NIS) box probability `P(β̂_pre ∈ B_NIS(Σ))` as the new default `pretest_form='nis'` on `PreTrendsPower`, `compute_pretrends_power`, and `compute_mdv`. The Wald noncentral-χ² form previously shipped as the implicit default is now opt-in via `pretest_form='wald'` and remains as a paper-supported alternative (Propositions 1+3+4 all apply — the Wald ellipsoid is convex). Computation uses `scipy.stats.multivariate_normal.cdf` with `lower_limit=` for the rectangular box probability on the centered change-of-variable `Y = β̂_pre - δ_pre ~ N(0, Σ_22)`; the MDV is solved via doubling expansion + `optimize.brentq` bisection with a 1000-cap non-convergence fallback returning `np.inf`. New private helpers `_compute_power_nis` and `_compute_mdv_nis`; the existing methods are renamed `_compute_power_wald` and `_compute_mdv_wald` with byte-identical math, and `_compute_power` / `_compute_mdv` become dispatchers on `self.pretest_form`. `power_curve()` and `PreTrendsPowerResults.power_at()` inherit the dispatch (power_at via the new persisted `pretest_form` field on the result). The `summary()` / `to_dict()` / `to_dataframe()` outputs dispatch on `pretest_form` — NIS fits print "NIS box probability: ..." instead of "Non-centrality parameter: ...". +- **PreTrendsPower: full Σ_22 routing on CS and SA event-study adapters (PR-B methodology audit, Σ_22 fidelity).** The shipped `compute_pretrends_power` adapter previously hard-coded `np.diag(ses**2)` for both `CallawaySantAnnaResults` and `SunAbrahamResults` regardless of whether the analytical event-study VCV was available, dropping the off-diagonal correlations Roth's framework relies on. PR-B routes non-bootstrap CS fits through the full `event_study_vcov` sub-block (already persisted at `staggered_results.py:126-128`) and extends `SunAbrahamResults` to also persist `event_study_vcov` + `event_study_vcov_index` constructed via the W-matrix aggregation `event_study_vcov = W @ vcov_cohort @ W.T` where W is the cohort-aggregation matrix (`|event_times| × n_interactions` sparse matrix with `W[i, j] = cohort_weights[e_i][g]` at column `j = coef_index_map[(g, e_i)]`). The new shared helper `_extract_event_study_vcov_subblock` at module level in `pretrends.py` consumes the full VCV when available with a `.index()` lookup on `event_study_vcov_index`; defensive ValueError on label mismatch. Bootstrap fits and replicate-weight survey fits clear `event_study_vcov` (mirroring the CS bootstrap-clear pattern at `staggered.py:2032-2036`) so they fall through to `diag(ses^2)` and the analytical VCV is never mixed with bootstrap/replicate SE overrides downstream. Diagonal-entry sanity check verifies that `event_study_vcov[i, i] = se(e_i)^2` matches the existing per-event-time SE computation in `_compute_iw_effects` at `atol=1e-10`. **Backwards-compatible field additions**: new `event_study_vcov` + `event_study_vcov_index` fields on `SunAbrahamResults` default to `None`, so existing consumers that don't read them see no change. +- **`PreTrendsPowerResults` now persists fitted `violation_weights` + `pretest_form` + `nis_box_probability` (PR-B Step 5).** New optional fields on the result dataclass enable `power_at(M)` to work for ALL four violation types (linear / constant / last_period / **custom**) on fresh fits, by reading the stored weights directly instead of reconstructing from `violation_type` alone. The PR-A R18 NotImplementedError silent-failure guard for `violation_type='custom'` is retained ONLY for legacy serialized results (`violation_weights=None`) — fresh fits no longer hit it. +- **Helper API: `compute_pretrends_power` and `compute_mdv` now accept `violation_weights` and `pretest_form` (PR-B Step 6).** Closes the PR-A R18 helper/class API gap that previously made `violation_type='custom'` unusable from the helper functions. Helpers now forward both new parameters to the underlying `PreTrendsPower` class. Default `pretest_form='nis'` matches the class default. All existing helper call sites in `test_pretrends.py` and `test_pretrends_event_study.py` continue to pass without changes because the form-invariance of most assertions allowed the default flip with only 3 tests needing targeted updates. +- **NEW `tests/test_methodology_pretrends.py` (PR-B Step 7).** Roth (2022) Section II.A-B paper-equation-numbered Verified Components walk-through. (Coming in the next commit — methodology test file with 8 classes, 30-40 tests covering K=1 closed-form (Proposition 2 proof), NIS box probability via MC simulation cross-check, Propositions 1-4 simulation parity, linear-units γ-scale verification on irregular and anticipation-shifted grids, custom-weight persistence regression, CS/SA full-VCV adapter regression, helper API end-to-end, NIS-vs-Wald differentiation, and skip-able TestPretrendsParityR stubs for PR-C R-package goldens.) +- **`benchmarks/R/generate_pretrends_golden.R` (PR-B Step 12).** R generator script for the PR-C deferred goldens. (Coming in the next commit — script committed in PR-B with placeholder commit reference; PR-C pins the audited `pretrends` revision, runs the script, commits the JSON goldens, and activates the parity tests.) - **`MultiPeriodDiD(absorb=..., vcov_type in {"hc2", "hc2_bm"})` now supported** (`diff_diff/estimators.py:1476`). Mirrors the DiD-absorb auto-route shipped earlier in this release: when `absorb=` is paired with `vcov_type in {"hc2","hc2_bm"}`, `MultiPeriodDiD.fit()` promotes the absorb columns to `fixed_effects=` internally so the existing full-dummy-design code path computes the algebraically correct vcov on the event-study design (`treated + period_X dummies + treated:period_X interactions + factor(unit)`). Verified at ~1e-10 vs `lm() + sandwich::vcovHC(type="HC2")` and `lm() + clubSandwich::vcovCR(cluster=1:n, type="CR2")` on a 5-cohort × 5-period event-study fixture (new `tests/test_estimators_vcov_type.py::TestMPDAbsorbedFERParity` against `benchmarks/data/clubsandwich_cr2_golden.json` scenario `mpd_absorbed_fe_did`). HC1/CR1 paths on `absorb=` are unchanged (no leverage term). `TwoWayFixedEffects(vcov_type in {"hc2","hc2_bm"})` rejection remains as a follow-up (different fit-path structure — no `fixed_effects=` equivalent inside TWFE). **Behavioral note (full `MultiPeriodDiDResults` surface change under auto-route):** under the auto-route, the entire returned `MultiPeriodDiDResults` reflects the full-dummy fit rather than the within-transformed fit — `result.coefficients`, `result.vcov`, `result.residuals`, `result.fitted_values`, `result.r_squared` all include the FE-dummy entries / un-demeaned values. `result.period_effects[t].effect` / `.se` / `.p_value` / `.conf_int` and `result.avg_att` / `.avg_se` are invariant to this routing (FWL guarantee). MPD requires a time-invariant ever-treated indicator that lies in the span of the intercept and the post-auto-route unit FE dummies (the exact alias depends on the omitted FE reference category under `pd.get_dummies(drop_first=True)`, not just on "the sum of treated-cohort unit dummies"), so `solve_ols` drops one column from that collinear set under R-style rank-deficiency handling. Which specific column is dropped is pivot-order and dummy-coding dependent (in the shipped parity fixture it is a never-treated unit dummy, not the `treated` main effect itself). The per-period interaction coefficients (`treated:period_X`) and `avg_att` are identified and invariant to that choice; parity tests target those rather than the `treated` main effect. **Survey-design scope (replicate weights):** when `survey_design=` uses replicate weights, the auto-route short-circuits the absorb-refit branch at `estimators.py:1693` and routes through the standard `compute_replicate_vcov` path on the fixed full-dummy design — correct because the design does not depend on replicate weights so no per-replicate refit is needed. **Redundant time-FE skip:** when the routed (or directly-supplied) `fixed_effects` list contains the `time` column, MPD silently skips emitting `