@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
128128 }
129129}
130130
131+ template <typename T>
132+ __device__ __forceinline__ T swigluoai_and_mul (const T& gate, const T& up,
133+ float alpha, float limit) {
134+ // clamp gate: min=None, max=limit
135+ const float gate_f = (float )gate;
136+ const float clamped_gate = gate_f > limit ? limit : gate_f;
137+
138+ // clamp up: min=-limit, max=limit
139+ const float up_f = (float )up;
140+ const float clamped_up =
141+ up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
142+
143+ // glu = gate * sigmoid(gate * alpha)
144+ const float sigmoid_val = 1 .0f / (1 .0f + expf (-clamped_gate * alpha));
145+ const float glu = clamped_gate * sigmoid_val;
146+
147+ // (up + 1) * glu
148+ return (T)((clamped_up + 1 .0f ) * glu);
149+ }
150+
151+ template <typename scalar_t ,
152+ scalar_t (*ACT_FN)(const scalar_t &, const scalar_t &, const float ,
153+ const float )>
154+ __global__ void swigluoai_and_mul_kernel (
155+ scalar_t * __restrict__ out, // [..., d]
156+ const scalar_t * __restrict__ input, // [..., 2, d]
157+ const int d, const float alpha, const float limit) {
158+ const int64_t token_idx = blockIdx .x ;
159+ // TODO: Vectorize loads and stores.
160+ for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
161+ // gate = x[..., ::2] (even indices)
162+ const scalar_t gate = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx]);
163+ // up = x[..., 1::2] (odd indices)
164+ const scalar_t up = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx + 1 ]);
165+
166+ out[token_idx * d + idx] = ACT_FN (gate, up, alpha, limit);
167+ }
168+ }
169+
131170} // namespace vllm
132171
133172#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (KERNEL, PARAM ) \
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
145184 PARAM); \
146185 });
147186
187+ #define LAUNCH_SIGLUOAI_AND_MUL (KERNEL, ALPHA, LIMIT ) \
188+ int d = input.size(-1 ) / 2 ; \
189+ int64_t num_tokens = input.numel() / input.size(-1 ); \
190+ dim3 grid (num_tokens); \
191+ dim3 block (std::min(d, 1024 )); \
192+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
193+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
194+ VLLM_DISPATCH_FLOATING_TYPES ( \
195+ input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
196+ vllm::swigluoai_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
197+ <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
198+ input.data_ptr <scalar_t >(), d, ALPHA, \
199+ LIMIT); \
200+ });
201+
148202void fatrelu_and_mul (torch::Tensor& out, // [..., d],
149203 torch::Tensor& input, // [..., 2 * d]
150204 double threshold) {
151205 LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (vllm::fatrelu_kernel, threshold);
152206}
207+ void swigluoai_and_mul (torch::Tensor& out, // [..., d]
208+ torch::Tensor& input, // [..., 2 * d]
209+ double alpha, double limit) {
210+ LAUNCH_SIGLUOAI_AND_MUL (vllm::swigluoai_and_mul, alpha, limit);
211+ }
153212namespace vllm {
154213
155214// Element-wise activation kernel template.
0 commit comments