88
99#include " RawBufferMethods.h"
1010#include " AlignmentSizeCalculator.h"
11+ #include " LowerTypeVisitor.h"
1112#include " clang/AST/ASTContext.h"
1213#include " clang/AST/CharUnits.h"
1314#include " clang/AST/RecordLayout.h"
@@ -284,44 +285,48 @@ SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
284285 // aligned like their field with the largest alignment.
285286 // As a result, there might exist some padding after some struct members.
286287 if (const auto *structType = targetType->getAs <RecordType>()) {
287- const auto *decl = structType->getDecl ();
288+ LowerTypeVisitor lowerTypeVisitor (astContext, theEmitter.getSpirvContext (),
289+ theEmitter.getSpirvOptions (), spvBuilder);
290+ auto *decl = targetType->getAsTagDecl ();
291+ assert (decl && " Expected all structs to be tag decls." );
292+ const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType (
293+ targetType, theEmitter.getSpirvOptions ().sBufferLayoutRule , llvm::None,
294+ decl->getLocation ()));
288295 llvm::SmallVector<SpirvInstruction *, 4 > loadedElems;
289- uint32_t fieldOffsetInBytes = 0 ;
290- uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
291- std::tie (structAlignment, structSize) =
292- AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
293- .getAlignmentAndSize (targetType,
294- theEmitter.getSpirvOptions ().sBufferLayoutRule ,
295- llvm::None, &stride);
296- for (const auto *field : decl->fields ()) {
297- AlignmentSizeCalculator alignmentCalc (astContext,
298- theEmitter.getSpirvOptions ());
299- uint32_t fieldSize = 0 , fieldAlignment = 0 ;
300- std::tie (fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize (
301- field->getType (), theEmitter.getSpirvOptions ().sBufferLayoutRule ,
302- /* isRowMajor*/ llvm::None, &stride);
303- fieldOffsetInBytes = roundToPow2 (fieldOffsetInBytes, fieldAlignment);
304- auto *byteOffset = address.getByteAddress ();
305- if (fieldOffsetInBytes != 0 ) {
306- byteOffset = spvBuilder.createBinaryOp (
307- spv::Op::OpIAdd, astContext.UnsignedIntTy , byteOffset,
308- spvBuilder.getConstantInt (astContext.UnsignedIntTy ,
309- llvm::APInt (32 , fieldOffsetInBytes)),
310- loc, range);
311- }
312-
313- loadedElems.push_back (processTemplatedLoadFromBuffer (
314- buffer, byteOffset, field->getType (), range));
296+ forEachSpirvField (
297+ structType, spvType,
298+ [this , &buffer, &address, range,
299+ &loadedElems](size_t spirvFieldIndex, const QualType &fieldType,
300+ const auto &field) {
301+ auto *baseOffset = address.getByteAddress ();
302+ if (field.offset .hasValue () && field.offset .getValue () != 0 ) {
303+ const auto loc = buffer->getSourceLocation ();
304+ SpirvConstant *offset = spvBuilder.getConstantInt (
305+ astContext.UnsignedIntTy ,
306+ llvm::APInt (32 , field.offset .getValue ()));
307+ baseOffset = spvBuilder.createBinaryOp (
308+ spv::Op::OpIAdd, astContext.UnsignedIntTy , baseOffset, offset,
309+ loc, range);
310+ }
315311
316- fieldOffsetInBytes += fieldSize;
317- }
312+ loadedElems.push_back (processTemplatedLoadFromBuffer (
313+ buffer, baseOffset, fieldType, range));
314+ return true ;
315+ });
318316
319317 // After we're done with loading the entire struct, we need to update the
320318 // byteAddress (in case we are loading an array of structs).
321319 //
322320 // struct size = 34 bytes (34 / 8) = 4 full words (34 % 8) = 2 > 0,
323321 // therefore need to move to the next aligned address So the starting byte
324322 // offset after loading the entire struct is: 8 * (4 + 1) = 40
323+ uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
324+ std::tie (structAlignment, structSize) =
325+ AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
326+ .getAlignmentAndSize (targetType,
327+ theEmitter.getSpirvOptions ().sBufferLayoutRule ,
328+ llvm::None, &stride);
329+
325330 assert (structAlignment != 0 );
326331 SpirvInstruction *structWidth = spvBuilder.getConstantInt (
327332 astContext.UnsignedIntTy ,
@@ -577,7 +582,7 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
577582 return ;
578583 default :
579584 theEmitter.emitError (
580- " templated load of ByteAddressBuffer is only implemented for "
585+ " templated store of ByteAddressBuffer is only implemented for "
581586 " 16, 32, and 64-bit types" ,
582587 loc);
583588 return ;
@@ -604,40 +609,36 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
604609 // aligned like their field with the largest alignment.
605610 // As a result, there might exist some padding after some struct members.
606611 if (const auto *structType = valueType->getAs <RecordType>()) {
607- const auto *decl = structType->getDecl ();
608- uint32_t fieldOffsetInBytes = 0 ;
609- uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
610- std::tie (structAlignment, structSize) =
611- AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
612- .getAlignmentAndSize (valueType,
613- theEmitter.getSpirvOptions ().sBufferLayoutRule ,
614- llvm::None, &stride);
615- uint32_t fieldIndex = 0 ;
616- for (const auto *field : decl->fields ()) {
617- AlignmentSizeCalculator alignmentCalc (astContext,
618- theEmitter.getSpirvOptions ());
619- uint32_t fieldSize = 0 , fieldAlignment = 0 ;
620- std::tie (fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize (
621- field->getType (), theEmitter.getSpirvOptions ().sBufferLayoutRule ,
622- /* isRowMajor*/ llvm::None, &stride);
623- fieldOffsetInBytes = roundToPow2 (fieldOffsetInBytes, fieldAlignment);
624- auto *byteOffset = address.getByteAddress ();
625- if (fieldOffsetInBytes != 0 ) {
626- byteOffset = spvBuilder.createBinaryOp (
627- spv::Op::OpIAdd, astContext.UnsignedIntTy , byteOffset,
628- spvBuilder.getConstantInt (astContext.UnsignedIntTy ,
629- llvm::APInt (32 , fieldOffsetInBytes)),
630- loc, range);
631- }
612+ LowerTypeVisitor lowerTypeVisitor (astContext, theEmitter.getSpirvContext (),
613+ theEmitter.getSpirvOptions (), spvBuilder);
614+ auto *decl = valueType->getAsTagDecl ();
615+ assert (decl && " Expected all structs to be tag decls." );
616+ const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType (
617+ valueType, theEmitter.getSpirvOptions ().sBufferLayoutRule , llvm::None,
618+ decl->getLocation ()));
619+ assert (spvType);
620+ forEachSpirvField (
621+ structType, spvType,
622+ [this , &address, loc, range, buffer, value](size_t spirvFieldIndex,
623+ const QualType &fieldType,
624+ const auto &field) {
625+ auto *baseOffset = address.getByteAddress ();
626+ if (field.offset .hasValue () && field.offset .getValue () != 0 ) {
627+ SpirvConstant *offset = spvBuilder.getConstantInt (
628+ astContext.UnsignedIntTy ,
629+ llvm::APInt (32 , field.offset .getValue ()));
630+ baseOffset = spvBuilder.createBinaryOp (
631+ spv::Op::OpIAdd, astContext.UnsignedIntTy , baseOffset, offset,
632+ loc, range);
633+ }
632634
633- processTemplatedStoreToBuffer (
634- spvBuilder.createCompositeExtract (field->getType (), value,
635- {fieldIndex}, loc, range),
636- buffer, byteOffset, field->getType (), range);
637-
638- fieldOffsetInBytes += fieldSize;
639- ++fieldIndex;
640- }
635+ processTemplatedStoreToBuffer (
636+ spvBuilder.createCompositeExtract (
637+ fieldType, value, {static_cast <uint32_t >(spirvFieldIndex)},
638+ loc, range),
639+ buffer, baseOffset, fieldType, range);
640+ return true ;
641+ });
641642
642643 // After we're done with storing the entire struct, we need to update the
643644 // byteAddress (in case we are storing an array of structs).
@@ -647,6 +648,13 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
647648 // (34 % 8) = 2 > 0, therefore need to move to the next aligned address
648649 // So the starting byte offset after loading the entire struct is:
649650 // 8 * (4 + 1) = 40
651+ uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
652+ std::tie (structAlignment, structSize) =
653+ AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
654+ .getAlignmentAndSize (valueType,
655+ theEmitter.getSpirvOptions ().sBufferLayoutRule ,
656+ llvm::None, &stride);
657+
650658 assert (structAlignment != 0 );
651659 auto *structWidth = spvBuilder.getConstantInt (
652660 astContext.UnsignedIntTy ,
0 commit comments