Skip to content

Commit b7e60ca

Browse files
Improve torch.compile support for int8 with torch>=2.8 nightly (#1617)
1 parent 46442b0 commit b7e60ca

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def forward(
236236
ctx.state = state
237237

238238
ctx.grad_shape = input_shape
239-
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
239+
ctx.dtype_A = A.dtype
240+
ctx.dtype_bias = None if bias is None else bias.dtype
240241

241242
if any(ctx.needs_input_grad[:2]):
242243
ctx.tensors = (CAt, subA, A)

0 commit comments

Comments
 (0)