diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 7b4c3ce..2c97e25 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -26,9 +26,17 @@ generate_testvectors: stage: test script: - !reference [.setup_test, script] - - python testGenerator.py -H 1 -S 64 -E 64 -P 64 -F 64 --activation gelu - - python testGenerator.py -H 1 -S 128 -E 192 -P 256 -F 256 --activation gelu - - python testGenerator.py -H 1 -S 192 -E 256 -P 128 -F 128 --activation relu + - python testGenerator.py -H 1 -S 64 -E 64 -P 64 -F 64 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 128 -E 192 -P 256 -F 256 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 192 -E 256 -P 128 -F 128 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 1 -E 2 -P 3 -F 3 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 1 -E 2 -P 3 -F 3 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 63 -E 62 -P 61 -F 61 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 65 -E 130 -P 195 -F 195 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 127 -E 190 -P 253 -F 253 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 511 -E 511 -P 127 -F 63 --activation relu --skip-vector-validation + - python testGenerator.py -H 1 -S 63 -E 63 -P 50 -F 129 --activation gelu --skip-vector-validation + - python testGenerator.py -H 1 -S 255 -E 63 -P 511 -F 511 --activation identity --skip-vector-validation artifacts: paths: - simvectors @@ -94,6 +102,73 @@ run_sim: - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb +run_sim_padding: + stage: sim + needs: + - generate_testvectors + parallel: + matrix: + - S: 1 + E: 2 + P: 3 + F: 3 + activation: gelu + no_stalls: 0 + single_attention: 0 + - S: 1 + E: 2 + P: 3 + F: 3 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 63 + E: 62 + P: 61 + F: 61 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 65 + E: 130 + P: 195 + F: 195 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 127 + E: 190 + P: 253 + F: 253 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 511 + E: 511 + P: 127 + F: 63 + activation: relu + no_stalls: 0 + single_attention: 0 + - S: 63 + E: 63 + P: 50 + F: 129 + activation: gelu + no_stalls: 0 + single_attention: 0 + - S: 255 + E: 63 + P: 511 + F: 511 + activation: identity + no_stalls: 0 + single_attention: 0 + script: + - make bender + - make sim VSIM_FLAGS=-c s=$S e=$E p=$P f=$F bias=1 activation=$activation no_stalls=$no_stalls single_attention=$single_attention + - ./modelsim/return_status.sh modelsim/build/transcript $S $E $P $F ita_tb + run_hwpe_sim: stage: sim needs: diff --git a/.vscode/launch.json b/.vscode/launch.json index 4e54398..42f08d8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,6 +17,7 @@ "-S${input:seq_len}", "-E${input:emb_len}", "-P${input:prj_len}", + "--no-bias" ], } ], diff --git a/.yapfignore b/.yapfignore deleted file mode 100644 index 99bbfb7..0000000 --- a/.yapfignore +++ /dev/null @@ -1,3 +0,0 @@ -*third_party/ -*venv/ -*simvectors/ \ No newline at end of file diff --git a/Bender.lock b/Bender.lock index f737675..7ecd6e3 100644 --- a/Bender.lock +++ b/Bender.lock @@ -15,8 +15,8 @@ packages: - common_verification - tech_cells_generic common_verification: - revision: 9c07fa860593b2caabd9b5681740c25fac04b878 - version: 0.2.3 + revision: fb1885f48ea46164a10568aeff51884389f67ae3 + version: 0.2.5 source: Git: https://github.com/pulp-platform/common_verification.git dependencies: [] diff --git a/Bender.yml b/Bender.yml index 6d2c5a6..03bdae0 100644 --- a/Bender.yml +++ b/Bender.yml @@ -32,6 +32,7 @@ sources: # Individual source files are simple string entries: - src/ita_package.sv - src/ita_accumulator.sv + - src/ita_masking.sv - src/ita_controller.sv - src/ita_dotp.sv - src/ita_fifo_controller.sv diff --git a/Makefile b/Makefile index 3359ca7..5dac0d1 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ BENDER_INSTALL_DIR = ${INSTALL_DIR}/bender VENV_BIN=venv/bin/ BENDER_VERSION = 0.28.1 -SIM_PATH ?= modelsim/build +SIM_FOLDER ?= build +SIM_PATH ?= modelsim/${SIM_FOLDER} SYNTH_PATH = synopsys BENDER_TARGETS = -t rtl -t test @@ -34,7 +35,29 @@ else ifeq ($(activation), relu) else activation_int = 0 endif -vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) + +mask ?= none +ifeq ($(mask), upper_triangular) + mask_int = 1 +else ifeq ($(mask), lower_triangular) + mask_int = 2 +else ifeq ($(mask), strided) + mask_int = 3 +else ifeq ($(mask), upper_strided) + mask_int = 4 +else ifeq ($(mask), lower_strided) + mask_int = 5 +else ifeq ($(mask), sliding_window) + mask_int = 6 +else ifeq ($(mask), strided_sliding_window) + mask_int = 7 +else + mask_int = 0 +endif + +i ?= 1 + +vlog_defs += -DNO_STALLS=$(no_stalls) -DSINGLE_ATTENTION=$(single_attention) -DSEQ_LENGTH=$(s) -DEMBED_SIZE=$(e) -DPROJ_SPACE=$(p) -DFF_SIZE=$(f) -DBIAS=$(bias) -DACTIVATION=$(activation_int) -DMASK=$(mask_int) -DMASK_INDEX=$(i) ifeq ($(target), sim_ita_hwpe_tb) BENDER_TARGETS += -t ita_hwpe -t ita_hwpe_test diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 0068723..e12c4ad 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -22,6 +22,9 @@ import numpy as np from numpy.typing import ArrayLike, DTypeLike +import seaborn as sns +import matplotlib.pyplot as plt + from .softmax import fastSoftmax, realSoftmax, streamingPartialSoftmax from .gelu import gelu_requantize, i_gelu_requantized, get_i_gelu_constants, get_i_gelu_requantized_constants from .util import (generate_matrix_mem, pack_8b_to_word, pack_array_8b_to_word, pack_hex_24b, pack_multihead_8b_to_word, @@ -43,6 +46,7 @@ def __init__(self, path: Union[str, os.PathLike], bias: bool = True, activation: str = "identity", + mask: str = "none", Q: ArrayLike = None, K: ArrayLike = None, V: ArrayLike = None, @@ -69,10 +73,10 @@ def __init__(self, self._init_paths(path) - self.S_ITA = max(64, S) - self.P_ITA = max(64, P) - self.E_ITA = max(64, E) - self.F_ITA = max(64, F) + self.S_ITA = ((S - 1) // self.ITA_M + 1) * self.ITA_M + self.P_ITA = ((P - 1) // self.ITA_M + 1) * self.ITA_M + self.E_ITA = ((E - 1) // self.ITA_M + 1) * self.ITA_M + self.F_ITA = ((F - 1) // self.ITA_M + 1) * self.ITA_M self.H_ITA = 4 self.split = self.ITA_M // self.ITA_N @@ -83,6 +87,7 @@ def __init__(self, self.H = H self.bias = bias self.activation = activation + self.mask = mask # Setup transformation functions self.split_m_m = partial(split_matrix, block_shape = (self.ITA_M, self.ITA_M)) @@ -110,10 +115,10 @@ def _validate_matrix_constraints(self, K: ArrayLike, V: ArrayLike): assert (np.all(K == V)) # WIESEP: Current restrictions for ITA - assert (self.S % self.ITA_M == 0), "Sequence length must be divisible by ITA_M" - assert (self.P % self.ITA_M == 0), "Projection space must be divisible by ITA_M" - assert (self.E % self.ITA_M == 0), "Embedding size must be divisible by ITA_M" - assert (self.F % self.ITA_M == 0), "Feedforward size must be divisible by ITA_M" + # assert (self.S % self.ITA_M == 0), "Sequence length must be divisible by ITA_M" + # assert (self.P % self.ITA_M == 0), "Projection space must be divisible by ITA_M" + # assert (self.E % self.ITA_M == 0), "Embedding size must be divisible by ITA_M" + # assert (self.F % self.ITA_M == 0), "Feedforward size must be divisible by ITA_M" assert ( self.E <= 512 @@ -172,7 +177,9 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bq_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bq = np.pad(self.Bq_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bq_broadcast = np.reshape(np.repeat(self.Bq, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bq_broadcast = np.reshape(np.repeat(self.Bq, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bq_broadcast = np.pad(self.Bq_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) + if self.bias: self.Bk_in = random_shuffled_tensor( @@ -180,7 +187,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bk_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bk = np.pad(self.Bk_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bk_broadcast = np.reshape(np.repeat(self.Bk, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bk_broadcast = np.reshape(np.repeat(self.Bk, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bk_broadcast = np.pad(self.Bk_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bv_in = random_shuffled_tensor( @@ -188,7 +196,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bv_in = np.zeros((self.H, self.P), dtype = np.int8) self.Bv = np.pad(self.Bv_in, ((0, 0), (0, self.P_ITA - self.P))) - self.Bv_broadcast = np.reshape(np.repeat(self.Bv, self.S, axis = 0), (self.H, self.S, self.P)) + self.Bv_broadcast = np.reshape(np.repeat(self.Bv, self.S, axis = 0), (self.H, self.S, self.P_ITA)) + self.Bv_broadcast = np.pad(self.Bv_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bo_in = random_shuffled_tensor( @@ -196,7 +205,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bo_in = np.zeros((self.H, self.E), dtype = np.int8) self.Bo = np.pad(self.Bo_in, ((0, 0), (0, self.E_ITA - self.E))) - self.Bo_broadcast = np.reshape(np.repeat(self.Bo, self.S, axis = 0), (self.H, self.S, self.E)) + self.Bo_broadcast = np.reshape(np.repeat(self.Bo, self.S, axis = 0), (self.H, self.S, self.E_ITA)) + self.Bo_broadcast = np.pad(self.Bo_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bff_in = random_shuffled_tensor( @@ -204,14 +214,16 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, else: self.Bff_in = np.zeros((1, self.F), dtype = np.int8) self.Bff = np.pad(self.Bff_in, ((0, 0), (0, self.F_ITA - self.F))) - self.Bff_broadcast = np.reshape(np.repeat(self.Bff, self.S, axis = 0), (1, self.S, self.F)) + self.Bff_broadcast = np.reshape(np.repeat(self.Bff, self.S, axis = 0), (1, self.S, self.F_ITA)) + self.Bff_broadcast = np.pad(self.Bff_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) if self.bias: self.Bff2_in = random_shuffled_tensor( (1, self.E), int(np.log2(self.E)) + 8, type = np.int32) if Bff2 is None else Bff2 else: self.Bff2_in = np.zeros((1, self.E), dtype = np.int8) self.Bff2 = np.pad(self.Bff2_in, ((0, 0), (0, self.E_ITA - self.E))) - self.Bff2_broadcast = np.reshape(np.repeat(self.Bff2, self.S, axis = 0), (1, self.S, self.E)) + self.Bff2_broadcast = np.reshape(np.repeat(self.Bff2, self.S, axis = 0), (1, self.S, self.E_ITA)) + self.Bff2_broadcast = np.pad(self.Bff2_broadcast, ((0, 0), (0, self.S_ITA - self.S), (0, 0))) #### Intermediate tensors #### @@ -231,6 +243,8 @@ def _initialize_tensors(self, Q, V, Wq, Wk, Wv, Wo, Bq, Bk, Bv, Bo, FF_in, Wff, self.A_real_softmax = np.zeros([self.H, self.S, self.S], dtype = np.int8) self.A_partial_softmax = np.zeros([self.H, self.S, self.S], dtype = np.int8) + self.Mask = None + self.O_soft = None self.O_soft_requant = None @@ -348,6 +362,9 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: # Weight Wqk is H x E x P # Transpose Wqk to H x P x E + # print(f"qk: {qk.shape}") + # print(f"qk: {weight.shape}") + weight = np.transpose(weight, (0, 2, 1)) tile_x = qk.shape[0] // self.ITA_M # S // ITA_M @@ -362,6 +379,19 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: Input = np.tile(Input, [1, 1, self.split, 1]) # Repeat each tile number of output row tiles times Input = np.tile(Input, [1, tile_y, 1, 1]).reshape((-1, self.ITA_M)) + # fig, ax = plt.subplots(1, 2) # Create a figure with two subplots + # im0 = ax[0].imshow(Input, cmap='viridis') + # im1 = ax[1].imshow(np.squeeze(weight, axis=0)) + + # # Add colorbars for each image if needed + # fig.colorbar(im0, ax=ax[0]) + # fig.colorbar(im1, ax=ax[1]) + + # # Set titles for each subplot + # ax[0].set_title("Inputs") + # ax[1].set_title("Weights") + + plt.show() write_matrix(Input, input_file, self.paths["standalone"]) # Transposed Weight Wqk is H x P x E @@ -373,7 +403,7 @@ def tiler_QK(self, qk: np.ndarray, weight: np.ndarray, bias: np.ndarray, output: # Bias Bqk is H x P # Broadcast Bias Bqk to H x S x P - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) for h in range(self.H): Bias = split_matrix(bias[h], (self.ITA_M, self.ITA_N)) write_matrix(Bias, f"{bias_file}_{h}", self.paths["standalone"]) @@ -416,7 +446,7 @@ def tiler_V(self, v, weight, bias, output, input_file, weight_file, bias_file, o # Bias Bv is H x P # Broadcast Bias Bv to H x S x P - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) # Transpose Bias Bv to H x P x S bias = np.transpose(bias, (0, 2, 1)) for h in range(self.H): @@ -497,7 +527,7 @@ def tiler_Out(self, O, weight, bias, output, input_file, weight_file, bias_file, # Bias Bo is H x E # Broadcast Bias Bo to H x S x E - bias = np.tile(bias, [1, self.S, 1]) + bias = np.tile(bias, [1, self.S_ITA, 1]) for h in range(self.H): Bias = split_matrix(bias[h], (self.ITA_M, self.ITA_N)) write_matrix(Bias, f"{bias_file}_{h}", self.paths["standalone"]) @@ -512,6 +542,12 @@ def step1_Qp(self): self.Qp = np.clip(self.Qp, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.Qp_requant = requantize(self.Qp, self.requant_eps_mult[0], self.requant_right_shift[0], self.requant_add[0]) + + # Set padded values to zero + if (self.S_ITA - self.S) > 0: + self.Qp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Qp_requant[:, :, -(self.P_ITA - self.P):] = 0 self.tiler_QK(self.Q, self.Wq, self.Bq, self.Qp_requant, "Q", "Wq", "Bq", "Qp") @@ -521,6 +557,11 @@ def step2_Kp(self): self.Kp_requant = requantize(self.Kp, self.requant_eps_mult[1], self.requant_right_shift[1], self.requant_add[1]) + if (self.S_ITA - self.S) > 0: + self.Kp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Kp_requant[:, :, -(self.P_ITA - self.P):] = 0 + self.tiler_QK(self.K, self.Wk, self.Bk, self.Kp_requant, "K", "Wk", "Bk", "Kp") def step3_Vp(self): @@ -529,24 +570,131 @@ def step3_Vp(self): self.Vp_requant = requantize(self.Vp, self.requant_eps_mult[2], self.requant_right_shift[2], self.requant_add[2]) + if (self.S_ITA - self.S) > 0: + self.Vp_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.Vp_requant[:, :, -(self.P_ITA - self.P):] = 0 + # Compute Vp in transposed form self.tiler_V(self.V, self.Wv, self.Bv, self.Vp_requant, "V", "Wv", "Bv", "Vp") - def step4_QK(self, no_partial_softmax): + def apply_mask(self, index): + # True means this positon gets masked + if (self.mask == 'upper_triangular'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool') + if (0 < index and index < self.S): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range((i + index), self.Mask.shape[2]): + self.Mask[h][i][j] = True + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'lower_triangular'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool') + if (0 < index and index < self.S): + for h in range(self.Mask.shape[0]): + for i in range(index, self.Mask.shape[1]): + for j in range((i-(index-1))): + self.Mask[h][i][j] = True + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'strided'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + if (index > 0 and (index & (index - 1)) == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + self.Mask[h][i][i] = False + for j in range(i, self.Mask.shape[2], index): + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'upper_strided'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + if (index > 0 and (index & (index - 1)) == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2], index): + self.Mask[h][i][j] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'lower_strided'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + if (index > 0 and (index & (index - 1)) == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2], index): + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'sliding_window'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, min((index + i), self.Mask.shape[2])): + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'strided_sliding_window'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + if (index > 0 and (index & (index - 1)) == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2]): + if (j < (index + i) or ((j-i) % index == 0)): + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'none'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool') + else: + raise ValueError("Mask not supported") + + + def step4_QK(self, no_partial_softmax, index): self.A = np.array( [np.matmul(self.Qp_requant[i], np.transpose(self.Kp_requant[i]), dtype = np.int32) for i in range(self.H)]) self.A = np.clip(self.A, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.A_requant = requantize(self.A, self.requant_eps_mult[3], self.requant_right_shift[3], self.requant_add[3]) + + self.apply_mask(index) + + if (self.S_ITA - self.S) > 0: + self.A_requant[:, -(self.S_ITA - self.S):, :] = 0 + self.A_requant[:, :, -(self.S_ITA - self.S):] = 0 + self.soft(no_partial_softmax) self.tiler_AV(self.Qp_requant, self.Kp_requant, self.A_requant, "Qp_in", "Kp_in", "A") def soft(self, no_partial_softmax = False): - self.A_real_softmax = realSoftmax(self.A_requant) + self.A_real_softmax = realSoftmax(self.A_requant[:, :self.S, :self.S]) + self.A_real_softmax = np.pad(self.A_real_softmax, ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) + if no_partial_softmax: - self.A_partial_softmax = fastSoftmax(self.A_requant) + self.A_partial_softmax = fastSoftmax(self.A_requant[:, :self.S, :self.S]) + self.A_partial_softmax = np.pad(self.A_partial_softmax, + ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) else: - self.A_partial_softmax = streamingPartialSoftmax(self.A_requant) + self.A_partial_softmax = streamingPartialSoftmax(self.A_requant[:, :self.S, :self.S], self.Mask) + self.A_partial_softmax[self.Mask] = 0 + self.A_partial_softmax = np.pad(self.A_partial_softmax, + ((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S))) if self.H == 1: A_save = [np.tile(self.A_partial_softmax[i], [self.split, 1]) for i in range(self.H)] @@ -555,17 +703,25 @@ def soft(self, no_partial_softmax = False): A_save = self.A_partial_softmax[h] write_matrix(A_save, f"A_soft_{h}", self.paths["standalone"]) - def step5_AV(self): + def step5_AV(self): self.O_soft = np.array([ np.matmul(self.A_partial_softmax[i].astype(np.uint8), self.Vp_requant[i], dtype = np.int32) for i in range(self.H) ]) + self.O_soft = np.clip(self.O_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.O_soft_requant = requantize(self.O_soft, self.requant_eps_mult[4], self.requant_right_shift[4], self.requant_add[4]) + + if (self.S_ITA - self.S) > 0: + self.O_soft_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.P_ITA - self.P) > 0: + self.O_soft_requant[:, :, -(self.P_ITA - self.P):] = 0 self.tiler_AV(self.A_requant, np.transpose(self.Vp_requant, (0, 2, 1)), self.O_soft_requant, "A_stream_soft_in", "Vp_in", "O_soft") + + def apply_activation(self, preactivation, activation): if activation not in ["gelu", "relu", "identity"]: @@ -590,6 +746,12 @@ def step6_O(self): self.Out_soft = np.clip(self.Out_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1) self.Out_soft_requant = requantize(self.Out_soft, self.requant_eps_mult[5], self.requant_right_shift[5], self.requant_add[5]) + + if (self.S_ITA - self.S) > 0: + self.Out_soft_requant[:, -(self.S_ITA - self.S):, :] = 0 + if (self.E_ITA - self.E) > 0: + self.Out_soft_requant[:, :, -(self.E_ITA - self.E):] = 0 + self.tiler_Out(self.O_soft_requant, self.Wo, self.Bo, self.Out_soft_requant, "O_soft_in", "Wo", "Bo", "Out_soft") @@ -599,7 +761,7 @@ def feedforward_layer(self): self.FFp_requant = requantize(self.FFp, self.requant_eps_mult_ffn[0], self.requant_right_shift_ffn[0], self.requant_add_ffn[0]) self.FFp_requant = self.apply_activation(self.FFp_requant, self.activation) - + self.tiler_QK(self.FF, self.Wff, self.Bff, self.FFp_requant, "FF", "Wff", "Bff", "FFp") self.FF2p = np.matmul(self.FFp_requant, self.Wff2, dtype = np.int32) + self.Bff2_broadcast @@ -934,8 +1096,8 @@ def export_mempool(self, path): def export_numpy(self): assert np.all(np.equal(self.K, self.V)), "For ITA, keys and values have to be equal" - q = self.Q - k = self.K + q = self.Q_in + k = self.K_in w1 = self.Wq_in b1 = self.Bq_in w2 = self.Wk_in @@ -971,11 +1133,13 @@ def generateTestVectors(path, **kwargs): f = kwargs['F'] h = kwargs['H'] activation = kwargs['activation'] + mask = kwargs['mask'] + index = kwargs['I'] bias = int(not kwargs['no_bias']) export_snitch_cluster = kwargs['export_snitch_cluster'] export_mempool = kwargs['export_mempool'] - acc1 = Transformer(s, p, e, f, h, bias = bias, path = path, activation = activation) + acc1 = Transformer(s, p, e, f, h, bias = bias, path = path, activation = activation, mask = mask) if kwargs['verbose']: print("=> Generating test vectors...") @@ -983,7 +1147,7 @@ def generateTestVectors(path, **kwargs): acc1.step1_Qp() acc1.step2_Kp() acc1.step3_Vp() - acc1.step4_QK(kwargs['no_partial_softmax']) + acc1.step4_QK(kwargs['no_partial_softmax'], index=index) acc1.step5_AV() acc1.step6_O() acc1.step7_Osum() @@ -1200,7 +1364,7 @@ def plot_heatmap(tensor, title, ax): def util_main(**kwargs): B = 8 log2e = np.log2(np.exp(1)) - range_scale = 32 + range_scale = 1 eps_max = range_scale * B / (2**B) N = 1024 diff --git a/PyITA/ITA_onnx.py b/PyITA/ITA_onnx.py index eda85f3..235cf00 100644 --- a/PyITA/ITA_onnx.py +++ b/PyITA/ITA_onnx.py @@ -259,8 +259,8 @@ def exportONNX(path, verbose = False, **kwargs): # Transform from MUL-DIV-ADD to MUL-ADD-DIV RQ_ADD = (RQ_ADD * 2**RQ_SHIFT.astype(np.float32)) - input0_values = np.expand_dims(inputs['q'][:(S * E // 64), :].reshape(S, E), axis = 0) - input1_values = np.expand_dims(inputs['k'][:(S * E // 64), :].reshape(S, E), axis = 0) + input0_values = np.expand_dims(inputs['q'].reshape(S, E), axis = 0) + input1_values = np.expand_dims(inputs['k'].reshape(S, E), axis = 0) np.savez(path + "inputs.npz", input0_values, input1_values) diff --git a/PyITA/softmax.py b/PyITA/softmax.py index 7545086..feea3ca 100644 --- a/PyITA/softmax.py +++ b/PyITA/softmax.py @@ -14,6 +14,8 @@ # # ---------------------------------------------------------------------- +import argparse + import numpy as np @@ -30,7 +32,7 @@ def fastSoftmax(x, integerize = True): B = 8 # Scaling factor - range_scale = 32 + range_scale = 1 eps_max = range_scale * B / (2**B) # Find the maximum for each row in the current column block (consisting of 16 columns) @@ -66,22 +68,19 @@ def fastSoftmax(x, integerize = True): return np.repeat(exp_sum_inverse, seq_length).reshape(n_heads, seq_length, seq_length) / 2**shift -def streamingPartialSoftmax(x, integerize = True): +def streamingPartialSoftmax(x, mask, integerize = True): if not integerize: x = x.astype(np.float32) seq_length = x.shape[-1] n_heads = x.shape[-3] - width = 16 # 16 PE (processing units) - groups = seq_length // width - - assert seq_length % width == 0, f"Sequence length must be a multiple of width ({width})" + PE = 16 # 16 PE (processing units) # Number of bits B = 8 # Scaling factor - range_scale = 32 + range_scale = 1 eps_max = range_scale * B / (2**B) if integerize: @@ -103,12 +102,17 @@ def streamingPartialSoftmax(x, integerize = True): global_max = np.full((n_heads, seq_length), -np.Infinity, dtype = np.float32) ## STAGE 1: Compute the denominator of the softmax - for i in range(groups): + for i in range((seq_length + PE - 1) // PE): + width = seq_length % PE if i * PE + PE > seq_length else PE + + mask_slice = mask[... ,i*PE:(i*PE)+width] + x_slice = x[..., 0 + i * PE:width + i * PE] + # Find the maximum for each row in the current column block (consisting of 16 columns) if integerize: - current_max = np.max(x[..., 0 + i * width:width + i * width].astype(np.int32), axis = -1) + current_max = np.max(np.where(mask_slice, -128, x_slice.astype(np.int32)), axis = -1) else: - current_max = np.max(x[..., 0 + i * width:width + i * width].astype(np.float32), axis = -1) + current_max = np.max(np.where(mask_slice, -np.inf, x_slice.astype(np.float32)), axis = -1) # Initialize all shift values for each row to zero if integerize: @@ -131,11 +135,11 @@ def streamingPartialSoftmax(x, integerize = True): # Find the difference between the maximum and x in the current part of the row if integerize: - diff = np.repeat(global_max, width).reshape( - n_heads, seq_length, width) - x[..., 0 + i * width:width + i * width].astype(np.int32) + diff = np.repeat(global_max, width).reshape(n_heads, seq_length, + width) - x[..., 0 + i * PE:width + i * PE].astype(np.int32) else: - diff = np.repeat(global_max, width).reshape( - n_heads, seq_length, width) - x[..., 0 + i * width:width + i * width].astype(np.float32) + diff = np.repeat(global_max, width).reshape(n_heads, seq_length, + width) - x[..., 0 + i * PE:width + i * PE].astype(np.float32) # Shift the values by B-log2B -> multiply by B/2**B = log2e*eps_x # Make sure to do use round-half-up instead of round-half-to-even @@ -144,19 +148,23 @@ def streamingPartialSoftmax(x, integerize = True): else: shift = diff * eps_max + # Set shift value so high that 2**8 >> shift gets zero for all masked values + shift[mask_slice] = 32 + # Calculate exponential sum over the current part of the row and scale it by 2**10 to prevent underflow if integerize: # exp_sum = np.sum(2**8 >> shift, -1) # or exp_sum = np.floor(np.sum(2**8 >> shift, axis = -1)) else: exp_sum = np.sum(1 / 2**shift, axis = -1) - + # Update the accumulated sum and add the accumulation over the current part of the row if integerize: exp_partial_sum = np.floor((exp_partial_sum.astype(np.int32) >> shift_sum)) + exp_sum else: exp_partial_sum = (exp_partial_sum / 2**(shift_sum.astype(np.float32))) + exp_sum + ## STAGE 2: Calculate the softmax activation # Invert the partial sum if integerize: @@ -164,9 +172,14 @@ def streamingPartialSoftmax(x, integerize = True): else: exp_partial_sum_inverse = 1 / exp_partial_sum + # Find the difference between the maximum and x diff = np.repeat(global_max, seq_length).reshape(n_heads, seq_length, seq_length) - x.astype(np.int32) + # The global_max can be smaller than a few positions in x because not all values in x were considered for the global_max due to the mask. + # So diff should normally not be smaller than 0 + diff[mask] = 0 + # Shift the values by B-log2B -> multiply by B/2**B = log2e*eps_x # Make sure to do use round-half-up instead of round-half-to-even if integerize: @@ -198,7 +211,66 @@ def realSoftmax(A_requant, integerize = True): x = A_requant.astype(np.float64) exp = np.exp(x - np.max(x, axis = 2).reshape(n_heads, -1, 1)) + + # Replace nan with zero + exp = np.nan_to_num(exp) + if integerize: return (exp / exp.sum(axis = 2).reshape(n_heads, -1, 1) * (2**7 - 1)).astype(A_requant.dtype) else: return exp / exp.sum(axis = 2).reshape(n_heads, -1, 1) + + +if __name__ == "__main__": + np.set_printoptions(linewidth = 120) + np.set_printoptions(precision = 4) + + # Always print whole array + np.set_printoptions(threshold = np.inf) + + parser = argparse.ArgumentParser(description = "Test Utility for Softmax.") + # Sequence length + parser.add_argument("-S", default = 64, type = int, help = "Sequence length") + + # ITA sequence length + parser.add_argument("-M", default = 64, type = int, help = "ITA sequence length") + + # Quantiztion (float or int) + parser.add_argument("--int", action = "store_true", help = "Quantize to int") + parser.add_argument('--seed', default = 0, type = int, help = 'Random seed') + + args = parser.parse_args() + + ITA_WI = 8 + WO = 26 + ITA_N = 16 + ITA_M = args.M + + if args.seed != -1: + np.random.seed(args.seed) + + if args.int: + x = np.random.randint(-128, 128, (1, 1, args.S, args.S)).astype(np.int8) + else: + x = np.random.randn(1, 1, 16, 16).astype(np.float32) + + print("Input:") + print(x) + + # Pad last two dimensions to be a multiple of ITA_M + pad_x = (ITA_M - x.shape[-1] % ITA_M) % ITA_M + pad_y = (ITA_M - x.shape[-2] % ITA_M) % ITA_M + pad_value = -2**(ITA_WI - 1) if args.int else -np.inf + + print(f"Padding x by ({pad_y}, {pad_x}) with {pad_value}") + x_pad = np.pad(x, ((0, 0), (0, 0), (0, pad_y), (0, pad_x)), mode = 'constant', constant_values = pad_value) + + res = realSoftmax(x, integerize = args.int) + res_pad = realSoftmax(x_pad, integerize = args.int) + + res_unpad = res_pad[:, :, :args.S, :args.S] + + # Compare results + print(f"Equal: {np.allclose(res, res_unpad, atol = 1e-3)}") + print(res) + print(res_unpad) diff --git a/PyITA/util.py b/PyITA/util.py index 8690ae2..9f42f9b 100644 --- a/PyITA/util.py +++ b/PyITA/util.py @@ -52,6 +52,15 @@ def write_matrix(matrix: np.ndarray, name: str, path: Union[str, os.PathLike]): name (str): The name of the file. path (Union[str, os.PathLike]): The path to the directory where the file will be saved. """ + # output_files = ["Qp_0", "Kp_0", "Vp_0", "A_0", "Out_soft_0", "FFp_0", "FF2p_0"] + # if name in output_files: + # import matplotlib.pyplot as plt + # heatmap = np.squeeze(matrix) + # plt.imshow(heatmap, cmap='viridis') + # plt.colorbar() + # plt.title(f"{name}") + # plt.show() + with open('%s%s.txt' % (path, name), "wb+") as f: for row in matrix: np.savetxt(f, row, fmt = '%d') diff --git a/README.md b/README.md index 3a29e23..b58d20e 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ To install the required Python packages, create a virtual environment. Make sure $> python -m venv venv $> source venv/bin/activate $> pip install -r requirements.txt +$> pip install -r requirements.dev.txt # Only required for PyITA/test_gelu.py ``` If you want to enable pre-commit hooks, which perform code formatting and linting, run the following command: diff --git a/modelsim/Makefile b/modelsim/Makefile index 7d181aa..a9a5183 100644 --- a/modelsim/Makefile +++ b/modelsim/Makefile @@ -6,13 +6,13 @@ all: lib build QUESTA_SEPP ?= questa-2023.4 -buildpath ?= build +buildpath ?= $(SIM_FOLDER) VOPT ?= $(QUESTA_SEPP) vopt VSIM ?= $(QUESTA_SEPP) vsim VLIB ?= $(QUESTA_SEPP) vlib VMAP ?= $(QUESTA_SEPP) vmap -VSIM_FLAGS ?= -gui -DEBUG ?= ON +VSIM_FLAGS ?= -c # -gui +DEBUG ?= OFF # ON lib: cd $(buildpath) && $(VLIB) work && $(VMAP) work work @@ -21,7 +21,7 @@ build: cd $(buildpath) && $(VSIM) -c -do 'source compile.tcl; quit' sim_ita_tb: lib build - cd $(buildpath) && $(VSIM) $(VSIM_FLAGS) -do 'source ../sim_ita_tb.tcl' + cd $(buildpath) && $(VSIM) $(VSIM_FLAGS) -do 'set DEBUG $(DEBUG); source ../sim_ita_tb.tcl' sim_ita_hwpe_tb: lib build cd $(buildpath) && $(VSIM) $(VSIM_FLAGS) -do 'set DEBUG $(DEBUG); source ../sim_ita_hwpe_tb.tcl' diff --git a/modelsim/sim_ita_tb.tcl b/modelsim/sim_ita_tb.tcl index 2eb5b81..5fff852 100644 --- a/modelsim/sim_ita_tb.tcl +++ b/modelsim/sim_ita_tb.tcl @@ -2,18 +2,12 @@ # Licensed under the Apache License, Version 2.0, see LICENSE for details. # SPDX-License-Identifier: Apache-2.0 -set DEBUG ON - # Set working library. set LIB work -if {$DEBUG == "ON"} { - set VOPT_ARG "+acc" - echo $VOPT_ARG - set DB_SW "-debugdb" -} else { - set DB_SW "" -} +set VOPT_ARG "+acc" +echo $VOPT_ARG +set DB_SW "-debugdb" quit -sim diff --git a/modelsim/sim_ita_tb_wave.tcl b/modelsim/sim_ita_tb_wave.tcl index 4f29360..8f035a7 100644 --- a/modelsim/sim_ita_tb_wave.tcl +++ b/modelsim/sim_ita_tb_wave.tcl @@ -11,7 +11,11 @@ add wave -noupdate /ita_tb/dut/i_inp2_mux/rst_ni add wave -noupdate /ita_tb/dut/i_inp2_mux/weight_i add wave -noupdate /ita_tb/dut/i_inp2_mux/inp2_o add wave -noupdate /ita_tb/dut/i_controller/ctrl_i -add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/* -add wave -expand -group Controller /ita_tb/dut/i_controller/* +add wave -noupdate /ita_tb/dut/oup_o +add wave -noupdate /ita_tb/dut/inp1_q +add wave -noupdate /ita_tb/dut/inp2_q + +add wave -group {Controller} /ita_tb/dut/i_controller/* add wave -group {Softmax Controller} ita_tb/dut/i_softmax_top/i_softmax/* add wave -group {Accumulator} ita_tb/dut/i_accumulator/* +add wave -group {Masking} ita_tb/dut/i_controller/i_masking/* diff --git a/modelsim/sim_ita_tb_wave_important.tcl b/modelsim/sim_ita_tb_wave_important.tcl new file mode 100644 index 0000000..0cff06e --- /dev/null +++ b/modelsim/sim_ita_tb_wave_important.tcl @@ -0,0 +1,309 @@ +onerror {resume} +quietly WaveActivateNextPane {} 0 +add wave -noupdate /ita_tb/dut/i_inp1_mux/clk_i +add wave -noupdate /ita_tb/dut/i_inp1_mux/rst_ni +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp_i +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp1_o +add wave -noupdate /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate /ita_tb/dut/i_inp2_mux/rst_ni +add wave -noupdate /ita_tb/dut/i_inp2_mux/weight_i +add wave -noupdate /ita_tb/dut/i_inp2_mux/inp2_o +add wave -noupdate /ita_tb/dut/i_controller/ctrl_i +add wave -noupdate /ita_tb/dut/oup_o +add wave -noupdate /ita_tb/dut/inp1_q +add wave -noupdate /ita_tb/dut/inp2_q +add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/count_d +add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/bias_count +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/disable_row +add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/first_outer_dim +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/last_inner_tile_q6 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_controller/calc_en_o +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q1 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q2 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q3 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q4 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q5 +add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q6 +add wave -noupdate -expand -group {Masking Signals} -expand -group {In Softmax Module} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {In Softmax Module} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/clk_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/rst_ni +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/ctrl_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/step_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/calc_en_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/last_inner_tile_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/count_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/tile_x_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/tile_y_i +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_o +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_col_offset_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_col_offset_q +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_tile_x_pos_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_tile_x_pos_q +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_tile_y_pos_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_tile_y_pos_q +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_pos_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_pos_q +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_d +add wave -noupdate -expand -group {Masking Signals} -expand -group {Masking Module} /ita_tb/dut/i_controller/i_masking/mask_q +add wave -noupdate -group Requant /ita_tb/dut/i_controller/requant_add_i +add wave -noupdate -group Requant /ita_tb/dut/i_controller/requant_add_o +add wave -noupdate -group Bias /ita_tb/dut/inp_bias +add wave -noupdate -group Bias /ita_tb/dut/inp_bias_padded +add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q1 +add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q2 +add wave -noupdate /ita_tb/dut/calc_en_q4 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_i +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 +add wave -noupdate -radix binary -childformat {{{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[63]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[62]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[61]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[60]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[59]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[58]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[57]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[56]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[55]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[54]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[53]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[52]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[51]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[50]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[49]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[48]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[47]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[46]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[45]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[44]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[43]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[42]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[41]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[40]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[39]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[38]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[37]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[36]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[35]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[34]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[33]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[32]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[31]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[30]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[29]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[28]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[27]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[26]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[25]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[24]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[23]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[22]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[21]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[20]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[19]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[18]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[17]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[16]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[15]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[14]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[13]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[12]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[11]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[10]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[9]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[8]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[7]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[6]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[5]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[4]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[3]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[2]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[1]} -radix binary} {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[0]} -radix binary}} -subitemconfig {{/ita_tb/dut/i_softmax_top/i_softmax/disable_col[63]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[62]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[61]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[60]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[59]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[58]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[57]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[56]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[55]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[54]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[53]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[52]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[51]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[50]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[49]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[48]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[47]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[46]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[45]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[44]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[43]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[42]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[41]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[40]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[39]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[38]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[37]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[36]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[35]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[34]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[33]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[32]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[31]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[30]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[29]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[28]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[27]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[26]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[25]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[24]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[23]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[22]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[21]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[20]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[19]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[18]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[17]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[16]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[15]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[14]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[13]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[12]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[11]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[10]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[9]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[8]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[7]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[6]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[5]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[4]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[3]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[2]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[1]} {-height 16 -radix binary} {/ita_tb/dut/i_softmax_top/i_softmax/disable_col[0]} {-height 16 -radix binary}} /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/step_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/mask_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_i +add wave -noupdate -radix hexadecimal /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q +add wave -noupdate /ita_tb/dut/i_requantizer/clk_i +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_mask_q +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_x_d +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_x_q +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_y_d +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_y_q +add wave -noupdate -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_activation/data_q3 +add wave -noupdate -radix decimal /ita_tb/dut/inp_i +add wave -noupdate -group {All in one Phase} -radix decimal -childformat {{{/ita_tb/dut/inp[63]} -radix decimal} {{/ita_tb/dut/inp[62]} -radix decimal} {{/ita_tb/dut/inp[61]} -radix decimal} {{/ita_tb/dut/inp[60]} -radix decimal} {{/ita_tb/dut/inp[59]} -radix decimal} {{/ita_tb/dut/inp[58]} -radix decimal} {{/ita_tb/dut/inp[57]} -radix decimal} {{/ita_tb/dut/inp[56]} -radix decimal} {{/ita_tb/dut/inp[55]} -radix decimal} {{/ita_tb/dut/inp[54]} -radix decimal} {{/ita_tb/dut/inp[53]} -radix decimal} {{/ita_tb/dut/inp[52]} -radix decimal} {{/ita_tb/dut/inp[51]} -radix decimal} {{/ita_tb/dut/inp[50]} -radix decimal} {{/ita_tb/dut/inp[49]} -radix decimal} {{/ita_tb/dut/inp[48]} -radix decimal} {{/ita_tb/dut/inp[47]} -radix decimal} {{/ita_tb/dut/inp[46]} -radix decimal} {{/ita_tb/dut/inp[45]} -radix decimal} {{/ita_tb/dut/inp[44]} -radix decimal} {{/ita_tb/dut/inp[43]} -radix decimal} {{/ita_tb/dut/inp[42]} -radix decimal} {{/ita_tb/dut/inp[41]} -radix decimal} {{/ita_tb/dut/inp[40]} -radix decimal} {{/ita_tb/dut/inp[39]} -radix decimal} {{/ita_tb/dut/inp[38]} -radix decimal} {{/ita_tb/dut/inp[37]} -radix decimal} {{/ita_tb/dut/inp[36]} -radix decimal} {{/ita_tb/dut/inp[35]} -radix decimal} {{/ita_tb/dut/inp[34]} -radix decimal} {{/ita_tb/dut/inp[33]} -radix decimal} {{/ita_tb/dut/inp[32]} -radix decimal} {{/ita_tb/dut/inp[31]} -radix decimal} {{/ita_tb/dut/inp[30]} -radix decimal} {{/ita_tb/dut/inp[29]} -radix decimal} {{/ita_tb/dut/inp[28]} -radix decimal} {{/ita_tb/dut/inp[27]} -radix decimal} {{/ita_tb/dut/inp[26]} -radix decimal} {{/ita_tb/dut/inp[25]} -radix decimal} {{/ita_tb/dut/inp[24]} -radix decimal} {{/ita_tb/dut/inp[23]} -radix decimal} {{/ita_tb/dut/inp[22]} -radix decimal} {{/ita_tb/dut/inp[21]} -radix decimal} {{/ita_tb/dut/inp[20]} -radix decimal} {{/ita_tb/dut/inp[19]} -radix decimal} {{/ita_tb/dut/inp[18]} -radix decimal} {{/ita_tb/dut/inp[17]} -radix decimal} {{/ita_tb/dut/inp[16]} -radix decimal} {{/ita_tb/dut/inp[15]} -radix decimal} {{/ita_tb/dut/inp[14]} -radix decimal} {{/ita_tb/dut/inp[13]} -radix decimal} {{/ita_tb/dut/inp[12]} -radix decimal} {{/ita_tb/dut/inp[11]} -radix decimal} {{/ita_tb/dut/inp[10]} -radix decimal} {{/ita_tb/dut/inp[9]} -radix decimal} {{/ita_tb/dut/inp[8]} -radix decimal} {{/ita_tb/dut/inp[7]} -radix decimal} {{/ita_tb/dut/inp[6]} -radix decimal} {{/ita_tb/dut/inp[5]} -radix decimal} {{/ita_tb/dut/inp[4]} -radix decimal} {{/ita_tb/dut/inp[3]} -radix decimal} {{/ita_tb/dut/inp[2]} -radix decimal} {{/ita_tb/dut/inp[1]} -radix decimal} {{/ita_tb/dut/inp[0]} -radix decimal}} -subitemconfig {{/ita_tb/dut/inp[63]} {-height 16 -radix decimal} {/ita_tb/dut/inp[62]} {-height 16 -radix decimal} {/ita_tb/dut/inp[61]} {-height 16 -radix decimal} {/ita_tb/dut/inp[60]} {-height 16 -radix decimal} {/ita_tb/dut/inp[59]} {-height 16 -radix decimal} {/ita_tb/dut/inp[58]} {-height 16 -radix decimal} {/ita_tb/dut/inp[57]} {-height 16 -radix decimal} {/ita_tb/dut/inp[56]} {-height 16 -radix decimal} {/ita_tb/dut/inp[55]} {-height 16 -radix decimal} {/ita_tb/dut/inp[54]} {-height 16 -radix decimal} {/ita_tb/dut/inp[53]} {-height 16 -radix decimal} {/ita_tb/dut/inp[52]} {-height 16 -radix decimal} {/ita_tb/dut/inp[51]} {-height 16 -radix decimal} {/ita_tb/dut/inp[50]} {-height 16 -radix decimal} {/ita_tb/dut/inp[49]} {-height 16 -radix decimal} {/ita_tb/dut/inp[48]} {-height 16 -radix decimal} {/ita_tb/dut/inp[47]} {-height 16 -radix decimal} {/ita_tb/dut/inp[46]} {-height 16 -radix decimal} {/ita_tb/dut/inp[45]} {-height 16 -radix decimal} {/ita_tb/dut/inp[44]} {-height 16 -radix decimal} {/ita_tb/dut/inp[43]} {-height 16 -radix decimal} {/ita_tb/dut/inp[42]} {-height 16 -radix decimal} {/ita_tb/dut/inp[41]} {-height 16 -radix decimal} {/ita_tb/dut/inp[40]} {-height 16 -radix decimal} {/ita_tb/dut/inp[39]} {-height 16 -radix decimal} {/ita_tb/dut/inp[38]} {-height 16 -radix decimal} {/ita_tb/dut/inp[37]} {-height 16 -radix decimal} {/ita_tb/dut/inp[36]} {-height 16 -radix decimal} {/ita_tb/dut/inp[35]} {-height 16 -radix decimal} {/ita_tb/dut/inp[34]} {-height 16 -radix decimal} {/ita_tb/dut/inp[33]} {-height 16 -radix decimal} {/ita_tb/dut/inp[32]} {-height 16 -radix decimal} {/ita_tb/dut/inp[31]} {-height 16 -radix decimal} {/ita_tb/dut/inp[30]} {-height 16 -radix decimal} {/ita_tb/dut/inp[29]} {-height 16 -radix decimal} {/ita_tb/dut/inp[28]} {-height 16 -radix decimal} {/ita_tb/dut/inp[27]} {-height 16 -radix decimal} {/ita_tb/dut/inp[26]} {-height 16 -radix decimal} {/ita_tb/dut/inp[25]} {-height 16 -radix decimal} {/ita_tb/dut/inp[24]} {-height 16 -radix decimal} {/ita_tb/dut/inp[23]} {-height 16 -radix decimal} {/ita_tb/dut/inp[22]} {-height 16 -radix decimal} {/ita_tb/dut/inp[21]} {-height 16 -radix decimal} {/ita_tb/dut/inp[20]} {-height 16 -radix decimal} {/ita_tb/dut/inp[19]} {-height 16 -radix decimal} {/ita_tb/dut/inp[18]} {-height 16 -radix decimal} {/ita_tb/dut/inp[17]} {-height 16 -radix decimal} {/ita_tb/dut/inp[16]} {-height 16 -radix decimal} {/ita_tb/dut/inp[15]} {-height 16 -radix decimal} {/ita_tb/dut/inp[14]} {-height 16 -radix decimal} {/ita_tb/dut/inp[13]} {-height 16 -radix decimal} {/ita_tb/dut/inp[12]} {-height 16 -radix decimal} {/ita_tb/dut/inp[11]} {-height 16 -radix decimal} {/ita_tb/dut/inp[10]} {-height 16 -radix decimal} {/ita_tb/dut/inp[9]} {-height 16 -radix decimal} {/ita_tb/dut/inp[8]} {-height 16 -radix decimal} {/ita_tb/dut/inp[7]} {-height 16 -radix decimal} {/ita_tb/dut/inp[6]} {-height 16 -radix decimal} {/ita_tb/dut/inp[5]} {-height 16 -radix decimal} {/ita_tb/dut/inp[4]} {-height 16 -radix decimal} {/ita_tb/dut/inp[3]} {-height 16 -radix decimal} {/ita_tb/dut/inp[2]} {-height 16 -radix decimal} {/ita_tb/dut/inp[1]} {-height 16 -radix decimal} {/ita_tb/dut/inp[0]} {-height 16 -radix decimal}} /ita_tb/dut/inp +add wave -noupdate -group {All in one Phase} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o +add wave -noupdate -group {All in one Phase} -radix decimal /ita_tb/dut/inp1 +add wave -noupdate -radix unsigned /ita_tb/dut/inp1_q +add wave -noupdate -radix decimal /ita_tb/dut/i_accumulator/oup_i +add wave -noupdate -radix decimal -childformat {{{/ita_tb/dut/i_accumulator/result_d[15]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[14]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[13]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[12]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[11]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[10]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[9]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[8]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[7]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[6]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[5]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[4]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[3]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[2]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[1]} -radix decimal} {{/ita_tb/dut/i_accumulator/result_d[0]} -radix decimal}} -subitemconfig {{/ita_tb/dut/i_accumulator/result_d[15]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[14]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[13]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[12]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[11]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[10]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[9]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[8]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[7]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[6]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[5]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[4]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[3]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[2]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[1]} {-height 16 -radix decimal} {/ita_tb/dut/i_accumulator/result_d[0]} {-height 16 -radix decimal}} /ita_tb/dut/i_accumulator/result_d +add wave -noupdate -radix decimal /ita_tb/dut/i_accumulator/result_o +add wave -noupdate -radix hexadecimal -childformat {{{/ita_tb/dut/i_activation/data_i[15]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[14]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[13]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[12]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[11]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[10]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[9]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[8]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[7]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[6]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[5]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[4]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[3]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[2]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[1]} -radix decimal} {{/ita_tb/dut/i_activation/data_i[0]} -radix decimal}} -subitemconfig {{/ita_tb/dut/i_activation/data_i[15]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[14]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[13]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[12]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[11]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[10]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[9]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[8]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[7]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[6]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[5]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[4]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[3]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[2]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[1]} {-height 16 -radix decimal} {/ita_tb/dut/i_activation/data_i[0]} {-height 16 -radix decimal}} /ita_tb/dut/i_activation/data_i +add wave -noupdate /ita_tb/dut/i_activation/data_q1 +add wave -noupdate /ita_tb/dut/i_activation/data_q2 +add wave -noupdate /ita_tb/dut/i_activation/data_q3 +add wave -noupdate /ita_tb/dut/i_activation/data_q4 +add wave -noupdate /ita_tb/dut/i_activation/data_o +add wave -noupdate /ita_tb/dut/i_fifo/data_i +add wave -noupdate /ita_tb/dut/i_fifo/data_o +add wave -noupdate /ita_tb/dut/oup_o +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/clk_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/rst_ni +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/mode_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/eps_mult_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/right_shift_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/calc_en_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/calc_en_q_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/result_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_i +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/requant_oup_o +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/mult_signed +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/product +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_added +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_d +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/shifted_q +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q1 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q2 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q3 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/add_q4 +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/requant_oup_d +add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/requant_oup_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/clk_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/rst_ni +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ctrl_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/weight_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/weight_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_ready_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/oup_valid_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/oup_ready_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/pop_softmax_fifo_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/soft_addr_div_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_done_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/calc_en_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/last_inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_i +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_pad_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/mask_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_o +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/step_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/count_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/count_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_count +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_x_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_tile_x_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_tile_x_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/tile_y_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_tile_y_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/bias_tile_y_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_tile_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_tile_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_soft_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/ongoing_soft_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inp_bias_padded +add wave -noupdate -group Controller /ita_tb/dut/i_controller/last_time +add wave -noupdate -group Controller /ita_tb/dut/i_controller/inner_tile_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_outer_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/second_outer_dim +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_outer_dim_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/first_outer_dim_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/second_outer_dim_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/second_outer_dim_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_fifo +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div_done_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/softmax_div_done_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/busy_q +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_d +add wave -noupdate -group Controller /ita_tb/dut/i_controller/requant_add_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/clk_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/rst_ni +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/ctrl_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/step_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/requant_oup_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/soft_addr_div_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/softmax_done_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/pop_softmax_fifo_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inp_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_inp_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_valid_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_ready_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_valid_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_ready_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_oup_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_acc_data_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_acc_data_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/prev_max_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/read_max_data_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_en_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_addr_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/write_max_data_o +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_x_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_y_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_i +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_q4 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_q4 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/inner_tile_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_x_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/tile_y_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_x_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_x_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_y_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_y_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_outer_dim_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/mask_tile_outer_dim_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_mask_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/addr_div_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/addr_div_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_read_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_read_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_write_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/div_write_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/requant_oup_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_sum_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_sum_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp_diff +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q2 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q3 +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_full +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_empty +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/push_to_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/pop_from_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/data_to_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/data_from_fifo +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_usage +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_shift +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_row +add wave -noupdate -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/clk_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/rst_ni +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/calc_en_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/calc_en_q_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/first_tile_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/first_tile_q_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/last_tile_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/last_tile_q_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/oup_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/inp_bias_i +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/result_o +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_en +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_addr +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_data +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_data_unused +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_en +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_addr +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_data +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_addr_d +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/read_addr_q +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_addr_d +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_addr_q +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/result_d +add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/result_q + diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000..faa86d4 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,7 @@ +# Copyright 2023 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +torch +pytest +pytest-check diff --git a/requirements.txt b/requirements.txt index 4bfbc47..e8e03c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,4 @@ onnxruntime netron seaborn matplotlib -torch -pytest -pytest-check pre-commit diff --git a/src/hwpe/ita_hwpe_ctrl.sv b/src/hwpe/ita_hwpe_ctrl.sv index 1edd454..421224a 100644 --- a/src/hwpe/ita_hwpe_ctrl.sv +++ b/src/hwpe/ita_hwpe_ctrl.sv @@ -104,6 +104,12 @@ module ita_hwpe_ctrl ctrl_stream_o.bias_disable = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][2]; ctrl_stream_o.bias_direction = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][3]; ctrl_stream_o.output_disable = reg_file.hwpe_params[ITA_REG_CTRL_STREAM][4]; + ctrl_engine_o.seq_length = reg_file.hwpe_params[ITA_REG_SEQ_PROJ_LENGTH][9:0]; + ctrl_engine_o.proj_space = reg_file.hwpe_params[ITA_REG_SEQ_PROJ_LENGTH][19:10]; + ctrl_engine_o.embed_size = reg_file.hwpe_params[ITA_REG_EMBED_FF_SIZE][9:0]; + ctrl_engine_o.ff_size = reg_file.hwpe_params[ITA_REG_EMBED_FF_SIZE][19:10]; + ctrl_engine_o.mask_type = reg_file.hwpe_params[ITA_REG_MASK][2:0]; + ctrl_engine_o.mask_start_index = reg_file.hwpe_params[ITA_REG_MASK][12:3]; end logic [31:0] input_addr, bias_addr, output_addr; @@ -223,4 +229,4 @@ module ita_hwpe_ctrl end end -endmodule : ita_hwpe_ctrl \ No newline at end of file +endmodule : ita_hwpe_ctrl diff --git a/src/hwpe/ita_hwpe_package.sv b/src/hwpe/ita_hwpe_package.sv index 47e54b9..d63c0b6 100644 --- a/src/hwpe/ita_hwpe_package.sv +++ b/src/hwpe/ita_hwpe_package.sv @@ -12,7 +12,7 @@ package ita_hwpe_package; parameter int unsigned N_CORES = 9; parameter int unsigned N_CONTEXT = 4; parameter int unsigned ID_WIDTH = 2; - parameter int unsigned ITA_IO_REGS = 17; // 5 address + 11 parameters + 1 sync + parameter int unsigned ITA_IO_REGS = 20; // 5 address + 11 parameters + 1 sync + 2 length + 1 mask parameter int unsigned ITA_TCDM_DW = 1024; parameter int unsigned ITA_INPUT_DW = M*WI; @@ -38,7 +38,9 @@ package ita_hwpe_package; parameter int unsigned ITA_REG_CTRL_STREAM = 14; // ctrl_stream [0]: weight preload, ctrl_stream [1]: weight nextload, ctrl_stream [2]: bias disable, ctrl_stream [3]: bias direction, ctrl_stream [4]: output disable parameter int unsigned ITA_REG_GELU_B_C = 15; // gelu_b [15:0], gelu_c [31:16] parameter int unsigned ITA_REG_ACTIVATION_REQUANT = 16; // activation_requant_mult [7:0], activation_requant_shift [15:8], activation_requant_add [23:16] - + parameter int unsigned ITA_REG_SEQ_PROJ_LENGTH = 17; // seq_length [9:0], proj_space [19:10] + parameter int unsigned ITA_REG_EMBED_FF_SIZE = 18; // embed_size [9:0], ff_size [19:10] + parameter int unsigned ITA_REG_MASK = 19; // mask_type[2:0], mask_start_index [12:3] typedef struct packed { hci_package::hci_streamer_ctrl_t input_source_ctrl; @@ -74,4 +76,4 @@ package ita_hwpe_package; Done } state_t; -endpackage : ita_hwpe_package \ No newline at end of file +endpackage : ita_hwpe_package diff --git a/src/hwpe/tb/ita_hwpe_tb.sv b/src/hwpe/tb/ita_hwpe_tb.sv index 7f8e30c..a60b0c0 100644 --- a/src/hwpe/tb/ita_hwpe_tb.sv +++ b/src/hwpe/tb/ita_hwpe_tb.sv @@ -28,6 +28,9 @@ module ita_hwpe_tb; parameter integer FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; parameter activation_e ACTIVATION = `ifdef ACTIVATION `ACTIVATION `else Identity `endif; parameter integer SINGLE_ATTENTION = `ifdef SINGLE_ATTENTION `SINGLE_ATTENTION `else 0 `endif; + parameter mask_e MASK = mask_e'(`ifdef MASK `MASK `else None `endif); + parameter integer MASK_START_INDEX = `ifdef MASK_INDEX `MASK_INDEX `else 1 `endif; + integer N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM; integer N_ELEMENTS_PER_TILE; @@ -133,8 +136,13 @@ module ita_hwpe_tb; "_H1_B", $sformatf("%0d", `ifdef BIAS `BIAS `else 0 `endif), "_", - $sformatf( "%s", ACTIVATION) + $sformatf("%s", ACTIVATION), + "_", + $sformatf("%s", MASK), + "_I", + $sformatf("%0d", MASK_START_INDEX) }; + // Number of tiles in the sequence dimension N_TILES_SEQUENCE_DIM = SEQUENCE_LEN / M_TILE_LEN; // Number of tiles in the embedding dimension @@ -334,6 +342,8 @@ endfunction logic [5:0][31:0] ita_reg_rqs_val; logic [31:0] ita_reg_gelu_b_c_val; logic [31:0] ita_reg_activation_rqs_val; + logic [1:0][31:0] ita_reg_dims_val; + logic [31:0] ita_reg_mask_val; $timeformat(-9, 2, " ns", 10); @@ -348,6 +358,8 @@ endfunction ita_reg_tiles_val_compute(N_TILES_SEQUENCE_DIM, N_TILES_EMBEDDING_DIM, N_TILES_PROJECTION_DIM, N_TILES_FEEDFORWARD_DIM, ita_reg_tiles_val); ita_reg_eps_mult_val_compute(ita_reg_rqs_val); ita_reg_activation_constants_compute(ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val); + ita_reg_dims_compute(SEQUENCE_LEN, EMBEDDING_SIZE, PROJECTION_SPACE, FEEDFORWARD_SIZE, ita_reg_dims_val); + ita_reg_mask_compute(MASK, MASK_START_INDEX, ita_reg_mask_val); // soft clear PERIPH_WRITE( 32'h14, 32'h0, 32'h0, clk); @@ -358,7 +370,7 @@ endfunction PERIPH_READ( 32'h04, 32'h0, status, clk); // 1: Step Q - ita_compute_step(Q, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(Q, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); // 2: Step K if (SINGLE_ATTENTION == 1) begin @@ -367,7 +379,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; end - ita_compute_step(K, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(K, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); // 3: Step V if (SINGLE_ATTENTION == 1) begin @@ -376,7 +388,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[2] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[4] >> 8; end - ita_compute_step(V, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(V, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); if (SINGLE_ATTENTION == 1) begin // Reset the RQS values @@ -391,7 +403,7 @@ endfunction BASE_PTR_OUTPUT[AV] = BASE_PTR[19] + group * N_TILES_OUTER_X[AV] * N_ELEMENTS_PER_TILE; // 4: Step QK - ita_compute_step(QK, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(QK, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); // WIESEP: Hack to ensure that during the last tile of AV, the weight pointer is set correctly if (group == N_TILES_SEQUENCE_DIM-1) begin @@ -399,7 +411,7 @@ endfunction end // 5: Step AV - ita_compute_step(AV, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(AV, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); end // 6: Step OW @@ -411,7 +423,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 8; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 8; end - ita_compute_step(OW, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(OW, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); ita_reg_cnt = 0; @@ -424,7 +436,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 16; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 16; end - ita_compute_step(F1, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(F1, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); // 8: Step FF2 if (SINGLE_ATTENTION == 1) begin @@ -435,7 +447,7 @@ endfunction ita_reg_rqs_val[2] = ita_reg_rqs_val[3] >> 24; ita_reg_rqs_val[4] = ita_reg_rqs_val[5] >> 24; end - ita_compute_step(F2, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, clk); + ita_compute_step(F2, ita_reg_cnt, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_dims_val, ita_reg_mask_val, clk); // Wait for the last step to finish wait(evt); @@ -465,6 +477,8 @@ endfunction input logic [5:0][31:0] ita_reg_rqs_val, input logic [31:0] ita_reg_gelu_b_c_val, input logic [31:0] ita_reg_activation_rqs_val, + input logic [1:0][31:0] ita_reg_dims_val, + input logic [31:0] ita_reg_mask_val, ref logic clk_i ); @@ -520,7 +534,7 @@ endfunction $display(" - ITA Reg En 0x%0h, Ctrl Stream Val 0x%0h, Weight Ptr En %0d, Bias Ptr En %0d", ita_reg_en, ctrl_stream_val, weight_ptr_en, bias_ptr_en); // Program ITA - PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_en, ctrl_engine_val, ctrl_stream_val, clk_i); + PROGRAM_ITA(input_ptr, weight_ptr0, weight_ptr1, weight_ptr_en, bias_ptr, bias_ptr_en, output_ptr, ita_reg_tiles_val, ita_reg_rqs_val, ita_reg_gelu_b_c_val, ita_reg_activation_rqs_val, ita_reg_en, ctrl_engine_val, ctrl_stream_val, ita_reg_dims_val, ita_reg_mask_val, clk_i); // Wait for ITA to finish @(posedge clk_i); @@ -723,6 +737,25 @@ endfunction activation_requant_reg = activation_requant_mult | activation_requant_shift << 8 | activation_requant_add << 16; endtask + task automatic ita_reg_dims_compute( + input integer seq_length, + input integer proj_space, + input integer embed_size, + input integer ff_size, + output logic [1:0][31:0] reg_val + ); + reg_val[0] = seq_length | proj_space << 10; + reg_val[1] = embed_size | ff_size << 10; + endtask + + task automatic ita_reg_mask_compute( + input mask_e mask_type, + input integer mask_start_index, + output logic [31:0] reg_val + ); + reg_val = mask_type | mask_start_index << 3; + endtask + task automatic read_activation_constants( output gelu_const_t gelu_b, output gelu_const_t gelu_c, @@ -841,6 +874,8 @@ endfunction input logic ita_reg_en, input logic [31:0] ctrl_engine_val, input logic [31:0] ctrl_stream_val, + input logic [2:0][31:0] ita_reg_dims_val, + input logic [31:0] ita_reg_mask_val, ref logic clk_i ); PERIPH_WRITE( 4*ITA_REG_INPUT_PTR, ITA_REG_OFFSET, input_ptr, clk_i); @@ -861,6 +896,9 @@ endfunction PERIPH_WRITE( 4*ITA_REG_ADD1, ITA_REG_OFFSET, ita_reg_rqs_val[5], clk_i); PERIPH_WRITE( 4*ITA_REG_GELU_B_C, ITA_REG_OFFSET, ita_reg_gelu_b_c_val, clk_i); PERIPH_WRITE( 4*ITA_REG_ACTIVATION_REQUANT, ITA_REG_OFFSET, ita_reg_activation_rqs_val, clk_i); + PERIPH_WRITE( 4*ITA_REG_SEQ_PROJ_LENGTH, ITA_REG_OFFSET, ita_reg_dims_val[0], clk_i); + PERIPH_WRITE( 4*ITA_REG_EMBED_FF_SIZE, ITA_REG_OFFSET, ita_reg_dims_val[1], clk_i); + PERIPH_WRITE( 4*ITA_REG_MASK, ITA_REG_OFFSET, ita_reg_mask_val, clk_i); end PERIPH_WRITE( 4*ITA_REG_CTRL_ENGINE, ITA_REG_OFFSET, ctrl_engine_val, clk_i); diff --git a/src/ita.sv b/src/ita.sv index 2dad263..7822050 100644 --- a/src/ita.sv +++ b/src/ita.sv @@ -40,7 +40,7 @@ module ita logic weight_valid, weight_ready; inp_t inp, inp_stream_soft; weight_t inp1, inp1_q, inp2, inp2_q; - bias_t inp_bias, inp_bias_q1, inp_bias_q2; + bias_t inp_bias, inp_bias_padded, inp_bias_q1, inp_bias_q2; oup_t oup, oup_q, accumulator_oup; requant_const_t requant_mult, requant_shift, activation_requant_mult, activation_requant_shift; requant_oup_t requant_oup; @@ -48,6 +48,9 @@ module ita requant_mode_e requant_mode, activation_requant_mode; requant_oup_t post_activation; + //Masking + logic [N-1:0] mask, mask_q1, mask_q2, mask_q3, mask_q4, mask_q5, mask_q6; + // FIFO signals logic fifo_full, fifo_empty, push_to_fifo, pop_from_fifo; fifo_data_t data_to_fifo, data_from_fifo; @@ -106,6 +109,12 @@ module ita activation_q3 <= Identity; activation_q2 <= Identity; activation_q1 <= Identity; + mask_q6 <= '0; + mask_q5 <= '0; + mask_q4 <= '0; + mask_q3 <= '0; + mask_q2 <= '0; + mask_q1 <= '0; end else begin calc_en_q10 <= calc_en_q9; calc_en_q9 <= calc_en_q8; @@ -146,6 +155,12 @@ module ita activation_q3 <= activation_q2; activation_q2 <= activation_q1; activation_q1 <= ctrl_i.activation; + mask_q6 <= mask_q5; + mask_q5 <= mask_q4; + mask_q4 <= mask_q3; + mask_q3 <= mask_q2; + mask_q2 <= mask_q1; + mask_q1 <= mask; end end @@ -153,8 +168,8 @@ module ita if (!rst_ni) begin inp1_q <= '0; inp2_q <= '0; - inp_bias_q2 <= '0; inp_bias_q1 <= '0; + inp_bias_q2 <= '0; oup_q <= '0; end else begin if (calc_en_q2) begin @@ -162,7 +177,7 @@ module ita oup_q <= oup; end if (calc_en_q1) begin - inp_bias_q1 <= inp_bias; + inp_bias_q1 <= inp_bias_padded; inp1_q <= inp1; inp2_q <= inp2; end @@ -171,6 +186,12 @@ module ita assign oup_o = valid_o ? data_from_fifo : '0; + requant_oup_t requant_add_o; + + counter_t inner_tile; + counter_t tile_x; + counter_t tile_y; + ita_controller i_controller ( .clk_i (clk_i ), .rst_ni (rst_ni ), @@ -190,6 +211,14 @@ module ita .calc_en_o (calc_en ), .first_inner_tile_o (first_inner_tile ), .last_inner_tile_o (last_inner_tile ), + .tile_x_o (tile_x ), + .tile_y_o (tile_y ), + .inner_tile_o (inner_tile ), + .requant_add_i (requant_add ), + .requant_add_o (requant_add_o ), + .inp_bias_i (inp_bias ), + .inp_bias_pad_o (inp_bias_padded ), + .mask_o (mask ), .busy_o (busy_o ) ); @@ -255,13 +284,17 @@ module ita .soft_addr_div_o (soft_addr_div ), .softmax_done_o (softmax_done ), .pop_softmax_fifo_o (pop_softmax_fifo ), - .inp_stream_soft_o (inp_stream_soft ) + .inp_stream_soft_o (inp_stream_soft ), + .tile_x_i (tile_x ), + .tile_y_i (tile_y ), + .inner_tile_i (inner_tile ), + .mask_i (mask_q6 ) ); ita_requatization_controller i_requantization_controller ( .ctrl_i (ctrl_i ), - .requantizer_step_i (step_q4 ), + .requantizer_step_i (step_q4 ), .requant_mult_o (requant_mult ), .requant_shift_o (requant_shift ), .requant_add_o (requant_add ), @@ -283,7 +316,7 @@ module ita .calc_en_i ( calc_en_q4 && last_inner_tile_q4 ), .calc_en_q_i ( calc_en_q5 && last_inner_tile_q5 ), .result_i ( accumulator_oup ), - .add_i ( {N {requant_add}} ), + .add_i ( requant_add_o ), .requant_oup_o( requant_oup ) ); diff --git a/src/ita_controller.sv b/src/ita_controller.sv index 0fa8034..3a834bb 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -10,44 +10,72 @@ module ita_controller import ita_package::*; ( - input logic clk_i , - input logic rst_ni , - input ctrl_t ctrl_i , - input logic inp_valid_i , - output logic inp_ready_o , - input logic weight_valid_i , - output logic weight_ready_o , - input logic bias_valid_i , - output logic bias_ready_o , - input logic oup_valid_i , - input logic oup_ready_i , - input logic pop_softmax_fifo_i , - output step_e step_o , - input counter_t soft_addr_div_i , - input logic softmax_done_i , - output logic calc_en_o , - output logic first_inner_tile_o , - output logic last_inner_tile_o , - output logic busy_o + input logic clk_i , + input logic rst_ni , + input ctrl_t ctrl_i , + input logic inp_valid_i , + output logic inp_ready_o , + input logic weight_valid_i , + output logic weight_ready_o , + input logic bias_valid_i , + output logic bias_ready_o , + input logic oup_valid_i , + input logic oup_ready_i , + input logic pop_softmax_fifo_i , + output step_e step_o , + input counter_t soft_addr_div_i , + input logic softmax_done_i , + output logic calc_en_o , + output logic first_inner_tile_o , + output logic last_inner_tile_o , + output counter_t tile_x_o , + output counter_t tile_y_o , + output counter_t inner_tile_o , + input requant_t requant_add_i , + output requant_oup_t requant_add_o , + input bias_t inp_bias_i , + output bias_t inp_bias_pad_o , + output logic [N-1:0] mask_o , + output logic busy_o ); step_e step_d, step_q; - counter_t count_d, count_q; + counter_t count_d, count_q, bias_count; + counter_t tile_d, tile_q; counter_t inner_tile_d, inner_tile_q; + counter_t tile_x_d, tile_x_q, bias_tile_x_d, bias_tile_x_q; + counter_t tile_y_d, tile_y_q, bias_tile_y_d, bias_tile_y_q; counter_t softmax_tile_d, softmax_tile_q; ongoing_t ongoing_d, ongoing_q; ongoing_soft_t ongoing_soft_d, ongoing_soft_q; + bias_t inp_bias, inp_bias_padded; + logic last_time; + + tile_t inner_tile_dim; + input_dim_t first_outer_dim, second_outer_dim; + input_dim_t first_outer_dim_d, first_outer_dim_q; + input_dim_t second_outer_dim_d, second_outer_dim_q; + logic softmax_fifo, softmax_div, softmax_div_done_d, softmax_div_done_q, busy_d, busy_q; + requant_oup_t requant_add, requant_add_d, requant_add_q; - assign step_o = step_q; - assign busy_o = busy_q; + assign step_o = step_q; + assign busy_o = busy_q; + assign tile_x_o = tile_x_q; + assign tile_y_o = tile_y_q; + assign inner_tile_o = inner_tile_q; + assign requant_add_o = requant_add_q; + assign inp_bias_pad_o = inp_bias_padded; + always_comb begin count_d = count_q; tile_d = tile_q; inner_tile_d = inner_tile_q; + tile_x_d = tile_x_q; + tile_y_d = tile_y_q; first_inner_tile_o = (inner_tile_q == 0) ? 1'b1 : 1'b0; last_inner_tile_o = 1'b0; ongoing_d = ongoing_q; @@ -59,10 +87,12 @@ module ita_controller step_d = step_q; softmax_tile_d = softmax_tile_q; softmax_div_done_d = softmax_div_done_q; - - busy_d = busy_q; - softmax_fifo = 1'b0; - softmax_div = 1'b0; + last_time = 1'b0; + requant_add = {N {requant_add_i}}; + inp_bias = inp_bias_i; + busy_d = busy_q; + softmax_fifo = 1'b0; + softmax_div = 1'b0; if (step_q != AV) begin softmax_div_done_d = 1'b0; @@ -98,7 +128,7 @@ module ita_controller busy_d = 1'b1; if (count_d == M*M/N) begin // end of tile busy_d = 1'b0; // Generate done signal for current tile - count_d = '0; + count_d = '0; inner_tile_d = inner_tile_q + 1; end end @@ -108,6 +138,8 @@ module ita_controller case (step_q) Idle : begin inner_tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; tile_d = '0; softmax_tile_d = '0; softmax_div_done_d = 1'b0; @@ -126,51 +158,80 @@ module ita_controller end // Attention Q : begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin // end of step Q + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step Q tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = K; end end end K: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step K tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = V; end end end V: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.proj_space; + second_outer_dim = ctrl_i.seq_length; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_s-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_p) begin // end of step V tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = QK; end end end QK : begin - if (inner_tile_q == ctrl_i.tile_p-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_p-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.seq_length; if (inner_tile_d == ctrl_i.tile_p) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_s-1)) begin + tile_x_d = '0; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s) begin // end of step QK tile_d = '0; step_d = AV; @@ -178,64 +239,96 @@ module ita_controller end end AV : begin - if (inner_tile_q == ctrl_i.tile_s-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_s-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.proj_space; if (inner_tile_d == ctrl_i.tile_s) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_p-1)) begin + tile_x_d = '0; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_p) begin tile_d = '0; softmax_tile_d = softmax_tile_q + 1; if (softmax_tile_d == ctrl_i.tile_s) begin softmax_tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; if (ctrl_i.layer == Attention) begin step_d = OW; end else if (ctrl_i.layer == SingleAttention) begin step_d = Idle; end end else begin + tile_y_d = tile_y_q + 1; step_d = QK; end end end end OW : begin - if (inner_tile_q == ctrl_i.tile_p-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_p-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.embed_size; if (inner_tile_d == ctrl_i.tile_p) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_e-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_e) begin // end of step OW tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = Idle; end end end // Feedforward F1: begin - if (inner_tile_q == ctrl_i.tile_e-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_e-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.ff_size; if (inner_tile_d == ctrl_i.tile_e) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; - if (tile_d == ctrl_i.tile_s*ctrl_i.tile_f) begin + if (tile_x_q == (ctrl_i.tile_f-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end + if (tile_d == ctrl_i.tile_s*ctrl_i.tile_f) begin tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = F2; end end end F2: begin - if (inner_tile_q == ctrl_i.tile_f-1) begin - last_inner_tile_o = 1'b1; - end + inner_tile_dim = ctrl_i.tile_f-1; + first_outer_dim = ctrl_i.seq_length; + second_outer_dim = ctrl_i.embed_size; if (inner_tile_d == ctrl_i.tile_f) begin // end of inner tile inner_tile_d = '0; tile_d = tile_q + 1; + if (tile_x_q == (ctrl_i.tile_e-1)) begin + tile_x_d = '0; + tile_y_d = tile_y_q + 1; + end else begin + tile_x_d = tile_x_q + 1; + end if (tile_d == ctrl_i.tile_s*ctrl_i.tile_e) begin tile_d = '0; + tile_x_d = '0; + tile_y_d = '0; step_d = Idle; end end @@ -255,6 +348,41 @@ module ita_controller end end endcase + + bias_count = (count_q == 0) ? ((M*M/N)-1) : count_q - 1; + bias_tile_x_d = (count_q == 0) ? bias_tile_x_q : tile_x_q; + bias_tile_y_d = (count_q == 0) ? bias_tile_y_q : tile_y_q; + first_outer_dim_d = (count_q == 0) ? first_outer_dim_q : first_outer_dim; + second_outer_dim_d = (count_q == 0) ? second_outer_dim_q : second_outer_dim; + + if ((step_q != Idle && step_q != MatMul) || (step_q == Idle && bias_count == ((M*M/N)-1))) begin + if (inner_tile_q == inner_tile_dim) begin + last_inner_tile_o = 1'b1; + end + if ((((((bias_count) & (M-1)) + bias_tile_y_d * M)) > ((first_outer_dim_d - 1)))) begin + requant_add_d = {N {1'b0}}; + inp_bias = {N {1'b0}}; + end else begin + if ( ((bias_count) + bias_tile_x_d * M*M/N) >= (second_outer_dim_d / N) * M ) begin + if ( (((bias_count) / M) * N + bias_tile_x_d * M ) < second_outer_dim_d) begin + for (int i = 0; i < N; i++) begin + if (i >= (second_outer_dim_d & (N-1))) begin + requant_add_d[i] = 1'b0; + inp_bias[i] = 1'b0; + end else begin + requant_add_d[i] = requant_add[i]; + inp_bias[i] = inp_bias_i[i]; + end + end + end else begin + requant_add_d = {N {1'b0}}; + inp_bias = {N {1'b0}}; + end + end + end + end + inp_bias_padded = inp_bias; + if (inp_valid_i && inp_ready_o && oup_valid_i && oup_ready_i && last_inner_tile_o) begin ongoing_d = ongoing_q; end else if (inp_valid_i && inp_ready_o && last_inner_tile_o) begin @@ -276,22 +404,50 @@ module ita_controller step_q <= Idle; count_q <= '0; tile_q <= '0; + tile_x_q <= '0; + tile_y_q <= '0; inner_tile_q <= '0; softmax_tile_q <= '0; ongoing_q <= '0; ongoing_soft_q <= '0; softmax_div_done_q <= 1'b0; + requant_add_q <= '0; busy_q <= 1'b0; + bias_tile_x_q <= '0; + bias_tile_y_q <= '0; + first_outer_dim_q <= '0; + second_outer_dim_q <= '0; end else begin step_q <= step_d; count_q <= count_d; tile_q <= tile_d; + tile_x_q <= tile_x_d; + tile_y_q <= tile_y_d; inner_tile_q <= inner_tile_d; softmax_tile_q <= softmax_tile_d; ongoing_q <= ongoing_d; ongoing_soft_q <= ongoing_soft_d; softmax_div_done_q <= softmax_div_done_d; + requant_add_q <= requant_add_d; busy_q <= busy_d; + bias_tile_x_q <= bias_tile_x_d; + bias_tile_y_q <= bias_tile_y_d; + first_outer_dim_q <= first_outer_dim_d; + second_outer_dim_q <= second_outer_dim_d; end end + + ita_masking i_masking ( + .clk_i (clk_i), + .rst_ni (rst_ni), + .ctrl_i (ctrl_i), + .step_i (step_o), + .calc_en_i (calc_en_o), + .last_inner_tile_i (last_inner_tile_o), + .count_i (count_q), + .tile_x_i (tile_x_o), + .tile_y_i (tile_y_o), + .mask_o (mask_o) + ); + endmodule diff --git a/src/ita_masking.sv b/src/ita_masking.sv new file mode 100644 index 0000000..7622655 --- /dev/null +++ b/src/ita_masking.sv @@ -0,0 +1,233 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Solderpad Hardware License, Version 0.51, see LICENSE for details. +// SPDX-License-Identifier: SHL-0.51 + +/** + ITA masking module. +*/ + +module ita_masking + import ita_package::*; +( + input logic clk_i, + input logic rst_ni, + input ctrl_t ctrl_i, + input step_e step_i, + input logic calc_en_i, + input logic last_inner_tile_i, + input counter_t count_i, + input counter_t tile_x_i, + input counter_t tile_y_i, + output logic [N-1:0] mask_o +); + + logic [3:0] mask_col_offset_d, mask_col_offset_q; + counter_t mask_tile_x_pos_d, mask_tile_x_pos_q; + counter_t mask_tile_y_pos_d, mask_tile_y_pos_q; + counter_t mask_pos_d, mask_pos_q; + logic [N-1:0] mask_d, mask_q; + + assign mask_o = mask_q; + + always_comb begin + mask_col_offset_d = '0; + mask_tile_x_pos_d = '0; + mask_tile_y_pos_d = '0; + mask_pos_d = '0; + mask_d = '0; + case (ctrl_i.mask_type) + UpperTriangular: begin + mask_col_offset_d = (step_i == QK || step_i == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1)); + mask_tile_x_pos_d = (step_i == QK || step_i == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M); + mask_tile_y_pos_d = mask_tile_y_pos_q; + mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1)); + + if (step_i == QK && last_inner_tile_i == 1'b1) begin + if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i) begin + if (count_i == ((M * M / N) - 1)) begin + mask_tile_x_pos_d = mask_tile_x_pos_q + 1'b1; + end + if ((count_i >= mask_pos_q) && (count_i < (mask_pos_q + N))) begin + if ((count_i & (M - 1)) == (M - 1)) begin + mask_tile_y_pos_d = tile_y_i + 1'b1; + mask_tile_x_pos_d = tile_x_i; + if (((count_i + mask_col_offset_q) & (N-1)) == (N-1)) begin + mask_pos_d = ((count_i + 1) & ((M*M/N)-1)); + if ((count_i / M) == ((M / N) - 1)) begin + mask_tile_x_pos_d = tile_x_i + 1'b1; + end + end else begin + mask_pos_d = ((count_i + (((ctrl_i.tile_s * (M*M/N)) - M) + 1)) & ((M*M/N)-1)); + end + end else if (((count_i + mask_col_offset_q) & (N-1)) == (N-1)) begin + mask_pos_d = (mask_pos_q + (N - ((mask_pos_q + mask_col_offset_q) & (N-1))) + M) & ((M*M/N)-1); + end + for (int i = 0; i < N; i++) begin + if (((count_i + mask_col_offset_q) & (N - 1)) <= i) begin + mask_d[i] = 1'b1; + end else begin + mask_d[i] = 1'b0; + end + end + end else if ((count_i & (M - 1)) < (mask_pos_q & (M - 1))) begin + mask_d = '1; + end + end else if (mask_tile_x_pos_q <= tile_x_i && mask_tile_y_pos_q != tile_y_i) begin + mask_d = '1; + end else if (mask_tile_x_pos_q != tile_x_i && mask_tile_y_pos_q == tile_y_i) begin + mask_d = '0; + end + end + end + LowerTriangular: begin + mask_tile_x_pos_d = mask_tile_x_pos_q; + mask_tile_y_pos_d = (step_i == QK || step_i == AV) ? mask_tile_y_pos_q : ((ctrl_i.mask_start_index) / M); + mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1)); + + if (step_i == QK && last_inner_tile_i == 1'b1) begin + if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i) begin + if (count_i == ((M * M / N) - 1)) begin + mask_tile_x_pos_d = mask_tile_x_pos_q + 1'b1; + end + if ((count_i >= mask_pos_q) && (count_i < (mask_pos_q + N))) begin + if ((count_i & (M - 1)) == (M - 1)) begin + mask_tile_y_pos_d = tile_y_i + 1'b1; + mask_tile_x_pos_d = tile_x_i; + if (((count_i + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) == (N-1)) begin + mask_pos_d = ((count_i + 1) & ((M*M/N)-1)); + if ((count_i / M) == ((M / N) - 1)) begin + mask_tile_x_pos_d = tile_x_i + 1'b1; + end + end else begin + mask_pos_d = ((count_i + (((ctrl_i.tile_s * (M*M/N)) - M) + 1)) & ((M*M/N)-1)); + end + end else if (((count_i + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) == (N-1)) begin + mask_pos_d = (mask_pos_q + (count_i - mask_pos_q + 1) + M) & ((M*M/N)-1); + end + for (int i = 0; i < N; i++) begin + if (((count_i + (N - (ctrl_i.mask_start_index & (N - 1)))) & (N - 1)) >= i) begin + mask_d[i] = 1'b1; + end else begin + mask_d[i] = 1'b0; + end + end + end else if ((count_i & (M - 1)) >= (mask_pos_q & (M - 1))) begin + mask_d = '1; + end + end else if (mask_tile_x_pos_q > tile_x_i && mask_tile_y_pos_q == tile_y_i) begin + mask_d = '1; + end else if (mask_tile_x_pos_q >= tile_x_i && mask_tile_y_pos_q != tile_y_i) begin + mask_d = '0; + end + end + end + Strided: begin + if (step_i == QK && last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //col_pos = count_i/M * N + i + tile_x_i * M + //row_pos = count_i & (M-1) + tile_y_i * M + //Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) + & (ctrl_i.mask_start_index-1)) == 0) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + UpperStrided: begin + if (step_i == QK && last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((((count_i / M) * N) + i + (tile_x_i * M)) >= ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + LowerStrided: begin + if (step_i == QK && last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((((count_i / M) * N) + i + (tile_x_i * M)) <= ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + SlidingWindow: begin + if (step_i == QK && last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + if (((count_i & (M-1)) + (tile_y_i * M)) < ctrl_i.mask_start_index) begin + if ((((count_i / M) * N) + i + (tile_x_i * M)) < (ctrl_i.mask_start_index + ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end else begin + if ((((count_i & (M-1)) + (tile_y_i * M) - (ctrl_i.mask_start_index-1)) <= (((count_i / M) * N) + i + (tile_x_i * M))) && + ((((count_i / M) * N) + i + (tile_x_i * M)) < ((count_i & (M-1)) + (tile_y_i * M) + ctrl_i.mask_start_index))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + end + StridedSlidingWindow: begin + if (step_i == QK && last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //Strided logic + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0) begin + mask_d[i] = 1'b0; + end else begin + //Sliding window logic + if (((count_i & (M-1)) + (tile_y_i * M)) < ctrl_i.mask_start_index) begin + if ((((count_i / M) * N) + i + (tile_x_i * M)) < (ctrl_i.mask_start_index + ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end else begin + if ((((count_i & (M-1)) + (tile_y_i * M) - (ctrl_i.mask_start_index-1)) <= (((count_i / M) * N) + i + (tile_x_i * M))) && + ((((count_i / M) * N) + i + (tile_x_i * M)) < ((count_i & (M-1)) + (tile_y_i * M) + ctrl_i.mask_start_index))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + end + end + default: ; + endcase + end + + always_ff @(posedge clk_i or negedge rst_ni) begin + if (~rst_ni) begin + mask_pos_q <= '0; + mask_tile_x_pos_q <= '0; + mask_tile_y_pos_q <= '0; + mask_col_offset_q <= '0; + mask_q <= '0; + end else begin + if (calc_en_i) begin + mask_pos_q <= mask_pos_d; + mask_tile_x_pos_q <= mask_tile_x_pos_d; + mask_tile_y_pos_q <= mask_tile_y_pos_d; + end + mask_col_offset_q <= mask_col_offset_d; + mask_q <= mask_d; + end + end + +endmodule diff --git a/src/ita_package.sv b/src/ita_package.sv index c20ef71..2d3e0fe 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -40,18 +40,36 @@ package ita_package; typedef logic signed [GELU_CONSTANTS_WIDTH-1:0] gelu_const_t; typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t; + // Masking + typedef enum {None=0, + UpperTriangular=1, + LowerTriangular=2, + Strided=3, + UpperStrided=4, + LowerStrided=5, + SlidingWindow=6, + StridedSlidingWindow=7} mask_e; + typedef logic [WO-WI*2-2:0] mask_index_t; + // IO typedef logic [EMS-1:0] requant_const_t; typedef logic [N_REQUANT_CONSTS-1:0][EMS-1:0] requant_const_array_t; typedef logic signed [WI-1:0] requant_t; typedef logic signed [N_REQUANT_CONSTS-1:0][WI-1:0] requant_array_t; - typedef logic [idx_width(S+1)-1:0] seq_length_t; - typedef logic [idx_width(P+1)-1:0] proj_space_t; - typedef logic [idx_width(E+1)-1:0] embed_size_t; - typedef logic [idx_width(H+1)-1:0] n_heads_t; + typedef logic [WO-WI*2-1:0] input_dim_t; + typedef input_dim_t seq_length_t; + typedef input_dim_t proj_space_t; + typedef input_dim_t embed_size_t; + typedef input_dim_t ff_size_t; typedef logic [ 32-1:0] tile_t; typedef struct packed { logic start ; + seq_length_t seq_length ; + proj_space_t proj_space ; + embed_size_t embed_size ; + ff_size_t ff_size ; + mask_e mask_type ; + mask_index_t mask_start_index; layer_e layer ; activation_e activation ; requant_const_array_t eps_mult ; @@ -96,7 +114,7 @@ package ita_package; // Softmax localparam int unsigned SoftmaxScalar = 65280; // (2**8-1) * 2**8 - localparam int unsigned SoftmaxShift = 0; + localparam int unsigned SoftmaxShift = 5; localparam int unsigned SoftmaxAccDataWidth = 19; // Up to S = 2048 localparam int unsigned SoftFifoDepth = 12; typedef logic [idx_width(SoftFifoDepth)-1:0] soft_fifo_usage_t; diff --git a/src/ita_requantization_controller.sv b/src/ita_requantization_controller.sv index 3b6c865..e844358 100644 --- a/src/ita_requantization_controller.sv +++ b/src/ita_requantization_controller.sv @@ -33,6 +33,12 @@ module ita_requatization_controller endcase end + // always_comb begin + // requant_mult = ctrl_i.eps_mult[step_q4]; + // requant_shift = ctrl_i.right_shift[step_q4]; + // requant_add = ctrl_i.add[step]; + // end + assign requant_mult_o = ctrl_i.eps_mult[constant_idx]; assign requant_shift_o = ctrl_i.right_shift[constant_idx]; assign requant_add_o = ctrl_i.add[constant_idx]; diff --git a/src/ita_requantizer.sv b/src/ita_requantizer.sv index 6033c09..c67b97c 100644 --- a/src/ita_requantizer.sv +++ b/src/ita_requantizer.sv @@ -22,7 +22,8 @@ module ita_requantizer logic signed [EMS+WO:0] product ; logic signed [EMS+WO:0] shifted_added; logic signed [ N-1:0][EMS+WO-1:0] shifted_d, shifted_q; - requant_oup_t add_q1, requant_oup_d, requant_oup_q; + requant_oup_t add_q1, add_q2, add_q3, add_q4; + requant_oup_t requant_oup_d, requant_oup_q; assign requant_oup_o = requant_oup_q; @@ -49,7 +50,7 @@ module ita_requantizer end end if (calc_en_q_i) begin - shifted_added = shifted_q[i] + (EMS+WO)'(signed'(add_q1[i])); + shifted_added = shifted_q[i] + (EMS+WO)'(signed'(add_q4[i])); requant_oup_d[i] = shifted_added[WI-1:0]; if (~shifted_added[EMS+WO-1] & (|(shifted_added[EMS+WO-2:WI-1]))) begin requant_oup_d[i] = '1; @@ -76,8 +77,14 @@ module ita_requantizer always_ff @(posedge clk_i, negedge rst_ni) begin if (!rst_ni) begin add_q1 <= '0; + add_q2 <= '0; + add_q3 <= '0; + add_q4 <= '0; end else begin add_q1 <= add_i; + add_q2 <= add_q1; + add_q3 <= add_q2; + add_q4 <= add_q3; end end endmodule diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index 2eb5255..1de77a8 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -39,14 +39,22 @@ module ita_softmax input requant_t [1:0] read_max_data_i, output logic write_max_en_o, output logic [InputAddrWidth-1:0] write_max_addr_o, - output requant_t write_max_data_o + output requant_t write_max_data_o, + input counter_t tile_x_i, + input counter_t tile_y_i, + input counter_t inner_tile_i, + input logic [N-1:0] mask_i ); counter_t tile_d, tile_q1, tile_q2, tile_q3, tile_q4; counter_t count_d, count_q1, count_q2, count_q3, count_q4; + counter_t inner_tile_q; + counter_t tile_x_q, tile_y_q; + counter_t mask_tile_x_d, mask_tile_x_q, mask_tile_y_d, mask_tile_y_q; + counter_t mask_tile_outer_dim_d, mask_tile_outer_dim_q; logic unsigned [SoftmaxAccDataWidth-1:0] exp_sum_d, exp_sum_q; - counter_t count_soft_d, count_soft_q; + counter_t count_soft_d, count_soft_q1, count_soft_q2, count_soft_mask_q; counter_t count_div_d, count_div_q, addr_div_d, addr_div_q; logic [NumDiv-1:0] div_read_d, div_read_q, div_write_d, div_write_q; @@ -69,13 +77,19 @@ module ita_softmax logic [SoftmaxAccDataWidth-1:0] data_to_fifo, data_from_fifo; soft_fifo_usage_t fifo_usage ; + logic [N-1:0] disable_shift; + logic disable_row; + logic [M-1:0]disable_col; + + assign disable_row = ((count_soft_q2 & (M-1)) + tile_y_q * M) > (ctrl_i.seq_length - 1); + assign pop_softmax_fifo_o = pop_from_fifo; assign soft_addr_div_o = addr_div_q; always_comb begin tile_d = tile_q1; count_d = count_q1; - count_soft_d = count_soft_q; + count_soft_d = count_soft_q1; count_div_d = count_div_q; div_read_d = div_read_q; div_write_d = div_write_q; @@ -109,6 +123,10 @@ module ita_softmax shift_inp_diff = '0; inp_stream_soft_o = '0; softmax_done_o = 0; + mask_tile_x_d = mask_tile_x_q; + mask_tile_y_d = mask_tile_y_q; + mask_tile_outer_dim_d = mask_tile_outer_dim_q; + //************ Accumulation ************// case (step_i) @@ -135,13 +153,20 @@ module ita_softmax //************ Pipeline Stage 1 ************// if (calc_en_q1) begin // Find max and accumulate - max_o = requant_oup_q; max_d = max_i; for (int i = 0; i < N; i++) begin shift_diff[i] = max_i - requant_oup_q[i]; - shift_d[i] = unsigned'(shift_diff[i]) >> SoftmaxShift; - if (SoftmaxShift != 0 && shift_diff[i][SoftmaxShift-1]) - shift_d[i] = (unsigned'(shift_diff[i]) >> SoftmaxShift) + 1; + disable_shift[i] = ((tile_q2*M+N*(count_q2 >> $clog2(M))+i ) >= ctrl_i.seq_length); + + if (disable_shift[i] || mask_i[i]) begin + max_o[i] = 8'h80; + shift_d[i] = 4'hF; + end else begin + max_o[i] = requant_oup_q[i]; + shift_d[i] = unsigned'(shift_diff[i]) >> SoftmaxShift; + if (SoftmaxShift != 0 && shift_diff[i][SoftmaxShift-1]) + shift_d[i] = (unsigned'(shift_diff[i]) >> SoftmaxShift) + 1; + end end if (tile_q2 != '0 || count_q2>=M) begin // If not first part of the first row, normalize previous sum read_acc_en_o[0] = 1; @@ -162,10 +187,12 @@ module ita_softmax write_max_addr_o = count_q3; write_max_data_o = max_q; for (int i = 0; i < N; i++) begin - exp_sum_d += unsigned'(9'h100)>>shift_q[i]; + //Marcel Kant: This if statement is most likely not required + if (shift_q[i] != 4'hF) + exp_sum_d += unsigned'(9'h100)>>shift_q[i]; end if (tile_q3 != '0 || count_q3>=M) begin // If not first part of the first row - exp_sum_d += ( unsigned'(read_acc_data_i[0]) >> shift_sum_q); + exp_sum_d += (unsigned'(read_acc_data_i[0]) >> shift_sum_q); end end @@ -211,28 +238,172 @@ module ita_softmax //*********** Stream Softmax ***********// // Main controller checks if division is ready if (calc_stream_soft_en_i) begin - count_soft_d = count_soft_q + 1; + count_soft_d = count_soft_q1 + 1; read_acc_en_o[1] = 1; - read_acc_addr_o[1] = count_soft_q[5:0]; + read_acc_addr_o[1] = count_soft_q1[5:0]; read_max_en_o[1] = 1; - read_max_addr_o[1] = count_soft_q[5:0]; + read_max_addr_o[1] = count_soft_q1[5:0]; if (count_soft_d == M*M/N) begin count_soft_d = '0; end end if (calc_stream_soft_en_q) begin - for (int i = 0; i < M; i++) begin - shift_inp_diff[i] = read_max_data_i[1]-inp_i[i]; - shift_inp[i] = unsigned'(shift_inp_diff[i]) >> SoftmaxShift; - if (SoftmaxShift != 0 && shift_inp_diff[i][SoftmaxShift-1]) - shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> SoftmaxShift) + 1; - inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i]; + if (count_soft_mask_q == (((M*M)/N)-1)) begin + if (mask_tile_x_q == (ctrl_i.tile_s - 1)) begin + mask_tile_x_d = '0; + mask_tile_outer_dim_d = mask_tile_outer_dim_q + 1; + if (mask_tile_outer_dim_q == (ctrl_i.tile_p - 1)) begin + mask_tile_outer_dim_d = '0; + mask_tile_y_d = mask_tile_y_q + 1; + end + end else begin + mask_tile_x_d = mask_tile_x_q + 1; + end + if (mask_tile_y_q == ctrl_i.tile_s) begin + mask_tile_outer_dim_d = '0; + mask_tile_x_d = '0; + mask_tile_y_d = '0; + end + end + + if (disable_row) begin + inp_stream_soft_o = { M { '0 } }; + end else begin + for (int i = 0; i < M; i++) begin + if ((inner_tile_q*M + i) >= ctrl_i.seq_length) begin + disable_col[i] = 1'b1; + end else begin + case (ctrl_i.mask_type) + None: begin + disable_col[i] = 1'b0; + end + UpperTriangular: begin + // (ctrl_i.mask_start_index / M) -> tile where masking starts + if (mask_tile_x_q == mask_tile_y_q + (ctrl_i.mask_start_index / M)) begin + if (i >= ((count_soft_mask_q & (M-1)) + (ctrl_i.mask_start_index & (M-1)))) begin + disable_col[i] = 1'b1; + end else begin + disable_col[i] = 1'b0; + end + end else if (mask_tile_x_q == ((ctrl_i.mask_start_index / M) + 1'b1 + mask_tile_y_q)) begin + if (i < signed'((count_soft_mask_q & (M-1)) - (M - (ctrl_i.mask_start_index & (M-1))))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end else if (mask_tile_x_q > ((ctrl_i.mask_start_index / M) + 1'b1 + mask_tile_y_q)) begin + disable_col[i] = 1'b1; + end else begin + disable_col[i] = 1'b0; + end + end + LowerTriangular: begin + if (mask_tile_y_q == mask_tile_x_q + (ctrl_i.mask_start_index / M)) begin + if (i <= signed'((count_soft_mask_q & (M-1)) - (ctrl_i.mask_start_index & (M-1)))) begin + disable_col[i] = 1'b1; + end else begin + disable_col[i] = 1'b0; + end + end else if (mask_tile_y_q == ((ctrl_i.mask_start_index / M) + 1'b1 + mask_tile_x_q)) begin + if (i <= ((count_soft_mask_q & (M-1)) + (M - (ctrl_i.mask_start_index & (M-1))))) begin + disable_col[i] = 1'b1; + end else begin + disable_col[i] = 1'b0; + end + end else if (mask_tile_y_q > ((ctrl_i.mask_start_index / M) + 1'b1 + mask_tile_x_q)) begin + disable_col[i] = 1'b1; + end else begin + disable_col[i] = 1'b0; + end + end + Strided: begin + //col_pos = i + mask_tile_x_q * M + //row_pos = count_soft_mask_q & (M-1) + mask_tile_y_q * M + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + UpperStrided: begin + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((i + (mask_tile_x_q * M)) >= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + LowerStrided: begin + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((i + (mask_tile_x_q * M)) <= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + SlidingWindow: begin + if (((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)) < ctrl_i.mask_start_index) begin + if ((i + (mask_tile_x_q * M)) < ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) + ctrl_i.mask_start_index)) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end else begin + if ((((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) - (ctrl_i.mask_start_index-1)) <= (i + (mask_tile_x_q * M))) && + ((i + (mask_tile_x_q * M)) < ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) + ctrl_i.mask_start_index))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + end + StridedSlidingWindow: begin + //Strided logic + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin + disable_col[i] = 1'b0; + end else begin + //Sliding window logic + if (((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)) < ctrl_i.mask_start_index) begin + if ((i + (mask_tile_x_q * M)) < ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) + ctrl_i.mask_start_index)) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end else begin + if ((((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) - (ctrl_i.mask_start_index-1)) <= (i + (mask_tile_x_q * M))) && + ((i + (mask_tile_x_q * M)) < ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M) + ctrl_i.mask_start_index))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + end + end + endcase + end + + if (disable_col[i]) begin + inp_stream_soft_o[i] = '0; + end else begin + shift_inp_diff[i] = read_max_data_i[1]-inp_i[i]; + shift_inp[i] = unsigned'(shift_inp_diff[i]) >> SoftmaxShift; + if (SoftmaxShift != 0 && shift_inp_diff[i][SoftmaxShift-1]) + shift_inp[i] = (unsigned'(shift_inp_diff[i]) >> SoftmaxShift) + 1; + inp_stream_soft_o[i] = read_acc_data_i[1] >> shift_inp[i]; + end + end end end end always_ff @(posedge clk_i or negedge rst_ni) begin if(~rst_ni) begin + inner_tile_q <= '0; + tile_x_q <= '0; + tile_y_q <= '0; + mask_tile_x_q <= '0; + mask_tile_y_q <= '0; + mask_tile_outer_dim_q <= '0; tile_q4 <= '0; tile_q3 <= '0; tile_q2 <= '0; @@ -240,8 +411,10 @@ module ita_softmax count_q4 <= M*M/N; count_q3 <= M*M/N; count_q2 <= M*M/N; - count_q1 <= M*M/N; - count_soft_q <= '0; + count_q1 <= M*M/N; + count_soft_q1 <= '0; + count_soft_q2 <= '0; + count_soft_mask_q <= '0; count_div_q <= '0; div_read_q <= '0; div_write_q <= '0; @@ -253,6 +426,9 @@ module ita_softmax shift_q <= '0; shift_sum_q <= '0; end else begin + inner_tile_q <= inner_tile_i; + tile_x_q <= tile_x_i; + tile_y_q <= tile_y_i; tile_q4 <= tile_q3; tile_q3 <= tile_q2; tile_q2 <= tile_q1; @@ -261,7 +437,14 @@ module ita_softmax count_q3 <= count_q2; count_q2 <= count_q1; count_q1 <= count_d; - count_soft_q <= count_soft_d; + count_soft_q1 <= count_soft_d; + count_soft_q2 <= count_soft_q1; + if (calc_stream_soft_en_i) begin + count_soft_mask_q <= count_soft_q1; + end + mask_tile_x_q <= mask_tile_x_d; + mask_tile_y_q <= mask_tile_y_d; + mask_tile_outer_dim_q <= mask_tile_outer_dim_d; count_div_q <= count_div_d; div_read_q <= div_read_d; div_write_q <= div_write_d; diff --git a/src/ita_softmax_top.sv b/src/ita_softmax_top.sv index e44fb4f..56a6078 100644 --- a/src/ita_softmax_top.sv +++ b/src/ita_softmax_top.sv @@ -19,7 +19,12 @@ module ita_softmax_top output counter_t soft_addr_div_o , output logic softmax_done_o , output logic pop_softmax_fifo_o , - output inp_t inp_stream_soft_o + output inp_t inp_stream_soft_o , + input counter_t tile_x_i , + input counter_t tile_y_i , + input counter_t inner_tile_i , + input logic [N-1:0] mask_i + ); logic [1:0] read_acc_en; @@ -34,7 +39,7 @@ module ita_softmax_top logic unsigned [ NumDiv-1:0][DividerWidth-1:0] div_oup ; logic unsigned [ DividerWidth-1:0] val ; - requant_oup_t max_in ; + requant_oup_t max_in; requant_t prev_max, max_out; ita_max_finder i_max_finder ( @@ -113,7 +118,12 @@ module ita_softmax_top .write_max_en_o (write_max_en ), .write_max_addr_o (write_max_addr ), - .write_max_data_o (write_max_data ) + .write_max_data_o (write_max_data ), + + .tile_x_i (tile_x_i ), + .tile_y_i (tile_y_i ), + .inner_tile_i (inner_tile_i ), + .mask_i (mask_i ) ); ita_register_file_1w_multi_port_read #( diff --git a/src/tb/ita_tb.sv b/src/tb/ita_tb.sv index e8f84a6..1fdddec 100644 --- a/src/tb/ita_tb.sv +++ b/src/tb/ita_tb.sv @@ -46,6 +46,8 @@ module ita_tb; integer N_TILES_INNER_DIM_LINEAR_PROJECTION[N_PHASES]; integer N_ATTENTION_TILE_ROWS, N_GROUPS; activation_e ACTIVATION; + mask_e MASK; + integer MASK_START_INDEX; // Signals logic clk, rst_n; @@ -76,6 +78,8 @@ module ita_tb; EMBEDDING_SIZE = `ifdef EMBED_SIZE `EMBED_SIZE `else M_TILE_LEN `endif; FEEDFORWARD_SIZE = `ifdef FF_SIZE `FF_SIZE `else M_TILE_LEN `endif; ACTIVATION = activation_e'(`ifdef ACTIVATION `ACTIVATION `else Identity `endif); + MASK = mask_e'(`ifdef MASK `MASK `else None `endif); + MASK_START_INDEX = `ifdef MASK_INDEX `MASK_INDEX `else 1 `endif; simdir = { "../../simvectors/data_S", @@ -89,11 +93,17 @@ module ita_tb; "_H1_B", $sformatf("%0d", `ifdef BIAS `BIAS `else 0 `endif), "_", - $sformatf( "%s", ACTIVATION) + $sformatf("%s", ACTIVATION), + "_", + $sformatf("%s", MASK), + "_I", + $sformatf("%0d", MASK_START_INDEX) }; - N_TILES_SEQUENCE_DIM = SEQUENCE_LEN / M_TILE_LEN; - N_TILES_EMBEDDING_DIM = EMBEDDING_SIZE / M_TILE_LEN; - N_TILES_PROJECTION_DIM = PROJECTION_SPACE / M_TILE_LEN; + // Round up + N_TILES_SEQUENCE_DIM = (SEQUENCE_LEN + M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_EMBEDDING_DIM = (EMBEDDING_SIZE+ M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_PROJECTION_DIM = (PROJECTION_SPACE + M_TILE_LEN -1 ) / M_TILE_LEN; + N_TILES_FEEDFORWARD = (FEEDFORWARD_SIZE + M_TILE_LEN -1) / M_TILE_LEN; N_TILES_LINEAR_PROJECTION = N_TILES_SEQUENCE_DIM * N_TILES_EMBEDDING_DIM * N_TILES_PROJECTION_DIM; N_TILES_ATTENTION = N_TILES_SEQUENCE_DIM * N_TILES_PROJECTION_DIM; N_ENTRIES_PER_TILE = M_TILE_LEN * M_TILE_LEN / N_PE; @@ -103,7 +113,6 @@ module ita_tb; N_ENTRIES_PER_SEQUENCE_DIM = N_ENTRIES_PER_TILE * N_TILES_SEQUENCE_DIM; N_ATTENTION_TILE_ROWS = N_TILES_SEQUENCE_DIM; N_GROUPS = 2 * N_ATTENTION_TILE_ROWS; - N_TILES_FEEDFORWARD = FEEDFORWARD_SIZE / M_TILE_LEN; N_TILES_INNER_DIM_LINEAR_PROJECTION[0] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[1] = N_TILES_EMBEDDING_DIM; N_TILES_INNER_DIM_LINEAR_PROJECTION[2] = N_TILES_EMBEDDING_DIM; @@ -489,6 +498,12 @@ task automatic apply_ITA_weights(input integer phase); ita_ctrl.tile_p = N_TILES_PROJECTION_DIM; ita_ctrl.tile_s = N_TILES_SEQUENCE_DIM; ita_ctrl.tile_f = N_TILES_FEEDFORWARD; + ita_ctrl.seq_length = SEQUENCE_LEN; + ita_ctrl.proj_space = PROJECTION_SPACE; + ita_ctrl.embed_size = EMBEDDING_SIZE; + ita_ctrl.ff_size = FEEDFORWARD_SIZE; + ita_ctrl.mask_type = MASK; + ita_ctrl.mask_start_index = MASK_START_INDEX; read_activation_constants(ita_ctrl.gelu_b, ita_ctrl.gelu_c, ita_ctrl.activation_requant_mult, ita_ctrl.activation_requant_shift, ita_ctrl.activation_requant_add); diff --git a/testGenerator.py b/testGenerator.py index 0c94a55..c26a2a4 100644 --- a/testGenerator.py +++ b/testGenerator.py @@ -48,7 +48,11 @@ def generateMHA(**args): NO_BIAS = args['no_bias'] NO_PARTIAL_SOFTMAX = args['no_partial_softmax'] ACTIVATION = args['activation'].capitalize() - base_path = f'{current_dir}/simvectors/data_S{S}_E{E}_P{P}_F{F}_H{H}_B{int(not NO_BIAS)}_{ACTIVATION}' + features = args['mask'].split('_') + seperator = "" + MASK = seperator.join(feature.capitalize() for feature in features) + INDEX = args['I'] + base_path = f'{current_dir}/simvectors/data_S{S}_E{E}_P{P}_F{F}_H{H}_B{int(not NO_BIAS)}_{ACTIVATION}_{MASK}_I{INDEX}' if NO_PARTIAL_SOFTMAX: path = f'{base_path}_noPartialSoftmax/' @@ -102,6 +106,19 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter type = str, help = 'Activation function', choices = ['gelu', 'relu', 'identity']) + self.group1.add_argument('--mask', + default = 'none', + type = str, + help = 'Attention-Mask', + choices = ['none', + 'upper_triangular', + 'lower_triangular', + 'strided', + 'upper_strided', + 'lower_strided', + 'sliding_window', + 'strided_sliding_window']) + self.group1.add_argument('-I', default = 1, type = int, help = 'Masking starting index') self.group1.add_argument('--no-partial-softmax', action = 'store_true', help = 'Disable partial softmax calculation') diff --git a/tests/masking_test.sh b/tests/masking_test.sh new file mode 100755 index 0000000..84aeef6 --- /dev/null +++ b/tests/masking_test.sh @@ -0,0 +1,162 @@ +#!/bin/bash + +# Copyright 2023 ETH Zurich and University +# of Bologna. Licensed under the Apache License, +# Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 + +echo "Testing ITA ..." + +# Set the log file +log_file=tests/logs/run_loop_$(date +%Y%m%d%H%M%S).log + +# Create folder and log file +mkdir -p tests/logs +touch $log_file + +# Activate the virtual environment +source venv/bin/activate + +# Set the simulation path +export buildpath=build +export SIM_PATH=modelsim/$buildpath + +if [ -z "$target" ]; then + no_stalls=0 + echo "Target not set. Using default value: $target" +fi + +# Set to -gui to use the GUI of QuestaSim +export vsim_flags=-c + +# Set the no_stalls if not set +if [ -z "$no_stalls" ]; then + no_stalls=0 + echo "No_stalls not set. Using default value: $no_stalls" +fi + +# Set the n_tests if not set +if [ -z "$n_tests" ]; then + n_tests=250 + echo "Granularity not set. Using default value: $n_tests" +fi + +# Log the parameters +echo "no_stalls=$no_stalls" >> $log_file +echo "n_tests=$n_tests" >> $log_file + +# List of masking names +masking_names=("upper_triangular" "lower_triangular" "strided" + "upper_strided" "lower_strided" + "sliding_window" "strided_sliding_window") + +# List of activation names +activation_names=("identity" "relu" "gelu") + +# Helper function: checks if a mask is one of the strided ones +is_strided_mask() { + case "$1" in + "strided"|"upper_strided"|"lower_strided"|"strided_sliding_window") + return 0 # True + ;; + *) + return 1 # False + ;; + esac +} + +# Helper function: returns all powers of two < s +# (2, 4, 8, 16, ...), stored in an array +powers_of_two_less_than_s() { + local limit=$1 + local val=1 + local results=() + + # If you also want to allow i=1 (which is 2^0), + # set val=1 and do while [ $val -lt $limit ] + # If you need strictly 2,4,8..., set val=2. + val=2 + while [ $val -lt $limit ]; do + results+=($val) + val=$((val*2)) + done + + echo "${results[@]}" +} + +# Run the tests +for test_idx in $(seq 1 $n_tests); do + # Randomly pick s, e, p, f in [2..512] + s=$((2 + RANDOM % 511)) + e=$((1 + RANDOM % 511)) + p=$((1 + RANDOM % 511)) + f=$((1 + RANDOM % 511)) + + # Pick one random masking + random_mask_idx=$((RANDOM % ${#masking_names[@]})) + masking=${masking_names[$random_mask_idx]} + + # Pick one random activation + random_activation_idx=$((RANDOM % ${#activation_names[@]})) + activation=${activation_names[$random_activation_idx]} + + # Pick one random bias (0 or 1) + bias=$((RANDOM % 2)) + + # Decide how to pick i based on whether masking is strided + if is_strided_mask "$masking"; then + # 1) We need i that is < s and also a power of two + valid_i_list=( $(powers_of_two_less_than_s $s) ) + + # If no valid i found, skip this iteration + if [ ${#valid_i_list[@]} -eq 0 ]; then + echo "No valid i for mask=$masking with s=$s (need i < s and i a power of two). Skipping..." + continue + fi + + # Pick a random valid i from the list + i=${valid_i_list[$((RANDOM % ${#valid_i_list[@]}))]} + else + # 2) Non-strided masks: pick i in [1 .. s-1] + if [ "$s" -le 1 ]; then + echo "No valid i for mask=$masking with s=$s (need i < s). Skipping..." + continue + fi + i=$((1 + (RANDOM % (s-1)))) + fi + + echo "Index is: $i (Masking = $masking, s=$s)" + + # Create test vectors (no-bias and bias) + if [ "$bias" -eq 1 ]; then + python testGenerator.py -H 1 -S $s -P $p -E $e -F $f \ + --activation "$activation" --mask "$masking" -I "$i" + else + python testGenerator.py -H 1 -S $s -P $p -E $e -F $f \ + --activation "$activation" --mask "$masking" -I "$i" --no-bias + fi + + # Log the test + echo "Testing ita_tb: S=$s E=$e P=$p F=$f Activation=$activation Masking=$masking I=$i Bias=$bias" >> $log_file + + # Run the test + make sim VSIM_FLAGS=$vsim_flags DEBUG=OFF target=sim_$target \ + no_stalls=$no_stalls s=$s e=$e p=$p f=$f bias=$bias \ + activation=$activation mask=$masking i=$i + + # Check the simulation status + ./modelsim/return_status.sh "${SIM_PATH}/transcript" \ + "$s" "$e" "$p" "$f" ita_tb "$masking" "$i" >> $log_file + + # Format masking for directory name (e.g. "upper_strided" -> "UpperStrided") + formatted_masking="" + for word in ${masking//_/ }; do + formatted_masking+="${word^}" + done + + # echo "simvectors/data_S${s}_E${e}_P${p}_F${f}_H1_B${bias}_${activation^}_${formatted_masking}_I${i}" >> $log_file + + # Remove the test vectors + rm -rf simvectors/data_S${s}_E${e}_P${p}_F${f}_H1_B${bias}_${activation^}_${formatted_masking}_I${i} + +done