From 5a92a0d874f97f450fcba453f16c3a73a5b14fd7 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 3 Jul 2025 12:38:52 +0000 Subject: [PATCH 01/18] Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model --- .../tests/pipeline_tests/test_end2end.py | 523 ++++++++++++++++++ .../tests/pipeline_tests/test_vision_model.py | 86 +++ .../tests/pipeline_tests/test_vision_tower.py | 64 +++ .../mistral_24b/tests/test_conv2d.py | 100 ++++ .../mistral_24b/tests/test_patch_rot_emb.py | 92 +++ .../tests/test_pixtral_transformer.py | 113 ++++ .../tests/test_vision_attention.py | 121 ++++ .../mistral_24b/tests/test_vision_mlp.py | 91 +++ .../mistral_24b/tests/test_vision_rms.py | 96 ++++ models/experimental/mistral_24b/tt/model.py | 125 +++++ .../tt/pipeline/mistral_vision_tower.py | 161 ++++++ .../mistral_24b/tt/pipeline/vision_model.py | 45 ++ models/experimental/mistral_24b/tt/rmsnorm.py | 196 +++++++ .../mistral_24b/tt/vision_attention.py | 249 +++++++++ .../mistral_24b/tt/vision_conv2d.py | 114 ++++ .../experimental/mistral_24b/tt/vision_mlp.py | 114 ++++ .../experimental/mistral_24b/tt/vision_mmp.py | 175 ++++++ .../tt/vision_pixtral_image_block.py | 80 +++ .../tt/vision_pixtral_transformer.py | 53 ++ .../mistral_24b/tt/vision_rope.py | 101 ++++ models/tt_transformers/tt/common.py | 57 ++ models/tt_transformers/tt/generator.py | 1 + models/tt_transformers/tt/load_checkpoints.py | 37 +- models/tt_transformers/tt/model_config.py | 304 ++++++++-- .../pixtral_transformer_inputs/demo_small.jpg | Bin 0 -> 8554 bytes .../pixtral_transformer_inputs/people.jpg | Bin 0 -> 49606 bytes 26 files changed, 3037 insertions(+), 61 deletions(-) create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py create mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py create mode 100644 models/experimental/mistral_24b/tests/test_conv2d.py create mode 100644 models/experimental/mistral_24b/tests/test_patch_rot_emb.py create mode 100644 models/experimental/mistral_24b/tests/test_pixtral_transformer.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_attention.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_mlp.py create mode 100644 models/experimental/mistral_24b/tests/test_vision_rms.py create mode 100644 models/experimental/mistral_24b/tt/model.py create mode 100644 models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py create mode 100644 models/experimental/mistral_24b/tt/pipeline/vision_model.py create mode 100644 models/experimental/mistral_24b/tt/rmsnorm.py create mode 100644 models/experimental/mistral_24b/tt/vision_attention.py create mode 100644 models/experimental/mistral_24b/tt/vision_conv2d.py create mode 100644 models/experimental/mistral_24b/tt/vision_mlp.py create mode 100644 models/experimental/mistral_24b/tt/vision_mmp.py create mode 100644 models/experimental/mistral_24b/tt/vision_pixtral_image_block.py create mode 100644 models/experimental/mistral_24b/tt/vision_pixtral_transformer.py create mode 100644 models/experimental/mistral_24b/tt/vision_rope.py create mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg create mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py new file mode 100644 index 000000000000..a78d1b683371 --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py @@ -0,0 +1,523 @@ +"""Test for Mistral-24B End-to-End Vision-Text Pipeline""" + +import torch +import pytest +from loguru import logger +import os +import ttnn + +from models.tt_transformers.tt.common import ( + sample_host, + PagedAttentionConfig, + preprocess_inputs_prefill, +) + +from models.tt_transformers.tt.model_config import DecodersPrecision +from models.experimental.mistral_24b.tt.model import MistralTransformer as Transformer + +from models.tt_transformers.tt.generator import Generator + +from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer +from models.utility_functions import skip_for_grayskull, skip_for_blackhole + +from models.tt_transformers.tt.model_config import ModelArgs +from transformers import AutoProcessor, AutoModelForVision2Seq + +import re + + +def run_reference_demo_pipeline(messages, model_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503"): + """ + Run Hugging Face reference demo model (Vision-Text pipeline) using given messages. + """ + logger.info("Running reference HF vision-text model...") + + processor = AutoProcessor.from_pretrained(model_id) + model = AutoModelForVision2Seq.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + + model.eval() + + # Apply chat template + prompt_text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + # Extract images (already loaded) + image_inputs = [] + for msg in messages: + for item in msg["content"]: + if item["type"] == "image": + image_inputs.append(item["image"]) + + # Tokenize and move to model device + inputs = processor( + text=[prompt_text], + images=image_inputs, + return_tensors="pt", + ).to(model.device, dtype=torch.bfloat16) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=100, + temperature=0.0, + top_p=0.9, + do_sample=False, + pad_token_id=model.config.pad_token_id, + ) + + # Decode + output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + logger.info(f"HF reference model output: {output}") + + chat = parse_chat_output(output) + display_chat(logger, chat) + + return output + + +def parse_chat_output(text): + """Parse chat output format from generated text.""" + pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" + matches = re.finditer(pattern, text, re.DOTALL) + return [(match.group("role"), match.group("message").strip()) for match in matches] + + +def display_chat(logger, conversation): + """Display chat conversation in formatted output.""" + for role, message in conversation: + if role == "user": + logger.info(f"👤 User: {message}") + elif role == "assistant": + logger.info(f"🤖 Assistant: {message}") + + +def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): + """Setup model arguments for vision-enabled model (Single Responsibility).""" + instruct = True if weights == "instruct" else False + + model_args = ModelArgs( + mesh_device=mesh_device, + instruct=instruct, + optimizations=optimizations, + max_seq_len=max_seq_len, + max_batch_size=batch_size, + ) + + return model_args, instruct + + +def setup_vision_prompts_and_tokenizer(model_args, instruct): + """Setup multimodal prompts and tokenizer for vision-enabled model.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://www.theeducationmagazine.com/wp-content/uploads/2020/03/18.jpg"}, + {"type": "text", "text": "Tell me who you see in the image and describe the image ?"}, + ], + } + ] + + tokenizer = model_args.tokenizer + return messages, tokenizer + + +def process_vision_info(messages): + """Extract images (already opened) from messages.""" + image_inputs = [] + video_inputs = None # Not used + + for msg in messages: + content = msg.get("content", []) + for item in content: + if item.get("type") == "image": + image_inputs.append(item["image"]) + + return image_inputs, video_inputs + + +def process_real_vision_inputs(messages, model_args): + """Process real image inputs using AutoProcessor (Interface Segregation).""" + processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" + ) + + image_inputs, video_inputs = process_vision_info(messages) + # image_inputs, video_inputs = None, None + + encoded = processor( + text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt", return_dict=True + ).to("cpu", dtype=torch.bfloat16) + input_ids = encoded["input_ids"] + pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None + attention_mask = encoded["attention_mask"] if "attention_mask" in encoded else None + image_sizes = encoded["image_sizes"] if "image_sizes" in encoded else None + + return { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "image_sizes": image_sizes, + "processor": processor, + } + + +def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): + """Load separate vision and text models following test_end2end.py pattern.""" + state_dict = model_args.load_state_dict() + + vision_prefix = "vision_tower." + # Setup paged attention config (exactly like test_end2end.py) + paged_attention_config = None + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Load vision model (exactly like test_end2end.py) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=vision_prefix, + dtype=dtype, + model_args=model_args, + ) + + # Load text model (exactly like test_end2end.py) + text_model = Transformer( + args=model_args, + mesh_device=mesh_device, + dtype=dtype, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + paged_attention_config=paged_attention_config, + ) + logger.info("Separate vision and text models loaded like test_end2end.py") + return vision_model, text_model + + +def run_generation_exactly_like_test_end2end( + vision_model, + text_model, + processed_inputs, + model_args, + page_table=None, + paged_attention_config=None, + max_gen_len=20, + repetition_ngram_size=3, +): + """Run generation following the EXACT pattern from test_end2end.py.""" + input_ids = processed_inputs["input_ids"] + + logger.info("Running generation exactly like test_end2end.py...") + + logger.info("Running Vision Model...") + generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) + tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None + + input_tokens_prefill = input_ids + batch_size = input_tokens_prefill.shape[0] + + prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) + input_prompts = [prompt_text] + + ( + input_tokens_prefill_pt, + encoded_prompts, + decoding_pos, + prefill_lens, + ) = preprocess_inputs_prefill( + input_prompts, + model_args.tokenizer, + [model_args], + instruct=True, + max_generated_tokens=max_gen_len, + max_prefill_len=8192, + ) + + input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) + + logger.info("Running prefill...") + logits = generator.prefill_forward_text( + input_tokens_prefill_pt, + page_table=page_table, + kv_cache=tt_kv_cache, + prompt_lens=decoding_pos, + vision_model=vision_model, + processed_inputs=processed_inputs, + ) + + prefilled_token = torch.argmax(logits, dim=-1) + prefilled_token_decoded_res = model_args.tokenizer.decode(prefilled_token[0].item()) + logger.info(f"prefilled_token_decoded_res: {prefilled_token_decoded_res}") + + logger.info(f"Prefilled token: {prefilled_token}") + + import torch.nn.functional as F + + logger.info(f"Encoded prompt: {encoded_prompts[0]}") + logger.info(f"Decoded prompt: {model_args.tokenizer.decode(encoded_prompts[0])}") + + # logits: [1, 1, vocab_size] + last_logits = logits[0, -1] # shape: [vocab_size] + probs = F.softmax(last_logits, dim=-1) + + top_k = 5 + topk_probs, topk_indices = torch.topk(probs, k=top_k) + + topk_tokens = [model_args.tokenizer.decode([idx.item()]) for idx in topk_indices] + + logger.info("🔍 Top-5 predicted tokens (with probabilities):") + for i in range(top_k): + logger.info(f"{i+1}. Token: '{topk_tokens[i]}' (ID={topk_indices[i].item()}), P={topk_probs[i].item():.4f}") + + all_outputs = [encoded_prompts[0][: prefill_lens[0]]] + all_outputs[0].append(int(prefilled_token[0].item())) + + current_pos = torch.tensor([decoding_pos[0]]) + out_tok = prefilled_token + generation_length = max_gen_len + + results = [] + + logger.info("Starting decode loop...") + for iteration in range(generation_length): + logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") + + logits = generator.decode_forward_text( + out_tok, + current_pos, + enable_trace=False, + page_table=page_table, + kv_cache=tt_kv_cache, + ) + + _, out_tok = sample_host( + logits, + temperature=0, + top_p=0.9, + ) + + token_id = out_tok[0].item() + decoded_token = model_args.tokenizer.decode([token_id]) + logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") + + # Stop if EOS detected + if token_id == model_args.tokenizer.eos_token_id: + logger.info("EOS token detected, stopping generation.") + break + + # Stop if repetition detected (n-gram) + if len(all_outputs[0]) >= repetition_ngram_size * 2: + last_ngram = tuple(all_outputs[0][-repetition_ngram_size:]) + for i in range(len(all_outputs[0]) - repetition_ngram_size): + if tuple(all_outputs[0][i : i + repetition_ngram_size]) == last_ngram: + logger.info(f"Detected {repetition_ngram_size}-gram repetition, stopping.") + break + + # Create result object + result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() + + results.append(result) + + all_outputs[0].append(token_id) + current_pos += 1 + + # Early stopping (exactly like test_end2end.py) + if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): + logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") + break + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Each iteration Generated Response:\n{response}") + logger.info(f"📝 Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f" Each iteration Generated {len(results)} tokens successfully") + + # Final response (exactly like test_end2end.py) + response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) + logger.info(f"📝 Final Generated Response:\n{response}") + logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + chat = parse_chat_output(response) + display_chat(logger, chat) + + logger.info(f"Generated {len(results)} tokens successfully") + return results + + +def validate_e2e_outputs(results, expected_min_tokens=1): + """Validate end-to-end pipeline outputs.""" + if not results: + logger.error("No results generated from E2E pipeline") + return False + + if len(results) < expected_min_tokens: + logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") + return False + + # Check if tokens are valid + for result in results: + if not hasattr(result, "token") or not hasattr(result, "text"): + logger.error("Invalid result format") + return False + + logger.info("E2E pipeline validation passed") + return True + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") +@pytest.mark.timeout(1800) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "weights, layers", + [ + ("instruct", None), + ], + ids=["full"], +) +@pytest.mark.parametrize( + "paged_attention", + ( + True, + # False, + ), + ids=( + "paged_attention", + # "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (1024 * 8,), # Use smaller seq_len like test_end2end.py to avoid memory issues +) +@pytest.mark.parametrize( + "optimizations", + [ + lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), + ], + ids=["accuracy"], +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 10 * 1024}], indirect=True) +def test_e2e_vision_text_pipeline( + weights, + layers, + max_seq_len, + batch_size, + paged_attention, + page_params, + optimizations, + mesh_device, + reset_seeds, + request, + device_params, +): + """Test end-to-end vision-text pipeline using proper Generator methods.""" + logger.info("Starting E2E vision-text pipeline test") + + # Use bfloat8_b like test_end2end.py for better memory efficiency + dtype = ttnn.bfloat8_b + + # Setup vision-enabled model configuration + model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) + + if layers is not None: + model_args.n_layers = layers + + # Setup vision prompts and tokenizer + messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) + + # logger.info("Running reference HF vision-text model using messages..... ") + # hf_output = run_reference_demo_pipeline(messages) + + # Process real vision inputs from images + processed_inputs = process_real_vision_inputs(messages, model_args) + + # Load separate models following test_end2end.py pattern + logger.info("Loading separate vision and text models like test_end2end.py...") + vision_model, text_model = load_separate_models_like_test_end2end( + model_args, mesh_device, dtype, paged_attention, page_params + ) + + # Setup page table for paged attention (exactly like test_end2end.py) + page_table_tt = None + paged_attention_config = None + + # Prepare page table for paged attention (exactly like test_end2end.py) + page_table = None + + if paged_attention: + paged_attention_config = PagedAttentionConfig( + block_size=page_params["page_block_size"], + max_num_blocks=page_params["page_max_num_blocks"], + ) + + # Implied shuffling of blocks + permutation = torch.randperm(paged_attention_config.max_num_blocks) + # Page table which maps virtual blocks to physical + reverse_permutation = torch.argsort(permutation) + page_table = reverse_permutation.reshape( + model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size + ) + page_table_tt = ttnn.from_torch( + page_table, + device=mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, -2) if batch_size > 1 else (None, None), + mesh_shape=model_args.cluster_shape, + ), + ) + + # Run generation following EXACT test_end2end.py pattern + logger.info("Running generation following EXACT test_end2end.py pattern...") + results = run_generation_exactly_like_test_end2end( + vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=1024 * 4 + ) + + # Validate results + validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) + + # Final validation + if validation_passed and len(results) > 0: + logger.info("✅ E2E vision-text pipeline test PASSED!") + logger.info(f"Successfully generated {len(results)} tokens") + + # Log generated tokens for debugging + for i, result in enumerate(results[:5]): + logger.info(f"Token {i}: {result.token} -> '{result.text}'") + else: + logger.error("❌ E2E pipeline test failed") + assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py new file mode 100644 index 000000000000..97f5736680ef --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +def get_image_features(vision_tower, projector, input_tensor, image_sizes): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state + image_features = projector(vision_token.squeeze(0), image_sizes) + return image_features + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mistral_vision_model(mesh_device, reset_seeds): + pcc_required = 0.97 + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + mmp_partial_state_dict = { + k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) + } + + reference_mmp = model_args.reference_vision_multi_modal() + reference_mmp.load_state_dict(mmp_partial_state_dict) + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) + + # ##### TT Model: TtMistralVisionTransformer ##### + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + model_args=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0] + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) + tt_output = tt_output[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py new file mode 100644 index 000000000000..35222c0a9b65 --- /dev/null +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_mistral_vision_tower(mesh_device, reset_seeds): + pcc_required = 0.99 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + + reference_output = reference_output.last_hidden_state + ##### TT Model: MistralVisionTower ##### + vision_model = MistralVisionTower( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_output.shape[-1] + ] + tt_output = tt_output.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/experimental/mistral_24b/tests/test_conv2d.py b/models/experimental/mistral_24b/tests/test_conv2d.py new file mode 100644 index 000000000000..69d1ccb35ac9 --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_conv2d.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from ttnn import ConcatMeshToTensor + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_conv2d_inference( + mesh_device, + reset_seeds, +): + pcc_required = 0.9999 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + first_layer_prefix = "vision_tower.patch_conv." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + num_devices = model_args.num_devices + + ##### Create input tensor for the all gather ##### + B, NCH, H, W = (1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size) + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + model_args.vision_dim, + model_args.vision_patch_size, + model_args.vision_patch_size, + False, + ) + + assert NCH == in_channels, "Number of channels in input tensor should match in_channels for the Conv2d patch." + assert type(kernel_size) == int, "Only symmetric kernel_size is currently supported." + assert kernel_size == stride, "Only same kernel_size and stride are currently supported." + + assert H % kernel_size == 0, "Height should be divisible by kernel_size." + assert W % kernel_size == 0, "Width should be divisible by kernel_size." + + ##### Prepare inputs ##### + input_tensor = torch.randn((B, NCH, H, W)).to(dtype=torch.bfloat16) + logger.info(f"Input tensor shape: {input_tensor.shape}") + + reference_model = model_args.reference_conv2d_patch() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor) + + tt_model = TtMistralConv2dPatch( + mesh_device, + state_dict, + first_layer_prefix, + dtype, + in_channels, + out_channels, + kernel_size, + stride, + bias, + ) + tt_output = tt_model(input_tensor) + + ##### Check the outputs ##### + out = ttnn.from_device(tt_output) + tt_output_torch = ttnn.to_torch(out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=2)) + + # Only select output from one device + tt_output_torch = tt_output_torch[0, ..., :out_channels] + + # 1. Restore batch dim + tt_output_torch = tt_output_torch.unsqueeze(0) + # 1 1024 4096 + # 2. Permute to match Conv2D output: (N, C_out, H_out, W_out) + tt_output_torch = tt_output_torch.permute(0, 2, 1).reshape(1, 1024, 64, 64) + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py new file mode 100644 index 000000000000..903cdf395d91 --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py @@ -0,0 +1,92 @@ +from loguru import logger + +import torch +import pytest +import os +import ttnn + +# models/tt_transformers/tt/common.py +from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rot_emb(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + partial_state_dict = {} + + reference_model = tt_model_args.reference_vision_rot_emb() + reference_model.load_state_dict(partial_state_dict) + + image_size = tt_model_args.vision_image_size + patch_size = tt_model_args.vision_patch_size + dim = tt_model_args.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + position_ids = torch.arange(4096, dtype=torch.long) + + x = torch.randn(batch_size, 4096, 1024) + + cos, sin = reference_model(x, position_ids) + tt_model = RotarySetup( + device, + batch_size, + dim, + image_size, + patch_size, + num_patches, + tt_model_args.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + cos2, sin2 = tt_model.get_rot_mats(position_ids) + cos2 = ttnn.from_device(cos2) + cos2 = ttnn.to_torch(cos2) + cos2 = cos2.squeeze(0) + + sin2 = ttnn.from_device(sin2) + sin2 = ttnn.to_torch(sin2) + sin2 = sin2.squeeze(0) + + passing, pcc_message = comp_pcc(cos, cos2) + + logger.info(comp_allclose(cos, cos2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"COS PCC value is lower than {0.99} for some of the outputs. Check Warnings!" + + passing, pcc_message = comp_pcc(sin, sin2) + + logger.info(comp_allclose(sin, sin2)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"SIN PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py new file mode 100644 index 000000000000..affd89d61d6d --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs + +from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "batch, num_chunks", + ((1, 1),), +) +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +def test_image_transformer_inference(batch, num_chunks, mesh_device): + pcc_required = 0.99 + + model_args = ModelArgs(mesh_device) + dtype = ttnn.bfloat16 + + state_dict = model_args.load_state_dict() + n_layers = model_args.vision_n_layers + first_layer_prefix = "vision_tower.transformer." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + dim = model_args.vision_dim + heads = model_args.vision_attn_n_heads + seq_len = model_args.vision_chunk_ntok - 1 + head_dim = dim // heads + + reference_model = model_args.reference_vision_encoder() + reference_model.load_state_dict(partial_state_dict) + reference_model.eval() + + all_tests_pass = True + + tt_model = TtPixtralTransformer( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=None, + dtype=dtype, + configuration=model_args, + layers=n_layers, + ) + + # Create PT input + pt_attention_input = torch.rand(batch, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + attention_input = model_args.prepare_residual_tensor_prefill( + pt_attention_input, + force_replicated=True, + ) + + with torch.no_grad(): + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + reference_output = reference_model( + pt_attention_input, attention_mask=attention_mask, position_embeddings=(cos, sin) + )[0] + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[ + : tt_out.shape[0] + ] + tt_output_torch = tt_output_torch.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + if not passing: + logger.warning(f"PCC value -- {pcc_message} -- is lower than {pcc_required} for the output.") + else: + logger.info(f"PCC: {pcc_message}") + logger.info(comp_allclose(reference_output, tt_output_torch)) + all_tests_pass = all_tests_pass and passing + + assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_vision_attention.py b/models/experimental/mistral_24b/tests/test_vision_attention.py new file mode 100644 index 000000000000..8466b102eed9 --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_attention.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention + +from ttnn import ConcatMeshToTensor + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_vision_attention(mesh_device, seq_len, batch_size): + logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=256) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.attention." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_attention() + reference_model.load_state_dict(partial_state_dict) + + hidden_size = model_args.vision_dim + n_heads = model_args.vision_attn_n_heads + head_dim = hidden_size // n_heads + + tt_model = TtLlamaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=first_layer_prefix, + weight_cache_path=model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=model_args, + ) + + dim = model_args.vision_dim + pt_attention_input = torch.randn(batch_size, seq_len, dim).to(torch.bfloat16) + attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len).to(torch.bfloat16) + + B, T, D = pt_attention_input.shape + cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) + sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) + + # attention_input = model_args.prepare_residual_tensor_prefill( + # pt_attention_input, + # force_replicated=True, + # ) + + attention_input = ttnn.from_torch( + pt_attention_input.unsqueeze(0), + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + cos_t = ttnn.from_torch( + cos, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + sin_t = ttnn.from_torch( + sin, + device=mesh_device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + tt_out = tt_model(attention_input, position_embeddings=(cos_t, sin_t)) + tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_out.shape[-1] + ] + tt_output_torch = tt_output_torch.squeeze(0) + reference_output = reference_model(pt_attention_input, attention_mask, position_embeddings=(cos, sin))[0] + pcc_required = 0.99 + + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!" diff --git a/models/experimental/mistral_24b/tests/test_vision_mlp.py b/models/experimental/mistral_24b/tests/test_vision_mlp.py new file mode 100644 index 000000000000..32159eeabf7b --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_mlp.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn + +# from models.tt_transformers.tt.mlp import MLP +from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP +from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (64 * 1024,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_mlp_inference(seq_len, batch_size, mesh_device, reset_seeds): + dtype = ttnn.bfloat8_b + mode = "decode" if seq_len <= 32 else "prefill" + + model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=128) + model_args.n_layers = 1 + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower.transformer.layers.0.feed_forward." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model = model_args.reference_vision_mlp() + reference_model.load_state_dict(partial_state_dict) + + tt_model = MLP( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + weight_cache_path=model_args.weight_cache_path(dtype), + state_dict_prefix="vision_tower.transformer.layers.0.feed_forward.", + dtype=dtype, + ) + torch_input = torch.randn(1, 1, seq_len, 1024).to(torch.bfloat16) + + reference_output = reference_model(torch_input) + tt_input = ttnn.from_torch( + torch_input, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + ) + + logger.info("Run MLP") + tt_output = tt_model(tt_input) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, :1024 + ] + pcc_required = 0.99 + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + if passing: + logger.info("MLP Passed!") + else: + logger.warning("MLP Failed!") + + assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/mistral_24b/tests/test_vision_rms.py b/models/experimental/mistral_24b/tests/test_vision_rms.py new file mode 100644 index 000000000000..93181c2bc95f --- /dev/null +++ b/models/experimental/mistral_24b/tests/test_vision_rms.py @@ -0,0 +1,96 @@ +from loguru import logger + +import torch +import pytest +import os + +import ttnn +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + +from models.tt_transformers.tt.model_config import ModelArgs + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("device"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "seq_len", + (128,), +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): + dtype = ttnn.bfloat16 + mode = "decode" if seq_len <= 32 else "prefill" + + tt_model_args = ModelArgs( + device, + max_batch_size=batch_size, + max_seq_len=128, + ) + + tt_model_args.n_layers = 1 + state_dict = tt_model_args.load_state_dict() + + reference_model = tt_model_args.reference_vision_rms() + + first_layer_prefix = "vision_tower.transformer.layers.0.ffn_norm." + + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + reference_model.load_state_dict(partial_state_dict) + + tt_model = RMSNorm( + device=device, + dim=1024, + state_dict=state_dict, + state_dict_prefix="vision_tower.transformer.layers.0.", + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + input = torch.rand(batch_size, seq_len, 1024) + + reference_output = reference_model(input) + + tt_input = ttnn.from_torch( + input, + device=device, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device=device), + ) + + tt_output = tt_model(tt_input, mode=mode) + + tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + logger.info(f"tt_output_torch: {tt_output_torch.shape}") + passing, pcc_message = comp_pcc(reference_output, tt_output_torch) + + logger.info(comp_allclose(reference_output, tt_output_torch)) + logger.info(f"PCC: {pcc_message}") + + if passing: + logger.info("rms_norm Passed!") + else: + logger.warning("rms_norm Failed!") + + assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/mistral_24b/tt/model.py b/models/experimental/mistral_24b/tt/model.py new file mode 100644 index 000000000000..bfe094b9d8ed --- /dev/null +++ b/models/experimental/mistral_24b/tt/model.py @@ -0,0 +1,125 @@ +""" +This is the end-to-end pipeline for the Mistral-Small-3.1-24B-Instruct-2503 model. + +The `MistralTransformer` class inherits from the `Transformer` class in tt_transformers. +It overrides `prepare_inputs_prefill` to run inference on the vision model and +pass the resulting visual tokens to the text model along with text tokens. +""" + + +import ttnn +import torch + +from models.tt_transformers.tt.model import Transformer +from ttnn import ConcatMeshToTensor + + +class MistralTransformer(Transformer): + def __init__( + self, + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=None, + use_paged_kv_cache=False, + ): + super().__init__( + args, + dtype, + mesh_device, + state_dict, + weight_cache_path, + paged_attention_config=paged_attention_config, + use_paged_kv_cache=use_paged_kv_cache, + ) + + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + """ + Inputs are torch tensors or python types. This function returns ttnn + tensors on device. + TODO: Debate whether this function is responsible for padding + """ + + tokens = tokens.reshape(1, 1, 1, -1) + S = tokens.shape[-1] + tokens = ttnn.from_torch( + tokens, + device=self.mesh_device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + # self.embed_scale = args.dim**0.5 + tokens_embd = self.embd(tokens) + # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) + + pixel_values = kwargs["processed_inputs"]["pixel_values"] + input_ids = kwargs["processed_inputs"]["input_ids"] + image_sizes = kwargs["processed_inputs"]["image_sizes"] + + if pixel_values is not None: + vision_model = kwargs["vision_model"] + vision_output = vision_model(pixel_values, image_sizes) + vision_output_torch = ttnn.to_torch( + vision_output, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1) + )[:, : vision_output.shape[-1]] + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1)) + sliced_token_embds = tokens_embd[: tokens_embd.shape[0]] + + image_features = vision_output_torch + + input_ids = torch.nn.functional.pad( + input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 + ) + # image_features = image_features.squeeze(0) + special_image_mask = (input_ids == 10).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + + tokens_embd = ttnn.from_torch( + tokens_embd, + dtype=ttnn.bfloat16, + device=self.mesh_device, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ttnn.ShardTensor2dMesh( + self.mesh_device, dims=(None, 2), mesh_shape=list(self.mesh_device.shape) + ), + ) + + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) + # Slice the rot mats to the prefill seqlen + assert ( + self.rope_setup.cos_matrix.shape[2] >= start_pos + S + ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" + + tt_rot_mats_prefill_global = [ + self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + + if page_table is not None: + tt_page_table = ttnn.from_torch( + page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_page_table = None + + if chunk_page_table is not None: + tt_chunk_page_table = ttnn.from_torch( + chunk_page_table, + device=self.mesh_device, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + else: + tt_chunk_page_table = None + + return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py new file mode 100644 index 000000000000..c78b9a9a3669 --- /dev/null +++ b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.tt_transformers.tt.common import position_ids_in_meshgrid_tt +from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup + +from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from ttnn import ConcatMeshToTensor + + +class MistralVisionTower(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + configuration, + return_intermediate=None, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.dtype = dtype + self.config = configuration + + self.image_size = configuration.vision_chunk_size + self.patch_size = configuration.vision_patch_size + self.width = configuration.vision_dim + self.layers = configuration.vision_n_layers + self.heads = configuration.vision_attn_n_heads + self.vision_head_dim = configuration.vision_head_dim + self.mlp_ratio = configuration.vision_mlp_ratio + self.act_layer = configuration.vision_act_layer + self.in_channels = configuration.vision_in_channels + self.n_global_layers = configuration.vision_n_global_layers + self.max_seq_len = configuration.max_seq_len + self.return_intermediate = return_intermediate + self.n_layers = configuration.vision_n_layers + + in_channels, out_channels, kernel_size, stride, bias = ( + 3, + configuration.vision_dim, + configuration.vision_patch_size, + configuration.vision_patch_size, + False, + ) + + self.patch_conv = TtMistralConv2dPatch( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_conv.", + dtype=self.dtype, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + self.ln_pre = RMSNorm( + device=mesh_device, + dim=self.width, + state_dict=self.state_dict, + state_dict_prefix=state_dict_prefix, + weight_dtype=dtype, + weight_key="ln_pre", + is_distributed=False, + simplified_rms=True, + ) + + image_size = configuration.vision_image_size + patch_size = configuration.vision_patch_size + dim = configuration.vision_head_dim + num_patches_per_dim = image_size // patch_size + num_patches = num_patches_per_dim * num_patches_per_dim + self.num_patches = num_patches + + self.patch_positional_embedding = RotarySetup( + self.mesh_device, + 1, + dim, + image_size, + patch_size, + num_patches, + configuration.vision_rope_theta, + scale_factor=None, + orig_context_len=num_patches, + datatype=dtype, + ) + + self.transformer = TtPixtralTransformer( + mesh_device=self.mesh_device, + state_dict=self.state_dict, + state_dict_prefix=f"{state_dict_prefix}transformer.", + weight_cache_path=configuration.weight_cache_path(dtype), + dtype=self.dtype, + configuration=configuration, + layers=self.n_layers, + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + patch_embeds = self.patch_conv(input_tensor) + patch_embeds = ttnn.transpose(patch_embeds, 1, 2) + height, width = image_sizes[0] + patch_embeds = ttnn.reshape( + patch_embeds, + [patch_embeds.shape[0], self.width, height // self.patch_size, width // self.patch_size], + ) + + patch_embeds_list = [ + ttnn.slice( + patch_embeds, + [0, 0, 0, 0], + [1, self.width, size[0] // self.patch_size, size[1] // self.patch_size], + ) + for size in image_sizes + ] + + reshaped_patches = [] + for p in patch_embeds_list: + p = ttnn.reshape(p, (1, self.width, -1)) + p = ttnn.transpose(p, 1, 2) + reshaped_patches.append(p) + + patch_embeds = ttnn.concat(reshaped_patches, dim=0) + + # ln_pre RMS Norm + mode = "prefill" # if self.max_seq_len <= 32 else "prefill" + patch_embeds = self.ln_pre(patch_embeds, mode=mode) + + # # positional embeddings + position_ids = position_ids_in_meshgrid_tt( + patch_embeds_list, + max_width=self.config.vision_image_size // self.config.vision_patch_size, + device=self.mesh_device, + ) + + torch_position_ids = ttnn.to_torch(position_ids, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=0))[ + : position_ids.shape[-1] + ] + + position_embeddings = self.patch_positional_embedding.get_rot_mats(torch_position_ids) + + patch_embeds = ttnn.unsqueeze(patch_embeds, 0) + out = self.transformer(patch_embeds, position_embeddings=position_embeddings) + # deallocate position_embeddings + ttnn.deallocate(position_embeddings[0]) + ttnn.deallocate(position_embeddings[1]) + + return out diff --git a/models/experimental/mistral_24b/tt/pipeline/vision_model.py b/models/experimental/mistral_24b/tt/pipeline/vision_model.py new file mode 100644 index 000000000000..098c32bab03f --- /dev/null +++ b/models/experimental/mistral_24b/tt/pipeline/vision_model.py @@ -0,0 +1,45 @@ +""" +This is the end-to-end architecture of the Mistral-24B vision model. + +It brings together all components related to visual and MultiModalProjector together. +""" + +import ttnn +from models.common.lightweightmodule import LightweightModule + +from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower +from models.experimental.mistral_24b.tt.vision_mmp import TTMistral3MultiModalProjector + + +class TtMistralVisionTransformer(LightweightModule): + def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, model_args): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + + self.vision_tower = MistralVisionTower( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + dtype=dtype, + configuration=model_args, + ) + + self.mmp = TTMistral3MultiModalProjector( + mesh_device=mesh_device, + args=model_args, + state_dict=state_dict, + state_dict_prefix="multi_modal_projector.", + dtype=dtype, + eps=1e-05, # layer_norm_eps + ) + + def forward(self, input_tensor, image_sizes=None): + """ + input_tensor shape: (B, C, H, W) + """ + + x = self.vision_tower(input_tensor, image_sizes=image_sizes) + x = ttnn.squeeze(ttnn.squeeze(x, 0), 0) + x = self.mmp(x, image_sizes) + return x diff --git a/models/experimental/mistral_24b/tt/rmsnorm.py b/models/experimental/mistral_24b/tt/rmsnorm.py new file mode 100644 index 000000000000..65f12713999d --- /dev/null +++ b/models/experimental/mistral_24b/tt/rmsnorm.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import ttnn +from models.common.lightweightmodule import LightweightModule + +TILE = 32 +SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile + + +class RMSNorm(LightweightModule): + """ + RMSNorm supporting replication over a MeshDevice and sharding within devices. + + This class implements a Root Mean Square Normalization (RMSNorm) that can be + distributed across multiple devices and cores. If the `device` parameter is a + MeshDevice, the weights and computations are replicated across all devices in + the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. + + Args: + device: The device or MeshDevice on which to perform the computations. + state_dict: The state dictionary containing the model parameters. + dim: Input dimension (e.g. model hidden dimension size). + layer_num: The layer number to determine the weight key in the state dictionary. + weight_key: The key for retrieving the weight from the state dictionary. + weight_cache_path: Optional path for caching the tilized weights. + weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. + weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. + model_config: Optional configuration dictionary for the model. + eps (float): Small value to avoid division by zero in normalization, default is 1e-05. + + If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG + and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. + """ + + def __init__( + self, + device, + dim, + state_dict, + weight_key, + layer_num=None, + state_dict_prefix=None, + weight_cache_path=None, + weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, + weight_dtype=ttnn.bfloat16, + is_distributed=None, + eps: float = 1e-05, + add_unit_offset=False, + sharded_program_config=None, + sharded_output_config=None, + output_mem_config=None, + ccl_topology=ttnn.Topology.Ring, + simplified_rms=False, + ): + super().__init__() + self.eps = eps + self.is_distributed = is_distributed + self.ccl_topology = ccl_topology + + if state_dict_prefix: + weight_name = f"{state_dict_prefix}{weight_key}.weight" + else: + if layer_num is None: + weight_name = f"{weight_key}.weight" + else: + weight_name = f"layers.{layer_num}.{weight_key}.weight" + + torch_weight = ( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) + ) + + # Add offset before caching + if add_unit_offset: + torch_weight = torch_weight + 1.0 + + cache_name = None if weight_cache_path is None else weight_cache_path / weight_name + + # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) + is_mesh_device = device.__class__.__name__ == "MeshDevice" + + self.weight = ttnn.as_tensor( + state_dict[weight_name].unsqueeze(0).view(1, 1, dim), + device=device, + dtype=weight_dtype, + layout=ttnn.TILE_LAYOUT if weight_dtype == ttnn.bfloat8_b else ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + ) + + if self.is_distributed: + self.weight_distributed = ttnn.as_tensor( + torch_weight, + device=device, + dtype=weight_dtype, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=weight_memory_config, + cache_file_name=cache_name, + mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) + if is_mesh_device + else None, + ) + + self.sharded_output_config = sharded_output_config + self.sharded_program_config = sharded_program_config + self.output_mem_config = output_mem_config + + self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + self.simplified_rms = simplified_rms + + def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: + # If input is sharded do sharded RMSNorm and optionally return sharded output + program_config = self.sharded_program_config if in_sharded else None + memory_config = self.sharded_output_config if out_sharded else None + distributed = self.is_distributed and self.is_distributed(mode) + norm = ( + self._simplified_rmsnorm + if self.simplified_rms + else self._distributed_rmsnorm + if distributed + else ttnn.rms_norm + ) + + weight = self.weight_distributed if distributed else self.weight + + if in_sharded: + assert not distributed, "Distributed RMSNorm does not support sharded inputs" + else: + assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + + x = norm( + x, + epsilon=self.eps, + weight=weight, + program_config=program_config, + memory_config=memory_config, + compute_kernel_config=self.compute_kernel_config_hifi2, + ) + + if in_sharded and not out_sharded: + return ttnn.sharded_to_interleaved(x) + else: + return x + + def _simplified_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + inp = ttnn.sharded_to_interleaved(inp, ttnn.DRAM_MEMORY_CONFIG) + xnorm = ttnn.pow(inp, 2) + xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) + xnorm = ttnn.rsqrt(xnorm + epsilon) + xnorm = ttnn.multiply(inp, xnorm) + weight = ttnn.reshape(weight, [1, 1, -1]) + output = ttnn.multiply(xnorm, (weight), use_legacy=False) + + if memory_config is not None: + output = ttnn.to_memory_config(output, memory_config) + + ttnn.deallocate(xnorm) + ttnn.deallocate(weight) + + return output + + def _distributed_rmsnorm( + self, inp, epsilon=None, weight=None, program_config=None, memory_config=None, compute_kernel_config=None + ): + assert program_config is None, "Distributed RMSNorm does not support sharded inputs" + assert memory_config is None, "Distributed RMSNorm does not support sharded outputs" + + # Run distributed rmsnorm part 1 + tt_stats = ttnn.rms_norm_pre_all_gather(inp, compute_kernel_config=compute_kernel_config, dtype=ttnn.bfloat16) + # AllGather stats + tt_stats = ttnn.all_gather( + tt_stats, + dim=3, + num_links=1, + topology=self.ccl_topology, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + # Run distributed rmsnorm part 2 + tt_out = ttnn.rms_norm_post_all_gather( + inp, + tt_stats, + epsilon=epsilon, + weight=weight, + compute_kernel_config=compute_kernel_config, + ) + tt_stats.deallocate(True) + + return tt_out diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py new file mode 100644 index 000000000000..3bcc772f36de --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.utility_functions import nearest_32 + + +def rotate_half(x): + last_dim = x.shape[-1] + half = last_dim // 2 + + x1 = ttnn.slice(x, (0, 0, 0, 0), (x.shape[0], x.shape[1], x.shape[2], half)) + x2 = ttnn.slice(x, (0, 0, 0, half), (x.shape[0], x.shape[1], x.shape[2], last_dim)) + + neg_x2 = ttnn.mul(x2, -1, use_legacy=False) + return ttnn.concat([neg_x2, x1], dim=-1) + + +def apply_rotary_pos_emb_vision_tt(q, k, cos, sin): + cos = ttnn.unsqueeze(cos, 0) + sin = ttnn.unsqueeze(sin, 0) + + q_embed = ttnn.add(ttnn.mul(q, cos, use_legacy=True), ttnn.mul(rotate_half(q), sin, use_legacy=True)) + k_embed = ttnn.add(ttnn.mul(k, cos), ttnn.mul(rotate_half(k), sin)) + return q_embed, k_embed + + +class TtMistralImageAttention(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + self.num_devices = configuration.num_devices + + self.hidden_size = configuration.vision_dim + self.n_heads = configuration.vision_attn_n_heads + self.head_dim = self.hidden_size // self.n_heads + self.n_kv_heads = self.n_heads + + self.n_local_heads = self.n_heads // configuration.num_devices + self.n_local_kv_heads = self.n_kv_heads // configuration.num_devices + + self.dtype = dtype + + self.grid_size = configuration.max_grid_size + + self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 + self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 + self.compute_kernel_config_sdpa = configuration.compute_kernel_config_sdpa + self.configuration = configuration + + self.model_config = configuration.get_model_config() + + if configuration.dummy_weights or (weight_cache_path is None): + cache_name = lambda _: None + else: + cache_name = lambda name: weight_cache_path / (f"{state_dict_prefix}{name}") + + wq_str = f"{state_dict_prefix}wq.weight" + wk_str = f"{state_dict_prefix}wk.weight" + wv_str = f"{state_dict_prefix}wv.weight" + wo_str = f"{state_dict_prefix}wo.weight" + + # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices + assert self.n_heads % configuration.num_devices == 0 + assert self.n_kv_heads % configuration.num_devices == 0 + + # Pad head_dim to multiple of 32 + def pad_head_dim(weight, heads_out=True): + # Pad head dim to multiple of 32 + # heads_out means that the output dim of this weight contains heads. + dim = weight.shape[1] + assert weight.shape[0] == dim + padded_head_dim = nearest_32(self.head_dim) + padding_size = padded_head_dim - self.head_dim + if padding_size > 0: + if heads_out: + weight = weight.transpose(-1, -2) + weight = weight.reshape(dim, self.n_heads, self.head_dim) + padding = torch.zeros(dim, self.n_heads, padding_size, dtype=weight.dtype) + weight = torch.cat([weight, padding], dim=-1) + weight = weight.reshape(dim, self.n_heads * padded_head_dim) + if heads_out: + weight = weight.transpose(-1, -2) + return weight + + wq_padded = pad_head_dim(self.state_dict[wq_str]) + wk_padded = pad_head_dim(self.state_dict[wk_str]) + wv_padded = pad_head_dim(self.state_dict[wv_str]) + wo_padded = pad_head_dim(self.state_dict[wo_str], heads_out=False) + wq_chunked, wk_chunked, wv_chunked = ( + torch.chunk(w, configuration.num_devices) for w in [wq_padded, wk_padded, wv_padded] + ) + + self.wqkv = ttnn.as_tensor( + torch.concat( + [ + torch.concat( + [ + torch.transpose( + wq_chunked[i], + -2, + -1, + ), + torch.transpose( + wk_chunked[i], + -2, + -1, + ), + torch.transpose( + wv_chunked[i], + -2, + -1, + ), + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name("wqkv_sharded"), + ) + + self.wo = ttnn.as_tensor( + torch.transpose( + wo_padded, + -2, + -1, + ), + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + dtype=self.dtype, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name("wo_sharded"), + ) + + self.scale = self.head_dim**-0.5 + + def forward(self, x_11SH, position_embeddings=None): + seq_len = x_11SH.shape[-2] + + MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ + + if seq_len > MAX_MM_SEQ_LEN: + x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + xqkv_fused = ttnn.linear( + x_11SH, + self.wqkv, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config_hifi4, + program_config=self.model_config["IMAGE_ATTN_QKV_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + + # split qkv into heads + ( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + ) = ttnn.experimental.nlp_create_qkv_heads( + xqkv_fused, + num_heads=self.n_local_heads, + num_kv_heads=self.n_local_kv_heads, + transpose_k_heads=False, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if position_embeddings is not None: + cos, sin = position_embeddings + q_heads_1QSD, k_heads_1KSD = apply_rotary_pos_emb_vision_tt(q_heads_1QSD, k_heads_1KSD, cos, sin) + ttnn.deallocate(xqkv_fused) + # TODO: get this from model_config + sdpa_cfg = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=(8, 8), q_chunk_size=128, k_chunk_size=128, exp_approx_mode=False + ) + attn_output_1QSD = ttnn.transformer.scaled_dot_product_attention( + q_heads_1QSD, + k_heads_1KSD, + v_heads_1VSD, + is_causal=False, + scale=self.scale, + program_config=sdpa_cfg, + compute_kernel_config=self.compute_kernel_config_sdpa, + ) + # deallocate keys and values + ttnn.deallocate(q_heads_1QSD) + ttnn.deallocate(k_heads_1KSD) + ttnn.deallocate(v_heads_1VSD) + + ### + # Output matmul + ### + attn_output_11SH = ttnn.experimental.nlp_concat_heads( + attn_output_1QSD, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + ttnn.deallocate(attn_output_1QSD) + + # reshaping long sequence to matmul fit on device + if seq_len > MAX_MM_SEQ_LEN: + attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) + + output_11SH = ttnn.linear( + attn_output_11SH, + self.wo, + compute_kernel_config=self.compute_kernel_config_hifi4, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + program_config=self.model_config["IMAGE_ATTN_OUT_PROGCFG"](seq_len, MAX_MM_SEQ_LEN), + ) + if seq_len > MAX_MM_SEQ_LEN: + output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) + ttnn.deallocate(attn_output_11SH) + + # All reduce + if self.num_devices > 1: # replace with reduce_scatter and all_gather + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + output_11SH.deallocate(True) + dense_out_reduced = ttnn.experimental.fast_reduce_nc( + dense_out_gathered, dims=[1], output=None, compute_kernel_config=None + ) + # slicing the required sequence length + dense_out_reduced = dense_out_reduced[:, :, : dense_out_gathered.shape[-2], :] + return dense_out_reduced + else: + return output_11SH diff --git a/models/experimental/mistral_24b/tt/vision_conv2d.py b/models/experimental/mistral_24b/tt/vision_conv2d.py new file mode 100644 index 000000000000..0b16dca7fbcf --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_conv2d.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class TtMistralConv2dPatch(LightweightModule): + """Conv2D Patching layer. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + dtype, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias, + ): + super().__init__() + + self.mesh_device = mesh_device + self.num_devices = self.mesh_device.get_num_devices() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.bias = ( + ttnn.as_tensor( + torch.reshape(state_dict[f"{state_dict_prefix}_linear.bias"], (1, -1)), + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + if bias + else None + ) + + self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) + + weight = state_dict[f"{state_dict_prefix}weight"] + if weight.ndim == 4: + weight = weight.reshape(out_channels, -1).T + # pad_len = nearest_32(weight.shape[-1]) - weight.shape[-1] + # padding = torch.zeros(self.out_channels, pad_len, dtype=weight.dtype) + # padded_weight = torch.cat([weight, padding], dim=-1) + # padded_weight = padded_weight.permute(1, 0).reshape(1, 1, -1, self.out_channels) + + self._linear_weight = ttnn.as_tensor( + weight, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + self.compute_kernel_config = ttnn.init_device_compute_kernel_config( + mesh_device.arch(), + math_fidelity=ttnn.MathFidelity.HiFi2, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + def forward(self, x: torch.Tensor): + x = self._unfold(x) + x = x.permute(0, 2, 1) + + # Need to pad the last dimension of x to be a multiple of a tile + # pad_len = nearest_32(x.shape[-1]) - x.shape[-1] + # padding = torch.zeros((x.shape[0], x.shape[1], pad_len), dtype=x.dtype, device=x.device) + # x = torch.cat([x, padding], dim=-1) + + x = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=self.mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + ) + + out = ttnn.linear( + x, + self._linear_weight, + bias=self.bias, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) + + return out diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py new file mode 100644 index 000000000000..61ac96c3ed45 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule + + +class MistralTTVisionMLP(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + weight_cache_path, + dtype, + state_dict_prefix=None, + ): + super().__init__() + + self.mesh_device = mesh_device + self.args = args + self.state_dict = state_dict + self.dim = args.dim + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + # Weights and Biases + self.w1 = as_tensor("w1", dtype) + self.b1 = as_tensor("w1", ttnn.bfloat16, is_bias=False) + + self.w3 = as_tensor("w3", dtype) + self.b3 = as_tensor("w3", ttnn.bfloat16, is_bias=False) + + self.w2 = as_tensor("w2", dtype) + self.b2 = as_tensor("w2", ttnn.bfloat16, is_bias=False) + + self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + fp32_dest_acc_en=True, + packer_l1_acc=True, + dst_full_sync_en=False, + ) + + def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + """ + Qwen HF MLP reference: + output = down_proj(act_fn(gate_proj(x)) * up_proj(x)) + Mapping: + w1 -> gate_proj + w3 -> up_proj + w2 -> down_proj + """ + + # if x.shape[-2] >= self.args.prefill_len_cutoff and mode != "decode": + # x = ttnn.reshape(x, [1, x.shape[-2] // self.args.prefill_len_cutoff, self.args.prefill_len_cutoff, -1]) + + # Linear with SILU activation + w1_out = ttnn.linear( + x, + self.w1, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="silu", + compute_kernel_config=self.compute_kernel_config, + ) + + w3_out = ttnn.linear( + x, + self.w3, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + # Element-wise multiply + w2_in = ttnn.mul(w1_out, w3_out, dtype=ttnn.bfloat16) + + # Final projection + w2_out = ttnn.linear( + w2_in, + self.w2, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + compute_kernel_config=self.compute_kernel_config, + ) + + ttnn.deallocate(w1_out) + ttnn.deallocate(w3_out) + ttnn.deallocate(w2_in) + return w2_out diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py new file mode 100644 index 000000000000..a54f3057270a --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm +import ttnn +from ttnn import ConcatMeshToTensor + + +class TTMistral3PatchMerger(LightweightModule): + def __init__( + self, + mesh_device, + args, + state_dict, + state_dict_prefix, + weight_cache_path=None, + dtype=ttnn.bfloat16, + ): + super().__init__() + self.device = mesh_device + hidden_size = args.vision_dim + self.spatial_merge_size = 2 # TODO Handle in Model_config spatial_merge_size + self.patch_size = args.vision_patch_size + self.args = args + # self.patch_size = ttnn.from_torch( + # torch.tensor(args.vision_patch_size, dtype=torch.int32), + # device=mesh_device, + # dtype=ttnn.int32 + # ) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + self.merging_weights = as_tensor("merging_layer", dtype) + self.merging_bias = as_tensor("merging_layer", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes) -> ttnn.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(ttnn.split(image_features, tokens_per_image, dim=0)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + + image_tokens = ttnn.to_layout(image_tokens, ttnn.ROW_MAJOR_LAYOUT) + + image_grid = ttnn.view(image_tokens, (h, w, d)) + # Permute the grid to have channels last + image_grid = ttnn.permute(image_grid, (2, 0, 1)) # Channels first + image_grid = ttnn.unsqueeze(image_grid, dim=0) # Add batch dimension + # Reshape the grid to merge patches + if self.args.num_devices > 1: + image_grid_torch = ttnn.to_torch(image_grid, mesh_composer=ConcatMeshToTensor(self.device, dim=0)) + image_grid_torch = image_grid_torch[0].unsqueeze(0) # shape: [1, 1024, 30, 44] + image_grid_torch = image_grid_torch.to(dtype=torch.bfloat16) + else: + image_grid_torch = ttnn.to_torch(image_grid).to(dtype=torch.bfloat16) + + grid = torch.nn.functional.unfold( + image_grid_torch, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + + grid = ttnn.from_torch(grid, device=self.device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + grid = ttnn.view(grid, (d * self.spatial_merge_size**2, -1)) + grid = ttnn.transpose(grid, 0, 1) # Transpose to have features first + + permuted_tensor.append(grid) + + image_features = ttnn.concat(permuted_tensor, dim=0) + # Apply merging layer + image_features = ttnn.linear( + image_features, self.merging_weights, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return image_features + + +class TTMistral3MultiModalProjector(LightweightModule): + def __init__(self, mesh_device, args, state_dict, state_dict_prefix, dtype, eps, weight_cache_path=None): + super().__init__() + + self.norm = RMSNorm( + device=mesh_device, + dim=args.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="norm", + weight_dtype=dtype, + eps=eps, + is_distributed=False, + simplified_rms=True, + ) + + self.patch_merger = TTMistral3PatchMerger( + mesh_device=mesh_device, + args=args, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}patch_merger.", + ) + + def get_weight(name): + return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) + + def get_bias(name): + return state_dict[f"{state_dict_prefix}{name}.bias"] + + def cache_name(name): + if args.dummy_weights: + return None + return weight_cache_path / f"{state_dict_prefix}.{name}" + + def as_tensor(name, dtype, is_bias=False): + tensor_data = get_bias(name) if is_bias else get_weight(name) + return ttnn.as_tensor( + tensor_data, + dtype=dtype, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + # cache_file_name=cache_name(name), + ) + + self.linear_1_weight = as_tensor("linear_1", dtype) + self.linear_1_bias = as_tensor("linear_1", ttnn.bfloat16, is_bias=False) + + self.linear_2_weight = as_tensor("linear_2", dtype) + self.linear_2_bias = as_tensor("linear_2", ttnn.bfloat16, is_bias=False) + + def forward(self, image_features: ttnn.Tensor, image_sizes): + image_features = self.norm(image_features, mode="decode") + image_features = self.patch_merger(image_features, image_sizes) + + hidden_states = ttnn.linear( + image_features, + self.linear_1_weight, + dtype=ttnn.bfloat16, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + activation="gelu", # Using GELU activation as per Mistral 3 model + ) + + hidden_states = ttnn.linear( + hidden_states, self.linear_2_weight, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + return hidden_states diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py new file mode 100644 index 000000000000..8fc053f87164 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + +from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention +from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP + + +class TtPixtralImageTransformerBlock(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + ): + super().__init__() + self.state_dict = state_dict + self.mesh_device = mesh_device + self.configuration = configuration + self.num_devices = configuration.num_devices + self.hidden_size = configuration.vision_dim + + self.attention_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="attention_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.attention = TtLlamaImageAttention( + mesh_device, + state_dict, + state_dict_prefix=f"{state_dict_prefix}attention.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + + self.ffn_norm = RMSNorm( + device=mesh_device, + dim=configuration.vision_dim, + state_dict=state_dict, + state_dict_prefix=state_dict_prefix, + weight_key="ffn_norm", + weight_dtype=dtype, + is_distributed=False, + simplified_rms=True, + ) + + self.mlp = MLP( + mesh_device=mesh_device, + args=configuration, + state_dict=state_dict, + weight_cache_path=weight_cache_path, + state_dict_prefix=f"{state_dict_prefix}feed_forward.", + dtype=dtype, + ) + + def forward(self, x_input, position_embeddings=None): + mode = "prefill" + # attention norm Input and result replicated + attn_norm_res = self.attention_norm(x_input, mode=mode) + # attention Input and results replicated + attn_out = self.attention(attn_norm_res, position_embeddings=position_embeddings) + res = ttnn.add(x_input, attn_out, use_legacy=True) + ffn_norm_res = self.ffn_norm(res, mode=mode) + mlp_out = self.mlp(ffn_norm_res) + out = ttnn.add(res, mlp_out) + return out diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py new file mode 100644 index 000000000000..e28a5862074d --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from tqdm import tqdm + +from models.common.lightweightmodule import LightweightModule +from models.experimental.mistral_24b.tt.vision_pixtral_image_block import TtPixtralImageTransformerBlock + + +class TtPixtralTransformer(LightweightModule): + def __init__( + self, + mesh_device, + state_dict, + state_dict_prefix, + weight_cache_path, + dtype, + configuration, + layers, + ): + super().__init__() + + self.state_dict = state_dict + self.mesh_device = mesh_device + + block_key = "layers" + self.resblocks = [ + TtPixtralImageTransformerBlock( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", + weight_cache_path=weight_cache_path, + dtype=dtype, + configuration=configuration, + ) + for i in tqdm(range(layers), desc=f"Loading vision transformer layers") + ] + + def forward(self, x, return_intermediate=None, position_embeddings=None): + """ + Different from reference impl in that if return_intermediates, it returns + a list of intermediate tensors rather than a stack of intermediates. + Outer code will have to be aware and handle this correctly. + """ + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, position_embeddings=position_embeddings) + if return_intermediate is not None: + return x, out + return x diff --git a/models/experimental/mistral_24b/tt/vision_rope.py b/models/experimental/mistral_24b/tt/vision_rope.py new file mode 100644 index 000000000000..d356e8172807 --- /dev/null +++ b/models/experimental/mistral_24b/tt/vision_rope.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.common import precompute_vision_freqs +from ttnn import ReplicateTensorToMesh + + +def compute_gather_cos_sin(dhead, max_patches_per_side, theta, scale_factor, orig_context_len, position_ids): + cos, sin = precompute_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) + return cos, sin + + +class VisionRotarySetup(LightweightModule): + def __init__( + self, + device, + batch_size: int, + head_dim: int, + image_size: int, + patch_size: int, + max_seq_len: int, + rope_theta: float, + scale_factor: float, # use None to disable rope scaling + orig_context_len: int, # only used if scaling enabled + datatype=ttnn.bfloat16, + ): + super().__init__() + + self.batch_size = batch_size + self.head_dim = head_dim + self.device = device + self.is_mesh_device = isinstance(device, ttnn._ttnn.multi_device.MeshDevice) + self.num_devices = device.get_num_devices() if self.is_mesh_device else 1 + if self.num_devices == 32: + self.batch_size_per_device_group = max(self.batch_size // list(device.shape)[1], 1) + else: + self.batch_size_per_device_group = self.batch_size + self.core_grid = device.compute_with_storage_grid_size() + + max_patches_per_side = image_size // patch_size + + # Generate the cos/sin matrices needed for ttnn.embedding op + cos_matrix, sin_matrix = compute_gather_cos_sin( + dhead=head_dim, + max_patches_per_side=max_patches_per_side, + theta=rope_theta, + scale_factor=scale_factor, + orig_context_len=orig_context_len, + position_ids=torch.arange(max_seq_len), + ) + self.cos_matrix = ttnn.from_torch( + cos_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + self.sin_matrix = ttnn.from_torch( + sin_matrix, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=datatype, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + + def get_rot_mats(self, position_idxs, return_rot_idxs=False): + device = self.device + + # return self.cos_matrix, self.sin_matrix + # If position_idxs is a torch tensor, get the TTNN version of it + if isinstance(position_idxs, torch.Tensor): + rot_idxs = position_idxs.unsqueeze(0) + else: + rot_idxs = position_idxs + assert len(rot_idxs.shape) == 2 and rot_idxs.shape[0] == 1, "rot_idxs must be a [1, batch] tensor" + + rot_idxs = ttnn.from_torch( + rot_idxs, + device=device, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.uint32, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ) + # Send the idxs to device + if rot_idxs.device != device: + rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + embedding_layout = ttnn.TILE_LAYOUT + cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] + sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] + + if return_rot_idxs: + return [cos, sin], rot_idxs + ttnn.deallocate(rot_idxs) + return [cos, sin] diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 28612764db38..d4e0e7cea40a 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -98,6 +98,26 @@ def rope_scaling_model_factory(rope_scaling_params: dict) -> RopeScaling: raise ValueError(f"Unexpected RoPE scaling type: {rope_scaling_type}") +def position_ids_in_meshgrid_tt(tt_patch_embeds_list, max_width, device): + position_ids_tt = [] + for tt_patch in tt_patch_embeds_list: + shape = tt_patch.shape + height, width = shape[-2], shape[-1] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + + tt_ids = ttnn.from_torch( + ids, + device=device, + dtype=ttnn.uint32, + layout=ttnn.ROW_MAJOR_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + position_ids_tt.append(tt_ids[:, 0]) + return ttnn.concat(position_ids_tt, dim=0) + + def encode_prompt_instruct(tokenizer, prompt_text, system_prompt_text=None): """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -283,6 +303,43 @@ def apply_llama3_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) +def apply_scaling_vision(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + return freqs / scale_factor + + +def precompute_vision_freqs( + dim: int, max_patches_per_side: int, theta: float, scale_factor=None, orig_context_len=None +): + # Compute base frequencies + base_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + if scale_factor is not None: + base_freqs = apply_scaling_vision(base_freqs, scale_factor, orig_context_len) + + # Get height and width indices + h_idx = torch.arange(max_patches_per_side) + w_idx = torch.arange(max_patches_per_side) + + # Compute 2D frequency matrices + freqs_h = torch.outer(h_idx, base_freqs[::2]) + freqs_w = torch.outer(w_idx, base_freqs[1::2]) + + # Broadcast + merge + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape( + -1, dim // 2 + ) # Shape: [H*W, dim//2] + + full_freqs = torch.cat([inv_freq, inv_freq], dim=-1) + cos = full_freqs.cos() + sin = full_freqs.sin() + return cos, sin # Shape: [H*W, dim] + + def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ Precompute the frequency tensor for sine and cosine values with given dimensions. diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index 2eb8863c24c9..5433d11e3538 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -188,6 +188,7 @@ def prefill_forward_single_user_text( chunk_start_idx=chunk_start, get_last_token=(last_token_idx_in_chunk // 32) * 32, kv_cache=kv_cache, + **kwargs, ) if chunk_start == last_chunk_start: diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index fa00a6b882cf..5b658b9dfca9 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -528,7 +528,10 @@ def convert_hf_qkv_to_meta_format(loaded_weights, head_dim): """Convert HuggingFace QKV weights to Meta format for RoPE compatibility.""" converted_weights = {} for key, tensor in loaded_weights.items(): - if "q_proj.weight" in key or "k_proj.weight" in key: + if "vision_tower" in key: + # Skip conversion for vision tower weights + converted_weights[key] = tensor + elif "q_proj.weight" in key or "k_proj.weight" in key: # For weights: n_heads = tensor.shape[0] // head_dim n_heads = tensor.shape[0] // head_dim converted_weights[key] = reverse_permute(tensor, n_heads, tensor.shape[0], tensor.shape[1]) @@ -596,6 +599,38 @@ def map_hf_to_meta_keys(loaded_weights): return replace_keys(loaded_weights, replacements) +def map_vision_meta_to_hf_keys(loaded_weights): + """ + Map Hugging Face checkpoint keys to Meta checkpoint keys. + You can use this to support other models by adding more mappings. + See replace_keys for more details on the format of replacements. + """ + base_mapping = [ + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ("wq", "q_proj"), + ("wk", "k_proj"), + ("wv", "v_proj"), + ("wo", "o_proj"), + ] + + extra_mapping = [ + ("attention_norm", "input_layernorm"), + ("ffn_norm", "post_attention_layernorm"), + ("attention", "self_attn"), + ("feed_forward", "mlp"), + ] + + model_name = os.getenv("HF_MODEL") + if "Mistral" in model_name: + mapping = base_mapping + else: + mapping = base_mapping + extra_mapping + + return replace_keys(loaded_weights, mapping) + + def convert_vision_meta_to_hf(state_dict, head_dim): # state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) state_dict = map_vision_meta_to_hf_keys(state_dict) diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 65c339b68c9a..1e8c94976919 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1395,6 +1395,8 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): def _get_text_prefix(self): if self.is_vision(): + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + return "language_model." return "text_model." else: return "" @@ -1530,6 +1532,24 @@ def _set_params_from_dict(self, config, is_hf=False): self._set_vision_params(config) self.is_multimodal = "vision_config" in config or self.is_vision() + # Vision params (Meta-specific) + self.vision_chunk_size = config.get("vision_chunk_size", 896) + self.vision_max_num_chunks = config.get("vision_max_num_chunks", 4) + self.vision_num_cross_attention_layers = config.get("vision_num_cross_attention_layers", -1) + + # Vision constants + self.vision_dim = 1280 + self.vision_mlp_ratio = 4 + self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) + self.vision_act_layer = ttnn.UnaryOpType.GELU + self.vision_dropout = 0.0 + self.vision_attn_n_heads = 16 + self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads + self.vision_n_layers = 32 + self.vision_n_global_layers = 8 + self.vision_max_num_tiles = 4 + self.vision_patch_size = 14 + self.vision_in_channels = 3 self.state_dict_text_prefix = self._get_text_prefix() self.state_dict_vision_prefix = self._get_vision_prefix() @@ -1605,28 +1625,33 @@ def _set_params(self, checkpoint_dir): else None ) - def _set_vision_params(self, config): + def _set_vision_params(self, vision_config): vision_config = config.get("vision_config", config) - self.vision_chunk_size = vision_config.get("vision_chunk_size", -1) - self.image_size = vision_config.get("image_size", 896) + self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) - self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", -1) - self.vision_dim = vision_config.get("hidden_size", 1280) - + self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) + self.vision_dim = vision_config.get("hidden_size", 1152) intermediate_size = vision_config.get("intermediate_size", self.vision_dim * 4) + self.vision_image_size = vision_config.get("image_size", 1540) + self.vision_rope_theta = vision_config.get("rope_theta", 10000.0) + self.image_token_index = vision_config.get("image_token_index", 10) + self.vision_mlp_ratio = intermediate_size // self.vision_dim self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio) - self.vision_attn_n_heads = vision_config.get("num_attention_heads", 16) + self.vision_attn_n_heads = vision_config.get("num_attention_heads") or vision_config.get("num_heads") or 16 self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads - self.vision_n_layers = vision_config.get("num_hidden_layers", 32) + self.vision_n_layers = vision_config.get("num_hidden_layers") or vision_config.get("depth") or 27 self.vision_patch_size = vision_config.get("patch_size", 14) self.vision_in_channels = vision_config.get("num_channels", 3) self.vision_dropout = vision_config.get("attention_dropout", 0.0) self.mm_tokens_per_image = vision_config.get("mm_tokens_per_image", 256) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self.vision_head_dim = vision_config.get("head_dim", 64) + # Optional vision activation layer, defaults to GELU act_layer = vision_config.get("act_layer", "gelu").lower() self.vision_act_layer = { @@ -1638,8 +1663,26 @@ def _set_vision_params(self, config): # Optional tuning knobs self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) self.vision_n_global_layers = vision_config.get("n_global_layers", 8) + # self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) + # self.vision_n_global_layers = vision_config.get("n_global_layers", 8) + + # # Optional Meta-specific knobs + # self.vision_max_num_chunks = vision_config.get("max_num_chunks", 4) + # self.vision_num_cross_attention_layers = vision_config.get("num_cross_attention_layers", -1) def _set_hf_params(self, checkpoint_dir): + def merge_text_config(base_config): + text_config = base_config.get("text_config", {}) + # Merge non-nested keys into text_config + text_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return text_config + + def merge_vision_config(base_config): + vision_config = base_config.get("vision_config", {}) + # Merge non-nested keys into vision_config + vision_config.update({k: v for k, v in base_config.items() if k not in ["text_config", "vision_config"]}) + return vision_config + if self.from_hf_url: # Special case Qwen2.5-VL models until they are fully integrated into a HF release if "Qwen/Qwen2.5-VL" in self.model_name: @@ -1655,12 +1698,25 @@ def _set_hf_params(self, checkpoint_dir): else: self.hf_config = AutoConfig.from_pretrained(self.CKPT_DIR) config = self.hf_config.to_dict() + if "text_config" in config or "vision_config" in config: + merged_text_config = merge_text_config(config) + self._set_params_from_dict(merged_text_config, is_hf=True) + + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + self._set_vision_params(config["vision_config"]) + else: + if "vision_config" in config: + merged_vision_config = merge_vision_config(config) + self._set_vision_params(merged_vision_config) + else: + self._set_params_from_dict(config, is_hf=True) + else: config_file = os.path.join(checkpoint_dir, "config.json") assert os.path.exists(config_file), f"config.json file not found at {config_file}" with open(config_file, "r") as f: config = json.load(f) - self._set_params_from_dict(config, is_hf=True) + self._set_params_from_dict(config) def __repr__(self): return f"""ModelArgs( @@ -1753,6 +1809,8 @@ def load_state_dict(self): ) print("Loading Qwen2.5-VL model: ", AutoModelForCausalLM) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoModelForCausalLM @@ -2135,55 +2193,76 @@ def create_tokenizer(self): logger.info(f"Model name: {self.model_name}") logger.info(f"Base model name: {self.base_model_name}") - try: - # Try to load tokenizer from the original model path - tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) - logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") - except Exception as e: - logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") - - # Try to use base model tokenizer as fallback - fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) - - # If no direct match, try to infer from model name patterns - if not fallback_tokenizer_path: - model_name_lower = self.model_name.lower() - if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" - elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" - elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" - elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" - elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" - elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" - elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: - fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" - elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" - elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: - fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" - elif "mistral" in model_name_lower and "7b" in model_name_lower: - fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" - - if fallback_tokenizer_path: - logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") - try: - tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) - logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") - except Exception as fallback_e: - logger.error(f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}") - raise fallback_e - else: - logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") - raise e + # Special handling for Mistral-Small-3.1-24B-Instruct-2503 + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + tokenizer = AutoTokenizer.from_pretrained( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", trust_remote_code=True + ) + logger.info("Manually setting Mistral instruct-style chat template on the tokenizer.") + + mistral_template = """{% for message in messages %} + {% if message['role'] == 'system' %} + <|system|> + {{ message['content'] }} + {% elif message['role'] == 'user' %} + [INST] {{ message['content'] }} [/INST] + {% elif message['role'] == 'assistant' %} + {{ message['content'] }}{{ eos_token }} + {% endif %} + {% endfor %}""" + tokenizer.chat_template = mistral_template + else: + try: + # Try to load tokenizer from the original model path + tokenizer = AutoTokenizer.from_pretrained(self.TOKENIZER_PATH) + logger.info(f"Successfully loaded tokenizer from {self.TOKENIZER_PATH}") + except Exception as e: + logger.warning(f"Failed to load tokenizer from {self.TOKENIZER_PATH}: {e}") + + # Try to use base model tokenizer as fallback + fallback_tokenizer_path = base_model_tokenizer_mapping.get(self.base_model_name) + + # If no direct match, try to infer from model name patterns + if not fallback_tokenizer_path: + model_name_lower = self.model_name.lower() + if "qwen2.5" in model_name_lower and "0.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + elif "qwen2.5" in model_name_lower and "1.5b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-1.5B-Instruct" + elif "qwen2.5" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-3B-Instruct" + elif "qwen2.5" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-7B-Instruct" + elif "qwen2.5" in model_name_lower and "14b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-14B-Instruct" + elif "qwen2.5" in model_name_lower and "32b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-32B-Instruct" + elif "qwen2.5" in model_name_lower and "72b" in model_name_lower: + fallback_tokenizer_path = "Qwen/Qwen2.5-72B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "8b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-8B-Instruct" + elif "llama" in model_name_lower and "3.1" in model_name_lower and "70b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.1-70B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "1b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-1B-Instruct" + elif "llama" in model_name_lower and "3.2" in model_name_lower and "3b" in model_name_lower: + fallback_tokenizer_path = "meta-llama/Llama-3.2-3B-Instruct" + elif "mistral" in model_name_lower and "7b" in model_name_lower: + fallback_tokenizer_path = "mistralai/Mistral-7B-Instruct-v0.3" + + if fallback_tokenizer_path: + logger.info(f"Attempting to use fallback tokenizer: {fallback_tokenizer_path}") + try: + tokenizer = AutoTokenizer.from_pretrained(fallback_tokenizer_path) + logger.info(f"Successfully loaded fallback tokenizer from {fallback_tokenizer_path}") + except Exception as fallback_e: + logger.error( + f"Failed to load fallback tokenizer from {fallback_tokenizer_path}: {fallback_e}" + ) + raise fallback_e + else: + logger.error(f"No fallback tokenizer found for base model: {self.base_model_name}") + raise e # Add meta-compatible stop token list to the HF tokenizer if not "stop_tokens" in tokenizer.__dict__: @@ -2246,6 +2325,8 @@ def reference_transformer(self, wrap=True, load_checkpoint=False): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration as AutoModelForCausalLM, ) + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration as AutoModelForCausalLM else: from transformers import AutoConfig, AutoModelForCausalLM @@ -2287,6 +2368,8 @@ def reference_vision_multi_modal(self): layer = model.multi_modal_projector # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_rms_norm(self): @@ -2294,6 +2377,22 @@ def reference_vision_rms_norm(self): layer = model.multi_modal_projector.mm_soft_emb_norm # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm_qwen(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.blocks[0].norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms_norm_qwen_merger(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.visual.merger + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer def reference_rms_norm(self): @@ -2303,7 +2402,8 @@ def reference_rms_norm(self): return RMSNorm(self.dim, self.norm_eps) else: model = self.reference_transformer(wrap=False) - layer = model.model.norm + layers = getattr(model, "layers", getattr(model, "model", {}).layers) + layer = layers[0].input_layernorm layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) return layer @@ -2323,6 +2423,12 @@ def reference_vision_transformer(self, wrap=True, load_checkpoint=False): model = Gemma3ForConditionalGeneration.from_pretrained(self.CKPT_DIR) model = model + elif "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + from transformers import Mistral3ForConditionalGeneration + + model = Mistral3ForConditionalGeneration.from_pretrained(self.CKPT_DIR, torch_dtype=torch.bfloat16) + model = model + else: if self.cached_hf_model is None: model = AutoModelForCausalLM.from_pretrained(self.CKPT_DIR) @@ -2343,6 +2449,17 @@ def reference_gemma_model(self): layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer + def reference_vision_model(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + # Mistral-Small-3.1-24B-Instruct-2503 has a different structure + layer = model.vision_tower + else: + layer = model.vision_tower.vision_model + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + def reference_vision_model(self): model = self.reference_vision_transformer(wrap=False) layer = model.vision_tower.vision_model @@ -2357,11 +2474,41 @@ def reference_vision_mlp(self): # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer + def reference_pixtral_image_block(self, layer_num=0): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[layer_num] + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_mlp(self, layer_idx=0): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[layer_idx].feed_forward + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rms(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.transformer.layers[0].ffn_norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_conv2d_patch(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.patch_conv + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + def reference_siglip_patch_embed(self): model = self.reference_vision_transformer(wrap=False) layer = model.vision_tower.vision_model.embeddings.patch_embedding # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_pos_embedding(self): @@ -2369,6 +2516,8 @@ def reference_vision_pos_embedding(self): layer = model.vision_tower.vision_model.embeddings.position_embedding # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_embedding(self): @@ -2395,6 +2544,33 @@ def reference_vision_attention(self): layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_layernorm(self): + model = self.reference_vision_transformer(wrap=False) + layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_attention(self, layer_idx=0): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].attention + else: + layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + return layer + + def reference_vision_rot_emb(self): + model = self.reference_vision_transformer(wrap=False) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.patch_positional_embedding + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_encoder_block(self): @@ -2402,6 +2578,8 @@ def reference_vision_encoder_block(self): layer = model.vision_tower.vision_model.encoder.layers[0] # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_encoder(self): @@ -2409,6 +2587,12 @@ def reference_vision_encoder(self): layer = model.vision_tower.vision_model.encoder # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer + else: + layer = model.vision_tower.vision_model.encoder + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_mlp(self): @@ -2469,7 +2653,7 @@ def reference_attention(self): "Gemma3Attention", ) wrapper = HfAttentionWrapper( - layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None + layer, self.head_dim, model.model.rotary_emb wif use_position_embeddings else None ) return wrapper diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f51ba21be8d4cbeb5faceca2b88150b50926abf2 GIT binary patch literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip literal 0 HcmV?d00001 diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16dad8dcbf18374fbb920fa4ad944e7e9aef8b89 GIT binary patch literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Wed, 20 Aug 2025 10:16:00 +0000 Subject: [PATCH 02/18] Rebase Experimental E2E and refactor to include tt_ccl --- .../tests/pipeline_tests/test_end2end.py | 20 ++++++++++++---- .../tests/pipeline_tests/test_vision_model.py | 8 +++++++ .../tests/pipeline_tests/test_vision_tower.py | 8 +++++++ .../tests/test_pixtral_transformer.py | 8 +++++++ .../tests/test_vision_attention.py | 8 +++++++ models/experimental/mistral_24b/tt/model.py | 10 +++++++- .../tt/pipeline/mistral_vision_tower.py | 3 +++ .../mistral_24b/tt/pipeline/vision_model.py | 4 +++- .../mistral_24b/tt/vision_attention.py | 22 ++++++++++++++++-- .../tt/vision_pixtral_image_block.py | 3 +++ .../tt/vision_pixtral_transformer.py | 3 +++ models/tt_transformers/tt/model_config.py | 10 ++++---- .../pixtral_transformer_inputs/demo_small.jpg | Bin 8554 -> 0 bytes .../pixtral_transformer_inputs/people.jpg | Bin 49606 -> 0 bytes 14 files changed, 95 insertions(+), 12 deletions(-) delete mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg delete mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py index a78d1b683371..a91be706502c 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py @@ -6,6 +6,7 @@ import os import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.common import ( sample_host, PagedAttentionConfig, @@ -117,8 +118,14 @@ def setup_vision_prompts_and_tokenizer(model_args, instruct): { "role": "user", "content": [ - {"type": "image", "image": "https://www.theeducationmagazine.com/wp-content/uploads/2020/03/18.jpg"}, - {"type": "text", "text": "Tell me who you see in the image and describe the image ?"}, + { + "type": "image", + "image": "https://img.freepik.com/premium-photo/girl-hugging-dog-with-girl-hugging-her_737761-2565.jpg", + }, + { + "type": "text", + "text": "Is there a cat in this image? If not, what animal do you see in the image? Describe the image in detail in 600 words.", + }, ], } ] @@ -182,9 +189,11 @@ def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged max_num_blocks=page_params["page_max_num_blocks"], ) + tt_ccl = TT_CCL(mesh_device) # Load vision model (exactly like test_end2end.py) vision_model = TtMistralVisionTransformer( mesh_device=mesh_device, + tt_ccl=tt_ccl, state_dict=state_dict, state_dict_prefix=vision_prefix, dtype=dtype, @@ -418,6 +427,11 @@ def validate_e2e_outputs(results, expected_min_tokens=1): ], ids=["accuracy"], ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) @pytest.mark.parametrize( "mesh_device", [ @@ -427,8 +441,6 @@ def validate_e2e_outputs(results, expected_min_tokens=1): ], indirect=True, ) -# @pytest.mark.parametrize("device_params", [{"l1_small_size": 1584864, "trace_region_size": 0}], indirect=True) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 10 * 1024}], indirect=True) def test_e2e_vision_text_pipeline( weights, layers, diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py index 97f5736680ef..3d5cd77cb75f 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py @@ -7,6 +7,7 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -31,6 +32,11 @@ def get_image_features(vision_tower, projector, input_tensor, image_sizes): ], indirect=True, ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_mistral_vision_model(mesh_device, reset_seeds): pcc_required = 0.97 dtype = ttnn.bfloat8_b @@ -62,8 +68,10 @@ def test_mistral_vision_model(mesh_device, reset_seeds): reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) # ##### TT Model: TtMistralVisionTransformer ##### + tt_ccl = TT_CCL(mesh_device=mesh_device) vision_model = TtMistralVisionTransformer( mesh_device=mesh_device, + tt_ccl=tt_ccl, state_dict=state_dict, state_dict_prefix=first_layer_prefix, dtype=dtype, diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py index 35222c0a9b65..a6c00009258c 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py @@ -7,6 +7,7 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -22,6 +23,11 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_mistral_vision_tower(mesh_device, reset_seeds): pcc_required = 0.99 dtype = ttnn.bfloat16 @@ -43,9 +49,11 @@ def test_mistral_vision_tower(mesh_device, reset_seeds): reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) reference_output = reference_output.last_hidden_state + tt_ccl = TT_CCL(mesh_device) ##### TT Model: MistralVisionTower ##### vision_model = MistralVisionTower( mesh_device=mesh_device, + tt_ccl=tt_ccl, state_dict=state_dict, state_dict_prefix=first_layer_prefix, dtype=dtype, diff --git a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py index affd89d61d6d..578efa204295 100644 --- a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py +++ b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py @@ -8,6 +8,7 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer @@ -28,6 +29,11 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_image_transformer_inference(batch, num_chunks, mesh_device): pcc_required = 0.99 @@ -52,8 +58,10 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device): all_tests_pass = True + tt_ccl = TT_CCL(mesh_device) tt_model = TtPixtralTransformer( mesh_device, + tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=None, diff --git a/models/experimental/mistral_24b/tests/test_vision_attention.py b/models/experimental/mistral_24b/tests/test_vision_attention.py index 8466b102eed9..821dfa3222d2 100644 --- a/models/experimental/mistral_24b/tests/test_vision_attention.py +++ b/models/experimental/mistral_24b/tests/test_vision_attention.py @@ -9,6 +9,7 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -36,6 +37,11 @@ "batch_size", (1,), ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_vision_attention(mesh_device, seq_len, batch_size): logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") dtype = ttnn.bfloat16 @@ -56,8 +62,10 @@ def test_vision_attention(mesh_device, seq_len, batch_size): n_heads = model_args.vision_attn_n_heads head_dim = hidden_size // n_heads + tt_ccl = TT_CCL(mesh_device) tt_model = TtLlamaImageAttention( mesh_device, + tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=model_args.weight_cache_path(dtype), diff --git a/models/experimental/mistral_24b/tt/model.py b/models/experimental/mistral_24b/tt/model.py index bfe094b9d8ed..ebf987a4e511 100644 --- a/models/experimental/mistral_24b/tt/model.py +++ b/models/experimental/mistral_24b/tt/model.py @@ -100,6 +100,14 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], ] + if hasattr(self, "rope_local_setup"): + tt_rot_mats_prefill_local = [ + self.rope_local_setup.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_local_setup.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None + if page_table is not None: tt_page_table = ttnn.from_torch( page_table, @@ -122,4 +130,4 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag else: tt_chunk_page_table = None - return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table + return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py index c78b9a9a3669..5656ab0232e4 100644 --- a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py +++ b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py @@ -18,6 +18,7 @@ class MistralVisionTower(LightweightModule): def __init__( self, mesh_device, + tt_ccl, state_dict, state_dict_prefix, dtype, @@ -28,6 +29,7 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device + self.tt_ccl = tt_ccl self.dtype = dtype self.config = configuration @@ -98,6 +100,7 @@ def __init__( self.transformer = TtPixtralTransformer( mesh_device=self.mesh_device, + tt_ccl=tt_ccl, state_dict=self.state_dict, state_dict_prefix=f"{state_dict_prefix}transformer.", weight_cache_path=configuration.weight_cache_path(dtype), diff --git a/models/experimental/mistral_24b/tt/pipeline/vision_model.py b/models/experimental/mistral_24b/tt/pipeline/vision_model.py index 098c32bab03f..19f0b86478e7 100644 --- a/models/experimental/mistral_24b/tt/pipeline/vision_model.py +++ b/models/experimental/mistral_24b/tt/pipeline/vision_model.py @@ -12,13 +12,15 @@ class TtMistralVisionTransformer(LightweightModule): - def __init__(self, mesh_device, state_dict, state_dict_prefix, dtype, model_args): + def __init__(self, mesh_device, tt_ccl, state_dict, state_dict_prefix, dtype, model_args): super().__init__() self.state_dict = state_dict self.mesh_device = mesh_device + self.tt_ccl = tt_ccl self.vision_tower = MistralVisionTower( mesh_device=mesh_device, + tt_ccl=self.tt_ccl, state_dict=state_dict, state_dict_prefix=state_dict_prefix, dtype=dtype, diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py index 3bcc772f36de..f342365e4694 100644 --- a/models/experimental/mistral_24b/tt/vision_attention.py +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -6,7 +6,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.utility_functions import nearest_32 +from models.utility_functions import is_blackhole, nearest_32 def rotate_half(x): @@ -33,6 +33,7 @@ class TtMistralImageAttention(LightweightModule): def __init__( self, mesh_device, + tt_ccl, state_dict, state_dict_prefix, weight_cache_path, @@ -43,6 +44,7 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device + self.tt_ccl = tt_ccl self.num_devices = configuration.num_devices self.hidden_size = configuration.vision_dim @@ -237,7 +239,23 @@ def forward(self, x_11SH, position_embeddings=None): # All reduce if self.num_devices > 1: # replace with reduce_scatter and all_gather - dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + # TODO: 26411 + # Remove this blackhole condition once fabric CCLs are working on blackhole + if is_blackhole(): + dense_out_gathered = ttnn.all_gather(output_11SH, dim=1, num_links=1, topology=ttnn.Topology.Linear) + else: + dense_out_gathered = ttnn.experimental.all_gather_async( + output_11SH, + persistent_output_buffer=None, + dim=1, + multi_device_global_semaphore=self.tt_ccl.get_and_cycle_ag_semaphore_handles(), + num_links=1, + topology=ttnn.Topology.Linear, + barrier_semaphore=self.tt_ccl.get_and_cycle_barrier_semaphore_handle(), + chunks_per_sync=10, + num_workers_per_link=2, + num_buffers_per_channel=2, + ) output_11SH.deallocate(True) dense_out_reduced = ttnn.experimental.fast_reduce_nc( dense_out_gathered, dims=[1], output=None, compute_kernel_config=None diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py index 8fc053f87164..e80d8f662856 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py @@ -14,6 +14,7 @@ class TtPixtralImageTransformerBlock(LightweightModule): def __init__( self, mesh_device, + tt_ccl, state_dict, state_dict_prefix, weight_cache_path, @@ -23,6 +24,7 @@ def __init__( super().__init__() self.state_dict = state_dict self.mesh_device = mesh_device + self.tt_ccl = tt_ccl self.configuration = configuration self.num_devices = configuration.num_devices self.hidden_size = configuration.vision_dim @@ -40,6 +42,7 @@ def __init__( self.attention = TtLlamaImageAttention( mesh_device, + tt_ccl, state_dict, state_dict_prefix=f"{state_dict_prefix}attention.", weight_cache_path=weight_cache_path, diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py index e28a5862074d..85408be02b9f 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py @@ -12,6 +12,7 @@ class TtPixtralTransformer(LightweightModule): def __init__( self, mesh_device, + tt_ccl, state_dict, state_dict_prefix, weight_cache_path, @@ -23,11 +24,13 @@ def __init__( self.state_dict = state_dict self.mesh_device = mesh_device + self.tt_ccl = tt_ccl block_key = "layers" self.resblocks = [ TtPixtralImageTransformerBlock( mesh_device=mesh_device, + tt_ccl=self.tt_ccl, state_dict=state_dict, state_dict_prefix=f"{state_dict_prefix}{block_key}.{i}.", weight_cache_path=weight_cache_path, diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 1e8c94976919..e79f633b54e3 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1395,8 +1395,6 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): def _get_text_prefix(self): if self.is_vision(): - if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: - return "language_model." return "text_model." else: return "" @@ -1742,8 +1740,12 @@ def is_vision(self): return self.vision_chunk_size > 0 def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): - text_prefix = self.state_dict_text_prefix - vision_prefix = self.state_dict_vision_prefix + if self.is_vision() and "Mistral-Small-3.1-24B" not in self.model_name: + text_prefix = self.state_dict_text_prefix + else: + text_prefix = "" if not is_vision else self.state_dict_text_prefix + + vision_prefix = self.state_dict_vision_prefix if is_vision else "" layer_prefix = f"layers.{layer_num}." if layer_num is not None else "" diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg deleted file mode 100644 index f51ba21be8d4cbeb5faceca2b88150b50926abf2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg deleted file mode 100644 index 16dad8dcbf18374fbb920fa4ad944e7e9aef8b89..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Fri, 22 Aug 2025 06:49:17 +0000 Subject: [PATCH 03/18] Refactor comments and clean code --- .../tests/pipeline_tests/test_end2end.py | 21 +++++++++---------- .../tests/pipeline_tests/test_vision_model.py | 6 +++++- .../tests/pipeline_tests/test_vision_tower.py | 6 +++++- .../mistral_24b/tests/test_conv2d.py | 6 ++++-- .../mistral_24b/tests/test_patch_rot_emb.py | 4 +++- .../tests/test_pixtral_transformer.py | 4 ++-- .../tests/test_vision_attention.py | 8 +------ .../mistral_24b/tests/test_vision_mlp.py | 2 +- .../mistral_24b/tests/test_vision_rms.py | 4 ++++ models/experimental/mistral_24b/tt/model.py | 7 ++++--- .../tt/pipeline/mistral_vision_tower.py | 6 +++++- .../mistral_24b/tt/pipeline/vision_model.py | 4 ++++ .../mistral_24b/tt/vision_attention.py | 5 +++-- .../mistral_24b/tt/vision_conv2d.py | 12 +---------- .../experimental/mistral_24b/tt/vision_mlp.py | 6 +++--- .../experimental/mistral_24b/tt/vision_mmp.py | 11 ++++------ .../tt/vision_pixtral_image_block.py | 6 +++++- .../tt/vision_pixtral_transformer.py | 2 +- .../mistral_24b/tt/vision_rope.py | 6 +++--- models/tt_transformers/tt/common.py | 2 +- models/tt_transformers/tt/model_config.py | 1 + 21 files changed, 70 insertions(+), 59 deletions(-) diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py index a91be706502c..0cc07bf317d5 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + """Test for Mistral-24B End-to-End Vision-Text Pipeline""" import torch @@ -157,7 +160,6 @@ def process_real_vision_inputs(messages, model_args): ) image_inputs, video_inputs = process_vision_info(messages) - # image_inputs, video_inputs = None, None encoded = processor( text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt", return_dict=True @@ -284,7 +286,7 @@ def run_generation_exactly_like_test_end2end( topk_tokens = [model_args.tokenizer.decode([idx.item()]) for idx in topk_indices] - logger.info("🔍 Top-5 predicted tokens (with probabilities):") + logger.info("Top-5 predicted tokens (with probabilities):") for i in range(top_k): logger.info(f"{i+1}. Token: '{topk_tokens[i]}' (ID={topk_indices[i].item()}), P={topk_probs[i].item():.4f}") @@ -347,8 +349,8 @@ def run_generation_exactly_like_test_end2end( # Final response (exactly like test_end2end.py) response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"📝 Each iteration Generated Response:\n{response}") - logger.info(f"📝 Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + logger.info(f"Each iteration Generated Response:\n{response}") + logger.info(f"Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") chat = parse_chat_output(response) display_chat(logger, chat) @@ -356,8 +358,8 @@ def run_generation_exactly_like_test_end2end( # Final response (exactly like test_end2end.py) response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"📝 Final Generated Response:\n{response}") - logger.info(f"📝 Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") + logger.info(f"Final Generated Response:\n{response}") + logger.info(f"Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") chat = parse_chat_output(response) display_chat(logger, chat) @@ -469,9 +471,6 @@ def test_e2e_vision_text_pipeline( # Setup vision prompts and tokenizer messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) - # logger.info("Running reference HF vision-text model using messages..... ") - # hf_output = run_reference_demo_pipeline(messages) - # Process real vision inputs from images processed_inputs = process_real_vision_inputs(messages, model_args) @@ -524,12 +523,12 @@ def test_e2e_vision_text_pipeline( # Final validation if validation_passed and len(results) > 0: - logger.info("✅ E2E vision-text pipeline test PASSED!") + logger.info("E2E vision-text pipeline test PASSED!") logger.info(f"Successfully generated {len(results)} tokens") # Log generated tokens for debugging for i, result in enumerate(results[:5]): logger.info(f"Token {i}: {result.token} -> '{result.text}'") else: - logger.error("❌ E2E pipeline test failed") + logger.error("E2E pipeline test failed") assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py index 3d5cd77cb75f..761a177acaee 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py @@ -1,6 +1,10 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file is a unit test for validating the Mistral-24B Vision Model pipeline. +""" + import os import pytest import torch diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py index a6c00009258c..8401ec212ecd 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py @@ -1,6 +1,10 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file is a unit test for validating the Mistral-24B Vision Tower model. +""" + import os import pytest import torch diff --git a/models/experimental/mistral_24b/tests/test_conv2d.py b/models/experimental/mistral_24b/tests/test_conv2d.py index 69d1ccb35ac9..cfb05115560c 100644 --- a/models/experimental/mistral_24b/tests/test_conv2d.py +++ b/models/experimental/mistral_24b/tests/test_conv2d.py @@ -1,7 +1,9 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file is a unit test for validating the Mistral-24B conv2d. +""" import os import pytest diff --git a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py index 903cdf395d91..4cb5a284a912 100644 --- a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py +++ b/models/experimental/mistral_24b/tests/test_patch_rot_emb.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + from loguru import logger import torch @@ -5,7 +8,6 @@ import os import ttnn -# models/tt_transformers/tt/common.py from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py index 578efa204295..0458847993b0 100644 --- a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py +++ b/models/experimental/mistral_24b/tests/test_pixtral_transformer.py @@ -1,6 +1,6 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 + import os import pytest diff --git a/models/experimental/mistral_24b/tests/test_vision_attention.py b/models/experimental/mistral_24b/tests/test_vision_attention.py index 821dfa3222d2..8c216339b02d 100644 --- a/models/experimental/mistral_24b/tests/test_vision_attention.py +++ b/models/experimental/mistral_24b/tests/test_vision_attention.py @@ -1,5 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os @@ -81,11 +80,6 @@ def test_vision_attention(mesh_device, seq_len, batch_size): cos = torch.ones((1, T, head_dim)).to(torch.bfloat16) sin = torch.zeros((1, T, head_dim)).to(torch.bfloat16) - # attention_input = model_args.prepare_residual_tensor_prefill( - # pt_attention_input, - # force_replicated=True, - # ) - attention_input = ttnn.from_torch( pt_attention_input.unsqueeze(0), device=mesh_device, diff --git a/models/experimental/mistral_24b/tests/test_vision_mlp.py b/models/experimental/mistral_24b/tests/test_vision_mlp.py index 32159eeabf7b..6097605051d9 100644 --- a/models/experimental/mistral_24b/tests/test_vision_mlp.py +++ b/models/experimental/mistral_24b/tests/test_vision_mlp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 diff --git a/models/experimental/mistral_24b/tests/test_vision_rms.py b/models/experimental/mistral_24b/tests/test_vision_rms.py index 93181c2bc95f..186ddd67c1cb 100644 --- a/models/experimental/mistral_24b/tests/test_vision_rms.py +++ b/models/experimental/mistral_24b/tests/test_vision_rms.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + from loguru import logger import torch diff --git a/models/experimental/mistral_24b/tt/model.py b/models/experimental/mistral_24b/tt/model.py index ebf987a4e511..764c12bf3a1d 100644 --- a/models/experimental/mistral_24b/tt/model.py +++ b/models/experimental/mistral_24b/tt/model.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + """ This is the end-to-end pipeline for the Mistral-Small-3.1-24B-Instruct-2503 model. @@ -51,9 +55,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - # self.embed_scale = args.dim**0.5 tokens_embd = self.embd(tokens) - # tokens_embd = ttnn.multiply(tokens_embd, self.embed_scale) pixel_values = kwargs["processed_inputs"]["pixel_values"] input_ids = kwargs["processed_inputs"]["input_ids"] @@ -73,7 +75,6 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag input_ids = torch.nn.functional.pad( input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 ) - # image_features = image_features.squeeze(0) special_image_mask = (input_ids == 10).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(tokens_embd) image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py index 5656ab0232e4..f5b69bf578da 100644 --- a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py +++ b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py @@ -2,6 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 +""" +This file implements the Vision Tower submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + import ttnn from models.common.lightweightmodule import LightweightModule from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch @@ -139,7 +143,7 @@ def forward(self, input_tensor, image_sizes=None): patch_embeds = ttnn.concat(reshaped_patches, dim=0) # ln_pre RMS Norm - mode = "prefill" # if self.max_seq_len <= 32 else "prefill" + mode = "prefill" patch_embeds = self.ln_pre(patch_embeds, mode=mode) # # positional embeddings diff --git a/models/experimental/mistral_24b/tt/pipeline/vision_model.py b/models/experimental/mistral_24b/tt/pipeline/vision_model.py index 19f0b86478e7..ebc816a71279 100644 --- a/models/experimental/mistral_24b/tt/pipeline/vision_model.py +++ b/models/experimental/mistral_24b/tt/pipeline/vision_model.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + """ This is the end-to-end architecture of the Mistral-24B vision model. diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py index f342365e4694..ac096594ce42 100644 --- a/models/experimental/mistral_24b/tt/vision_attention.py +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -1,7 +1,10 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file implements the vision attention submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" import torch import ttnn @@ -139,7 +142,6 @@ def pad_head_dim(weight, heads_out=True): dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name("wqkv_sharded"), ) self.wo = ttnn.as_tensor( @@ -153,7 +155,6 @@ def pad_head_dim(weight, heads_out=True): memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name("wo_sharded"), ) self.scale = self.head_dim**-0.5 diff --git a/models/experimental/mistral_24b/tt/vision_conv2d.py b/models/experimental/mistral_24b/tt/vision_conv2d.py index 0b16dca7fbcf..ae9e14e8c3b9 100644 --- a/models/experimental/mistral_24b/tt/vision_conv2d.py +++ b/models/experimental/mistral_24b/tt/vision_conv2d.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -10,7 +10,6 @@ class TtMistralConv2dPatch(LightweightModule): """Conv2D Patching layer. - Column parallel over unfolded input. Arguments: in_channels: Input channels. out_channels: Output channels. @@ -61,10 +60,6 @@ def __init__( weight = state_dict[f"{state_dict_prefix}weight"] if weight.ndim == 4: weight = weight.reshape(out_channels, -1).T - # pad_len = nearest_32(weight.shape[-1]) - weight.shape[-1] - # padding = torch.zeros(self.out_channels, pad_len, dtype=weight.dtype) - # padded_weight = torch.cat([weight, padding], dim=-1) - # padded_weight = padded_weight.permute(1, 0).reshape(1, 1, -1, self.out_channels) self._linear_weight = ttnn.as_tensor( weight, @@ -87,11 +82,6 @@ def forward(self, x: torch.Tensor): x = self._unfold(x) x = x.permute(0, 2, 1) - # Need to pad the last dimension of x to be a multiple of a tile - # pad_len = nearest_32(x.shape[-1]) - x.shape[-1] - # padding = torch.zeros((x.shape[0], x.shape[1], pad_len), dtype=x.dtype, device=x.device) - # x = torch.cat([x, padding], dim=-1) - x = ttnn.as_tensor( x, dtype=ttnn.bfloat16, diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py index 61ac96c3ed45..d79a85f1118a 100644 --- a/models/experimental/mistral_24b/tt/vision_mlp.py +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -1,7 +1,10 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" import torch import ttnn @@ -75,9 +78,6 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: w2 -> down_proj """ - # if x.shape[-2] >= self.args.prefill_len_cutoff and mode != "decode": - # x = ttnn.reshape(x, [1, x.shape[-2] // self.args.prefill_len_cutoff, self.args.prefill_len_cutoff, -1]) - # Linear with SILU activation w1_out = ttnn.linear( x, diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py index a54f3057270a..3454d7bedfde 100644 --- a/models/experimental/mistral_24b/tt/vision_mmp.py +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -8,6 +8,10 @@ import ttnn from ttnn import ConcatMeshToTensor +""" +This file implements the Vision pixtral image submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + class TTMistral3PatchMerger(LightweightModule): def __init__( @@ -25,11 +29,6 @@ def __init__( self.spatial_merge_size = 2 # TODO Handle in Model_config spatial_merge_size self.patch_size = args.vision_patch_size self.args = args - # self.patch_size = ttnn.from_torch( - # torch.tensor(args.vision_patch_size, dtype=torch.int32), - # device=mesh_device, - # dtype=ttnn.int32 - # ) def get_weight(name): return torch.transpose(state_dict[f"{state_dict_prefix}{name}.weight"], -2, -1) @@ -51,7 +50,6 @@ def as_tensor(name, dtype, is_bias=False): mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), ) self.merging_weights = as_tensor("merging_layer", dtype) @@ -147,7 +145,6 @@ def as_tensor(name, dtype, is_bias=False): mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), ) self.linear_1_weight = as_tensor("linear_1", dtype) diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py index e80d8f662856..647b65910c15 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -9,6 +9,10 @@ from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP +""" +This file implements the Pixtral_image_block submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + class TtPixtralImageTransformerBlock(LightweightModule): def __init__( diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py index 85408be02b9f..b20b6a912999 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 diff --git a/models/experimental/mistral_24b/tt/vision_rope.py b/models/experimental/mistral_24b/tt/vision_rope.py index d356e8172807..8efec81d3cdc 100644 --- a/models/experimental/mistral_24b/tt/vision_rope.py +++ b/models/experimental/mistral_24b/tt/vision_rope.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -6,12 +6,12 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.common import precompute_vision_freqs +from models.tt_transformers.tt.common import precompute_mistral_vision_freqs from ttnn import ReplicateTensorToMesh def compute_gather_cos_sin(dhead, max_patches_per_side, theta, scale_factor, orig_context_len, position_ids): - cos, sin = precompute_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) + cos, sin = precompute_mistral_vision_freqs(dhead, max_patches_per_side, theta, scale_factor, orig_context_len) return cos, sin diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index d4e0e7cea40a..7dff033acdd8 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -307,7 +307,7 @@ def apply_scaling_vision(freqs: torch.Tensor, scale_factor: float, orig_context_ return freqs / scale_factor -def precompute_vision_freqs( +def precompute_mistral_vision_freqs( dim: int, max_patches_per_side: int, theta: float, scale_factor=None, orig_context_len=None ): # Compute base frequencies diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index e79f633b54e3..9786babd5a1c 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -571,6 +571,7 @@ def __init__( "Phi-3.5-mini-instruct": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, + "Mistral-Small-3.1-24B-Instruct-2503": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, } try: max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] From 733e9b35d1e9080892d5d88aa7bf58e48f7012c8 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 12:35:29 +0000 Subject: [PATCH 04/18] Register mistral-24b --- models/tt_transformers/tt/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 9786babd5a1c..41c58202668e 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -571,7 +571,7 @@ def __init__( "Phi-3.5-mini-instruct": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, "QwQ-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, "Qwen3-32B": {"N150": None, "N300": None, "T3K": 64, "TG": 128, "P150x4": 128}, - "Mistral-Small-3.1-24B-Instruct-2503": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, + "Mistral-Small-3.1-24B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128, "P150x4": 128}, } try: max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] From 4e6b224bccdaaba93aa0feeb64792ea782bec526 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 13:56:15 +0000 Subject: [PATCH 05/18] Address comments and refactor comments --- .../tests/pipeline_tests/test_vision_model.py | 2 +- .../mistral_24b/tests/test_vision_mlp.py | 1 - .../tt/pipeline/mistral_vision_tower.py | 1 + models/experimental/mistral_24b/tt/rmsnorm.py | 6 ++ .../mistral_24b/tt/vision_attention.py | 9 ++- .../mistral_24b/tt/vision_conv2d.py | 9 ++- .../experimental/mistral_24b/tt/vision_mlp.py | 7 +- .../experimental/mistral_24b/tt/vision_mmp.py | 9 +-- .../tt/vision_pixtral_image_block.py | 2 +- .../tt/vision_pixtral_transformer.py | 5 ++ .../mistral_24b/tt/vision_rope.py | 8 +- models/tt_transformers/tt/load_checkpoints.py | 16 +--- models/tt_transformers/tt/model_config.py | 77 +++---------------- 13 files changed, 53 insertions(+), 99 deletions(-) diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py index 761a177acaee..2939816b1dbe 100644 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py +++ b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py @@ -82,7 +82,7 @@ def test_mistral_vision_model(mesh_device, reset_seeds): model_args=model_args, ) - tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0] + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ :, : tt_output.shape[-1] ] diff --git a/models/experimental/mistral_24b/tests/test_vision_mlp.py b/models/experimental/mistral_24b/tests/test_vision_mlp.py index 6097605051d9..849b95673058 100644 --- a/models/experimental/mistral_24b/tests/test_vision_mlp.py +++ b/models/experimental/mistral_24b/tests/test_vision_mlp.py @@ -10,7 +10,6 @@ import ttnn -# from models.tt_transformers.tt.mlp import MLP from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP from models.tt_transformers.tt.model_config import ModelArgs from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py index f5b69bf578da..7a244d83543f 100644 --- a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py +++ b/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py @@ -4,6 +4,7 @@ """ This file implements the Vision Tower submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline constructs the vision tower from vision model architecture. """ import ttnn diff --git a/models/experimental/mistral_24b/tt/rmsnorm.py b/models/experimental/mistral_24b/tt/rmsnorm.py index 65f12713999d..0aa7cec84448 100644 --- a/models/experimental/mistral_24b/tt/rmsnorm.py +++ b/models/experimental/mistral_24b/tt/rmsnorm.py @@ -1,6 +1,12 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 + +""" +This is the modified version of the rmsnorm for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `simplified_rms_norm` function to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + import ttnn from models.common.lightweightmodule import LightweightModule diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py index ac096594ce42..57d9b1022da4 100644 --- a/models/experimental/mistral_24b/tt/vision_attention.py +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 """ -This file implements the vision attention submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. - +This is the modified version of the vision_attention for the Mistral-Small-3.1-24B-Instruct-2503 model. +We introduced the `apply_rotary_pos_emb_vision_tt` function to llama_image_attention to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. """ -import torch +import torch import ttnn + from models.common.lightweightmodule import LightweightModule from models.utility_functions import is_blackhole, nearest_32 @@ -162,7 +163,7 @@ def pad_head_dim(weight, heads_out=True): def forward(self, x_11SH, position_embeddings=None): seq_len = x_11SH.shape[-2] - MAX_MM_SEQ_LEN = self.configuration.VISION_MAX_MM_SEQ + MAX_MM_SEQ_LEN = seq_len if seq_len > MAX_MM_SEQ_LEN: x_11SH = ttnn.reshape(x_11SH, [1, seq_len // MAX_MM_SEQ_LEN, MAX_MM_SEQ_LEN, -1]) diff --git a/models/experimental/mistral_24b/tt/vision_conv2d.py b/models/experimental/mistral_24b/tt/vision_conv2d.py index ae9e14e8c3b9..4dc76f9f5ada 100644 --- a/models/experimental/mistral_24b/tt/vision_conv2d.py +++ b/models/experimental/mistral_24b/tt/vision_conv2d.py @@ -2,9 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 -import torch +""" +This is the modified version of the vision_patch_conv2d for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the llama_patch_conv2d to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" +import torch import ttnn + from models.common.lightweightmodule import LightweightModule @@ -57,7 +62,7 @@ def __init__( self._unfold = torch.nn.Unfold(kernel_size=self.kernel_size, stride=self.stride) - weight = state_dict[f"{state_dict_prefix}weight"] + weight = state_dict[f"{state_dict_prefix}_linear.weight"] if weight.ndim == 4: weight = weight.reshape(out_channels, -1).T diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py index d79a85f1118a..30c84ea94f03 100644 --- a/models/experimental/mistral_24b/tt/vision_mlp.py +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -1,13 +1,15 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 + """ +This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model. This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. - """ -import torch +import torch import ttnn + from models.common.lightweightmodule import LightweightModule @@ -48,7 +50,6 @@ def as_tensor(name, dtype, is_bias=False): mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - # cache_file_name=cache_name(name), ) # Weights and Biases diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py index 3454d7bedfde..6e88dbf65680 100644 --- a/models/experimental/mistral_24b/tt/vision_mmp.py +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -1,6 +1,9 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file implements the Vision MultiModalProjector submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" import torch from models.common.lightweightmodule import LightweightModule @@ -8,10 +11,6 @@ import ttnn from ttnn import ConcatMeshToTensor -""" -This file implements the Vision pixtral image submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. -""" - class TTMistral3PatchMerger(LightweightModule): def __init__( @@ -26,7 +25,7 @@ def __init__( super().__init__() self.device = mesh_device hidden_size = args.vision_dim - self.spatial_merge_size = 2 # TODO Handle in Model_config spatial_merge_size + self.spatial_merge_size = 2 self.patch_size = args.vision_patch_size self.args = args diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py index 647b65910c15..66a010a35af8 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py @@ -10,7 +10,7 @@ from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP """ -This file implements the Pixtral_image_block submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This file implements the pixtral image block specific for the Mistral-Small-3.1-24B-Instruct-2503 model. """ diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py index b20b6a912999..7e45e9ff8573 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py +++ b/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py @@ -2,6 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 +""" +This file implements the Vision Transformer submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +This pipeline iterates over the pixtral image blocks to generate the image embeddings. +""" + from tqdm import tqdm from models.common.lightweightmodule import LightweightModule diff --git a/models/experimental/mistral_24b/tt/vision_rope.py b/models/experimental/mistral_24b/tt/vision_rope.py index 8efec81d3cdc..bb299dc4ca07 100644 --- a/models/experimental/mistral_24b/tt/vision_rope.py +++ b/models/experimental/mistral_24b/tt/vision_rope.py @@ -2,9 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 -import torch +""" +This is the modified version of the RoPE for the Mistral-Small-3.1-24B-Instruct-2503 model. +We have modified the compute_gather_cos_sin function of RMSNorm to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. +""" +import torch import ttnn + from models.common.lightweightmodule import LightweightModule from models.tt_transformers.tt.common import precompute_mistral_vision_freqs from ttnn import ReplicateTensorToMesh @@ -71,7 +76,6 @@ def __init__( def get_rot_mats(self, position_idxs, return_rot_idxs=False): device = self.device - # return self.cos_matrix, self.sin_matrix # If position_idxs is a torch tensor, get the TTNN version of it if isinstance(position_idxs, torch.Tensor): rot_idxs = position_idxs.unsqueeze(0) diff --git a/models/tt_transformers/tt/load_checkpoints.py b/models/tt_transformers/tt/load_checkpoints.py index 5b658b9dfca9..ddbcccf41557 100644 --- a/models/tt_transformers/tt/load_checkpoints.py +++ b/models/tt_transformers/tt/load_checkpoints.py @@ -595,6 +595,7 @@ def map_hf_to_meta_keys(loaded_weights): ("o_proj", "wo"), ("q_norm", "q_norm"), ("k_norm", "k_norm"), + ("patch_conv.weight", "patch_conv._linear.weight"), ] return replace_keys(loaded_weights, replacements) @@ -613,20 +614,9 @@ def map_vision_meta_to_hf_keys(loaded_weights): ("wk", "k_proj"), ("wv", "v_proj"), ("wo", "o_proj"), + ("_linear.weight", "weight"), ] - - extra_mapping = [ - ("attention_norm", "input_layernorm"), - ("ffn_norm", "post_attention_layernorm"), - ("attention", "self_attn"), - ("feed_forward", "mlp"), - ] - - model_name = os.getenv("HF_MODEL") - if "Mistral" in model_name: - mapping = base_mapping - else: - mapping = base_mapping + extra_mapping + mapping = base_mapping return replace_keys(loaded_weights, mapping) diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 41c58202668e..bb94d3e7c251 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1625,9 +1625,8 @@ def _set_params(self, checkpoint_dir): ) def _set_vision_params(self, vision_config): - vision_config = config.get("vision_config", config) - self.vision_chunk_size = vision_config.get("vision_chunk_size", 896) + self.image_size = vision_config.get("image_size", 896) self.vision_max_num_chunks = vision_config.get("vision_max_num_chunks", 4) self.vision_num_cross_attention_layers = vision_config.get("vision_num_cross_attention_layers", 8) self.vision_dim = vision_config.get("hidden_size", 1152) @@ -1662,12 +1661,6 @@ def _set_vision_params(self, vision_config): # Optional tuning knobs self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) self.vision_n_global_layers = vision_config.get("n_global_layers", 8) - # self.vision_max_num_tiles = vision_config.get("max_num_tiles", 4) - # self.vision_n_global_layers = vision_config.get("n_global_layers", 8) - - # # Optional Meta-specific knobs - # self.vision_max_num_chunks = vision_config.get("max_num_chunks", 4) - # self.vision_num_cross_attention_layers = vision_config.get("num_cross_attention_layers", -1) def _set_hf_params(self, checkpoint_dir): def merge_text_config(base_config): @@ -1741,7 +1734,7 @@ def is_vision(self): return self.vision_chunk_size > 0 def get_state_dict_prefix(self, module_name, layer_num, is_vision=False): - if self.is_vision() and "Mistral-Small-3.1-24B" not in self.model_name: + if self.is_vision() and self.model_name.startswith("Mistral") and "Small-3.1-24B" not in self.model_name: text_prefix = self.state_dict_text_prefix else: text_prefix = "" if not is_vision else self.state_dict_text_prefix @@ -2384,20 +2377,6 @@ def reference_vision_rms_norm(self): layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_rms_norm_qwen(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.visual.blocks[0].norm1 - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) - return layer - - def reference_vision_rms_norm_qwen_merger(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.visual.merger - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) - return layer - def reference_rms_norm(self): if self.checkpoint_type == CheckpointType.Meta: from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm @@ -2463,30 +2442,19 @@ def reference_vision_model(self): layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_model(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - return layer - - def reference_vision_mlp(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder.layers[0].mlp - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - return layer - - def reference_pixtral_image_block(self, layer_num=0): + def reference_vision_mlp(self, layer_idx=0): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.transformer.layers[layer_num] + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + layer = model.vision_tower.transformer.layers[layer_idx].feed_forward + else: + layer = model.vision_tower.vision_model.encoder.layers[0].mlp layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_mlp(self, layer_idx=0): + def reference_pixtral_image_block(self, layer_num=0): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.transformer.layers[layer_idx].feed_forward + layer = model.vision_tower.transformer.layers[layer_num] layer._load_state_dict = layer.load_state_dict layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer @@ -2510,8 +2478,6 @@ def reference_siglip_patch_embed(self): layer = model.vision_tower.vision_model.embeddings.patch_embedding # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_pos_embedding(self): @@ -2519,8 +2485,6 @@ def reference_vision_pos_embedding(self): layer = model.vision_tower.vision_model.embeddings.position_embedding # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_embedding(self): @@ -2542,22 +2506,6 @@ def reference_vision_layernorm(self, layer_name="layer_norm1"): # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer - def reference_vision_attention(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder.layers[0].self_attn # Common naming - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - return layer - - def reference_vision_layernorm(self): - model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder.layers[0].layer_norm1 - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - return layer - def reference_vision_attention(self, layer_idx=0): model = self.reference_vision_transformer(wrap=False) if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: @@ -2581,15 +2529,10 @@ def reference_vision_encoder_block(self): layer = model.vision_tower.vision_model.encoder.layers[0] # layer._load_state_dict = layer.load_state_dict # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) - layer._load_state_dict = layer.load_state_dict - layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) return layer def reference_vision_encoder(self): model = self.reference_vision_transformer(wrap=False) - layer = model.vision_tower.vision_model.encoder - # layer._load_state_dict = layer.load_state_dict - # layer.load_state_dict = lambda x: layer._load_state_dict(convert_vision_meta_to_hf(x, self.head_dim)) if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: layer = model.vision_tower.transformer else: @@ -2656,7 +2599,7 @@ def reference_attention(self): "Gemma3Attention", ) wrapper = HfAttentionWrapper( - layer, self.head_dim, model.model.rotary_emb wif use_position_embeddings else None + layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None ) return wrapper From 1b4f88b7ed163859c5bebfbd495e9cf38a3f78d1 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 3 Jul 2025 12:38:52 +0000 Subject: [PATCH 06/18] Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model --- .../pixtral_transformer_inputs/demo_small.jpg | Bin 0 -> 8554 bytes .../pixtral_transformer_inputs/people.jpg | Bin 0 -> 49606 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg create mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f51ba21be8d4cbeb5faceca2b88150b50926abf2 GIT binary patch literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip literal 0 HcmV?d00001 diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16dad8dcbf18374fbb920fa4ad944e7e9aef8b89 GIT binary patch literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Wed, 20 Aug 2025 10:16:00 +0000 Subject: [PATCH 07/18] Rebase Experimental E2E and refactor to include tt_ccl --- .../pixtral_transformer_inputs/demo_small.jpg | Bin 8554 -> 0 bytes .../pixtral_transformer_inputs/people.jpg | Bin 49606 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg delete mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg deleted file mode 100644 index f51ba21be8d4cbeb5faceca2b88150b50926abf2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg deleted file mode 100644 index 16dad8dcbf18374fbb920fa4ad944e7e9aef8b89..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Fri, 22 Aug 2025 06:49:17 +0000 Subject: [PATCH 08/18] Refactor comments and clean code --- models/experimental/mistral_24b/tt/vision_attention.py | 2 +- models/experimental/mistral_24b/tt/vision_mlp.py | 1 - models/experimental/mistral_24b/tt/vision_mmp.py | 4 ++++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/experimental/mistral_24b/tt/vision_attention.py index 57d9b1022da4..03c6f755dfc9 100644 --- a/models/experimental/mistral_24b/tt/vision_attention.py +++ b/models/experimental/mistral_24b/tt/vision_attention.py @@ -4,8 +4,8 @@ """ This is the modified version of the vision_attention for the Mistral-Small-3.1-24B-Instruct-2503 model. We introduced the `apply_rotary_pos_emb_vision_tt` function to llama_image_attention to be compatible with the Mistral-Small-3.1-24B-Instruct-2503 model. -""" +""" import torch import ttnn diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py index 30c84ea94f03..46bf41db2e6e 100644 --- a/models/experimental/mistral_24b/tt/vision_mlp.py +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 - """ This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model. This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py index 6e88dbf65680..101d509d5810 100644 --- a/models/experimental/mistral_24b/tt/vision_mmp.py +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -11,6 +11,10 @@ import ttnn from ttnn import ConcatMeshToTensor +""" +This file implements the Vision pixtral image submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. +""" + class TTMistral3PatchMerger(LightweightModule): def __init__( From c863c6e818cdbd1a0bd8102a214440a02f58ce14 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 13:56:15 +0000 Subject: [PATCH 09/18] Address comments and refactor comments --- models/experimental/mistral_24b/tt/vision_mlp.py | 1 + models/experimental/mistral_24b/tt/vision_mmp.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/experimental/mistral_24b/tt/vision_mlp.py index 46bf41db2e6e..30c84ea94f03 100644 --- a/models/experimental/mistral_24b/tt/vision_mlp.py +++ b/models/experimental/mistral_24b/tt/vision_mlp.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 + """ This is the modified version of the FeedForward for the Mistral-Small-3.1-24B-Instruct-2503 model. This file implements the Vision FeedForward submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/experimental/mistral_24b/tt/vision_mmp.py index 101d509d5810..6e88dbf65680 100644 --- a/models/experimental/mistral_24b/tt/vision_mmp.py +++ b/models/experimental/mistral_24b/tt/vision_mmp.py @@ -11,10 +11,6 @@ import ttnn from ttnn import ConcatMeshToTensor -""" -This file implements the Vision pixtral image submodule specific for the Mistral-Small-3.1-24B-Instruct-2503 model. -""" - class TTMistral3PatchMerger(LightweightModule): def __init__( From a82d10c99fa6e4ecd882beefc505be25ab26d392 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 3 Jul 2025 12:38:52 +0000 Subject: [PATCH 10/18] Add Support for mistralai/Mistral-Small-3.1-24B-Instruct-2503 model --- models/tt_transformers/tt/model_config.py | 2 ++ .../pixtral_transformer_inputs/demo_small.jpg | Bin 0 -> 8554 bytes .../pixtral_transformer_inputs/people.jpg | Bin 0 -> 49606 bytes 3 files changed, 2 insertions(+) create mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg create mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index bb94d3e7c251..204b228826ab 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1396,6 +1396,8 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): def _get_text_prefix(self): if self.is_vision(): + if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: + return "language_model." return "text_model." else: return "" diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f51ba21be8d4cbeb5faceca2b88150b50926abf2 GIT binary patch literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip literal 0 HcmV?d00001 diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16dad8dcbf18374fbb920fa4ad944e7e9aef8b89 GIT binary patch literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Wed, 20 Aug 2025 12:18:08 +0000 Subject: [PATCH 11/18] Migrate mistral 24b to tt_transformers --- .../demo/simple_vision_demo.py | 85 ++++++++++---- .../multimodal/mistral_24b}/test_conv2d.py | 2 +- .../mistral_24b}/test_patch_rot_emb.py | 12 +- .../mistral_24b}/test_pixtral_transformer.py | 11 +- .../mistral_24b}/test_vision_attention.py | 12 +- .../mistral_24b}/test_vision_mlp.py | 0 .../mistral_24b}/test_vision_rms.py | 16 +-- models/tt_transformers/tt/common.py | 45 ++++++++ models/tt_transformers/tt/generator.py | 108 +++++++++++++++++- .../mistral_24b/mistral_e2e_model.py} | 61 +++++----- .../mistral_24b}/mistral_vision_tower.py | 10 +- .../tt/multimodal/mistral_24b}/rmsnorm.py | 0 .../mistral_24b}/vision_attention.py | 0 .../multimodal/mistral_24b}/vision_conv2d.py | 0 .../tt/multimodal/mistral_24b}/vision_mlp.py | 0 .../tt/multimodal/mistral_24b}/vision_mmp.py | 5 +- .../multimodal/mistral_24b}/vision_model.py | 5 +- .../vision_pixtral_image_block.py | 9 +- .../vision_pixtral_transformer.py | 2 +- .../tt/multimodal/mistral_24b}/vision_rope.py | 0 20 files changed, 269 insertions(+), 114 deletions(-) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_conv2d.py (97%) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_patch_rot_emb.py (95%) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_pixtral_transformer.py (93%) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_vision_attention.py (93%) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_vision_mlp.py (100%) rename models/{experimental/mistral_24b/tests => tt_transformers/tests/multimodal/mistral_24b}/test_vision_rms.py (94%) rename models/{experimental/mistral_24b/tt/model.py => tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py} (71%) rename models/{experimental/mistral_24b/tt/pipeline => tt_transformers/tt/multimodal/mistral_24b}/mistral_vision_tower.py (93%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/rmsnorm.py (100%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_attention.py (100%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_conv2d.py (100%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_mlp.py (100%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_mmp.py (98%) rename models/{experimental/mistral_24b/tt/pipeline => tt_transformers/tt/multimodal/mistral_24b}/vision_model.py (87%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_pixtral_image_block.py (89%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_pixtral_transformer.py (94%) rename models/{experimental/mistral_24b/tt => tt_transformers/tt/multimodal/mistral_24b}/vision_rope.py (100%) diff --git a/models/tt_transformers/demo/simple_vision_demo.py b/models/tt_transformers/demo/simple_vision_demo.py index d51340a4429f..951f0b7e13c9 100644 --- a/models/tt_transformers/demo/simple_vision_demo.py +++ b/models/tt_transformers/demo/simple_vision_demo.py @@ -27,7 +27,9 @@ import ttnn from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf from models.perf.benchmarking_utils import BenchmarkProfiler +from models.tt_transformers.tt.common import hf_multimodal_encode from models.tt_transformers.tt.generator import Generator +from models.tt_transformers.tt.model_config import CheckpointType def get_batch_sampler(temperature, top_p, tokenizer): @@ -62,6 +64,7 @@ def create_multimodal_model( ): from models.tt_transformers.tt.model_config import ModelArgs from models.tt_transformers.tt.multimodal.llama_vision_model import CrossAttentionTransformer + from models.tt_transformers.tt.multimodal.mistral_24b.mistral_e2e_model import MistralTransformer tt_model_args = ModelArgs(mesh_device, max_batch_size=max_batch_size) assert tt_model_args.is_vision(), "This model is multimodal" @@ -76,14 +79,25 @@ def create_multimodal_model( if checkpoint is None: checkpoint = tt_model_args.load_state_dict() - model = CrossAttentionTransformer( - mesh_device, - state_dict=checkpoint, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - dtype=dtype, - configuration=tt_model_args, - use_paged_kv_cache=use_paged_kv_cache, - ) + + if tt_model_args.base_model_name == "Mistral-Small-3.1-24B": + model = MistralTransformer( + mesh_device=mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(ttnn.bfloat8_b), + dtype=ttnn.bfloat8_b, + args=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) + else: + model = CrossAttentionTransformer( + mesh_device, + state_dict=checkpoint, + weight_cache_path=tt_model_args.weight_cache_path(dtype), + dtype=dtype, + configuration=tt_model_args, + use_paged_kv_cache=use_paged_kv_cache, + ) return tt_model_args, model, checkpoint @@ -136,7 +150,7 @@ def prepare_generator_args( ) @pytest.mark.parametrize( "test_type,max_seq_len", - (("normal", 512),), + (("normal", 2048),), ids=["normal"], ) @pytest.mark.parametrize( @@ -182,9 +196,6 @@ def test_multimodal_demo_text( profiler = BenchmarkProfiler() profiler.start("run") - ckpt_dir = os.environ["LLAMA_DIR"] - tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") - num_devices = mesh_device.get_num_devices() if isinstance(mesh_device, ttnn.MeshDevice) else 1 max_batch_size *= data_parallel # input batch_size is interpreted as size per DP group @@ -195,11 +206,27 @@ def test_multimodal_demo_text( max_batch_size=max_batch_size, max_seq_len=max_seq_len, ) + + HF_MODEL = model_args[0].checkpoint_type == CheckpointType.HuggingFace + + if not HF_MODEL: + ckpt_dir = os.environ["LLAMA_DIR"] + tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model") + + tokenizer = Tokenizer(model_path=tokenizer_path) + formatter = ChatFormat(tokenizer) + else: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_args[0].CKPT_DIR) + tokenizer = model_args[0].tokenizer + generator = Generator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) - formatter = ChatFormat(tokenizer) - xattn_caches = [model.setup_cache(model_args[i].max_batch_size) for i, model in enumerate(generator.model)] + xattn_caches = [ + model.setup_cache(model_args[i].max_batch_size) if not HF_MODEL else None + for i, model in enumerate(generator.model) + ] # Create random images for trace capture with specific dimensions trace_img_560x560 = create_random_image(560, 560) @@ -264,6 +291,8 @@ def test_multimodal_demo_text( _num_prefill_tokens = 0 _num_decode_tokens = 0 + prompt_encoder = hf_multimodal_encode if HF_MODEL else formatter.encode_dialog_prompt + for iter_num in range(warmup_iters + 1): logger.info(f"Iteration {iter_num}") current_dialogs = trace_dialogs + dialogs @@ -273,9 +302,15 @@ def test_multimodal_demo_text( for msg in dialog: print(f"{msg.role.capitalize()}: {msg.content}\n") batch_model_input = [ - formatter.encode_dialog_prompt(dialog, tool_prompt_format=False) for dialog in batch_dialogs + prompt_encoder(dialog, processor) if HF_MODEL else prompt_encoder(dialog, tool_prompt_format=False) + for dialog in batch_dialogs ] + if HF_MODEL: + image_sizes = [model_input.image_sizes for model_input in batch_model_input] + else: + image_sizes = None + # Do initial prefill vision_images = [ model_input.vision.images if model_input.vision else None for model_input in batch_model_input @@ -288,7 +323,7 @@ def test_multimodal_demo_text( total_lens = prefill_lens + max_gen_len # Create padded tokens tensor for batch - pad_id = tokenizer.pad_id + pad_id = tokenizer.pad_token_id if HF_MODEL else tokenizer.pad_id bsz = len(prompt_tokens) tokens = torch.full((bsz, max(total_lens)), pad_id, dtype=torch.long) @@ -312,6 +347,7 @@ def test_multimodal_demo_text( xattn_caches, total_lens, prefill_lens, + image_sizes=image_sizes, ) # Get cached prefill time @@ -323,12 +359,7 @@ def test_multimodal_demo_text( decode_batch_xattn_masks, decode_batch_text_masks, ) = generator.prefill_forward( - vision_images, - vision_mask, - tokens, - xattn_caches, - total_lens, - prefill_lens, + vision_images, vision_mask, tokens, xattn_caches, total_lens, prefill_lens, image_sizes=image_sizes ) prefill_end = time.perf_counter() @@ -375,12 +406,16 @@ def test_multimodal_demo_text( ) # gen_idx is (num_tokens - 1) to avoid counting compile iter # Log full text output for each user in batch - vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] + if HF_MODEL: + # For HF models, get vision tokens from the processor if they exist + vision_tokens = [] + else: + vision_tokens = [tokenizer.special_tokens["<|image|>"], 128256] for user_id in range(max_batch_size): # Remove <|image|> tokens since they break the tokenizer tokens_out = [ - t if t not in vision_tokens else tokenizer.pad_id + t if t not in vision_tokens else pad_id for t in tokens[user_id].tolist()[: position_id[user_id] + 2] ] text = tokenizer.decode(tokens_out) diff --git a/models/experimental/mistral_24b/tests/test_conv2d.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py similarity index 97% rename from models/experimental/mistral_24b/tests/test_conv2d.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py index cfb05115560c..f754b220bfc2 100644 --- a/models/experimental/mistral_24b/tests/test_conv2d.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_conv2d.py @@ -12,7 +12,7 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from ttnn import ConcatMeshToTensor diff --git a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py similarity index 95% rename from models/experimental/mistral_24b/tests/test_patch_rot_emb.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py index 4cb5a284a912..43f527a464a2 100644 --- a/models/experimental/mistral_24b/tests/test_patch_rot_emb.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py @@ -1,17 +1,17 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 -from loguru import logger +import os -import torch import pytest -import os -import ttnn +import torch +from loguru import logger -from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup +import ttnn -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() diff --git a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py similarity index 93% rename from models/experimental/mistral_24b/tests/test_pixtral_transformer.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py index 0458847993b0..54e762e393f6 100644 --- a/models/experimental/mistral_24b/tests/test_pixtral_transformer.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py @@ -8,10 +8,8 @@ from loguru import logger import ttnn -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.model_config import ModelArgs - from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.model_config import ModelArgs from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -29,11 +27,6 @@ ], indirect=True, ) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) def test_image_transformer_inference(batch, num_chunks, mesh_device): pcc_required = 0.99 @@ -58,10 +51,8 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device): all_tests_pass = True - tt_ccl = TT_CCL(mesh_device) tt_model = TtPixtralTransformer( mesh_device, - tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=None, diff --git a/models/experimental/mistral_24b/tests/test_vision_attention.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py similarity index 93% rename from models/experimental/mistral_24b/tests/test_vision_attention.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py index 8c216339b02d..7bb2a0052064 100644 --- a/models/experimental/mistral_24b/tests/test_vision_attention.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py @@ -8,12 +8,9 @@ from loguru import logger import ttnn -from models.tt_transformers.tt.ccl import TT_CCL +from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention from models.tt_transformers.tt.model_config import ModelArgs from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - -from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention - from ttnn import ConcatMeshToTensor @@ -36,11 +33,6 @@ "batch_size", (1,), ) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) def test_vision_attention(mesh_device, seq_len, batch_size): logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") dtype = ttnn.bfloat16 @@ -61,10 +53,8 @@ def test_vision_attention(mesh_device, seq_len, batch_size): n_heads = model_args.vision_attn_n_heads head_dim = hidden_size // n_heads - tt_ccl = TT_CCL(mesh_device) tt_model = TtLlamaImageAttention( mesh_device, - tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=model_args.weight_cache_path(dtype), diff --git a/models/experimental/mistral_24b/tests/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py similarity index 100% rename from models/experimental/mistral_24b/tests/test_vision_mlp.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py diff --git a/models/experimental/mistral_24b/tests/test_vision_rms.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py similarity index 94% rename from models/experimental/mistral_24b/tests/test_vision_rms.py rename to models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py index 186ddd67c1cb..0be74c271781 100644 --- a/models/experimental/mistral_24b/tests/test_vision_rms.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py @@ -1,19 +1,13 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger +import os -import torch import pytest -import os +import torch +from loguru import logger import ttnn -from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm from models.tt_transformers.tt.model_config import ModelArgs +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @torch.no_grad() diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 7dff033acdd8..b7d8023a4b51 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -5,9 +5,11 @@ import math import re from enum import Enum +from types import SimpleNamespace from typing import Optional import torch +from llama_models.llama3.api.datatypes import ImageMedia from loguru import logger from pydantic import AliasChoices, BaseModel, Field @@ -672,3 +674,46 @@ def create_tt_model( tt_kv_cache = [l.attention.layer_past for l in model.layers] if paged_attention_config else None return tt_model_args, model, tt_kv_cache, state_dict + + +def hf_multimodal_encode(messages, processor): + hf_messages = [] + + for msg in messages: + hf_content = [] + + for item in msg.content: + if isinstance(item, ImageMedia): + hf_content.append( + { + "type": "image", + "image": item.image, + } + ) + elif isinstance(item, str): + hf_content.append( + { + "type": "text", + "text": item, + } + ) + + hf_messages.append( + { + "role": msg.role, + "content": hf_content, + } + ) + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return SimpleNamespace( + **encoded, + tokens=encoded["input_ids"].squeeze(0), + vision=SimpleNamespace( + images=encoded["pixel_values"], + mask=None, + ), + ) diff --git a/models/tt_transformers/tt/generator.py b/models/tt_transformers/tt/generator.py index 5433d11e3538..115c23ee3b9d 100644 --- a/models/tt_transformers/tt/generator.py +++ b/models/tt_transformers/tt/generator.py @@ -22,6 +22,7 @@ get_padded_prefill_len, num_blocks_in_seq, ) +from models.tt_transformers.tt.model_config import CheckpointType @dataclass(frozen=True) @@ -86,6 +87,7 @@ def prefill_forward_text( seq_len = int(prompt_lens[idx]) last_token_idx = seq_len - 1 prefill_seq_len = get_padded_prefill_len(seq_len) + local_kwargs = kwargs.copy() logger.info(f"Prefilling User {user_id + 1} up to {seq_len} tokens") @@ -101,6 +103,12 @@ def prefill_forward_text( ) model_kv_cache = kv_cache[model_id] if kv_cache is not None else None + # Check if 'pixel_values' exists and index it safely + if "pixel_values" in local_kwargs: + local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx] + if "image_sizes" in local_kwargs: + local_kwargs["image_sizes"] = local_kwargs["image_sizes"][idx] + logits = self.prefill_forward_single_user_text( prefill_ids, page_table=page_table_user, @@ -108,7 +116,7 @@ def prefill_forward_text( last_token_idx=last_token_idx, kv_cache=model_kv_cache, model_id=model_id, - **kwargs, + **local_kwargs, ) out_list.append(logits) @@ -493,6 +501,61 @@ def _prefill_forward_single_user( # Note: This function is called by vLLM def prefill_forward( + self, + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=None, + kv_cache=None, + cross_page_table=None, + empty_slots=None, + **kwargs, + ): + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + logits = self.prefill_forward_text( + tokens, + page_table=page_table, + kv_cache=kv_cache, + prompt_lens=prompt_lens, + pixel_values=vision_images, + **kwargs, + ) + + return logits, None, None, None, None + + else: + ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) = self.prefill_forward_llama_vision( + vision_images, + vision_masks, + tokens, + xattn_caches, + total_lens, + prompt_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + empty_slots=empty_slots, + ) + + return ( + output_logits, + prefill_output_xattn_masks, + prefill_output_full_text_row_masked_out_masks, + decode_output_xattn_masks, + decode_output_full_text_row_masked_out_masks, + ) + + # Note: This function is called by vLLM + def prefill_forward_llama_vision( self, vision_images, vision_masks, @@ -589,7 +652,7 @@ def prefill_forward( ) # Note: This function is called by vLLM - def decode_forward( + def decode_forward_llama_vision( self, start_pos, tokens, @@ -653,6 +716,47 @@ def decode_forward( else: return tt_logits + def decode_forward( + self, + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches=None, + page_table=None, + kv_cache=None, + cross_page_table=None, + enable_trace=True, + read_from_device=True, + ): + pass + + if self.model_args[0].checkpoint_type == CheckpointType.HuggingFace: + return self.decode_forward_text( + tokens, + start_pos, + enable_trace=enable_trace, + page_table=page_table, + kv_cache=kv_cache, + ) + else: + return self.decode_forward_llama_vision( + start_pos, + tokens, + prefill_cross_attention_masks, + prefill_full_text_row_masked_out_mask, + decode_cross_attention_masks, + decode_full_text_row_masked_out_mask, + xattn_caches, + page_table, + kv_cache, + cross_page_table, + enable_trace, + read_from_device, + ) + # Note: This function is called by vLLM def read_decode_output(self, tt_out, async_read=False): """ diff --git a/models/experimental/mistral_24b/tt/model.py b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py similarity index 71% rename from models/experimental/mistral_24b/tt/model.py rename to models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py index 764c12bf3a1d..a4257bfe1bc7 100644 --- a/models/experimental/mistral_24b/tt/model.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_e2e_model.py @@ -12,10 +12,8 @@ import ttnn -import torch - from models.tt_transformers.tt.model import Transformer -from ttnn import ConcatMeshToTensor +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer class MistralTransformer(Transformer): @@ -39,17 +37,25 @@ def __init__( use_paged_kv_cache=use_paged_kv_cache, ) - def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): + self.vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + state_dict=state_dict, + state_dict_prefix="vision_tower.", + dtype=dtype, + model_args=args, + tt_ccl=self.tt_ccl, + ) + + def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_page_table=None, **kwargs): """ Inputs are torch tensors or python types. This function returns ttnn tensors on device. TODO: Debate whether this function is responsible for padding """ - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] + S = pt_tokens.shape[-1] tokens = ttnn.from_torch( - tokens, + pt_tokens.reshape(1, 1, 1, -1), device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, @@ -57,37 +63,22 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag ) tokens_embd = self.embd(tokens) - pixel_values = kwargs["processed_inputs"]["pixel_values"] - input_ids = kwargs["processed_inputs"]["input_ids"] - image_sizes = kwargs["processed_inputs"]["image_sizes"] + vision_output = self.compute_vision_token(**kwargs) - if pixel_values is not None: - vision_model = kwargs["vision_model"] - vision_output = vision_model(pixel_values, image_sizes) - vision_output_torch = ttnn.to_torch( - vision_output, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1) - )[:, : vision_output.shape[-1]] - tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ConcatMeshToTensor(self.mesh_device, dim=-1)) - sliced_token_embds = tokens_embd[: tokens_embd.shape[0]] - - image_features = vision_output_torch - - input_ids = torch.nn.functional.pad( - input_ids, (0, tokens_embd.shape[1] - input_ids.shape[1]), "constant", 0 - ) - special_image_mask = (input_ids == 10).unsqueeze(-1) + if vision_output is not None: + tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == 10).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(tokens_embd) image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) - tokens_embd = ttnn.from_torch( + tokens_embd = self.args.prepare_residual_tensor_prefill( tokens_embd, - dtype=ttnn.bfloat16, - device=self.mesh_device, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, dims=(None, 2), mesh_shape=list(self.mesh_device.shape) - ), ) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) @@ -132,3 +123,9 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag tt_chunk_page_table = None return tokens_embd, tt_rot_mats_prefill_global, tt_rot_mats_prefill_local, tt_page_table, tt_chunk_page_table + + def compute_vision_token(self, pixel_values, image_sizes): + if pixel_values is not None: + vision_output = self.vision_model(pixel_values, image_sizes) + return vision_output + return None diff --git a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py similarity index 93% rename from models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py rename to models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py index 7a244d83543f..d875011d2712 100644 --- a/models/experimental/mistral_24b/tt/pipeline/mistral_vision_tower.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/mistral_vision_tower.py @@ -9,13 +9,11 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.mistral_24b.tt.vision_conv2d import TtMistralConv2dPatch -from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm - from models.tt_transformers.tt.common import position_ids_in_meshgrid_tt -from models.experimental.mistral_24b.tt.vision_rope import VisionRotarySetup as RotarySetup - -from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_conv2d import TtMistralConv2dPatch +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup from ttnn import ConcatMeshToTensor diff --git a/models/experimental/mistral_24b/tt/rmsnorm.py b/models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py similarity index 100% rename from models/experimental/mistral_24b/tt/rmsnorm.py rename to models/tt_transformers/tt/multimodal/mistral_24b/rmsnorm.py diff --git a/models/experimental/mistral_24b/tt/vision_attention.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py similarity index 100% rename from models/experimental/mistral_24b/tt/vision_attention.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_attention.py diff --git a/models/experimental/mistral_24b/tt/vision_conv2d.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py similarity index 100% rename from models/experimental/mistral_24b/tt/vision_conv2d.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_conv2d.py diff --git a/models/experimental/mistral_24b/tt/vision_mlp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py similarity index 100% rename from models/experimental/mistral_24b/tt/vision_mlp.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_mlp.py diff --git a/models/experimental/mistral_24b/tt/vision_mmp.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py similarity index 98% rename from models/experimental/mistral_24b/tt/vision_mmp.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py index 6e88dbf65680..2dd21e0d0177 100644 --- a/models/experimental/mistral_24b/tt/vision_mmp.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_mmp.py @@ -6,9 +6,10 @@ """ import torch -from models.common.lightweightmodule import LightweightModule -from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm + import ttnn +from models.common.lightweightmodule import LightweightModule +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm from ttnn import ConcatMeshToTensor diff --git a/models/experimental/mistral_24b/tt/pipeline/vision_model.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py similarity index 87% rename from models/experimental/mistral_24b/tt/pipeline/vision_model.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py index ebc816a71279..f85cbf2e7ecc 100644 --- a/models/experimental/mistral_24b/tt/pipeline/vision_model.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_model.py @@ -10,9 +10,8 @@ import ttnn from models.common.lightweightmodule import LightweightModule - -from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower -from models.experimental.mistral_24b.tt.vision_mmp import TTMistral3MultiModalProjector +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mmp import TTMistral3MultiModalProjector class TtMistralVisionTransformer(LightweightModule): diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py similarity index 89% rename from models/experimental/mistral_24b/tt/vision_pixtral_image_block.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py index 66a010a35af8..a564dc282ba6 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_image_block.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_image_block.py @@ -4,10 +4,11 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from models.experimental.mistral_24b.tt.rmsnorm import RMSNorm - -from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention -from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import ( + TtMistralImageAttention as TtLlamaImageAttention, +) +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP """ This file implements the pixtral image block specific for the Mistral-Small-3.1-24B-Instruct-2503 model. diff --git a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py similarity index 94% rename from models/experimental/mistral_24b/tt/vision_pixtral_transformer.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py index 7e45e9ff8573..d21e417875f0 100644 --- a/models/experimental/mistral_24b/tt/vision_pixtral_transformer.py +++ b/models/tt_transformers/tt/multimodal/mistral_24b/vision_pixtral_transformer.py @@ -10,7 +10,7 @@ from tqdm import tqdm from models.common.lightweightmodule import LightweightModule -from models.experimental.mistral_24b.tt.vision_pixtral_image_block import TtPixtralImageTransformerBlock +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_image_block import TtPixtralImageTransformerBlock class TtPixtralTransformer(LightweightModule): diff --git a/models/experimental/mistral_24b/tt/vision_rope.py b/models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py similarity index 100% rename from models/experimental/mistral_24b/tt/vision_rope.py rename to models/tt_transformers/tt/multimodal/mistral_24b/vision_rope.py From f26394f4265b6e2973385b67c762cc13c6660502 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 21 Aug 2025 10:11:01 +0000 Subject: [PATCH 12/18] Add unit tests to adopt tt_ccl --- .../mistral_24b/test_patch_rot_emb.py | 3 +- .../mistral_24b/test_pixtral_transformer.py | 10 +- .../mistral_24b/test_vision_attention.py | 12 ++- .../multimodal/mistral_24b/test_vision_mlp.py | 4 +- .../mistral_24b/test_vision_model.py | 95 +++++++++++++++++++ .../multimodal/mistral_24b/test_vision_rms.py | 2 +- .../mistral_24b/test_vision_tower.py | 73 ++++++++++++++ 7 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py create mode 100644 models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py index 43f527a464a2..e62c0c9751c9 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py @@ -8,9 +8,10 @@ from loguru import logger import ttnn +from models.tt_transformers.tt.model_config import ModelArgs +# models/tt_transformers/tt/common.py from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup -from models.tt_transformers.tt.model_config import ModelArgs from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py index 54e762e393f6..618e7122e2f0 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_pixtral_transformer.py @@ -8,8 +8,9 @@ from loguru import logger import ttnn -from models.experimental.mistral_24b.tt.vision_pixtral_transformer import TtPixtralTransformer +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_pixtral_transformer import TtPixtralTransformer from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull @@ -27,6 +28,11 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_image_transformer_inference(batch, num_chunks, mesh_device): pcc_required = 0.99 @@ -51,8 +57,10 @@ def test_image_transformer_inference(batch, num_chunks, mesh_device): all_tests_pass = True + tt_ccl = TT_CCL(mesh_device) tt_model = TtPixtralTransformer( mesh_device, + tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=None, diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py index 7bb2a0052064..4f4994704f64 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_attention.py @@ -8,8 +8,11 @@ from loguru import logger import ttnn -from models.experimental.mistral_24b.tt.vision_attention import TtMistralImageAttention as TtLlamaImageAttention +from models.tt_transformers.tt.ccl import TT_CCL from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_attention import ( + TtMistralImageAttention as TtLlamaImageAttention, +) from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull from ttnn import ConcatMeshToTensor @@ -33,6 +36,11 @@ "batch_size", (1,), ) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) def test_vision_attention(mesh_device, seq_len, batch_size): logger.info(f"seq_len: {seq_len}, batch_size: {batch_size}") dtype = ttnn.bfloat16 @@ -53,8 +61,10 @@ def test_vision_attention(mesh_device, seq_len, batch_size): n_heads = model_args.vision_attn_n_heads head_dim = hidden_size // n_heads + tt_ccl = TT_CCL(mesh_device) tt_model = TtLlamaImageAttention( mesh_device, + tt_ccl, state_dict, state_dict_prefix=first_layer_prefix, weight_cache_path=model_args.weight_cache_path(dtype), diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py index 849b95673058..ac17185e02a0 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_mlp.py @@ -9,9 +9,9 @@ from loguru import logger import ttnn - -from models.experimental.mistral_24b.tt.vision_mlp import MistralTTVisionMLP as MLP from models.tt_transformers.tt.model_config import ModelArgs + +from models.tt_transformers.tt.multimodal.mistral_24b.vision_mlp import MistralTTVisionMLP as MLP from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py new file mode 100644 index 000000000000..ff4d63a111b4 --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.vision_model import TtMistralVisionTransformer +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +def get_image_features(vision_tower, projector, input_tensor, image_sizes): + """ + Get image features from the vision tower and projector. + """ + vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state + image_features = projector(vision_token.squeeze(0), image_sizes) + return image_features + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_model(mesh_device, reset_seeds): + pcc_required = 0.97 + dtype = ttnn.bfloat8_b + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + + mmp_first_layer_prefix = "multi_modal_projector." + + mmp_partial_state_dict = { + k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) + } + + reference_mmp = model_args.reference_vision_multi_modal() + reference_mmp.load_state_dict(mmp_partial_state_dict) + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) + + # ##### TT Model: TtMistralVisionTransformer ##### + tt_ccl = TT_CCL(mesh_device=mesh_device) + vision_model = TtMistralVisionTransformer( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + model_args=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) # [0] + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, : tt_output.shape[-1] + ] + + non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) + tt_output = tt_output[non_zero_indices] + reference_output = reference_output[non_zero_indices] + + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py index 0be74c271781..ccd092631ad1 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py @@ -5,8 +5,8 @@ from loguru import logger import ttnn -from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.rmsnorm import RMSNorm from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py new file mode 100644 index 000000000000..03060b499b1a --- /dev/null +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from loguru import logger + +import ttnn +from models.tt_transformers.tt.ccl import TT_CCL +from models.tt_transformers.tt.model_config import ModelArgs +from models.tt_transformers.tt.multimodal.mistral_24b.mistral_vision_tower import MistralVisionTower +from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull + + +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "device_params", + [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], + indirect=True, +) +def test_mistral_vision_tower(mesh_device, reset_seeds): + pcc_required = 0.99 + dtype = ttnn.bfloat16 + + model_args = ModelArgs(mesh_device) + state_dict = model_args.load_state_dict() + + first_layer_prefix = "vision_tower." + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) + } + + B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size + input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) + + ##### Reference model output (Torch) ##### + reference_model = model_args.reference_vision_model() + reference_model.load_state_dict(partial_state_dict) + reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) + + reference_output = reference_output.last_hidden_state + tt_ccl = TT_CCL(mesh_device) + ##### TT Model: MistralVisionTower ##### + vision_model = MistralVisionTower( + mesh_device=mesh_device, + tt_ccl=tt_ccl, + state_dict=state_dict, + state_dict_prefix=first_layer_prefix, + dtype=dtype, + configuration=model_args, + ) + + tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) + tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ + :, :, :, : tt_output.shape[-1] + ] + tt_output = tt_output.squeeze(0) + passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) + + logger.info(comp_allclose(reference_output, tt_output)) + logger.info(f"PCC: {pcc_message}") + assert passing, f"PCC below {pcc_required}. {pcc_message}" From 12210acf7855c2ee8f9bfb616a221ee808738adc Mon Sep 17 00:00:00 2001 From: nikileshx Date: Thu, 21 Aug 2025 10:44:35 +0000 Subject: [PATCH 13/18] delete samples --- .../pixtral_transformer_inputs/demo_small.jpg | Bin 8554 -> 0 bytes .../pixtral_transformer_inputs/people.jpg | Bin 49606 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 real_inputs/pixtral_transformer_inputs/demo_small.jpg delete mode 100644 real_inputs/pixtral_transformer_inputs/people.jpg diff --git a/real_inputs/pixtral_transformer_inputs/demo_small.jpg b/real_inputs/pixtral_transformer_inputs/demo_small.jpg deleted file mode 100644 index f51ba21be8d4cbeb5faceca2b88150b50926abf2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8554 zcmb7pWl$Z>*6!f$?zVA)hTv|&-Q7Zvjk^beLvVL@*WfO}0&E~ig1ZC=5S(wn=hUrp zeth><-PPS|o>}Xe)lywO-PNx@UN-@31zCAn00ssIcn?j$>o&{}c_}GVH4Rl+c_kUB z0{~#g0jR)p0sz?2-Bm+Ql3GvSfEsBNfPtoeiMgel^S{yms-djgrGIl5-u$nO|F;pv z%G%8mO1Xd*kSnxuXq)h%7~kgKnCTyE@oy~r5B73*c87A*{=u$V8d6Yf1;tD@|Aj67 z3tKw7{=<)ja)iMS9{<$!kNnd+bZbX#P3RLDTF3!6Km(8iB>(9@^c^ZrMF7Bm0|0Od z|FKzQ0zi8h01&VK#|A0@0Ibgd&^G%Y+keKy+1%CqU*X`P8ur5n0Jtg#01SNq!21pW zsD}UYLDm1IZPZW~Ika6)&}0KR0M-CCAP+bKmH;~xaRZzH7x3w=8W{VB1qk$_Ip z!gW9@WbT3x#3QY-W!dyL=@J@(35 zbI{GsOm`$M(iwA1DYg=?Z2aIdpS08q>17c{MX6P}$DjOBePhr-rfMhAJZsm>l0o2n zw1=sCQ9`VcA0>LJqiKDnEN=&_YhR}HS8wAxDLs&`G*hOs-6l>X=u&aX$EzT(=!mh+ zxoGtl^_J#Ak}c0PgLGI|)2KC%*wRcO11zv3hJH?V*0u4adZ5O_lIAfx zvqVe;#HF$MY)-Ao)x_0IsN2jJb_Dx7mBMne2G~)~?qh$Tmc*fRL3hr)0gRjvJn3kt z9h;rGcol*O)1<8rF<72+rxlEs65Xw{DkMGAXnV6}7;l2AdDg0;^JbXqY5 zzjY$3uNYzQ6CerlgaN>7*L7^uZf2c@aNw+F3GZ=m(pTNljRkB|V*m_fF$yzM$xe86r=kcUB!$FeQLg1d{Lak}CDprK zB@WK@<#|d@j1kz`C}oWGciZZs#D>)K+n9d$==s1DvFIRMO7xpN0ib>-PBL+{P*Kx} zzBkZ;+$R%j9aVG}QPo8u4A-_;Y%%nGImMstLC@uY@~*PzF!^JaLz|Z-GCTN`)_Jb} zQh_A5fv9q37~N~ykLJx4BE&;YQFe0$?3O3*tW2S(nAsF!FML~OGA2?rdK*f?oZeDn z&N9wb1?hAaK|o^9le23TDHBq=GKg3BmF7tL>FEPA#EM#GeO&XzSJbbqOfvPr~47iBDF>zE$#B|g^4 zDQwEAceRs+`roiv>s3aJ$B@^@UIFD@Q_GIB4Qa^PO#geNM#PRlM%vCgVrlsyCL%;L z>|3E|f(~VpCLFQKp!sZ1h}m&doroM~=}`TK!HwER`mSw=WbkM@6)LN{1K+oN6GAJv zEqLGPk)6tI=(zOVyU46w?zPMiD!So_hS!rVB9<*g=%$lrh<23pA&;?KgC1-3PFqjN z8e57*!^^<)C6_QbPEXRsw36^DA4jbrV{y5ShTuvCd>FVkL z)5&MFNq;oLbO2CBoqZxo{fbAvaJO6x4V^hfotxv0l@H7`98g-t4&b=`Ahd?RC8^JN zsE@q3G>w*ux(gqrFjDkRHX1zip%)k5-9yKM{wTYD0`R#fU+?VF6w*5@vU$Iv-*~;a z?GtgVX|^#$0@E92U43&q5<9(=(0NKT$FC~+H9(Rx{ExavYn1m z5sXrod@`9zJ;UQ1MpT1!`H44+;T6rb9b>E3-aDWm6+7LbJ`!P}OPTq#h?~g|J)?-( z&cZRDPAh)wk5yUuEla%u5!*sn0oyIwH}(GK_mR$BwF7^j-sI8?!AphFk0~0iNYVcl z3~l6X_Q$$_q&&dexpFgke#aCgq-~{DYeRg?xK!-2fWM%KXXB!(STq%U)r!K9$$ewA zGoJxMM#Ht@$~kAUr85He*!iLOhMB7;uNjZZYdHI98ax)Tsj3o-R0(eBxj&s?jBJw& zTg4klsmwM~#qkTk*v03mZc%G_3;`-50V4`V=A)=Kd*bqrSU()7WXoHRWUAQp(91s{ zt$t2dHSaxo(5gugY%^hdEM+EnKj8E=%nF4q#h`$4y1}*ZMT;*%{USSvRlD1;8a)^T z%f*~5Cl#2B(l#@RGbBN}vpL64Z!=#s{Ebg$pVjR&O)xL$rN8sCgMZVGXD2h3jnOi> zUX1fxJ&hps*X}IPI9mW*tNmvxINn7`zIb&!kZD15`%56`Gv9CbBdlt%mM|_5yIMm2Ml4g%p;)z!v@P_<}?&C82Jt$tL+RoILk4%n}<0waSLq}d0 z=PLYEqUb)GW|Kf!n9+`m##q$;CFO|}hJC;XHY+v-$4y$^a(*`Z)gh zs5_3KU${vpQML7@ZA}o$ErV57D~mWcW`5ugMz~KP{rV0@LrH0N3m0*Jvx7VcK;l*y zsH}*NNA}_K4E`W56CmpSu%>+Bia~BTMseF1L<+wzG(z>8qHzPAj%uF(>DuFg`GTQz z3p3MeEDJxK^t^o5Yk$OrO%xdW$eAlxbXTVHo zS}{6J-x=(-mmU;ZDWa{7WNuPn_&!P-%kI~(EBI=)-R-u-{cUVfVMY=PyK~2E!qJ04 zho8ueHfB{-EQ}ePa(w1-H_>&?1BbJ2s4$BC$U0T!ag;YNQ*B(!0;1j!Yr9~cT7O9B z*AtVSO;a(o&1u+UdGxE<^0lcep70NENRw21Zw7mFy+|3dHv@jbp#=>5l>oSZ&mnM+q2jmD0?#R|+*( zL(*2ceNmf~DTB@+uYi2PhiFbwRvXIeIYIT$RZnT^R9lFl*(yPjBJ zYfOL_GQ)(GQY=_t14!v|(6u~9BfGY1aqp$>Ca74+=;@K)FcwdJpqdYEkA z&T3B^erW0!3cY#@N|Rf-E^jNYg4R%fIx2NxEz+eU4`yU1wprJXtC=l)V$K9(3TjFp zDoL{*)1SZd&VTqI^?YPve7G^`_6*PK^V{e_-I)<<7R0ELtj1NrV2B*sm^QnD(ZJYe z|G6JFvTt{g3^5|+5?j2zrG39*sb5~NQVy6xxorDNFz}BpJ zbXY`}#w&z3vOdq&eFDr07zB9~KwzH{ZV zyDfc9O}OhMbv`SS%_Puyj+zC_iMO}jhc}Z_cgzqB+ALu z9uwY8%}og*TdZL zCI9n0l9)SDBstPxFU`^5WTg9=vLycy@&^qKJwY|+5TZVr*EhGgHOB@NR@e?#Srz# z4L0sKq2$alq!QsinVoj08&SYSuL;P2Q|iO(*9!;^a2pR2 zyQz})*DabV)w$87Ae0;XTiaGxK27PriGNO3z#Bibk7`SA8sxNol?OTI-Bl;U+2r^pT88CYq`^Qm-NOkSS!F+M7`xu>i(y~E2;0BsQ&H;csePClB!{!@?p_tgNzLibWUAv8h{WaG|Hk_Y?23=59y}jv^(^XHXn)Q>;(wZUIY)kw zyO_nLV&8rUB2N=j(Q9Cvp^jjf}0Hs-`x7)|)*rYL%f^ z#u-rSUs&4tzK@k?%qyI76E~)EDplIMLn}O()zp=WeSg_*IgQ61YZF)yof4thPM4}9 zPxcB>@uED1X`6^Bap!rb@JW#+h&_Up0~FPP`1*UZ8?4LFJJkLy7oK6!6COg=12@@JfsC|< zh`nPxF~;&oQeS211u|b>XBv`|(`p;qPu}lOnuUzT;>cs`5R9ex5<8m|75x0|ydRdY zK&yDqevKTnqV5_MnsktY{mE)9|H0>qd$5ufH~Q++Q&0(fvmD=P6vSVsIZ;uHJ)7fN zX^wAGf+|;Ha_eRa#M_-DD;P>%wHa_vz|a{6-Cea~yckG7Tp3Cn%aA)r`8yM47fz!&K1Lqg2Y3!nSXVZB`k2bub>1XbhZ0ac3=QVLt<$Dk{RRpu?{)W6X z8T6F8i8f~n$C`9uoHtRlXqB0QFf>Rv(-OZvrsH(r2J?L1g}W!k#pH}vq1U-ur6o(L z%+(4Fc3Epg`i(C$i!EOz*(`6500GZilgv&;7)a^s4Q)&@Ql)sGTLox=_mgo4hnuM! z1I;DgLq|sSy@+F0xz(cC&k^vOmUci8{9BKqBFBoQJ&;zo>abX zf6x{`@{o>x>|$!Go4mSC!|2E-Ak-U@9_1J>XYRpgQhTSwup)TLjgpjJ*gF!yz*k> z5f7Wnp9aR2e<(keJK9?9GiG*i9WBNBv!4n~{<@gZy@%tnrZf?#`b57>LQJ-uBUyav zhE9H7u!Ih{7E^ZA-a8-z2O=WXbL<^4Z2R<~aFg*SAt`emLtyGLjkT)N9 zr?pJ~s?=d}J4-wrCQe-@&2!L@S=O?@U!fknKt9eSkB{9pr@l6>mbkQu_Nb337f-Hk^fotY^vS z;j)gSJ1|P9>sigXNG?EZK}MJD^t{Ah7R}?yiA9}p9F@m*ki4%I`*2>0(bBekd$yVG zLra8?Gy4a9^ZAC%y4udy;l(2ia*Mn9#ZUL+8j@qNFY|c>OX>nI3Pn1Iz8RYtNj>@( zrC+vUweo{J9M)_pC!$^qe37ZMqpNaTYt1y-5JS1|oY-pIIeE4J06#Bp>0Ri^DnU(y zzWv{Z!IN)lef)6uKF!|(mJth8v5#QT$U-_q4Bov03ST;RD7ZkVQM`kBHPx|!{1p~o zB*LcPy;&Wa8n*4nMd?+J{SAkGA+kbejxWdD63lOg2pEZva&8DiSAZhukHl4Wjx zZs&>unUy^3@Xrigloc$t+=KQWhTt~-I`mak+S|H^vCRJ(p3K=)DVXCAt&oHvrKVoc zZkj?AO!Ctbc?C#}j&zGeaQ6*0Q6@Uu5{hj);3*bc2C9Kl`goGO zXx{gaatit`N;nKjpT?pv79Pv!s8H;JhcN~dw%K_W?{dUDO( zv@~W+u*q_x&w^b(tvU5(C*kTl%OZ$O6SUaj?YQN+^+LoEW8SiYI@VK@I(_`h=3uK+gO*7%`FWilFWVFs+{`of$%NG$~;a^|Vre}PRk z!#ILVf8Jk{@p`>y{mpmN*^on1R6;&h7ZFKJxtRZb;c zkbx+J;_lVTV5dU~v`yCq)bNsfjG6-0WU=Qc#9|z05~Ty9XnC4P8Vtj;lw*HwAzr3` zUsB-PoQOF>W4C(Z^r&N4_!NDn60SJMXJqYkKl4`|eD>%vmR%C^+5I%+qxG`R z3ch-t!>o~Ibi1_GFPYA_>9;eW=#Oh2V(Y$NK7&pAFN%*uiZ3yvceK8|5P<8?jSD(3#T}44`BQ85l(r#8`G1|En$*h5sFzBh_a4#PERcUKCfpl z*Jo$&Ow7;lS5lFGn&xENyjqO9Gc+3x<}n$#p@ovC@@@CDS;!;uayZh*zck<`!Tt_ClgVxTZB_93_dgFV&#NIJ7_oG;G+dzts zM#`SIS(INZMp(=`@>MYfu5&S!R0RG zy~+_7dd|-suHW8d>C;nV#!d}W_N3_CrH5%s`v_0VJJuzM*OD??V2#X9NnDThXf2^{ zAnc#2{tQ(uFOL@eBQFZdcTE(Y7wghsSF-_?c4y3bYJE%j8U zAaq%3;vY{%2JPa?wK{;4273}UxMEqbk@$QaWI?0v>A$={Ah~n`px?>)gSi&Jw9GMK zHt6c?48EPmKrkA136FimN9BL^yaF4gu%74kuB;&cAI3{W<^iD{2;^gfmh!?0 zD9|Dn8_8gS+wEM)efq6HNC=UW1~ai+AIFBSHGA|QICI%CDtr88o`#hRq*CtTV2K}3 zM@>@)=Bf9K!Jj+FW1OuVW$u%uOK@OE)ka3w$L+m@+fLERz?Yc&pXF>zsFT0`o{!+k zkSUWPDwnb6nIPZldI`i>OA`dIy|HwHMCtG|em>=bb|1 z`0c%##WW%Qs(9>U*xu1AP^_4=cKn6$vm0Z_4D|kmmsHwv;AhXn9CoK+rJS6GZ|F;^ zE=#{bOQ*+*R#6Z!Bc0f4zM(CJ{}iSyfGZ-Z#W7f#N8Qh%rg~en3F&Tym1Y6%-jg7$ zy=8c0o%A4cb8sj}h#1HTJyHSXndp}|P(8`uxled3(>Z9lv4&eij~&dx+sN)RW8Y2< zg}-_3-4=nL8vGlEMFYvN>E?q!``Cw|9zZn9w@?)weW?_R-f7i#iV@oz2+CoEw+58U z;SuW9)|JTUh$sX*qg2PEcsKi%+NWJz(j-S-7yZak8L1;s)%Y?dK#`KC*{8Js^{p<< z3lZGZn!pRhP(>U_k~m7%OInbl9}4g7t!5F7qN8qI_4VZ_dP2-03Mh;75>H~Tt;bP& zUk8VHtUijOL$Xe;Ej5-;;?Vu(`_7l8)vX(_B#BJU;YR1D6qu>u5EO)ztDV3*f}KFT zou6(Y90n(ni;nJJf4DY&njweF;l4KigP66Q_imYgG4v#B@mJL$v1ED@;nl>bkE=7! zryO4_h25PyZ`4Azcy=3FH)EkL?<_ZRFou^hzAm;)?g9Qu^D$|R+?z83 zk%>IRo)iLkZQHLFnTqNDfdT6HNrLjL#T83X9Ve^j*`A7!na3z|$vQRcM0L?G8wRd+ zaZY9*O)R#s48#@L!BlTNUwx;GoPyAPuavUCR9Ln2-fNJaG4?Edo_T_6>X2ti{rR7ZU;#g{ zQh<4yiCG6+yOGA~k&@YcJ_$d+*zf1)G!`QUP~0x`>Nv;hVQFW@ztuD*8;%bDsRIl5 z0GV$@1dOx63^C7lXT-BFc9GIE;M*Vk>zE}mJ>_*-c83S5CGKJing zrluss)-+f!q=ZON9A+O_+AxEXK*YbAnj)VSuu3strrT5w0uHnBvYB~$N1K`|4ht5+ zZJB-Lk6B8gWPeAjtz`)7JX@SjULfD{?=7!tYKkN^K3MeX!-po$M_KqD+(jI+va&>H zmlHYj^g`Hb{F}d?>FaCps`K!VblZGAxm(b2kdFU)q_3|crT}vTFBOb!X6Y*~_6Pi1 zUtj#KlF*y=parMIGM~t{82GiyMM(eT)UljJz+J8EhwUV&0P(Sp9$ GmHz>irf1Ip diff --git a/real_inputs/pixtral_transformer_inputs/people.jpg b/real_inputs/pixtral_transformer_inputs/people.jpg deleted file mode 100644 index 16dad8dcbf18374fbb920fa4ad944e7e9aef8b89..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 49606 zcmdpdWmp{B)+P{w1%g9xch}$q4>T@KH*QVSH0~ZD2@b&>0t9c|JrLa8Aq4l}1Omx) za?U;9z2AKEW1jgp`+=@qYp=bm-o2`7Rn@QgUu!4?YRan0D5%KBjmnLJ^6M)qtCA1Q z76nCBl@kRE1qI~+%41YClx1X35cyAu`Va*bxqXcMq58W|_M3og+RknsZk8|)dS_dD zcTa0;h`YP3rvp9I&DqoCH_gh^9p;22(z`&MT^u0vmLBwu&NdJ?C*+_l%o;-P;Oq>C zIYEC9SX(+m+$?$N|6Qh?rHhO2KT;eaPV|;eHuUb69x!)X7{unE5pU<;W9}Zv(Ld7M z=`HC!AeN52^ctQH9x#Z9ofFLZzs&mE8h1yS!yk5}%>UgGZlsUD-Tc?~*H;V;I}Z;R z2|hk2cV0^yXDbM=wX-9ikEIKr053luij=I6i={Oj;z4f(fx?`m8Nao3Fw(ct31t5_ znU7Hh2V2U`##TaCQTcBk~^ndGBfp~cMsybRiA?_-w z%KHCEK}zL?{vLz<8T%7(D9qO5KlIyJ|HJR%>E`f9tBo}u!~udtg@-$mR{*K%ujsls zc)I_?jr7ga9rBy>XO4r7Egw?h-?IOn@Y@e1A7qIB4)_1)7m|Tbke^?ShhLCKQ2ZZd zNMX7#sGY}uX#C%iT! zq5^^<0%F3y{rp?0gdxOA&%@H=PZa*L-yI3w-*|z!dpKLe|ET*DQU?gs(t*zbxgwq1 zAv~5)H^`q!vYsA}=8i~9JRPNx__9PUB34UB7s=mbcV|yGYlt+|*%|5p`G<;J#4eUF zsFU>XFd>fx(kEneB@`_^EHz+G$i?_pt zoNe7Ko&2QuenTBe`K$AXhwtBKBKtP~wJ`n?Am`xmFG;_lAmL#7+j<`!8;GqXGN~}i z{ClGM|1J6bw}}0T#qWsydz5gb{v2ENKdAX1;{PYVxLCTmL-c)JAku8QkiRg)Ci9;Y z4RqCIelM>76FC3C^dE3S9sdjLJpb(^!v8oh4Hd0cWXT!GDlIF>%PR*oO5m^*|sq zYKp2zlFNGi$ubCzqg~AD!?E5IMxO@3rT5E`Qndo;7seQdN~t2HhSJAjF~{MPl`O~g zA=ncd3b-q?9Nh~0!r3A(e!^ZtA~JY4{r z;2{q5tb%d*%k=>(0Fml`a9<2_MHtfHoQS+Gdo+BES;MB&LQ6%~2x~Ijk>0e#D0zB^ z!4f&4LKdsf=+#{q2GN7HO?S8Ibq^s9lY~jCm>k^gHB5qo-! zr0YuZz7Xze)%C*H58rV-FH`k^76|r7g>y0p!=F@QXi3|Zk`eVd?1WN`*5ymT!-ho8 z4^Al)j>h$F8_hG;!?1UB{klEV_8)gXH2v&1z(fAlpp@P9rk#>izudvZA-*sCJ@0cR zr(71JAjBTe@-^7<$w1xa!02r6 zTJGa9%C7J8heNAenH|C!s%jK|eU7%7Z@42g$T}A~cNCz{yd`>PEd&Z#W|Mr1m7X)3 zn$=^|j28LPbLSMZ%bSfE#$EiZ_E5H~7qHkG-Kj@d4b3yjO2k=Qz;TI&umB0=ChOK# zqaX4g_iGeA7A?=^YUmrk$}{RLw;irt5qDnTFi-;`$}-42l|2ztX8uombN3hw^y?CI zdUbv5;Tn!a=DMHU)$$c*84Jrwn2Okg>b!$P2jTculUwJjnIY9z-0n?t@3y8ZXP>Yn z39-oAeWi-JT~;I)@`Ot%H2s5U2e>S$yo+KI807)l_OXL?on&m@QeO6d!V^ zz3E4>51F=voJZ+HbM=hft6XG0}7o2pks6$N{uV6$0iXtT=(Ypb(L72>LggOcnd zz}amWAo6~i$My3VU_{7$+&P`26JSiEfzv;-lDJ3!ASZeL@IgY2ve6(s1z^C43!7M_ zP5ykq?M&9R(=hCf0=GAHLTXO;#{?!XOy?X{^|F7zUKp)zh+NHxu&wiSSTr+Cw%n*s ztB(AZalpFQ7d(=CK?nlS<_yb<_U1OV821p&vbU}(rcY2+&v75bNEUB8^%XysPU};t zi`(}P=(7sF<;dJP-&M2fj2x(Gx0xHFSkAVdv8=f za1rf-Crw0bUHi|EW_<54S2bE)rG4hWJej;sOH$c<&#K0oWpUsVv*PV08Bmi8!EWeI z9b|k=#XDp9z&xN^+&%G$_>^Fa<~)FRa-G{dZPCDz@?yxoHa=}#O zuZndbp*jgqmmGdJ_iAe;5EA*#n})ynY~meJ+GbSqO*4(}^19E}#KsqOYw*}AW_uu6 zRTpn|_Dor(+2@O$bf~~oGGZaRElVXiPa0+l^vIXSXzD^}ob2Xv#Ag;YgZ9f$B*shf zQf6%{lRPO$MyN0J0_t50cuA8NN5s7@7M?~)#x20f6CKaTwkEibJw?SYy#r+RCJ>9&HZK9|dQ$$U7R$B1uBNOGVFui`oo_T+R#!KzU)K*gUIM@OW6(q*? zc5tMl5iS>P%3mY^RK1q<92 zUbr-WHtpKRa+1Brc4WDrw8rCBaoWHOU+=QCdvTVr(w3b!5y18vXLy!P>ouPe>(b18 zVCWyR-}#GzxLx(4(k%#1d|LPNY8QT-5yCVXAA%@T zqXRpGnQ=@C1Zz4($|vGg0@h7%oerBv3%;-RjPhPeI@-+wRL8|_J+{;R7N%&#r>J(s zZncVw^>@m+hTD9tohfcszTW~`?dRKJt~m$PF;3EuOuSy<)_RGSv6>^hgA6&5Ec1$N zk^|TF7X|D>P@SjoI~QqNBap@l^__tswfd^OXawi&&J?RzjotW&VREihF=ChW`Oc*0 z&?CyT9x5Y$d0VQdnT^vsZ9HeQb5#348DD=xu%Dfwgvo6~yTTXdg3FqMX7~69BL{5C z@2U&&4ldTrGQRI0PGz3o-j~_Ixz(q#;`hcfRz;osR;)1ujIsh3MlZ%^?Yz#nmt7|> zz^nLXZoW@>W_!=S`{Ch&v*&mZVmNqped6V1-a@4_pYouKUx|8|P9De%Gl?gQ-`#RR zOTB5#58{Wdw%3-wov$9>P`uu^K6cM&Q*%y#R1qDyqw06}ZcfD4W2QIlU=YqWvfu6} zF~AGhJzo)T@@s=;?948=QA!LyJC~cvQogzJ?~G^Ef$a}(?7G3B9f<|LVD04Ua(r;L z*~bH5YYSUn8v@@aq0l){=(J*irY-82ErAyrAD`wknbn^ozR8va zz<&PpkaX4Kkn+=uSxAhsflGUox?b=_M=i(lwHlku3RSfR90l&^(Fo%)p>oU1&MBOE7c(~M z!q^Fl8{vbv9_c&?cRyqH3i?Q%O`TP1N5qcm>4mx)TGmllgOsy6Q19`+2%?fD%EsEO zfR#EC60wkaR;+`uVPTsvRX|WN2WLZ zER;tg$xRMn*hV>n!o};ITs=cQ&)Ib>v3z*EslBjvtX^N?3s~$2LY~W1W`7I{mIUTF`amd@x%Q7 z@SVK0?Ew;Y1>Z0a!YDutFYp@&OO0Qt7YidO{^hKbNu7xqRgJDGMEG?BQ9}F-TG}YE z#$@-(0cy%7^GbljIK-Q$3HsY@n*S}p()7on=yi0S}-xQbVL4kZe zOKE?d^?^vjmEF*++d@*lsu1_)tyQ9`L{Fp0mTQTU3g;^YeL^%|>b7TMAOl=K;Avr% zn$czgp`DZNyoM4&vnPLUBxq>n{oWWHrgn-J3jGW|7p%{e{WF{ztAe zWDdUU7S6;R#ZCS)#?{NG@|xroB8RnY0+znM%LiRePxCQ?$3M$|B&W?GAyzyG7Y}{d zsza0|H=W(GS0_z3yJmt<>ZB%~?+{;RW=M6?2(66~;37FB_4&P{cDrYd& z3l=jFz0!n(>rCv9eLaGE*2((VrYI-}to@e4Y&W6Wd=E<~N#wC#5^1}Ia&v~-*4v*? zEwBL&SSmTE(aREe0E13}UW8@mB1!VIB42n|ys?WXM?6w>)FbA`&T(l=o{vP7hz=k=Q#a4o5@7;`4mXaUKr^XuV!6o=k8{oGi@GaF1_s6Le$6*wHFDq;O`G&ET%U{*efPZ3WS zWvJ&?=7XmzE~1b7iBYS2NI7V#b!V#QLnY@)najPnddr|56D3hNIi3bO^6przn27T| zgV;xrv}uCc;(h%%3>;j>ef^*jJe%BD;LU)BhP~y2NV8|Wa$JvAM!)=}MGj{VdiAjf zY3n{yEJEH86P^JuNn8Cv6*mr6vUwo(IODiEm}p0>6-_6~F!&vF;;E>dZk29aG$ZAp z{e&?Yh7V)PiTt~_qXpF{;)Tvk*^b;7q$$#^2hLPJdmW!AV5)Atr=3^l^Gyt|rxQt{ zhYKp$=+C50jjD@>ldJMtN4L3&bifZ|vJOP2Cxp4!NWpN(&$V*4di0Z?^x;yx-C25{{V4vB5YA)QFE+P^Ol&<+AW;rxWQAi!> zV`9-Bd?GuyV;e(XK)^QyY&qOiV8M4xwk@1n8suVeHd&`zZT&+Eth4@A{!Iy5bT?T! zL4j3lzETx2LBhUa&gjRuv)Ph3x5ZJsm{Q*o&gU>)$|7zPyEs*5!e>AfG9CzzFU7Kz zLDN$1rTeU;m7dmszpzNUnvl)T-OnEt?+|?lJAAa?MzkGxdSy!o^q81*UMGKeX5Bgg z_gz5aOV2~BR&7X2dEq(P4m{i{llR#<-1bannNBkqVb)?5u1iY#Do@L8R7ygmWOXz1 z+TO=nXL$VdFwpH`tbW2QkuetwV93)`$aRrTA;xR%_y$A&#;fx9rma1|H0h0!WtxDS zUF>YRSH{JwSH00W@Qi*hx#f|)z8LbPI_wSJ#v_Xk(UgX;+yilrkL2NDd;$m# zs6ys?=YH-rQ{Qv!zM3MAT;4ba;p`XErEqNUwCl;r zQ`=0p)v{!#r(jw$RyjHCs<_k}gd7*JVz`cSC&rKp9 z*y7YbT?Y^=DvR)9%fRL`OL`EsHLYD5K(5b~@7{Uh3XGLk(8(cK;tk^>2uHuA-QRt8 zg-bvtMq`xGai<2PIgK9sbQ5DWM4H;r2y9;ZyzYx2rH}6mV<*m5I62JkZJkpf(yn45 z)9XvxHR4HD&@UYFvs#QZmWYr>ERKQXeJ-hUs0)PJcxjU$g1VFQL!$XkKAIU%R*f}! zT$MNuo8$|k)$Hise|rBqj9o3?*7hb|_hTaSsa2KkjP(o$IFP+{<>NxyuGQ9aX386> zdO%1n7!^?+&{xh*)TJhwZh$SXbK~*)?XGT3Uwwfj*uygp%%5rP@wIHMo=+1@<3rY) z?=IoM;bu|#tn(p4(cbP2DZ9Jzjzm^Xj}^8wUubprOjt=zO}>vuYm1O6HR$x=vmqmT z{YE$*w4^V7d!-3eqRxck`6W46w7QSB2FRn$lPECdp6!)iLJl^rAkcuF9qg}5-r!9U z_`I%;q2jF(*^J{ByfnGx<~8E|D1L8YJc2lV{;|GjTo z#{E(1nTQ!NO&+GgzA0-ta4hqB1MOz~v}6d*R}7Ear%RTfz^7%`j@vqndwzafyBO4z zl_B`Ug|%k0h}%g>wRJfyuAo4qAv1}d!r=6+dv@#M&;e7@M;3wNBs6HlY@HQ56BbcDlcqA-d5W`ui44`9!cgXREB#XI;0~KZ}JNTS^4HBKtV%8eeeJs z4Gmcy|6Tt^zDV$Z5S@UCm_gtnJ)@wI90`-KJSj8(Gh{Iw3t1vZeTZiG>eZLp;Vq%x ztyzqy7xNbi>7Skd`bpRl*VOy-lk{!N;++)eX?M3PI)T%~9X_jUX7o%Tt1IbFYE0V6Eos;^`>+x~g#sSs zrUEUioE4tsei`K^a4rhM;o*IuwQ+q|x>wKkgrB&mq|{CLType8Q;3*2OdtUliW#3tblKf%s(xSB(5ETVm&s~^_aI%06x@8oK znKuAE8I4m1-CzCqA@)N^T(dalJP}uE;pY%~O;>28BQ%=$&7)BI;6Rnhj;d4F)YLb_ zM_L<8dfxW64&!#GRLX}hHYn-TVLkgbD%K)C!O*j~Z+-PCwj;i$dY3`(AeKd`usc!`++lTqTl;At1!p30} zLLVk-H&Imu2#P&e{4^wZ@O)N|*I&ePI`sVZsP!vHkdDRT#@nBYKNSl`wdF8wJ4&u1 z7qYw2GYJHWA+|?tbm3+Yrp7SL8oNI^wwGCS2TJ1e^N)N*uA|MqKa|0D&Pix$Vz9WJ*)9;@ zj0*?~8J>(>DAJi6CChwjz$IT$uRzoFZbrdLOW$HaDpk+Lv(AMZJgoU4G6Z4Q!6o$9t-S?6^z2=tsqomo_4F3$F7MF z9Uv7`Bvd?R_$HC06+1_oY>9g8OP{GpoB{!ye9{zp^Oftk9|+VQ`yx!7Ry)i4xX4?M zGorra>+=TPtJwZe-D-#SBMvl*bLQhyd+>duDs`s!(=!uoLC>gTG#lb#6B{+v8140G z9TLkctICI8`ol{fdE}k)e=izlqYNMRM)fP^tO6Fa;8w)$W8R{F4>LyeFUYbqJM;`1 zPXP_hQb(;iTWIl&qT7fTj0&xLclCKm zzmPSZdDzzyeIW%{FEF0BkLLOJ!h{p_%SU1}&qxhTFXI@83^@vap)ibcA>11i5{#_b zdNN6i>nP#7T?M9rLU+YhY2{#R6uVj%FV4$(EE#FlcdBX5T>&OSVqe9rETj;XsOe+| zgFz&f=XDH5JKf#uGvUeE>iQiu>}U9c-BlmRu?wXht$gAwn0>QtZ~kV$;tL7jwVEfk z4X)s4&i*g`z^>W5JPpuJK4w&=Md@9_h+l$Q+KH7ts;*yJlorICV-0Zd5K%UdgXppM zlQ!7T*-W?-smYnT#NWV5M1}Kp?*Vxj;Vm!^PmA67=w|p-b6-GvrV(fuk1M3Eae5e> zGV*jyhG^?2WcnigW$+%Ca(0^$vEnF0N0XDHjdCq%p3b!4k5tul7At-e{Z#+r9bzpv z!i(whXH{~*9oVGJgeLDSN1C@qKqFg}*Qu!a0^gIlT(??*#J$^bJ0o+g=lU_9p9d#Z zZp9D!fQxjvENY6l#b{(Vm;_5^i4t+y&K|aZFh?t-z72oohM@F|7bpc3BPMbm5F1{o z+(<;u4C0sbdp;V|^j8q>2-Yh);&`#G-4B5Z5vrSgdJHI;ynyZzP^hTTNOL4l!zF|V zEJG4}Gs2UGr-~6v{z0~CU2gu>9~>GDu`jjT`P$`TIRzB+rvjR(@P8Yuepf~ zOST;#O8&N~VtS=N*b+{vClcx!OJ}VcFxtC40#JQlg^TR|}28)5o4d+jwyC;?k8aM=Vphus}j1 z(MDZumh{Pqi1|7ntptldHT9UDSfPTsE~!YhTA%X@tk(}IDPwP0L-y-06bE$;ye7jU z5Y9ToViN1UBje!FWa?{d*hoR5@9KSgVG^n?aesp?W(MiA4%#^SuLC}@Jg3u~uu4x7 zccW5Q3K9*^+-7!74VAZOK|4iXriy6Q>;^Q*Ie?jPif+vGg)@W1?*v0*SA`D*8JYhPs1iukA~Ww>%BC>tpRt{6-XOW;49z%%_{{cz4lQ zK77S;4UL3*1Oe>NXdgB7sdaE2Q!^OAY{QIrQ;H|q$&G?gLX0VD@<>jGM{hIjB?O1S0V~^mYZM|i8iD!-r{59h?wP;JQ8^;!-7wS*}#6~?nILsUUWqG zhF}Pv?S77U)6cY(3TDiLUecKKSTynz<0!*hyHcUhVd$|>T7@3UL`8K|;S5Wj;S2f5 zG0jp3b&o8HyoSo?M!r?MlhWd;(*P&`D7+U7ws}#7I;OTmWitzfA9XZ_?pyuezIqn^ z38YuxAPzrMU?qz@>;C91WWp*D>+`h^BnleU@|k_}Dk3+94t0s&9uBN9&;vZ<7^`{k@8)e(fMyN{+0MEayeR*EH=r@7n4izUp0Q{!HD*R#;TU5kF7aQ3C+h zMr*u`vtF;H5o#=?-CByc=(M$s`kJy4dBa@x`JID`)wm%JHrP@CgFbvJE=Hr=Bg-wxllC6cmFrDX9a%s%Dg6NbfYtVCkIj<>{9cA4GiP6zMH zMzGSj1WOXrayt2)qKr}hUXpFUBQ$OL@Ue1V^S(d0 zprW{lKw$;FcT`5gxktpvHSmS*7h=#RexDumC7f!*PjNyTuR_?&svWk-;s zX^25ll;^gc<36l^-K%ClqKCVQQ)?Sm&^mb>)J%Im7-(Dc3uTcN-<}V00*$f$ymLQH z++Vt7aOKG>bYV+R69VzVFf{454ug6`u^DM4d~b&dJv7`hy{{u`KfkYk!0%&DS9Rzk zbJuRBt{{GK0jNre!c0j#^jBEa_azOs-zj+!wK=9y${eWN?e7|^+Pmnzqv2XV6KUN3 z(MyKwc1%ka*Vf9`;H<4BeXh`7VI3x^zrmQ@yAvP38c*fS?-g8|yJ6spD_IMgy3ftJ z8!+olw=gJn}`i|Dlh;8|(?6f`aq~XmWv@l5N zlzMGX#raDSE8QjT2dm=C;LcN>Z_t#?Zs}4w^Z7$>JAZ>8zfi1Y=dTj4i6e!i$8GT< zZIfx$U$eh-$ZII2J9y{k`wOL5NT_jHg~EuAL&2pnNd2hv&bnKM_fF_q7?ZMs>Pfq- z-(u5L!w)`7Zn5G3wNaYZtgecW&V{!>f1wy;9ByMP<@DsIqLqLSOY2K+*2hy*FDOphP2?))J9jQ!(%^#4dTC?9i{{%M#b>MJjVFp) zWptbP-bIBp>ZRvrOW-97*beEKOzy)-L`6cpi!gXKO8@lL1#Y11zh-V+A z4n9cP)-Kq6{QS$OITJlCS?k)1suLOVfj15Of3Fz%i76k`QX)cr-Um}{a>D?nhXUN;K>-~gK zzhfm!!h}y=RPdvHdy6c#{25S3nW5{W%Ci7D zY4sPXiFC)Nb zGE(iJ#PFqt^nO)g9Oyw{8>i|3SNK|1?1g|gf%@T7SJG3fYqHdm+eyo?tbkdg<-LV~ zPR#gq?0uSx!V&2euNi@#DtlqU{b3!)i?8WVe-ICI&5tcl;b-(&;M9M*;*olMC%4&p zgXp{bB)QnzT&po#IQ=E`LA>8<7sod0!>6>JE#Gkz48E8DupjFKEQ1%&uFzbf#TGc{ z<*sBzXCn88^?Esuh8}3k)+&1vP~c*7jG?UXJz=J(9LNX8+@a;`5O3p6KQIy2tIbsQ>z~hB~4S0 zszqeg+CRY)V$k}&=1}r>QbU8QG_rx5o2^Oc!yf$35MPLiBcO7&sJ_bdVkp*A!f1C&|3xh2d@`*`wGf|q zXfR~CH%$k*AX^WJ$##DzoHatUQj3n1S5sOmuCdEBDc_tHWEM2Fb5Z(oZ8)gC#DDSF z>VZ%VH3s*1REiaDX9~9D&vRMN1DJN8d@H!3=A}RP6T$M!W(|?d@;ebe!9!gGSI{!G zfW>8Q?cN+e5yK!9YL>oZiP61oLs9Sxh4zcY=mF>}=q!Y(GOeLU(og$mSxlvtzKKy= zzIDQ0ScwUzygw9K_(2UGSWZn*XM7^n5xYvr8I8)FkF`Zt1cBS4ojIXAX6W}h}l_@R8j?fW@VsubHIozdje@!0@$568O1@hMqjj=V(7D6hD0u zU*Zp}bVc8ic<08yP7xF0wi3xD?rJnA+{Ob^wqd>8nUSOHD}TYbH(T2U8X>k#_czNE z;BY5;Sp&Aq)6;C_{)H0vJW$={AXP_N`CZh00vEo))YuK4sO%P3M)&py8{xI`Rt<;3 zK%yOOnY4ysG@W83r;%gAe+M@zHwL`_vJzKDcnOC@u)sIT5XaA9g78JxX|JN;8e@0oCoy z3z`ty7VR-1!OFFax1~aFuu%vDz~(yo@0AxHcANLPd-8g%y-wBc$~q@p%8XZTy#(yc z2MiD0!R@!qT)Bco&F=?PJ7W^1DC<87p-}UObK4ac3zRnEr%*cx=%wxzWxla~^2rQC zQpZpXL$rIZnDgbQm+DR!^rxvQY^9OA*Bf7`f;xu?N>Ni&L;bNKrq3lEvhp>v>Mu1v zo?`Fw6F8PMmj6uaSbX8;8=V3m;ww-7CJBSm~!7JrP1_sd!bx?t2ISS}`%&kmXBi#`=#mYKSSMkQ4AVC`?TjNIXEGwTghMR?-RW1f93y^02p&g$dO(je4` zCSl|zIFyoI#9A_5pi`p`_lZ)1=p~V_#N5Ga*7uoGKm72=C6?)|%x>x95t|PzwDKNi z5^7d*_~d0U3?4&$LfJK~bs3!aaH;Mn;*!iNv)0e)eXw<9WRSMj^bC0L)3&GjzO9Ma z!nbaj4K6Wg&Bu+_@xWgwr-WXnF|J=bTtfnDgoLWHb)(37wxdFEj$S2pYrhX~vM(=B zOA%^`2bYg9jmTtcQD{{nt#n5x8tTko({xEmOJE^iwY-iI%~H=c#a#A(;g6u&VOe@j zd-6o8VV33{t5^3|G%G>1%O~Lr5d+!z^RS6R&AQ0?ucZf+oQkx0hEx_u80D8?`#^5m zlJ|a@7%%!^(DRE*zk)P;oheDr+HKE^bDi>HSl;a(S=fmBk)*uB z%*JA`+si}xW#3BMMt?b>2A)J)R^|C6!k6KujA<@H5_eYn_Gm|1+J>tzWbVpsFO$5U zzsl08Q+`g9RlMxabGLdo+?gW%Qh~@$>k8j7F7pimlBn};A#)q*)0EAQl0qDs0i_5F z?Xt3R$$+T*uoIH5gMHRJbSCa-&=hi6h;#j-bJLbO|2_5P58#t_Vc7(KGqz0vIn*oT z@2v{>wKmTzuA$44B!fr$fC!Zss#@A3`kSaz0kIK_o|d2n{zwPujy&StkaNDewY&Fn z7{5?jWC^p2Y_m1)!bMF+&yfXqcH}$i$O1e%^4)ZlzY6fk3OpeJ5xu|z2L5Nnayr_A zjPe$)e~Rx&9$D0%)fa}>U-t)Y7fle;;g5xn&t?c$z%E59U>>K7@N3oDLQy=L8MCV&ApMzNB&;v9!J5VsA|jk z0|y(QM5;&kGlk*2eLag34%22%MosYoP{+eGEGkWd&oLB|^EDJeW;R&Og>7c&76#4x z87`av$e434?P8Q-l#VUMRp>UTs#Qvo=x!;CXs*Idc6Cw3N*|=S0ELjiWN*R=b`{uf z$25#TSASVhA`RUhbxCD>S2g9ybU3423%JlA3Csk83D@};lE<2mY$Gbpj8GVsu!TMewryEAqf z)~&*Ay}Ef9wkpwjRING0?R1K{Wb=J6cH#XX^Qm=KzkrmQf`1Z3wsUpYwO+$H;D*>P zxYkp4?yYsfl}JQh{2OKcQ@cR(PC!#p+jz~-`e$dZuBD6I7Av>OX!qW^9jEL22(HY~Y%%}j`ggktbyC3VgWLrl(<>>^RfWCO zMzQ2+g~17T-zzUZ{7Bigi#hjV+5D;>6*l}?Y`t_RPRU`oAhWy$>b#x;i3A|B!#UawIuT2}fcXxgA!_k_(@cWChX zOKpoE-!BxB>U)*FioosuYjVJab{KEsi5icUv}n(ia!Q(&O!m$Rc|%UZsRFl8L}5A$ z$;zXGyu*qIJ62bh-8Y11c&js`-A}7-OFx~oiPX!G*Gt*%&K_!UPbU|(;?EdciIM1l za>LobFUe5QJ)%>}lmrMRew85CySUNe%?m0ysjy`>jbar`Q^FVv{T_RoYoCgLCMuK5 zGVL%uah7ul;8-8xo;hy2T(4s7K|KQLF`l#t>B=Yue9TO`k}{)S>a)~g{L*5)4pUFC zn_x^`F3^(^l^11a1~UIRH0G{1e|u=%GcneFZBMS9fT%f~jAH-MbtG^}M5piOV9d1R zZ#HydXBAQHVy@B~V6EX~W)RvnaGuVsC!@Dbr<1nePu8<9OU-m$muQnHt6Uj2nCM4X z+-+x8dMZyo${vt8x)b0j$|r9YqA?v^SgLT+>AkXi{8WWWf9N7RcjdNWy;|sgBr2)* z2jMT2_+#eSX>(s>{uDKHFe~{<6fYkKeV8*ql>K$8PdcQ|H#_qM1GIQ_{+sdY($ z#{^m;9ZS3WEyy!*kLSkvAO;av)hjEZhw6j!(zLvP63n!m z`;{xS-EMiY-ySLZviJZ+yPv(UUh$<~BK8F*C%=s`Q@rqT0D+>-xUkxatR`@ssnsTX zFM|Z=e4v|`5!^;WvNmtu0fxk+7`ASSY~IK2Jo-Gbg@^qW5_+Hb=AtURmj`-)^$`qN z2Pu()th#JMw&-%0A_$eT^GN2pzDIAbWE4rBx4!F8lqq z4+wM`Wt)lmw19zS5fS*QVW|b2_I03MsQwL)rn#ZeujC%$wk0tOBZ> z{z93o;S-Nk+OVz8<(!E-Md&TWUdS|TrWGj60lH@+sP#C*>%|tk)MM6vq2$Gci;>pm zYA3%AUo?;I@A;AeoEo-Fm1yNIZZRlS z60ockH2Gf>dp1D;Q9L@Zvfjq zfiWAtWpLvaD^mwkc&Izd3h-ETacuKoMy%je?TLl=wYUDJgi*eZ-d4pe+4Ia~0(3 z%iaz1^^Jt%;h$!2_n^0A``v(QQzjruZsHlE>Jimur#<1q1{t8dXtj&H82HV%vf2;^ z`&D+|0#_(1X>N~YkZ^T>P2b7o#Ph&6RQllMeyS>SZ$jPnRR1If7s|!T#-zI;iGaB-4Ub&Fnl=XdC`@Y`4Y!z62 zdESs1gS@mAB|t%afc60M0UE~NFP9)cr6eR^_Z&*qwvN!(fLfTm$iN@w-NCRCF`9@k<z6Ln z%YDq+F6p1XJl!SRLxh3OKP76Ehm z^%E}@@&;R2z-sjtz*2*f*|f5y#@#9jRW7C6MqP2nzT3X55H!JLsZZLSXmJRAM;ZO_ zeMFAQIEh?Q!M+TTEgM3oYOL`5)}lfcMjf9?Lq0Vhsk!a@@O!#~s`Xb(pRQjhH%gR7 zuUebR`kC?)uWSMs%|?OE^i6|<3*6EV4@oqR=61a0Sh;I7Wmb8 zPFJ5MFQj?auuT+s7w@5YlCjQ{8~9-F@eCGBu0vB*@r=FYgYE|F>sN)7c}H8uTt!+! zxx)LkKj)U;5Qloo>*bZ7QD`!mebx>@HIi;S7|P4EWIC+Lin_ z|BI@(0E#1MyM=Lw#oZQ`#TF;HF1EOaK!Uqla1F35?(VK31Qr4bkf6afI0+7m1xN@1 z0tC;``@P?PZ{4o0nyHztsj04+uG7zR&UwTKf4?wcfyTsfG0ZfZ)wtji#fwB5fjuP% zKhDulqgwnY^~<)P+p76- zt+LmLVhaIBpBleTJ>&uX7x9w+T4kQaTA(iLW~b(f-e<(m=k?UgR2P)#f5cV!@Y3kX zZcoc=EsAmBpqAHv$gzI@kn%$YGP3iGnDFxLyj+LzUx%oi#lbvlY1dU22Q4 z*En{q6rS>0b_XLZ$hb&t79UdUe;7nhb0^*@)DM;{8i-1rRjIIhY{tZQO?*B#R{I2iKc zGtswJ=sH8%WXk8YDxXdnAp6n!>{yx8V7!}#O-G$j$4?7f=tq(Z4UZwtvS(8DB0E7% z-1L_!PU`@Wge~hew*Fd=(SSFB2D?1%jy+k^cc9*vu*R2(;?e$WjP91UE+%%@Zw|hv zF8Z}_(;rn9a?#0EU8S_nLBG{n4OI=VLv4QS`LyCweshUcuf^+9&c^?YX(TS5B>R3s z+d@N@H>@swl?^LZn*K50w+C-#GEDK4H%SfUm1$)BHkPD~JkFgtf2OQi$I$6|s|GT+ zDqdB;t5VMDN@7qX*A9c|guiAAFqLNB8RPve8awiqwsMsXTi_idZyokT_N?R}Ffh6u5SX-$qv zSOysKR%b&=7<1^w8TyhNm}^{IsfXnn8wpWP(=UxP>gKOWsi~Q%(~@842zV3>B7%4) zd5GSfmxUANhJJa$Uy8~~VtL(5a?)CE>zn##0tPb!&ey%-ndSJYL~Hj6+=&1;5_Q2AhRC#2K2j;%|lc z6Ztcb`!h*oJj;)FO+SfMg_)5tQFiHUo18lL6`VOnc1XjJ{XCgmH{xvvlW(psY^IiS+)BwiY< zavN^oG<+O1VNvqX(Pnk0gKtp0^x2-5#`%`y&utuTcp0Wzwim5fV}oJ;=Z^CyULB+1 z@1^^Ie}>73R+tv)bm~7I7>DmK_wX}qrZ+_!$eLnnp41c89X6UOe>A12md5QqWPmYB zEOgJZ$yu=YH=?*r7o1%W{Gw-k@Bw68WY@~c#A=N`&2+VtsvER*KaRQuZDRMZZ{HcM z&$af3%~6%lYPl-=W9FeVrG@l=1QZ0&`i<39zdF{Lhh`i6ivNwF8HYN&G@2Yqmb|T^R zTiYngbK$hY=_vLeh713xxZI7v`u7X!6fJt1B7Nlg@|JbN3HhKZPg~r9RosDD@#IUZ z903^6t6L)=n7BnJO1en4p7kJ?N3sdG_31F9)S{ED&oPMKq6}uFh2uWWZamQ~z#Cj4 z-u=~%cYd&B@ zWxV5=f$x&cJ?(kxUm}c`J#jv#NA(m}ruK(TsV24teq@Up6C7parq!QZ@T~|e8(&qf zyX*3Xm+Mi(sqX~O=Q7t{K8N5As_ZJNCylz-36PX6o8lH*Hj;ZOnVlMdd=bfZPcqjnXcJhYO&?Y1fX z?%8DGaz)jAL#4b6qjbeADt=FA@KqbnJI_#RpQEgU8zO&Be_ zXK(&KW4hP4nR13mDw`=*;ddq(zdy~M(iOAKT@ZNobT%~awW|L5xYcgOC~*c4|aat_!*Mw_i5<%+46g#_hAB-C!c=D(Rl}W_C!A0@7j~{=*;?41N8LH>1s-o z`??lHRvp*kE;B~!DVbARbwOaC%(C5U5IqUL3=+)t+oO4;>cW9w|KRw%4)(bXcj zu8x0kB0jdZ4=$1ZWxSL25930yE=M+wx8;Wv{Bru&XiyS$xHB{U6{^Gkx1_SA>woeB z{~_l8Py0`7>A%c;3ZRHG`Qx@JIs0uEw@a9cKsPi`3lYbI0XL+9MzjBq zohGQvmcx}zJFIO(9!Iu8L1r_H_{#d&J$RZrZbBQ*GaSV&&^KnHm1`%Td}@$R`9BPY<`lxv zEb^fou3@UBM%Zga{H`Wag9B`SF!EC;=p=5af_#fd09#NbJ;mPVvwgvmDWpM_dM{`E z0}U88a1ffb5x*T;NI2o-3|E@Wv{i{Q#1s+L1Z9PgXGJASJq}3<6o$d`_+SCR!y)T_ zvBEP>X&|#}>*b)1>2~*MH)m&_S5(T^^G%S$VT5CWyJl_&v zq&GM+Ab(gIM{AC%yn@r$zRJYp*-~;PBuP#A5_lrwePk>d^npMl8)P^24upO5f|gQ5 z+-CPq4F8nV+;nOr6U&}VL2Wqa&rD6?TYEZVN3H;dZ<+*Jn90!`IzB0O)d5*mIhfxP zU~^1|MG;hB&B5O|=~yPxQA;+`kUOBz?>Orkg{uTr%V>aB85RoSq8(r|Ux%Qb6)BgY2NDAIq-z%8n|W?96F|0I5swTl zVM}aAOSNZz4Q~^^K&_Z6uWKTRA7tlVE6h4?Mx`#N8$sqqRc^nmK&7*Ycy}N<6Bz`+ z8`77en0OJ$*%NYn^0{1xTBXV3j<_zKFj&&>TdAh=8}T!^XuLYnUPngK%Uh1mI92X} zQq8x_sQZk67)$AF1ur3z5qvFJ#Qj6UT*8xM=0)^HTiRlv-aNwlY^ak`IjH}ZE1)wd z@U!MEZP#nnntbDOM+*_HMUJ^~onY@w?9SzYv3i69xEh>P@w_3PJI=w#3#n8kXoJX& z$UY4v&+9OSygm^UTsk$${yUpQ58tNTlW!KawH^DupF~UHWTdFPOVy@SQ?Ee(#u2ql z8%wT8qKECQu|;^`k|;s``i7v}=+oc?Sd=U9ojgDl|Ap#L9GiGG&A^7(#xrANWWrU( zL3om`o^2U}AJprKh`6UbML>Bf!4Gy5R<(C8M911vPkN=65*NzkTQd;-?ngCawi77-&x+63^Y*nr-Zz0 z9NDZCgx;OcRpIHCiCifPh>-~wGuF04+?SxGEJg-W;2AI&htDh`M}rc2?eIlJ6v_NK zaN9wT`_e8Ct}j_fNI-QWBl_HoLF^9rItCe5V3>JN;8usHUXdLq!#5g_Wui$+Q97hk za3$Cm|b10B8(DXy)I_GKSB4l1ZReg}&Wu;dX$U6~7sg{nNO^^w{Z{xQga4}b| zbO@#;pV~dKOof!Ghn~PYP)A~&>di^4d}^9_(Z_jgwmg)eTRE|UGt<*s$nWtFriqy3Wl4L=1I}>$dgwP8-)Y~nQ9IY43dNHLR z3haxTW=w6GL}{jaqKUO$)Tm@Mdhr}OM~6iN&OLkvdl8|{n%{3R?S*F6ZV{~FbQ-CF@tA^nQm5-v1c~5c?GeWK(7^=6jOX)>8@L-mc3ef za^BmDci&x}oOQ)e!}szb=roSUqsasPPPqBd&nQoY#Ue-4 zC4SUsF;*HfSaVKvhIM2Cu+#Df4z=%cz@bY9ax3N0B;KMK zU%+s#MVq@FiGLW;ZsSDdTmb%a=Dcm(Fmu9Mp)9*RxnvbeV&yxgN^u2(c8}OPrWDum z;@2D`$~}aHcG;S*ekCQhx!M?GXLhUxzSnt@AP{pD%_y$y!WNp3aKte$B%FUXPa8^K zwbvpp?YY7WtuqjfHBy>RmB+0L^3+X(E-;#ZN{xP6uMS!>X7^RRYvam?BL006`c!W-%!Adk+kp}^$819L>qfOSM(Z5 z%gRCMU}7GV_!I}F6jZcxkj~y*(9NrZ`N;QX=$5@EFM-cA9LKX1Uh3FYK+9m4#Kou^ z7W{~IGR)AduR&JNFnPe`sx(IHp!nM750@Ia_M6cg6=gWIO{hQ{@L0{8W`)sLRWQA^ zz*SloqIB$zO`2;lMNF+GF9zxy7(qG#uweICjks1I(LCsecv{i0sXiJgSV(zI)4KVI z3eXdp*#y5Upm!waHsW~~a4s!6PmvcUD(QxZKb+%UsCK{V4FJ}$QX{k3AvQ37?6pQK$`@Wh-$ zI&Q%aN3~C1)XZt~iJW<@zSO4OI*~j%p?@+z4`gm*_00dg|90O=cZo6=L|)pE)n_3_ zK9cf;JEtNElk-;SSiR0TEuV;Ug=5=WAAdLcm1@G0q4<%CTArZfKMW)}|MaDuD6c4iVr()fLjmWQEz(G<$($m zO{QZXSM4Jb#6CEo3O%fj+4NJ5JvbzfG@FY!))RO8Dqmyi`*3;ALsd=ogoLgH9!8tW z#=m)|>NKelQF6Z8c#wi*WxasIB@xE=bZJvNY+0rnNidp+-JInIhNFyK$!Ij8hY?Dy z>gPB~G%)GFP1DX2^m-2RJwy#N$dc#4f1J>gh{yH0YCXx@8}Xm{tL!IK)F9w%al8f{ zPp>N{5}oL5@v?$O330d2wZ5bHm>iST{mCI8L@Jww*Ee(%mLhQC1j)2;T8GgLZ^ArY z7CZ&J_8-rmS+M{{2hxIr4_qmQG78A;t0D?MgT5K8@ryEYRy7xo6REa#j#heV`iuK< zWsD$We$mz~-=YgiN?x9r1y}j{ylty(pekMFqn%m9sCoD@-74qCnMiLD?!4H-{brs| z@%2FQyLmZcp7FDh`YmSkS>t0aTHyUSJ(G3~p=&Pifr(9>CYrADLcb9YlYE$TK0D1k;+8CtM~YuPf?idE(HHqfeUP22Z!q6^8_q>&}f` z4EsyzBcW~h2UhQW$f|D0c`Dto%!XYqT~r36pS(L^D5lpsv+E>;HaNFQLPxNS^m}DQ zba6y)U=zFq<{cVX>jqcq8gWvkZQ5MPKg0t0W|8Ci%fK=4qzwv5kf;!&M8{%ZIYI6RJ1@;AJ9~S9%PlzJ|M=DdJ z?Mxr&_)?bb_gG^JEGTEAitQwxaB{;1TBJ)k&$#k~G^%X$0R6G%XCqg|RAgvy!BIlg zDWs3b|23bLR*ZW*B6g66>aA(?cf1_h2JbJ5Jq?N|vkNwLZcmhj6QAZ;s8 z=TylOPd>jr{x>5X?q@p*~guU1T3cl2Eb!a4|u3_>UfG%GZ;fMQ8ggz z67uLvo+9eGjP-r^)8~fhaq^mbA4)uhrCr)6TcQN+Wh^M$T{pG1K4>A;!vn0(BMcv& zrh0lHXXxF}NVi2#@xKSi~e84 z*j0N9DUCa32NUCxIij>z1amtnIn&Csx&=v#M3{vurC25CZnhIT?+f7NA@kF_-@n)AD>) zQ&?17P3%ip!q~R`wu_I1!HULEvP%Egss!UHfhbQDCH$y4vjRB;a+=; z7Mr0msD6?dOBKRxD~(K`O_-VK7MCeN)D72$r((&u4X{@c!-v2=csSXw=CA1K#zZym zXj82@DT7Fw!uXjXbd*~lu1PM4yH0OLcxqV;>DG15)+(n#IZgBjs&YDobk{*`0H>@c zKyCm@8XS+C%|zA^?0uDqXx6;6uqMdX0j6kPE!FVa!@B9nlJ8U1K?b^XJWkFz5M_0T zLgoHfE6ihWm4a2wy3Qjg2;=Lmq?o#9u2xdtP&Rjr?Fo){>PuV3Q-MrBW`bc?o2=Rk z2{+IG-eHe(RfUU7I01#f7*TgF2oh(#UUL#9btOk|*S*>?5Lc|F)ud`eXNsLAc#XV` zyZ7O&@)}b7yVNQ=gmp-ZBQa-3NG^tU{epuN%O#MLDK5h?6X?h*(+RIQh$l<=hk?T$ z7GB<$F!P(E`Zo`G%?5$oPp2-%Jv8i(YaN#oJyd`JxwBbt&HvcMmORt4nAIS;`CXr-4gIU0Fo?fz#I%|}pB(TP< zXLpxil-^A8t4%!|$mW4g8nsBuuT|m0cK123UqEzLU@a|nf_ogYdaEEl4OikDNiR;m z&zHN5ncZgvMtiB|cSO?_D#Q7-!h3JV<>{l!sB^yuR~ZU$(2PeAknC{~=o8%!5^smQ zAaCT7$i!!_`Dn#5(a`udv4d3b_usJL7UCsJJCiJXKN%f)sEaB{VTd$mk$NH^LzQG*RQ!cXw2&_(%sZOgqro5KbIl?=sY*ok@m!MAix8kJ zpQuFmz=YGdQKtE++v7lmYKPqiOsN4|5nSz%FUgNth7GKITjkci9>7siGMbPXsVSgAAZW{2!5a_okgk64nCSL7Q zNEWyD@cE;MWikqfMaNIUHq(Vh87ISFHExFIL*A2-kvCPk{L6Cl9{wrZj;W4RM`P~} z>@h=V7owE5+%ZR^t?DA=>^MDgidYMEJ_t{ayV#rc9keHSdw@XmC$kiORI${uK|I#i zkBm~AQ3SN}ITigU?wP{ok~*+8hy72c1Jy2=bu5Q~EYps1sz5?a{6zZS8ARAtn{WJy zMvJKZFPBnUhAeZMqZrUu@_H`qfQZDf)kZwRm`A5hxb^pdGX2u`@o4IDbxfb0Db^H} zJ9^N3#(I~|^{Z=Tj%LPDzvzLh&_)J4N4j%Gt$r`ArpU!PjfRxpyG=g5-Sc78h3tQ-h~@wk@Q)HQ5DnG8$fec10IFRs>hVafRQ-{$(XTp;f%5+_ z(4j-aoA>y0p$xuf&nXX7f(dYpo1%V0%#YdWT{9qHe zwd>5T?8@gSz?wU5J4TCY`VmBbBbED+&RaYHpH-spC*9OwE`PX+f{I&+6(WsL$IhZ1 zfL+e_NebLF_mGC+n!HLIeeiDe57WCBfqWSRX4La0o}2p*318z4Y1*)Jxh^?r%efdN zw@7j8^7D8T)xM}D{(6YHl(RDrXTheTUnH!oj|VZ4)SuDD5<$Ch#N0^|6R}!+LI{V$ zy0!+}hWaxz)J-C7Eu)5Y7?XA24HhHVc5NIUKH^}FlIHdB1haCTYy;_7M*GVl|IlIX z19u$bHcP0HFJ0DnKWP|8Rw_JHNK`s%H3E_x)Q=S=6#zMc{OP)!| z5tO;LYNzb;L>J=u027Xn`TfI3{2#`UgE_sQjF33mUFTu^Km|)Vqo%(gB0CsZrBFwbehSe9uYRtFc-8%`*&5C zI+}0Y7t*XL6DzeeUXIH8O*E|2 zRD^of8#=9W_npKrq8Arn75$oPo~GJS7o3LrNdM zeUm$S>Mo*H-ifyPEPv+_ERG{p-!UzX%K49;iHV?ZIP+h!QF^mk1ng;?tdAVa7y?c> zyWBh^o`092208s1dZLUK>Nrf42+~XWAdH0PBUw=K_7(H5eM%@qhjxb^NC~;dTDj6Z z=}DuM8>zu74d+}-qQ5UC{i3+w&i(PK`D^wb8@F~pJGf|};yXc%)SAu5j+NAuZG?-Y zq{@!|!&jOO3}X@TDiBS=&ne}Nnd<*A5Qr+A#Oyn|?@u0<240`UIhNy(0`zQD6sQn< zf~n)eFFqon1B|m+KA9|(BRXwDsw7K#I6~Mer3Md(Y{VsWWEzyi6jn*yF5h3|sxnjf z$RL5e=VnPv^L4hb;6>ae6oVN&iMHbQ6x$y-NWNRe3A;8eahmrMFj(8xe*CKQJX z2zy1HAgpKX!H5pag8LG(XC9!-xZy=$>j9560u6F5QndYW{*{G8jv-I0GKmEr50F|* zZtd*=)`Z?r@FW$sC+&zl)qy1zoZW9nya&k|b1ln-I{PeU+%nIZ&)u8;hkKkhZzbo6 zzV6$xE_WU+=Ef?bP_%#4 z&hHEtdJ%xkat|JVv`S~$s49YU@lSB@5Khqmt3ttMUsc~&EFtl3*v-rKZ$hyh!FUI4 zQg)*y1>_?`d2z87y8XR4V$6c9TpeqQQKz>I(c-94&Ua(*Uu?Vn+&nuoK{`h8Tk6A9 zd#KHrRfs)0%S$zG%AJl)rM4Y^sWg;rx9>c|$K#Qe-#0dGw}=y-q@V5FeF$Mv2+xgUs3N_b6<)CweXXKg1%mwH*bUf5G#j8i(cU*N+1}eU+^x2 zt;FOxg)S#m7r{b{xUZjvjj2`^qo7G2m&axqI*8WKOat)3H5>}nMHH=6XVibDOCVZu z!POM)jC`(_=N2hfL}MP)ecyC-=AxCE#`bji-Cpib1bC`ddSI039lcXZq`N@fHW**x z^j(ZGNv!+6r=HElQR$@Mvd`8*sFz|Yp0XF@G-?b`QX^ac!G=$pidl2HyrpiHBP&aOTvl;b`(x5)(^rndt#N1CCnnezE@TZ2N8PlA-lq{6-BIp=?uZ`bYtkW%ImiSRr=KuZ*QG&0cXO_V}*P z&(QF1$B^Xl2gvZ^?APlj9Z%BB*zoFJ8`)A|y{^R&#Eh<0dtX+h6*Mgr=SyECf1i7c zb&YSFLD5*{Kg7G%-z&uZj9PgEjPrf=4IJCA&IY{O49%{D<-4myxr8)MLtfbj{=-;x z`Sf#`Z6O-qCk}rTO*AoEkruFp$k%P)pxgqAj!J7u`OVjDBU|W1q|f1RJ~5rRHcnhv zYU`{{Ic8X1IBeU2E{A>hHX)A&(yKHTa2zILMvnjNHhYf-Mrew^wl?St^?u>( zDPs)u+*%3I5Pir?Ag(yF(=eM1eG)8rS>o1vco56Ja`(WJ5;(u)fz@JWkTIg<{|p++ z`%~;X@vJbMT>`wbFd=stPfS>S*)XqjViQi+zrO42{QK(q`LFr>uvC{?)#-7>H?@xc=Gbz^H5Sd_iPAOf6e<-=kq{`|q%Ppf$2 z{qA}0x4z4|Wd+{N?G#Hk&%1C9Ql@5*05;z%ES=nJIj0{3T~eB@_%}mIkxR7-)X9(8 zaeKvO8EqWlpG?CU`)@f+Vh22l$b?-#IgnQ%qxd))tsd=D#&J{|)uvBv@EW+>1F5po z)gp+cXC3;SkdBNP7%_D!<9pBdnHe%v=D*e`K~xTc*}>!4GjzF_NeSLbD-)lxW6oaq z6b=+$8sB3s+S93Zy!nT5k8u8O^PH_w4A^&up5eVu^YVU|1#E7`fK^2Ot^X}6x}>IT z1?Ne-C&^h7zG}obfBgp5{n;LAuy0&~fFSwV&=jz~!eiKOW+T zQ94CnvQs^9|HDX>guyVKb^-Lx^ZqWKo~VxGr~dZsC$C`K*&6~XMS0nudt$0^(s16* zJ399L`gy6x)#0C)I{n3q%Ipt2h4&a+k!J}xMs>i&AedV+&&l;4h9Vb%_KW==J%%Z& z9{M;aA9%iSW>TpZwvTknBbnE8l#@qL%8dV$u-x*az+A%4qr7xV6+c75AWK?&!WK7P zbBr?S6~+wV?R~QC@I(E=ZUN~okJ_%}ErJQuq8$Dva0G1>eaVY#5#WBc`jspzIe`J@ zdA2KrWE%niUlPs;7|cgelDEWaKi0S)%5=O$-)^))?gv26Z%bo)&_KcfH?SeKWADy4S|C->u!x&gXJHWiTKKOCz6 z>h%7XOoj2NL=jmr7QPe)~403(UICy)3))wKvFPlw_3T%ajNb#TkQw z?6WV$AkM|90UQK8!u>HTt^{vp`jmsA6R>Eprv zX)gD9K-P_BTGm=R0X#bYUF!{KNSj|?(0aPyueQ`=5!;`RXRG$LE3!e{ct$77o`i+A z7+OnDGL*~yGKZ)TTxS055#nA;3=a20O!e(QhVGYVUA8B)BI! z&CYrR#~@v>&Zc{yDUEuhx99@3b9mcZ(WQZ2<<7{+fW~FMJgIaOU z3e2Qz%Xs4+{=*o<8Ybikl2TH?lNjW3{9(m+Iu?!6{FTKz8j;6N*C@PY+mDo|w2zOU z<@gJ+dP}aC@7%MXe>upa+bsgeN3K$`f~`OtWYr#&OufiEL@$VP4Yrp}cKl7pjB-!o zwW5c@TPonO;HxVKS!96gvDxmsltQ-PSyA@V)%N=3Y1+$oOZzMy%%B!gZVx;m;{nZ} z4S-an5!dir%^&7I1!Z%^Sb}NwuTJkCgCKo;M%C9J!Q_Xmah1y$2l+4g=S~M)mheV#vCTK{k~vm`xr47+zM!hz zLV_-}JuTaCv-TzR3M2M@OuxR``KHn!9E`H!u4&RlpTMCXmG8YJY<`qIi_P5}XZfbn z=XAa^K#<{h8Tt`L81!n3HiNbLN#=>P0+Zv)<@k;_Qy8%mEN|VLReUEV^w%<3HpMXo z5JQNot@Dmqpespw_;UuV8zCxWis{ix>fAAu&Fe}#G4Fnf!j9W9oO`+E6F%mdE}!VV z&G)@QEJ-J!vNp0k`)|uZf)T( z3-)1_;rT?Q3k`(hQg*y5@}2FK|Cx1{22kt|C$x`GedlRDG02^owtPlL#`fVVB=K)L zp_h4Yz%=W%0{+t3X#d$*VZxuAyJuExv7Y%ZcxKKCb&3eSf9^k={CYfWAn~t`Z@9CV|e= ztg*iK6RF@oL+9CSD!&13zo$2c|6$N?l~=GwMF)R7ZQB(1wG)2_;9a{{2$?>MZ=u&q z936u|MG8I3NTwB}BD-SaOSbwH(&eZI{wiYG8F5+2b)`Tckd&@NK@_AjzZ|>&W=fAW zaDxnxpL`in1~{cvR+#>F5=M365E0J}kajq^{S%%Z!gK_LAzH|I^ zg=|6ZEz=)`UbJ~vCRS_y1*5_`$ae>M660hZ`8?9{U$u^v~hp zoc4UQgj|ZG*HynKYs)2DYL%YnLT4g}qWi$OTgKU>ww+66{EG{P#am(16%gpyAF(;cg%)l%3lqFzSzLZV!3a*IZrZwI@nByU**SJXG()FJE4{FOG21(z4%ea zRcEY6yS@4qU6{9Ad;c()q^bvRnbc3%sWLqju&rOEOT9adWc5TluD$2?iU2xe4xt8q zFNoZukM{kR4)KrZVT1vSjrP|z0b8V9huYVnsjGdje*feyXU<+F3xY&up9f=^RdjXw z`ES%iMi=ZRMbG|-raI6|gQvOW)35r_x%@w6_PNvaG25ow_RfM87JuwG`eDmvEzzU0 zhve|*PV_3`gRJha)5knJBht0;DE=aSG}lP>uHc@CFD8(V8o=hiTk(u705PCF5e8XzYPlN<=sZhT~QXRf zj{i1Fw=#QG)Yf21Vm@uo*7NZb2)_>pAS^AzBYZR_Zy1u76t#)Dxer!7JC|Ou=-!gN z+?3y_&AGb0Az4L*mLPR@cNX!NB1wiQ!a!3I8>Z%DBL~xgR{GY(i}R5{)Rgb<*gO3A zXuOi%s7HXv{$GIjAFBOh`yU_t{{lp1VMALbga6~=K+Y6ZH*^Pke4YQNcWtFdKMJX1 z(e3%;Co1=v4^01i`^|p_{->RUi<~O-s)n<%UQ45-mV!}=SJG!-Y5QE}yj7|6=wmCU zMQ+YPbJB!KpRcc(n1IhAk7*QC=N%3e^P@%K==5&@{sG$F0RK_b-Sy0f)=U|^9GIce z`TqmalEgoKj|FEQf!R#;L8QjO=+Q_3F7p6YgXfikoJJp$_r|J6pY$EURvIgrmD#mj zB>vd9A$%O)tcCD$%y!l%ZJx~b&Vy#F9=y)dqv4)yYZvlq^jm*o^HU10Js&|Uo#)Ef zO(2pzxv@9A93P3606Q~`n9F&yzNf{zk$S(A^4!1i(PP!=&-tU5S2RDOwa{OLKd1Ok~d{0TS!Ws)B&PoTMLT zWruS7Lw|Psc#gN2wv6(C zR8+BkM!H839Mq=(9C_y5Wvc44ZtF}ox2Bj;#wO+3sbMO-PEkFvc|5uGQLCFaZO0)* zLTK;p2E!v3t~VU0p>SXj+y-{Oeotjt&ADj2#fAN1xwQ{?gYX`?&;Q^)CAvkwNcxaK zZ6ParV6%bpUd|g-A?4E6;gNI>_9pk%of&7xO+TYLGneg?Qv4yh$qwDvwQCmZ z(L*l*N(j$bkK;Nh<9_1a&>pqNphm2c{gy650#{-c7Rx|y0?dN&0fxT;U{d7q!4h@i zyccWr2A*`kWoT<4Nx`muSMowHz(8S8KTxl>znr~vWw+%ot2rdKd>N%}nC>BV)Y{vl z9SpFq>FyFpo~D-I7uYrUlJD{pL01ME#$JzYkbKYlBdol{F~R6Y)Yc99<&f8Rxf427 zsTsLm%n{LR{TgS%4R(k`O_MDB#MtE|BP*8taL^;6l@9C;u$$KApov>w9ca7 zMdVd_27O4bJsg@AyTyMQ17p*c&*{8yvsosX}T|OZ%$ut{s zunHW~-R3G0cM#iDBtIr25KDU$Yo+~-TP2V|&hsj$HVeIjlg`t&Oo^MptPTC1BxB3hrgUG79d!ph9U zyCYQ^dWbtOLQ`PysW;(W$Cfry`Z{8fkLaAXBG4PSPvUoq0xl9CZFzKGBxLKf`Y~_{ zaM13IU;1l-K^7e7Wh6=pvk&PYSD-Kzz60TIp?kHuNZ7mbJSeI~P9MR2Rc@N8L$yd+ zicpUC>*Kl@b$GO3J(cozaH8S;X_04#t(^G0o_!B`4mPSS9sfw5QWNESw zf#_6#tI^{UZii%q`4+f|-kv$xXZLwQibo92Z&NK_Kk~jDzsQ~+<&mnwJ_Q~eX_9F3afEsyu{oq*~|X{wRogb?T-b?$CVG4(8X+?p%xW8@sma8 zFX_<*(;;_cFFJZFZ*b?4VL{O~Uu-;jp(!%vTMxui5xyn+(ycuTFH3gpgB^2i;sx^k zrv72N7rj_N%q@DGP$-!(pV)ncNC+~mdC}8)!@?F z!Zh6^yMLp<#-B0E$x)qrbX@ZVKSH#Tr2^Z$=N$jLm2PV; z?YwXMQNN_^&wbZpZI`}amXk32I*J_yKab=F$S&HQZ7>4VgDzK)uw&hyGTUF|W_^ry z4-Jm?fEcZ-f92W{&PdULM;?T@10k*c=xd>@@RN z%%o6IS?HOz{fFVrHrpOMe&yNro4Ca~-ROeMXPTI8_Ij7u`GicaQ%XKPME~H62Isys zv;WI{^rfY;yw$oD{Vhtty7M0fVz9uF*qm^(JkHNLI=qgE3^LSZ#+=@;}~E&pEtRXVE0r}#V^ z4`<$ABZtr6LK}W3(Z4xllO6G!Yw}U=%IrL2)&OGjX4$a$K5XG32M*wZL;Mh3-nd>4 z7@PB)vxAd|`7_FFgzjI-bTY5bEx&D3fSyDCU+Z`U)2QY3{MhHimfR)IZLh{TK2Xgj z9c2DM{UGu1OvlN|#^blQsE?r))C(qN+sSwRum~^6$H3A!S#A7|4o_o;+4^}8&*Xe} z2m1=Nh;JtgA15aaVp(<7+CM&a$vDr1vdHv&VHS>0{sH8AwY~grBhGym?YCoC)P!WF zj|@vO@_!7DGC9s)}fx%;LTlReK!u}28ZTMyk z>V2Q)bG&Cz#{%17x18am$C5-qESA>@IfW#QvNo!#@f64$t;4)C0DrN>u==lBTWwl6TbY&*-s!Qr1jnear+ z-a`?#{{UZ3d}ZS!t4rARz~JZ4@DcL2_&qph`JV?HGhq;N|HJ?%5CH%J0s;X90s{d7 z0RR910096IAu&NwVR3El~YnUA1Ji_{e zgpfuti0@LO+1wl=oS*!{%l-<+3N43p^G`JMPc-w-H1bb0@=rALPc-vSB=eKbPbBlt zJo87kX(I$DPpr_I!Es%)Q~Ax3vb{hb{{T9F02W;nr7uJbe+3wRNw1m6K+?qx%~HyC zu*`S_n12XiT$6_JGRp#Rcw96EI$)$ycmDvuAYn8N^Zx+#fROC}0N_dFh;uVl%MHqS zfbV-mD#CF800Yd2atgquD};tfSpF{`S*iS85s1?@CcnWz%<~h>Mmm4_(h;^1gT&13 zDpKM){tGdO=ZRyEGXkEN)^!A?C9+b=;{>#-3k8DQSj(b_sc!Hca~BMZ5-g*L%yM%DJjBgHEm(|JARvnbcbEs6sWqM`BB)(l z)c!^22)j)W@rEvw)@Xn6NU$!YtRZrM@h>IZ{$1ihH-k5ZH37wuqDq6x!uYjw;OE3Z z6531rLXAzOSt#?AInA?}F41QxY-yF|JU}d;Fbc^%1;i6EbU+;91Q(IfM~&F*4V3x?(+*+8-fYRKK(@q7^D+<_KFiNbHdPsPIR5uhwc!XVHh# znW;4X7sJ#Sy6QLjq9vm}L4*GQh)PX!lelUvNzMKMsCN@Aq-JY?7>qNhxFb1wdUpeb z0{Ep5xTLjVU1B&CH$z|Y?f9Dg^0GY`1hW99e4qISf*z*;P3ehtnS41*4K5;~guq^4 z8V{IJ01Kd~>!jxxY`M%u7ykf?jobyZHfASlWeVYr<4$JT4p<3*lt4LP<5u{Jp=88= zfQ5?*Lai}rk23g0f;+*FtknKB4-kF%sWkou=ZTTQ6ySvoPyQ(g+Y1K~Z8GgJktHar z{{Y02-l8*&1;0;uQtOyImAV@700C}A3<7Q%{Xj7XT{54(4SblmcNK*2n3iq`M^I{c ziVc^a;%@rHg#d*}myAn+9LI1mRtu%r-|$r&Ty?+To&6$55bhyGFHadM5GBTNufEah$)#LR~}xnRy2-_c12~ATWm!6%aB*DmI4Bf9EqLv0d}B03j(+6Gg@duemcs9q4A~QEXAH&6X5Qw9{{R%!RvRoQNX8%q7sGP2le<(ATHY_*cYrL;4-sH#tVRAp z;)>XE`-sL%m4_r+9YW{#Ihc+Rf;`LMYR#PI5}Zsqj8fuYv2xYCK>OuJ@^=E#8(PFb zmn3Q1R+5O5boeJF&z>5Cw|NchM?ScR3(VyVU0wbLbTUA z+TilTQoRXJmHHz5)l`0J01)d`^*?js1sBt~U{1OTqESoJ%AUmaeKx^~CadfKZ8i!K$5#eh9->d^>)b4Hk z1OzZjrR8VVRY(SgIe;ITSylRr(QJyA&IOVB)^PCzoG^wPx6pra5Lpioh-w{{WRA1pX8m9YT{%u8)h7X_;wwh6$?M!E;ISJY;^{81%VAv$`+{OV?5|rp)F@e9!8q43 z*;Z--7_fRY!)n@+;;16XHEFxrKjL9+sQkt}P~G?lB6aL51>7_ACAy3iiyqTF##vGG zQqTIa$H~HZjuasd(qkg;l<4jHf`Q%Y9s!8tDLX_d4DR+EKp1;JWBNcnWcx|X238zJ znW|+a2{h`sJl#hEZK9sY4PR8oHvJ?%nLiK(sIy?TcHVoIojal&HfIn^f?g%)b?ry2 zQkAbvoh0y&m^xZy2y=60tEC??mX*b3aq2HZKM~SGG%^u!K=MJvAzvb1?2YZ|r|#h) zxWw{96z&2sToV5P8pHw)^l-q^7CsQNO>o2+Mi-;948e$;lf}vPEYM$qt|rw*uvi#{ zgar%10f#DxC4MCXg-2GxS6PM=IH!RG7Qhly3?tkwUZ^_SqG(r4sklI0?8Ig4okx(b zMwkNgxd>(89#qixs7KG3CDCx}bEz&Vf+dC#hlm9V-VcM3*SERKtQn@{TQ8U4lDv(s1(ue zmdcg-1f<0Z@5{t0Y1IO?1;r^u$yABD8lBpQ-DQ%c$cuoNZd3(9+qh@%$sV`?j>YV!V>aq0daY1%GU7PE*Mbqvt-h@1v_Ggz9gn1gAB@o~~&3+;cXwh?G`7+Ru^ z>FQr(x@`WC34oP?-$-GRGDLNv(PQXku*GcG#I(*QX^MV7A|KSMUHVYy@fN;56R0^t z@Zk==CifK7y;K9B4Sc(q6B8nT?aH2KD=!T7NeyRY&+{_p6?JneGc8*Wz9wHouwjdn z3OQyLUf(0s&6w;BOzBn8L)X$2Qu_kq>WonAI%RytsNGAR23+_?XtiV7R8t=j4GOqh zBLsYm;DIoNVXVQf*nVz^T<>MYK_;VV-38l34@3fkna+RM;oh8GXY}* z(fj@#N3fecVSw0?4?<922Q_SJ*-YxEFj4dQY#aH7*#{IA?Epesj)eeLAupUC#I-$P ziyO9C@m;pf#gGaGv)hu(9dJ{av@igfv?ISYQvwS_cGe|1HT;D~Ht`jKxDqiUX=)HYJu4t;uM zGlYDi6V;!qm>VQ7c^|ITaadTIG+cS>o@H4<(nZ_7WyBhc3c1OSzw!!blF2X34J}fV z^$`PN7TSNK<^W7q)z9t@qSxv)N5M>(BAvK7IhGs)5q`5temF@@2;-mnp`j>K%kCiP zV!S{32xab_wF053C|oD*P=Io{CdkEHY^e5tn+7VY_BR>KT3X@FlQnf#aPR)}3o^ou zC9mXwd`u5niT(&w+N+n$$^aa)puC*GVjRY0%5VXKY8W~}y3WGGbWDXPGjW0thhksc z5He=Ef+2${a~%%BG>5Ye&*C@PGiI<|5Bv2n;_Ga;c-_Y)fU_mB0}$H7pdz7w$5uHb zp<%$PR4q4&)SHRnO17p$okB^ew*y|e(MUvrZIxXb>F-XvtdRBddJ1a zw)!AwUW4U^I|?5$Yp54oS0u3#vrHAtqtWwmA(l2=vG-D?&u-6L)lAiQd@0qC|2<0+1wb20LL{>GfbNgZvggU&1nJGF;)zTmlUr zRy;>LPTHBSrR)v?>O6$Fx~0GYMLdS*+qiAc7h)y99=BwBqW zrVbxS*CO#QdJb;Mm&;CSeYyg&! z3v?XH4G?K?y^tUvXniBr!Yu|f9Fw<0nQvC#0nrfl@F=;z34&b=n~57EctE2LB?Pfy z+9nAbZim#<5*tPf4;9R`bi(bVtM9pn5GKuL%+Hlx$m?8*yr_ZEdNfQDRiFCGL&mPWN*mDPSF%j%r`Kd8oNGB6)uq5gkT=T4F! zFk&wvbI&8%0uUO0H>9`6kFGeKO+9~Iu8|3-d%qwKbMj~dEVi1rHXt*3P6n>)K`{1IL(p>kmFg8)ZB*30!xqi>NOQyIsK8!0ARrGfO zeik4Gp%U?C&12W%Q}y&(3ANoHk~3Si_fO&smLG6hHku5+C3-=ls3X7-xDFf62;v%I z>YBKT5K0Wo=x=NHl?LqO<}UGcslJ2ERTXV+Cu2ZC>L9zoV>a^{BExB-DXFBTAU)$7 zW}L&~_Qwf!rTdLjfP3HV!mj6d(GmrZGZPP1E4K(-%7j(IbsCibA9=A(i{bAqsQwAO zLyFkI%m)uiU?iy#sYn(!@#>pv3_M(T6d;F$hy5EUzDT1M=?aewORreI>tKue#{uHf@-{8{%&z4 zb+1Q%Pl!cN$eyo zLg%{W|{Qauks1UP_fWeeit`?kN?@YtjHRt7{ci%4wjM zp+d;5kKCxGmCUXHA;qfZwk437h&Fe>cp;wX75RYRFcX}_3(g5r(U}_LgJ^ZP{fI}^ zqptz&iGp#NRV1}HEEx4Vl{1y|efgPhWLR48^E2m1Olbk7SD9QPhi^d|GP8^|Z!E2V zx6}|b5wHR&4O-P_RGNxcjjG@rnX+1-BSjPoS2KqZrMvRMv6*XPvO@Jrw6MqK6=$^B zk}*;_?D=!5_L))$@#a#nW6z&?PizT82vP`~#JkZ~o?g$y!*jja9!^hrb*%5yi=p=k za4ki?At9WZ^B)cWL?ZDrDunFj zc7Wt6ZARP*h`fe{p4*2Buw9;1KBJTsEM}U=QtxR```#IMCnVvXxoz)KdZR;e!gMJ z!!FihU$h#6ZBOeCA#F>9+BQbCm-^TrZ@7_l5Jovv?i;8{^3U00)(CPx)AuYBm{HUR!pD0yu^1-Ik9dy?jG3O# zPZ6b>M<)pKOWfzUl-8+^pY|Y7c>w&uhEJy7P}WR#vGEbDRA2t|lm&|;3=dfJfTLEK zo4BycP{`q_R*of7kd=+{jkO)@G@x2)_Tk~?9v@8RRi~_AcI%@B$6yq0rY=*~3zeox zrbW>dDF?`hNYJO&AE<E%qpzvEQ^+zxI0zwlYxyPh+Jf_3qaFs*k602T5mS1&V^1#^Pm!khfRcjXE@W51lO!Yr3LFz*SRh(0y@-3)*Uq2 z=sFttnHB|KqXJ>2fo{(TS%rg_WN**ZLz&53lwn~l_sm-g&k)); zU^z#UC;L^=`-uC9AUMsLjxN^g%Clt)n--I0JjY?6^CO7&f-sS6nHNG0CA#V zGL8QL@#E$mw(2y2LoBcwjdPg1PDf05n_)BX;gs0oS=>)@{{Ui&I)e4*?U>XmeA2F% z!PKE1A;Y*B%*2qA6#L3w?iOa{);g9!?4QzJHhc%r%J(5xA^!kwBC)H(*$z%+kr%R2 z<#h<l6%$Erw0M-LgVM|-7~CeEbNZK|>G>fgyBY6ts0H6sdzDeqpJz`p%Pwo- z&+aaJf%}YdIAKZGIXz(4rx(O(dSj#u>RdX>S&q)?<%41a^gs3y7AhhI%gj(VJ|bA0 z>k@;MtAj(9GYSC;_K%N5N<4XqTFWY~$g->N^n}}_P7{~pex^vqk?GvUPNn_oCZtgt z#o1i0ps&2{{V$L;=9KLGQb=foFXUVlwJqwV18wlgkx8u9(jumMPrf_HWB^8 zzKT5g^?>zw3ZLvZtOHMlD>aJxXw_2bt%o zL4OldNEoz)Nr65pJBgY>8Oay>AwK{{ZAGJRL+AXh1n@1QyM?wXq$55P5FBMrAETo?=R~=%7lFX08$Xaj8^pB?2M`cIIi{}$!siFNr=?6vgPZ%5CFiG5S9@o?r63|0&OJ;W4D|fgXvJK2zFQ`erV(Mk6Te#9NmUh3W#48D7E_1N6#vjbJ`ERx> za)~fqn=liS9|R6x5RB#DN&A|7v&Id;yOb+f-cdicHy36m2 z%8j+|Wc~cZ0A=oj1C|_Im*+4P!0PEgP%6HRYu9n;{{W;)^I&l+2TO~1eM^ATNqo#2 zE9PpMJ3t7ksQ@EIej^5}xHZUoM$34X(M|OpdT)uyGT2n!&%_irW6u*{!K5lc4&W5f zU(m1AVe%!h@5d62Kf<*YMj_Q@^KJ?T0$EsY3AlS9?Cbm?B4)J-1`gc%K{OD`5IGnl z^fN=8LD3y6&RT$gN{(*-0PBhNru_E}<_*kIZ#5eiS(j63xQj56*yi7F2352^OA~BT zyV)af!4b+kT!-!wxv9z2;cluwSf3~N0b<9;U)o(`#OPH1PwErA^9M*~AgB%*j&=hU z5u0(MyE6)K>V1f!c2VxbCX3)qN-q|E(>IqD3_@tay@gqr-kDP*<&MKwWVu$7qSRve zg-hTP#s*VfgtHo1Eq4(Ns8$%u_<)UPFlt^OlqqVkSWFIFM7MYEDaGJ>uKm9(Ma>NQ zzy+b_+G7{36Sf%SJ9ipOhEe3>@dzzPZrI*qw15aREftpQA(n+W4;;$g_Vm5*s?7kLiog*}q!UoSuGZ#U9n_v#lMOvKE3#YgcR*0DBa-owwd zF^ypngfBbqi0?rN;XeH?d$Kg4brlQRPHjkf)+Q4!p})Zcd1Ujk`D6Dq9FuXgm+}fOv@6wH_+`NQDt>k238LqXF&w zIjCOY)A%N9nCsOl`XA)bzu+!Q_gDIq1LK(NLQ?p&#k}J&Q8Ek^{{T_Tz4b7v>G_Fe z_}hnH=fO5xH`HlHgj`I`FH^;MmJk-zK?nfgScAMV5gKj(02zN1BHhG&RKG70SbH@T zd;Za}7L$sVt|$R3jaNE|&AcJD4@_h;=M`VB<*O3Lq4egWQDAsKh>tZGj8Q0kANW20 z0D{{EB4nzI+q0#KpJ&0DS)d1S@1khZf9F zSgfxw=)$bEO0Yp2U!hHY;zVs_E%O?mhvP6ZUg?f-ZJ8NGV~7!^xge-9uZ5)l0J@5! zkB1Nrk$&L=v>{nMOcONRQHV4@B7Wj+1UH}Y#YjyEh1FQ~SNfP$8TF}h0 zVnu>{a9#?Sat#sZ3 zy<7NzVOS+IT^}$4)HoWKWgr~z;$21m0N*l(?~eZf*k>kIW+o*8F&X9x)H+0JIj|SQtyhpsyGEi8}={oWhd0<1)c-5w6{^>e`Lm za>W6HnB~vXSlq!y8MarN)6-KQ^h02` z2p(2ggH?b(8Hs`v3VaC$&;W2tx|3=}cp*-?QD~|VISU`A;VFE4LY4so>;0xZIA8~+y z9Ah-R(FD7|eIrD>*FWqO9+O=2FoK)uHMK2VGY$p}#*uD49H>N)Egn*Tr1uEN^5QG6 z0rw57T72X8C1ME!6B?=QTnmc{F% z`}TnVtwnQ~ut42i?h8F(T^I8z9c-*9SLKT8*LBbPF>mkvlpBgTJ%l12u!I=Tzu}X7 z%!QL8zU9N8nNfJS-rMlXx=o|y@@Hhx_C&&tdl~%^v%v}wFW_AJT2>~+R3;wwG` z`J6$RDLlt~@e>fco?1uLNu6T;f2m79t0V0O`l`TjSb{uHX>efyT&wyaVOtw82fR5z z{GzHAX~ea8OG^wgklMY*#!R!3?OTr4=noLP9TmyhBO>B%%*OjfnyV% z_WDb(@{ClpzOmu8%|YC9Jat~z%togwtn-?HfRRpGb&YN);{Z?}vI2aSUvcXwy8|x} z%Mm$^(ow#lVeufOSJ<|r){&=3ROPbe?NKyQwb<$-xKp-4njqQ=z?dltiBmi$BG6;T zB^|M|+E(hYkK!@-m1xPQt6!;DR^SRmN#{fMgAr!M(U;;qSOn<j=9S zXS5`Ey2LarE+=BhqY-w?9l91)aGqc#k6;c5ur!pz5LS;$9BLC1c;1&Rgstyo(6)8_O{DaQnfYk#+YL zSZeB+%wX6T^(qHLsvVmp`2;D~b$%983CIPWfEVi`sMtf#;A2J&esF@_c<0qZg$il~D(gXN72+*$|hw;zDSd&EaMQ{($sVa-aW zrO&|t(lhbzf8NZ~6JB2cvc)%Gx`J{g$!A@dQ>%?1$mg|=(* z#tJg-3bdeYg#ZfBO76OaE0BRQ5TK|mDoX5!HSH~UN{d^e&u5#Mg}fw08BL>}stQOL zE+AMq5J=Rf!MzZuEG)gS#{tApfF^L_9*alZ6k43I1YaEy1--Ax@}9E(nGxtVzemsTT~7<21O1l}(=?Tth7yn9 zRk}~IMb08&TDfeht0|0aSJeedTWaw#uSB}A(<}j{3q%YT3f(#5XH(F)(7+G47-MnSk0=@=Tcst8NIZdy$`f(zbK; zpJ*-6F12tC+!{C`&t@BmV6Mn)M-WEa<`tKH!g6i?qZ+GS&Ta@wV9v7;R|Ueht5UKp z5)p)8GP#+Rv7oY*g@MY!%(H5A+$f>6E}zSm2MNCv;nD>><+7O zW@VW?#U*ZQ8&E~I?;{dj;8gI;@4j}Ykwjc802$grTUC`uM_--aG9c2NxY`PC?m~aH+ z7@0s}`RNY$XOaci6_X>|8cobWb)THZ2-Ex!?%pUpvtIa2;+D=T^* zM^K~>v|NdRcnlLsl*ScE`Ti)dZhpeDE%~U`uYGHbtw9_($gc^pp zwk#ETVJ_3FwjioD%WFX3IF6m=i`7(ZR`F3`EW{MUjaKEVS%E9nHcqT+Ji3C59{o#L z+-d4G#UT7O4>@MQz7da4+965-Pyne~@e)vaL;WF{_M`K1^C{^*37JSxcB%Bj^(W4M zaR7IraAQ|$zNHt7Rez{;TvR ziUb`nyTf8S$;@^`nNaimO5I-LaRjll03C~OT#=ajD0BHG@?&VL?jy*GT+7-~u%{7< zH)k^l1c9^~JBTprxEo8PRu6Tvb=O9&-SsUm4cl;#P&Uj-cs)awHegou1q821{zaw4 zYpo9OJ|m3Q#qp4t1E?F7cOF!@NSd*LG$#5UY8Ld?^l*MY0Q3pEUPJnW2Sr8kU#JiP;HdjTW!fCTEy$EU zzEFAgVsLZGh_ESY*~5o1z|Kn+)>2vkH?f%1!f{M7nM>$_Lgk)|ymJN33&MVc7lrQt z<=`sF@i@$2e{d~6QD3P1MH>r#?ag)clMoFL zqlnT^a60#jYhhlJ%GS{4APOxQ&xnr+qtw6L$XD3J$qL2}5v0L+gNYo9zL=n(8~(Je z*h@`%THlGLO&*Jy+(gjENsmLhal7{e_qo6ieE$Ft5<&w#CwVnJP<+qBa4VGMhh6b^qHHWvcG6*D!DD~`j=e11gMd81l+(-DY0XZ z#HQSWEeA;)_K0nrPGIY9p7E#QhdE-v1s z1JOkF_Fi`#jalt6Uanz|qC68XmHtGr1_Nq{IcjL3U`T+xf zEb7c4s}f@gUzM0RKZqk6UFbKNLV74)%*2D*o>H$d>7*$zq=7GpoU3eHOfA~1Fu)@l z-LLrt3)!hf5Z|Mr8vFtT*JSt1R^JlT2(DJ)SA-VislDO?cA5457zd*-O(WVa`h{l^ zIvDg;*vv2uf#Nc%*wsLlW}f1j6;~4hj7~60h)M+=C_m#?JGg6&L2BFR4DTO@4-&u$-PD^n#8GX=Dw?1pai_ut$gIJc|q%zm~o{{V6@ zX%XR=3Ve~wL)#E`p(8`#Ln%!-N>ZSEXRKBg^nXS@kOlKiGScI}{6nt&NUF}GQ(2g& zP@OBMPx2Ff=nxKygas`t`h!aD?)u;LDuE_ZXl5?CX@-MDUFFvG{0j542SXwCC%S=6RZqYfCfx{+oDGde_-2jwtuhYL(0okK)Y z4bqMUEu9I&+96Am%&KaqqX($23D}AenxOo^l)57J%n&GnX>eU)Cli3*5rWA&M0QP{ zm%tK$Onsd}Wqf1vG=_^MDuWz@1cQacM}$(vf>4kwf-4Pf@7@ZgoXfw1{`iTo1@59j z>MMhY7F!R@07;5sAO#E$tyB1mVE2lQH^jXGO2YLh?R;Fc2~BV9A%a-pfw_6yY7H%U zm79YRbQXuqQbk%?sd_3~mi}N;A5F&bM+nrmb8d;32r0#T^D0ER2bSND%yNW8-~d{) z!Q8*%*>?NRHT+$Srtrjv(krr4)Ihki5}B+P%8^>L3BT%W)^U+cAwXD@6w!M#!I*Nu zIJsGj0ygvr6zs&!ULYA-vK5<3mU{OTP6Z5+DXy)nm!k44+kfKa6DXj}!GND+z|ld& z9TM?y?;A|Z+*T>n33mXyWfoqWCda%rw1e+5&00W8#9H)O4%G+t{7a+@Jb&J?cnUzw zjyi%c!2&|G^-7oO2u2;{vH~}`bXb@ELM`WxeGnN$(wc}=CuzmAlAK(iyE!nKpRbB6 zS-d|Iyo?K#u05tf+;*gGpbpE=?oka^MDDK7xY0qsrDtDK{Hl2~9)luN@&*;L{AhAGUe{3<4BYQ`wiWU`KZVfV*5Cp=8E zP^J5w&0a4Nudm1OS!hBcdy_mu>-#9l=MYo*AsMd|`X{_Nr}roxyuVRit?WPCqc9Ko zrZFb}0OJ+=NT6L|d!rQ!0K~VA5C#$Uo`>9?&xu3K zCV#YZRv>w0z-b0JAv7pu&SUXxw=44v^C~0$mKpPk%Mx4`fZGTWyyPWpsZ4V3eMc z*X#W=9fen$MN2m_M=@q<_qSad6#FtJTDu{}^;LImwSaVq0n-cQUP!hypVF z1yT73S)EEHhY?p&xB;Ok(lkP+xkR+7nM0V=8I}uz)w#K>%b9mLXVlUgPf)+<*)@l} zedn!}glQ=4E-cK+L0KUtN`fZhtU%4*>Jc6v$cbXHh{`JHDghQnP4%a!0WHIcj)%D{ zuIRQw-9=qPFEN;0cGLiv;%BHa3L39LKjv(;0yWJ4YZ9BOA5#2T*H?fET(X IHu#_a* Date: Fri, 22 Aug 2025 11:43:02 +0000 Subject: [PATCH 14/18] remove experimentals files --- .../tests/pipeline_tests/test_end2end.py | 534 ------------------ .../tests/pipeline_tests/test_vision_tower.py | 76 --- 2 files changed, 610 deletions(-) delete mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py delete mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py deleted file mode 100644 index 0cc07bf317d5..000000000000 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_end2end.py +++ /dev/null @@ -1,534 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. -# SPDX-License-Identifier: Apache-2.0 - -"""Test for Mistral-24B End-to-End Vision-Text Pipeline""" - -import torch -import pytest -from loguru import logger -import os -import ttnn - -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.common import ( - sample_host, - PagedAttentionConfig, - preprocess_inputs_prefill, -) - -from models.tt_transformers.tt.model_config import DecodersPrecision -from models.experimental.mistral_24b.tt.model import MistralTransformer as Transformer - -from models.tt_transformers.tt.generator import Generator - -from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer -from models.utility_functions import skip_for_grayskull, skip_for_blackhole - -from models.tt_transformers.tt.model_config import ModelArgs -from transformers import AutoProcessor, AutoModelForVision2Seq - -import re - - -def run_reference_demo_pipeline(messages, model_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503"): - """ - Run Hugging Face reference demo model (Vision-Text pipeline) using given messages. - """ - logger.info("Running reference HF vision-text model...") - - processor = AutoProcessor.from_pretrained(model_id) - model = AutoModelForVision2Seq.from_pretrained( - model_id, - device_map="auto", - torch_dtype=torch.bfloat16, - ) - - model.eval() - - # Apply chat template - prompt_text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" - ) - - # Extract images (already loaded) - image_inputs = [] - for msg in messages: - for item in msg["content"]: - if item["type"] == "image": - image_inputs.append(item["image"]) - - # Tokenize and move to model device - inputs = processor( - text=[prompt_text], - images=image_inputs, - return_tensors="pt", - ).to(model.device, dtype=torch.bfloat16) - - with torch.no_grad(): - generated_ids = model.generate( - **inputs, - max_new_tokens=100, - temperature=0.0, - top_p=0.9, - do_sample=False, - pad_token_id=model.config.pad_token_id, - ) - - # Decode - output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - logger.info(f"HF reference model output: {output}") - - chat = parse_chat_output(output) - display_chat(logger, chat) - - return output - - -def parse_chat_output(text): - """Parse chat output format from generated text.""" - pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" - matches = re.finditer(pattern, text, re.DOTALL) - return [(match.group("role"), match.group("message").strip()) for match in matches] - - -def display_chat(logger, conversation): - """Display chat conversation in formatted output.""" - for role, message in conversation: - if role == "user": - logger.info(f"👤 User: {message}") - elif role == "assistant": - logger.info(f"🤖 Assistant: {message}") - - -def setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations): - """Setup model arguments for vision-enabled model (Single Responsibility).""" - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - - return model_args, instruct - - -def setup_vision_prompts_and_tokenizer(model_args, instruct): - """Setup multimodal prompts and tokenizer for vision-enabled model.""" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://img.freepik.com/premium-photo/girl-hugging-dog-with-girl-hugging-her_737761-2565.jpg", - }, - { - "type": "text", - "text": "Is there a cat in this image? If not, what animal do you see in the image? Describe the image in detail in 600 words.", - }, - ], - } - ] - - tokenizer = model_args.tokenizer - return messages, tokenizer - - -def process_vision_info(messages): - """Extract images (already opened) from messages.""" - image_inputs = [] - video_inputs = None # Not used - - for msg in messages: - content = msg.get("content", []) - for item in content: - if item.get("type") == "image": - image_inputs.append(item["image"]) - - return image_inputs, video_inputs - - -def process_real_vision_inputs(messages, model_args): - """Process real image inputs using AutoProcessor (Interface Segregation).""" - processor = AutoProcessor.from_pretrained(os.getenv("HF_MODEL")) - - text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, padding=True, padding_side="left" - ) - - image_inputs, video_inputs = process_vision_info(messages) - - encoded = processor( - text=[text], images=image_inputs, videos=video_inputs, return_tensors="pt", return_dict=True - ).to("cpu", dtype=torch.bfloat16) - input_ids = encoded["input_ids"] - pixel_values = encoded["pixel_values"] if "pixel_values" in encoded else None - attention_mask = encoded["attention_mask"] if "attention_mask" in encoded else None - image_sizes = encoded["image_sizes"] if "image_sizes" in encoded else None - - return { - "input_ids": input_ids, - "pixel_values": pixel_values, - "attention_mask": attention_mask, - "image_sizes": image_sizes, - "processor": processor, - } - - -def load_separate_models_like_test_end2end(model_args, mesh_device, dtype, paged_attention, page_params): - """Load separate vision and text models following test_end2end.py pattern.""" - state_dict = model_args.load_state_dict() - - vision_prefix = "vision_tower." - # Setup paged attention config (exactly like test_end2end.py) - paged_attention_config = None - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - tt_ccl = TT_CCL(mesh_device) - # Load vision model (exactly like test_end2end.py) - vision_model = TtMistralVisionTransformer( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - state_dict=state_dict, - state_dict_prefix=vision_prefix, - dtype=dtype, - model_args=model_args, - ) - - # Load text model (exactly like test_end2end.py) - text_model = Transformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - paged_attention_config=paged_attention_config, - ) - logger.info("Separate vision and text models loaded like test_end2end.py") - return vision_model, text_model - - -def run_generation_exactly_like_test_end2end( - vision_model, - text_model, - processed_inputs, - model_args, - page_table=None, - paged_attention_config=None, - max_gen_len=20, - repetition_ngram_size=3, -): - """Run generation following the EXACT pattern from test_end2end.py.""" - input_ids = processed_inputs["input_ids"] - - logger.info("Running generation exactly like test_end2end.py...") - - logger.info("Running Vision Model...") - generator = Generator([text_model], [model_args], vision_model.mesh_device, tokenizer=model_args.tokenizer) - tt_kv_cache = [[l.attention.layer_past for l in text_model.layers]] if paged_attention_config else None - - input_tokens_prefill = input_ids - batch_size = input_tokens_prefill.shape[0] - - prompt_text = model_args.tokenizer.decode(input_ids[0].tolist()) - input_prompts = [prompt_text] - - ( - input_tokens_prefill_pt, - encoded_prompts, - decoding_pos, - prefill_lens, - ) = preprocess_inputs_prefill( - input_prompts, - model_args.tokenizer, - [model_args], - instruct=True, - max_generated_tokens=max_gen_len, - max_prefill_len=8192, - ) - - input_tokens_prefill_pt = torch.stack(input_tokens_prefill_pt).view(batch_size, -1) - - logger.info("Running prefill...") - logits = generator.prefill_forward_text( - input_tokens_prefill_pt, - page_table=page_table, - kv_cache=tt_kv_cache, - prompt_lens=decoding_pos, - vision_model=vision_model, - processed_inputs=processed_inputs, - ) - - prefilled_token = torch.argmax(logits, dim=-1) - prefilled_token_decoded_res = model_args.tokenizer.decode(prefilled_token[0].item()) - logger.info(f"prefilled_token_decoded_res: {prefilled_token_decoded_res}") - - logger.info(f"Prefilled token: {prefilled_token}") - - import torch.nn.functional as F - - logger.info(f"Encoded prompt: {encoded_prompts[0]}") - logger.info(f"Decoded prompt: {model_args.tokenizer.decode(encoded_prompts[0])}") - - # logits: [1, 1, vocab_size] - last_logits = logits[0, -1] # shape: [vocab_size] - probs = F.softmax(last_logits, dim=-1) - - top_k = 5 - topk_probs, topk_indices = torch.topk(probs, k=top_k) - - topk_tokens = [model_args.tokenizer.decode([idx.item()]) for idx in topk_indices] - - logger.info("Top-5 predicted tokens (with probabilities):") - for i in range(top_k): - logger.info(f"{i+1}. Token: '{topk_tokens[i]}' (ID={topk_indices[i].item()}), P={topk_probs[i].item():.4f}") - - all_outputs = [encoded_prompts[0][: prefill_lens[0]]] - all_outputs[0].append(int(prefilled_token[0].item())) - - current_pos = torch.tensor([decoding_pos[0]]) - out_tok = prefilled_token - generation_length = max_gen_len - - results = [] - - logger.info("Starting decode loop...") - for iteration in range(generation_length): - logger.info(f"[Text] Decoding token {iteration}, current_pos: {current_pos.item()}") - - logits = generator.decode_forward_text( - out_tok, - current_pos, - enable_trace=False, - page_table=page_table, - kv_cache=tt_kv_cache, - ) - - _, out_tok = sample_host( - logits, - temperature=0, - top_p=0.9, - ) - - token_id = out_tok[0].item() - decoded_token = model_args.tokenizer.decode([token_id]) - logger.info(f"Generated token {iteration}: ID={token_id}, text='{decoded_token}'") - - # Stop if EOS detected - if token_id == model_args.tokenizer.eos_token_id: - logger.info("EOS token detected, stopping generation.") - break - - # Stop if repetition detected (n-gram) - if len(all_outputs[0]) >= repetition_ngram_size * 2: - last_ngram = tuple(all_outputs[0][-repetition_ngram_size:]) - for i in range(len(all_outputs[0]) - repetition_ngram_size): - if tuple(all_outputs[0][i : i + repetition_ngram_size]) == last_ngram: - logger.info(f"Detected {repetition_ngram_size}-gram repetition, stopping.") - break - - # Create result object - result = type("TokenResult", (), {"token": token_id, "text": decoded_token})() - - results.append(result) - - all_outputs[0].append(token_id) - current_pos += 1 - - # Early stopping (exactly like test_end2end.py) - if len(all_outputs[0]) >= 5 and all(t == all_outputs[0][-1] for t in all_outputs[0][-5:]): - logger.warning(f"Detected exact repetition of token {all_outputs[0][-1]} five times in a row. Stopping.") - break - - # Final response (exactly like test_end2end.py) - response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"Each iteration Generated Response:\n{response}") - logger.info(f"Each iteration Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") - chat = parse_chat_output(response) - display_chat(logger, chat) - - logger.info(f" Each iteration Generated {len(results)} tokens successfully") - - # Final response (exactly like test_end2end.py) - response = model_args.tokenizer.decode(all_outputs[0], skip_special_tokens=True) - logger.info(f"Final Generated Response:\n{response}") - logger.info(f"Generated {len(all_outputs[0])} tokens: {all_outputs[0]}") - chat = parse_chat_output(response) - display_chat(logger, chat) - - logger.info(f"Generated {len(results)} tokens successfully") - return results - - -def validate_e2e_outputs(results, expected_min_tokens=1): - """Validate end-to-end pipeline outputs.""" - if not results: - logger.error("No results generated from E2E pipeline") - return False - - if len(results) < expected_min_tokens: - logger.warning(f"Generated only {len(results)} tokens, expected at least {expected_min_tokens}") - return False - - # Check if tokens are valid - for result in results: - if not hasattr(result, "token") or not hasattr(result, "text"): - logger.error("Invalid result format") - return False - - logger.info("E2E pipeline validation passed") - return True - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") -@pytest.mark.timeout(1800) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "weights, layers", - [ - ("instruct", None), - ], - ids=["full"], -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - # False, - ), - ids=( - "paged_attention", - # "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (1024 * 8,), # Use smaller seq_len like test_end2end.py to avoid memory issues -) -@pytest.mark.parametrize( - "optimizations", - [ - lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), - ], - ids=["accuracy"], -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_e2e_vision_text_pipeline( - weights, - layers, - max_seq_len, - batch_size, - paged_attention, - page_params, - optimizations, - mesh_device, - reset_seeds, - request, - device_params, -): - """Test end-to-end vision-text pipeline using proper Generator methods.""" - logger.info("Starting E2E vision-text pipeline test") - - # Use bfloat8_b like test_end2end.py for better memory efficiency - dtype = ttnn.bfloat8_b - - # Setup vision-enabled model configuration - model_args, instruct = setup_vision_model_args(weights, max_seq_len, batch_size, mesh_device, optimizations) - - if layers is not None: - model_args.n_layers = layers - - # Setup vision prompts and tokenizer - messages, tokenizer = setup_vision_prompts_and_tokenizer(model_args, instruct) - - # Process real vision inputs from images - processed_inputs = process_real_vision_inputs(messages, model_args) - - # Load separate models following test_end2end.py pattern - logger.info("Loading separate vision and text models like test_end2end.py...") - vision_model, text_model = load_separate_models_like_test_end2end( - model_args, mesh_device, dtype, paged_attention, page_params - ) - - # Setup page table for paged attention (exactly like test_end2end.py) - page_table_tt = None - paged_attention_config = None - - # Prepare page table for paged attention (exactly like test_end2end.py) - page_table = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Run generation following EXACT test_end2end.py pattern - logger.info("Running generation following EXACT test_end2end.py pattern...") - results = run_generation_exactly_like_test_end2end( - vision_model, text_model, processed_inputs, model_args, page_table, paged_attention_config, max_gen_len=1024 * 4 - ) - - # Validate results - validation_passed = validate_e2e_outputs(results, expected_min_tokens=1) - - # Final validation - if validation_passed and len(results) > 0: - logger.info("E2E vision-text pipeline test PASSED!") - logger.info(f"Successfully generated {len(results)} tokens") - - # Log generated tokens for debugging - for i, result in enumerate(results[:5]): - logger.info(f"Token {i}: {result.token} -> '{result.text}'") - else: - logger.error("E2E pipeline test failed") - assert False, f"E2E pipeline failed - generated {len(results)} tokens, validation: {validation_passed}" diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py deleted file mode 100644 index 8401ec212ecd..000000000000 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_tower.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. -# SPDX-License-Identifier: Apache-2.0 - -""" -This file is a unit test for validating the Mistral-24B Vision Tower model. -""" - -import os -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.mistral_24b.tt.pipeline.mistral_vision_tower import MistralVisionTower -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_mistral_vision_tower(mesh_device, reset_seeds): - pcc_required = 0.99 - dtype = ttnn.bfloat16 - - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - first_layer_prefix = "vision_tower." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) - } - - B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size - input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) - - ##### Reference model output (Torch) ##### - reference_model = model_args.reference_vision_model() - reference_model.load_state_dict(partial_state_dict) - reference_output = reference_model(input_tensor, image_sizes=[(H, W)]) - - reference_output = reference_output.last_hidden_state - tt_ccl = TT_CCL(mesh_device) - ##### TT Model: MistralVisionTower ##### - vision_model = MistralVisionTower( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - state_dict=state_dict, - state_dict_prefix=first_layer_prefix, - dtype=dtype, - configuration=model_args, - ) - - tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) - tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ - :, :, :, : tt_output.shape[-1] - ] - tt_output = tt_output.squeeze(0) - passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output)) - logger.info(f"PCC: {pcc_message}") - assert passing, f"PCC below {pcc_required}. {pcc_message}" From d094dabb5a75fbd5710b9ce34107ed7e75fe6690 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 14:20:33 +0000 Subject: [PATCH 15/18] Refactor comments and clean code --- .../tests/multimodal/mistral_24b/test_vision_model.py | 2 +- .../tests/multimodal/mistral_24b/test_vision_rms.py | 4 ++++ .../tests/multimodal/mistral_24b/test_vision_tower.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py index ff4d63a111b4..a46c6faa7ba7 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py index ccd092631ad1..868cfec3c806 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_rms.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + import os import pytest diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py index 03060b499b1a..90d225e79b06 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import os From e45ddaad7d5d2f8f6768a32aacd07390059395b3 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 15:09:32 +0000 Subject: [PATCH 16/18] Refactor comments in mistral multimodal test --- .../tests/multimodal/mistral_24b/test_vision_model.py | 4 ++++ .../tests/multimodal/mistral_24b/test_vision_tower.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py index a46c6faa7ba7..e9f1ec9a91ed 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_model.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file is a unit test for validating the Mistral-24B Vision Model pipeline. +""" + import os import pytest diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py index 90d225e79b06..fdfe2f1dcb5b 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_vision_tower.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 +""" +This file is a unit test for validating the Mistral-24B Vision Tower model. +""" + import os import pytest From 08aabb5086e3829dd0900f5fc8560c185b809155 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 15:11:48 +0000 Subject: [PATCH 17/18] Remove prefix in text_state_dict_prefix --- models/tt_transformers/tt/model_config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 204b228826ab..bb94d3e7c251 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -1396,8 +1396,6 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): def _get_text_prefix(self): if self.is_vision(): - if "Mistral-Small-3.1-24B-Instruct-2503" in self.model_name: - return "language_model." return "text_model." else: return "" From c5ea6e34be1185298b7ff16dea727592aa925ce6 Mon Sep 17 00:00:00 2001 From: nikileshx Date: Fri, 22 Aug 2025 16:07:36 +0000 Subject: [PATCH 18/18] Remove comment --- .../tests/pipeline_tests/test_vision_model.py | 98 ------------------- .../mistral_24b/test_patch_rot_emb.py | 2 - 2 files changed, 100 deletions(-) delete mode 100644 models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py diff --git a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py b/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py deleted file mode 100644 index 2939816b1dbe..000000000000 --- a/models/experimental/mistral_24b/tests/pipeline_tests/test_vision_model.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. -# SPDX-License-Identifier: Apache-2.0 - -""" -This file is a unit test for validating the Mistral-24B Vision Model pipeline. -""" - -import os -import pytest -import torch -from loguru import logger - -import ttnn -from models.tt_transformers.tt.ccl import TT_CCL -from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.mistral_24b.tt.pipeline.vision_model import TtMistralVisionTransformer -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - - -def get_image_features(vision_tower, projector, input_tensor, image_sizes): - """ - Get image features from the vision tower and projector. - """ - vision_token = vision_tower(input_tensor, image_sizes).last_hidden_state - image_features = projector(vision_token.squeeze(0), image_sizes) - return image_features - - -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "device_params", - [{"fabric_config": ttnn.FabricConfig.FABRIC_1D, "trace_region_size": 30000000, "num_command_queues": 1}], - indirect=True, -) -def test_mistral_vision_model(mesh_device, reset_seeds): - pcc_required = 0.97 - dtype = ttnn.bfloat8_b - - model_args = ModelArgs(mesh_device) - state_dict = model_args.load_state_dict() - - first_layer_prefix = "vision_tower." - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if k.startswith(first_layer_prefix) - } - - ##### Reference model output (Torch) ##### - reference_model = model_args.reference_vision_model() - reference_model.load_state_dict(partial_state_dict) - - mmp_first_layer_prefix = "multi_modal_projector." - - mmp_partial_state_dict = { - k[len(mmp_first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(mmp_first_layer_prefix)) - } - - reference_mmp = model_args.reference_vision_multi_modal() - reference_mmp.load_state_dict(mmp_partial_state_dict) - - B, C, H, W = 1, 3, model_args.vision_chunk_size, model_args.vision_chunk_size - input_tensor = torch.rand((B, C, H, W), dtype=torch.bfloat16) - - reference_output = get_image_features(reference_model, reference_mmp, input_tensor, image_sizes=[(H, W)]) - - # ##### TT Model: TtMistralVisionTransformer ##### - tt_ccl = TT_CCL(mesh_device=mesh_device) - vision_model = TtMistralVisionTransformer( - mesh_device=mesh_device, - tt_ccl=tt_ccl, - state_dict=state_dict, - state_dict_prefix=first_layer_prefix, - dtype=dtype, - model_args=model_args, - ) - - tt_output = vision_model(input_tensor, image_sizes=[(H, W)]) - tt_output = ttnn.to_torch(tt_output, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))[ - :, : tt_output.shape[-1] - ] - - non_zero_indices = tt_output.ne(0).nonzero(as_tuple=True) - tt_output = tt_output[non_zero_indices] - reference_output = reference_output[non_zero_indices] - - passing, pcc_message = comp_pcc(reference_output, tt_output, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output)) - logger.info(f"PCC: {pcc_message}") - assert passing, f"PCC below {pcc_required}. {pcc_message}" diff --git a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py index e62c0c9751c9..6ccc36813ae9 100644 --- a/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py +++ b/models/tt_transformers/tests/multimodal/mistral_24b/test_patch_rot_emb.py @@ -9,8 +9,6 @@ import ttnn from models.tt_transformers.tt.model_config import ModelArgs - -# models/tt_transformers/tt/common.py from models.tt_transformers.tt.multimodal.mistral_24b.vision_rope import VisionRotarySetup as RotarySetup from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull