Skip to content

Commit 46442b0

Browse files
Support 4bit torch.compile fullgraph with PyTorch nightly (#1616)
1 parent c244e98 commit 46442b0

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

bitsandbytes/nn/modules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,13 @@ def from_prequantized(
290290

291291
return self
292292

293+
@classmethod
294+
def __torch_function__(cls, func, types, args=(), kwargs=None):
295+
if kwargs is None:
296+
kwargs = {}
297+
with torch._C.DisableTorchFunctionSubclass():
298+
return func(*args, **kwargs)
299+
293300
def _quantize(self, device):
294301
w = self.data.contiguous().to(device)
295302
w_4bit, quant_state = bnb.functional.quantize_4bit(
@@ -486,7 +493,7 @@ def forward(self, x: torch.Tensor):
486493

487494
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
488495

489-
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
496+
return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
490497

491498

492499
class LinearFP4(Linear4bit):

0 commit comments

Comments
 (0)