@@ -8436,89 +8436,52 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
8436
8436
: stInst->getValueOperand ()->getType ();
8437
8437
Type *pOverloadTy = Ty->getScalarType ();
8438
8438
Value *offset = baseOffset;
8439
- unsigned arraySize = 1 ;
8440
- Value *eltSize = nullptr ;
8441
8439
8442
- if (pOverloadTy->isArrayTy ()) {
8443
- arraySize = pOverloadTy->getArrayNumElements ();
8444
- eltSize = OP->GetU32Const (
8445
- DL.getTypeAllocSize (pOverloadTy->getArrayElementType ()));
8440
+ if (ldInst) {
8441
+ unsigned numComponents = 0 ;
8442
+ Value *newLd = nullptr ;
8443
+ if (VectorType *VTy = dyn_cast<VectorType>(Ty))
8444
+ numComponents = VTy->getNumElements ();
8445
+ else
8446
+ numComponents = 1 ;
8446
8447
8447
- pOverloadTy = pOverloadTy->getArrayElementType ()->getScalarType ();
8448
- }
8448
+ if (ResKind == HLResource::Kind::TypedBuffer) {
8449
+ // Typed buffer cannot have offsets, they must be loaded all at once
8450
+ ResRetValueArray ResRet = GenerateTypedBufferLoad (
8451
+ handle, pOverloadTy, bufIdx, status, OP, Builder);
8449
8452
8450
- if (ldInst) {
8451
- auto LdElement = [=](Value *offset, IRBuilder<> &Builder) -> Value * {
8452
- unsigned numComponents = 0 ;
8453
- if (VectorType *VTy = dyn_cast<VectorType>(Ty)) {
8454
- numComponents = VTy->getNumElements ();
8455
- } else {
8456
- numComponents = 1 ;
8457
- }
8453
+ newLd = ExtractFromTypedBufferLoad (ResRet, Ty, offset, Builder);
8454
+ } else {
8455
+ Value *ResultElts[4 ];
8458
8456
Constant *alignment =
8459
8457
OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8460
- if (ResKind == HLResource::Kind::TypedBuffer) {
8461
- // Typed buffer cannot have offsets, they must be loaded all at once
8462
- ResRetValueArray ResRet = GenerateTypedBufferLoad (
8463
- handle, pOverloadTy, bufIdx, status, OP, Builder);
8464
-
8465
- return ExtractFromTypedBufferLoad (ResRet, Ty, offset, Builder);
8466
- } else {
8467
- Value *ResultElts[4 ];
8468
- GenerateRawBufLd (handle, bufIdx, offset, status, pOverloadTy,
8469
- ResultElts, OP, Builder, numComponents, alignment);
8470
- return ScalarizeElements (Ty, ResultElts, Builder);
8471
- }
8472
- };
8473
-
8474
- Value *newLd = LdElement (offset, Builder);
8475
- if (arraySize > 1 ) {
8476
- newLd =
8477
- Builder.CreateInsertValue (UndefValue::get (Ty), newLd, (uint64_t )0 );
8478
-
8479
- for (unsigned i = 1 ; i < arraySize; i++) {
8480
- offset = Builder.CreateAdd (offset, eltSize);
8481
- Value *eltLd = LdElement (offset, Builder);
8482
- newLd = Builder.CreateInsertValue (newLd, eltLd, i);
8483
- }
8458
+ GenerateRawBufLd (handle, bufIdx, offset, status, pOverloadTy,
8459
+ ResultElts, OP, Builder, numComponents, alignment);
8460
+ newLd = ScalarizeElements (Ty, ResultElts, Builder);
8484
8461
}
8462
+
8485
8463
ldInst->replaceAllUsesWith (newLd);
8486
8464
} else {
8487
8465
Value *val = stInst->getValueOperand ();
8488
- auto StElement = [&](Value *offset, Value *val, IRBuilder<> &Builder) {
8489
- Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8490
- Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8491
- uint8_t mask = 0 ;
8492
- if (Ty->isVectorTy ()) {
8493
- unsigned vectorNumElements = Ty->getVectorNumElements ();
8494
- DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8495
- assert (vectorNumElements <= 4 );
8496
- for (unsigned i = 0 ; i < vectorNumElements; i++) {
8497
- vals[i] = Builder.CreateExtractElement (val, i);
8498
- mask |= (1 << i);
8499
- }
8500
- } else {
8501
- vals[0 ] = val;
8502
- mask = DXIL::kCompMask_X ;
8503
- }
8504
- Constant *alignment =
8505
- OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8506
- GenerateStructBufSt (handle, bufIdx, offset, pOverloadTy, OP, Builder,
8507
- vals, mask, alignment);
8508
- };
8509
- if (arraySize > 1 )
8510
- val = Builder.CreateExtractValue (val, 0 );
8511
-
8512
- StElement (offset, val, Builder);
8513
- if (arraySize > 1 ) {
8514
- val = stInst->getValueOperand ();
8515
-
8516
- for (unsigned i = 1 ; i < arraySize; i++) {
8517
- offset = Builder.CreateAdd (offset, eltSize);
8518
- Value *eltVal = Builder.CreateExtractValue (val, i);
8519
- StElement (offset, eltVal, Builder);
8466
+ Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8467
+ Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8468
+ uint8_t mask = 0 ;
8469
+ if (Ty->isVectorTy ()) {
8470
+ unsigned vectorNumElements = Ty->getVectorNumElements ();
8471
+ DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8472
+ assert (vectorNumElements <= 4 );
8473
+ for (unsigned i = 0 ; i < vectorNumElements; i++) {
8474
+ vals[i] = Builder.CreateExtractElement (val, i);
8475
+ mask |= (1 << i);
8520
8476
}
8477
+ } else {
8478
+ vals[0 ] = val;
8479
+ mask = DXIL::kCompMask_X ;
8521
8480
}
8481
+ Constant *alignment =
8482
+ OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8483
+ GenerateStructBufSt (handle, bufIdx, offset, pOverloadTy, OP, Builder,
8484
+ vals, mask, alignment);
8522
8485
}
8523
8486
user->eraseFromParent ();
8524
8487
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(user)) {
0 commit comments