Skip to content

Commit 02a7318

Browse files
qqaatwpytorchmergebot
authored andcommitted
[MPS] Add aminmax op (pytorch#101691)
Pull Request resolved: pytorch#101691 Approved by: https://github.com/malfet
1 parent 80dd847 commit 02a7318

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

aten/src/ATen/native/Histogram.cpp

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
#include <ATen/ops/_histogramdd_from_bin_tensors.h>
1717
#include <ATen/ops/_histogramdd_from_bin_tensors_native.h>
1818
#include <ATen/ops/aminmax.h>
19-
#include <ATen/ops/amin.h>
20-
#include <ATen/ops/amax.h>
2119
#include <ATen/ops/empty.h>
2220
#include <ATen/ops/histc_native.h>
2321
#include <ATen/ops/histogram_native.h>
@@ -196,9 +194,8 @@ select_outer_bin_edges(const Tensor& input, c10::optional<c10::ArrayRef<double>>
196194
// non-empty input
197195
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "histogramdd", [&]() {
198196
if (input.is_mps()) {
199-
// aminmax has not been implemented on mps.
200-
Tensor min = at::amin(input, 0);
201-
Tensor max = at::amax(input, 0);
197+
Tensor min, max;
198+
std::tie(min, max) = at::aminmax(input, 0);
202199

203200
for (const auto i : c10::irange(N)) {
204201
leftmost_edges[i] = min[i].item().to<scalar_t>();
@@ -239,18 +236,9 @@ std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input,
239236
double rightmost_edge = max.to<double>();
240237

241238
if (leftmost_edge == rightmost_edge && input.numel() > 0) {
242-
if (input.is_mps()) {
243-
// aminmax has not been implemented on mps.
244-
Tensor min = at::amin(input);
245-
Tensor max = at::amax(input);
246-
247-
leftmost_edge = min.item<double>();
248-
rightmost_edge = max.item<double>();
249-
} else {
250-
auto extrema = aminmax(input);
251-
leftmost_edge = std::get<0>(extrema).item<double>();
252-
rightmost_edge = std::get<1>(extrema).item<double>();
253-
}
239+
auto extrema = aminmax(input);
240+
leftmost_edge = std::get<0>(extrema).item<double>();
241+
rightmost_edge = std::get<1>(extrema).item<double>();
254242
}
255243

256244
if (leftmost_edge == rightmost_edge) {

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ATen/ops/all_native.h>
1616
#include <ATen/ops/amax_native.h>
1717
#include <ATen/ops/amin_native.h>
18+
#include <ATen/ops/aminmax_native.h>
1819
#include <ATen/ops/any_native.h>
1920
#include <ATen/ops/argmax_native.h>
2021
#include <ATen/ops/argmin_native.h>
@@ -989,6 +990,24 @@ Tensor trace_mps(const Tensor& self) {
989990
mps::reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, mps::MPSReductionType::AMIN, "amin_out_mps");
990991
}
991992

993+
TORCH_IMPL_FUNC(aminmax_out_mps)
994+
(const Tensor& input_t, c10::optional<int64_t> dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) {
995+
mps::reduction_out_mps(input_t,
996+
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : c10::nullopt,
997+
keepdim,
998+
c10::nullopt,
999+
min_t,
1000+
mps::MPSReductionType::AMIN,
1001+
"aminmax_out_mps_min");
1002+
mps::reduction_out_mps(input_t,
1003+
dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : c10::nullopt,
1004+
keepdim,
1005+
c10::nullopt,
1006+
max_t,
1007+
mps::MPSReductionType::AMAX,
1008+
"aminmax_out_mps_max");
1009+
}
1010+
9921011
Tensor prod_mps(const Tensor& self, c10::optional<ScalarType> opt_dtype) {
9931012
std::vector<int64_t> dims(self.dim());
9941013
std::iota(dims.begin(), dims.end(), 0);

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,6 +3571,7 @@
35713571
structured: True
35723572
dispatch:
35733573
CPU, CUDA: aminmax_out
3574+
MPS: aminmax_out_mps
35743575

35753576
- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
35763577
dispatch:

test/test_mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def mps_ops_grad_modifier(ops):
7777
'cdist': [torch.float32],
7878
'masked.scatter': [torch.float16, torch.float32],
7979
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
80+
'aminmax': [torch.float32],
8081

8182
# Correctness issues
8283
'atanh': [torch.float32],
@@ -394,7 +395,6 @@ def mps_ops_modifier(ops):
394395
'rounddecimals_3': None,
395396
'rounddecimals_0': None,
396397
'__rsub__': None,
397-
'aminmax': None,
398398
'angle': None,
399399
'bucketize': None,
400400
'cauchy_': None,

0 commit comments

Comments
 (0)