diff --git a/docs/kv_smash/hf_example.py b/docs/kv_smash/hf_example.py new file mode 100644 index 00000000..8b9706b1 --- /dev/null +++ b/docs/kv_smash/hf_example.py @@ -0,0 +1,39 @@ +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +ctx = LinearContext(window_size=100) +ctx.insert( + CBlock( + "Nathan Fulton is a Senior Research Scientist at the MIT-IBM Watson AI Lab, a joint venture between MIT and IBM.", + cache=True, + ) +) +ctx.insert( + CBlock( + "The MIT-IBM Watson AI Lab is located at 314 Main St, Cambridge, Massachusetts.", + cache=True, + ) +) +ctx.insert(CBlock("The ZIP code for 314 Main St, Cambridge, Massachusetts is 02142")) + + +msg = Message( + role="user", content="What is the likely ZIP code of Nathan Fulton's work address." +) +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) +result = backend._generate_from_context_with_kv_cache( + action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000} +) +print(f".{result}.") + +msg2 = Message( + role="user", + content="We know that Nathan does not work for a university. What is the likely name of Nathan's employer?", +) +result = backend._generate_from_context_with_kv_cache( + action=msg2, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000} +) +print(f".{result}.") diff --git a/docs/kv_smash/kv_with_chat.py b/docs/kv_smash/kv_with_chat.py new file mode 100644 index 00000000..f5a249c8 --- /dev/null +++ b/docs/kv_smash/kv_with_chat.py @@ -0,0 +1,110 @@ +import torch + +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) + +model = backend._model +tokenizer = backend._tokenizer +device = backend._device + + +KV_CACHE: dict[str, DynamicCache] = dict() + + +def cache(s: str, store=True) -> DynamicCache: + toks = tokenizer(s, return_tensors="pt") + dc = DynamicCache() + with torch.no_grad(): + rv = model( + toks["input_ids"].to(device), + attention_mask=toks["attention_mask"].to(device), + past_key_values=dc, + ).past_key_values + KV_CACHE[s] = rv + return rv + + +def merge(toks, dcs): + merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1) + merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1) + merged_dcs = merge_dynamic_caches(dcs) + + return merged_toks, merged_masks, merged_dcs + + +c_blocks = ["this is a test", "this is another test"] + +# pretend this stuff already existed in the cahce. +for cb in c_blocks: + cache(cb) + + +# apply the chat template to a conversation that contins these strings, but without tokenization. +messages = [ + {"role": "user", "content": c_blocks[0]}, + {"role": "user", "content": "Not cached"}, + {"role": "user", "content": c_blocks[1]}, + {"role": "user", "content": "Also no cash"}, +] +templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False) + +str_parts = [] +tok_parts = [] +dc_parts = [] + +current_suffix = templatized_input +partially_cached_templatized_input = list[str | DynamicCache] +for cb in c_blocks: + parts = current_suffix.split(cb) + assert len(parts) == 2 + prefix, next_suffix = parts + + if prefix != "": + # Add the prefix. + str_parts.append(prefix) + # Add the tokens and attention mask for the prefix. + tok_parts.append(tokenizer(prefix, return_tensors="pt")) + # Add the dynamic cache for the prefix. + dc_parts.append(cache(prefix, store=False)) + + # Add cb itself. + str_parts.append(cb) + tok_parts.append(tokenizer(cb, return_tensors="pt")) + dc_parts.append(KV_CACHE[cb]) + + # set the current suffix. + current_suffix = next_suffix + +# REMEMBER: add the final suffix. +if current_suffix != "": + str_parts.append(current_suffix) + tok_parts.append(tokenizer(current_suffix, return_tensors="pt")) + dc_parts.append(cache(current_suffix, store=False)) + +# Merge evertything together. +merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1) +merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1) +merged_dcs = merge_dynamic_caches(dc_parts) + +# crop the last KV for safety. +merged_dcs.crop(-1) + +# generate and print result. +result = model.generate( + merged_toks.to(device), + attention_mask=merged_masks.to(device), + past_key_values=merged_dcs, + use_cache=True, + return_dict_in_generate=True, + output_scores=True, +) + +result_decoded = tokenizer.decode( + result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True +) +print(result_decoded) diff --git a/docs/kv_smash/kvcache.py b/docs/kv_smash/kvcache.py new file mode 100644 index 00000000..15b4d9b7 --- /dev/null +++ b/docs/kv_smash/kvcache.py @@ -0,0 +1,55 @@ +import torch + +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B +from mellea.stdlib.base import CBlock, LinearContext +from mellea.stdlib.chat import Message + +backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) + +model = backend._model +tokenizer = backend._tokenizer +device = backend._device + + +def cache(toks) -> DynamicCache: + dc = DynamicCache() + with torch.no_grad(): + rv = model( + toks["input_ids"].to(device), + attention_mask=toks["attention_mask"].to(device), + past_key_values=dc, + ).past_key_values + return rv + + +def merge(strs: list[str]): + strs_toks = [tokenizer(x, return_tensors="pt") for x in strs] + strs_dcs = [cache(toks) for toks in strs_toks] + + merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1) + merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1) + merged_dcs = merge_dynamic_caches(strs_dcs) + + return merged_toks, merged_masks, merged_dcs + + +strs = ["this is a test", "this is another test"] + +merged_toks, merged_masks, merged_dcs = merge(strs) +merged_dcs.crop(-1) + +result = model.generate( + merged_toks.to(device), + attention_mask=merged_masks.to(device), + past_key_values=merged_dcs, + use_cache=True, + return_dict_in_generate=True, + output_scores=True, +) + +result_decoded = tokenizer.decode( + result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True +) +print(result_decoded) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index f09b4a04..5af34978 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -29,7 +29,7 @@ ) from transformers.generation.utils import GenerateDecoderOnlyOutput -from mellea.backends import BaseModelSubclass +from mellea.backends import BaseModelSubclass, kv_block_helpers from mellea.backends.aloras import Alora, AloraBackendMixin from mellea.backends.cache import Cache, SimpleLRUCache from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter @@ -264,6 +264,301 @@ def _generate_from_context_alora( return alora_output + _cached_blocks: dict[str, DynamicCache] = dict() + + def _make_dc_cache(self, toks, **model_options): + dc = DynamicCache() + with torch.no_grad(): + dc = self._model( + toks["input_ids"].to(self._device), + attention_mask=toks["attention_mask"].to(self._device), + past_key_values=dc, + **model_options, + ).past_key_values + return dc + + def _generate_from_context_with_kv_cache( # noqa: C901 + self, + action: Component | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict[str, Any] = {}, + generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: + # Construct input. + # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. + # Otherwise, we will linearize the context and treat it as a raw input. + decoded_result: str | None = None + if ctx.is_chat_context: + linearized_ctx = ctx.render_for_generation() + + assert linearized_ctx is not None, ( + "If ctx.is_chat_context, then the context should be linearizable." + ) + ctx_as_message_list: list[Message] = self.formatter.to_chat_messages( + linearized_ctx + ) + # add action + ctx_as_message_list.extend(self.formatter.to_chat_messages([action])) + + ctx_as_conversation = [ + {"role": m.role, "content": m.content} for m in ctx_as_message_list + ] + + # Check that we ddin't accidentally end up with CBlocks. + for msg in ctx_as_conversation: + for v in msg.values(): + if "CBlock" in v: + FancyLogger.get_logger().error( + f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}" + ) + + # handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step. + system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None) + if system_prompt is not None: + system_msg: dict[str, str] = { + "role": "system", + "content": system_prompt, + } + ctx_as_conversation.insert(0, system_msg) + + # Append tool call information if applicable. + tools: dict[str, Callable] = dict() + if tool_calls: + if format: + FancyLogger.get_logger().warning( + f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" + ) + else: + if isinstance(action, Component) and isinstance( + action.format_for_llm(), TemplateRepresentation + ): + tools = get_tools_from_action(action) + + model_options_tools = model_options.get(ModelOption.TOOLS, None) + if model_options_tools is not None: + assert isinstance(model_options_tools, dict) + for fn_name in model_options_tools: + # invariant re: relationship between the model_options set of tools and the TemplateRepresentation set of tools + assert fn_name not in tools.keys(), ( + f"Cannot add tool {fn_name} because that tool was already defined in the TemplateRepresentation for the action." + ) + # type checking because ModelOptions is an untyped dict and the calling convention for tools isn't clearly documented at our abstraction boundaries. + assert type(fn_name) is str, ( + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." + ) + assert callable(model_options_tools[fn_name]), ( + "When providing a `ModelOption.TOOLS` parameter to `model_options`, always used the type Dict[str, Callable] where `str` is the function name and the callable is the function." + ) + # Add the model_options tool to the existing set of tools. + tools[fn_name] = model_options_tools[fn_name] + + seed = model_options.get(ModelOption.SEED, None) + if seed is not None: + set_seed(seed) + + # Explanation for code blocks inside of use_kv_cache checks: + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. + # 3. apply the chat template (without?) tokenization + # 4. split on cache hits + # 5. prefill + smash together everything. + # 6. generate + + # 1. cache every CBlock that is marked with `cache=True` and store in _cached_blocks. + # AND + # 2. Mark each "hit" by adding the string (tokenized?) value to `cached_block_keys`. + cached_block_keys = [] + for c in linearized_ctx: + match c: + case CBlock() if c.cache: + assert c.value is not None + if c.value in self._cached_blocks: + FancyLogger.get_logger().info( + f"KV CACHE HIT for: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" # type: ignore + ) + else: + FancyLogger.get_logger().debug( + f"HF backend is caching a CBlock with hashed contents: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" + ) + tokens = self._tokenizer(c.value, return_tensors="pt") + dc = DynamicCache() + with torch.no_grad(): + dc = self._model( + tokens["input_ids"].to(self._device), # type: ignore + attention_mask=tokens["attention_mask"].to( + self._device + ), # type: ignore + past_key_values=dc, + use_cache=True, + ).past_key_values + self._cached_blocks[c.value] = dc + cached_block_keys.append(c.value) + case _: + continue + + # 3. apply the chat template WITHOUT tokenization. + # Doing this without tokenization and then gluing together the tokens is necessary because + # things that KV cache together must tokenize together. + input_text = self._tokenizer.apply_chat_template( # type: ignore + ctx_as_conversation, + tools=convert_tools_to_json(tools), # type: ignore + **self._make_backend_specific_and_remove(model_options), + tokenize=False, + ) + + # 4. split the input_text back up again, re-using DC where it exists. + str_parts = [] + tok_parts = [] + dc_parts = [] + current_suffix = input_text + for key in cached_block_keys: + assert key is not None, ( + "Some input CBlock must not have bee ncomputed yet? The error comes far before this line." + ) + assert key in current_suffix, ( + "Could happen but would be rare. related to the other assert in this block." + ) + parts = current_suffix.split(key) # type: ignore + assert len(parts) == 2, ( + "Known issue: cached substring might occur more than once. We need to handle this situation earlier. Notice if this happens and keep a count." + ) + prefix, suffix = parts + # Add the prefix, if any, to str+tok+dc parts. + if prefix != "": + FancyLogger.get_logger().debug( + f"Doing a forward pass on uncached block which is prefix to a cached CBlock: {prefix[:3]}.{len(prefix)}.{prefix[-3:]}" + ) + str_parts.append(prefix) + tok_parts.append(self._tokenizer(prefix, return_tensors="pt")) + dc_parts.append(self._make_dc_cache(tok_parts[-1])) + # Add the cached CBlock to str+tok+dc parts. + FancyLogger.get_logger().debug( + f"Replacing a substring with previously computed/retrieved cache with hahs value {hash(key)} ({key[:3]}..{key[-3:]})" + ) + # str_parts.append(key) + # tok_parts.append(self._tokenizer(key, return_tensors="pt")) + # dc_parts.append(self._make_dc_cache(tok_parts[-1])) # TODO this is wrong. + str_parts.append(key) + tok_parts.append(self._tokenizer(key, return_tensors="pt")) + dc_parts.append(self._cached_blocks[key]) + # set the suffix for the next loop iteration. + current_suffix = suffix + # "base" case: the final suffix. + if current_suffix != "": + FancyLogger.get_logger().debug( # type: ignore + f"Doing a forward pass on final suffix, an uncached block: {current_suffix[:3]}.{len(current_suffix)}.{current_suffix[-3:]}" # type: ignore + ) # type: ignore + str_parts.append(current_suffix) + tok_parts.append(self._tokenizer(current_suffix, return_tensors="pt")) + dc_parts.append(self._make_dc_cache(tok_parts[-1])) + + # Smash together the caches, the input_ids, and the attention masks. + assert "".join(str_parts) == input_text, ( + "Should've ended up with the same input text!" + ) + input_ids = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1) + attention_mask = torch.cat( + [toks["attention_mask"] for toks in tok_parts], dim=1 + ) + assert input_ids.shape == attention_mask.shape + merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts) + # TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes. + + # rewind merged cache by 1 for safety. + merged_cache.crop(-1) + + if format is None: + chat_output = self._model.generate( # type: ignore + input_ids.to(self._device), + attention_mask=attention_mask.to(self._device), + use_cache=True, + past_key_values=merged_cache, + return_dict_in_generate=True, + output_scores=True, + **self._make_backend_specific_and_remove(model_options), + ) # type: ignore + + else: + raise NotImplementedError("Copy implementation from above.") + # outlines.generate.json always parses the resulting json into a python dict. + # We however want to keep it as a json string for later storing it in ModelOutputThunk + schema: dict[str, Any] = format.model_json_schema() + schema_json: str = json.dumps(schema) + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( + schema_json + ) + + from outlines.models.transformers import TransformerTokenizer + from outlines.processors import RegexLogitsProcessor + from transformers import LogitsProcessorList + + chat_output = self._model.generate( # type: ignore + input_ids, + return_dict_in_generate=True, + output_scores=True, + logits_processor=LogitsProcessorList( + [ + RegexLogitsProcessor( + regex_str, + tokenizer=TransformerTokenizer(self._tokenizer), + ) + ] + ), + **self._make_backend_specific_and_remove(model_options), + ) + + decoded_result = self._tokenizer.decode( + chat_output.sequences[0, input_ids.shape[1] :], skip_special_tokens=True + ) + + # Add an entry to the cache for ALora reuse. + if self._use_caches: + output_complete = chat_output.sequences[0] + cache: DynamicCache = chat_output.past_key_values + + cache_info = HFAloraCacheInfo( + kv_cache=cache, + merged_token_ids=output_complete, + merged_attention=torch.ones_like(output_complete).to(self._device), + q_end=len(input_ids[0]), + ) + + assert decoded_result is not None + self.cache_put(decoded_result, cache_info) + else: + raise Exception("Does not yet support non-chat contexts.") + + assert decoded_result is not None + + result = ModelOutputThunk(value=decoded_result) + + # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. + if format is None and tool_calls: + result.tool_calls = self._extract_model_tool_requests(tools, decoded_result) + + parsed_result = self.formatter.parse(action, result) + if generate_logs is not None: + assert isinstance(generate_logs, list) + generate_log = GenerateLog() + generate_log.prompt = ctx_as_conversation + generate_log.backend = f"hf::{self.model_id!s}" + generate_log.model_options = model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = decoded_result + generate_log.extra = { + "format": format, + "tools_available": tools, + "tools_called": result.tool_calls, + "seed": seed, + } + generate_log.action = action + generate_log.result = parsed_result + generate_logs.append(generate_log) + return parsed_result + def _generate_from_context_standard( self, action: Component | CBlock, diff --git a/mellea/backends/kv_block_helpers.py b/mellea/backends/kv_block_helpers.py new file mode 100644 index 00000000..f729ade3 --- /dev/null +++ b/mellea/backends/kv_block_helpers.py @@ -0,0 +1,58 @@ +"""Utilities for KV smashing.""" + +from collections.abc import Iterable +from functools import reduce +from typing import Any + +import torch +from transformers import BatchEncoding, DynamicCache + +TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache] +LegacyCache = Any + + +def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache: + """Concatenates two LegacyCache Ks and Vs along the time axis.""" + legacy_merged = tuple( + (torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2)) + for i in range(len(a)) + ) + return legacy_merged + + +def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache: + """Merges two DynamicCache Ks and Vs along the time axis.""" + legacies = [c.to_legacy_cache() for c in caches] + assert len(legacies) >= 1 + rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore + return rv # type: ignore + + +def combine_representations( + tokenizer, reps: Iterable[str | DynamicCache] +) -> TokenizedCacheIterleaving: + rv = [] + for rep in reps: + if type(rep) is DynamicCache: + rv.append(rep) + else: + rv.append(tokenizer(rep)) + return rv + + +def tokens_to_legacy_cache( + model, device: str, tokens_or_cache: BatchEncoding | DynamicCache +) -> Iterable[LegacyCache]: + """Prefills and returns Ks and Vs as a LegacyCache.""" + if type(tokens_or_cache) is DynamicCache: + return tokens_or_cache.to_legacy_cache() + else: + tokens = tokens_or_cache + dc = DynamicCache() + with torch.no_grad(): + dc = model( + tokens["input_ids"].to(device), # type: ignore + attention_mask=tokens["attention_mask"].to(device), # type: ignore + past_key_values=dc, + ).past_key_values + return dc.to_legacy_cache() diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 2cb8daa6..2535cd28 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -22,11 +22,23 @@ class CBlock: """A `CBlock` is a block of content that can serve as input to or output from an LLM.""" - def __init__(self, value: str | None, meta: dict[str, Any] | None = None): - """Initializes the CBlock with a string and some metadata.""" + def __init__( + self, + value: str | None, + meta: dict[str, Any] | None = None, + *, + cache: bool = False, + ): + """Initializes the CBlock with a string and some metadata. + + Args: + value: the underlying value stored in this CBlock + meta: Any meta-information about this CBlock (e.g., the inference engine's Completion object). + cache: If set to `True` then this CBlock's KV cache might be stored by the inference engine. Experimental.""" if value is not None and not isinstance(value, str): raise TypeError("value to a Cblock should always be a string or None") self._underlying_value = value + self.cache = cache if meta is None: meta = {} self._meta = meta