Skip to content

Fix infinite norm negative axis mismatch bug for matrices with 2 or more dimensions.#3756

Open
danlee2002 wants to merge 3 commits into
ml-explore:mainfrom
danlee2002:fix-inf-norm-axis
Open

Fix infinite norm negative axis mismatch bug for matrices with 2 or more dimensions.#3756
danlee2002 wants to merge 3 commits into
ml-explore:mainfrom
danlee2002:fix-inf-norm-axis

Conversation

@danlee2002

@danlee2002 danlee2002 commented Jun 23, 2026

Copy link
Copy Markdown

Proposed changes

The following changes aims to fix an axis mismatch bug that occurs when one tries to calculate the $\pm\infty$ norm of some matrix A with 2 or more dimensions by passing in a series of negative axes and when keepdim is false.
As an example consider the following snippet of code:

import mlx.core as mx 
from mlx.core import linalg as la 
import numpy as np 
a_np = np.arange(9, dtype =np.float32).reshape((3,3))
a_mx = mx.arange(9, dtype = mx.float32).reshape((3,3)) 
# numpy as reference: 21.0
print(np.linalg.norm(a_np, ord = np.inf , axis  = (-2,-1)))
# raises IndexError: Invalid axis -2 for array with 1 dimensions.
print(la.norm(a_mx, ord = float('inf'), axis  = (-2,-1)))

Based on the numpy implementation, one should expect the following to return the float value of 21.0. However, we get an index error and the culprit of this error is how the matrix_norm is calculated when keepdims is false and we receive negative index.
If we take a look at the matrix_norm code. It's clear that the norm method for pos/neg infinite norm is implemented series of chained reductions:
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s)
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s)
Hence, when we specify that we want to reduce along (row_axis, col_axis) and keepdims is false, we need to account for the reduction in dimensions when we call max or min. Otherwise, we either end up with an index out of bound error or we reduce the wrong axis by one due to an off by one error.

The non-negative cases handle this via:
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
However, this is skipped for negative cases due to the row_axis > 0. To address this issue, this pr converts the respective for (row_axis, col_axis) to the equivalent non-negative values when negative axes values. This ensures proper behavior is achieved.

Testing

  • Ran existing unit tests for C++ and Python API on M5 Macbook Air
  • Added tests for pos/neg infinite norm for all possible negative indexing axis tuples for shapes (3, 3), (2, 3, 3), (2, 3, 3, 3) where outputs were verified against expected numpy outputs.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@danlee2002 danlee2002 changed the title Fix inf norm axis Fix infinite norm negative axis mismatch bug for matrices with 2 or more dimensions. Jun 23, 2026
@danlee2002

Copy link
Copy Markdown
Author

@zcbenz Could you review this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant