Skip to content

Commit 0b57b38

Browse files
neginraooffacebook-github-bot
authored andcommitted
Im2col export (pytorch#30972)
Summary: Added im2col to opset 11. This symbolic is used to export torch.nn.Unfold Pull Request resolved: pytorch#30972 Reviewed By: hl475 Differential Revision: D18946921 Pulled By: houseroad fbshipit-source-id: 13dd0cbae899700df32fd74d6dff1f29033a2b4c
1 parent 6cd987e commit 0b57b38

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,7 +2296,6 @@ def forward(self, x, pad):
22962296
y = pad = (torch.tensor(2, dtype=torch.int64), torch.tensor(4, dtype=torch.int64))
22972297
self.run_test(Pad(), (x, y))
22982298

2299-
23002299
def test_reflection_pad(self):
23012300
model = torch.nn.ReflectionPad1d(2)
23022301
x = torch.randn(2, 4, 4)
@@ -2315,6 +2314,17 @@ def test_replication_pad(self):
23152314
x = torch.randn(2, 2, 4, 4)
23162315
self.run_test(model, x)
23172316

2317+
@skipIfUnsupportedMinOpsetVersion(11)
2318+
def test_im2col(self):
2319+
class Unfold(torch.nn.Module):
2320+
def forward(self, input):
2321+
return torch.nn.functional.unfold(input, kernel_size=(10, 15), dilation=2, padding=5, stride=3), \
2322+
torch.nn.functional.unfold(input, kernel_size=(2, 2), dilation=1, padding=0, stride=3), \
2323+
torch.nn.functional.unfold(input, kernel_size=(1, 1), dilation=5, padding=2, stride=3)
2324+
2325+
x = torch.rand(1, 1, 200, 100)
2326+
self.run_test(Unfold(), x)
2327+
23182328
@skipIfNoLapack
23192329
@skipIfUnsupportedMinOpsetVersion(11)
23202330
def test_det(self):
@@ -2582,23 +2592,23 @@ def setup_rnn_tests():
25822592
dict(TestONNXRuntime.__dict__, opset_version=11))
25832593

25842594

2585-
# opset 9 tests, with keep_initializers_as_inputs=False for
2595+
# opset 9 tests, with keep_initializers_as_inputs=False for
25862596
# IR version 4 style export.
25872597
TestONNXRuntime_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
25882598
(unittest.TestCase,),
25892599
dict(TestONNXRuntime.__dict__,
25902600
keep_initializers_as_inputs=False))
25912601

25922602

2593-
# opset 10 tests, with keep_initializers_as_inputs=False for
2603+
# opset 10 tests, with keep_initializers_as_inputs=False for
25942604
# IR version 4 style export.
25952605
TestONNXRuntime_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
25962606
(unittest.TestCase,),
25972607
dict(TestONNXRuntime.__dict__, opset_version=10,
25982608
keep_initializers_as_inputs=False))
25992609

26002610

2601-
# opset 11 tests, with keep_initializers_as_inputs=False for
2611+
# opset 11 tests, with keep_initializers_as_inputs=False for
26022612
# IR version 4 style export.
26032613
TestONNXRuntime_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
26042614
(unittest.TestCase,),

torch/onnx/symbolic_opset11.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.onnx.symbolic_helper as sym_help
77
import warnings
8+
import numpy
89

910
from torch.onnx.symbolic_helper import parse_args, _unimplemented
1011
from torch.onnx.symbolic_opset9 import expand
@@ -514,6 +515,101 @@ def __lshift_(g, self, other):
514515
return lshift
515516

516517

518+
def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding_d, stride_d):
519+
# Input is always 4-D (N, C, H, W)
520+
# Calculate indices of sliding blocks along spatial dimension
521+
# Slide kernel over input each dim d:
522+
# each dimension d ranges from 0 to input[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1)
523+
# with steps = stride
524+
525+
blocks_d = g.op("Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)))
526+
blocks_d = g.op("Sub", blocks_d, g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
527+
528+
# Stride kernel over input and find starting indices along dim d
529+
blocks_d_indices = g.op("Range", g.op("Constant", value_t=torch.tensor(0)),
530+
blocks_d, g.op("Constant", value_t=torch.tensor(stride_d)))
531+
532+
# Apply dilation on kernel and find its indices along dim d
533+
kernel_grid = numpy.arange(0, kernel_size_d * dilation_d, dilation_d)
534+
kernel_grid = g.op("Constant", value_t=torch.tensor([kernel_grid]))
535+
536+
# Broadcast and add kernel staring positions (indices) with
537+
# kernel_grid along dim d, to get block indices along dim d
538+
blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
539+
kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1])))
540+
block_mask = g.op("Add", blocks_d_indices, kernel_mask)
541+
542+
return block_mask
543+
544+
545+
def _get_im2col_padded_input(g, input, padding_h, padding_w):
546+
# Input is always 4-D tensor (N, C, H, W)
547+
# Padding tensor has the following format: (padding_h, padding_w)
548+
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
549+
pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
550+
return g.op("Pad", input, pad)
551+
552+
553+
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
554+
batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
555+
channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
556+
channel_unfolded = g.op("Mul", channel_dim,
557+
g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)))
558+
559+
return g.op("Concat",
560+
g.op("Unsqueeze", batch_dim, axes_i=[0]),
561+
g.op("Unsqueeze", channel_unfolded, axes_i=[0]),
562+
g.op("Constant", value_t=torch.tensor([-1])), axis_i=0)
563+
564+
565+
@parse_args('v', 'is', 'is', 'is', 'is')
566+
def im2col(g, input, kernel_size, dilation, padding, stride):
567+
# Input is always 4-D tensor (N, C, H, W)
568+
# All other args are int[2]
569+
570+
input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
571+
input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
572+
573+
stride_h, stride_w = stride[0], stride[1]
574+
padding_h, padding_w = padding[0], padding[1]
575+
dilation_h, dilation_w = dilation[0], dilation[1]
576+
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
577+
578+
blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h)
579+
blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w)
580+
581+
output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
582+
padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
583+
584+
# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
585+
# [[[[1., 2., 3.,],
586+
# [4., 5., 6.,],
587+
# [7., 8., 9.,]]]]
588+
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
589+
# [[[[[1., 2., 3.],
590+
# [4., 5., 6.]],
591+
# [[4., 5., 6.],
592+
# [7., 8., 9.]]]]]
593+
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
594+
# [[[[[[1., 2.],
595+
# [4., 5.]],
596+
# [[2., 3.],
597+
# [5., 6]]],
598+
# [[[4., 5.],
599+
# [7., 8.]],
600+
# [[5., 6.],
601+
# [8., 9.]]]]]]
602+
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
603+
# [[[1., 2., 4., 5.],
604+
# [2., 3., 5., 6.],
605+
# [4., 5., 7., 8.],
606+
# [5., 6., 8., 9.]]]
607+
output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
608+
output = g.op("Gather", output, blocks_col_indices, axis_i=4)
609+
output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
610+
return g.op("Reshape", output, output_shape)
611+
612+
517613
@parse_args('v', 'i', 'i')
518614
def flatten(g, input, start_dim, end_dim):
519615
dim = input.type().dim()

0 commit comments

Comments
 (0)