diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6f8ec1c5..3a851f73 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
+- **SunAbraham `vcov_type` parameter (Phase 1b PR 1/8).** `SunAbraham(vcov_type=...)` now accepts `{"classical","hc1","hc2","hc2_bm"}` (defaults to `"hc1"`, which preserves prior behavior bit-equally - SA historically hard-coded HC1). Auto-cluster-at-unit dropped when the user opts into explicit `vcov_type="hc2"` or `vcov_type="classical"` (one-way only); preserved for `"hc1"` and `"hc2_bm"`. When `vcov_type in {"classical","hc2","hc2_bm"}`, `_fit_saturated_regression` auto-routes to a full-dummy saturated design (mirrors TWFE Gate 1 from PR #469): FWL preserves cohort coefficients but not the hat matrix, so HC2 leverage and Bell-McCaffrey Satterthwaite DOF must be computed on the full FE projection. Empirically matches R `lm()` summary classical SE, `sandwich::vcovHC(type="HC2")`, and `clubSandwich::vcovCR(..., type="CR2")` + `coef_test()$df_Satt` at atol=1e-10 (cohort SE and BM DOF pinned in `tests/test_methodology_sun_abraham.py`). For `vcov_type="hc2_bm"`, the user-facing aggregated inference (`event_study_effects[e]['p_value']`/`['conf_int']`, `overall_p_value`/`overall_conf_int`) uses CR2 Bell-McCaffrey contrast DOF — matches `clubSandwich::Wald_test(test="HTZ")$df_denom` at atol=1e-10 (mirrors PR #465's `_compute_cr2_bm_contrast_dof` pattern for MultiPeriodDiD's post-period-average ATT). `vcov_type` is now propagated to `SunAbrahamResults.vcov_type` for downstream introspection. `SurveyDesign` (any kind — analytical weights, stratified, PSU, or replicate-weight) combined with `vcov_type in {"classical","hc2","hc2_bm"}` raises `NotImplementedError`: the survey-design TSL (or replicate-weight refit) variance overrides the analytical sandwich family, and the auto-cluster guard for one-way families would silently downgrade unit-level PSUs to per-observation PSUs. Use `vcov_type="hc1"` (default) for survey designs. `conley` rejected at `__init__` with a deferral message (would require threading 6+ `conley_*` params through the saturated regression call). **Deviation from R:** SA's within-transform HC1 SE differs from `fixest::sunab()` by ~1-2% (~2e-3 absolute) on typical panel sizes due to a different `(n-k)` finite-sample correction (fixest counts absorbed FE in k_total; SA's `solve_ols` counts only within-transformed columns); the IW aggregation step is otherwise identical (pinned at atol=5e-3, tracked in TODO.md). First PR of the Phase 1b standalone-estimator threading initiative (7 PRs to follow: StackedDiD, WooldridgeDiD-OLS, CallawaySantAnna, ImputationDiD, TripleDifference, TwoStageDiD, EfficientDiD).
- **PreTrendsPower R `pretrends` parity goldens (PR-C closes PR-B's deferred R-parity row).** JSON goldens at `benchmarks/data/r_pretrends_golden.json` generated from the committed `benchmarks/R/generate_pretrends_golden.R` script against `jonathandroth/pretrends` commit `122731d082` (package version 0.1.0, R 4.5.2). 4 fixtures cover regular K=3 grid (`uniform_3_pre_periods_no_anticipation`), irregular K=3 grid `[-5,-3,-1]` (`irregular_pre_periods` — locks the PR-B Step 4 γ-unit linear-weight fix), anticipation-shifted K=4 grid (`anticipation_shifted`), and K=1 closed form (`single_pre_period_closed_form` — Roth Proposition 2 univariate truncated-normal). `TestPretrendsParityR` in `tests/test_methodology_pretrends.py` now active (4 tests): NIS power vs R `pretrends::pretrends()` at `atol=1e-4` across all 4 fixtures × 4 γ values; γ_p MDV vs R `slope_for_power()` at `atol=1e-4` across all 4 fixtures × 2 target_power values; end-to-end `fit()` on irregular grid vs R γ_p at `atol=1e-4` (locks the full `fit() → _extract_pre_period_params → _get_violation_weights → _compute_mdv_nis` chain through the public API); K=1 three-way cross-check (Python ≡ analytical truncated-normal closed form `1 - Φ(z - γ/σ) + Φ(-z - γ/σ)` at `atol=1e-7`; both within `atol=1e-4` of R). Tolerance rationale: R hardcodes `thresholdTstat.Pretest=1.96` while Python uses `scipy.stats.norm.ppf(0.975) = 1.959963984540054` (`dz ≈ 3.6e-5`); R `slope_for_power` uses `uniroot(tol = .Machine$double.eps^0.25 ≈ 1.22e-4)` versus Python `brentq(xtol=2e-12)`; the inverse-solver tolerance gap dominates γ_p, and `mvtnorm::pmvnorm` (R) vs `scipy.stats.multivariate_normal.cdf` (Python) Genz-Bretz randomized-lattice differences bound the K=4 NIS power gap at ~5e-5. `METHODOLOGY_REVIEW.md` PreTrendsPower row promoted `**Complete** (R parity pending)` → `**Complete**`. Roth (2022) paper review's `R \`pretrends\` package version pin (provisional)` Gaps bullet struck. Closes the PR-C TODO row.
- **`SpilloverDiD(survey_design=...)` integration on HC1 / CR1 paths via Binder TSL (Wave E.1).** Lifts the Wave B/C/D upfront `NotImplementedError` and adds design-based variance for `vcov_type ∈ {"hc1"}` plus `cluster=
` (CR1). **Documented synthesis** of Gerber (2026, arXiv:2605.04124) Proposition 1 — Binder Taylor Series Linearization for IF representations of smooth functionals; explicitly derived for TwoStageDiD in the paper's Appendix — composed with the Wave D Gardner GMM first-stage uncertainty correction (Butts 2021 §3.1 + Gardner 2022 §4) applied to SpilloverDiD's ring-indicator stage-2 design. No reference software combines all ingredients. **Mechanical composition:** SpilloverDiD's per-obs Wave D IF `psi_i = gamma_hat' * X_{10,i} * eps_{10,i} - X_{2,i} * eps_{2,i}` (with survey weights threaded through `gamma_hat` solve, eps construction, and bread inversion via Hájek normalization) is aggregated to PSU totals and passed to the audited `_compute_stratified_meat_from_psu_scores` Binder TSL meat helper. Stage-1 FE estimation extends `_iterative_fe_subset` with a `weights=` kwarg implementing WLS-FE via weighted bincount (numerator `bincount(w*resid)` / denominator `bincount(w)`); the `weights is None` path is bit-identical to the Wave B / C / D unweighted bincount. **Degrees of freedom:** t-distribution lookup uses `ResolvedSurveyDesign.df_survey` (4-way branch: PSU+strata → `n_PSU - n_strata`; PSU only → `n_PSU - 1`; strata only → `n_obs - n_strata`; neither → `n_obs - 1`), threaded through all four `safe_inference` call sites (aggregate `tau_total`, per-ring `delta_j`, event-study per-event-time `tau_k` / `delta_jk`, scalar `att` lincom). **Survey-array subsetting:** when `finite_mask` drops baseline-treated rows, `survey_weights` and `ResolvedSurveyDesign.{weights, strata, psu, fpc, replicate_weights}` are subsetted in parallel; `n_psu`, `n_strata`, and `survey_metadata` are recomputed (mirrors `TwoStageDiD.fit:567-601`). **Cluster + survey resolution:** when `cluster=` and `survey_design.psu` are both supplied with different groupings, a `UserWarning` fires and PSU wins (mirrors `_resolve_effective_cluster` at `survey.py:1253-1275`; TwoStageDiD parity). When `cluster=` is supplied without `survey_design.psu`, the cluster column is injected as the effective PSU via `_inject_cluster_as_psu`, which now honors `SurveyDesign.nest`: under `nest=False`, cluster labels must be globally unique across strata (raises if they repeat, matching the explicit-PSU resolver's contract). **Saturated `df_survey = 0` NaN-fail:** when `lonely_psu="remove"` removes all strata (singleton PSUs), the meat helper returns `(_, var_computed=False, legit_zero=0)` and SpilloverDiD's Wave E.1 path returns NaN meat with a `UserWarning` matching `"df_survey"` so callers can `pytest.warns(UserWarning, match="df_survey")`. This is a **departure from TwoStageDiD** (`two_stage.py:2003-2005`) which currently NaN-fails SILENTLY; Wave E.1 surfaces the diagnostic per `feedback_no_silent_failures`. **Subpopulation limitation (Wave E.3 follow-up):** `SurveyDesign.subpopulation()`-derived designs with zero-weight padding rows that lose stage-1 FE support have those rows physically removed by `finite_mask`, so `n_psu` / `df_survey` / Binder centering reflect the reduced fit sample rather than the full domain design (documented in REGISTRY; Wave E.3 will preserve full-design bookkeeping). **Public surface restrictions:** `vcov_type="conley" + survey_design=` raises `NotImplementedError` pointing at planned Wave E.2 (Conley × survey product-kernel synthesis with within-stratum Conley sandwich on PSU totals); replicate-weight variance (BRR / Fay / JK1 / JKn / SDR) raises `NotImplementedError` — per Gerber (2026) Appendix A, the IF-reweighting shortcut does not apply to TwoStageDiD-class estimators because `gamma_hat` is weight-sensitive; correct support requires per-replicate full re-fit and is queued as a follow-up; non-pweight (`weight_type ∈ {"fweight", "aweight"}`) raises `ValueError` (the Binder TSL assumes probability weights). **Implementation:** `_compute_gmm_corrected_meat` extended with `survey_weights` + `resolved_survey` kwargs at `diff_diff/two_stage.py:56` (TYPE_CHECKING forward reference for `ResolvedSurveyDesign` to avoid circular import); new module-level helper `_compute_binder_tsl_meat` at `diff_diff/two_stage.py` wraps `_compute_stratified_meat_from_psu_scores` with implicit per-obs PSU synthesis for no-PSU survey designs + the Wave E.1 NaN-fail + warning; `_iterative_fe_subset` weighted path at `diff_diff/spillover.py:1382` (in-place extension, bit-identical fallback, positive-weight identification gate); `_inject_cluster_as_psu` honors `nest` (shared survey-helper fix that also benefits TwoStageDiD); `ResolvedSurveyDesign` gains a `nest` field propagated through all 5 construction sites. `SpilloverDiDResults` extended with `survey_metadata`, `n_psu`, `n_strata` fields at `diff_diff/results.py`. **Tests:** new `TestSpilloverDiDWaveE1SurveyDesignHc1` (17 tests: bit-identity fallback, Binder TSL hand-check uniform + non-uniform weights, lonely_psu modes, FPC degenerate limits ×3, saturated NaN-fail with `pytest.warns(match="df_survey")`, cluster+survey warn-and-use-PSU, no-PSU regressions (weights-only, weights+strata, cluster-without-PSU, cluster overlap with nest=False/True), zero-weight Omega_0 exclusion + all-zero raises, replicate-weight + non-pweight + Conley+survey rejections, fit idempotency, finite_mask subsetting) and `TestSpilloverDiDWaveE1SurveyDesignEventStudy` (7 tests: event-study + survey on both `is_staggered` branches with `df_survey` lincom verification, distinguishability between survey-share and sample-share lincom rules via manual reconstruction with cohort-correlated weights + non-constant tau_k, aggregate-vs-event-study parity, drift goldens, subset-path invariant). Wave B/C/D bullets below are unchanged; this entry replaces the pre-Wave-E.1 `survey_design=` rejection.
diff --git a/TODO.md b/TODO.md
index 5ccfaa09..9098ef1e 100644
--- a/TODO.md
+++ b/TODO.md
@@ -98,7 +98,9 @@ Deferred items from PR reviews that were not addressed before merge.
| PreTrendsPower: CS/SA `anticipation=1` R-parity fixture. The PR-C R-parity goldens cover NIS power + γ_p MDV at `atol=1e-4` on four shifted-grid / regular / irregular / K=1 fixtures, but R `pretrends` has no anticipation parameter so the Python-side `_extract_pre_period_params` anticipation filter (`if t < _pre_cutoff` in `pretrends.py` lines 1138-1150 for CS; mirror in SA branch) is not R-parity-locked. Build a synthetic `CallawaySantAnnaResults` (or `SunAbrahamResults`) with `anticipation=1` and a t=-1 event-study entry that should be filtered before reaching `_compute_power_nis`, then assert the resulting γ_p matches R's `slope_for_power()` on the K=4 shifted-grid fixture. Existing PR-B MC-based tests (`TestPretrendsPropositions`) and full-VCV tests (`TestPretrendsCovarianceSource`) already cover the filter mechanically; this would close the loop against R. | `tests/test_methodology_pretrends.py::TestPretrendsParityR`, `benchmarks/R/generate_pretrends_golden.R` | PR-C follow-up | Low |
-| Thread `vcov_type` (classical / hc1 / hc2 / hc2_bm) through the 8 standalone estimators that expose `cluster=`: `CallawaySantAnna`, `SunAbraham`, `ImputationDiD`, `TwoStageDiD`, `TripleDifference`, `StackedDiD`, `WooldridgeDiD`, `EfficientDiD`. Phase 1a added `vcov_type` to the `DifferenceInDifferences` inheritance chain only. | multiple | Phase 1a | Medium |
+| Thread `vcov_type` (classical / hc1 / hc2 / hc2_bm) through the 7 standalone estimators that expose `cluster=` but not yet `vcov_type=`: `CallawaySantAnna`, `ImputationDiD`, `TwoStageDiD`, `TripleDifference`, `StackedDiD`, `WooldridgeDiD`, `EfficientDiD`. Phase 1a added the chain to DiD/MPD/TWFE; Phase 1b PR 1/8 added `SunAbraham` (this row tracks the remaining 7). | multiple | Phase 1b | Medium |
+| Extend `SunAbraham` with `vcov_type="conley"` (Conley spatial-HAC) as a first-class feature: thread `conley_coords` / `conley_cutoff_km` / `conley_metric` / `conley_kernel` / `conley_time` / `conley_unit` / `conley_lag_cutoff` through `_fit_saturated_regression`. Phase 1b PR 1/8 deferred this; SA currently rejects `vcov_type="conley"` at `__init__` with a deferral message. | `diff_diff/sun_abraham.py` | follow-up | Medium |
+| Harmonize SunAbraham's HC1 within-transform finite-sample correction with `fixest::sunab()`. SA's `solve_ols` applies `n / (n - k_dm)` (within-transform columns only); fixest applies `n / (n - k_total)` (counts absorbed FE). SE values differ by ~1-2% on typical panel sizes (documented in REGISTRY.md "Deviation from R"; pinned at `atol=5e-3` in `tests/test_methodology_sun_abraham.py`). Either thread `df_adjustment` into the vcov scaling or document as an intentional difference. | `diff_diff/sun_abraham.py`, `diff_diff/linalg.py::compute_robust_vcov` | follow-up | Low |
| Weighted one-way Bell-McCaffrey (`vcov_type="hc2_bm"` + `weights`, no cluster) currently raises `NotImplementedError`. `_compute_bm_dof_from_contrasts` builds its hat matrix from the unscaled design via `X (X'WX)^{-1} X' W`, but `solve_ols` solves the WLS problem by transforming to `X* = sqrt(w) X`, so the correct symmetric idempotent residual-maker is `M* = I - sqrt(W) X (X'WX)^{-1} X' sqrt(W)`. Rederive the Satterthwaite `(tr G)^2 / tr(G^2)` ratio on the transformed design and add weighted parity tests before lifting the guard. | `linalg.py::_compute_bm_dof_from_contrasts`, `linalg.py::_validate_vcov_args` | Phase 1a | Medium |
| Weighted CR2 Bell-McCaffrey cluster-robust (`vcov_type="hc2_bm"` + `cluster_ids` + `weights`) currently raises `NotImplementedError`. Weighted hat matrix and residual rebalancing need threading per clubSandwich WLS handling. | `linalg.py::_compute_cr2_bm` | Phase 1a | Medium |
| `TwoWayFixedEffects(vcov_type in {"hc2","hc2_bm"})` with replicate-weight survey designs raises `NotImplementedError` (`twfe.py:~233`). The replicate path re-demeans per replicate (re-demeaning depends on the per-replicate weight vector), which doesn't compose with the full-dummy HC2/HC2-BM build — a correct implementation would need per-replicate full-dummy refit. Workaround: use `vcov_type="hc1"` for replicate-weight CR1. | `twfe.py::fit` | follow-up | Low |
diff --git a/benchmarks/R/generate_clubsandwich_golden.R b/benchmarks/R/generate_clubsandwich_golden.R
index a2dc1338..77976fa3 100644
--- a/benchmarks/R/generate_clubsandwich_golden.R
+++ b/benchmarks/R/generate_clubsandwich_golden.R
@@ -283,6 +283,191 @@ output$twfe_two_period <- list(
dof_bm_unit = as.numeric(ct_twfe_unit$df_Satt)
)
+# --- SunAbraham saturated regression HC2 / HC2-BM scenario (Phase 1b PR 1/8) -
+# Mirrors SunAbraham(vcov_type in {"classical","hc2","hc2_bm"}) on a
+# 5-cohort × 8-period balanced panel. SA's Part G auto-route builds a
+# full-dummy saturated design when vcov_type needs the hat matrix —
+# matches lm(y ~ D_ge interactions + factor(unit) + factor(period)).
+# Targets the (g=4, e=0) cohort × event-time interaction (the canonical
+# at-treatment effect of the earliest cohort).
+
+set.seed(42)
+n_units_per_cohort <- 8
+n_sa_periods <- 8
+sa_cohorts <- c(0, 4, 5, 6, 7) # 0 = never-treated
+
+d_sa <- expand.grid(u_in_cohort = seq_len(n_units_per_cohort),
+ period = seq_len(n_sa_periods),
+ cohort_idx = seq_along(sa_cohorts))
+d_sa <- d_sa[order(d_sa$cohort_idx, d_sa$u_in_cohort, d_sa$period), ]
+d_sa$unit <- (d_sa$cohort_idx - 1) * n_units_per_cohort + d_sa$u_in_cohort - 1
+d_sa$first_treat <- sa_cohorts[d_sa$cohort_idx]
+d_sa$time <- d_sa$period
+d_sa$rel_time <- ifelse(d_sa$first_treat > 0,
+ d_sa$time - d_sa$first_treat, -999L)
+sa_unit_fe <- rnorm(max(d_sa$unit) + 1, mean = 0, sd = 1)
+d_sa$treated <- as.integer(d_sa$first_treat > 0 & d_sa$time >= d_sa$first_treat)
+d_sa$y <- sa_unit_fe[d_sa$unit + 1] + 0.3 * d_sa$time +
+ 1.0 * d_sa$treated + rnorm(nrow(d_sa), sd = 0.5)
+
+# Build cohort × event-time interaction columns (excluding ref period -1).
+# Sanitize negative event times for R formula compatibility (e=-3 → "n3").
+sa_treatment_groups <- sort(unique(d_sa$first_treat[d_sa$first_treat > 0]))
+sa_all_rel_times <- sort(unique(d_sa$rel_time[d_sa$first_treat > 0]))
+sa_all_rel_times <- sa_all_rel_times[sa_all_rel_times != -1]
+sa_interaction_cols <- c()
+sa_col_map <- list()
+for (g in sa_treatment_groups) {
+ for (e in sa_all_rel_times) {
+ e_safe <- if (e < 0) paste0("n", abs(e)) else as.character(e)
+ col_name <- paste0("D_", g, "_", e_safe)
+ original_name <- paste0("D_", g, "_", e)
+ ind <- as.integer(d_sa$first_treat == g & d_sa$rel_time == e)
+ if (sum(ind) > 0) {
+ d_sa[[col_name]] <- ind
+ sa_interaction_cols <- c(sa_interaction_cols, col_name)
+ sa_col_map[[original_name]] <- col_name
+ }
+ }
+}
+
+sa_target_orig <- "D_4_0" # the (g=4, e=0) interaction
+sa_target_safe <- sa_col_map[[sa_target_orig]]
+stopifnot(!is.null(sa_target_safe))
+
+sa_rhs <- paste(c(sa_interaction_cols, "factor(unit)", "factor(time)"),
+ collapse = " + ")
+fit_sa <- lm(as.formula(paste("y ~", sa_rhs)), data = d_sa)
+sa_coef_names <- names(coef(fit_sa))
+sa_target_idx <- which(sa_coef_names == sa_target_safe)
+stopifnot(length(sa_target_idx) == 1L)
+
+# Extract SE/DOF for the target only (atol=1e-10 pin in Python tests).
+sa_classical_se <- summary(fit_sa)$coefficients[sa_target_safe, "Std. Error"]
+sa_vcov_hc2 <- sandwich::vcovHC(fit_sa, type = "HC2")
+sa_hc2_se <- sqrt(sa_vcov_hc2[sa_target_safe, sa_target_safe])
+# Singleton-cluster CR2 reduces to one-way HC2-BM.
+sa_vcov_cr2_singleton <- vcovCR(fit_sa, cluster = seq_len(nrow(d_sa)),
+ type = "CR2")
+sa_cr2_singleton_se <- sqrt(sa_vcov_cr2_singleton[sa_target_safe,
+ sa_target_safe])
+sa_ct_singleton <- coef_test(fit_sa, vcov = sa_vcov_cr2_singleton)
+sa_dof_bm_singleton <- sa_ct_singleton[sa_target_safe, "df_Satt"]
+# CR2-BM clustered at unit (the SA auto-cluster default for hc2_bm).
+sa_vcov_cr2_unit <- vcovCR(fit_sa, cluster = d_sa$unit, type = "CR2")
+sa_cr2_unit_se <- sqrt(sa_vcov_cr2_unit[sa_target_safe, sa_target_safe])
+sa_ct_unit <- coef_test(fit_sa, vcov = sa_vcov_cr2_unit)
+sa_dof_bm_unit <- sa_ct_unit[sa_target_safe, "df_Satt"]
+# fixest::sunab() parity for SA's HC1 cluster-at-unit default path.
+# SA HC1 uses within-transform; fixest also uses within-transform.
+# Note: fixest::sunab requires a specific encoding — first_treat=0 means
+# never-treated. fixest auto-handles that.
+suppressPackageStartupMessages(library(fixest, quietly = TRUE))
+fit_sunab <- fixest::feols(
+ y ~ sunab(first_treat, time) | unit + time,
+ data = d_sa,
+ cluster = ~unit
+)
+# fixest::sunab aggregates to event-study coefficients (IW-aggregated
+# across cohorts). The coefficient labels are "time::".
+# Compare SA's event_study_effects[0] (overall e=0 ATT) against fixest's
+# "time::0" event-study SE.
+sunab_coef_table <- as.data.frame(summary(fit_sunab)$coeftable)
+sunab_target_label <- "time::0"
+sunab_hc1_es0_se <- if (sunab_target_label %in% rownames(sunab_coef_table)) {
+ sunab_coef_table[sunab_target_label, "Std. Error"]
+} else {
+ warning("Could not locate fixest sunab event-study target ",
+ sunab_target_label)
+ NA_real_
+}
+
+# CR2-BM Bell-McCaffrey contrast DOF for the IW-aggregated event-time e=0
+# effect (under cluster=unit). The contrast at e=0 aggregates all cohorts
+# present at relative time 0 with weights w_{g,0} = n_{g,0} / Σ_g n_{g,0}.
+# All 4 treated cohorts (g=4,5,6,7) have 8 units each at e=0 → equal
+# weights 0.25 each. Build the contrast in full-coef space and call
+# Wald_test(test="HTZ") — on a 1-row constraint matrix HTZ reduces to a
+# Satterthwaite t-test whose df_denom IS the BM DOF.
+sa_all_coef_names <- names(coef(fit_sa))
+sa_n_coef <- length(sa_all_coef_names)
+sa_es0_contrast <- setNames(rep(0, sa_n_coef), sa_all_coef_names)
+sa_es0_cols <- c("D_4_0", "D_5_0", "D_6_0", "D_7_0")
+sa_es0_contrast[sa_es0_cols] <- 0.25
+# Subset to non-NA coefficients (clubSandwich's convention).
+sa_finite_mask <- !is.na(coef(fit_sa))
+sa_es0_kept <- sa_es0_contrast[sa_finite_mask]
+sa_dof_bm_es0_unit <- Wald_test(
+ fit_sa,
+ constraints = matrix(sa_es0_kept, 1),
+ vcov = sa_vcov_cr2_unit,
+ test = "HTZ"
+)$df_denom
+
+# CR2-BM Bell-McCaffrey contrast DOF for the IW-aggregated OVERALL ATT.
+# SA's overall ATT = Σ_e w_e × Σ_g w_{g,e} × δ_{g,e} where w_e is the
+# mass at post-period event-time e and w_{g,e} is the IW cohort share.
+# Post-period event-times e ∈ {0, 1, 2, 3} on this panel; n_{g,e} = 8
+# for e=0 (all 4 cohorts), 6 for e=1 (3 cohorts), 4 for e=2 (2 cohorts),
+# 2 for e=3 (1 cohort) — actually, per fixest::sunab construction:
+# cohort g treats at time g; observed event-times for cohort g are
+# t - g for t ∈ {1..8}. Compute the cohort × event-time mass matrix
+# empirically.
+# Post-period event-times: SA includes ALL observed e >= 0, not just
+# those where multiple cohorts contribute. For the 4-cohort × 8-period
+# panel, max observed e = 8 - 4 = 4 (cohort g=4 at t=8).
+sa_post_event_times <- sort(unique(d_sa$rel_time[d_sa$first_treat > 0 & d_sa$rel_time >= 0]))
+sa_overall_contrast <- setNames(rep(0, sa_n_coef), sa_all_coef_names)
+sa_per_event_mass <- numeric(length(sa_post_event_times))
+for (i in seq_along(sa_post_event_times)) {
+ e <- sa_post_event_times[i]
+ cohorts_at_e <- sort(unique(d_sa$first_treat[d_sa$first_treat > 0 & d_sa$rel_time == e]))
+ if (length(cohorts_at_e) == 0) next
+ n_per_cohort <- sapply(cohorts_at_e, function(g) sum(d_sa$first_treat == g & d_sa$rel_time == e))
+ sa_per_event_mass[i] <- sum(n_per_cohort)
+}
+sa_post_weights <- sa_per_event_mass / sum(sa_per_event_mass)
+for (i in seq_along(sa_post_event_times)) {
+ e <- sa_post_event_times[i]
+ cohorts_at_e <- sort(unique(d_sa$first_treat[d_sa$first_treat > 0 & d_sa$rel_time == e]))
+ if (length(cohorts_at_e) == 0) next
+ n_per_cohort <- sapply(cohorts_at_e, function(g) sum(d_sa$first_treat == g & d_sa$rel_time == e))
+ iw_weights <- n_per_cohort / sum(n_per_cohort)
+ for (j in seq_along(cohorts_at_e)) {
+ g <- cohorts_at_e[j]
+ e_safe <- if (e < 0) paste0("n", abs(e)) else as.character(e)
+ col_name <- paste0("D_", g, "_", e_safe)
+ sa_overall_contrast[col_name] <- sa_post_weights[i] * iw_weights[j]
+ }
+}
+sa_overall_kept <- sa_overall_contrast[sa_finite_mask]
+sa_dof_bm_overall_unit <- Wald_test(
+ fit_sa,
+ constraints = matrix(sa_overall_kept, 1),
+ vcov = sa_vcov_cr2_unit,
+ test = "HTZ"
+)$df_denom
+
+output$sun_abraham_two_cohort <- list(
+ unit = d_sa$unit,
+ time = d_sa$time,
+ first_treat = d_sa$first_treat,
+ y = d_sa$y,
+ target_cohort_g = 4L,
+ target_event_time_e = 0L,
+ target_col_safe = sa_target_safe,
+ classical_se = unname(sa_classical_se),
+ hc2_se = unname(sa_hc2_se),
+ cr2_bm_singleton_se = unname(sa_cr2_singleton_se),
+ dof_bm_singleton = unname(sa_dof_bm_singleton),
+ cr2_bm_unit_se = unname(sa_cr2_unit_se),
+ dof_bm_unit = unname(sa_dof_bm_unit),
+ sunab_hc1_event_study_e0_se = unname(sunab_hc1_es0_se),
+ sunab_event_study_target_label = sunab_target_label,
+ dof_bm_contrast_es0_unit = unname(sa_dof_bm_es0_unit),
+ dof_bm_contrast_overall_unit = unname(sa_dof_bm_overall_unit)
+)
+
output$meta <- list(
source = "clubSandwich",
clubSandwich_version = as.character(packageVersion("clubSandwich")),
diff --git a/benchmarks/data/clubsandwich_cr2_golden.json b/benchmarks/data/clubsandwich_cr2_golden.json
index 539f5efb..34eb0b2c 100644
--- a/benchmarks/data/clubsandwich_cr2_golden.json
+++ b/benchmarks/data/clubsandwich_cr2_golden.json
@@ -98,11 +98,30 @@
"vcov_cr2_unit": [0.007651392098640002, -0.01530278419727998, -0.007651392098640009, -1.340972185906478e-17, -0.007651392098640004, -1.786362460978703e-17, -0.007651392098640002, -1.652260459528918e-17, -0.007651392098640006, 8.815980097977497e-19, -0.01530278419727999, 0.04018425503992974, 0.02009212751996491, 2.723300676932653e-17, 0.0200921275199649, 3.747012207819093e-17, 0.02009212751996489, 3.337015158794735e-17, 0.0200921275199649, -0.009578686645369807, -0.00765139209864001, 0.02009212751996491, 0.01004606375998247, 1.361650338466344e-17, 0.01004606375998247, 1.873506103909563e-17, 0.01004606375998246, 1.668507579397384e-17, 0.01004606375998247, -0.00478934332268491, -1.340972185906478e-17, 2.723300676932653e-17, 1.361650338466343e-17, 1.668199508743466e-31, 1.361650338466343e-17, 1.656912369884294e-31, 1.361650338466342e-17, 1.631551769724017e-31, 1.361650338466343e-17, -4.135630511972947e-19, -0.007651392098640005, 0.0200921275199649, 0.01004606375998247, 1.361650338466343e-17, 0.01004606375998247, 1.873506103909562e-17, 0.01004606375998246, 1.668507579397383e-17, 0.01004606375998247, -0.004789343322684908, -1.786362460978703e-17, 3.747012207819092e-17, 1.873506103909563e-17, 1.656912369884295e-31, 1.873506103909562e-17, 1.697016749718632e-31, 1.873506103909561e-17, 1.615308081491078e-31, 1.873506103909562e-17, -1.742872858617179e-18, -0.007651392098640003, 0.02009212751996489, 0.01004606375998247, 1.361650338466342e-17, 0.01004606375998246, 1.873506103909562e-17, 0.01004606375998246, 1.668507579397383e-17, 0.01004606375998246, -0.004789343322684904, -1.652260459528918e-17, 3.337015158794737e-17, 1.668507579397385e-17, 1.631551769724017e-31, 1.668507579397384e-17, 1.615308081491077e-31, 1.668507579397383e-17, 1.665183639542917e-31, 1.668507579397384e-17, -3.249423973693037e-19, -0.007651392098640007, 0.0200921275199649, 0.01004606375998247, 1.361650338466343e-17, 0.01004606375998247, 1.873506103909563e-17, 0.01004606375998246, 1.668507579397384e-17, 0.01004606375998246, -0.004789343322684903, 3.612539513184754e-18, -0.009578686645369814, -0.004789343322684914, -4.135630511972954e-19, -0.004789343322684916, -1.742872858617175e-18, -0.004789343322684907, -3.249423973692947e-19, -0.004789343322684909, 0.009578686645369807],
"dof_bm_unit": [3, 6.000000000000002, 5.999999999999998, 1.027080069278844, 6.000000000000001, 1.038147635656014, 6.000000000000003, 1.078234257225623, 5.999999999999999, 2.999999999999998]
},
+ "sun_abraham_two_cohort": {
+ "unit": [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39],
+ "time": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8],
+ "first_treat": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
+ "y": [1.773957747246796, 1.790429797872335, 2.650040064996427, 2.207606033608381, 2.186817924937021, 3.387367460091027, 3.065261859053333, 4.493009078007295, -0.4804212727027615, 0.3631257703050145, 0.4962644612058844, 0.2433823581637234, 1.7231655884999, 1.556751481462569, 1.580182151903714, 1.973577202249643, 1.002772819364975, 1.00804485462688, -0.2334166302391283, 1.705569888102669, 1.679511089966852, 2.255743693770143, 2.754040275020093, 3.462996824983678, 0.569216575223808, 1.884133920983112, 1.700786664837078, 2.352115654309851, 2.593226889106364, 2.793301686394472, 2.211303135677113, 2.987769411655687, 1.016027404140771, 0.5275066442548271, 1.032853915854071, 1.89476657198184, 2.288357692058295, 2.436152117411083, 2.06138017443616, 2.254377873817071, 0.9502289888109801, 0.6228362026745314, 0.8380955984883091, 1.033427215138971, 0.7967110363282518, 1.999873932928709, 1.885305561035256, 2.202497130742555, 2.278195161724519, 2.522408552693063, 3.107580185406075, 2.473435035911602, 3.336696277802091, 4.007077225633939, 3.05612755771499, 3.481125704000018, -0.3605283788399826, -0.2242660381642955, 0.8453322382074829, 1.431943131411497, 2.00582364938615, 2.227716505170765, 1.503736638166978, 3.229581912423276, 1.985037009498133, 2.671180620105076, 2.707295772942613, 4.157248627899556, 4.612520231127791, 4.878004192875545, 5.10587743844334, 5.472460077848058, -0.005431716975754841, 0.2851773356036273, 0.00673636099017294, 1.94611903751067, 2.180960772008679, 4.088231401119978, 2.356227785352719, 3.405914010226882, 0.85805712056534, 1.169651783505087, 2.267220847321989, 3.006550086781467, 3.803958347071131, 3.890740213510577, 4.098033850998738, 3.692530731513931, 1.974271417521114, 2.976403613260076, 3.470455689912874, 4.240206715924369, 4.786676834733782, 5.64809021439109, 6.106573264189202, 5.138088508498196, -1.147520481237428, -0.1881115006524887, -0.7237254913954898, 0.7849045564181625, 1.068085649769116, 0.9672997899344449, 1.488797296445292, 1.996416859343541, -0.1857231913463332, 0.8779042448667302, 0.3807148123556375, 1.704626716882264, 2.56964252145868, 1.993027026597174, 2.800861995607021, 2.345438822008834, 0.750263438224183, 0.3298558129193017, 0.5327560012702151, 1.447552499613235, 2.362797646717679, 2.266537574630512, 2.699932498631124, 3.910516286398637, 0.8481874629490108, 0.7000592059947366, 1.617553839303766, 2.654581190256099, 3.430957172063744, 4.152161361935569, 3.239604142515328, 4.263275546860216, 0.05819610792317124, 0.7635298697162, 0.5008580091107943, 1.334056612814234, 1.343219147915581, 3.360476539240614, 3.248136067843217, 3.040359084141054, -3.08095898597436, -1.734951070883785, -1.514858488997392, -1.459633234115471, -0.08072747447356395, -0.1485099060796785, 0.6279479454103449, 0.8908717489549816, -2.280096615246807, -2.508585256022096, -1.190092519355502, -0.9633686174385022, -0.3586202249762264, -0.4377610095786406, 0.7620123617182981, 0.786989082438036, 1.746419197412419, 1.273112112987915, 1.74052812354001, 3.063000772570142, 4.022000798088049, 4.413357114089841, 5.327727568807168, 4.784524060031383, -1.00710321294423, 0.4602500046383105, 1.179023969600923, 1.923131027071177, 1.504930606801265, 1.917933623107972, 2.440450708541465, 2.56633351488293, -1.804180295551246, -1.273997417818252, -1.481919459349993, 0.4371776495115745, 0.7725789384627728, 0.976637515767209, 1.566501386822297, 1.637399159078982, 0.06203862576242368, 1.166476356016427, 0.6195675391943267, 0.3862815421942659, 2.52091658946208, 2.452326207475832, 2.667184597562244, 2.694017043896793, 1.728857650805945, 1.727665581959102, 2.372508563496614, 2.297492060519638, 3.385422986261713, 4.639793001211957, 4.1787928416169, 5.088650697110197, 1.594402246210496, 2.262135413077214, 2.660517763688375, 2.899710757199534, 4.069546967260822, 4.68381111061576, 5.117306386816691, 4.824007607333003, -0.4950777698609869, 0.6685653226712245, 1.09877170069158, 1.393962712798848, 0.3792123436311352, 3.394511215211997, 3.177967283293571, 2.956172136324671, 0.3945345066302006, -0.1429619973448615, 0.09465249645164164, 0.9672558426998689, 0.6434826889521804, 2.637740116510944, 3.491583567033403, 2.625793755748695, -1.832383462301312, -1.139881115469348, -1.371961145107833, -0.7548050651379192, 0.1732146206409191, 1.521609421790278, 1.528760247318767, 0.7110590832714875, 0.7330989864181885, 1.592483962000413, 1.766694873537074, 1.56468911777913, 0.6101324505494201, 3.290580674231829, 3.84697320357306, 3.882999144693821, -0.2612886058721816, 0.1757878104742661, 0.06173025601477453, 1.214994236915655, 1.095201823976852, 1.53866998875258, 3.150792852232647, 3.362234592544793, 1.16748710508216, 0.2241354221484223, 1.070796951438479, 1.973207031885982, 1.977311127030148, 3.429456275071873, 4.785246897676426, 3.446259961041385, -0.05176272023574935, 1.441684973444026, 1.261038916607113, 2.127857863709752, 1.79864497533352, 4.610865077396387, 3.742984351440104, 3.866169584213793, 1.251972776226618, 2.066385213783138, 1.983773764571332, 1.422295152364278, 2.532793138056693, 3.215224605800949, 4.154598978406204, 4.802639592777142, -0.3821626889066463, -0.03787004309785271, 0.5322583576450548, 1.087545442980656, 0.2678758755767675, 1.174329862169984, 2.455592533967791, 2.411613297727313, 0.2877754428283937, 0.7895891463125604, 1.698358983304468, 1.496793795206144, 1.61251121854531, 2.386663283041019, 2.986598005743295, 4.427892011416306, -1.659306387182698, -1.02244427323301, -0.7915055132858565, -0.517129013516564, 0.68768234214809, -0.3296726576283155, 1.95572654704065, 1.698777914703278, -0.9020619110328562, -0.2188408328877526, 0.488926835126737, 0.2027816244232016, 0.3294998740237753, 1.091923044994451, 2.809839414209903, 2.578811824284417, -1.24442087095639, -0.9042455463641998, -0.3351052567154878, 0.08553834313625253, 0.6383788733026693, 1.284341441321622, 2.031783886537278, 1.992152514167058, -1.810654652507383, -1.676479165576837, -0.935534115239264, -2.05544807968576, -0.8705481054991951, 0.06247329702470183, 1.047879250395052, 0.5695159371712619, 0.7023868502764627, 0.2001591718936493, 0.7094238512773592, 1.829889746206794, 1.391049950997216, 2.250395679374974, 2.990508752502671, 2.647941404518303],
+ "target_cohort_g": 4,
+ "target_event_time_e": 0,
+ "target_col_safe": "D_4_0",
+ "classical_se": 0.3267751000983616,
+ "hc2_se": 0.3089737532360686,
+ "cr2_bm_singleton_se": 0.3089737532360685,
+ "dof_bm_singleton": 27.439999999999753,
+ "cr2_bm_unit_se": 0.3311337313280216,
+ "dof_bm_unit": 13.999999999999952,
+ "sunab_hc1_event_study_e0_se": 0.1528572531585001,
+ "sunab_event_study_target_label": "time::0",
+ "dof_bm_contrast_es0_unit": 34.999999999999865,
+ "dof_bm_contrast_overall_unit": 21.211009174311531
+ },
"meta": {
"source": "clubSandwich",
"clubSandwich_version": "0.7.0",
"R_version": "R version 4.5.2 (2025-10-31)",
- "generated_at": "2026-05-19 01:30:25 UTC",
+ "generated_at": "2026-05-20 00:53:14 UTC",
"note": "CR2 Bell-McCaffrey cluster-robust parity target for diff_diff._compute_cr2_bm"
}
}
diff --git a/diff_diff/guides/llms-full.txt b/diff_diff/guides/llms-full.txt
index 1fd632e0..8bfbc55a 100644
--- a/diff_diff/guides/llms-full.txt
+++ b/diff_diff/guides/llms-full.txt
@@ -335,10 +335,11 @@ SunAbraham(
control_group: str = "never_treated", # "never_treated" or "not_yet_treated"
anticipation: int = 0,
alpha: float = 0.05,
- cluster: str | None = None, # Defaults to unit-level clustering
+ cluster: str | None = None, # Defaults to unit-level clustering (dropped on explicit vcov_type='hc2' / 'classical')
n_bootstrap: int = 0, # 0 = analytical cluster-robust SEs
seed: int | None = None,
rank_deficient_action: str = "warn",
+ vcov_type: str = "hc1", # {"classical","hc1","hc2","hc2_bm"}; classical/hc2/hc2_bm route through full-dummy saturated design. survey_design=... is rejected for classical/hc2/hc2_bm (use hc1 default for surveys)
)
```
diff --git a/diff_diff/sun_abraham.py b/diff_diff/sun_abraham.py
index 56040429..fd47d06c 100644
--- a/diff_diff/sun_abraham.py
+++ b/diff_diff/sun_abraham.py
@@ -63,6 +63,17 @@ class SunAbrahamResults:
Significance level used for confidence intervals.
control_group : str
Type of control group used.
+ vcov_type : str
+ Variance-covariance family from the fit-time configuration
+ (``classical``, ``hc1``, ``hc2``, or ``hc2_bm``). Note: when a
+ ``survey_design=`` is supplied, the survey-design Taylor Series
+ Linearization (or replicate-weight refit) variance overrides
+ this analytical family — the field still records the
+ configured value but ``survey_metadata`` indicates the survey
+ path was active. Likewise, on bootstrap fits (``n_bootstrap >
+ 0``) the SE comes from the pairs bootstrap (or Rao-Wu rescaled
+ bootstrap under stratified / PSU survey designs), not the
+ analytical family.
"""
event_study_effects: Dict[int, Dict[str, Any]]
@@ -79,6 +90,7 @@ class SunAbrahamResults:
n_control_units: int
alpha: float = 0.05
control_group: str = "never_treated"
+ vcov_type: str = "hc1"
# Anticipation periods (``k``) used at fit time. Persisted so
# downstream diagnostics (``BusinessReport`` / ``DiagnosticReport``
# / ``compute_pretrends_power``) can classify pre-period vs
@@ -372,7 +384,13 @@ class SunAbraham:
Significance level for confidence intervals.
cluster : str, optional
Column name for cluster-robust standard errors.
- If None, clusters at the unit level by default.
+ If None, clusters at the unit level by default — UNLESS
+ ``vcov_type`` is explicitly set to ``"hc2"`` or ``"classical"``,
+ in which case the unit auto-cluster is dropped (both are
+ one-way families and the linalg validator rejects them with
+ ``cluster_ids``). Use ``vcov_type="hc1"`` (default) or
+ ``vcov_type="hc2_bm"`` for cluster-robust inference; the latter
+ routes to CR2 Bell-McCaffrey at the cluster level.
n_bootstrap : int, default=0
Number of bootstrap iterations for inference.
If 0, uses analytical cluster-robust standard errors.
@@ -383,6 +401,54 @@ class SunAbraham:
- "warn": Issue warning and drop linearly dependent columns (default)
- "error": Raise ValueError
- "silent": Drop columns silently without warning
+ vcov_type : {"classical", "hc1", "hc2", "hc2_bm"}, default "hc1"
+ Variance-covariance family for analytical inference. Defaults to
+ ``"hc1"`` (preserves prior behavior bit-equally; SA historically
+ hard-coded HC1).
+
+ - ``"classical"``: homoskedastic OLS standard errors. One-way
+ only (linalg validator rejects ``classical + cluster_ids``);
+ the unit auto-cluster is dropped when ``classical`` is
+ explicitly opted into.
+ - ``"hc1"``: Eicker-Huber-White HC1 finite-sample correction
+ (default; cluster-robust when ``cluster=`` is set or the unit
+ auto-cluster fires).
+ - ``"hc2"``: Eicker-Huber-White HC2 leverage correction. One-way
+ only; the linalg validator rejects combining ``hc2`` with
+ clusters. The unit auto-cluster is dropped when ``hc2`` is
+ explicitly opted into.
+ - ``"hc2_bm"``: HC2 + Bell-McCaffrey CR2 Satterthwaite DOF for
+ cluster-robust inference. Routes to CR2-BM at the cluster
+ level; preserves the auto-cluster default.
+
+ When ``vcov_type ∈ {"classical","hc2","hc2_bm"}``, the
+ saturated regression switches from the within-transform path
+ to a full-dummy ``[intercept + interactions + covariates +
+ unit_dummies + time_dummies]`` build. For ``hc2`` and
+ ``hc2_bm``, the Frisch-Waugh-Lovell theorem preserves
+ coefficients but NOT the hat matrix, so HC2 leverage and BM
+ Satterthwaite DOF must be computed on the full FE projection.
+ ``classical`` also routes through full-dummy so the ``(n-k)``
+ finite-sample correction in ``s² × (X'X)^{-1}`` matches R's
+ ``lm()`` interpretation. Empirically matches
+ ``lm(...) + sandwich::vcovHC(type="HC2")`` and
+ ``clubSandwich::vcovCR(..., type="CR2")`` at atol=1e-10.
+
+ ``"hc1"`` keeps the within-transform path (cluster-robust HC1
+ does not depend on the hat matrix); empirically close to
+ ``fixest::sunab(cluster=~unit)``. See REGISTRY.md for the
+ documented HC1 finite-sample-correction deviation.
+
+ Survey designs (``survey_design=``) are rejected for
+ ``vcov_type ∈ {"classical","hc2","hc2_bm"}`` because the
+ survey-design Taylor Series Linearization (or replicate-weight
+ refit) variance overrides the analytical sandwich family, and
+ the auto-cluster guard for one-way families would silently
+ downgrade unit-level PSUs to per-observation PSUs. Use
+ ``vcov_type="hc1"`` (default) for survey designs.
+
+ ``conley`` spatial-HAC is not yet wired up for SunAbraham; see
+ TODO.md.
Attributes
----------
@@ -460,6 +526,7 @@ def __init__(
n_bootstrap: int = 0,
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
+ vcov_type: str = "hc1",
):
if control_group not in ["never_treated", "not_yet_treated"]:
raise ValueError(
@@ -473,6 +540,20 @@ def __init__(
f"got '{rank_deficient_action}'"
)
+ if vcov_type not in ("classical", "hc1", "hc2", "hc2_bm"):
+ if vcov_type == "conley":
+ raise ValueError(
+ "vcov_type='conley' is not yet wired up for SunAbraham: "
+ "would require threading conley_coords / conley_cutoff_km / "
+ "conley_metric / conley_kernel / conley_time / conley_unit / "
+ "conley_lag_cutoff through the saturated regression call. "
+ "Tracked in TODO.md (SA Conley follow-up row)."
+ )
+ raise ValueError(
+ f"vcov_type must be one of "
+ f"{{'classical','hc1','hc2','hc2_bm'}}; got '{vcov_type}'"
+ )
+
self.control_group = control_group
self.anticipation = anticipation
self.alpha = alpha
@@ -480,6 +561,15 @@ def __init__(
self.n_bootstrap = n_bootstrap
self.seed = seed
self.rank_deficient_action = rank_deficient_action
+ self.vcov_type = vcov_type
+ # Track whether the user explicitly opted out of the "hc1" default.
+ # The auto-cluster-at-unit default in `fit` is suppressed only when
+ # the user explicitly opts into a one-way family — currently
+ # ``vcov_type in {"hc2","classical"}``. Both are rejected by the
+ # linalg validator when combined with ``cluster_ids``. Leaving the
+ # auto-cluster on the default "hc1" path preserves backward compat;
+ # ``hc2_bm`` also keeps the auto-cluster (routes to CR2-BM at unit).
+ self._vcov_type_explicit = vcov_type != "hc1"
self.is_fitted_ = False
self.results_: Optional[SunAbrahamResults] = None
@@ -537,6 +627,31 @@ def fit(
if missing:
raise ValueError(f"Missing columns: {missing}")
+ # Validate explicit cluster column upfront. Without this guard, a
+ # missing `cluster=` column would cascade through cluster_var=None
+ # and silently downgrade clustered inference to one-way (HC1 →
+ # heteroskedasticity-only; HC2-BM → singleton CR2-BM). Explicit
+ # user input must error, not silently weaken the SE convention.
+ if self.cluster is not None:
+ if self.cluster not in data.columns:
+ raise ValueError(
+ f"cluster column {self.cluster!r} not found in data; "
+ f"available columns: {list(data.columns)}"
+ )
+ # NA cluster labels are silently dropped by the meat-side
+ # `groupby(cluster_ids)` but counted by `np.unique(cluster_ids)`
+ # in `n_clusters`, producing malformed cluster-robust SEs. Reject
+ # explicitly so the user fixes the cluster column rather than
+ # consuming silently-wrong inference.
+ if data[self.cluster].isna().any():
+ n_na = int(data[self.cluster].isna().sum())
+ raise ValueError(
+ f"cluster column {self.cluster!r} contains {n_na} "
+ "NA/NaN values. Cluster labels must be non-missing for "
+ "all observations to produce well-formed cluster-robust "
+ "standard errors. Drop or impute the NA rows before fit."
+ )
+
# Resolve survey design if provided
from diff_diff.survey import (
_resolve_effective_cluster,
@@ -560,6 +675,36 @@ def fit(
"Replicate weights provide their own variance estimation."
)
+ # Survey-design + non-HC1 analytical family reject: survey-design
+ # Taylor Series Linearization (or replicate-weight refit) variance
+ # overrides the analytical sandwich family, so the requested
+ # vcov_type ∈ {classical, hc2, hc2_bm} would either silently downgrade
+ # unit-as-PSU injection to per-observation PSUs (auto-cluster guard
+ # drops cluster_var=None before the survey path injects unit as PSU)
+ # or hit the linalg validator's hc2/classical + cluster_ids reject.
+ # Explicit reject preserves the "survey TSL overrides analytical"
+ # contract documented in REGISTRY. Use vcov_type='hc1' (default) for
+ # survey designs.
+ if resolved_survey is not None and self.vcov_type in ("classical", "hc2", "hc2_bm"):
+ raise NotImplementedError(
+ f"SunAbraham(vcov_type={self.vcov_type!r}) with survey_design "
+ "is not yet supported: the survey-design TSL (or replicate-"
+ "weight refit) variance overrides the analytical sandwich, "
+ "so the requested HC2/HC2-BM/classical family would be "
+ "silently discarded. Additionally, the auto-cluster guard "
+ "for explicit one-way families (classical/hc2) would drop "
+ "the unit auto-cluster before survey-PSU injection, "
+ "downgrading the panel structure from unit-level to "
+ "per-observation PSUs. Use vcov_type='hc1' (default) for "
+ "survey designs; the survey TSL machinery computes the "
+ "design-aware SE on the within-transform path."
+ )
+
+ # Note: the broader survey reject above (line ~625) already covers
+ # the replicate-weight + hc2/hc2_bm combo (replicate is a subset of
+ # survey). The replicate-only reject that previously lived here is
+ # redundant and was removed; see commit history for the rationale.
+
# Bootstrap + survey supported via Rao-Wu rescaled bootstrap.
# Determine Rao-Wu eligibility from the *original* survey_design
# (before cluster-as-PSU injection which adds PSU to weights-only designs).
@@ -632,7 +777,24 @@ def fit(
]
# Determine cluster variable
- cluster_var = self.cluster if self.cluster is not None else unit
+ # One-way HC2 and classical are single-way only — the linalg
+ # validator rejects `vcov_type ∈ {"hc2","classical"} + cluster_ids`.
+ # Drop the unit auto-cluster when the user opts into either
+ # explicitly. `hc1` and `hc2_bm` preserve the auto-cluster
+ # (route to CR1 / CR2-Bell-McCaffrey at unit respectively).
+ # SA has no `inference=` parameter — its bootstrap path uses the
+ # pairs bootstrap (or Rao-Wu rescaled bootstrap on stratified /
+ # PSU survey designs) via `n_bootstrap > 0`, which overrides the
+ # analytical SE downstream and does NOT consume the cluster
+ # structure of the main fit. So the SA guard simplifies to
+ # "explicit-vcov-only", without TWFE's `inference == "analytical"`
+ # subguard.
+ if self.cluster is not None:
+ cluster_var: Optional[str] = self.cluster
+ elif self.vcov_type in ("hc2", "classical") and self._vcov_type_explicit:
+ cluster_var = None
+ else:
+ cluster_var = unit
# Filter data based on control_group setting
if self.control_group == "never_treated":
@@ -642,8 +804,14 @@ def fit(
# Keep all units (not_yet_treated will be handled by the regression)
df_reg = df.copy()
- # Resolve effective cluster and inject cluster-as-PSU
- cluster_ids_raw = df_reg[cluster_var].values if cluster_var in df_reg.columns else None
+ # Resolve effective cluster and inject cluster-as-PSU.
+ # When `cluster_var is None` (one-way HC2 explicit path), the survey
+ # path skips PSU injection and the saturated regression receives
+ # `cluster_ids=None` downstream.
+ if cluster_var is not None and cluster_var in df_reg.columns:
+ cluster_ids_raw = df_reg[cluster_var].values
+ else:
+ cluster_ids_raw = None
effective_cluster_ids = _resolve_effective_cluster(
resolved_survey, cluster_ids_raw, cluster_var if self.cluster is not None else None
)
@@ -665,6 +833,7 @@ def fit(
cohort_ses,
vcov_cohort,
coef_index_map,
+ bm_artifacts,
) = self._fit_saturated_regression(
df_reg,
outcome,
@@ -681,6 +850,7 @@ def fit(
# computing bogus replicate vcov on already-demeaned data. We
# override vcov_cohort below with the correct estimator-level refit.
resolved_survey=None if _uses_replicate_sa else resolved_survey,
+ vcov_type=self.vcov_type,
)
# Replicate variance override: fully refit the IW estimator per
@@ -699,7 +869,7 @@ def _refit_sa(w_r):
nz = w_r > 0
df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg
w_nz = w_r[nz] if not np.all(nz) else w_r
- ce_r, _, vcov_r, cim_r = self._fit_saturated_regression(
+ ce_r, _, vcov_r, cim_r, _ = self._fit_saturated_regression(
df_reg_nz,
outcome,
unit,
@@ -712,6 +882,7 @@ def _refit_sa(w_r):
survey_weights=w_nz,
survey_weight_type=survey_weight_type,
resolved_survey=None,
+ vcov_type=self.vcov_type,
)
# Create temp weight column for IW aggregation with w_r
# Use full w_r (including zeros) for correct mass computation
@@ -809,8 +980,9 @@ def _refit_sa(w_r):
W_mat[i, j] = w
es_vcov = W_mat @ vcov_cohort @ W_mat.T
- # Compute overall ATT (average of post-treatment effects)
- overall_att, overall_se = self._compute_overall_att(
+ # Compute overall ATT (average of post-treatment effects).
+ # Capture overall_weights_by_coef for the hc2_bm contrast-DOF path.
+ overall_att, overall_se, _overall_weights_by_coef = self._compute_overall_att(
df,
first_treat,
event_study_effects,
@@ -819,10 +991,109 @@ def _refit_sa(w_r):
vcov_cohort,
coef_index_map,
survey_weight_col=survey_weight_col,
+ return_overall_weights=True,
)
+ # Bell-McCaffrey contrast-DOF for analytical hc2_bm aggregated
+ # inference. Cohort-level coefficients already use BM DOF via
+ # `LinearRegression.get_inference()` inside `_fit_saturated_regression`,
+ # but `event_study_effects` (IW-aggregated) and `overall_att` are
+ # linear contrasts of the cohort × event-time coefficients. Per
+ # the registry contract for `vcov_type="hc2_bm"`, the user-facing
+ # aggregated inference must use CR2 Bell-McCaffrey Satterthwaite
+ # DOF for each contrast — not the normal distribution that
+ # `safe_inference(..., df=None)` would otherwise default to.
+ # Mirrors the MultiPeriodDiD post-period-average contrast pattern
+ # added in PR #465 (`_compute_cr2_bm_contrast_dof`).
+ _es_contrast_dofs: Dict[int, float] = {}
+ _overall_att_contrast_dof: Optional[float] = None
+ if bm_artifacts is not None and not _uses_replicate_sa:
+ from diff_diff.linalg import _compute_cr2_bm_contrast_dof
+
+ X_full, cluster_ids_full, bread_matrix = bm_artifacts
+ n_full_coef = X_full.shape[1]
+ # `coef_index_map` is 0-indexed within the cohort-effects
+ # block; under full-dummy the interactions occupy columns
+ # `coef_offset .. coef_offset + n_interactions - 1` in
+ # X_full (where coef_offset == 1 for the intercept). Shift
+ # by the same offset when building the contrast vector in
+ # full-coef space — otherwise the contrast lands on the
+ # wrong columns (off-by-one with the intercept).
+ _coef_offset_bm = 1 # full-dummy → interactions at cols 1..n
+ # Per-event-time contrasts (IW aggregation across cohorts at
+ # each event-time): c_e[full_idx(g, e)] = w_{g,e} for each g.
+ es_contrast_keys: List[int] = []
+ es_contrast_columns: List[np.ndarray] = []
+ for e in sorted(event_study_effects.keys()):
+ w_dict = cohort_weights.get(e, {})
+ if not w_dict:
+ continue
+ col = np.zeros(n_full_coef)
+ for g, w_ge in w_dict.items():
+ key = (g, e)
+ if key in coef_index_map:
+ col[coef_index_map[key] + _coef_offset_bm] = w_ge
+ if np.any(col != 0):
+ es_contrast_keys.append(e)
+ es_contrast_columns.append(col)
+ # Overall ATT contrast: c_overall[full_idx(g,e)] = period_w × cohort_w
+ overall_col: Optional[np.ndarray] = None
+ if _overall_weights_by_coef:
+ overall_col = np.zeros(n_full_coef)
+ for (g, e), w in _overall_weights_by_coef.items():
+ if (g, e) in coef_index_map:
+ overall_col[coef_index_map[(g, e)] + _coef_offset_bm] = w
+ if es_contrast_columns or overall_col is not None:
+ contrast_cols: List[np.ndarray] = list(es_contrast_columns)
+ if overall_col is not None:
+ contrast_cols.append(overall_col)
+ contrasts_matrix = np.column_stack(contrast_cols)
+ try:
+ dof_vec = _compute_cr2_bm_contrast_dof(
+ X_full, cluster_ids_full, bread_matrix, contrasts_matrix
+ )
+ for idx, e in enumerate(es_contrast_keys):
+ _es_contrast_dofs[e] = float(dof_vec[idx])
+ if overall_col is not None:
+ _overall_att_contrast_dof = float(dof_vec[-1])
+ except (ValueError, np.linalg.LinAlgError) as exc:
+ # Rank-deficient or other linalg issue: fall back to
+ # the shared analytical df (downgraded to normal
+ # inference). Emit a UserWarning so the deviation is
+ # visible.
+ warnings.warn(
+ f"SunAbraham(vcov_type='hc2_bm') aggregated inference "
+ f"could not compute Bell-McCaffrey contrast DOF "
+ f"({type(exc).__name__}: {exc}). Falling back to "
+ "shared df; aggregated p-values/CIs may use normal "
+ "distribution instead of t(BM DOF).",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ # Apply contrast DOFs to the user-facing aggregated inference.
+ # Override the per-event-time inference fields with BM-DOF-aware
+ # values when available; otherwise leave the `safe_inference`
+ # output from `_compute_iw_effects` in place (which used
+ # `df=_sa_survey_df`).
+ if _es_contrast_dofs:
+ for e, df_e in _es_contrast_dofs.items():
+ eff_e = event_study_effects[e]["effect"]
+ se_e = event_study_effects[e]["se"]
+ t_e, p_e, ci_e = safe_inference(eff_e, se_e, alpha=self.alpha, df=df_e)
+ event_study_effects[e]["t_stat"] = t_e
+ event_study_effects[e]["p_value"] = p_e
+ event_study_effects[e]["conf_int"] = ci_e
+
overall_t, overall_p, overall_ci = safe_inference(
- overall_att, overall_se, alpha=self.alpha, df=_sa_survey_df
+ overall_att,
+ overall_se,
+ alpha=self.alpha,
+ df=(
+ _overall_att_contrast_dof
+ if _overall_att_contrast_dof is not None
+ else _sa_survey_df
+ ),
)
# Replicate variance override: refit fully re-aggregated estimates
@@ -871,7 +1142,7 @@ def _refit_sa_cohort(w_r):
nz = w_r > 0
df_reg_nz = df_reg[nz] if not np.all(nz) else df_reg
w_nz = w_r[nz] if not np.all(nz) else w_r
- ce_r, _, _, _ = self._fit_saturated_regression(
+ ce_r, _, _, _, _ = self._fit_saturated_regression(
df_reg_nz,
outcome,
unit,
@@ -884,6 +1155,7 @@ def _refit_sa_cohort(w_r):
survey_weights=w_nz,
survey_weight_type=survey_weight_type,
resolved_survey=None,
+ vcov_type=self.vcov_type,
)
return np.array([ce_r.get(k, np.nan) for k in _keys_ordered])
@@ -971,6 +1243,7 @@ def _refit_sa_cohort(w_r):
alpha=self.alpha,
control_group=self.control_group,
anticipation=self.anticipation,
+ vcov_type=self.vcov_type,
bootstrap_results=bootstrap_results,
cohort_effects=cohort_effects_storage,
survey_metadata=survey_metadata,
@@ -991,22 +1264,33 @@ def _fit_saturated_regression(
treatment_groups: List[Any],
rel_periods: List[int],
covariates: Optional[List[str]],
- cluster_var: str,
+ cluster_var: Optional[str],
survey_weights: Optional[np.ndarray] = None,
survey_weight_type: str = "pweight",
resolved_survey: object = None,
+ vcov_type: str = "hc1",
) -> Tuple[
Dict[Tuple[Any, int], float],
Dict[Tuple[Any, int], float],
np.ndarray,
Dict[Tuple[Any, int], int],
+ Optional[Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]],
]:
"""
Fit saturated TWFE regression with cohort × relative-time interactions.
Y_it = α_i + λ_t + Σ_g Σ_e [δ_{g,e} × D_{g,e,it}] + X'γ + ε
- Uses within-transformation for unit fixed effects and time dummies.
+ Uses within-transformation for unit + time fixed effects when
+ ``vcov_type == "hc1"`` (cluster-robust HC1 does not depend on
+ the hat matrix; matches ``fixest::sunab()`` convention). Routes
+ to a full-dummy saturated design when
+ ``vcov_type ∈ {"classical","hc2","hc2_bm"}``. For ``hc2`` /
+ ``hc2_bm``, FWL preserves coefficients/residuals but NOT the
+ hat matrix —
+ HC2 leverage and Bell-McCaffrey DOF must be computed on the full
+ FE projection. Mirrors the TwoWayFixedEffects Gate 1 pattern
+ from PR #469.
Returns
-------
@@ -1015,15 +1299,25 @@ def _fit_saturated_regression(
cohort_ses : dict
Mapping (cohort, rel_period) -> standard error
vcov : np.ndarray
- Variance-covariance matrix for cohort effects
+ Variance-covariance matrix for cohort effects (size
+ n_interactions × n_interactions; extracted from the full
+ vcov regardless of which path was taken).
coef_index_map : dict
- Mapping (cohort, rel_period) -> index in coefficient vector
+ Mapping (cohort, rel_period) -> index in the cohort_effects
+ block (0-based, NOT the index in the full coefficient vector
+ of the underlying regression).
"""
df = df.copy()
# Create cohort × relative-time interaction dummies
# Exclude reference period
- # Build all columns at once to avoid fragmentation
+ # Build all columns at once to avoid fragmentation.
+ # `coef_index_map` is 0-based within the interactions block; the
+ # index in the full coefficient vector depends on the branch:
+ # - Within-transform branch: matches coef_index_map directly
+ # (X has no intercept; interactions occupy positions 0..n-1)
+ # - Full-dummy branch: shift by 1 (intercept at position 0;
+ # interactions occupy positions 1..n)
interaction_data = {}
coef_index_map: Dict[Tuple[Any, int], int] = {}
idx = 0
@@ -1051,58 +1345,149 @@ def _fit_saturated_regression(
"No valid cohort × relative-time interactions found. " "Check your data structure."
)
- # Apply within-transformation for unit and time fixed effects
- variables_to_demean = [outcome] + interaction_cols
- if covariates:
- variables_to_demean.extend(covariates)
-
- df_demeaned = _within_transform_util(
- df, variables_to_demean, unit, time, suffix="_dm", weights=survey_weights
- )
+ n_interactions = len(interaction_cols)
+ n_units_fe = df[unit].nunique()
+ n_times_fe = df[time].nunique()
+ # Route through the full-dummy saturated design when the variance
+ # family depends on the hat matrix (hc2 / hc2_bm) — FWL preserves
+ # coefficients but not the hat matrix, so HC2 leverage and BM DOF
+ # must be computed on the full FE projection. Also route classical
+ # through full-dummy so the (n-k) finite-sample correction in
+ # ``s² × (X'X)^{-1}`` matches R's ``lm(y ~ ... + factor(unit) +
+ # factor(time))`` interpretation at atol=1e-12.
+ #
+ # hc1 stays on the within-transform path: cluster-robust HC1
+ # uses the cluster-mean residual outer product (no hat matrix), and
+ # matches ``fixest::sunab(cluster=~unit)`` (which also uses
+ # within-transform) at atol=1e-8 — fixest is the natural R parity
+ # anchor for SA's HC1 default.
+ use_full_dummy = vcov_type in ("hc2", "hc2_bm", "classical")
+
+ if use_full_dummy:
+ # Full-dummy auto-route: build [intercept, interactions,
+ # covariates, unit_dummies, time_dummies] explicitly. FWL
+ # preserves cohort coefficients but NOT the hat matrix, so HC2
+ # leverage and Bell-McCaffrey Satterthwaite DOF must be
+ # computed on the full FE projection (matches lm() +
+ # sandwich::vcovHC / clubSandwich::vcovCR). Memory guard
+ # mirrors PR #469's TWFE Gate 1 threshold.
+ n_obs = len(df)
+ n_cov = len(covariates or [])
+ dense_cells = n_obs * (1 + n_interactions + n_cov + (n_units_fe - 1) + (n_times_fe - 1))
+ if dense_cells > 50_000_000:
+ import warnings
- # Build design matrix
- X_cols = [f"{col}_dm" for col in interaction_cols]
- if covariates:
- X_cols.extend([f"{cov}_dm" for cov in covariates])
+ warnings.warn(
+ f"SunAbraham(vcov_type={vcov_type!r}) builds a dense "
+ f"full-dummy saturated design (~{dense_cells:,} float64 "
+ "cells, >50M). FWL preserves coefficients but not the hat "
+ "matrix, so HC2/HC2-BM requires the full-dummy projection "
+ "(within-transform would produce a methodologically "
+ "different statistic). For very high-cardinality panels, "
+ "consider vcov_type='hc1' (within-transform; no full-"
+ "dummy needed) or reducing the panel size.",
+ UserWarning,
+ stacklevel=2,
+ )
- X = df_demeaned[X_cols].values
- y = df_demeaned[f"{outcome}_dm"].values
+ interaction_arrs = [df[c].values.astype(np.float64) for c in interaction_cols]
+ cov_arrs = [df[c].values.astype(np.float64) for c in (covariates or [])]
+ unit_dummies = pd.get_dummies(
+ df[unit], prefix=f"_fe_{unit}", drop_first=True
+ ).values.astype(np.float64)
+ time_dummies = pd.get_dummies(
+ df[time], prefix=f"_fe_{time}", drop_first=True
+ ).values.astype(np.float64)
+ intercept = np.ones(len(df))
+ X = np.column_stack(
+ [intercept] + interaction_arrs + cov_arrs + [unit_dummies, time_dummies]
+ )
+ y = df[outcome].values.astype(np.float64)
+ if cluster_var is not None and cluster_var in df.columns:
+ cluster_ids = df[cluster_var].values
+ else:
+ cluster_ids = None
+ # Full-dummy already counts unit + time dummies in n_params, so
+ # no extra adjustment (matches TWFE PR #469 Gate 1).
+ df_adj = 0
+ # Interactions occupy columns 1..n_interactions (intercept at 0)
+ coef_offset = 1
+ else:
+ # Within-transform path (existing) — used for hc1 only.
+ # classical now routes through the full-dummy branch above so its
+ # (n-k) finite-sample correction matches R's lm() interpretation.
+ variables_to_demean = [outcome] + interaction_cols
+ if covariates:
+ variables_to_demean.extend(covariates)
+
+ df_demeaned = _within_transform_util(
+ df, variables_to_demean, unit, time, suffix="_dm", weights=survey_weights
+ )
- # Fit OLS using LinearRegression helper (more stable than manual X'X inverse)
- cluster_ids = df_demeaned[cluster_var].values
+ X_cols = [f"{col}_dm" for col in interaction_cols]
+ if covariates:
+ X_cols.extend([f"{cov}_dm" for cov in covariates])
- # Degrees of freedom adjustment for absorbed unit and time fixed effects
- n_units_fe = df[unit].nunique()
- n_times_fe = df[time].nunique()
- df_adj = n_units_fe + n_times_fe - 1
+ X = df_demeaned[X_cols].values
+ y = df_demeaned[f"{outcome}_dm"].values
+ if cluster_var is not None and cluster_var in df_demeaned.columns:
+ cluster_ids = df_demeaned[cluster_var].values
+ else:
+ cluster_ids = None
+ df_adj = n_units_fe + n_times_fe - 1
+ # Interactions occupy columns 0..n_interactions-1 (no intercept)
+ coef_offset = 0
reg = LinearRegression(
- include_intercept=False, # Already demeaned, no intercept needed
- robust=True,
+ include_intercept=False, # Full design already built (with or without intercept)
+ robust=True, # legacy alias; vcov_type below overrides
cluster_ids=cluster_ids,
rank_deficient_action=self.rank_deficient_action,
weights=survey_weights,
weight_type=survey_weight_type,
survey_design=resolved_survey,
+ vcov_type=vcov_type,
).fit(X, y, df_adjustment=df_adj)
vcov = reg.vcov_
- # Extract cohort effects and standard errors using get_inference
+ # Extract cohort effects and standard errors using get_inference.
+ # coef_index_map is 0-based within the interactions block; under
+ # full-dummy we shift by +1 to skip the intercept.
cohort_effects: Dict[Tuple[Any, int], float] = {}
cohort_ses: Dict[Tuple[Any, int], float] = {}
- n_interactions = len(interaction_cols)
for (g, e), coef_idx in coef_index_map.items():
- inference = reg.get_inference(coef_idx)
+ full_idx = coef_idx + coef_offset
+ inference = reg.get_inference(full_idx)
cohort_effects[(g, e)] = inference.coefficient
cohort_ses[(g, e)] = inference.se
- # Extract just the vcov for cohort effects (excluding covariates)
+ # Extract the vcov sub-block for cohort effects only (covariates
+ # and FE dummies excluded). Under full-dummy the interactions
+ # start at column 1; under within-transform they start at 0.
assert vcov is not None
- vcov_cohort = vcov[:n_interactions, :n_interactions]
+ vcov_cohort = vcov[
+ coef_offset : coef_offset + n_interactions,
+ coef_offset : coef_offset + n_interactions,
+ ]
- return cohort_effects, cohort_ses, vcov_cohort, coef_index_map
+ # Stash BM contrast-DOF artifacts when hc2_bm — needed by the
+ # aggregated inference layer to compute per-event-time and
+ # overall-ATT Satterthwaite DOF on user-facing outputs. Under
+ # other vcov_type values aggregated inference falls back to the
+ # shared analytical df (None → normal distribution).
+ if vcov_type == "hc2_bm":
+ bread_matrix = X.T @ X
+ bm_artifacts: Optional[Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]] = (
+ X,
+ cluster_ids,
+ bread_matrix,
+ )
+ else:
+ bm_artifacts = None
+
+ return cohort_effects, cohort_ses, vcov_cohort, coef_index_map, bm_artifacts
def _within_transform(
self,
@@ -1225,6 +1610,7 @@ def _compute_overall_att(
vcov_cohort: np.ndarray,
coef_index_map: Dict[Tuple[Any, int], int],
survey_weight_col: Optional[str] = None,
+ return_overall_weights: bool = False,
) -> Tuple[float, float]:
"""
Compute overall ATT as weighted average of post-treatment effects.
@@ -1232,11 +1618,19 @@ def _compute_overall_att(
When survey weights are provided, the per-period weights use
survey-weighted mass rather than raw observation counts.
- Returns (att, se) tuple.
+ Returns (att, se) tuple. When ``return_overall_weights=True``,
+ the returned tuple is extended to (att, se, overall_weights_by_coef)
+ where the dict maps (g, e) → weight in the overall ATT
+ contrast (i.e. ``c[full_idx(g,e)] = period_weight × cohort_weight``).
+ Used by the analytical hc2_bm path to build Bell-McCaffrey
+ contrast DOFs for the user-facing aggregated inference. The dict
+ is ``None`` when the simplified-variance fallback path was taken.
"""
post_effects = [(e, eff) for e, eff in event_study_effects.items() if e >= 0]
if not post_effects:
+ if return_overall_weights:
+ return np.nan, np.nan, None
return np.nan, np.nan
# Weight by (survey-weighted) mass of treated observations at each relative time
@@ -1288,6 +1682,8 @@ def _compute_overall_att(
(post_weights_arr**2) * np.array([eff["se"] ** 2 for _, eff in post_effects])
)
)
+ if return_overall_weights:
+ return overall_att, np.sqrt(overall_var), None
return overall_att, np.sqrt(overall_var)
# Build full weight vector and compute variance
@@ -1297,6 +1693,8 @@ def _compute_overall_att(
overall_var = float(weight_vec @ vcov_subset @ weight_vec)
overall_se = np.sqrt(max(overall_var, 0))
+ if return_overall_weights:
+ return overall_att, overall_se, overall_weights_by_coef
return overall_att, overall_se
def _run_bootstrap(
@@ -1309,7 +1707,7 @@ def _run_bootstrap(
treatment_groups: List[Any],
rel_periods_to_estimate: List[int],
covariates: Optional[List[str]],
- cluster_var: str,
+ cluster_var: Optional[str],
original_event_study: Dict[int, Dict[str, Any]],
original_overall_att: float,
resolved_survey: object = None,
@@ -1400,6 +1798,7 @@ def _run_bootstrap(
cohort_ses_b,
vcov_b,
coef_map_b,
+ _,
) = self._fit_saturated_regression(
df_b,
outcome,
@@ -1413,6 +1812,7 @@ def _run_bootstrap(
survey_weights=boot_survey_weights,
survey_weight_type=survey_weight_type,
resolved_survey=None, # Use explicit weights, not stale design
+ vcov_type=self.vcov_type,
)
# Compute IW effects for this bootstrap sample
@@ -1509,7 +1909,7 @@ def _run_rao_wu_bootstrap(
treatment_groups: List[Any],
rel_periods_to_estimate: List[int],
covariates: Optional[List[str]],
- cluster_var: str,
+ cluster_var: Optional[str],
original_event_study: Dict[int, Dict[str, Any]],
original_overall_att: float,
resolved_survey: object,
@@ -1631,6 +2031,7 @@ def _run_rao_wu_bootstrap(
cohort_ses_b,
vcov_b,
coef_map_b,
+ _,
) = self._fit_saturated_regression(
df_b,
outcome,
@@ -1644,6 +2045,7 @@ def _run_rao_wu_bootstrap(
survey_weights=boot_weights_b,
survey_weight_type=survey_weight_type,
resolved_survey=None,
+ vcov_type=self.vcov_type,
)
# Compute IW effects using rescaled weights for cohort shares
@@ -1742,6 +2144,7 @@ def get_params(self) -> Dict[str, Any]:
"n_bootstrap": self.n_bootstrap,
"seed": self.seed,
"rank_deficient_action": self.rank_deficient_action,
+ "vcov_type": self.vcov_type,
}
def set_params(self, **params) -> "SunAbraham":
@@ -1751,6 +2154,10 @@ def set_params(self, **params) -> "SunAbraham":
setattr(self, key, value)
else:
raise ValueError(f"Unknown parameter: {key}")
+ # Refresh the explicit-vcov-type flag if vcov_type changed, so the
+ # auto-cluster guard at fit time uses the updated value.
+ if "vcov_type" in params:
+ self._vcov_type_explicit = self.vcov_type != "hc1"
return self
def summary(self) -> str:
diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md
index ebc9475a..ce624616 100644
--- a/docs/methodology/REGISTRY.md
+++ b/docs/methodology/REGISTRY.md
@@ -975,7 +975,77 @@ Interaction-weighted estimator:
where weights ŵ_{g,e} = n_{g,e} / Σ_g n_{g,e} (sample share of cohort g at event-time e).
*Standard errors:*
-- Default: Cluster-robust at unit level
+- Default: Cluster-robust HC1 at unit level (`vcov_type="hc1"`)
+- `vcov_type ∈ {"classical","hc1","hc2","hc2_bm"}` supported as of Phase 1b
+ PR 1/8 (mirrors the DiD/MPD/TWFE chain established in Phase 1a):
+ - `"hc1"` (default): Eicker-Huber-White HC1 with cluster-at-unit default.
+ Auto-clusters at unit unless an explicit `cluster=` is passed.
+ - `"classical"`: homoskedastic OLS standard errors. Auto-cluster is
+ dropped (one-way only). Routes through the full-dummy saturated
+ design (see Implementation note below) for R-parity.
+ - `"hc2"`: HC2 leverage correction. Auto-cluster is dropped (one-way
+ only); the linalg validator rejects `hc2 + cluster_ids`. Routes
+ through full-dummy.
+ - `"hc2_bm"`: HC2 + Bell-McCaffrey CR2 Satterthwaite DOF for
+ cluster-robust inference. Auto-cluster fires at unit (or explicit
+ `cluster=`); routes through full-dummy. R-parity matches
+ `clubSandwich::vcovCR(..., type="CR2")` + `coef_test()$df_Satt` at
+ atol=1e-10.
+ - `"conley"`: rejected at `__init__` (deferred; would require
+ threading `conley_coords` / `conley_cutoff_km` / ... through
+ `_fit_saturated_regression`).
+- **Note (Phase 1b auto-route):** When `vcov_type ∈ {"classical","hc2",
+ "hc2_bm"}`, `_fit_saturated_regression` bypasses the within-transform
+ path and builds the full-dummy saturated design `[intercept +
+ cohort × event-time interactions + covariates + unit_dummies +
+ time_dummies]` directly. The FWL theorem preserves cohort coefficients
+ and residuals but does NOT preserve the hat matrix, so HC2 leverage
+ and Bell-McCaffrey Satterthwaite DOF must be computed on the full FE
+ projection (matches `lm() + sandwich::vcovHC` / `clubSandwich::vcovCR`).
+ Classical SE also routes through full-dummy so the `(n-k)` finite-
+ sample correction matches R's `lm()` interpretation at atol=1e-10.
+ `hc1` stays on the within-transform path (cluster-robust HC1 doesn't
+ depend on the hat matrix); matches `fixest::sunab()` event-study
+ aggregates closely (see deviation note below).
+- **Note (Phase 1b aggregated BM contrast DOF):** Under
+ `vcov_type="hc2_bm"`, the user-facing aggregated inference
+ (`event_study_effects[e]['p_value']`/`['conf_int']`,
+ `overall_p_value`/`overall_conf_int`) uses CR2 Bell-McCaffrey
+ Satterthwaite DOF per contrast — not the normal distribution.
+ Per-event-time contrast `c_e[full_idx(g,e)] = w_{g,e}` (IW weight)
+ and overall ATT contrast `c_overall[full_idx(g,e)] = w_e × w_{g,e}`
+ are passed to `_compute_cr2_bm_contrast_dof` (the helper PR #465
+ added for MultiPeriodDiD's post-period-average DOF). The resulting
+ per-contrast DOF threads into `safe_inference(..., df=)`.
+ Matches `clubSandwich::Wald_test(constraints=matrix(c, 1),
+ test="HTZ")$df_denom` at atol=1e-10 (pinned in
+ `tests/test_methodology_sun_abraham.py`). Cohort-level coefficients
+ separately get per-coefficient BM DOF via
+ `LinearRegression.get_inference()` inside `_fit_saturated_regression`.
+ If the linalg helper fails (rank-deficient design, singular bread),
+ the aggregated inference falls back to the shared analytical df with
+ an explicit `UserWarning`.
+- **Deviation from R (HC1 finite-sample correction):** SA's
+ within-transform HC1 SE differs from `fixest::sunab(cluster=~unit)`
+ by ~1-2% on typical panel sizes. fixest's correction counts the
+ absorbed unit + time FE in the effective parameter count
+ (`n / (n - k_total)`) whereas SA's `solve_ols` counts only the
+ within-transformed design columns (`n / (n - k_dm)`). The IW
+ aggregation step is otherwise identical. Tracked as a follow-up
+ (harmonizing the correction or documenting it as an intentional
+ difference).
+- Survey designs (`survey_design=`) + `vcov_type ∈ {"classical","hc2",
+ "hc2_bm"}` are rejected at fit-time: the survey-design Taylor Series
+ Linearization (or replicate-weight refit) variance overrides the
+ analytical sandwich family, so the requested HC2/HC2-BM/classical
+ family would be silently discarded. Additionally, the auto-cluster
+ guard for one-way families (classical/hc2) would drop the unit
+ auto-cluster before survey-PSU injection, downgrading the panel
+ structure from unit-level to per-observation PSUs. Mirrors the
+ TWFE Gate 1 + replicate-weight gate from PR #469 and the
+ `linalg.py::_validate_vcov_args` `hc2_bm + weights` gate. Use
+ `vcov_type="hc1"` (default) for survey designs; the survey TSL
+ machinery computes the design-aware SE on the within-transform path.
- Delta method for aggregated coefficients
- Optional: Pairs bootstrap for robustness
@@ -999,9 +1069,21 @@ where weights ŵ_{g,e} = n_{g,e} / Σ_g n_{g,e} (sample share of cohort g at eve
- **Note**: Defensive enhancement matching CallawaySantAnna behavior; R's `fixest::sunab()` may produce Inf/NaN without warning
- Inference distribution:
- Cohort-level p-values: t-distribution (via `LinearRegression.get_inference()`)
- - Aggregated event study and overall ATT p-values: normal distribution (via `compute_p_value()`)
- - This is asymptotically equivalent and standard for delta-method-aggregated quantities
- - **Deviation from R**: R's fixest uses t-distribution at all levels; aggregated p-values may differ slightly for small samples
+ - Aggregated event study and overall ATT p-values:
+ - Under `vcov_type="hc2_bm"`: t-distribution with CR2 Bell-McCaffrey
+ contrast DOF per aggregated effect (see "Phase 1b aggregated BM
+ contrast DOF" Note above). Matches `clubSandwich::Wald_test(
+ test="HTZ")$df_denom`.
+ - Under `vcov_type ∈ {"classical","hc1","hc2"}` (no replicate-weight
+ survey): normal distribution (via `compute_p_value()`), which is
+ asymptotically equivalent and standard for delta-method-aggregated
+ quantities.
+ - Under replicate-weight survey: t-distribution with replicate-derived
+ DOF (`survey_metadata.df_survey`).
+ - **Deviation from R**: R's fixest uses t-distribution at all levels
+ under `vcov_type ∈ {"classical","hc1","hc2"}`; aggregated p-values
+ may differ slightly for small samples on those families. The
+ `hc2_bm` aggregated path matches clubSandwich exactly.
**Reference implementation(s):**
- R: `fixest::sunab()` (Laurent Bergé's implementation)
diff --git a/tests/test_methodology_sun_abraham.py b/tests/test_methodology_sun_abraham.py
new file mode 100644
index 00000000..f3a687ba
--- /dev/null
+++ b/tests/test_methodology_sun_abraham.py
@@ -0,0 +1,264 @@
+"""R-parity tests for SunAbraham vcov_type threading (Phase 1b PR 1/8).
+
+Pins SA(vcov_type=...) cohort SEs and event-study SEs against R goldens
+generated by benchmarks/R/generate_clubsandwich_golden.R. The golden file
+contains the panel data so Python and R fit the SAME numerical panel.
+
+Parity targets:
+- classical → R lm() summary at atol=1e-10 (full-dummy)
+- hc2 → R sandwich::vcovHC(type="HC2") at atol=1e-10 (full-dummy)
+- hc2_bm → R clubSandwich::vcovCR(cluster=..., type="CR2") + coef_test$df_Satt at atol=1e-10 (full-dummy)
+- hc1 event-study e=0 → R fixest::sunab + cluster=~unit at atol=5e-3
+ (within-transform; documented HC1 finite-sample-correction deviation,
+ see REGISTRY.md SunAbraham section and TODO.md row tracking the gap)
+"""
+
+import json
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import pytest
+
+from diff_diff import SunAbraham
+
+GOLDEN_PATH = (
+ Path(__file__).resolve().parents[1] / "benchmarks" / "data" / "clubsandwich_cr2_golden.json"
+)
+
+
+@pytest.fixture(scope="module")
+def sa_golden():
+ """Load the sun_abraham_two_cohort scenario from the clubSandwich golden."""
+ if not GOLDEN_PATH.exists():
+ pytest.skip(f"Golden file missing: {GOLDEN_PATH}")
+ with GOLDEN_PATH.open() as f:
+ all_goldens = json.load(f)
+ if "sun_abraham_two_cohort" not in all_goldens:
+ pytest.skip("sun_abraham_two_cohort scenario missing from golden")
+ return all_goldens["sun_abraham_two_cohort"]
+
+
+@pytest.fixture(scope="module")
+def sa_panel(sa_golden):
+ """Reconstruct the panel DataFrame from the golden arrays."""
+ return pd.DataFrame(
+ {
+ "unit": sa_golden["unit"],
+ "time": sa_golden["time"],
+ "first_treat": sa_golden["first_treat"],
+ "y": sa_golden["y"],
+ }
+ )
+
+
+def _target_cohort_se(sa_results, g, e):
+ """Extract SE for the (cohort=g, event_time=e) interaction from SA results."""
+ return float(sa_results.cohort_effects[(g, e)]["se"])
+
+
+class TestSunAbrahamRParity:
+ """SA vcov_type R-parity tests against clubSandwich/sandwich/fixest goldens."""
+
+ def test_sa_classical_cohort_se_matches_lm_summary(self, sa_panel, sa_golden):
+ """SA(vcov_type='classical') cohort SE matches R lm summary() classical SE."""
+ sa = SunAbraham(vcov_type="classical")
+ res = sa.fit(sa_panel, outcome="y", unit="unit", time="time", first_treat="first_treat")
+ sa_se = _target_cohort_se(
+ res, sa_golden["target_cohort_g"], sa_golden["target_event_time_e"]
+ )
+ r_se = sa_golden["classical_se"]
+ assert np.isclose(sa_se, r_se, atol=1e-10), (
+ f"SA classical SE {sa_se} does not match R lm() classical SE {r_se} "
+ f"at atol=1e-10 (diff={abs(sa_se - r_se):.2e})"
+ )
+
+ def test_sa_hc2_cohort_se_matches_lm_vcovHC(self, sa_panel, sa_golden):
+ """SA(vcov_type='hc2') cohort SE matches R sandwich::vcovHC(type='HC2')."""
+ sa = SunAbraham(vcov_type="hc2")
+ res = sa.fit(sa_panel, outcome="y", unit="unit", time="time", first_treat="first_treat")
+ sa_se = _target_cohort_se(
+ res, sa_golden["target_cohort_g"], sa_golden["target_event_time_e"]
+ )
+ r_se = sa_golden["hc2_se"]
+ assert np.isclose(sa_se, r_se, atol=1e-10), (
+ f"SA hc2 SE {sa_se} does not match R vcovHC(type='HC2') SE {r_se} "
+ f"at atol=1e-10 (diff={abs(sa_se - r_se):.2e})"
+ )
+ # Sanity: hc2 == cr2_bm_singleton (one-way HC2-BM via singleton-cluster trick)
+ assert np.isclose(sa_se, sa_golden["cr2_bm_singleton_se"], atol=1e-10), (
+ "SA hc2 SE should match clubSandwich CR2 with singleton clusters "
+ "(both reduce to one-way HC2-BM)"
+ )
+
+ def test_sa_hc2_bm_cohort_se_matches_clubsandwich_cr2(self, sa_panel, sa_golden):
+ """SA(vcov_type='hc2_bm') cohort SE matches R clubSandwich::vcovCR(cluster=unit, type='CR2').
+
+ Auto-cluster fires (no explicit cluster=); routes to CR2-BM at unit.
+ """
+ sa = SunAbraham(vcov_type="hc2_bm")
+ res = sa.fit(sa_panel, outcome="y", unit="unit", time="time", first_treat="first_treat")
+ sa_se = _target_cohort_se(
+ res, sa_golden["target_cohort_g"], sa_golden["target_event_time_e"]
+ )
+ r_se = sa_golden["cr2_bm_unit_se"]
+ assert np.isclose(sa_se, r_se, atol=1e-10), (
+ f"SA hc2_bm SE {sa_se} does not match R clubSandwich CR2 (cluster=unit) "
+ f"SE {r_se} at atol=1e-10 (diff={abs(sa_se - r_se):.2e})"
+ )
+
+ def test_bm_dof_matches_clubsandwich_singleton_and_unit(self, sa_panel, sa_golden):
+ """Bell-McCaffrey Satterthwaite DOF parity vs R clubSandwich.
+
+ SA's cohort_effects dict doesn't surface per-coefficient BM DOF, so
+ we exercise the underlying LinearRegression machinery directly on
+ the same full-dummy design SA's Part G builds. Pins both the
+ singleton-cluster (one-way HC2-BM) DOF and the cluster=unit DOF at
+ atol=1e-10. A regression in BM DOF logic would change inference
+ output (p-values, CIs) without changing SEs; this test guards that.
+ """
+ import pandas as pd
+ from diff_diff.linalg import LinearRegression
+
+ # Reconstruct the full-dummy design SA's Part G builds (same shape
+ # as benchmarks/R/generate_clubsandwich_golden.R `sun_abraham_two_cohort`).
+ df = sa_panel.copy()
+ df["rel_time"] = np.where(df["first_treat"] > 0, df["time"] - df["first_treat"], -999)
+ treatment_groups = sorted(g for g in df["first_treat"].unique() if g > 0)
+ all_rel = sorted(e for e in df.loc[df["first_treat"] > 0, "rel_time"].unique() if e != -1)
+ interaction_cols = []
+ for g in treatment_groups:
+ for e in all_rel:
+ name = f"D_{g}_{e}"
+ ind = ((df["first_treat"] == g) & (df["rel_time"] == e)).astype(float)
+ if ind.sum() > 0:
+ df[name] = ind
+ interaction_cols.append(name)
+
+ target_col = f"D_{sa_golden['target_cohort_g']}_{sa_golden['target_event_time_e']}"
+ target_within_idx = interaction_cols.index(target_col)
+ # Full-dummy: intercept at 0, interactions at 1..n_interactions
+ target_full_idx = target_within_idx + 1
+
+ interaction_arrs = [df[c].values.astype(np.float64) for c in interaction_cols]
+ unit_dummies = pd.get_dummies(df["unit"], prefix="_fe_unit", drop_first=True).values.astype(
+ np.float64
+ )
+ time_dummies = pd.get_dummies(df["time"], prefix="_fe_time", drop_first=True).values.astype(
+ np.float64
+ )
+ intercept = np.ones(len(df))
+ X = np.column_stack([intercept] + interaction_arrs + [unit_dummies, time_dummies])
+ y = df["y"].values.astype(np.float64)
+
+ # Singleton-cluster CR2 (one-way HC2-BM): cluster_ids = seq_len(n)
+ reg_singleton = LinearRegression(
+ include_intercept=False,
+ cluster_ids=np.arange(len(df)),
+ vcov_type="hc2_bm",
+ ).fit(X, y)
+ bm_dof_singleton = float(reg_singleton._bm_dof[target_full_idx])
+ r_dof_singleton = sa_golden["dof_bm_singleton"]
+ assert np.isclose(bm_dof_singleton, r_dof_singleton, atol=1e-10), (
+ f"Singleton-cluster BM DOF {bm_dof_singleton} does not match R "
+ f"{r_dof_singleton} at atol=1e-10"
+ )
+
+ # CR2-BM at unit (the SA auto-cluster default for hc2_bm)
+ reg_unit = LinearRegression(
+ include_intercept=False,
+ cluster_ids=df["unit"].values,
+ vcov_type="hc2_bm",
+ ).fit(X, y)
+ bm_dof_unit = float(reg_unit._bm_dof[target_full_idx])
+ r_dof_unit = sa_golden["dof_bm_unit"]
+ assert np.isclose(bm_dof_unit, r_dof_unit, atol=1e-10), (
+ f"CR2-BM (cluster=unit) DOF {bm_dof_unit} does not match R "
+ f"{r_dof_unit} at atol=1e-10"
+ )
+
+ def test_sa_hc2_bm_aggregated_inference_uses_bm_contrast_dof(self, sa_panel, sa_golden):
+ """SA(vcov_type='hc2_bm') user-facing aggregated inference uses
+ Bell-McCaffrey contrast DOF, NOT normal-distribution DOF.
+
+ Both the IW-aggregated event-study effects and the overall ATT
+ are linear contrasts of cohort × event-time coefficients. The
+ registry contract for ``vcov_type="hc2_bm"`` requires CR2 BM
+ Satterthwaite DOF on these aggregated outputs (matching
+ ``clubSandwich::Wald_test(constraints=matrix(c, 1), test="HTZ")``).
+ This test guards against a regression to ``df=None`` →
+ normal-distribution inference, which would silently change
+ ``p_value``/``conf_int`` even though SE matches R at atol=1e-10.
+
+ Reverse-engineers the DOF from ``p_value`` and ``|t|`` to pin
+ the threading end-to-end.
+ """
+ from scipy import stats
+
+ sa = SunAbraham(vcov_type="hc2_bm")
+ res = sa.fit(sa_panel, outcome="y", unit="unit", time="time", first_treat="first_treat")
+
+ def _dof_from_p(t_stat, p_value, lo=2.0, hi=1e6):
+ target = 1 - p_value / 2
+ for _ in range(80):
+ mid = (lo + hi) / 2
+ if stats.t.cdf(abs(t_stat), mid) < target:
+ lo = mid
+ else:
+ hi = mid
+ return 0.5 * (lo + hi)
+
+ # Event-study e=0 contrast DOF: 4 cohorts contribute (g=4,5,6,7),
+ # IW weights 0.25 each. R clubSandwich Wald_test(test='HTZ') = 35.0.
+ eff_0 = res.event_study_effects[0]
+ t_es0 = eff_0["effect"] / eff_0["se"]
+ df_es0 = _dof_from_p(t_es0, eff_0["p_value"])
+ r_df_es0 = sa_golden["dof_bm_contrast_es0_unit"]
+ assert np.isclose(df_es0, r_df_es0, atol=1e-6), (
+ f"SA event-study e=0 inferred BM DOF {df_es0:.6f} does not match "
+ f"R clubSandwich Wald_test HTZ DOF {r_df_es0:.6f}. If this regresses, "
+ "the user-facing event_study_effects[0]['p_value']/['conf_int'] is "
+ "using normal-distribution DOF instead of CR2-BM contrast DOF."
+ )
+
+ # Overall ATT contrast DOF.
+ t_all = res.overall_att / res.overall_se
+ df_all = _dof_from_p(t_all, res.overall_p_value)
+ r_df_all = sa_golden["dof_bm_contrast_overall_unit"]
+ assert np.isclose(df_all, r_df_all, atol=1e-6), (
+ f"SA overall ATT inferred BM DOF {df_all:.6f} does not match "
+ f"R clubSandwich Wald_test HTZ DOF {r_df_all:.6f}. The overall_att "
+ "inference is using normal-distribution DOF instead of CR2-BM "
+ "contrast DOF."
+ )
+
+ def test_sa_hc1_event_study_e0_se_close_to_fixest_sunab(self, sa_panel, sa_golden):
+ """SA(vcov_type='hc1') event-study e=0 SE matches fixest::sunab at atol=5e-3.
+
+ SA hc1 uses within-transform path; fixest::sunab also uses within-transform.
+ **Deviation from R:** fixest applies a finite-sample correction that
+ counts absorbed unit + time FE in the effective parameter count
+ (n / (n - k_total)) whereas SA's solve_ols counts only the
+ within-transformed design columns (n / (n - k_dm)) — the SE values
+ differ by ~1-2% (typically ~2e-3 in absolute terms on this panel).
+ The IW aggregation step is otherwise identical. Tolerance set to
+ atol=5e-3 (matches REGISTRY.md deviation note and the TODO.md
+ follow-up row that tracks harmonizing the finite-sample correction).
+
+ Tests the IW-aggregated event-study coefficient (matches fixest's
+ natural sunab() output), not the per-cohort coefficient.
+ """
+ sa = SunAbraham(vcov_type="hc1")
+ res = sa.fit(sa_panel, outcome="y", unit="unit", time="time", first_treat="first_treat")
+ sa_se = float(res.event_study_effects[0]["se"])
+ r_se = sa_golden["sunab_hc1_event_study_e0_se"]
+ if r_se is None or (isinstance(r_se, float) and np.isnan(r_se)):
+ pytest.skip("fixest::sunab parity SE missing from golden")
+ # Absolute tolerance pin matching the documented deviation; harmonized
+ # with REGISTRY.md SunAbraham deviation note and TODO.md row.
+ assert abs(sa_se - r_se) < 5e-3, (
+ f"SA hc1 event-study e=0 SE {sa_se:.6f} diverges from fixest::sunab "
+ f"event-study SE {r_se:.6f} by {abs(sa_se - r_se):.2e} > 5e-3. "
+ "Expected ~2e-3 from finite-sample correction (see REGISTRY.md "
+ "deviation note); investigate if larger."
+ )
diff --git a/tests/test_sun_abraham.py b/tests/test_sun_abraham.py
index 4dd79381..3c4bcf1a 100644
--- a/tests/test_sun_abraham.py
+++ b/tests/test_sun_abraham.py
@@ -362,11 +362,7 @@ def test_bootstrap_basic(self, ci_params):
assert results.bootstrap_results.n_bootstrap == n_boot
assert results.bootstrap_results.weight_type == "pairs"
assert results.overall_se > 0
- assert (
- results.overall_conf_int[0]
- < results.overall_att
- < results.overall_conf_int[1]
- )
+ assert results.overall_conf_int[0] < results.overall_att < results.overall_conf_int[1]
def test_bootstrap_reproducibility(self, ci_params):
"""Test that bootstrap is reproducible with same seed."""
@@ -503,9 +499,7 @@ def test_both_recover_treatment_effect(self):
"""Test that both estimators recover the treatment effect."""
from diff_diff import CallawaySantAnna
- data = generate_staggered_data(
- n_units=200, treatment_effect=3.0, seed=42
- )
+ data = generate_staggered_data(n_units=200, treatment_effect=3.0, seed=42)
# Sun-Abraham
sa = SunAbraham()
@@ -543,9 +537,7 @@ def test_pre_period_difference_expected_between_cs_sa(self):
"""
from diff_diff import CallawaySantAnna
- data = generate_staggered_data(
- n_units=200, treatment_effect=3.0, seed=42
- )
+ data = generate_staggered_data(n_units=200, treatment_effect=3.0, seed=42)
# Sun-Abraham (uses fixed reference period e=-1)
sa = SunAbraham()
@@ -747,9 +739,7 @@ def test_single_cohort(self):
def test_many_cohorts(self):
"""Test with many treatment cohorts."""
- data = generate_staggered_data(
- n_units=200, n_periods=15, n_cohorts=8, seed=42
- )
+ data = generate_staggered_data(n_units=200, n_periods=15, n_cohorts=8, seed=42)
sa = SunAbraham()
results = sa.fit(
@@ -858,7 +848,7 @@ def test_rank_deficient_action_error_raises(self):
unit="unit",
time="time",
first_treat="first_treat",
- covariates=["cov1", "cov1_dup"]
+ covariates=["cov1", "cov1_dup"],
)
def test_rank_deficient_action_silent_no_warning(self):
@@ -882,12 +872,15 @@ def test_rank_deficient_action_silent_no_warning(self):
unit="unit",
time="time",
first_treat="first_treat",
- covariates=["cov1", "cov1_dup"]
+ covariates=["cov1", "cov1_dup"],
)
# No warnings about rank deficiency should be emitted
- rank_warnings = [x for x in w if "Rank-deficient" in str(x.message)
- or "rank-deficient" in str(x.message).lower()]
+ rank_warnings = [
+ x
+ for x in w
+ if "Rank-deficient" in str(x.message) or "rank-deficient" in str(x.message).lower()
+ ]
assert len(rank_warnings) == 0, f"Expected no rank warnings, got {rank_warnings}"
# Should still get valid results
@@ -923,14 +916,12 @@ def test_per_effect_tstat_consistency(self):
if not np.isfinite(se) or se == 0:
assert np.isnan(t_stat), (
- f"t_stat for e={e} should be NaN when SE={se}, "
- f"got t_stat={t_stat}"
+ f"t_stat for e={e} should be NaN when SE={se}, " f"got t_stat={t_stat}"
)
else:
expected = effect_data["effect"] / se
assert np.isclose(t_stat, expected), (
- f"t_stat for e={e} should be effect/SE, "
- f"expected {expected}, got {t_stat}"
+ f"t_stat for e={e} should be effect/SE, " f"expected {expected}, got {t_stat}"
)
def test_overall_tstat_nan_when_se_invalid(self):
@@ -956,22 +947,20 @@ def test_overall_tstat_nan_when_se_invalid(self):
t_stat = results.overall_t_stat
if not np.isfinite(se) or se == 0:
- assert np.isnan(t_stat), (
- f"overall_t_stat should be NaN when SE={se}, got {t_stat}"
- )
+ assert np.isnan(t_stat), f"overall_t_stat should be NaN when SE={se}, got {t_stat}"
assert np.isnan(results.overall_p_value), (
f"overall_p_value should be NaN when SE={se} (analytical inference), "
f"got {results.overall_p_value}"
)
ci = results.overall_conf_int
- assert np.isnan(ci[0]) and np.isnan(ci[1]), (
- f"overall_conf_int should be (NaN, NaN) when SE={se}, got {ci}"
- )
+ assert np.isnan(ci[0]) and np.isnan(
+ ci[1]
+ ), f"overall_conf_int should be (NaN, NaN) when SE={se}, got {ci}"
else:
expected = results.overall_att / se
- assert np.isclose(t_stat, expected), (
- f"overall_t_stat should be ATT/SE, expected {expected}, got {t_stat}"
- )
+ assert np.isclose(
+ t_stat, expected
+ ), f"overall_t_stat should be ATT/SE, expected {expected}, got {t_stat}"
def test_bootstrap_tstat_nan_when_se_invalid(self, ci_params):
"""Bootstrap t_stat uses NaN (not 0.0) when SE is non-finite or zero."""
@@ -998,9 +987,9 @@ def test_bootstrap_tstat_nan_when_se_invalid(self, ci_params):
t_stat = results.overall_t_stat
if not np.isfinite(se) or se == 0:
- assert np.isnan(t_stat), (
- f"bootstrap overall_t_stat should be NaN when SE={se}, got {t_stat}"
- )
+ assert np.isnan(
+ t_stat
+ ), f"bootstrap overall_t_stat should be NaN when SE={se}, got {t_stat}"
# Check event study effects
for e, effect_data in results.event_study_effects.items():
@@ -1030,12 +1019,14 @@ def test_aggregated_event_study_tstat_nan(self):
first_treat = 3 if unit < n_units // 2 else 0
for t in range(1, n_periods + 1):
outcome = np.random.randn()
- data.append({
- "unit": unit,
- "time": t,
- "outcome": outcome,
- "first_treat": first_treat,
- })
+ data.append(
+ {
+ "unit": unit,
+ "time": t,
+ "outcome": outcome,
+ "first_treat": first_treat,
+ }
+ )
df = pd.DataFrame(data)
@@ -1059,9 +1050,9 @@ def test_aggregated_event_study_tstat_nan(self):
f"got t_stat={t_stat}"
)
ci = effect_data["conf_int"]
- assert np.isnan(ci[0]) and np.isnan(ci[1]), (
- f"Aggregated CI for e={e} should be (NaN, NaN) when SE={se}, got {ci}"
- )
+ assert np.isnan(ci[0]) and np.isnan(
+ ci[1]
+ ), f"Aggregated CI for e={e} should be (NaN, NaN) when SE={se}, got {ci}"
else:
expected_t = effect / se
assert np.isclose(t_stat, expected_t, rtol=1e-10), (
@@ -1096,12 +1087,14 @@ def test_no_post_effects_returns_nan(self):
time_fe = np.tile(np.arange(n_periods) * 0.1, n_units)
outcomes = unit_fe + time_fe + np.random.randn(len(units)) * 0.3
- data = pd.DataFrame({
- "unit": units,
- "time": times,
- "outcome": outcomes,
- "first_treat": first_treat_expanded.astype(int),
- })
+ data = pd.DataFrame(
+ {
+ "unit": units,
+ "time": times,
+ "outcome": outcomes,
+ "first_treat": first_treat_expanded.astype(int),
+ }
+ )
sa = SunAbraham(n_bootstrap=0)
results = sa.fit(
@@ -1109,22 +1102,18 @@ def test_no_post_effects_returns_nan(self):
)
# Overall ATT and SE should be NaN
- assert np.isnan(results.overall_att), (
- f"Expected NaN overall_att, got {results.overall_att}"
- )
- assert np.isnan(results.overall_se), (
- f"Expected NaN overall_se, got {results.overall_se}"
- )
+ assert np.isnan(results.overall_att), f"Expected NaN overall_att, got {results.overall_att}"
+ assert np.isnan(results.overall_se), f"Expected NaN overall_se, got {results.overall_se}"
# Downstream inference should propagate NaN
- assert np.isnan(results.overall_t_stat), (
- f"Expected NaN overall_t_stat, got {results.overall_t_stat}"
- )
- assert np.isnan(results.overall_p_value), (
- f"Expected NaN overall_p_value, got {results.overall_p_value}"
- )
- assert np.isnan(results.overall_conf_int[0]) and np.isnan(results.overall_conf_int[1]), (
- f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
- )
+ assert np.isnan(
+ results.overall_t_stat
+ ), f"Expected NaN overall_t_stat, got {results.overall_t_stat}"
+ assert np.isnan(
+ results.overall_p_value
+ ), f"Expected NaN overall_p_value, got {results.overall_p_value}"
+ assert np.isnan(results.overall_conf_int[0]) and np.isnan(
+ results.overall_conf_int[1]
+ ), f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
def test_no_post_effects_bootstrap_returns_nan(self, ci_params):
"""Test that no post-treatment effects returns NaN even with bootstrap.
@@ -1151,12 +1140,14 @@ def test_no_post_effects_bootstrap_returns_nan(self, ci_params):
time_fe = np.tile(np.arange(n_periods) * 0.1, n_units)
outcomes = unit_fe + time_fe + np.random.randn(len(units)) * 0.3
- data = pd.DataFrame({
- "unit": units,
- "time": times,
- "outcome": outcomes,
- "first_treat": first_treat_expanded.astype(int),
- })
+ data = pd.DataFrame(
+ {
+ "unit": units,
+ "time": times,
+ "outcome": outcomes,
+ "first_treat": first_treat_expanded.astype(int),
+ }
+ )
n_boot = ci_params.bootstrap(50)
sa = SunAbraham(n_bootstrap=n_boot, seed=42)
@@ -1167,21 +1158,17 @@ def test_no_post_effects_bootstrap_returns_nan(self, ci_params):
)
# All overall inference fields should be NaN
- assert np.isnan(results.overall_att), (
- f"Expected NaN overall_att, got {results.overall_att}"
- )
- assert np.isnan(results.overall_se), (
- f"Expected NaN overall_se, got {results.overall_se}"
- )
- assert np.isnan(results.overall_t_stat), (
- f"Expected NaN overall_t_stat, got {results.overall_t_stat}"
- )
- assert np.isnan(results.overall_p_value), (
- f"Expected NaN overall_p_value with bootstrap, got {results.overall_p_value}"
- )
- assert np.isnan(results.overall_conf_int[0]) and np.isnan(results.overall_conf_int[1]), (
- f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
- )
+ assert np.isnan(results.overall_att), f"Expected NaN overall_att, got {results.overall_att}"
+ assert np.isnan(results.overall_se), f"Expected NaN overall_se, got {results.overall_se}"
+ assert np.isnan(
+ results.overall_t_stat
+ ), f"Expected NaN overall_t_stat, got {results.overall_t_stat}"
+ assert np.isnan(
+ results.overall_p_value
+ ), f"Expected NaN overall_p_value with bootstrap, got {results.overall_p_value}"
+ assert np.isnan(results.overall_conf_int[0]) and np.isnan(
+ results.overall_conf_int[1]
+ ), f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
def test_event_time_no_truncation(self):
"""Test that event times beyond ±20 are estimated (Step 5d).
@@ -1206,12 +1193,14 @@ def test_event_time_no_truncation(self):
post = (times >= first_treat_expanded) & (first_treat_expanded > 0)
outcomes = unit_fe + time_fe + 2.0 * post + np.random.randn(len(units)) * 0.3
- data = pd.DataFrame({
- "unit": units,
- "time": times,
- "outcome": outcomes,
- "first_treat": first_treat_expanded.astype(int),
- })
+ data = pd.DataFrame(
+ {
+ "unit": units,
+ "time": times,
+ "outcome": outcomes,
+ "first_treat": first_treat_expanded.astype(int),
+ }
+ )
sa = SunAbraham(n_bootstrap=0)
results = sa.fit(
@@ -1220,12 +1209,8 @@ def test_event_time_no_truncation(self):
# Verify that event times beyond ±20 are present
event_times = sorted(results.event_study_effects.keys())
- assert min(event_times) < -20, (
- f"Expected event times < -20, got min={min(event_times)}"
- )
- assert max(event_times) > 20, (
- f"Expected event times > 20, got max={max(event_times)}"
- )
+ assert min(event_times) < -20, f"Expected event times < -20, got min={min(event_times)}"
+ assert max(event_times) > 20, f"Expected event times > 20, got max={max(event_times)}"
def test_df_adjustment_sets_regression_df(self):
"""Test that df_adjustment for absorbed FE is applied correctly (Step 5a).
@@ -1249,30 +1234,31 @@ def test_df_adjustment_sets_regression_df(self):
# the last call's state.
def capturing_fit(self_reg, X, y, **kwargs):
result = original_fit(self_reg, X, y, **kwargs)
- captured_df['df'] = self_reg.df_
- captured_df['n_obs'] = self_reg.n_obs_
- captured_df['n_params_effective'] = self_reg.n_params_effective_
- captured_df['df_adjustment'] = kwargs.get('df_adjustment', 0)
+ captured_df["df"] = self_reg.df_
+ captured_df["n_obs"] = self_reg.n_obs_
+ captured_df["n_params_effective"] = self_reg.n_params_effective_
+ captured_df["df_adjustment"] = kwargs.get("df_adjustment", 0)
return result
sa = SunAbraham(n_bootstrap=0)
- with patch.object(LinearRegression, 'fit', capturing_fit):
- results = sa.fit(data, outcome="outcome", unit="unit",
- time="time", first_treat="first_treat")
+ with patch.object(LinearRegression, "fit", capturing_fit):
+ results = sa.fit(
+ data, outcome="outcome", unit="unit", time="time", first_treat="first_treat"
+ )
# Verify df_adjustment was passed and applied
n_units = data["unit"].nunique()
n_times = data["time"].nunique()
expected_df_adj = n_units + n_times - 1
- assert captured_df['df_adjustment'] == expected_df_adj, (
- f"Expected df_adjustment={expected_df_adj}, got {captured_df['df_adjustment']}"
- )
- expected_df = captured_df['n_obs'] - captured_df['n_params_effective'] - expected_df_adj
- assert captured_df['df'] == expected_df, (
- f"Expected df={expected_df}, got {captured_df['df']}"
- )
- assert captured_df['df'] > 0, "Regression df must be positive"
+ assert (
+ captured_df["df_adjustment"] == expected_df_adj
+ ), f"Expected df_adjustment={expected_df_adj}, got {captured_df['df_adjustment']}"
+ expected_df = captured_df["n_obs"] - captured_df["n_params_effective"] - expected_df_adj
+ assert (
+ captured_df["df"] == expected_df
+ ), f"Expected df={expected_df}, got {captured_df['df']}"
+ assert captured_df["df"] > 0, "Regression df must be positive"
def test_variance_fallback_warning(self):
"""Test that the variance fallback path emits a warning (Step 5e).
@@ -1289,37 +1275,52 @@ def test_variance_fallback_warning(self):
# Patch _compute_overall_att to simulate the fallback path
original_method = sa._compute_overall_att
- def patched_compute_overall_att(df, first_treat, event_study_effects,
- cohort_effects, cohort_weights,
- vcov_cohort, coef_index_map,
- survey_weight_col=None):
+ def patched_compute_overall_att(
+ df,
+ first_treat,
+ event_study_effects,
+ cohort_effects,
+ cohort_weights,
+ vcov_cohort,
+ coef_index_map,
+ survey_weight_col=None,
+ return_overall_weights=False,
+ ):
# Pass an empty coef_index_map to trigger the fallback
return original_method(
- df, first_treat, event_study_effects,
- cohort_effects, cohort_weights,
- vcov_cohort, {}, # Empty coef_index_map forces fallback
+ df,
+ first_treat,
+ event_study_effects,
+ cohort_effects,
+ cohort_weights,
+ vcov_cohort,
+ {}, # Empty coef_index_map forces fallback
+ survey_weight_col=survey_weight_col,
+ return_overall_weights=return_overall_weights,
)
- with patch.object(sa, '_compute_overall_att', side_effect=patched_compute_overall_att):
+ with patch.object(sa, "_compute_overall_att", side_effect=patched_compute_overall_att):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
results = sa.fit(
- data, outcome="outcome", unit="unit", time="time",
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
first_treat="first_treat",
)
fallback_warnings = [
- x for x in w
- if "simplified variance" in str(x.message).lower()
+ x for x in w if "simplified variance" in str(x.message).lower()
]
- assert len(fallback_warnings) > 0, (
- "Expected warning about simplified variance fallback"
- )
+ assert (
+ len(fallback_warnings) > 0
+ ), "Expected warning about simplified variance fallback"
# The result should still have a positive SE (simplified variance)
- assert results.overall_se > 0, (
- f"Expected positive SE from fallback, got {results.overall_se}"
- )
+ assert (
+ results.overall_se > 0
+ ), f"Expected positive SE from fallback, got {results.overall_se}"
def test_iw_weights_match_cohort_shares(self):
"""Test that IW weights match event-time sample shares.
@@ -1337,25 +1338,20 @@ def test_iw_weights_match_cohort_shares(self):
for e, weights in results.cohort_weights.items():
# Weights should sum to 1
total = sum(weights.values())
- assert abs(total - 1.0) < 1e-10, (
- f"Weights for e={e} sum to {total}, expected 1.0"
- )
+ assert abs(total - 1.0) < 1e-10, f"Weights for e={e} sum to {total}, expected 1.0"
# Individual weights should match event-time sample shares
cohort_counts = {}
for g in weights.keys():
cohort_counts[g] = len(
- data[
- (data["first_treat"] == g)
- & (data["time"] - data["first_treat"] == e)
- ]
+ data[(data["first_treat"] == g) & (data["time"] - data["first_treat"] == e)]
)
total_count = sum(cohort_counts.values())
for g, w in weights.items():
expected_w = cohort_counts[g] / total_count
- assert abs(w - expected_w) < 1e-10, (
- f"Weight for cohort {g} at e={e}: got {w}, expected {expected_w}"
- )
+ assert (
+ abs(w - expected_w) < 1e-10
+ ), f"Weight for cohort {g} at e={e}: got {w}, expected {expected_w}"
def test_iw_weights_unbalanced_panel(self):
"""Test that IW weights use event-time counts, not cohort sizes, for unbalanced panels."""
@@ -1386,9 +1382,9 @@ def test_iw_weights_unbalanced_panel(self):
# The dropped units are from the first cohort at max_time
affected_e = max_time - first_cohort
- assert affected_e in results.cohort_weights, (
- f"Expected event-time {affected_e} in cohort_weights but not found"
- )
+ assert (
+ affected_e in results.cohort_weights
+ ), f"Expected event-time {affected_e} in cohort_weights but not found"
weights = results.cohort_weights[affected_e]
# Verify weights use actual observation counts, not total cohort sizes
@@ -1403,9 +1399,9 @@ def test_iw_weights_unbalanced_panel(self):
total_count = sum(cohort_counts.values())
for g, w in weights.items():
expected_w = cohort_counts[g] / total_count
- assert abs(w - expected_w) < 1e-10, (
- f"Weight for cohort {g} at e={affected_e}: got {w}, expected {expected_w}"
- )
+ assert (
+ abs(w - expected_w) < 1e-10
+ ), f"Weight for cohort {g} at e={affected_e}: got {w}, expected {expected_w}"
def test_never_treated_inf_encoding(self):
"""Test that first_treat=np.inf is handled as never-treated, not as a cohort."""
@@ -1428,26 +1424,24 @@ def test_never_treated_inf_encoding(self):
# np.inf must not appear as a cohort in weights
for e, weights in results_inf.cohort_weights.items():
- assert np.inf not in weights, (
- f"np.inf found as cohort key in weights at e={e}"
- )
+ assert np.inf not in weights, f"np.inf found as cohort key in weights at e={e}"
# No ±inf in event study periods
for e in results_inf.event_study_effects.keys():
assert np.isfinite(e), f"Non-finite event time {e} in event study"
# np.inf must not appear in results.groups
- assert np.inf not in results_inf.groups, (
- f"np.inf found in results.groups: {results_inf.groups}"
- )
+ assert (
+ np.inf not in results_inf.groups
+ ), f"np.inf found in results.groups: {results_inf.groups}"
# Results should be identical to first_treat=0 encoding
- assert np.isclose(results_inf.overall_att, results_zero.overall_att), (
- f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}"
- )
- assert np.isclose(results_inf.overall_se, results_zero.overall_se), (
- f"SE differs: inf={results_inf.overall_se}, zero={results_zero.overall_se}"
- )
+ assert np.isclose(
+ results_inf.overall_att, results_zero.overall_att
+ ), f"ATT differs: inf={results_inf.overall_att}, zero={results_zero.overall_att}"
+ assert np.isclose(
+ results_inf.overall_se, results_zero.overall_se
+ ), f"SE differs: inf={results_inf.overall_se}, zero={results_zero.overall_se}"
def test_removed_params_raise_typeerror(self):
"""Removed min_pre_periods/min_post_periods raise TypeError."""
@@ -1474,3 +1468,497 @@ def test_all_never_treated_inf_raises(self):
time="time",
first_treat="first_treat",
)
+
+
+class TestSunAbrahamVcovType:
+ """Tests for SunAbraham `vcov_type` parameter (Phase 1b PR 1/8).
+
+ Threads `vcov_type ∈ {"classical","hc1","hc2","hc2_bm"}` through SA.
+ `vcov_type="hc1"` is the default and preserves prior behavior bit-equally.
+ `hc2`/`hc2_bm` route through a full-dummy saturated design (FWL preserves
+ coefficients but not the hat matrix). `classical` also uses full-dummy to
+ match R's `lm()` interpretation. `conley` is rejected at __init__ (deferred).
+ """
+
+ @staticmethod
+ def _panel():
+ return generate_staggered_data(n_units=40, n_periods=8, n_cohorts=3, seed=42)
+
+ def test_default_matches_hc1_explicit(self):
+ """SA() (default) and SA(vcov_type='hc1') produce identical ATT/SE
+ (modulo floating-point representation noise at the last bit)."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res_default = SunAbraham().fit(data, **kwargs)
+ res_explicit = SunAbraham(vcov_type="hc1").fit(data, **kwargs)
+ # Both paths execute the same code under the hood (default vcov_type
+ # is "hc1"); any divergence is float64 representation noise.
+ assert np.isclose(res_default.overall_att, res_explicit.overall_att, atol=1e-14)
+ assert np.isclose(res_default.overall_se, res_explicit.overall_se, atol=1e-14)
+ assert np.isclose(res_default.overall_p_value, res_explicit.overall_p_value, atol=1e-12)
+
+ def test_classical_finite(self):
+ """SA(vcov_type='classical') produces finite ATT/SE; drops auto-cluster."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res = SunAbraham(vcov_type="classical").fit(data, **kwargs)
+ assert np.isfinite(res.overall_att)
+ assert np.isfinite(res.overall_se)
+ assert np.isfinite(res.overall_p_value)
+ # Classical differs from HC1 cluster-at-unit in heteroscedastic / panel data
+ res_hc1 = SunAbraham(vcov_type="hc1").fit(data, **kwargs)
+ assert res.overall_se != res_hc1.overall_se
+
+ def test_hc2_finite_no_auto_cluster(self):
+ """SA(vcov_type='hc2') runs without explicit cluster (linalg validator
+ would reject hc2 + cluster_ids; SA drops the auto-cluster)."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res = SunAbraham(vcov_type="hc2").fit(data, **kwargs)
+ assert np.isfinite(res.overall_att)
+ assert np.isfinite(res.overall_se)
+ # HC2 differs from cluster-at-unit HC1 on this panel
+ res_hc1 = SunAbraham(vcov_type="hc1").fit(data, **kwargs)
+ assert res.overall_se != res_hc1.overall_se
+
+ def test_hc2_bm_finite_with_auto_cluster(self):
+ """SA(vcov_type='hc2_bm') auto-clusters at unit by default; routes to CR2-BM."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res = SunAbraham(vcov_type="hc2_bm").fit(data, **kwargs)
+ assert np.isfinite(res.overall_att)
+ assert np.isfinite(res.overall_se)
+ # HC2-BM applies BM Satterthwaite DOF; should differ from cluster-at-unit HC1
+ res_hc1 = SunAbraham(vcov_type="hc1").fit(data, **kwargs)
+ assert res.overall_se != res_hc1.overall_se
+
+ def test_hc2_bm_explicit_cluster_works(self):
+ """Explicit cluster= overrides the auto-cluster default; CR2-BM at named cluster."""
+ data = self._panel()
+ # Add a second cluster column
+ data["firm"] = (data["unit"] // 5).astype(int)
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res = SunAbraham(cluster="firm", vcov_type="hc2_bm").fit(data, **kwargs)
+ assert np.isfinite(res.overall_att)
+ assert np.isfinite(res.overall_se)
+ # Differs from cluster-at-unit (different cluster grouping)
+ res_unit = SunAbraham(vcov_type="hc2_bm").fit(data, **kwargs)
+ assert res.overall_se != res_unit.overall_se
+
+ def test_hc2_rejects_explicit_cluster(self):
+ """SA(cluster=..., vcov_type='hc2') is rejected by the linalg validator."""
+ data = self._panel()
+ data["firm"] = (data["unit"] // 5).astype(int)
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ sa = SunAbraham(cluster="firm", vcov_type="hc2")
+ with pytest.raises(ValueError, match="hc2 is one-way only"):
+ sa.fit(data, **kwargs)
+
+ def test_hc2_bm_replicate_weights_rejected(self):
+ """SA(vcov_type='hc2_bm' or 'hc2') + replicate-weight survey raises NotImplementedError."""
+ from diff_diff.survey import SurveyDesign
+
+ data = self._panel()
+ # Build a minimal BRR replicate-weight design (unit-constant weights)
+ data["w"] = 1.0
+ rep_cols = []
+ for j in range(4):
+ colname = f"r{j}"
+ data[colname] = 1.0
+ rep_cols.append(colname)
+ sd = SurveyDesign(
+ weights="w",
+ replicate_weights=rep_cols,
+ replicate_method="BRR",
+ )
+ kwargs = dict(
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ survey_design=sd,
+ )
+ for vt in ("hc2", "hc2_bm"):
+ sa = SunAbraham(vcov_type=vt)
+ with pytest.raises(NotImplementedError, match="replicate-weight"):
+ sa.fit(data, **kwargs)
+
+ def test_hc1_replicate_weights_path_unchanged(self):
+ """SA(vcov_type='hc1') + replicate-weight survey still works (regression
+ test that the Part C reject didn't broaden to non-hc2 paths)."""
+ from diff_diff.survey import SurveyDesign
+
+ data = self._panel()
+ # Unit-constant weights and replicate cols.
+ data["w"] = 1.0
+ units_sorted = sorted(data["unit"].unique())
+ rep_cols = []
+ for j, u_drop in enumerate(units_sorted[:5]):
+ colname = f"r{j}"
+ data[colname] = data["unit"].apply(lambda u: 0.0 if u == u_drop else 1.0).astype(float)
+ rep_cols.append(colname)
+ sd = SurveyDesign(
+ weights="w",
+ replicate_weights=rep_cols,
+ replicate_method="JK1",
+ )
+ sa = SunAbraham(vcov_type="hc1")
+ res = sa.fit(
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ survey_design=sd,
+ )
+ assert np.isfinite(res.overall_att)
+ # Replicate path → bootstrap-style SE; should be finite
+ assert np.isfinite(res.overall_se)
+
+ def test_get_params_includes_vcov_type(self):
+ """get_params() returns dict with the new 'vcov_type' key."""
+ sa = SunAbraham(vcov_type="hc2_bm")
+ params = sa.get_params()
+ assert "vcov_type" in params
+ assert params["vcov_type"] == "hc2_bm"
+
+ def test_set_params_updates_vcov_type_and_explicit_flag(self):
+ """set_params(vcov_type=...) updates both self.vcov_type and _vcov_type_explicit."""
+ sa = SunAbraham() # default hc1; _vcov_type_explicit=False
+ assert sa.vcov_type == "hc1"
+ assert sa._vcov_type_explicit is False
+
+ sa.set_params(vcov_type="hc2")
+ assert sa.vcov_type == "hc2"
+ assert sa._vcov_type_explicit is True # opted out of default
+
+ sa.set_params(vcov_type="hc1")
+ assert sa.vcov_type == "hc1"
+ assert sa._vcov_type_explicit is False # back to default
+
+ def test_clone_repeat_fit_idempotent(self):
+ """fit() → clone (via get_params/set_params) → re-fit produces identical results.
+ Validates per feedback_fit_does_not_mutate_config: fit() doesn't mutate
+ configuration; vcov_type and _vcov_type_explicit survive a clone."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ sa1 = SunAbraham(vcov_type="hc2_bm")
+ res1 = sa1.fit(data, **kwargs)
+ # Clone via get_params → new SunAbraham → set_params
+ sa2 = SunAbraham()
+ sa2.set_params(**sa1.get_params())
+ res2 = sa2.fit(data, **kwargs)
+ assert res1.overall_att == res2.overall_att
+ assert res1.overall_se == res2.overall_se
+ assert sa2.vcov_type == "hc2_bm"
+ assert sa2._vcov_type_explicit is True
+
+ def test_invalid_vcov_type_value_rejects(self):
+ """Unknown vcov_type raises ValueError; conley emits a deferral message."""
+ with pytest.raises(ValueError, match="vcov_type must be one of"):
+ SunAbraham(vcov_type="foo")
+ with pytest.raises(ValueError, match="conley.*not yet wired up"):
+ SunAbraham(vcov_type="conley")
+
+ def test_survey_design_rejects_non_hc1_vcov_type(self):
+ """SA(vcov_type ∈ {classical, hc2, hc2_bm}) + survey_design= raises.
+
+ Survey-design TSL (or replicate-weight refit) variance overrides
+ the analytical sandwich family, so requesting a non-HC1 analytical
+ family under a survey design would either silently discard the
+ request OR downgrade unit-level PSUs to per-observation PSUs
+ (the auto-cluster guard for one-way families drops cluster_var=None
+ before survey-PSU injection). Explicit reject preserves the
+ contract.
+ """
+ from diff_diff.survey import SurveyDesign
+
+ data = self._panel()
+ data["w"] = 1.0
+ sd = SurveyDesign(weights="w")
+ for vt in ("classical", "hc2", "hc2_bm"):
+ sa = SunAbraham(vcov_type=vt)
+ with pytest.raises(NotImplementedError, match="survey_design"):
+ sa.fit(
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ survey_design=sd,
+ )
+
+ def test_hc1_survey_design_preserves_unit_psu_injection(self):
+ """SA(vcov_type='hc1') + survey_design= (no explicit PSU/cluster)
+ injects unit as PSU; verifies n_psu == n_units and df_survey > 0.
+
+ Regression test for the P1-A cascade bug: prior to the survey-
+ reject contract, non-HC1 fits would drop the auto-cluster
+ BEFORE PSU injection, downgrading n_psu from n_units to n_obs.
+ HC1 must still inject unit-as-PSU correctly.
+ """
+ from diff_diff.survey import SurveyDesign
+
+ data = self._panel()
+ data["w"] = 1.0
+ sd = SurveyDesign(weights="w")
+ sa = SunAbraham(vcov_type="hc1") # default; auto-cluster fires
+ res = sa.fit(
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ survey_design=sd,
+ )
+ assert np.isfinite(res.overall_att)
+ assert np.isfinite(res.overall_se)
+ # Survey metadata should record unit-level PSUs (n_psu == n_units),
+ # NOT per-observation PSUs (n_psu == n_obs). On the panel: 40
+ # units × 8 periods = 320 obs.
+ assert res.survey_metadata is not None
+ n_units = data["unit"].nunique()
+ assert res.survey_metadata.n_psu == n_units, (
+ f"Expected n_psu == n_units ({n_units}); "
+ f"got {res.survey_metadata.n_psu}. If this is n_obs ({len(data)}), "
+ "the unit-as-PSU injection regressed."
+ )
+ # df_survey should be n_psu - n_strata. No explicit strata → n_strata
+ # defaults to 1 → df_survey == n_psu - 1 == n_units - 1. Asserting
+ # both n_psu and df_survey guards against a regression where one is
+ # computed correctly but the other is downgraded (e.g. n_psu=40 but
+ # df_survey=319 would suggest df was computed against n_obs).
+ assert res.survey_metadata.df_survey == n_units - 1, (
+ f"Expected df_survey == n_psu - 1 == {n_units - 1}; "
+ f"got {res.survey_metadata.df_survey}."
+ )
+
+ def test_vcov_type_propagated_to_results(self):
+ """SunAbrahamResults.vcov_type reflects the fit-time configuration.
+
+ Per the P1-B fix: downstream consumers need to know which
+ variance family generated the SEs (matters for survey
+ precedence claims, bootstrap-vs-analytical distinctions, etc.).
+ """
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ for vt in ("hc1", "classical", "hc2", "hc2_bm"):
+ sa = SunAbraham(vcov_type=vt)
+ res = sa.fit(data, **kwargs)
+ assert hasattr(
+ res, "vcov_type"
+ ), "SunAbrahamResults must expose vcov_type for downstream consumers"
+ assert res.vcov_type == vt, f"Expected res.vcov_type == {vt!r}; got {res.vcov_type!r}"
+
+ def test_hc2_bm_routes_through_full_dummy_path(self):
+ """Sanity: hc2_bm coefficient on cohort effect matches the hc2 path
+ coefficient (FWL preserves the cohort coefficient; only the SE differs
+ across vcov families). Verifies the full-dummy auto-route preserves
+ cohort_effects extraction (intercept-offset indexing)."""
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res_hc1 = SunAbraham(vcov_type="hc1").fit(data, **kwargs)
+ res_hc2 = SunAbraham(vcov_type="hc2").fit(data, **kwargs)
+ res_hc2_bm = SunAbraham(vcov_type="hc2_bm").fit(data, **kwargs)
+ # ATT should match across all vcov families (FWL)
+ assert np.isclose(res_hc1.overall_att, res_hc2.overall_att, atol=1e-12)
+ assert np.isclose(res_hc1.overall_att, res_hc2_bm.overall_att, atol=1e-12)
+
+ def test_hc2_robustness_value_differs_from_within_transform(self):
+ """Confirms hc2 routes through full-dummy: the SE differs from a manual
+ within-transform HC2 (sanity check that the auto-route is active, not
+ silently using within-transform)."""
+ from diff_diff.linalg import solve_ols
+ from diff_diff.utils import within_transform as _within_transform_util
+
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ res_hc2 = SunAbraham(vcov_type="hc2").fit(data, **kwargs)
+
+ # Manually construct the within-transform HC2 for comparison:
+ # this is what SA WOULD compute if Part G didn't auto-route.
+ df = data.copy()
+ df["_rel_time"] = np.where(df["first_treat"] > 0, df["time"] - df["first_treat"], -999)
+ treatment_groups = sorted(g for g in df["first_treat"].unique() if g > 0)
+ all_rel = sorted(e for e in df.loc[df["first_treat"] > 0, "_rel_time"].unique() if e != -1)
+ interaction_cols = []
+ for g in treatment_groups:
+ for e in all_rel:
+ name = f"_D_{g}_{e}"
+ ind = ((df["first_treat"] == g) & (df["_rel_time"] == e)).astype(float)
+ if ind.sum() > 0:
+ df[name] = ind
+ interaction_cols.append(name)
+ df_dm = _within_transform_util(
+ df, ["outcome"] + interaction_cols, "unit", "time", suffix="_dm"
+ )
+ X_dm = df_dm[[f"{c}_dm" for c in interaction_cols]].values
+ y_dm = df_dm["outcome_dm"].values
+ _, _, vcov_within = solve_ols(X_dm, y_dm, vcov_type="hc2")
+ # Within-transform HC2 SE for first interaction column
+ within_se = float(np.sqrt(vcov_within[0, 0]))
+
+ # SA's full-dummy HC2 SE for the same target should differ from
+ # within_se (different hat matrix → different leverage correction).
+ # Pick any present cohort × event-time interaction.
+ target_key = next(iter(res_hc2.cohort_effects))
+ sa_se = res_hc2.cohort_effects[target_key]["se"]
+ assert not np.isclose(sa_se, within_se, atol=1e-6), (
+ f"SA hc2 SE ({sa_se}) too close to within-transform HC2 ({within_se}); "
+ "Part G full-dummy auto-route may not be active."
+ )
+
+ def test_point_estimates_invariant_across_vcov_type_with_covariates(self):
+ """FWL invariance: cohort effects + overall ATT identical across all
+ vcov_type values when a covariate is present.
+
+ The new full-dummy branch is a manual second implementation of SA's
+ regression design. FWL preserves coefficients/residuals across the
+ within-transform vs full-dummy choice, so vcov_type must not change
+ the point estimates — only the SEs. This guards against a
+ regression in the new design-builder that would silently change
+ the estimator (not just the SEs) when covariates are passed.
+ """
+ data = self._panel()
+ # Add a time-varying covariate uncorrelated with the treatment
+ rng = np.random.default_rng(7)
+ data["x"] = rng.normal(0, 1, size=len(data))
+ kwargs = dict(
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ covariates=["x"],
+ )
+ results = {}
+ for vt in ("hc1", "classical", "hc2", "hc2_bm"):
+ results[vt] = SunAbraham(vcov_type=vt).fit(data, **kwargs)
+ # Overall ATT and event_study effects should match across vcov_type
+ for vt in ("classical", "hc2", "hc2_bm"):
+ assert np.isclose(results["hc1"].overall_att, results[vt].overall_att, atol=1e-10), (
+ f"overall_att diverges between hc1 ({results['hc1'].overall_att}) "
+ f"and {vt} ({results[vt].overall_att}); covariate threading "
+ "regression in the full-dummy design-builder."
+ )
+ for e in results["hc1"].event_study_effects:
+ hc1_eff = results["hc1"].event_study_effects[e]["effect"]
+ vt_eff = results[vt].event_study_effects[e]["effect"]
+ assert np.isclose(
+ hc1_eff, vt_eff, atol=1e-10
+ ), f"event_study[{e}] effect diverges: hc1={hc1_eff}, {vt}={vt_eff}"
+
+ def test_point_estimates_invariant_across_vcov_type_not_yet_treated(self):
+ """FWL invariance under control_group='not_yet_treated': cohort effects
+ and overall ATT identical across all vcov_type values.
+
+ The not_yet_treated control group changes the regression sample
+ (keeps all units, not just never-treated controls). The new
+ full-dummy branch must reproduce the same point estimates as the
+ within-transform path under this sample composition.
+ """
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ results = {}
+ for vt in ("hc1", "classical", "hc2", "hc2_bm"):
+ results[vt] = SunAbraham(control_group="not_yet_treated", vcov_type=vt).fit(
+ data, **kwargs
+ )
+ for vt in ("classical", "hc2", "hc2_bm"):
+ assert np.isclose(results["hc1"].overall_att, results[vt].overall_att, atol=1e-10), (
+ f"overall_att diverges between hc1 ({results['hc1'].overall_att}) "
+ f"and {vt} ({results[vt].overall_att}) under "
+ "control_group='not_yet_treated'; sample-composition regression "
+ "in the full-dummy branch."
+ )
+ for e in results["hc1"].event_study_effects:
+ hc1_eff = results["hc1"].event_study_effects[e]["effect"]
+ vt_eff = results[vt].event_study_effects[e]["effect"]
+ assert np.isclose(hc1_eff, vt_eff, atol=1e-10), (
+ f"event_study[{e}] effect diverges under not_yet_treated: "
+ f"hc1={hc1_eff}, {vt}={vt_eff}"
+ )
+
+ def test_cluster_missing_column_raises_across_vcov_types(self):
+ """Explicit cluster= referencing a missing column must raise, not
+ silently downgrade clustered inference to one-way.
+
+ Regression test for the cluster-resolution cascade: the survey-
+ path fix that tolerates `cluster_var=None` for explicit one-way
+ families (hc2, classical) would, without this guard, also treat
+ a missing user-supplied cluster as cluster_var=None, silently
+ producing HC1-no-cluster / HC2-singleton / classical SEs
+ instead of the requested cluster-robust inference. Must error
+ upfront for at least hc1, classical, hc2, hc2_bm.
+ """
+ data = self._panel()
+ for vt in ("hc1", "classical", "hc2", "hc2_bm"):
+ sa = SunAbraham(cluster="nonexistent_col", vcov_type=vt)
+ with pytest.raises(ValueError, match="cluster column"):
+ sa.fit(
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ )
+
+ def test_cluster_na_values_raise_across_vcov_types(self):
+ """Cluster columns with NA/NaN values must raise — otherwise
+ the meat-side `groupby(cluster_ids)` drops NA rows while
+ `np.unique(cluster_ids)` still counts the NA group, producing
+ silently-malformed cluster-robust SEs.
+
+ Per codex CI R4: column-existence validation is insufficient
+ — must also validate non-NA values upfront.
+ """
+ data = self._panel()
+ # Add a cluster column with NA values in some rows
+ data["my_cluster"] = data["unit"].astype("Int64")
+ data.loc[data.index[:5], "my_cluster"] = pd.NA
+ for vt in ("hc1", "hc2_bm"):
+ sa = SunAbraham(cluster="my_cluster", vcov_type=vt)
+ with pytest.raises(ValueError, match="NA/NaN values"):
+ sa.fit(
+ data,
+ outcome="outcome",
+ unit="unit",
+ time="time",
+ first_treat="first_treat",
+ )
+
+ def test_bootstrap_finite_and_point_estimates_invariant(self, ci_params):
+ """Bootstrap (n_bootstrap > 0) coverage for the new vcov_type paths.
+
+ Both `_run_bootstrap` and `_run_rao_wu_bootstrap` re-enter
+ `_fit_saturated_regression(..., vcov_type=self.vcov_type)`, so
+ bootstrap fits with the new families execute the new full-dummy
+ plumbing on every refit. This test asserts:
+ 1. Bootstrap output is finite for each vcov_type
+ 2. Point estimates (overall ATT, event-study effects) are
+ invariant across vcov_type at the same seed (FWL)
+ The bootstrap SE may legitimately differ across vcov_type because
+ the multiplier-bootstrap inputs depend on the per-fit residuals,
+ but the per-fit residuals are FWL-invariant; the SE divergence
+ comes from numerical precision in the iterated refits.
+ """
+ n_boot = ci_params.bootstrap(50)
+ data = self._panel()
+ kwargs = dict(outcome="outcome", unit="unit", time="time", first_treat="first_treat")
+ results = {}
+ for vt in ("hc1", "classical", "hc2", "hc2_bm"):
+ results[vt] = SunAbraham(vcov_type=vt, n_bootstrap=n_boot, seed=42).fit(data, **kwargs)
+ # Bootstrap outputs must be finite for each family
+ assert np.isfinite(
+ results[vt].overall_att
+ ), f"overall_att non-finite for vcov_type={vt} under bootstrap"
+ assert np.isfinite(
+ results[vt].overall_se
+ ), f"bootstrap SE non-finite for vcov_type={vt}"
+ # Point estimates must match across vcov_type (FWL preserves them
+ # across within-transform vs full-dummy refits).
+ for vt in ("classical", "hc2", "hc2_bm"):
+ assert np.isclose(results["hc1"].overall_att, results[vt].overall_att, atol=1e-10), (
+ f"Bootstrap overall_att diverges: hc1={results['hc1'].overall_att}, "
+ f"{vt}={results[vt].overall_att}. Bootstrap refits in the new "
+ "full-dummy branch are not FWL-invariant."
+ )