Skip to content

Commit

Permalink
Batch: rename BatchInterval to BatchStep, move to tail of StepInfo
Browse files Browse the repository at this point in the history
Previous we append BatchInterval at the head of StepInfo, using it
to update index of software buffer.
This change rename BatchInterval to BatchStep, and append it to
StepInfo in BatchCollector rather than BatchAssembler to simplify
logic. Also, putting BatchStep at tail of StepInfo allow some
operation after Batch parsing, which is needed by Incremental
transmit.
  • Loading branch information
klin02 committed Feb 9, 2025
1 parent bbe919b commit 8e1c380
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 36 deletions.
73 changes: 40 additions & 33 deletions src/main/scala/Batch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ case class BatchParam(config: GatewayConfig, bundles: Seq[DifftestBundle]) {
val StepGroupSize = bundles.distinctBy(_.desiredCppName).length
val StepDataByteLen = bundles.map(_.getByteAlignWidth).map { w => w / 8 }.sum
val StepDataBitLen = StepDataByteLen * 8
val StepInfoByteLen = StepGroupSize * (infoWidth / 8)
val StepInfoByteLen = (StepGroupSize + 1) * (infoWidth / 8) // Include BatchStep to update buffer index
val StepInfoBitLen = StepInfoByteLen * 8

// Width of statistic for data/info byte length
val StatsDataWidth = log2Ceil(math.max(MaxDataByteLen, StepDataByteLen))
val StatsInfoWidth = log2Ceil(math.max(MaxInfoSize, StepGroupSize))
val StatsInfoWidth = log2Ceil(math.max(MaxInfoSize, StepGroupSize + 1))

// Truncate width when shifting to reduce useless gates
val TruncDataBitLen = math.min(MaxDataBitLen, StepDataBitLen)
Expand Down Expand Up @@ -123,7 +123,7 @@ class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam) ext
val step_enable = IO(Output(Bool()))

val sorted =
in.groupBy(_.bits.desiredCppName).values.toSeq.sortBy(gens => gens.length * gens.head.bits.getByteAlignWidth)
in.groupBy(_.bits.desiredCppName).values.toSeq.sortBy(gen => gen.length * gen.head.bits.getByteAlignWidth).reverse
// Stage 1: concat bundles with same desiredCppName
val group_bitlen = sorted.map(_.head.bits.getByteAlignWidth)
val group_length = sorted.map(_.length)
Expand Down Expand Up @@ -175,23 +175,33 @@ class BatchCollector(bundles: Seq[Valid[DifftestBundle]], param: BatchParam) ext
val info_num = delay_group_status.last.info_size
step_enable := info_num =/= 0.U
step_status := delay_group_status
// append BatchStep to last step_status
step_status.last.info_size := delay_group_status.last.info_size + 1.U
// Use BatchStep to update index of software buffer
val BatchStep = Wire(new BatchInfo)
BatchStep.id := Batch.getTemplate.length.U
BatchStep.num := info_num // unused, only for debugging
// Collect from tail, collect(i) include last 0~i
val toCollect_data = delay_group_data.reverse
val toCollect_info = delay_group_info.reverse
val toCollect_vsize = delay_group_vsize.reverse
val collect_data = Wire(MixedVec(Seq.tabulate(param.StepGroupSize) { idx =>
UInt(delay_group_data.take(idx + 1).map(_.getWidth).sum.W)
UInt(toCollect_data.take(idx + 1).map(_.getWidth).sum.W)
}))
val collect_info = Wire(MixedVec(Seq.tabulate(param.StepGroupSize) { idx =>
UInt(((idx + 1) * param.infoWidth).W)
UInt(((idx + 2) * param.infoWidth).W)
}))
// Collect from head, collect(i) include 0~i
collect_data(0) := delay_group_data(0)
collect_info(0) := delay_group_info(0)

