Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ if(WITH_GPU)

message(
STATUS
"copy paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h to $ENV{runtime_include_dir}"
"copy paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h paddle/cinn/common/float8e4m3.h to $ENV{runtime_include_dir}"
)
file(COPY paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h
paddle/cinn/common/float8e4m3.h
DESTINATION $ENV{runtime_include_dir})

find_library(CUDASTUB libcuda.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ std::string CodeGenC::GetTypeName(Type type) {
GET_SCALAR_TYPE(type.is_uint(32), "uint32_t");
GET_SCALAR_TYPE(type.is_uint(64), "uint64_t");

GET_SCALAR_TYPE(type.is_float8e4m3(), "float8e4m3");
GET_SCALAR_TYPE(type.is_bfloat16(), "bfloat16");
GET_SCALAR_TYPE(type.is_float16(), "float16");
GET_SCALAR_TYPE(type.is_float(32), "float")
Expand Down Expand Up @@ -987,6 +988,8 @@ void CodeGenC::PrintRuntimeType(const cinn_type_t &type) {
str_ += "cinn_uint64_t()";
} else if (type == cinn_bfloat16_t()) {
str_ += "cinn_bfloat16_t()";
} else if (type == cinn_float8e4m3_t()) {
str_ += "cinn_float8e4m3_t()";
} else if (type == cinn_float16_t()) {
str_ += "cinn_float16_t()";
} else if (type == cinn_float32_t()) {
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ const std::string CodeGenCudaDev::general_source_header_ = // NOLINT
#define CINN_WITH_CUDA
#include "bfloat16.h"
#include "float16.h"
#include "float8e4m3.h"
using cinn::common::bfloat16;
using cinn::common::float16;
using cinn::common::float8;
using cinn::common::float8e4m3;
using cinn::common::half4;
using cinn::common::half8;
using cinn::common::float168;
using cinn::common::float164;
using cinn::common::float162;
using cinn::common::bfloat168;
using cinn::common::bfloat164;
using cinn::common::float8e4m32;
using cinn::common::float8e4m34;
using cinn::common::bfloat162;
#include <cooperative_groups.h>
#include "cinn_cuda_runtime_source.cuh"
Expand All @@ -47,7 +50,7 @@ const std::string CodeGenCudaDev::source_header_ = // NOLINT
#include <float16_h>
using cinn::common::bfloat16;
using cinn::common::float16;
using cinn::common::float8;
using cinn::common::float8e4m3;
using cinn::common::half4;
using cinn::common::half8;
using cinn::common::float168;
Expand All @@ -56,6 +59,9 @@ using cinn::common::float162;
using cinn::common::bfloat168;
using cinn::common::bfloat164;
using cinn::common::bfloat162;
using cinn::common::float8e4m3;
using cinn::common::float8e4m32;
using cinn::common::float8e4m34;
#include <cooperative_groups.h>
#include <cinn_cuda_runtime_source_h>
)";
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace backends {
using BinaryInstruction = llvm::Instruction::BinaryOps;
using cinn::common::bfloat16;
using cinn::common::float16;
using cinn::common::float8e4m3;

namespace {

Expand Down Expand Up @@ -263,6 +264,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::FloatImm *op) {
return llvm::ConstantFP::get(b_->getBFloatTy(), op->value);
} else if (op->type().is_float16()) {
return llvm::ConstantFP::get(b_->getHalfTy(), op->value);
} else if (op->type().is_float8e4m3()) {
PADDLE_THROW(::common::errors::InvalidArgument(
"llvm not support float8 yet.")); // TODO(YuhanXu)
} else {
PADDLE_THROW(::common::errors::InvalidArgument("illegal float type."));
}
Expand Down Expand Up @@ -566,6 +570,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_bfloat16);
} else if (op->type().is_float16()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_float16);
} else if (op->type().is_float8e4m3()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_float8e4m3);
} else if (op->type() == type_of<void *>()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_void_p);
} else if (op->type() == type_of<cinn_buffer_t *>() ||
Expand Down Expand Up @@ -1794,6 +1800,8 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_double);
} else if (to_type == type_of<bfloat16>()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_bfloat16);
} else if (to_type == type_of<float8e4m3>()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_float8e4m3);
} else if (to_type == type_of<float16>()) {
callee = m_->getFunction(runtime::intrinsic::pod_value_to_float16);
} else if (to_type == type_of<bool>()) {
Expand Down
11 changes: 11 additions & 0 deletions paddle/cinn/backends/llvm/llvm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace backends {

using cinn::common::bfloat16;
using cinn::common::float16;
using cinn::common::float8e4m3;

llvm::Type *CinnTypeToLLVMType(cinn::common::Type type,
llvm::Module *m,
Expand Down Expand Up @@ -52,6 +53,8 @@ llvm::Type *CinnTypeToLLVMType(cinn::common::Type type,
llvm::Type *f16 = llvm::Type::getHalfTy(m->getContext());
llvm::Type *f32 = llvm::Type::getFloatTy(m->getContext());
llvm::Type *f64 = llvm::Type::getDoubleTy(m->getContext());
llvm::Type *f8e4m3 = llvm::Type::getInt8Ty(
m->getContext()); // TODO(YuhanXu) : llvm not support fp8
llvm::Type *arr =
llvm::Type::getPrimitiveType(m->getContext(), llvm::Type::ArrayTyID);
if (type.is_void() && type.is_cpp_handle()) {
Expand Down Expand Up @@ -87,6 +90,13 @@ llvm::Type *CinnTypeToLLVMType(cinn::common::Type type,
ir_type = bf16;
} else if (type.is_float16()) {
ir_type = f16;
} else if (type.is_float8e4m3()) {
PADDLE_ENFORCE_NOT_NULL(
ir_type,
::common::errors::InvalidArgument(
"LLVM can't convert type: f8e4m3.")); // TODO(YuhanXu) : llvm not
// support fp8
ir_type = f8e4m3;
} else if (type.is_void()) {
ir_type = v;
} else if (type.is_string()) {
Expand Down Expand Up @@ -140,6 +150,7 @@ __(bfloat16)
__(float16)
__(float)
__(double)
__(float8e4m3)
__(cinn_buffer_t)
__(cinn_buffer_t *)
__(cinn_pod_value_t *)
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/backends/nvrtc/header_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ static const std::string cinn_float16_header = // NOLINT
read_file_as_string("float16.h");
static const std::string cinn_bfloat16_header = // NOLINT
read_file_as_string("bfloat16.h");
static const std::string cinn_float8e4m3_header = // NOLINT
read_file_as_string("float8e4m3.h");
static const std::string cinn_with_cuda_header = // NOLINT
R"(
#pragma once
Expand All @@ -86,6 +88,8 @@ JitSafeHeaderGenerator::JitSafeHeaderGenerator() {
headers_.emplace_back(cinn_float16_header.data());
include_names_.emplace_back("bfloat16_h");
headers_.emplace_back(cinn_bfloat16_header.data());
include_names_.emplace_back("float8e4m3_h");
headers_.emplace_back(cinn_float8e4m3_header.data());
include_names_.emplace_back("cinn_with_cuda_h");
headers_.emplace_back(cinn_with_cuda_header.data());
include_names_.emplace_back("cinn_cuda_runtime_source_h");
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/common/cinn_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class CINNValue : public cinn_pod_value_t {
explicit CINNValue(float value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<float>();
}
explicit CINNValue(float8e4m3 value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<float8e4m3>();
}
explicit CINNValue(bfloat16 value) : cinn_pod_value_t(value) {
type_code_ = ::cinn_type_code<bfloat16>();
}
Expand All @@ -163,6 +166,7 @@ class CINNValue : public cinn_pod_value_t {
using cinn_pod_value_t::operator double;
using cinn_pod_value_t::operator float;
using cinn_pod_value_t::operator cinn::common::bfloat16;
using cinn_pod_value_t::operator cinn::common::float8e4m3;
using cinn_pod_value_t::operator cinn::common::float16;
using cinn_pod_value_t::operator bool;
using cinn_pod_value_t::operator int32_t;
Expand All @@ -189,6 +193,7 @@ class CINNValue : public cinn_pod_value_t {
CINNValue& operator=(float value);
CINNValue& operator=(double value);
CINNValue& operator=(bfloat16 value);
CINNValue& operator=(float8e4m3 value);
CINNValue& operator=(float16 value);
CINNValue& operator=(char* value);
CINNValue& operator=(const std::string& value);
Expand Down
Loading