From c7023ae6f807485169e4dcedbdaef08699fed4ad Mon Sep 17 00:00:00 2001 From: Kris Dong Date: Fri, 19 Apr 2024 14:03:48 -0700 Subject: [PATCH] Add test cases for working mesh --- software/gemmini-rocc-tests | 2 +- src/main/scala/gemmini/Mesh.notscala | 129 -------------------------- src/main/scala/gemmini/Mesh.scala | 2 +- src/main/scala/gemmini/Tile.notscala | 132 --------------------------- 4 files changed, 2 insertions(+), 263 deletions(-) delete mode 100644 src/main/scala/gemmini/Mesh.notscala delete mode 100644 src/main/scala/gemmini/Tile.notscala diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 60523d54..d2415bbf 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 60523d54b71a835a7a0c54bf08d0a080d300cfb7 +Subproject commit d2415bbf4980938a829df8a4bed19c6c77852008 diff --git a/src/main/scala/gemmini/Mesh.notscala b/src/main/scala/gemmini/Mesh.notscala deleted file mode 100644 index cd056658..00000000 --- a/src/main/scala/gemmini/Mesh.notscala +++ /dev/null @@ -1,129 +0,0 @@ - -package gemmini - -import chisel3._ -import chisel3.util._ -import chisel3.experimental._ - -/** - * A Grid is a 2D array of Tile modules with registers in between each tile and - * registers from the bottom row and rightmost column of tiles to the Grid outputs. - * @param width - * @param tileRows - * @param tileColumns - * @param meshRows - * @param meshColumns - */ -class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, - df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int, - max_simultaneous_matmuls: Int, output_delay: Int, - val tileRows: Int, val tileColumns: Int, - val meshRows: Int, val meshColumns: Int) extends Module { - val io = IO(new Bundle { - val in_a = Input(Vec(meshRows, Vec(tileRows, inputType))) - val in_b = Input(Vec(meshColumns, Vec(tileColumns, inputType))) - val in_d = Input(Vec(meshColumns, Vec(tileColumns, inputType))) - val in_control = Input(Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))) - val in_id = Input(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) // The unique id of this particular matmul - val in_last = Input(Vec(meshColumns, Vec(tileColumns, Bool()))) - val out_b = Output(Vec(meshColumns, Vec(tileColumns, outputType))) - val out_c = Output(Vec(meshColumns, Vec(tileColumns, outputType))) - val in_valid = Input(Vec(meshColumns, Vec(tileColumns, Bool()))) - val out_valid = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) - val out_control = Output(Vec(meshColumns, Vec(tileColumns, new PEControl(accType)))) - val out_id = Output(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) - val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) - }) - - // mesh(r)(c) => Tile at row r, column c - val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns))) - val meshT = mesh.transpose - - def pipe[T <: Data](valid: Bool, t: T, latency: Int): T = { - // The default "Pipe" function apparently resets the valid signals to false.B. We would like to avoid using global - // signals in the Mesh, so over here, we make it clear that the reset signal will never be asserted - chisel3.withReset(false.B) { Pipe(valid, t, latency).bits } - } - - // Chain tile_a_out -> tile_a_in (pipeline a across each row) - // TODO clock-gate A signals with in_garbage - for (r <- 0 until meshRows) { - mesh(r).foldLeft(io.in_a(r)) { - case (in_a, tile) => - tile.io.in_a := ShiftRegister(in_a, tile_latency+1) - tile.io.out_a - } - } - - // Chain tile_out_b -> tile_b_in (pipeline b across each column) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft((io.in_b(c), io.in_valid(c))) { - case ((in_b, valid), tile) => - tile.io.in_b := pipe(valid.head, in_b, tile_latency+1) - (tile.io.out_b, tile.io.out_valid) - } - } - - // Chain tile_out -> tile_propag (pipeline output across each column) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft((io.in_d(c), io.in_valid(c))) { - case ((in_propag, valid), tile) => - tile.io.in_d := pipe(valid.head, in_propag, tile_latency+1) - (tile.io.out_c, tile.io.out_valid) - } - } - - // Chain control signals (pipeline across each column) - assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_))) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft((io.in_control(c), io.in_valid(c))) { - case ((in_ctrl, valid), tile) => - (tile.io.in_control, in_ctrl, valid).zipped.foreach { case (tile_ctrl, ctrl, v) => - tile_ctrl.shift := pipe(v, ctrl.shift, tile_latency+1) - tile_ctrl.dataflow := pipe(v, ctrl.dataflow, tile_latency+1) - tile_ctrl.propagate := pipe(v, ctrl.propagate, tile_latency+1) - } - (tile.io.out_control, tile.io.out_valid) - } - } - - // Chain in_valid (pipeline across each column) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft(io.in_valid(c)) { - case (in_v, tile) => - tile.io.in_valid := ShiftRegister(in_v, tile_latency+1) - tile.io.out_valid - } - } - - // Chain in_id (pipeline across each column) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft(io.in_id(c)) { - case (in_id, tile) => - tile.io.in_id := ShiftRegister(in_id, tile_latency+1) - tile.io.out_id - } - } - - // Chain in_last (pipeline across each column) - for (c <- 0 until meshColumns) { - meshT(c).foldLeft(io.in_last(c)) { - case (in_last, tile) => - tile.io.in_last := ShiftRegister(in_last, tile_latency+1) - tile.io.out_last - } - } - - // Capture out_vec and out_control_vec (connect IO to bottom row of mesh) - // (The only reason we have so many zips is because Scala doesn't provide a zipped function for Tuple4) - for (((((((b, c), v), ctrl), id), last), tile) <- io.out_b zip io.out_c zip io.out_valid zip io.out_control zip io.out_id zip io.out_last zip mesh.last) { - // TODO we pipelined this to make physical design easier. Consider removing these if possible - // TODO shouldn't we clock-gate these signals with "garbage" as well? - b := ShiftRegister(tile.io.out_b, output_delay) - c := ShiftRegister(tile.io.out_c, output_delay) - v := ShiftRegister(tile.io.out_valid, output_delay) - ctrl := ShiftRegister(tile.io.out_control, output_delay) - id := ShiftRegister(tile.io.out_id, output_delay) - last := ShiftRegister(tile.io.out_last, output_delay) - } -} diff --git a/src/main/scala/gemmini/Mesh.scala b/src/main/scala/gemmini/Mesh.scala index eeee17ea..7af51b19 100644 --- a/src/main/scala/gemmini/Mesh.scala +++ b/src/main/scala/gemmini/Mesh.scala @@ -61,7 +61,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, // TODO clock-gate A signals with in_garbage for (c <- 0 until meshColumns) { for (r <- 0 until meshRows) { - mesh(r)(c).io.in_a := io.in_a(r) + mesh(r)(c).io.in_a := ShiftRegister(io.in_a(r), tile_latency+1) } } diff --git a/src/main/scala/gemmini/Tile.notscala b/src/main/scala/gemmini/Tile.notscala deleted file mode 100644 index 9c2a418c..00000000 --- a/src/main/scala/gemmini/Tile.notscala +++ /dev/null @@ -1,132 +0,0 @@ -// See README.md for license details. - -package gemmini - -import chisel3._ -import chisel3.util._ -import Util._ - -/** - * A Tile is a purely combinational 2D array of passThrough PEs. - * a, b, s, and in_propag are broadcast across the entire array and are passed through to the Tile's outputs - * @param width The data width of each PE in bits - * @param rows Number of PEs on each row - * @param columns Number of PEs on each column - */ -class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, tree_reduction: Boolean, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int)(implicit ev: Arithmetic[T]) extends Module { - val io = IO(new Bundle { - val in_a = Input(Vec(rows, inputType)) - val in_b = Input(Vec(columns, outputType)) // This is the output of the tile next to it - val in_d = Input(Vec(columns, outputType)) - - val in_control = Input(Vec(columns, new PEControl(accType))) - val in_id = Input(Vec(columns, UInt(log2Up(max_simultaneous_matmuls).W))) - val in_last = Input(Vec(columns, Bool())) - - val out_a = Output(Vec(rows, inputType)) - val out_c = Output(Vec(columns, outputType)) - val out_b = Output(Vec(columns, outputType)) - - val out_control = Output(Vec(columns, new PEControl(accType))) - val out_id = Output(Vec(columns, UInt(log2Up(max_simultaneous_matmuls).W))) - val out_last = Output(Vec(columns, Bool())) - - val in_valid = Input(Vec(columns, Bool())) - val out_valid = Output(Vec(columns, Bool())) - - val bad_dataflow = Output(Bool()) - }) - - import ev._ - - val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, max_simultaneous_matmuls))) - val tileT = tile.transpose - - // TODO: abstract hori/vert broadcast, all these connections look the same - // Broadcast 'a' horizontally across the Tile - for (r <- 0 until rows) { - tile(r).foldLeft(io.in_a(r)) { - case (in_a, pe) => - pe.io.in_a := in_a - pe.io.out_a - } - } - - // Broadcast 'b' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_b(c)) { - case (in_b, pe) => - pe.io.in_b := (if (tree_reduction) in_b.zero else in_b) - pe.io.out_b - } - } - - // Broadcast 'd' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_d(c)) { - case (in_d, pe) => - pe.io.in_d := in_d - pe.io.out_c - } - } - - // Broadcast 'control' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_control(c)) { - case (in_ctrl, pe) => - pe.io.in_control := in_ctrl - pe.io.out_control - } - } - - // Broadcast 'garbage' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_valid(c)) { - case (v, pe) => - pe.io.in_valid := v - pe.io.out_valid - } - } - - // Broadcast 'id' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_id(c)) { - case (id, pe) => - pe.io.in_id := id - pe.io.out_id - } - } - - // Broadcast 'last' vertically across the Tile - for (c <- 0 until columns) { - tileT(c).foldLeft(io.in_last(c)) { - case (last, pe) => - pe.io.in_last := last - pe.io.out_last - } - } - - // Drive the Tile's bottom IO - for (c <- 0 until columns) { - io.out_c(c) := tile(rows-1)(c).io.out_c - io.out_control(c) := tile(rows-1)(c).io.out_control - io.out_id(c) := tile(rows-1)(c).io.out_id - io.out_last(c) := tile(rows-1)(c).io.out_last - io.out_valid(c) := tile(rows-1)(c).io.out_valid - - io.out_b(c) := { - if (tree_reduction) { - val prods = tileT(c).map(_.io.out_b) - accumulateTree(prods :+ io.in_b(c)) - } else { - tile(rows - 1)(c).io.out_b - } - } - } - io.bad_dataflow := tile.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_) - - // Drive the Tile's right IO - for (r <- 0 until rows) { - io.out_a(r) := tile(r)(columns-1).io.out_a - } -}