Skip to content
104 changes: 104 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,107 @@ def _(
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)


torch.library.define(
"bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
)


@register_fake("bitsandbytes::optimizer_update_32bit")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]

torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)


torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
)


@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]

torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
torch._check(
state1.dtype == torch.uint8,
lambda: f"state1 must be uint8, got {state1.dtype}",
)
torch._check(
qmap1.dtype == absmax1.dtype == torch.float32,
lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
)
if state2 is not None:
torch._check(
state2.dtype == torch.uint8,
lambda: f"state2 must be uint8, got {state2.dtype}",
)
torch._check(
qmap2.dtype == absmax2.dtype == torch.float32,
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
)
226 changes: 226 additions & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,229 @@ def _gemv_4bit_impl(
ct.c_int32(blocksize),
stream,
)


"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"momentum": (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
),
"rmsprop": (
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
),
"lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
),
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}

str2optimizer8bit_blockwise = {
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}


def _optimizer_update_32bit_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
optim_fns = str2optimizer32bit.get(optimizer_name, None)
if optim_fns is None:
raise ValueError(
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
)
if g.dtype == torch.float32:
optim_func = optim_fns[0]
elif g.dtype == torch.float16:
optim_func = optim_fns[1]
elif g.dtype == torch.bfloat16 and len(optim_fns) == 3:
optim_func = optim_fns[2]
else:
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)

with _cuda_device_of(g):
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)


def _optimizer_update_8bit_blockwise_impl(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
# torch._check(
# g.numel() == p.numel(),
# lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
# )
# compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]

# torch._check(
# g.dtype in compute_dtypes,
# lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
# )
# torch._check(
# g.dtype == p.dtype,
# lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
# )
# torch._check(
# state1.dtype == torch.uint8,
# lambda: f"state1 must be uint8, got {state1.dtype}",
# )
# torch._check(
# qmap1.dtype == absmax1.dtype == torch.float32,
# lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
# )
# if state2 is not None:
# torch._check(
# state2.dtype == torch.uint8,
# lambda: f"state2 must be uint8, got {state2.dtype}",
# )
# torch._check(
# qmap2.dtype == absmax2.dtype == torch.float32,
# lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
# )
optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name)
if optimizer_fns is None:
raise ValueError(
f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
)

if g.dtype == torch.float32:
optimizer_fn = optimizer_fns[0]
elif g.dtype == torch.float16:
optimizer_fn = optimizer_fns[1]
elif g.dtype == torch.bfloat16:
optimizer_fn = optimizer_fns[2]
else:
raise ValueError(
f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
)

with _cuda_device_of(g):
optimizer_fn(
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)


register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)
Loading
Loading