Fix Imputer.fit discarding weights and accepting NaN weights#178
Fix Imputer.fit discarding weights and accepting NaN weights#178
Conversation
Previously, weight_col was used only to construct bootstrap-resample
probabilities over X_train, and the resampled dataset was then fed
UNWEIGHTED into the underlying estimator (QRF, OLS, statmatch, MDN).
Effective sample size shrank, rare donors were dropped entirely, and
variance was inflated relative to the correct weighted estimator. Worse,
the "positive weight" guard used (weights <= 0).any() which returns False
on NaN, allowing NaN weights to propagate into .sample() probabilities.
Changes
- Imputer.fit now threads sample_weight through to _fit() so each learner
uses its native weighted-fit API:
- QRF: sample_weight=... to RandomForestQuantileRegressor.fit and
RandomForestClassifier.fit (for categorical/boolean targets).
- OLS: sm.WLS replaces sm.OLS when weights are provided;
LogisticRegression.fit gets sample_weight for classification targets.
- Matching: donor_sample_weight is passed as weight.don into R
StatMatch.NND.hotdeck.
- Models that do NOT support weighted fit (QuantReg, MDN) now raise
NotImplementedError when weights are provided, instead of silently
ignoring them.
- The weight-validation check now detects NaN AND non-positive weights
with np.isnan(weights) | (weights <= 0) and raises a clear error.
- QRF max_train_samples subsampling is switched to positional indexing so
sample_weight stays aligned with X_train.reset_index.
Tests
- test_imputer_rejects_nan_weights / test_imputer_rejects_zero_weights
- test_weighted_fit_differs_from_unweighted: WLS-fit OLS on an asymmetric
dataset (block of weight 50 vs weight 1 with different slopes) must
differ from unweighted OLS.
- Updates test_weighted_training so QuantReg/MDN assert the expected
NotImplementedError.
- Fixes an unrelated pre-existing test_reproducibility parametrize that
referenced Matching unconditionally (fails when rpy2 isn't installed).
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
MaxGhenis
left a comment
There was a problem hiding this comment.
Weights threaded to learners verified on all five paths:
- QRF:
RandomForestQuantileRegressor.fit(sample_weight=...)andRandomForestClassifier.fit(sample_weight=...). - OLS → WLS:
sm.WLS(y, X_with_const, weights=weights).fit()when sample_weight supplied. - Logistic:
LogisticRegression.fit(..., sample_weight=...). - StatMatch:
donor_sample_weight→ Rweight_donvia localconverter (statmatch_hotdeck.py). - QuantReg and MDN raise
NotImplementedError— explicit failure beats silent discard.
Weight validation fixed: np.isnan(weights) | (weights <= 0) (previously only <=0 which missed NaN). weight_col now accepts pd.Series. max_train_samples subsampling switched to positional rng.choice so sample_weight[sel] stays aligned.
Tests cover each path: test_imputer_rejects_nan_weights, test_imputer_rejects_zero_weights, test_weighted_fit_differs_from_unweighted (WLS vs OLS on asymmetric weights/slopes), test_weighted_training parametrised with expected NotImplementedError for QuantReg/MDN. Pre-existing test_reproducibility import guard added for missing rpy2.
Minor nit (not blocking): weights_arr is defined inside the if weights is not None: block above and re-referenced in a second conditional block — both guarded by the same condition, so it's correct, just slightly noisy.
CI all green. Mergeable. LGTM.
Summary
Fixes finding #4 from the bug hunt.
Imputer.fit(weight_col=...)previously used weights only to construct a bootstrap-resample probability overX_train, then fed the resampled dataset UNWEIGHTED into the underlying estimator. Effective sample size shrank, rare donors were dropped entirely, and variance was inflated relative to the correct weighted estimator. The weight validation(weights <= 0).any()also silently accepted NaN values.Changes
Imputer.fitnow threadssample_weightthrough to each model's_fit()so learners use their native weighted-fit API:RandomForestQuantileRegressor.fit(sample_weight=...)andRandomForestClassifier.fit(sample_weight=...)sm.WLSreplacessm.OLSwhen weights are provided;LogisticRegression.fit(sample_weight=...)for classificationdonor_sample_weightflows into R StatMatch viaweight.donNotImplementedErrorrather than silently discarding weights.np.isnan(weights) | (weights <= 0)to catch both NaN and non-positive values.QRF.max_train_samplessubsampling switches to positional indexing sosample_weightstays aligned.Test plan
test_imputer_rejects_nan_weightspassestest_imputer_rejects_zero_weightspassestest_weighted_fit_differs_from_unweightedconfirms WLS differs from OLS on asymmetric datatest_weighted_trainingpasses (OLS/QRF fit weighted; QuantReg/MDN raise)tests/test_models/suite (100 passed, 2 skipped)