Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 51 additions & 61 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ Buffer::Buffer(int rank,
calc_ctx = reinterpret_cast<phi::GPUContext*>(
reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
->GetDeviceContext(place, true));
// Task fifo memory
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;

// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);

// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
(num_nvl_bytes <= std::numeric_limits<int64_t>::max() ||
(num_nvl_bytes <= std::numeric_limits<int>::max() ||
num_rdma_bytes == 0));
EP_HOST_ASSERT(
num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
Expand All @@ -90,40 +91,35 @@ Buffer::Buffer(int rank,
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);

// Get ranks
// CUDA_CHECK(cudaGetDevice(&device_id));
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS),
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS);
num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);

// Get device info
cudaDeviceProp device_prop = {};
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));

if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handle
CUDA_CHECK(cudaMalloc(
&buffer_ptrs[nvl_rank],
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank],
num_nvl_bytes + barrier_signal_bytes +
buffer_ptr_bytes + barrier_signal_ptr_bytes));
CUDA_CHECK(
cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void**>(
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
fifo_bytes);

// Set task fifo
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
task_fifo_ptrs_gpu = reinterpret_cast<int**>(
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
fifo_bytes + buffer_ptr_bytes);
buffer_ptrs_gpu =
reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) +
num_nvl_bytes + barrier_signal_bytes);

// Set barrier signals
barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
barrier_signal_ptrs_gpu = reinterpret_cast<int**>(
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
barrier_signal_bytes + buffer_ptr_bytes);

// No need to synchronize, will do a full device sync during `sync`
CUDA_CHECK(cudaMemsetAsync(
buffer_ptrs[nvl_rank],
0,
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
comm_stream));
barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
}

