Skip to content
Open
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7904256
Naive implementation, do not merge
jessegrabowski May 23, 2025
db5b23c
Implement suggestions
jessegrabowski May 23, 2025
c687856
Simplify perf test
jessegrabowski May 23, 2025
4db2a33
float32 compat in tests
jessegrabowski May 23, 2025
3504f0b
Remove np.pad
jessegrabowski May 23, 2025
1bcf463
set dtype correctly
jessegrabowski May 23, 2025
0ce2cae
fix signature, add infer_shape
jessegrabowski May 23, 2025
161e172
micro-optimizations
jessegrabowski May 23, 2025
1ddd529
Rename b to x, matching BLAS docs
jessegrabowski May 24, 2025
b16189e
Add numba dispatch for banded_dot
jessegrabowski May 24, 2025
a902694
Eliminate extra copy in numba impl
jessegrabowski May 24, 2025
6becc7d
Create `A_banded` as F-contiguous array
jessegrabowski May 24, 2025
22578f3
Remove benchmark
jessegrabowski May 24, 2025
65c485e
Don't cache numba function
jessegrabowski May 24, 2025
905fc7c
all hail mypy
jessegrabowski May 24, 2025
687877c
set INCX by strides
jessegrabowski May 24, 2025
62ccf13
relax tolerance of float32 test
jessegrabowski May 24, 2025
8d30a29
Add suggestions
jessegrabowski May 25, 2025
e3d0b14
Test strides
jessegrabowski May 25, 2025
21873a9
Add L_op
jessegrabowski May 25, 2025
c1b6e01
*remove* type hints to make mypy happy
jessegrabowski May 25, 2025
e62b613
Remove order argument from numba A_to_banded
jessegrabowski May 25, 2025
025879a
Incorporate feedback
jessegrabowski May 25, 2025
beeec6a
Adjust numba test
jessegrabowski May 25, 2025
f467322
Remove more useful type information for mypy
jessegrabowski May 25, 2025
976422f
Fix negative strides
jessegrabowski Jun 10, 2025
72ba0dc
Rename `BandedDot` to `BandedGEMV` and move to `blas.py`
jessegrabowski Jun 24, 2025
eb50ca6
Add numba `gemv` overload
jessegrabowski Jun 24, 2025
976fd5b
All hail mypy
jessegrabowski Jun 26, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adjust numba test
  • Loading branch information
jessegrabowski committed Jun 10, 2025
commit beeec6ad4ea493d6b681fbcea0617b8f3b1a1b81
29 changes: 20 additions & 9 deletions tests/link/numba/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,25 +724,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
np.testing.assert_allclose(b_val_not_contig, b_val)


@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}")
def test_banded_dot(stride):
def test_banded_dot():
rng = np.random.default_rng()

A = pt.tensor("A", shape=(10, 10), dtype=config.floatX)
A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)

x_shape = (10 * abs(stride),)
x_val = rng.normal(size=x_shape).astype(config.floatX)
x_val = x_val[::stride]

A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
x = pt.tensor("x", shape=(10,), dtype=config.floatX)
x_val = rng.normal(size=(10,)).astype(config.floatX)

output = banded_dot(A, x, upper_diags=1, lower_diags=1)

compare_numba_and_py(
fn, _ = compare_numba_and_py(
[A, x],
output,
test_inputs=[A_val, x_val],
numba_mode=numba_inplace_mode,
eval_obj_mode=False,
)

for stride in [2, -1, -2]:
x_shape = (10 * abs(stride),)
x_val = rng.normal(size=x_shape).astype(config.floatX)
x_val = x_val[::stride]

nb_output = fn(A_val, x_val)
expected = A_val @ x_val

np.testing.assert_allclose(
nb_output,
expected,
strict=True,
err_msg=f"Test failed for stride = {stride}",
)