Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions onnxruntime/contrib_ops/cpu/inverse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"
#include "core/util/math_cpuonly.h"
#include "Eigen/src/Core/Map.h"
#include "Eigen/LU"
#include <functional>

namespace onnxruntime {
namespace contrib {
class Inverse final : public OpKernel {
public:
explicit Inverse(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* ctx) const override;

private:
template <typename T>
struct ComputeImpl;
};

ONNX_OPERATOR_KERNEL_EX(
Inverse,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", BuildKernelDefConstraints<float, double, MLFloat16>()),
Inverse);

template <typename T>
using MatrixT = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

template <typename T>
struct Inverse::ComputeImpl {
void operator()(const Tensor* input, Tensor* output,
int64_t batch_num, int64_t rows, int64_t cols) const {
auto batch_offset = batch_num * rows * cols;
const auto* input_data = input->Data<T>() + batch_offset;
auto* output_data = output->MutableData<T>() + batch_offset;

Eigen::Map<const MatrixT<T>> input_matrix(input_data, rows, cols);
Eigen::Map<MatrixT<T>> output_matrix(output_data, rows, cols);
output_matrix = input_matrix.inverse();
}
};

template <>
struct Inverse::ComputeImpl<MLFloat16> {
void operator()(const Tensor* input, Tensor* output,
int64_t batch_num, int64_t rows, int64_t cols) const {
auto batch_offset = batch_num * rows * cols;
// Direct cast to half as it just as MLFloat16 containes only uint16_t
const auto* input_data = reinterpret_cast<const Eigen::half*>(input->Data<MLFloat16>() + batch_offset);
auto* output_data = reinterpret_cast<Eigen::half*>(output->MutableData<MLFloat16>() + batch_offset);

Eigen::Map<const MatrixT<Eigen::half>> input_matrix(input_data, rows, cols);
Eigen::Map<MatrixT<Eigen::half>> output_matrix(output_data, rows, cols);
output_matrix = input_matrix.inverse();
}
};

Status Inverse::Compute(OpKernelContext* ctx) const {
const auto& input = ctx->Input<Tensor>(0);
const auto elem_type = input->GetElementType();
const auto& input_shape = input->Shape();
const auto num_dim = input_shape.NumDimensions();
auto* output = ctx->Output(0, input_shape);

int64_t num_batches = 1;
const int64_t rows = input_shape.GetDims()[num_dim - 2];
const int64_t cols = input_shape.GetDims()[num_dim - 1];
if (num_dim > 2) {
num_batches = input_shape.SizeToDimension(num_dim - 2);
}

std::function<void(ptrdiff_t)> fn = [elem_type, input, output, rows, cols](ptrdiff_t batch_num) {
utils::MLTypeCallDispatcher<ComputeImpl, float, double, MLFloat16> t_disp(elem_type);
t_disp.Invoke(input, output, batch_num, rows, cols);
};

concurrency::ThreadPool::TryBatchParallelFor(ctx->GetOperatorThreadPool(), num_batches, std::move(fn), 0);

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);

Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
Expand Down Expand Up @@ -127,6 +128,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
160 changes: 160 additions & 0 deletions onnxruntime/contrib_ops/cuda/inverse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

class Inverse final : public ::onnxruntime::cuda::CudaKernel {
public:
explicit Inverse(const OpKernelInfo& info) : CudaKernel{info} {
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
using Base = CudaKernel;
using CublasHandle = cublasHandle_t;

template <typename T>
struct ComputeImpl;
};

ONNX_OPERATOR_KERNEL_EX(
Inverse,
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", BuildKernelDefConstraints<float, double, MLFloat16>()),
Inverse);

namespace inverse_internal {

template <typename T>
Status ComputeMatrixOffsets(T* workspace_data, size_t num_batches, size_t rows, IAllocatorUniquePtr<T*>& matrix_ptrs) {
std::vector<T*> cuda_ptrs;
const size_t matrix_size = rows * rows;
for (size_t i = 0; i < num_batches; ++i) {
cuda_ptrs.push_back(workspace_data);
workspace_data += matrix_size;
}
CUDA_RETURN_IF_ERROR(cudaMemcpy(matrix_ptrs.get(), cuda_ptrs.data(), sizeof(T*) * num_batches,
cudaMemcpyHostToDevice));
return Status::OK();
}

Status CheckForSingularity(const IAllocatorUniquePtr<int>& info, const std::unique_ptr<int[]>& info_cpu, size_t num_batches) {
// Let's check if any of the info values is non-zero
CUDA_RETURN_IF_ERROR(cudaMemcpy(info_cpu.get(), info.get(), sizeof(int) * num_batches,
cudaMemcpyDeviceToHost));
for (size_t i = 0; i < num_batches; ++i) {
if (info_cpu[i] != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Matrix is singular at batch:", i);
}
}
return Status::OK();
}

} // namespace inverse_internal

template <typename T>
struct Inverse::ComputeImpl {
Status operator()(Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output,
const IAllocatorUniquePtr<int>& info, const IAllocatorUniquePtr<int>& pivots,
size_t num_batches, size_t rows) const {
using namespace onnxruntime::cuda;
using namespace inverse_internal;
using CudaT = typename ToCudaType<T>::MappedType;
const size_t input_count = static_cast<size_t>(input.Shape().Size());
auto info_cpu = onnxruntime::make_unique<int[]>(num_batches);
const auto dim = static_cast<int>(rows);
const auto n_batches = static_cast<int>(num_batches);

// Make a copy of the input which will serve as a workspace as well.
if (std::is_same<T, float>::value || std::is_same<T, MLFloat16>::value) {
IAllocatorUniquePtr<float> input_workspace = inst->GetScratchBuffer<float>(input_count);
if (std::is_same<T, MLFloat16>::value) {
// Convert from MLFloat16(half) to float
Impl_Cast<CudaT, float>(reinterpret_cast<const CudaT*>(input.Data<MLFloat16>()), input_workspace.get(), input_count);
} else {
CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data<float>(), sizeof(float) * input_count,
cudaMemcpyDeviceToDevice));
}
IAllocatorUniquePtr<float*> matrix_ptrs = inst->GetScratchBuffer<float*>(n_batches);
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(input_workspace.get(), num_batches, rows, matrix_ptrs));
// Do LU factorization
CUBLAS_RETURN_IF_ERROR(cublasSgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));

// Need to compute ptrs for output buffers
// Output for MLFloat
IAllocatorUniquePtr<float*> output_ptrs = inst->GetScratchBuffer<float*>(n_batches);
if (std::is_same<T, MLFloat16>::value) {
IAllocatorUniquePtr<float> ml_float_output = inst->GetScratchBuffer<float>(input_count);
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(ml_float_output.get(), num_batches, rows, output_ptrs));
// Do the inverse
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
// Copy the result to output with casting
Impl_Cast<float, CudaT>(ml_float_output.get(), reinterpret_cast<CudaT*>(output.MutableData<MLFloat16>()), input_count);
// We are done here
} else {
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(output.MutableData<float>(), num_batches, rows, output_ptrs));
// Do the inverse
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
// We are done here
}
} else if (std::is_same<T, double>::value) {
IAllocatorUniquePtr<double> input_workspace = inst->GetScratchBuffer<double>(static_cast<int>(input_count));
CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data<double>(), sizeof(double) * input_count,
cudaMemcpyDeviceToDevice));

