Skip to content
14 changes: 14 additions & 0 deletions csrc/generation/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

__global__ void RemovePaddingV2(int64_t *output_data,
Expand Down
37 changes: 37 additions & 0 deletions csrc/generation/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,44 @@
#pragma once

#include "paddle/extension.h"
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#include <hipcub/hipcub.hpp>
#include <hiprand.h>
#include <hiprand_kernel.h>
namespace cub = hipcub;
#else
#include <cub/cub.cuh>
#include <curand_kernel.h>
#endif

constexpr int kBlockSize = 256;
constexpr int kNumWaves = 16;

#ifdef PADDLE_WITH_HIP
inline hipError_t GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
hipError_t err = hipGetDevice(&dev);
if (err != hipSuccess) { return err; }
}
int sm_count;
{
hipError_t err = hipDeviceGetAttribute(&sm_count, hipDeviceAttributeMultiprocessorCount, dev);
if (err != hipSuccess) { return err; }
}
int tpm;
{
hipError_t err = hipDeviceGetAttribute(&tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev);
if (err != hipSuccess) { return err; }
}
*num_blocks = std::max<int>(1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return hipSuccess;
}
#else
inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
int dev;
{
Expand All @@ -41,6 +73,7 @@ inline cudaError_t GetNumBlocks(int64_t n, int* num_blocks) {
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
#endif

template<typename T>
__device__ T max_func(const T a, const T b) {
Expand Down Expand Up @@ -74,7 +107,11 @@ class PDTraits<paddle::DataType::FLOAT16> {
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
#ifdef PADDLE_WITH_HIP
typedef hip_bfloat16 DataType;
#else
typedef __nv_bfloat16 DataType;
#endif
typedef paddle::bfloat16 data_t;
};

Expand Down
18 changes: 17 additions & 1 deletion csrc/generation/quant_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
#include<sys/mman.h>
#include<stdio.h>
#include<algorithm>
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_bfloat16.h>
#else
#include<cuda_fp16.h>
#include<cuda_bf16.h>
#endif


constexpr int DequantKernelVecSize = 4;
Expand Down Expand Up @@ -52,11 +57,17 @@ __forceinline__ __device__ half add_mul<half>(half a, half b, half c) {
return __hmul(__hadd(a, b), c);
}

#ifdef PADDLE_WITH_HIP
template<>
__forceinline__ __device__ hip_bfloat16 add_mul<hip_bfloat16>(hip_bfloat16 a, hip_bfloat16 b, hip_bfloat16 c) {
return (a + b) * c;
}
#else
template<>
__forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
return __hmul(__hadd(a, b), c);
}

#endif


template <typename data_t>
Expand Down Expand Up @@ -173,8 +184,13 @@ std::vector<paddle::Tensor> LaunchQuantInt8(const paddle::Tensor& input,
auto output=paddle::full(input_shape, -1, paddle::DataType::INT8, input.place());
int m = input_shape[0];
int n = input_shape[1];
#ifdef PADDLE_WITH_HIP
dim3 grid(((n >> 2) + 63) / 64, (m + 7) / 8);
dim3 block(64, 8);
#else
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
#endif
auto stream = input.stream();
if (shift && smooth) {
QuantKernel<DataType_><<<grid, block, 0, stream>>>(reinterpret_cast<const DataType_*>(input.data<data_t>()),
Expand Down
7 changes: 6 additions & 1 deletion csrc/generation/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ void InvokeRebuildPadding(T *output_data,
const int *padding_offset,
const int token_num,
const int dim_embed,
cudaStream_t stream) {
#ifdef PADDLE_WITH_HIP
hipStream_t stream
#else
cudaStream_t stream
#endif
) {
// src: [token_num, dim_embed]
// dst: [batch_size * max_seq_len, dim_embed]
RebuildPaddingKernel<<<token_num, 256, 0, stream>>>(
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/rebuild_padding_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T, int VecSize>
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/set_value_by_flags_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"

__global__ void set_value_by_flag_and_id_v2(const bool *stop_flags,
Expand Down
22 changes: 22 additions & 0 deletions csrc/generation/step.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

// #define DEBUG_STEP
Expand Down Expand Up @@ -255,7 +269,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
max_decoder_block_num
);
#ifdef DEBUG_STEP
#ifdef PADDLE_WITH_HIP
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
#endif
auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false);
const int grid_size = cpu_recover_lens.data<int>()[0];
Expand Down Expand Up @@ -287,7 +305,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
first_token_id
);
#ifdef DEBUG_STEP
#ifdef PADDLE_WITH_HIP
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
#endif
}
}
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/stop_generation_multi_ends_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"
#include<stdlib.h>
#include<string.h>
Expand Down
14 changes: 14 additions & 0 deletions csrc/generation/token_penalty_multi_scores_v2.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"


Expand Down
7 changes: 6 additions & 1 deletion csrc/generation/transpose_removing_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ void InvokeTransposeRemovePadding(const T* input_data,
const int head_dim,
const int token_num,
const int* padding_offset,
cudaStream_t cu_stream) {
#ifdef PADDLE_WITH_HIP
hipStream_t cu_stream
#else
cudaStream_t cu_stream
#endif
) {
// [batch_size, num_head, max_len_this_time, head_dim] -> [token_num, num_head,
// head_dim]
constexpr int VEC_16B = 16;
Expand Down
35 changes: 31 additions & 4 deletions csrc/generation/write_int8_cache_kv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

#include "helper.h"

#ifdef PADDLE_WITH_HIP
constexpr int32_t WARP_SIZE = 64;
constexpr int32_t HALF_WARP = 32;
#else
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t HALF_WARP = 16;
#endif
constexpr float QUANT_MAX_BOUND = 127.0;
constexpr float QUANT_MIN_BOUND = -127.0;

Expand Down Expand Up @@ -47,14 +52,22 @@ struct MaxFunc{
template<>
struct MaxFunc<half>{
__device__ half operator()(half a, half b){
#if __CUDA_ARCH__ >= 800
#if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP)
return __hmax(a, b);
#else
return max(static_cast<float>(a), static_cast<float>(b));
#endif
}
};

#ifdef PADDLE_WITH_HIP
template<>
struct MaxFunc<hip_bfloat16>{
__device__ hip_bfloat16 operator()(hip_bfloat16 a, hip_bfloat16 b){
return static_cast<hip_bfloat16>(max(static_cast<float>(a), static_cast<float>(b)));
}
};
#else
template<>
struct MaxFunc<__nv_bfloat16>{
__device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b){
Expand All @@ -65,6 +78,7 @@ struct MaxFunc<__nv_bfloat16>{
#endif
}
};
#endif

template<typename T>
struct AbsFunc{
Expand All @@ -76,14 +90,22 @@ struct AbsFunc{
template<>
struct AbsFunc<half>{
__device__ half operator()(half x){
#if __CUDA_ARCH__ >= 800
#if (__CUDA_ARCH__ >= 800) || defined(PADDLE_WITH_HIP)
return __habs(x);
#else
return abs(static_cast<float>(x));
#endif
}
};

#ifdef PADDLE_WITH_HIP
template<>
struct AbsFunc<hip_bfloat16>{
__device__ hip_bfloat16 operator()(hip_bfloat16 x) {
return static_cast<hip_bfloat16>(abs(static_cast<float>(x)));
}
};
#else
template<>
struct AbsFunc<__nv_bfloat16>{
__device__ __nv_bfloat16 operator()(__nv_bfloat16 x){
Expand All @@ -94,6 +116,7 @@ struct AbsFunc<__nv_bfloat16>{
#endif
}
};
#endif

template <typename T, typename Vec, int VecSize>
__inline__ __device__ T LocalReduceMax(Vec& vec) {
Expand All @@ -109,7 +132,11 @@ template <typename T>
__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) {
#pragma unroll
for (int mask = HALF_WARP; mask > 0; mask >>= 1){
#ifdef PADDLE_WITH_HIP
val = MaxFunc<T>()(val, static_cast<T>(__shfl_xor(static_cast<float>(val), mask, WARP_SIZE)));
#else
val = MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE));
#endif
}
return val;
}
Expand Down Expand Up @@ -147,7 +174,7 @@ __global__ void write_cache_k_int8_kernel(const T* k, const int64_t num_head, co
InVec abs_max_vec;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
abs_max_vec[i] = 0.0f;
abs_max_vec[i] = static_cast<T>(0.0f);
}

T local_abs_max;
Expand Down Expand Up @@ -205,7 +232,7 @@ __global__ void write_cache_v_int8_kernel(const T* v, const int64_t num_head, co
InVec abs_max_vec;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
abs_max_vec[i] = 0.0f;
abs_max_vec[i] = static_cast<T>(0.0f);
}

T local_abs_max;
Expand Down
Loading