Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Bender.lock
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ packages:
- hwpe-stream
- l2_tcdm_hybrid_interco
hwpe-ctrl:
revision: 2926867cafb3fb518a1ae849675f281b79ecab8a
revision: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2
version: null
source:
Git: https://github.com/pulp-platform/hwpe-ctrl
Expand Down
2 changes: 1 addition & 1 deletion Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
common_cells: { git: https://github.com/pulp-platform/common_cells, version: 1.23.0 }
hwpe-stream: { git: https://github.com/pulp-platform/hwpe-stream, rev: a20f35e62fe2842904797079dc7881e490ff7117 }
hci: { git: https://github.com/pulp-platform/hci, rev: 066c7ce7d24b61587e245decb592054669d7a2d1 }
hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 2926867cafb3fb518a1ae849675f281b79ecab8a }
hwpe-ctrl: { git: https://github.com/pulp-platform/hwpe-ctrl, rev: 7ba707d837697c2c7c6ea1396ec4e4ab094054a2 }
scm: { git: https://github.com/pulp-platform/scm, rev: 998466d2a3c2d7d572e43d2666d93c4f767d8d60 }
tech_cells_generic: { git: https://github.com/pulp-platform/tech_cells_generic, version: 0.2.11 }

Expand Down
8 changes: 4 additions & 4 deletions src/hwpe/ita_hwpe_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ module ita_hwpe_ctrl
always_comb begin
ctrl_engine_o = '0;
ctrl_engine_o.start = slave_flags.start;
ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0];
ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4];
ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8];
ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12];
ctrl_engine_o.tile_s = reg_file.hwpe_params[ITA_REG_TILES][3:0] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][3:0];
ctrl_engine_o.tile_e = reg_file.hwpe_params[ITA_REG_TILES][7:4] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][7:4];
ctrl_engine_o.tile_p = reg_file.hwpe_params[ITA_REG_TILES][11:8] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][11:8];
ctrl_engine_o.tile_f = reg_file.hwpe_params[ITA_REG_TILES][15:12] == 0 ? 1 : reg_file.hwpe_params[ITA_REG_TILES][15:12];
ctrl_engine_o.eps_mult[0] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][7:0];
ctrl_engine_o.eps_mult[1] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][15:8];
ctrl_engine_o.eps_mult[2] = reg_file.hwpe_params[ITA_REG_EPS_MULT0][23:16];
Expand Down
2 changes: 1 addition & 1 deletion src/hwpe/ita_hwpe_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package ita_hwpe_package;

// HWPE Configuration
parameter int unsigned N_CORES = 9;
parameter int unsigned N_CONTEXT = 2;
parameter int unsigned N_CONTEXT = 4;
parameter int unsigned ID_WIDTH = 2;
parameter int unsigned ITA_IO_REGS = 17; // 5 address + 11 parameters + 1 sync

Expand Down
51 changes: 19 additions & 32 deletions src/hwpe/tb/ita_hwpe_tb.sv
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ endfunction
// Signals
logic [31:0] status;
string STIM_DATA;
int ita_reg_cnt;
logic [31:0] ita_reg_tiles_val;
logic [5:0][31:0] ita_reg_rqs_val;
logic [31:0] ita_reg_gelu_b_c_val;
Expand All @@ -338,6 +339,7 @@ endfunction

// Wait for reset to be released
wait (rst_n);
ita_reg_cnt = 0;

// Load memory
STIM_DATA = {simdir,"/hwpe/mem.txt"};
Expand All @@ -356,7 +358,7 @@ endfunction
PERIPH_READ( 32'h04, 32'h0, status, clk);

// 1: Step Q
ita_compute_step(Q, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(Q, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 2: Step K
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -365,7 +367,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(K, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(K, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 3: Step V
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -374,7 +376,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8;
end
ita_compute_step(V, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(V, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

if (SINGLE_ATTENTION == 1) begin
// Reset the RQS values
Expand All @@ -389,15 +391,15 @@ endfunction
BASE_PTR_OUTPUT[AV] = BASE_PTR[19] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE;

// 4: Step QK
ita_compute_step(QK, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(QK, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly
if (group == N_TILES_SEQUENCE_DIM-1) begin
BASE_PTR_WEIGHT0[QK] = BASE_PTR_WEIGHT0[OW];
end

// 5: Step AV
ita_compute_step(AV, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(AV, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
end

// 6: Step OW
Expand All @@ -409,7 +411,9 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8;
end
ita_compute_step(OW, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(OW, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

ita_reg_cnt = 0;

// 7: Step FF1
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -420,7 +424,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16;
end
ita_compute_step(F1, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(F1, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// 8: Step FF2
if (SINGLE_ATTENTION == 1) begin
Expand All @@ -431,7 +435,7 @@ endfunction
ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24;
ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24;
end
ita_compute_step(F2, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);
ita_compute_step(F2, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk);

// Wait for the last step to finish
wait(evt);
Expand All @@ -456,6 +460,7 @@ endfunction

task automatic ita_compute_step(
input step_e step,
inout integer ita_reg_cnt,
input logic [31:0] ita_reg_tiles_val,
input logic [5:0][31:0] ita_reg_rqs_val,
input logic [31:0] ita_reg_gelu_b_c_val,
Expand Down Expand Up @@ -500,7 +505,12 @@ endfunction
ita_reg_en = 1'b1;
end else begin
// Calculate ita_reg_en
ita_reg_en_compute(step, tile, ita_reg_en);
if (ita_reg_cnt < N_CONTEXT) begin
ita_reg_en = 1'b1;
ita_reg_cnt++;
end else begin
ita_reg_en = 1'b0;
end
end

// Calculate ctrl_stream_val, weight_ptr_en, and bias_ptr_en
Expand Down Expand Up @@ -575,29 +585,6 @@ endfunction
$display(" - output_ptr 0x%08h (output_base_ptr 0x%08h)", output_ptr, output_base_ptr);
endtask


task automatic ita_reg_en_compute(
input step_e step,
input integer tile,
output logic enable
);
enable = 1'b0;
// Write requantization parameters only in first two programming phases
if (step == Q) begin
if (tile == 0 || tile == 1)
enable = 1'b1;
end else if (step == K && N_TILES_OUTER_X[Q]*N_TILES_OUTER_Y[Q]*N_TILES_INNER_DIM[Q] == 1) begin
if (tile == 0)
enable = 1'b1;
end else if (step == F1) begin
if (tile == 0 || tile == 1)
enable = 1'b1;
end else if (step == F2 && N_TILES_OUTER_X[F1]*N_TILES_OUTER_Y[F1]*N_TILES_INNER_DIM[F1] == 1) begin
if (tile == 0)
enable = 1'b1;
end
endtask

task automatic ctrl_val_compute(
input step_e step,
input integer tile,
Expand Down
Loading