IAllocatorUniquePtr<double*> matrix_ptrs = inst->GetScratchBuffer<double*>(n_batches);
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(input_workspace.get(), num_batches, rows, matrix_ptrs));
// Do LU factorization
CUBLAS_RETURN_IF_ERROR(cublasDgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));

// Need to compute ptrs for output buffers
IAllocatorUniquePtr<double*> output_ptrs = inst->GetScratchBuffer<double*>(n_batches);
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(output.MutableData<double>(), num_batches, rows, output_ptrs));
CUBLAS_RETURN_IF_ERROR(cublasDgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
// We are done here
} else {
ORT_THROW("Type is not supported");
}
return Status::OK();
}
};

Status Inverse::ComputeInternal(OpKernelContext* ctx) const {
const auto* input = ctx->Input<Tensor>(0);
const auto& input_shape = input->Shape();
const auto num_dim = input_shape.NumDimensions();
auto* output = ctx->Output(0, input_shape);

size_t num_batches = 1;
const size_t rows = static_cast<size_t>(input_shape.GetDims()[num_dim - 2]);
const size_t cols = static_cast<size_t>(input_shape.GetDims()[num_dim - 1]);
ORT_ENFORCE(rows == cols, "Expecting square matrices");
if (num_dim > 2) {
num_batches = static_cast<size_t>(input_shape.SizeToDimension(num_dim - 2));
}

IAllocatorUniquePtr<int> info = GetScratchBuffer<int>(num_batches);
CUDA_RETURN_IF_ERROR(cudaMemset(info.get(), 0, num_batches));
IAllocatorUniquePtr<int> pivots = GetScratchBuffer<int>(rows * num_batches);

utils::MLTypeCallDispatcherRet<Status, ComputeImpl, float, double, MLFloat16> t_disp(input->GetElementType());
return t_disp.Invoke(Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
Expand Down Expand Up @@ -110,6 +111,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,62 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);

// Used to be ONNX 1.7 Inverse(12)
// Comment out docs not to increase the binary size
//
// static const char* Inverse_ver1_doc = R"DOC(
//Calculates inverse of a square matrix or batches of square matrices.
//Inverse takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
//and the inner-most 2 dimensions form square matrices. These matrices must be invertible (full-rank).
//The behavior where one of the matrices is not invertible is undefined. The implementation can choose
//to throw an error or output (garbage) results as is. The output is a tensor of shape `[*, M, M]`,
//containing the individual inverses of all input submatrices.
//)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.Input(0, "X", "Input tensor. Every matrix in the batch must be invertible.", "T")
.Output(0, "Y", "Output tensor of the same type and shape as the input tensor.", "T")
.TypeConstraint(
"T",
{"tensor(float16)",
"tensor(float)",
"tensor(double)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
using namespace ONNX_NAMESPACE;
propagateElemTypeFromInputToOutput(ctx, 0, 0);

// Shape inference
if (hasInputShape(ctx, 0)) {
const TensorShapeProto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const int rank = static_cast<int>(input_shape.dim_size());

if (rank < 2) {
fail_shape_inference("Input rank must be >= 2.")
}

const auto mat_w = input_shape.dim(rank - 1);
const auto mat_h = input_shape.dim(rank - 2);
if (mat_w.has_dim_value() && mat_h.has_dim_value() &&
(mat_w.dim_value() != mat_h.dim_value())) {
fail_shape_inference(
"The inner-most 2 dimensions must have the same size (mat_w:",
mat_w.dim_value(),
" != mat_h:",
mat_h.dim_value(),
").");
}

// Shape inference
propagateShapeFromInputToOutput(ctx, 0, 0);
}
});

RegisterBertSchemas();
}
} // namespace contrib
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12,Clip)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Min)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Max)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MaxPool)>,
Expand Down
Loading