Skip to content
28 changes: 10 additions & 18 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,22 +693,13 @@ def master_grad_hook():
].append(param)
amp_global_state().already_classify_params_meshes = True

if os.getenv("FLAGS_enable_tensor_fusion") not in [
"True",
"true",
"1",
] and os.getenv("FLAGS_enable_main_grad") not in [
"True",
"true",
"1",
]:
if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)
if len(amp_global_state().mesh2params):
for _, params in amp_global_state().mesh2params.items():
core.eager.set_master_grads(params)
else:
core.eager.set_master_grads(
amp_global_state().model_parameters
)

amp_global_state().already_register_final_backward_hook = False

Expand Down Expand Up @@ -750,8 +741,9 @@ def param_hook(tmp_grad):
if not hasattr(param, "main_grad"):
param.main_grad = None
param._register_grad_hook(_update_main_grad_hook(param))

core.eager._add_backward_final_hook(master_grad_hook)
os.environ["FLAGS_enable_tensor_fusion"] = "0"
else:
core.eager._add_backward_final_hook(master_grad_hook)
amp_global_state().already_register_final_backward_hook = True

if tracer:
Expand Down
29 changes: 12 additions & 17 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,14 +1163,8 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
self._mp_group = None
self.do_tensor_fusion_once = True
self._strategy = Strategy()
self.enable_tensor_fusion = os.getenv("FLAGS_enable_tensor_fusion") in [
"True",
"true",
"1",
]
self.enable_sharding_overlap = os.getenv(
"FLAGS_enable_sharding_overlap"
) in ["True", "true", "1"]
self.enable_tensor_fusion = False
self.enable_sharding_overlap = False

def _set_and_check_sharding_prop_from_param(self):
global_mesh = fleet.auto.get_mesh()
Expand Down Expand Up @@ -1507,14 +1501,14 @@ def get_mesh(pp_idx=0):
self.param_storage[idx].is_sync = False

def _enable_tensor_fusion(self):
# TODO: enable after clear FLAGS_enable_tensor_fusion
# self.enable_tensor_fusion = True
pass
os.environ["FLAGS_enable_tensor_fusion"] = "1"
self.enable_tensor_fusion = True
self._shard_fn._enable_tensor_fusion()

def _enable_sharding_overlap(self, layers):
if hasattr(layers, 'config') and layers.config.get("to_static", False):
return
# self.enable_sharding_overlap = True
self.enable_sharding_overlap = True
if not isinstance(layers, paddle.nn.Layer):
raise RuntimeError(
f"`layers` must be `paddle.nn.Layer` but got {type(layers)}"
Expand Down Expand Up @@ -1951,15 +1945,19 @@ def __init__(self, mesh, sharding_mesh_dim):
self._mesh = mesh
self._sharding_axis = 0
self._sharding_mesh_dim = sharding_mesh_dim
self.enable_tensor_fusion = False

def _set_sharding_axis(self, sharding_axis):
self._sharding_axis = sharding_axis

def _enable_tensor_fusion(self):
self.enable_tensor_fusion = True

def shard_master_weight(
self, param: Tensor, master_weight: Tensor
) -> Tensor:
if param.is_dist():
if os.getenv("FLAGS_enable_tensor_fusion") in ["True", "true", "1"]:
if self.enable_tensor_fusion:
placements = param.placements
else:
placements = get_placement_with_sharding(
Expand Down Expand Up @@ -2095,10 +2093,7 @@ def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
return tensor

# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
if (
os.getenv("FLAGS_enable_tensor_fusion") not in ["True", "true", "1"]
and 'beta' not in key
):
if not self.enable_tensor_fusion and 'beta' not in key:
placements = get_placement_with_sharding(param, self._sharding_axis)
else:
placements = [
Expand Down
8 changes: 3 additions & 5 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,11 +2016,9 @@ def step(self) -> None:
for param in self._param_groups:
if param.stop_gradient:
continue
if os.getenv("FLAGS_enable_tensor_fusion") in [
"True",
"true",
"1",
] or os.getenv("FLAGS_enable_main_grad") in [
if getattr(self, 'enable_tensor_fusion', False) or os.getenv(
"FLAGS_enable_main_grad"
) in [
"True",
"true",
"1",
Expand Down
8 changes: 0 additions & 8 deletions test/auto_parallel/semi_auto_parallel_sharding_stage_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ def test_sharding_stage_1_overlap_to_static(self):

def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion(self):
def run_sharding_test(enable_tensor_fusion):
os.environ['FLAGS_enable_tensor_fusion'] = (
'1' if enable_tensor_fusion else '0'
)
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
paddle.seed(self._seed)
model = paddle.nn.Linear(10, 10)
Expand Down Expand Up @@ -169,7 +166,6 @@ def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip(
self,
):
dist.init_parallel_env()
os.environ['FLAGS_enable_tensor_fusion'] = '1'
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
paddle.seed(self._seed)
model = paddle.nn.Linear(10, 10)
Expand All @@ -195,10 +191,6 @@ def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip(

def test_pure_sharding_multi_mesh_stage_1_with_sharding_overlap(self):
def run_sharding_test(enable_sharding_overlap):
os.environ['FLAGS_enable_tensor_fusion'] = '1'
os.environ['FLAGS_enable_sharding_overlap'] = (
'1' if enable_sharding_overlap else '0'
)
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
paddle.seed(self._seed)
model = paddle.nn.Linear(10, 10)
Expand Down
Loading