From 212c7ee21efdb7d21379017141b562d173073350 Mon Sep 17 00:00:00 2001 From: Frank Sun Date: Mon, 20 Oct 2025 16:23:23 -0400 Subject: [PATCH] Add configurable BHT history parameters via CSR --- src/main/scala/rocket/BTB.scala | 26 +++++++++++++++++++------- src/main/scala/rocket/CSR.scala | 9 +++++++++ src/main/scala/rocket/Frontend.scala | 2 ++ src/main/scala/rocket/RocketCore.scala | 2 +- src/main/scala/tile/CustomCSRs.scala | 2 ++ 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/main/scala/rocket/BTB.scala b/src/main/scala/rocket/BTB.scala index cb8392a7c29..e8529c4d918 100644 --- a/src/main/scala/rocket/BTB.scala +++ b/src/main/scala/rocket/BTB.scala @@ -72,18 +72,26 @@ class BHTResp(implicit p: Parameters) extends BtbBundle()(p) { // - each counter corresponds with the address of the fetch packet ("fetch pc"). // - updated when a branch resolves (and BTB was a hit for that branch). // The updating branch must provide its "fetch pc". -class BHT(params: BHTParams)(implicit val p: Parameters) extends HasCoreParameters { +class BHT(params: BHTParams, historyLengthConfig: UInt, historyBitsConfig: UInt)(implicit val p: Parameters) extends HasCoreParameters { def index(addr: UInt, history: UInt) = { - def hashHistory(hist: UInt) = if (params.historyLength == params.historyBits) hist else { - val k = math.sqrt(3)/2 - val i = BigDecimal(k * math.pow(2, params.historyLength)).toBigInt - (i.U * hist)(params.historyLength-1, params.historyLength-params.historyBits) + def hashHistory(hist: UInt) = { + Mux(historyBitsConfig >= historyLengthConfig, + hist, + { + val k = math.sqrt(3)/2 + val i = (BigDecimal(k * math.pow(2, params.historyLength)).toBigInt.U) >> (params.historyLength.U - historyLengthConfig) + val product = i * hist + (product >> (historyLengthConfig - historyBitsConfig)) & ((1.U << historyLengthConfig) - 1.U) + } + ) } def hashAddr(addr: UInt) = { val hi = addr >> log2Ceil(fetchBytes) hi(log2Ceil(params.nEntries)-1, 0) ^ (hi >> log2Ceil(params.nEntries))(1, 0) } - hashAddr(addr) ^ (hashHistory(history) << (log2Up(params.nEntries) - params.historyBits)) + val slicedInputHistory = history >> (params.historyLength.U - historyLengthConfig) + val hashValue = hashHistory(slicedInputHistory) + hashAddr(addr) ^ (hashValue << (log2Up(params.nEntries).U - historyBitsConfig)) } def get(addr: UInt): BHTResp = { val res = Wire(new BHTResp) @@ -114,6 +122,8 @@ class BHT(params: BHTParams)(implicit val p: Parameters) extends HasCoreParamete private val table = Mem(params.nEntries, UInt(params.counterLength.W)) val history = RegInit(0.U(params.historyLength.W)) + val slicedHistory = history >> (params.historyLength.U - historyLengthConfig) + private val reset_waddr = RegInit(0.U((params.nEntries.log2+1).W)) private val resetting = !reset_waddr(params.nEntries.log2) private val wen = WireInit(resetting) @@ -192,6 +202,8 @@ class BTB(implicit p: Parameters) extends BtbModule { val ras_update = Flipped(Valid(new RASUpdate)) val ras_head = Valid(UInt(vaddrBits.W)) val flush = Input(Bool()) + val historyLengthConfig = Input(UInt(4.W)) + val historyBitsConfig = Input(UInt(4.W)) }) val idxs = Reg(Vec(entries, UInt((matchBits - log2Up(coreInstBytes)).W))) @@ -299,7 +311,7 @@ class BTB(implicit p: Parameters) extends BtbModule { } if (btbParams.bhtParams.nonEmpty) { - val bht = new BHT(btbParams.bhtParams.get) + val bht = new BHT(btbParams.bhtParams.get, io.historyLengthConfig, io.historyBitsConfig) val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR val res = bht.get(io.req.bits.addr) when (io.bht_advance.valid) { diff --git a/src/main/scala/rocket/CSR.scala b/src/main/scala/rocket/CSR.scala index 0e931388065..6177054300b 100644 --- a/src/main/scala/rocket/CSR.scala +++ b/src/main/scala/rocket/CSR.scala @@ -797,8 +797,17 @@ class CSRFile( require(!read_mapping.contains(csr.id)) val reg = csr.init.map(init => RegInit(init.U(xLen.W))).getOrElse(Reg(UInt(xLen.W))) val read = io.rw.cmd =/= CSR.N && io.rw.addr === csr.id.U + val write = io.rw.cmd.isOneOf(CSR.W, CSR.S, CSR.C) && io.rw.addr === csr.id.U csr_io.ren := read + csr_io.wen := write when (read && csr_io.stall) { io.rw_stall := true.B } + // Handle writes for writable CSRs (mask != 0) + if (csr.mask != 0) { + when (write) { + val wdata = readModifyWriteCSR(io.rw.cmd, reg, io.rw.wdata) + reg := wdata & csr.mask.U + } + } read_mapping += csr.id -> reg reg } diff --git a/src/main/scala/rocket/Frontend.scala b/src/main/scala/rocket/Frontend.scala index 8156e51c475..1dea99c2329 100644 --- a/src/main/scala/rocket/Frontend.scala +++ b/src/main/scala/rocket/Frontend.scala @@ -218,6 +218,8 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) val force_taken = io.ptw.customCSRs.bpmStatic when (io.ptw.customCSRs.flushBTB) { btb.io.flush := true.B } when (force_taken) { btb.io.bht_update.valid := false.B } + btb.io.historyLengthConfig := io.ptw.customCSRs.historyLengthConfig + btb.io.historyBitsConfig := io.ptw.customCSRs.historyBitsConfig val s2_base_pc = ~(~s2_pc | (fetchBytes-1).U) val taken_idx = Wire(UInt()) diff --git a/src/main/scala/rocket/RocketCore.scala b/src/main/scala/rocket/RocketCore.scala index ec6d82cdf29..ccc906226bd 100644 --- a/src/main/scala/rocket/RocketCore.scala +++ b/src/main/scala/rocket/RocketCore.scala @@ -97,7 +97,7 @@ trait HasRocketCoreParameters extends HasCoreParameters { class RocketCustomCSRs(implicit p: Parameters) extends CustomCSRs with HasRocketCoreParameters { override def bpmCSR = { - rocketParams.branchPredictionModeCSR.option(CustomCSR(bpmCSRId, BigInt(1), Some(BigInt(0)))) + rocketParams.branchPredictionModeCSR.option(CustomCSR(bpmCSRId, BigInt(0x1FF), Some(BigInt(0)))) } private def haveDCache = tileParams.dcache.get.scratch.isEmpty diff --git a/src/main/scala/tile/CustomCSRs.scala b/src/main/scala/tile/CustomCSRs.scala index 5d9be4eade4..4636707773f 100644 --- a/src/main/scala/tile/CustomCSRs.scala +++ b/src/main/scala/tile/CustomCSRs.scala @@ -40,6 +40,8 @@ class CustomCSRs(implicit p: Parameters) extends CoreBundle { def flushBTB = getOrElse(bpmCSR, _.wen, false.B) def bpmStatic = getOrElse(bpmCSR, _.value(0), false.B) + def historyLengthConfig = getOrElse(bpmCSR, _.value(4,1), 0.U) + def historyBitsConfig = getOrElse(bpmCSR, _.value(8,5), 0.U) def disableDCacheClockGate = getOrElse(chickenCSR, _.value(0), false.B) def disableICacheClockGate = getOrElse(chickenCSR, _.value(1), false.B) def disableCoreClockGate = getOrElse(chickenCSR, _.value(2), false.B)