@@ -982,6 +982,24 @@ def get_pp_group() -> GroupCoordinator:
982982 return _PP
983983
984984
985+ _CP : Optional [GroupCoordinator ] = None
986+
987+
988+ def get_cp_group () -> GroupCoordinator :
989+ assert _CP is not None , ("context parallel group is not initialized" )
990+ return _CP
991+
992+
993+ def get_context_model_parallel_world_size ():
994+ """Return world size for the tensor model parallel group."""
995+ return get_cp_group ().world_size
996+
997+
998+ def get_context_model_parallel_rank ():
999+ """Return my rank for the tensor model parallel group."""
1000+ return get_cp_group ().rank_in_group
1001+
1002+
9851003@deprecated ("`get_pipeline_model_parallel_group` has been replaced with "
9861004 "`get_pp_group` and may be removed in v0.12. Please use "
9871005 "`get_pp_group` instead." )
@@ -1088,6 +1106,7 @@ def init_distributed_environment(world_size: int = -1,
10881106def initialize_model_parallel (
10891107 tensor_model_parallel_size : int = 1 ,
10901108 pipeline_model_parallel_size : int = 1 ,
1109+ context_model_parallel_size : int = 1 ,
10911110 decode_context_model_parallel_size : Optional [int ] = 1 ,
10921111 backend : Optional [str ] = None ,
10931112) -> None :
@@ -1138,7 +1157,7 @@ def initialize_model_parallel(
11381157 # last dimension, then reshape to 2D, then unbind the last dimension
11391158 all_ranks = torch .arange (world_size ).reshape (
11401159 - 1 , data_parallel_size , pipeline_model_parallel_size ,
1141- tensor_model_parallel_size ) # noqa
1160+ context_model_parallel_size , tensor_model_parallel_size ) # noqa
11421161
11431162 # Build the tensor model-parallel groups.
11441163 global _TP
@@ -1174,7 +1193,7 @@ def initialize_model_parallel(
11741193 global _PP
11751194 assert _PP is None , (
11761195 "pipeline model parallel group is already initialized" )
1177- group_ranks = all_ranks .transpose (2 , 3 ).reshape (
1196+ group_ranks = all_ranks .transpose (2 , 4 ).reshape (
11781197 - 1 , pipeline_model_parallel_size ).unbind (0 )
11791198 group_ranks = [x .tolist () for x in group_ranks ]
11801199 _PP = init_model_parallel_group (group_ranks ,
@@ -1185,7 +1204,7 @@ def initialize_model_parallel(
11851204 global _DP
11861205 assert _DP is None , ("data parallel group is already initialized" )
11871206 group_ranks = all_ranks .transpose (1 ,
1188- 3 ).reshape (- 1 ,
1207+ 4 ).reshape (- 1 ,
11891208 data_parallel_size ).unbind (0 )
11901209 group_ranks = [x .tolist () for x in group_ranks ]
11911210 _DP = init_model_parallel_group (group_ranks ,
@@ -1196,23 +1215,34 @@ def initialize_model_parallel(
11961215 global _EP
11971216 assert _EP is None , ("expert parallel group is already initialized" )
11981217 group_ranks = all_ranks .transpose (1 , 2 ).reshape (
1199- - 1 , data_parallel_size * tensor_model_parallel_size ).unbind (0 )
1218+ - 1 , data_parallel_size * tensor_model_parallel_size * context_model_parallel_size ).unbind (0 )
12001219 group_ranks = [x .tolist () for x in group_ranks ]
12011220 _EP = init_model_parallel_group (group_ranks ,
12021221 get_world_group ().local_rank ,
12031222 backend ,
12041223 group_name = "ep" )
12051224
1225+ global _CP
1226+ assert _CP is None , ("context parallel group is already initialized" )
1227+ group_ranks = all_ranks .transpose (3 , 4 ).reshape (
1228+ - 1 , context_model_parallel_size ).unbind (0 )
1229+ group_ranks = [x .tolist () for x in group_ranks ]
1230+ _CP = init_model_parallel_group (group_ranks ,
1231+ get_world_group ().local_rank ,
1232+ backend ,
1233+ group_name = "cp" )
1234+
12061235 logger .info (
12071236 "rank %s in world size %s is assigned as "
1208- "DP rank %s, PP rank %s, TP rank %s, EP rank %s" , rank , world_size ,
1237+ "DP rank %s, PP rank %s, TP rank %s, EP rank %s, CP rank %s " , rank , world_size ,
12091238 _DP .rank_in_group , _PP .rank_in_group , _TP .rank_in_group ,
1210- _EP .rank_in_group )
1239+ _EP .rank_in_group , _CP . rank_in_group )
12111240
12121241
12131242def ensure_model_parallel_initialized (
12141243 tensor_model_parallel_size : int ,
12151244 pipeline_model_parallel_size : int ,
1245+ context_model_parallel_size : int ,
12161246 decode_context_model_parallel_size : Optional [int ] = 1 ,
12171247 backend : Optional [str ] = None ,
12181248) -> None :
@@ -1225,6 +1255,7 @@ def ensure_model_parallel_initialized(
12251255 if not model_parallel_is_initialized ():
12261256 initialize_model_parallel (tensor_model_parallel_size ,
12271257 pipeline_model_parallel_size ,
1258+ context_model_parallel_size ,
12281259 decode_context_model_parallel_size , backend )
12291260 return
12301261
@@ -1238,6 +1269,11 @@ def ensure_model_parallel_initialized(
12381269 "pipeline parallel group already initialized, but of unexpected size. "
12391270 f"got: { pp_world_size = } vs. "
12401271 f"wanted: { pipeline_model_parallel_size = } " )
1272+ cp_world_size = get_cp_group ().world_size
1273+ assert (cp_world_size == context_model_parallel_size ), (
1274+ "context parallel group already initialized, but of unexpected size: "
1275+ f"{ cp_world_size = } vs. "
1276+ f"{ context_model_parallel_size = } " )
12411277
12421278
12431279def prepare_communication_buffer_for_model (model : torch .nn .Module ):
@@ -1345,6 +1381,11 @@ def destroy_model_parallel():
13451381 _EP .destroy ()
13461382 _EP = None
13471383
1384+ global _CP
1385+ if _CP :
1386+ _CP .destroy ()
1387+ _CP = None
1388+
13481389
13491390def destroy_distributed_environment ():
13501391 global _WORLD , _NODE_COUNT
0 commit comments