collect_data(0) := toCollect_data(0)
collect_info(0) := Mux(toCollect_vsize(0) =/= 0.U, Cat(BatchStep.asUInt, toCollect_info(0)), BatchStep.asUInt)
(1 until param.StepGroupSize).foreach { idx =>
val cat_map = Seq.tabulate(group_length(idx) + 1) { len =>
(len.U, Cat(collect_data(idx - 1), delay_group_data(idx)(len * group_bitlen(idx) - 1, 0)))
val cat_map = Seq.tabulate(group_length.reverse(idx) + 1) { len =>
(len.U, Cat(collect_data(idx - 1), toCollect_data(idx)(len * group_bitlen.reverse(idx) - 1, 0)))
}
collect_data(idx) := LookupTree(delay_group_vsize(idx), cat_map)
collect_data(idx) := LookupTree(toCollect_vsize(idx), cat_map)
collect_info(idx) := Mux(
delay_group_vsize(idx) =/= 0.U,
Cat(collect_info(idx - 1), delay_group_info(idx)),
toCollect_vsize(idx) =/= 0.U,
Cat(collect_info(idx - 1), toCollect_info(idx)),
collect_info(idx - 1),
)
}
Expand Down Expand Up @@ -227,8 +237,8 @@ class BatchAssembler(
val delay_step_enable = RegNext(step_enable)
val delay_step_trace_info = Option.when(config.hasReplay)(RegNext(step_trace_info.get))
val data_bytes_avail = param.MaxDataByteLen.U -& state_status.data_bytes
// Always leave space for BatchFinish and BatchInterval, use MaxInfoSize - 2
val info_size_avail = (param.MaxInfoSize - 2).U -& state_status.info_size
// Always leave space for BatchFinish, use MaxInfoSize - 1
val info_size_avail = (param.MaxInfoSize - 1).U -& state_status.info_size
val data_exceed = Wire(Bool())
val info_exceed = Wire(Bool())
val append_data = Wire(UInt(param.TruncDataBitLen.W))
Expand All @@ -239,10 +249,6 @@ class BatchAssembler(
val next_state_info = Wire(UInt(param.MaxInfoBitLen.W))
val next_state_stats = Wire(new BatchStats(param))

// Use BatchInterval to update index of software buffer
val BatchInterval = Wire(new BatchInfo)
BatchInterval.id := Batch.getTemplate.length.U
BatchInterval.num := delay_step_status.last.info_size // unused, only for debugging
val BatchFinish = Wire(new BatchInfo)
BatchFinish.id := (Batch.getTemplate.length + 1).U
BatchFinish.num := finish_step
Expand Down Expand Up @@ -274,13 +280,13 @@ class BatchAssembler(
assert(remain_stats.data_bytes <= param.MaxDataByteLen.U)
assert(remain_stats.info_size + 1.U <= param.MaxInfoSize.U)

val concat_data = (delay_step_data >> (remain_stats.data_bytes << 3).asUInt).asUInt
val concat_info = (delay_step_info >> (remain_stats.info_size * param.infoWidth.U)).asUInt
// Note we need only lowest bits to update state, truncate high bits to reduce gates
val remain_data = (~(~0.U(param.TruncDataBitLen.W) <<
(remain_stats.data_bytes << 3).asUInt)).asUInt & delay_step_data
val remain_info = (~(~0.U(param.StepInfoBitLen.W) <<
(remain_stats.info_size * param.infoWidth.U))).asUInt & delay_step_info
val concat_data = (~(~0.U(param.TruncDataBitLen.W) <<
(concat_stats.data_bytes << 3).asUInt)).asUInt & delay_step_data
val concat_info = (~(~0.U(param.StepInfoBitLen.W) <<
(concat_stats.info_size * param.infoWidth.U))).asUInt & delay_step_info
val remain_data = (delay_step_data >> (concat_stats.data_bytes << 3).asUInt).asUInt
val remain_info = (delay_step_info >> (concat_stats.info_size * param.infoWidth.U)).asUInt

// Delay step can be partly appended to output for making full use of transmission param
// Avoid appending when step equals batchSize(delay_step_exceed), last appended data will overwrite first step data
Expand All @@ -290,20 +296,20 @@ class BatchAssembler(
finish_step := state_step_cnt + Mux(append_whole, 1.U, 0.U)

append_data := Mux(has_append, concat_data(param.TruncDataBitLen - 1, 0), 0.U)
val append_finish_map = Seq.tabulate(param.StepGroupSize) { g =>
val append_finish_map = Seq.tabulate(param.StepGroupSize + 2) { g =>
(g.U, (BatchFinish.asUInt << (g * param.infoWidth)).asUInt)
}
append_info := Mux(
has_append,
Cat(concat_info | LookupTree(concat_stats.info_size, append_finish_map), BatchInterval.asUInt),
concat_info | LookupTree(concat_stats.info_size, append_finish_map),
BatchFinish.asUInt,
)

next_state_step_cnt := Mux(has_append && append_whole, 0.U, 1.U)
next_state_data := Mux(has_append, remain_data, delay_step_data)
next_state_info := Mux(has_append, remain_info, Cat(delay_step_info, BatchInterval.asUInt))
next_state_info := Mux(has_append, remain_info, delay_step_info)
next_state_stats.data_bytes := Mux(has_append, remain_stats.data_bytes, delay_step_status.last.data_bytes)
next_state_stats.info_size := Mux(has_append, remain_stats.info_size, delay_step_status.last.info_size + 1.U)
next_state_stats.info_size := Mux(has_append, remain_stats.info_size, delay_step_status.last.info_size)
} else {
data_exceed := delay_step_enable && delay_step_status.last.data_bytes > data_bytes_avail
info_exceed := delay_step_enable && delay_step_status.last.info_size > info_size_avail
Expand All @@ -316,9 +322,9 @@ class BatchAssembler(

next_state_step_cnt := 1.U
next_state_data := delay_step_data
next_state_info := Cat(delay_step_info, BatchInterval.asUInt)
next_state_info := delay_step_info
next_state_stats.data_bytes := delay_step_status.last.data_bytes
next_state_stats.info_size := delay_step_status.last.info_size + 1.U
next_state_stats.info_size := delay_step_status.last.info_size
}

// Stage 2:
Expand Down Expand Up @@ -352,6 +358,7 @@ class BatchAssembler(
out.step := Mux(out.enable, finish_step, 0.U)

val state_update = delay_step_enable || state_flush || timeout

when(state_update) {
when(delay_step_enable) {
when(should_tick) {
Expand All @@ -365,9 +372,9 @@ class BatchAssembler(
state_data := state_data |
(delay_step_data(param.TruncDataBitLen - 1, 0) << (state_status.data_bytes << 3).asUInt).asUInt
state_info := state_info |
(Cat(delay_step_info, BatchInterval.asUInt) << (state_status.info_size * param.infoWidth.U)).asUInt
(delay_step_info << (state_status.info_size * param.infoWidth.U)).asUInt
state_status.data_bytes := state_status.data_bytes + delay_step_status.last.data_bytes
state_status.info_size := state_status.info_size + delay_step_status.last.info_size + 1.U
state_status.info_size := state_status.info_size + delay_step_status.last.info_size
if (config.hasReplay) state_trace_size.get := state_trace_size.get + delay_step_trace_info.get.trace_size
}
}.otherwise { // state_flush without new-coming step
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/DPIC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class DPICBatch(template: Seq[DifftestBundle], batchIO: BatchIO, config: Gateway

override def desiredName: String = "DifftestBatch"
override def dpicFuncAssigns: Seq[String] = {
val bundleEnum = template.map(_.desiredModuleName.replace("Difftest", "")) ++ Seq("BatchInterval", "BatchFinish")
val bundleEnum = template.map(_.desiredModuleName.replace("Difftest", "")) ++ Seq("BatchStep", "BatchFinish")
val bundleAssign = template.zipWithIndex.map { case (t, idx) =>
val bundleName = bundleEnum(idx)
val perfName = "perf_Batch_" + bundleName
Expand Down Expand Up @@ -265,7 +265,7 @@ class DPICBatch(template: Seq[DifftestBundle], batchIO: BatchIO, config: Gateway
| ${bundleEnum.mkString(",\n ")}
| };
| extern void simv_nstep(uint8_t step);
| static int dut_index = -1;
| static int dut_index = 0;
| $batchDecl
| for (int i = 0; i < $infoLen; i++) {
| uint8_t id = info[i].id;
Expand All @@ -277,7 +277,7 @@ class DPICBatch(template: Seq[DifftestBundle], batchIO: BatchIO, config: Gateway
|#endif // CONFIG_DIFFTEST_INTERNAL_STEP
| break;
| }
| else if (id == BatchInterval) {
| else if (id == BatchStep) {
| dut_index = (dut_index + 1) % CONFIG_DIFFTEST_BATCH_SIZE;
|#ifdef CONFIG_DIFFTEST_QUERY
| difftest_query_step();
Expand Down

0 comments on commit 8e1c380

Please sign in to comment.