Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 165 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,11 @@ def __init__(
self._set_tokenizer_params() # set tokenizer params
# Skip inputs/outputs
self._session.skip_buffers(
[x for x in self._session.input_names + self._session.output_names if x.startswith("past_")]
[
x
for x in self._session.input_names + self._session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

def _set_tokenizer_params(self):
Expand Down Expand Up @@ -822,6 +826,166 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):

return decode_pause_time

def run_vision_language_continuous_batching_decode(
self, prompt_queue, generation_len, shared_vision_embeddings=None
):
"""
Runs continuous batching decode for vision language models with shared vision embeddings.

Method sets up the initial conditions for decoding and preparing the decode inputs. Then enters a loop that continues as long as there are prompts in the queue or any decoding is ongoing. In each iteration of the loop, it runs the session with the current decode inputs, prepares the inputs for the next iteration and updates the decode inputs. If a prompt has been fully decoded, it runs prefill for the next prompt in the queue if available.

Args:
prompt_queue (deque): The queue of prompts to be decoded.
generation_len (int): The generation length.
shared_vision_embeddings (np.array, optional): Shared vision embeddings for vision-language models. Defaults to None.

"""
# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(self.full_batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32
)
self._session.set_buffers({"logits": logits_out_placeholder})

# Set shared vision embeddings if provided
if shared_vision_embeddings is not None:
self._session.set_buffers(shared_vision_embeddings)

# Generate flag for tracking progress for each batch ID
current_decode_ongoing = np.full((self.full_batch_size, 1), True)

# Generate an array for maintaining the tokens generated in each batch ID
generated_id_current_index = np.ones((self.full_batch_size, 1), np.int64)

# Generate a batch ID map for mapping the batch ID if input > full_batch_size.
# This ID map will be used for storing all generated tokens
batch_id_map = {i: i for i in range(self.full_batch_size)}
decode_pause_time = 0

# Prepare decode inputs.
decode_inputs = self.prepare_decode_inputs()

while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)

# Prepare inputs for next iteration
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

for decode_batch_id in range(self.full_batch_size):
if (
next_token_id[decode_batch_id, -1] == self.tokenizer.eos_token_id
or generated_id_current_index[decode_batch_id] >= self.generation_len[decode_batch_id]
):
if prompt_queue:
start = perf_counter()
# run prefill for next prompt input.
outputs, position_ids, generation_len = self.run_vision_language_prefill(
prompt_queue.popleft(),
generation_len,
decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1),
shared_vision_embeddings=shared_vision_embeddings,
)

new_token_id = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

batch_id_map[decode_batch_id] = max(batch_id_map.values()) + 1
self.generated_ids[batch_id_map[decode_batch_id], 0] = new_token_id.squeeze(1)
generated_id_current_index[decode_batch_id] = 1

self._session.set_buffers({"logits": logits_out_placeholder})

# Re-set shared vision embeddings for consistency
if shared_vision_embeddings:
self._session.set_buffers(shared_vision_embeddings)

decode_pause_time += perf_counter() - start

if self._prompt_to_lora_id_mapping_decode:
decode_inputs["lora_ids"][decode_batch_id] = self._prompt_to_lora_id_mapping_decode[
batch_id_map[decode_batch_id]
]

else:
current_decode_ongoing[decode_batch_id] = False
else:
# If the generated sequence is valid and within generation len prepare for next decode
decode_inputs["input_ids"][decode_batch_id, -1] = next_token_id[decode_batch_id, -1]
decode_inputs["position_ids"][decode_batch_id, -1] += 1
self.generated_ids[batch_id_map[decode_batch_id], generated_id_current_index[decode_batch_id]] = (
next_token_id[decode_batch_id, -1]
)

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time

def run_vision_language_prefill(
self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None, shared_vision_embeddings=None
):
"""
Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
Args:
prompt (str): The prompt for which to run prefill.
generation_len (int): Max allowed length for generating tokens. The decoding process will be terminated when generation length is reached.
decode_batch_id (np.ndarray, optional): The decode batch ID for continuous batching. Defaults to None.
"""
# Run prefill
inputs = self.tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
padded_len = inputs["input_ids"].shape[1]
num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float
padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len

# Initialize variables specific to request
# Calculate the max generation length.
max_gen_len = self._ctx_len - position_ids.max()
generation_len = self._fetch_generation_len(generation_len, max_gen_len)

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((prefill_logit_bs, 1, self._vocab_size), dtype=np.float32)
self._session.set_buffers({"logits": logits_out_placeholder})

# Set shared vision embeddings if provided
if shared_vision_embeddings is not None:
self._session.set_buffers(shared_vision_embeddings)

inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs.pop("token_type_ids", None)

if decode_batch_id is not None:
inputs["batch_index"] = decode_batch_id
if self.is_tlm:
inputs["num_logits_to_keep"] = np.zeros((1, 1))

if self._prompt_to_lora_id_mapping_prefill:
if self.full_batch_size:
inputs["lora_ids"] = np.array(
self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64
).reshape(1, 1)
else:
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
]
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
]
outputs = self._session.run(chunk_inputs)
if self._write_io_dir is not None:
write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False)
return (
outputs,
position_ids,
generation_len,
)

def run_decode(self, decode_inputs, generation_len, streamer: Optional[transformers.TextStreamer] = None):
"""
Default method for running decode. Executes the decoding process for a given set of inputs and a specified generation length.
Expand Down
29 changes: 23 additions & 6 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def update(

else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))

# Update the position_ids to handle the sliding window
Expand All @@ -460,10 +461,22 @@ def update(
valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
if batch_index is not None:
invalid_scatter_index = torch.iinfo(torch.int32).max
scatter_position_ids = torch.where(position_ids < 0, invalid_scatter_index, position_ids)

self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)

self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
Expand All @@ -483,8 +496,12 @@ def update(
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, final_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, final_indices)
else:
k_out = CtxGatherFunc.apply(k_out, final_indices)
v_out = CtxGatherFunc.apply(v_out, final_indices)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out
Loading
Loading