Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ function(onnxruntime_configure_target target_name)

# Keep BinSkim happy
if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM")
target_link_options(${target_name} PRIVATE "/CETCOMPAT")
target_link_options(${target_name} PRIVATE "$<$<LINK_LANGUAGE:CXX,C>:/CETCOMPAT>" "$<$<LINK_LANGUAGE:CUDA>:-Xlinker=/CETCOMPAT>")
endif()

endfunction()
Expand Down Expand Up @@ -1421,7 +1421,6 @@ configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_c
get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)

if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDA_HOME)
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
Expand All @@ -1441,6 +1440,14 @@ if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
if (UNIX)
# Suppress deprecation errors (e.g., long4 in CUDA 13)
add_compile_options(-Wno-deprecated-declarations)
endif()
endif()

if (NOT WIN32)
list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
endif()
Expand Down
13 changes: 13 additions & 0 deletions cmake/external/cuda_configuration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ macro(setup_cuda_compiler)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION)
message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}")
endif()

# For CUDA 13+, explicitly set the compiler front-end to Clang to handle
# MSVC-specific pragmas correctly in device code.
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0 AND NOT DEFINED CMAKE_CUDA_COMPILER_FRONTEND_VARIANT)
message(STATUS "Setting CUDA compiler front-end to Clang by default for CUDA 13+.")
set(CMAKE_CUDA_COMPILER_FRONTEND_VARIANT "CLANG")
endif()

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
set(CMAKE_CUDA_RUNTIME_LIBRARY "Hybrid")
else()
set(CMAKE_CUDA_RUNTIME_LIBRARY "Shared")
endif()
endmacro()

macro(setup_cuda_architectures)
Expand Down
11 changes: 11 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,17 @@
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=221>")
endif()

if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
if (UNIX)
# Suppress -Wattributes warning from protobuf headers with nvcc on Linux
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-attributes>")
endif()

if (MSVC)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=20199>")
endif()
endif()

if (UNIX)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-reorder>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wno-reorder>")
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ endif()
list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo)
endif()
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS)
list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart)
endif()

if (onnxruntime_USE_TENSORRT)
Expand Down Expand Up @@ -1751,6 +1751,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (HAS_QSPECTRE)
list(APPEND custom_op_lib_option "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /Qspectre>")
endif()
set(custom_op_lib_link ${custom_op_lib_link} CUDA::cudart)
endif()

file(GLOB custom_op_src ${custom_op_src_patterns})
Expand Down
14 changes: 10 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128;

// longlong4 is deprecated in cuda 13.
// LongLong4 is similar to longlong4_32a, except this is also visible in Host compiler (longlong4_32a is only visible to nvcc);
typedef struct __align__(32) {
long long int x, y, z, w;

Check warning on line 28 in onnxruntime/contrib_ops/cuda/bert/attention_impl.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/attention_impl.h:28: Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4]
} LongLong4;

