@@ -23,7 +23,7 @@ namespace phi {
2323using Tensor = DenseTensor;
2424
2525template <typename DeviceContext, typename T>
26- inline void ResizeToChannelFirst (const DeviceContext& context ,
26+ inline void ResizeToChannelFirst (const DeviceContext& dev_ctx ,
2727 const Tensor* input,
2828 Tensor* transformed_input) {
2929 int dim = input->dims ().size () - 2 ;
@@ -37,7 +37,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
3737 in_dims_vec[3 ] = input->dims ()[2 ];
3838 in_dims_vec[4 ] = input->dims ()[3 ];
3939 transformed_input->Resize (common::make_ddim (in_dims_vec));
40- context .template Alloc <T>(transformed_input);
40+ dev_ctx .template Alloc <T>(transformed_input);
4141 } else if (dim == 2 ) {
4242 // input
4343 transformed_input->Resize (input->dims ());
@@ -47,20 +47,20 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
4747 in_dims_vec[2 ] = input->dims ()[1 ];
4848 in_dims_vec[3 ] = input->dims ()[2 ];
4949 transformed_input->Resize (common::make_ddim (in_dims_vec));
50- context .template Alloc <T>(transformed_input);
50+ dev_ctx .template Alloc <T>(transformed_input);
5151 } else if (dim == 1 ) {
5252 transformed_input->Resize (input->dims ());
5353
5454 auto in_dims_vec = common::vectorize (input->dims ());
5555 in_dims_vec[1 ] = input->dims ()[2 ];
5656 in_dims_vec[2 ] = input->dims ()[1 ];
5757 transformed_input->Resize (common::make_ddim (in_dims_vec));
58- context .template Alloc <T>(transformed_input);
58+ dev_ctx .template Alloc <T>(transformed_input);
5959 }
6060}
6161
6262template <typename DeviceContext, typename T>
63- inline void ResizeToChannelLast (const DeviceContext& context ,
63+ inline void ResizeToChannelLast (const DeviceContext& dev_ctx ,
6464 const Tensor* input,
6565 Tensor* transformed_input) {
6666 int dim = input->dims ().size () - 2 ;
@@ -74,7 +74,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
7474 in_dims_vec[3 ] = input->dims ()[4 ];
7575 in_dims_vec[4 ] = input->dims ()[1 ];
7676 transformed_input->Resize (common::make_ddim (in_dims_vec));
77- context .template Alloc <T>(transformed_input);
77+ dev_ctx .template Alloc <T>(transformed_input);
7878
7979 } else if (dim == 2 ) {
8080 // input
@@ -85,58 +85,58 @@ inline void ResizeToChannelLast(const DeviceContext& context,
8585 in_dims_vec[2 ] = input->dims ()[3 ];
8686 in_dims_vec[3 ] = input->dims ()[1 ];
8787 transformed_input->Resize (common::make_ddim (in_dims_vec));
88- context .template Alloc <T>(transformed_input);
88+ dev_ctx .template Alloc <T>(transformed_input);
8989 } else if (dim == 1 ) {
9090 transformed_input->Resize (input->dims ());
9191
9292 auto in_dims_vec = common::vectorize (input->dims ());
9393 in_dims_vec[1 ] = input->dims ()[2 ];
9494 in_dims_vec[2 ] = input->dims ()[1 ];
9595 transformed_input->Resize (common::make_ddim (in_dims_vec));
96- context .template Alloc <T>(transformed_input);
96+ dev_ctx .template Alloc <T>(transformed_input);
9797 }
9898}
9999
100100template <typename DeviceContext, typename T>
101- inline void TransToChannelFirst (const DeviceContext& context ,
101+ inline void TransToChannelFirst (const DeviceContext& dev_ctx ,
102102 const Tensor* input,
103103 Tensor* transformed_input) {
104104 VLOG (5 ) << " Why am I called?" ;
105105 int dim = input->dims ().size () - 2 ;
106106 if (dim == 3 ) {
107107 std::vector<int > axis{0 , 4 , 1 , 2 , 3 };
108108 phi::funcs::Transpose<DeviceContext, T, 5 > trans5;
109- trans5 (context , *input, transformed_input, axis);
109+ trans5 (dev_ctx , *input, transformed_input, axis);
110110
111111 } else if (dim == 2 ) {
112112 std::vector<int > axis{0 , 3 , 1 , 2 };
113113 phi::funcs::Transpose<DeviceContext, T, 4 > trans4;
114- trans4 (context , *input, transformed_input, axis);
114+ trans4 (dev_ctx , *input, transformed_input, axis);
115115 } else if (dim == 1 ) {
116116 std::vector<int > axis{0 , 2 , 1 };
117117 phi::funcs::Transpose<DeviceContext, T, 3 > trans3;
118- trans3 (context , *input, transformed_input, axis);
118+ trans3 (dev_ctx , *input, transformed_input, axis);
119119 }
120120}
121121
122122template <typename DeviceContext, typename T>
123- inline void TransToChannelLast (const DeviceContext& context ,
123+ inline void TransToChannelLast (const DeviceContext& dev_ctx ,
124124 const Tensor* input,
125125 Tensor* transformed_input) {
126126 int dim = input->dims ().size () - 2 ;
127127 if (dim == 3 ) {
128128 std::vector<int > axis{0 , 2 , 3 , 4 , 1 };
129129 phi::funcs::Transpose<DeviceContext, T, 5 > trans5;
130- trans5 (context , *input, transformed_input, axis);
130+ trans5 (dev_ctx , *input, transformed_input, axis);
131131
132132 } else if (dim == 2 ) {
133133 std::vector<int > axis{0 , 2 , 3 , 1 };
134134 phi::funcs::Transpose<DeviceContext, T, 4 > trans4;
135- trans4 (context , *input, transformed_input, axis);
135+ trans4 (dev_ctx , *input, transformed_input, axis);
136136 } else if (dim == 1 ) {
137137 std::vector<int > axis{0 , 2 , 1 };
138138 phi::funcs::Transpose<DeviceContext, T, 3 > trans3;
139- trans3 (context , *input, transformed_input, axis);
139+ trans3 (dev_ctx , *input, transformed_input, axis);
140140 }
141141}
142142
0 commit comments