@@ -24,48 +24,22 @@ TEST(Pad, real) {
2424 for (size_t imgSizeW : {5 , 32 , 96 }) {
2525 VLOG (3 ) << " numSamples=" << numSamples << " channels=" << channels
2626 << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
27-
28- FunctionCompare compare (" Pad" ,
29- FuncConfig ()
30- .set (" cstart" , 2 )
31- .set (" cend" , 3 )
32- .set (" hstart" , 1 )
33- .set (" hend" , 2 )
34- .set (" wstart" , 3 )
35- .set (" wend" , 2 ));
36- TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
37- TensorShape outDims{
38- numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
39- compare.addInputs (BufferArg (VALUE_TYPE_FLOAT, inDims));
40- compare.addOutputs (BufferArg (VALUE_TYPE_FLOAT, outDims, ASSIGN_TO));
41- compare.run ();
42- }
43- }
44- }
45- }
46- }
47-
48- TEST (PadGrad, real) {
49- for (size_t numSamples : {5 , 32 }) {
50- for (size_t channels : {1 , 5 , 32 }) {
51- for (size_t imgSizeH : {5 , 33 , 100 }) {
52- for (size_t imgSizeW : {5 , 32 , 96 }) {
53- VLOG (3 ) << " numSamples=" << numSamples << " channels=" << channels
54- << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
55- FunctionCompare compare (" PadGrad" ,
56- FuncConfig ()
57- .set (" cstart" , 2 )
58- .set (" cend" , 3 )
59- .set (" hstart" , 1 )
60- .set (" hend" , 2 )
61- .set (" wstart" , 3 )
62- .set (" wend" , 2 ));
63- TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
64- TensorShape outDims{
65- numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
66- compare.addInputs (BufferArg (VALUE_TYPE_FLOAT, outDims));
67- compare.addOutputs (BufferArg (VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
68- compare.run ();
27+ for (bool test_grad : {false , true }) {
28+ FunctionCompare compare (
29+ test_grad ? " PadGrad" : " Pad" ,
30+ FuncConfig ()
31+ .set <std::vector<uint32_t >>(" channel" , {2 , 3 })
32+ .set <std::vector<uint32_t >>(" height" , {1 , 2 })
33+ .set <std::vector<uint32_t >>(" width" , {3 , 2 }));
34+ TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
35+ TensorShape outDims{
36+ numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
37+ compare.addInputs (
38+ BufferArg (VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
39+ compare.addOutputs (BufferArg (
40+ VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
41+ compare.run ();
42+ }
6943 }
7044 }
7145 }
0 commit comments