WIP: Add normalize parameter to sliced_wasserstein_distance#808
Conversation
|
Thanks for this PR, we are a bit busy at te moement and will have more time to give some feedback after the neurips deadline in two weeks. |
|
Sounds good, best of luck with the NeurIPS. |
|
Hello @Harguna , thanks for the PR. We had a look with @clbonet and we are not really comfortable with having a normalization inside the sliced wasserstein function. While this might make sens in some applications it also means that for instance when optimizing the SWD, the loss between two optimization steps or minibtach is not comparable (since normalized locally) which poses a practical problem because it is an intuitive behavior and leads to different minimizers. But we agree with you that normalization should be easier to handle. So we propose to handle it in a slightly different way as follows : scaler = ot.utils.DataScaler(norm='standard').fit([X_s,X_t]) # can take a tensor or a list for joint normalization
swd = ot.sliced_wasserstein_distance(X_s, X_t, scaler=scaler)this means that the normalization is fitted outside on a class (compatible with sklearn with a fit and transform function but that handles backends). The def apply_scaler(X_s, X_t, scaler=None)that handles the preprocessing of the data (or not if scaler=None) so that we can add this API to other functions in POT such as Would you be OK with implementing our suggestions? |
|
Hello @rflamary, Thanks for the detailed feedback, this makes sense. I had accounted for the relative shift between I agree with your suggested design which decouples the fitting step from the distance computation. I'm happy to implement DataScaler with backend compatibility and apply_scaler as a standalone helper so it can be reused across other POT functions. I will get started on that. |
Types of changes
Motivation and context / Related issue
Addresses #807.
Sliced Wasserstein Distance is sensitive to feature scale: features with larger numerical ranges dominate the random projections, drowning out meaningful differences in smaller-scale features. Users often don't realize this is happening and, when they do, the manual fix (preprocessing inputs with a scaler) is verbose and easy to get wrong — fitting each distribution independently silently corrupts the distance.
This PR adds optional
normalizeandnormalize_modeparameters tosliced_wasserstein_distanceandmax_sliced_wasserstein_distanceto handle this cleanly inside the function. Default behavior (normalize=None) is unchanged, so the change is fully backward-compatible.This is a
[WIP]skeleton PR - it establishes the API surface, signatures, docstrings, and a helper function so the design can be reviewed before the full implementation lands. The actual normalization math, edge case handling, behavioral tests, and example script will follow in subsequent commits on this same branch.How has this been tested (if it applies)
In this skeleton:
test/test_sliced.pytest suite continues to pass (verifies the new keyword parameters didn't break anything).pre-commit run --all-filespasses locally.Tests related to the new feature will be added with the full implementation in the subsequent commits.
PR checklist