Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
revert
  • Loading branch information
crusaderky committed Apr 18, 2025
commit 49f9ba72709a4b8ab6886ecd047f8ef1545d7871
11 changes: 6 additions & 5 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def vector_norm(
if axis is None:
# Note: xp.linalg.norm() doesn't handle 0-D arrays
_x = x.ravel()
axis = 0
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# xp.linalg.norm() only supports a single axis for vector norm.
Expand All @@ -176,24 +176,25 @@ def vector_norm(
newshape = axis + rest
_x = xp.transpose(x, newshape).reshape(
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
axis = 0
_axis = 0
else:
_x = x
_axis = axis

res = xp.linalg.norm(_x, axis=axis, ord=ord)
res = xp.linalg.norm(_x, axis=_axis, ord=ord)

if keepdims:
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
axis = cast(
_axis = cast(
"tuple[int, ...]",
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
range(x.ndim) if axis is None else axis,
x.ndim,
),
)
for i in axis:
for i in _axis:
shape[i] = 1
res = xp.reshape(res, tuple(shape))

Expand Down
Loading