diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 24f7b0b..0068723 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -1200,7 +1200,8 @@ def plot_heatmap(tensor, title, ax): def util_main(**kwargs): B = 8 log2e = np.log2(np.exp(1)) - eps_max = B / (2**B) + range_scale = 32 + eps_max = range_scale * B / (2**B) N = 1024 A = np.random.randint(-128, 127, size = (1, N, N), dtype = np.int8) diff --git a/PyITA/softmax.py b/PyITA/softmax.py index 8cbc5cf..7545086 100644 --- a/PyITA/softmax.py +++ b/PyITA/softmax.py @@ -30,7 +30,8 @@ def fastSoftmax(x, integerize = True): B = 8 # Scaling factor - eps_max = B / (2**B) + range_scale = 32 + eps_max = range_scale * B / (2**B) # Find the maximum for each row in the current column block (consisting of 16 columns) max = np.repeat(np.max(x, axis = -1), seq_length).reshape(n_heads, seq_length, seq_length) @@ -80,7 +81,8 @@ def streamingPartialSoftmax(x, integerize = True): B = 8 # Scaling factor - eps_max = B / (2**B) + range_scale = 32 + eps_max = range_scale * B / (2**B) if integerize: x = x @@ -145,20 +147,20 @@ def streamingPartialSoftmax(x, integerize = True): # Calculate exponential sum over the current part of the row and scale it by 2**10 to prevent underflow if integerize: # exp_sum = np.sum(2**8 >> shift, -1) # or - exp_sum = np.floor(np.sum(2**8 / 2**shift, axis = -1)) + exp_sum = np.floor(np.sum(2**8 >> shift, axis = -1)) else: exp_sum = np.sum(1 / 2**shift, axis = -1) # Update the accumulated sum and add the accumulation over the current part of the row if integerize: - exp_partial_sum = np.floor((exp_partial_sum / 2**shift_sum)) + exp_sum + exp_partial_sum = np.floor((exp_partial_sum.astype(np.int32) >> shift_sum)) + exp_sum else: exp_partial_sum = (exp_partial_sum / 2**(shift_sum.astype(np.float32))) + exp_sum ## STAGE 2: Calculate the softmax activation # Invert the partial sum if integerize: - exp_partial_sum_inverse = np.floor((2**8 - 1) * 2**8 / exp_partial_sum).astype(np.int32) + exp_partial_sum_inverse = np.floor((2**8 - 1) * 2**8 // exp_partial_sum).astype(np.int32) else: exp_partial_sum_inverse = 1 / exp_partial_sum @@ -176,8 +178,8 @@ def streamingPartialSoftmax(x, integerize = True): if integerize: # A_partial_softmax[0] = np.repeat(exp_partial_sum_inverse, seq_length).reshape(seq_length, seq_length) >> shift return np.floor( - np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift).astype( - np.int8) + np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) >> shift).astype( + np.uint8) else: return np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift @@ -187,7 +189,8 @@ def realSoftmax(A_requant, integerize = True): B = 8 log2e = np.log2(np.exp(1)) - eps_x = B / (2**B * log2e) + range_scale = 32 + eps_x = range_scale * B / (2**B * log2e) if integerize: x = A_requant * eps_x diff --git a/src/ita_package.sv b/src/ita_package.sv index 335e173..c20ef71 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -30,7 +30,7 @@ package ita_package; parameter int unsigned M3AddrWidth = idx_width(S) ; parameter int unsigned NumReadPorts = N ; parameter int unsigned MNumReadPorts = N ; - parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 14 `endif; + parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 12 `endif; localparam int unsigned SplitFactor = 4 ; parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif; @@ -96,12 +96,13 @@ package ita_package; // Softmax localparam int unsigned SoftmaxScalar = 65280; // (2**8-1) * 2**8 + localparam int unsigned SoftmaxShift = 0; localparam int unsigned SoftmaxAccDataWidth = 19; // Up to S = 2048 - localparam int unsigned SoftFifoDepth = 4; + localparam int unsigned SoftFifoDepth = 12; typedef logic [idx_width(SoftFifoDepth)-1:0] soft_fifo_usage_t; typedef logic [idx_width(SoftFifoDepth+1)-1:0] ongoing_soft_t; localparam int unsigned DividerWidth = SoftmaxAccDataWidth + 1; - localparam int unsigned NumDiv = 5; + localparam int unsigned NumDiv = 10; // Requantizer typedef enum {Signed=0, Unsigned=1} requant_mode_e; diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index 675750c..2eb5255 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -54,11 +54,11 @@ module ita_softmax requant_oup_t requant_oup_q; requant_t max_d, max_q; - logic unsigned [N-1:0][3:0] shift_d, shift_q; + logic unsigned [N-1:0][WI-SoftmaxShift:0] shift_d, shift_q; logic [N-1:0][WI-1:0] shift_diff; - logic unsigned [3:0] shift_sum_d, shift_sum_q; + logic unsigned [WI-SoftmaxShift:0] shift_sum_d, shift_sum_q; logic [WI-1:0] max_diff; - logic unsigned [M-1:0][3:0] shift_inp; + logic unsigned [M-1:0][WI-SoftmaxShift:0] shift_inp; logic [M-1:0][WI-1:0] shift_inp_diff; logic calc_stream_soft_en_q; @@ -139,18 +139,18 @@ module ita_softmax max_d = max_i; for (int i = 0; i < N; i++) begin shift_diff[i] = max_i - requant_oup_q[i]; - shift_d[i] = unsigned'(shift_diff[i]) >> 5; - if (shift_diff[i][4]) - shift_d[i] = (unsigned'(shift_diff[i]) >> 5) + 1; + shift_d[i] = unsigned'(shift_diff[i]) >> SoftmaxShift; + if (SoftmaxShift != 0 && shift_diff[i][SoftmaxShift-1]) + shift_d[i] = (unsigned'(shift_diff[i]) >> SoftmaxShift) + 1; end if (tile_q2 != '0 || count_q2>=M) begin // If not first part of the first row, normalize previous sum read_acc_en_o[0] = 1; read_acc_addr_o[0] = count_q2; prev_max_o = read_max_data_i[0]; max_diff = max_i - prev_max_o; - shift_sum_d = max_diff >> 5; - if (max_diff[4]) - shift_sum_d = (max_diff >> 5) + 1; + shift_sum_d = max_diff >> SoftmaxShift; + if (SoftmaxShift != 0 && max_diff[SoftmaxShift-1]) + shift_sum_d = (max_diff >> SoftmaxShift) + 1; end else begin prev_max_o = 8'h80; end @@ -223,9 +223,9 @@ module ita_softmax if (calc_stream_soft_en_q) begin for (int i = 0; i < M; i++) begin shift_inp_diff[i] = read_max_data_i[1]-inp_i[i]; - shift_inp[i] = unsigned'(shift_inp_diff[i]) >> 5; - if (shift_inp_diff[i][4]) - shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> 5) + 1; + shift_inp[i] = unsigned'(shift_inp_diff[i]) >> SoftmaxShift; + if (SoftmaxShift != 0 && shift_inp_diff[i][SoftmaxShift-1]) + shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> SoftmaxShift) + 1; inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i]; end end