Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2467,6 +2467,36 @@ static PyObject* tensor__clear_dataptr(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__clear_to_zero_allocation(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto* dense_tensor =
dynamic_cast<phi::DenseTensor*>(self->tensor.impl().get());
if (dense_tensor != nullptr && dense_tensor->Holder() != nullptr) {
phi::DenseTensor tmp(std::make_shared<phi::Allocation>(
nullptr, 0, dense_tensor->Holder()->place()),
dense_tensor->meta());
dense_tensor->ShareBufferWith(std::move(tmp), /*only_buffer=*/true);
}
RETURN_PY_NONE
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__holder_size(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto* dense_tensor =
dynamic_cast<phi::DenseTensor*>(self->tensor.impl().get());
size_t size = 0;
if (dense_tensor != nullptr && dense_tensor->Holder() != nullptr) {
size = dense_tensor->Holder()->size();
}
return PyLong_FromSize_t(size);
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__copy_gradient_from(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -3952,6 +3982,14 @@ PyMethodDef variable_methods[] = { // NOLINT
(PyCFunction)(void (*)())tensor__clear_dataptr,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_clear_to_zero_allocation",
(PyCFunction)(void (*)())tensor__clear_to_zero_allocation,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_holder_size",
(PyCFunction)(void (*)())tensor__holder_size,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_copy_gradient_from",
(PyCFunction)(void (*)())tensor__copy_gradient_from,
METH_VARARGS | METH_KEYWORDS,
Expand Down
36 changes: 36 additions & 0 deletions .../paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
g_sharding_v2_check_zero_padding = int(
os.getenv("FLAGS_sharding_v2_check_zero_padding", "0")
)
g_shard_bypass_dygraph_optimizer = int(
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
)


def _is_trainable(param):
Expand Down Expand Up @@ -618,13 +621,15 @@ def __init__(self, optimizer, hcg):
self._hcg = hcg
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()
self.clear_color = None

self._parameter_list = optimizer._parameter_list

# param name -> slice_param
self._slice_params = {}
# comm_buffer_list = []
self._comm_buffer_list = []
self._color_to_comm_buffer_list = {}

# slice parameter list
self._local_parameter_list = [
Expand Down Expand Up @@ -814,12 +819,42 @@ def _build_comm_buffers(
group_idx += 1
self._comm_buffer_list.append(buffer)

if g_color not in self._color_to_comm_buffer_list.keys():
self._color_to_comm_buffer_list[g_color] = []
self._color_to_comm_buffer_list[g_color].append(buffer)

for p in parameters:
if p.name in self.param2bucket:
self.param2bucket[p.name].append(buffer)
else:
self.param2bucket[p.name] = [buffer]

def clear_param_storage(self, color):
self.clear_color = color
if color in self._color_to_comm_buffer_list.keys():
for comm_buffer in self._color_to_comm_buffer_list[color]:
for param in comm_buffer.params:
grad_view = comm_buffer._sharding_param_grad_view[
param.name
]
slice_param = self._slice_params[param.name]
if (
not g_shard_bypass_dygraph_optimizer
and grad_view._param_begin < grad_view._param_end
):
grad_view.fill_slice_param(slice_param)
self._create_master_weight(slice_param)
slice_param._clear_dataptr()
comm_buffer._clear_param_storage()

def reset_param_storage(self):
color = self.clear_color
if color is None:
return
if color in self._color_to_comm_buffer_list.keys():
for comm_buffer in self._color_to_comm_buffer_list[color]:
comm_buffer._reset_param_storage()

def clear_grad(self, set_to_zero=True):
"""
should clear grad for all parameters in model
Expand Down Expand Up @@ -1096,6 +1131,7 @@ def _assign_slice_grad(self):
def step(self):
# TODO Check whether the model trainable param changed and update state accordingly
# hack for pp comm overlap
self.reset_param_storage()

if self._all_gather_overlap_forward:
# Clear the pre forward hook in the optimizer step.
Expand Down
25 changes: 24 additions & 1 deletion python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def fill_slice_param(self, slice_param):
slice_begin = self._param_begin
slice_end = self._param_end
slice_buffer = self._param_buffer._slice(slice_begin, slice_end)
slice_param.get_tensor()._set_dims([slice_end - slice_begin])
slice_buffer._share_buffer_to(slice_param)
slice_param.get_tensor()._set_dims([slice_end - slice_begin])

def assign_slice_grad(self, slice_param):
assert self._param_buffer._is_shared_buffer_with(self._param)
Expand All @@ -304,6 +304,16 @@ def assign_slice_grad(self, slice_param):
else:
assert slice_param.grad._is_shared_buffer_with(slice_grad)

def _clear_param_buffer(self):
self._param._clear_to_zero_allocation()
self._param_buffer._clear_to_zero_allocation()

def _reset_param_buffer(self, new_param_storage):
new_param = paddle.empty_like(self._param)
new_param._share_buffer_to(self._param)
new_param_storage._share_buffer_to(self._param_buffer)
self._share_param_buffer()

def _clear_grad_buffer(self):
if self._slice_grad is not None:
self._slice_grad._clear_dataptr()
Expand Down Expand Up @@ -553,6 +563,18 @@ def _record_addr(self):
param, self.use_main_grad
)

def _clear_param_storage(self):
self.param_storage._clear_to_zero_allocation()
for param in self._params:
self._sharding_param_grad_view[param.name]._clear_param_buffer()

def _reset_param_storage(self):
new_param_storage = paddle.empty_like(self.param_storage)
new_param_storage._share_buffer_to(self.param_storage)
for param in self._params:
grad_view = self._sharding_param_grad_view[param.name]
grad_view._reset_param_buffer(new_param_storage)

def _clear_grad_storage(self):
self.grad_storage._clear_dataptr()
self.grad_storage = None
Expand Down Expand Up @@ -775,6 +797,7 @@ def _comm_grads(self):
group=self._comm_group,
sync_op=False,
)

if self._free_grads_in_comm:
self._reset_grad_storage(reduce_scattered)

Expand Down