Skip to content

Conversation

@xingmingyyj
Copy link
Contributor

@xingmingyyj xingmingyyj commented Jul 15, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

  • 修复fused_rms_norm 大Tensor 访存越界问题
  • fused_rms_norm的kernel计算流程是将每一行加载到shared mem中,所以当col的取值过大时会导致kernel launch失败,在kernel中未做强制检查,导致kernel未launch就直接退出,输出结果变为全0。这里补充检查。
  • 当输入数据类型为float16时,fused_rms_norm中会将数据cast成float32参与norm计算,以提升精度。在float16下,可以和下面的torch实现对齐精度。
class RMSNormFunction(torch.autograd.Function): @staticmethod def forward( ctx, x: torch.Tensor, norm_weight: torch.Tensor, norm_bias: Optional[torch.Tensor], epsilon: float, begin_norm_axis: int, bias: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, quant_scale: float = -1.0, quant_round_type: int = 0, quant_max_bound: float = 0.0, quant_min_bound: float = 0.0, ) -> torch.Tensor: """Forward pass of RMSNorm.""" def _flatten_from_axis(tensor: torch.Tensor, axis: int) -> torch.Tensor: """Flatten tensor starting from given axis.""" rows = torch.prod(torch.tensor(tensor.shape[:axis])).item() cols = torch.prod(torch.tensor(tensor.shape[axis:])).item() return tensor.reshape(rows, cols) # Save original dtype and shape for later origin_dtype = x.dtype origin_shape = x.shape # Convert inputs to float32 if needed x = x.float() if x.dtype == torch.float16 else x norm_weight = norm_weight.float() if norm_weight.dtype == torch.float16 else norm_weight residual = residual.float() if residual is not None and residual.dtype == torch.float16 else residual bias = bias.float() if bias is not None and bias.dtype == torch.float16 else bias norm_bias = norm_bias.float() if norm_bias is not None and norm_bias.dtype == torch.float16 else norm_bias # Apply residual and bias if provided output = x if residual is not None: output = output + residual if bias is not None: output = output + bias # Normalization output = _flatten_from_axis(output, begin_norm_axis) output_sq = output.pow(2) mean_output_sq = output_sq.mean(dim=-1, keepdim=True) rms = torch.sqrt(mean_output_sq + epsilon) invvar = 1.0 / rms output_norm = output * invvar output = output_norm * norm_weight # Add norm_bias if provided if norm_bias is not None: output = output + _flatten_from_axis(norm_bias, begin_norm_axis) # Quantization if enabled if quant_scale > 0: output = output / quant_scale if quant_round_type == 0: output = torch.round(output) elif quant_round_type == 1: output = torch.where( output >= 0, torch.ceil(output - 0.5), torch.floor(output + 0.5), ) else: raise ValueError(f"Unsupported quant_round_type: {quant_round_type}") output = output * quant_scale output = torch.clamp(output, min=quant_min_bound, max=quant_max_bound) # Convert back to original dtype if no quantization if origin_dtype == torch.float16 and quant_scale <= 0: output = output.to(origin_dtype) norm_weight = norm_weight.to(origin_dtype) # Save tensors and metadata for backward ctx.save_for_backward(x, norm_weight, invvar) ctx.epsilon = epsilon ctx.exist_residual = residual is not None ctx.exist_bias = bias is not None ctx.exist_norm_bias = norm_bias is not None ctx.quant_scale = quant_scale ctx.begin_norm_axis = begin_norm_axis ctx.origin_shape = origin_shape return output @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: """Backward pass of RMSNorm.""" def _flatten_from_axis(tensor: torch.Tensor, axis: int) -> torch.Tensor: """Flatten tensor starting from given axis.""" rows = torch.prod(torch.tensor(tensor.shape[:axis])).item() cols = torch.prod(torch.tensor(tensor.shape[axis:])).item() return tensor.reshape(rows, cols) exist_bias = ctx.exist_bias exist_residual = ctx.exist_residual exist_norm_bias = ctx.exist_norm_bias quant_scale = ctx.quant_scale if quant_scale > 0 or exist_norm_bias or exist_bias or exist_residual: raise NotImplementedError # Retrieve saved tensors and metadata x, weight, invvar = ctx.saved_tensors origin_shape = ctx.origin_shape origin_dtype = grad_output.dtype # Flatten tensors for computation grad_output = _flatten_from_axis(grad_output.float(), ctx.begin_norm_axis) x = _flatten_from_axis(x.float(), ctx.begin_norm_axis) weight = weight.float() # Gradient w.r.t. weight (gamma) x_norm = x * invvar grad_weight = (grad_output * x_norm).sum(dim=tuple(range(grad_output.dim() - 1)), keepdim=False) grad_weight = grad_weight.to(origin_dtype) # Gradient w.r.t. input (x) D = x.size(-1) S = (grad_output * weight * x * invvar).sum(dim=1, keepdim=True) term1 = invvar / D grad_x = (D * grad_output * weight - x * invvar * S) * term1 grad_x = grad_x.to(origin_dtype).reshape(origin_shape) # Return gradients (order matches forward inputs) return ( grad_x, # x grad_weight, # norm_weight None, # norm_bias None, # epsilon None, # begin_norm_axis None, # bias None, # residual None, # quant_scale None, # quant_round_type None, # quant_max_bound None, # quant_min_bound )

Pcard-73263

@paddle-bot
Copy link

paddle-bot bot commented Jul 15, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@xingmingyyj xingmingyyj force-pushed the fix_rms_norm_kernel branch 2 times, most recently from 23c2faa to 046eee7 Compare July 15, 2025 13:56
@xingmingyyj xingmingyyj force-pushed the fix_rms_norm_kernel branch from 046eee7 to b99d0af Compare July 15, 2025 14:14
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
const int64_t read_idx = threadIdx.y * blockDim.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int64_t read_idx = static_cast<int64_t>(threadIdx.y) * blockDim.x + threadIdx.x;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@xingmingyyj
Copy link
Contributor Author

/re-run all-failed

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 971eac1 into PaddlePaddle:develop Jul 21, 2025
72 of 73 checks passed
@xingmingyyj xingmingyyj deleted the fix_rms_norm_kernel branch July 30, 2025 02:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants