Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions paddle/phi/api/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,12 @@
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL)
const auto & comm_context_manager_ = phi::distributed::CommContextManager::GetInstance();
if (nranks > 1 && !comm_context_manager_.Has(std::to_string(ring_id))) {{
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
CREATE_COMM_CONTEXT(store, std::to_string(ring_id), rank, nranks);
std::string store_key;
store_key = "nccl_ids/" + std::to_string(ring_id) + "/0";
if (!comm_context_manager_.Has(store_key)) {{
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
CREATE_COMM_CONTEXT(store, std::to_string(ring_id), rank, nranks);
}}
}}
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
const auto & comm_context_manager_ = phi::distributed::CommContextManager::GetInstance();
Expand All @@ -103,9 +107,16 @@
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
const auto & comm_context_manager = phi::distributed::CommContextManager::GetInstance();
COMM_CONTEXT* comm_context = nullptr;
if (comm_context_manager.Has(std::to_string(ring_id))) {{
comm_context = static_cast<COMM_CONTEXT*>(
std::string store_key;
store_key = "nccl_ids/" + std::to_string(ring_id) + "/0";
if (comm_context_manager.Has(std::to_string(ring_id))||comm_context_manager.Has(store_key)) {{
if (comm_context_manager.Has(std::to_string(ring_id))) {{
comm_context = static_cast<COMM_CONTEXT*>(
comm_context_manager.Get(std::to_string(ring_id)));
}} else {{
comm_context = static_cast<COMM_CONTEXT*>(
comm_context_manager.Get(store_key));
}}
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
Expand All @@ -114,17 +125,14 @@
"has ring_id(%d) attr.",
std::to_string(ring_id)));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_XPU_BKCL)
if (!comm_context->GetDevContext() || !comm_context->GetDevContext()->GetCommContext())
{{
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{}", kernel_data_type);
}}
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
dev_context->SetCommContext(comm_context);
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
if (FLAGS_low_precision_op_list) {{
phi::KernelFactory::Instance().AddToLowPrecisionKernelList("{}", kernel_data_type);
}}
Backend act_kernel_backend = kernel_res.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_context = GetDeviceContextByBackend(act_kernel_backend);
dev_context->SetCommContext(comm_context);
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
auto kernel_res = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"{}", {{kernel_backend, kernel_layout, kernel_data_type}}, true);
Expand Down
13 changes: 13 additions & 0 deletions test/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -451,5 +451,18 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_new_api_per_op_and_group_intranode
PROPERTIES TIMEOUT "120")
endif()
if((WITH_GPU) AND (LINUX))
bash_test_modules(
test_comm_group_num
START_BASH
test_comm_group_num.sh
TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=22304;http_proxy=;https_proxy=")
set_tests_properties(test_comm_group_num PROPERTIES TIMEOUT "120")
endif()
add_subdirectory(fleet)
add_subdirectory(multinode)
62 changes: 62 additions & 0 deletions test/collective/test_comm_group_num.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.base import topology as tp
from paddle.distributed.fleet.layers.mpu import mp_ops


class CommGroupNumTest(unittest.TestCase):
def test_comm_group_num(self):
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 2,
"sharding_degree": 2,
"order": ["dp", "pp", "sharding", "sep", "mp"],
}
fleet.init(is_collective=True, strategy=strategy)

place = paddle.framework._current_expected_place()
input = np.random.uniform(
low=-2.0, high=2.0, size=(1, 4096, 16000)
).astype('float32')
input = paddle.to_tensor(input, place=place)
input.stop_gradient = False

label = np.random.randint(
low=1, high=29956, size=(1, 4096, 1), dtype='int64'
)
label = paddle.to_tensor(label, place=place)
label.stop_gradient = True

model_parallel_group = (
tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group()
)
loss = mp_ops._c_softmax_with_cross_entropy(
input,
label,
group=model_parallel_group,
ignore_index=-100,
)


if __name__ == '__main__':
unittest.main()
121 changes: 121 additions & 0 deletions test/collective/test_comm_group_num.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

function is_a100() {
if [ $(nvidia-smi|grep A100|wc -l) -ne 0 ];then
echo 1
else
echo 0
fi
}

if [ "$(is_a100)" == "1" ]; then
exit 0
fi

unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT
nnodes=$PADDLE_TRAINERS_NUM
rank=$PADDLE_TRAINER_ID

export NCCL_IB_QPS_PER_CONNECTION=8
export NCCL_DEBUG=INFO
# export NCCL_DEBUG=WARN
export NCCL_IB_TIMEOUT=22
export NCCL_IB_ADAPTIVE_ROUTING=0
# export NCCL_IB_GID_INDEX=3
export NCCL_NVLS_ENABLE=0
export NCCL_SOCKET_IFNAME=xgbe0
# export NCCL_DEBUG_SUBSYS=INIT,ENV,GRAPH,ALLOC
export NCCL_DEBUG_SUBSYS=INIT,COLL,TUNING,ALLOC
export NCCL_IB_HCA=mlx5_1,mlx5_8,mlx5_6,mlx5_4,mlx5_2,mlx5_9,mlx5_7,mlx5_5,mlx5_3
# export IB_GID_INDEX=3

for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
unset ${name}
done

START_RANK=0
END_RANK=1

if [[ $rank -lt $START_RANK ]]; then
exit 0
fi

if [[ $rank -ge $END_RANK ]]; then
exit 0
fi
rank=$(($rank-$START_RANK))
nnodes=$(($END_RANK-$START_RANK))
# master=`cat /root/paddlejob/workspace/hostfile | head -n $(($START_RANK+1)) | tail -n 1 | awk '{print $1}'`
# master=`cat hostfile | head -n $(($START_RANK+1)) | tail -n 1 | awk '{print $1}'`
if [ -f "/root/paddlejob/workspace/hostfile" ]; then
# 文件存在,按原逻辑获取 master
master=$(cat /root/paddlejob/workspace/hostfile | head -n $(($START_RANK+1)) | tail -n 1 | awk '{print $1}')
else
# 文件不存在,设置为当前机器的 IP
master=$(hostname -I | awk '{print $1}') # 获取本机 IP
echo "hostfile not found, using current machine IP: $master"
fi
port=36677

version=2_21_5
new_api=4

if [ "$version" = "2_21_5" ]; then
export NCCL_RUNTIME_CONNECT=0
fi

# root_path 改成自己的路径
root_path=/root/paddlejob/workspace/env_run/output/test_comm_group_num
task_name=llama2_13b_dynamic_hand_"$version"_"$new_api"
export NCCL_DEBUG_FILE=$root_path/Nccl/nccl_log/$task_name/%h.%p.log

export FLAGS_eager_communication_connection=1

export NNODES=1
export PADDLE_TRAINERS_NUM=1
export CUDA_DEVICE_MAX_CONNECTIONS=1

# log
export GLOG_v=3

log_dir="$root_path/Nccl/log/$task_name"

if [ -d $log_dir ]; then
rm -rf $log_dir
fi

shell_dir=$(dirname "$(readlink -f "$0")")

python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir ${log_dir} \
--master $master:$port \
--nnodes $nnodes \
--rank $rank \
--run_mode=collective \
$shell_dir/test_comm_group_num.py

count7=$(grep -c "init NCCLCommContext rank" "${log_dir}/workerlog.7")

if [ $count7 -ne 7 ]; then
echo -e "\033[31m test_comm_group_num failed, got ${count7}, expect 7 \033[0m"
exit 1
fi

rm -rf $root_path
1 change: 1 addition & 0 deletions test/collective/testslist.csv
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ test_mpi_comm,linux,,,DIST,test_mpi_comm.sh,2,,http_proxy=;https_proxy=,WITH_MPI
test_strategy_group,linux,rocm;gpu,120,DIST,test_strategy_group.sh,2,,http_proxy=;https_proxy=,
test_orthogonal_strategy,linux,rocm;gpu,120,DIST,test_orthogonal_strategy.sh,2,,http_proxy=;https_proxy=,
test_new_api_per_op_and_group_intranode,linux,gpu,120,DIST,test_new_api_per_op_and_group_intranode.sh,2,,http_proxy=;https_proxy=,
test_comm_group_num,linux,gpu,120,DIST,test_comm_group_num.sh,2,,http_proxy=;https_proxy=,
Loading