Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 45 additions & 43 deletions csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,14 +653,15 @@ class DispatchFFNCombineKernel {
m_prevSumBeforeRank = prevSumBeforeRank;
}
int prevSum = prevSumBeforeRank;
uint32_t prevGroupSum1 = 0;
uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0;
uint32_t dequantSum = 0;
int32_t syncLoopIdx = -1;
uint32_t n = params.problemShape.n();
BlockEpilogue1 blockEpilogue(resource, n);
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
// The ith core reads data from the ith rank's peermem
groupIdxDeq = groupIdx - 2;
uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
if (rowStart < params.maxOutputSize) {
Expand All @@ -687,57 +688,58 @@ class DispatchFFNCombineKernel {
}
}

if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) {
syncLoopIdx++;
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
}

AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete prevGroupSum1 += currentM;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line introduces two issues:

  1. The synchronization variable syncgmm1Idx is no longer incremented inside the loop. This will break the synchronization protocol with the GMM1 function, which expects a unique, incrementing flag for each iteration, likely leading to a deadlock or incorrect execution.
  2. The statement prevGroupSum1 += currentM; is appended to the end of the CrossCoreSetFlag call. This is poor coding style and harms readability. It should be on its own line.
            AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT);   // V notifies C that the current communication round is complete
            syncgmm1Idx++;
            prevGroupSum1 += currentM;

syncgmm1Idx++;

if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
uint32_t dequantLen = prevGroupSum1 - dequantSum;
if (dequantLen >= params.maxOutputSize) {
dequantLen = dequantLen - params.maxOutputSize;
prevGroupSum1 += currentM;
//第一次swglu的token数以及截断逻辑
if (groupIdx + 1 <= params.epilogueGranularity) {
if (dequantSum1 + currentM <= params.maxOutputSize) {
dequantSum1 += currentM;
} else if (dequantSum1 < params.maxOutputSize) {
dequantSum1 = params.maxOutputSize;
}

MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
}
prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx);
if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) {
dequantSum = 0;

//第二次swglu的token数以及截断逻辑
if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) {
if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) {
dequantSum2 += currentM;
} else if (dequantSum1 + dequantSum2 < params.maxOutputSize) {
dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2;
}
}
}
syncLoopIdx ++;
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V);
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); // Swiglu等GMM1【1】
AscendC::SyncAll<true>();

uint32_t lastDequantExpertNum = params.expertPerRank;
if (params.epilogueGranularity < params.expertPerRank) {
lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity;
}
if (lastDequantExpertNum < params.expertPerRank) {
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C);
}
if (prevGroupSum1 - dequantSum < params.maxOutputSize) {
uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum;
if (prevGroupSum1 >= params.maxOutputSize) {
dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize);
}
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
//第一次swglu
if (dequantSum1 > 0) { //开启了swglu深融合
uint32_t rowStartThisCore = 0;
MatrixCoord offsetC{0U, 0};
MatrixCoord shapeC{dequantSum1, params.problemShape.n()};
LayoutC layoutC{dequantSum1, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); // swiglu通知GMM2【1】
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) {
AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); // Swiglu等GMM1【1】
AscendC::SyncAll<true>();
if (dequantSum2 > 0) {
uint32_t rowStartThisCore = dequantSum1;
MatrixCoord offsetC{rowStartThisCore, 0};
uint32_t dequantLen = dequantSum2;
MatrixCoord shapeC{dequantLen, params.problemShape.n()};
LayoutC layoutC{dequantLen, params.problemShape.n()};
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC);
blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum);
}
AscendC::SyncAll<true>();
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); // swiglu通知GMM2【2】
}
blockEpilogue.Finalize();
}
Expand Down