diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7b4c3ce..2c97e25 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,9 +26,17 @@ generate_testvectors: stage: test script: - !reference [.setup_test, script] - - python testGenerator.py -H 1 -S 64 -E 64 -P 64 -F 64 --activation gelu - - python testGenerator.py -H 1 -S 128 -E 192 -P 256 -F 256 --activation gelu - - python testGenerator.py -H 1 -S 192 -E 256 -P 128 -F 128 --activation relu + - python testGenerator.py -H 1 -S 64 -E 64 -P 64 -F 64 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 128 -E 192 -P 256 -F 256 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 192 -E 256 -P 128 -F 128 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 1 -E 2 -P 3 -F 3 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 1 -E 2 -P 3 -F 3 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 63 -E 62 -P 61 -F 61 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 65 -E 130 -P 195 -F 195 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 127 -E 190 -P 253 -F 253 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 511 -E 511 -P 127 -F 63 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 63 -E 63 -P 50 -F 129 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 255 -E 63 -P 511 -F 511 --activation identity --skip-vector-validation artifacts: paths: - simvectors @@ -94,6 +102,73 @@ run_sim: - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb +run_sim_padding: + stage: sim + needs: + - generate_testvectors + parallel: + matrix: + - S: 1 + E: 2 + P: 3 + F: 3 + activation: gelu + no_stalls: 0 + single_attention: 0 + - S: 1 + E: 2 + P: 3 + F: 3 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 63 + E: 62 + P: 61 + F: 61 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 65 + E: 130 + P: 195 + F: 195 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 127 + E: 190 + P: 253 + F: 253 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 511 + E: 511 + P: 127 + F: 63 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 63 + E: 63 + P: 50 + F: 129 + activation: gelu + no_stalls: 0 + single_attention: 0 + - S: 255 + E: 63 + P: 511 + F: 511 + activation: identity + no_stalls: 0 + single_attention: 0 + script: + - make bender + - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention + - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb + run_hwpe_sim: stage: sim needs: diff --git a/.vscode/launch.json b/.vscode/launch.json index 4e54398..42f08d8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,6 +17,7 @@ "-S${input:seq_len}", "-E${input:emb_len}", "-P${input:prj_len}", + "--no-bias" ], } ], diff --git a/Makefile b/Makefile index 3359ca7..c8192e7 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ BENDER_INSTALL_DIR = ${INSTALL_DIR}/bender VENV_BIN=venv/bin/ BENDER_VERSION = 0.28.1 -SIM_PATH ?= modelsim/build +SIM_FOLDER ?= build +SIM_PATH ?= modelsim/${SIM_FOLDER} SYNTH_PATH = synopsys BENDER_TARGETS = -t rtl -t test diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 24f7b0b..8c304f6 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -22,6 +22,9 @@ import numpy as np from numpy.typing import ArrayLike, DTypeLike +import seaborn as sns +import matplotlib.pyplot as plt + from .softmax import fastSoftmax, realSoftmax, streamingPartialSoftmax from .gelu import gelu_requantize, i_gelu_requantized, get_i_gelu_constants, get_i_gelu_requantized_constants from .util import (generate_matrix_mem, pack_8b_to_word, pack_array_8b_to_word, pack_hex_24b, pack_multihead_8b_to_word, @@ -69,10 +72,10 @@ def __init__(self, self._init_paths(path) - self.S_ITA = max(64, S) - self.P_ITA = max(64, P) - self.E_ITA = max(64, E) - self.F_ITA = max(64, F) + self.S_ITA = ((S - 1) // self.ITA_M + 1) * self.ITA_M + self.P_ITA = ((P - 1) // self.ITA_M + 1) * self.ITA_M + self.E_ITA = ((E - 1) // self.ITA_M + 1) * self.ITA_M + self.F_ITA = ((F - 1) // self.ITA_M + 1) * self.ITA_M self.H_ITA = 4 self.split = self.ITA_M // self.ITA_N @@ -110,10 +113,10 @@ def _validate_matrix_constraints(self, K: ArrayLike, V: ArrayLike): assert (np.all(K == V)) # WIESEP: Current restrictions for ITA - assert (self.S % self.ITA_M == 0), "Sequence length must be divisible by ITA_M" - assert (self.P % self.ITA_M == 0), "Projection space must be divisible by ITA_M" - assert (self.E % self.ITA_M == 0), "Embedding size must be divisible by ITA_M" - assert (self.F % self.ITA_M == 0), "Feedforward size must be divisible by ITA_M" + # assert (self.S % self.ITA_M == 0), "Sequence length must be divisible by ITA_M" + # assert (self.P % self.ITA_M == 0), "Projection space must be divisible by ITA_M" + # assert (self.E % self.ITA_M == 0), "Embedding size must be divisible by ITA_M" + # assert (self.F % self.ITA_M == 0), "Feedforward size must be divisible by ITA_M" assert ( self.E <= 512 @@ -172,7 +175,9 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bq_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bq = np.pad(self.Bq_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bq_broadcast = np.reshape(np.repeat(self.Bq, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bq_broadcast = np.reshape(np.repeat(self.Bq, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bq_broadcast = np.pad(self.Bq_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) + if self.bias: self.Bk_in = random_shuffled_tensor( @@ -180,7 +185,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bk_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bk = np.pad(self.Bk_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bk_broadcast = np.reshape(np.repeat(self.Bk, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bk_broadcast = np.reshape(np.repeat(self.Bk, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bk_broadcast = np.pad(self.Bk_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bv_in = random_shuffled_tensor( @@ -188,7 +194,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bv_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bv = np.pad(self.Bv_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bv_broadcast = np.reshape(np.repeat(self.Bv, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bv_broadcast = np.reshape(np.repeat(self.Bv, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bv_broadcast = np.pad(self.Bv_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bo_in = random_shuffled_tensor( @@ -196,7 +203,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bo_in = np.zeros((self.H, self.E), dtype = np.int8) self.Bo = np.pad(self.Bo_in, ((0, 0), (0, self.E_ITA - self.E))) - self.Bo_broadcast = np.reshape(np.repeat(self.Bo, self.S, axis = 0), (self.H, self.S, self.E)) + self.Bo_broadcast = np.reshape(np.repeat(self.Bo, self.S, axis = 0), (self.H, self.S, self.E_ITA)) + self.Bo_broadcast = np.pad(self.Bo_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bff_in = random_shuffled_tensor( @@ -204,14 +212,16 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bff_in = np.zeros((1, self.F), dtype = np.int8) self.Bff = np.pad(self.Bff_in, ((0, 0), (0, self.F_ITA - self.F))) - self.Bff_broadcast = np.reshape(np.repeat(self.Bff, self.S, axis = 0), (1, self.S, self.F)) + self.Bff_broadcast = np.reshape(np.repeat(self.Bff, self.S, axis = 0), (1, self.S, self.F_ITA)) + self.Bff_broadcast = np.pad(self.Bff_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bff2_in = random_shuffled_tensor( (1, self.E), int(np.log2(self.E)) + 8, type = np.int32) if Bff2 is None else Bff2 else: self.Bff2_in = np.zeros((1, self.E), dtype = np.int8) self.Bff2 = np.pad(self.Bff2_in, ((0, 0), (0, self.E_ITA - self.E))) - self.Bff2_broadcast = np.reshape(np.repeat(self.Bff2, self.S, axis = 0), (1, self.S, self.E)) + self.Bff2_broadcast = np.reshape(np.repeat(self.Bff2, self.S, axis = 0), (1, self.S, self.E_ITA)) + self.Bff2_broadcast = np.pad(self.Bff2_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) #### Intermediate tensors #### @@ -348,6 +358,9 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: # Weight Wqk is H x E x P # Transpose Wqk to H x P x E + # print(f"qk: {qk.shape}") + # print(f"qk: {weight.shape}") + weight = np.transpose(weight, (0, 2, 1)) tile_x = qk.shape[0] // self.ITA_M # S // ITA_M @@ -362,6 +375,19 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: Input = np.tile(Input, [1, 1, self.split, 1]) # Repeat each tile number of output row tiles times Input = np.tile(Input, [1, tile_y, 1, 1]).reshape((-1, self.ITA_M)) + # fig, ax = plt.subplots(1, 2) # Create a figure with two subplots + # im0 = ax[0].imshow(Input, cmap='viridis') + # im1 = ax[1].imshow(np.squeeze(weight, axis=0)) + + # # Add colorbars for each image if needed + # fig.colorbar(im0, ax=ax[0]) + # fig.colorbar(im1, ax=ax[1]) + + # # Set titles for each subplot + # ax[0].set_title("Inputs") + # ax[1].set_title("Weights") + + plt.show() write_matrix(Input, input_file, self.paths["standalone"]) # Transposed Weight Wqk is H x P x E @@ -373,7 +399,7 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: # Bias Bqk is H x P # Broadcast Bias Bqk to H x S x P - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) for h in range(self.H): Bias = split_matrix(bias[h], (self.ITA_M, self.ITA_N)) write_matrix(Bias, f"{bias_file}_{h}", self.paths["standalone"]) @@ -416,7 +442,7 @@ def tiler_V(self, v, weight, bias, output, input_file, weight_file, bias_file, o # Bias Bv is H x P # Broadcast Bias Bv to H x S x P - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) # Transpose Bias Bv to H x P x S bias = np.transpose(bias, (0, 2, 1)) for h in range(self.H): @@ -497,7 +523,7 @@ def tiler_Out(self, O, weight, bias, output, input_file, weight_file, bias_file, # Bias Bo is H x E # Broadcast Bias Bo to H x S x E - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) for h in range(self.H): Bias = split_matrix(bias[h], (self.ITA_M, self.ITA_N)) write_matrix(Bias, f"{bias_file}_{h}", self.paths["standalone"]) @@ -512,6 +538,12 @@ def step1_Qp(self): self.Qp = np.clip(self.Qp, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.Qp_requant = requantize(self.Qp, self.requant_eps_mult[0], self.requant_right_shift[0], self.requant_add[0]) + + # Set padded values to zero + if (self.S_ITA - self.S) > 0: + self.Qp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Qp_requant[:, :, -(self.P_ITA - self.P):] = 0 self.tiler_QK(self.Q, self.Wq, self.Bq, self.Qp_requant, "Q", "Wq", "Bq", "Qp") @@ -521,6 +553,11 @@ def step2_Kp(self): self.Kp_requant = requantize(self.Kp, self.requant_eps_mult[1], self.requant_right_shift[1], self.requant_add[1]) + if (self.S_ITA - self.S) > 0: + self.Kp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Kp_requant[:, :, -(self.P_ITA - self.P):] = 0 + self.tiler_QK(self.K, self.Wk, self.Bk, self.Kp_requant, "K", "Wk", "Bk", "Kp") def step3_Vp(self): @@ -529,6 +566,11 @@ def step3_Vp(self): self.Vp_requant = requantize(self.Vp, self.requant_eps_mult[2], self.requant_right_shift[2], self.requant_add[2]) + if (self.S_ITA - self.S) > 0: + self.Vp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Vp_requant[:, :, -(self.P_ITA - self.P):] = 0 + # Compute Vp in transposed form self.tiler_V(self.V, self.Wv, self.Bv, self.Vp_requant, "V", "Wv", "Bv", "Vp") @@ -537,16 +579,27 @@ def step4_QK(self, no_partial_softmax): [np.matmul(self.Qp_requant[i], np.transpose(self.Kp_requant[i]), dtype = np.int32) for i in range(self.H)]) self.A = np.clip(self.A, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.A_requant = requantize(self.A, self.requant_eps_mult[3], self.requant_right_shift[3], self.requant_add[3]) + + if (self.S_ITA - self.S) > 0: + self.A_requant[:, -(self.S_ITA - self.S):, :] = 0 + self.A_requant[:, :, -(self.S_ITA - self.S):] = 0 + self.soft(no_partial_softmax) self.tiler_AV(self.Qp_requant, self.Kp_requant, self.A_requant, "Qp_in", "Kp_in", "A") def soft(self, no_partial_softmax = False): - self.A_real_softmax = realSoftmax(self.A_requant) + self.A_real_softmax = realSoftmax(self.A_requant[:, :self.S, :self.S]) + self.A_real_softmax = np.pad(self.A_real_softmax, ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) + if no_partial_softmax: - self.A_partial_softmax = fastSoftmax(self.A_requant) + self.A_partial_softmax = fastSoftmax(self.A_requant[:, :self.S, :self.S]) + self.A_partial_softmax = np.pad(self.A_partial_softmax, + ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) else: - self.A_partial_softmax = streamingPartialSoftmax(self.A_requant) + self.A_partial_softmax = streamingPartialSoftmax(self.A_requant[:, :self.S, :self.S]) + self.A_partial_softmax = np.pad(self.A_partial_softmax, + ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) if self.H == 1: A_save = [np.tile(self.A_partial_softmax[i], [self.split, 1]) for i in range(self.H)] @@ -564,6 +617,11 @@ def step5_AV(self): self.O_soft_requant = requantize(self.O_soft, self.requant_eps_mult[4], self.requant_right_shift[4], self.requant_add[4]) + if (self.S_ITA - self.S) > 0: + self.O_soft_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.O_soft_requant[:, :, -(self.P_ITA - self.P):] = 0 + self.tiler_AV(self.A_requant, np.transpose(self.Vp_requant, (0, 2, 1)), self.O_soft_requant, "A_stream_soft_in", "Vp_in", "O_soft") @@ -590,6 +648,12 @@ def step6_O(self): self.Out_soft = np.clip(self.Out_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.Out_soft_requant = requantize(self.Out_soft, self.requant_eps_mult[5], self.requant_right_shift[5], self.requant_add[5]) + + if (self.S_ITA - self.S) > 0: + self.Out_soft_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.E_ITA - self.E) > 0: + self.Out_soft_requant[:, :, -(self.E_ITA - self.E):] = 0 + self.tiler_Out(self.O_soft_requant, self.Wo, self.Bo, self.Out_soft_requant, "O_soft_in", "Wo", "Bo", "Out_soft") @@ -599,7 +663,7 @@ def feedforward_layer(self): self.FFp_requant = requantize(self.FFp, self.requant_eps_mult_ffn[0], self.requant_right_shift_ffn[0], self.requant_add_ffn[0]) self.FFp_requant = self.apply_activation(self.FFp_requant, self.activation) - + self.tiler_QK(self.FF, self.Wff, self.Bff, self.FFp_requant, "FF", "Wff", "Bff", "FFp") self.FF2p = np.matmul(self.FFp_requant, self.Wff2, dtype = np.int32) + self.Bff2_broadcast @@ -934,8 +998,8 @@ def export_mempool(self, path): def export_numpy(self): assert np.all(np.equal(self.K, self.V)), "For ITA, keys and values have to be equal" - q = self.Q - k = self.K + q = self.Q_in + k = self.K_in w1 = self.Wq_in b1 = self.Bq_in w2 = self.Wk_in diff --git a/PyITA/ITA_onnx.py b/PyITA/ITA_onnx.py index eda85f3..235cf00 100644 --- a/PyITA/ITA_onnx.py +++ b/PyITA/ITA_onnx.py @@ -259,8 +259,8 @@ def exportONNX(path, verbose = False, **kwargs): # Transform from MUL-DIV-ADD to MUL-ADD-DIV RQ_ADD = (RQ_ADD * 2**RQ_SHIFT.astype(np.float32)) - input0_values = np.expand_dims(inputs['q'][:(S * E // 64), :].reshape(S, E), axis = 0) - input1_values = np.expand_dims(inputs['k'][:(S * E // 64), :].reshape(S, E), axis = 0) + input0_values = np.expand_dims(inputs['q'].reshape(S, E), axis = 0) + input1_values = np.expand_dims(inputs['k'].reshape(S, E), axis = 0) np.savez(path + "inputs.npz", input0_values, input1_values) diff --git a/PyITA/softmax.py b/PyITA/softmax.py index 8cbc5cf..84d063c 100644 --- a/PyITA/softmax.py +++ b/PyITA/softmax.py @@ -14,6 +14,8 @@ # # ---------------------------------------------------------------------- +import argparse + import numpy as np @@ -71,10 +73,7 @@ def streamingPartialSoftmax(x, integerize = True): seq_length = x.shape[-1] n_heads = x.shape[-3] - width = 16 # 16 PE (processing units) - groups = seq_length // width - - assert seq_length % width == 0, f"Sequence length must be a multiple of width ({width})" + PE = 16 # 16 PE (processing units) # Number of bits B = 8 @@ -101,12 +100,14 @@ def streamingPartialSoftmax(x, integerize = True): global_max = np.full((n_heads, seq_length), -np.Infinity, dtype = np.float32) ## STAGE 1: Compute the denominator of the softmax - for i in range(groups): + for i in range((seq_length + PE - 1) // PE): + width = seq_length % PE if i * PE + PE > seq_length else PE + # Find the maximum for each row in the current column block (consisting of 16 columns) if integerize: - current_max = np.max(x[..., 0 + i * width:width + i * width].astype(np.int32), axis = -1) + current_max = np.max(x[..., 0 + i * PE:width + i * PE].astype(np.int32), axis = -1) else: - current_max = np.max(x[..., 0 + i * width:width + i * width].astype(np.float32), axis = -1) + current_max = np.max(x[..., 0 + i * PE:width + i * PE].astype(np.float32), axis = -1) # Initialize all shift values for each row to zero if integerize: @@ -129,11 +130,11 @@ def streamingPartialSoftmax(x, integerize = True): # Find the difference between the maximum and x in the current part of the row if integerize: - diff = np.repeat(global_max, width).reshape( - n_heads, seq_length, width) - x[..., 0 + i * width:width + i * width].astype(np.int32) + diff = np.repeat(global_max, width).reshape(n_heads, seq_length, + width) - x[..., 0 + i * PE:width + i * PE].astype(np.int32) else: - diff = np.repeat(global_max, width).reshape( - n_heads, seq_length, width) - x[..., 0 + i * width:width + i * width].astype(np.float32) + diff = np.repeat(global_max, width).reshape(n_heads, seq_length, + width) - x[..., 0 + i * PE:width + i * PE].astype(np.float32) # Shift the values by B-log2B -> multiply by B/2**B = log2e*eps_x # Make sure to do use round-half-up instead of round-half-to-even @@ -177,7 +178,7 @@ def streamingPartialSoftmax(x, integerize = True): # A_partial_softmax[0] = np.repeat(exp_partial_sum_inverse, seq_length).reshape(seq_length, seq_length) >> shift return np.floor( np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift).astype( - np.int8) + np.uint8) else: return np.repeat(exp_partial_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift @@ -195,7 +196,66 @@ def realSoftmax(A_requant, integerize = True): x = A_requant.astype(np.float64) exp = np.exp(x - np.max(x, axis = 2).reshape(n_heads, -1, 1)) + + # Replace nan with zero + exp = np.nan_to_num(exp) + if integerize: return (exp / exp.sum(axis = 2).reshape(n_heads, -1, 1) * (2**7 - 1)).astype(A_requant.dtype) else: return exp / exp.sum(axis = 2).reshape(n_heads, -1, 1) + + +if __name__ == "__main__": + np.set_printoptions(linewidth = 120) + np.set_printoptions(precision = 4) + + # Always print whole array + np.set_printoptions(threshold = np.inf) + + parser = argparse.ArgumentParser(description = "Test Utility for Softmax.") + # Sequence length + parser.add_argument("-S", default = 64, type = int, help = "Sequence length") + + # ITA sequence length + parser.add_argument("-M", default = 64, type = int, help = "ITA sequence length") + + # Quantization (float or int) + parser.add_argument("--int", action = "store_true", help = "Quantize to int") + parser.add_argument('--seed', default = 0, type = int, help = 'Random seed') + + args = parser.parse_args() + + ITA_WI = 8 + WO = 26 + ITA_N = 16 + ITA_M = args.M + + if args.seed != -1: + np.random.seed(args.seed) + + if args.int: + x = np.random.randint(-128, 128, (1, 1, args.S, args.S)).astype(np.int8) + else: + x = np.random.randn(1, 1, 16, 16).astype(np.float32) + + print("Input:") + print(x) + + # Pad last two dimensions to be a multiple of ITA_M + pad_x = (ITA_M - x.shape[-1] % ITA_M) % ITA_M + pad_y = (ITA_M - x.shape[-2] % ITA_M) % ITA_M + pad_value = -2**(ITA_WI - 1) if args.int else -np.inf + + print(f"Padding x by ({pad_y}, {pad_x}) with {pad_value}") + x_pad = np.pad(x, ((0, 0), (0, 0), (0, pad_y), (0, pad_x)), mode = 'constant', constant_values = pad_value) + + res = realSoftmax(x, integerize = args.int) + res_pad = realSoftmax(x_pad, integerize = args.int) + + res_unpad = res_pad[:, :, :args.S, :args.S] + + # Compare results + print(f"Equal: {np.allclose(res, res_unpad, atol = 1e-3)}") + print(res) + print(res_unpad) diff --git a/README.md b/README.md index 3a29e23..b58d20e 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ To install the required Python packages, create a virtual environment. Make sure $> python -m venv venv $> source venv/bin/activate $> pip install -r requirements.txt +$> pip install -r requirements.dev.txt # Only required for PyITA/test_gelu.py ``` If you want to enable pre-commit hooks, which perform code formatting and linting, run the following command: diff --git a/modelsim/Makefile b/modelsim/Makefile index 7d181aa..8aec4cf 100644 --- a/modelsim/Makefile +++ b/modelsim/Makefile @@ -6,7 +6,7 @@ all: lib build QUESTA_SEPP ?= questa-2023.4 -buildpath ?= build +buildpath ?= $(SIM_FOLDER) VOPT ?= $(QUESTA_SEPP) vopt VSIM ?= $(QUESTA_SEPP) vsim VLIB ?= $(QUESTA_SEPP) vlib diff --git a/modelsim/sim_ita_tb_wave.tcl b/modelsim/sim_ita_tb_wave.tcl index 4f29360..072c1cd 100644 --- a/modelsim/sim_ita_tb_wave.tcl +++ b/modelsim/sim_ita_tb_wave.tcl @@ -11,6 +11,36 @@ add wave -noupdate /ita_tb/dut/i_inp2_mux/rst_ni add wave -noupdate /ita_tb/dut/i_inp2_mux/weight_i add wave -noupdate /ita_tb/dut/i_inp2_mux/inp2_o add wave -noupdate /ita_tb/dut/i_controller/ctrl_i +add wave -noupdate /ita_tb/dut/oup_o +add wave -noupdate /ita_tb/dut/inp1_q +add wave -noupdate /ita_tb/dut/inp2_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/count_d +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/count_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/bias_count +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/bias_tile_x_d +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/bias_tile_x_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/bias_tile_y_d +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/bias_tile_y_q +add wave -noupdate -expand -group Requant /ita_tb/dut/i_controller/requant_add_i +add wave -noupdate -expand -group Requant /ita_tb/dut/i_controller/requant_add_o +add wave -noupdate -expand -group Requant /ita_tb/dut/i_controller/step_q +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_padded +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_q1 +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_q2 +add wave -noupdate /ita_tb/dut/i_accumulator/oup_i +add wave -noupdate /ita_tb/dut/i_accumulator/result_d +add wave -noupdate /ita_tb/dut/i_accumulator/result_o +add wave -noupdate /ita_tb/dut/i_requantizer/requant_oup_o +add wave -noupdate /ita_tb/dut/i_activation/data_i +add wave -noupdate /ita_tb/dut/i_activation/data_q1 +add wave -noupdate /ita_tb/dut/i_activation/data_q2 +add wave -noupdate /ita_tb/dut/i_activation/data_q3 +add wave -noupdate /ita_tb/dut/i_activation/data_q4 +add wave -noupdate /ita_tb/dut/i_activation/data_o +add wave -noupdate /ita_tb/dut/i_fifo/data_i +add wave -noupdate /ita_tb/dut/i_fifo/data_o +add wave -noupdate /ita_tb/dut/oup_o add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/* add wave -expand -group Controller /ita_tb/dut/i_controller/* add wave -group {Softmax Controller} ita_tb/dut/i_softmax_top/i_softmax/* diff --git a/modelsim/sim_ita_tb_wave_important.tcl b/modelsim/sim_ita_tb_wave_important.tcl new file mode 100644 index 0000000..6513e4c --- /dev/null +++ b/modelsim/sim_ita_tb_wave_important.tcl @@ -0,0 +1,241 @@ +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +onerror {resume} +quietly WaveActivateNextPane {} 0 +add wave -noupdate /ita_tb/dut/i_inp1_mux/clk_i +add wave -noupdate /ita_tb/dut/i_inp1_mux/rst_ni +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp_i +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp1_o +add wave -noupdate /ita_tb/dut/i_inp2_mux/rst_ni +add wave -noupdate /ita_tb/dut/i_inp2_mux/weight_i +add wave -noupdate /ita_tb/dut/i_inp2_mux/inp2_o +add wave -noupdate /ita_tb/dut/i_controller/ctrl_i +add wave -noupdate /ita_tb/dut/inp1_q +add wave -noupdate /ita_tb/dut/inp2_q +add wave -noupdate /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/count_d +add wave -noupdate -radix unsigned /ita_tb/dut/i_controller/count_q +add wave -noupdate /ita_tb/dut/calc_en +add wave -noupdate /ita_tb/dut/calc_en_q1 +add wave -noupdate /ita_tb/dut/calc_en_q2 +add wave -noupdate /ita_tb/dut/calc_en_q3 +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_padded +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_q1 +add wave -noupdate -expand -group Bias /ita_tb/dut/inp_bias_q2 +add wave -noupdate /ita_tb/dut/i_accumulator/oup_i +add wave -noupdate /ita_tb/dut/i_accumulator/result_d +add wave -noupdate /ita_tb/dut/i_accumulator/result_o +add wave -noupdate /ita_tb/dut/i_requantizer/requant_oup_o +add wave -noupdate /ita_tb/dut/i_activation/data_i +add wave -noupdate /ita_tb/dut/i_activation/data_q1 +add wave -noupdate /ita_tb/dut/i_activation/data_q2 +add wave -noupdate /ita_tb/dut/i_activation/data_q3 +add wave -noupdate /ita_tb/dut/i_activation/data_q4 +add wave -noupdate /ita_tb/dut/i_activation/data_o +add wave -noupdate /ita_tb/dut/i_fifo/data_i +add wave -noupdate /ita_tb/dut/i_fifo/data_o +add wave -noupdate /ita_tb/dut/oup_o +add wave -noupdate -expand -group Softmax /ita_tb/dut/i_softmax_top/i_softmax/requant_oup_q +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/clk_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/rst_ni +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/mode_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/eps_mult_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/right_shift_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/calc_en_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/calc_en_q_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/mult_signed +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/product +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_added +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_d +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_q +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q1 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q2 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q3 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q4 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/requant_oup_d +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/requant_oup_q +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/clk_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/rst_ni +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ctrl_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/weight_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/weight_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/oup_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/oup_ready_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/pop_softmax_fifo_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/soft_addr_div_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_done_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/calc_en_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/last_inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_pad_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_soft_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_soft_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_padded +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_outer_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/second_outer_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_fifo +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div_done_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div_done_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/clk_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/rst_ni +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/ctrl_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/step_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/requant_oup_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/soft_addr_div_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/softmax_done_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/pop_softmax_fifo_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inp_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_inp_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_valid_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_ready_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_valid_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_ready_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_oup_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_data_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_data_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/prev_max_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_data_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_data_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_x_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_y_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q4 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q4 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_y_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/addr_div_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/addr_div_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_read_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_read_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_write_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_write_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/requant_oup_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_sum_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_sum_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_full +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_empty +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/push_to_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/pop_from_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/data_to_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/data_from_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_usage +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_shift +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_row +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/clk_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/rst_ni +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/calc_en_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/calc_en_q_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/first_tile_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/first_tile_q_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/last_tile_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/last_tile_q_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/inp_bias_i +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_en +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_addr +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_data +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_data_unused +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/write_en +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/write_addr +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/write_data +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_addr_d +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/read_addr_q +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/write_addr_d +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/write_addr_q +add wave -noupdate -expand -group Accumulator /ita_tb/dut/i_accumulator/result_q +TreeUpdate [SetDefaultTree] +WaveRestoreCursors {{Cursor 1} {414600 ps} 1} {{Cursor 2} {550600 ps} 1} {{Cursor 3} {710600 ps} 1} {{Cursor 4} {390540 ps} 0} +quietly wave cursor active 4 +configure wave -namecolwidth 176 +configure wave -valuecolwidth 100 +configure wave -justifyvalue left +configure wave -signalnamewidth 1 +configure wave -snapdistance 10 +configure wave -datasetprefix 0 +configure wave -rowmargin 4 +configure wave -childrowmargin 2 +configure wave -gridoffset 0 +configure wave -gridperiod 1 +configure wave -griddelta 40 +configure wave -timeline 0 +configure wave -timelineunits ns +update +WaveRestoreZoom {371422 ps} {416865 ps} diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000..faa86d4 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,7 @@ +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +torch +pytest +pytest-check diff --git a/requirements.txt b/requirements.txt index 4bfbc47..e8e03c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,4 @@ onnxruntime netron seaborn matplotlib -torch -pytest -pytest-check pre-commit diff --git a/src/ita.sv b/src/ita.sv index 2dad263..6a9c1a2 100644 --- a/src/ita.sv +++ b/src/ita.sv @@ -40,7 +40,7 @@ module ita logic weight_valid, weight_ready; inp_t inp, inp_stream_soft; weight_t inp1, inp1_q, inp2, inp2_q; - bias_t inp_bias, inp_bias_q1, inp_bias_q2; + bias_t inp_bias, inp_bias_padded, inp_bias_q1, inp_bias_q2; oup_t oup, oup_q, accumulator_oup; requant_const_t requant_mult, requant_shift, activation_requant_mult, activation_requant_shift; requant_oup_t requant_oup; @@ -153,8 +153,8 @@ module ita if (!rst_ni) begin inp1_q <= '0; inp2_q <= '0; - inp_bias_q2 <= '0; inp_bias_q1 <= '0; + inp_bias_q2 <= '0; oup_q <= '0; end else begin if (calc_en_q2) begin @@ -162,7 +162,7 @@ module ita oup_q <= oup; end if (calc_en_q1) begin - inp_bias_q1 <= inp_bias; + inp_bias_q1 <= inp_bias_padded; inp1_q <= inp1; inp2_q <= inp2; end @@ -171,6 +171,12 @@ module ita assign oup_o = valid_o ? data_from_fifo : '0; + requant_oup_t requant_add_o; + + counter_t inner_tile; + counter_t tile_x; + counter_t tile_y; + ita_controller i_controller ( .clk_i (clk_i ), .rst_ni (rst_ni ), @@ -190,6 +196,13 @@ module ita .calc_en_o (calc_en ), .first_inner_tile_o (first_inner_tile ), .last_inner_tile_o (last_inner_tile ), + .tile_x_o (tile_x ), + .tile_y_o (tile_y ), + .inner_tile_o (inner_tile ), + .requant_add_i (requant_add ), + .requant_add_o (requant_add_o ), + .inp_bias_i (inp_bias ), + .inp_bias_pad_o (inp_bias_padded ), .busy_o (busy_o ) ); @@ -255,13 +268,16 @@ module ita .soft_addr_div_o (soft_addr_div ), .softmax_done_o (softmax_done ), .pop_softmax_fifo_o (pop_softmax_fifo ), - .inp_stream_soft_o (inp_stream_soft ) + .inp_stream_soft_o (inp_stream_soft ), + .tile_x_i (tile_x ), + .tile_y_i (tile_y ), + .inner_tile_i (inner_tile ) ); ita_requatization_controller i_requantization_controller ( .ctrl_i (ctrl_i ), - .requantizer_step_i (step_q4 ), + .requantizer_step_i (step_q4 ), .requant_mult_o (requant_mult ), .requant_shift_o (requant_shift ), .requant_add_o (requant_add ), @@ -282,8 +298,8 @@ module ita .calc_en_i ( calc_en_q4 && last_inner_tile_q4 ), .calc_en_q_i ( calc_en_q5 && last_inner_tile_q5 ), - .result_i ( accumulator_oup ), - .add_i ( {N {requant_add}} ), + .result_i ( accumulator_oup ), + .add_i ( requant_add_o ), .requant_oup_o( requant_oup ) ); diff --git a/src/ita_controller.sv b/src/ita_controller.sv index 0fa8034..28e1885 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -10,44 +10,69 @@ module ita_controller import ita_package::*; ( - input logic clk_i , - input logic rst_ni , - input ctrl_t ctrl_i , - input logic inp_valid_i , - output logic inp_ready_o , - input logic weight_valid_i , - output logic weight_ready_o , - input logic bias_valid_i , - output logic bias_ready_o , - input logic oup_valid_i , - input logic oup_ready_i , - input logic pop_softmax_fifo_i , - output step_e step_o , - input counter_t soft_addr_div_i , - input logic softmax_done_i , - output logic calc_en_o , - output logic first_inner_tile_o , - output logic last_inner_tile_o , - output logic busy_o + input logic clk_i , + input logic rst_ni , + input ctrl_t ctrl_i , + input logic inp_valid_i , + output logic inp_ready_o , + input logic weight_valid_i , + output logic weight_ready_o , + input logic bias_valid_i , + output logic bias_ready_o , + input logic oup_valid_i , + input logic oup_ready_i , + input logic pop_softmax_fifo_i , + output step_e step_o , + input counter_t soft_addr_div_i , + input logic softmax_done_i , + output logic calc_en_o , + output logic first_inner_tile_o , + output logic last_inner_tile_o , + output counter_t tile_x_o , + output counter_t tile_y_o , + output counter_t inner_tile_o , + input requant_t requant_add_i , + output requant_oup_t requant_add_o , + input bias_t inp_bias_i , + output bias_t inp_bias_pad_o , + output logic busy_o ); step_e step_d, step_q; - counter_t count_d, count_q; + counter_t count_d, count_q, bias_count; counter_t tile_d, tile_q; counter_t inner_tile_d, inner_tile_q; + counter_t tile_x_d, tile_x_q, bias_tile_x_d, bias_tile_x_q; + counter_t tile_y_d, tile_y_q, bias_tile_y_d, bias_tile_y_q; counter_t softmax_tile_d, softmax_tile_q; ongoing_t ongoing_d, ongoing_q; ongoing_soft_t ongoing_soft_d, ongoing_soft_q; + bias_t inp_bias, inp_bias_padded; + logic last_time; + + tile_t inner_tile_dim; + logic [WO-WI*2-2:0] first_outer_dim, second_outer_dim; + logic [WO-WI*2-2:0] first_outer_dim_d, first_outer_dim_q; + logic [WO-WI*2-2:0] second_outer_dim_d, second_outer_dim_q; + logic softmax_fifo, softmax_div, softmax_div_done_d, softmax_div_done_q, busy_d, busy_q; + requant_oup_t requant_add, requant_add_d, requant_add_q; - assign step_o = step_q; - assign busy_o = busy_q; + assign step_o = step_q; + assign busy_o = busy_q; + assign tile_x_o = tile_x_q; + assign tile_y_o = tile_y_q; + assign inner_tile_o = inner_tile_q; + assign requant_add_o = requant_add_q; + assign inp_bias_pad_o = inp_bias_padded; always_comb begin count_d = count_q; tile_d = tile_q; inner_tile_d = inner_tile_q; + tile_x_d = tile_x_q; + tile_y_d = tile_y_q; first_inner_tile_o = (inner_tile_q == 0) ? 1'b1 : 1'b0; last_inner_tile_o = 1'b0; ongoing_d = ongoing_q; @@ -59,6 +84,8 @@ module ita_controller step_d = step_q; softmax_tile_d = softmax_tile_q; softmax_div_done_d = softmax_div_done_q; + last_time = 1'b0; + requant_add = {N {requant_add_i}}; busy_d = busy_q; softmax_fifo = 1'b0; @@ -98,7 +125,7 @@ module ita_controller busy_d = 1'b1; if (count_d == M*M/N) begin // end of tile busy_d = 1'b0; // Generate done signal for current tile - count_d = '0; + count_d = '0; inner_tile_d = inner_tile_q + 1; end end @@ -108,6 +135,8 @@ module ita_controller case (step_q) Idle : begin inner_tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; tile_d = '0; softmax_tile_d = '0; softmax_div_done_d = 1'b0; @@ -126,51 +155,80 @@ module ita_controller end // Attention Q : begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin // end of step Q + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step Q tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = K; end end end K: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step K tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = V; end end end V: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.proj_space; + second_outer_dim = ctrl_i.seq_length; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_s-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step V tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = QK; end end end QK : begin - if (inner_tile_q == ctrl_i.tile_p-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_p-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.seq_length; if (inner_tile_d == ctrl_i.tile_p) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_s-1)) begin + tile_x_d = '0; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s) begin // end of step QK tile_d = '0; step_d = AV; @@ -178,64 +236,96 @@ module ita_controller end end AV : begin - if (inner_tile_q == ctrl_i.tile_s-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_s-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_s) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin + tile_x_d = '0; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_p) begin tile_d = '0; softmax_tile_d = softmax_tile_q + 1; if (softmax_tile_d == ctrl_i.tile_s) begin softmax_tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; if (ctrl_i.layer == Attention) begin step_d = OW; end else if (ctrl_i.layer == SingleAttention) begin step_d = Idle; end end else begin + tile_y_d = tile_y_q + 1; step_d = QK; end end end end OW : begin - if (inner_tile_q == ctrl_i.tile_p-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_p-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.embed_size; if (inner_tile_d == ctrl_i.tile_p) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_e-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_e) begin // end of step OW tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = Idle; end end end // Feedforward F1: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.ff_size; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; - if (tile_d == ctrl_i.tile_s*ctrl_i.tile_f) begin + if (tile_x_q == (ctrl_i.tile_f-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end + if (tile_d == ctrl_i.tile_s*ctrl_i.tile_f) begin tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = F2; end end end F2: begin - if (inner_tile_q == ctrl_i.tile_f-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_f-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.embed_size; if (inner_tile_d == ctrl_i.tile_f) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_e-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_e) begin tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = Idle; end end @@ -255,6 +345,44 @@ module ita_controller end end endcase + + inp_bias = inp_bias_i; + requant_add_d = requant_add; + bias_count = (count_q == 0) ? 255 : count_q - 1; + bias_tile_x_d = (count_q == 0) ? bias_tile_x_q : tile_x_q; + bias_tile_y_d = (count_q == 0) ? bias_tile_y_q : tile_y_q; + first_outer_dim_d = (count_q == 0) ? first_outer_dim_q : first_outer_dim; + second_outer_dim_d = (count_q == 0) ? second_outer_dim_q : second_outer_dim; + + if ((step_q != Idle && step_q != MatMul) || (step_q == Idle && bias_count == 255)) begin + if (inner_tile_q == inner_tile_dim) begin + last_inner_tile_o = 1'b1; + end + if ((((((bias_count) & (M-1)) + bias_tile_y_d * M)) > ((first_outer_dim_d - 1)))) begin + requant_add_d = {N {1'b0}}; + inp_bias = {N {1'b0}}; + end else begin + if ( ((bias_count) + bias_tile_x_d * M*M/N) >= (second_outer_dim_d / N) * M ) begin + if ( (((bias_count) / M) * N + bias_tile_x_d * M ) < second_outer_dim_d) begin + for (int i = 0; i < N; i++) begin + if (i >= (second_outer_dim_d & (N-1))) begin + requant_add_d[i] = 1'b0; + inp_bias[i] = 1'b0; + end else begin + requant_add_d[i] = requant_add[i]; + inp_bias[i] = inp_bias_i[i]; + end + end + end else begin + requant_add_d = {N {1'b0}}; + inp_bias = {N {1'b0}}; + end + end + end + end + + inp_bias_padded = inp_bias; + if (inp_valid_i && inp_ready_o && oup_valid_i && oup_ready_i && last_inner_tile_o) begin ongoing_d = ongoing_q; end else if (inp_valid_i && inp_ready_o && last_inner_tile_o) begin @@ -276,22 +404,36 @@ module ita_controller step_q <= Idle; count_q <= '0; tile_q <= '0; + tile_x_q <= '0; + tile_y_q <= '0; inner_tile_q <= '0; softmax_tile_q <= '0; ongoing_q <= '0; ongoing_soft_q <= '0; softmax_div_done_q <= 1'b0; + requant_add_q <= '0; busy_q <= 1'b0; + bias_tile_x_q <= '0; + bias_tile_y_q <= '0; + first_outer_dim_q <= '0; + second_outer_dim_q <= '0; end else begin step_q <= step_d; count_q <= count_d; tile_q <= tile_d; + tile_x_q <= tile_x_d; + tile_y_q <= tile_y_d; inner_tile_q <= inner_tile_d; softmax_tile_q <= softmax_tile_d; ongoing_q <= ongoing_d; ongoing_soft_q <= ongoing_soft_d; softmax_div_done_q <= softmax_div_done_d; + requant_add_q <= requant_add_d; busy_q <= busy_d; + bias_tile_x_q <= bias_tile_x_d; + bias_tile_y_q <= bias_tile_y_d; + first_outer_dim_q <= first_outer_dim_d; + second_outer_dim_q <= second_outer_dim_d; end end endmodule diff --git a/src/ita_package.sv b/src/ita_package.sv index 335e173..3a2c25f 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -45,13 +45,17 @@ package ita_package; typedef logic [N_REQUANT_CONSTS-1:0][EMS-1:0] requant_const_array_t; typedef logic signed [WI-1:0] requant_t; typedef logic signed [N_REQUANT_CONSTS-1:0][WI-1:0] requant_array_t; - typedef logic [idx_width(S+1)-1:0] seq_length_t; - typedef logic [idx_width(P+1)-1:0] proj_space_t; - typedef logic [idx_width(E+1)-1:0] embed_size_t; - typedef logic [idx_width(H+1)-1:0] n_heads_t; + typedef logic [WO-WI*2-2:0] seq_length_t; + typedef logic [WO-WI*2-2:0] proj_space_t; + typedef logic [WO-WI*2-2:0] embed_size_t; + typedef logic [WO-WI*2-2:0] ff_size_t; typedef logic [ 32-1:0] tile_t; typedef struct packed { logic start ; + seq_length_t seq_length ; + proj_space_t proj_space ; + embed_size_t embed_size ; + ff_size_t ff_size ; layer_e layer ; activation_e activation ; requant_const_array_t eps_mult ; diff --git a/src/ita_requantization_controller.sv b/src/ita_requantization_controller.sv index 3b6c865..e844358 100644 --- a/src/ita_requantization_controller.sv +++ b/src/ita_requantization_controller.sv @@ -33,6 +33,12 @@ module ita_requatization_controller endcase end + // always_comb begin + // requant_mult = ctrl_i.eps_mult[step_q4]; + // requant_shift = ctrl_i.right_shift[step_q4]; + // requant_add = ctrl_i.add[step]; + // end + assign requant_mult_o = ctrl_i.eps_mult[constant_idx]; assign requant_shift_o = ctrl_i.right_shift[constant_idx]; assign requant_add_o = ctrl_i.add[constant_idx]; diff --git a/src/ita_requantizer.sv b/src/ita_requantizer.sv index 6033c09..c67b97c 100644 --- a/src/ita_requantizer.sv +++ b/src/ita_requantizer.sv @@ -22,7 +22,8 @@ module ita_requantizer logic signed [EMS+WO:0] product ; logic signed [EMS+WO:0] shifted_added; logic signed [ N-1:0][EMS+WO-1:0] shifted_d, shifted_q; - requant_oup_t add_q1, requant_oup_d, requant_oup_q; + requant_oup_t add_q1, add_q2, add_q3, add_q4; + requant_oup_t requant_oup_d, requant_oup_q; assign requant_oup_o = requant_oup_q; @@ -49,7 +50,7 @@ module ita_requantizer end end if (calc_en_q_i) begin - shifted_added = shifted_q[i] + (EMS+WO)'(signed'(add_q1[i])); + shifted_added = shifted_q[i] + (EMS+WO)'(signed'(add_q4[i])); requant_oup_d[i] = shifted_added[WI-1:0]; if (~shifted_added[EMS+WO-1] & (|(shifted_added[EMS+WO-2:WI-1]))) begin requant_oup_d[i] = '1; @@ -76,8 +77,14 @@ module ita_requantizer always_ff @(posedge clk_i, negedge rst_ni) begin if (!rst_ni) begin add_q1 <= '0; + add_q2 <= '0; + add_q3 <= '0; + add_q4 <= '0; end else begin add_q1 <= add_i; + add_q2 <= add_q1; + add_q3 <= add_q2; + add_q4 <= add_q3; end end endmodule diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index 675750c..ac61ed2 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -39,14 +39,19 @@ module ita_softmax input requant_t [1:0] read_max_data_i, output logic write_max_en_o, output logic [InputAddrWidth-1:0] write_max_addr_o, - output requant_t write_max_data_o + output requant_t write_max_data_o, + input counter_t tile_x_i, + input counter_t tile_y_i, + input counter_t inner_tile_i ); counter_t tile_d, tile_q1, tile_q2, tile_q3, tile_q4; counter_t count_d, count_q1, count_q2, count_q3, count_q4; + counter_t inner_tile_q; + counter_t tile_y_q; logic unsigned [SoftmaxAccDataWidth-1:0] exp_sum_d, exp_sum_q; - counter_t count_soft_d, count_soft_q; + counter_t count_soft_d, count_soft_q1, count_soft_q2; counter_t count_div_d, count_div_q, addr_div_d, addr_div_q; logic [NumDiv-1:0] div_read_d, div_read_q, div_write_d, div_write_q; @@ -69,13 +74,19 @@ module ita_softmax logic [SoftmaxAccDataWidth-1:0] data_to_fifo, data_from_fifo; soft_fifo_usage_t fifo_usage ; + logic [N-1:0] disable_shift; + logic disable_row; + logic [M-1:0]disable_col; + + assign disable_row = ((count_soft_q2 & (M-1)) + tile_y_q * M) > (ctrl_i.seq_length - 1); + assign pop_softmax_fifo_o = pop_from_fifo; assign soft_addr_div_o = addr_div_q; always_comb begin tile_d = tile_q1; count_d = count_q1; - count_soft_d = count_soft_q; + count_soft_d = count_soft_q1; count_div_d = count_div_q; div_read_d = div_read_q; div_write_d = div_write_q; @@ -135,13 +146,20 @@ module ita_softmax //************ Pipeline Stage 1 ************// if (calc_en_q1) begin // Find max and accumulate - max_o = requant_oup_q; max_d = max_i; for (int i = 0; i < N; i++) begin shift_diff[i] = max_i - requant_oup_q[i]; - shift_d[i] = unsigned'(shift_diff[i]) >> 5; - if (shift_diff[i][4]) - shift_d[i] = (unsigned'(shift_diff[i]) >> 5) + 1; + disable_shift[i] = ( (tile_q2*M+N*(count_q2 >> $clog2(M))+i ) >= ctrl_i.seq_length); + + if (disable_shift[i]) begin + max_o[i] = 8'h80; + shift_d[i] = 4'hF; + end else begin + max_o[i] = requant_oup_q[i]; + shift_d[i] = unsigned'(shift_diff[i]) >> 5; + if (shift_diff[i][4]) + shift_d[i] = (unsigned'(shift_diff[i]) >> 5) + 1; + end end if (tile_q2 != '0 || count_q2>=M) begin // If not first part of the first row, normalize previous sum read_acc_en_o[0] = 1; @@ -162,7 +180,8 @@ module ita_softmax write_max_addr_o = count_q3; write_max_data_o = max_q; for (int i = 0; i < N; i++) begin - exp_sum_d += unsigned'(9'h100)>>shift_q[i]; + if (shift_d[i] != 4'hF) + exp_sum_d += unsigned'(9'h100)>>shift_q[i]; end if (tile_q3 != '0 || count_q3>=M) begin // If not first part of the first row exp_sum_d += ( unsigned'(read_acc_data_i[0]) >> shift_sum_q); @@ -211,28 +230,39 @@ module ita_softmax //*********** Stream Softmax ***********// // Main controller checks if division is ready if (calc_stream_soft_en_i) begin - count_soft_d = count_soft_q + 1; + count_soft_d = count_soft_q1 + 1; read_acc_en_o[1] = 1; - read_acc_addr_o[1] = count_soft_q[5:0]; + read_acc_addr_o[1] = count_soft_q1[5:0]; read_max_en_o[1] = 1; - read_max_addr_o[1] = count_soft_q[5:0]; + read_max_addr_o[1] = count_soft_q1[5:0]; if (count_soft_d == M*M/N) begin count_soft_d = '0; end end if (calc_stream_soft_en_q) begin - for (int i = 0; i < M; i++) begin - shift_inp_diff[i] = read_max_data_i[1]-inp_i[i]; - shift_inp[i] = unsigned'(shift_inp_diff[i]) >> 5; - if (shift_inp_diff[i][4]) - shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> 5) + 1; - inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i]; + if (disable_row) begin + inp_stream_soft_o = { M { '0 } }; + end else begin + for (int i = 0; i < M; i++) begin + disable_col[i] = ((inner_tile_q*M + i) >= ctrl_i.seq_length); + if (disable_col[i]) begin + inp_stream_soft_o[i] = '0; + end else begin + shift_inp_diff[i] = read_max_data_i[1]-inp_i[i]; + shift_inp[i] = unsigned'(shift_inp_diff[i]) >> 5; + if (shift_inp_diff[i][4]) + shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> 5) + 1; + inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i]; + end + end end end end always_ff @(posedge clk_i or negedge rst_ni) begin if(~rst_ni) begin + inner_tile_q <= '0; + tile_y_q <= '0; tile_q4 <= '0; tile_q3 <= '0; tile_q2 <= '0; @@ -240,8 +270,9 @@ module ita_softmax count_q4 <= M*M/N; count_q3 <= M*M/N; count_q2 <= M*M/N; - count_q1 <= M*M/N; - count_soft_q <= '0; + count_q1 <= M*M/N; + count_soft_q1 <= '0; + count_soft_q2 <= '0; count_div_q <= '0; div_read_q <= '0; div_write_q <= '0; @@ -253,6 +284,8 @@ module ita_softmax shift_q <= '0; shift_sum_q <= '0; end else begin + inner_tile_q <= inner_tile_i; + tile_y_q <= tile_y_i; tile_q4 <= tile_q3; tile_q3 <= tile_q2; tile_q2 <= tile_q1; @@ -261,7 +294,8 @@ module ita_softmax count_q3 <= count_q2; count_q2 <= count_q1; count_q1 <= count_d; - count_soft_q <= count_soft_d; + count_soft_q1 <= count_soft_d; + count_soft_q2 <= count_soft_q1; count_div_q <= count_div_d; div_read_q <= div_read_d; div_write_q <= div_write_d; diff --git a/src/ita_softmax_top.sv b/src/ita_softmax_top.sv index e44fb4f..df2b421 100644 --- a/src/ita_softmax_top.sv +++ b/src/ita_softmax_top.sv @@ -19,7 +19,11 @@ module ita_softmax_top output counter_t soft_addr_div_o , output logic softmax_done_o , output logic pop_softmax_fifo_o , - output inp_t inp_stream_soft_o + output inp_t inp_stream_soft_o , + input counter_t tile_x_i , + input counter_t tile_y_i , + input counter_t inner_tile_i + ); logic [1:0] read_acc_en; @@ -113,7 +117,11 @@ module ita_softmax_top .write_max_en_o (write_max_en ), .write_max_addr_o (write_max_addr ), - .write_max_data_o (write_max_data ) + .write_max_data_o (write_max_data ), + + .tile_x_i (tile_x_i ), + .tile_y_i (tile_y_i ), + .inner_tile_i (inner_tile_i ) ); ita_register_file_1w_multi_port_read #( diff --git a/src/tb/ita_tb.sv b/src/tb/ita_tb.sv index e8f84a6..1b2a077 100644 --- a/src/tb/ita_tb.sv +++ b/src/tb/ita_tb.sv @@ -91,9 +91,11 @@ module ita_tb; "_", $sformatf( "%s", ACTIVATION) }; - N_TILES_SEQUENCE_DIM = SEQUENCE_LEN / M_TILE_LEN; - N_TILES_EMBEDDING_DIM = EMBEDDING_SIZE / M_TILE_LEN; - N_TILES_PROJECTION_DIM = PROJECTION_SPACE / M_TILE_LEN; + // Round up + N_TILES_SEQUENCE_DIM = (SEQUENCE_LEN + M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_EMBEDDING_DIM = (EMBEDDING_SIZE+ M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_PROJECTION_DIM = (PROJECTION_SPACE + M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_FEEDFORWARD = (FEEDFORWARD_SIZE + M_TILE_LEN -1) / M_TILE_LEN; N_TILES_LINEAR_PROJECTION = N_TILES_SEQUENCE_DIM * N_TILES_EMBEDDING_DIM * N_TILES_PROJECTION_DIM; N_TILES_ATTENTION = N_TILES_SEQUENCE_DIM * N_TILES_PROJECTION_DIM; N_ENTRIES_PER_TILE = M_TILE_LEN * M_TILE_LEN / N_PE; @@ -103,7 +105,6 @@ module ita_tb; N_ENTRIES_PER_SEQUENCE_DIM = N_ENTRIES_PER_TILE * N_TILES_SEQUENCE_DIM; N_ATTENTION_TILE_ROWS = N_TILES_SEQUENCE_DIM; N_GROUPS = 2 * N_ATTENTION_TILE_ROWS; - N_TILES_FEEDFORWARD = FEEDFORWARD_SIZE / M_TILE_LEN; N_TILES_INNER_DIM_LINEAR_PROJECTION[0] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[1] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[2] = N_TILES_EMBEDDING_DIM; @@ -489,6 +490,10 @@ task automatic apply_ITA_weights(input integer phase); ita_ctrl.tile_p = N_TILES_PROJECTION_DIM; ita_ctrl.tile_s = N_TILES_SEQUENCE_DIM; ita_ctrl.tile_f = N_TILES_FEEDFORWARD; + ita_ctrl.seq_length = SEQUENCE_LEN; + ita_ctrl.proj_space = PROJECTION_SPACE; + ita_ctrl.embed_size = EMBEDDING_SIZE; + ita_ctrl.ff_size = FEEDFORWARD_SIZE; read_activation_constants(ita_ctrl.gelu_b, ita_ctrl.gelu_c, ita_ctrl.activation_requant_mult, ita_ctrl.activation_requant_shift, ita_ctrl.activation_requant_add);