Skip to content

Commit a9f48ec

Browse files
committed
support fp8 quant in sharding optimizer
1 parent 3efb8db commit a9f48ec

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

paddle/fluid/pybind/eager_method.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,6 +2467,36 @@ static PyObject* tensor__clear_dataptr(TensorObject* self,
24672467
EAGER_CATCH_AND_THROW_RETURN_NULL
24682468
}
24692469

2470+
static PyObject* tensor__clear_to_zero_allocation(TensorObject* self,
2471+
PyObject* args,
2472+
PyObject* kwargs) {
2473+
EAGER_TRY
2474+
auto* dense_tensor =
2475+
dynamic_cast<phi::DenseTensor*>(self->tensor.impl().get());
2476+
if (dense_tensor != nullptr && dense_tensor->Holder() != nullptr) {
2477+
phi::DenseTensor tmp(std::make_shared<phi::Allocation>(
2478+
nullptr, 0, dense_tensor->Holder()->place()),
2479+
dense_tensor->meta());
2480+
dense_tensor->ShareBufferWith(std::move(tmp), /*only_buffer=*/true);
2481+
}
2482+
RETURN_PY_NONE
2483+
EAGER_CATCH_AND_THROW_RETURN_NULL
2484+
}
2485+
2486+
static PyObject* tensor__holder_size(TensorObject* self,
2487+
PyObject* args,
2488+
PyObject* kwargs) {
2489+
EAGER_TRY
2490+
auto* dense_tensor =
2491+
dynamic_cast<phi::DenseTensor*>(self->tensor.impl().get());
2492+
size_t size = 0;
2493+
if (dense_tensor != nullptr && dense_tensor->Holder() != nullptr) {
2494+
size = dense_tensor->Holder()->size();
2495+
}
2496+
return PyLong_FromSize_t(size);
2497+
EAGER_CATCH_AND_THROW_RETURN_NULL
2498+
}
2499+
24702500
static PyObject* tensor__copy_gradient_from(TensorObject* self,
24712501
PyObject* args,
24722502
PyObject* kwargs) {
@@ -3952,6 +3982,14 @@ PyMethodDef variable_methods[] = { // NOLINT
39523982
(PyCFunction)(void (*)())tensor__clear_dataptr,
39533983
METH_VARARGS | METH_KEYWORDS,
39543984
nullptr},
3985+
{"_clear_to_zero_allocation",
3986+
(PyCFunction)(void (*)())tensor__clear_to_zero_allocation,
3987+
METH_VARARGS | METH_KEYWORDS,
3988+
nullptr},
3989+
{"_holder_size",
3990+
(PyCFunction)(void (*)())tensor__holder_size,
3991+
METH_VARARGS | METH_KEYWORDS,
3992+
nullptr},
39553993
{"_copy_gradient_from",
39563994
(PyCFunction)(void (*)())tensor__copy_gradient_from,
39573995
METH_VARARGS | METH_KEYWORDS,

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

100755100644
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
g_sharding_v2_check_zero_padding = int(
4646
os.getenv("FLAGS_sharding_v2_check_zero_padding", "0")
4747
)
48+
g_shard_bypass_dygraph_optimizer = int(
49+
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
50+
)
4851

4952

5053
def _is_trainable(param):
@@ -618,13 +621,15 @@ def __init__(self, optimizer, hcg):
618621
self._hcg = hcg
619622
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
620623
self._sharding_rank = self._hcg.get_sharding_parallel_rank()
624+
self.clear_color = None
621625

622626
self._parameter_list = optimizer._parameter_list
623627

624628
# param name -> slice_param
625629
self._slice_params = {}
626630
# comm_buffer_list = []
627631
self._comm_buffer_list = []
632+
self._color_to_comm_buffer_list = {}
628633

629634
# slice parameter list
630635
self._local_parameter_list = [
@@ -814,12 +819,42 @@ def _build_comm_buffers(
814819
group_idx += 1
815820
self._comm_buffer_list.append(buffer)
816821

822+
if g_color not in self._color_to_comm_buffer_list.keys():
823+
self._color_to_comm_buffer_list[g_color] = []
824+
self._color_to_comm_buffer_list[g_color].append(buffer)
825+
817826
for p in parameters:
818827
if p.name in self.param2bucket:
819828
self.param2bucket[p.name].append(buffer)
820829
else:
821830
self.param2bucket[p.name] = [buffer]
822831

832+
def clear_param_storage(self, color):
833+
self.clear_color = color
834+
if color in self._color_to_comm_buffer_list.keys():
835+
for comm_buffer in self._color_to_comm_buffer_list[color]:
836+
for param in comm_buffer.params:
837+
grad_view = comm_buffer._sharding_param_grad_view[
838+
param.name
839+
]
840+
slice_param = self._slice_params[param.name]
841+
if (
842+
not g_shard_bypass_dygraph_optimizer
843+
and grad_view._param_begin < grad_view._param_end
844+
):
845+
grad_view.fill_slice_param(slice_param)
846+
self._create_master_weight(slice_param)
847+
slice_param._clear_dataptr()
848+
comm_buffer._clear_param_storage()
849+
850+
def reset_param_storage(self):
851+
color = self.clear_color
852+
if color is None:
853+
return
854+
if color in self._color_to_comm_buffer_list.keys():
855+
for comm_buffer in self._color_to_comm_buffer_list[color]:
856+
comm_buffer._reset_param_storage()
857+
823858
def clear_grad(self, set_to_zero=True):
824859
"""
825860
should clear grad for all parameters in model
@@ -1096,6 +1131,7 @@ def _assign_slice_grad(self):
10961131
def step(self):
10971132
# TODO Check whether the model trainable param changed and update state accordingly
10981133
# hack for pp comm overlap
1134+
self.reset_param_storage()
10991135

11001136
if self._all_gather_overlap_forward:
11011137
# Clear the pre forward hook in the optimizer step.

python/paddle/distributed/fleet/utils/tensor_fusion_helper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def assign_slice_grad(self, slice_param):
304304
else:
305305
assert slice_param.grad._is_shared_buffer_with(slice_grad)
306306

307+
def _clear_param_buffer(self):
308+
self._param._clear_to_zero_allocation()
309+
self._param_buffer._clear_to_zero_allocation()
310+
311+
def _reset_param_buffer(self, new_param_storage):
312+
new_param = paddle.empty_like(self._param)
313+
new_param._share_buffer_to(self._param)
314+
new_param_storage._share_buffer_to(self._param_buffer)
315+
self._share_param_buffer()
316+
307317
def _clear_grad_buffer(self):
308318
if self._slice_grad is not None:
309319
self._slice_grad._clear_dataptr()
@@ -553,6 +563,18 @@ def _record_addr(self):
553563
param, self.use_main_grad
554564
)
555565

566+
def _clear_param_storage(self):
567+
self.param_storage._clear_to_zero_allocation()
568+
for param in self._params:
569+
self._sharding_param_grad_view[param.name]._clear_param_buffer()
570+
571+
def _reset_param_storage(self):
572+
new_param_storage = paddle.empty_like(self.param_storage)
573+
new_param_storage._share_buffer_to(self.param_storage)
574+
for param in self._params:
575+
grad_view = self._sharding_param_grad_view[param.name]
576+
grad_view._reset_param_buffer(new_param_storage)
577+
556578
def _clear_grad_storage(self):
557579
self.grad_storage._clear_dataptr()
558580
self.grad_storage = None
@@ -775,6 +797,7 @@ def _comm_grads(self):
775797
group=self._comm_group,
776798
sync_op=False,
777799
)
800+
778801
if self._free_grads_in_comm:
779802
self._reset_grad_storage(reduce_scattered)
780803

0 commit comments

Comments
 (0)