Skip to content

feat(aggregation): Add ExcessMTLWeighting#747

Open
KhusPatel4450 wants to merge 1 commit into
SimplexLab:mainfrom
KhusPatel4450:feat/excess-mtl-weighting
Open

feat(aggregation): Add ExcessMTLWeighting#747
KhusPatel4450 wants to merge 1 commit into
SimplexLab:mainfrom
KhusPatel4450:feat/excess-mtl-weighting

Conversation

@KhusPatel4450

Copy link
Copy Markdown
Contributor

feat(aggregation): Add ExcessMTLWeighting

Implements ExcessMTLWeighting from Robust Multi-Task Learning with Excess Risks (He et al., ICML 2024).

At each forward call, per-task excess risks are estimated via a second-order Taylor approximation (Equations 6-7) using an AdaGrad-style diagonal Hessian accumulated across all calls. Task weights are then updated via an exponentiated gradient step (Equation 9) and normalised to the probability simplex.

Design notes

  • State: two registered buffers, _grad_sum ([m, n], accumulates squared gradients) and _weights ([m], current task weights). Both move with .to(device) and appear in state_dict().
  • Warmup (n_warmup_steps, default 0): during warmup, weights stay uniform and gradient statistics are collected. On the first post-warmup call, the average excess risk over the warmup period is saved as a normalisation baseline (initial_w), following Appendix C.1. Setting n_warmup_steps=0 matches the official implementation and LibMTL behaviour (first call's excess used as baseline directly).
  • Normalisation convention: weights initialised to [1/m, ..., 1/m] and always sum to 1, following the paper (vs. LibMTL's sum-to-m).
  • _n_steps: stored as a registered buffer (scalar torch.long) so warmup progress survives checkpointing. Zeroed in-place in reset() to preserve device placement.

References

@KhusPatel4450 KhusPatel4450 added cc: feat Conventional commit type for new features. package: aggregation labels Jun 20, 2026
@github-actions github-actions Bot changed the title feat(Aggregation): Add ExcessMTLWeighting feat(aggregation): Add ExcessMTLWeighting Jun 20, 2026
@KhusPatel4450 KhusPatel4450 force-pushed the feat/excess-mtl-weighting branch from 5143ce9 to d4a29f8 Compare June 20, 2026 15:20
@KhusPatel4450 KhusPatel4450 force-pushed the feat/excess-mtl-weighting branch from d4a29f8 to 5a6358f Compare June 20, 2026 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant