Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion PyITA/ITA.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,8 @@ def plot_heatmap(tensor, title, ax):
def util_main(**kwargs):
B = 8
log2e = np.log2(np.exp(1))
eps_max = B / (2**B)
range_scale = 32
eps_max = range_scale * B / (2**B)

N = 1024
A = np.random.randint(-128, 127, size = (1, N, N), dtype = np.int8)
Expand Down
19 changes: 11 additions & 8 deletions PyITA/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def fastSoftmax(x, integerize = True):
B = 8

# Scaling factor
eps_max = B / (2**B)
range_scale = 32
eps_max = range_scale * B / (2**B)

# Find the maximum for each row in the current column block (consisting of 16 columns)
max = np.repeat(np.max(x, axis = -1), seq_length).reshape(n_heads, seq_length, seq_length)
Expand Down Expand Up @@ -80,7 +81,8 @@ def streamingPartialSoftmax(x, integerize = True):
B = 8

# Scaling factor
eps_max = B / (2**B)
range_scale = 32
eps_max = range_scale * B / (2**B)

if integerize:
x = x
Expand Down Expand Up @@ -145,20 +147,20 @@ def streamingPartialSoftmax(x, integerize = True):
# Calculate exponential sum over the current part of the row and scale it by 2**10 to prevent underflow
if integerize:
# exp_sum = np.sum(2**8 >> shift, -1) # or
exp_sum = np.floor(np.sum(2**8 / 2**shift, axis = -1))
exp_sum = np.floor(np.sum(2**8 >> shift, axis = -1))
else:
exp_sum = np.sum(1 / 2**shift, axis = -1)

# Update the accumulated sum and add the accumulation over the current part of the row
if integerize:
exp_partial_sum = np.floor((exp_partial_sum / 2**shift_sum)) + exp_sum
exp_partial_sum = np.floor((exp_partial_sum.astype(np.int32) >> shift_sum)) + exp_sum
else:
exp_partial_sum = (exp_partial_sum / 2**(shift_sum.astype(np.float32))) + exp_sum

## STAGE 2: Calculate the softmax activation
# Invert the partial sum
if integerize:
exp_partial_sum_inverse = np.floor((2**8 - 1) * 2**8 / exp_partial_sum).astype(np.int32)
exp_partial_sum_inverse = np.floor((2**8 - 1) * 2**8 // exp_partial_sum).astype(np.int32)
else:
exp_partial_sum_inverse = 1 / exp_partial_sum

Expand All @@ -176,8 +178,8 @@ def streamingPartialSoftmax(x, integerize = True):
if integerize:
# A_partial_softmax[0] = np.repeat(exp_partial_sum_inverse, seq_length).reshape(seq_length, seq_length) >> shift
return np.floor(
np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift).astype(
np.int8)
np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) >> shift).astype(
np.uint8)
else:
return np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift

Expand All @@ -187,7 +189,8 @@ def realSoftmax(A_requant, integerize = True):

B = 8
log2e = np.log2(np.exp(1))
eps_x = B / (2**B * log2e)
range_scale = 32
eps_x = range_scale * B / (2**B * log2e)

if integerize:
x = A_requant * eps_x
Expand Down
7 changes: 4 additions & 3 deletions src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ package ita_package;
parameter int unsigned M3AddrWidth = idx_width(S) ;
parameter int unsigned NumReadPorts = N ;
parameter int unsigned MNumReadPorts = N ;
parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 14 `endif;
parameter int unsigned FifoDepth = `ifdef ITA_OUTPUT_FIFO_DEPTH `ITA_OUTPUT_FIFO_DEPTH `else 12 `endif;
localparam int unsigned SplitFactor = 4 ;
parameter int unsigned N_WRITE_EN = `ifdef TARGET_ITA_HWPE 8 `else M `endif;

Expand Down Expand Up @@ -96,12 +96,13 @@ package ita_package;

// Softmax
localparam int unsigned SoftmaxScalar = 65280; // (2**8-1) * 2**8
localparam int unsigned SoftmaxShift = 0;
localparam int unsigned SoftmaxAccDataWidth = 19; // Up to S = 2048
localparam int unsigned SoftFifoDepth = 4;
localparam int unsigned SoftFifoDepth = 12;
typedef logic [idx_width(SoftFifoDepth)-1:0] soft_fifo_usage_t;
typedef logic [idx_width(SoftFifoDepth+1)-1:0] ongoing_soft_t;
localparam int unsigned DividerWidth = SoftmaxAccDataWidth + 1;
localparam int unsigned NumDiv = 5;
localparam int unsigned NumDiv = 10;

// Requantizer
typedef enum {Signed=0, Unsigned=1} requant_mode_e;
Expand Down
24 changes: 12 additions & 12 deletions src/ita_softmax.sv
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ module ita_softmax
requant_oup_t requant_oup_q;
requant_t max_d, max_q;

logic unsigned [N-1:0][3:0] shift_d, shift_q;
logic unsigned [N-1:0][WI-SoftmaxShift:0] shift_d, shift_q;
logic [N-1:0][WI-1:0] shift_diff;
logic unsigned [3:0] shift_sum_d, shift_sum_q;
logic unsigned [WI-SoftmaxShift:0] shift_sum_d, shift_sum_q;
logic [WI-1:0] max_diff;
logic unsigned [M-1:0][3:0] shift_inp;
logic unsigned [M-1:0][WI-SoftmaxShift:0] shift_inp;
logic [M-1:0][WI-1:0] shift_inp_diff;

logic calc_stream_soft_en_q;
Expand Down Expand Up @@ -139,18 +139,18 @@ module ita_softmax
max_d = max_i;
for (int i = 0; i < N; i++) begin
shift_diff[i] = max_i - requant_oup_q[i];
shift_d[i] = unsigned'(shift_diff[i]) >> 5;
if (shift_diff[i][4])
shift_d[i] = (unsigned'(shift_diff[i]) >> 5) + 1;
shift_d[i] = unsigned'(shift_diff[i]) >> SoftmaxShift;
if (SoftmaxShift != 0 && shift_diff[i][SoftmaxShift-1])
shift_d[i] = (unsigned'(shift_diff[i]) >> SoftmaxShift) + 1;
end
if (tile_q2 != '0 || count_q2>=M) begin // If not first part of the first row, normalize previous sum
read_acc_en_o[0] = 1;
read_acc_addr_o[0] = count_q2;
prev_max_o = read_max_data_i[0];
max_diff = max_i - prev_max_o;
shift_sum_d = max_diff >> 5;
if (max_diff[4])
shift_sum_d = (max_diff >> 5) + 1;
shift_sum_d = max_diff >> SoftmaxShift;
if (SoftmaxShift != 0 && max_diff[SoftmaxShift-1])
shift_sum_d = (max_diff >> SoftmaxShift) + 1;
end else begin
prev_max_o = 8'h80;
end
Expand Down Expand Up @@ -223,9 +223,9 @@ module ita_softmax
if (calc_stream_soft_en_q) begin
for (int i = 0; i < M; i++) begin
shift_inp_diff[i] = read_max_data_i[1]-inp_i[i];
shift_inp[i] = unsigned'(shift_inp_diff[i]) >> 5;
if (shift_inp_diff[i][4])
shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> 5) + 1;
shift_inp[i] = unsigned'(shift_inp_diff[i]) >> SoftmaxShift;
if (SoftmaxShift != 0 && shift_inp_diff[i][SoftmaxShift-1])
shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> SoftmaxShift) + 1;
inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i];
end
end
Expand Down