diff --git a/Bender.lock b/Bender.lock index 731bfca..f737675 100644 --- a/Bender.lock +++ b/Bender.lock @@ -30,7 +30,7 @@ packages: - hwpe-stream - l2_tcdm_hybrid_interco hwpe-ctrl: - revision: 2926867cafb3fb518a1ae849675f281b79ecab8a + revision: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2 version: null source: Git: https://github.com/pulp-platform/hwpe-ctrl diff --git a/Bender.yml b/Bender.yml index f0497c7..6d2c5a6 100644 --- a/Bender.yml +++ b/Bender.yml @@ -18,7 +18,7 @@ dependencies: common_cells: { git: https://github.com/pulp-platform/common_cells, version: 1.23.0 } hwpe-stream: { git: https://github.com/pulp-platform/hwpe-stream, rev: a20f35e62fe2842904797079dc7881e490ff7117 } hci: { git: https://github.com/pulp-platform/hci, rev: 066c7ce7d24b61587e245decb592054669d7a2d1 } - hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 2926867cafb3fb518a1ae849675f281b79ecab8a } + hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2 } scm: { git: https://github.com/pulp-platform/scm, rev: 998466d2a3c2d7d572e43d2666d93c4f767d8d60 } tech_cells_generic: { git: https://github.com/pulp-platform/tech_cells_generic, version: 0.2.11 } diff --git a/src/hwpe/ita_hwpe_ctrl.sv b/src/hwpe/ita_hwpe_ctrl.sv index 3b371bf..1edd454 100644 --- a/src/hwpe/ita_hwpe_ctrl.sv +++ b/src/hwpe/ita_hwpe_ctrl.sv @@ -64,10 +64,10 @@ module ita_hwpe_ctrl always_comb begin ctrl_engine_o = '0; ctrl_engine_o.start = slave_flags.start; - ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0]; - ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4]; - ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8]; - ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12]; + ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][3:0]; + ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][7:4]; + ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][11:8]; + ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][15:12]; ctrl_engine_o.eps_mult[0] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][7:0]; ctrl_engine_o.eps_mult[1] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][15:8]; ctrl_engine_o.eps_mult[2] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][23:16]; diff --git a/src/hwpe/ita_hwpe_package.sv b/src/hwpe/ita_hwpe_package.sv index 984685f..47e54b9 100644 --- a/src/hwpe/ita_hwpe_package.sv +++ b/src/hwpe/ita_hwpe_package.sv @@ -10,7 +10,7 @@ package ita_hwpe_package; // HWPE Configuration parameter int unsigned N_CORES = 9; - parameter int unsigned N_CONTEXT = 2; + parameter int unsigned N_CONTEXT = 4; parameter int unsigned ID_WIDTH = 2; parameter int unsigned ITA_IO_REGS = 17; // 5 address + 11 parameters + 1 sync diff --git a/src/hwpe/tb/ita_hwpe_tb.sv b/src/hwpe/tb/ita_hwpe_tb.sv index e24e70d..7f8e30c 100644 --- a/src/hwpe/tb/ita_hwpe_tb.sv +++ b/src/hwpe/tb/ita_hwpe_tb.sv @@ -329,6 +329,7 @@ endfunction // Signals logic [31:0] status; string STIM_DATA; + int ita_reg_cnt; logic [31:0] ita_reg_tiles_val; logic [5:0][31:0] ita_reg_rqs_val; logic [31:0] ita_reg_gelu_b_c_val; @@ -338,6 +339,7 @@ endfunction // Wait for reset to be released wait (rst_n); + ita_reg_cnt = 0; // Load memory STIM_DATA = {simdir,"/hwpe/mem.txt"}; @@ -356,7 +358,7 @@ endfunction PERIPH_READ( 32'h04, 32'h0, status, clk); // 1: Step Q - ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(Q, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 2: Step K if (SINGLE_ATTENTION == 1) begin @@ -365,7 +367,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; end - ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(K, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 3: Step V if (SINGLE_ATTENTION == 1) begin @@ -374,7 +376,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; end - ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(V, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); if (SINGLE_ATTENTION == 1) begin // Reset the RQS values @@ -389,7 +391,7 @@ endfunction BASE_PTR_OUTPUT[AV] = BASE_PTR[19] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE; // 4: Step QK - ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(QK, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly if (group == N_TILES_SEQUENCE_DIM-1) begin @@ -397,7 +399,7 @@ endfunction end // 5: Step AV - ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(AV, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); end // 6: Step OW @@ -409,7 +411,9 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8; end - ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(OW, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + + ita_reg_cnt = 0; // 7: Step FF1 if (SINGLE_ATTENTION == 1) begin @@ -420,7 +424,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16; end - ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(F1, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // 8: Step FF2 if (SINGLE_ATTENTION == 1) begin @@ -431,7 +435,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24; end - ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(F2, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); // Wait for the last step to finish wait(evt); @@ -456,6 +460,7 @@ endfunction task automatic ita_compute_step( input step_e step, + inout integer ita_reg_cnt, input logic [31:0] ita_reg_tiles_val, input logic [5:0][31:0] ita_reg_rqs_val, input logic [31:0] ita_reg_gelu_b_c_val, @@ -500,7 +505,12 @@ endfunction ita_reg_en = 1'b1; end else begin // Calculate ita_reg_en - ita_reg_en_compute(step, tile, ita_reg_en); + if (ita_reg_cnt < N_CONTEXT) begin + ita_reg_en = 1'b1; + ita_reg_cnt++; + end else begin + ita_reg_en = 1'b0; + end end // Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en @@ -575,29 +585,6 @@ endfunction $display(" - output_ptr 0x%08h (output_base_ptr 0x%08h)", output_ptr, output_base_ptr); endtask - - task automatic ita_reg_en_compute( - input step_e step, - input integer tile, - output logic enable - ); - enable = 1'b0; - // Write requantization parameters only in first two programming phases - if (step == Q) begin - if (tile == 0 || tile == 1) - enable = 1'b1; - end else if (step == K && N_TILES_OUTER_X[Q]*N_TILES_OUTER_Y[Q]*N_TILES_INNER_DIM[Q] == 1) begin - if (tile == 0) - enable = 1'b1; - end else if (step == F1) begin - if (tile == 0 || tile == 1) - enable = 1'b1; - end else if (step == F2 && N_TILES_OUTER_X[F1]*N_TILES_OUTER_Y[F1]*N_TILES_INNER_DIM[F1] == 1) begin - if (tile == 0) - enable = 1'b1; - end - endtask - task automatic ctrl_val_compute( input step_e step, input integer tile,