Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,12 +1097,12 @@ def _inner_training_loop(
if dp_master_grad:
is_no_sync = True

if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
sync_context = model.no_sync() if is_no_sync else contextlib.nullcontext()
with sync_context:
if "step_control" in inspect.signature(self.training_step).parameters:
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
else:
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)

tr_loss += tr_loss_step

Expand Down Expand Up @@ -2279,7 +2279,9 @@ def _enable_delay_scale_loss(self):
else:
return False

def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def training_step(
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]], step_control=0
) -> paddle.Tensor:
"""
Perform a training step on a batch of inputs.

Expand Down
65 changes: 65 additions & 0 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import paddle
import paddle.nn as nn


class SimpleContrastiveLoss(nn.Layer):
def __init__(self, embedding_temperature: float = 0.02):
super().__init__()
self.embedding_temperature = embedding_temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")

def forward(self, q_reps, p_reps):
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0]))
scores = scores / self.embedding_temperature

group_size = p_reps.shape[0] // q_reps.shape[0]
batch_size = q_reps.shape[0]

target = paddle.arange(batch_size, dtype="int64")
target = target * group_size

loss = self.cross_entropy(scores, target)
return loss


class MatryoshkaContrastiveLoss(nn.Layer):
def __init__(self, embedding_temperature: float = 0.02, embedding_matryoshka_dims: Optional[List[int]] = None):
super().__init__()
self.embedding_temperature = embedding_temperature
if embedding_matryoshka_dims is None:
self.embedding_matryoshka_dims = []
else:
self.embedding_matryoshka_dims = embedding_matryoshka_dims
self.loss_fn = SimpleContrastiveLoss(embedding_temperature)

def forward(self, q_reps, p_reps):
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim]
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)

reduced_p_reps = p_reps[:, :dim]
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)

dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)
loss += dim_loss
else:
loss = self.loss_fn(q_reps, p_reps)
return loss
51 changes: 51 additions & 0 deletions paddlenlp/transformers/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.distributed import fleet


def dist_gather_tensor_with_gradient(tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以找其他 utils 的目录放吧。感觉还比较通用。gathe across dp

函数里面的 _gather_tensor_with_gradient 这个 gradient怎么体现的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

体现在这里,重新把tensor赋值进去。all_tensors[sharding_rank] = tensor

if tensor is None:
return None

if paddle.distributed.get_world_size() <= 1:
return tensor

hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
sharding_rank = sharding_group.rank
data_group = hcg.get_data_parallel_group()
data_rank = data_group.rank

if sharding_group.nranks == 1 and data_group.nranks == 1:
return tensor

if sharding_group.nranks > 1:
all_tensors = []
paddle.distributed.all_gather(all_tensors, tensor.contiguous(), group=sharding_group)
all_tensors[sharding_rank] = tensor
all_tensors = paddle.concat(all_tensors, axis=0)
else:
all_tensors = tensor

if data_group.nranks > 1:
final_tensors = []
paddle.distributed.all_gather(final_tensors, all_tensors.contiguous(), group=data_group)
final_tensors[data_rank] = all_tensors
final_tensors = paddle.concat(final_tensors, axis=0)
else:
final_tensors = all_tensors

return final_tensors
1 change: 1 addition & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .embedding_trainer import EmbeddingTrainer
from .kto_criterion import KTOCriterion
from .kto_trainer import KTOTrainer
from .model_config import *
Expand Down
181 changes: 181 additions & 0 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import paddle
from paddle.base import core
from paddle.distributed import fleet

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import (
MatryoshkaContrastiveLoss,
SimpleContrastiveLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient

__all__ = ["EmbeddingTrainer"]


class EmbeddingTrainer(Trainer):
def __init__(self, model_args, **kwargs):
super().__init__(**kwargs)

self.model_args = model_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好加一个 loss type之类的字段。我看zhangjie那边还需要加inf-cl的loss

self.embedding_negatives_cross_device = model_args.embedding_negatives_cross_device
self.accum_data = []
self.accum_freq = 0
self.accum_q_features = []
self.accum_p_features = []
self.accum_rng_states = {}
self.accum_rng_states["cpu"] = []
self.accum_rng_states["cuda"] = []
self.accum_rng_states["hybrid"] = []

if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0:
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
else:
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)

def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()

def clear_state(self):
self.accum_data.clear()
self.accum_rng_states["cpu"].clear()
self.accum_rng_states["cuda"].clear()
self.accum_rng_states["hybrid"].clear()
self.accum_freq = 0

@paddle.no_grad()
def forward_no_grad(self, model, inputs):
# Step1: graph-less forward
self.accum_data.append(inputs)
inputs = self._prepare_inputs(inputs)
with self.autocast_smart_context_manager():
# collect rand states
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state())
self.accum_rng_states["cuda"].append(paddle.get_rng_state())
if self.args.use_hybrid_parallel:
self.accum_rng_states["hybrid"].append(
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
)

query_reps, passage_reps = model(**inputs, return_encode=True)

if self.embedding_negatives_cross_device:
query_reps = dist_gather_tensor_with_gradient(query_reps)
passage_reps = dist_gather_tensor_with_gradient(passage_reps)

self.accum_q_features.append(query_reps)
self.accum_p_features.append(passage_reps)

self.accum_freq += 1

def get_current_rng_state(self):
return {
"cpu": [paddle.framework.core.default_cpu_generator().get_state()],
"cuda": [paddle.get_rng_state()],
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()]
if self.args.use_hybrid_parallel
else [],
}

def reset_rng_state(self, states, index=0):
# set random states
if len(states) != 3:
raise ValueError("The length of state should be 3")
cpu_state = states["cpu"][index]
cuda_state = states["cuda"][index]
paddle.framework.core.default_cpu_generator().set_state(cpu_state)
# TODO(daisiming): support xpu and other custom devices.
if core.is_compiled_with_cuda():
for j in range(core.get_cuda_device_count()):
core.default_cuda_generator(j).set_state(cuda_state[j])
if self.args.use_hybrid_parallel:
hybrid_state = states["hybrid"][index]
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state)

def accum_forward_backward(self, model):
# Step2: representation gradient computation and caching
for i in range(len(self.accum_q_features)):
self.accum_q_features[i].stop_gradient = False
q_reps = paddle.concat(self.accum_q_features, axis=0)
for i in range(len(self.accum_p_features)):
self.accum_p_features[i].stop_gradient = False
p_reps = paddle.concat(self.accum_p_features, axis=0)

loss = self.loss_fn(q_reps, p_reps)
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
# get represetation gradient cache
accum_q_grads = [q.grad for q in self.accum_q_features]
accum_p_grads = [p.grad for p in self.accum_p_features]
del q_reps, p_reps

# clear trash memory
self.clear_memory()

current_rng_state = self.get_current_rng_state()
# Step3: sub-batch gradient accumulation
for i in range(self.accum_freq):
inputs = self.accum_data[i]
inputs = self._prepare_inputs(inputs)

sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext()
with sync_context:
self.reset_rng_state(self.accum_rng_states, index=i)

with self.autocast_smart_context_manager():
query_reps, passage_reps = model(**inputs, return_encode=True)

if self.embedding_negatives_cross_device:
query_reps = dist_gather_tensor_with_gradient(query_reps)
passage_reps = dist_gather_tensor_with_gradient(passage_reps)

_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot(
passage_reps.flatten(), accum_p_grads[i].flatten()
)
_loss.backward()

self.reset_rng_state(current_rng_state)
self.clear_state()
return loss.detach()

def training_step(
self,
model,
inputs,
step_control=0,
):
if self.args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

if self.args.gradient_accumulation_steps == 1:
return super().training_step(model, inputs)
else:
self.forward_no_grad(model, inputs)

# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch.
if (step_control + 1) % self.args.gradient_accumulation_steps != 0:
return 0.0

loss = self.accum_forward_backward(model)
return loss