// Create 32 MiB workspace
Expand Down Expand Up @@ -165,8 +161,7 @@ Buffer::~Buffer() noexcept(false) {
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(
task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
move_fifo_slots();
barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());

// Close remote IPC
Expand Down Expand Up @@ -197,10 +192,6 @@ Buffer::~Buffer() noexcept(false) {
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
}

void Buffer::move_fifo_slots(int num_slots) {
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
}

bool Buffer::is_available() const { return available; }

bool Buffer::is_internode_available() const {
Expand Down Expand Up @@ -249,7 +240,7 @@ void Buffer::sync(

// Sync IPC handles
if (num_nvl_bytes > 0) {
EP_HOST_ASSERT(num_ranks == static_cast<int64_t>(device_ids.size()));
EP_HOST_ASSERT(num_ranks == device_ids.size());
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
++i) {
Expand All @@ -261,22 +252,22 @@ void Buffer::sync(
ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(
&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
task_fifo_ptrs[i] = reinterpret_cast<int*>(
reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
barrier_signal_ptrs[i] = reinterpret_cast<int*>(
static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved,
handle_str.c_str(),
CUDA_IPC_HANDLE_SIZE) == 0);
}
}

// Copy all buffer and task pointers to GPU
// Copy all buffer and barrier signal pointers to GPU
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu,
buffer_ptrs,
sizeof(void*) * NUM_MAX_NVL_PEERS,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu,
task_fifo_ptrs,
CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu,
barrier_signal_ptrs,
sizeof(int*) * NUM_MAX_NVL_PEERS,
cudaMemcpyHostToDevice));
CUDA_CHECK(cudaDeviceSynchronize());
Expand Down Expand Up @@ -520,7 +511,7 @@ Buffer::intranode_dispatch(

// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
Expand All @@ -529,6 +520,8 @@ Buffer::intranode_dispatch(
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}

// Allocate all tensors on comm stream if set
Expand Down Expand Up @@ -564,12 +557,10 @@ Buffer::intranode_dispatch(
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(),
num_memset_int,
buffer_ptrs_gpu,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
num_ranks,
comm_stream);
move_fifo_slots(2);
} else {
rank_prefix_matrix = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_ranks, num_ranks},
Expand Down Expand Up @@ -604,12 +595,10 @@ Buffer::intranode_dispatch(
num_memset_int,
expert_alignment,
buffer_ptrs_gpu,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
comm_stream,
num_channels);
move_fifo_slots(3);

// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -719,10 +708,13 @@ Buffer::intranode_dispatch(
is_token_in_rank.data_ptr<bool>(),
channel_prefix_matrix.data_ptr<int>(),
num_tokens,
0, // num_worst_tokens (not exposed)
static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
num_topk,
num_experts,
num_scales,
scale_token_stride,
scale_hidden_stride,
buffer_ptrs_gpu,
rank,
num_ranks,
Expand Down Expand Up @@ -867,15 +859,11 @@ Buffer::intranode_combine(
num_channels,
num_recv_tokens,
num_channels * num_ranks * 2,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
num_ranks,
comm_stream);

// NOTES: this function uses two FIFO slots (barrier before and after)
move_fifo_slots(2);

// Combine data
auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_recv_tokens, hidden}, x.dtype(), x.place()));
Expand All @@ -895,6 +883,8 @@ Buffer::intranode_combine(
recv_topk_weights_ptr,
x.data_ptr(),
topk_weights_ptr,
nullptr, // bias_ptrs[0] (not exposed)
nullptr, // bias_ptrs[1] (not exposed)
src_idx.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(),
channel_prefix_matrix.data_ptr<int>(),
Expand Down Expand Up @@ -1084,7 +1074,7 @@ Buffer::internode_dispatch(

// FP8 scales checks
float* x_scales_ptr = nullptr;
int num_scales = 0;
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
if (x_scales.has_value()) {
EP_HOST_ASSERT(x.element_size() == 1);
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
Expand All @@ -1093,6 +1083,8 @@ Buffer::internode_dispatch(
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
x_scales_ptr = x_scales->data_ptr<float>();
scale_token_stride = static_cast<int>(x_scales->stride(0));
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
}

// Allocate all tensors on comm stream if set
Expand Down Expand Up @@ -1144,15 +1136,13 @@ Buffer::internode_dispatch(
config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes,
true,
low_latency_mode);
move_fifo_slots(2);
} else {
rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_rdma_ranks, num_channels},
Expand Down Expand Up @@ -1196,14 +1186,12 @@ Buffer::internode_dispatch(
config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes,
low_latency_mode);
move_fifo_slots(3);

// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -1320,12 +1308,14 @@ Buffer::internode_dispatch(
recv_rdma_rank_prefix_sum.data_ptr<int>(),
gbl_channel_prefix_matrix.data_ptr<int>(),
recv_gbl_rank_prefix_sum.data_ptr<int>(),
is_token_in_rank.data_ptr<bool>(),
num_tokens,
hidden_int4,
num_scales,
num_topk,
num_experts,
is_token_in_rank.data_ptr<bool>(),
scale_token_stride,
scale_hidden_stride,
rdma_buffer_ptr,
config.num_max_rdma_chunked_send_tokens,
config.num_max_rdma_chunked_recv_tokens,
Expand Down Expand Up @@ -1523,15 +1513,13 @@ Buffer::internode_combine(
config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu,
config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu,
head,
barrier_signal_ptrs_gpu,
rank,
comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes,
false,
low_latency_mode);
move_fifo_slots(2);

// Launch data combine
auto combined_x =
Expand All @@ -1543,6 +1531,8 @@ Buffer::internode_combine(
is_combined_token_in_rank.data_ptr<bool>(),
x.data_ptr(),
topk_weights_ptr,
nullptr, // bias_ptrs[0] (not exposed)
nullptr, // bias_ptrs[1] (not exposed)
combined_rdma_head.data_ptr<int>(),
combined_nvl_head.data_ptr<int>(),
src_meta.data_ptr(),
Expand Down
10 changes: 3 additions & 7 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ struct Buffer {
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;

// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;
// Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr;

// Workspace
void* workspace = nullptr;
Expand All @@ -97,9 +96,6 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

private:
void move_fifo_slots(int num_slots = 1);

public:
Buffer(int rank,
int num_ranks,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/collective/deep_ep/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct Tensor {
}

int64_t element_size() const { return phi::SizeOf(raw_tensor_.dtype()); }

int64_t stride(int64_t d) const { return raw_tensor_.strides().at(d); }
};

} // namespace deep_ep::detail
Loading