diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index 577c248cc09e..f28a20105500 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -48,6 +48,13 @@ llvm::cl::opt llvm::cl::desc("force use mma sync instead of wmma ops"), llvm::cl::init(false)); +/// Flag used to toggle expected context length. +/// Note: With system prompt we start with ~141 tokens. +llvm::cl::opt clGPUContextLength( + "iree-codegen-llvmgpu-context-length", + llvm::cl::desc("Sets expected context length to prevent overflow"), + llvm::cl::init(512)); + namespace { constexpr StringLiteral kCudaTarget = "cuda"; @@ -299,7 +306,33 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint, staticNonUnitParallelDimCount += bounds[nDim] != 1 && !ShapedType::isDynamic(bounds[nDim]); } - if (staticNonUnitParallelDimCount <= 1) + + // Since all parallel dims are tiled to 1 along the grid, there are + // cases where we overflow the gridDim. We add a temporary heuristic + // use slow path if context length and MaxGridDim may cause overflow. + // TODO: This is conservative and temporary solution, we'd need to + // add better distribution of parallel dims. + bool mayOverflowGridinWarpReduc = false; + { + SmallVector parallelDims; + op.getParallelDims(parallelDims); + std::optional parallelSize = 1; + const int kMaxGridDim = 4190000; + bool hasDynParallelDim = false; + for (int64_t dim : parallelDims) { + if (ShapedType::isDynamic(bounds[dim])) { + hasDynParallelDim = true; + continue; + } + *parallelSize *= bounds[dim]; + } + if (parallelSize && hasDynParallelDim && + *parallelSize * clGPUContextLength >= kMaxGridDim) { + mayOverflowGridinWarpReduc = true; + } + } + + if (staticNonUnitParallelDimCount <= 1 && !mayOverflowGridinWarpReduc) return failure(); // Don't consider operations that don't have a broadcast, those should go @@ -311,7 +344,7 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint, // TODO: Properly rematerialize leading elementwise with shared memory // promotion. - if (hasFusedLeadingOp(op)) { + if (hasFusedLeadingOp(op) && !mayOverflowGridinWarpReduc) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 8271580b58a1..45385e21faa7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -163,7 +163,8 @@ void addGPUVectorizationPassPipeline(OpPassManager &pm) { void addGPUMatmulSimtPassPipeline(OpPassManager &pm) { tileAndDistributeToWorkgroup(pm); auto &nestedModulePM = pm.nest(); - + nestedModulePM.addNestedPass( + createRematerializeParallelOpsPass()); nestedModulePM.addPass(createCanonicalizerPass()); nestedModulePM.addNestedPass( createWorkgroupSpecializationPass());