From 516082c65853d48b50850a05782267f36bcdc170 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 4 Dec 2025 16:49:47 -0800 Subject: [PATCH 1/2] add bf16 master weights example Signed-off-by: Masahiro Tanaka --- training/bf16_master_weight/README.md | 217 +++++++++++ .../bf16_master_weight/configs/baseline.json | 31 ++ .../bf16_master_weight/configs/bf16_full.json | 34 ++ training/bf16_master_weight/gather_memory.py | 175 +++++++++ training/bf16_master_weight/plot_loss.py | 93 +++++ training/bf16_master_weight/run_comparison.sh | 112 ++++++ training/bf16_master_weight/train.py | 358 ++++++++++++++++++ 7 files changed, 1020 insertions(+) create mode 100644 training/bf16_master_weight/README.md create mode 100644 training/bf16_master_weight/configs/baseline.json create mode 100644 training/bf16_master_weight/configs/bf16_full.json create mode 100644 training/bf16_master_weight/gather_memory.py create mode 100644 training/bf16_master_weight/plot_loss.py create mode 100755 training/bf16_master_weight/run_comparison.sh create mode 100644 training/bf16_master_weight/train.py diff --git a/training/bf16_master_weight/README.md b/training/bf16_master_weight/README.md new file mode 100644 index 000000000..2e039b2ea --- /dev/null +++ b/training/bf16_master_weight/README.md @@ -0,0 +1,217 @@ +# BF16 Low-Precision Master Weights and Optimizer States + +This example demonstrates DeepSpeed's new low-precision training options that can significantly reduce memory usage: + +- `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32 +- `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16 + +These options work with ZeRO Stage 3 and `torch.autocast` to provide memory-efficient training while maintaining numerical stability. + +## Memory Savings + +Using a 254M parameter simple transformer model with the following configuration: +- Hidden dimension: 1024 +- Layers: 12 +- Attention heads: 16 +- Batch size: 4 +- Sequence length: 512 +- ZeRO Stage: 3 + +### 1-GPU Results + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 4.14 GB | 5.93 GB | 0.1042s | +| BF16 low-precision (master + opt states) | **2.71 GB** | 5.73 GB | 0.1121s | + +**Allocated memory reduction: 1.43 GB (34.5%)** + +### 4-GPU Results (per GPU) - 254M Model + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 1.29 GB | 3.57 GB | 0.1189s | +| BF16 low-precision (master + opt states) | **0.94 GB** | 4.44 GB | 0.1249s | + +**Allocated memory reduction: 0.35 GB per GPU (27%)** + +### 4-GPU Results (per GPU) - 6.86B Model + +Using a 6.86B parameter model (hidden=4096, layers=32, heads=32, batch=1, seq=512): + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 25.74 GB | 41.28 GB | 0.5078s | +| BF16 low-precision (master + opt states) | **16.17 GB** | **33.20 GB** | 0.5064s | + +**Memory reduction: 9.57 GB allocated (37%), 8.08 GB peak (19.6%)** + +### 4-GPU Results (per GPU) - 6.86B Model with Activation Checkpointing + +With activation checkpointing enabled, the optimizer state memory becomes the dominant factor, making the savings even more visible: + +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s | +| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s | + +**Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)** + +With activation checkpointing, peak memory drops significantly for both configurations, but the bf16 low-precision option shows an even larger relative improvement - nearly **40% reduction in peak memory**. + +The allocated memory reflects the optimizer state memory, which is where the low-precision options provide savings. Peak memory includes activations and temporary buffers which can vary based on execution order. + +## Loss Curve Comparison + +To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset: + +![Loss Comparison](logs/7b_loss_run/loss_comparison.png) + +| Configuration | Final Loss | Mean Loss | Loss Std | +|---------------|------------|-----------|----------| +| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 | +| BF16 Low-Precision | 3.12 | 2.90 | 2.37 | + +The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings. + +To reproduce the loss curve comparison: + +```bash +# Run 1000 steps with wikitext dataset +deepspeed --num_gpus=4 train.py --deepspeed_config configs/baseline.json \ + --num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \ + --num_steps 1000 --activation_checkpointing \ + --loss_log_file logs/baseline_loss.csv --use_real_data --seed 42 + +deepspeed --num_gpus=4 train.py --deepspeed_config configs/bf16_full.json \ + --num_layers 32 --hidden_dim 4096 --num_heads 32 --batch_size 1 \ + --num_steps 1000 --activation_checkpointing \ + --loss_log_file logs/bf16_full_loss.csv --use_real_data --seed 42 + +# Generate comparison plot +python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss.csv \ + --output loss_comparison.png +``` + +## Configuration + +### Baseline (FP32 master weights and optimizer states) + +```json +{ + "bf16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3 + } +} +``` + +### BF16 Low-Precision (BF16 master weights, gradients, and optimizer states) + +```json +{ + "bf16": { + "enabled": true, + "bf16_master_weights_and_grads": true, + "bf16_optimizer_states": true + }, + "zero_optimization": { + "stage": 3 + }, + "torch_autocast": { + "enabled": true, + "dtype": "torch.bfloat16" + } +} +``` + +## Usage + +### Run Individual Configurations + +```bash +# Run baseline configuration +deepspeed --num_gpus=1 train.py --deepspeed_config configs/baseline.json + +# Run BF16 low-precision configuration +deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_full.json +``` + +### Run Memory Comparison + +```bash +# Run both configurations and generate comparison report +./run_comparison.sh + +# With custom settings +./run_comparison.sh --num_layers 24 --hidden_dim 2048 --batch_size 2 +``` + +### Gather Results from Logs + +```bash +python gather_memory.py --log_dir logs/ +``` + +## Training Script Options + +``` +--hidden_dim Hidden dimension size (default: 1024) +--num_layers Number of transformer layers (default: 12) +--num_heads Number of attention heads (default: 16) +--vocab_size Vocabulary size (default: 50000) +--batch_size Batch size per GPU (default: 4) +--seq_length Sequence length (default: 512) +--num_steps Number of training steps (default: 20) +--warmup_steps Warmup steps before measuring (default: 5) +--deepspeed_config Path to DeepSpeed config file +``` + +## Requirements + +- DeepSpeed with BF16 support +- PyTorch with BF16 support +- GPU with BF16 support (e.g., NVIDIA Ampere or newer) + +## How It Works + +### Standard BF16 Training (Baseline) + +In standard BF16 training with DeepSpeed: +- Model parameters are stored in BF16 +- Forward/backward computations use BF16 via `torch.autocast` +- Master weights are maintained in FP32 for optimizer updates +- Optimizer states (Adam momentum and variance) are in FP32 + +This requires significant memory for the FP32 copies. + +### BF16 Low-Precision Training + +With the new options enabled: +- `bf16_master_weights_and_grads=true`: Master weights and gradients stay in BF16 +- `bf16_optimizer_states=true`: Adam momentum and variance buffers use BF16 + +This eliminates the FP32 copies, reducing memory by approximately 2 bytes per parameter for master weights and 4 bytes per parameter for optimizer states (for Adam which has 2 state buffers). + +### Memory Breakdown + +For a model with N parameters: + +| Component | Baseline | BF16 Low-Precision | +|-----------|----------|-------------------| +| Model params | 2N bytes (BF16) | 2N bytes (BF16) | +| Master weights | 4N bytes (FP32) | 2N bytes (BF16) | +| Gradients | 4N bytes (FP32) | 2N bytes (BF16) | +| Adam momentum | 4N bytes (FP32) | 2N bytes (BF16) | +| Adam variance | 4N bytes (FP32) | 2N bytes (BF16) | +| **Total** | **18N bytes** | **10N bytes** | + +This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors. + +## Related Resources + +- [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options) +- [DeepSpeed Core API Updates Blog](../../blogs/core_api_update/README.md) +- [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700) diff --git a/training/bf16_master_weight/configs/baseline.json b/training/bf16_master_weight/configs/baseline.json new file mode 100644 index 000000000..525c4e00a --- /dev/null +++ b/training/bf16_master_weight/configs/baseline.json @@ -0,0 +1,31 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + + "bf16": { + "enabled": true + }, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_param_persistence_threshold": 0 + }, + + "torch_autocast": { + "enabled": false + } +} diff --git a/training/bf16_master_weight/configs/bf16_full.json b/training/bf16_master_weight/configs/bf16_full.json new file mode 100644 index 000000000..cc854baae --- /dev/null +++ b/training/bf16_master_weight/configs/bf16_full.json @@ -0,0 +1,34 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 100, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01 + } + }, + + "bf16": { + "enabled": true, + "bf16_master_weights_and_grads": true, + "bf16_optimizer_states": true + }, + + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "reduce_bucket_size": 5e7, + "stage3_param_persistence_threshold": 0 + }, + + "torch_autocast": { + "enabled": true, + "dtype": "torch.bfloat16" + } +} diff --git a/training/bf16_master_weight/gather_memory.py b/training/bf16_master_weight/gather_memory.py new file mode 100644 index 000000000..6eb9ab813 --- /dev/null +++ b/training/bf16_master_weight/gather_memory.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Script to gather and compare memory usage from training logs. + +Usage: + python gather_memory.py --log_dir logs/20231201_120000 +""" + +import argparse +import os +import re +from pathlib import Path + + +def parse_summary_line(line): + """Parse the SUMMARY line from log output.""" + pattern = r"SUMMARY: config=(\S+) params=(\d+) peak_mem_bytes=(\d+) alloc_mem_bytes=(\d+) avg_step_time=(\S+)" + match = re.search(pattern, line) + if match: + return { + "config": match.group(1), + "params": int(match.group(2)), + "peak_mem_bytes": int(match.group(3)), + "alloc_mem_bytes": int(match.group(4)), + "avg_step_time": float(match.group(5)), + } + return None + + +def format_bytes(bytes_val): + """Format bytes to human-readable string.""" + gb = bytes_val / (1024 ** 3) + return f"{gb:.2f} GB" + + +def format_bytes_mb(bytes_val): + """Format bytes to MB.""" + mb = bytes_val / (1024 ** 2) + return f"{mb:.1f} MB" + + +def get_config_name(config_path): + """Extract clean config name from path.""" + name = Path(config_path).stem + if name == "baseline": + return "Baseline (fp32 master)" + elif name == "bf16_master_wg": + return "bf16_master_weights_and_grads" + elif name == "bf16_full": + return "bf16_full (master + opt states)" + return name + + +def main(): + parser = argparse.ArgumentParser(description="Gather memory usage from training logs") + parser.add_argument("--log_dir", type=str, required=True, help="Directory containing log files") + parser.add_argument("--output", type=str, default=None, help="Output file for summary") + args = parser.parse_args() + + log_dir = Path(args.log_dir) + if not log_dir.exists(): + print(f"Error: Log directory '{log_dir}' does not exist") + return 1 + + # Find and parse all log files + results = [] + log_files = ["baseline.log", "bf16_full.log"] + + for log_file in log_files: + log_path = log_dir / log_file + if not log_path.exists(): + print(f"Warning: Log file '{log_path}' not found, skipping") + continue + + with open(log_path, "r") as f: + for line in f: + summary = parse_summary_line(line) + if summary: + results.append(summary) + break + + if not results: + print("No results found in log files") + return 1 + + # Calculate baseline for comparison + baseline_peak = None + for r in results: + if "baseline" in r["config"]: + baseline_peak = r["peak_mem_bytes"] + break + + # Generate summary + output_lines = [] + output_lines.append("=" * 80) + output_lines.append("BF16 Low-Precision Master Weights - Memory Usage Comparison") + output_lines.append("=" * 80) + output_lines.append("") + + # Table header + output_lines.append(f"{'Configuration':<40} {'Peak Memory':<15} {'Reduction':<15} {'Step Time':<12}") + output_lines.append("-" * 80) + + for r in results: + config_name = get_config_name(r["config"]) + peak_mem = format_bytes(r["peak_mem_bytes"]) + step_time = f"{r['avg_step_time']:.4f}s" + + if baseline_peak and baseline_peak > 0: + reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100 + reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-" + else: + reduction_str = "-" + + output_lines.append(f"{config_name:<40} {peak_mem:<15} {reduction_str:<15} {step_time:<12}") + + output_lines.append("-" * 80) + output_lines.append("") + + # Detailed breakdown + output_lines.append("Detailed Results:") + output_lines.append("-" * 40) + for r in results: + config_name = get_config_name(r["config"]) + output_lines.append(f"\n{config_name}:") + output_lines.append(f" Parameters: {r['params']:,}") + output_lines.append(f" Peak Memory: {format_bytes(r['peak_mem_bytes'])} ({r['peak_mem_bytes']:,} bytes)") + output_lines.append(f" Allocated Memory: {format_bytes(r['alloc_mem_bytes'])} ({r['alloc_mem_bytes']:,} bytes)") + output_lines.append(f" Avg Step Time: {r['avg_step_time']:.4f}s") + + output_lines.append("") + output_lines.append("=" * 80) + + # Generate markdown table + output_lines.append("") + output_lines.append("Markdown Table (for README):") + output_lines.append("-" * 40) + output_lines.append("") + output_lines.append("| Configuration | Peak Memory | Memory Reduction | Avg Step Time |") + output_lines.append("|---------------|-------------|------------------|---------------|") + + for r in results: + config_name = get_config_name(r["config"]) + peak_mem = format_bytes(r["peak_mem_bytes"]) + step_time = f"{r['avg_step_time']:.4f}s" + + if baseline_peak and baseline_peak > 0: + reduction = ((baseline_peak - r["peak_mem_bytes"]) / baseline_peak) * 100 + reduction_str = f"{reduction:+.1f}%" if reduction != 0 else "-" + else: + reduction_str = "-" + + output_lines.append(f"| {config_name} | {peak_mem} | {reduction_str} | {step_time} |") + + output_lines.append("") + + # Print to stdout + summary_text = "\n".join(output_lines) + print(summary_text) + + # Save to file + output_path = args.output or (log_dir / "summary.txt") + with open(output_path, "w") as f: + f.write(summary_text) + + print(f"\nSummary saved to: {output_path}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/training/bf16_master_weight/plot_loss.py b/training/bf16_master_weight/plot_loss.py new file mode 100644 index 000000000..2ec0ec850 --- /dev/null +++ b/training/bf16_master_weight/plot_loss.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Plot loss curves from training logs. + +Usage: + python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_loss.csv --output loss_comparison.png +""" + +import argparse +import csv +import matplotlib.pyplot as plt +import numpy as np + + +def load_loss_data(filepath): + """Load loss data from CSV file.""" + steps = [] + losses = [] + with open(filepath, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + steps.append(int(row['step'])) + losses.append(float(row['loss'])) + return np.array(steps), np.array(losses) + + +def smooth_curve(values, window=10): + """Apply moving average smoothing.""" + if len(values) < window: + return values + smoothed = np.convolve(values, np.ones(window)/window, mode='valid') + # Pad the beginning to maintain length + padding = np.full(window-1, smoothed[0]) + return np.concatenate([padding, smoothed]) + + +def main(): + parser = argparse.ArgumentParser(description="Plot loss curves comparison") + parser.add_argument("--baseline", type=str, required=True, help="Baseline loss CSV file") + parser.add_argument("--bf16", type=str, required=True, help="BF16 low-precision loss CSV file") + parser.add_argument("--output", type=str, default="loss_comparison.png", help="Output plot file") + parser.add_argument("--smooth", type=int, default=20, help="Smoothing window size") + parser.add_argument("--title", type=str, default="Loss Comparison: Baseline vs BF16 Low-Precision", help="Plot title") + args = parser.parse_args() + + # Load data + baseline_steps, baseline_losses = load_loss_data(args.baseline) + bf16_steps, bf16_losses = load_loss_data(args.bf16) + + # Apply smoothing + baseline_smooth = smooth_curve(baseline_losses, args.smooth) + bf16_smooth = smooth_curve(bf16_losses, args.smooth) + + # Create plot + fig, ax = plt.subplots(figsize=(12, 6)) + + # Plot raw data with low alpha + ax.plot(baseline_steps, baseline_losses, alpha=0.2, color='blue') + ax.plot(bf16_steps, bf16_losses, alpha=0.2, color='orange') + + # Plot smoothed curves + ax.plot(baseline_steps, baseline_smooth, label='Baseline (fp32 master)', color='blue', linewidth=2) + ax.plot(bf16_steps, bf16_smooth, label='BF16 Low-Precision', color='orange', linewidth=2) + + ax.set_xlabel('Step', fontsize=12) + ax.set_ylabel('Loss', fontsize=12) + ax.set_title(args.title, fontsize=14) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + + # Add final loss values as text + final_baseline = baseline_losses[-1] + final_bf16 = bf16_losses[-1] + textstr = f'Final Loss:\n Baseline: {final_baseline:.4f}\n BF16: {final_bf16:.4f}' + props = dict(boxstyle='round', facecolor='wheat', alpha=0.5) + ax.text(0.98, 0.98, textstr, transform=ax.transAxes, fontsize=10, + verticalalignment='top', horizontalalignment='right', bbox=props) + + plt.tight_layout() + plt.savefig(args.output, dpi=150) + print(f"Plot saved to: {args.output}") + + # Print statistics + print(f"\nStatistics:") + print(f" Baseline - Final loss: {final_baseline:.4f}, Mean: {baseline_losses.mean():.4f}, Std: {baseline_losses.std():.4f}") + print(f" BF16 - Final loss: {final_bf16:.4f}, Mean: {bf16_losses.mean():.4f}, Std: {bf16_losses.std():.4f}") + + +if __name__ == "__main__": + main() diff --git a/training/bf16_master_weight/run_comparison.sh b/training/bf16_master_weight/run_comparison.sh new file mode 100755 index 000000000..760e507d0 --- /dev/null +++ b/training/bf16_master_weight/run_comparison.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# Run memory comparison across different bf16 configurations +# +# Usage: +# ./run_comparison.sh [--num_gpus N] [--num_layers L] [--hidden_dim H] +# +# This script runs training with two configurations: +# 1. baseline - Standard bf16 with fp32 master weights/grads/optimizer states +# 2. bf16_full - bf16 master weights, gradients, and optimizer states + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Default settings +NUM_GPUS=1 +NUM_LAYERS=12 +HIDDEN_DIM=1024 +BATCH_SIZE=4 +SEQ_LENGTH=512 +NUM_STEPS=20 +WARMUP_STEPS=5 +MASTER_PORT=29600 + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --num_gpus) + NUM_GPUS="$2" + shift 2 + ;; + --num_layers) + NUM_LAYERS="$2" + shift 2 + ;; + --hidden_dim) + HIDDEN_DIM="$2" + shift 2 + ;; + --batch_size) + BATCH_SIZE="$2" + shift 2 + ;; + --seq_length) + SEQ_LENGTH="$2" + shift 2 + ;; + --num_steps) + NUM_STEPS="$2" + shift 2 + ;; + --master_port) + MASTER_PORT="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +# Create logs directory +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +LOG_DIR="logs/${TIMESTAMP}" +mkdir -p "$LOG_DIR" + +echo "==============================================" +echo "BF16 Low-Precision Master Weights Memory Test" +echo "==============================================" +echo "Configuration:" +echo " NUM_GPUS: $NUM_GPUS" +echo " NUM_LAYERS: $NUM_LAYERS" +echo " HIDDEN_DIM: $HIDDEN_DIM" +echo " BATCH_SIZE: $BATCH_SIZE" +echo " SEQ_LENGTH: $SEQ_LENGTH" +echo " NUM_STEPS: $NUM_STEPS" +echo " LOG_DIR: $LOG_DIR" +echo "==============================================" + +COMMON_ARGS="--num_layers $NUM_LAYERS --hidden_dim $HIDDEN_DIM --batch_size $BATCH_SIZE --seq_length $SEQ_LENGTH --num_steps $NUM_STEPS --warmup_steps $WARMUP_STEPS" + +# Run baseline configuration +echo "" +echo "[1/2] Running BASELINE (bf16 with fp32 master weights/grads/optimizer states)..." +echo "----------------------------------------------" +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train.py \ + --deepspeed_config configs/baseline.json \ + $COMMON_ARGS \ + 2>&1 | tee "$LOG_DIR/baseline.log" + +# Run bf16_full configuration +echo "" +echo "[2/2] Running BF16_FULL (bf16 master weights, gradients, and optimizer states)..." +echo "----------------------------------------------" +deepspeed --num_gpus=$NUM_GPUS --master_port=$MASTER_PORT train.py \ + --deepspeed_config configs/bf16_full.json \ + $COMMON_ARGS \ + 2>&1 | tee "$LOG_DIR/bf16_full.log" + +echo "" +echo "==============================================" +echo "All runs complete. Gathering results..." +echo "==============================================" + +# Run the gather script to create summary +python gather_memory.py --log_dir "$LOG_DIR" + +echo "" +echo "Results saved to: $LOG_DIR/summary.txt" +echo "==============================================" diff --git a/training/bf16_master_weight/train.py b/training/bf16_master_weight/train.py new file mode 100644 index 000000000..230c95367 --- /dev/null +++ b/training/bf16_master_weight/train.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +""" +Training script demonstrating bf16_master_weights_and_grads and bf16_optimizer_states +options in DeepSpeed for reduced memory usage. + +Usage: + deepspeed --num_gpus=1 train.py --deepspeed_config configs/baseline.json + deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_master_wg.json + deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_full.json +""" + +import argparse +import time +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +import deepspeed +from deepspeed.accelerator import get_accelerator +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + + +class SimpleTransformerBlock(nn.Module): + """A simple transformer block for demonstration.""" + + def __init__(self, hidden_dim, num_heads=8, ff_dim=None): + super().__init__() + ff_dim = ff_dim or hidden_dim * 4 + + self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True) + self.norm1 = nn.LayerNorm(hidden_dim) + self.norm2 = nn.LayerNorm(hidden_dim) + + self.ff = nn.Sequential( + nn.Linear(hidden_dim, ff_dim), + nn.GELU(), + nn.Linear(ff_dim, hidden_dim), + ) + + def forward(self, x): + # Self-attention with residual + attn_out, _ = self.attention(x, x, x) + x = self.norm1(x + attn_out) + + # Feed-forward with residual + ff_out = self.ff(x) + x = self.norm2(x + ff_out) + + return x + + +class SimpleTransformerModel(nn.Module): + """A simple transformer model for memory benchmarking.""" + + def __init__(self, vocab_size=50000, hidden_dim=1024, num_layers=12, num_heads=16, max_seq_len=512): + super().__init__() + self.hidden_dim = hidden_dim + self.activation_checkpointing = False + + self.embedding = nn.Embedding(vocab_size, hidden_dim) + self.pos_embedding = nn.Embedding(max_seq_len, hidden_dim) + + self.layers = nn.ModuleList([ + SimpleTransformerBlock(hidden_dim, num_heads) for _ in range(num_layers) + ]) + + self.output_proj = nn.Linear(hidden_dim, vocab_size) + + def enable_activation_checkpointing(self): + self.activation_checkpointing = True + + def forward(self, input_ids): + batch_size, seq_len = input_ids.shape + positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) + + x = self.embedding(input_ids) + self.pos_embedding(positions) + + for layer in self.layers: + if self.activation_checkpointing: + x = checkpoint(layer, x, use_reentrant=False) + else: + x = layer(x) + + logits = self.output_proj(x) + return logits + + +def get_args(): + parser = argparse.ArgumentParser(description="BF16 Low-Precision Master Weights Demo") + + # Model configuration + parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension size") + parser.add_argument("--num_layers", type=int, default=12, help="Number of transformer layers") + parser.add_argument("--num_heads", type=int, default=16, help="Number of attention heads") + parser.add_argument("--vocab_size", type=int, default=50000, help="Vocabulary size") + + # Training configuration + parser.add_argument("--batch_size", type=int, default=4, help="Batch size per GPU") + parser.add_argument("--seq_length", type=int, default=512, help="Sequence length") + parser.add_argument("--num_steps", type=int, default=20, help="Number of training steps") + parser.add_argument("--warmup_steps", type=int, default=5, help="Warmup steps before measuring") + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") + + # DeepSpeed configuration + parser.add_argument("--deepspeed_config", type=str, required=True, help="Path to DeepSpeed config") + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training") + + # Logging + parser.add_argument("--log_interval", type=int, default=5, help="Log interval") + + # Activation checkpointing + parser.add_argument("--activation_checkpointing", action="store_true", help="Enable activation checkpointing") + + # Loss logging + parser.add_argument("--loss_log_file", type=str, default=None, help="File to save loss values for plotting") + + # Dataset + parser.add_argument("--use_real_data", action="store_true", help="Use wikitext dataset instead of random data") + parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") + + return parser.parse_args() + + +def load_wikitext_data(tokenizer_name, seq_length, batch_size, world_size, rank): + """Load wikitext dataset for training.""" + from datasets import load_dataset + from transformers import AutoTokenizer + + print(f"[Rank {rank}] Loading wikitext dataset...") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load dataset + dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train[:5%]') + + # Filter empty texts + dataset = dataset.filter(lambda x: len(x['text'].strip()) > 50) + + # Tokenize + def tokenize_function(examples): + return tokenizer( + examples['text'], + padding='max_length', + max_length=seq_length, + truncation=True, + return_tensors=None + ) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text']) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + # Create distributed sampler and dataloader + sampler = DistributedSampler(tokenized_dataset, num_replicas=world_size, rank=rank, shuffle=True) + dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, sampler=sampler, num_workers=2) + + print(f"[Rank {rank}] Dataset loaded: {len(tokenized_dataset)} examples, vocab_size={tokenizer.vocab_size}") + return dataloader, tokenizer.vocab_size + + +def count_parameters(model): + """Count total and trainable parameters.""" + total = sum(p.numel() for p in model.parameters()) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + return total, trainable + + +def get_memory_stats(): + """Get current GPU memory statistics.""" + if not torch.cuda.is_available(): + return {"allocated": 0, "reserved": 0, "peak": 0} + + return { + "allocated": torch.cuda.memory_allocated(), + "reserved": torch.cuda.memory_reserved(), + "peak": torch.cuda.max_memory_allocated(), + } + + +def format_memory(bytes_val): + """Format bytes to human readable string.""" + gb = bytes_val / (1024 ** 3) + return f"{gb:.2f} GB" + + +def main(): + args = get_args() + + # Initialize distributed + deepspeed.init_distributed() + local_rank = args.local_rank + if local_rank == -1: + local_rank = int(deepspeed.comm.get_rank()) + + world_size = deepspeed.comm.get_world_size() + + device = get_accelerator().device_name(local_rank) + torch.cuda.set_device(local_rank) + + # Set seed for reproducibility + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + + # Load real data if requested + dataloader = None + actual_vocab_size = args.vocab_size + if args.use_real_data: + dataloader, actual_vocab_size = load_wikitext_data( + tokenizer_name="gpt2", + seq_length=args.seq_length, + batch_size=args.batch_size, + world_size=world_size, + rank=local_rank + ) + + # Create model + print(f"[Rank {local_rank}] Creating model with hidden_dim={args.hidden_dim}, " + f"num_layers={args.num_layers}, num_heads={args.num_heads}, vocab_size={actual_vocab_size}") + + model = SimpleTransformerModel( + vocab_size=actual_vocab_size, + hidden_dim=args.hidden_dim, + num_layers=args.num_layers, + num_heads=args.num_heads, + max_seq_len=args.seq_length, + ) + + total_params, trainable_params = count_parameters(model) + print(f"[Rank {local_rank}] Model parameters: {total_params:,} total, {trainable_params:,} trainable") + + # Enable activation checkpointing if requested + if args.activation_checkpointing: + model.enable_activation_checkpointing() + print(f"[Rank {local_rank}] Activation checkpointing enabled") + + # Read config to check if torch_autocast is enabled + import json + with open(args.deepspeed_config, 'r') as f: + ds_config = json.load(f) + + use_autocast = ds_config.get("torch_autocast", {}).get("enabled", False) + autocast_dtype_str = ds_config.get("torch_autocast", {}).get("dtype", "torch.bfloat16") + autocast_dtype = torch.bfloat16 if "bfloat16" in autocast_dtype_str else torch.float16 + + # Initialize DeepSpeed - use config file path directly (not via args to avoid conflict) + model_engine, optimizer, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=args.deepspeed_config, + ) + + print(f"[Rank {local_rank}] DeepSpeed initialized with config: {args.deepspeed_config}") + + mem_after_init = get_memory_stats() + print(f"[Rank {local_rank}] Memory after init: allocated={format_memory(mem_after_init['allocated'])}, " + f"reserved={format_memory(mem_after_init['reserved'])}") + + # Training loop + model_engine.train() + loss_fn = nn.CrossEntropyLoss() + + total_time = 0 + step_times = [] + loss_history = [] + + # Setup data iterator + if dataloader is not None: + data_iter = iter(dataloader) + + for step in range(args.num_steps): + start_time = time.time() + + # Get input data + if dataloader is not None: + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(dataloader) + batch = next(data_iter) + input_ids = batch['input_ids'].to(device) + labels = input_ids.clone() # For language modeling, labels = input_ids + else: + # Generate random input data + input_ids = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device) + labels = torch.randint(0, actual_vocab_size, (args.batch_size, args.seq_length), device=device) + + # Forward pass with optional autocast + if use_autocast: + with torch.autocast(device_type="cuda", dtype=autocast_dtype): + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, actual_vocab_size), labels.view(-1)) + else: + logits = model_engine(input_ids) + loss = loss_fn(logits.view(-1, actual_vocab_size), labels.view(-1)) + + # Backward pass - use PyTorch-style backward + loss.backward() + + # Optimizer step + model_engine.step() + + step_time = time.time() - start_time + + if step >= args.warmup_steps: + step_times.append(step_time) + + # Record loss for plotting + loss_history.append((step, loss.item())) + + if step % args.log_interval == 0 or step == args.num_steps - 1: + mem_stats = get_memory_stats() + print(f"[Rank {local_rank}] Step {step}: loss={loss.item():.4f}, " + f"time={step_time:.3f}s, " + f"alloc_mem={format_memory(mem_stats['allocated'])}, " + f"peak_mem={format_memory(mem_stats['peak'])}") + + # Final statistics + final_mem = get_memory_stats() + avg_step_time = sum(step_times) / len(step_times) if step_times else 0 + + print("\n" + "=" * 60) + print(f"[Rank {local_rank}] FINAL RESULTS") + print(f" Config: {args.deepspeed_config}") + print(f" Model: hidden_dim={args.hidden_dim}, num_layers={args.num_layers}") + print(f" Parameters: {total_params:,}") + print(f" Batch size: {args.batch_size}, Seq length: {args.seq_length}") + print(f" Average step time: {avg_step_time:.3f}s") + print(f" Peak memory: {format_memory(final_mem['peak'])}") + print(f" Final allocated memory: {format_memory(final_mem['allocated'])}") + print("=" * 60) + + # Output machine-readable summary line for parsing + print(f"SUMMARY: config={args.deepspeed_config} params={total_params} " + f"peak_mem_bytes={final_mem['peak']} alloc_mem_bytes={final_mem['allocated']} " + f"avg_step_time={avg_step_time:.4f}") + + # Save loss history to file if requested (only rank 0) + if args.loss_log_file and local_rank == 0: + import csv + with open(args.loss_log_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['step', 'loss']) + writer.writerows(loss_history) + print(f"Loss history saved to: {args.loss_log_file}") + + model_engine.destroy() + + +if __name__ == "__main__": + main() From fe10b7ee823edc5e34ab05152652fb0056d50b1a Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 4 Dec 2025 19:05:25 -0800 Subject: [PATCH 2/2] add results Signed-off-by: Masahiro Tanaka --- training/bf16_master_weight/README.md | 184 +-- .../logs/7b_loss_run/baseline_loss.csv | 1001 +++++++++++++++++ .../logs/7b_loss_run/bf16_full_loss.csv | 1001 +++++++++++++++++ .../logs/7b_loss_run/loss_comparison.png | Bin 0 -> 165941 bytes 4 files changed, 2022 insertions(+), 164 deletions(-) create mode 100644 training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv create mode 100644 training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv create mode 100644 training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png diff --git a/training/bf16_master_weight/README.md b/training/bf16_master_weight/README.md index 2e039b2ea..2af73a278 100644 --- a/training/bf16_master_weight/README.md +++ b/training/bf16_master_weight/README.md @@ -1,80 +1,17 @@ # BF16 Low-Precision Master Weights and Optimizer States -This example demonstrates DeepSpeed's new low-precision training options that can significantly reduce memory usage: +This example demonstrates DeepSpeed's [new low-precision training options](https://github.com/deepspeedai/DeepSpeed/pull/7700) that can significantly reduce memory usage: - `bf16_master_weights_and_grads`: Keep master parameters and gradients in BF16 instead of FP32 - `bf16_optimizer_states`: Keep optimizer states (e.g., Adam moments) in BF16 -These options work with ZeRO Stage 3 and `torch.autocast` to provide memory-efficient training while maintaining numerical stability. -## Memory Savings +### Running an Example -Using a 254M parameter simple transformer model with the following configuration: -- Hidden dimension: 1024 -- Layers: 12 -- Attention heads: 16 -- Batch size: 4 -- Sequence length: 512 -- ZeRO Stage: 3 +The following commands run training for 1000 steps on the Wikitext-103 dataset using both the baseline and BF16 low-precision configurations, then generates a loss comparison plot. +The model has approximately 6.86 billion parameters (hidden=4096, layers=32, heads=32, batch=1, seq=512). +For BF16 low-precision training, we use `torch.autocast`. -### 1-GPU Results - -| Configuration | Allocated Memory | Peak Memory | Avg Step Time | -|---------------|------------------|-------------|---------------| -| Baseline (fp32 master) | 4.14 GB | 5.93 GB | 0.1042s | -| BF16 low-precision (master + opt states) | **2.71 GB** | 5.73 GB | 0.1121s | - -**Allocated memory reduction: 1.43 GB (34.5%)** - -### 4-GPU Results (per GPU) - 254M Model - -| Configuration | Allocated Memory | Peak Memory | Avg Step Time | -|---------------|------------------|-------------|---------------| -| Baseline (fp32 master) | 1.29 GB | 3.57 GB | 0.1189s | -| BF16 low-precision (master + opt states) | **0.94 GB** | 4.44 GB | 0.1249s | - -**Allocated memory reduction: 0.35 GB per GPU (27%)** - -### 4-GPU Results (per GPU) - 6.86B Model - -Using a 6.86B parameter model (hidden=4096, layers=32, heads=32, batch=1, seq=512): - -| Configuration | Allocated Memory | Peak Memory | Avg Step Time | -|---------------|------------------|-------------|---------------| -| Baseline (fp32 master) | 25.74 GB | 41.28 GB | 0.5078s | -| BF16 low-precision (master + opt states) | **16.17 GB** | **33.20 GB** | 0.5064s | - -**Memory reduction: 9.57 GB allocated (37%), 8.08 GB peak (19.6%)** - -### 4-GPU Results (per GPU) - 6.86B Model with Activation Checkpointing - -With activation checkpointing enabled, the optimizer state memory becomes the dominant factor, making the savings even more visible: - -| Configuration | Allocated Memory | Peak Memory | Avg Step Time | -|---------------|------------------|-------------|---------------| -| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s | -| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s | - -**Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)** - -With activation checkpointing, peak memory drops significantly for both configurations, but the bf16 low-precision option shows an even larger relative improvement - nearly **40% reduction in peak memory**. - -The allocated memory reflects the optimizer state memory, which is where the low-precision options provide savings. Peak memory includes activations and temporary buffers which can vary based on execution order. - -## Loss Curve Comparison - -To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset: - -![Loss Comparison](logs/7b_loss_run/loss_comparison.png) - -| Configuration | Final Loss | Mean Loss | Loss Std | -|---------------|------------|-----------|----------| -| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 | -| BF16 Low-Precision | 3.12 | 2.90 | 2.37 | - -The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings. - -To reproduce the loss curve comparison: ```bash # Run 1000 steps with wikitext dataset @@ -93,107 +30,27 @@ python plot_loss.py --baseline logs/baseline_loss.csv --bf16 logs/bf16_full_loss --output loss_comparison.png ``` -## Configuration +Here is a summary of the memory usage and training time results using 4xH100. +This shows a significant memory reduction: Memory reduction: 9.57 GB allocated (37%), 12.45 GB peak (39.7%)**. -### Baseline (FP32 master weights and optimizer states) - -```json -{ - "bf16": { - "enabled": true - }, - "zero_optimization": { - "stage": 3 - } -} -``` - -### BF16 Low-Precision (BF16 master weights, gradients, and optimizer states) - -```json -{ - "bf16": { - "enabled": true, - "bf16_master_weights_and_grads": true, - "bf16_optimizer_states": true - }, - "zero_optimization": { - "stage": 3 - }, - "torch_autocast": { - "enabled": true, - "dtype": "torch.bfloat16" - } -} -``` - -## Usage - -### Run Individual Configurations - -```bash -# Run baseline configuration -deepspeed --num_gpus=1 train.py --deepspeed_config configs/baseline.json - -# Run BF16 low-precision configuration -deepspeed --num_gpus=1 train.py --deepspeed_config configs/bf16_full.json -``` - -### Run Memory Comparison - -```bash -# Run both configurations and generate comparison report -./run_comparison.sh - -# With custom settings -./run_comparison.sh --num_layers 24 --hidden_dim 2048 --batch_size 2 -``` - -### Gather Results from Logs - -```bash -python gather_memory.py --log_dir logs/ -``` - -## Training Script Options - -``` ---hidden_dim Hidden dimension size (default: 1024) ---num_layers Number of transformer layers (default: 12) ---num_heads Number of attention heads (default: 16) ---vocab_size Vocabulary size (default: 50000) ---batch_size Batch size per GPU (default: 4) ---seq_length Sequence length (default: 512) ---num_steps Number of training steps (default: 20) ---warmup_steps Warmup steps before measuring (default: 5) ---deepspeed_config Path to DeepSpeed config file -``` - -## Requirements - -- DeepSpeed with BF16 support -- PyTorch with BF16 support -- GPU with BF16 support (e.g., NVIDIA Ampere or newer) - -## How It Works +| Configuration | Allocated Memory | Peak Memory | Avg Step Time | +|---------------|------------------|-------------|---------------| +| Baseline (fp32 master) | 25.74 GB | 31.38 GB | 0.6016s | +| BF16 low-precision (master + opt states) | **16.17 GB** | **18.93 GB** | 0.6427s | -### Standard BF16 Training (Baseline) -In standard BF16 training with DeepSpeed: -- Model parameters are stored in BF16 -- Forward/backward computations use BF16 via `torch.autocast` -- Master weights are maintained in FP32 for optimizer updates -- Optimizer states (Adam momentum and variance) are in FP32 +## Loss Curve Comparison -This requires significant memory for the FP32 copies. +To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset: -### BF16 Low-Precision Training +![Loss Comparison](logs/7b_loss_run/loss_comparison.png) -With the new options enabled: -- `bf16_master_weights_and_grads=true`: Master weights and gradients stay in BF16 -- `bf16_optimizer_states=true`: Adam momentum and variance buffers use BF16 +| Configuration | Final Loss | Mean Loss | Loss Std | +|---------------|------------|-----------|----------| +| Baseline (fp32 master) | 3.09 | 2.78 | 1.56 | +| BF16 Low-Precision | 3.12 | 2.90 | 2.37 | -This eliminates the FP32 copies, reducing memory by approximately 2 bytes per parameter for master weights and 4 bytes per parameter for optimizer states (for Adam which has 2 state buffers). +The loss curves show that both configurations converge similarly, demonstrating that the reduced precision does not significantly impact training quality while providing substantial memory savings. ### Memory Breakdown @@ -208,10 +65,9 @@ For a model with N parameters: | Adam variance | 4N bytes (FP32) | 2N bytes (BF16) | | **Total** | **18N bytes** | **10N bytes** | -This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors. +This gives a theoretical ~44% reduction in optimizer-related memory. The actual savings depend on activation memory and other factors, but our results show a very close match to the theoretical savings. ## Related Resources - [DeepSpeed BF16 Documentation](https://www.deepspeed.ai/docs/config-json/#bf16-training-options) -- [DeepSpeed Core API Updates Blog](../../blogs/core_api_update/README.md) - [Low-precision master params PR](https://github.com/deepspeedai/DeepSpeed/pull/7700) diff --git a/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv b/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv new file mode 100644 index 000000000..0915cc541 --- /dev/null +++ b/training/bf16_master_weight/logs/7b_loss_run/baseline_loss.csv @@ -0,0 +1,1001 @@ +step,loss +0,11.125 +1,7.15625 +2,5.875 +3,3.984375 +4,13.4375 +5,12.1875 +6,3.765625 +7,5.0625 +8,4.15625 +9,0.96484375 +10,8.625 +11,5.28125 +12,5.09375 +13,6.625 +14,4.0 +15,9.375 +16,7.0625 +17,5.28125 +18,4.625 +19,1.7890625 +20,1.453125 +21,2.0625 +22,2.984375 +23,2.234375 +24,5.53125 +25,2.859375 +26,4.3125 +27,5.3125 +28,2.546875 +29,5.15625 +30,2.0 +31,4.1875 +32,3.859375 +33,6.0625 +34,1.8671875 +35,4.96875 +36,1.8203125 +37,6.03125 +38,1.3359375 +39,1.515625 +40,1.1953125 +41,1.421875 +42,3.15625 +43,4.96875 +44,5.5 +45,1.53125 +46,2.0625 +47,3.078125 +48,3.375 +49,2.78125 +50,2.34375 +51,3.203125 +52,2.609375 +53,2.578125 +54,4.25 +55,3.578125 +56,6.21875 +57,3.734375 +58,4.8125 +59,2.859375 +60,3.546875 +61,0.8359375 +62,4.375 +63,4.71875 +64,2.203125 +65,2.015625 +66,3.703125 +67,3.234375 +68,1.296875 +69,1.515625 +70,1.5703125 +71,3.984375 +72,3.671875 +73,5.46875 +74,3.03125 +75,6.59375 +76,1.8828125 +77,2.875 +78,1.3984375 +79,0.984375 +80,4.65625 +81,2.78125 +82,2.9375 +83,2.59375 +84,5.5 +85,0.98046875 +86,1.4453125 +87,2.015625 +88,4.53125 +89,1.0234375 +90,1.9296875 +91,1.4453125 +92,2.5625 +93,2.703125 +94,3.484375 +95,6.1875 +96,2.265625 +97,1.8359375 +98,6.1875 +99,5.15625 +100,1.2734375 +101,2.90625 +102,1.59375 +103,0.54296875 +104,1.3046875 +105,2.171875 +106,2.484375 +107,1.5546875 +108,1.625 +109,0.60546875 +110,2.125 +111,2.046875 +112,5.96875 +113,1.9609375 +114,2.390625 +115,3.921875 +116,2.65625 +117,3.96875 +118,2.109375 +119,3.140625 +120,3.421875 +121,2.3125 +122,3.6875 +123,1.4140625 +124,0.63671875 +125,5.5625 +126,3.40625 +127,6.75 +128,2.40625 +129,2.71875 +130,3.34375 +131,4.90625 +132,6.28125 +133,2.203125 +134,5.34375 +135,1.4609375 +136,1.53125 +137,2.515625 +138,3.15625 +139,4.90625 +140,3.671875 +141,1.734375 +142,2.359375 +143,1.6796875 +144,2.15625 +145,0.6171875 +146,1.078125 +147,2.15625 +148,2.25 +149,3.5 +150,3.171875 +151,3.1875 +152,3.328125 +153,0.8984375 +154,0.87890625 +155,2.65625 +156,4.03125 +157,3.0625 +158,3.015625 +159,2.171875 +160,3.28125 +161,1.8046875 +162,4.125 +163,1.4140625 +164,4.25 +165,4.5 +166,3.359375 +167,3.1875 +168,3.5 +169,2.78125 +170,3.53125 +171,2.953125 +172,0.5546875 +173,2.015625 +174,1.1796875 +175,2.046875 +176,1.8125 +177,1.640625 +178,2.875 +179,4.59375 +180,0.65234375 +181,7.25 +182,3.75 +183,1.46875 +184,2.046875 +185,3.0625 +186,0.77734375 +187,1.5859375 +188,2.1875 +189,0.51953125 +190,4.28125 +191,2.890625 +192,4.46875 +193,1.0 +194,2.578125 +195,1.90625 +196,5.0 +197,1.71875 +198,6.59375 +199,0.6015625 +200,1.2578125 +201,4.25 +202,0.43359375 +203,1.2265625 +204,1.875 +205,2.921875 +206,2.859375 +207,1.8828125 +208,1.09375 +209,7.8125 +210,2.390625 +211,1.8125 +212,1.2578125 +213,0.671875 +214,1.9765625 +215,1.59375 +216,2.796875 +217,2.03125 +218,1.875 +219,3.0625 +220,1.0859375 +221,3.140625 +222,1.265625 +223,0.7265625 +224,0.96875 +225,3.6875 +226,4.03125 +227,4.53125 +228,2.59375 +229,1.9921875 +230,0.82421875 +231,0.7890625 +232,3.078125 +233,3.65625 +234,3.578125 +235,2.265625 +236,1.1484375 +237,2.875 +238,2.5625 +239,0.94921875 +240,2.78125 +241,0.64453125 +242,4.03125 +243,1.2421875 +244,3.171875 +245,1.9765625 +246,0.5859375 +247,2.828125 +248,3.9375 +249,1.140625 +250,2.5 +251,3.90625 +252,5.0625 +253,6.15625 +254,2.609375 +255,2.4375 +256,2.453125 +257,1.8359375 +258,1.890625 +259,2.640625 +260,1.609375 +261,4.84375 +262,3.859375 +263,0.86328125 +264,0.81640625 +265,5.28125 +266,1.9921875 +267,5.8125 +268,0.828125 +269,0.984375 +270,0.90234375 +271,2.890625 +272,6.03125 +273,2.40625 +274,3.453125 +275,3.015625 +276,2.3125 +277,2.984375 +278,3.65625 +279,3.3125 +280,5.8125 +281,0.7421875 +282,5.1875 +283,3.109375 +284,2.34375 +285,0.8203125 +286,1.8671875 +287,1.671875 +288,0.6875 +289,4.09375 +290,2.890625 +291,1.171875 +292,2.953125 +293,3.65625 +294,0.7578125 +295,1.59375 +296,2.171875 +297,2.1875 +298,3.609375 +299,3.34375 +300,3.515625 +301,7.75 +302,3.203125 +303,4.84375 +304,3.125 +305,3.6875 +306,3.9375 +307,0.8046875 +308,0.625 +309,1.609375 +310,0.62890625 +311,2.0625 +312,2.296875 +313,1.0859375 +314,2.671875 +315,2.859375 +316,1.703125 +317,2.71875 +318,2.546875 +319,2.234375 +320,3.53125 +321,3.234375 +322,2.84375 +323,1.015625 +324,3.5625 +325,0.60546875 +326,1.5546875 +327,1.328125 +328,4.46875 +329,3.109375 +330,2.171875 +331,1.8984375 +332,4.34375 +333,2.515625 +334,2.75 +335,2.6875 +336,2.953125 +337,3.828125 +338,2.171875 +339,4.5625 +340,4.0 +341,1.6015625 +342,3.28125 +343,4.65625 +344,2.28125 +345,9.375 +346,5.46875 +347,3.5 +348,1.1796875 +349,2.890625 +350,3.03125 +351,2.640625 +352,1.75 +353,4.96875 +354,2.359375 +355,1.7109375 +356,3.8125 +357,4.65625 +358,2.6875 +359,1.5703125 +360,2.53125 +361,2.5625 +362,2.25 +363,1.8515625 +364,1.9453125 +365,2.71875 +366,0.890625 +367,3.875 +368,3.46875 +369,3.0 +370,0.8828125 +371,3.359375 +372,1.6171875 +373,5.25 +374,3.84375 +375,3.140625 +376,0.98046875 +377,4.875 +378,2.671875 +379,1.0546875 +380,1.3828125 +381,3.078125 +382,1.1640625 +383,3.234375 +384,2.140625 +385,3.671875 +386,3.8125 +387,1.0 +388,4.5 +389,4.09375 +390,6.96875 +391,2.9375 +392,2.140625 +393,5.84375 +394,5.875 +395,3.25 +396,3.453125 +397,1.8125 +398,1.3828125 +399,0.9140625 +400,2.765625 +401,1.40625 +402,0.84375 +403,3.171875 +404,2.359375 +405,0.65625 +406,2.296875 +407,2.390625 +408,2.578125 +409,0.75390625 +410,4.1875 +411,1.3828125 +412,3.640625 +413,4.5625 +414,3.671875 +415,5.96875 +416,4.1875 +417,1.90625 +418,2.046875 +419,4.6875 +420,0.66796875 +421,1.9140625 +422,3.375 +423,4.78125 +424,1.8046875 +425,2.8125 +426,2.296875 +427,1.0390625 +428,1.9921875 +429,1.6484375 +430,0.9140625 +431,4.5 +432,2.390625 +433,2.03125 +434,2.546875 +435,2.109375 +436,5.65625 +437,6.53125 +438,2.28125 +439,4.96875 +440,0.90234375 +441,2.25 +442,3.3125 +443,5.0625 +444,2.46875 +445,3.1875 +446,1.953125 +447,6.46875 +448,4.78125 +449,6.5625 +450,3.015625 +451,1.1640625 +452,1.7109375 +453,4.125 +454,3.15625 +455,5.03125 +456,1.3671875 +457,1.7890625 +458,1.734375 +459,4.125 +460,4.53125 +461,3.578125 +462,0.77734375 +463,1.546875 +464,4.5625 +465,4.40625 +466,1.3984375 +467,2.03125 +468,2.109375 +469,3.5 +470,2.21875 +471,1.4140625 +472,5.09375 +473,3.40625 +474,3.34375 +475,4.40625 +476,1.90625 +477,2.484375 +478,1.84375 +479,1.3359375 +480,2.765625 +481,9.125 +482,3.625 +483,3.828125 +484,1.8046875 +485,4.8125 +486,1.84375 +487,1.1484375 +488,1.3828125 +489,1.5078125 +490,1.84375 +491,2.921875 +492,2.859375 +493,2.234375 +494,4.1875 +495,2.796875 +496,1.2109375 +497,2.5625 +498,1.921875 +499,3.5 +500,0.85546875 +501,3.34375 +502,2.5 +503,2.859375 +504,2.96875 +505,0.91796875 +506,2.796875 +507,1.359375 +508,3.796875 +509,2.390625 +510,5.59375 +511,5.65625 +512,1.703125 +513,1.1796875 +514,2.21875 +515,0.61328125 +516,2.703125 +517,1.921875 +518,1.8828125 +519,2.796875 +520,2.859375 +521,3.328125 +522,1.21875 +523,2.234375 +524,1.421875 +525,1.3984375 +526,1.6875 +527,2.296875 +528,4.53125 +529,2.796875 +530,2.28125 +531,4.40625 +532,2.703125 +533,3.21875 +534,1.4609375 +535,2.578125 +536,1.9375 +537,2.765625 +538,4.0 +539,2.453125 +540,3.8125 +541,3.703125 +542,2.46875 +543,3.625 +544,2.734375 +545,2.46875 +546,2.625 +547,3.109375 +548,0.99609375 +549,0.7109375 +550,3.28125 +551,0.6171875 +552,3.265625 +553,2.734375 +554,2.1875 +555,1.734375 +556,7.65625 +557,5.34375 +558,3.03125 +559,3.671875 +560,2.21875 +561,2.5 +562,2.796875 +563,4.78125 +564,3.609375 +565,0.81640625 +566,0.78515625 +567,4.46875 +568,3.921875 +569,3.375 +570,0.9765625 +571,4.90625 +572,4.03125 +573,3.453125 +574,2.1875 +575,0.59375 +576,0.4921875 +577,1.0 +578,4.28125 +579,2.421875 +580,1.984375 +581,3.34375 +582,2.03125 +583,3.84375 +584,1.7265625 +585,1.3359375 +586,3.0625 +587,2.671875 +588,2.59375 +589,3.171875 +590,2.875 +591,0.61328125 +592,1.328125 +593,0.609375 +594,2.875 +595,1.6640625 +596,2.640625 +597,4.28125 +598,0.7890625 +599,3.484375 +600,1.6328125 +601,3.59375 +602,0.84375 +603,0.65625 +604,2.515625 +605,1.453125 +606,1.2734375 +607,3.578125 +608,5.78125 +609,0.67578125 +610,3.453125 +611,1.4296875 +612,2.09375 +613,2.90625 +614,3.3125 +615,0.65625 +616,2.265625 +617,1.8359375 +618,3.578125 +619,5.375 +620,2.9375 +621,5.6875 +622,1.7265625 +623,1.578125 +624,0.75390625 +625,3.875 +626,3.25 +627,1.953125 +628,3.796875 +629,1.6953125 +630,2.328125 +631,2.34375 +632,3.265625 +633,0.79296875 +634,2.953125 +635,3.140625 +636,2.453125 +637,2.53125 +638,3.171875 +639,2.78125 +640,2.421875 +641,5.625 +642,3.015625 +643,3.203125 +644,0.62109375 +645,0.67578125 +646,2.375 +647,3.8125 +648,4.53125 +649,1.34375 +650,4.5 +651,2.84375 +652,1.46875 +653,3.0 +654,3.5625 +655,3.03125 +656,1.6875 +657,0.94140625 +658,1.3671875 +659,1.109375 +660,0.765625 +661,4.03125 +662,2.546875 +663,3.453125 +664,3.328125 +665,5.25 +666,2.0 +667,2.5625 +668,1.859375 +669,1.984375 +670,2.203125 +671,0.64453125 +672,3.171875 +673,1.8828125 +674,5.5 +675,0.546875 +676,2.0 +677,1.4765625 +678,2.28125 +679,3.0 +680,1.9921875 +681,3.640625 +682,2.609375 +683,1.7890625 +684,3.046875 +685,0.6875 +686,3.34375 +687,0.6875 +688,1.265625 +689,1.75 +690,2.984375 +691,1.5546875 +692,1.09375 +693,3.296875 +694,0.9765625 +695,0.7109375 +696,2.53125 +697,5.78125 +698,3.4375 +699,5.0 +700,3.078125 +701,3.5625 +702,4.78125 +703,2.234375 +704,0.62109375 +705,2.21875 +706,3.03125 +707,3.25 +708,2.421875 +709,4.59375 +710,3.0625 +711,2.609375 +712,2.953125 +713,3.015625 +714,1.5390625 +715,1.1328125 +716,0.83984375 +717,1.9140625 +718,3.296875 +719,3.34375 +720,4.03125 +721,3.28125 +722,1.421875 +723,3.59375 +724,5.09375 +725,2.890625 +726,2.65625 +727,3.234375 +728,4.625 +729,1.4296875 +730,2.234375 +731,3.578125 +732,1.5234375 +733,2.375 +734,1.3359375 +735,0.72265625 +736,3.515625 +737,2.09375 +738,6.8125 +739,1.984375 +740,2.296875 +741,4.125 +742,2.203125 +743,3.71875 +744,2.203125 +745,4.1875 +746,1.5546875 +747,5.46875 +748,1.71875 +749,2.703125 +750,2.765625 +751,4.21875 +752,4.59375 +753,1.6640625 +754,5.5625 +755,1.5859375 +756,0.859375 +757,0.73828125 +758,2.71875 +759,3.90625 +760,2.078125 +761,2.1875 +762,2.078125 +763,1.59375 +764,2.609375 +765,2.53125 +766,4.875 +767,0.5078125 +768,4.03125 +769,3.328125 +770,1.703125 +771,2.734375 +772,3.59375 +773,0.82421875 +774,2.734375 +775,0.69921875 +776,2.21875 +777,3.15625 +778,2.421875 +779,1.734375 +780,2.578125 +781,2.578125 +782,1.0859375 +783,1.65625 +784,2.59375 +785,4.1875 +786,2.515625 +787,2.15625 +788,4.1875 +789,0.984375 +790,0.71875 +791,2.734375 +792,1.03125 +793,1.796875 +794,1.9296875 +795,3.296875 +796,5.09375 +797,4.25 +798,2.75 +799,2.015625 +800,1.65625 +801,2.859375 +802,2.3125 +803,0.8515625 +804,4.5625 +805,3.015625 +806,2.734375 +807,1.5234375 +808,1.328125 +809,1.3828125 +810,2.546875 +811,2.25 +812,1.9453125 +813,1.71875 +814,0.921875 +815,2.625 +816,3.53125 +817,4.125 +818,1.53125 +819,4.0 +820,1.359375 +821,1.7578125 +822,2.796875 +823,4.5625 +824,2.40625 +825,0.71875 +826,1.7265625 +827,0.69921875 +828,1.4609375 +829,3.96875 +830,1.078125 +831,3.296875 +832,2.109375 +833,1.390625 +834,4.28125 +835,3.078125 +836,1.671875 +837,1.5625 +838,1.890625 +839,2.109375 +840,1.5546875 +841,2.4375 +842,1.1015625 +843,2.015625 +844,2.15625 +845,1.609375 +846,3.65625 +847,2.625 +848,2.109375 +849,1.546875 +850,4.6875 +851,3.140625 +852,3.328125 +853,0.78125 +854,2.46875 +855,6.0 +856,4.28125 +857,0.8984375 +858,4.40625 +859,1.3828125 +860,1.53125 +861,0.96875 +862,2.3125 +863,1.640625 +864,1.7734375 +865,4.1875 +866,1.78125 +867,2.984375 +868,2.484375 +869,1.203125 +870,3.984375 +871,2.890625 +872,2.171875 +873,3.234375 +874,3.140625 +875,0.625 +876,1.171875 +877,3.5 +878,1.3125 +879,0.58984375 +880,2.21875 +881,1.3359375 +882,2.46875 +883,2.953125 +884,4.21875 +885,2.78125 +886,2.28125 +887,0.734375 +888,2.421875 +889,2.890625 +890,1.6328125 +891,3.28125 +892,3.234375 +893,4.59375 +894,1.390625 +895,2.359375 +896,1.4921875 +897,2.578125 +898,2.65625 +899,3.734375 +900,3.140625 +901,1.96875 +902,1.7578125 +903,3.03125 +904,0.703125 +905,2.921875 +906,1.734375 +907,1.3046875 +908,2.53125 +909,1.7109375 +910,3.828125 +911,2.828125 +912,1.7734375 +913,1.3515625 +914,2.46875 +915,3.125 +916,0.62109375 +917,5.03125 +918,4.125 +919,3.171875 +920,0.7734375 +921,1.5703125 +922,3.109375 +923,1.9609375 +924,2.671875 +925,2.84375 +926,5.59375 +927,0.73046875 +928,3.03125 +929,2.78125 +930,3.375 +931,2.84375 +932,3.03125 +933,5.34375 +934,1.3984375 +935,1.1484375 +936,2.5 +937,1.734375 +938,2.421875 +939,3.90625 +940,2.515625 +941,3.453125 +942,2.1875 +943,3.375 +944,4.34375 +945,4.03125 +946,1.4765625 +947,2.453125 +948,1.921875 +949,4.6875 +950,1.71875 +951,1.9140625 +952,3.265625 +953,1.6640625 +954,6.4375 +955,2.75 +956,2.171875 +957,5.28125 +958,3.328125 +959,1.6484375 +960,2.546875 +961,2.140625 +962,6.09375 +963,2.09375 +964,1.203125 +965,1.328125 +966,2.734375 +967,4.09375 +968,2.203125 +969,2.296875 +970,6.75 +971,1.0234375 +972,5.46875 +973,6.84375 +974,5.875 +975,2.828125 +976,2.4375 +977,3.03125 +978,5.5 +979,2.484375 +980,3.0625 +981,1.375 +982,3.265625 +983,3.25 +984,2.109375 +985,0.80078125 +986,0.470703125 +987,3.234375 +988,1.7578125 +989,2.109375 +990,0.9296875 +991,4.28125 +992,2.15625 +993,2.0625 +994,1.3984375 +995,1.65625 +996,0.66796875 +997,1.6328125 +998,2.875 +999,3.09375 diff --git a/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv b/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv new file mode 100644 index 000000000..362c11199 --- /dev/null +++ b/training/bf16_master_weight/logs/7b_loss_run/bf16_full_loss.csv @@ -0,0 +1,1001 @@ +step,loss +0,11.1134033203125 +1,7.132085800170898 +2,4.012870788574219 +3,36.0167236328125 +4,24.2044677734375 +5,26.455078125 +6,20.6685791015625 +7,27.5283203125 +8,21.50885009765625 +9,4.9122314453125 +10,8.48201847076416 +11,3.808837413787842 +12,8.899911880493164 +13,8.652142524719238 +14,3.5635316371917725 +15,6.893402099609375 +16,6.127050399780273 +17,4.923583984375 +18,4.4448394775390625 +19,2.13494873046875 +20,2.0114288330078125 +21,1.949951171875 +22,2.407745361328125 +23,1.5159912109375 +24,6.731402397155762 +25,3.2044296264648438 +26,4.803050994873047 +27,5.274627685546875 +28,2.7762451171875 +29,5.0152130126953125 +30,2.579620361328125 +31,4.108551025390625 +32,3.868255615234375 +33,6.349365234375 +34,1.6423149108886719 +35,5.202238082885742 +36,1.54669189453125 +37,6.46185302734375 +38,1.2036895751953125 +39,1.492156982421875 +40,1.2420654296875 +41,1.4814453125 +42,3.099506378173828 +43,4.7874755859375 +44,5.309574127197266 +45,1.6868438720703125 +46,2.1453781127929688 +47,3.08148193359375 +48,3.3328895568847656 +49,2.7371826171875 +50,2.27471923828125 +51,3.1996421813964844 +52,2.597991943359375 +53,2.57806396484375 +54,4.214580535888672 +55,3.5824012756347656 +56,6.238430023193359 +57,3.6706466674804688 +58,4.710197448730469 +59,2.9067230224609375 +60,3.5498733520507812 +61,0.910888671875 +62,4.35565185546875 +63,4.719722747802734 +64,2.167003631591797 +65,1.9737205505371094 +66,3.7718124389648438 +67,3.2602691650390625 +68,1.2766494750976562 +69,1.5045166015625 +70,1.5656356811523438 +71,3.9177169799804688 +72,3.622711181640625 +73,5.317466735839844 +74,3.068389892578125 +75,6.45367431640625 +76,1.9294891357421875 +77,2.8223037719726562 +78,1.311737060546875 +79,0.8475265502929688 +80,4.85711669921875 +81,2.817279815673828 +82,2.951894760131836 +83,2.5832557678222656 +84,5.453367233276367 +85,1.0596847534179688 +86,1.5508003234863281 +87,2.1005859375 +88,4.404327392578125 +89,1.083099365234375 +90,1.9316482543945312 +91,1.4068145751953125 +92,2.5340423583984375 +93,2.6919097900390625 +94,3.4820175170898438 +95,6.285247802734375 +96,2.244140625 +97,1.8041229248046875 +98,6.160301208496094 +99,5.0840911865234375 +100,1.2824630737304688 +101,2.894184112548828 +102,1.6238822937011719 +103,0.6098289489746094 +104,1.3268051147460938 +105,2.1666526794433594 +106,2.453876495361328 +107,1.551025390625 +108,1.6232891082763672 +109,0.5829753875732422 +110,2.1284561157226562 +111,2.0172958374023438 +112,6.07049560546875 +113,1.935211181640625 +114,2.33990478515625 +115,3.9076614379882812 +116,2.6385498046875 +117,3.9417266845703125 +118,2.087188720703125 +119,3.135406494140625 +120,3.41455078125 +121,2.2973289489746094 +122,3.6464080810546875 +123,1.4052810668945312 +124,0.600006103515625 +125,5.67963981628418 +126,3.4544477462768555 +127,6.842742919921875 +128,2.3925647735595703 +129,2.6814651489257812 +130,3.316925048828125 +131,4.820037841796875 +132,6.195037841796875 +133,2.1667022705078125 +134,5.387062072753906 +135,1.3815345764160156 +136,1.4716072082519531 +137,2.519887924194336 +138,3.1788883209228516 +139,4.956596374511719 +140,3.654022216796875 +141,1.7361297607421875 +142,2.3483810424804688 +143,1.7004280090332031 +144,2.1595535278320312 +145,0.64605712890625 +146,1.0794601440429688 +147,2.1403236389160156 +148,2.229206085205078 +149,3.4716796875 +150,3.1315536499023438 +151,3.1563339233398438 +152,3.2884521484375 +153,0.9025955200195312 +154,0.8674774169921875 +155,2.6904144287109375 +156,4.048305511474609 +157,3.055718421936035 +158,2.996180534362793 +159,2.1611242294311523 +160,3.25347900390625 +161,1.7946586608886719 +162,4.068458557128906 +163,1.45806884765625 +164,4.2074127197265625 +165,4.471534729003906 +166,3.3179588317871094 +167,3.15625 +168,3.4902725219726562 +169,2.7586212158203125 +170,3.4999923706054688 +171,2.938079833984375 +172,0.5755615234375 +173,2.0198135375976562 +174,1.1787986755371094 +175,2.032512664794922 +176,1.7835540771484375 +177,1.6204299926757812 +178,2.8943557739257812 +179,4.579692840576172 +180,0.6011581420898438 +181,7.233245849609375 +182,3.7090682983398438 +183,1.4494476318359375 +184,2.058147430419922 +185,3.034748077392578 +186,0.79302978515625 +187,1.5850868225097656 +188,2.1826953887939453 +189,0.502166748046875 +190,4.2438507080078125 +191,2.8563690185546875 +192,4.449333190917969 +193,1.0401611328125 +194,2.5916748046875 +195,1.8848648071289062 +196,5.068851470947266 +197,1.7050399780273438 +198,6.735166549682617 +199,0.5814323425292969 +200,1.2588768005371094 +201,4.211669921875 +202,0.4931640625 +203,1.26312255859375 +204,1.9409255981445312 +205,2.9423751831054688 +206,2.848846435546875 +207,1.8715667724609375 +208,1.0339775085449219 +209,7.9362335205078125 +210,2.3682479858398438 +211,1.7879524230957031 +212,1.24407958984375 +213,0.6689834594726562 +214,1.9735183715820312 +215,1.6094589233398438 +216,2.7777099609375 +217,2.0335693359375 +218,1.8849143981933594 +219,3.063079833984375 +220,1.0706634521484375 +221,3.130584716796875 +222,1.2679176330566406 +223,0.7518539428710938 +224,0.9654579162597656 +225,3.7058658599853516 +226,4.020160675048828 +227,4.510618209838867 +228,2.578065872192383 +229,1.9994659423828125 +230,0.9467849731445312 +231,0.8559494018554688 +232,3.0728759765625 +233,3.6545753479003906 +234,3.615123748779297 +235,2.252063751220703 +236,1.118560791015625 +237,2.8727340698242188 +238,2.560455322265625 +239,0.967529296875 +240,2.76373291015625 +241,0.689727783203125 +242,3.9878807067871094 +243,1.2500762939453125 +244,3.176746368408203 +245,1.9695472717285156 +246,0.5463180541992188 +247,2.8335800170898438 +248,3.9300308227539062 +249,1.123687744140625 +250,2.5000457763671875 +251,3.8977737426757812 +252,5.036094665527344 +253,6.100738525390625 +254,2.600006103515625 +255,2.4392242431640625 +256,2.445037841796875 +257,1.8540763854980469 +258,1.8763656616210938 +259,2.6569042205810547 +260,1.5930023193359375 +261,4.9168548583984375 +262,3.9106521606445312 +263,0.8489055633544922 +264,0.8064384460449219 +265,5.282951354980469 +266,1.9851341247558594 +267,5.7763519287109375 +268,0.8550567626953125 +269,1.0300445556640625 +270,0.9309310913085938 +271,2.892791748046875 +272,6.195487976074219 +273,2.441011428833008 +274,3.4674148559570312 +275,3.0429115295410156 +276,2.309520721435547 +277,2.98455810546875 +278,3.6409454345703125 +279,3.296466827392578 +280,5.745349884033203 +281,0.7671279907226562 +282,5.163974761962891 +283,3.103363037109375 +284,2.3465194702148438 +285,0.8010101318359375 +286,1.8536376953125 +287,1.6527175903320312 +288,0.6404037475585938 +289,4.074607849121094 +290,2.8751220703125 +291,1.1809844970703125 +292,2.9337730407714844 +293,3.6357154846191406 +294,0.79071044921875 +295,1.6123733520507812 +296,2.164337158203125 +297,2.1856632232666016 +298,3.6139373779296875 +299,3.333059310913086 +300,3.5030879974365234 +301,7.875446319580078 +302,3.210662841796875 +303,4.817474365234375 +304,3.1644973754882812 +305,3.7071151733398438 +306,3.9317092895507812 +307,0.7196159362792969 +308,0.5052833557128906 +309,1.595855712890625 +310,0.5883102416992188 +311,2.079524040222168 +312,2.30682373046875 +313,1.0888118743896484 +314,2.637971878051758 +315,2.823902130126953 +316,1.80908203125 +317,2.781005859375 +318,2.5951385498046875 +319,2.2735366821289062 +320,3.5009918212890625 +321,3.203327178955078 +322,2.8149452209472656 +323,0.9082870483398438 +324,3.644315719604492 +325,0.5691452026367188 +326,1.5577869415283203 +327,1.3331146240234375 +328,4.3954315185546875 +329,3.05267333984375 +330,2.176136016845703 +331,1.9119300842285156 +332,4.3045654296875 +333,2.5377044677734375 +334,2.7671661376953125 +335,2.6728286743164062 +336,2.9331512451171875 +337,3.832061767578125 +338,2.130687713623047 +339,4.605674743652344 +340,4.019569396972656 +341,1.5959663391113281 +342,3.273557662963867 +343,4.640205383300781 +344,2.2807388305664062 +345,9.244529724121094 +346,5.3750152587890625 +347,3.497802734375 +348,1.22119140625 +349,2.889404296875 +350,3.0145111083984375 +351,2.6391067504882812 +352,1.7414531707763672 +353,5.096355438232422 +354,2.3698997497558594 +355,1.70794677734375 +356,3.8087921142578125 +357,4.649024963378906 +358,2.6874847412109375 +359,1.6042404174804688 +360,2.5404624938964844 +361,2.5606842041015625 +362,2.267913818359375 +363,1.85089111328125 +364,1.9569931030273438 +365,2.70562744140625 +366,0.8981552124023438 +367,3.89678955078125 +368,3.4724559783935547 +369,3.019786834716797 +370,0.8579788208007812 +371,3.39013671875 +372,1.5927085876464844 +373,5.293540954589844 +374,3.840545654296875 +375,3.153797149658203 +376,1.0394439697265625 +377,4.864471435546875 +378,2.69451904296875 +379,1.115753173828125 +380,1.3917999267578125 +381,3.06719970703125 +382,1.1477203369140625 +383,3.240375518798828 +384,2.1412124633789062 +385,3.6806678771972656 +386,3.8046875 +387,0.9766197204589844 +388,4.5174560546875 +389,4.07159423828125 +390,6.927947998046875 +391,2.9377593994140625 +392,2.157146453857422 +393,5.8182525634765625 +394,5.853107452392578 +395,3.243896484375 +396,3.44207763671875 +397,1.8158073425292969 +398,1.38507080078125 +399,0.9081001281738281 +400,2.771198272705078 +401,1.39654541015625 +402,0.8424606323242188 +403,3.1727371215820312 +404,2.365640640258789 +405,0.6451492309570312 +406,2.2880172729492188 +407,2.3833656311035156 +408,2.5667648315429688 +409,0.7316741943359375 +410,4.179656982421875 +411,1.382537841796875 +412,3.6302032470703125 +413,4.5171966552734375 +414,3.6554718017578125 +415,5.948081970214844 +416,4.1723175048828125 +417,1.8863525390625 +418,2.03424072265625 +419,4.659786224365234 +420,0.68743896484375 +421,1.9088134765625 +422,3.3668212890625 +423,4.7425079345703125 +424,1.8061065673828125 +425,2.8161659240722656 +426,2.3002777099609375 +427,1.03265380859375 +428,1.9893379211425781 +429,1.6477088928222656 +430,0.9208526611328125 +431,4.5027618408203125 +432,2.3863067626953125 +433,2.038341522216797 +434,2.5567398071289062 +435,2.1070709228515625 +436,5.679634094238281 +437,6.574100494384766 +438,2.2461776733398438 +439,4.967735290527344 +440,0.93743896484375 +441,2.2551002502441406 +442,3.2996063232421875 +443,5.007568359375 +444,2.4714508056640625 +445,3.1668853759765625 +446,1.9462738037109375 +447,6.492755889892578 +448,4.795219421386719 +449,6.624614715576172 +450,3.016498565673828 +451,1.133056640625 +452,1.70758056640625 +453,4.1140289306640625 +454,3.1417903900146484 +455,4.935401916503906 +456,1.3822021484375 +457,1.7961349487304688 +458,1.7515106201171875 +459,4.09033203125 +460,4.515331268310547 +461,3.5706787109375 +462,0.7114334106445312 +463,1.5092239379882812 +464,4.604736328125 +465,4.4217529296875 +466,1.3947181701660156 +467,2.0403671264648438 +468,2.1144866943359375 +469,3.482555389404297 +470,2.2158432006835938 +471,1.4167633056640625 +472,5.098869323730469 +473,3.3986968994140625 +474,3.340484619140625 +475,4.4418487548828125 +476,1.8847389221191406 +477,2.4641494750976562 +478,1.8467636108398438 +479,1.3627281188964844 +480,2.7605056762695312 +481,8.9556884765625 +482,3.627307891845703 +483,3.8278656005859375 +484,1.8056793212890625 +485,4.8571014404296875 +486,1.7985496520996094 +487,1.1248779296875 +488,1.3766860961914062 +489,1.5183334350585938 +490,1.8641853332519531 +491,2.907207489013672 +492,2.8539066314697266 +493,2.224214553833008 +494,4.172332763671875 +495,2.781158447265625 +496,1.2123565673828125 +497,2.5675811767578125 +498,1.9204216003417969 +499,3.5189361572265625 +500,0.8372611999511719 +501,3.3610305786132812 +502,2.4884376525878906 +503,2.8520660400390625 +504,2.9636001586914062 +505,0.9283828735351562 +506,2.7838211059570312 +507,1.3746414184570312 +508,3.7743759155273438 +509,2.39794921875 +510,5.595367431640625 +511,5.6392364501953125 +512,1.721435546875 +513,1.1995849609375 +514,2.213287353515625 +515,0.5738868713378906 +516,2.71051025390625 +517,1.9071578979492188 +518,1.8789024353027344 +519,2.816082000732422 +520,2.8899879455566406 +521,3.3472518920898438 +522,1.2237567901611328 +523,2.2347335815429688 +524,1.4557609558105469 +525,1.4267578125 +526,1.7025146484375 +527,2.2992019653320312 +528,4.513965606689453 +529,2.7763900756835938 +530,2.2656707763671875 +531,4.434913635253906 +532,2.70062255859375 +533,3.222929000854492 +534,1.4761962890625 +535,2.5851573944091797 +536,1.9381179809570312 +537,2.7728805541992188 +538,4.004940032958984 +539,2.4510154724121094 +540,3.8232421875 +541,3.688995361328125 +542,2.46881103515625 +543,3.64715576171875 +544,2.7339134216308594 +545,2.4784927368164062 +546,2.6420669555664062 +547,3.0912399291992188 +548,1.0083999633789062 +549,0.7198677062988281 +550,3.266571044921875 +551,0.6218490600585938 +552,3.2571983337402344 +553,2.7480716705322266 +554,2.1891021728515625 +555,1.7352447509765625 +556,7.654598236083984 +557,5.308502197265625 +558,3.0606536865234375 +559,3.6773757934570312 +560,2.2095947265625 +561,2.490966796875 +562,2.7908401489257812 +563,4.944637298583984 +564,3.6980762481689453 +565,0.799713134765625 +566,0.7726097106933594 +567,4.5063323974609375 +568,3.887918472290039 +569,3.3334732055664062 +570,1.0847396850585938 +571,4.8308258056640625 +572,4.020942687988281 +573,3.4620361328125 +574,2.176300048828125 +575,0.508270263671875 +576,0.40605926513671875 +577,0.9608879089355469 +578,4.364368438720703 +579,2.4404983520507812 +580,1.984304428100586 +581,3.3433990478515625 +582,2.0698928833007812 +583,3.8393402099609375 +584,1.7667045593261719 +585,1.3889541625976562 +586,3.0649337768554688 +587,2.6718292236328125 +588,2.5797119140625 +589,3.1608619689941406 +590,2.879131317138672 +591,0.5877914428710938 +592,1.3176078796386719 +593,0.5999565124511719 +594,2.8893566131591797 +595,1.660898208618164 +596,2.6410675048828125 +597,4.2732391357421875 +598,0.851531982421875 +599,3.4852752685546875 +600,1.6572265625 +601,3.6005859375 +602,0.8543777465820312 +603,0.6684913635253906 +604,2.5104446411132812 +605,1.4377632141113281 +606,1.2602081298828125 +607,3.561370849609375 +608,5.76171875 +609,0.68505859375 +610,3.462726593017578 +611,1.422576904296875 +612,2.0976028442382812 +613,2.9023895263671875 +614,3.3157958984375 +615,0.6492080688476562 +616,2.2681961059570312 +617,1.8462715148925781 +618,3.5692367553710938 +619,5.419681549072266 +620,2.9370269775390625 +621,5.6828460693359375 +622,1.737152099609375 +623,1.602386474609375 +624,0.746856689453125 +625,3.9183425903320312 +626,3.311798095703125 +627,1.9642677307128906 +628,3.844409942626953 +629,1.7015380859375 +630,2.332853317260742 +631,2.3421859741210938 +632,3.2565040588378906 +633,0.8407440185546875 +634,2.966064453125 +635,3.1615371704101562 +636,2.4944839477539062 +637,2.5630569458007812 +638,3.1669464111328125 +639,2.7607421875 +640,2.4143142700195312 +641,5.789052963256836 +642,3.057220458984375 +643,3.2291336059570312 +644,0.5886573791503906 +645,0.6615276336669922 +646,2.36865234375 +647,3.7959442138671875 +648,4.51214599609375 +649,1.4258270263671875 +650,4.4889984130859375 +651,2.864990234375 +652,1.502471923828125 +653,3.0164642333984375 +654,3.56158447265625 +655,3.042694091796875 +656,1.6647262573242188 +657,0.9025287628173828 +658,1.3644561767578125 +659,1.0993881225585938 +660,0.7659111022949219 +661,4.01934814453125 +662,2.541027069091797 +663,3.4581985473632812 +664,3.3175582885742188 +665,5.193931579589844 +666,2.0252456665039062 +667,2.5653457641601562 +668,1.87408447265625 +669,1.9904556274414062 +670,2.2085418701171875 +671,0.61199951171875 +672,3.2067089080810547 +673,1.8816394805908203 +674,5.542570114135742 +675,0.5315284729003906 +676,2.0016250610351562 +677,1.4842510223388672 +678,2.2986984252929688 +679,3.00762939453125 +680,1.9947052001953125 +681,3.6446876525878906 +682,2.6217994689941406 +683,1.7988967895507812 +684,3.031036376953125 +685,0.7153358459472656 +686,3.334339141845703 +687,0.7108306884765625 +688,1.2816200256347656 +689,1.7671051025390625 +690,2.9856719970703125 +691,1.5478477478027344 +692,1.0807342529296875 +693,3.3064193725585938 +694,0.9369049072265625 +695,0.66357421875 +696,2.5416908264160156 +697,5.779876708984375 +698,3.4454498291015625 +699,5.008609771728516 +700,3.0802230834960938 +701,3.5379295349121094 +702,4.722465515136719 +703,2.2318572998046875 +704,0.6578369140625 +705,2.223125457763672 +706,3.0366439819335938 +707,3.2618484497070312 +708,2.4131393432617188 +709,4.5950469970703125 +710,3.08062744140625 +711,2.605419158935547 +712,2.960582733154297 +713,2.996685028076172 +714,1.509735107421875 +715,1.1052627563476562 +716,0.8219146728515625 +717,1.942230224609375 +718,3.301952362060547 +719,3.3530731201171875 +720,4.0272674560546875 +721,3.260211944580078 +722,1.4404525756835938 +723,3.5693702697753906 +724,5.06439208984375 +725,2.8951263427734375 +726,2.647674560546875 +727,3.245880126953125 +728,4.684795379638672 +729,1.3997955322265625 +730,2.2204971313476562 +731,3.5927658081054688 +732,1.5444793701171875 +733,2.3971214294433594 +734,1.34686279296875 +735,0.7172088623046875 +736,3.4994773864746094 +737,2.089111328125 +738,6.775016784667969 +739,1.99200439453125 +740,2.3129653930664062 +741,4.1109619140625 +742,2.208709716796875 +743,3.71685791015625 +744,2.1857223510742188 +745,4.196357727050781 +746,1.5334548950195312 +747,5.4618072509765625 +748,1.7282867431640625 +749,2.693511962890625 +750,2.766704559326172 +751,4.1958160400390625 +752,4.5892486572265625 +753,1.6612091064453125 +754,5.5735626220703125 +755,1.5728797912597656 +756,0.8423728942871094 +757,0.7233734130859375 +758,2.7063446044921875 +759,3.910602569580078 +760,2.0574302673339844 +761,2.21533203125 +762,2.0984954833984375 +763,1.6245613098144531 +764,2.629138946533203 +765,2.523406982421875 +766,4.832759857177734 +767,0.5534858703613281 +768,4.025787353515625 +769,3.3492164611816406 +770,1.6903953552246094 +771,2.7217559814453125 +772,3.6035385131835938 +773,0.7267303466796875 +774,2.7258644104003906 +775,0.664794921875 +776,2.2199172973632812 +777,3.159778594970703 +778,2.4090499877929688 +779,1.7539825439453125 +780,2.5580062866210938 +781,2.5738601684570312 +782,1.1121826171875 +783,1.669921875 +784,2.601806640625 +785,4.244083404541016 +786,2.5022811889648438 +787,2.133392333984375 +788,4.2078857421875 +789,0.935394287109375 +790,0.6654891967773438 +791,2.7285995483398438 +792,1.06890869140625 +793,1.8234825134277344 +794,1.9538841247558594 +795,3.265777587890625 +796,5.0235137939453125 +797,4.219125747680664 +798,2.740842819213867 +799,2.019315719604492 +800,1.6578369140625 +801,2.8594970703125 +802,2.3069114685058594 +803,0.8043365478515625 +804,4.569328308105469 +805,3.020465850830078 +806,2.734027862548828 +807,1.5362548828125 +808,1.3424072265625 +809,1.3950881958007812 +810,2.55596923828125 +811,2.247222900390625 +812,1.9441699981689453 +813,1.72515869140625 +814,0.92816162109375 +815,2.6270751953125 +816,3.5499114990234375 +817,4.100822448730469 +818,1.6193313598632812 +819,4.028602600097656 +820,1.40875244140625 +821,1.700958251953125 +822,2.763885498046875 +823,4.5920257568359375 +824,2.3800392150878906 +825,0.6684722900390625 +826,1.7244644165039062 +827,0.7157058715820312 +828,1.4762954711914062 +829,3.9496231079101562 +830,1.1088714599609375 +831,3.2770729064941406 +832,2.113544464111328 +833,1.4053840637207031 +834,4.242767333984375 +835,3.085693359375 +836,1.6646347045898438 +837,1.5428619384765625 +838,1.8729248046875 +839,2.1006317138671875 +840,1.5426254272460938 +841,2.4358062744140625 +842,1.105804443359375 +843,2.0183258056640625 +844,2.1589813232421875 +845,1.6298904418945312 +846,3.618694305419922 +847,2.62042236328125 +848,2.1177101135253906 +849,1.5579643249511719 +850,4.641153335571289 +851,3.1571388244628906 +852,3.3424758911132812 +853,0.7713623046875 +854,2.4743270874023438 +855,6.051445007324219 +856,4.31634521484375 +857,0.8591461181640625 +858,4.4473419189453125 +859,1.3624038696289062 +860,1.5254631042480469 +861,0.9824981689453125 +862,2.3231887817382812 +863,1.666168212890625 +864,1.8040504455566406 +865,4.148765563964844 +866,1.8070945739746094 +867,2.983673095703125 +868,2.4856948852539062 +869,1.2143898010253906 +870,4.0149078369140625 +871,2.8915557861328125 +872,2.1803131103515625 +873,3.2860031127929688 +874,3.1650390625 +875,0.5963287353515625 +876,1.1540031433105469 +877,3.526172637939453 +878,1.3190460205078125 +879,0.6318016052246094 +880,2.2330856323242188 +881,1.3703231811523438 +882,2.4706382751464844 +883,2.9503707885742188 +884,4.2059326171875 +885,2.784942626953125 +886,2.27197265625 +887,0.7115211486816406 +888,2.415630340576172 +889,2.8975830078125 +890,1.611419677734375 +891,3.3005638122558594 +892,3.244609832763672 +893,4.608329772949219 +894,1.3816642761230469 +895,2.3633384704589844 +896,1.5139007568359375 +897,2.5995101928710938 +898,2.6572418212890625 +899,3.7224998474121094 +900,3.1385421752929688 +901,1.982940673828125 +902,1.7619056701660156 +903,3.04119873046875 +904,0.682769775390625 +905,2.91143798828125 +906,1.73614501953125 +907,1.3114356994628906 +908,2.529857635498047 +909,1.6997261047363281 +910,3.850330352783203 +911,2.8565673828125 +912,1.7781143188476562 +913,1.3589935302734375 +914,2.483814239501953 +915,3.1075439453125 +916,0.6572761535644531 +917,4.9974822998046875 +918,4.10919189453125 +919,3.1724166870117188 +920,0.80682373046875 +921,1.5555534362792969 +922,3.0896377563476562 +923,1.930419921875 +924,2.658203125 +925,2.85784912109375 +926,5.639961242675781 +927,0.7335128784179688 +928,3.0355072021484375 +929,2.7886390686035156 +930,3.3843994140625 +931,2.8586044311523438 +932,3.031139373779297 +933,5.3276824951171875 +934,1.4003448486328125 +935,1.14642333984375 +936,2.52325439453125 +937,1.7171630859375 +938,2.42291259765625 +939,3.911651611328125 +940,2.5111732482910156 +941,3.4689865112304688 +942,2.1542510986328125 +943,3.3801727294921875 +944,4.362659454345703 +945,4.00445556640625 +946,1.4822425842285156 +947,2.4615707397460938 +948,1.9614181518554688 +949,4.649784088134766 +950,1.73931884765625 +951,1.9136409759521484 +952,3.306222915649414 +953,1.6706218719482422 +954,6.530183792114258 +955,2.7719345092773438 +956,2.188152313232422 +957,5.286003112792969 +958,3.335968017578125 +959,1.7009239196777344 +960,2.5676422119140625 +961,2.1608810424804688 +962,6.158561706542969 +963,2.0723800659179688 +964,1.153411865234375 +965,1.3055076599121094 +966,2.72601318359375 +967,4.1406707763671875 +968,2.2024497985839844 +969,2.3154220581054688 +970,6.7754364013671875 +971,1.038095474243164 +972,5.451904296875 +973,6.805778503417969 +974,5.848621368408203 +975,2.8549118041992188 +976,2.46435546875 +977,3.0417861938476562 +978,5.574432373046875 +979,2.4814796447753906 +980,3.07537841796875 +981,1.3266258239746094 +982,3.3192214965820312 +983,3.3089828491210938 +984,2.109272003173828 +985,0.8027496337890625 +986,0.4718017578125 +987,3.2317256927490234 +988,1.7699203491210938 +989,2.112060546875 +990,0.9800643920898438 +991,4.264732360839844 +992,2.1623687744140625 +993,2.0734634399414062 +994,1.4274139404296875 +995,1.6518898010253906 +996,0.6370849609375 +997,1.6137313842773438 +998,2.902008056640625 +999,3.12066650390625 diff --git a/training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png b/training/bf16_master_weight/logs/7b_loss_run/loss_comparison.png new file mode 100644 index 0000000000000000000000000000000000000000..9394b12416f2c3cb3c651344d2b1b3732f504a46 GIT binary patch literal 165941 zcmeGEcRZGF{|Am=n$n_$l2wtJQ4yIbl9|2BNcPCy6(JcFqKs5RRwR3rk)jgWdn9CM z@9*=x@6YFR-|v5azsL8#@AY`xcfEz{Jdg7@Ua#kRU4cpp(z_`RQ;DYXYq%S!$nO86`LCl&e!ctNb=VmY%OgZEX@oVolNX+nb}zLbMx|W z3ve>tba1e}C3NbP)&G45w~f8&DY>j;s`xECY-O}=kx110i2sq9#y4Cgk&#H3r6g2c zV#c~0bq_3UY@cdr*u9CJog!H?z)w#zswT zW(~&`2bh^_cGK$3WLMx{(l*>3SLUsT|NN0OFkq;Q7i)16oQ>1!@}z0M zr|2qjmM*X#yk=#!+tbstwYBxRYy|t%{QRdl5qC-khRkuLK>G)?+gVJ!3kwV1W@k&j zH*ezcShQVSS~3&XWhVKAhKAM*sN?E?{GfDla=O`FeD3$RXZ~;BzO`(6vR+qHBWosd z`}XbDJR1fU#kdENTsq~SKCwuPhAPIL-(h286V9Qr=lAd5JZH}MJ$}6Z{Q2`d7R{XJ z-RC)L9cSVs;tGu&!+qMnwd@-8OobA9lRae`g<}tB`5X0h?%sC&z26MQ`VfX z3ZiZc3FI646nx~z&!4qTb#KDM56}rYkRQABKs@;H8Kq(uJDS6X%NrXHoRAF{f6O93 zReP$e_iHpiE&Y=>AvQ;k9@W&+`dC%9756IWxtvoY^lGn|?D2iYt}~ZjT7}79{T9W0 zY`;#tK-_uPM=1-opTn38mmZzCazoRr{>4SZrE|A_{}QhZ*Lrv3+a?mHK`Gr{`tv{R zPfLYol1oX2ztk`D;v!L!$R{Qy+V>1SdlD6Oa&>)W@=ifPL4Hr;+Cg?znJ5wW+_0#q zuA{)ptTdLU`=@rL6p8jQ9$Fn^b?ey}4?+Od1QlpG1D|-aSbrC8c|V8tK|6 zG_|#FxI8dv$ug?ma^%RfT`fP8m5(tUKYqMonO@Z6WPU+`iDnNTKm@m#=klzao<*v9 z`kVT7Rq3{KBIb=JgCoZqQy!OR<6zaPpFejZb3(#?rto>-;-X7}Vb%S%_48_`L*HX1 zu~SZqat9q?V_u!96O+`{jcHisyY(Q?x=VX;c`jA*Xv5@Ll5b0IBLdc6AFkZ# zG&VJ5)#p$f7fZ1ptTFl1Ir<~vuI4iyE!H@9^RK+FGo8#|zWCxUq)6N0h^g?7v)IUthQ1P*hO(R2|A-H$SG8rkS zD=S`98?K8><3DBtt!wm-3OgN$*NnU$99%UsR5#t@z0rrANWNpo?W82u)wMO{*K*Nv zEl=VyGBV0fGwJQ5L&*w1Wk8>#5ZhpTGdCY3T+Ys|` zo#y-HXJ@U-U&}wUPxrEtm6t!LGder+leHp}+k17|U$ay*%W!cdDgMWEiOmPNb!l)@ zF6*o36yimLjjBUTI5LWyCts)}DagpmZX%tw>82A~Upy8Q6GM6QXt4Xj#PyGNwvzAJ z^J%d5N#DpwphZiDgtW9PtD={e*VNRMxTYqJgQH`8k+Y49n_GE(J(ZiAo6E+!h=-?V zmiLBNaR0M27KenK#!jvqP;QaBbg6%$t4Ood;~b9Rot1^jmsTCQQ!_JnOG~|S=eLt+ zjvP70&AkVkx1zJ)*7a-Go@ZOi$OJ6X%RG5mEMV1s`n=aF&BdTYRY{8R_Bj(cNe}2n z4%bFmm!GbBfw)0NY;&qT~o zpVHTzpl3vFWMX0x^j<$7$*p%kIr(JFdDneig^ns6Pp8})PyxqUvo`P8v14j(POa?` zkD#FD*ii`$4Qe(vwyEi9iMDK$h)6}Ocv+o}rsjul-{fP41_vWX^?v=X68al-`7ui$ z_M4#V-{Tp2MfaBHMlFPoGuQ0RiFtmUd3oaP8J`=E9zF8KHV;~P<22q*D;mAHzUF=% zwdsx&bN!;3!V8;d%bk*r*xAR?^2mI!m(>`xf5z@VI*oG%{f`JjiTTzE9ppGa(~BFtv( zEDl7XNqI%Zp!Nvec^4(Kg3q625>!*J8jhyucc^1WiKj9866{1(|-nNii2cbCP*eXd=<&J-zv%HO#<`F?BP(=!&Q z@7t}+k8dJ$Z<9jLiaKiP9 zdn6!8^7F2LrE00D%r51=ADtU*xn(6XKR+LJ+H4>CHXuZ7dST=#gOH9qo6RI*%QdI# zq;h3u9NaY+;sd>B2n{spiS4 zP=JyO?0%D_vAC9o=F@1>QG6!%8~D`K)oBE*W6JqPytv>c+!k(%HHi zUd$baN_u;F8eHf8G+n?^J;S~C)%&Z*#~;cZAMPm=^T`=luJGTh9TE`mcQXOON|nES z`O=zicdY%r1?!$hYv3*c0Rg$V^Ryp7exz-lUYs@}4t>)bwY!OlCuZm7u3K0Hp=~6n zrK%IaWqNi*C9(FNpI={Q*}7!Kz1{ftoj`7Ae*{!XHYgJXwmiL1=vrAPrM$4=bYY{9^PFT=g0q$HZMk`j^S`SBp)Yy|63 zZqY5U-};zU@e-<1dRtD~&wONO0yMx6rd`>yy*?i zIr~?azFZt}ZOpaOnRYqL%PaBa(Fx@eH^*JHd|PnJE*KiVs1ZUnduumv$zk}1L|R2d z!#)C7wq_Ykqw;;7+(J7N`t!kW|9kfg+jGop9Uc7&3eJ7~_D%WarQHNl2nY^tx*3c6 zuLx&TQ7Uq>YHe>X|L}o?a_{Hwzg;~|^R}L4xhfsNd2-A z&9+kTm6-NKcZtrbM{`Sye?&xtq~|%uZ})?ODo`JYRta=j-r7p@>C>l^IBJCsBX^oo zHI6YeOI|-`_VrIlL0VkPh}dSeFC*^Bs8FjblVzyUENN+J*9;BsfMZ;5cy*b4=S~%! zCoH0#%f9bS>o$=JohLJ#R?sU>d=dld!=5v7U;Bf;x`_mo(!|1SG8y~g#ft;KfjErN z$Ven-XXn%43#OaZnDR2Zy1TncK4|tSt9{F;JuZt=3A(QGvBJ)KfH3?)FVF<80bKw` z8BO(m%)d1ha>=o>x;m&j9r&JjS_Sqda?yNS(Ke&+v*u0u+1CQ(PYu@}!%~{GjytQb zPFFMZkDj;^`QhV7CY4$Ny!p@CS4n^Pj|)_N`ecZGLolGRv9W8{u95ip`2__9pR1>9 z)9~Mr%PkU6uWoCj1;i(j7MGXboXFoTxdVo%o^CY;D#Oy zE5w~=LlIza%(Cs1Ok9zR9-|Ao!$_WxkYL|c&%W~W=gX#6!xk)EhodsQ+cx#Y`B!IjzZqMlc^U!wxhPMGTH>0Mu6U1D#U zS(-6TaGU#c_~?~4h0AVJy|;VE#?qX$aKqOJYa-YOKHp?8Z7g{*whA&$fqT%Py@KM; zKlZ{FT|Wim{(dZz1ks8O~&STHa(@gYH`!bBu1cA&5YVMbf4>3mBS_8 z=uF?PrbUnJn*|CrY&&b!Zd_FQ^>fY8G+IxU$)^Xj#}okF4|SKgvjXQ|waW8#r807B zTN|_Z@%%PP9QZn|VF|w_BK1nskW#FW;}+7Lpr9*}xfdw;oHw6D5f+4#j(ibMqU2rX=&}{q^hktJkkV2 z@&mX?5%Izv=ehBfhUvtKizdYY2$D;G2i2OD!LR$Ue1{!*pHeSK7WAPw+yQ+diI2_m zvldl&{rwfFse!ZW_v_a$x`5@4%cJTo#5O4hky7oB|8rcHPqQ0RYnHrs=7^3%eLZ>45m(=H`Ct{dl);Vj`s6 zZ|AWSC$=((dVE0bv|U}a^Gixcp{Xx%clKENyA=R->hIq>1e$=%g0_Or^`WPS0gUsT zy?}qV-fhR>A52#qUp?8sBCl>~$%zf;6Bb4(7b8fOuKgYl^Bia-uhFMnTefX8#GL_` z6Yuoa>?cc(Mav$D97Io%yl^4QZO(kGr;GvEhz2~M^Jhd&nr^{90h?~AMA--)>rTP^ z!a`nSbj8>U55@Ga#{L&f`Xbm=A-Z<7U62s@1h+Nq|Wqv%tzt^_V2&DV?T?yqT*iAPqiZ|&P(PFDk&CW`J z{8cvd5g6jzGl4AgCU##xzw4d(b|=-|94K@geXUge`QfpuwXR~>B5TvheJg9i@)f_G!Pph7d_abr;{(bT7w=PY&|6*_R^ z#*Jaf7<%-7Aavt)JA}3SmdI7e*H3D1-3QM zRD_OFK2RM>MuJ9jyR@`a(B|9IGp)t0{Gc)53N|!DIdm z!~iTc3iv6?0|$JsCdz*KCgwa5Q@-K7GA>YHW>DsJu*k#;1xmr@_0@%muKfmAbfjJ_ z0J0LbkSbZCx5fYe40CMyH@Ijow{16K+LLZ~T#o(d_u$xeBcN6cyZ~EPs?S^=1 z)BT)UEt4%0ls&i5EP-D8aAJv`*_j2vcz4fHpFMd#f?KRU8#YT%sU#?WpiE>9 zM!bA^%xi5?*_NHy`tl-*9gqDl z8A1*2>+2J=?Io#a>Ib7p8Mn@S$a6aNDy6bgD$7C{ig9q^+E>oDnOulW+C?qRzxa={ zqERYZrnJ7d&~TXSJ_=6C0}K!G2Itv#CpC939!kMs>7TFVvm8g&NoHJQ4S9=?b3gF& z@$D;$wCc>~CDdMg<29f{`$0&ztlbO+MMXoyBFYX9rvXbtz-Bhdtb}AkqYQ3QPF3ex zTv-|Tk#J#gZJ`Gw^Q@&fD%-D`M4Q1i=vSQl@@MIh1P>K@zIavZ}_} zQ9n>RiHfp1Rk4@k14WljH%QHqmFuAlfp=mA?Ka^VOM7^DWaK?%ec_?dK`YDrVp85I z{w<*2p(hOIUEVgZHb8%19UdLsDXwji1ZksS38?wtsL~mqB2^$K$*u+y6P-Y_?6lV? zZ>$*ukmI!894m=wwYkH%QJIX*W4p2L>Bnc&T}%Z`y@fG-E<}o1`P>ipBB$}nW%~0j z&w0ew9tT5EmW|(1<~SO84U#$*o+11_Y1gh@U#Hmq60~Eq&|@jPfiF9~9r)f4ve%n* z6}_G1|AuyizL)QTU-(N~_Ogh}l(>g|6OIbbeuN}#TT-1PNY*jwkdwg)5dIC%shQZP zZcQjFE9>kjabNJo4edZBp$-+$(b4I{{j!hN4tr4S+__&TKu}niBZd{5um4;_iVExd zbhZ~SUmgs}118IAk^*G#Lo-gSy>#*74j?-fckUZGp=Ycl$X|&}XQmyrK?zxOD1)gn zHCRjS$A=EcGv`|dLDQ{Nw#T)I-H zqh|l-=^XBF+6=(g^L@HcV!HBAmAtMtkxYJbiKrtBO6U;f={k7_YL1^c@nw#pcXBe* z-kt9%6cVz}`O-mHLOahh7R?7x;h)bJpXKM@A3`tkbTEGl(|8F%xiwAE3Vcx%ccRdj zx4kp<0}Q1=$AGj?pc_21q?|a2JAsaA24(tCZ9a|h9ik4;oO9}{q z)C(M>zEH9C&&^&kJ$IBS8$>k{e`e?wFl%^0ZtZvFvbz3Md1v{?`if%AYpZ)dyrqGbinxX=AbbPz+mP%WNMF0>vREKyQ;?4|hj-Mhy{k-v<0mVdmvJ;*T~ z6!sYv$D`)-jIGxp-73A1+L1Me3myDDj;1SKy}o=i)%3P*qhrbAKi&a_C7f$KB(*Fz@zzbo5z@NnmpC z^*dn!C;d%^b~fTbZ2>$ZlNnWKUE6D$ZS38tS?rS4d443gBqBoX^}rJ^-tCMqVNU(5 zwLojt)|I~0wQR4meOZjGnsLftz#7sOqJUXNMw@3)OOg*o; zrJJk6q}J7!e5^xF;gZ26lrzRi`S@^3-xzsH01^!oI>4-hwQN@?pn0j27*#WL&sA8z zC7eSD%&2X4Pi(S@4&1k{C?i8zT0|^guCQpsJ z?4cACqE|j z%_Ke-SVWFhaB9gXCkw}AyX&W3q5Rz}5}rN5ml#5KjyY1N(7`M+mkp{3Y1x)d(ZRT~c5?09?~xIs=GydRIJx7^fT-lNESD?o+`!ZOhlX;@euhUrWYM^tvwbgc6;3ag!n4z6 zk7}0RJ!aG`^(e?6gJaG2o|Ed}`S#Jh{s93878e&8d3mYC+r=AUK(g2FfgJ1sxN275 z3gGcCnH+L4s|v5oldV1g`dDB&uFOC*gz5Lu(OS8qwJB?nafKOPZ!|KwODAx^YPiG9 z81|k&&`i8arWEB;?kfX%qZeHdJ3<-#^{>QaW^3(Z@$l zKf_VIobsY*Q9t@cmnO?L!XwW)jnR~PENvpm79Ea^y9GyLU(w~CuVGahtu47M5p97up@bO7P_xkZtdarJgll<79g|dzH$D?mJGk-Oj#^1h8cKEC% zxp?VL+un^LYA%)6yVx}{IGt(}(>jzH$NB3T8u&P#z3j!$+?8Uc+JMroj8nC`Jet}6 zJ_XQHGut?@M#ybpg3I~al~YL>T7Ya@NO0*55vicH*10kG+cj2aL22{0T}tN^yBsxC zz%FHEWJn}U9k?FP){7yWR*(El;ze45kZK|JMzE`uzfnu&F{wE~GV6K;N$cE|tjmO9 z3IO>5tTf{L6Q`sCtB$h~@#7;#4E#5C0XlBc(9pp9zL0syOgNpkFpQ^X+!1tGgjtYl z2+l$W-M|kA(KwY}t+=Fwj09SCNltDTiIIm#sd(GvA=6*zI|TM5J%??Vtdwx4CcA-q z)9sw`D=>$KzI5Mmn(WpacN|NK_quxIC+w^6hp%282WveB`iXOVsU<`2mX%)g5`8Oj z5+L2|Yqi5PJ9g~~U)hBRECJ<-)OWvp=w)E zYqNgPQ|28Cq)U<+(oO@`Nz74LJeA zV#I@p9d|o(Ld1Lb?%n)jhYg&bsl}GzFi_Ve;ra|Ar5RFJgWGtr z!!gSEVsiU_Un!}cFg(D3Oqe2>xoLOrZoW@_Y7eQO+9(A=4m>m(@%GSm$Xb;#=Wf{& zc<4$#JHeD8GmuE|buK_iA=nN9u;azNzt5g!GlC~q`rUGqWOPisP4>S8>Q_+8j+d%0 z82MVBYdck0lYB&MyUdYze|}0D$6E8JnT?ZRiN|ZeacWNup+9ycyRaRdv zZ{EH~RnY12qeoK{UBY1MM~TA$9#>!JXkiIPc4aPc<6U-k&GdNKBmGRDoUCn-g_CE= z!uqR%jNl#gZxE)I1#ilu$MK#GcAEjFrlHf@6`r1EjXkxdaaIC7NE)T<^S39mva`1uQT;GSY_)_z_Adv~C|1 zeKZkTs5(^;i+SKG=U8AA-bWm(#mER{v%J7nr9%K#WzV^cq zvLL4siZr%L1xmOQB(7ZyCu{GMZQdphlU`X_xfQ_8C#N!S_HgD%)cfdjUbl1Xw%*Pe zP!OW0WAr|68s3@G3Fz6u1 z>$h}a^Ob*KP*x@#1em%{$bK-}h=>}YCD?SIw-dM10_7HT8jFfbbDQqp`ul72-CU~< zLLEd2-1I4w;gqOe)WPh)JRl8|0{Rr<95?J$kBb}Y9sW~O6Iq))sDaAX^nLrf_eyw` z*GR(&LZE<8EC6*srf=EUTUUkUCF|ca4`YpZQzugy$;tQb)zDJ-xp0=BYt9~>WB``SATap^D}^j67;s&R(XJxKwTEWc*uKVI$6TC?O!t`myECmf3W2R1rJKwIp$n zo+^ec6R_8vsQmtcfjbaXu;WG}n6nMkx(OO>vHYrNsRW0?iZQN@TY^Mf)nZ{chDn4Dkf>03{=iu)v8ztcR zixnA`sTH6FUz%2s_J~LqLq*R5-jA4e8f&FO4ynHfw0s)>M&xA(QxnDhUPny*kfws} zJJgEGNEuwp=c}HyU%t&~cPuDrt))e*DXVE2%14g5?*1HKG(t;5Lu@NU%LWE-{$l`R zBJT5*Pfi&So*TSrLTCQ|{W~_852R*d^&totLMTQ0z?j|QGV>947}Jdt9b@{B2^n;3 z?16#T(pKn{{&8{iqNvgUQMK=LYj=Q*UI6eVTlLPq+RkwS`>c^=-Na;WQNAo~Z1llI z_1xTClfF)Al{=FoA%}XeFKs6gd#}JV{D^Vc_8U2w5O%npR7T4Gx$(+ny5ByYiHwYl zeAg~1=t|I^dAwF#BKAtwuoRvimzU0s$(Cxv4z0X3iz*2OvO~yu;sN2b{@Z%5KrrG= zcj?&8hF5`$+FEzs!YZX0dY*GJA^R`M{`kW@<1hQB-AZ8?5VWVm@rJr}Oj*soNe+z+ zDYzT&a=7IytEvJso914LOKstikdd10GKXEMU+i+5$Xs-E9KpJg8oZ`v=H{wUKN0n! zP1I8Z)a!@Ko$oTu3~6vD0t?grQJ{l{5Jw3xv1r2g&bVl(7~)#wA4XRQc*E`QIZD0O zC?IPS{sPo-2_+@U{{H?j#kJPbC8#Nc~G5Z!G!0 zlMAd7Fp7L}Ko}vEp=R{2oa%^umTmL_4V#FCL6w_=Df9v|m?5%9md<$ z`lPks$#Zja+l$p0&*mY+64Q41oZNi%Jeshv`7`9Uq7ohPj{EUFq~|}zJHZR9nzQ~C z)~q4PNJ&EzfI9gSamD`7Sb!W95*i}pfJ$ll$3HhM`|aBgQ+*X=uhwqHX>|R)5n(A% z_vg{wk)a{go$gjCD7~!|SDrG4vyj1+idsbPqel#G@R5JRUl0dFjl^>wk++E8&^WZg zuA1zFzE*qU3Ev|e1h2vyo>zwE^2BpbWEC2p2x~5hAG3ShF*FYyZhC%Phh!$ghCpKq zr|_-ca=asXPUG5sbtYDNj{z?40Wp$DaCaj)x}-J;k+6Z@RtNfgYix>xn^#191_x=X zPHck_zz*B{Zb}N9(?n;=l=H98R9UMy163m*~& za^xd|;lZt4YiFAcAuRK~7prdsfsG3y1y zN(}ifyT#Qyonk)`I}ehHz?){q`7r(+M=VenXX7t?}FNJ9`VbQ%Zn(?99?C)aM^lc6(nR#w*kX!8CLADB@~h};USvcbZp zpHj%n5)7a80SQ8|8lm`zcMkIONdTL4SiYDip zE011V>APe96i`o$kRt`$gsmjv&c)jckk?4b-2>`IM3?CJ$55{Y9seAF!oJ(v+xxu7 zA|=x%RnlrgK<{qP5%hI+~Ama6q@7z*b5^NMCi+VvJiEtpiO>AlH z_itMJFFf6H{?_kZ*aO=UkIvG??>bm$4@LCr9D~riV}Mr8&CNeR5+DJKdMt)jP}S51 z&j^)uU!Xm5jy|hWYeqRl1^|AQdJHnKj09`n+q|*_r%>42rU{ zSKYr3=D9PCXvV?LIvhMV5Au{1Q&i;sI&^2-Uu0*cfD%JQ!lP7-Co|6!aLd-R>4=s< zo(nK)WMyIL#|1!vV#Gz=1sNqW)F=*w;r1?TE3F00AcUxc;ZcB#?#91RXFkm4+Lic5 zHDwbA2L}Pt;pd#x$cPbKzXnJ#I5C%PKKvtr2#eg?MSBP2?L$}BQMiTs8MwdgBpxz0 z3>8vK!1hBjqeVn#0U+EOHZL*`4%3=rwqHBvW`aY$PYk1b@FOQ1S*cePI+XjUJTKrq znM+HZ#-Fxl!vQDasy;aklc$xTz!BC>b;!~E-n#b#13zL@^XWE28lb1A*L-h&5J3-i z1=x~BmqsI-K&$QO%7F$6yDvK9g1Yr&UUC2EXsXjXDt~2C>d$I!VfE<3Q-|UFyFmVE zzWbHwDP2&stHD5je`Vzbz=hI|3{_*DoJ0mwWw%|hROPG$bhJjCP+3ld6gtGF&71j4;%wvrT&mPr-xR8S+YJd|zc=S&C;r2fl#J`cJxPiT?*|yz zRA8zRtp$PnIF_bcbt+E~X$XR&5P+^2l; z?$rhcXCmzfyl{WWZP|WF0=Q##cJ|kIwhi)qpKc!Pso?T;uos02(lcF>=ey2 zd3@`s?u>Emt1|B|wfR45$1U=pS}%q@&yl|KHJASD*J%!&Pvzy3mJ(Y(Gt|KiRMM$N zOKl3-x>nM)b1(Gz! zWJ`2}7oDV;g*a$h!Swb-oTW@hG+5#{OfsYY7lO|ysJykGAl|gn)ZmY z+d2y?tA&I7m(M7`Jf6!ABs0AhF*qk2R=Gq)P23vo75SM}-fGw|{dk^IH40{H5-5 z{Y&Y=LWP>zwa90XBVNw?6UKI6rb*9c@pWlB$ITBjHM5vCs_APf)T&ho_nNnPexi#b zH%#>gRHl9>Z;MF_1tcdUTJ;{F5O>mWf7oEy{HJ>ZhH$|xMh?28N0m!GTuJwbyrDl- zwV)M$Yj!K$ih9x?N{>4J^&R&^kb!&9=r)n|Fo-@tMzvs?yrj4|TQ?5|O=qa2R>3{f zNXx52FS-Bbq)P?vGZcu13$4-7ZwF_#e>7&*PELi`BhPw;?SGMH7?fr3=*@5JWe^Qt zq2f={p&s;$DZJr`z&%v@805Z;g)_d^e0PnFpGZQq^tY>09Qq#=&UJ>Ib3qm-#HGF|8jM)5wq!+!Z`ICYCP zyiH;X1cNl_deIYVuTUEe`94)wOHW8o&qDG#;iNU(*T>W{Zb_UE{mq1H35pppmxjcb zmE6*}{xA$-`+-%qmB-_M>ajCo;+ir{O%Ww4r9WfdDDXrkqWdn?#1hlf+lJ6ZDz_LA z#b{s4Qy_;(m-V5l60K!B^2A9Km`NNg3_5NyOJ7``9k~l5Eh;MN8Xg*vO+a0tzzB)(3Z_dlHHC6MTDHr@QUi&cxz01TalN0t^BuOXZdF#17Z!w7s~rHMs_ zwB$8u;{6#s=;}Sik)vW>x+B@qypYqT*(yP>U%<+-H+so7gC0oKhND6bVMq27c1G?A zR#qeQE{cIdG9qeH$;-gS$$7x=35zA^8Y+zJC}I6~QvWRui#1XP8$yPZ6xRH=o|wtn zKL*TB;2YG@%ZatWBB5zvwhAL#zuI`iA3eGQAt7kPrX?d*(pj#-s2858tWGG>NrrGX z`u9GpDR!Ha9or-wfhr%+5->aD4W6Z(Ai0(F0o9#f?h-nmAui&uV`eHB|BbIcppA^= z!`&rC?Fd@&j8m+{EKmO^($y#B3s;tw0^`xoShQ*RIg-RyUKy>Dndzyq z9zQOQ)bp6R?i=7XB96;QZb)4Li!T}Cg@UDeTuq8aGth1o_aG7T0N6-QI|c*7(l05c z^L*yaj7YfUBfgjVG>E3N62}zhUc~C z)~#E1dhjmP4y#}kgzze+yI&(XL&0_F_`V&OGugyAzP+FL$&tJIfN<3NA+gU7V^XKV z(fkjfh~dKROP4O$?|2-e3W*8w^)XRVdO~r&dp(-TbC?(i0F#ndRSlPn#+kl`Do9Z) zGcr0F)HUOe^bG!6_OT>dVy$i32Urt-r>ChxFIt^Xex!z`h_E>wl&6M2&FLI1cN8Lu z53_s(5$4j)A%)HJeEs$1v07v>8YwZAG-LlYikD(x7!t_0x7+t(F@630H!H2L2_=Agh@eaC=I&8@IdM8`txuHeHj^YATTb4x{)nRVa)?SKsqrKLIkCd^E_SK)qiEu zie%hcUJ6;G-`yJl9>m18tTaXto`Cd|soT*2 zLuJ1PCmeG#L^2yOmFqWduxPY^jQ&6eCL}b3ko<#!c0%}k2&oIsJTXp!MEJI_voZx^ zdvk%+s`2Oowcnd30F4q_CXp9ICGc2XqypkPpp*T@EBZSpC&O<-~WbMfa z)fr(PB~Vc0J~4^83DK=e2IoNwd0*=Vd;Wt%${78nJHSn)rX67mN3|%&sENQrkdBu<`2yy^!j^ysDu=0JKlbV7&nBd3~5=Mj#1AKaHuuP#3(m@Y`|OYe>K zn_Y$P7p1Ik-aI+cQx=L&HUL_htdqy>1-H*KJr8HIU-J}1ehDe5Ej~HZJDR33E1GXN zK!{2Ab{~0xjB83!*^Wi9Hz8)YUR~Z#O!9EHAf(~Z_WqWF z-aBA2Uu?{ehz=u`i6NSukOTt2M4#iV{%C5dLWi`495&Wnas)Qd=f=)lmtE$+!! z@Rv|?FqO0!7ZKaHS(2DSxVvpPXqg|*F%jM(ejjxAY51C)I=OOk_tlVx!;REl)kU&! z5i1p#$L=&-=tU0Vs;k^8W;9w$+yy|f{1NIT#&L*2 z2x3$w)1a&$Wgb&-TenE~Aa8`6(00gx?|DnTH}w13WAW>d#Gn)P(4?C=M0%|d`(}=p zwP^%)0mLIeJkLICnPlzv?Ef9sLP`x`Esl9nK8!}5zWGBuArAVgArOzk!LlLtjhH(Q>9)T?gR!Y1P7oy4t80Aug>zyXeTl9R_rpp32wL{ z^A){2Z=tU82nZ0}p*7JC4@wydwQgW07DW*Ki_qRMT1gBIvMR(VwQqeR1PtX99Zie0 z5U=TXIxtRRCdF{n6{)xabtkCT5+J`+I8FpzL)o~SmnT5nE3%kCKLc1IVjB~yg_ty) zrc2gzgm@ZMZ2{WVH$Kiq{SyVMZ(!gq%o`(U3j{wT<_k4Gb$8P%C@8qBEuXW=}`$k@fhpJA=>)`Ssb}>4rcb}TQE1SjE!-KWW0Rp{CV9p;*BNosG1`GH_v~X z$O#8GYk@>DF*1rp@$8U+H_ZWB#P2 zrG*-tBy!MYYhOpcY1j#OPnEn-jKr~PSVJ>yYu|hIoE@j}?-?z2$Rog7M7D+q8Ydbl zb~%|Y6eHEhb80eyuf6+gl%U-U7f~+^wO0HktZ_X0*G|F~O>a&VyDuo19L-+2h2>_i z#fiD{qot*8%JY`J{XV=+%Gn8!g9Lz8^HE-;uW-(jviHbOPA7U}ki$Il@zXGoq$IiM zM~Onl3k$FbA|eBWgAWo(VkdWu1`)G#M38>4@Xnn(zNpu!X++U02cSArSp7N3Jpl(L z&;*j&8z-?F{$r*QuY2~Meh>FL0bbu>h9?LEl>=b-MMcph&RwoGvp-t%8M4+EYHI4> z{y**SQ)^b<%*Lm`Sw|w!7a3*tbR!@R!vZLCChi@GV-T?o+wX|Frs}b*IVPh>n%A1n zMyGu-ohK$1VC78<4Gj!D9uZn!ol%Tf(QD&9N=L_uLL>QdK_0ewK2(tWF_;`xlfR&B zvrl^CoN?e&4#-Kq7@64tib!E{)}}knzT{2FG+ZYNlJC9~vg9y7WNRs@z8LP5WPV{2 zhRGp9k^|r|0&>rHU*IO<3V3@=1toy;Lg)0YVXZ`}B@kV#A2^&~dZJ#d0RVHT@aP6F z(b0fg_b)F@W={EH%x{30FGH537e$@uBZTdcpD##^q+#A1sn`?S_R_~BTKhEoM2-N< zNOX&__V+uX5mn$3VBYIK;78Q6XB8+7kW-`}i2r~NHl0!cm9pm!plBL?%47iv+#n?tsQaAT4n53iMGT^@BJPG3E#2^a0Y=%<7IQWHGM-8bR|TGQtG^ z=qz^K4>F7TDh><#2h!#bXm~ead0h*wTtA|4bSmQK2`@=als6qJ1>Biy|4x-0BTN#E zgAtS3gbRgPg!)vCQ^Wws;4-u}UzEK8+$&KfvFV5_A~Nh~;shKzEMWaW;_kMTqf^}& zqk{JF3HBc`2ZcaZKTbXYu3-2O&LvvTeZ^XBEiIEG@$Afm82=eXR=Z` zEG#UDVL?2;i&(ivQO{UfLPEt)R&F4ME>zLHHI!bY9BPRj;1fQPVi$b>OF%~0t=DX9 z2>h0S9a+#m7#0$;2X(R!p8_G|_~$u-d zf-A>}o55I4eo>JTNSN)dTPiw=@Tlyd=yGahZ$VIj7$+jV3MOVOrJ+2j4;&hBwL9fN zgfrg1v8E_uIl=W37`z=mzGop6vo!nd}zwuari6*HNSO}Uym)YAem z{7u{l5iv)R48)BbLw+BxlUmuw;b8cnkG&bcQG_l+jB5~>3uCQ)5c3viNA{!j14WaW znwk7DqU!J&MiDgFfZARJ&6i7O^SnAvr=Y|q(297P~B;(-Gt zODaxoc)vk3T?}pH-a!uaUUUWM@9)~VtIIr=nD=bwM~9?a;ym_ZGG2C~S`@F*J;Ck?EdTqX zdz9k>?%yv5BgAMdVfdK~)723BzZ}n-7-qsDwGgITzH|{F3n7NBB)O-j1HVWxH_#CB z=D%Lyry@AD_U&hpBa!6dMUSGwV<`LNf?lD+4*VVYXny4JGRAFBQtRd0l96~EhLsL* zsLR#X|Hw0kMuv~SkRT#Jz{6Jl7o?WRtC30g_oaGRAdZ61b-4iN217(E@(=dPhCc&u z+a5xr%0uW-h$*!=(aZ+(3kXyL-_>X62?Gk<6sBXYktVVY|MPolK?mUgH967XinV#s zey#k)rIrp9(VY1AxBUD6*~O1sCDYXwEgyd@(T_c!FyWht>347e|G_3ESCiq8U=fC+o_=*2isKtT3` z7GCC>I+xwYUaW@vKBN2KoT18x%fsVb0|}fcBe7U%!6Mp$i1CawD=`$a^azLmc(=sk!hgx~eTq z|NC3;XCSWWEl?810jMNpnZurLA-|Jb@nTX+(vX?N*ziHT_Mx%VCK-n*!CtZu`ZMCt zIDQ}$h1qSZ?UBd+d1oZEFCU479p|sjv&S#nat{RXC0;=frhzBez9+PF^A0;Z66u1G z5tsCfr;OzIbfS+>-rAn``n&04ykwvI8atj6LBJ=OFMs~WMUwqXPJbJ=_oVT4AF>kn z^3K(ciEWSm#^<+cl6s%iS%scA1*mv4=q{R7b#u7;%A<-oA&U!{MJxXJ&vf zBHGPXrW5ofg|2+(WM;+G*fqpspQFH{I%|%YvBt}FC%~NLdrgT;??joBaIxM2fOAThKCa| zEGz%VR&ub%{`118K#<50B5ia!*>mQ?D24gP`SykFq^~h02ktc=eCxX6m0megUn;gM zougO&hj4I>L9zV>?)AxFsIxGc8OZ}eL-840ZLgEYN&k1*0Rb{ATO>$~;?n+BL4PGT zx?{r?PEG0x*=%MCzKjRxA7=E=i>@&mjjU4__8wMT`%!X5_`o}k7-=eHiSmCujCz4j z&95mfttiZdcP%Sq#q#{;E~x4SE;Z5kW*cp$7jiljeBSLvn5ScFM8@6qyIvoQ78nlx zp;%v}orz4lu(Qs~*4}K_yB8TJ)e@N8cF_cOc&)p>zf9p_(|xg}v_h%8fd1>au5yMs z|AW=Jy&mN2uX0uc`Kjw)&+H|B@RqQKpt`y5YpdmJ%`qnSgIgp%gosH0$dP9V=jSKASe7OUDI`IX;o)O%&3;%0LOLr9!p zsJ%(-k1szrZ@afm^WF9Bf6k8IXRXu!5@AqaeMD$(m6kfQQ=B&G>TqZ233fHxOFugH zMn!W=2%0VHx~!e2cb*XOO}*mk@OM_cjc<<&&D4RNPck$Ds(uvSiE`jN*;;k-zgAdT z^qu*dQ;b;dK}QRX?WD?2pAxK|FxamB)*EYg-le}*c;95SB_*UtBlP*zi>3ua3h_k` zGCcjSH!g}z56`p79Dj5^aIMQfl;*6{VbZn9+=DlJUel6@Pv_b16@ND|b-iNm1=im7 zf|7o|osnt&P+JFo3iqjZsl{-UZ{pS~VOtjy6(!?nw)QKCX0z54-NL%v-A>=$p3lRk zt9CHGmf@eZ|77T8o%F!CVH{j9zsI*zQd5Uo$7%{21#!GJVI&{yGH{yLLc0}Bq_-(D;)9Y%AQR=K)r{XQ&T7v&zZ7@-*UQB8t0{8 z)h`gr_1~vE;IHPtj)nmLnWmddhvbi(l)J?}N=_nQr+LA1iiwFT>AT)1tFrCnJ9b}y zF+l3=dzn+&8}#IMvTjQ8rP1#hqXLtgNN$ccnLKqrevtHpaz14@J+o_CIQyq1os5*M z+_Y80bFz2np$d;%)RP8ZR?-1B%grnOPWh7?G%D*+p$yyS#m%Z8nGZ~t*)c!U3Ev>X zQ2%*_S2uIT$EFWnIgqCkb>l{4q%Nl1!k?{A)vGv+@b9`UF{#!}eemexbQivXfg~l8 z4nY@XX~Pij|Hu!JI#ihZ3NF+b@|! zvR(0c&vSzxX;uaWuWtVQ_>oUXo+84cT1ecIaJb>a~<2k}W7EG&KcW%ZU{D`Jc_8V`{u7bosEEHy-ioVb55 z;+eIvaX`qp_O-zm(d7A#6waCSw{hkcy9@ny>&(|gynoAcA=7$1>-r#F@l=`5l_aO~ z()^=)=*|W2H|73au0D9`pY_@^iJ7@{lT!ByGnYz@n3;m7Eldge<9|p#T`M7MU$0=y zb^QV>8EN;QW>p=<#QWcA{u(im#6JuSnb$U_sO|Q2bNZ2NDe>W-h1L0uoZ0r_;o%QM zk50vvc^p%~2Pl3BFu1|UxoQzt-B{+#iqAp1zOX(Z=+!PA6ZiD#{S$FB&H$Vx*4Ss9h4(NK{}DoJUoL_6)Y z(ISa7Xh?fdNqeU#B<<2(+9(a}`n^u~2f3f;`F+2?=bvAH+&7o2>pIW(`Fv$c< z@jCWe6ZjHy8__l-uE#H92 znPA?nJnN+DBF3U+*+0xH2f?Yekji$&;lWy&<#led2FOBx3tj-`mA${atygE!Q*1n(3)hRT1r+ zsL7c*9SX_Ki;Le^m7kG*Ve^e!HPeQ1{jw)Ti&b2-{X-Fv4xBcV!){?h6Rwuq_mhb8 z{D^CePc`$iXQ_-vhkvNh8R@&7{g=E~Yde)Zy&-_^ulo_r!ORk1mQoqX$%v{yXJ=>9 zL2(};V27U?GilpaZ;-1WIeJQLnl4FG%(wmffBxtjgzjpWV|?rpdd2j*vJKMBot4&X zhj?j7mkhPSdrluV8Hg|Uxcjgp=9MVefzY9$vhtP0DC zik)}<+o1gqFSIkCsI(&gV`f;3BOIJ)l@;&La=nbB$f>sM002rK7JUAdn&*nH8d*;~ z@57)UyZaddiVE@8gD3rX0+L#bDy!t>VkEW_K;X(rHBH6rb;?$>&PF-wPQ2|J{z@Ez zTQsejYA;uQk!(2O&UDLk-YxX~DB41l@~=?Bci$BGQ788)YOF=g7%6%QvwLx-QO_W(d8G!Z~0>*Z&FBT+|g3$8HTb8CT_{e8Bgk=#%~zd9uFGPGZ(FI%bm4~& z)ttl(4Y#rR?!UFpD5%@u@@5!@S1#>WyZdjO^P67UIX<4hINmr};QHs0ij7Ot3ia;E zH%KoOSwDm?!t|yk;J+)QXdhld`C9L%uhxTst%vW3OecV--vH#TiF_G_D zCwrm`Y_@pNJehxC(9SvjrT4Z~LZJTn_9LMNO6j&-hR+kOeXl*sA7DyO!qKJlUbmdX z=Jn;494qn*{=Hp?r7kSEcgNUw$i|bO-S)}j@YyI5N#;qONYF?1M-#(lHrxgt*=f%_ z%V_E7x{AJkja5szqk5PfN)$DV$}U-+^;)kp=`tN1Lsjjy9fMpyed7NR_=B>(THY zA(B9LHd~T6s(^Ho_E`+To!3QLjOdM2CUQOaKZIR^!#sT0fl~K>5 z{?4;e#fr&Sd*JOtsA5yWjaiTl{5#1;{3L6zqbU0 z_=avyPeUp_J@|QtcPOXHyQ8TiUsml2DB;*Yc<0;}`tVZc2bo{{l2Q}zS8|-0))tJf zH(E91!y}7hq^h4*v|iMyIN{j(&F|e^#XBa2YTm2RB;6@m$5R#GExhfR<%#Yo1+NpC zh;ne;3qLeUZ5aQyeB~VXx?RCTuU^_H)Ed-p5tn(Qc(JcLhMiM)O3XK|g}m1)aqq=^ zvYGiQN+tovL{rtqu%u9bdoB)+%2<0<)#LVuVs>rHvCxg;D3B^m)!S2{x-rJV>u%Ar z1}OEFQ0ebK(A~B-A#LFCnBx8WPaCJ4fgus7-apOCQQ5FH2K1q*_@3_6N7CE(d(}T| z|J$Z1KZH$~T^~v!vH7y^*+#zFaqG4*X3Wi1rJ4%B*-{kPk@dHkxD@fCFj)I4 zEZ6st4RYwuh1yumZblgn>C$4Syv?Sm zbx(`}H7ope%np`_h!l@Zo5xNq!pM#Pp!doczV&_b?swW_MJfjtyA)WOEax{AEM}1s z0{L!AMVU1=)gL^zPV2)plew`{8&gKYGJKa`c=cuaSx43LFr^=euWi;pCOJ3=Z2OeF z1`ld$U&zs*^wu^ZnZk%;G2F(j0UO$btdqYLsGXLOc;oIJYFa+I&gY^0NU1NU?)d`$ zQzk&| z+A~i5SPoVpNp|=%-Pp(x*3A#(I7imo+)+%rB^#zn_gwk7RM?m5X|BpdsixU%%f9WI z=PgcMKF^(~T0oT(?rM6i54LzHvg${#^@_w5oi|rdRz^u}g}ekogCog!?b~yxp-DiViT@#kL@AppdmK2~MI8qZy<~eM!7m^sQz6spXyfxc7-nocE z&;5$|W>KSS5R~aQ=5ENgGW5a~QjbhM<`?)Z#gC@Wi0n135;W;){!-(N0@umbGYg3J z{W!>JTI=2}zhIev)U(Bgc z%{Jpgw?3!&AiEjyWp&{jR@zM#HBY|cH&xE(xhyf;wYTpJxfJEtnbOI@aus6RH(N8i zRqWD;jEvNpiIk%~*0suO2s9uW>CxNNCaH?r*R&*R#*lim%C~Q+d7de*B;l8Y!{(xm zW?wf{RdGrV9l!ROJN*m42aaw>-4IJ=wr^=V6=lcgOG&cnx#oJoLdjHh&n^@_yvm+D zTqzNJ#A(I%@2hpIU#w$O*!!%rzSoIZdP3p`aW2ZjR}jKp1g=97#|pHLyhhrJD1d*V ziuw%M^yghSI(vKlhdAZ0m=Y?e|M>|wc_5zR_F}(2CF!*EmXXfxt1Y<>%nZE4CpM@s zoSCtVEjkR)z;F5Ou<@*b@M&FIWLik1^~_PbR%+y~-K^qY7dyC&q?BnPaEB)SM3DbC z6c)(Smo-uF{R`x%R6D@TJH$1@S4Tm8!$$lQcC1^8n-;~ zE|b5mN;QBye4?-K)6WDd3#!dhB8+Q{=t)79KZ8d`eztwxuv z8@?EkI$IGuI`rrjj1oW1!FCE*jdA<}Q74u=PpUU=A;d1xw<#(fo?G7BeJqBI6sA$s zX+69>dNGZ{GRUg<{Yj^Ar{RxgI_#>WCzEbOT+%377qiajrb^oTsf|oGq*Z=sek`iv zcN#dF_@s5`z1!0Q`>(AoLM_TG6kD7+HA?pdxr9qdLP4gB8U-+`1O%qlxX|&B>YX0V z$R%*3J>a{x<4wEbqh>X*B9-~GEFY%qIC1kg6bll&a^1b3giD2LQmnO2diinhV}G6U zmLbjB${R}A`dl-M-0QGrvFcG91NjoXZ#d+Hr)oA)Q@h-3N&K(yg%CMy-!>A?fZ=d(jm8m7B1dA zX&!wg^Ld~d3 zLhKlOx1uClHa$a5Ts++LqMqu!o%K1DM1PfUy_`ho)*k)wRaJK%k0iDKjK+>{^X}e~ zeWkBo?||)ne&FEi*Ri!p5i7$$=J8*f+KjF19_4zZDb+%}du5v3jMMFE$t|060u@TC zi?$i5@VyzFXg2-W7ie=lgx_I7-1#OOQ+D+gcWRy-w{t}-{wl#%=RlB3vFyV0);kh4eOjNBoIF;G%r2%4&W1VhnKorS6*E0X_{bw%PGt{1U%qg0>q{48 zc|m{aozGXjxCeFGEUesX5yYQ6UDWC@LP6>XzkEi5L@@9OH~oWX0H@kkIf*>3UMKm` zHvW^c3K>a62ld8>4k}n2+0U56IcE@AGW4#(*ZPw($q1$j`9{~26t3WqE$cs`7(*N7 zq-emTK=RNPR!n!h~+6*z$?I6WZpTJw2{S^+%vOik9ONiO2y$G6HSOieYo8i+N%5U5k)| zBifx%wo;X#!iOA+!FJno`_;_N{@mw`m8*+eXN8mhGLmw})Z%A3tR`5=8*n$rou*F( zGckucZz%X6tTd(Qn!S$km?uNVw|!?Ob0(=n1)Xj@e!Tkmfg0sImftEjZYc|R1Gai( zvSkY(m#@kxXcUi46!T~ld@;Ga`!3x#DsAnZXM%sQuVXrrz##+%a&~>l>zv6^wd)HU z<9wO*dbi-7baNANN_HL>F%@e2F;-g1S%dj7GZy|qlMBUr@Sf!jM~yc7GYEX3Byrd+ zNM=obnCw-1K5McSN1?JnaZi>Goa^_g0V&t)dgG?syY?@Nf+f8>l~vG^+a2R@ZQiYG zb|V9WyIu-SS9P7ri~Y*9U}-n@r5Uzi*uhEygK;gWAF%3KLy`$CC<;vpfH0fAXo7^f z{Ee?#27ZI zBu`Jz>H*jq7ZhwkcrhjY#0rSTjcz^A`j+B*+S7nLwd7^TK~?NEuhuAqnDr+0VvoHA z-IjMfM!D3B_wu3Cc*pkbWUIe__we*8ugaD`SM9@T-kVfZa?p+}(kpFpzviWx zpZNKYe9|6G5`AqM6ht54lyq;(WR|Q}zjn13L#@jLa9!=Wo&>BMsA2-NI05aT%n#Q? zpDCp|uDa8_<%U7e=nD=m(UBNx)q?9=;o=@ml~G5s?%8?h2p#E#<=5jD1AN<@=2FNB z!d&X))y?V2B>^MXrdAynh1Xtq!nv;Q+4=j;PtKpX-z@$eRex6R2_aIcqfq79*BVnj zzma)zaW>aQP4l_R4?ahFPC9{A_e8)g5x_@&leYcT6dmGomfgJT0?vMQmflX8oXB@$ zc;w5pfs2b{)zo~@aqUA_=JF^MRNqIOf7eBDY(Y;aZKnF^NKeh})~}_nlupb?3mL)9 zkv7-M7g+JEolfSvkkEdk$?I5{M0fUS$umW+Wqd3Y%lykj*X`W7{PiAIkFs2cNwT`Z z>)~3{hxdL}S~D|ao7uNWYqG384Bzr;zI|h4H0#boycbnVCf%5O52!x;A(-QP%QfqW zpsP!fPzPPn&U`c!i5#x5fFrpAYbTtXYIZFdx#jp|4S6h%F3PPVE!wC#u~Q^x$&sdnOuc(*iQtSHysUEQKIOi)urSdEJ zD#ahVrExlbkcpK~iB~nFCat-6vBQ0#_1afAX7EtAvE})F*>1F#tSfoDefj6?5I=pS ziNdwDOy^6~rdpj`Mm$c_GSC&hOJ>cTwS0AHSRLj;Niz#>UN+hO;&bJK-BSae z9Gtu&TLMP=8`Bu^&hj*umW)$7D86`aI9=A(0AWI?A^$9iz!!v>w3ML6|)nLBxK_K&HT zdh8YI=DBNDgKolWeZ4x|MoP=62D5c?Npdu&Rr=_H_U>NDx0F&eI(0diT!@NknVsFN zs)mMf@4Yo&Xk7--ia|65QInGq1tYq2D6AnN%YMz+?bH&`DCUc_sqc1&!QHH^ViBWg zFeg=KGe()LW_YtSr5jg79(EU1F3U*oecxr;NuHo4DxXNih#fPYLx21VJZSK_ryE@e z`A?LN5L)+q0{ix@hIH5)xTi~~C_&4>veO0w-e#siC9-?n5Rh{6algpVrbl^$kFxycFEq=V+v~5oj0IMQzKS zwH6O$V%+Q>bi{5S7)@29)|mB+t0vtb($2}5%x^Qk)_7W=@awUK%*7GTjhh~h%(`?w zI%;~^%~a54ikvdQsPIi0ANA<2wHN$pVr3J#XC?++fbg$qdHPX$Pu+pv?flW|ZZmzD zdb>suiF7aOOI38#u@WzOa+M$v`i20sVJ&x;i7`qF6`K3DiiMYV9gu2KcF%kBgBvt< znvQ-rx$W3>GCt$>QumGu4(?jR;8~-YJ)CPlq1dcwW(t=Jom+uc}@#$qYZKvk4kYS?2dw)yXmv*dt9M4Tcha4@HwU z!i8sAOwbS zuQf03yU&hPCHb)*Oyv*AJKfiZa#&X>D{Xx#7PaPMto$>2@ERQm-)Ih~d573}cg%hN zdd=)bLu`I*KkfLusbRF1Ql{bZD-{|m&=wY-xJ9lxfeqNTbB0QYfcjXd&i*z_u#x^M zhmf<%%PZ}utm1m{GF6yiy|w*n%F>16m}^}hP6i8(GwWw5JWz=?cfT;TktUJ*9{BR@ zh%o5u&CAZuP1F6ACOqG9mwD{rrzV-f|4KBw_8d($N|d7%MnGaSIkYl^-cK$id;i(s zd=G?QKAmP;4^3-*`lPq_>-jaa=?V7@s4#}gLtI8(_k|^0N(Q^zP~T*@vY#aI--ejd z(mj{(UYLLC-O{sLk1dw*Um;aC)QGk}VU&vqq}0D){U&8=()SvVclrUGJU>ZSj?1#l zzAxDiopo9mrE0QTAR{uwq0{?+k{xA6P1);-S_6~=@H-hY+@X<>8fW6SUz3-6O1qhF9R zO>1t9e|A@X-*e@|ltR}hL?+@7funm~UjM263NfK=vUk@SH640Ln&9pl??}BzEW{J9 z;7=seA3WH;Fei|(>uE&~1AW6)ef^tM`};1vu2?RcVP#26M}H;rgG;hd_b^?-`(3NZ z3bXF+ydo_lwdPxs^%>2qH~V?XNB}(QXV#%FlU$&gySPm82EW-{@F~;K0vpCx%%Za4 zN^;yrs)`Xh`K!)QS6iR1v?BdE>b^2tKtYp09P~|RCr!_Eey`c8+wy`Jj9F;kP=L>F zE~>S@-irZS<1-^IyXoh?r?iYW01AC>ltR`t%t*@GwI|y`cLhrVPwZ~1FGtpr`WwsT zw_ltQ(EdBOeOhWwo6{6INp`#1P+o3_dE-`+Y|;fwXq9MBE@+SBFq706)rgF1BBKh* zM^0;Il{ZgzBTHG8dE#L3k(awvI3Yf%*&D1wa^_{avpO?(lx^f7sqpL-RdJHdNCB)> zIM?{ZmQ9c978{DHg*q;}4M7UMG%VsOl%{Ph-d@O^ruus)Ki!3%|DK_YM=6)KsOPN{ zNn`V?jK3an+`IdGP3rd*@~-D=jjQP@o)q{8Eta-RUAxrxew&Hf)Y#j7rzNCM*od)y zYojcj?U0Gph*`O?(A}_2P3Fo06+i#}al_^_#xp>+B+{-s&HL06#crn#oG@v>nt89d zv|yq$8M)kFLS{LB1znm;gMr&&QLsmWT@!7G^o~LQ_{NfKj!{?Mb26iv?k@nzs`c989hQo1O#%v5> z3>j&+r1H6C&I5}-q|RFyx(zj^?@zF+x>;{aed*@?e5q}>mee2Cj{VBK#l>lCt8QBNRSmv*yge8#fnC8InJm~(34S?h%i^0bzC5=oWKulNSz zGud|uyTITQg-0XZg-l&$i~<6cRW&=yhP)SlZ3W%!Um=eLYjPsV_(|5wdnW&lDcb}m@F0i=LG+K- zRhlL3o^l@&DRO5l6_gT1n|lfGmH7Jdl`CsVu}b;R!BZUO<`%9negZz8@sLK&t^k$~ z``&!H^Hi#|{@Sz7H5n@zM_`Z$Rvjam%8C{+ z-b0i`a?W}@^T>Kw$PLZ#+$QIQ8hq}zAO6?a&hdkL{~>|MNd8HUK7B*%69tDv7$yB= zk7^v9H-8(T>Qx+ec{O*e;wEMxS0vq6Wn~|V!x^niZ1$#rTa7F5Djaz=@8|!}Dq?1# zSUhU1y+9^NJw9^OVQ-pt#Rk%f6gxTRzXyRnkl$xVD)dymaXtSVatO;vsLNmrVP<;5 ztJYM_!PnTN1gb-{*vvug6|+zTN~NghILUW9GC8+9F@(J|^n0(@nzQn_<2?urx=vd& zA74fS`E1{_GPcy4A-J3y4%wI5+8@^XTbsQQcQw-? zN}Yh-**>&8O>HJUReKiAHuNd^Si;1mQmccbtc^1VwXl(eZa)1=@yE9BWnQzIsyKA< zzzR~TaTR|>gNcDW1vr3H78VB_A5U$C%+{OhDpuPHvAK`{6hoO)kN+`?@wKW$sjuMe zYi!hVCK-#1#bvfcBQtmD(<42b`QC#tsqXZ3XyRGa>^C9GL*R?veXie!A$+1SPlUQV zBd+$>$5CB4^knsIF*wTz_v<)-6_55ev<|-7iI&*c*M0YWb^c)#^i0*UwC!eV-d<9_ z-tTg10zHQ3@&ZGEd?%6a2ePqDX;T=t`CSu-prz;NXv-T zB!-MHZwAa2NhB*P*W{~ZhqH99EFK}_Gifiy>AN;D!$=rpBz{c$zKNowN$X9KkX&{C zqn5)OJ0pl_7`CU5KfR`pJz`O&yW_LcM!)@jw{9nI5F|&8xXK&z*1(0{(HfdTA5Koj zuwX~ySi=-OU6gDepip7nte%>^yVZmf_d9BYRGk_DWB^kE=Ut?qEubG`bn_?J2V%Xy+@U!& zwtIrBO|E_4U)wCU@CZ_eEsNzJ&UC1jYZxRvJvE z2oQ#9Ap&Envhhc#4Mo!)j%7b^@B^1Z-@_R8YhxcBvh+aDZ zbONfZ9K2a?7B0{htoUyVUKS_0#Xf~BhQ=AJQMynm0yW7%`OJG8HuUWLtJqr3m}uz# z){=WSn#$lQa_XpoZ;$(@6j-9o`}-rc|1QAeU%9mGd4n7os#+b~KkfDZbDTU*Z~1Ha z`<+-j-d+}K7y9c}3IKGHb7bV@gW@9NMv`aFEMwVF;Wu~pFS~RTjhQS}Cv|lnp-g?I z7d0TLD`mRhUkH)R=FZCun_R{jwZ^P}M}grVFr~*$tN{%Bs_mSdaz9Iq!5Xjb?+=4c zh~B$#`+I+XK}ZOrp`qb};2p0XQz{;z*xRcBXMI54=&ll0|FGJ89| zK0M7zOo~941!cjk3^x5OBJBz|*gwD-BD_N8@?~x&AqTQ{l}YL?LCc!Mf1pi1FunNw zhT{=SxDn9HU`{|0pmBs$(QaQd7GOfQeZ{F$Zud`r+P7Ut!|a^$WSZ^FO67ZEd6oJ(V*b6DTC zv3eb=Hlrv?qsSO{_R~ssbaar=wg=%?NC}HI%k|LirWlwlc1-?rkuRgLYDC8j8lzm% znFcWi^D`pvH)!PBlabJJr+Z{0R6)lv{%aZLv>Cbk?EkxNEDEV+A5N_DQ^1piF->?{3vIqX!o*o`+ zsQ*_l49xDNXu7v_Gum4%U}2Y_hBRM{oKMHj_pgUS|k59b^pU<4;*`a zNc_)Z_3$|VziTDQU%t(J>4rNkj|gK@{8G_hFaCY}|A$LjdNu$3S`Qp!{nI@3)7$;K zf>mq<#U>Zu#npzK|Hmc43eYUApc0G!zbdBt6F7dUZD_unGkHoSQ8gW zEU#QkOM3W_vK@d%4lT5)yU`a(H1`oLNNARO3UUq-ZI{|i$p1PYGy{6a zoO7)(^nzgIz+fBCfNmp(m}5Q%1=>_^JyCB;Jc5zqH07(-AsGiUwqU(i5;Byy%+~{yGzUl|LVQOHbT8T- z-66pP=dcVPTQOg_&z#(fQ*z0tMV|Sc*Fc6kkGMP9LCU3v`iNhgM(=>f06{gN$#RI1SfF~+-Iq?Gc)&wEI^Ef z7(fN(z8dMhl~A6USzMHmb}_baX_@o=EGryS$kc}3z_)5d)`{+~tkcf~)ERRmBeWvY zgD+glp+%?h0d#5oljI@?myGt(ZVXDTw$fd)&}e9kF62vo38`R=0VL*|5%YS8WF&-i zh}k~qf+j;@x3p>7IQ(?lZe`YTuPrM&fQ}#1Nd?z zw@#9TF(jNsLJuycDbkTHp?PXe4|F^zNtjaQfbB%TiQGtBc**LFOS8$YE`nNV7xayg zA@&P;;_FMa=WAhgICyza;UnH$FPIypBwd0UA|WWlk37$iQl2dj#(R!noNRCT%KK~1 ztXyi9^dx1P$v#Ixkb_W}g)#s-es4kV9l6HN;bCvcVW13@(0M@@Ze->eOok!k2M(y` z5`E&g+}&3Z`o+B;=+I3MeFzMym9%3xx#X(Fd{?hvJJFBnbs+Q!(HdetHa;ARN+fY{ zaYFA0&sqe1d0YHdQgt_=E7}ds%fw`JqLH2WKD6T2LQjNJMTAcKnaI*}qO<1w%}R8) zIboiazyAbg-vtCHeq6cq>&1La*VNLw{PofbM%d2~t-rah_y3(QX(vmkS^CVYH~(K* zUYCEHV!VLWlaXRP;D4Q=we?F1M8OC$!J|vXb*q*NUB7adPvD^N|#% zuij9jNuNB*)i+t1U3|7Vq%q_rSKpXpUS){<@M8gUt6tSUyLNIH(yt$9S8D1Vf{)K{ zS@!FdDG_0cUwa9@T~8lQ)Ith zwcpM{{_D@A11`?gzyI=rdYQ=UJrDoy-@7lHcEzvX9&!EmCkfVwr#(&5`}GDA3yWB= zc~jJ#Tj%OFh=~4vS>kq;F^T$u->S%_@R`s}K2e`PRvE4T_0Ph-awk0aC$!TpiSqB2 zP&iVh_xrwu8U7DzbS0JH*T;}5cWfACpSeszSmU_4)&jVR=%&OvV=x1JI6=kFcq zrA-=$2{x1b)J*mJVqES-Eh_@2*=kmHFuU)G9kfQcj$v!`@g_^e~^5Iz@*1Fo-Sg3rxPdFhQ?NoXY56bOz zni+^`!g!zfXV2=O6Q^cTb@Hwdt>+o}!*_Kvem_kXeF_Sv@fPz62_K?b8N@@<^x{;h zPzw{BzlzY}bVVAbND@E$~1fb|FXXpHxVwhgdb~ zJuu`|zWaU6FH_>2cJV07hmRlYF+ft`{P{?`;XFm_b_eUiA?VnuAPlKQvD-waJM+qI z)X2Wq^^Lo|FTx*Iw5*()vaS0Rd8fMY!rT7cjYSXBmkUc>Yg5!ueNx7qU))rIP9y@vRq|aau<4Nunn`V|+P=gdzVI_HaBAOY%7OQvN zsve7d9dUA3V6@&KwDJBuqaVrsrO; zRHA<@74i3E?|9)~2Eo~R19g!W#in1Y;GBlPE64Kw{!AM2KfnCbE+24PdT-<|FKCzj zwsn_`fFjs$Ge(`=`In>cbfz|aJ0esTdGq(NS7IRB-sOC?zd=m+^wx-9ms8la#BvZG zQf!Vh_KzJLamNRFaQ<;HK|12>nAMh=o``8bbEx$j7ShWqIb1Pk!o+`W))33Gp*Ebbz1&+JH62DBE z?f3cwI|$}bO(_03WTYpZ+HvGD0 zd-AGZ=aR&Nb7{8ZC=k1fpZZMylH*C7?ca+?!0z9FDHOq}IJkSs(%@M9KH@*|y6Jll ztL49cp^24vD1Y{+b2KhcYv~0N$Cng_`hb6a2(0`2T8V@E_ovf{T1WnR>pz=q?aroy zzg{8fKK$2jiv@>)%=ypH(|v!o{wLeWlwc^!qjV=d@-SW+~Gh`w$KchdbL0w)%y1!OlrVg@)mOmP=4ux2^-Jbhx%Sj~Vy_`){9w>nS zKjS?f6KJ{G0&LV`$QmZdo`M?OLssGsD86jK5Cm>$VZqCZ%?uRodds;&`eFEg_7{tA zwBCE}Te&M&ksSF$&Sor@ausu2POS>7f?ozqoXB378OnkTLN^dru&ZloAgQ}2Dk|z^&XGeuvE`l+3m?|F=;#tl z*$WB^12r;2eynPyZ?>opai5tTc3NZzyAZC^bmi!O;=|OFKpRjIKWab4(aGX`U+VN} zSqy1O4cfqo?*evt|KY=VJw3hDpl+?WHAT^1Fwym#179Rmx*o88{gf=q!^1NOU3Ab* zeQ6&d5ymP38S#6nqW?J{EO+0(e}Demxg8RNRT1tGPSuU`oF)}LOH4d_&VgyuaPZj} z0AB*hjgFe>T)hgE8Yg{z<3{R*{;a+y4_T#y`Q<<-`b0tr@U&0GHCZ2QlBBzL*KXOe zCGPq2GXnzy5A+Tw-(INx&lft;v#DYxWjIjBsVp8+msj}}^I50kDUF7Mm0`DtUw;1l zfL}nsv&cvn=NAbGK-P0ZNOjk)T?4oGJh?{TrT|t@;~$5!0aq0vEnP^~3FSqHv+XHY zz0n91i6K!eDkjoY)*KQ(3GuS0)RI}>ojL2Da=6YxvnI^trB;KeyMV?^)Bfcii6y|o z9z4Otq}NA3YLm1svVPR0nmsKgWv(zo5{~8Hc2AdYI~$&&g0gZvkkzi;EB)p0>I;Qy zrB272@BSCSyyMAlZ$eE`!`)+XREO3Lf~yg&+ejV^Sg6BmMbE~_N7KzeS$9zFlwU|l z0*HG8;ZL^NC4I}EHO`nOH27H+E?k&t#W-zO))g&z{g5bjIBODn;0yjB+AHe`a-n91PEJ=&5kAswaZwxS>Cei#+&49&(nSXGw z;q5~tGBNipB`k&qF!0t*b1x`nBn+tf=jcCv*Encr#eU#G?TW3cPa&>nfCkLKW7dfc zhmo1M&#I%NGu=X5(;ZLGx-9dnCCqpCRmt@JHnrhze8@BrKOV~N-vPNrX+oK8e^Ty_ z;*oD|W<;ndAt46{{4ymauh(g^)!HX=lgx3>FGjx&AxY$w8JVWEv=_$aM1q7?-l}dB z)@IX9Rj?}%dwiy`!g05RGUhQi%wev?X`e$mUtW&z(Ec`;6x#tx5zr`VIHU>AkING_ zOBpjtQF2u0$ZYY`Fvtg)8$Fqw^ zg*LsQedgfcV0dfr1&Lo^-WF54cGa!#vcvypRgEDtns~x-55|4x4Ckv<2uuc_1&R)V z7IJ(-0%1vY0;6HIED2jTxQ}J?ncrBCnw~I);AC^7V&M>{T9~>pJc+?w^;obdC=>R>(6Aji zaGu};T1-pz25aN8;Z^FOzLzwC$CidE5)Py9_9R+j3MDu6{?u#-qKL@77USYBWU$v` z45i#~?$|xz7(yRe+fn-YSt%nslv1#MO#+;K&-kyxs=Jk zDhuA~dI$^60i~yCy%}(4c9b2Ft#)J$sk7ILX}L3P@sYK{&%2MWWbt>M z-PUVyG!2;=iE7-dPRhmjP1!Vh6cmuSOoHraX7tw0n+cjfJ}&M-`IQ4qq`TS6(7SVA(HwVR@Q2UpJA4IZ37jrX=$B4gR5J4PH zR})4ALC2Q)<@}Y)mt!L%#ju+PF~|7u;lsW#0iZl2s-gZo{%%jG&DRvD%EKu8uv%i4 zIdP^E>+4VJ#pX((N7=S@1tpCx&g#*=VDs_DtD&x$ zIyd6tM$8Q9`qZe4dFlISf2zkW`gtPJC;UaJ+ z%FwrdTXf<-M@rZ*r_+rg)pkO^p@7+O>I7~%yX(&W&lX25K>@@=A-Ab~?PElKV$B&Q z-nXw|4y+#`9E#ZVcoKX8CiCryl!a)ep@(YZ8;PQ!FFEl3$Z)QB!kf0K?Rl{BmJl37N zry4TF@6v`5BFrFeuzwn`X=z7!Bd$}GR%;cj};Xk0T^`GSw_O<--Ne-Hv z{;YNN^|s|2$y%~EpxM-K~1>=k$s0!uKO zf8ZpZ!FX?w$&BhAAt4R)4pl>a3gS$8&}+ z!bHyqJOP1`!sUcmH}RDqwOUO}8QIv_Msr#*(moHs@aWL``$5C zHRF{LSSR9quPv^O3Q!;99Zsnbe5gsAnEVB{WmhqW>LUFUmF_;@(9k5X3;O8fht03T zv9WA-n2bwJ5^j4JAeK9b)eaLEYjy|{HZ!+G^i$!K$A~Bwno<)E)_s4F^sLw#_5bkM-B0SiQ}hk z|NSI^u?6c9iAvyap`tiyX89hjxT_JRUp_91Z!O!6lD1!wwqhM2wIN0y7Kl1WqPAckAD<#*iO)gTAKJ=$8|LspGI}W2tl@?-HzWTwVF)^($3GGZ z=;a~D6tUnnaA|>Se;mhHVP%}iEgx*FtKv)HwN|95-}><33^=z_z#&H$rn&*H(r|OF z_G9JPyVVE_EC{v@d@KDCv&e1h+z)ec@jwtCcWib z7-tCfqfW&nFAD0sN--f;fz7srVZg|qeQ}P@-vFV=0YV?O8F`BFCB&Ihv*~p8i3ID$ zcH~H7-!Q?w88F~9vvM`)vsfNX1FAplN!Juej7{Ul$HlVsGl7! zNFp{a;<5&)ap(HaBIt<6c~=BwUJVMKd5(Si`zDJOVc5}VG~>n$Wa-C;PohoA+v?$e zwo&2B;IsSK*##GlgmgM2tf8b#L)#VcE+1~pcJkVGq}m{haHKd~1Tl!JkTi%@gCaIS zr2?@zsQ6ks>gGRVp7Q04<_UmoHguGx7drC5wc7zE{Lts;nL7@*;pgocjioY0ueO^gR1LeoTyy-v?7hoYXaI)>vRQXP&ZCs`G=21^Vd% zh>}*V+g3@0aI@cXWN_+j5uGk~6!b}ToVO)Bkq@zLR-_rhfb%UJ8YUcW-9l>$eeL6?&J~TYv7KWnhFSF z2yFNgZ2pl`K`8M-g%j|1??V|uJq)N@nfDbg^po%9{OR5x;~LDjIwNrf@ttsU7Hpri zM6YFoTA#V2l~sBpTpe-Qm``sZ-40S>-@9%6*Pb>UMG(rp>oGQt;UPYAJnG^;s;r;TfbraG3(`&eL{n{4nl94_^!b1oHxRtU6(ggB#}6NN+j^77XSFMI^XC6(G7#)5z-GqnAmJgH zaqt~F8%+z1j?9nOEZ4U)?^Dz$ddzq-?VV4)17*noCd)kPu}LF5senW@BHF!!=qN&iXy;)0#P=-ZuvM=j$& z3QVeoT%r87~UrQGbFEc-m36CSK>osv0dBy?a*XaS zpu94|b?YF12x$>k{oMR8DtH~$hX|MA1N{GrhzSdqug9!t33?&N+#w}n{zQ7_GU(VV zB3#0-(?EMC@}2wlOJSl?8a!r<#=>ZTU_chmH;!*4;LHFTUe(8rS$*k`pEHsy=JX2+ zYR*0-E>T;`(t;Ess~-MqT!SFAK+4BBfBH(pi3KF|)V+_7dJv%xa!n^wLp(OPZTL#_ zA4C7*DT8t53E*|pAR=R~|<6Fjc4WGJ~o8SPb9SLc&ky2* z`=94$QG8EK^GQRA;_k`jDW$VWL*-4jJ^N{xkx=A8ffH!O>qooq5)|MRjhA~CAXY<# z&}?JH+@Y&;k9J+w(TPJiVLWzE;D^#KR#xtuBsZ3vPFhbMx%CudV`F|0SO1t|we?Pw!sR=0`k4iDykKAxu#~2S9Nd^HFyT9R%`31Q&@3Ye|9*Gnpc(;!M}Yh)i@t@*J@{QL0u$$)5WZ z1VJX-S|ajC-50@>j;BU2l_c_48obHQP>&}#UtYy zGcz+^KffqssP5UNtxxVDxuvP^boVb6aSeX1ty+2)A6f$O(F^UM<``|;H$ zRhW?mdNl%uTteKgItX;sx6JN!NfXESuVMgT2?St?M*K$r#!!bJkK63Wq#dR@-Dglm zIsyLU88Yd{4egklsJ2DRNLuvRB5c*b`~6948?6{_=D2ah@B|J@z6{IegyK`j*NNeF zjXh6n!Z9)`33D}PsVFd~Gp2pvM+T7-#aJ3OwS@FtRHdU(F0**IE5N3F()s{{TNn}U zAhtr&$9;?jOdvvH0wA~94jm*|Si)jJ|7w}Oc(G}%Y!%&Es*1R{j?9Cew!)vVMYhn; z#3M2V8Uf7>(CA}PLU(mabpUG-`XDQO zA^`=O`)mQ+jjWnshUze)j)08kAn;BY)fi{w%(vHo)eF}Ez(NQF zSlWO%=Vp^*q)V65x0)F5{ZmhP4E#?aIELBL$2}3jBr9XWrSU#|%YD|bpr{oT6a-6C z*7~`1hiR&oC~13IEvI+9NOdbgazY|Sla?SdTKo_0+b5@|7pKVl{@0;QRX~-cKP$0& zj4~74!XE>vC0SzV5mRNRg%?b{E{10KL}JfnMe~sxJI&OD-ZZcw=SeliByhy(tMSau z-0)m;ZZ`QVW5hci^DvJC3MQLg>UiP+gZQK#`GIgF3N_X0)m1Dh$>{hPL=xbkUf>KM zU^uXRevjAsl<0dsIA7-x*Q&`jGXQ-&GpH&lH52W?z4e&qLEAweVSRjJVm;0db*k4Y387 zndtIlm-I=@`=Jc8mf6d9UIjuxWe1fS!M6}GZH5kmufM-Oa2>%)<;H((b6Rw0$TT3> z3WFJfgTwI~ggLDNoEhHIRa;*_2z;k&kEY5jj7Sor& z49FG+c?bLpmIa|H4B_{1!Sw3 zZ2PwhOw`J&cdrFsU;qb}J)A$QJZBHCjEKmIt)y4{fb)LH zJB++qr9LYmBhvsslcx1%jRTH7j^qUxK`_M@3)6i>MkaeO{#kB4gMe)+!7vxh^|0Fl z#SDf&=dt@?=4z8^I}n zY>UG|B)$VFj6}i#LmI|-=#wEYv5)}s{zSrqAR3Sog6XnQ`347AhtOsm;jjS)=QZWo zW$5ec*P^0|cu)cAws9Eu+0@mCnT#S5`eS|}(wrYkEQ34(LDWe4U@A9#kN;~ox19Z` zo7(L+{7|2A=xXGIU$y;Rxr?cChZ!e?tha67mF{^)l&o-l22pWG`@7=qKuaG7&KI7~ zX&;b|ZCxxYbPb4_KEhVnq8ihFWvujLL5FN$je>zHappRdRu$C%l@37n1eKMX+)8^AS+CF2R4l#Lj8yjqGqt2_taF@ciEEG8lN0K)WW z3;Lbpru|be%W?CKYZoviD5N z?6!)?-g{*$N+H=Zl9ZWIR`&P2ddk!1`TUOG@jbr%@f^=#+}`i&eO>4II$y8XdA_lM zNSK0~4zj->c4A65a0?3YcA#32fv#MBd$w|J5%v~DCIrwRzehki?tHeD+$y|&hCk@a z)A-Lj77Y#cu3dG&en{eNKRS(<%Xm+YRX#9DuT~KJ58}EwYUf zrz@`vwblcBWzWzD&LWSaqS7Gr{9>Ed;8xGnbD{$Atx?vUF^Fo+GH_5HI120tpq`G{ zCC_{)mw|G$ic!0MJsxUDcHoU+juAm`65yA#NTO0puk`lO1wPt;jRFp zDC+9Qnhikj6N?PuXnTJ^j)nf+K#g^*86hd5{7$gJV0tK)Jm-N#M$iNpJe^L2HbJR` zdZ5{&9}XE*rH*z|J0@ui1E=D-*D+AKqbG_iwV>LlD<}k`-$PZm)7J>gW_QT4unaJC zELh2`{iNc-9seFSAL1jFUulR51pt7;#CZ>Lm)VMG4&d2eS@eS%(!@!Q0!llM?dB5$ z=`a3&mKW20fE3>USOQv-(E7I`ts6dF5mo!Z)hSvYyXkH4Gv;8s(ax_Ft6H37K;36ckO0h}@#QXmtmT3^ZwyEI zmqS9|UmXTrK_y(v|L4npJ#Q~61|%qxz)sVB&0|fu8i=a9muSz+cg{RRdc^j+{^nxZ z8r(%1i@g+HzwaaQ*2d58x%0{2YmxkMTu`G|$0i6UHJaM# zb*6Hz+Pq)duz!U8JnQhkWAo zm}pBtML)thx>hfV5CBIhAGXpt*|U=7pWpv+0*rKm4z`bq^TkqGe7p!+6mU>qL!t=M zfp~5;kk?K>v#Q$z0(DrcleizbwaW&C2ZSQcthHfCSq@<8(?BT0v_t;ozFUSx{iy!D z5^yvF=z4zGf7(@>j-5Xply!8n;fx|UhHzAz?j{Y(S0`NE|ZA*CPys2x%$S`c3kmXe5FgvZ~(~4Nlt+#ZaUU z*?+J=vJUbM04S3Im(KU{t!LXC#zFL#h2z~&Egq$F_W?+dS)eeOf^!&1Uw4t|S*Rfa zbUXqt&eEI9TZ)L?>geduS%9m{h5zv^HKd5Ejj5{>8v^DH{RxCPq4Ky6?=B6}+qF3X zAQaU}c#C_c0yqIUO-6=>tYF)P9xqe-exy^wPr_k&DAN4rK;X*3q+<$|6L=^x07x!D zo}DUe<}h2b7h>-LgGHxjZ}rzmZ52JeT;y;^UV-W{dMImhEcL6VJ5mA=sfBPUnpB5= zg7+AfQ2xOhl`jM}xD_&-W+yk;^g)nH=6wDqPI>mgP?02fm-HDKP?Um9wQDw)wcP;x zO!XI~fks@KUU~(O7;eOH!LwHc=m$TY4f5gi9gp=GD2_iALS%~qsRsaSuRRpxGt^A; zP)saEUxqnw5F%?gXYMeXxxmiSO+fFH)$OEdNY=@};e|9K*j@;r!b`99Zpi@SxgO1J zf=A#I=E@+iwz`h>J+LA4eyOQln_~YZ*Fw}nf+VOJX5ab>kX{fDoUIlIP|tIHp-9dI zjA;vBM#K@6#GNS2L3Wx9@z_XGx9pr>TLUzfWE|%tiE;(_u+l*-rGuqJwM%$Z)VNJ7 z80^aQ(m*o90nsqkoA(Tzz1wfG=rzVb%gNX5$Hl_-1XLA>wh1Itp2}}h^lr&hQ>I8Y zK8QPZGwfLb@#fRA-%ps~DOyI@mon{n?_B^yvNU$<=x=W zQOnjyK)flUf`e)K5)Ug}5(O>02LwBAeExhBWh#VAmVnSt2ig!dv!EpbhHvWC;cH!>;LJH}LMk*x zp9k^_(9F5eaLWoK*?!S~uujBX#Km%eIDouuE}YklE$M#4zQ-5-LsZekl$Ej&RwcDF>(ZGBm0-l1B@_q>&D79*G-Y#Rw67%zb(V*(lI zwC6y4dGv*j5Z>Ibes8%K&R&o;e2SBPBHhsHAd>1+WZs zfYy@%7EnYAu$?~L2Bl9H++t>NvC1}Vpj~YF&S1I(+B^%{Nr6p=7mLTB?>M#=kv zQ6Ta#A?>ceNDB4>@;M=1%mhjSf|6Y?10XOV1Bjl}R|>?B!RADIoE@Cg3`n#1sW+!L zfLbZO-FOJf?@S~f7?$N<3VZR}tLg6_O=w~gf_r{~J9SJpV@SNT9HH-_aDoZU=boYUbR> zyAUV%VIAxsHBUkDhKN<$Ce}1*yzu(B3%wNyO7Phzq{3#xCqwBwVDm?t#}KwdOp0#n0t6tC8HTqvpCN$w zd=rZA3Ch}^S43up3vS&{UbD)BU1Yi7y!`E$6>$;J04TrXlfjFo3QiyEmg(-SD3~vh z4fV7cXcS0MvK_MPb+SkTq6b9{VA;W=#$9!H2AssIi^`qqSy z+lA~d4x7JX9>3j>=#SvK2rVh7s?xzD0s$EFS7(c#!kGk;_ZK}5tW6Kne<88nGan8* zt6eIGw*_lJ**%|qmKBpH5y}7y%7YrnAP5WLgY6FvVAo_iKvz-)jv0*Ay%bnDqaDz) zrb1Ji>e3Dpd-GByD+fmpgjHP%>k#3!bayC%oO5hmQ>dd476x4t*mxAVEk|zLLT&?4 z8Ag`{5h5BK)UliDl7X=g=HG1|ec@L?1skBrzVSH*EW&6IYrzue*gKbAfP7|PjB0N_ zSDOt5Wt*{xBjB#}cdkf=LuYrO$oN{ofZx~^A%^f|2w0}@wsIu-EFv8c493u-7Rn4f zlHxK~@G@A^yox;>|UT8smaHab;U$moh_A$W5$veY5o z9f&OlB2REA<^rT*&k#PcF>}X*9m;ToN|dGv>)^qZxx@5UHRx4bgo089)~Cy9Ja*iM z&SPX4atezKigh4FOxiqwL@Ki+bcPo4b}^(1O{eOnwn3NlNdwKRXZ=;)h>Pc^ZO{Gxo}|G%k^vnSqdPu z8t`4n5{1?WGA2AqHSoE@vVT?^DhKAS*U>Xb*Y*2rM6dKNMzE_B({Tc>{0Mz~-|1}b z-qq2qSyKiEe61WpI^I?73Gpv!0gh1ym_2b2Y1aV1CU|U&(&sv%|Jp%`bFnhEXX7}{ zPnTFIwvyLAvK%}Pm6-V%R@TOG0>kGYkmKPParceiumx+)syqE=Hh_dxm#H)hs=&W= zqRy$k`H}pf5Fyhd^6BcPCN#B8+WO_tLDk0whX;1+28Rv20ll{{h-bGi(u|<)$wf{( zc&6ar3RTgU`3VGu38KlWolDGO8xYB=xKw^{~L`5JryZL8|JEI3B>*rEgYNaHeogLU| zh$(eJ!~k2ImUCSKX-KxN;kMwch{C;z*hL zvjJ_nRSJ=%5jeO2EdkUe0zSE>5i6KZXl@RxhdzvXl}`KI5o8IM=|C!CMFoG}YF_nn zG!WTc@Sb}0^!`tMw?b553JdbiAhQhq%LN7~AZ;QV{{eq)V1cD7@j=j1_x+~ zJn6{j1Hi^0$B$eMXnth90#*wdsZm#LRYU)WT8~;7Fp~v_)yh`{eJB;67qy;3^5WMX?K|CINvvuh0vBRh*25p+$_gg-0)Sb2fIvShQvoV8IzPQt zsB(ZOg6KQIbXg{ysXUcJ5^6EDK3LNIUcbAiJSUKgsTj!T-qD8-_kyO=%40ML;VfH2{$6yF0qT|?S65S$#Ur&9Bqf*JNk{;lzif!JKwBas2ugX#bi)H(-LNgD!1so{ zl2B5Ei=pml$uy`WEo#zZhpwc0a*#sOJ zhzNGGMBD~S{c@dMaMV8Y)C|Ek&jO75vcz0_;Vn}LD*Dl-z*wO!;NSPr@vEeoun7o! z;ISNrq%l?c5qTrPa_VWdaX{b_HU?Pp&}gEk81NA@LW` zD5CU1t2KrGA)v^hbnXCEgtE=%36vU0G(~Q;V)6joWPrXTVKFHGQ);K_ zPK2FK(D6EQF9|8g$*BOi;W4P&t!vk<8ZGSR{5F$hdTI*5)+e57=o7MVt$@6fX9&n* z5sA16x&iLZ1x%d*T{=j*K{yvRT0=!;u2>1#wAFs~E>a6%I?UQA{dPL?Hlk3jr4{{N zc-_3^F_$>fTg-WfV;R<_l)7x(#k-BP$w~Lsw;?hDu{p>DiIi-lgW|hDP(ois*_kCu z=X1#{@WN84O@9rZS2l!3C6MBk;707AJzyG6Y&YbbMNun~%B z98k^v^Mx06=D?A+Iq!D=^A(KOiT`#n{zXD!4iPY(c>tC5KR>QdJ~|P!oZ4u|jUlB- z>zrIG|6b0k^>0@re(IdoDQ0GAJZK+@gVY#*dgSU!2Bv$bO;hAU2Q+=iC~0bHhHDdI zzQ5#_2MfFm9w2wpV6^kEfhJn+N-Mmf4Zoeu1@Z7ufh0nOA50$VMxkAm5mT%tt{xLq zIT)x8(H8y;)RaolBE$@mkb>r6D~D(cob8ziVTYG{lUtPBw7u zENyEsN?w}zpbi$Xc1u6h61t1{>hDl4*1$=Qh#B}EeE!WdVi?A+R+pH&LtYyb0gIDM zb8bNP(&>Mj`vpNGvYQ>$L{@y0=-h|&7+elgGnJlO4kZra=P!^awW!G+!q%M_7^C^E zE8wOFeh^t=3f=%mu}RElz0)Nc3Jthd zZ|*>0%ne=W8r}8(__#5icK_ZQ=AX$g;_8WiE)65k$l!YgOS^ML(nrZjd%@lURDu#L zwklZ8Z&Ue^*Lyf}gIl`ye9aaT-VT}zxsu(L6Ak;f!`;)4MY59M(K4#|q0rptT#c~T z<6FVsqurs9^i9;<`I{mV8XOFQz!2JfesIbZC^2dQ`)t$|6dfIH)LS{38Uvis;{3%O z(FQ}x)J)`iKyY59IK~r)OQiYtv)Bs;seDD+ZXMmZgI5QeRk5XAppRq%2b2j^IPxfg z()jA9&(0K}KwxuZ*%5%v8k7^cKwPJT%%W&YQPm#tr`Zc5qD)E}WBrcn4G#LT z(2NLfP(@+8Pl71ci*$Vi?t#`v&0D)Ze*AB<;dLINT*&HBMw~u}f7*>_HkS8#m&~8{ z)k5{k`KgC2GDdd`pGL4AG&zO&+4)nD7*pW)&+rX$ku#&T#0)*w$H|`7kD#_rjZ%4$ zzCdt7Hq2Nfaw0%~Lktx!l0nodAiPL{!>8LME%s>2(Gx$ph>rc$qfJR*K(_S%J7MEq zc92>dD4nbYMOoh0=_+I!W2-2$)5A?^moBG46XAFMT{xOcbG!TvciJI;w-9P9)Ts)c zmMP7*=qCiy;G|4bW8;InS9ez0^92@LDh6Fj$!nHU?LEi-)%P$WF(M*;Vt)_JKoeK~ zZDRuw>xxL9Z&j6G^?V%}gb$k-Jz$fOmUUBG; z)dqum_?55@t_3_0@6DV?O`wqip9~!M|IK)e9{Xa#myCbb#FLqf*cS%Y^n#snmliMZ z>}WiOBaJVX)Lu^XZw>GDzgp6-e)647jt=s_;uac^b*2utc;T-${^|eb1nUNz`@uh# zfFb)|ZS@SHTtpc=juuQPv?pCLV(>cg=T}5r|NDu5zJoz6{2D(0E9d+j9ANyy`|GSb zm)v$_C!1g}sOkUU`a6*d@|41zGcR(2!Iy>W-{}VV9RK-^9bxs?^bH*@`Ha$@UfRD_ zFEsfTNPdvyln0i55~wS?@}6}Sr9C71A1dh;BF4{>)&(jnx0Y_Gsj0;5tZ1)~hE4;!YpceR`E^zvdbo@jgn{*yJgK8s3nk{tmGN#?-%~>s5&@k5`*yqJ0n{>$Kbi; z9;Dv60k~e@EB{?NCH=FwpO43EMF0FX%+>!MFY6cPi~1paNB&)kGDy}uRDW#+n6+^t z?q~YZMO{&8&7Ju2zmCem-W_oPP8%Jpr(~xA!+)0QqMr-)?O*X`7H;+8gu)I38t(+_ zY3A#p$RPV;v%}zy6ENOX81=lK+G!Ct z$<-6jewq94zYLFp{B9P5EpH|bd6{phmiV8~{DqbHeoKq7hM7)Rj~g#H&nNGAH=RnX8ezZe)Y8B7M8d7_35GDzn{3oJL^0KTgd{J zKUaV$fRR94Y1SeWx}p=pJ%6u0+R*PzmZ*Fe1^&pKm@}{;xzfH<`3CJ0(=VXcWcc$r zo&vBdX`k<{J6-x47=uIPs3r#kjM=S5o-8!K3J?!;Ue4YcSS^4Cpja5#g?N9|^$0A1 zhQ8a9>TWIaYqN0S)!A9zoSws6J0DNM{6AovDRO>Fq2_bo_>uP_9RL+-TLZ^SHZn&- z|KW(q8CTQ6*Xb(sgw%fCr%xr%?nH5oC`g(}=SQ~bKf-leyeM!`-zzu3@Q0?NWdL$P zmXBZX0sRkhn?q+c4(fTJ-2B@vcpAS>V`o_O`Zs7{2H}=nKY5XOYi)tKB8+*`p>{o% zoi=>RrvdpES}$)`ArAv;EoZbiwXpz9`PSt>cZu030@(O69f&xZXE`4R}d4AZMYviWwUFFTRZA>A4BWNcS=&1z$w?nTW^%74xGqlHO)9`N?RG3EkUDTcz=Ur{QHG$!*9w|v9#f6rdv-nx^v z6K$e+my+l-pnqicro*5$@V!mhIExz^KzZbCl4w|BHmr9(R`%Llx!}=1OwRO8Wq`Djtf_Kaa6s>3=sd%NfZtLbmI3^S*ae7Rl)L_KT7 ze(=bWnwrR0IPBkm3qzOsd_hRx{APSPjQK3I)tPPmYDtRHwbzvRK-6gTlgg(poH|5) zf_G;iCk7AYfZbg6iY(q8fIsY9-P}p=I<-Srw>o&tCx9~#R9nbyRu`fA^kZkgsYO!F z5K9c8_D(nn;1UKR32!@Kl2)FIW*8t_2eS#%)S+Dmx#A{4O8jNEhL-EI`ZYQLl0o=W zmz0#0c(zN>nTls(j>x@x(}j0q7+~hHba}bdoiqQ$$o%uo zc3cmCe#*Mvetowy3rh!ZotFh@VJ2kG@JS5DKljI7la<-EXJ)3!TboHAj{mwiPaMR_ zZ~a?OE23ucNu_G6?~kNhVxf@Ll}+Bir@6aKpxswO<^%H+1qgQS7RgJG$ma+L9*Qfq z(zBD9k})rl8B`t?kajy4F>ru<1O~#2*c29pY(0&+`Ws2ZSku0uuhvT$XS27WtV zArmr^KeNld#rnZtGa#?Z6w*Ix$qr|{uQY8BqkjHbX+^A&8rs zzb}N->91*Lw=}ePl3ChqU6$993$3kZ8M?U6b;5`-qeheG?2KHBjg3mOg(_+seCDx( zM@#yXgD3H~W$jX=jG66JMho_W-Csswy#39w*s7y@;)*2st3FSRK6F#+x6ByL$C}f^ zhmtdI)qkCgcPvO$zy&lA`>CoPM6s4v$^pQP7r};ALi*LcT!=t7?S*?$LUt|yl0h|AjBh`L& z>vhHFc_Lptw`p*fnfHpVOS#JXp+59O=pVy*I7?xIl~fir-feHlTc4D$=)sZe|H12A zz#=>PG2D6Uor|8l9;pL_=s~#OU%&Oq4{*b|zPMmH<; zZc5wR4W>WDt9~tU*YBZupNRlfDYa4f21Q=B5*aAGGT`yTWp%glgAb#jYFOT$5vr9p z-Dt!emg5ih{1gnZKEqms>XoT#x00dZW-HO^M>Fcwp&9EwScBQwkVzQG=jp%4=>Wm6 z>%wrM3v0|d6#pc^HRv26Ki$&ywc_!{>#-qZzO#qxj$%`tcl-YOBfN1$zQ0R_7gGpI z#?TzynZF*zubN5ZXW1#=)Xs2o+us-}I{;?> zGS`oRGZfD{Ou0sV-!-USL_Ad;GqHNfRX5Py#zPLx8uSe^CQ9g(+4D~r{WS#eWXRsS z7qX<)yChnteI$!3s94Z04isb~oxJrA9M*uTcQ^6clu0dTrtSB&U0_XYwiCCf(S2hPmJB0*?zKzXy`Zi#pk12C_qcAL+PL*=B8+GLpK*xsTxe6)7X`GR zENi)A^iH&u%+H|GeWL$$dXZ#d)O{0U3`RrS#$%3PcSDiDFXC6^8-QK&y(bvVi}AhT zW6o3o8k9+gM$uDZ{8N8r!&BmFEb``T3vDyw%X@Qh>Lz!A&+V)B&(vBj`NF;XI2<7O zv(#v^VpyQ0l%-A7&*%zbsuC_>{3{=ZqB)Os-hZZI55g&2E@va1&f2?gE>0qZV?eEe zA-w7hei%yc_#X)-y@q09@}D+U82 z+X5eMD#<5=_ejZ3ADijHaMiE-#anH8{X>W1*3%gz5g1(j zXG7w=tEb-3?0KrBs#dJ+qwy&{Le@dX(1kfoXdvgo`9EbFn$lR%IIk}{ks%)}dKxgD z=ch|%s(Kt>X(KbUl7^^W`8MssSF2>QM`rH7QpwZ)wg29wKNAQ8Am_-PkDk2QrB9JC z{`EzAu;@g;&<#}uDhD9v_-H#``Y2gHD)#Vz3wV58J+bh7MwQN>?`N;y<1ug81&aDV zzU_F@lR$M3?`7hIao>D3Z0eGYhmK5?7iNP!@6uXPXU5ussv@8*(&LI6eJ)H?g8@kmq-#WBcg($%|l&ae^A8I+z9)?dMozyVkxTy(-3J!RW zIYV8`e)?Rl%5v@*)o>Nb%3n&u00!-`R!`2y*2FQB-K16ArSJB=|YvHoA*|q)iDTf!9dRq1sCaP}GHxDQB7+ zpjdkg>x`??+T=Y!g_R8yCyjqymQS+#{laG5IteNHors?!JmKCZ?3FMX;%F4p@ex5B|goQw}j&P4kDbJ*{NsG!~ZRb)fDrbUY+8 zi*Vu8pAksOnu}k=T}<1f)3;0P9*46R=KJ%LmcPk^Ybw|iMowt)JBOld);R{@4GH)v zGxB=cLoi5JLboT_?QZyZI<4=hZ~#QRHY(CJ&=Jmue>xMT;^8A=D5DszJxWhNf)#_C@cfxs`gP`TjG#x< zKeJBZiM%^+%gwM@MNn!@G14AD<=0CTB*I}OQS5yk&sJ1lCXSI0NXUxQNgyJWe6hdy z$XtZuh}&@#CNP*^fvGvtikryxQn11S+YQmiUMk-><$7A|NxSevLJfj2*PX$1z}86Q z1%WR)WrNp?pKHqXRB_*5HqN1am@f#AD7%5c;0A@SFvs1WPbpt|UcX5z7Qc^c;O@lH z#@P@r$4)%Q?$1w6WbfTml94Ia_AY(b+j6?f?Pi!d5#v9{I}(_q$A;Ye&NFAB5~JNJ zC-M!d5s?yP$ObMZG3BiKSEq3ri|08a+M{*`Pn(W?r7K`J8z-0YwLmdXvD&s%V1m9YSu#KYr>-)uC`E3Zfv+&n!zBUZ5VxM=!+(- zl_p0U;aSNSjC76fl<-W~WR_{^`eNy4#^Qs1tg4M{9u5<{6UcxT%U^%L)Iu{-@oqjo zJ2EjsgXbcp>4_5en+n&eY&QE~AS&V?Y1JdcqMec$q-0@TtbO?6O;&FIlv2X>F;&sG z>rxMby^}8Y8c7wf;69`?1Zr;_C-4=K53>DfwYg1zv8dmYyI-TPpY1~;y{G-3b-jh{ zx$h`NzfR}VX2Vo-sLs#gEe9(#g)d2cmH4ZBnc|r_xr&Ah$$0@m8LqsP$j@3>zB|nsFj67jS#nyj_ z@QyNE3#p#`m3zyx5w~T9oTYUuB`U9cXUtYTs4A75Y0kC=El{T;@-F-R5shWUzFwQITel_Xgkn1I-kU^<(_hw~npN4zFD`YF4|4U*mL5teRbjeeo)Gr{Bgloa_v5YD(7n z(H7r7U3Q6b@5f%D`jv|4R24mr{uw;&%L5&v09o$pd2BmvGblZ^&#;YBsbD&J2eZ)s z8b@x|8kAIZS?bp-30{TqSANN_B=NQ@8LaYxFnYphm5g*#4ZC);NByR%+FG2&W4jT0 zAI|ck@X#^gmJ>sPMusv14zeFgH>si-ZsFf zemOnrwIUxVx#9N^69^E0)YQ50MGC|a) zwIPJ@#PBhxzaBq^<)141^zhdO9V_io$MVN2IcXkV=QQ)7f5+2zt&4mNbW#$yjS_o_9=9&rZarmJWL5#S|scH3I-FcUMj`*@wDpC zsS1u_Xy54^6)TPP#I~S_xwpnUf0B#T&GPEObU;nJ>V618`H^wHrjnZHDUWQXK>Thg z0t(-kpPMzLC9>1s-l{*#pV#*KG+dQ+xQX-q?JD6}f0AlCz7vh%nOP~H2`}r)|5Vi0 zOLiNcyFNBH({#snmx_S_?gQ1HnqhX>@sHFglhBTTu%8+ee=L}qFEc-hy9<$3#toiUI~`mUtgT6m-5IpOFYm;6?H?c%)PXmLZX1B;C`N7 z!id{Nx7&LCkvah@TZG3G8yGmOy~a-K&zb4bbH7#`-(ulsc|Rc*J2}u-5~Y~di!pZeIm8oIL6c0nQzcTbeY(I zv&mkT#|eM7(apL05$Eg?JhN1ZiuM;(7VI?uV(wT`+j~roU&t;+i`~1{>9}2%pq;~g@NRr*KY;dVsO!&lUwra8^hxJ_+n=%3u6N(eU0G3-;6Yo zF5Nv`J^C~Iy~#_NGDF=!;xhWq@OSjr2284l##OHvd=pL6U2Y;|NiB}T^GmpIKl=9Z zok^6ed+3*NrHZRFBodax2_5ggKMi97yVG`IFwDPX;VCaObaKns*LlC>x_Oj~Fx!^7E3nP@XlN&`PwZhh9lZdO$mdCs*g{^t5nZZduKGJ82Wtm6g1PCUdE==jjwD zp7G;*Pi7s+O03J+aBb9k1PJj`GShueAS zQe9=N{#ANXygpxl_($P%SV1(k@)bU(sgAsBua@adN=g*yE}d=+H#JSJ5lUv#G-P`} zz}_jX){%azn^CBA{qWV?W)<%u_YGNNi|FLV3EuRUhu4V`)aU!wY#v6iYcYe;TqDcN zBtA8KSxn&UhqF)Yl;HhmfP$qTv<3T&+YAb&vDSM(_gg)~Q2cZK&9wJ%3SY}_C2{(R zIWAQ{tNLGjo_khi5FMFwh6GCuPp0kZ(Wj@f7E+pgl$0vM-R0U`7RgEF({vTm2&q%a zdR`@X^%+tLzw!O9EsBdKSUKHOcD{X*yzq!XTkM90sD^P=cqm)Sveh&Fh{J>K;_0cp z8}#3*^;6&KK7HFPq$RJVrM12NIa}!b?WmS)m7e$C+GgGsJls#GBror4Z_2xO-<0f$;?Bqf@A=cidG|`# zjY%D&jq=QNi4u;D^E^2!6!A(>|4l)d4S}cHr|CVI-KRq5C{&5o?dIc^SNT|!EarRZ z1b2x+Z&*T%A{gfUsm25t!nO28`+nv=oVR;Esq)^yWQAk$lc>xLaYf77W&s#W%{f!C zZ$3LSWVxl;k9w0JhkbvGWUQT1qgsY}n#!@dHZEzuvo+^bO8t6`i#YEQf21dzyj#p8 zr>yivO+n2Y3poXxH|K_eKEh#BqDKpvkT?HCC0EzNMOvwO{8{8ciMhBE9Ma2tp}_emusK zK-aZ8FX*K3VGXZd%5L7HI;yIynHOc8?767FxxATQaOl;y>Wf_6sKkA}c+)ldEkEPbk+= z(pFBZUQf8T_($c&lPnTbCP}AMYW%I#=&6wwG+tT?C@BIAg7xsH%Rj~~> zMJ595T3HH$pp0jvd>F&v#T!2_ozc|H>@pN0~0vVC7qFgtA+FovOjtYTXJdi+wJ-M+obvT&mqjd=+uhG@qWYV6ti$ zOxOH;qNyjXg2kooepBQ#Dzv9p@-EWhEz*)&{`SCi*AZZAHr+&mUV64Q&4oGEgQzPO zLPK?so{hcS$a@|qLz$E_k|)ItDy#A@e(g1Rj7w-QXa8IYW#=4=Qzz$b?r4(>D$mrP zEZyoPih4$5U%0X%W?4Etw_KLs&3PzM@yfxWa|a&vKWFmx@R0AcRXypbs&0MGU%x$z zf!zJF3ikNd)e+wC@%$e@)~1XtmjspfniojTJG23i-1nxx1uzw*NOf+O+z~1=)Mmk)GHsK&XDlWJxo<^I2zU z&XcpgR2x=JI+(Cm41meuSe8OQcCp%9^n!2bAINbjOjqiduhuX<&$IfTDPYiUVGAHxpU`&obDG5Q*yq4&nyq-l~<)!8cTkgC3r@!O2(A17oU?ITTc^J-93 zA77H>zZp&iFleO+icNLVM*-7VAqBt=$Q?)tFT%&w`%)RxRk4BZj3Gix%9 z;G~k*Nzcjc9;xn^l~YifV{|j@%l^I@08jc{_K&zLU!3A7>xW!?Y8k`iAPsH=d<`kd zzocB_qQD&(sT1{mKWT_%h`o(gZmwKrW7tz3B`4bo>tLz-_H@^; zC??1T(<*KyY~U`Qx|!(Ic@8V9S1oA8$)w_4{hd?iV$PdSD_DvBzk;bU!^PFzcP)9` zbb?L$N)14;$2H6PR0+$`WGI$TUPfBnO|P-tBN<|1PQ0~bpuPEd~ZzD5!F!F9v9;3>t&1+GkIna z95g;+EBpFNs%c`IvPE|OR1niyMRm=owSGO-o~fCLqmgEBdN1RKBj3HEJdtOf8alJ6!c%wDR%iZtO1^Q?)$|QdS z$xu-Jt&j7X&-U_@msbRFbpZ{2b^~j^M!lUjgvnT&xH2j6X2joDAdV8lWG&jZuuJ`R zt(@b19>23ME!sp}Vs#(L$Q(3R#$<*_mk{Bjx6Vrv88lxpC$EdNNNKdtdQg8n4&F84 zILB1xcJdOFwuMM{tKwE%X-X~gYMIQN-g}`q+sjlld&Q)5N!ds@iF~;ZW%Hams=(46 zQNo`uq`jPZqA@b{mdD+vhr$xwY_CM+*#P_z#n}pzzaH+)rnhej&6vW=ex;XHUvbyf zu5J;bI3f-2J9uuXRLODh!`aVR-M|xij`mzrwwJ;~KaF}b4FpdKc1O^v+%zRQ78@L7 zZ?^G+{~GUMVUl$^p0)*!abs?H3)U-6rrQ+0aZyR+bo)pRc{^X@$SU_Q?j4bnnHhS+ zC|6jip6?r=u0iQlnWE)VAQUdwxtMuk>04dCeENG;#e(o>b3<1(M-Lp_?d!F$m4Eb! zXFiL%`l`@zidfIK1$oLr;p&QFa%2an@YNni_v?@v^o4(lTV?b znbcw`&T0E89IHf5-TmL)7w5$4_!&u_X?`9MapL@1XR9+jMDzeJT zJ|7;VO^=-&rY{)s3rXT<8#1_<5%i2nob}M~WUO6D(@>H2>Jt}3UN-|`99L_|vl`(8 zi586lfX5wT1}cl>jN}#J^&sl`7nOW-yz{>=k#~l(EnhXPWFgVD@PuVKUq9wx<>=8H zxuKu;{kRY3r@dm`9xtKcR#p~IOiq5r2{MBoV6b5eE&OP{Y4mESuC{jVwnFTel_gxw z*+;btc{~m@12amRPr3?{y6`a-uByXMv%`A5%Qddy5j6bYj5V54A=yn+{zzn18v5+$ z;b18vGO^lMQLEm1OYo3t>qjVh%3FD{SglEK)m!e{%i4=ep|0Eemg(4}ARRRkB+T_P zP6Tn6Z{jY$x$)|Ad2&_ypp?V3#hi`qtCB}6j@hQJ1RRmu?%{9d=O4tqYod)f&X*is z%$%}J#rkZBjIrtjr=;`8tzB$_O7w{z4uy56=+}nM?V8_kk-q2lA?lg>_V$2QMtZ!Y z?wZ7Np)Fql2J*!UU+qv0Hwjgi)w_Ku%L?}c73zHkPkr8pXu1*W>XNP7##FPLzhVq}4}x9X&nz)TY^0sq^y|-!>Ao>>^lXGbXPSId+y-oLl1~{a)>kAidiI)MQm3pK)8cX)u@L%%Q^ygK8pPSxuBr%DaYj|-_n-ZTcqE4ncnJUiS^h#Q1NJQG}oA&7=N7kS;ufT zYrEXoqP)U(uW6LOc)_X+HkA38C54S-QrOLWIH&AXlA}$`eYh)5WV0V}j)U(( zKZDn&ojIL;fP|&ngU^()wKAp7>7e18;I*mCdA+PJE_H+GONEi zzeF+s7ASH81GCyQ#Ft*A))#rVne;f&wDnXBo9sGq(Nc1H<*>J&?5JNSe7jQH_tlk)Z|N64z`xhXThUue?#-9=#Q-`OTVf zyFKBPJrGs&pI4mExLv7Y%_)(y*Ne*iPFc-%w34Wvb`3vmt=8IVS#rEi%KZbGsoeS` z?tL`DsRW|m1??xj1saM+KJrxwgqfaXSX?vj`60r^$k;x97slS`j!EB^EES^xOTo?e ztsWc{NA1JqSQ~NR((&W-qvd5~N!fByu&~(d`f{6$&xU zP-XY|sl6(Zs`0OwV>+CJ!w=!L2mKsJhn9>V2X8W5{<3=Xp~WkbG^V|`uxtFr&kX}b zCqBR0+T%3dWWM+w`i$+x-SB1pnL9_f6wEx zY>~s%yaDUkoW#|8^{=PbeU%>xzVW*n!ykUuw&r{HRau6ng@q8|e09J4D{rA6c1_Z_cFrp~BTG*b`Z*+hDreiej1Osxc~1`*jee!IpWcWSyFp623%z0r zMsYUHZNNxO6PPS^(EO1d^opuz^kJ=bAGWnMH*(i%re@a0e8^o{{s1O3S!lQ%zgZ+j zndJAauzO5~dcwwqheydh!axN@+`gOb?0MX~eT5a) zGK6jq;0-TTjKCX?F%17Dm9*mNlKJUvKx~tX6t>nBsm`V$F+h_`XP8 z>D+Kl-r?ZWu8$T46_|!!r{~EwVWD6)^9pPo?X#@IE8E zo9cIpu#N$Z7>2nW5A82t>tU4dR@JwqdxLt6s^6lXbX=9}Cg{Db-NLv!nttYP^#co| z61gJ7;h?G-D!<8qn_;~SS~{YA`1q97S*E(~*ElmO`EdFQB42sm7}xg}ah(QDM?YJh zs&ZdB%g$*2BSG~iEqyV?T#mSa{YCTS#;>lyG#|Qt=oRgY2=7|)|28U|6go=i#N-a9 zmKX+=*!e6AZ_{*omzDCfX zd^Ntr_lv-Oteai}`Ng(klepTv)Ggx2%h8$r7q_3%nB_`k#=}8#nCR5|wHf%+w(z_q+l9IBMmkU=CNY1J% zb9aeQ<=tHEPNGOHFYkyvtJHgYVH@lB- z9GKY{yy>#KC$05zPG_mT?#VeH5%0wmqUfA9@4lM}mEoP1Sq{YGU+Fga#zU2l`H0sZ z7^X{0E zqMKd8J~d%hs6w**O3tm6g?+zp_?D2Zm3)f)!q&$_8h5nYKTQj2swIOQeCYkeQ8 z!|C9+xBJ(pE?vz%Ka-Z3Te5E3Di^XYCwbnQ(VO|DnCFdv$}?oc_s^8h7rZSzlmCsQ z^y<8!m$}u6SK?x~@iOOyJ7-}&oOb`}RysS&c=%w=*|5p+m8+?D8Z(qt(}Jt!)&=w2 zb)R>r=e_WyAQ{>Fp1*2)^Tj*4qnNg)`&^u)w=WFT*y8BG!MRrPV2_2i&4qlE!=B}w zYwIOdZbs6=1nzt~8`Y+Eq&y|U=c92-(TsWa!U@kUj;N1(w!EIv z0yR(1Q>C39+o(FUqGWR9sSgGlPE7E0Xn)<;Kr5tY(oqx5rkF&Mc;+XG|#07lx>!(y+6fL4}EBP&n*fp%N&aV!5 z{1k`RL3@uonGe^{T^zDK3?r9oySn+)L>Hb+2&|X$F@zD=7at?>i>79X=-4`W)L}1) zNB*J{R@LiF{SvcQ#09?!9apZG^xHUFY|Yx0vDcbg!#D05R5B=B&wJEDSoK6S{gBAK zaEE4~0T~OI76D0uT|1WtF{9lNsj)eRY+H$xvp$^qrAH@956z4nCo`9j+qbLSWL|T8 zn*Sz?7$wGMsFJj`E=w*u_SU_yG0d54pHNFuy|$|y+WEJe_FMyZmFX)>!PGmJHY2;3v*yM*H(#5ZY{ODA*)6UOGo~ zrhGg~rg|(Oh1|P1Fr&G+_QCtc$FoAaf=@Ga7-bJ>T%0?vPUTA$qCrkYca$JeL^Z3e zzr2gb*|KiX9=;JkB6eL4bg)A0LgGvPOF}&&#JS{^Cgt;S%K6^=#pryjSUhhL*`GeO zxxg26&TaDQ&yOoL_vc*41%AR?qXMD#?K89;epXVSe(J~ffX~7rw2_LX5)+@XXMR1v z_({_z22!yC^Zk1xoJc(@FxTa&d!y)|p5LgM6aHiyHEhN2HMXd{OOhl($fr&E>Ny87 z%i}tf54mob1=`IpWQ%lVkr5lPoXDi?6xAp@bw#cDeZXCgo1#I6z6`=bL0+O~ufHtL zdir@xx50*h`$79_PRYk~lnT8%rv1FmOYd7Gl_OuJhhrEA%5)xhKKe>LK=o#7{b<+~ z>4*)z+q?D!a7D>~J+m||6y_Z&NWjx`uRt*;?)|Dv&bkfXr$v*N{!8x_goN(Mg@lhK zJt+5WdBoB-Yh1pa_gLjcqb6e)Vq@ZY+LbJ$*=9 z(z1Z7{UTx9mlvXv=GncgyD#P!r!^g^^gbj>e1Y_Y%Ry(9 z#uE5+_RXh8##gsy<0{iioQN9coxJdwuVnuZRbLg>MjLI76Fg8PxVyVk2=4AuT#9Sa z7I!GF!HY|AEAA59tw3?tLb3MGch0%_pJ(neH#n)V?HNjy9QRL75My|V*{9>%)jVb4+ISp%BMe*4u4mvj7ioLQ{+=EnSx!FaFak-DY z(Q-6n^r^$!ybB8en~BWIxM8a??)>fZb2d6PDG_Gw0&ciJ^y3~M0(}sGcEOUQiyEkV z&?AOJ+atCaY)H$(2$j8bT*!-aC!TSZN6hxc7AIYZ_J*C z_#Vxs#yCP3<&M(Zx2?zGFFnZ}hJy|0>cB?7zuEXddU-RufzdLw3b}Qh&}oEzPfh_g z2^g^mW0+FXEs~>QOc6I8vU~{l`!`bU2GNyq)4d|btR78rw2hc%poDTReDH}Is3+Ka zTq7$^7teB8B+{Licc+^0`<5K7=U7&OG)>}KrJ`FX<=LZ4*v`Up(k7HS+Hd)@2xC!a z8Db;T#+R3crZL#|^Ye?JwWa1YU0j_4X=r~$%*Vvhcff^CznHjIf&@9X8R?rRV2)u= zh7(=&a1x@}6~Yz$OTj=7?*#hHi(5e3Tb7>+ z5!%hBpq)0c?RMdxAOOd_AVGo*=I^Y$Y$-W2BSt;pz5FVskvlN_AG^(#1tsZZf=O|@ zTW-*}6Lu}>GP_i5je?9Ed<*`)PRx@moh2EqwUp$NO2VGsoc-dg`o@wC_o3`0aYqMy zg{8&0L~M3;_izvfH<74*Q9|;`FV}2Zz|mwHHb}VnyQ_~FV=C9W z#l>=mXPNiw{XJjc-{FPR1ey2;y`@pw^k`{HkMgS(JcitIs3%};RU{>u&hvA>ZdClk z!mGDsLza}V3-ilT`e&#(zHCUnq{zZietLyU;4a38zTBVp3{F40h+%4~t35%Y5-@R+ zJ(y&ZINT0Q#6pZmWbgu=5XS`+9?X9nx(8Q4S>Za%i-z$c1Wb&YDbE{{3b9vtEN5Lg z<>N#B(Z;9Peu849;LCF`I9}&p!4gHwNOBXm{SgB-`GbwnbU740w#AvbShlJDTmkeE zTOt;^5@^B7J+WyKn01%vPd`Ku2BfCO6OqzfU7g+Zq+W&{^V|hUI)~p%Q(3%&goOcb^FC9 zB#cch`m5@-h?-*bD+VBHh6bldIcotoZIr4x+luq!d&hJmQBsTXIh{)EI?Pq^!9Vx6 z+<_e76Y-@(9s~)5mE^WOeVb%aO)+GLdy@ujHGQH`fnow!!bfCAuB^aQ?wkVL2|BI( zHURj=iCt31*<5ps^MYcQEaAPCkd8Q@_Ge5I81Z-f&VdT=|ODtqz4s`leYd^0((c1s?jNE~6{U~B2Hj3^)+ zk9}kktL&nHz;a1tDVil4Dm4OAAi?RWkz|1Q$K|glohggbTou26+m`>edoL1^A0kKg zUiE85z@abe>8WjBwK1IbwY;M)VTGU^g&YICA-^173ttKHuKLHhl(+{ieFMUe=P&?TjxBux}yWRC84G6Oj3tnp-hBN zfDf!{lvmF_#lCY?q?&J^C7Z!1-gkINxX9Q3y_)#M4WD_Wr)-3QudQ`prZmY`%nA;$ zya>u~f=pT%IY^q4Qz=VoOTP<;%hS9w3D40hWLSuFHp=x#4D=4A8PBrlitCJ&3%_;l zckZRE#*to>zk)MF5*o6 zCOZ~0->0Sh!7gAbJPABuvxX4@^+E`BuwX5dj-Jx9dOQ56Y!3A#4uwiX;*47Lmzp*b z3IPf$>C+C*4k&+yWBB;n#=+!!$$UJic6}j|72A=A54R;bd)}U7*@R_F-rm>RY_{xE zQvkiDZ2fen$!u#v_k}mX)vodg$&Wk;y$MAfb)9^Pr`x(z_+Upz)}fl{imA5k#l`T* zWro4lReM1hEg^ZM`WIDx+2m!}<&F+D5lHaNV&>52Hf_oIYC*O z<#qPDM^Mn9u(0hZu4|UmbehH%5n#4JIgk(_knve?C?+-V*Mo9aIc#?I^myS+OGJ|+uNNe0@up5~y z)IOP6X)mM%?P;mWgwI`hiuF1@=Vy0{NNY&apQh);k{)97b8L9{FgRqDw>CE$V6Gvj zElAiW(V07$V?$EFBfaRP0(s;$w=!0sHJx$2x%v5zH|%Vg0zLw?MnsYOtn5i(e*Bk3 zbV8&PH?qpW_gXfM`VGK+M1`}F#WJ}EP$3{8{&X;MlR#ynl0ZD_Ak@+!XdIfUgL)r)LR!Qp%Gl%91sKuG0og z4^;I7R5bvYO$Q2GR+d3u^fm#Pg)Oi#>CP2p_haGpI+FJGW47lHcQRA}HPWpW%6Ho= zVSqzF84vv`6Nv85*Dt4oWvchl3B299Kv`^-4;T@=!pZZ?W2Q8mY+?1)aXUk3tfU?i z8rD7)`Sit_v4|e|J_XDN7R_?ozSW5AS~b?EBLfBLskm_S3y8-T$#WaDwERngSEss|I+5ocA|Eq$ z*spk^I)u$V@VJJ7TGTPCgBfzD%ntqR6lf9s5r!>Oz^x2nD}T< zDiZW#(b8rkMMgG%$wbGvfwH78cyNQnwXKMehT#9CbK#{*V|?yb3K1j@w{duCemlZf z3A)!hN`0SMcqNE0n2J67^J~*O;n`!TRS|Kxz>A-Duw#Q$wOINazuq-uG&ZarAzFq8 zE;uLWiC0Z;hn!5-GHrRK$^@{sqqdK*!DN#~7A}SZD54n*4MX6TpP7>MG+>VBuI*MG zkjB}+?m9Aj9OwC3C&7M|z0Z3PqKsqwUD*?uA>2?#+=$E%Z|J7I3ihH2;} z=u1>s<>VM*1JpUzBC(Z@Tv24XYcTuzx(;40OjCq?V#; zTQ1t}y5`HbCpXmM`+Dwc@1mrk`}W4lYJC$k#4Cwg&2YFHI$4XO;hN`;@w@%Vy>WOHL`^=|q+Z}}?o5xJ>ivny;Hgf>l&d~%pT z^AVB7xJH$TZjz3JeO@j1$ebx{$9R$^xk2F+J>GARDu5)VF^q?3;hzxJ;@dYK=W}cE zufx`Z^QRhFS3m?zEQMo>M!)Uknv}gB)NO%i zeJPXhdjnZLnTH{VvhSF9)eJ4@t2PcdAr|l8aB0XDTS$xR?azECT?xX5Ab(vM1E;JP zX2P27`}wUX?SC|f00c`;gj>*S$4`3C!i_^Jm9K{j0lvCx^(e}jVZ+BInW@@q7|^Y+ zbP{dPE5kB$2oMRZkH9XQ+U`ZQZ=L`f7&C|KWomB|d@p#yH|s6ezXMk^q-yZpmEtVg zqV}Y=r&Wrj%=^H@;HuW#&vIDm7y+gt4Em3bY|QXGEG*3Tfb~O+4od_~G8H#3FHc!# z(3AAg+@3HDXg+CUggryFgOh%12NYHo-`*q^C035E0`H*JbS0jw;l;iR z`iP`-sXt^sSrvGTCV^{RoP$k_JosxJW@gV6lb{~By}Ryz%cD{UP>y<{*jZO{+|ioV>~x?O+)=Y0UtW&D+;n}syu5Z`LW=fIPD8V^!2jG7 zteXta*P2jarkj&m`gf2Bq87Tr&51}UgVhrpy^uXTDx@7nlwyF6F7elUlCFv>7D*i- z+|h-~uKK4o{6_*%IAM(DyRanSRJ}vpZ?OGe42U~Ij-vm}mTqv~-MGo40YsYmj}QsG z-aCRF-)i`UJ@JKn+{r(N?jzeYT!8LFDn8|u{{muoUjehpN_!_Q{y?-Ujk{|euY@qF z=@Bpgord)q$zlUwg#GN%>C04U;@JDl_+6QJTJE}o{pt@tu&y1avl@Q7UsrLj;ml+n zB0O}+9%GD~$bmWPkRSJdH6;yuUn8^RvTd?a)-bWWGL5C_vIO@dhPfrvRFr5^@r>5# z(})>)soZq9!&SqB!cbVW^rLvt9i5u+arEzTGFmYtuvOV5h8TSDZMD;_fhi{VH#ZX3 zVo{y^8*pFpq_-w5fb*^FSAQcoxSz=c?mLjqx+`PRf7+do`ZmeLu^@u`yw`3$<9977 zI498WjTnRhQ@;&PpJH1$+0`w*n6Cy-h;uw&*Rgo3iY<4>9L9>H6KhH5W|wthSk`Iy z1>?p8Hxt+n>g%T7Iu3Y*^S-&H&hg33+=o0qX)?1tHpC%&x_u)>v0$K(7RitqtkU}f ztNg2S;MYzhtbe&LwtO1UKP$pzr$K16)xTraR$w#&VrIk#=rF4NMEPxkpox!3E*j$` zC-w5IYkz9ewy4`UO-)A}I@spgg0XID{@1Gg-Jdin@aBHTBDgBiU=uKAcTjK zNYfLClc6RaN!`9hS5^3hHQuzB=3~^DpEUfU@(O{8YazkR(O%q-XQ{KzAO4w}E5dFBiKydeu z`03WaML;bFir=A%mWxAFI=aEtg@GwSizV*hG*#lW8kQ>Al^kSIs~TOqjTE>lH>$g? z|A9&NE+`@@5_&;F>?^h;vaDyb#;X-%wN9eM`Xtb+q1tbQmf)1;e^u42`nJ~XnTkc< z7Xys!W9A~JrLrdRFdt=i?$8W+Tj0SqK~y$U$xD`rwBm#|{a1d)J&Pm9o^~me1)vdM zGsBc53O=V!7%#EhmVMNt3l5)-RXjm)y}`Dmt+l(6GU{m^IfEz)y-QbDj)5eS#_(>q z0hAtxt4Wdo#Pn7lELSmpKKvD+b+VV5k9^e1NmUi=B^Gn~Y*c-ps7?GNT#_9mZ$(Jh zN_?2DgHrUy&+C9(60~rQinv-|+aa`)mzyQ!O9lrJNps5GFJp z;m1wo;^vDx?IiJXY5X+LL7#;UPB-Wg#ZXQ=UXM*dOyk>eWvY7zeXG1BO9i#~dQI4? z+cG9Mw99${{gdWWZOkj|u;KbUr;N5sm?SHzJmP)sjE1_0jt6+bn*-xQj+c%X`A_#_ zbVD^|(Of^bzD5eXd=8SJVu2s^%*y6iV$5z94lz>j!vr%v%v_IixM>tIr9)kRKv6kS zpxe=@;Y~D(Adxu?a4_tVy#svMtbaE<+U|i)6TrG5)*nG<5A5$zi%K0>4B>PIQ1KD$ z^Nq0wI&KuL=Zy;Zgz?P?vB7sune@c{ECHxR^a?|ImR`H$;r?z_%;E%79QGD9EO4(+ zgREy%;W1L}X(!lv^@0Bro-ifRU6?!2rIx-GGM}5tTCDN&0#vhcI3h+vm#sjbcd@1L zNCBgmFi%Ict#tPtz&c_wb9lF*{M3Dq-Zo~=!WnKCeH6a$)#hOnQZ7hbfgVCbLcbZU zWU?XGJ=Wn4@zc=h4eEH<0;)#dC6f*6FvOBO(#rCOxh;4d#cOOO`9-?rw zVlz?M_mG@vKm!WLwPHEWHrW2fO;SS%_xPzTRz|~<3ZASoT3qfJlh#$+? zo6l+fZ&^33*-qRWWJ{EXmjvBeMF3foE&LU@5TgwrGY9)$Rj}spzON!rvIQDbT;@qj zwoX(c`ZGI4dX`|K?p}|ZecMArs;;hodfghiC7WRCLUt^WJbS6d9`Lx`OYlmNQdE~l z;aY)FUIWrM0sgMMMav^F-_hno%6qM|)6 z%o6IE8g;a(A8TY{9T|#5M8-8;c=@>yt?|lua`EZEa?2a@1~;-v(ozIEc372i(P!U3@7ir{APrVZvk|8ePi|U*_D`)QF&j> zrReVg+BxuvE1DK)n%P(V{9$g?&nkBrzvXA}{iNp?i!pcl3JL zjM2EdYCUl{|3x?VW!HB)DO0&gz96Xo3#)-|jSxWJEhCK87|-PWQ|mrA+dN9pj#n}` zoWXs!ftG$~mHNXM)uH--{i7sFN&1EfYcW5671SUfhH?SC5p_k9VWg*SnbIuc@Cf)+ zf;VXZFzp#PEebf1XVF7cHD@0J3McsL#$(#c<4L^;RJ}g2I5inV&StQC&m!yRM>kc(nQvYfMeR zX?#MBS4&4G#T}U$QY{gGz_n?dT%8wkK;sa^{I&7 z$!j(;6CFJCn<2Il;;A`0U?|VKB(_UhiBEmj$Qpz?LO4CU9oy8@Mu{S2JI+Tr&;bpt zMCD}PZM|!T-6Gk_uac2}2#os6?It^@sKH5xv3Mj;AhB}tXZV|`N61F>tr4FoZk2QeE zqp~7mgzkruaNr>LI6_~(|E?6wLk$1#xqvh|NyQKFHLxp7FlIQ|2#8YT!kw&a#NJ{C z`8=e(Un3kF?Gnr+GNL1?!7m4giUt9%v=+{v%;3dn8n)aM&>p_9H4`qaQ)G0oR$M30!3s;a&dE5-Z#tf8@_HyS7rm(t%=#SJux5{A}| zzr9J@OVC)uq{sGP%z?)GeUArcPWPd~joeKRnT2xjZ4E;rH60<~cMB>AYZSvtuwnDs7Hoo%4sIdoc z#)Wy6PD0n9%ciDo1dY0jE7TM)8!8#+ndQeF3$?14Y;qd;rQ%U!t(W<!c$ zEu5BfcwsW$*o-lnZQEWSzQkyu1hO;e;N^xvIJR{*&kSteD;I z(laEU3y+4kv>ooi)Ai65vT(Hv`yw0C?#7A?bTN5BnwfC=>k}8GeQuf2ixzpm)H}h< zB4J@Zw=zY<>Y1i}dggs9KG|kH|76twPUJlY`uIo*-i>+W4C!wT>~|w#iGJ2GnWUwT z*#0c7$Zt=`s73iw3V!*A=Abuj!_*?3na^dYbsacr4g2h6V;-~PzRKGS=k?!`=fc@O z>;Wdn)dI{gBhWlwrfh}+=Pt>^wF2?GgCx~)k`kd@U#scpG^HLFwvPQz@$obrCy?ZI zLZyzWmj=Ma^Jl3~128}$T&E`nYzkcE)n}?A5O}nwIAi6XDO8GQK}OmAr5wNBGRQ{H z)WKs=-1a1y}(6&lGfJ zWyNgSj%-$@j1pg7UEhBVxt1FgI2jVW!=+5O0Al|h!M!|dv~j>4l;rPy`Q#ySayB!+ z-^!MP4i(OP_d!q&XcRO7-yowiy&`4b_WhcS48`h)0zDgN%-j&f+h%8f(-W~tdl6I6 zStOlqOf?UzaT-Q)Ri1xktutdAl6=Xma({0Xn3l}9P-ceFAvc~)79N)z2F1Fc3rCTU zS`eijx)tw?Raq?9TZ*Y4!UUH;G??^wD+jAle6%eCabIYcwAy5CN5Bp;-*`>^&oFOl zeRE*v46P2f{Y5M>zRn0}_$hE2)$p_#PQ`^lVT67qxh&RlYN=?uJW8LKV)=({e!=^n zfm@awjEpEnagVlR!)ZJsyqL~O6KS4T;To!Pj?={*GyeH^lp$KoQ=oQW;j&){TR_GG zBHSS#{3KG1>tI5s9JRu%i5xZ6;KWdR(b)Z&s+va-V3ygg>#=KfoWK z70f?(zS^9q+}%Q$1F{ahRS5JR=qGNJNz6yvVkV7&pv=%?oNL~ENe8fkwhlOq?N%(r|jzVISHF3koCv5pUo!Q!|jflVQyxa;Cx)AZlyf}&)1lfZZbdsN?Z^&vu@_*^6X&;AeG|@R>VnrC>J-4_RZT&TDa-Xh$ zKfD2=-(yXf`N4XTO3C{4Ceiq__39x}RB*>2R55==OECxFgAKAJ_(NBZgZKE^ufD+RQ(BJDIkcIOg?BzM|4B&1 ze(70#7WcOSYLPd$Uirq5hueV2GLUSEzXi5*B^^L9wnF#G_d<7&t%`m8i)?AZKbf2R zU^Ar2-1hoB~xqLWUOmS3Td@u%u_nL*c(F zi9=jo`%IlUpQJMg&Dlys)QD`M&8>)*C%qYyppU_O&m;JfdPyLDad~ECJ0L>v&=_@^ z(iz1p4Y3w)gI4qVauHOUY>P`&2Vs{y^Xl7%G)(YFODSZ9O!=T+689( z@#bAj_#Y|LJuYYGHPSqerc1tBHZy{2e2AHIWv`k{K_bG@ChJ4{Qw5o-ac~US)Y!T5PtSvoC*g#r13slXI zJD7I?;A6a!J8!b;>=c@C2qfqK3sjZjEgrz{zqGtzQm)|(S0z5*>P`(q3!yYjzlqz> z`abea(q1XF^Qg8kE<0O+PH&t^9x?QHS#T@c2(5tZce55)BZTPVgi zQbaI@or)2jVI=<>QS6BXqV6GvTUA>UYeSBdHKr z9C*SfM7tpQT*?>oE5`F1TT3%C4^vFG%>};HoNzB5Lq!3VAwu_L z;Rhxg7edy)oCLN&+Z17iN+vie(%g_N@9C3b7WAgeL-)z*E>3qqcw?c^RiP*YmW>7| zoT*V$*X#XO6mOK8F=sNTG9H@FxV`;6s*-QUytU9Lbhj%?$oA?x;!9*yVn|BgzJ`q< z0-N+os4$@NPQ8iui{E`4iVv$RC6k3aki`p?I#mkL8s<>sQL|>5|7JkHKc0~=RG~^) z5tZB1fybm@G|De~Y%d3vF1Kl;F+bi|qTYEpi z)~S(bO5t$GKUs5T8ThG*mtJ_bBM>sM@Ael5LFHlN$ik-k%MZUUdMe+1HXT`bHOlOf zlX7|7ypqQ=y(hJ+75w%HES=Ux%}0~2)n@x783g8e)Lt=o=6MY6wrgkQrlpfMb|#{>$R_mJBc44eg6Wl zFkSp_FyXM<$o&&`h5UB(9>D#BQ$O;Mgh{4NINDn*A%sGmx9qrPYM9`=Z)qLRZdFyK z=N^(%-tEEIV6cKjOvm6$y-d6zqcR4;N!=lyf*Mws!g&|o?k-oAhkQ~V`8PEamH{fS z&i8`Y;1KaZ!o1~l+i1GW>XyVSKBCy}U`4s|Up!T1cE=O2ZQaAp?Hr6*i{$As_nNnhpG?j! zR;6~;o&&RVa(APu*s9V~vX$^O5mt>SHR+Wyi@#dBD`;jD1xYPgsY|a2l$N{MGDzEP zkoC{~Gj?!U!|Jv0Lhd3jiLaksH)TIxg14O{^f*9t4ts2lNgohUce{>u(%pdTze^`m zq}sW#2fUObdd`^ji^}ohla%^z)*0j=&hMSwr!!Z+STEr5RL%=3u$F^@Cjn4z#Og0e z<~6Q(V^blimLLgnolV9yE=a5r7RR#x0Z%@gRbdSc?iBm+(NXhelJ(2qq!yRozo;PJ zQ~jK{u)jr8D`3#2<#r9(W#cYAW5#AK1 zw=&8}V?1t{hLH7?y*GnrO;KY*-*oFF}ezzf8UGL;Qzz9=m}#tKp;V0%&bQq_{dq<@8pFkO1OXWg|Mmwv&Biti5sXO zs;5-kxdA(^fmKy6MM)k8&^An{0444)8QvF|i?gL7g0;)(Xr`6{DJq&0bQ|$BZr#4n-`Gdog4-fb08E_ zb|bX0saj^oeB_}rh&JlJ^C!guGTzw=($Hlgwbi7P!+4wqKyG~CKLthp^sAqZbQE-g zzc4$I!aM9zGAnI`X;>Q`8QSAbo|ZF>v|rgZ!~6%QTSd<}$T@j(VOJ@jyZZl{7`gp#kw*N_4Fd@xDlgUEsCkp7Tqs=8!gIDE z?|F@t2+y=_xi-F#?b)eO6XmikjpEff^ImdSH66n?Xz*Z|TXZc!2RXZvHyVc$6<1tA zd%f?x-eW_*UPmhf@9=0Q1iy4@D}PXrq!KEo$G=vD&mCqWKq+>>N6~5cK}e|Pw=-Lw zR;;Y;e3kPJnQI0}w~{XNpcvE=S~?gT=7wx17|5TYdT{U9)t3QU;sp@#Hs0F-JVEs?$# z!wQFM(q@K6Xm?7&<%Mw__e&`_A2pWSzn`huLFfL1zZQhxwRe#BozvDvB>_qdYk-g* z_H;5?+1V1V9~AP+=dmDJX5?ALdM1a>lBh^kj9kVxo!MhkMmB#i!~}1(X&9p}3jV1W zMGJk~EKhM8)Y2q4`_ zEq*kld)Me^z4V9Ps#*G{`1}d&bDV{JHx_Nh9g`g)M(iL@wunjf@xEiB_OP#l0X-C33`i@^v!z6VH%0<-* z1j~5l7|q-+2;2jAA3xMPxyF4aqB;yjxjg$NkP;7~skX-uIvUrp^}*ErrFi6cQstTq z$7&LVR|?R;%UhdH#epRSp<%~P2BmH)fd?$%di&4Nq(F4Dq zOkXFN!bb890jNbCr_CmzF5k8`$G{>*cz7ib)4id*0o=v)EZhISS?P5MmTzX5?dp4B z+$?ha&KlFf-~Gv(btVT0&gnZ+}s- zvHsH6mpeP$+Nn0Kj1nHl(Zg#bB^I@`0p=iDDMjfeN>$k9o?_2WmyF4MOjI!?7QibJ zt39{OJ4aNF#RgAw`==WVm}NekZ_WfwW((^2Q*^Ilf-#IWG?^P8&jpCH*$!y)NxwTu z$cfiD=WgNP>9&X@$G#=~HuEq{cG;YqM>NR@ml9Hku4Em?m{N@-WecHj2jW-MtB|*R zF^{KL>nA;^A-YHL5w5Ncg0i{QD%n7XGgH9FH|u6}rw%itX#Vs`1KZ7h_?oqE1Ut38 z%@oFEcXjP$kJD5T@?7whwmjF8`UmZ_=_ZR7_7wc)1Ha61s%G8gdL;0aEBUBLyp1<^ z{<8a!e z8yaX0SdzbdONN-J56$Hs`1RdfEdCd3cERzTkDP9y&UTg&=L#Yg0gYROv9QmU;b9l4 z;J|E%2r*uEGE1TSL4d4jZK%Nl=}mIMm%;k2^64Upx8O>SOUBkZ#(o*1ma92Qv*3PW z^8n^n?jA04BYIqghWwhw`|AvC3kqy7$_CTs3I>b?Gw%K;J>@kPMBYf|D!eUasq;@w zSgMf3JVVKHQ7oiX7_cRFrL;K3jJL8B@Cg^*ktn-N6$;_p(_o@sU}{~n|IroQN0XIK`NL`=3zfc=;eOHXBk-l&b|K`gN6 zLT%BpZJEy+fGDbo_G_U@0V-FhH6l^gOdgK{I1DJ}dS`Fr&)=znm1xl}82a~}6;ONG zHNm!v{h=V^??w=6A=A2u(M&6!eA=mnTfDA;|CpF-(xyT@4=B93BmRgi9Fy4_LCshv zcZRc(%)SO=Y4(wbf^$jA9;Q72yX%lGmOX5?dv|TABmOV1gk$=9pUlCzVg4idwwj%Y z&itw>e8k?UCqE7>`qAkn6J}5t9E3b4wiphxM2jP@r1c&1sCl1pN-ycm6rCqXf;aS} zL0IoM5&7K|`z8H%BWo|XurQ>cZuG{cuyr+Vbh>A_@jMgSI}(azk3?8y3n3*W5S&N@ zl^*;vVe!pl+}J88{a2B+7fY}RcFfIREuzSRv~H{X4a!Ij;}R1>SJdGWhZ3|E{L>fF5->kT~3BN6u2GHxW@XSyS+c{(H^2z+{iihWz`2?ipqS zQ_VFgi-UnplyiepWDP`Q!ul)dG&ZeBAWGF({z`q1a`_^UEi&;kM)9P<0=e66k-I{1 z>{>&8k$Mt+HpPnuCT83BxLmuMe`?)a04XN~Mf6JSdq-?t8?HwpMg>dcNg`J40T_ba z5kd+T6A=w5V{lH*?U@oI+Vd-H)q7dt5_adJ<0<{uJ|P4^qyBT78REM2f-B~A`f*TV zd+5sT-kSgKBTj7==8HIce3uMuBVN`K#{A!O?O>Hu2CpDuV=FOMRPDLwL!MB3W_mf; z#J0;nIL`hf_&62MbnRKOnCiNQEvuOvSK z6gL!2RfW!Nlh@}sPe9O}*5frwE40tp*r^1NfEOD+R_jJcm$K&^wcuEARUZ+&f~h<60Jvp5uL6WEe`5+f%^=F9&c~mTyk)L24&lqX zs?}lM4XQRkfA`lxCn}b=sk{3lE5?77IX~a&TLVu$zu3Q2p35w7!?uy_x4okB+sx7S zDlZXtN2c$!61`ngGaPm7-Kdd3motU`RMLX_5c)`Bz8jFcP#6HaXCQJj3wA&pb^ULj zOZ|zjwy6#dX}Dx{;eZkAkwZ?2ou;-z0QT+l%B^@G9BO{j5$Fz_)=epz8Ta5MNfL^n z=jH$}1j5vk*-(U3$Pyj@U{S~zyJ|nvwa(+cX&%_t%so^8CI6g4bU48S6t2w7ed-Nq zcErgb3nwtjYzQ@)AeN6CoxA)2ebnFxo89a^#Y|F^VOBh_wiRjyz%dU3WViuH;}(ey`&H58nds5cuIc0rInMiv@plt2f zk&@3=mg`r~&dvZ$FPWyu9hMAAjMupM&nQA^_p_`{Y`Oqf|O~4O~{@gD888ffk>qyJkpXV~;x& zhY4PMUPF;VHC?KMnzyo!ZSD^rrqfoJ<|-#dlqMsH2#DK`j3-LOjd7EP$ru>Cg>!k7$a_-E6l>NA?{i}0^dMy&vd0Uk$Br`ZzMrFW$ zn{0mk_NTN6A*&Cx@|)>3Nze~{6zT}}(MHTKmU43Ky+Ed0=g_AVKF@NaM61Buw%0mjV^BgaT>ugCkRyT(=XgZBu{<2pVr zfd6J_oK#c4vmP=TH7ylIOw75m&om57j9P#2KCfFyypx^Ll1W>VVnezA4L?^A=bxK} zFD7lN^8v?p{tg+|z*d(#iPI;-98lVa392WO4#u&zu8SFv)qVZm9Hx=^g-gTKadp(z zs1hT;Eo{%viu9MzHNw^Edm_n;sOUSHcI#2Q`@8syMw8E?o+PcOydAq_qVN$ycd0tVF9Ll!3CeUq|0 zVY-RP{i~Wyg|+J%rdXSgl5n4Jf};Ibdw!2L7(Uf``uTA&#-_0lB-)J6tEi7B(Mx6o zYijTKtO@_KA2B2?&)u+|G(+im0;Oy!5#yT+fbkG8^RvM&x1h*QhQcve9rV`i&K9GY zmXB49rK2<UgK!oj0V(8#vDxn{mf#K3dVtPM6BUzFAFriuGbY z!=5=^bT+|=e)$a@mz555?_k& zpuM_)NG_xjLq1ngZpkxN4$$c9RKeX)Q2C%}D2<@RH3SdC+)>^y z7pLqA>B9z5Jkg7vx2(E_y2I(mH9cA9snU0;pryPpFr5A3myZa#1NK{7@^ZGHx5vZ7 z>6K$>62}9gp4Y|Yn>@chc%ARRpTY#2;(>$U6MuW9fzs(1v}aX2*r1v82Q02x(rG~* zJHn2}T%`Miky7@skE7B*-JhHELS62ON zKG7(9p%y_VxeIC)EDEv`?k8!k5>z*As#1K>MDD3_70fdzTSbZ{kxh5Qf2DbFZ`Uoh zDPy&|(Q<#Jct`eg6gr0VTp%8Dycpkx-Wj~&mzhbnY)0NdGdxFcOF!bNM;s!+aW#73 zduchRZL|`E4dzng(t+H2^elsYpBgj^G@BA19jkEUR`sGfFe+oc1qq}i1r z6&k$1@>_{RQ5VszhLT$Zu!S;8IOA>(G=g~4__sr)?2|7Rg9O=7sFkxPMI!w?fkwBi z0|s_0*x<6E3PIxo^l7aBhNRaJS$Wj@pdS+73vnYmvivN|UTSAZrN1_t>-6TUF1iV`eTIM+k8#`2$v@5^KW^RNE)K??LSUXVX z@*|^UoFz43{;AS!RQv+bwzMz5xD#KV>hOs(z56AJr;D_&>6Qdc8l^rwr>y{RA5nP? zQdzSV5)Me&85 z+Ixv$R62hC$c~fd6S2_0wB=l0#i&-MZ;2~yCZCFkkh9(QwcM>!n5l&Ox?Ex!B;KP^ zJMK>XKbqdMq0O%A+6^v6Qrro}t+;ETXwl-uo#O89ZY}N(!Ci|(pt!rVxD|>!Z?1bk z`@{JUlC|ajJURVTFlwbh zbbncOMs&g0?^j37kX*ms)@mB_$?jy?XcB5v%{2;KYW`;AnFEiv8ktMonkq5K0 z#&hNBM?FVdbVf4ow|?Hkba~m3pCI`KqAO#V`wevj{=&LtON+*{hZ4nb8Qx{3taWR* zz?wzOjoJvKRD~={DYE^IOPDRYy;8I;)(8Jneh(1$sg+G5MQ?jSc+W7~Qwst#dX^_t z9Js=oS3t3#jyrULAm*s?!6uo^b~Ci2#T`|m!oH}iW9Dw$NgB+=Pi3@9cj8W(r)Zp~Ki>I>MZoJIBcqsMU@{>;eWK$HdG$AP1V&CCXJ@gX6s$z% z`@Gp_1Kn|oyv+gocS=sus>G%Sn9ZfmT%%2Aukgj$ZpMEHwH=e28ov|aV+f!!uT3rA z)StN5NUGTYbJiBcFVe_DA=ZAcuzRc)O9CNeIx!@%Z?M(rq5>UuetFqoN$`PbxZMx~ z>bRG_m`80KWc*_FMgL1Aut$9SlkBtUx(>(^L|ZCcKl{xtVJ`G!-2WQ<#K`gg3o@^Effqvi_Ia}u-70qVCTynS zn;DJ{OcI#ky57!w;P;F%_jL5WS71Dr`p&ZBO9mOgQ6-Llk`Nj#aXAdQMU;lc!-Tni zw=1c;%Mc5o3S^xNW%0%rq$rhb>zMo)axM~%iI$Dsr?J42An3G!w4x+su=E$H7VlTJ zi%Y#v$I=p2h7h)}8kyhLWz#6Rl){Iwr^%bzbE50O_YW9DJNWww3r#cgyp4e!ZopJ( zxheQ+Y`#=QA}^b7Hu~z?`>dQ^{OYQtWUyPx&H*YA;sS044TTIJ$SWUoa%FgD~ceFu0hN`D9ENN*oj1 z?N*x2xR{OV+YAYbj7QT(7=9!l145;`${oeaV{RVNQwd=J8fJcaBT3-61G^u(uWaks zruZbx#IEH*XC?5*7JEj2;9dy73y3?B7fREl6TdO=_6{^eK31&AlqDn-WW=Sifp#wu z?uQf7R|;YB_20Rg6GFY78b3f-UGxaH6efMimw~-Vc8BXRF00_|>xn{n-&%K?w`jrZ(lb=rspO*J|3xNi*j04vO08otRg0vU$8INMaEuelWj!A z7d#zN#A1$933*}E7nr7JIkn2VMHjlJIE;mbOpiucZQ2BMxuh_``;q?fk||AD`36Q! z9Gl!aG_j{7VGuZsDi&os5dVnjDA04^a;ciuH;^Hx6DBI zbcjVtVs#!(NC5LMY7shFHeFdiS80o?#KQ(l&L;v5b5Thp)*IDV6^(xmOizO*U5muO znLxA-$klm7B3yw%MKCo+hJ4EM=E{T#nenV^AAViHEBbut+Sj?_mjJ4cU3qp_PZ)=V zw`ZaDL3_QE^35n6jtWhsWM#IINN4T9y4W(w!C2r@cbLdsK2omd6CKN!42E*xbPhiY zo&N^|L4R*yi*@C|AXC3ctei`AJ|R4pd~evCgN0{;#LrJ$X|Aef%`Wu)fVt1X7???` zO1>K3t$29w+A?k>eW5Sl_-mqkZ^MmG%-zJdeiq9!Y=J{ciyls%H|+toW!XQ+O0I-V z+*4)|NlisDUQ>_;Wg91O*z1@LpLD=pU%7d0DVaQaV+Ch~wNoT%wlG2Vb_gA(z@NFP z#R%(1((K51e26ULzHeO0S81MQo{BQgvBDI~&HdY~Ql^aZ*e2o>EMY?-*EJkK&fNV4iJpjcJbvGC3$@^rx?qn)q~!?89WtOdsNdmmS|3-B1lVG~=u z8G$TE!m3SqY*4dA=~M@U@)myiMO@<>?EaHf>GQ-(Fk_$PiuyE9LwQIOgSHc@3|u~| z+`TrE`qUX9{r!ddf>~{Q79pL;__Lu7jr6yvU}J6WfX|H~BHJMF#86TeIRFJ7i?cLh zgP4|VE$|&enZ8-0JjdzmyLyt&p9ZzFUx3~}eR2U^uDOfwVEgc4mYf}#mRx>@~5j24DeoX+!cnsq##(?u_3vUZ%t-@+B1 zOx+O_8F!1<+NfbpBn>;r^t5`MFO7Z}amR8IvZbED#HU>HQ1@^{3Bx)#SL#k`waLG_ z#R4_1(?Xj342J>xL?7n%#z;wyhaNQrCVQLaK6j~MriZHJpvC*K(WVvrjrjXW<_0f^ zD)P0vZc2q>Pausb^t457S zv-?&Mw0(~@;H6y6CH>m&^_6S2*C*4g9-+o7b!;K}OX40aOK;(ggH_5HmMr$FSvji~|d=i?}2O1}>8J`K>lgD=?({e^1hb9g>0vk5lAdo~_d{awrz{0Va zV{-IdY%*PxDwHf#P~VN1|JD`5nKy1&No?^i;`1$yHXu+r4)0A#F#$m8@g<>A!RQ-K z5e5PV^r`QPlXW8DAODaP9$$e?4e9Z3ChL`bqHuzrzkQ>%s>mW-hgN`(J_nu%A*;_h zV}vHlIxV;N6YF;d zS|f82NcN4_F$r6~^sjrx7hwBGD?#i1ZqKwErMPfL3j?2Cp9Mondcvvi#KN>3o`LY9 zXIY0<>4UvdQ$g#VwI#lcJN+vbQJ#E%_Z4R6s=U_>g6RtAnJ-Y9<0VoRwsNz>e#BF& z%%P^!p53^a!XgCYCtbHL)&`Or`yDm0(F!nW?B6USWr7`fWWC+orYY4$=ox^IvjPE* zqR>i&SEb}TiY8upw6(T!eerT}_9x>{<6FJlTf%8AE=cgP9ycayw1oWgeInG8!eYS2 zzhof$Ufw~@U^m;|4tfse4?vRuE_;>K%m~&_Xjrn`ChwShm@!JN(%KM=eY5+fpGZ zZOoOjWI2+opj%4Jwnr5>2{Z7StfZ`XWNo?iy-H$kr#7r!F{y?BY@33mr^!NpbISI8 zeC^LpF4aS(s($0VOG?&;rLE14!hipC%r1EOl^hB%$J!qZ9HYa5kvf8Vm=LK{Yyu{P zwZW4nN^s64+kRBQz)b9QRpKC#z*{@36P7!_NC?gaybPQ2fw20z{~S`3Y9A4c3?A`u zV_tvIqD&wta8x%?4nQARajh8|O7t~&&I%2X7|10?J}~L=ws_+QV8-ZgeEr z$Yu*M7mII=3G#FF5Of`63NJ*MC(`ugZ(Q3hx-^ZbX=2NFCII)My6R*jj26@=oLO(r zRQMZpK>$_7`Tid!C>vYUyamQQjpt3nDe~K3bllkLC%t>^@EqSU{7FL2_ePTC-okle zTTxW5m@O@X`o{_1W)a=WcK%j}=Tal4FS5C9-;<&XL_||wez#|0a{fzPtpN6>+*46v zi-6N|Vnh`QQQ3|?hd$?UIlt#87T)NhzbcGM#~>OS5m9n*Y>GN&;X*O~o#keYn_(6N zKHLP4KyWo)rN6M4>ym}WN-17tVrgOZ%gZf2z%L}0%zj{cKapS29K8|F`XGz<=)U2= zJTOlh8}Hzupj3zh1iEf`kZw9}S3Xn+4A<`j3{w3r1>jvan8Vce3tTZa zhAlU`T_lKbim>c5m?$fm8`a@p;>%A2sq@I0*M);uhSB#c#F^+p!b}a&?l2zR&+m6d z!F`xQ0dQmwNW6BRvkk$Wbi5S>#d~<7ep`{yw3fJ~)ZT!cN+mZNk zypt%ddbiJO%>g=Rx;%TJYb^~hUUj_X6^U#Teo;u`2TK7SesD@z6=LTG z&t$wJSOc%A>KofH9v@hWaM)wc>1q@_f}wOd31S%^o4`u2WfF3LrPKC<)}GY>v~D3s zy1zn_(G`IB8QO?J5GK?<>#=4uO87|h0EWR-l z`j?L~m|u9NYCkV=-@T+hv(%jI60m3h7VD9w7t^m@vh#}4aapg99{6`Pq?)T9SZv55 zoa6$;_OIcq^fFo8@su;xR{AE=2b?)ON7t$$h!uD$`Z;&+bcVxWs#NSeT~o_@3bKgo z_@H;({Diiop}iJ~aElI*Wh_+2G3|=6mif8S@%he z_CWWouHLzmvMG;4XF**IRN4kjt@sECL{{X9{M61b)LY5Nt=4>PGIPuL&*6af&o`#q zAtt=^y8kl*9w3sr09Sk4^4-os*AHj(M%l0k^k1C)ayJ9u#h)~dofEMa#Mj^WhU|;r zI5OZ%a)`>?I9?cZP8mUY{anV@ z7Kc*w9(g`A_Dd$S#{i3PexkJOSh_?q(<q_&&*opvG%_YS5>F|A-cvbZyIY1~fp9K2gs^U~`Y(03w*{b1IRQp@37>9@ek%HSmDWzQM>@yo720ntpd(5@gV1rAZhQ11jK zzt|I)?FOYl00i~5b0K~VxLK;h{xB0S;6Ih;w6nNXpJ1Brb$_0uC}v%ib+p5c^iA!= z?2Pc#WvM;mX*!m{JEZC%^T;Gu{#S!?m^v6@D&IVKHR2uqXo-t;s&VrVPd03&mbJ3Uv~6o>slJyVi9>@ycs=A%_4|#VMG&n_P2pDX zXs1|rNSzQ2^RLMX)ZkUfz0q;h^dl#|WWgq;5iCd#oq2PP)>vf@FD#3t-EnaeScQ5> zxmqfOm-f_?B9BztU9c)NfCd$TW+Cx^6A1SI$pGofYyBKheAB{h>}A6r2VP6L@|r)$ zpm1d8lBOZSR9ik5XFeAH`cArcAma~4ClGsE{Tq*=t)7PspKa_gagfnMSna2nKof2} zRSn@wRB+*esBM*4Yz!Y||CW0^e9r0RmhTEHz#dNU^rXJ@8+`3;&-;Ma1;aYiRwq>j zA1(XXyUZ>e-SB?auG~rp(=108&3p%3>VbgmI%@F_nF;mKI}(D!p!_XW6Ig-De1boD zWvk+sgiedG^9$%yH#zZ^=EaS!(~yl}JkLqxXSc#nW>9G*Nl@mFtehf3of9yc^SKYu zkMK2~X7NY*jwg4EI(1G{f*v&lUgyKb!^&2QZUs&uj$^Q*z{hMa_ZZaGS7(C!6odSe zu^q)P&iit?+EmC?X9O&lqI3x1+Z-rVTwm5`nHNTX=gVed=+lygHrX&Xn2dfOG}**g zQ9FP_Fi%*pgy@*?_AF~;=yCzQlR6gu4rEP6D6mk5x_vt`LO$r*Sw9#xxZst0Um9BH zeL^=LtGEq;V1IS~I}5sCsKG-Semo+L0dQCR?(&cYN_U$3SzqeKEsvc zyyzB@7@8p=$S6zvlDp-lhe`N(E&O>t*EVh51GPZ_$IGC1@6X!E9HCxh8*OHE}`)kiodz^EAH{zN<#5lM78+x#XYP+`{wO7f)Iq7n<7 zxltaX9j~fmUktvQ`W}3O+bVwokJmrEwt_00&V(3dGen7zm2J1BKhsYs!RIsY6+Fn1 z%jgQwM@T z4=b@_Cz&8~P9i~n64o3Wl{Iw}M$eSp7E@*DeuHcn3LF4e-^no>9f;Ka2}2j#zLjXZ z{G4Md-kf6|xMNiAhu1%hrMyOyx@n((5K#!%PgU>>5iI^vN9!NJ)&A)gQV@3mkCNNk@aF{_ zF!S5QUadkFaiA=sw*yH^bK9kk^>E<>Dk697*(WpGkz2Athc$JmyAcv@r9!u)0D!?h33x{S`S8|sQzG_V=`6fr#9X&6LP ze0a|k$TL$n){bZeTM$wjeMH7sV%3}4Wnf|v6}^~b^xeia95U>z8SLpkqSl_{HfB5+ znf8Bv$`gyFfBHCls7LRApX>2wKPDigifAxu=E~95lX9^+)=2UMs1*OV?{;BhdCm)l z_X8OGe0v(|>Gb|PjNu>2+F&vGuysy3bh6Y_xRxDqhP+f+! z___-BW%p|WE|fOb13MUd(RgFhyJ|~6_fxeN6cYtr!$OD<1D7D$cuBYbEgU|4D4W3v z#011hHp%ka^+?!hUi;fzy%cxwQha~c*waN6)G89X=pKjpDJ%P;*o_kJ#|JCs(?FQ9<${SHGd(Q((qU9;ZT2y zch|1+QtX4IM1_xnPF*lxiaRiM6vVw%mOTcet)bmN`&Vt$<98|Nli*_F6osK`zd&`6 z5#~mXw7y#L^P}S9H>7io)N*qp0U=n%*1}45t+uy>CFjZ2FhE@e$;RFCd_6jOmJnQg z$PW{tMYOgLLe*b}DPpRWdEq!Lms%KYKj;+5`f7%Q8RW?R{^W}8mdS;y9aR1K8@(K} zcO0Y6g?@SwIS`mFIG=Drx*5NU$tVtA_v@+Iy@?WvMTTMlHx8eq7Izr`f~#q_89Z3A zRcJu;iRNBH-OvN~#y5Uv8~Z-VQ~;GRW+$UwE>h>_A64H^P;SMz&W+^{dxJpl z7)9d=&)YE-dERF~?kmNXO$}VNzLNBvmAslCPCA%Amna)e>}#aHj+`3t4a9K-TCnrL zLh`T|>F%V}dCq|%h6G=GHN&>83$TNki(Faqn7nr}1`Ds1kz#Lde4Id};h*e32dpnn z>q(qwF(I)#%L*z|ny18}9S63P8W1+a3pAh)!OBaTUtosG02Q&q9NZ&Z=Yo05GPSgJA&Snatbdv1YY2G@L*Nr2qoi zW_XFdd8j)uO9>V@j37jus4y8QhO1JFaMq9noO$tcphn5F3yC5XGFa)X8IS#PF&}At z*u>Z2DVK$Lkh@gA(sPi9s`Dn?`6mu-a_Ysv_{5uJ$j%>O(#Ms#k8$-(joL4;)lMu= zlgK8HdVY=rZm7|B5~`lQd$rW}88iES?PToBq|j79j*9~1Hn_2sVP~Lx%MQz)zUr?j zbLUO6(2rLizY>KCN-~|0k%ZJ;u-4D!v)AgdG-ANi3JS(OT9eJ5b`$bbqR^U0e5B_n zHX}cuw8|3-(#@FeBl1JV9z3^&Zw&2-Ure}Qazx{yRW=hvYxAq|__`D0mz%f9V>d!$ z6C;ch_o_0K>W=!(QAY7{cTsz!nBf1_`k;+N!6U@Zdf*tkSt-A{eKd5#^YrbmZQtqf z#(0YUFDngBfAk!2!h#XS?h}lC|DMzimA+lccm2LJ{4-mvygh!eB)sWDY%zP{xPmWw z+CvwmzvMU)g|>b<$m;nv26sl9{@`I){EgmyZA^UK5`%eF~5R{K3@h2Icj%|gZ1W(raM4`9wk5V>@hhNy6{O`{vMN|`HhxL2S0k3)m1n(cMJN>Hx zMX5fo9rR8I0e*OFucsb;PJfHiw4sd2r_pGGPQv^FJc5hr#?Is2evuhu*2gQjTCKD1 zSBp1?!B(bib>G9;Q-8x`CYGXccU}SEnD!*t7!PbQtZO84vLA!(Ahc!$1`q@CM;u+u zfL{nKSJbh@v3~z{aLZl>Tf55gIpd( zHRp4(pD`nAUTBUr({RGuQ(6$4q%w8-$3Ff_y?clum#Py)r2+tM* zCHb^FGv?%bohGYD7L<%3+5YNhwbg;Sqwl9;Mr=YC(*7vOU29)pxE=ztBnQovp$%4r z`k}xN2O@`uOt+tx5F@Fhb{U~^Pv5cmm6ax*^$-?L$8RikPrH5OmrK4_?&%?%vgsLW zIB$saK&a^atR#sd%>B>fu(154ov6lLw#Arh)Yl)){*sG%Blz%`u0{_-u`0L{NR|zYqfRf% zO?Mbmmm*73nd2nRI?1j`TeN4m9z8jmGHBo5|0J)>?MwJutji`HwuFIozb8$dl}%dPbC3_$ljyT&i%YMcXg5icaV;NnWZ<+ClYJ zq9}g`@VK#V#9h;9;t!&WAe<5qpcBS2O5jHH(zVG08|?KAG}NI!35qi8BE?7u6`_(( z(!2{>id)8cLUW!DT}H0JhVHqC5qgr^^I&j!`8)_iD4Y=U6YacHS7P$uJD6*(->=Q} z;t0P4!>y^h$9mJqr{=@xG`3#@VyI02;qv|>c*FuYYs81uKvZT?7aWqY$C;juRQ#PY z+@};|IVV;54(fDqG(6d#$(BNso*`dYUxJJoC;a3?H<%KVxJa-b>!W7?->e-g%9h{W z#R<>IjZgm6>$vX;V+j7UEXYSs3fgJwdf`~JQIuYDGF+;xdtLPabiOQaP;I9Wp7o7u z6U|9xsy~WgKMNVCz3{3d$?tgE*bz-!qFn5bDueki1OX=%EdC~w#!wMVQAJEQVa(ct`pHt8{8;Q08?-kEGb|r(@Svhb| z@Z0m9&bqfM#@5R}n}2f=TjAuobDk$zoNccVRduA#A)IWeZ4c&OxkrZmI#o2ZGGTq3 zND7#jRyG>biCOyseh2r%(0Z#97vo@b;*P{NH z5*>PBT10|i z-*4^7sL?Ty-Zf9t6%|>QA93+A!||F>y~UFKHKjkv$pB?9q=OOajNazBAt`=Im3=50 zscny+T2=5 zT5~tw2&dI?!E@u7ft2F7iQc5!5}#!a@9ipDZNk=3 ze;>xU5+B&8DJJ_VF+iNe(spXpnjRz}YG3|wb0&1Z*K!rL)urTXYkkXUyGAe^*4AQ+ zgcP=zl9EjIpwC1le;qs)A^g3Jn`W{_a=?2Udf;&X9&6))g3PxM2@u!^_dbZraB%H* z-JIW^cP5tp_1ae*1;?r;7k1g&4eM=7p;ifEALzR2DcZs??iusZ4Rf zVws#Jjo(b!lXGl$d})4kaSC$ZLg4$@jEa3QzL=oPU6faE*!UQH%n4IwV}&}6KNn@z zFNOcEdGd!1cHSrj`tfqTx%dWFY_OnTu_;r&Rg z?)8=|h1q3$-@g0;J*Z6xCX2L`JdkOYUE^S|_qzauT`(;!{wdCA!y2AK9gunYT zpEN42>~(&K|7k;Bo2Y)A{L5!5Hx{Bju(|c|No|?2! zg<9LhmX$MsP$Ly(wh^^rmGs4!?$Lu>h4C7zina%24T+NeN^YkR;ng5l-WtSc*vPwJ zxa8;3me<2$NXRA6yp80bG8U<<8Aj%$QF5;jMEU&f)nI*#^=x@RWg!|JuYb@-2~<(j zGB*J-h4W2XciKN%Yp7OXNT@xH!0u564qmuG*eNF#qCyF(67ap*N`aIEs+r4>Nxdz( z+Eq_Mp&vQqZK&PPhARigT9nLXV@Zf$GFjtNLZ)$gb+iv4=*|`V!8~q&DR)Ng^OfaVjBE-NHYHMthKVn?VJ$qq z9hAT#S*0))mwyylU4I!P8QlXd>+Q|uO)f?AIPNHnQNq*alyNi|Q>4s72MuJqc^M== zI|=4DDftzkqmb)xjvAU&yv7wolC&wo3KzmQ- z+dg`~YCd1S>N@FPp-ivneUba^75v%e4NWp4AckvWDgtLUP$5<3?DS277Jl~R4SRlq z)b5-IKSPoeoTxAoNzb|LnLplz&RkCPe|AT#uDf9v_^*o5dgdWbzAw>6de5c+H4d6(kX3hqsKEQmlY+=V2huM1>Nii?ZW&0W1ew3B7HCRe6FqUhYVn6@5E(AF@OW898XGKo#ghf zpE-y(-UG}Q68ffPsH(e(0V-RI!S_w6T|aFOo%NzU@W$*bv-&pMvo4mCvJ$I1(FqL# z3ici!X0tPoGH2L{LW>Bs3gnkiD zW8h?P_bELv*ZexqT1(J)GDLJFlOXg*#_hLxPbmr~Ks+q{syUB1H)Lwq{j$a3xkl*} zu=^&V>Ur^pruRW*LfP1EflCD~qJH-yO+6Alr6?kU!2MSbQ>a1FKMM<%SkLtb0ephh z>>iRzk5j-dw|9J;cq#*ns`R`wTn(+XMsPu{9EdA?y(w^Y_2+;6W)XJBXNzDffr-h! zJ+}z_qB8G%2otAMk&EHzxAR4og}$q7xA$M$s^l@SDvjKDW@oHk`JckFs|zNi1tlrn z81QJN^*OcDA7s}@3pzNKOH4*hyi81xbEwbZU0d=?GX=PZ48umu-YKEw$z0uz!F)}$ zl&EK-g9?)~1y!(ofWvxh!~p+GgaeEpv;5YPusz=RGo2_~QjoFG`w^wh_#^;+t=Z-E zLX2`?O!SH^{vqbH_t#jeVo623_@YY*G0pxM=v0h&xYDEJ1Y1W%14}u~Fs9?(E7t3= zwCoCtwXUY|#9mrafkw=|ok1&{^IvoR+iAl@{udux^j9<+5@|B|MaRud%gcZl;&NuY zt@YKAO-ujyOq(rBNW!fG9rAz*mkVIKSVrX&$Mh$@_{a?I+lTJshEN>q;BR^?J!nPH z#U0OEgk{g8cXm&3rAg^K?{mKYJLKDnBlUPebQ}K_En!#FsosMwadGjkP zyNt=_2flmBdexKq%`Th{E`O`E+BD$YHfX)#HTuJQCPN}F0zmTI-Xu-Zy|cGnZzm~a zko3R3YI7~}IM_CoWD6PR16%cA)E}YxV&-UjoKc{PvNBZ3KC|B$l10ovCV(=0QLH6z zXE*GpS6>Je?BAFN$yDZ^1oWDpM8`N--CR76x(x{f&Y`M+9t42vMm_l*#S3NEq1*5S zk$(6HZOC+ZyWEK~ze~46nrW#3Te%$|0s4?@@G&lR2a47}1HO2di1QT)kAh!+B)}xu z3E;4C1rlI5a!ZyK9ukFq0zZ!^*Yn(yhR$Ur@xNV(-;e|{Tzno#9N9i#kI{h1hy}zl7&tv-!jMSn2O$R93PwN#i)@KbMDWs6k#&xobKFa>a2gwBr3wi` z>Y2tOCrdeIN=9o}{KUQ}h5q=iqHj1LEd3_>E?y~b%gVYEQq7wq2@-!7?-KlB4Az$W zuYiHGf=qC+HNRk#KcE^!3D5r5g+k?`sj zaQ;GpIfos9%V+nK=j}wM?Jc1e+B?jb?`DUBoaaTtXLpYXV9gnk_65qq7oI{xA~q47 zH)QIqH<%cbO)?A{_HXTvt8Sz$!mx0!KHsefD9UVHle;USH9G&eWph|prdpa=*4UGc z?Jx%7gwb*wD>^5#W*qrtH;wrH(u`956gVhpyM znoZ9R#XB(X;|6H@nCZRA?s?16KkDbBZlsu_~B1<_!D6r?_mqu0f4CW2IV9q_f0yWBk(tT=Kvy&N(7dB89uhoq@ zK9D1P{GaahV4j4gU2o>!J>|;*Lu=V$~F6{5D5Ge(#JgbzSSUFTu#6a$)5(o zs-dHQLNhLxYHjXMIpKb=VGF>I5IG%;e+gkoTEPb4`~9eQ7wY;>PNa2u1*~lX6zl`iT2iblhGR>8bh{1;bkF@U zv4?qljh4FIs1)oOeuV~z1O|^d1f4%m?HvBIsJavhOB9PJP2@Q@hNJ($CUy0$Cn}SC zEm7y{qwq~{0t`~D+IWI-VxlM)`7>@bmcpAlA z?{H%nv$jgw@42^mj&Lvf`uwz1P`gqUmTwWAfqg?+Xb@8$~Q+sc)$bzjo!x z)MDdZTfxrvL!n0|+m)V~f6MY*sF=9cD9FYjyk6M=#DShS_S~MANYj-v%bH#YX+ZRa zHn(9nb~TD&!a~|N$-(+9G3=+=bN1}UdN%;6u9%q`B&jcPd~sVj;T19K)*i+efzL(p z{2&pZz%j|z=a|<`(yN8gCx)OSuN$AjGzg-Bpqn*yIH>@!BUHEQI-0idL#B1-(+dGA5bwdip#g?rW*4Coesi2 zsO01x`%5L^cJiVxR=b!##S`2qA%J`LZSFVzT6erqcj2F*yoYeRkMu%ym!EctlCoJO zg9Nrd!kfYt5|Rhw5t0|UeQ4n-A4(u&;IZpycslw!79jF%pR;Z8eWrFw7J z;4jF58bgKw7npjQH;nh8m(Ire^RJ|_k>`P8DvyikXPl&Y<~hTNP~8I%m+}kQ6JIQ8 zypQjN)QIgfBFR@sqStXuoAXs(hPgX0W)W`Qpa^z1DXGCV_;!eROFuvvecizzEbVWD zC3}s1=3pn=BllnL?Ju_Lnhqkp1?MGxreU`lo~0SFvGqPMRsQe2+f3A7*lXgAdbZui zX^TtjljC0ly`D3;fmhcW^t*9ia@gD;{dbQ}?HHg~S!c`=Dh(YW;sgypSEAM30!c@3Z^TI10I(Vii;{{&CE zzzsu zQ9_oZx;f`_d(Zl#!%~%X`EN&nnQm0^wY46Bj7N$-GJ?16YlIc zg8N{e5%a1kI%37b?yyUi z3-lDtpJLEGNyk(bT^QkDD(YfFr%UeRa68lzD50Nqw<1(}I}AGP%Lwk1Ctt7<@;dNs zT2hzEH07=ys$G2EF6q0v4$0;(f?0duFPM#RmZw0;h`u;ATeqg4l%6RP;-2vcH$8vu z5iQ<9n3!f6iC-cpLO}6UZGV|;HyalnF*KrvuN8mESL6u7UZ?D0FhJ{PED36%tTGkU zpPbDZOCD5ZH4xWD0Ul8AzJbmSa}3zZ>GCTr4ASNx&doxd6?x!sq|<9FNx>Q!>aNI< z(?2r$@xBsLzt^ke5yS*J7R(PTDeXV*yjc?fv-1Ncl;kwoQwteX@bRMrCW3w_nI$Qf z{;q|TuEGin!3$*f*}d+P=`J+5DcY9|HekS2Q}YZ_z{hSaR_t<7q<3qY`5wUNQGt0^ zKMY-6o^NkW+9Ai{=+hO}EsIKU-?2frR|KbM6YvwOw;d{Egpl)r!Flg|f5O4v^fR+B z@0(f$*Jc3Q9vgXmW1?OH$hjgl)&KN=Wt`O8IAD_0^pN*FAE#Y}i{!ZocUM?o70|Qw z9&s54<4^{&@u6N&H5pzxaohbwh5V4>{Mvv$-=e~hz^v{1|KC#rsLlVPP%b49!h&jw zkBiUsMEFr3B=oyriT`o2#@`{lNz@QjookLF>Iw{r{#@W?SN?)zxx{uaJ#fMyFPb8V z`++7B0@7ZB8H5osDv=Buk^95RxBW=;f0j%{%$J-?@KypUjLUUcImPI4${`lhqew0M z(y>ulqM_=iI`fFt!w0198|cE`%oMNxt~i*_BsLz-hlUm5QQ>fu{mI$ZUe}(J?m#I# z(ypb>JVaIxyt96HfBxJW;%m$$dNuy1XwElcofs%F=p=38g3 zeZQ|1NLp!3Ia=vPOgP919@JU1R#ur{XiLuciN=TER=|~sp(*f~{!lQAy;1U0T2<$? zCcaMChf37YST)SGFlAN2f!8&*ea9S2ED7ky=AYs`+EzL$@8`R=x?9qVR!^6&`uSZb z-+P|X>3ID}bPcX2|MaRS<5S5-_<&0PSgpGTnQu3If@Xez=5^hawRLTixkjbgKWMGL zp80vML~I@(b;09FifMfHsc=rs2uMWc_7m?@hUg6V8(Q_kzG%UDW54Q zENJ_A#^(^tu`6&jvFnzN%-XjG9a~Pkj zN!=wH(y^8)^r~ja^E{4!=0glnzN7uJ@era+49v|iIm&BEoR#TzH?*m3O+X%SkM1lm|(G^KOOsAyWV4+tG zek)UFGtSuK4!70D!kans|Ag8?*iL)1TN37|d1tlrVa)Rs1}O=RrlB(BO#S9uU^|fd zgLMx?Cgg_<%ejM|MXPIV9n!feSl-YZm0Q46@a}FzcVW*0q0cYlCshvoianIzVes!t z{>N(5<;&DO1b}qzEAz*S*MMciyeWF9tccp)l!?oglkKn*nV@jkBeQ72;g1xjl<46& zED5PkAMqvYN;kh_RAsOt_HrEN17eC~mX^|c%vaZ(nK`~PeCCadSczTe&5fsLSwmOx zV+FR;g%%s=ldA)Qziq$)HIVr$k4D((T;RowZdX1WJzVLCpbA%%Qi zD3ne@#A@JT;fcp;Qzd6Qajqc7ceXqq6tCDZT%AQ$5#3_%(P$AnL9`&Y8|(JII1R{x zJbScwpR~eV19|a*xTqu?8zG;RiN<>#Tb9?%Ot$VTvN^oIr^poFIV8Nej8jf=^M`{u z6ZAfAMRxYQcvW>?n@bn+@Aw2f)}T1$UJ~oonF9g&z|D7yOEh0Y)WLK6|J0kPnA< zWlIpSc0B~Ef)bc5Pdx$JDMA9$x_UVf{@4+Rj`IjmA98$`=1fI&GA7Cw(ABZ5&u#az zPfyPcPoE(E5rBwWG?Ph~KC#a&WJ@peK~@S`Wn(#1u`@2$FLiF5{+>mt2Y$T0NJFD} zGDQ!0Zn?l0xd8iLhQPImGvVr}U-_-fzl%|9(}A$5##F6)f7Bq~)5_(JmLGSWg1xC= zJJgvaiAv=jME;z-j|-~U_1i}UG5?;i{T63<&k&`ORsSCV&OkB0SI^;fU;He-|H*dT zy!Df4u8r89Q{_)%SLP#HN{YOxDEHVOFfjHrG;P1K6eq>|P{U6E*fhUNu3+9+aS*QH zoIts32%un-EiJrl7n!=j%*pRp-0#LgsL~gh=4f&)xwmyd*CK@2QDd}8kgV85_6Pz% z89+(n0r(mZ0026+Z$?8~6GqMqAvv3;yB8S?z=n++03Sw9b4cdxFmp2^W|Jz;+iR6! zprNLHGkX0$Ma|Hcv9M>Vse*nM($R>Y)pa1G>Xk5JtHy53Q`Eo+WMb?vZ??cD- zy{K#1i$@+Ic;5P77WZAV?gu#hr<-x%%dHWMhG$VB2cal&2>6L%(Z_MX@Topb-}wPl zaZ?x+-i=KcU0U3K`p^G^mvkO6l%uK~o*+L!)2=-r4N+7#G$J)DAjEOtcme)^U`S+d zo*?ai-T%c*-}M;1>y6mbY%JRgssdTO=D_#xiwA>fy6~N7Ymrb>?J&Wkd?P+TJy|)i z3!w*lkwFMBJQhT6R}SjL19;hv*Wg3j2LNcSk01Z=NBG;XuGgJlZ!9e5{TI$j#?r`U zw7+9>S!9m)p`~+vNt^8?rb;rJI@uL);CK##AwrHG2@e4NK&ih_6$P>=qwm>MP-O{i z8#khLZOK~h=Giz!+!i9ET^oMq$sjby4rZf>oV*)V(n)NpdkVW6DDo&Kq;plM4q91& z9d!f1%)2rCU;lx4^^3sQu0d$^0VF4;5gPsvT+s12AOOb>T@K;GZ-Dptk&MP5XHqZ+ zx9X!Fi9*mB|7tPv&WZN7x1V5zWl)uh$=!7 zRjN#lUG$>NFpR5;f<%P9pyCCA>L5(@-V1Ci{f?xY&~)*;p!k9iwr_SjscAYuy^3Ti36JnAr~i@P|TZS=){W@BRv|Yx|fhX5lLFw*amN zFV7%1brQk09RLnE8cb>606hKOcjJP_A7FJoG?`2ta#8>dFZecm!P2+w_Y1h!yfdv&jNktxXCaV_Xniyz z#~ST<=^2DmV+p))-M=%QoOfiz5T^J13G4i!V*8XO8FdXI)Lpb6JkO)Pxdxed%uq*j zpT-M3c#bPN7MKRP{CSG3AlSGDZH*;^jpEObB^+nS8>8hl@*}nFE0`I7D_Uzy{o00B zV_xVd6CXq4`ESP9G|_fCJ4&$7*NuC^~Twm5I4;-4X1=-(cp6 z7Bp5zv9s_U_yH;BgMV1Y?tJ@9W@~WAUEjmLD|Z*q7c-@PRX>JF62UVMo`fHdqVwSM zuqv-`?q>WUVthdzgod${5e(xG1d#<@CRP$<5XGF6R<-R0q-+i#1c})w_DUavzuG>a zrKQ$+Tq_6rSKWovKf4~|nSHo!!;d_le@W-d5WEj6;n5fPIQ$!4LM_fG&%p1?Aw${$&PScc zQJt?SF$j4V5~EJ@Y_7!8LB<|bwkn*1NM?R58TZ)@)YW-h_dnT;5U?F4pz99u3ZNeSxP zDIq}DLwBQDeH?Sb<=A?~3oL9&0Eo>dur=}ltnG9syY2?pulYQFcV7-rve?~p7YN}o zk=%%(Q}09jww)k~gpSRfIPr_m;-%~U!*H&^t9bd2?;<|<-*{&HI*4jD(xHP`v-d(M zqJ-IzVXWG;8SQ47M8!BcmU0pVUl267R<0rR+?fUjVc|P5q<>H#z>ickiP+E-g0(g1 z+_91R{bUkJTPbE5^nHvk5TG)KRYhe6Tx!?T)z_wfW2FyZ1w+|^Rv%fCjq7)gXU-Gt zzG2pwZ#TaXGg()Cj@vN+0CE|+qhG%kOGV=t?H;xqW47g)?@4CZ4OB(3I0!Ryli&el zBNIr*no!$NRXU#XtYk@sFUSZ61+>5}Lx8e1$E6}GWhul7fhs7XHcp%0mz!}uGC7U8 zhrfi2*4}}at=G5Lci_%DK8HO0_SUDdZ~bxCbqe}KY#Mt#o*n!& zB(4d<>?7FG_#npPo6!75_o=H^urOG05U${yzWatM$9?7yJSJdN{Peof826AI+;uJb?SSY9;Uh}QF8UOI2jJCL6Fy}`JDAo^1L z^{4-dE8qGqW;ye55Q-vj`pF)&r+WA=9*DtYV#}>pRXr|UqbUP%`nZiTnNd>0i0&g74kU?jSfQVB&l`F zZZyC5aXkH#&tMDx695E>Y$N8f>k;&e@c9&^0v8~a3Bo6hpw-`n%`J2W)88EYH$eS2 zJo@P6=#K1!sMa9ryA0R9Vg~|#=Y=-M2!nE(RGSxSgaG&ct`F;R8xE{Hj2pHN0!GCi zR0lJ7>yCHfuKzuV7r%CU{yY~9La)e|X^sdN(+R|9t7vDiqZkP=F`9)U%joE=h3ZNo zMi*6?kEmoyjG82G1znyMRaG%JF;nb|rUwVny0!&`0I{eFDP#X^6iIP(coZUa-7*MS zQHC@rk)iHs5CUx2-c~A8RRKaU_0+AnsPTI^uhlvEIRG3Txd77}zl%VVfZmhSxVrjp zv8CP0zF1#>9KdlP3t#})3>erH`&K`O`=9t91mz)IxhXHd*bhAM_?vLi2Y(E~?=!UR zr~mgBT-(fMc>usvd<}Y%*MacWcwR?|5mpt*;#GkgaO*$*3ts#GK5iK=XJ-^7#!iDL zIfMsRp{c$dkduJF-e5@Jd_DyIK1d?n$gM>j)87j%f*Sx7DTl;djJgl7vf)R2!>Atq zSKPSG$R0eEd;?lTHEemFz#8~A-ORp@Sz^?!@j~B5InBwq>5}K+v7!G(YHAp5+xJ5f zWkf~?P}jZ=iSPto(p<{3cmCQ(arp38abDve5%TAtjPdCH^=H6ue6_pl>o9us5Uy2I;=<^O=E{^zH7%lV9j-|Mhto5qN1S+@qQl;8?_)NqN>kvjH6H0+LBj(KwIWk}UNA8Kj> z!In*^Y3qQP%At8(ClY`C3=VeE^9??g{T!~k@)C4ERa!tB)Epp84O^O{>tK`6gPWH!;s`yWnQX;32eVfmC|bx3xi8|6}z1;Dy-U ze27_w8Lf5JXz2pOt$^B43NLB>0PgtXdyxz7Kttya?7yrL6T=C}vRqljQ@gJ>E07}{$F^A*zm=|B&L?gZ9e@n(bu{tfGK8-!Y&q6j{Rq0xTG{?(A% zZh*>bC#oa2W4-v#2>P;!Wm_@nyAh=0lQ8Jaf;|^2s8G0(1NOoYXMJB@V7?yh9?b8| zZf(!;d{JJOs466z0KQ;Nby5D1a!~_2%j}R9d;)roj$!DxpTsp={|H~B0)rF+fWsrt z2fz1I@Spw%>|IUouQ#*~Vngd7dPn*Zn?8imQ>PH07((-=Qs3YGm%H)G&7aBtx*dOh z=amB_v49C23yD`7?PMiX8Cs6QjNNM@u%%9mSl8Pz}pL=u$3HA(p^D00v3@B7;!JR#X%fegNTt5!6;!L(FCo zX{e{KqXxeQl5t<=6~D zTJFwL+1?n003fE_r8pqO>I`Ga@=Mv0ZA!18qVEeP`-Z6J3IY#BAHZUEma*1HvTPfe zBBs+o-a)7Z>sg;wfhkIr^94m>ln{$kQjryKj@C#kn$~n+W^kZTu%%2pjGQ=$gZL;~ zH_-F^y=2`NF!0PT;pem1(==iUymZ37CtieA_0!nD{)o2|0rX66#kLMj31BwajGuM= z7=FKujk`K=;F47UKxHTcfInCTnI_9+g8BLE+*1v*`-30^NS%KWzyIzZK_(wT4L5{W zUf`_2a=t&G(jc!Wvc2sXdPj%R+S=-P?g|zTD-OaHoD-N&4npmGzPRjGRTcVik<$$> zejkUN3anx1=Ic>%5IP@gkN_|_+zb?aq8jW2!Riol4igfVqouip$5i8a5R+2~$i|#R zg23DG5%YB^IGQL<2R(=>T8@JRIXUIt!A1zc325Ga8L<5_bVpu;K=&7MUdI#8=TdMy z`(CsLOLBsO0kNtZ^s!U5Xs^y80B?YI?_%Qs;qiE)jBNq4JgD?7S zQBIJB+E5BRC$7i$zxo@z;%x_Y%rc;>zZ&C%Nu*;@1RL70|3V)+ca#zBm@kuN@R4(w z!a)chy9*oCpF>hY7*GETnpa!dtdW^CV#69EXHwaKmQ~hhd&~9p3!c28S|7n@Fls8s zen>F&BWCu6l#>u1Pa-xx$V@6;fD?O==lSACBa1TnpE?9J8^!udUWCpa8>r_L*_jCu zV<(OQ*=aOwyBHlCR+WxX33~4R8D6pTTL6H^9(o~$+uw_A7av5h%7^g7U&l2YZ>5%B zz-9B z&+*rv?89Ym`$%!y;FFKw`nHc#%SdKx(A)99*!RM#(eAGXKN`l&@F@QJcoJ8x`I+I^ z*KPlMJo=}ram9@n8@7iB$H5bAQKx3YAh9{*vND1-wNT{7;&GA5DMb5^Ld@h4sPDk) z9Xl|5vI~lsL;Lm}`CuD1gOD~f_8&Q8=vsAlHUPRjMgS7g40a~ohjm-5k+Z8VeJcoy zUFHJiFbH!w802&~U4(4gh7D-jP-?r4t2F9o=k7h2dZZog4O7K^7q;D2{9E8vY!|xkOdaG51xo?Y4;C@HSlbu1^+!^MX(` zBor2iiX>y?_!%g&g3$Q^T>bX9V@vuA*w%iATJH6mz5|^24z};Ne^w{rug2Eb z{Rh@Hbl}Y6nmT|`w;sa&zsEC=g>i80Lx$sdo}jTNhBsdLZS;rlL|5;RuyJcGriaEH z9fYZL{ypZ6R*(JRFL>?lpBs+vnz<0lo-d$oeH-dbleSP3P_t?kG6%nmGsixNogGIF z`==ADaM2q-l&@z=R?y%~r{seq6oVVbYYf8TF{&|OtZ3|EC;=GCGTn#1ZriVZF@|s3 ziuFysmh+|)?bvwmBJ>@hzhhM;5JI)s@RmC;_}|aN#wI#5v|Q9xCGq(o`;$%^Vn8`rNL=hc4f>2#EnpUp@0ewDSKzufWi#q5=o8IsajK#O& zvW@pxocvvryKzkTDw1&t>%|}7%B}ZdEYgWT^nVtY-26(2b0e60>}z<{0c{Mq`^1e< zcYg+#zH=)k`v!2^6TP^$_Je4xqZ`%qP*o0ho_;eFt_s=ut5CaY9kM;Y$E)|09D?Vs zeHedv<_}nP@D-NxiK>F*d5Zzp3j>9ogV6as)7TKKpgUd=@(za<44v#NItXDiE*k4o zaE8+jwInx<^vPSXHt-;#;yTEyUJBpZYc+{bLg4HAXSBBrQqSFx`!4Xncd>~vL~c-i z0w}=Psox^8>wA!=k71Ym1qeQ^P7R?MG~h>=e(*jd#Q#9u&X*VGlx&GnAI6Rz!t{fm zhoB}9RkmTq-+;N3JFxwQH>0|#G*~$*Kj_{SOO3!2EG!eV!_dn3VHsip7=t`i<9f`D zP)!vwi3F3~p=W35jkh!5Ec(Csueh}SpRjGm9QE9s6o9h*qo`?U0xtPBo(vB|$>ngN z^cG8S1^{eb)eYR+R7zjn2Xs&E#&mieZrXIW;n*9t{}@R8*b;C_!`*%F!m3N&h7C0# zKzKVIIQ}&J(|>|Q>QTGlcz&s>T} zzxYB#>RyfO-}f#{o3ML_?iqwCvT1_&TpSIp4G4zv@0B7ks$X*zoM(DRVth7Y_}2WP z5TtwnMJf_=?I1Ku`s>q4=%G>=Ud4Tu+?6&CLPn6gt2~Y-(8g@pBxvct4$QnXE6vl&iG!4s#>loOES_Cdf#S&Pz5L%1RFb0U0ZVGoA0@BjoA1w5<^cR z*t!jETlXT9h#~y!9e6?Kx2S!i9-CTv+||j1zxKTlYhUseWOE#D?Uo^r{RL~eyWtmd zcw+QMT(SPQxNy@`MYS18_F*>Jf|KJ1k?8ygHXdlkf8IX?F4cqjjThjCS2dOj&(Cjp zf!EfC#Q2>BszRw%XA_$YKz39I$IQkGaWegiXF3ju*6Q+pJ-W zvM&0CHkCnWJy$@^F!{ElzMG@$tD4$DF8Vk8<;d+=&E1FX&1^^E(A*Y`sV~L_jo(Mm zC!%+DAGW{vU6`9PGE~-F@i#d9^l4n!`9N`--zQ`J(d~Hb{tsaFmG8u+1N#9Y146J^ zFFPnQ!I66(!M?hmBAIKznZ%8_wEf5UoAa!>ehPTq-ml^(zxCnTjX$^S+u3m(n7b-Z zJu2a^gRemA)y4>aLWm*y+#k{$gn9#^n}|o_qQxjxR)DX%3H;T6z@dkKh6`)|6#{QR z$$#7VpT_?@@GD&M{EfwJC!d)_-I-V6r5hdxcIuBlgjo7c-2dAbPqck@XRCIL(og*}u3me4ar<=tTOnl!;13xYhT*}{;@`=sAv88K9N9XkhTsD$gODXL z%JIB4_lD`iPDT=lj}N%&?VQ?PVS(KD>}g~lcoSZ_`WXN~X6TD}=$n6s3*PoQ)HRiy z`~|7LQ}^G47q0ndT(m}u#!jYc@bHmauxamhs0u;XeZRr0^3Fj3xP0?{0PX|wpJeAo zcAcK+y8>P0?{Lu#Td|e{I<{;?Q->C9JfQOzbZ)4EwDv2wt?R3p>wW?++w>)@Z5n5k z>JC@d{5^8N{0EGLKMvw=#C2_-0B=gH^H=dd5v*&*rsrLW1M5~7w-;oFS}u7Xjy*Gl zeQib~&lSl_aNGa;M{Ky}&A9m5O~|APkTd2T>+h4%edt-VrG5dC)FIpcdWg|6T-*E^ z1RF#Ytw|%$qyhoNM(gp=*Io;??ON2fZ@|>ZEOPNV2-VG4yJa;R8W~1UgL9BWbweGJ zky3PhBAUU_u`Uq48qG7`raA}vW_DrQ8!kf6)9k+136>m$4hGA~;V9-th7F8-j^~); zZvnOSK|GfDG+xp0VcX-Ht0vHlpK3p5f$5>u7;XFq$fO!;=DvWn4O+D9cx)ra+rNV3 z@DW_I%6k8Q?R`F?xlYs_d=IuZmHHG~zQyp+SW$v#yk`j88XK|UyzR*3euBR~{v}+w z>2`R-fVT`Vr+XL^2)3t9E@2SM9elKvc|;d|F$F!X;> zx8)l8#F{7=x&J@0xuZ0{RaFi{2l~*`xfazkVXT^(!lOeUgxvGrXbKD=r&MD~z6|?r zdOOZbX$;Rm^}phQ(`hvC0B*QQ#pd%v7&*m)l~gR#C_v9@Wfxb5neFJkc7 z|HW)>GeUd>jloF_WS$Rc|GyznTjwd~pcyo0pquhUirgEE&8wd?Spx%{9fSqey2YSQ z2x2FHhyB5CqN+tRNUz2ap!^ld-UiHMHe+MMvz8pG`!P-@Y`*~FRM*`T*N$^T4AD`3YU)XPK@Rkf`n&+|w`k_cA$5ex;O$SMYooPd%XN8^@@ zu=$_@A4(2FMJ71;m&fqJj(yjAArF7d`xLlyBvh(fD1c? z9YS$H0F`49Dw1p%vpJ4uItc9;gyl$#TJqBx3q*GE2y(Gm)NMT9;tW$M-O;A(JoDsn62CZ`Qb3Xjd$?7>+ZN3RG3KJoChC z1gd18OBDdB8k^C+X(Lie8BHyIWa0@-b)P~md;*-m4ja$gieOC*a)~HLPjw;u*bi{^ z+8+P~vPWSVxzSp7Mhv;iA#cY0-LF8h{XMvtybG%vCor98MPl72(ZT^$A%e@VX+`V0 z4`T4B*3Y~IKsxh6{BF<>-seNd#x=;sW>DQ0$Lf6w)^GqoaP>>p16XglUST}qIKW?B z4Jl=`a|WtHNJpLPtthII*RZKp6 zH~8ErNX-_sY`*~hss>Dt&LDg0M|jijl2qE|pi7gP@wN-;gDMRfF$A*Z9hU8c5Ji515-Re!KI`MvNY&ZrM{5}~sw!a&<-SJ&qdfj%M zezFhOHNF?k8zR)>8>-`Y)deM1*GmuFj+f$g?ZbK&*x3Z&QtGKpWTh$$r7s7+?&Sy# ze+TOtjnQD2Ynt)jShS{M3oOq-guipA;Lpimf6fTxNga`lfS?me?0I%BV zUOyiI?mqP@Oeq&5*nR-Ju4=`_y1u{QbO)>~=|F(Td3#h~&&BAA+=IslMnI+>K^=N= z(YpK4Ql}Yp;+Z@U~ZFmgMIN*ql4cOTHNy2~;Ac1nu@#GvP=g^_5 z-VZh1U13*s^_-F5c{S_0bm>fWg%Ny9A$8TKc=CZyvODwdj(y&=`S+OG z|0y0S-Am`N1uhx}%*`w4#T90r_%B{P@>jNchu^5kh4@)RVvcM3J_CFP=CNe2ves{t zCv$x7-d|wwiknGi{N2<B?+e|&megrlO21(JVS5gp4M*r)c`LL!_kaHM z4hnnzm8L$KCLo$hIevH7y8&2PaxTm(3&W^N=91syj(d|_I~2^gwcC-&VdN{nM(6Gw zj_r2oOn~+sm-58Jui%`v@A{tWrfvTp_uYPq=Sm+W*E>vGPXyD1r=M8j%)t{F#TBA2 z{1)%tT?NGZTmDXXt|$38joIL6y1L53kEPO-_WT96Kl~qbL>}UlehnA?Omz;nzi$UP zRL<^oNhZ6coZb5v0%pPL!JgmU=JvfDY81M`xwrf@hd=YT4CI`CSI1@V_S}oz?DI&v zt%ub69_P_7{|kjv_oA4~NIkEo`JAnoKyLS&IHtdvmGONPmlhd3?;M(kcX05bhcSx_ zj6HKFUD9`uC4(LTo3U)K+w)c6tG>UN?NwVo%fnwiOykdei&Qqj%EBVkqn2-R=FkZm zJK7vUu~b^);Jyh|O(ol!71nNA3lLR)FU3NU(()`*ho40%Ok$Rn&~=Ge=N{Sy`>O|- zCX3@oSlRzoF7N)LLw>q!t?5&=@+)u*9y$3+7PB{#>Dj{czK4h|KgpKHAF#FE8EAjx z)GN909RmO?&shDZy{`qc@1sbVB!TvLjK=l`Mh{v5p)48n_u1!vYb;e1xZslNxFO4u z!!PU}YDK=_KK6a^E|CZf=yz{$vNZ=VOD8>vWrIp#k4`EC^ zj@ouUkcnnTV?O|o;lf*LBad~NW zmcsH1@pRG==PvoMXX-J@6*%{>HeF=}$~C-w&9&AN)m_s|28SQIhY|I&T-5e7OXG2# z8$F+FY?Rlu?z7nc1|0my5zZ`aVxj4+Y`g9{PT%pzym_;~wA=Jjj?wu(CZacT+511n z3lBWuY?`XjwrLwWpXK4PS6xm})LL@!uiLveW#y{MyTig2FSY-Jrv z#i9pv(~BiX>?n~*GjrOCLlu2M8Ijq6AVt;u=Y8F9_%t&Ub4;H)LoAb}cgGgzz9Gg^ zw}oY4I<6yP5SEffEJBn=Rf+^4g^QZ1z!f?>KN*-?^-q2 z8co0;G>X;sC9(}H&CYw~=|ocp?fq>;V-oR%#-Yca!o24#T;2RUlqH96{T%?JScK={ zkaYqdIDRwI_FpB_m}K7_U*&Z}f8^Qcxzjs2xwMBBsgo6>5ejpZv{trUaWS1Sg=da^ zh3_6XPRCFmo3=GlShUvshOT?wc_MDGV{4vlMB}JWI2n>kL9?7tKY%cdnGU+Xi$UmfGqS1?OV%FnYXvBU zzpn;1MfJNX?X=ZbsI*B05C$!xR;{~RLW{(BUQ{;4ti71;x$<2+^PL2-BY#ap

A- zjqG~+ACh|d1V*WZsiYvi#kY-M^_guy^><9%`n^Dzq=wie4Sb2cKRCofp^a<1s&x{k z4(B*MKggvmkE3p`J*kh+x08;|64OG;QrK*EIm6S5f1$Zs;OUsAW6@j&2xT8RNk6kY zYl^)Xy+Z=cCOaZWiC_O^jvxFoOEZ(mTAan1S$0PM&cQ6qrKWlRj<+)Qd<(m~$7{A@ zz3j-;7W8euPa)Ee97$3t=2?jKbMoM?(3P%VGUC;{{ezaT+xls~^S~&R@uk@RvVkuC);#ekFyHHV?(Tl9mr9uV8RiZ@NVKtsSlhPBZz2h- z3WLxVbavDrH0%p3xw5ZrR0E%Ot1Dx*M#i_SDjfWuKjO*`@psshJ3{N}8+hWow~=d2 zGci%3XZ}xwfKhXLf;Xju*;#H|F2wlu*x#`2raklm$y_swQ_i+TY>XxM7_CMk6{(-L z<|0?V`bOYJ0FEDCW0-Or`Ban zsVUyF<6}JY?6+7s=hMXEVo9UM)<))Mtml|tDbYUq+jMvOXCB^p?5EIMuHwAc-2(W{ zM{=m-2uYTmpxLTZVCV)bOL=06Xkfq_f&&xweKOf7-&^<{2K1kPf|MeO_K5%q*yas(oV+Wr6c!pDfjZ& z7tZJ5h3zzGC%ACnZg!Q~0-J-Lh-HtDy@tMf$!KbQDbpB-uXiwz4f0djz5J=0ki!TTy*{= zj`q0CBekbSXxw@o`JM}f`=i8(weKbdyBF}=6{La?OVvURVQXqok+tlF%6wl&pw7Z^Bi&WBpK}#QDupV;ucmOn539`me%wc zc7j#^Eopm(C5Z>V{_Es+zk&S1EXGqGKf}vJK3$ ze{so%j0==@CFJ=eTS$@Ojoj`Gci+=BuqnRV(%kN zJ@f=pyd9Fg#7Do#w$^)ys;gs`!fNu(zrbD?d$gO)sS|A0sInc_V4pwq-|V?&JD>c_ z?Y!rrkI|BuV5JbHsB1iabPsoreu##Q!3TD~)uT?}wnq+~$HGdAtGC_h*q)6~!J{`b zz1&8J3`ungO(}YkRl?X%?r|P`>Qlrn{}_g@GjaEyaMi$S{kS|#{b!EH=U>I(Pkj8W zBu2~S3fyz1wrXq1OEju^;#rJOn*H^V)+aUPE%v`%w~BcELzjCz7Jj(41v4wbMd^;~pT}z?i656-!rZX1v-CvDF7~Hi9 zSq(_$GEFXg+xvLr);84re}*RKnQVMNEf=1{lwM#~F3{HBN2z3R^x?ZGnQ=0g4P!63 z(0{?DJTdkS=EhIcIy6kQFUGUaKZ`p0Z4$~17xvyJzBX{<*8hg(|K`NwUF@IO!xPPK zM$U{du&bZuW`$gP6N)O;RfEv+w_%{*0ti>fRC^MlUMzY*Yj={SW|szC^(`FZuAw6; zfM`FjRTtMUF~|aNrcRB~bmBL;r0of4u+}Tz{pa81l^^~6%22!(*8O?`gt|at6rX!4 zKWo!4Ie6EtT-^MROfI%@Ec4rRZ}X}Eg?-PMCg#c*W6wXtu>9|AYdqrkSyM`!*YpJ# z`vUtX&t<0Ahtz>(}*S>U&>Vs(z2rc$`?G`aPBM`O134&lFb-(f(X!Paa}!vYMD3 ziAAjEv^viDG5^3{ZOCqAFxl@srfD*M=y}dO_XJujLCdD?fiknxCyp@n=zma%4AOP} z^{ATloVthFD)Ffo?z)S%y+1)?vpC_r9{cXeU*(#Q{9g_{Vo3yTy7m@se|CiPF^FDYIYq%&|%0&4bYVDV9`;-REvc)}ly~H%rhr=vC6mwdeS5 z>EpbDpYv@?**3R4;@Or{rqi?mzxo zF73V5QA6*@#4!35-(Yi7j)~(ZJja=a>2S&-F-sQJe(z!Gwb-RV?2e5aXmrR?# z_#4cQO>t`fVWRD@&wK}^M}D21tv_&_zqP>&4o)t%GNHYa&5_U7Y+t=L+w)sQqE^8S zy|`KfE#A%i_J3#Z$y@1M{tW4eIFofPcH_1$1FlTp8AR$3fi~Fm98cZ)Id=WTPgS-# z7SyK6*uJOeIQV9^B`0A8?tbF8X+GyC0Na=%Vi3Ay6wAsIP6pvBn5ZfhcC}t~24(a@ z_>5J*dI@D)n3>~(rq2cL6V(cwGjR>iKfROT>=82E;Q`9B%`c}(yBK^kOAS11{w>=s z+giO=BF>`gv(^FtLM|wzAcKp0tW5u{ zo}bsVZPz%ZoLYODh{hwNnwnW$urkDt-Sr>5cJq^>(_k19-<$Yjdatm~ zQ+XePpg<&%Aioe?AHh0@>;Q#Ig#x1o4>3DFgRYxsu{c9}wj#+ArGj{#3?nfrNfuM_ zf@`<(FW>nuqVo^35ZTRz*KOn4w8_3h|G@(fe}St9zsl@V1KH#Z-MQfGcej1@&>wN> z+b?JOR8{KM-DmB7LV1zy-bWM&XlQLji$=KUogbiBct0;Z`V{)Y85XC`5KU%C_VwFw z5#jR?`Wli%wEY~S?dJfN7k2Nn7js$4D>rs--y@uhcRsFVkDYlHT|e9=y*6aU8GIa`?=TjI|D_!_BGf5mJ1?r}&k%CgDTLtlY2Ux735!|`9|st^4U zk}Or;kzUL*dH8wa4LPQd?qjR^k6hn$7&6XwZN^_g!G66<8$Q9}-oIg`7^kUg{S&1Q zPYiLv`(9o7Tk!$ul8qVc-N?;`DX$0I<^~AMLFsI3BfEa~pE&;felj6I`pR`nMRweH z3!nRTjMs1e)4)15QLVs*{0`rHM5g-!t44EvrNrq&r;ybM*|sJ!4dTQTsQNbxvw0#h zl}JH1Un7ga3*)!p4T(q`!1!Y{x1XBC)t*m#erVb0R381UjN>0r>EQLXcT2v zTK341i+RKUc@wAhKF_Ov<}$wV*gZ%}7n)+SICX~2dzN|YKm?RHpZJ?E(!$-$FGYF( zwf{&eR+R=Wlr;YN?*GGu@A@dE(bN3-_wVKPTmFo}*29kLHO6bp`>zzD+_CR1UsS`GY7;+vwW1yRwECke`L7#bOnfieH?g{9f9( z?qR<9Q$V$>vZO@l+P;fa)|#`FScWqPz4FEY4XsTyv{ncFuC3c)>vjMh8vh`Ly`Sgu zo-YZH1?bM5T=l!GB&D$jTKN8%-{kVQ+(OTmUf*QuGY5}SJo+7UIYrwAZ}2ed%CQYQ zFrp|5ON%Ct-#?C7kZ5i)F-k>B1$UpU9y2#n;EvCHhS=;6NokWDD_l?KmaQ~2S%b~^ z;Umn9Pm*nIrlsG>ytXsy6;<^Gzx9&I{Nywpr#{SJTb1oE%O*E={sG_rbOYCa@c&qC z3lsQqpMMPip~xpK5R zK1f58MSl)6M@$}ij`rzKGTi(e+d{}Zk2F07JO+G?*?;>3jw$bC+naw6CEiG3xyqo6 z#A8I`5$r*~{BoZBqCcoFD>AZdqUbZ|a;9>eURdGGvFS>9Q&KdtZ8@gTjG-5;dkkWs zmVLO%(^JfiS7Rn8o_dV<^tULO4UD#5K}*j&h37f5_i2WfK1P4;5JRgj=giZMJoCUi zXxR0Ws1Yam;mmn72IQ6?ceC%3gjLql^uR<>tDDD3@&TIQ0MW)8$e#`Ct5tR`fpdExEf}6rra6hzQH0*LsB9{V=^nt z)xlZSuo#f5r;OHM8Tw@g*LsL$$JQ-HC^(sClkQr4$%?g4_vSuM?DabM!hn+PnXSxJ zgDXv)Ib<#Bdu_=lNK$rAum8FC@y$=I(47Atve8ovx9+R3I@ZJPJN+)Q7k+?@TpgZU zE_ye2J=w&z$Y*FyS!qA}CNDtR^GTWq8~DaOS94wSZ0S6^($Q(2E`EfjoflS^ZuXpEQ;Lw-`F_${KE~9OUqN~NLo~#!ET@Hhnz7`&Xu0rR zoIIhUpSpwY`9Gr}K2!7gt*7l;`e`2jhkJ;2oJY22h{C+g^rQbycl3UG%hGH%pYV#t zk8|{?uQL{XBaw672#6VkTEtdc5|d)SL?osPYhQ^NRcoSNPuC zKhwJKubkI&k7M7h9bVbcGqas67BURAyx{q5ekH|y)4xK`g|9-ACGyKfc;eSNr}It@ zjBVy<;~&vI>|8)c!Qdb`=dAKKzw;-5nN$D$Uvy@~!IZEzxsc<*<&Ps# zf9g-YDc2cPynGU%c!=bh~@bOmTg#aab=nL z>3K3u8Hb#+-6To&oo}l9NUMoJaOF({z(Dyf02myUx%{S&@z7l#AQfL`GB0u4owxFq zoqz1;6W@FKdT!Y19b3P=_YZje@BRieqpqw?7kKIz(^tQJ4?rkLDzSgdrw^UMGz{cK zE9r!oHX=+&3MseHd&Qf1=C<#$qcxmdx80_e+F8h_F_D-q^iynm9jS{hLW!m*76aN`~1Bsc^5d3y8&|B-)D}^QY1r1SJYK&SS%%O zH(~v@kn=$;dFg6fCFGsI#tY9+v!(Sh$4@0);l7z)q2+=%kd`D``g?iup-whWzMX7h z#<8!e=xi>2jENt1aB1IN$eIZulqs)n`6PG!V~Q(({Z#Nw5V9vx-V(ywt&%RvBE#-e=t!H=-zhBpb9GKOI?dTgAj={lU3-Gc^Slp5V3)eKf9OVQ?Dh~5k7$`tQ{U3;0IG54-11Yyz%Zr>|Maa}D4e>F zlzN7xQjUD~D&F|sRyqeGoRUj)z4IboIQDfW4*fhA4%`>2%|CeIM$Tk@jvW_tGq9>m z@b$+|!RhaCN!M3BK<8!pPk3ba?bhCKigx|=Og-0xVM;tP{uk{4mic-_J-(?b_9?nQ`8l`9FN;i;vLxJD>N|nq8V(Wa8n^bN%4I zq6}GWeCo-&S?qNtjF#s^`&xTye3{A5Uc=@6cQQWL%>Ul=AqFqMiAYX7hMYJw$?(+A zb5Z-D3Oai+ck$%Ee2dwEPjLQ?mossC5=@vlIYBht!r|wRQdnA|b)c81<{Y<*C7FC) zA{sTBIC_-)vG36xdzclu#D&dIRuNj4U7!0U?*GbWE_uVNexwY-{PK!peeaj3XXpk` zeBlqdw(GCiG2}h>c|CUncQf_XKjE&(r;yAddVUcpJxFWcYMpn}T7uQ0>d_dktf-oZ-#5@NeP2uh6n(h4g! z8B({MeEMNBPraExl^n7#JNG%lW5;c=WrL%5t(5?RTXQTxA_8vrFk+Lg;c8c#k5$I(u#9!5f&sX`y%zo z+dCM9N$ZGs6juQqkndGY2cb z`|10a(Ha{6CI;_hdc;i=TlPRuj)_r0VuHJVPJP@W|aSz@Go~)F8~y zj5?=9X`=+v4;T7v!95+PjXvh5l=Hp_l%gyYK!OR}6mB=V;4bi2n)k zj%HG4ewwySHQ@BZ=q`$T{*GR)Dz(wl+ez2;Kji7V9z*IX(J64?-F;b&Isrp>@i|$M zQ8cY00aVBr8~~wTZ40+0QxW!Ddn2RAk2ChQ3mIs(#>6`hyoRfO?Ng!Tn?>?6iNWv&+q4+LtkX+ z=)-7D+v&P;Cjfmrf0HvO&hW$?57OA2pkwn0C+__MS2cgKa^7Kv;N~{xcTgy4JfQyrO#_>V^~LGv>|pwotrpGdeTtrs8V#34Dao>L z4TiEU&d!l(%1~Gl_h-2d24S%2^mMLN5;0oecr&-}zmf454nScJt+AKO-!j7IzVuDr zvGt?KvdR2%im#ve8{YQ2KaHYPYv|dp>u9ThLFfsB3IcFtfY3ByVRFn)hMNDnsja3&Lj7HnoINV^~mH6x9($@@bx{* zc4fnFv-H_NVm_ba!pw;^2mQ+SZ#?>zsLWsTD;k(cL z55~T)@Dn?{Y7-9}x{_ms8%XwFMDOMtjoTrfaZ(#4fn-!Xow8~SJD9C5|6CSJC0o*? z7VuP_TS`S8EvgA6JZqUf7`AD9GZ$Tb6ZwV9h-Dht_sH|yd**kUeDI%m?e@=B_G?Lx z(K2xbdR#_HR*#QG^4zldkGbuRaSE}`NQ?VeQMwtttF z0CIYe9TyLf=zt!@q_yAbxbAhX`@46HV-no->KNc;_s0{mL#lRoSVB7IHvZ#ZKhD*C zpQkH#3Xv5`8sD7#8n^u6&CE^Caq@+#Z1&c3JArr8eBr10>c@YBTQ+}Ec#a^;B`p@q z^{LOosn!0mnrpK%_bDEH_EpS9E~fW_tI$-1QZbJfiDDQAXZAir=DGK>l$FC6kxq*7VYFzB z(((W0itaBEQH!o)2ALfGC-&TU9pJ~zw|%pcjHH$mMQbJhR{Mp0%@9kS1vN)tb~a5d z2BBY|+@72&@=xk(PmGlvpXP^8e3h)cpL5#02?=hyar1xh-G984-aS?S;?Tnna(?t< zoHy(QyLWXx!d+keWlBRIrc@}<)Sju_m;IR^`1;4VaoCdYZOKgVu5*5mBZoiB1NXiT zi~_0*i|S^EF1m!NQ{(I@et@19Z#jb9-48PsEbaz2}E-q-ZAS zNj}1LZO=fu=4p83@Wnj7@)6?M99Ko(Lr>0s&DEQIp5tG9$~Zj zDKbq~>>(S!7r2+DLkSMgoWp|I!?M(i+H)NZxg=xH-b-ZZDW(+|-~4tW=|)aGdOH_p z{*qk3H}Rl1caSdQr+E18n^;I*gQU+{;9!7^bwhlXEBn3#e2Fv9bdb;%$aK}O5On|H z>tNUKbNbQm(I`F1w7vwE5)!6SJ)df0?+5`xKL()- zw469{7#W?$oJOW8%Z^L802lJ$w|i;c|8|C2step(pZ_33nS0pL^+aXccSk?YD}L#> z7(YIWVZH%)0}X%@@s~Y9l29|sc6=iiWv(b3l!s;}FF7=Brr%ZoaT3p&Y^c+xHv zUD61uIB;khhLA<7YJ&TeR5H7Nnp5UGIB?=|;-jBuOY{AVEe?^m@oPj<0pHK+c}r>s zdVGLHW17KRE^+K)=!Q3VDOKCyC$Y79Bh`P$nlZa?K`&X)(Or`v(z2a{^4}3Z{%K@% zmKv^OVR91LTqZj5W~RpGI9LA=4Xc3P(WybEmO9wc{veuK;^^c6&rH7-wfkCnF6*LS zfAxDx7UbBY|1HN|K`*{71fAVj0NISpfzc~6YCs20(22xRD;m{8;v}KGZ z{()WEGBm=Ik33A|qCcmwoQK9#CD@gpJ%jYfAF!k4VOuSkFGG24Q;|@4gnJ&5XuAAY z{2OBGNM@l@GsY<88GZU5Oi7}B$7Ph37id2GL9$IVj>DQR`v_q5!Sp4+$({EpTs3s7 z=k$9=cXCYqX@;-5fR*`q_AdN2rRk5-)p(r!W80Y>{1Y;oQ-{AiekC&r@{1OW!jK}! zX4R>W#9R!*h#O>bew0?tvMNAW$~%KQVc8Ys@2agQ$FmoFn=gImzcJM;=f37fBGGWp zwqN}TCE3E@b?*RRzUKlC{(rAwDEEZp*uzuXnP_|;(dH47eXUi3>Q)Pk&*tJNtHVp{ z;1CQt({K&LAk~y(c#jomlN5!XE8foO2R=tfgL7`ZYV+;PEwXxel>H@s_r@|w&Ia(lod#W|*nOcWe1 zZ<`*SMIu~nN4M3w|FWs73X+1>c0Sp*^O-Z>%;<@S(a+pPCOSniIz-d<>zO}z7d?f4 zK~oAmxbPb^?Yh|aoEEqdd|gFVNo7+=vLdYcXZK(H-go659rdoobEkjZY3I3%h2TEy zsX^%e8XB^qa4hEsmy|`u4^r9XJvf2~N!kDH(+fY9zUN{1hMp3l_I&u^;{pWkz zqZ>%FOxKppJbd&&;Q0S=RnKQV`?>ApOgBekA0@A}qaObvR}6l`vri(rL?RlT2=v_P zU7RRhMlKvi>+bsZaemTV>$F5+-$WiM0NiKZLg-o41$yXlS$7S(fqP@WxKt46y z#q5Q5GQ7>CxHyZ6Os1`!`B7`E+;LGa(j~vdQ&aEZ$oD_PYq$LqQwxoJ@6>Oy^ZIMp z=}}T~G1~>nFUjm`LY6~C#e1Hh`r@*xR>py{AH5a^q1|LfMp9$Mva!``$>f?e8gKdq z?z?N0OZ&g=*xxR}Z?{(tI|IO@2mX=AioZ>BpHq^ulwV%;WlM}b@on1m|DmWSnAI-9 zTpk4^;+x(;s#Q#9GY!+x$LA*|xc_fI!lmv1R`Xu(+?3N0p6QkpOqPhdzfE_mg3Mz&mmBAYyaUzK&Mug*{Q zcWfA@qs)M;D4`b8vZ6TDL`+?0)7BItfBH#|ANvIV_Syg9+Qwg_xA{0v9oxwd_P>r? zYl^1buj7^vS#{w3*0h=im;TI0Sh@2rJ#xWi+dsOQeesVowDd8ywAYsM1&C>Rt{VIr z@HIFMGm8ymle3g`h2?z-+GFz$x#6f*;GDvHxc|Wq)41tkq79vxrXggAKK!MBXWQZj zIJd3-WWk4zUd5<+IlEr}Zm!(wR2@3@+(8znXJGIbnCbZhCm#A19TR`RQX$Qe((7qI z??zfXtF;Qm62Ri|yB1@A?7$188lp7yZ&?%26V_BaK= zQ9Z!T!(a9}J6zfSWzfG23XC6W=CKDaLzmLb<(t6NIlt|z+&Jt7+520L(BJY`&W#Ir zAwTpoQN*5z0Jl7FZ?N-!5SyakAMC@VIFpL1mC|R8(iG`E#OLpyQIui9}VhH@p^jtuUWD-IPPsWOQ9;VRF{d2i3B_w8j9lBw6J~3QP3^gjNEe zx9mcYpY8Vd)J!j>xz)shxOMmGrX2H=Vg|Vy(J19t0YV=Jp<&d=AS~JLVMAi{$ICVV z!VP;_*x5^rI%~6(N(jI~vDJ!g)mwwmbY-zh5}3%I;G3da0l!$nDIux`0Z;<-TlkBD zL$ap9s>F_nK`07riMBa2>qW;n|Mt_Sx3W6@fh^Iv{eqG9$=EbOHmeIh#jMINUG%C8 zSXq97Z-4F^tc*U5lI~*UmAe@`c?&;${?F-Z*w2%*Z{_vB{hQYHT+cC#gJ5rVEDkHH z31Ybpr=Y!G5XK(Jm3?XluFQ^c^vMa@hKFft^QON!^K`ZZv&+hM zUfZqgTmAr%xGmGz$TRljZFK5iqqlJ{QMJJ14_!;4_kH9}evE5dSLJ-lWz}3Y8}2;z zcCLEc`J6s@f|)a;L}L+Jul*gqci`QaQ%_NvKZB7NqJ8Twl6_gaslMlyfkCG7U*z=s zB=J2hWK?#@|yKe~hBsgxb7^)H#=ulKj6JMO6Yb zK|L9el7yngcy|6xbTs^xXa9!89BUlMn zU*I0^obtitanyNDcMTk?ZAF$Bi;NxdX0VG)s4@_O2@787UL=)x&n3`v2|YWvlE_#| zQ5_e(iGBCK6{BR5lVgH2xvyQDL8z!QU0X(&pPFN7cADaHHEY`;G3xU!B|nK#0YK>I zqsXrLt0rT@4G4;SJu{D~+uL88LFgVw-2Ryx5Uf?t+HTF=t+e#C*Q^U^H!Yr^RJ6eQ z*1ldEF8l+YKKMpD-d0eM@!!NIXhQUpKoV8mH#J z$KHE?o6GvY$3Y2M)6~zV+T77aY3_`oP{l_YJ4)F_r5BDx8T=$(H}Tl7f|EWL&uH{QaT z$z^_c|D&L2oO9JKcEzjDcQP-ryex6SPM7qnEYmU2#);$2q|ypKJ>}OV(a_w$?8GW7 zq|K7ov?qmFj7NJ30a-X-x=m6PM?W%jy;8Tjr&r>qfA6h4`M`Dj@cyIhxIp12UeiF| zc4y*{uUxZf(%#?355MzAT&;anxXXU&Lk#X5Al3dI9{%z>xuD0tMya1=lh$)l6`iyq zW~Rg=D_qOG~h{ zlt)XPK5L*q$Q%s98WN+v3>DjU`+V5Y-eYy=IswAUc}pdtsVr@mz1Ok*6!6I7-{ZR6 zZ#cFmV@uq!r9S3sOL~Hqv^UZ5M`4;;Vo(1Qq4z&DKFpckzo%nImIv-R#l?N!t^9UT z*EjgqeHWt0kS;#Rve}9uCy5wim_$e_W3(stv0O?qUuXpkW{u4>Y`&h^(Khzpb1y4L z?jU16%aYN?qjuw!jL zSxHeb^%8?ynuwhDX%^<@h$>cDpwB-z#pThTW1!_wsPkWz`YoRP?p>_pUP;5?#h6G) z^Lt4j`Z(KL&#vt+m3{y4duYAv{TzMaH2rhG&(7}0Yaa9Lsa>2|+>Attq1@y2H63Ac zA;;nAolK?Q#<{P53$JKRvG;-6=TagO+6S$1eg7lRVAv$Y%PIh2OCY#Sk|eZPgu+sF zj&Ucm`k5&J2o+7MB>VVW(D7&YP%{UfE?}BEeFK&xYS*TI7G@ebbMTmF-KuC2N_h)x z_LCCTOF9UAJ3oNX9_y8r4M^(8$u{1`hP^CojpKL4>~(?+-Jn>|#UIa_GD*TX*wuY> z-Kg+U$_p{cLJ1vK4qnDmPJqx}Q!I=L7YmX&u&HV)<>}M3Wqp;;5IXp0Rws=xprT)J z$utd0MLR(03UJlIAhcg&an6~&>ZZoW;xbz=+0M#*H4C||I|JQ+Knb`Y6-7~3v#VXt zu~yL8-R_pYUM7znC!R@>ZWh;da$Bibvh^YVMVN%a=z+sDwe9kp!=W86yRwU2naF;% z3d5Kw#{z-7$|W5VSU$jL4S+Bx5brssU1Bq20NZZSWQvv@m#&_`I$n#%Nn|sX^O;!1 z)BUm_BM~_79UtY+{qJMy@&BfYhuGToC@q<)MC}6yFK4guelR7F6<}ar}hmID2qKjmA);{zYsmHBM{}Zew!< zIktm*-k`fjL6V)$xmE)plqAW)pe?zh-E}eu-T$sxjG@>0 zp39IDj>HL7Q$0o49s3)mLykn2B@|VodFQK`Ir%xV32!~oAba@KO>BS9-*D{tqmF0j z4klL^gus$Fl8huvta9On`*`|~=sdFQ z$uPCQd)4I^7$xU99rDu^*(+y%5w$}vl{o(VA%~1kkcp+x!u98%#dUsZr2-HFn8Yzd z00^NNgsaz0Hzr9pCONUU-~b5i=65erMPk$m5Zb?w3ZUy(ciPp@Yt>xx$jJJf*8%3) z`m5i)xB~}v;Mn@UGz#vzh9MC@NsWQ5AFxswFO0yQB;Y+?c#||jc7EsHPYN|1#8S}zRu)?NnfMHkgaP+V+vh^ zR;J9{Wc9h3hJk5&g`M5Dd8kHqIaKA}EA7!#VyVRH;XXc&BnGYR$p@n30>8N?thuje z|GR(3Fn1sLQ!Y>bs-r%*Zpi%L*n82WIj-vauYhx#Y@Yjh4o-Z5OtbZziwom4?EfHr z%?Evt1Gu=y%Y+ze*$WH*%Gh^WQI?yyrsV};;>`cF&o>9_KY;60E4ui;rGGNxlY%3OuXwe8)|J*O}(6`QI^x21r=bz!i&aaV* zF7SgR@1*}KtL9xMo223W-{d=g@w42p=|9&57KgQ5a*~USKhKkQr$~2LVCTuFA7)SK z1Jz`^RnrZb2jzcb>m|E+=JpSI+4i^GASZP(Z z-olef;*n*pZ7sj|>R5d2i9VhiyOag`D8_fty!<8Jxz(u(kcrRphAt=Q3-HZjzfb?g zub{Qh`W9DsKhH}lEBAer%}w_)x00p3!Jo}`dZvR1&b%GTERvK@GGlII>nqm_>lx8yJhr>_HsCg}=jRE8m=;$O+`0co zUbXY9f#-Yj=q|Q*KkJxR>#-jkxP{2h50c8JICSUNxOwo;Jo3n8J2lbLN2Y)KdUXKr+#0pcDP5-ctT?9d0m*M#m)z7>`&*iOq>|?0EAjZ ztHgJ8T^BP5ONP+5^9KlR_prer{IRkPfN;ZJ7IyXwLYpIP8W=^tqQ7eZ2*b=%<$|=j z;Rq6zgEgWcRyk;;Xfc{wU}_Sw+4=!Ozq)}UhrP^N1PjPC4O7@&*U18}F9Zl(3mAWZ z&~#)CDykF6S~Ea6Gv~UF_dG$FmC;|su37(egtmj~|NrUJ*hw^_<>Bu*UYI(M#|A(mRH|JV*&-8E7 zbn)xaVkr+J$;;CA&-|2R*)fm#nWV@lnj!@5BqfHaTkTz(o%aW7B)N7CLfK*`OOjlv z|KPTunw=FhO=L;J@ENx}tbA9SWfv<1yyL8hDQSV`;lM!v+ySV^4)yFFU_SXWg?h%TWA0+!EQffMZV}Sn!Oh#Q0-q&n^&J@Ek7cfOwzY=OC`YOW@b)LEYf-XM>(+{IPaaNQ*@=HsA!qDY6aOHKodH&ui zGe}ccYo69ccYIC{^QQA=mgwN{$8TtU}WEfk9-ZogaQlJ2jd$U0Mt zO>|_P!QJNcI7`QWj)P}DN_6%IoYV5ZBywk0Q@S-XMr&qF_}lbiBgrj439R}Yw(mZ) z?3l}f<`zZMXzFpwsJa=1c8O6%kuhECTSZkJ0X|rj71-OfmSd$4kY8G+bD*j+VwxtIY>X{${2gw4?7hU(De_ZCh@bvE1_}#_ zt7zYLF2$8)re3&*%Ub@L#?`uOB3oEVa@V2Pp|HX$w|Rr;QLVtvg`eWFQ~w`J2XEts z-oN!k*q%ChAL!P%HnV{HmP{<#h0la`J{F%DYV;nOIpieB_Ow6d*CN?21h z)w8y2zVN*~`S@OTcit~NW@f2@d*^F5045tr|qEAlJec`K#cD$F_xXyux9%1^_ z8H{29V|f%wjWclht7z}<=HP$5k3B8_6WET2j$TLdvQHpyk{LH)a0jFs)0O_0X=>of zu}^V5Zx4NMYqeC&-&5;tI!bTT(Ynw1z|mJQbmgl(+nahiX#M5ySLXWhqR#Oj{3ADx zd=^w83;6~{SGLl5j z(2FH8gHRVU3;h`0`i8{lkF9M0gd6spulAn84Xv)hOlnU zs#UV@ln5>DXEsEUCF|4A!5%yT$*(OghmOyWL8*#*2VyUFi~YPxSyk{sU*94FpUyMsl?=% zx37663;h;PzD=l@gF)!CV4R{+7i?^yMjz-TmL}Z2C5*=ldy? zqP$|8QzmD3`yFe(&QdkdfNBfn;vCS%8En4hS!M(bhe+JkW-$9{g=D}&}|>E;}h z3scDE3c8%O7%4^#R<1{4)EZ0TEyK|A5dfxIKSBQFeJmb%koMA7Xi7M1b3HKn4w^2! zW%Yh!&+}Fkg=~8pGiOH8;&B>Uvev$WxlK{!P%_+Qd2d-!FkOrMRIZWvacfz-JimyE ze{j7P_*D(SOMduZCAI60wpX0@^F~5$XU4{baS> z?D3F+O-Yhy>C7@eIgQ`#sjh{sdvWckLFmics3v<=n`mO%e4DDos1JkCb#9U5tK*s; z2BEiYr9#0Os8xf5&LEH&bnV%=N>5a6sBrEN=ubq?- z336?N{fr(u;@DQNvV|kb&UDtbl@tYc-jb}U7$vWoiX=;pL@jRulE$`?3( z>?BKbPQOl3RT3H3OIp2#UMey$lCrk7_rtc|0U(k{P%0ELN^u(7UdywmE}`kzud=n{ zX^u|z^Vq@%h&BxqOD2&Nm1IYnfz2|e315ET7JBmk$VCHpJ1!ZItZ-TJHPDO4?mYT7 zuK3k2RQ#>-v+moy3#clzcx1AD%+G%`{1}8HwS*dM_L?S1p@4oRl0Yxi53W0E5C+K3 zg3}jKY#`sn1S~J_qkHRSR}DhdqmlsS7=gIHrJqHU2}-V{VKo+mi$BE!r;n39@fVJI zZHxH?&!qm1!L3z*zI}KYbMWuE|IBgp;tHA~F@NL%qUs{5T%=ig1W7VEv3wEPk;`c8 z9cIs4E@OIZj#S@8(038@^K(eRQYy_86ECoE<`AZ&lIht_^I#uJH116FG7K6zdg3KA^J=meeTXAZJVj>qR_4tCS}%ILEm?i_ z9cZefA8ozr3V!&tpX2K8f2ix({y5lT%iC!X{owL{RZU^w6(8USN8XGiDYRdjWhh_> zmShRZbSJtvV<&Datt_J%Q*3(m2l%JE-+?;*U78{XIm5NQ=v{{5Qw zb$X_QCo^B+ijSPbv8SH{V9%x9z#FUAGoib;n{#_DRfF*p;|z`@7#ukdI1jz_YQAy% zzw+8GRymbB58O=b{NLi(L$@;$`v&97{VcV-k&!EhNK2q`t4@8r?18|py{G47jGXX)0E}FW6t_{ zfyoXbR23r8aBIP}ndGVcuKFaMddxHh$so@DB1IZ73)S^|J&Ekz1T5K6szWd7!Z@2z z(#3OvX;zY%?WUiFy4Z%jxV8ZhZrIDh*0Tm-v1lTjRnR6DgOxnAx2@sCA*~7!y0SHG z5~J=R*&!L09){8asHwWIEvj)SNfN5+y;pl6$utZBpc#HKSnGGm?lc522zgU!3#&UMiMQR1MQGIJx&A`9*)QC7!YY2(HDL8c~r|jpl((7$u#_ zu_wIu6%vSU$jVHe4Bp57&fE-yvYmL|l2t0^i?sH2a{BP;s^cL+bB9y2!oDzYXUI9^ z9W@ZF(L5F=&>;$7yUnNp5V|F{9A+9!kIk|?GvWD3RwN;dv|cj}(-BOSRTWwAooTTc z`6UYkGxY+cqKT^6G7oD6>dTf)!c}j$6u6ZA`~@$3{m*HQKEt7vo5*dwj<~i+dv6ns zPri!5mi-(Y8)isQyI}w5Cd_NT&rp}f=>xNDzxiz#yp1+ses-DVvD=78f*D}<9lV-- z3pa6N-(S$4Ei=|6?mF-a7E4VG=N{wG%z3CIKgA^29JSM;(F#Cl_&f){elD`| z`oXY8$3zPPEgGSHsK>KiXp-cV5)7zsS<4pZmYF){tV!XPiwtgxJt$~a>IIc;Zd_cj zTx8kC2usVGnVnpur6X;fN0zO;8Nn?XhUrWmkR@@DT#3+4{CriPxcy^Su3 zQ$OF&{cYQyS$<(Ilm}G~P>b#}zf=u)+nUve|vgU zF-KoB4NSuaAWX(t2r-n$(+QU6mryl@RAV(P>JKJ}8J&IrVI?4B`v$dm!UJB`GLNLG z{kze{2@JNK!89DnF|s0282Lqxj_;$d=}6$Uk52S+wC~Flmlv5la)9Q65ytjDLAH25 zadn2>UH6eqSpM<;(aq?l#+G)kB;1`R-^;}x{5!6TYmA>5W8$b)%h{TT?I2nWwrnab z9Q#58J(d0KK8cgPZ23&i?^@L~N-npq!h}{7bRU4wPv*L^&8M#|OwV}0L4R8X2z}-n z(=<^utpXA?b+(|leujG|W@y-zWy|XG?A$zzUMw~B8z=uS5AGue@ZX_L{kC$ib=%h$mm3_IUSlD}Gu-Qt-POLG!N zVG$DrS<*x0=Q!l{eA*;p5GtDL7?T$l;mpB97&gFLji?k?Lep26M$X2+ZZvLs97dWR}~GSfXmV`n?5#tg@w-3LIY z7GcwcJFE90ke@}8h_-Ac+OpNQsXR_spJ{o0o_?%1Zbn#CMvySj zlmfBhK`!q31`P@8887A&JU4zh%X%ZZ$U(Xr4iHiDB%)qFqF7REO1I{V5uW2$*~XuE zSZw>pchYmo9`vGD1;cNh*4f__c)hYol1wC#sMMk_Z!4QSb^@?Cw_-^`s&K_M8OmVZ zp2i%>l=Zyd@cX~V=l^&&$?5;$rmg?!S$pe-%tMDRV{YJM-1v@5$S?T=;#Lx5wa>UE zFcno|StDL)2YlAi6tnkOahQ)w$^wdEO}L5z0|Qnf(ORz&CKz1aMR#SBR|aBFZ3r3F z9&C_g(A)w_;mxXpWLue)jm zCgbiyeA`o=$n;`KxF{+WibUfYb5pa-PYD(aHnTGb-D$ZF$xmNdr@+%!EY{FWXV)nf z7p%6=3dleXc0i2`p2-uZY3=J|Wv;rgY3V3O|8vz;qreUx=H!u9n>ND%ZAkAPm25J4dT!I0_XHv zQaFdl2RPC9Wv)55P1x_#4RP*z=#9K`(-(x-ed5^p9LoMFo3BoDkqk#?ujGdh-3LfC z4xLNy4OUWv0rZ#eKcJs#u_)9az177aG)?P!w++*QK(gq1f(#W!6f~_|GCV-I?2=K9 zCgL=#vRFd3&Gy|$tY-PG2C_`OBo2HC%tS%eBnF36jDd#gd??Atl2~cf&6JQNnOHL6 z+0XI)D-k*`c`ID@V4HKi$S>hzM=^ydtStt7BL95r~n{TYyhF{ zCliTuHo@uDAhcmYLzlO!544&TuC8HA7+70MIazQIT=RzE=Kk+t_G zi1|r%TlIpyLD5t;&i?#glUaP1P;RoME8Ma7)#$??X7~11W=F^9?AZ-KN4J$|1+ag< z$jsyttvxM7`mAr~-czTU*>^9EJkFfaPt$qt;HtM>UOBg0>f5ZN#&skb5rT;I)FK8i z=wt_;WJPfREFiA=S{@5!MXAis!vvcE0EdZ5L_t&sK@7rr)`W8t;cF6`ru9Ae`NF1Q zphaVjTEmK}kZW)9?57vS=V)jE^MYSw`poUV@}%E6^9ip1=&zz`uB0WOU})GNseR6| zGd=z4y^F?|TC%}4%{@Q)Sh68auA`Ny(-Wj}4aBq+rca(g(Nw0!{lN))qFX%@qau)R zcG~HisG2JLt~|+Ms_Y2Fi!2)V#u4nMmrAr>`C(o#-bd=(4t8uAq_8l}-1w9SD0KrN zhHiLjjoB>>!wIVI)SiS=KCdXD)X;U77nVq5lGb=yt~IMDp&0*GjYsGCVu=XX{`x=g z$oJmMseAq%kjZR&HErkQ$RaLhpJ{1fX>N`~vma*i%p7g~Jq+z0p!N6}mZmJZ=F?aS z4WoySQkWkjn#qyOwUBCTz%&d>#R8|Fx|d=3-xzB4%3(gA{wFqF_a^T7T#9QNf1?6k z&ab3VF8vIKETaot+q0F)mSlxg@!$ zEJ5`?r1C+y*>P^G0(tJFw;BLpz&T}EcHDbQZ!gP>2Pl=GYp~6+-&)Kx4WeIzh6OTH5+Lm^d}=IKCc$&>druWZ4m$)JwWB z4r-O(5ufwK3kT`nwG}m5U5DGpJ9P~SavS#I*#POGO=dRqLxf zhF0f?mySVbmqYXvkC$S;+R|$>*90xeGKp*>k;o}*u_MQ1Y8VEF?vj!bJ@^{+wtOcw zM|PTjmZ)1+(r&SoHyB-Mr)!ogKl`GOLsP6|MAZ(~N@_JLDV2$#C{~;+9(OJ3ecBRw zqL*YEdbZJ^FEKH?46~;^yVzv|!Uh!Vj$fjw!}DGpG8s02)y&ceA5sSel)tv7?PM$7j)(CQ(z(WE!f^T_oir2LL~YVfY5X{c3X7ZKffV z7!_QRNGw8r*;-r|mh&VtHh{1eb3|Z<;j9N|4>(1lGQFExDXuJ&ZB!^;eJlU+rO#rP zinN`36Bk`zIm>i=b4c?9t0gfb%XU%Q&IlKZWh*RIBV5r$l)-bh)-}PoK_aUU zSli}j7JL`M<@vDkx1!_`}m6)?ODJku2jMUsdmqbQLWdZCKm2R?5lLjR@~mS-l2 zH@2?!TNlI22ly9f)hiwbrWT8`Jnxi!smMLLWUFPx0`;w$UqO}Y$_&C90AX3e&y&%q zf*}sT+AWeu5KSg&Xvs2h%tq&eD7yS&`6O3(;gln%OgFJjY0~>|rX4YcD&`=N@02Y$~6itgW7O#Y|WI zULQF}$8~JK={El>+0@MZ_?T;dSI{~Hh-SZ6KYzg;#BwBQmFEaqMUIpuJnNA}u3j?` zaPxdG_0P}cnt-*w(UOg=;;t#N#9~B)ELeGC=4vcd3zx5FiBT8t&x2CqB*jrC*}0VU%J? z<@qza*t_sr(!*ENdt<5s#QRw)tC@ z?X{kd8jXr2B&@8^}j_Ul4-td`VtEwd-Ng^?H)CA9@98-y`K zF9rIhWlLpKL}Dt*Ou6WIig-g4D@%E%#-9y!ty&VJ_VLpV#(EAxM%85DIo)HlqD9c% zG+AfEs!1U}R!$%>S0-_O7`*&d44wb0ym0>$oV@?PxV-5viE0IYc>2R!{n4Ldd10B( zbFb&w#a+ygoN{#`(P(<@`#HMw^EB-=xC~$wWRgWbr@-v%dcQS(Sd$ewbWL4*d4<~NRUOAv z&JomCYuvYceo*2K3~cMQ)?KmIaMuX1@*F?7y@SQY4w9C_xRs1!d(9=KiC8K{Dg`qq zEhf96o364a3R*Nm_vU`v72I)$Iv9kCdu`xm5Q=LJx=o3uDCHNCFew&vCQpsiHqc%9 zJxJkg!+x~3@isQ>WnpI@pWyAEK80= zZAp?qiI8nhF?OoTVA1S>RTnrX`!e%_!3@9S?AbqNQ7lnaDioZIERV#fIP*^=Jz%af ziHdnKNK^)oOh2~Bs=4PtS&P!Uxr4>&X)XGUR|fP$er z0|z2GLwnzJH~<1iDv~M(qWuBFFzvEmvsfrmEREddI>NV}}L2-Wn<#Scoqt&r5k!`HIAG$#TdtfrChM5&)tzKJD`mcHoS)2iC>dth- za-U0Pv&@az0GnDMTB?^svbzduyT)$6_vOb{Q{4+g#bOmYnN;FFUtU1RQi+-Xp(IH) z`;uK|$I&KLMGwfnwc4UFw8vJ9w8r8&5|tU+-ipZFOeOpO!NhJc<%6(iYD_c4F+9XEnv6cA+PTSNKBt>;3HkyVJs@AY!7(xj~ zyLDgmYMyxVdz{nrNM+kShi>A+5B{ZZH}Ug9!`5LEN`C83>CRd;2>qUO*ynGQbWD39 zVN4~KOfok$@A#G@F+r$Nsx7bX++Dn{D1k5A+Q{O}40;JNxfHXL-kO!=Hn`2rid8j@ z!m>X=C@V_MHHV}2p{>uQI|fO&n;d`PFsz8#v4(DVm_p9$ORKd&!_^y>U3J93#qhVu z=`YRA0qxFYjvzX_ zkH0;`&19QX%umfSvzp_nmr8--LixOozv~!pMGQjwHdEJq{VTg=n$qY-hPg8qIbmhV zIhHz;4>lNtKTfs*5N_DZ!UEK9LYX9it~8R?xw^ej5KmO6 z_BrJ%JvCk=8Cmw!n)a0*(2FHsNrW5{ zv;`wn)ej)FWDm>mLt!=jbmGL=n!rKuG6>z@s}_w=4tJFbMOW>()x}lF?82rdiKeDx z?HyO!;RYgtfZQ+uVLbsTPXI!a>T@m&HY16o;*z|mm}SnBHidyuUQmSbh8dr|L5%2 zPQUts4%G-LIOAIuF`FcrzHPl?-5*!~FV&DzgK#wn0f$tvTQ=Ks&N}=8dtV5;z8Cyl z3xm-5mi@lbVxj05BPCh!j2(_P1ulHkX+yf ziE^FoOr5lXNbyVq#pnR}V%{{h(f@0|&_JVi3xTf+V{nMpfAZ8rvmC%jRC=nKb}HMU|1`tr*K= z7*Y(`46a`t##Bus+Stg#gf%HQ_x99eyqDWrtEP7{9c6a=r}*--&tsPIq&i19@146z zXZ_dVeyr)R^7r;Jvpg5O{k2-b{CXrtMH1DP^kj8g=jKhsQ*qQ|<4hg5f*p~xe?7K( z*gbweoCqL{r5YICJq%`vlAL9B@(Jg?sA9Gec;|8OQ>i<4gvm%pf9rVK{enZ4c`gsEI(MNMIOe$o1NP3QM*nziRrRZ}qx zJ0NdoxrX-w<<**OWz%(?xrw>Z$J@^Xt2OGEmr!C^z&|m|+9vz{2G*Qi&%H=xsiXKQ!_o<_XRd5+rjk{i%hh-NvVZ2N2@${`jMUJ?uh(Q*a%AvS zoT_M+zL7@q^V7r|+b~_utHYV`nM8|H3K$dqX1|u;hw3?Mt9n{Z5U|D~u?Q?FB260zUlFmB8W8J9lnTJj6^#-*q7Mhzus26pM z5n|g`9xL5*O8Gm9i{z_aR+h;Vf#B69i39MWxy*ZPjm)Sz#!Pp=_hZrroB`pLSI9bWWou`WsG3ZvXb_3V9OIQAi_A~%$(a^en*evZ*x~ zgvCOUlI=UFC2t)xceM0%`~DVWlDx`ht1;-wicHj1bGzO|7fF^q3K`WbfB#^VQ`XMQ zAe0rE*4}pVOUslB1=h^G`lVeGif!7Va5+1z{E;MvnS0&;}K^$2cK)yG}kk9)O8a`3(AU&tSAm< zUiq63Vi1aw%iPwxwTG$6E)ww)XI^-^zWWfEqG`1EcQPtjfzDi8BeN5;HIFlOBNUL> z&=@0o<@psLM5!Tj={FCAO_p#3YYlz1*0D|U9Z6Wf+-`ndSOjcqu_*Z^0VCE?BG+nr zu06-n+zO^?(zChCdyH#;;csE*1V4ojwZUpPASqEyz3RP{=OEAcC&?Z*d#z-GaywN` zqFAUt3^f`P0)#?vBnTiZF0n9C1^<+&*RPDHQ#Ap?^0#fDf6M*a4uGd~4a|>^TfdnG z(_>Sv)5sWM0K$msNH{Va48l4`8y;9Z){1}Amx8AdqbzB6H z5T$G+S#mr_dy=i2nZHqk@W;zG0KyG>S(tlVtXIvaGFD15g~$N9-iT!8i9}A+)rMAq zdRHO=i#2?^nl%Vr1bjc~%r%>vL1+(nH4Ku-SVgSe(@)7RG3qnPH#9dgGnzv-mxP(3 zzU&XdWM?8_N>Sg8Q&IL)Bpwe95b7n}mqA#bEFFPfEd{blTYnd+l!+Egpr!mXDn+&) zfCTtEL-qwjt!o*E5lS=MU%JNM0kEL8wIfS)%P9#!qMB?TQTtYduYfqiz78`D6T`Ns zR3Z^n&3isca)Lp{RRBqi#nD|0#b_diDYvsc<*l*hxVB4E-s5OH7M3L!m{@&oilWxc zxUR=!MR8;R)CwFZs*J2gFiQT8x!*ZbO|3Mx=a?Crt^kCl=?pZkMgB^Z$>(lEao&1A zN1k`G2SoCPk`Hhfjn~g0v@)2h-PX=t3<`|Z1IPr{DiicMS&>M!bg?jg5)*}n=El(1 z0z!`@-XN_wqgK~-q6z1-tw@Zz=9IGRT09l636|U0fxZA?g*E145QfN3^0HLchd~Jb zwRi*DqsVI-A5#r+CXQSGZfuSbNjC!P1Q6B<9Crc?{_>!< zImBkmb8|KKv0vByt=Y4nYXOA0wk(q;PGg#u6uNsDw}Q^$#y|TQW)EtK8H9C=C3a?7 zz*rn4c^Ss+wMhN0Q_2@B^4wws-~euRk}A-fp}4Q1h<)HIDinGEU8*+DPJae$g$AA zUpa}bRJ13PtUbIZT66mXgo5L=h{C{*p1^&fnjj_qt+uo`lo_7R1VZ&u)NLs{s&}_t!Lu^3qL`tXP9c##nT|AcchGrRsNRND;DaUWvnaI?2pwSG#=X zdV3JQ%)YeuZ!Myd%x0ONwEU~tQKt-stjL7<6je<{R&0Q^?VIYBr!fgRlU!UKegs^A!<1@FD4#Hg=qwp$X?JC3NTM#)~nWc9fQB??Kh z;`vS+T60VtUn4*$Nka7|zkU7MrzjGABN_B!i9}1k=a|ahK|bO-0fg?pQ&2lWO_NFK ztV|s6J-jign+Kpc<)OAs!iK<_$>jp-tgsiHL z_^{}j8}_5M^@%)h*oOU>+j=nw%cd(?N@`nO5~F_B(QXA}WEuPv;>y7mM~yJsU@hYW z#G>)q5<2ymC)3I8*x3Gd)St7n&3s-$CIh+N0kr1r$Pt0;pkFXdu-TL%fdMR$W761K zy=Fz3H3%Rq+U2sU&$;Y)=^{;SH9@j{6Ors1vcyFJl67JbDvIRD-mjT)Y?_XYdA~r5 zpHmjJXx5@<=N!}vqza``*C2r^*Z(WM=_1uo-Sq{E#$2Uv>C6cU4 zgGHi>r&esOAie#Cxh0I&1t63q-*3lQ8gt2RxN1wt;#zC9)uz1 zCK%k_N%uelotyf-=l9e<@^5RIJ)jrzj!b8f46^7v!{_g$Z|By)ed`JSgkj6I_B4@g zX&{=2(l$`_)%?bkTtGj=J%`Xol3*Of@*IxtaC5>2F< zy#NUk#4~A>07*DS0L+A$Lk+_W#gwS`+Qarzq7u97(4f*XlijK8^m2@3w!`V z4K1>ZlfUYjsveR@yr?G8t&^4bI{=O1dKh%P!?w zzE_7N*g9my2!kPO0SHBM7H-}85YDMRW8RO&;f zKP$KF$>x7C7=!`AZQFS6$)uKL;i9LOyj^euZ3}rbEI=re*Q{^WM5d_uXN(6*jGE#= zv7$h0N1RNy{y8=@K>BB#eQ1EebGDHY)U+%2LpfkP_mdcX zsTqXelxJFASPa zq3_|cESwH1+3G>r+%*7YN%jUf?eEMz7K`pZ($GP=bC`H0w<>37YoqI4Be4k4R2qP2 zDvf6Iy$X5DA6lHPW|~WCObGUio}sE~fm!|Te+TUc;ofs3m7`%YbzI=b||w;syvSzLVE4cl(aw1Wns9 ztNx4{t%7u_Dt?b95~`s8`mt!;$9DE*g};fIu%5)qav2E4ndoI-&w~X#ryJ8Kg6#Jy zU>4j~S+Ros_Ge%kItLzon*EPF#meH6@3ESP;V0LnFyLq9Z#e2M!?4)4)q zDTM!6?&EePll(Hw=bB9;QY-N3bz9gh%(4lKI*S52wZ3t?l$K|KVVy6Nt<7WsDCF(Z zOA#wEL6SVRq+fn&<+d#J`^SCP6ym9;Gjp~cjG=bL|^v*>S7-sK=4zS1&)F2G#hl&8e5hu;rlU#yp5UvY@Frdw4 z)~@Jz^=|2;xn02cvB}ex3Po1t=c?Cp&r?dZpOpj9!e?0L`0nrn>~#fK=f!#$2}0Ir z`P?CD5IWlhsg$)o7A2Te`e49&6-qnG65U%j(a_nC91)LAMzK&AOQ>wMq?mnc5Y~}Y z>IPx$5~J%a*A^zR=WJNeZc8<}n4*3dO$Y{IEa@CW%Fo6VK$j&?o@p&6X8E3@jXjlb zQOouT*Zv+>U-`Nr#v?`52x^LZ-4!JHZ0|RP<-F|}OHNXfAeD0_RhEAbW06%Q6VJSk zRaRK|;2`k6kl^(3?o=HpRuz3zJCc*PPz%BX!R&0FkInJn1cgBrB>fpe1Gy`lwgA zm)dKrL8z*hf>RjtQ(ICavC5cWNO2n5vX1@y5||A9<6nO9DDgT(kXatmylhw5vptMN zePx|&TN|9(dx(Qi9AJ6AnkaBq2`wA;qSyvNxM42~J6nKI)4ZTtGz!S;FsMmB6P@kf z<@Wkz$}Q1Z`FCC$T6d;)J((ze48kyuj}TDvl^7KS$z!p=w)r~yvI<(dX0Xs&gxLW? zTXAGejz0A)`Nb-;!q17R$AW?~b)6>!KnT#aX^?bFD?QsrP-0on?{--OzaX`Lpe^`$ zL{eE2*`~UKw${0W`&ZoMlZqD5A{oD0!8gZqb=L`7y+-R`AF;TMrhyuOm)XZ9XYDBW zb6#DD+6JZK%8C#mbT1tIo|CRS8HAo}N?RKYH<)NUry9VhT8nvQ?bZ;G_i=VTCJHo} zrlamd*hRA4OaX&X)Gm>U7XIR~Z|zwj_A#bZGCA<{DV~4iFpIPPnZA;u(BE&#I!m%b zPe92T-{)-_bq5Q>SlAjR-tWEo9Gbh^>Dk=N;O=2!sU(R^l4RDBh0C_JRGyg|AXK#Q z>@Gj{jr%z3|L#R-h`u9We2C5+)?&7WeSXQmZ}^#2AM`~62yNp|G83=J?biCn>XsM{ zDN9?=Il#$OaJNa9u8!&uGQY(L1SkJ zxvnmxz|KGik7SfVn;Ae^>q0GMqVZC!u z(7O)=2z{7f$+Sl2KofBvupw_}JIPw;Z?<*~LfP}apf<2*H@h#Y>-xG$jCwB=_=(!u zb9UpInrJ|YWr^oHXzt89e)E-YHcj7T#j?fR^`lprH=_QW&NUEOEq9S=O4s~`$_b3w zrYKE=J4ttMqpi1Dxc2u8f`>`e5;6MhmdaJfauY!b1G_qhtsRa!X0~J8HMgzb!}ZPR zs-L(Pf8M>0kz`4z)D?|uXfd%r0p4d_@{tL)Z}<3is=prss}iF=Y_Gv#hGV zLWP;k6^|+XdfB25?rFyKAUJ1 z$Yldks!CKStMpXP?ZW~UCcJy{lIvg)y21UjVvb#6)D{cPw6@Z_dylsb<#n>l|1K=f zGI`=uZg8xkZNtU#B%Im!5$1(rq+R^$vy1sj9VPh!-^DO=AB z(|XU{*wsVn=y8fgt6oGc=@_saJrfAbAGR|HUBE=SziF{J1Ly1j)8yz2$0^Os)V#L4 zc2wS0XS96DNIb^U+&o}O#h60rDOXU~=Q){%9-2YuUNbm@wj#!xV##LJFp5fCJGE!0?C2zq(&_{AD__fM?_*VN`>W6 z%!x2-9Yv8nNqzN5a69Wk84iij^#lkNMe$|hmTf)8vSd0&=Ww_1w|FW}Lu(WB6H_#_ zHZy%@yz-kJfEESyBT*SciXfY-Yu89t*k(zxWON62Z)SO6oeg!wq%@i8q@UuR#Z(xS8}j?#%fbjKn7#^T7KY?kDruZm}l0N z;BTTLC)#N5_WC-Wwd{J8kCD6&SO^fxGN}|SEC3SZT7u(OKp={!Y7&^W$C-3#46Hv4k!uG*j zU!LNH$73WS&H=e%l1xc;>RYd;d8If4CS6CiaCKqmSD4stVyDJuRuch)@ffAWfLZV? z7cusN-2l;e%oFsr+j>Zh)++<&dmh1a6xA-UTQVh)98Pm!R-?GG3J~UBoEn5@Yc<&r z@f5vV`&pPAC)LyxDmzSxXf@@9+-x()@%HwAbcPDW1UYs?!wi{TNo49(-nV9imDJ)DWhI#YEJ-S4GUZMyL5y5ZeSzk zlD7>4vfR(S%%lWw(9dt4E?YEFn?Yzd)i!693I!JD7AnC(Nmfau#e;<(P+yCY5*YmP zFg+y{Es9ZG@j1f(xrQ0P64{2D`*>`fe7l-J>Usf@g7f&vfH|1qiZf8GmvrCSvm%Qo z6D->RLNPP1dX6$X(e9V`?$|;ylK@GgH0=zYN7Bv2(>=)KVcuW^nC2Vpl!#{*Akg;X#gKaujG#pI+GF15~ZSnT4vV=C{@%l(HK-| zcc8>*@LDHK;#VXQ$0+3m95+Iw%J3mvHqvUNi?pq;H2QU?YG7k zJ%4w7J@<8ktQNqnXbOG1HlvqxaveF&9C9i=IO==b+E!Yr#x(5TmWXj{|9vI039Wb% zLr$aWqYh9Yp2?IiiXzJn1(sUo_E99sk{In8>E+0?hbnVOGLu5r4d#57D~hVpKiI+H z=OtuWqP@4uu#}@2jQpH+eyz&r^7;(5naTP2naaN->1Hw=ePpv@#;vyop}>z_%XcIQ zXo}0FD6|a?5IK3AL^?$zm0i=h9oGqy+^DK3F1BYZnP7Q#rn--!dXtFP01$dWu{vuI zu4UXbO3wZrbj^4w%BJ%+1)eV+haI~j=(<<=sy+r`_-Em^dTZob>Lsi1QMau@R{KZI z@oPh3^vBLN-pGc%ENv|SValaupeR5z2ZdGC@NC$G+G>W05`)=!z_9*pXAP~KV%Zkk zu(RztJK@wSTYGqT?alDBAU%7xi5wg)s?yxuTJt*M#a}oEp@~d1rm?(EK^KwP0YXVp zkrj!?_7br4* zzV-~kwJ`|Az)Ywtb5U&(!S^Mqu@ky|YhWbT^B%$`N~~7`K|Gyg2B7s-7?2l5Bo)7=(Tc`3ZLHM;u9y{MlEqP$~dlKIN>OMj5#$A8MtTDH_ zWyWfKpY|X?DW7NJ_?hatk}2Ou?f>p-u{VN#_ugA17Gos{AoNo{2r=fTa`gj*^@7uW zR#OIjour$C7!U3_#IeSbBxFg#SlwvoZfWV@=MYaPS+>-ZWHqvhbbsP=HJJ{^r~6x`)q>yWjy<4yJUFW=gxNMAcU7o4di|dfF)1zcP;o+-|MwglR`49z!ogJwB?)!X&a;>~`A*Nv_Km z(;{jmLXclBV%X|No~^XM-DFjzp}CRqBPSe5zYbfY^VHVc#OU!!Vlj!vju1)CtALR# z%a!$3!4??t8|!)o5@ef)h{soBH~0YF-U|jc$IlNIZI8yq^L&_h?gtDw&LdUXOCDCf z?dH^I9Iz5D^2>G>SGCWGYj#$jZ9`L<>FF&<5@^e39Q#G$G3POL0EBKIrF_kJLnB@N zO^ltGB%V-dXvt!lGIOJ6SY9r9rh7%xh$m#Wo!5zJI^S*Y)_x8@wHMPg=^m~!3uPbQ zF6=iMPbjpARU@r7`6wkguUY?hB$5I3V-QACIodYoJiqywl57tAT{LcLu^7db>NhDG z3!^SMc12Su6`c1`O+FGOw}-3e8K%Cw-PYpIuOXaV(*y)wz@KOl2H318Zs?yZ=QG}| zK4(Q$DcL5q^0!lKjk5s|{`lF(8`-dzrL6@ZbeAxR#{n$^u?{HBL%i$k8>Yk{y8}!e zV|flZ?z^k3RU6dSW;wm83Jq&WWUgiI2|8%u=?0S`u?PdZw;%@WOJ^5fA$5Z z8z^YZsdW()48s|;6S8y^6+-|J+MRF@K&UFAoowbY341qf^3i>QXAX#Dl7v9p%;iOdZQ)syKOIzYXSW4GcX=VUjpiptuqpUpxx z8)AI%>jxenF9MN_bCZX~Zv~r;AoWZi=BoeiwJj1)uw0z2xu0mkC(CkO3<*CI0%X2y z_I0j7=-NL_ElAHf+zdjWz#(ACGPx`$U?QqRTiM6;jC7y6JY{p&ZZeJF{q1v>URXl6 zO%;uu{=r{ACPys`8iPdOXCxkDB>({Oubo_@J(5N`8%};#OSe;;w*so^_Wn@E)Uua~ zK`4ALl`+iygI8|dSnWT0Ne?YC3J^~@<=SMt{m_w=W(oj8TLVSXf<8APG+7Q_XOzE1 zH-H^x&U7b(LHPO8{T;1W4Z^bQrRY0o>Fwss{&^%*C*Cq>JDG&4sg)X?OY=(<@&z)D zDW}Z#dNK%oz?@j7vG$2WILonQ45JWdd9F5!9F9TgW-Qe=m&B4W=BKQ=cx7P;(BRi`s)7ej^r5Ti{wVng|cW=iiE_3S0 z6!Q}km3{60p!*fZlNy^Y*ny#&M54AIJxQ_kCNBhkzwWw-Lcd}S3GSuktfGFMu!-k!WtVGNccYjRDxtT%9tHN(Bm?D1L*La>}?Hy~jmJJwaAItyU~RifaS+hJcX=_E<= zR->u{gtm?AQiD)De%7ovBZ(bucV9P=*925lEW|x&(TD>etmi$1U1-0U$sIoq<`$srS+XmH75xJ-2=6}fAPdD`j|#7txx$uyqgU7e4^L@o^HGvCDOUR0?CaNId+CDtM*#74MXzcNq zUI#1bX=VnHptZ|VGRQT_& zHg{XpXZ{J(j(QUSgXAnz4N=s1hSKt!^B5PXrl=Zvv6|&$A0Jc|+S{#v>v9^!u;Nup z%2`*f-XsMz;%c|z@+mZzA0eyA(KL$c^hs+mE>#s!gV|>Fs@2rm-_s`qF^e^g^@6a2 zDA_*jyU64+H1!Xos#Xo)S|&-dLoz&(NwP3$18W6qfoRa;EL{iyd+q@Q~pvtkUbuRtZS><=+Q9N3G zAzD-oOi-y+V|Bw`dfNa9H|%9;Yhe%qG&Mmkx2{A7Ukkd18EGV)dFg_w_0%p}UUD*w zvaR9e3}W%MON^F-KkiAecOXV2Dxo#>hHgg~vx6DJ>=&P`p1Q9ilKNOn2UKFTv?{c; z=pe@+HB?8YzK_JHUpA2}d9n|!^MrdIP0cbZ`4I{Q=n6^9U5hk_$P$e$BpPy*i)q&m zzWFl<#Rr502<;3)cc51)Ut31iB3@a}H8{C?$(iLTuJa{k1=-ErCnQA<{oX}=-R*F$ zCqU@gR?b*-z0+F2X^4vn|EwvYhZL&#xT1=rWW$?VT1RAV2fn^qIsb1IAb+s|O0PU+ zIqT95{)w_d1pwh$KG(I(Pcffl7`KT6l(G_>EYP_*oIZLz5~HG;s1D|Ez3<%f?X5v# z)UC4MViv@+EiBKBVDlUio;}+(F>zv)bgsb@ zq!F1eXHX0)kf_Zpl>7HA4b~xvmZt z^C^~RrvZb`p6~(gSX^OvM|p3NZP!Lsh})9(g53bDb?HibT38X9FD9z#tR^07Jh& zx4&Jle5cR5EcY)-g1k0}VoamU4JfKrcgXV5{gW2#W)B=&=d2eX6iMfm6T6(}N;M|w z=nq#?pjKdBl2sxx6+J{9-mvqk`~B2ApT}(gp{cJp4>XvUd+EL!)iR}w;4voq@K%=UWvytnJsKTi7k9j@&r5|LPa3AJ{sw}gPK zdi_PaW!qYqI8p@&O(YUoC%9p$hKK<|yG65|6sC`3$O$@KrDWF9FVgM(%#NNyOJvuq za^zS3B*?t}xlh{#Nx4va-R+fq6`Xo438-mr5_K&ZZnw4eb})AAv{kQhum`oa0E@3B zGENZJt<+Q5F2|7!E72;MYi4D^Ns91QSiwITtsaTdSUSnV)Rbe}S`%%v%}HqXCiAVg z2@{Ayj(PRYwdwS#VYzLx4B4JxCQnVGl}>w(7x_IK_M^3pH?m3{j!-*srGFzdv09v*JfQ2zX zryxicK$c|>TO_U$*>eToa+QO1=Uk>NsJoqT1Iu6-}qT8_99QJK8ZGAq~W|EqPK6T@ie7nf8~LC5`FwEmI^nY zdlr0B86b4_n`kmc+dvl}k!%RJz!n39KD9YqH3%b#1o?$(k*~9r7+uSFr0Y(nE|zdc zV>*U<7(JOGk&2N_3%+ehLPBy#OR3Q^gD}C!`8$0Rl+uk!(v5?m+uAW4kjM(w(#xW~>m_{tEg2+)Kr}B}RQzI4ZTcBxq@mF*epg zVRycNOQEzbmrB;9N>sxg(2t5*f^|!zCMwdZv}`;rUf1<*8{p*r<7m++xwb54mKI%mlmSA~ z7!*=sG~nB6XwEWm)LF)9y%LaLEEC1tVu^5lz1}gpy|0P**d+0UjFz&J;C$)l3eo?C8dp zmpYIdJ}W2N*3A6WET#$FoBQg!j2t5tgGIZa;5(@~Ut&2s+;B0W){+s_+&{p?5v#^R z`^X5R`(N5>_t7m3kIW@IgiQEwYf)AY3Atexu`eNvVcFLSDnakbW>H`Crpb7#xC zNHU*|z0_Io{!A0WwLdB>=bLHqB-!Q^GZWUhk#0;mlC<3C zx1HGaXYKbS3Vg05DGBQDn|>-B_UrfbKu3r89k?w_-Lw~zL0FGOV6AzetWXo?oMGRW z`&?^TA4L;!3i$=k_H_xk`h5=(1xZ#M_2GV$3_`p4fjS~v4^Y!LVD4sz27EA~#)Zh) zeJ9>TPKo|qo0*?2(9jqW&h1)6W7+n%-x!-r!lq5)cC58u)Ji^Zn|tlJmi`}BK0eye zPBh!W*qJ9WbsKxN#@OKgUBs*tjlt$>I+)vrc66adWJI5WBwH+Z)8VHFybDPpn$G(F zUWdtwOxKoS#*d#N*^nWg633iU4RJP~U+p-4(%+swBuSwqcZ2RZ5{sb4;}n;xz5K@= zAe0nuzi#Oc$C9siyi?OXG{w@OBtW*V>-SAm*0KT8#$YAWDXO{w5dJ@48*gO8HtfgM zv<#X$il#!QL1A_wiexTh#Cl#*c~Pj3K^RSE8Q!xAFhQ%c#)_;EjcY6}Rs(2$!2=m5 zYbu_!oDf5I0)%Ucl{I(fP@{1aRUzFF;mm&7!%TDocZ#Yym|W`>q!T$pFLk!&TD00l z3xl-=I>J~`fXKJTlRyU5vx_G%@06&kJHQ+w80KM(x_~63RQ*08F|YF-lqndt>j*Ot zLr9Fe`-N$mf&EoW<|r&L(bC(AET{?Le!lDFz>8!xE1T@Uxfz7(5D1Zc7z?TdXx!1F(t@UsVO!(J~K_(P{*a3nk({yEq&qS1H8a`Ek?Fg zmrTO@`~w+;ii{RdQe0j{)}kac>s!zIY%OeU=d->}KvIpz0^}|YU3W;3hMijy0gpXs zjuFW-JN~Y54q1_DA08x@Jc+7lq_PdnO}P>=L)@4nuY0_c z6^J^)N8jHA#ZT<~W(G#pIF& zSvzh9p_^J?PbOoKX%Wrq+rQKRp};BdlXR;bBMNfAh*_pVeX@PHf6eE0J2;vfY%S-LQ=*>k)V-7twJnCd?eQ~N@iGYOv0C5# z*|z^|tfjY`yxK=8v4!@|`edBywa~TILIQ+($?3?X8e*i%WIPFSt!dXW;oBNy5f@?< ziF+gWZaJk&aLL_>H8^ASEOgytVJzuvtGo7UZF9D_);uV()~@QH*c`rmox|K?K)n)V zFP3GSs$*Cp0ZK@@s`i0S8rs@v?&~F!OV>UC0ZSvq*xWQQgdrs`BR zw{>C3_tp+3djgvw`llp&Ys-f1Z_+v}Y{kkEYkNO027^#!vA$<`Z|{TFcPj zUWcwFrV%tZ1p+;FnmcK7Ej7*DlApBzKX-C&7;vd)>j2S)c3K98i3&>Jh~gNss(8mz zNvY}2pUrsqqcG7LN3=a?UU9#}TH`@uTNMD473C!#Kh|O`ooV86$mO89w=eX>7CONy+WxQ%Q8}_5O4aj!GUY7R%1VAWBknFMk2@DUzj}Fs!FR)*{D_H>HT;*@m z6kn{6cfQ23|6L*ptauMN3sY2;lI~HTr z#}>#oCC<9uUgg}v_f`f7om{q5Dg0aSyp7?sU&k_Ydd|=cP~j&T~c7{Z5^F5*~Vrfk+reS>oMQ4!lx!y&^Adj zDM-!0Gyny)lSssp#GW~Il6WSKmT&_t>oade6EPO2tAA_pD8;;Ue)I;p{iF!j01&!| z&>-f8Yv)MJIkxJ>`q+Gi?qct*q3>o|J26(~Fm#=4&v4zxgI6m0#bglHV~y>tEYD71 z=+IF|VvZKo=p62<>sp?GQkYtHok=o{Tha4N9DnwC-*dSw5>sgnDdqH1+e-!z*1;gu zW#`yq_aO~kcQG5*XIzpc8AHHG7O6nk&6K3d*Ca`@r^cYL?p+Ab(JKobZmf_L)lrE3 zrFy?$JsUGbJ?bzPiG`>^xQ>bpXCpD{mVreHMLqUPD)QCng)`7Wj6o8^u-~4a4bgqH1!>fPD(RGPY)w*8~fEwca%{0cyH2UAC zsQ*S{QA#B!u2k2dYl$!Y*n?jorN4t>03Oy{jgAmcCur{JWNu=LTxVN7t^JW%vbmdN z@)W2^h}A(DSr>p%JhznlmLJF)bdL3k>1^2g+6F+lVH>t#KUNlX6@ky^#gp2+(L)r= zu`H;)mF;q}i_+2@pp%cRE%@d(SyjQth&PeYqOM8KHR-KIF2>IkO>+T+>vrl;CUPtz zHL_^KuwY(nizRCmG6*3t>b{Ri+!-9VJ3zXeK`|_FyCi4*uW2(8cqsyWdV1j0X;2l& zwN`6(i6lmY0wc{moh(jIQplUM4|Fg$U7+A94({hvivka#_ndC)Bb{piH3f38Jf;0N z&jM@p+^vHnwDvAo=UXQe*HaQNNbDHGdPtD?WkyEh&i+~kBV-vGo2$RM150NM5Y}RF z$cjQ^dyd)hY4VE;v<~zJ?(aE1FpKy{?}pYRp?FpnHq%qcZgmBGP-$w9ab$k@|JUA~ zN6A^$`5*s0_0(R~)z$mnopg5+LN*qS!(?Itq6E;1NQgRmBvFSmibsyfcxH|uh#p4= z38*7D!)VZPorxO>E=QCT0}%v*AZzjqNoOPJEZv=?_vQZnD7w14s=90G>gw*)`*lvA zlRQh+eQLS&-21)vdu3@r#TCVpB$q#=vgXWY@24vEH`;dEaY-9vgTmux0?5K>t%-hD zJTdFe#m+!ZcLrf>M_We04cnveoF$Rpo(%63VrP#|dXPYt*q(su7A=y@kmQ`&RGgW1 z2K#(cpJHD39CvVMQ?Bl-629J}en)%U^X2K=D`Pn|=p5S#I%Wpp(9oFF!rAH|w;T>B zK}UFOeH#0n9sTr{!%}RFF_qUh(QjznEW;EB$k|R?qsJqr`ub#J+-t%>XNDtZQppUE z(d?;l?%U&M%w#Xc=;(oX#nn~n?r2kAPoI1R1!_MLy9RiQ?{nNBnhZMqU0(V8ZgqE# zd|+~MPZgV1Jl8V_ZI>%ys@F-?g;~CY>gGBf+2@n5AfVFn@!QMEe#fj(`u!NOb>2iN zamN!Igw13i-$d;t+vQf(+!F15De`+T_q_77AAgwJ>r#1Rqt2WfdU6Ef+gWQ+}L zWTkDG;IrWJdZd0MotgN!(j|;^gMf51^bjK5-Szv9&;8uz{k`j~XT9tEdCn}>lI_g? z=5<}4*qc35V}~yp8dyyZ)EM$pCKl+5J6Gav!$)tGGd_`(d^mk+%z1O}h7p|tyTBoP zcEdf$j+;hATzpSra^kecn2{j$#k-%1dasnzO6X~Mamsy-3FWQpWLc9m9~*^`eJMuS zKOiS5qS&nLd$k^}$0KN+XcOmABiZ*R;gT?Y^T*F5-&>;top1WN%--O9U^fSk0~vZa z%Dp`FZWP+3Z0>C#Mu^@Bvy6vUV;x4b^y_TtbvVd0u15YI$LH5&!3zm~(~@;W?#QnY zx5%Q+#5u;;rTcYrGu^%|L=gBRb5GW5gFAF>zYPtp{habjSQaDAW}-TBCvaX4C7_{T z{z5x#gyiM1R#|^BJ+N0&d-=M!H>sWM!{B75JjauZ8JsS|uB|vp`P$In0^UB8T$lYP za$A;yLrdRc9KWmKxFD?JR18y8)%0`TVdXLe99kO|NoC_pT?!-C6?Sk~zg3oLs%)zE zN@~^Ks*j>-P?R8Q5pXP)ZyAoUr=E6_8fnWa?mV7uC;KA6uiJVzMA$n~bpX|*z`;f@ z`@}JvJ7)Z@1G5?JcX|?4T9|kkwzFrf4EQcy_S_b!s}a2Y&?K?9X6CrFZU@iL$vzMW zNxnaO|I%4=(yHub?(Cort>f!+gJz29Ei-$A6rvg~@%Pb`h--n&T{1%7p)RrB;QrU3 z%{2;^s!7^2X0-Z7S__=Cf2F+l7-ATF^rIPVvlUGg81X8+P=u&XriK$5@iNS zIO&H=zsx;WhB)?Iq_sdYJnwOgDv^noPqqKBPJPjqERSguy~Stub3Mdnyyg^V4gse& z>G*FBf8&TZgtXX?kn7J8xB`kD`tT%3PILKVYhnU4zDK%Z^8&2Wo6BtSZWHN{ix}b7 z?OjRTB&ThI`58Psly>ZRD^%K_&%_G3x|T^R^*AZMCBXaEz5K$W((HV>R_UHS$#T!S z!9kN)+#Aevi1}EL$TvRyZ$_y@Cb%e}#=QJBRCjACFna9V3cvk|j8_;jc?($7;%; z6}GNm65ORozGg|%S`fhNmC!{;s~syz#+d7*7c-=~y52EY9Btrqm4h-O7aot-Ki6TZr^?f(W}+hQ6g##o(Cv^#+)d^{C>>0#8fmGnzIdm{rYNOo z>@O5S%9om%gxrTP&bP9ryKADUo)AvYh0z!3W+dVP996f=)FX$ z649t%*(P*_7jN{loUly@oHyy$Bg&aN96FX{?+YoMc=s}0xBh0Jl&U~Qdh5Ox|662@ zQ5;YF$8nfKg`k+l9_dJ{+^mUx4XyLy@{XTKLxKC!bz15YrrJ_OCP~t{TuM;Ae=}A) z;=!8C5fzONnd(~O8$6c^##hQo=O0w>Hq54X{yy~J;)|Z(ph3|J9x3cJ%g_D;*q2c(p-nPPZ+!P$pZOo7-~=F@ileKNI@5q zLm%?ei8OxN7*hxv?7uTiv2amUF<^;}YOB@O8t0`7p z)w_i4*s7Uhdefi3Q(4%H&8aDo!b!>=LfH$=+S^;=$eZ}Fg%YMKLkqSeD8sf$5I-H; zTu8D`>st!{;u^;v$;fqjf_`5CzMI(Zg-iD7jn{a`n7I1-=ei?$}Jb<`)5hxKFs}x*cBMen_-xr5#97F8+5^JvdQpf1RF7dP1 z@g9Z0D)GxTCz12tnceTaSmmFJJ^uxZe>+kKC3+~+6)Ult9rT`iq`zI&p`s08pD^`$ zPMr?6x!0OOBG2c~=6(SOk63)k$wtuOhC55x)2cv7Bd_Le9i(+HDSDJKAm74yeLH{7 zR8o4pa?8vrcwneYc;Iej5C>PK7nZwstVgw!-kGWyOKT5J_qOw4U~A#g_)HwQ=YYrk z&B$)jzGuR_m+#OzNU?alSQJ=s*=)2F-mcSY#LN<1^te{vb)sPb?{#PQ_lDPRqry4l z_I;paZ-tXuEo~phjji%MJULokxXC1Ip~s^!r7QW);Gw7>`@;6AFd0dRIb)V`TlaQ! zT2ZS!J6VPL+Oc)qxKO_(10Tis)|Shd%jCpH^H$6#BawRFKnKA=-Zuy6`OVg&><7UK zK_E;uUZ~V?_3`J^%=a7iojBZew&+)iKmX9-5GU$kR_46n-^)>_39eSaWm1x|=1@PB!I{vn$@rk)p-m$TQbl08T z6u)>VGfxsuuiv=c3d~Z=4BB3Y*O97i{!<(dd2cu0;E`&lzA{bMog0yGn_D9%cT?nU z88WyU8^|bkbT|+zR)~=sp$byzXF4eBGCZNH6WSefC!}FwI?T^^3G?W4$*oO&TrK=c z^;_ESW&uaj=O)7h+0vad*-8%wl}5T#(t8hGoDUD%4bcn*1w*NEw7DbL60_0NW_hFF z218Yf{>eMBog}orY<1%mHWfqPzw0F>DVv+`&av|oIU7<6j*4*<8?W!zb~Y%jtTS8t<`W5M&h66dBTlwIq716YY=Y7-BQq zEzN3#R<}rDORS7#@2orCX9r!1QlVqUZQrDnXdV8*f(`o7mZ*9LjnZ_V*DNsRs)_d# zRHZna&g0<@4ciNx1QC$@SNi`9s>;kHUFTT+Ia6rs=<&K+&yL*S8<4=VUF$S66>j zRpoZ^7DvavzpjF&HmaF)aDU??YdT@D_D$L=G)mL|{&-wYk4W`&0WxlN^}&rBna5Z? zx!=@6CTXAFq229z7bslWrvI-iO}V3w zT!-5)Lo=MtZoM%p%S5q1AlD%)_L{?eYOT%iEK(u)+i(Q!(~hSU(nwl{q+jLiQFrPU zGBV0#%7>A2fr=Ie&lzrp*H*;R5e@gDf@rO8$9z;(6Rz6Yb+crfV2OJnp_gQ{rX)5- zTftj49@MJBkgEpgGM2yeT-bqnDi+_5ap0q~&HWayFlmLK5s#NgCIfBh?k}UrOKf-* zgJVYL>Lw5z!xM6lRQfOZ}iN}U%&f5{lS=X&&(hEwqM z-k?3KXEogB;t4PgJ`zEgUdu^V$rJbV0O~Fv29X(|FqA5qD<*`WJx^}Vg&W)yg?~TH zb%v@YaR>6YX)RC})KYsM@3@Gu?JuU9ssG{CM{Jh7Onk;2bbncbZX*CAc*^Y;6m zlP0%;tjMb{GyXBDdY3Oh*WJjSOqMJ@t&8a?Lp0X|n)MZg7i~jrp|Z`WlDiO`KGFSl zx17H}+_)RuEQa9rX9LN)(<}WXncb*~viybfXF{@dxDG!F4Ru%XeAov4Cm6p63yCY< zi|m3&U95pCgAog#$?pr3VKwD0In@lK$N5dsZYzca$<(77GEXg@{MN;afaftRJmGsP zK^Ak#wBXeDer)_*kR)3R?`jEgi=!t1GJY!IhiE(Jtdh#~=2(iVd-w^2dRMA?ElqF2 zC%P8%&bl>+dviabMWYv+HgMQW8s7N#NUreO8KIK1Q0$Lk$xRLiE~TCA%I5IvT1(xe z_zymsXew)te+X=YF;)u4@Seqr&ra|&eEX(qs}Wb*ScKQmX%WYe-p~?mG!;15+&7tmmz#1Tq4M%-YUu0_3Qye-cBUg3({qJDtQ1?JHyrd$eC{ z#jS30wm}72dY_eR7n?0Htr=!==k^>Tzuj3`Zj=tR7@j;9;MR?jW#kc5Q-?|6OW|5a zYbajPWxou$l&a}cn#~`c5j|y+_DO3^QHu$U zo@#E@j&ro5spoW_#;ypX%vUu`o!@y%Msqigcy(-gzK0DwDGu1zjR3)i0rxr*W+}2AkRxFcDN6*2%tLefh8qP)|h5yQ=2eMr>|gDTYrPdL5_DvT%yzqM1Ex-y=S$tX_Qe z+>B}ZalMXSJ>f6-agOI=$(+`>hA{aUk6x#oh6;TmdB?WR(y7P>3i3@)`|3BiyIB!~ z6{Stck><#=bps1g6sf&H2?1aDSzSmkO?1=2p7ZXqIgtioyKtl5$Sd=p>IpRM5`Xp* zr*Gk2!24Az@-GFClMrZ?agE^hN6&A9hjth#Usl#);<+^i3>daf;<}XIzN6YT$zo}?S(K|*}zy_J;IfyU!{v9 zPZoCERn=pJOx}u#&9f4SZ*ssQ_cQ#%f~cWiW>R-vDX^(Nl5t-@hW-L|r@?D2>Y*TM#A+*^5j`$U>AEA%lD^mDA(c7VdEZI@;RhH>2Hl>eAI7N9u@#zX_{~9Bcy0v44 zfR?t}E3Me!ZegW&xSD21uvl?!=Q;r|w!x42g*|T|O%88+DOP^GYmcA!mJw7&0XOrw zk#yp3LqsTT)+>`qH<-(5h<8xlMamS#*>cLJ!GzKE5|J~XDSYEvIrUCkRIxv_lpU;Y z`M_{4#2IX@I$Z|REE4jDcRLMX$|l~6DZECZGqy@9mig6LPglkQ(}KTOD`@0ZWeX6vYHrtC2&c(I-==&W|Q4+Sh}5j z6z+}76xFD`O1 z4%|=|Q`^g+s`1eC536;NoA%SsB;sWAK2UWG%xt2An@D*OvOn{=TE8lImVb*;)!|(t zi!4%PCy_})QuO2wr*=5Ymj_vjrN2w4L0RQ-;tRh(%)^mKse!V3?z%fspC0sW3w_T@+-5HLHS++_^cyTQZL^0s=BC2zEYV3>iyrE?`(fui75PCH|q;K>2;aIYH8 zcg6B9X2A$8`l^jmu3t}M9hakV3;cxMjUNau-n-7d)zcN-rTICCP<1Hh>tpgpKl{SQ zL+OAe82rUMDt-!NC&{zdh$+NIeefRFOpSK2m1Zhprd1&bg`8-8079x`7O9@9n~K`N zVPHyVoa#eSP*7Y5K4hUJ{lxL>DctcXhhzRsj1Y}2zi^(;cU2C56@jw>eVxLpr2|Vl zyI1c9yBh-Q-@kv$Fn}dkFZy_xAI!bYK5>1KNAboT~3U593bmuYZAshDFABN+wT?dV~%{lp^kQ24Wat#gTV(kDalaeaQCA{%&3OErMNJWLN=GJ(rfU)M0Cp)Ow3UV6yfJJSuUoTUpuaM zspJfP{AOUVac#6_a4zc=29|@IoD82V4ObOQGG)Ca+115kr<_b}g=|&5FK5m*S?}A( zPhJcqi7HARuE(J^)FqndQO63gsFjE=x=_8m?v?0oOj-5$@!I*#p7ZToJ z;Hu+X^aXA9Rg@O>cJ&^fA=r&OjPbh@-#aav2}IO4-dC-z^D*m-u`S8s7^hWEBBwd6 ziz$@d?d-I@4S+t}>ao4ZZY zc8@}-GdmZn1Y9zdTQc=}zSE5>bBdd4nFP=V@cuYtcwM3j@3<@z#3>GDjZbJxR%qmR zxJLHaas|e~oPxNx!6~tjc$-qtPEId8^LIkzJLk*;X94^EN3_|CTp4`JZHAVEoKaCl z6uYrQ@~4Kide3xxlP45h^7Bkx@M7EPW1Ha{E|%YstS60gPYY7g;k$FilL3)*nQU|0 zg^=eDY%1n&a_8mO=!-PS&M`M?3{U2y6|1|yf27&FUVi9hBEK8ZC*8ZEX`Az{DR*$U zm~`rqwqS$84ObBN*%n`vr3IJvyR?0Lpd@oL7tWMZ18T3M^=`z`-&aPZS>Wv-p@*U7V)!l;OwQ>O2-czx*f^DhT(VuSr@fo;I*6WX~tk z@Wpck7CcXje8m~fs#Dl1p`L;2)fM7o)C^Qg+5Qbi>VEa~$aqwtFR6u9KQKq)^NxKVGe=ax#W8e=CKfo#>ADQxMh?LvBa6#z!x0%dGoIm4H?o-0o50Y8(c%0%&j~9?=UkSAh7;=;= zzJN0_WkzT*(k4G}XUu+d^ERZ}3P~}0g3=9s#+kl3N>Tn^B7k9SrGJN);~F8MY1N%3 zhxS=&1Tz9=z_z7{-@$vqsSs8CP1<8YDN!$lEUsG#@ap$;ClxHrgBrX@9cCm^`+|LL zaRQ4{Hzby>YVJl1AB8giFqx@eTAqV*xy0a9wb!Is(&uAtIVD%uip!0eOL@`Dop9%0 zKd^|1i7k5)xKEpi>%S(D?NZ__3B*~hdl1JXipyEh*SrEG5^9({*F=e;M|0<1O1I9c zIO2(Rt#mKGe0X9>EB2y`zu6-W3sR!>RV*{J4_u)LiNwx`lZ>spBQwupu>F^^iVhRe z!_#?bR!%0!d;#G9G8UFjfw+X?sg_m_*iSWpCtLaBLq3Ssq= zql-OX%B;~m!aF|+EUdaW-?TvdNSQvH6R)i@pWqmtTqGo|Mf{qaZEtX<#)7DH&%$3U zb=Ypi7(?&TZbIac(Iq_)2o$z}7sFtileqcDAUIJx|Cd4%tIy77kKaM`P91b{Z;v*o zPCFER(&$sXlwD`ME5fKn3-9?JH=}f4LVf`GBtt|*rg?=2f&9|>eOG#xPT9?Teg*~wyI~n3H7AQ-%A8h|t!L_W`-JwoorwOzeB2MwRouRG z1N_#_cI{@A+xbr0@W4Q7R@M_;x7h&uv(W-=de}s@vt{9DT+HYHyfQ)|h7bCLj}oCP zX`fkSDt+C~_WDFqN;t6o{Rs%-@$>kf?{IEFVP$1yw~p@F+M1hpXd?c3-%pM|y9UBJ zHC|#y2lHtVoxOZcQ`~-v;51YAsFnEqkZ!WxbxW@Y!6`+rkl6k;`XoC&-KhFA1QPW^ z7}tO&9Uq4ctrU(G-;>h}>2?`1$|a|uh_*(n?2dKQT>Gb&-J6^@qi*8*!!rsC3qPC* z3JTJ}u3x|YOnCY;Ykj=k*>-o+;NYOs=lD^#xJ~{_Oj|&bq~UQ{~rz$H8)X_-Z-YCTeDD6hz^jveMGG%^&e5 z$}Ho^$jP&7KK(O`=Q3a+9;ftW3vSh_xK6t-IT2s~=goE!a4XNhkNkvQ_}%~a0i>Pw z+W$3g{Z8D(ruqHj9nR|3)@Yi<=d7BiB~8D6eT8(f*61F0D&V(iEB`vPKCqv zMJW&~ej}$%*%ef^;n7jW4kJ!~yN@@oVLMdmpmaktjjk>)CK`ONoQQEC9FqV0=)Q9e z+rq*k0p3;~!fq-I2nUfjBK+AvV&=lA7KK4I5;Qj&f=%0fKDF8VC9Q~)aesZ+N zudw{DjPjc_!2LgI=Rfyu8nd}+U z$n>74r5<>H{m%S>2hq)IrB>q{Q|?Pi0-lHKl}jSWzx=i8y{g_fehvQo*?Cw^!+fbX zi9lp#E?l%Qzpz@SrKROI?w^m;AJuJ{oE1ani=yZi^J;5#Voh?Zs@$85a@!!?(6(^u zQRa%twCi^swlD4FR#v_y)^&=!JW09YpWV@}uo=!%%O|`Y0Dek64lIPvw*hNx0|g~# zt*}PH#KdIx#7nE#IB2u(NDmJWPrii?tDRO+_#k=&sgGP;wW!iz31~Jf(5X`X!fPAL z{$%sGu~BTwu}>Itl9xzIXJ==r&xJ?b`ED11h~g7xYKZh8XP;0ytM5sdN1x|(`ox5m zW-MJ-F+4dg-(@aUJh0S$PEt)nqigNJ;`02Y`r;5hj6y|?I4^SJbU&Dynp(8IT(k{= zs;H@%fHBy{vg+GF5yizEC9Yf3D=d%wgtmX++#`WW!Zj=HQmZzr*Sd)`bPR9bz8wN} z1UZf{Y71*gaz)AWlEPY-_7G*3qX~?Kl}&s}3P$MilaoH*AbZD~tES5M$zrIUp{<;=}< zBaI)j>RDlW1;sHGTJomlHVTE}9PjG8bL}R(rPpCe3OBr)E>7o7Zz4azl+V$WD>)3) zK1zcR;y%}SgPMzX($3#~B{QzoO4MQgQ`L6UolxlB zHZxU@+4eP4Pbl!el$V#YFC_}O4}N&@h_>T=KU*}^M$}_J$K1w-2i~0?E|PGth7+Mr zsY;<_G*Uv2-)jBRIE z7Z|ud^|?WJJy_xnLSkZa*doYWai_EG=fmKmh;*rN$qL)?;udL+o}Qkp1~5k>6XheU zA8+*m1ol{0Oqo)7EI-;iT|<(Gw{aT!vBhn)5;3Ziz!sxY&W)Y4eSLkM*jhF!W|DWD zm%sz>xJ>?AzQzMOCHYSKFd;!o(!rf0-va?g2Aoq&^^wD;>2~nZYX!0qag%}Mo z*kZEhW{}_#XfGquxc+n)88&_ygY*JJvK%c86{B zjC)0|yh=QeHgDm47&3yG!o$7SjX(q~@-`qwc6g_pR3_`?zwnU476|}Q$8U(9bSQ%L zu^KO~bWC<%L@d8`hDeXVTeL|K0b|6%GLSO!GtApMSJEnR)=M@vJJzmVAVA>2Nzqu`tP1UC+D@1 zTS*LBO_y8!{+Xp(Iy5xYV>W=e)c5MrDfk{_*G=ddrG~QdW7tH6ZHk`HiOGqVh%V^G zzRg5Ad>|Z`2JJ9zLm{$+Av!uqT1rZ(F^n2{1qJfq!iQtf!l$D5X?}j&o^LNI<92znFLz^;YM#(;>=1KgHlEiv$si4sMT&(G@ zEq<{bX%oG)pYq(Q6L6SskDo)M(HH>&li?dHEi6>X%*?bfHy>C*Wv8Rf{juX2z~liI z59TB5^QbSk1C0(bmHHH9FETCdv94neH(YWxz2!48J;hK=j!0+eCXS!nLh)3)x7HzM znrvPwQ8^Bzb#>ley{b2Rl7eZ3(wfyb*MX=+8R;Q(9QJRtdcsK z@of+}$W;W$rYU+l5T4d{Z`4piLgMx{>|4~FO5QKY34f>-S!-h|4yM6(_q?71URgAK zRDZOBkv)T%vTEAe`7>9SM+0Knh#@J&h&Lo6SEro3O0kOylZm~_!i=!pvkSn8XV&j( zus|X#8$f#PZ`k^3v&6~jA2$)Fp`(WNmAxo=yn0_QDQ#i35AoLT;0YfLsZ;~FOv}xG zf+6%z%q=aWpu+q8;-@{fKIyhTdtWT9t%ol!&YZiLD{^XTv`6)Q>oA;Iy7L>p?1B(r z=}|?X@?zu+*bod=_1TR|Sssw0j%l9lN6%ckaP#m;hWBLr_Z_$fF@Wp2FA#Ze`%_1= zOS(=uYdA5#fW({QQ}BJ&wygv3RWSGwk36e}4-!$Hn)(Pf_~VD)X6>GW_xT*PhNdQh zjEu}_RM&L{OhCmK^N54BLUWd9kdQ4Y(bGtPNQKqa0>UTrQGk|wbJ?0Mz$52b-mKXg zF^SX7nw|!T`vmffLdx%c$qP1;XsENZGu`CO`1msz!_%i2P4OI#Pc8BdKwb_=5Z)QD z5}WFkEcUbiwv0bRd8c#g>qWfRv}a^aC|21aoE^)jCLrML&jKsJvUqpztPC0ECQ4D_ z(x?D}G;lDgj~pCSwnmrBcg3-4(A@8@b>G`hxjISlSQ4Dc1-bj{|jMn+JwzhU@$y4>T(@#84{c?BbsIOjMTwG|C znD)f??id1$kAq(9g!@KEYkEhUPu+kZI!$UK&q~3yo==*Fe7ZA*x-}0rsi*6mRt6Je zZaw!A1e;S>yDL)%a_0}Lx1(C*-Toz(#Ap8k@Wa~yF9tSB#jxaJX2aHI6?R@XqXS2EC=0{& z+-9#I938b$ru=WJ$afI?L=PWYY)#hE0bT8%V8(?ZLFzIiM zc4;#57D-{LS?7QaUR}&wwQNu$FRM>i^V`;`K~z|;uB;3!#ObpjM+I{&u_E+C*l%4O zZlLL5pxg!9kS7O-V=^@ti5x99=`s!p%KD4a+6`?-qm)o7(wTwpc{px{ff4ha-z@{x z4xR~UJdC*FN@8yH(8(pbGfi`zQ@CxRCj7q^>x zRL2O40N9ey+uNJ5biV1@zzH6C2MZbp$l!F=nq&WeoMavp&NrEoNbm9w$ZyFj9Y-<@AnGndQiv)X@iZg4%VjQ@h!Jpjz}zJ&Kq!4Vs$CM6gjWH9 zz**N_3j}(JB=FtbyR@gl2!O*?;&XAjeAU(2**#t7iP8PLKkAOA$*sFFu(k*&n&PqR z`SA|>(v%ZQUJ|}@c3}#_Y!S?wOCcLIR&MQcxtBukRmiMeP7>H`$kfGwA>Nu5Hc7Vv zn?c9GgL1ShgX~HaLsHo>oGNw=*)aVt5)K60;-D;RsEgLcoam*|r=tzY+x%3%=R3>nwqWSprw42DYHBf&#v=vVW*^u4nVz|6#zSH_NV-KBeV{qR{-`5+>!*`W88v0&P>%1EPks0&NYHNe@7i?neI`*)?w=yFUSwrLu09-v_Z(U+KteJ8i{MfBXgr z4|zTc9>4fEQn9rb0}n96Ir@maM1Og-$;Mx6~@v#Uj@ zPmZ`G4JM-ygld@bKMn+>%;-0^ld%SXy6y5En#S;g{|#q&CG?3u3h7G zlw5t?OvoL_V zzsqqjCpd%?repr~i}qR;ryc|q+XP)Ma5T)~$KGj%Bg*AZ$C$Vf4hWw-w0d!pyx3+oQGhKGjQ z=XNlK-4^f`=c|jOeDGY`S^j(tWUPs1^<0?X(CBD3#^}(hcFM(m=NB9tynhHDK+_$@ff?9fwdtSUS6f(Ez<_kR zSjzxDr9j4Bub^U5Qj!NibLYnE#4nEu^y;+G6}E9u?=$pF*qz7tcz7$IP!CgI9ZNB5 zl}fYLZoj}q{;?Cgeg1aeGv@2hHc>#_XFaE(F$zicKGdATOs`!^Bn_ZvPI0j^hS`E4 z#172)f_7t}-oJ3DrAa8FGl0FN>#>pv&zQM9thg1Po1c$r4JDr|7o@rG&l&~wpT~$32(}^0vge|;v@qHs&S>X>n5hFC2U}L zV}60K)et^>xC)$)uOi2d_hjH69v=IkVy~>}UsadeOyv>KD=Y&pGz>f=;8m*lOai_j z4Od7JDFA|wUbsF#pKD_nY5kW@)7?K?`y&GsTV0=% z1&`&_K&Oj?QQ6>cM8_cVo)ol^f-4@-Dc`FLj#lV}=XT?b67Of2lOu@#E}q4UAIWV# zU7i7NM}BRxvr`5(myIt!Jlh=DHZ_Vy=L7Fk1`a$1`U4@%d2HH${p-?+u{;||$R&Di z=jGLg-OGfO)_D6q-J&-EP;-EQQ4GzX!nAoOgKPo!EkB3vXYtmVEBF10$`T@Qn+te7 zS3W!3Ky6OdL{S}0ZTaw?ZUNWe>FL(1r%)KEe$+>kv*UpvZr*P;TwUHu$u22T0mSIm z5hb6UfgJxwz=~I7^-tl@72pGm!s%q6c;RE3<SNgJLEV9r@ZEKZM zPKUDh!DdMu9UUI72GJ+?n~Cg^a&vLzl$CMM&S5s%i25v^%Iz1{!T!E99G^-k=cwW8 z!ef6xiVmiK{+pQ73p`*Dl+(kK?w2S3f`y2Eph49Eo|BgJY`8FjqEb*`Ehp25u}^@9 ztOjJC6B?63nDlBT880sY`W;vqFfszGbx@En@(w8ct1z28xc=Zu!RqpIe}Dfgr!(O9 zs;a0ETwDRxQs#Z;a