Skip to content

Commit c406d90

Browse files
【AutoParallel】Add strategy with more options (#8114)
* add strategy * polish
1 parent 880d2ea commit c406d90

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def _wrap_for_auto(self, model, train_dataloader):
112112
dist_loader = self._wrap_for_dist_loader(train_dataloader)
113113

114114
if self.args.to_static:
115+
unified_strategy = dist.Strategy()
116+
unified_strategy._from_legacy_strategy(self.args.strategy)
115117
return (
116-
dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=self.args.strategy),
118+
dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy),
117119
dist_loader,
118120
)
119121
else:

0 commit comments

Comments
 (0)