@@ -1192,9 +1192,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11921192 platform::EventRole::kInnerOp );
11931193 if (run_pten_kernel_) {
11941194 pten::KernelContext pt_kernel_context;
1195+ // Do data transform before building KernelContext
1196+ PreparePtenData (exec_scope, *pt_kernel_, *pt_kernel_signature_,
1197+ runtime_ctx);
11951198 BuildPtenKernelContext (*runtime_ctx, dev_ctx, &pt_kernel_context);
11961199 (*pt_kernel_)(&pt_kernel_context);
1197- WriteBackToOutputs (runtime_ctx, &pt_kernel_context);
11981200 } else {
11991201 (*kernel_func_)(
12001202 ExecutionContext (*this , exec_scope, *dev_ctx, *runtime_ctx));
@@ -1786,6 +1788,62 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
17861788 pten::TransToPtenKernelName (Type ()));
17871789}
17881790
1791+ Scope* OperatorWithKernel::PreparePtenData (
1792+ const Scope& scope, const pten::Kernel& pt_kernel,
1793+ const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
1794+ auto & input_names = std::get<0 >(pt_kernel_signature.args );
1795+ auto input_defs = pt_kernel.args_def ().input_defs ();
1796+ PADDLE_ENFORCE_EQ (input_names.size (), input_defs.size (),
1797+ platform::errors::InvalidArgument (
1798+ " The size of inputs_args names (%d) must be equal to "
1799+ " the size of kernel input_defs (%d)." ,
1800+ input_names.size (), input_defs.size ()));
1801+ Scope* new_scope = nullptr ;
1802+ for (size_t i = 0 ; i < input_defs.size (); ++i) {
1803+ auto & in_def = input_defs.at (i);
1804+ auto & ins_vector = ctx->inputs .at (input_names[i]);
1805+ for (size_t offset = 0 ; offset < ins_vector.size (); ++offset) {
1806+ // Only tensor can be tranfer to another device.
1807+ auto * var = ins_vector[offset];
1808+ if (var == nullptr || !VarIsTensor (*var)) {
1809+ continue ;
1810+ }
1811+
1812+ auto * tensor_in = GetLoDTensorOrSelectedRowsValueFromVar (*var);
1813+ if (!tensor_in->IsInitialized ()) {
1814+ continue ;
1815+ }
1816+
1817+ auto expected_place = pten::TransToFluidPlace (in_def.backend );
1818+ if (platform::is_same_place (tensor_in->place (), expected_place)) {
1819+ continue ;
1820+ }
1821+
1822+ // TODO(zyfncg): Now there is no kernel which need to transform input
1823+ // data, so we commented out following code temporarily,
1824+ // and it will be used in the future.
1825+
1826+ // VLOG(3) << "PTen Transform Variable " << input_names[i] << " from "
1827+ // << tensor_in->place() << " to " << expected_place;
1828+
1829+ // if (!new_scope) {
1830+ // new_scope = &scope.NewScope();
1831+ // }
1832+
1833+ // // Create new var with the same name in transfer scopes
1834+ // auto* trans_var = new_scope->Var(input_names[i]);
1835+ // ins_vector[i] = trans_var;
1836+
1837+ // // Do transfer
1838+ // Tensor out;
1839+ // framework::TensorCopySync(*tensor_in, expected_place, &out);
1840+ // SetTensorToVariable(*var, out, trans_var);
1841+ }
1842+ }
1843+
1844+ return new_scope;
1845+ }
1846+
17891847void OperatorWithKernel::BuildPtenKernelContext (
17901848 const RuntimeContext& ctx, platform::DeviceContext* dev_ctx,
17911849 pten::KernelContext* pt_kernel_context) const {
@@ -1818,7 +1876,6 @@ void OperatorWithKernel::BuildPtenKernelContext(
18181876 attr_names.size (), attr_defs.size ()));
18191877
18201878 for (size_t i = 0 ; i < input_names.size (); ++i) {
1821- auto & in_def = input_defs.at (i);
18221879 auto & ins_vector = ctx.inputs .at (input_names[i]);
18231880
18241881 // calcute the start and end index of the input tensors
@@ -1827,24 +1884,44 @@ void OperatorWithKernel::BuildPtenKernelContext(
18271884 size_t end_idx = start_idx + ins_vector.size ();
18281885
18291886 for (size_t offset = 0 ; offset < ins_vector.size (); ++offset) {
1830- pt_kernel_context->EmplaceBackInputWithoutSetRange (
1831- experimental::MakePtenTensorBaseFromVar (*ins_vector[offset], in_def));
1887+ const framework::Tensor* tensor_in = nullptr ;
1888+ auto * var = ins_vector[offset];
1889+ if (var->IsType <framework::LoDTensor>()) {
1890+ tensor_in = &(var->Get <framework::LoDTensor>());
1891+ } else {
1892+ PADDLE_THROW (platform::errors::Unimplemented (
1893+ " Unsupported input `%s` type when call pt kernel." ,
1894+ framework::ToTypeName (var->Type ())));
1895+ } // TODO(zyfncg): Add support for SelectedRows
1896+
1897+ pt_kernel_context->EmplaceBackInputWithoutSetRange (tensor_in);
18321898 }
18331899 pt_kernel_context->AssignInputRange (std::make_pair (start_idx, end_idx), i);
18341900 }
18351901
18361902 for (size_t i = 0 ; i < output_names.size (); ++i) {
1837- auto & out_def = output_defs.at (i);
18381903 auto & outs_vector = ctx.outputs .at (output_names[i]);
18391904
18401905 size_t start_idx =
18411906 (i == 0 ? 0 : pt_kernel_context->OutputRangeAt (i - 1 ).second );
18421907 size_t end_idx = start_idx + outs_vector.size ();
18431908
18441909 for (size_t offset = 0 ; offset < outs_vector.size (); ++offset) {
1845- pt_kernel_context->EmplaceBackOutputWithoutSetRange (
1846- experimental::MakePtenTensorBaseFromVar (outs_vector[offset],
1847- out_def));
1910+ framework::Tensor* tensor_out = nullptr ;
1911+ auto * var = outs_vector[offset];
1912+ if (var->template IsType <framework::LoDTensor>()) {
1913+ tensor_out = var->template GetMutable <framework::LoDTensor>();
1914+ } else {
1915+ PADDLE_THROW (platform::errors::Unimplemented (
1916+ " Unsupported output `%s` type when call pt kernel." ,
1917+ framework::ToTypeName (var->Type ())));
1918+ } // TODO(zyfncg): Add support for SelectedRows
1919+
1920+ experimental::ResetTensorByArgDef (tensor_out, output_defs.at (i));
1921+ SetAllocationForOutputTenosr (
1922+ tensor_out, pten::TransToFluidPlace (output_defs.at (i).backend ));
1923+
1924+ pt_kernel_context->EmplaceBackOutputWithoutSetRange (tensor_out);
18481925 }
18491926
18501927 // Deal with the case that some outputs are NULL when run the kernel.
0 commit comments