// A cache for cumulated sequence length. It will be initialized in the first request, then become read-only after that.
struct CumulatedSequenceLengthCache {
onnxruntime::IAllocatorUniquePtr<void> buffer;
Expand Down Expand Up @@ -144,14 +150,14 @@
template <typename T>
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
T* out, LongLong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);

} // namespace cuda
Expand Down
26 changes: 13 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace contrib {
namespace cuda {

template <typename T>
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
__global__ void StridedCopy(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h)
T* out, LongLong4 out_strides, // coord (b,n,s,h)
const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) {
const int h = threadIdx.x;
const int n = threadIdx.y;
Expand All @@ -30,8 +30,8 @@ __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, //
}

template <typename T>
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
__global__ void StridedCopyLarge(const T* in, const int H, LongLong4 in_strides, // coord (b,n,s,h)
T* out, LongLong4 out_strides, // coord (b,n,s,h)
const int* in_seqlens_offset, const int* out_seqlens_offset) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
Expand Down Expand Up @@ -77,7 +77,7 @@ struct ToByteType<16> {

template <>
struct ToByteType<32> {
using T = ulonglong4;
using T = LongLong4;
};

template <int NumBytes>
Expand All @@ -86,8 +86,8 @@ using ToBytes = typename ToByteType<NumBytes>::T;
template <typename T>
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block) {
int batch_size = in_shape.x;
int num_heads = in_shape.y;
Expand Down Expand Up @@ -157,8 +157,8 @@ Status LaunchStridedCopy(

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
T* out, LongLong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
const int* in_seqlens_offset = nullptr;
const int* out_seqlens_offset = nullptr;
Expand All @@ -170,14 +170,14 @@ Status LaunchStridedCopy(cudaStream_t stream,

template Status LaunchStridedCopy<float>(
cudaStream_t stream,
const float* in, int4 in_shape, longlong4 in_strides,
float* out, longlong4 out_strides,
const float* in, int4 in_shape, LongLong4 in_strides,
float* out, LongLong4 out_strides,
int max_threads_per_block);

template Status LaunchStridedCopy<half>(
cudaStream_t stream,
const half* in, int4 in_shape, longlong4 in_strides,
half* out, longlong4 out_strides,
const half* in, int4 in_shape, LongLong4 in_strides,
half* out, LongLong4 out_strides,
int max_threads_per_block);

} // namespace cuda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@

#pragma once

#include "core/providers/cuda/curand_wrapper.h"

#ifdef HAS_PYTORCH
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#endif

#include <curand_kernel.h>
#include <cmath>
#include <cinttypes>
#include <vector>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h"

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#pragma once

#include <stdint.h>
#include "core/providers/cuda/curand_wrapper.h"
#include <cuda_fp16.h>
#include <curand_kernel.h>

#include <cstdio>
#include "contrib_ops/cpu/transformers/generation_shared.h"

Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
namespace contrib {
namespace rocm {

typedef struct __align__(32) {
long long int x, y, z, w;

Check warning on line 18 in onnxruntime/contrib_ops/rocm/bert/attention_impl.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4] Raw Output: onnxruntime/contrib_ops/rocm/bert/attention_impl.h:18: Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4]
} LongLong4;

size_t GetAttentionScratchSize(
size_t element_size,
int batch_size,
Expand Down Expand Up @@ -162,14 +166,14 @@
template <typename T>
Status LaunchStridedCopy(
hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, LongLong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
const T* in, int4 in_shape, LongLong4 in_strides, // coord (b,n,s,h)
T* out, LongLong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block);
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ struct Strides {
int seqlen_dim,
int head_size_dim) {
ORT_UNUSED_PARAMETER(batch_dim);
return Strides{longlong4{
return Strides{LongLong4{
static_cast<int64_t>(num_head_dim) * seqlen_dim * head_size_dim,
static_cast<int64_t>(seqlen_dim) * head_size_dim,
static_cast<int64_t>(head_size_dim),
Expand All @@ -157,15 +157,15 @@ struct Strides {
int num_head_dim,
int head_size_dim) {
ORT_UNUSED_PARAMETER(batch_dim);
return Strides{longlong4{
return Strides{LongLong4{
static_cast<int64_t>(seqlen_dim) * num_head_dim * head_size_dim,
static_cast<int64_t>(head_size_dim),
static_cast<int64_t>(num_head_dim) * head_size_dim,
static_cast<int64_t>(1),
}};
}

template <typename T = longlong4>
template <typename T = LongLong4>
T ForBNSHCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
Expand All @@ -174,7 +174,7 @@ struct Strides {
static_cast<E>(strides_for_bnsh_coord.w)};
}

template <typename T = longlong4>
template <typename T = LongLong4>
T ForBSNHCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
Expand All @@ -183,7 +183,7 @@ struct Strides {
static_cast<E>(strides_for_bnsh_coord.w)};
}

template <typename T = longlong4>
template <typename T = LongLong4>
T ForBNHSCoord() const {
using E = typename T::value_type;
return T{static_cast<E>(strides_for_bnsh_coord.x),
Expand All @@ -198,7 +198,7 @@ struct Strides {
}

// store intermediate strides in the canonical (b,n,s,h) coordinate order
longlong4 strides_for_bnsh_coord;
LongLong4 strides_for_bnsh_coord;
};

template <typename HipT, typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ struct TreeNodeElement {

inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0x1F); }
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
inline bool is_missing_track_true() const { return static_cast<uint8_t>(flags) & static_cast<uint8_t>(MissingTrack::kTrue); }

#if defined(_TREE_DEBUG)
std::string str() const {
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cuda/curand_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#if defined(CUDA_VERSION) && CUDA_VERSION == 13000
#define __NV_NO_VECTOR_DEPRECATION_DIAG 1
#endif

#include <curand_kernel.h>
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cuda/fpgeneric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
*/
// NV_TODO: optimize speed -- pass things needed in, optimize kernel speed, add half2
// NV_TODO: investigate cub support for half

#include "core/providers/cuda/curand_wrapper.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include <curand_kernel.h>

#define TRANS_TILE_DIM 32
#define BLOCK_ROWS 8
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/generator/random_impl.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/curand_wrapper.h"
#include "core/providers/cuda/generator/random_impl.h"

#include <curand_kernel.h>
#include <algorithm>
#include "core/providers/cuda/cu_inc/common.cuh"

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/nn/dropout_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
*/

/* Modifications Copyright (c) Microsoft. */
#include "core/providers/cuda/curand_wrapper.h"

#include "core/providers/cuda/nn/dropout_impl.h"

#include <curand_kernel.h>
#include <algorithm>
#include "core/providers/cuda/cu_inc/bitmask.cuh"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/curand_wrapper.h"
#include "orttraining/training_ops/cuda/math/bias_softmax_dropout_impl.h"

#include <curand_kernel.h>
#include <algorithm>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/math/softmax_warpwise_impl.cuh"
Expand Down
Loading