diff --git a/src/neuronx_distributed_inference/utils/hf_adapter.py b/src/neuronx_distributed_inference/utils/hf_adapter.py index 34a2b852..6b81b5b4 100644 --- a/src/neuronx_distributed_inference/utils/hf_adapter.py +++ b/src/neuronx_distributed_inference/utils/hf_adapter.py @@ -278,6 +278,7 @@ def prepare_inputs_for_generation( scatter_index = kwargs.get("scatter_index", None) position_ids = kwargs.get("position_ids", None) input_capture_hook = kwargs.get("input_capture_hook", None) + tensor_capture_hook = kwargs.get("tensor_capture_hook", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation