Skip to content

Conversation

@lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Jun 10, 2024

int4 weight can be enabled by torch.ops.xla.quantized_matmul(x, weight, weight_scaler, int4_weight=True), XlaQuantizedLinear(...,int4_weight=True)

The matmul w/ int4 workflow is:

  1. The int4 weight is stored in int8 container (unpacked)
  2. During HLO lowering, xla::Literal will be created for the int4 weights
  3. F.linear on the activation and int4 weight

Original plan was to pack int4 values in int8 container, and do reinterpret cast, but reinterpret cast does't work on TPU now.

Test:
Added tests for quantized op and linear module.

@lsy323 lsy323 marked this pull request as ready for review June 10, 2024 22:41
@lsy323 lsy323 force-pushed the lsiyuan/int4-quant-ops branch from 4330117 to 03f46f1 Compare June 10, 2024 23:01
@JackCaoG JackCaoG self-requested a review June 10, 2024 23:05
@lsy323 lsy323 requested a review from JackCaoG June 10, 2024 23:17
@lsy323
Copy link
Collaborator Author

lsy323 commented Jun 10, 2024

Removed pack/unpack logic and test since not used now.

@lsy323 lsy323 merged commit ac371fb into master Jun 11, 2024
@miladm miladm assigned miladm and lsy323 and unassigned miladm Jun 13, 2024
@lsy323 lsy323 deleted the lsiyuan/int4-quant-ops branch December 6, 2024 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants