@@ -8368,96 +8368,59 @@ void TranslateStructBufSubscriptUser(Instruction *user, Value *handle,
8368
8368
baseOffset, status, OP, DL);
8369
8369
}
8370
8370
} else if (isa<LoadInst>(user) || isa<StoreInst>(user)) {
8371
- LoadInst *ldInst = dyn_cast<LoadInst>(user);
8372
- StoreInst *stInst = dyn_cast<StoreInst>(user);
8371
+ LoadInst *LdInst = dyn_cast<LoadInst>(user);
8372
+ StoreInst *StInst = dyn_cast<StoreInst>(user);
8373
8373
8374
- Type *Ty = isa<LoadInst>(user) ? ldInst ->getType ()
8375
- : stInst ->getValueOperand ()->getType ();
8374
+ Type *Ty = isa<LoadInst>(user) ? LdInst ->getType ()
8375
+ : StInst ->getValueOperand ()->getType ();
8376
8376
Type *pOverloadTy = Ty->getScalarType ();
8377
- Value *offset = baseOffset;
8378
- unsigned arraySize = 1 ;
8379
- Value *eltSize = nullptr ;
8377
+ Value *Offset = baseOffset;
8380
8378
8381
- if (pOverloadTy->isArrayTy ()) {
8382
- arraySize = pOverloadTy->getArrayNumElements ();
8383
- eltSize = OP->GetU32Const (
8384
- DL.getTypeAllocSize (pOverloadTy->getArrayElementType ()));
8379
+ if (LdInst) {
8380
+ unsigned NumComponents = 0 ;
8381
+ Value *NewLd = nullptr ;
8382
+ if (VectorType *VTy = dyn_cast<VectorType>(Ty))
8383
+ NumComponents = VTy->getNumElements ();
8384
+ else
8385
+ NumComponents = 1 ;
8385
8386
8386
- pOverloadTy = pOverloadTy->getArrayElementType ()->getScalarType ();
8387
- }
8387
+ if (ResKind == HLResource::Kind::TypedBuffer) {
8388
+ // Typed buffer cannot have offsets, they must be loaded all at once
8389
+ ResRetValueArray ResRet = GenerateTypedBufferLoad (
8390
+ handle, pOverloadTy, bufIdx, status, OP, Builder);
8388
8391
8389
- if (ldInst) {
8390
- auto LdElement = [=](Value *offset, IRBuilder<> &Builder) -> Value * {
8391
- unsigned numComponents = 0 ;
8392
- if (VectorType *VTy = dyn_cast<VectorType>(Ty)) {
8393
- numComponents = VTy->getNumElements ();
8394
- } else {
8395
- numComponents = 1 ;
8396
- }
8397
- Constant *alignment =
8392
+ NewLd = ExtractFromTypedBufferLoad (ResRet, Ty, Offset, Builder);
8393
+ } else {
8394
+ Value *ResultElts[4 ];
8395
+ Constant *Alignment =
8398
8396
OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8399
- if (ResKind == HLResource::Kind::TypedBuffer) {
8400
- // Typed buffer cannot have offsets, they must be loaded all at once
8401
- ResRetValueArray ResRet = GenerateTypedBufferLoad (
8402
- handle, pOverloadTy, bufIdx, status, OP, Builder);
8403
-
8404
- return ExtractFromTypedBufferLoad (ResRet, Ty, offset, Builder);
8405
- } else {
8406
- Value *ResultElts[4 ];
8407
- GenerateRawBufLd (handle, bufIdx, offset, status, pOverloadTy,
8408
- ResultElts, OP, Builder, numComponents, alignment);
8409
- return ScalarizeElements (Ty, ResultElts, Builder);
8410
- }
8411
- };
8412
-
8413
- Value *newLd = LdElement (offset, Builder);
8414
- if (arraySize > 1 ) {
8415
- newLd =
8416
- Builder.CreateInsertValue (UndefValue::get (Ty), newLd, (uint64_t )0 );
8417
-
8418
- for (unsigned i = 1 ; i < arraySize; i++) {
8419
- offset = Builder.CreateAdd (offset, eltSize);
8420
- Value *eltLd = LdElement (offset, Builder);
8421
- newLd = Builder.CreateInsertValue (newLd, eltLd, i);
8422
- }
8397
+ GenerateRawBufLd (handle, bufIdx, Offset, status, pOverloadTy,
8398
+ ResultElts, OP, Builder, NumComponents, Alignment);
8399
+ NewLd = ScalarizeElements (Ty, ResultElts, Builder);
8423
8400
}
8424
- ldInst->replaceAllUsesWith (newLd);
8401
+
8402
+ LdInst->replaceAllUsesWith (NewLd);
8425
8403
} else {
8426
- Value *val = stInst->getValueOperand ();
8427
- auto StElement = [&](Value *offset, Value *val, IRBuilder<> &Builder) {
8428
- Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8429
- Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8430
- uint8_t mask = 0 ;
8431
- if (Ty->isVectorTy ()) {
8432
- unsigned vectorNumElements = Ty->getVectorNumElements ();
8433
- DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8434
- assert (vectorNumElements <= 4 );
8435
- for (unsigned i = 0 ; i < vectorNumElements; i++) {
8436
- vals[i] = Builder.CreateExtractElement (val, i);
8437
- mask |= (1 << i);
8438
- }
8439
- } else {
8440
- vals[0 ] = val;
8441
- mask = DXIL::kCompMask_X ;
8442
- }
8443
- Constant *alignment =
8444
- OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8445
- GenerateStructBufSt (handle, bufIdx, offset, pOverloadTy, OP, Builder,
8446
- vals, mask, alignment);
8447
- };
8448
- if (arraySize > 1 )
8449
- val = Builder.CreateExtractValue (val, 0 );
8450
-
8451
- StElement (offset, val, Builder);
8452
- if (arraySize > 1 ) {
8453
- val = stInst->getValueOperand ();
8454
-
8455
- for (unsigned i = 1 ; i < arraySize; i++) {
8456
- offset = Builder.CreateAdd (offset, eltSize);
8457
- Value *eltVal = Builder.CreateExtractValue (val, i);
8458
- StElement (offset, eltVal, Builder);
8404
+ Value *val = StInst->getValueOperand ();
8405
+ Value *undefVal = llvm::UndefValue::get (pOverloadTy);
8406
+ Value *vals[] = {undefVal, undefVal, undefVal, undefVal};
8407
+ uint8_t mask = 0 ;
8408
+ if (Ty->isVectorTy ()) {
8409
+ unsigned vectorNumElements = Ty->getVectorNumElements ();
8410
+ DXASSERT (vectorNumElements <= 4 , " up to 4 elements in vector" );
8411
+ assert (vectorNumElements <= 4 );
8412
+ for (unsigned i = 0 ; i < vectorNumElements; i++) {
8413
+ vals[i] = Builder.CreateExtractElement (val, i);
8414
+ mask |= (1 << i);
8459
8415
}
8416
+ } else {
8417
+ vals[0 ] = val;
8418
+ mask = DXIL::kCompMask_X ;
8460
8419
}
8420
+ Constant *alignment =
8421
+ OP->GetI32Const (DL.getTypeAllocSize (Ty->getScalarType ()));
8422
+ GenerateStructBufSt (handle, bufIdx, Offset, pOverloadTy, OP, Builder,
8423
+ vals, mask, alignment);
8461
8424
}
8462
8425
user->eraseFromParent ();
8463
8426
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(user)) {
0 commit comments