|
| 1 | +// Copyright (C) 2018-2022 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include "helper_transforms/unique_replacer.hpp" |
| 6 | + |
| 7 | +#include <memory> |
| 8 | +#include <vector> |
| 9 | + |
| 10 | +#include "helper_ops/unique.hpp" |
| 11 | +#include "openvino/core/rt_info.hpp" |
| 12 | +#include "openvino/opsets/opset9.hpp" |
| 13 | +#include "openvino/pass/graph_rewrite.hpp" |
| 14 | +#include "openvino/pass/pattern/matcher.hpp" |
| 15 | +#include "openvino/pass/pattern/op/wrap_type.hpp" |
| 16 | +#include "utils.hpp" |
| 17 | + |
| 18 | +using namespace std; |
| 19 | +using namespace ov; |
| 20 | +using namespace ov::pass; |
| 21 | +using namespace ov::opset9; |
| 22 | +using namespace ov::frontend::tensorflow; |
| 23 | + |
| 24 | +ov::frontend::tensorflow::pass::UniqueReplacer::UniqueReplacer() { |
| 25 | + auto unique = pattern::wrap_type<Unique>(); |
| 26 | + |
| 27 | + matcher_pass_callback callback = [=](pattern::Matcher& matcher) { |
| 28 | + NodeRegistry rg; |
| 29 | + |
| 30 | + auto unique_node = std::dynamic_pointer_cast<Unique>(matcher.get_match_root()); |
| 31 | + if (!unique_node) { |
| 32 | + return false; |
| 33 | + } |
| 34 | + |
| 35 | + auto x = unique_node->input_value(0); |
| 36 | + auto output_indices_type = unique_node->get_output_indices_type(); |
| 37 | + auto x_type = x.get_element_type(); |
| 38 | + if (!x_type.is_real() && !x_type.is_integral_number()) { |
| 39 | + return false; |
| 40 | + } |
| 41 | + |
| 42 | + // denote a number of elements in x as n |
| 43 | + auto n = get_elements_number_1d(x, element::i32, rg); |
| 44 | + |
| 45 | + // create auxiliry constants to be re-used by different operations |
| 46 | + auto zero_const = rg.make<Constant>(element::i32, Shape{1}, 0); |
| 47 | + auto one_const = rg.make<Constant>(element::i32, Shape{1}, 1); |
| 48 | + auto one_const_scalar = rg.make<Constant>(element::i32, Shape{}, 1); |
| 49 | + auto minus_one_const = rg.make<Constant>(element::i32, Shape{1}, -1); |
| 50 | + auto true_const = rg.make<Constant>(element::boolean, Shape{1}, true); |
| 51 | + auto one_const_out_idx = rg.make<Constant>(output_indices_type, Shape{1}, 1); |
| 52 | + auto zero_const_out_idx = rg.make<Constant>(output_indices_type, Shape{1}, 0); |
| 53 | + |
| 54 | + // compute unique elements but not in the original order |
| 55 | + // 1. sort elements in x in order to compute unique elements |
| 56 | + auto x_sorted = rg.make<TopK>(x, n, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32); |
| 57 | + // 2. generate two vectors from x_sorted vector by padding in the beginning and in the end: |
| 58 | + // x1 = [0, x0, x1, ..., xn] |
| 59 | + // x2 = [x0, x1, ..., xn, 0] |
| 60 | + auto pad = rg.make<Constant>(x_type, Shape{1}, 0); |
| 61 | + auto x1 = rg.make<Concat>(OutputVector{pad, x_sorted->output(0)}, 0); |
| 62 | + auto x2 = rg.make<Concat>(OutputVector{x_sorted->output(0), pad}, 0); |
| 63 | + // 3. compare two vectors to see where unique elements are placed |
| 64 | + // and correct a mask because the first element is always unique |
| 65 | + // because the latest boolean element must be removed from the mask since |
| 66 | + // the vectors are padded |
| 67 | + auto mask1 = rg.make<NotEqual>(x1, x2); |
| 68 | + auto mask1_part = rg.make<Slice>(mask1, one_const, minus_one_const, one_const, zero_const); |
| 69 | + auto is_unique = rg.make<Concat>(OutputVector{true_const, mask1_part}, 0); |
| 70 | + // 5. compute positions where unique elements are placed in the sorted x |
| 71 | + auto is_unique_01 = rg.make<Select>(is_unique, one_const, zero_const); |
| 72 | + auto indices = rg.make<NonZero>(is_unique_01, element::i64); |
| 73 | + auto unique_element_indices = rg.make<Squeeze>(indices, zero_const); |
| 74 | + // 6. collect unique elements but currently they are not in the original order |
| 75 | + auto unique_elements = rg.make<Gather>(x_sorted->output(0), unique_element_indices, zero_const); |
| 76 | + |
| 77 | + // compute unique elements in the original order |
| 78 | + auto unsqueeze_x = rg.make<Unsqueeze>(x, zero_const); |
| 79 | + auto unsqueeze_unique_elements = rg.make<Unsqueeze>(unique_elements, one_const); |
| 80 | + // 1. compute a mask of pair comparison where each unique element is placed in the original |
| 81 | + auto nplus1 = rg.make<Add>(n, one_const_scalar); |
| 82 | + auto unique_vs_x = rg.make<Equal>(unsqueeze_unique_elements, unsqueeze_x); |
| 83 | + auto unique_vs_x_01 = rg.make<Select>(unique_vs_x, one_const_scalar, nplus1); |
| 84 | + auto range_1nplus1 = rg.make<Range>(one_const_scalar, nplus1, one_const_scalar, element::i32); |
| 85 | + auto unsqueeze_range_1nplus1 = rg.make<Unsqueeze>(range_1nplus1, zero_const); |
| 86 | + // 2. compute a mask with indices counting from one |
| 87 | + auto unique_vs_x_ind = rg.make<Multiply>(unique_vs_x_01, unsqueeze_range_1nplus1); |
| 88 | + // 3. compute positions of the first occurence for each unique element |
| 89 | + // or these are positions of unique elements in the original order |
| 90 | + auto minimum_indices_plus1 = rg.make<ReduceMin>(unique_vs_x_ind, one_const); |
| 91 | + auto minimum_indices = rg.make<Subtract>(minimum_indices_plus1, one_const); |
| 92 | + // denote a number of unique elements as m |
| 93 | + auto m = get_elements_number_1d(minimum_indices, element::i32, rg); |
| 94 | + auto sorted_minumum_indices = |
| 95 | + rg.make<TopK>(minimum_indices, m, 0, TopK::Mode::MIN, TopK::SortType::SORT_VALUES, element::i32); |
| 96 | + auto output_unique_elements = rg.make<Gather>(x, sorted_minumum_indices->output(0), zero_const); |
| 97 | + |
| 98 | + if (!unique_node->get_output_target_inputs(0).empty()) { |
| 99 | + output_unique_elements->set_friendly_name(unique_node->get_friendly_name() + ":0"); |
| 100 | + unique_node->output(0).replace(output_unique_elements->output(0)); |
| 101 | + } |
| 102 | + |
| 103 | + if (!unique_node->get_output_target_inputs(1).empty()) { |
| 104 | + // compute the second output |
| 105 | + // indices of elements of x in the vector of unique elements |
| 106 | + // 1. compute a mask for unique elements in the original order |
| 107 | + auto unsqueeze_output_unique_elements = rg.make<Unsqueeze>(output_unique_elements, one_const); |
| 108 | + auto unique_vs_x_orig = rg.make<Equal>(unsqueeze_output_unique_elements, unsqueeze_x); |
| 109 | + auto mplus1 = rg.make<Add>(m, one_const_scalar); |
| 110 | + auto unique_vs_x_orig_01 = rg.make<Select>(unique_vs_x_orig, one_const_out_idx, zero_const_out_idx); |
| 111 | + // 2. compute positions where each element from x is located in unique elements vector |
| 112 | + // the position counts from 1 |
| 113 | + auto range_1mplus1 = rg.make<Range>(one_const_scalar, mplus1, one_const_scalar, output_indices_type); |
| 114 | + auto unsqueeze_range_1mplus1 = rg.make<Unsqueeze>(range_1mplus1, one_const); |
| 115 | + auto unique_vs_x_ind_orig = rg.make<Multiply>(unique_vs_x_orig_01, unsqueeze_range_1mplus1); |
| 116 | + auto output_idx_plus1 = rg.make<ReduceMax>(unique_vs_x_ind_orig, zero_const); |
| 117 | + auto output_idx = rg.make<Subtract>(output_idx_plus1, one_const_out_idx); |
| 118 | + |
| 119 | + output_idx->set_friendly_name(unique_node->get_friendly_name() + ":1"); |
| 120 | + unique_node->output(1).replace(output_idx->output(0)); |
| 121 | + } |
| 122 | + |
| 123 | + copy_runtime_info(unique_node, rg.get()); |
| 124 | + |
| 125 | + return true; |
| 126 | + }; |
| 127 | + |
| 128 | + auto m = make_shared<pattern::Matcher>(unique, "ov::frontend::tensorflow::pass::UniqueReplacer"); |
| 129 | + register_matcher(m, callback); |
| 130 | +} |
0 commit comments