|
5 | 5 | import torch |
6 | 6 | import torch.onnx.symbolic_helper as sym_help |
7 | 7 | import warnings |
| 8 | +import numpy |
8 | 9 |
|
9 | 10 | from torch.onnx.symbolic_helper import parse_args, _unimplemented |
10 | 11 | from torch.onnx.symbolic_opset9 import expand |
@@ -514,6 +515,101 @@ def __lshift_(g, self, other): |
514 | 515 | return lshift |
515 | 516 |
|
516 | 517 |
|
| 518 | +def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding_d, stride_d): |
| 519 | + # Input is always 4-D (N, C, H, W) |
| 520 | + # Calculate indices of sliding blocks along spatial dimension |
| 521 | + # Slide kernel over input each dim d: |
| 522 | + # each dimension d ranges from 0 to input[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1) |
| 523 | + # with steps = stride |
| 524 | + |
| 525 | + blocks_d = g.op("Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))) |
| 526 | + blocks_d = g.op("Sub", blocks_d, g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1)))) |
| 527 | + |
| 528 | + # Stride kernel over input and find starting indices along dim d |
| 529 | + blocks_d_indices = g.op("Range", g.op("Constant", value_t=torch.tensor(0)), |
| 530 | + blocks_d, g.op("Constant", value_t=torch.tensor(stride_d))) |
| 531 | + |
| 532 | + # Apply dilation on kernel and find its indices along dim d |
| 533 | + kernel_grid = numpy.arange(0, kernel_size_d * dilation_d, dilation_d) |
| 534 | + kernel_grid = g.op("Constant", value_t=torch.tensor([kernel_grid])) |
| 535 | + |
| 536 | + # Broadcast and add kernel staring positions (indices) with |
| 537 | + # kernel_grid along dim d, to get block indices along dim d |
| 538 | + blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1] |
| 539 | + kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1]))) |
| 540 | + block_mask = g.op("Add", blocks_d_indices, kernel_mask) |
| 541 | + |
| 542 | + return block_mask |
| 543 | + |
| 544 | + |
| 545 | +def _get_im2col_padded_input(g, input, padding_h, padding_w): |
| 546 | + # Input is always 4-D tensor (N, C, H, W) |
| 547 | + # Padding tensor has the following format: (padding_h, padding_w) |
| 548 | + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) |
| 549 | + pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) |
| 550 | + return g.op("Pad", input, pad) |
| 551 | + |
| 552 | + |
| 553 | +def _get_im2col_output_shape(g, input, kernel_h, kernel_w): |
| 554 | + batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) |
| 555 | + channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) |
| 556 | + channel_unfolded = g.op("Mul", channel_dim, |
| 557 | + g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))) |
| 558 | + |
| 559 | + return g.op("Concat", |
| 560 | + g.op("Unsqueeze", batch_dim, axes_i=[0]), |
| 561 | + g.op("Unsqueeze", channel_unfolded, axes_i=[0]), |
| 562 | + g.op("Constant", value_t=torch.tensor([-1])), axis_i=0) |
| 563 | + |
| 564 | + |
| 565 | +@parse_args('v', 'is', 'is', 'is', 'is') |
| 566 | +def im2col(g, input, kernel_size, dilation, padding, stride): |
| 567 | + # Input is always 4-D tensor (N, C, H, W) |
| 568 | + # All other args are int[2] |
| 569 | + |
| 570 | + input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) |
| 571 | + input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) |
| 572 | + |
| 573 | + stride_h, stride_w = stride[0], stride[1] |
| 574 | + padding_h, padding_w = padding[0], padding[1] |
| 575 | + dilation_h, dilation_w = dilation[0], dilation[1] |
| 576 | + kernel_h, kernel_w = kernel_size[0], kernel_size[1] |
| 577 | + |
| 578 | + blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h) |
| 579 | + blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w) |
| 580 | + |
| 581 | + output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) |
| 582 | + padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) |
| 583 | + |
| 584 | + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 |
| 585 | + # [[[[1., 2., 3.,], |
| 586 | + # [4., 5., 6.,], |
| 587 | + # [7., 8., 9.,]]]] |
| 588 | + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: |
| 589 | + # [[[[[1., 2., 3.], |
| 590 | + # [4., 5., 6.]], |
| 591 | + # [[4., 5., 6.], |
| 592 | + # [7., 8., 9.]]]]] |
| 593 | + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: |
| 594 | + # [[[[[[1., 2.], |
| 595 | + # [4., 5.]], |
| 596 | + # [[2., 3.], |
| 597 | + # [5., 6]]], |
| 598 | + # [[[4., 5.], |
| 599 | + # [7., 8.]], |
| 600 | + # [[5., 6.], |
| 601 | + # [8., 9.]]]]]] |
| 602 | + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: |
| 603 | + # [[[1., 2., 4., 5.], |
| 604 | + # [2., 3., 5., 6.], |
| 605 | + # [4., 5., 7., 8.], |
| 606 | + # [5., 6., 8., 9.]]] |
| 607 | + output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) |
| 608 | + output = g.op("Gather", output, blocks_col_indices, axis_i=4) |
| 609 | + output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) |
| 610 | + return g.op("Reshape", output, output_shape) |
| 611 | + |
| 612 | + |
517 | 613 | @parse_args('v', 'i', 'i') |
518 | 614 | def flatten(g, input, start_dim, end_dim): |
519 | 615 | dim = input.type().dim() |
|
0 commit comments