diff --git a/bonsai/models/clip/README.md b/bonsai/models/clip/README.md new file mode 100644 index 00000000..477ece02 --- /dev/null +++ b/bonsai/models/clip/README.md @@ -0,0 +1,28 @@ +# ITA-CLIP — CLIP-style model (JAX / Flax) + +This directory contains a compact CLIP-like implementation (ITA-CLIP) in JAX/Flax, +intended for zero-shot image classification, prompt-guided heatmaps, and image-text embedding experiments. + +## Paper (reference) + +- Radford et al., *Learning Transferable Visual Models From Natural Language Supervision* (OpenAI CLIP) + Local copy used during development: `/mnt/data/2103.00020v1.pdf` + +--- + +## Tested on + +| Model Name | Config | CPU | GPU (single) | GPU (multi) | TPU | +| :--- | :---: | :---: | :---: | :---: | :---: | +| ITA-CLIP (TinyViT + TinyText) | ✅ Compact research config | ✅ Runs (CPU) | ❔ Needs check (CUDA JAX) | ❔ Needs check | ❔ Needs check | + +> Notes: This implementation uses a compact TinyViT and small text-transformer to make local testing and CI-friendly smoke tests possible. For large-scale ViT-B/32 or ViT-L/14 variants, add config presets and provide pretrained weights. + +--- + +### Running this model (quick smoke test) + +Run a forward pass / smoke test: + +```bash +python3 -m bonsai.models.clip.tests.run_model diff --git a/bonsai/models/clip/__init__.py b/bonsai/models/clip/__init__.py new file mode 100644 index 00000000..6858e88b --- /dev/null +++ b/bonsai/models/clip/__init__.py @@ -0,0 +1,5 @@ +from .modeling import CLIPModel, clip_contrastive_loss +from .params import CLIPConfig +from .tokenizer import load_tokenizer, simple_whitespace_tokenizer + +__all__ = ["CLIPModel", "clip_contrastive_loss", "CLIPConfig", "load_tokenizer", "simple_whitespace_tokenizer"] diff --git a/bonsai/models/clip/modeling.py b/bonsai/models/clip/modeling.py new file mode 100644 index 00000000..91acae03 --- /dev/null +++ b/bonsai/models/clip/modeling.py @@ -0,0 +1,185 @@ +from typing import Any +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax.linen import initializers +from .params import CLIPConfig + +def _get_dtype(cfg: CLIPConfig): + return jnp.float32 if cfg.dtype == "float32" else jnp.float16 + +class MLPBlock(nn.Module): + mlp_dim: int + out_dim: int + act = nn.gelu + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.mlp_dim, dtype=self.dtype)(x) + x = self.act(x) + x = nn.Dense(self.out_dim, dtype=self.dtype)(x) + return x + +class AddPositionEmbs(nn.Module): + max_len: int + emb_dim: int + dtype = jnp.float32 + + def setup(self): + self.pos_emb = self.param("pos_emb", initializers.normal(0.02), (1, self.max_len, self.emb_dim)) + + def __call__(self, x): + return x + self.pos_emb + +class TransformerEncoderBlock(nn.Module): + num_heads: int + mlp_dim: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x, deterministic=True): + y = nn.LayerNorm(dtype=self.dtype)(x) + y = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, deterministic=deterministic)(y) + x = x + y + y = nn.LayerNorm(dtype=self.dtype)(x) + y = MLPBlock(self.mlp_dim, x.shape[-1], dtype=self.dtype)(y) + return x + y + +class SimplePatchEmbed(nn.Module): + patch_size: int + emb_dim: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + ps = self.patch_size + x = nn.Conv(self.emb_dim, (ps,ps), strides=(ps,ps), padding='VALID', dtype=self.dtype)(x) + b,h,w,c = x.shape + return jnp.reshape(x, (b, h*w, c)) + +class ImageEncoderViT(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, images, deterministic=True): + cfg = self.cfg + x = SimplePatchEmbed(cfg.patch_size, cfg.image_embed_dim, dtype=self.dtype)(images) + cls = self.param('cls', initializers.zeros, (1,1,cfg.image_embed_dim)) + cls_b = jnp.tile(cls, (x.shape[0],1,1)) + x = jnp.concatenate([cls_b, x], axis=1) + x = AddPositionEmbs(x.shape[1], cfg.image_embed_dim, dtype=self.dtype)(x) + for _ in range(cfg.vit_num_layers): + x = TransformerEncoderBlock(cfg.vit_num_heads, cfg.vit_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) + cls_out = x[:,0] + cls_out = nn.LayerNorm(dtype=self.dtype)(cls_out) + img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(cls_out) + return img_feat + +# small ResNet-like encoder (kept light) +class ResNetStem(nn.Module): + out_ch: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + x = nn.Conv(self.out_ch, (7,7), strides=(2,2), padding='SAME', use_bias=False, dtype=self.dtype)(x) + x = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(x) + x = nn.relu(x) + x = nn.max_pool(x, (3,3), strides=(2,2), padding='SAME') + return x + +class ResidualBlock(nn.Module): + out_ch: int + strides: tuple = (1,1) + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + residual = x + y = nn.Conv(self.out_ch, (3,3), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(x) + y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) + y = nn.relu(y) + y = nn.Conv(self.out_ch, (3,3), padding='SAME', use_bias=False, dtype=self.dtype)(y) + y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) + if residual.shape[-1] != self.out_ch or self.strides != (1,1): + residual = nn.Conv(self.out_ch, (1,1), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(residual) + residual = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(residual) + return nn.relu(residual + y) + +class ImageEncoderResNet(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, images, deterministic=True): + cfg = self.cfg + x = ResNetStem(cfg.resnet_stem_channels, dtype=self.dtype)(images) + for ch, repeats in zip(cfg.resnet_block_channels, cfg.resnet_block_repeats): + for i in range(repeats): + strides = (2,2) if i == 0 else (1,1) + x = ResidualBlock(ch, strides=strides, dtype=self.dtype)(x) + x = x.mean(axis=(1,2)) + x = nn.LayerNorm(dtype=self.dtype)(x) + img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(x) + return img_feat + +class TextEncoder(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, token_ids, deterministic=True): + cfg = self.cfg + tok_emb = nn.Embed(num_embeddings=cfg.text_vocab_size, features=cfg.text_embed_dim, dtype=self.dtype)(token_ids) + tok_emb = AddPositionEmbs(tok_emb.shape[1], cfg.text_embed_dim, dtype=self.dtype)(tok_emb) + x = tok_emb + for _ in range(cfg.text_num_layers): + x = TransformerEncoderBlock(cfg.text_num_heads, cfg.text_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) + eos_feat = x[:, -1, :] + eos_feat = nn.LayerNorm(dtype=self.dtype)(eos_feat) + txt_feat = nn.Dense(cfg.text_embed_dim, dtype=self.dtype)(eos_feat) + return txt_feat + +class CLIPModel(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + def setup(self): + self.cfg.apply_model_size_presets() + self._dtype = _get_dtype(self.cfg) + if self.cfg.encoder_type == 'vit': + self.image_encoder = ImageEncoderViT(self.cfg, dtype=self._dtype) + else: + self.image_encoder = ImageEncoderResNet(self.cfg, dtype=self._dtype) + self.text_encoder = TextEncoder(self.cfg, dtype=self._dtype) + self.img_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) + self.txt_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) + self.logit_scale = self.param('logit_scale', lambda rng, shape: jnp.array(1.0), ()) + + def encode_image(self, images, deterministic=True): + feats = self.image_encoder(images, deterministic=deterministic) + proj = self.img_proj(feats) + proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) + return proj + + def encode_text(self, token_ids, deterministic=True): + feats = self.text_encoder(token_ids, deterministic=deterministic) + proj = self.txt_proj(feats) + proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) + return proj + + def __call__(self, images, token_ids, deterministic=True): + i_e = self.encode_image(images, deterministic=deterministic) + t_e = self.encode_text(token_ids, deterministic=deterministic) + scale = jnp.exp(self.logit_scale) + logits = jnp.matmul(i_e, t_e.T) * scale + return logits, i_e, t_e, scale + +def clip_contrastive_loss(logits: jnp.ndarray): + n = logits.shape[0] + labels = jnp.arange(n) + loss_i = jnp.mean(nn.softmax_cross_entropy(logits=logits, labels=jax.nn.one_hot(labels, n), axis=1)) + loss_t = jnp.mean(nn.softmax_cross_entropy(logits=logits.T, labels=jax.nn.one_hot(labels, n), axis=1)) + return 0.5 * (loss_i + loss_t) diff --git a/bonsai/models/clip/params.py b/bonsai/models/clip/params.py new file mode 100644 index 00000000..60ce4fb8 --- /dev/null +++ b/bonsai/models/clip/params.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import Literal + +@dataclass +class CLIPConfig: + image_size: int = 224 + encoder_type: Literal["vit", "resnet"] = "vit" + model_size: Literal["ViT-B/32", "ViT-L/14"] = "ViT-B/32" + dtype: str = "float32" + + patch_size: int = 32 + image_embed_dim: int = 768 + vit_num_layers: int = 12 + vit_num_heads: int = 12 + vit_mlp_dim: int = 3072 + + resnet_stem_channels: int = 64 + resnet_block_channels: tuple = (64, 128, 256, 512) + resnet_block_repeats: tuple = (3, 4, 6, 3) + + # text encoder + text_embed_dim: int = 512 + text_vocab_size: int = 49408 + text_max_length: int = 77 + text_num_layers: int = 12 + text_num_heads: int = 8 + text_mlp_dim: int = 2048 + + proj_dim: int = 512 + + def apply_model_size_presets(self): + if self.model_size == "ViT-B/32": + self.patch_size = 32 + self.image_embed_dim = 768 + self.vit_num_layers = 12 + self.vit_num_heads = 12 + self.vit_mlp_dim = 3072 + self.text_embed_dim = 512 + self.proj_dim = 512 + elif self.model_size == "ViT-L/14": + self.patch_size = 14 + self.image_embed_dim = 1024 + self.vit_num_layers = 24 + self.vit_num_heads = 16 + self.vit_mlp_dim = 4096 + self.text_embed_dim = 1024 + self.proj_dim = 1024 + else: + raise ValueError("Unknown model_size: " + str(self.model_size)) diff --git a/bonsai/models/clip/run_model.py b/bonsai/models/clip/run_model.py new file mode 100644 index 00000000..91acae03 --- /dev/null +++ b/bonsai/models/clip/run_model.py @@ -0,0 +1,185 @@ +from typing import Any +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax.linen import initializers +from .params import CLIPConfig + +def _get_dtype(cfg: CLIPConfig): + return jnp.float32 if cfg.dtype == "float32" else jnp.float16 + +class MLPBlock(nn.Module): + mlp_dim: int + out_dim: int + act = nn.gelu + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.mlp_dim, dtype=self.dtype)(x) + x = self.act(x) + x = nn.Dense(self.out_dim, dtype=self.dtype)(x) + return x + +class AddPositionEmbs(nn.Module): + max_len: int + emb_dim: int + dtype = jnp.float32 + + def setup(self): + self.pos_emb = self.param("pos_emb", initializers.normal(0.02), (1, self.max_len, self.emb_dim)) + + def __call__(self, x): + return x + self.pos_emb + +class TransformerEncoderBlock(nn.Module): + num_heads: int + mlp_dim: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x, deterministic=True): + y = nn.LayerNorm(dtype=self.dtype)(x) + y = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, deterministic=deterministic)(y) + x = x + y + y = nn.LayerNorm(dtype=self.dtype)(x) + y = MLPBlock(self.mlp_dim, x.shape[-1], dtype=self.dtype)(y) + return x + y + +class SimplePatchEmbed(nn.Module): + patch_size: int + emb_dim: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + ps = self.patch_size + x = nn.Conv(self.emb_dim, (ps,ps), strides=(ps,ps), padding='VALID', dtype=self.dtype)(x) + b,h,w,c = x.shape + return jnp.reshape(x, (b, h*w, c)) + +class ImageEncoderViT(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, images, deterministic=True): + cfg = self.cfg + x = SimplePatchEmbed(cfg.patch_size, cfg.image_embed_dim, dtype=self.dtype)(images) + cls = self.param('cls', initializers.zeros, (1,1,cfg.image_embed_dim)) + cls_b = jnp.tile(cls, (x.shape[0],1,1)) + x = jnp.concatenate([cls_b, x], axis=1) + x = AddPositionEmbs(x.shape[1], cfg.image_embed_dim, dtype=self.dtype)(x) + for _ in range(cfg.vit_num_layers): + x = TransformerEncoderBlock(cfg.vit_num_heads, cfg.vit_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) + cls_out = x[:,0] + cls_out = nn.LayerNorm(dtype=self.dtype)(cls_out) + img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(cls_out) + return img_feat + +# small ResNet-like encoder (kept light) +class ResNetStem(nn.Module): + out_ch: int + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + x = nn.Conv(self.out_ch, (7,7), strides=(2,2), padding='SAME', use_bias=False, dtype=self.dtype)(x) + x = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(x) + x = nn.relu(x) + x = nn.max_pool(x, (3,3), strides=(2,2), padding='SAME') + return x + +class ResidualBlock(nn.Module): + out_ch: int + strides: tuple = (1,1) + dtype = jnp.float32 + + @nn.compact + def __call__(self, x): + residual = x + y = nn.Conv(self.out_ch, (3,3), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(x) + y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) + y = nn.relu(y) + y = nn.Conv(self.out_ch, (3,3), padding='SAME', use_bias=False, dtype=self.dtype)(y) + y = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(y) + if residual.shape[-1] != self.out_ch or self.strides != (1,1): + residual = nn.Conv(self.out_ch, (1,1), strides=self.strides, padding='SAME', use_bias=False, dtype=self.dtype)(residual) + residual = nn.BatchNorm(use_running_average=True, dtype=self.dtype)(residual) + return nn.relu(residual + y) + +class ImageEncoderResNet(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, images, deterministic=True): + cfg = self.cfg + x = ResNetStem(cfg.resnet_stem_channels, dtype=self.dtype)(images) + for ch, repeats in zip(cfg.resnet_block_channels, cfg.resnet_block_repeats): + for i in range(repeats): + strides = (2,2) if i == 0 else (1,1) + x = ResidualBlock(ch, strides=strides, dtype=self.dtype)(x) + x = x.mean(axis=(1,2)) + x = nn.LayerNorm(dtype=self.dtype)(x) + img_feat = nn.Dense(cfg.image_embed_dim, dtype=self.dtype)(x) + return img_feat + +class TextEncoder(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + @nn.compact + def __call__(self, token_ids, deterministic=True): + cfg = self.cfg + tok_emb = nn.Embed(num_embeddings=cfg.text_vocab_size, features=cfg.text_embed_dim, dtype=self.dtype)(token_ids) + tok_emb = AddPositionEmbs(tok_emb.shape[1], cfg.text_embed_dim, dtype=self.dtype)(tok_emb) + x = tok_emb + for _ in range(cfg.text_num_layers): + x = TransformerEncoderBlock(cfg.text_num_heads, cfg.text_mlp_dim, dtype=self.dtype)(x, deterministic=deterministic) + eos_feat = x[:, -1, :] + eos_feat = nn.LayerNorm(dtype=self.dtype)(eos_feat) + txt_feat = nn.Dense(cfg.text_embed_dim, dtype=self.dtype)(eos_feat) + return txt_feat + +class CLIPModel(nn.Module): + cfg: CLIPConfig + dtype = jnp.float32 + + def setup(self): + self.cfg.apply_model_size_presets() + self._dtype = _get_dtype(self.cfg) + if self.cfg.encoder_type == 'vit': + self.image_encoder = ImageEncoderViT(self.cfg, dtype=self._dtype) + else: + self.image_encoder = ImageEncoderResNet(self.cfg, dtype=self._dtype) + self.text_encoder = TextEncoder(self.cfg, dtype=self._dtype) + self.img_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) + self.txt_proj = nn.Dense(self.cfg.proj_dim, dtype=self._dtype, use_bias=False) + self.logit_scale = self.param('logit_scale', lambda rng, shape: jnp.array(1.0), ()) + + def encode_image(self, images, deterministic=True): + feats = self.image_encoder(images, deterministic=deterministic) + proj = self.img_proj(feats) + proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) + return proj + + def encode_text(self, token_ids, deterministic=True): + feats = self.text_encoder(token_ids, deterministic=deterministic) + proj = self.txt_proj(feats) + proj = proj / (jnp.linalg.norm(proj, axis=-1, keepdims=True) + 1e-10) + return proj + + def __call__(self, images, token_ids, deterministic=True): + i_e = self.encode_image(images, deterministic=deterministic) + t_e = self.encode_text(token_ids, deterministic=deterministic) + scale = jnp.exp(self.logit_scale) + logits = jnp.matmul(i_e, t_e.T) * scale + return logits, i_e, t_e, scale + +def clip_contrastive_loss(logits: jnp.ndarray): + n = logits.shape[0] + labels = jnp.arange(n) + loss_i = jnp.mean(nn.softmax_cross_entropy(logits=logits, labels=jax.nn.one_hot(labels, n), axis=1)) + loss_t = jnp.mean(nn.softmax_cross_entropy(logits=logits.T, labels=jax.nn.one_hot(labels, n), axis=1)) + return 0.5 * (loss_i + loss_t) diff --git a/bonsai/models/clip/tests/run_model.py b/bonsai/models/clip/tests/run_model.py new file mode 100644 index 00000000..1545226a --- /dev/null +++ b/bonsai/models/clip/tests/run_model.py @@ -0,0 +1,6 @@ +import pytest +from clip.run_model import run_demo + +def test_run_demo_smoke(): + # smoke test that demo runs without raising + run_demo() diff --git a/bonsai/models/clip/tests/test_modeling.py b/bonsai/models/clip/tests/test_modeling.py new file mode 100644 index 00000000..43196a92 --- /dev/null +++ b/bonsai/models/clip/tests/test_modeling.py @@ -0,0 +1,24 @@ +import jax +import jax.numpy as jnp +from clip.params import CLIPConfig +from clip.modeling import CLIPModel, clip_contrastive_loss +from clip.tokenizer import simple_whitespace_tokenizer + +def test_clip_forward_and_loss(): + cfg = CLIPConfig() + cfg.model_size = "ViT-B/32" + cfg.dtype = "float32" + cfg.apply_model_size_presets() + cfg.image_size = 64 + cfg.text_max_length = 16 + model = CLIPModel(cfg) + rng = jax.random.PRNGKey(0) + images = jax.random.normal(rng, (2, cfg.image_size, cfg.image_size, 3)) + tokens, _ = simple_whitespace_tokenizer(["a cat", "a dog"], max_length=cfg.text_max_length) + params = model.init(rng, images, tokens) + logits, i_e, t_e, scale = model.apply(params, images, tokens, deterministic=True) + assert logits.shape == (2,2) + assert i_e.shape[1] == cfg.proj_dim + assert t_e.shape[1] == cfg.proj_dim + loss = clip_contrastive_loss(logits) + assert jnp.isfinite(loss) diff --git a/bonsai/models/clip/tests/test_outputs_clip.py b/bonsai/models/clip/tests/test_outputs_clip.py new file mode 100644 index 00000000..31592716 --- /dev/null +++ b/bonsai/models/clip/tests/test_outputs_clip.py @@ -0,0 +1,10 @@ +import os +from clip.run_model import run_demo + +PAPER_PATH = "/mnt/data/2103.00020v1.pdf" # local path to the uploaded paper + +def test_paper_exists(): + assert os.path.exists(PAPER_PATH), f"Paper must be present at {PAPER_PATH}" + +def test_demo_runs(): + run_demo() # should not raise diff --git a/bonsai/models/clip/tokenizer.py b/bonsai/models/clip/tokenizer.py new file mode 100644 index 00000000..700c8747 --- /dev/null +++ b/bonsai/models/clip/tokenizer.py @@ -0,0 +1,49 @@ +import os +from typing import List, Optional, Callable, Tuple +import jax.numpy as jnp + +def simple_whitespace_tokenizer(texts: List[str], max_length: int = 77) -> Tuple[jnp.ndarray, dict]: + vocab = {"": 0, "": 1} + next_id = 2 + batch = [] + for t in texts: + toks = t.strip().lower().split() + ids = [] + for w in toks[:max_length]: + if w not in vocab: + vocab[w] = next_id + next_id += 1 + ids.append(vocab[w]) + ids += [0] * (max_length - len(ids)) + batch.append(ids) + import numpy as _np + return jnp.array(_np.array(batch, dtype=_np.int32)), vocab + +def load_tokenizer(tokenizer_path: Optional[str]) -> Optional[Callable[[List[str], int], jnp.ndarray]]: + if tokenizer_path is None: + return None + try: + from tokenizers import Tokenizer + if os.path.isdir(tokenizer_path): + for fname in ("tokenizer.json", "bpe.json", "vocab.json"): + p = os.path.join(tokenizer_path, fname) + if os.path.exists(p): + tokenizer = Tokenizer.from_file(p) + break + else: + tokenizer = None + elif os.path.exists(tokenizer_path): + tokenizer = Tokenizer.from_file(tokenizer_path) + else: + tokenizer = None + if tokenizer is None: + return None + + def encode_texts(texts: List[str], max_length: int = 77): + encs = [tokenizer.encode(t).ids[:max_length] for t in texts] + padded = [e + [0]*(max_length - len(e)) if len(e) < max_length else e for e in encs] + import numpy as _np + return jnp.array(_np.array(padded, dtype=_np.int32)) + return encode_texts + except Exception: + return None diff --git a/bonsai/models/clip/train.py b/bonsai/models/clip/train.py new file mode 100644 index 00000000..6c56719f --- /dev/null +++ b/bonsai/models/clip/train.py @@ -0,0 +1,201 @@ +""" +Training entrypoint for CLIP reproduction. + +- Supports ViT-B/32 and ViT-L/14 via CLIPConfig.model_size +- Supports loading a pretrained tokenizer (tokenizers lib) via tokenizer_path +- Mixed precision: set CLIPConfig.dtype = "float16" to run model in float16 where supported +- Uses TFDS CIFAR-10 as toy dataset mapping labels -> captions +""" + +import os +import time +from dataclasses import dataclass +from typing import Optional, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_datasets as tfds +from flax.training import train_state + +from .params import CLIPConfig +from .modeling import CLIPModel, clip_contrastive_loss +from .tokenizer import simple_whitespace_tokenizer, load_tokenizer + + +@dataclass +class TrainConfig: + batch_size: int = 64 + epochs: int = 3 + lr: float = 3e-4 + workdir: str = "/tmp/clip_run" + tokenizer_path: Optional[str] = None + image_size: Optional[int] = None + +CIFAR10_LABELS = [ + "airplane","automobile","bird","cat","deer", + "dog","frog","horse","ship","truck" +] +TEMPLATES = ["a photo of a {}", "a close-up photo of a {}"] + +def preprocess(example, image_size): + img = example["image"] + img = tfds.as_numpy(img) / 255.0 + + if image_size != 32: + import cv2 + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LINEAR) + # normalize to [-1,1] + img = (img - 0.5) * 2.0 + label = int(example["label"]) + return img.astype(np.float32), label + +def make_datasets(batch_size, image_size): + ds_train = tfds.load("cifar10", split="train", as_supervised=False) + ds_train = ds_train.shuffle(1024).map(lambda ex: tfds.as_numpy(ex)) + ds_train = ds_train.batch(batch_size) + ds_val = tfds.load("cifar10", split="test", as_supervised=False).map(lambda ex: tfds.as_numpy(ex)).batch(batch_size) + return ds_train, ds_val + +def create_state(key, cfg: CLIPConfig, lr: float): + cfg.apply_model_size_presets() + model = CLIPModel(cfg) + dummy_img = jnp.zeros((1, cfg.image_size, cfg.image_size, 3), dtype=jnp.float32) + dummy_txt = jnp.zeros((1, cfg.text_max_length), dtype=jnp.int32) + params = model.init(key, dummy_img, dummy_txt) + tx = optax.adamw(lr) + state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx) + return state, model + +@jax.jit +def train_step(state, model, images, tokens): + def loss_fn(params): + logits, _, _, _ = model.apply(params, images, tokens, deterministic=False) + loss = clip_contrastive_loss(logits) + return loss + grads = jax.grad(loss_fn)(state.params) + new_state = state.apply_gradients(grads=grads) + return new_state + +def zero_shot_eval(state_params, model, tokenizer_fn, image_batch): + prompts = [t.format(c) for c in CIFAR10_LABELS for t in TEMPLATES] + if tokenizer_fn is None: + token_ids, _ = simple_whitespace_tokenizer(prompts, max_length=model.cfg.text_max_length) + else: + token_ids = tokenizer_fn(prompts, max_length=model.cfg.text_max_length) + text_embs = model.apply(state_params, jnp.zeros((1, model.cfg.image_size, model.cfg.image_size, 3)), token_ids, method=CLIPModel.encode_text, deterministic=True) + n_templates = len(TEMPLATES) + num_classes = len(CIFAR10_LABELS) + text_embs = text_embs.reshape(num_classes, n_templates, -1).mean(axis=1) + + img_embs = model.apply(state_params, jnp.array(image_batch), jnp.zeros((len(image_batch), model.cfg.text_max_length), dtype=jnp.int32), method=CLIPModel.encode_image, deterministic=True) + sims = jnp.matmul(img_embs, text_embs.T) + preds = jnp.argmax(sims, axis=-1) + return preds + +def train_main(train_cfg: TrainConfig, model_cfg: CLIPConfig): + + model_cfg.apply_model_size_presets() + + tokenizer_fn = None + if train_cfg.tokenizer_path is not None: + tokenizer_fn = load_tokenizer(train_cfg.tokenizer_path) + if tokenizer_fn is None: + print("Tokenizer load failed; falling back to simple_whitespace_tokenizer") + + key = jax.random.PRNGKey(0) + state, model = create_state(key, model_cfg, train_cfg.lr) + model.cfg = model_cfg + + ds_train, ds_val = make_datasets(train_cfg.batch_size, model_cfg.image_size) + + os.makedirs(train_cfg.workdir, exist_ok=True) + + print("Starting training loop (toy CIFAR-10) ...") + for epoch in range(1, train_cfg.epochs + 1): + t0 = time.time() + + for batch in ds_train: + imgs = [] + labels = [] + for ex in batch: + img = ex["image"].astype("float32") / 255.0 + if model_cfg.image_size != 32: + import cv2 + img = cv2.resize(img, (model_cfg.image_size, model_cfg.image_size), interpolation=cv2.INTER_LINEAR) + img = (img - 0.5) * 2.0 + imgs.append(img) + labels.append(int(ex["label"])) + imgs = jnp.array(imgs) + captions = [f"a photo of a {CIFAR10_LABELS[l]}" for l in labels] + if tokenizer_fn is None: + tokens, _ = simple_whitespace_tokenizer(captions, max_length=model_cfg.text_max_length) + else: + tokens = tokenizer_fn(captions, max_length=model_cfg.text_max_length) + state = train_step(state, model, imgs, jnp.array(tokens)) + t1 = time.time() + + val_images = [] + val_labels = [] + for i, ex in enumerate(ds_val): + for elem in ex: + img = elem["image"].astype("float32") / 255.0 + if model_cfg.image_size != 32: + import cv2 + img = cv2.resize(img, (model_cfg.image_size, model_cfg.image_size), interpolation=cv2.INTER_LINEAR) + img = (img - 0.5) * 2.0 + val_images.append(img) + val_labels.append(int(elem["label"])) + if len(val_images) >= 256: + break + break + if len(val_images) > 0: + preds = zero_shot_eval(state.params, model, tokenizer_fn, val_images) + # calculate toy accuracy + preds = list(map(int, list(preds))) + acc = sum([p == l for p, l in zip(preds, val_labels[:len(preds)])]) / max(1, len(preds)) + else: + acc = 0.0 + + ckpt_path = os.path.join(train_cfg.workdir, f"clip_epoch{epoch}.npz") + flat = jax.tree_map(lambda x: np.array(x), state.params) + + np_save_dict = {} + def collect(prefix, d): + for k, v in d.items(): + if isinstance(v, dict): + collect(prefix + "/" + k, v) + else: + np_save_dict[prefix + "/" + k] = v + try: + collect("", flat) + np.savez_compressed(ckpt_path, **np_save_dict) + print("Saved checkpoint:", ckpt_path) + except Exception as e: + print("Checkpoint save failed:", e) + + print(f"Epoch {epoch} done — time {t1-t0:.1f}s — toy zero-shot acc: {acc:.4f}") + + return state + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--workdir", type=str, default="/tmp/clip_run") + parser.add_argument("--tokenizer_path", type=str, default=None) + parser.add_argument("--model_size", type=str, default="ViT-B/32") + parser.add_argument("--dtype", type=str, default="float32") + args = parser.parse_args() + + model_cfg = CLIPConfig() + model_cfg.model_size = args.model_size + model_cfg.dtype = args.dtype + model_cfg.apply_model_size_presets() + + train_cfg = TrainConfig(batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, workdir=args.workdir, tokenizer_path=args.tokenizer_path) + + train_main(train_cfg, model_cfg) diff --git a/bonsai/models/efficientnet/README.md b/bonsai/models/efficientnet/README.md deleted file mode 100644 index 679dbb16..00000000 --- a/bonsai/models/efficientnet/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Efficientnet in JAX - -This directory contains a pure JAX implementation of the [Efficientnet](https://arxiv.org/abs/1905.11946), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API. - - -## Tested on: -*(Last Updated: 2025-09-19)* - - - -| Model Name | Config | CPU | GPU A100 (1x) | GPU H100 (1x) | GPU A100 (8x) | GPU H100 (8x) | TPU v2 (8x) | TPU v5e (1x) | -| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | -| **Model** | | | | | | | | | -| [Efficientnet](https://arxiv.org/abs/1905.11946) | ✅ Supported | ✅ Runs | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check |❔ Needs check | ❔ Needs check | - - -### Running this model - -Run Efficientnet in JAX [see code](modeling.py). - -```sh -python3 -m bonsai.models.efficientnet.tests.run_model -``` - - -## How to contribute to this model - -We welcome contributions! You can contribute to this model via the following: -* Add a model config variant from the above `🟡 Not started` to `class ModelConfig` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR. -* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `✅ Runs` or `⛔️ Not supported`. diff --git a/bonsai/models/efficientnet/__init__.py b/bonsai/models/efficientnet/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bonsai/models/efficientnet/modeling.py b/bonsai/models/efficientnet/modeling.py deleted file mode 100644 index cce0eb63..00000000 --- a/bonsai/models/efficientnet/modeling.py +++ /dev/null @@ -1,333 +0,0 @@ -import dataclasses -import math -from functools import partial -from typing import Literal, Sequence - -import jax -import jax.numpy as jnp -from flax import nnx - - -@dataclasses.dataclass(frozen=True) -class BlockConfig: - input_filters: int - output_filters: int - kernel_size: int - num_repeat: int - expand_ratio: int - strides: int - se_ratio: float - padding: int | Literal["SAME"] - - -@dataclasses.dataclass(frozen=True) -class BlockConfigs: - items: Sequence[BlockConfig] - - @classmethod - def default_block_config(cls): - # (in, out, kernel, repeat, expand, stride, se_ratio) - return cls( - [ - BlockConfig(32, 16, 3, 1, 1, 1, 0.25, 1), - BlockConfig(16, 24, 3, 2, 6, 2, 0.25, 1), - BlockConfig(24, 40, 5, 2, 6, 2, 0.25, 2), - BlockConfig(40, 80, 3, 3, 6, 2, 0.25, 1), - BlockConfig(80, 112, 5, 3, 6, 1, 0.25, 2), - BlockConfig(112, 192, 5, 4, 6, 2, 0.25, 2), - BlockConfig(192, 320, 3, 1, 6, 1, 0.25, 1), - ] - ) - - @classmethod - def tf_block_config(cls): - # (in, out, kernel, repeat, expand, stride, se_ratio) - return cls( - [ - BlockConfig(32, 16, 3, 1, 1, 1, 0.25, "SAME"), - BlockConfig(16, 24, 3, 2, 6, 2, 0.25, "SAME"), - BlockConfig(24, 40, 5, 2, 6, 2, 0.25, "SAME"), - BlockConfig(40, 80, 3, 3, 6, 2, 0.25, "SAME"), - BlockConfig(80, 112, 5, 3, 6, 1, 0.25, "SAME"), - BlockConfig(112, 192, 5, 4, 6, 2, 0.25, "SAME"), - BlockConfig(192, 320, 3, 1, 6, 1, 0.25, "SAME"), - ] - ) - - -@dataclasses.dataclass(frozen=True) -class ModelConfig: - width_coefficient: float - depth_coefficient: float - resolution: int - dropout_rate: float - stem_conv_padding: int | Literal["SAME"] - bn_momentum: float - bn_epsilon: float - block_configs: BlockConfigs - num_classes: int = 1000 - - @classmethod - def b0(cls, num_classes=1000): - return cls(1.0, 1.0, 224, 0.2, 1, 0.99, 1e-5, BlockConfigs.default_block_config(), num_classes) - - @classmethod - def b1(cls, num_classes=1000): - return cls(1.0, 1.1, 240, 0.2, 1, 0.99, 1e-5, BlockConfigs.default_block_config(), num_classes) - - @classmethod - def b2(cls, num_classes=1000): - return cls(1.1, 1.2, 260, 0.3, 1, 0.99, 1e-5, BlockConfigs.default_block_config(), num_classes) - - @classmethod - def b3(cls, num_classes=1000): - return cls(1.2, 1.4, 300, 0.3, 1, 0.99, 1e-5, BlockConfigs.default_block_config(), num_classes) - - @classmethod - def b4(cls, num_classes=1000): - return cls(1.4, 1.8, 380, 0.4, 1, 0.99, 1e-5, BlockConfigs.default_block_config(), num_classes) - - @classmethod - def b5(cls, num_classes=1000): - return cls(1.6, 2.2, 456, 0.4, "SAME", 1e-1, 1e-3, BlockConfigs.tf_block_config(), num_classes) - - @classmethod - def b6(cls, num_classes=1000): - return cls(1.8, 2.6, 528, 0.5, "SAME", 1e-1, 1e-3, BlockConfigs.tf_block_config(), num_classes) - - @classmethod - def b7(cls, num_classes=1000): - return cls(2.0, 3.1, 600, 0.5, "SAME", 1e-1, 1e-3, BlockConfigs.tf_block_config(), num_classes) - - -def round_filters(filters: int, width_coefficient: float, divisor: int = 8) -> int: - """Round number of filters based on width multiplier.""" - filters *= width_coefficient - new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor) - if new_filters < 0.9 * filters: - new_filters += divisor - return int(new_filters) - - -def round_repeats(repeats: int, depth_coefficient: float) -> int: - """Round number of repeats based on depth multiplier.""" - return math.ceil(depth_coefficient * repeats) - - -class SqueezeAndExcitation(nnx.Module): - """Squeeze-and-Excitation block""" - - def __init__(self, in_channels: int, se_channels: int, *, rngs: nnx.Rngs): - # conv_reduce - self.conv1 = nnx.Conv( - in_channels, - se_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - use_bias=True, - rngs=rngs, - ) - # conv_expand - self.conv2 = nnx.Conv( - se_channels, - in_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - use_bias=True, - rngs=rngs, - ) - - def __call__(self, x: jax.Array) -> jax.Array: - # 1. Squeeze - squeeze = jnp.mean(x, axis=(1, 2), keepdims=True) - - # 2. Excitation - excitation = self.conv1(squeeze) - excitation = nnx.silu(excitation) - excitation = self.conv2(excitation) - excitation = nnx.sigmoid(excitation) - - # 3. Scale - return x * excitation - - -class MBConv(nnx.Module): - """Mobile Inverted Bottleneck Convolution (MBConv) block.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - strides: int, - expand_ratio: int, - se_ratio: float, - padding: int | Literal["SAME"], - *, - cfg: ModelConfig, - rngs: nnx.Rngs, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.strides = strides - self.expand_ratio = expand_ratio - self.has_skip = strides == 1 and in_channels == out_channels - - # Expansion phase (1x1 Conv) - skipped if expand_ratio is 1 - expanded_channels = in_channels * expand_ratio - if expand_ratio != 1: - self.expand_conv = nnx.Conv(in_channels, expanded_channels, kernel_size=(1, 1), use_bias=False, rngs=rngs) - self.bn0 = nnx.BatchNorm( - expanded_channels, momentum=cfg.bn_momentum, epsilon=cfg.bn_epsilon, use_running_average=True, rngs=rngs - ) - else: - self.expand_conv = None - self.bn0 = None - - # Depthwise convolution - self.depthwise_conv = nnx.Conv( - expanded_channels, - expanded_channels, - kernel_size=(kernel_size, kernel_size), - strides=(strides, strides), - feature_group_count=expanded_channels, - padding=padding, - use_bias=False, - rngs=rngs, - ) - self.bn1 = nnx.BatchNorm( - expanded_channels, momentum=cfg.bn_momentum, epsilon=cfg.bn_epsilon, use_running_average=True, rngs=rngs - ) - - # Squeeze-and-Excitation layer - if 0 < se_ratio and se_ratio <= 1: - se_channels = max(1, int(in_channels * se_ratio)) - self.se = SqueezeAndExcitation(expanded_channels, se_channels, rngs=rngs) - else: - self.se = None - - # Projection phase (1x1 Conv) - self.project_conv = nnx.Conv(expanded_channels, out_channels, kernel_size=(1, 1), use_bias=False, rngs=rngs) - self.bn2 = nnx.BatchNorm( - out_channels, momentum=cfg.bn_momentum, epsilon=cfg.bn_epsilon, use_running_average=True, rngs=rngs - ) - - def __call__(self, x: jax.Array, training: bool) -> jax.Array: - identity = x - - is_inference = not training - - if self.expand_conv is not None: - x = self.expand_conv(x) - x = self.bn0(x, use_running_average=is_inference) - x = nnx.silu(x) - - x = self.depthwise_conv(x) - x = self.bn1(x, use_running_average=is_inference) - x = nnx.silu(x) - - if self.se is not None: - x = self.se(x) - - x = self.project_conv(x) - x = self.bn2(x, use_running_average=is_inference) - - if self.has_skip: - x += identity - - return x - - -class EfficientNet(nnx.Module): - """ - EfficientNet implementation. - See: https://arxiv.org/abs/1905.11946 - """ - - def __init__( - self, - cfg: ModelConfig, - *, - rngs: nnx.Rngs, - ): - super().__init__() - self.cfg = cfg - out_channels = round_filters(32, cfg.width_coefficient) - self.stem_conv = nnx.Conv( - 3, - out_channels, - kernel_size=(3, 3), - strides=(2, 2), - padding=cfg.stem_conv_padding, - use_bias=False, - rngs=rngs, - ) - self.stem_bn = nnx.BatchNorm( - out_channels, momentum=cfg.bn_momentum, epsilon=cfg.bn_epsilon, use_running_average=True, rngs=rngs - ) - - # Build blocks - self.blocks = nnx.List() - for bc in cfg.block_configs.items: - input_filters = round_filters(bc.input_filters, cfg.width_coefficient) - output_filters = round_filters(bc.output_filters, cfg.width_coefficient) - num_repeat = round_repeats(bc.num_repeat, cfg.depth_coefficient) - - for i in range(num_repeat): - strides = bc.strides if i == 0 else 1 - in_ch = input_filters if i == 0 else output_filters - - self.blocks.append( - MBConv( - in_ch, - output_filters, - kernel_size=bc.kernel_size, - strides=strides, - expand_ratio=bc.expand_ratio, - se_ratio=bc.se_ratio, - padding=bc.padding, - cfg=cfg, - rngs=rngs, - ) - ) - # Head - in_channels = round_filters(cfg.block_configs.items[-1].output_filters, cfg.width_coefficient) - out_channels = round_filters(1280, cfg.width_coefficient) - self.head_conv = nnx.Conv( - in_channels, out_channels, kernel_size=(1, 1), padding="SAME", use_bias=False, rngs=rngs - ) - - self.head_bn = nnx.BatchNorm(out_channels, use_running_average=True, rngs=rngs) - - self.gap = partial(jnp.mean, axis=(1, 2)) - self.dropout = nnx.Dropout(rate=cfg.dropout_rate) - self.classifier = nnx.Linear(out_channels, cfg.num_classes, rngs=rngs) - - def __call__(self, x: jax.Array, training: bool = False) -> jax.Array: - # Stem - x = self.stem_conv(x) - x = self.stem_bn(x, use_running_average=not training) - x = nnx.silu(x) - - # Blocks - for block in self.blocks: - x = block(x, training=training) - - # Head - x = self.head_conv(x) - x = self.head_bn(x, use_running_average=not training) - x = nnx.silu(x) - - x = self.gap(x) - x = self.dropout(x, deterministic=not training) - x = self.classifier(x) - return x - - -@jax.jit -def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> jax.Array: - model = nnx.merge(graphdef, state) - return model(x) diff --git a/bonsai/models/efficientnet/params.py b/bonsai/models/efficientnet/params.py deleted file mode 100644 index 5dd968a7..00000000 --- a/bonsai/models/efficientnet/params.py +++ /dev/null @@ -1,161 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np -import timm -from flax import nnx - -from bonsai.models.efficientnet import modeling as model_lib - - -def get_timm_pretrained_weights(model_name: str = "efficientnet_b0"): - """ - Downloads and returns pre-trained EfficientNet weights from the 'timm' library. - - This requires PyTorch and timm to be installed: - !pip install -q torch timm - - Returns: - A dictionary mapping pre-trained layer names to NumPy arrays. - """ - # Map to correct timm model names. Some larger models use specific checkpoints. - timm_name_map = { - "efficientnet_b0": "efficientnet_b0", - "efficientnet_b1": "efficientnet_b1", - "efficientnet_b2": "efficientnet_b2", - "efficientnet_b3": "efficientnet_b3", - "efficientnet_b4": "efficientnet_b4", - "efficientnet_b5": "tf_efficientnet_b5_ap", # AdvProp - "efficientnet_b6": "tf_efficientnet_b6_ap", # AdvProp - "efficientnet_b7": "tf_efficientnet_b7_ap", # AdvProp - } - timm_model_name = timm_name_map.get(model_name) - if not timm_model_name: - raise ValueError(f"No timm mapping for '{model_name}'. Available models are: {list(timm_name_map.keys())}") - m = timm.create_model(timm_model_name, pretrained=True) - m.eval() - - # Convert weights to a dictionary of numpy arrays - return {k: v.numpy() for k, v in m.state_dict().items()} - - -def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig) -> dict: - """ - Creates a mapping from the JAX model's parameter names to the timm model's names. - This version correctly handles the different architectures of the MBConv blocks. - """ - bn_map = { - "scale": "weight", - "bias": "bias", - "mean": "running_mean", - "var": "running_var", - } - name_map = {} - - # 1. Stem - name_map["stem_conv"] = {"kernel": "conv_stem.weight"} - name_map["stem_bn"] = {jax_n: f"bn1.{timm_n}" for jax_n, timm_n in bn_map.items()} - - # 2. Blocks - block_configs = model_lib.BlockConfigs.default_block_config().items - total_jax_block_idx = 0 - for i, bc in enumerate(block_configs): - num_repeat = model_lib.round_repeats(bc.num_repeat, cfg.depth_coefficient) - for j in range(num_repeat): - jax_base = f"blocks.{total_jax_block_idx}" - timm_base = f"blocks.{i}.{j}" - - if bc.expand_ratio != 1: - name_map[f"{jax_base}.expand_conv"] = {"kernel": f"{timm_base}.conv_pw.weight"} - name_map[f"{jax_base}.bn0"] = {jax_n: f"{timm_base}.bn1.{timm_n}" for jax_n, timm_n in bn_map.items()} - name_map[f"{jax_base}.depthwise_conv"] = {"kernel": f"{timm_base}.conv_dw.weight"} - name_map[f"{jax_base}.bn1"] = {jax_n: f"{timm_base}.bn2.{timm_n}" for jax_n, timm_n in bn_map.items()} - name_map[f"{jax_base}.project_conv"] = {"kernel": f"{timm_base}.conv_pwl.weight"} - name_map[f"{jax_base}.bn2"] = {jax_n: f"{timm_base}.bn3.{timm_n}" for jax_n, timm_n in bn_map.items()} - else: # This block handles the first MBConv layer where expand_ratio = 1 - name_map[f"{jax_base}.depthwise_conv"] = {"kernel": f"{timm_base}.conv_dw.weight"} - name_map[f"{jax_base}.bn1"] = {jax_n: f"{timm_base}.bn1.{timm_n}" for jax_n, timm_n in bn_map.items()} - name_map[f"{jax_base}.project_conv"] = {"kernel": f"{timm_base}.conv_pw.weight"} - name_map[f"{jax_base}.bn2"] = {jax_n: f"{timm_base}.bn2.{timm_n}" for jax_n, timm_n in bn_map.items()} - - # Squeeze-and-Excitation is the same for both block types - name_map[f"{jax_base}.se.conv1"] = { - "kernel": f"{timm_base}.se.conv_reduce.weight", - "bias": f"{timm_base}.se.conv_reduce.bias", - } - name_map[f"{jax_base}.se.conv2"] = { - "kernel": f"{timm_base}.se.conv_expand.weight", - "bias": f"{timm_base}.se.conv_expand.bias", - } - - total_jax_block_idx += 1 - - # 3. Head - name_map["head_conv"] = {"kernel": "conv_head.weight"} - name_map["head_bn"] = {jax_n: f"bn2.{timm_n}" for jax_n, timm_n in bn_map.items()} - name_map["classifier"] = { - "kernel": "classifier.weight", - "bias": "classifier.bias", - } - - return name_map - - -def load_pretrained_weights(model: model_lib.EfficientNet, pretrained_weights: dict): - """ - Loads pre-trained weights by directly modifying the JAX model's attributes in-place. - """ - name_map = _get_key_and_transform_mapping(model.cfg) - - timm_to_jax_map = {} - for jax_module_path, params_map in name_map.items(): - for jax_param_name, timm_param_name in params_map.items(): - path_parts = jax_module_path.split(".") - path_tuple = (*path_parts, jax_param_name) - timm_to_jax_map[timm_param_name] = path_tuple - - for timm_name, weight in pretrained_weights.items(): - if timm_name not in timm_to_jax_map: - continue - - path = timm_to_jax_map[timm_name] - weight_np = weight - param_name = path[-1] - - if param_name == "kernel" and len(weight_np.shape) == 4: - weight_np = np.transpose(weight_np, (2, 3, 1, 0)) - if param_name == "kernel" and len(weight_np.shape) == 2: - weight_np = np.transpose(weight_np, (1, 0)) - - target_module = model - for part in path[:-1]: - if part.isdigit(): - target_module = target_module[int(part)] - else: - target_module = getattr(target_module, part) - - param_to_update = getattr(target_module, param_name) - if param_to_update.shape != weight_np.shape: - raise ValueError( - f"Shape mismatch for '{'.'.join(path)}': " - f"JAX model has {param_to_update.shape}, " - f"pre-trained weight has {weight_np.shape}." - ) - - param_to_update.value = jnp.array(weight_np) - - return model - - -def create_efficientnet_from_pretrained(version: int): - """ - Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. - - Returns: - A flax.nnx.Model instance with loaded parameters. - """ - if version < 0 or version >= 8: - raise ValueError(f"Expected efficientnet version between 0 and 7, but got {version}") - config = getattr(model_lib.ModelConfig, f"b{version}")(1000) - torch_state_dict = get_timm_pretrained_weights(f"efficientnet_b{version}") - model = model_lib.EfficientNet(config, rngs=nnx.Rngs(0)) - return load_pretrained_weights(model, torch_state_dict) diff --git a/bonsai/models/efficientnet/tests/EfficientNet_ImageNet_validation_example.ipynb b/bonsai/models/efficientnet/tests/EfficientNet_ImageNet_validation_example.ipynb deleted file mode 100644 index a77e0865..00000000 --- a/bonsai/models/efficientnet/tests/EfficientNet_ImageNet_validation_example.ipynb +++ /dev/null @@ -1,257 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# **ImageNet Classification with EfficientNet-B0**\n", - "\n", - "This notebook demonstrates how to use the EfficientNet-B0 model from the Bonsai library to perform ImageNet classification. Note that this notebook loads a **trained** model. It serves to validate the model's architecture and demonstrate the full inference pipeline.\n", - "\n", - "*This colab demonstrates the EfficientNet implementation from the [Bonsai library](https://github.com/jax-ml/bonsai).*" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **1. Set-up**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -q git+https://github.com/jax-ml/bonsai@main\n", - "!pip install -q pillow matplotlib requests\n", - "!pip install -q torch timm\n", - "!pip install -q ipywidgets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import requests\n", - "from PIL import Image\n", - "\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX device: {jax.devices()[0].platform}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **2. Load Utilities**\n", - "\n", - "Download an image, load ImageNet class names, and preprocess the image according to EfficientNet's requirements." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "def load_imagenet_classes():\n", - " \"\"\"Load ImageNet class names from a common source.\"\"\"\n", - " url = \"https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt\"\n", - " response = requests.get(url)\n", - " response.raise_for_status()\n", - " classes = response.text.strip().split(\"\\n\")\n", - " return classes\n", - "\n", - "\n", - "def preprocess_image(image_url, target_size=(224, 224)):\n", - " \"\"\"Download and preprocess an image for EfficientNet inference.\"\"\"\n", - "\n", - " headers = {\"User-Agent\": \"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\"}\n", - "\n", - " response = requests.get(image_url, stream=True, headers=headers)\n", - " response.raise_for_status()\n", - " image = Image.open(response.raw).convert(\"RGB\")\n", - "\n", - " # Resize to slightly larger, then center crop (more standard approach)\n", - " resize_size = int(target_size[0] / 0.875) # ~256 for 224\n", - " image = image.resize((resize_size, resize_size), Image.Resampling.BICUBIC)\n", - "\n", - " # Center crop\n", - " left = (resize_size - target_size[0]) // 2\n", - " top = (resize_size - target_size[1]) // 2\n", - " image = image.crop((left, top, left + target_size[0], top + target_size[1]))\n", - "\n", - " # Convert to array and normalize\n", - " image_array = np.array(image).astype(np.float32) / 255.0\n", - "\n", - " # ImageNet normalization\n", - " mean = np.array([0.485, 0.456, 0.406])\n", - " std = np.array([0.229, 0.224, 0.225])\n", - " image_array = (image_array - mean) / std\n", - "\n", - " return jnp.array(image_array[None, ...]), image\n", - "\n", - "\n", - "imagenet_classes = load_imagenet_classes()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **3. Load EfficientNet Model and Run Inference**\n", - "\n", - "Now let's load the EfficientNet-B0 model from the Bonsai library. We will initialize it with random weights." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Import model creation and weight loading functions\n", - "from bonsai.models.efficientnet import modeling\n", - "from bonsai.models.efficientnet import params as params_lib\n", - "\n", - "# 1. Define model name\n", - "jax_model = params_lib.create_efficientnet_from_pretrained(0)\n", - "\n", - "# Prepare the input image\n", - "image_url = \"https://upload.wikimedia.org/wikipedia/commons/thumb/3/3c/Giant_Panda_2004-03-2.jpg/1024px-Giant_Panda_2004-03-2.jpg\"\n", - "# 3. Use the correct resolution for B0\n", - "input_tensor, original_image = preprocess_image(image_url, target_size=(224, 224))\n", - "\n", - "# Run inference with your JAX model\n", - "jax_logits = jax_model(input_tensor, training=False)\n", - "\n", - "# Post-process the output using JAX functions\n", - "jax_probs = jax.nn.softmax(jax_logits, axis=-1)\n", - "\n", - "# Get top-5 predictions\n", - "top_k = 5\n", - "top_probs, top_indices = jax.lax.top_k(jax_probs[0], k=top_k)\n", - "\n", - "# Ensure the results are on the host for printing\n", - "top_probs = np.array(top_probs)\n", - "top_indices = np.array(top_indices)\n", - "\n", - "# Print the results\n", - "print(f\"Input image shape: {input_tensor.shape}\")\n", - "print(f\"Output logits shape: {jax_logits.shape}\")\n", - "print(\"\\n--- Top 5 Predictions (from BONSAI JAX model B0) ---\")\n", - "for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):\n", - " print(f\"{i + 1}. {imagenet_classes[idx]}: {prob:.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **4. Visualize Results**\n", - "\n", - "Let's visualize the input image and the top-5 (random) predictions from our untrained model." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABZYAAAJOCAYAAAAkpFLLAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs/Qe0Ldl6FQbPyjucfHPf2/HlpJwBCwSDaMAyDAsbYxjCCDREsIwjCAlkDCInEQ0IhNEYEsYWDNmYJGFjCwsE/gXKenqh840n7Fj5H3N+a9Wuc7v11O+p+z0Zrdm9+9y7z95Vq1aqrvnNb35R3/c9AgICAgICAgICAgICAgICAgICAgICAt4g4jf6wYCAgICAgICAgICAgICAgICAgICAgIBALAcEBAQEBAQEBAQEBAQEBAQEBAQEBHzcCIrlgICAgICAgICAgICAgICAgICAgICAjwuBWA4ICAgICAgICAgICAgICAgICAgICPi4EIjlgICAgICAgICAgICAgICAgICAgICAjwuBWA4ICAgICAgICAgICAgICAgICAgICPi4EIjlgICAgICAgICAgICAgICAgICAgICAjwuBWA4ICAgICAgICAgICAgICAgICAgICPi4EIjlgICAgICAgICAgICAgICAgICAgICAjwuBWA4ICAgICAgICAgICAgICAgI+BmNf/JP/gmiKNJPj9/wG34DnnnmmTftHH/tr/01neMjH/nIm3bMgIBPJQKxHBDwJsPfKL73e7/3p0Xfrtdr/N7f+3sv3RzfyM30f/qf/qe3vG0BAQEBAQEBAQEBAQEBnzj47PZGXm/0efDNeBZ+vderr776k37/5/7cn3vpOycnJ/jcz/1c/NW/+lfRdR3+v4Q/8Af+AL7927/9U92MgIC3HOlbf4qAgIBPNbH8+37f7xtu1AEBAQEBAQEBAQEBAQH/duBv/I2/cenv3/zN34x/+A//4Wvef8973vNJa9PXf/3X49lnn7303tHR0Rv67p07d/AH/+Af1J/v37+v6/mNv/E34kd/9EfxDd/wDfhk43/4H/6HT4jUJrH8q3/1r8a/9+/9e5fe/3W/7tfh1/yaX4OiKN7EVgYEfOoQiOWAgICAgICAgICAgICAgICA/w/iP/6P/+NLf/9//p//R8Ty4+9/MvFLfskvwed8zud8Qt89PDy81Pbf/Jt/M971rnfhG7/xG/Hf/Xf/HbIse813SPxWVYXJZII3G693vp8KkiTRKyDg3xYEK4yAgE8C6Mu0t7eHl156SRFL/vnatWv4L/6L/wJt2w6fo88SU37+6B/9o/gTf+JP4Omnn8Z0OsUXf/EX4/u///svHZPq49dTII89oHg8noegatmnFNEa4+MBP8/vMUrMmzxv9jzu7/k9vwd93+OFF17Ar/yVvxIHBwe4efMm/tgf+2OXvs+b/Nd+7dfisz/7s/Xd+XyOn/Nzfg6+67u+6zXnevjwoaK4PBaj2r/+1/96fN/3fZ/Oz9SqMX74h39YUWCmSPF/Ivg/L3/37/7dj+vaAgICAgICAgICAgIC/m3GarXC7/ydvxNPPvmklLIkavnMyWe5MfjM9Vt/62/F3/ybf1Of4TMWn+H+z//z//y4z7lYLC49636imM1m+IIv+AJdAxXMj7fzfe97n67pf//f/3f9js/cX/7lX44bN27off6eVhqP48UXX9SzOZ9Nr1+/jq/+6q9GWZav+dzreSyTyP5Tf+pP4QMf+ID6iM/Gv/gX/+LBDpPtY3v/+l//68MzOI/zsTyW/9yf+3PDtTzxxBP4qq/6KpydnV36DJ//3//+9+MHf/AH8fN+3s9T39y+fRt/+A//4de0+8/8mT+j4/Ezx8fHelb+lm/5lk9gBAICPjaCYjkg4JME3lR/0S/6Rfj8z/983cT/0T/6RyJg3/a2t+Erv/IrL32W6T68EfNmst1uddP6ki/5Evybf/NvdIN8o+AN7s//+T+v43/pl34p/v1//9/X+5/2aZ/2CV3Dl33ZlymFiilI/+v/+r/i9//+3y9S9y/+xb+o9v2hP/SHdHMnYU4vrH/n3/l39L2Liwv85b/8l/Ef/of/IX7Tb/pNura/8lf+ivrjn//zf47P+IzPGG7Qv/yX/3K9xza/+93vxt/5O39H5PLj+IEf+AH8rJ/1s3Qj/W/+m/9G/0Pwbd/2bfqfg7/9t/+2rjcgICAgICAgICAgIOBnMkge/4pf8Ssk6qGlBJ+9/v7f//v4L//L/1IkLAVNY/wf/8f/gW/91m/Fb//tv10kJwlPkqZ8RiOp+UZA0nO5XCLPcz3z8bn3He94xyd8DR/60Iek8h3baXznd36nnv9IMF+9elXk7927d0VCe+KZz8N/7+/9PV03n0n/s//sP9N3N5sNfv7P//l4/vnndZ0kcmkdwmO+EfB4JIipzP5P/9P/FE3T4J/+038qtTgJXB6L73/e530evuIrvkLf4XP/xxJyUQj2C37BL9Bz8I/8yI/oOf5f/It/gf/7//6/L6mmT09PNR58tv8P/oP/QLWR/uv/+r8Wyc32ePsOXhdFWL/jd/wOcQr/+l//a3zP93wP/qP/6D/6hMchIOB10QcEBLyp+KZv+iaGfft/8S/+xfDer//1v17vff3Xf/2lz37mZ35m/9mf/dnD3z/84Q/rc9PptH/xxReH97/ne75H73/1V3/18N4Xf/EX6/U4eK6nn356+Pv9+/f13a/7uq97Q+3/ru/6Ln3+b/2tvzW8x+/yva/4iq8Y3muapr9z504fRVH/Dd/wDcP7p6enaj/bMf5sWZaXzsPP3bhxo//yL//y4b2//bf/ts7zJ//knxzea9u2/5Iv+RK9z771+Pk//+f3H/jAB/rtdju813Vd/0Vf9EX9O97xjjd0rQEBAQEBAQEBAQEBAf824au+6qv07OTx7d/+7fr77//9v//S5371r/7Vepb74Ac/OLzHz/H1vd/7vcN7H/3oR/vJZNJ/6Zd+6U967m/91m/tf8Nv+A39X//rf73/X/6X/6X/mq/5mn42m/VXr17tn3/++Z/0+3y+ffe7361nWL5+6Id+qP/tv/23q02//Jf/8kvtjOO4/4Ef+IFL3/+Nv/E39rdu3eofPHhw6f1f82t+TX94eNiv12v9nc+bPMa3fdu3DZ9ZrVb929/+dr3PZ+Kf6Pn6O7/zO/UZtutx8HnUYz6fX3omfpwv4LM/ce/evT7P8/4X/sJfqGdfj2/8xm/U5/7qX/2rl/qH733zN3/z8B6fs2/evNn/ql/1q4b3fuWv/JX9+973vp+wnwMC3kwEK4yAgE8ifstv+S2X/k47CEZfHwdVt1TiejDSSaXz//a//W/4VIJRVw9GjBmN5X2dEVsPRpGZNjW+Ln6W0WqvSn706JGiuvz+v/pX/2r4HNOXGI2lqtkjjmMpt8fg9xlNZoSW6ucHDx7oRRsNRsR/7Md+TNH3gICAgICAgICAgICAn8ngMySfx6hgHYPWGHyWo6J3jC/8wi+U/YXHU089JdtDqpx/MmsLPp990zd9E/6T/+Q/0TMtPZH5PT6n/ff//X//htpLu0MqjflitiwtHX7ZL/tlr7GzoF3ke9/73uHvvBZmrjIDln/2z4h88Rnx/Px8ePZkn9y6dUuKXg9aRnh18ccCz0FF9Nd93de95nd8/+MFM5lpHUk1NZ99PfhMTHtIZgqPQVvNsQc1n7PJF4yfv/lMTqsPKp4DAt5qBCuMgIBPErz30hj0OmIqy+N4vTShd77znUr1+VSC/1MxBv2SeV1MPXr8ff7Pwxj0l2IKFP9Hoa7r4f1xteCPfvSjusHzpj7G29/+9kt//+AHP6j/WaDHM1+vh3v37l0i5wMCAgICAgICAgICAn6mgc9YtHrY39+/9D5JW//7N/Isul6v5XFMK0QKfcbgc+5PVJDuZ//sny2RFAnUNwJaWtDKgSQtnzXZHnogP47xcyTBttGT+C/9pb+k10/0jOivmc+YjxPBFEj9ZPjxH/9x9Sf74c2A7//Hz03C+LnnnnvN+Ny5c+c17SavQKsLD1pjsL9JOPM6f+Ev/IWywKCVZEDAm41ALAcEfJLwZld+5c3k8WILxJtRIOHjuYaf6LrGbfsf/8f/UcUKGLWmlxf/x4Df+4N/8A/qxvzxgqpngl7OjD6/Hh4nowMCAgICAgICAgICAgJ+avju7/5ueSiP8eEPf/g1Be7GYNFA+ga/EbB2Dr2GfzKwyP3rPSNSzft6NXp+KrWGfjrhjTx/M2jA/v6O7/gOZQVTZU2v7K/92q+Vl3NAwJuJQCwHBPw0BK0cHseP/uiPXrpZMyr5ejYaj0c0P5F0nDcbLCjAaOv//D//z5fa83j60NNPP62iEoyGj1XLVCiPwWMRtM14I//TERAQEBAQEBAQEBAQ8DMRfMaiepUWgmPVMjNJ/e/fyLMon8+oTGZBv3/4D//hpd/fvHnzY7aBz62PZ+++2eDxeX0UWv1kz4i85u///u8XGTt+Pn0j5DeL8NHeg6rtj6VafqPP4b7/eW7/nEvQHoOE/Sf6vEuC/su+7Mv04rFY7I92JP/tf/vfSgkeEPBmIXgsBwT8NMS3f/u3X/IIZgVeVnD1VV79DY3/M8CUH4/v+77vU9XYMTxBy7SgT3VUdRxF5fX8s3/2zy59jupj2mQw9Wkcef6zf/bPXvocFc8/9+f+XPzFv/gX8corr7zmfOM+CQgICAgICAgICAgI+JmKX/pLf6nI1m/8xm+89P6f+BN/QuTn+BmT4DPauA7OCy+8gL/zd/6O7BT4XEeBE8nO8csTla/3HEY/43/5L/8lfvEv/sV4K8G2/apf9aukziVp/DjGbWOfvPzyyxJAeVDc9BNZaIzBc/C59vWUv+PnXRK7b+QZnP1H24s//af/9KXv/5W/8lfkC01/6Y8Xj9tS8vj0o+bxx7aUAQFvBoJiOSDgpyFo40Avqq/8yq9EWZb4k3/yT+LKlSv4r/6r/2r4zJd/+Zfjj//xPy4ylsXz6Bf1F/7CX8D73vc+XFxcXEoR4k3kW7/1W+WNxajq+9//fr0+Wfh3/91/V2rlL/3SL9WNkZFXtpXtWi6Xw+dolUEfKBaSoEr53e9+N/7u3/27g4fXOOpLspl99IEPfECFDRjdvXv3rv5HiIUKSLIHBAQEBAQEBAQEBAT8TAaL2dG64nf/7t+Nj3zkI/j0T/90/IN/8A9EFrNgHAVLY/A5kc+YLPZHdTItFIg3YqHwRV/0RfjMz/xMFWln3R0S1Cy6RyuM3/W7fhfeanzDN3yDMmDp6cxnRD5v8lmS7aBq2z9X8nck2llkkKQ36/z8jb/xN15T6+f1wL78db/u14kIprqbhDnFUP/0n/5T/e63/tbfqs+xACLPyWd2ejLTE5rtej2lNVXE7F8e61f8il8h9TL7/XM/93MvFep7o2AQgCpyeirfuHEDP/RDP6Tr5bP4417bAQE/VQRiOSDgpyF4g2NFWBLKJIxJtvJGwBve2Dfpm7/5m+WT9J//5/+5bpq8GX7Lt3wL/sk/+SeXjveX//Jfxm/7bb8NX/3VX600GFpQfDKJZforv/rqq1IYM22IbaXv8t/6W3/rUlsZZWbV29/xO36Hiv2xD0hGs728KY5TdniM7/3e79UN+K/9tb+mqCyVzPwfGfZJQEBAQEBAQEBAQEDAz3TwmYpiHT4jUWz0Td/0TbJY/CN/5I9I0PM4vviLvxhf+IVfqOes559/Xs9dfN56I/7EtF3g8xyJayqA+fxKEpfPcyQ432rwHMz2/fqv/3oJm0jOUqBF8dUf+kN/aPgcCeR//I//sZ6R/8yf+TP6+6/9tb9W6u03oqxmH7I/qCpmDSGS6CTTSax7kFD+iq/4CnzN13wNNpuNfJ9fj1gmfu/v/b0imPnMz2d2isH43T/wB/6A7B8/Xvzm3/yb8Tf/5t9UGyjkYsE/BgrYloCANxtR/3rVvwICAj4lYASZkUze5FmYLmBnDUKC+f/6v/6vUMk2ICAgICAgICAgICDgLQAzRL/qq77qNbYZAQEBAT8RgsdyQEDATyswmjsG/cAYRT44OMBnfdZnfcraFRAQEBAQEBAQEBAQEBAQEBCwQ7DCCAgI+GkFpiORXGb6Ff2lmcL03d/93UoDol90QEBAQEBAQEBAQEBAQEBAQMCnHoFYDggI+GmFL/mSL8Ef+2N/DN/xHd+B7XarQoZULPsiCAEBAQEBAQEBAQEBAQEBAQEBn3oEj+WAgICAgICAgICAgICAgICAgICAgICPC8FjOSAgICAgICAgICAgICAgICAgICAg4ONCIJYDAgICAgICAgICAgICAgICAgICAgI+LgSP5YCAgICAgICAgICAgICAn2boug4vv/wy9vf3EUXRp7o5AQEBAQE/Q9D3PRaLBZ544gnEcfzmEMtf8DnP6sBVWaJpGtRVhaoqkaUpptOpnSixm92mqdG0nf4c9UDbttjWNdq2x2rboml79C0bCkRR7F6u8fynt++y6XEUIUsy5GkO9BH4q67vsWkqtF2HtmnRta19kwfkOcGD+RtvhEmeI89S7E8LHM2nONzfw3NP3dHN+e7pGTZlheVmjW1VIY7sMpIYyDIgjiNMikQ/2S62bzad4eDgEGmSYpJNkaYp5odzpFmKtm7Ungg9YvTIshj78wmSJEKKTu/XdaM+SdIUeVEgz3IcHByoHzbbCnXd4iMv3sVLdx/hlVce4Ed+7AVMJzGefWYfR4cz/Kwv/CzcuX0TcTRBHBUo8gzzaaExyNMEbdvhw688wOlijfV2jW25Qc/xqDucX6zxfT/wEZwv1nj+4T1cbFbI0xhZEuP68THe9eSTuHJygs/57M/EbDbF8y+9gNPzM/zgBz+EH/rgh7Apa5yttho/qNt7IHGD5bu969A3LTgaXRSh5+COh4S/6O0N/RNzHkRIEqif7lw7xNM3TtDVHeplqbmRT4E4gfqGY36xAU7XPcq6wWK9RZYmePqJE8wmOfq60fkfXFzg1UePND5706n65uo8wzRL8MTtqzg52UfVNJqbZ+crfPBDd1E3DbKMbYlwvD/H3qywOcmJrCbHam8x6ZGmPSLUiNAiS3Lk6czmaM25wm7o9drUHSrO/fUW5xcr1E2PzZbzGEizBHESI0GCGLH1Rw9MsgTXDjPNH059XvvLD1e4f75G1UbY1hHSJML+zNrKAWBrOq6rFmiaVvOM/zNaaSwi5PkEaZphEkcoOL+jHpOo03VxxbU98HBdY1W12DT24lrr9FsgSRONU+xeJ3tTXNm3tZ8msdb1+UWJqu6writUnZ8DQJ5l2J/tochiXNuPUGQ29hzczn0mTRNMJxmiqEfbNppb09kcWZ6jbntUTY+yrHHOPqganJ9eoG0a7B/sYTab4PhoimvX5pgUE5zsH6NpOnzoo/dwsdzgoy89xP1HC03VDBHSOMK0YLsSXLs6w3SS4vhohr29ws9K1HWHxarWuR6crlBWnBuF1u2Vwz3cunKMpquwWJ+hqis8vPcI200FNBnQpoiiFFHCtrc436zRdC02le1bfjEkEfvOxnAySfR2xUPwetvO9kr2Y9/YbsbJB6DpbBlxPDQttYH6V4ye83BYbIaub9Fq47X1p384n3eLEhn3O45zHGk/s58x4iTC3t4E8/lEc5p90rUd6rJBEif4jPe+G7dv3sDJ1X2cXNvHiy+/hO/5V/8c63WJ5bJB0/jr4Xzkuu3Qdi2aprb15W5cWjh9jySyMeLcspsY50lk7W6AqItw58Z1PHf7NuazFFev5OrLtubVx3j2uedw/cYNnC4ucP/0FBeLC3zk+efRdg3292eaa8vlCpvNVnvuYrnVnlzVpXpjVkx0XVxH3E+3VY3NtrTuzbgNRDpvlsZI2MYo0rzflo0+kkS2Vji/+PurV49xdLiPqt5is1mgaVusyhJ10+LRKdtga5Qn4F7GucBlnXNPjIG9eap9iX3ed/3QZ1x3e/Nc7eDasXVYo+kapHGCaZJrfs1mqeZYnLI/I81hrueEe0/C+2+Er/ydf++N/u9AQEBAQEBAwCcJJJWffPLJ0N8BAQEBAZ8SvPDCC7hz586bQyyToNo9zhKeaPMv0Rw7GoOc0fARI6JIHokIcZ8XsSxS5PJ3dlTIDkOElt/X96KPK2prhLWRFyTLSCLw+3pId4T07lye9yHj6RvGV4yo7xDFJBcdKdZ7goRsB/upRd91iP3Fj7pLbXd9oF+SdO9ItvD7LeLYtUWnImHBB/8Eabp7+PffVZNHl6+/6wJ3p7v0S3ft/jeOyhl9ksfsRfywb0iOc8xJ6jB4oFeWiSBl35No9XNBzbAowaVLHvf+qHd/0rHy88X6i68RRdb7ltt/+bGERAnHY+gDu16SUwwosO0k30ksk+DM8wRpwmCB9akdmAS3nVNkqO+vx/t5zI/rkke0nO/n0d+H4eefh+HhOUYjwPej0THduf2YevLQNe9SDzp61ubMaET1nVg9SRZu+A7n/mjUhrlOepG/8aQxX8O5Rtf5+LheJi93R/YvXtfrrefdIUjacx35a3ZH9D9/gvnh9xUSqPwQyc6hj3WNrocf6y9/BH9s7WrD99w6emyWikx3JKf/Bc/XurU+hEgsQmIEoZ86fac94tK17Xr/NWv40vrxk+Nj9t9PtqTcEf0kec3h2EuXQ3HWLGvb0FOjvhn2y3HbGUQh6Tnag/wB+tcZ20uf0+R/7R7s59yw9ft14tec9iuS1kBTJ+i4f7bcm0jMM7BSK7jCvZl/H/Y+nl9Bn04vEscKUurv1qFtx16xe4XfN7VVuOtRi93vbM/edcawBkfDp/uMJjn3q1jHHfZzXSuP4/b10XEe7xTds4aOGAZk6OPXju7H3n133/lJ5llAQEBAQEDApwRUKvsHewqRAgICAgICPhm4uLhQYNPfh94UYnmxWtgfRPL1QNabWiuJEGWmtuKDufEFfFJORU6lIOnXSjVl6js+5JPY5cP2ZfLOOD73ud6RXFIXJsgT6gxJKsdo+w4NdY5tg/YxwsvRbI5DsYf2LqLqskfdmToVqyVefOUVHb+UKrND37ZIYyqUYyRSoSXIcqqAI11rHwN5QXI1xnw6U+fWVY0LqiZJTDy8p3PtzaeYTgrkVHLnmcjAmuLLJEaep0hjqglryX035RaL01Od89H5Usraw4MD5HmO46MjJOkcaZqjrEokcYfZPJZ6s+9j8DLYJgnF407EdE/yC7GNgwiSgXEW4d00Faq6FtFC1a8GQOSJKSAXqxIvvXIX222JZ568i+OjA9y4coI7T9xEnBXIpnO8cvc+FqsP6jieHJXo1YuSY5KERviNwgw7hnL0dz9WRhjz76bYnU4mOD48QFPVWHOYuxZ1X6GrWykdqUStSHz3jRSB+9M9kTU5VdnrSqpj9sfJwT6evHMNs8kE164caUz2nKqPCmC+qNxerTtEUYbpbIKE59yUOud+3qJLGzG6UWrzSZrzGFIGGynEGR6jaoGyWxspLl0sP2NEUVVSpd6jbdjXCdqec5FjAs1fHiaNO6SeMOqMPCqrGlEfYW+WocgT7E8ylNVEqv9tyUAEFY0MPFCJaERyVfWoK1NCT+JChFq32ihgsG3YhxX6PJMUkqrNecE5asEaBVw6qt575HWDpE7RdB3KttZ1eIE6r5/NpBKXczhLEiQ8ZmcEdRxRFSldMHqpZDu0faSfnGtRTDVvhI790faomQlBFWxKtSWDMk6fyn7JGkQcqChGliWa10Weax+Ioo3mzbbs0HQ1JkWGctsiRoO6LpUhEUctkqRX0MYTwPxv0/dYlT2ytsNy06tt0ymP7Se0rYs0TxCnCY5B9WqH1arEdlNi2bW4W23VpjkVo1mKdVGj7zKcVw2WW85PKkd57ez7xlG4uhTti6lU5o7k5P5U2dJoWxKPfu1Kzz1aS7ZZ+oCIPubX309EMPsAjSdHR8txTJryvdQplqUUTqhYNqWzSFgqYZvEiNV4dyvoqcZfrHCWnmE65d51AHQJ0myKuOqxLbdSmit7gWrrpkNbcy5wn+d+yP63plJ1bMEiZlGQzLcsBq0nJOpBZUAkES5WK3z4hRdR5CnuPZoqgDSfTBVMKrsP4cVXX8FitcbpYinVd9dXusZqswG7+sH9Rzg7X2BTtlhtOT4uIyECVuta8zBnxkAao497ZLNEeyuzJCioLyuS0hHyNEOfWpCQ16Q+iRrr64b3uhiL5VpZI5wr165cQdVUiC4eYVtWuFhEqEiKKxvHr0N1hjb4yKXQKNjIgBjXTtdLAc8ZVLkMGAYzOU7cB2J3v+Qew/7SeuHdITIFNQexaxpEvOkFBAQEBAQE/LSFD7STVA7EckBAQEDAJxtvRND7hollpnoTCUksEm2OhFWWckxaxB5iRSrCpcy75H6pT/n5qEOWmBKMXzOCjiSos81wafH8vKNPnFrQpxvbEePOpUiLXBqLtx67YKduM67OyKNajABT9hemQqT/gvsaiQSJEkWqUC2cm4qV1xf3SPMCxSRDMZ2imFhK+LYqUZUVNtVGZCT6IyOykKNLE0fIuId7ku0xSRsjYKm0Wyw3ugamyFNNO5vPkZPcmUwRRVOU2y1OjvbQRy3yvBX5TNJWpKRabXYLIhQQOxXeSEk3qPScEpkEiwgQUxd6VR0JjbKucX6xFFG4WiwwzVPcuHkdR8dHOF1tcL4pZYUh9XQ9Vsw5ws6djmSoKc/HiszHVa2XZcCm2DValhYjJOfrKEKdbRExdb4hkUICsZctgshzMGCRyK5CAYy6JbsjAprXOp8UuH39qsj+J65fEeFUFLnm12a7Udr7cl2j64x0IfEvUmddighm+n7XdEofl9x+x9Oj4xxkP0v1HKMl6VbTFoO2BmZ10JMg5jixWQ3n+k6VSCJNJBIJanE7Xqps30taBlhatJz7sJT4Io0xyVNUjEvI0oBBm1jBiiS14/B4Lecf52+WIW5aJNtKZHZTN+q7LI2Qs8Mjs04Rf+XaM81MuUwzl0aEeSP7BH5f82q0trSeGn6K/ZJcTmKQFYBXcJOkNUW8JzW1brlf0LZDY0rFqO0NvA6m/3PpkWhOSDb7FH4R6YlIep6DdCPXEdcerQhoX5GlPFZjJF/E4zlCd6QKH6xKAM2tuO5lUcLjsjO43vkFnbePMKWRTQtZXXRNi5J/qUtMpwUO5/RpoTVJjjQFmr7DhkrathGByHbUzt6nYH+L4Is0rmZB4oIMjO64zAGbJ2yoI5VfRzXvAzJeye/30Mf3wUtrbaRCv7wu7RcKBjLwor3dKfj7RvuH7WWKqAzZF75Dq22NzWqLuqzNJqcnCUpbkxQ1A0Fcmz7Lo3VzyWcfaM67q0q8Yt6NN0l4Z7MjgyES0Ynt6bQuKjdbrWtaVfDnlSMoO6Fua5xdxFhttrhYrhUE3d83uyDap/D869UGFxdLbOsOq5KBjBhxlqu/Ss3+HvMZ/ZBSBW5476Miuqus7dwfbA0miDsGPfwwmNJZI8H7G+1NygrrnjZKuQKTWZNgteV+08gCSHuA+oBnNXKZ8043C/WbqeG1rrnncP9TP3NfZ0CDBLZZNtGGg/3HfXjIfOCduDdimfceBnxsIXhbluDZGBAQEBAQEBAQEBAQEPDx4w0Tywd77qNRbKnrIm74UOyUhCQ7+XtHpPVOUUWyiM/FJAdMEecO44ho2hV40o4ftJR494zd7fg28jzyRc0yU421VJ/1qIf09N2D8c5pwv9DEtURJ87GwjLofYo7XXLtdNJKkuTiP20lVeF8Qo/mBHt7M8ymhRTFszxDRPJyPkVGJXJNkqBzn8kxyQpMC3ogk3QkkdGgrYxE9unPvMZicuCIvBZ1F8vvme2rayOe6e98+9Y1KS5nM5KfiRSbJK0AquKoovOKvwRdlO7SxfWvtORSgJoPdYe6pwK4NbLbBQfkJEFlZV3jYr3B86/ew2K7xfToiLJRNHWPPKNPcYEsTtEkjaWdu5Ru6/ddarcxJZ7U8GzYyHbjcSpD4212K9Vmi8XZhQg8qqtJhtvxqWit5BnML9AbelakONmboMgyHEwmyOIEJVPjLUIhNfJ2u8L6YmnzUD7BwMVmIzJfpFdl6uciZ1+k2G5yJCRTW0jJSAU75zj7eNu1xvGkpAUj5AUJXCODRGqTNCbh5NaCSOSROt8rSwcHXL7fAQ0DHxwrkfIR6o7kt6X2L9YkmTmP7Jp56iSyg1ZNqiNNqLIXMWhcVJYCkwnnEH1nSUZbej9nNoncVdkgI8FetPLVLVzDqIAW2SflZoqSClWuac6buJWylkEab2VDZSXtUTYVfYBtHTFgFHW9FNj8LHeOhJ93dglV412MufZ69E2HumK/GumoOFRsimwq9I3opNdvBiS9+TN3sQhE7iFUbNJ7e71JsV4X6lcqZLWlqC84fuYDLQJOJKmp+El4L1cb1HWCg70Ce3Pbr+h7bYpcIz8pmuaE5fqbTjO0VSP/YI7Jo3ypfaVmYxSUMt/ktuMeR6U2AyG2sXFqdDEtZjgPjSz0K8Fs6S1IZOS3/bTl7HzovdWG2+i0r+1m0+jneHXtbCe8R/PYemVwclCQzo7lfcV71y42n+Rwv63cOmdfmlqZXy7rCmspky3QxiyT6yfXUGRT3Lt3hore0o47taAW+353zb59Foh0HvdOscy+9PcMEwST+Hc2G4r5dFhtN9hWHING40wVP49BK4yqbBRcOjk6xKQo9NOsfabY3z/A3UfnWN29r7nNIAqvj4EqUvrcSzqq3vsYaRRpvC1IQsU+x5hzg4QvrZ4sa8CCFmaCzdiHSHAFu3rMNi3KbYwOKYqs0Gemk60CKGUVoazM2odrnIHYmoGJLkLWMHJkAVtZ/zg/epLFVIfL1zkhaewmqusn70elQEnCfusGublZeLDH+f1ALAcEBAQEBAQEBAQEBAS85cQyU8iZpm2FimjxwAdcpt3rQdsRZUqzlR1DJ8UnCeGUD8WyxdiRBLIBEC+xM1umUktCME8q88HdEctZFGPmVKWriufssaF6rY2Gz+oPjrz2RQCt6F58ifz0alopoV2bpEQluSAFNRVgpkbbK4wsPtyfYjad6gGexY+iPsNsb4qsyZC3VuRtb144YnmCWW6qZloaUOm3LNfo2gYFC/blObo+w2RyIAK+LtdSk25LFqtjgTqq+OKBWGY69v5e7gowsUAfCRCSb2blYZpxkkGOHHJEEfuWNIaIZadibhyx7Il2Uwga8bdtOkTrDV549R7OV2tcvXUL2WQuNSdJojydSCVM4qhNGFzwM8QT2Tti2Sumd0UVd/8dMFJOevsOEsvL04X7qF2bPwfHhqpzjpknlo/3p5hPJrh17YqUzlWbytKBhNFL9x6hLjcoL+4rHZ7qVf5zut5iUdaYTSfYm02R5zEODqxA1ibPEVMt3JGwalRMi+SO7BNqI0CLgsRPjGKSIuGcZPNoR8KCZq7AlldwS7Hu1oU4Hccn8vLEP4l4FGWr9aCVQmK5JJkcoVjRPiBGJ0sYWgRwijNYEIkYJyZS8Y9IuQyYFjG6jO1KUaUR1iTZ2H+0nWha5EmEbZtpdrB4o4qFORsEcxsgORzrHCRj1V6S2W5M+LIicha8EcHJcfcWLX2vQBIV2KZqppKYbbZ5kuscplZtqGjl9VRGhsUxxxiYTiuNDYlAEnu0BShyI99JLLN/1mWpwmoklmlVwXaWlQVoREiTWE4TEcwslkYLD/V522nPWK1bVGWMzfGeSHL1hewC3ABx92CWRkQ1bIJukmFVN1itaS3TqsNV2DAl8W1+6NznpGRXdoHZfvAfrj2NEVXl3EOcPYFmughlWwemNHWRCRHdfkGP4mi9X2dmqSHa38uYH1tkftVxznqfYFOS+zQDF/HwRTid/7S3yuG4lRX3Dc4DzjO3n3dGUjLbgbYo3L9qEstRghsn1xSIKvKPYJVoebisFrs28zne7R/eT90K95FQtuKpJIr1Ce/zbE4O4Kk7EvgNrWEqvXexWbrm27hxnkj5S+uRZILZdA+3b97BfMaA4ETEchM9jxfv31XgqOlZfI+fN4sYqtkZGEh5T2Pgg+13RD6Hnp/NupifdgEIW8vKTnA1TYmyblXMci4bG87LBEXKwCOtf3LdA6Rsb3or9KoEAP69QcN1WVv/kGTnjs4uyWl/Iym3EcsMdlmwwlTg6gMGoJRVZHZO9tN5YdsNUaYsAQEBAQEBAQEBAQEBAQGfCN7wEyVVuXxSbuntKwGpUxyKBLAH7ZwEEZWXNSmOTinhIkD7GI0rrrbT1Ll0aj2MO70UiSASyI6Isidh/2eTdooS7mm24TkQb868e1DenWVgX/Q3pZ8PthqjAmWuHZYibqS4ESdGPBdFJqUyVasidp1ImgQcVclp2iCuS5FuZhVihK5I3YR9R9Unz0eFJdWMVJaxL2usNrUj6UgSk6DKkGSFI4BptVFhu9mI0EoT+i+nmPAzSSpyusiLobgfSSMrKtahqUvUFb11zQtZhD3HR363rhCjI5OotBvltYvYr5pGZFFVNfJTXq43OD29wHK5kgrQFz8UYTiu+jWoIC+XevTvGynmCJFLOf12bv/f0agZIe78S2n7QEKZ1z2ZTDDNC2cBQjsRqnFrPDo/x3Jd4eH5CvceXaCnur3eitAj6a/5RsVoTILRvFGlynQqQxJZJGRpTUC/b/MuNTVtW5u1AhXcNUkkBjU0tjHyLEcbteiaWm0VoewtL1w6euJUvBnJV+dtrb5yni4i6/iScjnR2knjQvYg8lqVB26HCSXJUlZKiKifuRSTEpiLfM5jRwqSLCaRmyTIYirWIwWBmi7Cmn7bfYxCiks2waI4mrsklmOgyEic96BrK39uafPQ0xPayDQRaCoUuCvbJ5Jc85EWBkZsWqE0U76yYVorQ2FC+57Rls5vm13bmoVHl9GonK9OFB7bZgSc+XLLVoPK55pF26h07xzBa6Q0Mw7owUxQ4TyUgGRArLEgGANAm22FPsq0Zn07NFUdx8s1mrpClnk2Edu52dSI4gZITQ1KBbzfH2VbQNXvYO/jl8olB/KRxchjhezcvsZ+5LyjFZBfP4qHUT3rVcaSituY747gvaXdNimFKn9aIGBcFNUU0FwPTq3slPeyZHGZJ7ZnOuJcnvfW3s12K596Wgydnl0gzVPMZzPtE/uzuRTL6y39iWkNwbuDtwfatVS8tvZl5wDhMik8VGDSs+pu/+Ylsy2WceGLAnIPdP7CXGP0gN5WePX+Q6zWW+zRI39vrmyP46NjXF2ucOP6NazLCo8uVkYMS9nNwIDs8M1nncEdZSA4NbJrv9a6I8hlp+FsosbFDtUuquurGqvVGil9zDl9YNYuvLfkmRVPZQCE79n9zt3zZP3Ne5+ZZUhn7MZSN0NXlJb3IBLpygTgOnYdOBQadN3H/cq8tW38bBwDAgICAgICAgICAgICAt4iYpleonwwZeEi+WyS6BBxwgdi02UlNHQgsVmRrGpc2nsrpaWni/zDvqcTlDbuxXJeadVTIconaBJr5kVKk1oSa7Gx2nov24mTHTHiHpy9ss3O4CgeVxCKykengjOvaEcuO5KFafkkDjzxzM8d7E9xdDiTSpdKMacNRpakONw/kL/rslyIxM0z9glJrQ5xwkJKVLVmZn3RUmJnZDvPsViv8eDhuXwzpwcz5GmOrJihmO5Rq4ama3CxXOKF519U//Nzs9kUk/ke8nyK+XxP5I2uzRErVO+R+N2sFlgvzl06u5HUHIuK1+cKipmyOTFrExE9pngmcb8qKymjV5sSq3WJ+/dP8dEX7uLegwfYbEq1LaF1xEAuOj9dKfrMbmAsZh6zapd4aF+4UUPoUrd3s8O96GfKuWXkKP2GDw8PcXx8xeZhw+J4wOlFiR4lfuCHP4wXX7mH9bbBcsMxAQ5nVH9nmO7NkE8LpHWPRDbFsYhOnpuqYI45i4GlcY/FosG24oSlQtsUl3WpiY/N1grSTSZAV3DMcxTTyEhQFurrO5RS7Bv5zM/ye1TDsn8nCUMJvRUh1IwySlmEGf1pSRx2KeI+QZEdYDahP2yrOdXWMVZT88pmv9PreYoIExY+zMwwmaRrJuuKCFmRoU46nGe1vKOV4t/R1qPHaVnJRoOq4JxrJTZ5tXyJI7Pe2E9U0c7sIMBsgUrt1qvrpDC1gIsV2tS10EOcBKWsBIzApa0JAyybTYSGa2lihfE4hsxGkCbaJPeucCeDLx3KbYWCTDlT/lXUs0MSs20J2sy83Hke2mls1rRhiFDTwyBLTamcpNibTVBtzUJltd0Oe4XsX0p6Y3e4WG5UoPMAVoDTAlFGtlJpTIUtAzpUqXe0GJknUkY/Oj8zxbkrMmil+TiGMfI4lZK401hYQUILxvB67c870a4LojibAq8qlcKXxdnokZuScDdS0YW3zKPYBeD4HSrmd7D9zQTP5tVrgQy3ZmkPIsbcBZmkxDUy04WfUDn/awY7lD0iYtnITQUD0ePh2TmirpVtDPvt+rUruH3rFmbFBLeuXJHdyv3TFSIW0KNljEhR3w6nxHbziDYWVObKbcaFKYYAorfzGGx8rAdZFJbXwUwaZYmU9HV26vyqQ3W2wmr1IcwmOcptqeKgb3v7s3j69m0keS6PpvuPHuHiB39YRU5bTjdmZbS2MlMVn3Qqb2YkuPsJ20yLC/WPU8czCLVTnTs9P/9OC5rNGvcfPMRkkuLatQkizidauoi9tiwUBh5ZaJRzkwEsnpU3aq5zZgHo3sd9isbqpinXHN3WVhhXa1HF+lhY14r++aAOrZ04Zxhcs+CA7a8BAQEBAQEBAQEBAQEBAZ8I3ngOrCOHLIXcFeHyakNfNEqFl5iqH6FNYxU38h7GXjk1pDOPlVyXRXPm86oUev6HBJsVSpMuT9JBp9oaFG5G9EoBN2atx77LIyMG83D1RaJMGUd7D09+e99RHdPZFJDYjCgNlSpwd1R59vYR0oYKM0tzNxWepXvrj/JHtX6QMtXZB1iKvSlgjV6yomP0ZKaHMslb/eSroRKyEnGx3ZYiVOi13KoYHYmeVD7O2+1WhRZ5DBGEPDfJTKkEzcNaRNVI1uhLqlnRKEdWqrChtVsWFAwSVLWsBHzhP68cHGsxpXgdF+p7vIKkt8twv96Vc/NK5fH3R/PEETWTIrfic/R6cOrBalOKXIsrUzUvNlusyxpbkuwsathGshlJHSFEWCE4U/yZ0s+KuXmCyxcu2zXCK4pt/MndkcShMpbkjwIJTjE+XJtS/b0nq7HonjD1zga0tqBinGNJMonjmCaZlNks1Eel7dVrR9ibTxBF9NRukeYl2thsVjjH2NSDaYQipeqYdDWJYfN3ldJVa6dDzsCKSEk7v6aiU1PzW3wxyKM2kjjkvKRK3NnHcO6xzSLDSWA66wZZfEjhbBYG+qy8X0nI9ohJJDqzDCnxSVxq7CK0Fk0agj+DK8NY/a7ih62KI2oXGO0htgfZP/Kqli831yvnPtejEWe83nSwqPBzdhgq/Ye2A1uS2Hkmpb7sMNwcUQBjKBDqlJ5OKUxSkXNBcxRGepuVsdubfOBM82dsNXHZGGas7x8se9yFjoM3O8sZR8S7eepJap3TF8YbjuX2aClpLRhhxPrlfjD1rY2n91geVrMrBKdAiLxdqOD1XuGS/Ev1vd5s9VPniSL5Gs8mU0xyemEbSc4CepwbzWAsYW0cZ5JcIpXV5+5z3kpGl+NyU2SjBCSDX/9gNWzEvQJrDaIqwsVypXNQOWx7aYLjw0ONP72YufYrBkRMJu76xQVFXR/5/vDXblkutqh9kNJPYBsDK9iqIFlbK0jmP0simQGQNOlU/FGFKxXQGFtHjQyxte52WTqWkeKKUWrs3GiqLq11hC8aavsRVc0+MDE2wQ8ICAgICAgICAgICAgIeIuIZT4MmxKsRrmlMs0RNSKTTK3FQlsEH6qLqkUaUwlYiwTlgzO5CCrJ+ECuwmwSf+7ILKrU+JA/mU5U2d5xo2iqWi8WLGsbkohUbJnKjin8VZ+IrKL/p68RNzK32JHKEsQylZ4P8UYsMgWZ7aNtgaUzk6AiiUD1L1VvNc7OLlBXW0wnTH+najlFkuQqJjad7zviM0Md81paltCTsW7TbK1IlxcyOkUuC0RRvUl/6v29wgicmKRZi9X6FHWzRlm2UtqV9QIRvYzR43zxEJsyw3a7lg3GjRtXUNdHmBZz7M+PsN3WePmlF1GWW1TVVu1I8gniYoq63aJhgUD1FckNKoCperUifiQ3RZDQ6wSJvKKp1C5k9VFIrbvclNiUtVSAHANPOKpgGclWWgo4AtE7I+8K+3mThF2ZqDGl5dxcTSVp3gg2T9rWxlakYo8bt2/j+tVjnC/WOL1YYbnc4O7dU9kfbOQFDCzLBiX181J3spih2TdEsrYwxXbq/JkZ/KCvKQmvsnTzmKSzJ56liGShMlMUs8Af20q/2bJtkSW1iE8SQpz+VqDOlPf0C65rR06RzJXykarDnYVBMs2kRH/yzk3cuHkVe7MDHB9fHSwX6F9+88YVzOYTtHWFrm6w3lY4W24swCLetAPqFdBV2G6X2FZrVHWJ9XYtBfVywaJqDc6zEl3SICLZHrdW7M+NA4l3EndFESuwIysXzgeVfiRxHmM6JbltKf5c9ySW24GoauSrOy1IjkfIkaHhFa+3KqhGQSaLGqoA4rZCTfU2ckRdquY7MTRiT3Q63/SuMS/27bpC35aIaQGTTdFRSSoFbS9V8CQtVEhuvap0TcuLDfIsRTuhxYzRa9Miwabk2o/Nx10kvFPxIsL5gh7BtfqK5+aeNctzs8PJMq1lDi59tM3XvJUyvW0siMb+IMHHAn8pCzpyrJ11RMY9ximFlb2gYTPy1he1HNtk6AetcUSYcp06UpeWFBoP8zlW+3Ut3uvak8feI9nEyDyOzanMlL3KHqGivBbBq7Xvzl25gpy0KjH+2nxb2FYjlfl5enxboM1bYnB9XKy2ePneI6RZjtVqpe9fu3Iib3qOeJGusCoK3StKWsZsNlbU0hHLsvWRRYovOmrBv0Fw7ayTzBffFfWk/UvqlNfcX1uSyJbW0pvQ3QXVetRVgw+9+LLsXeq+x9nFAteuX8X73/VePHj0ENW2xOnFBT74/IsioKMkRxRnLjPCEfzeO57riHc4Ws1YRHQoNCjvcOfmJEE47zPOJqNqSmSc3+BaSeTznhe8TpLT9Jry/uwxUqZbaDDJdNNmh5bK5pvP+43uzZyLvfk4c2+n5zNJ6i6mPtn+8eQyvcy57/eZciJMdT32GwkICAgICAgICAgICAgIeCuIZREprto9CQl5wI68iVV0yak8TYHFh/1Wac2yoHAeriSUXF0+U4Q5awnCp2ybcjNBSyJDGcKN0um9CpjwRcqYEp4OPsGvVyDudd5xRfuMPCSxTE9bpyAclGbeEzaSUpcksJRkqq4WOX9e+64pE40E2olcjXgw/1NK3S77iYoCEhlP+peFoUx52vVULMeO9DWK1gh38gutCLW6rkyBx0JZNVP/M6fobFBuSSJuRLBQ9ZokGZKUNhpMFY8ujeGQrj00ykaC7ZCXc56bH7RT7VEdrJR/11eDv+6oX20yXPZW9l6kw3xxXq6vUSaPfo79aL3CnK2jsnc6mWKxKjUuZVlhtSGR3mApj+UeLYMSrkCjfIk1d80H1V+3FVBjcTxXMM0V2dO3zJHCeWW7MXWeqiL6SMTK6oWkdIe6sj4ZxteRgHppUpnfqVewKvBCaS99jycTpFmKk+ND3Lx+Ffv7R7h69abmZcQiiWmC69dPMJ1N0VQV2qoW+TlblzqV6poxZX5zjq5hAbspNpsltuVWBSZrFo6sOTcrTLIMk7RG1dWoWqfWd2vG+3B7tbZfX34M+GfOJ85Zr/r189jWl2UU6HMk9KTS5/yza9b6NXm4yGj4+eRsG0RMDvJYJzUl+W7W6k6BbMEYEnD0Kh8KfKoAHheIWczwxUCD1KC0ZnBKUSvA54hKN1dNeG/XokKIzhqDL45vRvI2oSc2HaadAnWwOPBmEbaHmRrUK+BNOSrC0alwNTV2Zu6OpHR7wWj+ex/y3cK4TPzZ2vFhmpHFjOamWVT4r4i0HYI7PrvD9itWLvSBniFDg+PinDGMMh5nGdi8tvVtFhi6zsGwhpkBnQJcW2Y31I1IYNoPcd3y5+BxTyuizsbGJQrofF6xbMTyziLJ32O0BrkPc45pjZlS2CuYZVXEueTnqLP30X7uBpzFHhkwvFgscXZ2jqMjBucmmE9nODo8UC8wKOHKAugada2Dw4jXj9uBbT+lHYevGrsjwmmBo3J7w7pRXoplfThbKFtXtGxhABbOz931u2Omjcu2TUX3FX5m8N/2Ps+2r3OO0WpkPE8G73PzKrJjSMXsZm/glgMCAgICAgICAgICAgLeUisMp6LyRZVIdvpidio9R+/UmjIxI95ovzAtLKWf3B4Ve0oXbyJUoB+kkZRMMTa/TCt4lKQsspUjSzKRhlTZ0u8SsUvvJxnC9OrMCimVUYa4ibGua3Q0jnVEnhEJnhSxh2eSDlQck+TYm8/UfhJ28hLemp9uD6ZHM40bIqf4IH5xQbK2NoUkCbOiQyEimETvxrqHim6m+JsviB78TTRHL99W6kPZZojMzpwCtEAxmdp1ZUaG5CmVkTGmeSYV5DQDpmmHjOTj0b5UrHk+0zEmkwQVC/RVDTbLrbxBu67SmBwcHCEvCkTFHqJijrp5iKY5Q1V2WK8rrFZbNF0lstoUgGyfFcU7PtzDc0/eUnr4wf6ezmmp5jsqTWpiMS271G+RQySASRiZGbXzQPW+qDtlop9SY0KNDBBJFc+48Rj0O+bPq7MTjVMfx7j38Byv3j/FS3cfoawa+QW3ZBwLue1K6U77Bx6cAYPY+HkrYFhVoGsJ+79IEpEwm7bTWFGxTFKUHs78yYKV0xltWEg6G1HJoAf7wKfK08eZ85rCwqJxyntnUUBVe5RDAYBmu0Ve5FJuTooJrh7fkBr85Po1TPfmePbtz+LOU3cwne3h4OiK6y+S+hGm86nIZypyVUDQEdqEinTRrqVcomsqzQeq1alaX67OUa43ePDyq1KPZsWH8ODhIzy6uMDpYiFyi17oHLOM/r0kq7QuqT01tSSJMZ0q4jm3WlObeotNU6FmwMeRqaIkGcxgv6UWgOmiDmUWYVskSGojy6RJV/E/82snu0aikYUAeRwq4zlBON/ZhxxLnpTq743WNdfmwhG5mbMZoaf5zsaGJDSVy1XaiFjnFKQiW9ehApBmtRNLi20WOl61rWKNyxJRtMA0zxEd0Bc7QV9Z4MzbH7DwJtcbz815QpscEtCKlaktdl38rvmZm7Jb6uCBVN4RyuO/qyjbuKioOGBHqjpbnZHQ2eaB2+tEKjuV7BA0sqjHqEgirylFF7vMBWdtw39IKvNcXEmWFDI2DdmZonuLnMH5xV0EVbNxVCmj4O6Dh1Lcl5XZ8zw6fYRX7z2QZ3PJwJkjbI04dl7QriAjC9gV9M92RSTH9jR236FKvJECV9pmV1CWHtgknXMqe7kuWKzRBXrAYIS8kq0PXrn/AIvFQt7zLMbINfau594uD+7ltsR89kDF/KjC7iLuFW4PGwKYrkdcwUMVPXRWFLw/msLcrtPucT7AxsKZ7KOF1O37hwkm+cxU4BE98iNsGTe65OrEPSWWdUjH4rhca31q/vm0puk6VMwmaTrEuu8yfcKK/flx3xXYdMVcudj0fwDjMF5AQEBAQEBAQEBAQEBAwFtBLAtGDHr/XRItpgZ03qleBeWIBxIAtBUgecCaYk0MrBIqFd3RRGbsDJb54G22GSy6RZsCKxRoxIJjUqLRgzsVrCz0p7T81ogfHktF2R5/VLbvWQG/GEVBW4tMxfBIZreolRpPpaKXqJqaMhLRzAf3+bRFO/HkgRgY9CK6vMeo8yp2/pjusix9PWIuNQlaqtNS14e7V1KYGtQcM8yrloR6ykJl/UwKusODObI0Q5bPZAnQ9bW8lOkNuqlLeSwzbZ7Hnk6nmM33EOV7QD5Dka9FeilNvGr0osOpVItSAdJSIMa8SLE/LXBydICTo0MU9HA26a6Ryl6JOyg3jRCWJ6tXJw9qwR1hpL4Y/rPzSB1TGiM31cGP2GwFMhyeHOrn+ekSi+VGhdbOFyvzOpYSlNEOdxQnZfWkim+bWWuwoJ71Ecm1rcbQq2K9d6q1hxYVWcaCjmZ/oHln1sGDepNEjtlB2PVKtU+1sQjBRAQtui26ukaUpVIq700nuHHtmoovPvH0HewdHOCZd74Nt59+Evlkhvn+kdlvuBT2mCSrFLm+uN24I/mfDm25Rkdbg6ZC01DJvcF6dY7NaoU8LrC8WOD8dKHia+qfhgUYWwV4vMMxPYDlZ0yyVAUuE33e6zObrhbxyJ8klWnpYkGGsWe4qUnZ5zwya+5xvckOgAw9iWrXxzw/7UlIvvLyTE1vQQipTp2KnL+QD7QU563UpkYw+s/yvCSmbb3yO5zfLJZJss4U065ApXx/qXC+bD9hGRQ2p7dUv69KFV7cLyaaV0lnGQBewS6fYFoK8Jpp30A7IB7P7Q08G8+jIqEi56kmNRKchOfg0zxSo3oi2PyvnbezyyDgOrX2jsxkdvb0ztPZkcvOc36cEOCV5Tqq2wc5D3SeUZaGJ5fZFi4nCxqY8nf4zKBgdkFEBZzsPV4nAwabssLFaq1AHj0kuJfTmmWxXGj8FT9wxfceV1R7/3Nmuxix7IoeOs93FppUkI40K9dzFznPYgYmXXHDhNkysfy95Y3Pvvft5+f7Xu1br1bY29vD/fsPcHx8hNt3nsFhU+PqyTHKusJqWwHLtdairRTn9e/vL87PWoplN5ZDgM17HQ/f8RuSZZ8wq6DrUhwc7elepCJ+Xab1XFXOJkVTg5kXFi0wVXKLpIs1Thj5YjNrha9OKQO2//k54pXgfjO22gGWeTQEHwICAgICAgICAgICAgIC3ipimSpfPrzOZq0KlWUZ/YktLdkX9iPxQ9JVdItqB5lKmPYBuWwkIGKN5AoJOlM/k4jkJ0nSWaWudUUPShKmLHJEdVaHOLWUYhKnZFGkZosT7FN5KqLaVJwkGqSk9IWNqOxKTY1L0sGTFlSekqwlcZpkKSYtiQOSEduR+pYElKMSqFYmUcjvFbmIW/2ebVdxNG8NYPAP67IQkR+pkxH61GWz3h0VEXTknlE/pn5Th5Lgy0QoT/IJ8py+znvyMG3le03prPU7i5stLrajolIk9mL5mc4nBQ7mc6xXa9cGq0bGqzva38PJ0R72pjNcOzyWmpt/JvEqiqqjjQbV2/IlGUgmkekqpmXt9t6xA4khBaYRXINNhE+ld3993ALA/FqNXKe6984Tt0T8106Zd3GxxqNHF1iutkbUy6LBWXw4gw4SUQNZ4lSgnqg1T+VOhe7keewsTPiFqjaSSkUn3cWYe0OPuq3piGqKSFqFkPWPIvmbmk0DSURTQEeNI6Do69zT9iLBZD7DzRvX8bbnnsHR4THe9vb3iVg+vn4dk/kcJzduYnpwRbYlUTLdEcnikkfknzpq5x/uq2FS8R6TTMtaxF2HpKiRTQ8w29sijWcoN2tk031cnJ3i0cP7OD19gOV6hQePHqAsS5ydX8i6oC5LlE2LIqOKMhmIRRXGtBJ5GluuoaTl+vYWLs4WRfYrjhSNOuRJLwU4p46sBbwgnevd2aooqERFaMd+tmJutJrhsRik4R5CNTFfDPZsSxYxBLLWiEl51XKuar1R9cn1SjuZCF3JOUv1tF0DAyv0JvZWCzYlL5uvcNltKp4jwnpTomVhtSJzftv2WXq++wKYIjBdQU4tWa0VmwPevlZELoNr9LJ283wg5X0xTOc1PNhYuBbJ5ZuKfBfE8VYhXvk8WLW46xGR7gM6oyAQ91HuR1qt3Hd7egJz76WfNslZyzjZEaZjGwx3PK1jX/jNE+O7IIcvlretKjx6dIbpJMfR0VwF/G5ev4Y8L/Dg/Az3Tk8tk4XZEFKfu2v2xWE1nxmAo7rakfVOxc09UN74Zo9tHulb9iEjZ3YNuucwOMlAWsJ7iHNkVhDM9hcjhHs8uljgxz70UVy5ssB0PpNy+fb1mzg6OFRftU0tv3sFHW0rsYBPbFZIgy2MD7qpW/rRuPtijqasz0l+g5YhKzRNgmpbaK9NkGFSzJR1skk2g+e/HT7ejTnvH6TVtVbs/krimUEtXqfZmFiQi17SvO+ZX7z9gvfEIVgru5nR5AkICAgICAgICAgICAgIeKuIZVN2Mm2YD8iucJinEqQGNRWUilSJVCYBZcXQSC6JkMtIApIsdR6uTuJIKrBuScgAm3KLWgX1jKAhYUhimb8k0UxKhsQ2SRE+NMuTUkrDFjXtKZrWcZ9MQXfp1SSX3cM3H/ZJKpOgKCa5Uv6rlqrlRinRphbsHOlt1I48ReV5y+8ZsSzSkUS6U0r6p3Nvh2nqOCtA52gk87KVctgR697T1jmaDoWWqD4mCdVTWWvEclEUeh3s7yMvJlKmUrFs1Hcr+4PtJkHr+00qywSzIsfeZIKD+QwLFdGyfqSNAltMYvmZJ27g+OAQT924JSJaSm4SOyK/mh257NlWJ/P1Re3ix+TInoIzwsuxjmNyeSCVd4yGXb+9+O5sOsNTd57SGLx8/x7qeo2LiyXu3n+EWsXvjNylPQDb2aiN5rnqqO+dIlQ+0WbHwkJ9kyICr04lrGKSar2IZXpq8/c2ta0tJECrtkZEVSrtLURaUVLOAn0dSkcsZ42NHllUWXF3tb5PFThtT27fuon3vf99OLlyHe9+32ditreP6f4h0rxAMtlHnM8ca5W8Rk27s0t1VgRDrzkVv/MAdlpp9w0qrRscHl5T4b+r129hu17h7NFdXJzex4OHD/Dhj/y47AA+9KGPyi7j9GKNzXqLbkrCj2tBUtHBAN2rMc0rtxv5xpqHrQguvtzIklgu6BvriGU5nTj1u4r/aT6R9DICn7YEsrtpE9mYcH6RXObYlC4osNmSPIQCSjzXJE+1LuWJ3pm3c+8Und7OhvYQbACz/9n2VmvcBz6cat3RuJwjIqh7YJ1tZbMyEbnL4xqBKW9ul6HAfYVF4ngOKptpz2AEsptH8p1ne2ir4ork+RGOnIJ5UIyONPzuPc5/EcAM3rkCcp5kfNynfGd54QhmKselxOa6cG3mPOYa4XFpLzOw0sOqMbKXx3F2GIPX+MAn7xTAQ66Bmwe8HpKwDx+eYW8+UdBqWhS4dfMGjk6O0X6kw93TB7t1RJKUynm/X/o5riKpQJp7H3dTLlPZSw94ZcxwTpYNViUzNnitZlnD4q+0ziH5ynPI13voLrP5oLUMffwfnS9xcb7E9YsFjo+OcHiwh9tP3EI+ybFcLrC8OMP5coPNdqMrTcDME5sP3rbEb/o732zb96SyVoDB7pMko1kAlIVOud4YGCm3hyhSFlrNMM3ZZo752lk4+b3A+smyRozM51pgHyiwwUKRJJaHrBcWkjUrIa4RBSZEKtudSGvM514M96+AgICAgICAgICAgICAgLeIWDZCRX+yomZOcextBDwl4b0+HWWjB3qRjGKRSTSzGJ2RT0yN916dRrjSnKFD1TRolartc3mtmpSl9rpiTM4Wg4pcFjmbZKl8iUmYyDdUci970OeDt/eDNocGquJY6IseyXzx4d0Xl9upz1SYSj/pk2lEAtXKJHcnxVREVrmlz6q/evPS1AWLa98VsdvRylb0S8WbpJrzthlGxfqCS2YkwOugT2yJKO7QUDXbkUg34ti6x4pwkewjuUYVLL0zVdiuukAfpyITaBhx49qx7Atu3rgmlfbF4kx+vEyfphdttS2xWW/0Zyu8lWJ+mCMrqMyD+tCK+V2uLTb2br2UCs7rUr626TP9OHvqzFSwfpSoriOBnorcZ39sNxVeeeWePvzw/Bwbqmmrepc270l45/UqKwdHBotA0cupQEk9R0bumaWCU1K7tvp0clKWVCwnUlWbos/sPC6n7LO4ZKY6jqawlViydqS764/ZZII8SXHtypH8sZ986mncvH0HB0cnmB4cIp/OkBT0+KYaVkYKb4Dfcevg0lsjiesl91fzaY6zQuRwsXeEOJuIUMwmEyTTmZTgy4sLNA2wXC6RJQUWF0sVyuOcER3GDtOYG21NwjdzwQNbo6Y4lc9sS1LVAkfkDFm0k2u0TalOtWukj6xISFcAj6rvxBcec2uJhCD7lPPU/HXdfjGyGJA1gV9jsmAhoe0LwjmfdedQo/FVoUZuWqaoNuLVexnvSFUeX9YbVKrTM1iWHO0wX312gfFyFvjyxhq2nfl57Yqeup++WKZVq3OEvHesGLlN2B63I5dFlFK9za+5onXylt59fEcwy/XB2uP3peG6lFXSWfCGsQJHpe8sKIy+1bodvOm9Ut0X3tzt7cOeNxCpwzvaG7mHyAPZBcgO9w9wEEe4d/pQVkRcQ6acZpvHY8YAQiv/dEsMYOaI+XazrSRUOYe8bQozGvIiE9nflK27V9m1cE8gySw7nHi357oNyJHk/UCG0xKDBVAPDvdE1O7P5rh14zqS9BSLzcYFE/yOtbP9MeWyqYp9JoqU82ykvy0oe8fdN/lnqrJhxWHX6xL7WYEin8pKhPdHdoysYFTU0O+iO3Zfnv+gtRPXJX3eLeAiCxtf0HBkQbSL/Y3HdJT9EBAQEBAQEBAQEBAQEBDwVhHLTP0mJFxkcS4plkko+fRv0/XyWZ3eovLkbVuUZKzSBJ18YqlYJlHUocgTKRBV5M17rjK1V8cq9WBMQsrUkab4U0E214ie5EnUYZomKJLC1Ik9iwI2aoP8Y6k8RIeiSFAUtL5I9GBPxqquWEKQxFFlSi4WL/Meqo44Md9MR8b0LfIiwd7+DPv7+zg8OBYhsN2cS30sGtvzw+LPnJLbETEuUd7ZapAUIVlNktiYCCOXRXk7ooafbVE1GyzXF2i6Ccp6ijjpZBGSyH+ahBdJevOvpa3H9es3Rej8+IdfkL3B4WqJ/cUp9veP8b73PINrN47xwoMHeOXeA/ybH/hB3L9fotzWWC1WiJseE2peWWixJ2GT4clnchwwTTuhgjjHemsKOKkHnRrZVMaOGHfjI4WlilkZwe9lfY6GHaV47369tzeXBceMEtc+wdnpEo/uf78pWaVG7bHcbtXHSv+mLpbEIa0nSCC1VuCK1FnPonbOxoKEct+n8mGuuxhVaz68nDRsOz2Aze6iRdSycBbJKTuHbDJ4dSTUR6R6EWciV7dO3UzyjN8XYUj7kSzF1YMj7M3neO9734Xnnn0at556Em9737uRTeeYHt5AnOSs2mjhmMEk13XI8Kexovv18Drveu9VF/1JJimSvsfBdF/qxKP2NtquxHZ5jieeeQarxQVu3HgCi4sL/NiP/DgePHiIs7MHODu7L29crhFdsyMh+ZN9UjodJdlRqizjphMZyBGeTbkuYxXZ5FyiOr5sEhV321bm1UwiUGuMQR5S+lKcGm3ZswhlxMwCBoToVW0qW1HbLqBlReeUlzBY13DOErJT0JhaQciqM0/olONG4k1krfWfPIW9FYPrQv616Xp5BXdNjIb+tyreaAX5ai5THlve8vRnpwE3rVn8vDarBe5DPmihueiVzhqbkcn4wMw6grUzmwNPZNMqwUhjI5grBc984MqCTDH7Qzp1Ryw7na+3p2hde7g9pdr6HVGu06qa4WB7YpYynkJ2ARWXTTAorr00dzQPPe3atA22m1JFW7uqQTwFbt15AofHRzhfLfDhFz6q8TZVuhHLJH+tAB5tIpiN0Wpttq2pk/ln476NsKbamr7rWZxinrLYaYtls9RPKtaVcdDHKKIcLYMSae0CBi6gJ7sP7gWcgx3Ol2v84A//GA72ZipgmuIqbl65gitH+/jRj34Up4sz84/e0FffEbe0wEkTEf8WTLP+ItmtwAFvSj57R/cxC7pwTItJqmtfLdfYrmrs7d3Ewd6x5l2en+p+6L5uWRlSG5tintdFKxOS5rNJMQR5s4qWGs4+xQUHzEbEBWfHcSe3V4sQDwgICAgICAgICAgICAh4K4llURWX60ENtJc9l3oV2OXv+SJu5udrD9ed0ujNlmJMR3jy0ZMTJEI84eOVxiKGff0yfYcEs9k65Km32aA6lcW0TNUlKwx5GY/8RqWQpHLYqQZHJImzBxbBM/jHOinucD0uJdkrny/XPnptR1yyxxj5nvprHxezG6S0l8RkRkyJnPJtHSlU1TwpwulVS4sAqgETlFWF/vwCcZzisGZ6dSOihMT2R/f2pFLlyGy2lYwhzrO1BQZikjaNs9ro5NVJkpQFuIzkchYCI6sGX4TM9+HIhtX92fknP+ZH7dP2izTFVErGWORKSxKHtgzsZ1d8jWSilIAjrbPSwx/rffvM4286JaojMKVsHObdru91PirtXTFIK1hpxM644YP60RhskT5Mv2fa/3RS4OTkBIeHB7h67RpOrl3D/tExitkcac4AAUlIs1jxqsFL1iGDstAXuPyJCebHBcyXF6D7j/MUjkhIxc5iZVJjNt8XAXd0fII0zXBy7dyI+KhB3W1RlRaA4SD6Qp0keKlG9sT7pbH2lgrO/1ZWBolTLmdGisoqozO7CgaZRGQ6i4ydC4QFhXxXiFhEIhKQ69nsvT3hafsLiWKuc792d+vMHU82Czbmo7Khbh+zd2wN7tTLmisiPLlXmGJWBKHsfWiT0qOUxzkJY7sWm1ucs968wCmmR2tf4zF4WlxWFo8VsYMa2+2v3ubF/j4aeRHs46N4nb4ngl2wTFkYzo94MKPZKZd3i9f7oj+WmvCaWbcr3mhtdZS2sxsiCa/gY0tSP0GR5ZjkhchQBi3akvY1wwbh9hSzC6J3OZHS+9yNq+wvNEimCJZ38FA00wUHuf875bOfO/w8CWVz8TFSWb7bo3sUz1zWtTIjLhZL+ULnkxSTYoK92UzFTNfbLbp2YdYaKgrq1dFjCxOnUHbtYmBDJK/WhCt4qzgE1wDJZ0XHbGwYxKIKO0llQdQowLgrdup9pscjoGwVFSNN0KeueKm79iHL4tLMGn3bFxwcp6AEBAQEBAQEBAQEBAQEBLzZxLIrWzRUKDLVHb1dE1W0N39lS3v2xAofePM4lUrLk9A51YcRMCNpmRthQ3Wb/Dydejly7DFVeTyW1NGteTXPJc8lKUB/VqDsNmipcE5j7O/FyKoYZW8+lXGbiYicFfRhBVKeq2lETJRtg74lgQnE9GNVCr/p/XgKNtaKAVJJHSFuTNlG/9y2aVBWW2y2JZarBerGefu6wntj+bInc5hzz/NJTavfe0Uz06lN5a00f1eAjoRJn1Kll8l+gwWlqMIk2eJr6EmVJk4jQlsbYZHKriLGweEVxOkML770PF65+xJOTg5wvrqLopjhc9/7DChArzdr7E0nODs/w4fvniFLFphkpyp6dbI3w2w2wZNPP4Ek2seN68eY7R/iIy/McPfeQ6w3G2w2S/WJbAtkIjFkl9tYxr1sDjwzYj6tXvXI3/k5YcGAq3tzXDs6wmazxcVqI4Ln9PxC/ZfmJMqZ9u38dNkXbn6oCJwUuoPHhhWlUjdZ8Tb1r1SoACn0igW5GvrbeqsLS2snyVXpelQbUESxqQxT86Vt5VysceB0oDdsU5HMoSdvhVmxj+eeuo2Tk2N8zhd8Pm48cRO37tzBlevXUUynmOwdmO0F28f55ixd/Fy4bGXhSSHzWf1YuEQLeTuOxz6zU+NyXjJwMMX84CqK6b6KqlXlFsfXb8hz+cWXPqq5szg/xasvPo+mrlBvNspCmLAgZgeUbYJNm5p/cWLZB/KUpbKUaf5ZhMJiFFZsM01RVrYGt1WCzZaeypy3qanBHSGngE3FwpTWr2nUIYsTzLMUW3kNd2hotdAYTUa1fhrVmCTANLMxXpbedMYIar/PqF/o0CIfhV2hM/OdtQCRcWx8hwp3+/uGCuKuwzwFZhPOQVq3AOuqxea81hrYNLT14BZhClizmHBWGxpzUxOLOuXGMgp02RlNtSuv6l0sZrd2nCULW1b6ee4+qf1QhflI5DsC1dnxaF1QmcvvtQ1SBttqZ8ugLJEeDX3Gnd2LV1ybv/wugCSyXny0BcJ4Eq+OJhMsG5rMlOPMbrh185osilhUdLVYo697ZMhxsneAJ2/d1Dx75dV7Gkv09Aj3lR05vtx/gbiOUdV2Pu4R3LvnRSLl+GQiTxvt/cUkE1FcTydI0hbrJe19WgWpZvJYpqLefjIrQdYtLGKnu2Cvtc7zVi1wUW7xb37og5hPXsJ73/MOPPP0HTzzxFO4enKIB2en+P/98A9itd5gveEewsBLrvYra8MFjKjElvVFJmp7iO/IBop7CrN3ciqWO6xOaT1Uoa1r7ScsLjmfzlFVJeqSJHYnL2lei7pKvvZOaa17hCnp475Ay3oBdY2KVkYsEujWX8IFy/GSHQwDIDbnuijVHvixo1MBAQEBAQEBAQEBAQEBAT9VxbJXog6FprzS1jyXvdrPq3lNteWJxJ0K01KZd6o/EwJ79eBIljVK3+bDtJVec0WyRh7I9D9u5dmaIlUqvKUbW5Els0ZQGrJXBPuU9EH9a20QwSNyhGT5zsfWPETtwXundG5FjunVOlWvu1Z3Fe67/r87VbHxxp5MGnnqus7zHqZemUaShqnWtGIYfFOH/h/Ric4c1pOkWZYjd8T9crlGmkU4O8uxv9/hiRs0M0hxcrgv79/VZiWSlUR6VTVSJ89yFirkwNK/GZgWOeIsw958hjzLZbex0aiY6ls2JaNr8Vn+4/d26lyv8LNiXCTRSRiT0J6QfFRQgV7bLTa0Nul7ZFK/9iJleHDZbDiSTv8dn2ykLB9pIc1ndaQWv6SUHISaZkUgpaonHFUMzHxefdebBQwJRAui8OpSqq6zFEeHBzg5OcL1m9dx89YtHF+9iv2jI3kp8+WLQe7GzvXdeEgvWQw4Im/81qW/7q77cVzWco9/wXlCmwkr+jebzVWUkgr3rMix3CxwsbqQbct0NkVVRujqUmOWOGmsFeekupnLbFfscvyiBQd5cV5vQXsC0KubBF+Miqp6kmb+84ouUTrqNehuj1FhMvOOpVexPGVbemWbz2482odon6M1vEsNMILZrXd+nwQsf15aPCZlHbIGLHmiH/Yg+UebTFokH3tNxSDZfreM5fFrHh02J0gSe7uEx8Z15008HiGbAPqvvC08oexHefxndwyli/CaHcHrt9BhLzGfYU+W75TTZsvjCxj6z1+aY+O/D+3e/cpfwyC6Hu1NHAsqlrl3MWhX17UF5rpea33OOSU7ol2Wg9/YTa3OaWBrUVS78yvmXKkTu6ew0CY/Y17Dbt8hcctYEoMEbhzN2juSSl6hOw0Kl4CdxxPlEvqSZO86nC+WKDcVtttKc2yaF5jOTvSd/T0WQO1Q1fTnZ7DCiiPqXjKqRcB2DeT8pbXq1daOJI7ZD47wVTCF98YEbexKcXorC9dODfkoIKHgiCvEyAAa52lEW59R0cXdvdlZE7l1MXj1v97mERAQEBAQEBAQEBAQEBDwpimWy9bSiqV3JKkUo1AFKD6RW4Gr1j3wy8ZA7Junh8ydUw/AmT3I5hn9IL0aMzH/T6rnVKjIFUojESieqTNVatyDAjBWu6dyjUSyWRSQWmpE+PKERizHtOI031YVZ/Lcq9lwSFnGdGNHCjRNjbpqpCTU7+lfyeJS9JbNzZ+ZirL1comO6jKqyRoSW0YsmMdmjDg16w2vDBMhINFfJF9QquTkASpiiynPpsTlNXrVLDyhTHX3NFWxQEsBT2X1MQieh961FGqRok69SxUbPTz39g5xdELf5RI//COvYH9vgSzdx3Q6w9ufvYbbTxzhhz94hB/98B4uLhZ4lQpCqrlp+5tFmO/t4ejoGHWboGoTnB3u48rhvvp4tV66QolG0HhVnoo7Sn3u/iJvbG/3sGPYqN8mSXjt6FDENa0wqG6kGvrs4kKK9Waw0CChT59fmxNGChlJWNPjVmnyNm8YsLB5GVkRSGqkh+J7TuVK5XNeiEiikr1jYUZ6gXc7EjFNO6SuX0lK8RqpsOVnlmWNWsXCSqR1jasnB3j6iat44omb+KKf80U4uXoFz73rXTg8OUEx21ORPju/0xEPVgScP+59Vma7RAS7PrykYv6pwweH5G1Mv+OmQl3WaJpKhRxJAOZZgYODI7WAyvbNZoX7LApZlvI2j9sec6TokxZd3KNlX1E5yYBE5qxuSJ6RBOb7rRXXo13NZi+Xxzr1rlQbWxE388tuK5KxMWpH8Fe0mmh7TKXKpM81cJCYn3W32GjP4F5SFECR0n7BkZEinbmuKxGFLHxJj2UqiZVZ4UjmcYHNS/y8+7MRjT3WVFBrrhWaW3meSZWbZjXOt62Uy6tNJVuDXWFNC55wCWhNSAVsXvGmH3ZBq8dIXEelKlDBfuH5RGBzJvdm86CAm7PmMPJZrtFGQjqyUUXceE6ez81pepUr26Q3D+CC/sBRhIqBEafwZslGtdx72r8mcLGzVzC7DEd4usAb1bfL1Qqv3LsrewYGi1iY89W79+UpTkL4xvUneIV46eVXbdypxHWBG2W9qIcswKf1KusQCwIyG4WBQlqa81jyMmYmh7Ok8RkfjOFwH+SLKm36MXMP2WwtO6Vp2K+uIiI92c2WWdzruqrkBf6RF1+Sp/2d29fxjnc8KTsgevYv12v82Ieex9n5AovFCuv1VnNs3ffIuxRZOpM6OWVpgV1FWzR1j6qUHbdU+Pzd/nyGvuhRZAm6pkbPC6wjRG2CNC4Qc5345AZuqSKkOccbd68wNXPC6JtZcWsESeD7cITPppHHvYIQ9P5udH7ziw7EckBAQEBAQEBAQEBAQMBbSCy3jalLRa/xJ5XAgzGwS8F2tg7SwjnCRrYY8lJ2Xpiy5zUFIslS0iB8/JUtwEDyeKWVkXCWym5mCySSyHaR1CORUKgxzuNU33d+uIOfpvPVHXn/0l9TBI/Jp41kIonJh3oqdFUc0Lw8pQRLEhGgPEfJ4nHGaDgCyzmYihxzD/jO43PnUWo/STZLNTh4+1rauNR4g2rPK9qoJmWhQxLcuVOZkTh3XqmeonUqy51u1ZgYqSozpohPMZsd4tGjh7h79xzrVY3bTzzEwWGFa1euoJjO5K+8WG5FXrz6yj0jnujDm0TIJzmm0ymyNkbexphPJphPJ6hY0I0ewc772Ht6Dp088ofdmUdf9j1m/5AE3ptNdFz2G9WNTOWmlynnEdPbNf9cSv7OIddbHXBuOIsDadQZaGBAIFJ6/2AhsOsdkdIar4S0cStVrfkG829USjv/bFdCkXNJ/e4IML5YJJKBhYiFFNsae5Mct65fxZ0nbuG5tz+H46tXcPXGDUz39xlNIZO9K3bmfZTVIOcx7GgfLwL1mlSfAeA9pV9XqTy8/zHIIWc7YmpYsy7gqyOR3Ji9C19KryfZnqSYTmbo6hr7BwcKgJzmDxVIoMqYPZ9nMWZRgi7u0KRM1zcy0zIEjOrUmnCBIIameKnTgqRmhO2WalZTWcqL1jOi6hOXsUC7EScuNouHBC2V302MxPlvU/kp8jC1z3Bv4TpmoIlzg3NKmQ0km9WTYzfisZr0sspedgNurBjkUHE555VBdeg0jxWwmuS03IkQb6qBVNYZvCrUq1j9OhUJa8e2wXB+ya9Rjtr8VYudWtr7Ll/KWRiI5cvu5VxbpJa1jkYKbu+/zP+S/KT/tfPNGAhrv2/69bzLKPFnvjzjjGR2ga0eKMsS54sLEctNXqDOcvkWT6cX6FPgYO8Qy+VK86WOjSC1LnOK6lHYzLIh3PywaWve+LKFcC+qxF2f61u8dhfk40v9wb7g1k1rDX7Jx3hc8GqwS496FTTlOnh0dqZro5XQ3nRPSuvpLMVqvcbyYq0gZ1NVWK9XIu5N1c7sj8LuAaqJ6HrJ+bu3LFrL6yos2DphkCVz90NZNJndUdRx/HgP0Ey0+6yP1UnhzCCbe09qbfNwSuln3bHgoRXRtfF26me3k3A9DNNBZHMglgMCAgICAgICAgICAgLeQmK5qY0IJGFpwk8+8EuGZ7QXlcJSDwOtE1qSrKUqstcDrz3jSsiqAnv2clbNllju+DXPs3kriLYlgdOJ+KSHqQgQ8gJqD5VpCVqyDm2rx2cqJ+WQKs7ZaFjTQruHau9tSmKN/s4kIWsqlulz2YysJow4yaiEpFpSrAZJOAnKdsSHJ15EWO4KLHly0LhnEmC7ombsRxIBcZxdstYV4S11bYY4yRxrYsSplHUR/aNJfjaOwLMrlF+qK95HqKgVVdD0rk1SNH2CizXVnyU+9MIr2Hv0CDdWS1lbFAnw7mefxPWTQxztz9XWyYxFq3K1/ezsAlk6QZoWSKMWWdIjc2J1XdngjWvXacW6fPEo8/k0Esgr2KlSBSZZJqXyflFgVuSotlQJNthUpZFbPu3bnYMEGZWqUW1KQ1+UzAcIeM0kKYuEtigJosoCCiJN3Fz15C09SDclVawkIFmgMBJBKLV92+gnyZ2u6VFJqWrtp2+rrrdp6ayK69eu4mR/ire//Vl85md9AFdvXMe127cx3z9AOpkZqSwlssNobrlBdX/VKLq3xvLVnTXM8Pnh52N2KD8pHDnf1GirLepyg81yIf/kzXoppen56RnW6zUenZ7i0dkjlNu1vGTpjYw4R5SQfNwqsMMzT5IIjYILreafzV0uRXpoc707SwiNAUl6yA+W3cBCjW4E3X5hewHJv6ym2tzIYk5pWavTgyOiT6+ts/ksRdPGmM0yTIpU2QxGQFKRSqsaZ1lDwm8ooueksUOPGJHqZsaubx2h27ufnA/0Iy6bFtuG6mzat7AYXY/5hGMcI1uXSOWz7LIURkXmdkf29J7TlfJzshLy5L9lH4z3QIIEtlGW1spxccLLhhrOOoU+yjBfej9LGAxsI1Mqe+aWbje51g25TfqV00XCyF2OhcZtt3NemoNmn+EUwlxezraCnyPBuikbNLSv6ZkJEOHBozM0XYT50VyvLJtgPttTJsZFtZXKfGfLYOZHNmW5F/BvDHhZsJIe60aGmkWQ/4+PTXDeccnTAogqerMHifU92vhYn1qnPr6SdMWMA9FvmVkZTY3jl/fxkQ+/gNm8wPGVObJ5irc99QyuX7mGK8cneHh6Kr/9R2enGkd6uHMOcv9lwNVuOr4nzcqnzuOB2Oc+3nU1tts1yqqWDQcxzWm7wb2PQYsWNWpzs2c2hivsJwuWuDVLGAYsWpLilsVhhvwWvGGQp2W/8PrsJjjYeATFckBAQEBAQEBAQEBAQMBba4VBlRe/wDpFtJrgg6k8U3dK1YwKVhJDsakw9Weprwa9m8hQKf5ICMcd+CtzCaUSyxGvvgCc87UVUUFF4EAsGykgVTIf3LMUUWeKQnINmdVrE01nfzRlqz3Zm3pv8NmlmpGF+qpKaf4kowYpl1N5UnVHK4yEx6cCkkWeZF3BImUklfSkvvOzJMkitax5n8pb1F2frssptUkCy+NWSkhHk1JtKxuBDEmUi/imWtz3M+kiTyyzECIJaivuZernTPYe9mfzrqUVQYa6S3C+7BGtt8BHXsRkkmC9OsfR/gw3bz6Dp9/+NNbbCs88eVvFCFflxvqp7XH66BwH+x0O9hJZoUhgZ1bHr8uXGrFsym9eo4gskXumJCUhkkZUfGaYk1ieFvIwvbtY4WK1VtE+ry3ldejQTpFOlbBXhDonXSMznT0LX9Mila0G5wJtFEhqNaowZ3NJBfrqFuuoVsG2iGMXJ5jkUxF4W1oJc15Q3ciiaM4eg8realuLQJtGiQi5p25ew9uefgLvfv978Pk/6wswPzzElTtPyWYDCW0TzBRcpJv6x83Dx72Pxx04JpMvFfTzqwifmEWGk3fS+qIptyjXK6wXFyKWV6uFPG/PHp5iuVzi4ekjvUiqlbLJIOtbyOql7SupukmIFSltFDq9jNc1ctyKK/oCn2ZtIFuI3gqN8XP1JNcaJ9mrl5tTTUObHM5femqbF22S8s9mLZNHJCuBvb1MAabZJJO3tUS/DBZxv2qoKG9QK0hg+0xHVbQsBEaK5YH0ddTsJbsHI5bZdhWh7CKRytu6lQc5/ZwneYq9aYEotvWYRA1akdFahYNKebC9GNSrztt9CFB5ktgCZl6dbAS3ZSgooOYk7Z4Mt+OO/cJ39hpS4Tt7He5dsimirHzYSyxjgOMo1XIcy9uclVq55XDfNusdTyz763LX5H30SSwri8Vnh1hxxXVZI006U023Pe49eISL1RZPxE9g/+QEeTrF3nxfGQbLi+2gdt4R2U6dzX1Ya95dO9e0CGZneTSyaVbw0xHdnE8ZC0hOLIOBRDvbUVrlQ9QMfrlgxOBtLT9ny9bgSS9WS5yfnWE+yXF1fw/Xrp3g5vUTTCY5Js/OtR9dv3YPD89O8crdV0UOl2WF84uVZeZk9MJIXFTAecO3EfosQV2wECXUPt4PWxHLPbZlg3JTIk0zHEz3FDwt67V+78ecJDPvOWwzswis/806xIhletNzjMzTnXYzlqXTmRWOv8lqP2zf+B4SEBAQEBAQ8CnB+7/u7yMuZqH3AwICAgJ+QnzkG34ZfnorlvkcS9KHz6BMDW/48G4KSD6YqhhRZuna5GAGAwIpGH35vtisZZ1VhinLzG+S77Po0q58lRGGRo5aKr0UzK4wnFS/UgfvCnexQpYvxsQXU6Cz3pSF9J+kwtrU0SR7a8eLkzYhmWuktz+mtc37V7o/O2WytWFcrM9bd4yK8Y2NWx35susVtndXLMuTKIO5q0uhvlQQ0HlDk2j1pI5xMAOjsjufc8P2Bb3kX921KMnYocH5guRFjGmWYbup0eMe2ob+0Bkm+QSZWGOzpVit1rioqVjOcHx4iMmkwMnJsbyqi/w+ypre1ruCVT67387tlHrOCsOa6sj6LMHBfIZpnotsoWKcFhhbKsedT7QnP3y6PyE7EdmrmCcxiS/5S0uAZyap8llOTQ3IcZOKzwUw1IeR85utG1mz0EaBXrZpniHW/K5cgTNn5eKCDBxzegMzgHLj4Ah70wmeffZpPPe2p3Dzzm3Mj48xmc2lEDeVsilWbY24UoDu526CPKZedspK60iye7w2U1m/Rlt5iQt63CpjF8wZFK1jb+W6UjBlu9mgrkrZEpBYXq828otdb7bYbBhoaTQu7PcoyZD0zBCokLZcA94sxJ/DrBCYAWC/IazAIQNDNYvdMTjUmMWF2kfCa/Aj5uLieXrkBfeRBLNpjHkWYVLEyPNYNgXy56biOco0/iR4NaZUblZsb2PBAF3umKB8bY/v7Hcc7+s5N2+h4guJOp92KuY55+k57ouA8twkbjniZkhxScLvTzNY4/hCed79wix83NpxQa+BbN9Vi9wZXQxD6ZTRA8nsZsFwXNel8qVnsUOuCT+fqdTf2V6YbUlstjCutiYJcu6RtDARSTnSLA9ErlubpjFWabyh3/y+5gnykt70iLHelthsKgUeJsVUgTPzkmagx60Zt8fJLsip4M1WxXnwj9rBtvF3bLR8pV2wQlY7nBuO3/c+0AyOxLwXxFTh297lHc9JwnNuMftB2xeDgCR2qwr3Hz6StcrRC3Ptg/J+kf97jKP9Q7R1i6Yyj/g0e6igTBy74JXvi4HMNWskiZm5ZtjnTkxPL3BRvcrMsdCoD7SmtNTRpTao9QXz57cMHiOP/T6jIrMtLYssBGde30awKy7s7h9aqcEKIyAgICAgICAgICAgIOCtJJarwQazRyqL4Q5lxQdS+gXUUs7OpqaUNfLYfIJNoGwF+pS+a3m8Ti2nWk6mSCS5KG9bUx3ys4l7UKatAf2G+Xl9NvbemU4FHHfklN15SGyQvugwyWJ0VFVOc+RFirZmcTBedIu2XKPpMkT9XKnYLCrWkrgiAebU2ZSBqSCZrDDYDiqJTf2VJtmQJi8qkIQl7Sc8ITLieklI7ggW+7u3iRAZOrBZvrddwSX5zdpPkgp5XjhPZrZzNzb2Va++84SmVXESodc32DY1luVW6tTlmoRHj7NHK0yyBCcvPsLx4Ydw+/YdfMZnfDqSJMdelWBb1vj+D7+Au3fvYpIXePrObRwfHeIdzz2Lg4en+OCLr4p4ZOo2SegdmWUeoSQYRTMNHstO5RwB80mBJ65dR875UlcimxbrDS7WG7OhcN7M3mbE2100IvWM5PRqbY6oEWqmsKT/bc6Ci02ncVPZSQZG+MmUBbGoaDUvZ6ldiwJpnmM2m5nK9mwzeCl72xROaqo6Z5NChPJnv/dduHH1Kj7z8z8b73zvu0QqH16/LlKZbZOLqtgirxSmgtBsAgwDRecG0al7pTZ1NFef7EhlH0AYBQ52VOLjr8fhyEmpbkkql6i2a9lgnJ8+wna7xaNHj0SenZ6ey9/69HSBR6cLjWvZlJp3ExYhzCNM6IucbtFUK7048kbKcg0z0GBqWXmZc9/gfkH7Ea7/PsKWRJjGynzJ1SskU3mtHW1SeuynqebKlVmK/SLVeBa0KulbJAo8xNiLqQiP0ZQsOthgU1e4WG6x3bLNJNW8E7llUljRS6//3dHhw1JS0MkCEVrJZl5rRdeo9G47rKoK55sYsyKWKpcEZkKf7sj2IM5EMzHw/ti7onRWrnMXjJKXu6w4zN/Yrx9eo1qtwnTmv25zx7PfWknmAT1SX9syU9jMBXeozrbCojkDLKmR47wWBYRYqI9kfNKhYEFCbcgMQnEHtf1Pan9XE28ogLrznhhCZSrU6EhWvowo5sv+zGtdrDZo+g0m+3vYOzxWJsfB/jHSJMer6YOddY7bL7wi2ixuoGuwoq0+yEeynCuBwT8rhMeiofxemnG8eL8wCyLbToxcTYtYRG5M3w/ec5SRYO1ndopfnbQOkZdzHON8ucCP/NiPYzYrcPfBXRSTHCcnJ/Kff/rpp/HkjSdx8+Qm3vH023C2uMAPf/jHsNos8fDhy9huV7ZvK7AqSbHme81CoKBKnMy+FV9kQ9Ust2/QnobEtgU+7T6YK0hWoVLxWwt2MOuCRTN9YJRH4Nrd0iu6532DBVcTBQtUBJGqaXHnJLkZXAwICAgICAgICAgICAgIeAuJZa+kNTUdyQdf+cfSlJXC3jJF3fyPlaY8IiIsLXfHkVlRPEca6lk7Ruy/44ulebfNkVLZVFckCI1cFrEswsAK5zV6Nt+pm/nwLCLaDJedV6c/nteRGSFO24Ys7aRw5jXqPCKwnXL6ktftY73jzmnKUK9CvAyRQ50VOBx8mV3fXuIOB6sOI66k/vU2G2Nj448xVpYiTjUbCVIq2lorZibG18apVhGpXhYYadpjuVphsVggy6aIk72BJJd1RFXp91VVI8tzFEUuoi/PMiOVXUFCpzUe0ZyuaJwrWMaLIB9Mgp6Fr3gOkrwkjElytSO18phrH6vDLzkLS93t+nJUn8wsMqwgnzyePX8/qOFN9i7FLMkk+mw3rdLjpVL2/rciGW0O0hP65HAP+/M5rt+4jhvXr6tI397hEYrZXLYoChb4hnk/XxbKcqYsLuoyeOz6vnI+DpcJYjZ8JwH3Rx0CPN424LEeGT63O/j4Xe8t3spLmuQ6x5b2I+W2dEple223leYNias46ZHWtAwhcWdBI/PB3c21nSp/d2KRuZyDVPs6FbF8rJ21wWADcWkt0YfZigAWBYMEVEmzGJuNR0JfC2+zwjY0JMpsbL3C3PzZIyms2UqvEyd5raCMV/yrn8ek/evD97uUyx2tHpw9CotWOv5Zdgs+O0GsqL+y0b7gPHaNRPU+GJ5U3o2xt7vwhT53xT3th59lsspwBTzHGnhP1ut3bv6a2wXVul7HbsXkSJDnpMTjFFmcYEZPdnoYS/FKP2Pf7p0CflfYz//YXRN7fDcLbP/kWHX0u2467SGcWzHJ0DSzl2yBEs2L3Zy2Ptee78ZbmSOWUGHzi/2jJWUBNQbguMdz3NmvFB5bMMsH3pxql3u62mX7jY/l2HFdUE/ssr3J7qe1yrYiQb5C2VRI0lTXsrhYYI9e0WobbZNy7O/vycalKvdEhMsTujVbJAZBfDaMBR39/cIaIe7ZhUilRvb2UG4/tAKldr+jWt4XPLxkteL2bqmhue9rDTvrqtHxXKpHsMIICAgICAgICAgICAgIeGuJZaqGDVQbsogZFWH2gJxlmZUdq5gWbsXtJKQT+WyeoSQU+JCeyzc5wrRIkOVWoIkP0OuqxVbEoqnW+JBct0wXpncz0FE1HAPTzLxdD+a5VMbMSE6TVlYNWZaLIcg3JANaRE2PqO1ELtL/lIpoPkTzz/PpBCnbnZrKa1JkyNJUKc7T2UQP5EwjJq8wo+JZHss+XZnqSlMkSvnmHsxVHMqlIA+8+0i5zNRodkiWp1LOkhC16nQkcRzp7SwkmrYUAcHCcrTtoFKt6zNH9npl94hF9ZSPV1FH9FomaQtsNrSYqFD3W/kbZ1T7SgI3QZekWNE+4HyNTf0iTi9WONg/xLvf+Wko8olUvMfHx3h0do7v/X+/D7O9fZxcu4nZbA9XT45F5Dx49Ajrnh6gnBM7Us9+OiJDrI2zN6EykNdRtmjQ43yxUZr8pqpllSBPWJ+/7q7NW4cktOmgx7cncEd2IbINYHGutpdXLM8zyzMVrEqlLjXKjWRgVuQitKg8zOcFqqrBK/ceoqaHbmXko0gZFtBKYxxkCa5dPcHnfOan4crVq/jsL/xCXLt1EwdXr2J+dCiCukfqYi2OJG5LV9CM+kMzhrhklTIaN18B0Tx/nVqbOfmqVmZKdH+9VLkPpNBw0J+cHBL5T7KwodpxqyJ95xcX+nn/wSNs1lvcvfcAC9qfLJa4WK6copeEWYzVupH6VPYiUYYGOWoUUl3Gfat1kIvwMk9fqV/pZUvlZdtjVRqRKd9rktP0TKblCn2snT2Oty84OpzKduRokmCW+f5w6tSO2xYL03EWxYjbBmXbOcUsbRUg7+Ox68i261Hx1dZSTI8o3IEIHdTLmia2Di0o4m1uqOSNwDqG9FpeljKJMFWtvIoT5GknRbU2LhWgdHNJljvj4JAnio3oH5PrnrSlPzSJyGH9jF6ZiFLahTg/e7aR68KRlNyPGaipXVFVFTWNgYKq1g7YtGahsKpb9V02ncjPdzqd4PB4H3VV46UXX5UdCpXGLKpJqwUGg7j/t1Tg61yORJavPM9j9jN+LpLAZYHQPE2xrM5VmPP8Yoni/iPMZwWuX93nN7C3t6/1uVjRIsPSY3jLscwUesfHyhRgH1MhbmYrHbbbjW5jTcd5GcsCR9sqfZZ5d0l4fxAVrqAa93AGxBhLqhuuoxRl2WDTNfodz+ODhBwHlvVsWRjP+fqXXYsHF+fajx4+OtU5mdFxuH+Aq1ev4s6dJ5EVVtiv61s8unqCzXaN89NzLBf0MWfAZqnvT6bchzl3nPc5CWKOO22uVZyUhSI532MUKoTJ/jACPk9YNJL+9RbM4+2VqmsR1VGKhMUQukbEN69is1mjaTPEBTMqHKnOewS3Krc1BQQEBAQEBAQEBAQEBAS8ZcQyCcQxeaUaT84TmUWDZLtARZa8To1YdtySkaxOvCmQIJRthPNlJWEhQtIUWJ5E0RlJMFClKH9ckiSmdM7SGFlGe4rWPE75PlWODa0vSTbtNJCWDD96OfWb+XEa8av2iLzqZLFA0kXFnXje1ArqyY/Sk7ojeeBIdzoU0vKfGfyBh/eteJx/e6cB3B1LBJfsQEw1aN95vaJvnlgeHcORkv4nSVQpgTk2TiMqha+XfpMoktdqi37NAlqPUNUd1uuNkXZxLKsIqnnPF0tEae7S0xPMJhNspyXyLEUpUsv5azsvT6/k9CpBqX+9dzU/S6sCqqFpp0HFspTCdg2eTBuUkiNLjF3Kv79u3/dGBJuq1I5EkpJko1LAnSZQ7SFBk9JOw9L1Sepsy1LEcteabYsfHZJNLDJ4uDfDzevXcPX6dVy/dRNXbt5ENp8jLaY2MoMU1k94KrnJ2vAl4/HRPBnbGjg1+4g4lBLXzMqdetOMFLwS1vfLbtwH3fvld8c+D86rVyrIlmrUBrUUy6Yg3WxLrDZb+WpTubxlFUMnGY0as22JkxZFlovws8JuFhySBYL9SWvME6hUavPqSXxxPWnfYOBJ69n8iTXqTnHr12eW0foiVfE1qpX1KQVxrCAdx4xe11Fv67iVmtUpOWlJQZWpV8lzb3Jznz7Dfk3ufIJ9tsGo90YqYPNM3hXa6wbfcl69X2pOgco2DFT/LkVj5OT82HiNhsfxzf6PFgghqWlz/vGECe2vY6XyyKJiUGiPA1yyjDD1uC9aZ2vf+VFHtibmsxmqtMJ0Qv/zFpsyRsIgne83P01H5/P2RhYAiC+dM01TFVhVu+ht3nC+cd/g/puqgCnHmwHKOGqGgoB+uXh/ZI4zs0kUgOBYjn2ihxCIGwtXyNW88G2tWZNdMUMS7iwgmzLgZ+M49s6XypjXwA7zm7mzxiHhT3sXWfK4jB0W2+NecnR8ghmm2D9m0b0Izd6B1ktbtehqu1e17VZbL0lyi5/tdPs6jQhmOy6DdaLFLS7r1MqmxOZ9cHwb4Hr0nuL+budVy9y/tSdnzNBwY+NV2R87HhUQEBAQEBAQEBAQEBAQ8FMnlieUC8tL1ZSifNCl7zLJQD4wK8u2M6KorknO9c7ugKnr9mAsT0fnk2w16Ehs0HfVHuq7LhEBagpgUgCmEiZZnaATkbQ3pf1CjH0plqlkq5HErXnq0oO1S4wsaOlVSdKyRa72Wco8H8zlkZwVIjtIGpNUptLNvItJWPgU/cLaQQWds6UgmUBVHpWIIpHoH+qIFFO5GtGhn1Jfs3CWU98OtJS7PikdqatzXp+u6NKO4HJ9Tkq4b+W3yePm+UwKbar2pIKmb3BPX1/SeEbW6pj0Uq1aLJcNytKsGKjgKwpTAGYpx4jKP153gU3dYbWlz3GLqvlXmE2nuHnjKg6pyCW5Tn/PYqLiaGz703eexLUrV9F3/DvbV6uNryHBZH9gRGTGAAD7jWTMeityaF3W2NKaQarL5LJNiLMP8Cw9VaxUnPdRquJZJILXGxbbs34naOtAReUkSTHLUvn5ql/YLtpVJPQLnulaNtsNzu9TqdxiU1Ehbp6rbMEszuTv+7anbuPT3vM23Lh5A5/xeZ+Hg6NjHN28g3z/0FSizkQ38hGUji67DAiU+jkQRyR+vP/4QD+6n1KeWtVGs9OgT673GRibHPifO8Xr7jOvp1z2BTY7tHWNTv7KLM63wWaz0c/V2hTKq/UGZ+cLXCxWIpb54kI18j2SZQrHZlpMpO7vNHeoheU8yuSTHpPAoroyM6KV64QOrmXfY11bRkLTGelLX18pUrnSnepf5LXkrxzjxBUCZNcZQc+gEf3SqZ6dxIX5oycZ2hnXI21ZeilQ0ZXq64Ke14ixoQVD22FJn3WqPBk4UjFLs8Eh7EpG5hxDwb0R7+tIbapKV1Wt/S2X17fZd0wyYF3XCnCJgPYBgRFROh5DkaRjcs85EwytGAh35/3giwoOs8epa0V2m/TUn4f7qifAG9rVuKCctNPyEu7QuCDY2bpE8+BMWQPsXxK4Tz75hObNq3cfYrFYYbmpsdrQmzlBBmZQWNBKCmlflHUIjLiigOyXLFGQQHtsAlR1JaU8u712ay7PJ5hOWyxXVNjy5XY+FfZsFUBgy5OYfvhUHOdS9HY9C3CarzQ3YJG9zHSpaT3Bl7PB4bEUGKTdOQlpCxjFRYKIAY+ERfa8iVAkix8S5JUyXbiWeV8wcjbn3HYFEa3+a491ucYLr7yIB2ePtG/eunlD6u/r169hNj1Efn2G61duYbl8hIenr2jPlOH/4LtONX2rwqUq2Me1xTHU5GBWj83MNO0wm8S6zjixIpWc7m0bo42tukDbUYns9hvdlmPdkxXgyoyU5n2cdjbmI+73ooCAgICAgICAgICAgICAt4hYpv2EFLB6yHUODr7ollLbLf1axLJUn6Yw5MO5E9/6bH9TaTkyRg+2ImBJuvHF45PItePvHvat2B+JRb4mJBgzpvOaCszUZySVzeeSNgIkTfigrgJVJLw6kkzmD0triZgFwkQOO/9O+oCOnSUkVutQb9foGiNQSNSQCRKJKK9gHs86gwSD96A2r1fXeSPV3U5V6LxPpYx7TO18SQHolWxGOHjvaLaV0nDzGDWy30R1RkSY5s0sSMqSJBqDAiT+IaU3Fd+0K2Bvk3jln0i8bdZW3G+9eR7TosDJySEmtA1xBe54bKpd2Y9Xj09Q7lV4+dUXsVieWaq27BtGCjw/9k49qmADCR3SRxU9fI38r1xKvxTJftKNfaSlUqRSPdL4M90bcYooarDesg+cGlmFJhsFAfIJSezMjZONB31RSZKzECJfS9pBLFZqB0lPNptWBjwXCwvO0hTXrxzjne98G67duIEnn3sWs/0DFIfHiPMp+q5GRwXjMFRWqE+MVl89lmfu5aVjIsd90auc48F8wvnjjiSsY2J5LJl/DS6r2j0xz3bSXsUCAKZU5qsksb8tZXuw3pRYrbdSb7Moo4hlt0twrlkzIjRZ64qlUUHsAiy0e9BnekRch041T49YrkUWKVNRRNcyFiHTMaQmtb7pnPrZlLG21+x6kL7nvQhKBmRIKJPQ5nG6DCirBsvpRtkVm5VVNpsyeCQSrVVAg4Xr1gxijCj/sXe6Wc344I9fh5czAviObBGaFp2zwDBxKwM2TpE/rHV/jNeO4eP214P6ePSL8d8vu7LvsiGMWN4FVqzeoPPqVcyCKllaw5hNBSlaqVUZcHFZJ5uydtZGtDzZYG82xa0bV6Q2rmuzPEK0QV1v0LnMEtl7sAigAjHOp3hEwJuy2PpFa56e+FTxto2CFsx4kEUS7y9phjzPbe/2ViGatrbvcf2aVzQL85H45h5nwQd6e/SRFSPk9ajfXF1Z7ckjEa9oXJHhJIZNSd0kjRTKVrjRxoHEugoCuvaTVNb+5e5BZn1iJiqbtVn5LM/XeOXePRHLdVlin7ZBh1eQ708wn+7p/sTvNu1KBTTrLYlvy2bQfUMe4HRx4sy0frGYg9djWzFCir+ZWcPADotjbiqzR1KQVvdR3ri8dY7tJjwNgyQM8mpiaDw9sawZ8Tr7SEBAQEBAQEBAQEBAQEDAm0QsmwcquTxLOc9bU8KRFCUpxwdukn7KrhVZQW9Qs1sQ7+qUfyr611qqutlQOLJDJFUif2Kqq2gFweORTFYBpj7CLEtwtJdKAUcFNZXGfZ+4B3JTDYtgdAXY1DZPsDh/XZXy8lYUTiltxDJJRyBxEkGz6HDp/vKvNCWbHsCdn6kV4XOP/7SscOTlzt7Ck0fufCKc2B/ex9Me9oe0b/fRwQ7CtdUbDBh50GFbbkQU0S+V3qWmmHaFusbp/Ry3tkVZlSIVSZ6IEEktpVyiZk9IizCPkWRMf++lXGTRto++8DyWi3PcuHEFN26cYDrdw/HRERPc5VG6LWscHx5isTrHZrOSnzP7lKpxHrgloS1vW1OnF2mKIsswoc91ymJdPfazAhnVpNsKZdPoavl1+Y16ds5k7ciTQrYUtGEhUWmDbkS6pf1TRWk/GeAgWcarnGRm2dAnGfo4wWq1wWK5kdKQ0Px1pP0syZDHCd7+zBO4c+Mq3vf+9+Dt73sv9o+OMD08QjqZGkfcNeipOmxrRyib/YUplndKRD8ipiL0q+myDYJXwhrB6KI2HpckrcObuy8PPz21Pvbw5eRtZGnQViXqcouGlh9lpTnhfWnpNU71sMTTtnq1dmQj4cg/f2QqfXktJAkpJO6THnFGBavMYRHH5pmctKaYrxjYqF2xO2dJoP9qrXZIox5ZbKEQ72tNspvHYzE5efc6OxWtRUfC130jgljkGPeeDJhMqL6nlYZ1wYyF/1iULiOpDMRb632qQzU/2QZn7zB2eLA9wpHjbjUZ+e1WtCxmqA8l6UiC04jGIrO9za8/v55ZNNCT5zvFsZHCPujkZf5m7fDYaLvKbWO1sr3v2imLGZs3zpxklzXgMkX0ffYB57uCLU69LVIz1ngvt1u8fPchppNC+zv9iKkCn872pIRFzyyPGPnEMiTYB9yTlhdrlCr2GCs7wfYjBsJI/Fobspxqd8sKoc93WW2xXK4UDJxOZiJ5zy5W2Gy3Uhszi8Dba3BOLLcbeWRXXYYis6KOzL5IshjF1LJLSIJzThm77iyIaFniirNa7MeUvrI44j2kS9C1manpae/siHfu8bwfaHox0KSAHhXYFhyVXzUV6/kUbZuj3NZYrflejPPFBTblFj/ywR/H3qv3cO3KCY6ODrXTnxw/gbapsF2doW0Y3FmpP5jNwQnNazUlOfuR9xXei6xQJr3cEzLLfY9p2miPYyCtZfQGtVu77rv0fGcYgRkjDfsiRspYET9KcjrlvCepzMUSiOWAgICAgICAgICAgICAt5JYduoz+lwynbrNnepPilMqUUnm0DOTijQ+3NIewbMZlhpNUD1M4oeWDFILu2Juysymujnusamtiv3cEc5pzMdj+twmON6nH6cpFomK9hkkMhzx6tXCUqCSFO7NP1dkjfPi9GYTnkAQsUwSmw/bJBJ0vAg1NdjyyTWvYu/Z2ZM4ErFM0sH6hQSA2Rk4b8vIFMO7RHq7XrbTUu/NHiMi4cSCY87uckcs+0J8ooEGBSJJEto3lHGJdH8f6aQY/FilFvUcpEiVaCCWSbBSda6gABV/SYJKKeFGAurcTJXP6afcYLVYi4D54Ic+jJfyBB/onsXVoxiTvQK3rx0iiQuUVSGl49XjY1ysWMiKBOZaqk1Zp/RAXVJVzfR3zgvIt3U6LVDEEfZSs0Epi4kCFdumwbYyb1uqTAe7Ym+N0hsxvTeZYFXWqOraFIlO027EpX2HymkSU7QKSRErFZ99XqeZ0sXPlhdYrFZStpP7Ylp7Qe/uOMbxJMcsz/Getz+F97zrbXjH+96H93zWZyDJJ0gmc7F0UuF3NfqWFekqU7maRBLomMpPGb4zGvfwJLnzzrX5aBYlYkUj2pvs5otnEG0WOZJ6UL16Gw3/3s5ywZ1ksMFQ++ipXG5Rbzao+LNksKE1pT8LedHvVSpQ51HtgjUWVDEZKOcOCc5K69OK5JGM7lOS07nzbjYiP2cgI3HEcm2KeRJ3GhtvISBPdrt02kmwHBxf7Ev5XbcMnBSyYGGxOhV8pALarWUqw6VbdUXj4rzHZMrgVAzW8WSXzCYJCgasaN3AHuLe1QOb2kg5kqel9+f17XJ8vA0BVazSrZqPtAs8WXFPZgDw2DxZJGserjkVgNtRy2bD4KxN/M5jKmq+eAwvSR593u23NgeMDR5IZeehPrbWYJ9adVF+3HsNu+3XEcsK5DmBvIhl2Rd16GMbbwbvFustFouF9ojVZqOshVs3buBgbx9JlKNIckwmOY4OZkb0c6fvWrzw0n2cnS+xrYCyNlWwFZOLlbnS04InjzGbJthsOS9KBcgWyyWKIsPB4QzzeIaHZ2dYbVZoSwtSaCzkC93jYrtBUjKgVKDIUhzGOebzCfKcxf9yzQnauCjzwO+xVCdbRVBkST7UB+C+mxcx0swK+EXINU+XS+6Guz7jPUGJBK6QoOY8LYSSCEXhvPrjQvej1WqL7DxSUOTs7FTneXS2VlHZd77tOTx95w6Ojg5w/eaTKsi6zGeo6y0uLu4hqtZKdKASWdk8UnZzHTD7grMz03qk7VHMorMa2wZp3egzDBAwi0Pe8uxsF14wEykGg5z9R20e5Ltiqrww65eAgICAgICAgICAgICAgLeMWK7pc8kU4dZS1AmlR7uCUCQ2RPAojdg8UklUSGsrtZhZTpgw05HNjkgx9a4v2LZzOfUPv7EjYng+z6uR3OqdolW8gTPpEFkkKw5L298VZHJ8q7OrIOHa+uJHw/ntsySLSBpVZa207VJF3Sop60goRF2LtDNiWyS5ty3wfN/4etwxjf0dSnoZnC3G8Dk96XtrBF9cbEdMGhEVoaavMq0FWFSvyEX203fULEAcodC35q+qIm1Uf9IyhOPDYojeTcET7juFpn9Jyccx1LhZevzde/dF2h3tX0NRzDCbX0WRR9jfm+DoYB99S9LSK9Ht+jUXeqr7jOwgiSniOUtwOJlIJUmVYN50OD1fm7WHVbQaKUeN9OOLhGZV0c6BpLFZrng3icGOgt6rLLxIhR4VuUiRR7nGalU3qOjF2lifpE6vR99mtmda5HjyxnUc7u3hyWeewq2nnsTh1RPEmaXpe0+XiGplccckiJkm7xXKI6Wy8wHZOSKPVMaPja1+6wrzDTYo7tpHLOfAKV/26/VHfcx7md7KHYn9Fn1LJWd76SVbjKaWApnznGuiH80zrUAXTPG2NWbnYv7hvgnMvBfJqiCQ+R/b99sRL9ohUbDF7HI4R4x+tYALVbAq1Eki0CI1l/xn/TXzbduLOjRsJ+eU5nWMxgUa+MmU6QdOxWvWFEZK00JnPqHKP1Igg/YvJLi5Vqw4mlPwO4Z5WBauW604oCsg6VwFqOh1O9mlNa+rHwpSqhzboFQ2mxB/TCP0fYf6Am1+7u8GfKed1hCPLHa0VTovDfurr9A3ZpjZL4mOm3I+aO/mmvLBMT/zEqmOl5tKwafs9BwbqpEZUBAhXyNlkVMXYFPQTfutkZa6L3DeuyKe9OOmZ/Hh/h6KIsdy3SJLW0yKFJ2sZNi6QsejUnp/b65R35ZmZ8I9wxpojtWmZI4UtCCpStKX65wTjQEknr9hIkFjingp6Wnf4nwyuB6UoeGsZpStkDGI4gpjKtuB64Gezhb0S7y3ckKFst3LMmbN+AwQZkIw2EDbkCTGtqrRMFuD3vhNhfPzc9zNMhHqzMxRhkxM32l6S9MmI0VcxYgYk+Ka5D6oheCzBbrX3i+U4UNrDO5i5sQ9ck9y1k1mK8XrtTiCZRhFHQvccn9nHYOgWA4ICAgICAgICAgICAh4i4llEgt8mKVa2dOPfKCOnUKWD9hUpImr5QNz6w0t7aFdhfLksWkE1EAqe9WyiIpeCtJBVeyICalslbJPb10qzjosNyTDerRUOZNsjVs0SY3Nhp7CjdRnJICyhN7MRjCSwOB3qci1InNG+niimy8+7GcsArapsLhYmnfm8kLkwHSWqbAV09pJMLBdvBYq/kii+4JseugXG+H9kD1Rar6khKnQjCQwIswIH9lpeKte7/Uq0ojkW6a/rpbnqKpSykj6TM9nc8zmc6nPSAoppbuvRZqVTYltReuMRvYOItlIhsu9w4p67chlU/mxPzLmeyNBMaV1RYz7Dx/i9NF9XL/6ANtVi+PjE3z6p82wN5/iievHiFlAMalQ1Y+MvG+MtJlOSFxEOJgXSquvSxL2PQ5mOZ64ehVdlGBek/BtcffhudpJdSntKnwxQ8IT3LTeYLGtqqHHLdWSJEnMG1WOvPISrtFTzdqnmCSFvHjn6Vxezi+eP8TFdou6rUW65H2CaR9jP8vw9NUTHB0e4PM//7NVqO8d738vbj/zFIrZHtLpzMgadq4mUiuFuBUao0rXuwE7dbGzAjBi2c2zkR/3QDV7hnhMKDtnVfNnlvGw922x4n4f01fZa+RbZ9XRqWAfSWT6urZ1JQKW878qS6zXa7222y3KksprrkFn1+J8Z31xO39s2cwwkOOKEXJOU6GvoMF05opz1jJXiFEiRalCa1nMcSGB27k9w3nV0nu2yOQtW5M89IpuBmLo48tzaU1Rld5hu67svE0p8o1qZ9q8dLW9GFAosolWI21zqEKXbQfnfsKMhxTrutH+Qp/kRemJPM9nm30C21ErcGBXbyRkrCCEeXYbwc4gl6PI3X5lBSqp9NZORvLZ8YMi3dWnXIdGrfvsCxaXsz3R5guvlcGPwRpHhLXzIHexCQuE2UQak9W7IJmz9XHZJiR26SE+4fpQgVO7ToXmdCKytCnqvsO9s6X6/96jc10Pg1dFmmGaJVic59qPJ8VE/cognCxsWLCRmQj02u97ZFmE6TTH/t4MV44PpPA9O1/h0elKavG63Yio7jFR8cXjowPMZgWSew+w2WytWF5ONTKLfG60X5FMripnPUPfeUxQpPuydMlow9J2WCxKbLpaxe8qBtUYQKi8HZOp6jl32F1MZkin9GsHomUpP336zPP3LBZLApz3sPmMRV9jTKcWmOD9il+aTXIUKj6YoigSkcpRTs/vFuePSq2rF176KF556UWcnFzB/QePcHi4j7e97RlMpnMUk0Lq+8XqIdbrU2y4XzNYhQYbWfw4D2btz8wEMsU0swS4508mFtxoGhLNLpjrvOa3Da2JrFAhM3hob9RTtdxPEHf0zS8Q57PX2u8EBAQEBAQEBAQEBAQEBLyZxDIJFBGxjoT0qlwpD70C0aVbk1jxFNlA3Do54OAV6t7beYVamvGgan5MjumVzEr1ZTGwlqSEETtMESYZE3f0mySxYHYYXulqiuXLBfLk5aoHdr61O9dQoEtKREv590WxpOtzxxmslL06We12RbHkBeuILP/e65KB/tp3RPvuU5d1rnYdnmx17advrhSpZgExVgq70nkiZKhaJQljZMjO/3lQxI5gFgM79bK/JhVAVNGtCosl07sLrNcLkaBMm6ca8fxipsJVVO2VXWlcqNTsJPhSTGlAK0sGU/tSeUoyrSx7lM5ewdsQj5vmr98Uy+YkSttfXxRy+IymRq/zSf3qTLA5vjy2ji9iuKclrwIP0zzD/qTA0d4erl45wfHRIa5dv4ar16/h4OgQ0/keEhJbw0LYFdqjStmV3BqRwkPPj4oxjg1RxoTyru3jMd+NkDuPSB+v4R+P13j8xnPs8mf8mNqc8Spm9kc7KuRX6SXVsreU8XYMjxeaG5/eLyutQ1PwqxilCvBxTrJ4XusKO9qxfRzHK4T9/N+pfTlvLABlyubd6cyf13zUZW0t33cqZxXhUV+RgDYV524/8XOfgbEi46glmE0ypE0iAjTmvsF1JIWxWUNwgnlXGU/kjttq+4RTgA5rJ7rknfzaPjNLHb+ObS7YJ40E9pPDBZ0kS358LxxG1o4xfP+1ZxxG0EUDSYzzHFJy87rZL876x6uaGbzwauzeEedNxLFz2QxRJ7uLtqMy3fqcPtz0RbZ1a8ERa7/5HrNIJO13WAiSBDfJWhZb3JSlhUEY+KAamAHINFbmQM7ApQIP9GUGktr6e9iflXnAjAzOa+uL1GfNJBFqEtzkkF2/kqgfslaG/mc7SdIym8LuQV6Vb9Y6dk1mi2FtofWG5pVTE4vQzljUlJ7P9DXuFYDk0dPMMjI4l3isqt5itVqIGF6taNtRIKe1huYLj8sMlFb7KwN9Scz+YbbJrm2yX1LNvRixG0Pthbom2wj9HB1ipoPPtpsvGjeuEQY6QvG+gICAgICAgICAgICAgLeYWKY/Lh9WmzZC2pLM8Uo7KhddGrwKV5lSN6EvMYWdTS8ygbaP4ppU+ChCktGr2auI7eF/mptyb5GaetcIUj7ctyISqVJdsNZaBCy2JDqdlbD8d9kOEhg9thVJIlpj8BhGNPLIfOamwpWkGv15+aAt4pg+sro+syNgWjyVXZXSvhtkRY4sSpEXCdI8QcK0ZRYxpL8m/S6p1NRPSjWpqo4dceeIVZmZklRo9D4/J09bFiZzSjEqKj1h7JXLIj34fZindBpnRnK4VHz5vDKdv6LydCMF22RaGFElH1imqNfomy1iNChSsyiw4k4ufd6xWUakc5yd9+/g8Up2LzFFXB/hYtvipXuPcL7aII5aKRGffPpZ3Ln5NkyLSETKxeIUr7z6gkjclJ6zcYT5fIb9+RxNvcFFs8bZosTp2SsqovbSYisF6elqa17JJI+p1OS8orpaivNcZBgDB1XFglbWLpGfEa0oOHYt4r7DlB7SKVWEORokIrYWm1ONddq1OIhJABUiim5du4qnbz2Bq9eu4gOf9j4cHh/hnZ/2XhxdOcZ07wD5dGbzg9UCSSAy5V9zzfkpU7EcmQJyZwrtlcpewexfzjJFpE9ymQx2vgaDYnkg940w9YXbRuXl3Mr0P3cF9vTJS34r5v/ctrS9KFFRrVzXWG82OD+/wMViiXv3HmBFf90NvavZx0Y4O0dgl61gqk7zTPdBFFPpmzUG1eidCMKI5rRM948SzLh+6DudrCzAkdjcalR8j7MxMbUu1ymXD/3aU/Nyp2+xsiTkrWz+vYslbTw61FuS5D2O9nLEE6pVE+RFji4DCm1NHdK21pxQEVDuSySW8wh7yHB4MJeq/97ZGpuqwXpbY1tZgKNmIIWFRF0AYacmtuNwDXpf7y2tGZztj+92rvmEdgMk2kfEvAUiuNZ5TCv+pwCIOzbBOmw+c0HuOY5EFrnqyeeBYLa1yj2ChK3ecUE1CyYATB7hDC0c4c79eoYOWdeiFCHrAwK0I6EnMvcmC76onQxg9R1KFXKkotwKLjIws97wGhzhq2Kp0srr/Aw4tm2FV19+VYTxydEB5rOpPLMP5xOctUtcnJ9pHtFeYjab4MaNazg42JdP+Xp1ZtfJsed+HhW2L8vmgvcDYLlpkec1lsstppMUR3sZMtoCdSzYGeGCWRQd910LSGiP1Q2HgYkGbdqrmCiDXjzPdFYhosVGVWtsabdTJ1QJ54Mi+Wifvt+c89yjOkwmiQoT8h5giuUEbcTihFaMlr7fWZTIkodr8N6D53F2UWCxPsWkmOLmtRuYz6j0Z1DkANNijunsEFsW2mzM4iVLWWiTthdcd43Id2aA+GwbZhBwTbPwLfc/3pfStMeEY6nAHscUKnSowEs2Q5ZRQc75EIjlgICAgICAgICAgICAgLdcsWyKKa9kM+uHkQJ5rFr2qlMWPRqUUl7u59VTRhhaWr0r2OaIH1MZjtV3PgWfZLJ5sFKtTCsMKVLlZUrFnOwpB+WhKSb9a0ziefXmrtifXdflnyRTRG6yGBfJcJHB9mAuf9Hhz7v3vaUGT0qyxbkXXLK+3akeRz7Tr6tfJXU1VkMnRm4O37P2m70HmXFL2x8cXwcfUZLsnuB+TNn2mFWv91v2vKanMKn1Y3I2+122KABOT0/RVFs8/fSzmE+n2N/bw+HBocgkqvhIGsb0b3VKSfqIkpDjXGJ7t9sK27rF+cUS67oFs9u9GtEU5ztPWRJwUkaKsKddg42r0anWF6T6+A8JUKkd6bvrCq2RJCWhxCJ9TOGfF6lUk1cODnDj2lVcu3Edt27fxsHxIU6uXcPe0QHiJEdEMpX2DFI6G8loPiU7P2WSSzu9+ejlFZHO+HTQn3v56GPjPbDBfpLop1NDDxLu11Mlu78PMvqd8tmG0amoHUno1coMIpB054vp+nyRSLNgi/OmHfW/BRvGfstOuevV9k5FKqKZ819z1pThXOOpgi4MSFlRw64lOWfXp97hHuLsdEkmaz8YZTV0bl7UtZGEDDBwCDj1dRh9x4qBxpntCVFfWw1Fp47m8ZRRESXIk1zzmbY6PAuDYAowOVsIXr9XDVuJyHGv75Sz5uetTXFkWeFVzV5J7z2wd2t7Z5Li+tMdVrEefd98i32+hterP74zmmp5p1geFNaDDcsoe8QpwhXUc/3B/Za2I35sBwU07WhEOPt9kYG5Dk0c66cPUGl+OIsP282dJ7buFyw2WqKuaqmQORcmUxbyo9UOx7DS/kZrlrpOREDvzZn5UGA6yWx+OCuRSt79sDnqVMUifhlsqlkskr9mAMqpnFvzdPf3LdvP/d5itkR23dgVsUxjJNqzotco92U/IfshUyxzL6KqmOfV1komXsXyGOBgdkkndTOJ/jxO9drQhmezRst5eZ6gyDeYT/cUPJpMUn1envmy/6D9yET9yuAc7Ybs3uh8lnmvZEe4CaEAAK2N2E+je6qpm/191e5RvJckcW7z1Hk3BwQEBAQEBAQEBAQEBAS8dcSyCAwWjjMVMFXEmdPfecqEvpAiehoWvaOKz6VKy7qXKsYIaW4KxDTN9JDO71AFRyKFD+x85icpyIdxkkR8KNbDPBXSJCt7EpNULvIcJKLs+FRKlltaLERo6LPKdOiMRQON1PR+oyQxSBOQoI4bUyVHSc1ag1LA0jvVrshhlLbP9pCUpEcr1WkkyujJrIdzfcm8Y0Vyioh26fjOQsR7LZNg1fepHnT+tV5lqL8OymUjeawAmZEuHINJQVVypz40AoQFrRqzIMDa0py7HkWS4mBa4Or+DNuqxGpt/rWNo5uUSu1SxEXWOsLQF3jiganiNNLaeoSerFsWzdvUePHFM0zyFQ4Onsd2U6LYn+LT3/9pePmV57FaPsB2s0a1WEm5fH6xxWYT4+K8xGZrqkeSjyRjWASriHt09EClylsd0rtrzm3+0cqi6ZQ6T09WkfksVjjy/Wb/8JBU/7GP6Pv64OG5Jvks6jHJU1w9OsJ0MsHTT9/B9etXcPupp/HcO96J+f4+rt++jaKYYDo/QExpPRWOVO2qMKMv0EcVMkk1vm/6TJHLY1LYeSNbqvxlSxNfcO2Sj4T/pX/fzaFB/+q9lr0tgubZ4wTzThPrJtRg12EBBpLr9HvlOKyxWq9EKKdZrrT7PJ9obW/q0gjlwZ7FK9rpO9xLQUuyalgbsq+xtWFcuFlKUPXcto4c1vjQq3uidiesUEYS2a1/KyBHdbRTZKroo817X1ROfrxUB9ctNlWHpuqwWZtieTqhly+kGk0n9HZn5Mn8n5XVwLnGQBj3C9pycE3R0mCaSQXNuU/FsimHS2xJtitYIP8Okfve57buGpTMAoBlKMgygdELzj8fNHJewyRR+Y+VLPQ7pdsvFWjwFScv77XKIpBCm23eEevevsGbooxh5zQPeqqrRbq6oBlV2SKRmwZrR+Ry4LQ3MnuCwTqpzv1Rreii7UXmlc+Z7u0+ZP+hgBf38swyKpzHM1XejSscWnMakeRlpkUSqxggfYcn21Kvs4sVNuvK+pGZLX2krIZrVw6xWR9hsz4W8bxar20ckwQVgxGyHXeEcw+sqwavPFhIsUxOdUYbCnov98ygYeHQGBwxWp64+JCuoal71EmHPOU8oiK6lz0K71mLtBFpbTZGnvC3Qn9D0EOex9yzlQtj95nMfO5JjJMg3m5rJEmPvUmGGT3mmxjHV1IFRxarEpuqwo9/+IOI+gwHB4fYn+/h5OQAt29fR5FOce1KhrarEfUrRKhQFC3SpNHYiEQmjc21IocfBovok+78opW5YOuXSmbdK1TrIMEEdv9h8cRtuRhsWQICAgICAgICAgICAgIC3hpi2WrpyZLCq1BZcG8neXXKLhE1DToSy05FRyJaXyZPQ0WYU4aRHJRykeScCEISt0bASh3sC1CJpDVFMNO1SVqQVGa6Oi+AhxYpVnpvDCsY5wtpmX/lSGEpEqU39V3bIqGXcpcg7Y3YHpPKY19VX0hQBDMJGV/kbOyP7JSlKkzmUtNFQMqn04hcUzgnwzWZeNhpbx8rzjb4K+u89BQhYZ6hY/Etn/pOepNENClj5omT4pEFRYxplmFvSmUaicVUxA+LOqmYl9puafBeBedhRdlIEJrCleQFh4uqP/Y9x7e+YAHBBPfuPRTZ8uzhc3jqzpPo2lIEES0ImpUVDlxvWFhxi/XaCitqoiS7OZGSIGmBVJJC60z2BAvvESRg6CfNdG8GFaTeoxWJCE0rAkfSX0XGigJFMcV202Cx2GCSRpix8GICXDncw9HBPt753NN4+pmn8MQzz+HZ975PRayKPRLKzo9ARGqFzpHqO82nFWQjkeRJZRHI3oN3pFb2KtWhXz2pPEhfPan4OMF8mTQeK47l4es9eL03uPvOY27Zr1FPdy3JMtpg7PyUbR6mClIkKVP+WbTMDqq5ORJJ78insVrZ1OOebCZBJwJWamQGPayUGD2QUyrAWfjS2bzI+kKMnfk+6zhDwU57YeR3y/XOAowVCcHGLG+oFubfadHDHk+ySMUc1WtthIZqVXcM86yWIYkRkDnVpRHKaY8sSbHa0PO3QUNSubHFb8Sis4/hCiNB31m/cd2IoNR0JuluqnlTRlswy+mAX4NdiVI337zVhutbFeMjMS3FuhOv+3npsxwc8e+VyDynEO8IaO0KTlWduf6jehWZZQ+oGKCz1vHe7NqLXMFQ/iSJbE4nO/9wtUHkNNdcYkvZZUGo8CHJaGZXsMvjVBYbm6pF2bNQYqvXer1FVTZsimInnM/TiRX6OzqY4eRormBU123NAodzqQG2me0VVJfTSokFGM+WW5R1ipO9QmORSSlPijdS4ToGpry62+alqaDNr9uCMCpy54KbCoZ4v//B/9teslaigphjZx48ugDLwomQsXgkKwLSfoSK5b7FbJZgf0YiN0Hbp1ivayzWS/nMnz28QLntcPWkwuagRZZOgFsT9cv+nP7uDJjmQM9ihhvE8Vb2P7wHKFtCrkCcB+1gzzHeTi4Vj9Q1JcPYdR3V5OuhqGxAQEBAQEBAQEBAQEBAwFuoWCaB48g0kqMyTjZCRQSEU0eRQDUHWSMujLwwkjnxZKA/ngrukdglcWnezKbQJdlDT1CmHvNB3aVBk9SUpzPJSKdCdscy60zzW9ZDPjzBzOOxDc5jkwouGQqbcjl1hZVMxUsFbYasbZHluSswWDkFqPXFrjiXsyvwpNtgReAtJew6jYjb0YdGEjtCx5dkcySeiCQxILw6a7t/ibxzZKJsPJyKt6Nnrffa9MrK2mxDDg9mePapW7hYrpA/SLHebrF8+FAK57gw1aVsTHidvkraUI5ODtoi36wfzZN6uVohZgp4m0hJvlhtcb5YYbFYYLW4EFX2xPU7WM4u0G5KbNYblBv6f7bomh50KaDvdjZloILKzBhlQ3KqRakuNcKO3qeH86kFAajE6zoV5bJQgqXd80/0LyXBfbS/hzxNlRq/XK5Qb0tkXY+DYoJnbl/H4cE+3v+BD+DkygmeefvbcePWTRxduY6cCmX2g7N7AFXKKnJXyet1qAyoXnEKZf3ceSj7AnkiklXMzRHBu0mzEymLJRwrjd0vXrfa22PeBwMet8Z4rXXM4+plr4JmRsF2S5J/jYuLBZarjUg+qs6p4FeRQ0eIa276uT5SzTLY4C1i/MvPS+0JzneX32VRN7Ni8RkMPHZnnsBkI+NaQypSLpP5y86+hlYMDGaJIEzQpsAk57i32EYk1kyNudnUmourgn7cKaa0ECCJTXaORKD8jrmH0K+9RazgF3Ws9GYGupTna9F3tdTTtHPh+opIeIsgNcU2Pc3L2gduUkXczCphd73c76hWVnBJqRBONT5YX7j/OnWxEdI7FbIV/GRxPSuE6PesgYoWuWwBKb8fmVKcwTxT9Go/9AXoXGCkbBqs6lp9oVqH8rmm5rWT8pcEtM+wMPsTt1/pnKPsjZ0RiPZpWptoL3bFFQfrErcOWGCvZTFGrVfJzxHVPCt9lWciaauywnq1RrkpUZc1JvkEN67dwGJ5gbLcIK4ieVmzDywLxQW/XEHX9dayLh48WmNNm5uMe1OMTU2VdisCmQp6vmShJEsVy0Spavo1sw+MhO45xwqSwMyC8HYabnwV94pcf9jY7JYb57Lt3wxU9JkVIyTZW7A9GfcX62sGP4+OpthOWlSszxc1qJoNTi8eIYpZaLXHZJLj6ORA9hjMwsizmYjlKN6i6ys03VoEM+ONVB6b7/fu3sANm2tBhj06KffwCZK40Kh2HYNLJZp2c4mMDggICAgICAgICAgICAh404llkhQESWDzWzYVsdVsst+RJIRX5Co12sg6kZYiV/mQTnUkGUp7+CVRXUryFyFOU6UXU3VKIoKFplhwjqRyUZi/blk1agtJGymhxc440sfxfEaFktD2nqqmjpaFBx+z+W9D+wLzmeVLhKyI5QTFpJDlA4v2kS1pRTT6wmzGCRr9tfOnlB+r6yN5tKqP3HWJcvGsovWD92X2qkMSt0YKGFkdJZmKpkkZp+swQn7s1UwCkDYVWeYsQVyRKxIiVKU2dYvjwznme0/h4em5yJhH5wu8cPcuyu0WrK9mttA2Rj7r267IlMzSYDrlIskOEWurpdpWRHORJeeLDYpJgovzCyzPT0WIPXX7GSwW5zh7eF+K5eX5ButlLTIujzPZFsznVJ8bEcdCXAmtN6jEcxQUi76dHMyUkr5cXSjFXpYoJBgdqc5zkaiepAmuHB5gUuR46dUHOD9fIK5a5G2Pw2KCdzz9JK7fuIbP+9mfj+tP3MKVW3dwcHIVSVIgIcMtA9naFLdlZX9mQIESSc9la/Bp28LeMesG31dkufjdQeI7lvoOpLF7X1zyyPh3iDqY6laHf5xMHshtI6w117xa8tKHHieVPeltXs9cAywGtt6ssVgucXZ2htVqi+VqLc/rqmoUkJAY06l19dMb94rE3NGjUuxz7rjik/58XIqE/Jyt/B/y1gqJMTbAttCbNs64mGx+URVPpSfPRDW1V07LooFF9ziBsxjTgrMyxjKqZaVRlS02Iq5ZnCzCtKAamTYzQJLn5o9OYpkFRV2Wgjxym0pznySgFSRsHdnGIodUyFtVTxK5JJeNjGywccrnjGvUDfNYhM75TBWvt8ph34vgdX7nvu+8RYYZVeyIXO97bAX1uKbNzsPtNOi0p+xcn22tMjjGXjdHdXvPyGWqwc0PgcpWWjPEVtwwiqQcFgHtxsyTyD6opZ8M5LmzSRk9KKbZVqq/e6ROe819O3KnM/tsI7QVHPP+0dx7YzPX2Nubay5zP2Kh0c1qg2pTYVrMML81wenZBBfnjzSWG63JzmxSaDfvir6SUKbafFtGiJuV9oj9WYEJrSeo0tZcYj/SYskK//G7KZc372lZj5JWESpuaMEP2mroBinPbgsG2vQ2crkfeaXbPCXZ76xnGAShTRH3SAY2uk7K5Uluo8ZPMLviSj/HtuywXhlhvlmt5Te/XF7g0aMHKmL4rne+A/v7cxwdXMHefIIoLhFFJapmhXbLtVcjSTm3LHND7eE9gPPW+cmbN7Y59ifJFGk6s3t6V6JtN2ibVSCWAwICAgICAgICAgICAt5qYtkXKPMF4IwE84XgfGEt1hKSVQQfcpm2L6LUiCLzGrZ0fnkIkxwUuUyFlnl+OuHtoOQdF/FzvJVgqchUQxqxfakQ3ih1n+fzBJX9fTBCdYX/dgdXcTIVOqJVh3kp6/sNlYjeHsEUwS5nfFd8y6sR1Ub6WbqKUV7mN9aXXkqv9r93vetT9v3bJHyonESDStI2I+t8kT31uYqwmWKN48T36KG7LStsy3OU1YUIQ/p3khi5dnIs5TKJEF6D9x31fWZ180zBbMShFcAiUSE6zKlf27ixgn59IwX4ar3B2emFuy4qzjNcuXoVxWSCzfIuNpuFiHuzADGfbkJkDmJMciPLSJKQHCO5M58Vup6j7T4mZAxJYPrCWiq0ZgpOHkMabwYEmIIex9g7mOKgmOH2rat48pmncOXaFRxfvY69oysopnuIkwKRiD9HDLdMOacKtlL6uUhlWGE3I3sdLeTHyBGC8i51fac/vZ5d6WBh4ebcJeJ4R1Abv+zmw2he7D7mFdA79evugF41PUQeRCp2DVXIljYvm5bUioOlUvMaCSn/ZBWxdOtAqmGnipUS14me3Xq363UvXyzTEW7jZnr1rBXZswKfDBwxpUE/OK9SBnTMX1yF9dy8tsJ+po/NqABOzd92UshAQ9655ptsbSPBvFyW6JoIe5l53vL64jQTqcwgBCc7A0oqhOlIfguQmcd7nqfI21YeyVQl9yR4SUYq88LIUyPujKDWCLmCmmyvPNxd1oCW8WCHQeJv57M8ENGjId6pxL2vu7O3IDE99lO5PG0uqciVtfC4nn0oRAf5LYt+bLy6eaeeHsbVv9xAcl+hrYQKwLkCicpgcFkMXMsKGHC06JHu/ez9fuLn4eDbbaQ6vZmZYaCzKX5Dv2oGChpMpiyuOUFRbOR7Tr/oZLNFnDAo4AKFPNkuXKc1SosUXsy6dEVNuTcq4LMjgc132wrA0gi6qmNlOWj+cU71JMuNZFbxyMF73vZJxV0kCLaIk2XCuHXiCh1y7pgligUYvPpefUkbFan22W8d8iLGpOY+ZOtA6yNmwLPE6dkDVNUa00mMpt7DdMoifxnQF0jiGfXgSOIWHdXJ2QRZxkwC8/6Wt3hbm1nPYG9Dj2XuezW6ntkYnVk2eVudgICAgICAgICAgICAgIC3glhm0So+XOe+WpfINJ8abCxE1VK5RsVgIaVWx3TnmqQCC+uV6LIEbV1IGddQedy02JYNtnUtdTIVwqZsc6SFU/GRMDCLByNs+MhOAogkDo9DewU+jNMrVkXwHEtqRemsiJcnFUgKyMvVWFN9zpNpPD7JAh676XJMZxMjnZqVCgKKUKI/MckQsTA7KwwvXfQelyIbLb/ePuadcl2quHWjWVxILRmTDDD9Iq+h78y/s2vNDzeOeZ12XSSRpQhs5aqpInWreGUkYZRI4fvy/Xu4WCxxcX4Xi4t7uHZyHc8+/Q5UJyzYVmC13eDlB69gsVrK87iUotJIfusjUysmLPbEIogkLjTmrSNeezTRVuRr2VXY1DXu3X+ErmqwtzfB9etHUg+/770fkO3GavkvcbFsZHXA4oyI6LW60Tmm+QRRboQL1Y20xuBrbzbFzWuHuq4Zlctdh9l0KlLUewNvN1ucPXiItm7Q1xW6ssYeCbk8w7NPP4V3v/OdeOLODXzBF306Do6OcO3Zd2Oyf4gongBxLqsDMU20RijXskLo+hX6iES9808WOWhzxbywzVPViMCdJ7PI6RH1N7jrDgTxY+Tg4Kf8mKs2Sc/h885D15/DkdkKCnjizxowSs3v0TeuP5oaVblBw6J8Hdk7S7E/PDzAZmtFENm/RsJZUMJiLzuPZZ1GQn/7C8lIW189IhY4dNYW0WNt9oRy7z5LYpnEcO4KXpK0y2kXUGQoikJ0o9lQ0IebftCmoeWLyt+cczCnXQELM3aoFjW2ceNO0GO5qPDovMb+rEGGmdTPe/u5FK4ZFaoiFbk+aFFh18MQBgll2qnsz6fYKJOBBeFKdFGMnOtCytVIQTOOO9dhw3lflUZokmTk3si20t6H1g+0bZAYmypZ9qUVvfOWJLaL+fF25LDrKxKOcr2R1QjrENIuxpOXfuLs1MoaC61dp1jmucRzSj5tBfi4b5BYrqhYjlDTJsTq8tn5R6Sy3885J0y5bJYezDQxspUFSG0ss5QBJN5ITMnMDATaGqkooy/iV27t+hzbTFsLKs5nkwIHxxOpoetNJeK+3FRYXGwwKa7gYP+a9u+jwxMk2RIXm42sO7j9ch/h5ZKYZV+mrsjgtgR4tnVJAtr2NFqEcF8vEruevjMi1TzGeV8zL2TaLk1136Hdinno+/XL8aWdBovu1epmBjt4zbulbEE+60UFBvpeHvExi0rKzN4KuiZpJnPoGCu1f2+PGRwJ2obXRAK5QbWtULbn+PEPnyJJMiwXz+Dw4AhPPHELN65fo9QfWTZF3FfoehYcLDGf8pgTRLINovKeXvpm99J2me6XJJ+LyT6AMzTdVgT2lET160bDAgICAgICAgICAgICAgLeJGLZK/CMsBhVMvLE6uAA4BS8O6HuToHIB32nbvREk4gsT2ZdotPsoPLRlGekfUYqMVdIz4orGVnhFX5OmDaofXfqyp33qLejEPnrVJeDQlKq5V2hvuE8xv/63rgsnx6rVUcfcb8Y1MWDwnBQLPvXjuQePHt96S2vulb7vMGAVzo7r2F5X1NZG4vkI0FI0muzLbHabOShO51ssdmUIotI4k/yAvszenZS3UyiukIjjs4IU98WKYKH4d6pc82GgU0w1TN9c60tJDGMwMqQYTqfIiYJQtWo7EgGOboj0UylyvezjOOXDMQySX2ps539gydNrTCeG0P5hzobChJhXYe96RSzyQTXrl3BzSdu4tqNazg4PsHewQHyyQxJShKTwQE3f6lSllqZVhgsXrcrzOccVUdj+pjVxeBb8Tp+x96jeOjN1/NWNnsT69fXKpR9H41m2uUX++VSwceRt7Oz5+hp9cKXbFJ2Cn77iK0/zX+n7vfqS9/OgaR8nT1BKl2nePaKfTftLxHMnqykatXWfDS8SHxaIKZRTIh8spF4ntS3zAapoh3J26UkpWN0DQuzOb6Wtg4tid1Ovr6yk+hIjrvFy39pvULFvLMFkebUvA5EVpJQLNIYBXlK2i3wXO57JJbNzGJn5zGMCINWDCjRbkO2FGZ9sdsMXaDBbVC7te9VzeNNY6CcXYDNKXIf73+1a7wl2ZoczjuaO+a/620zfDDN9jadQ+M4pC0M4QrCK5BVXFXWEN4TerdZ+73R751WuNAVNeS880p3HzTUPOB8s8KtzA6xY3Me+CwQKzCnYpwtLR+M7LUCdK54pA/U+Snt7Jh0jSpA6grXcY90wY9IRLtf+2alQcWyWbA4+yIFPnxWi12gCHNzR1Ef6PcMEI5HxxdNdOr/3T61G3fz0Pb3GSsUmA7+46aC7pwfdMvAbA+UFffvNTabjfZ1JlowyEL9eRQx0MY+nOl7fcNgqHrBCv+Zr80lCyZb5wyAmhI9ICAgICAgICAgICAgIOCttcKoR6nT7kneFzSyYmUjEpWqXT2Amw8uycyqiUQAlRX9jIHNppP6q6SMTtykeXuax6cV4FPRJ/qa8tX2yLIEsxmVV5a6zrbUsal2Kbyc5iZWprpYQj1HMDQ97RpiZj2jbFvkVK9lZgXA9pOIIhEaUWE3by3VWwX/XKp+YpYeSgOXkLiXEtR+yQunX6uRIUOaNq1AHAnb0huUxAZbKnljLA9lKtcmWeGK7tELk4nNVPmRDCzlAUtCkOao0lb2tAAwj2FT7GZScpK22zqbDDaIRfLOLxZ4dLbAg4f28+F5gxdfXmBSFLhx7QRFkePT3v4cprMMz798Dy+9ch/3zlYqHmWFwJzqjlYVTFuPPCncKUBAYiXP6YkLTCYxJtMEh8fHuHb9KVTlCi+/eg9FnqHIpyJN0qhAkU4pV0SNyhEaVhyRxyG3wUJrk5ZjlqHtCjR1j5deuieifLElKU4bB3rnJkg5fnmBar3B+YMHmjNX51PMixzv+cB7cPvJJ/D029+Od7z//Zjt7eHKjZv6fJpTrZdJacsJ2TdbtPVCyvC2OhepHEWVkTLe05iTyynKfeBkR6CSGDNTWU9Ujkm2S/YU5r7taDCqJh0zJXLZzRnPuDsizk6yI+Od94IKBzq61r4fcSmTMGIHmqWH/aTXdommLFFutihZpM8V6iNJxQJ+/FlWFWpmETg3Xfn6OjGwb7vJaHcX39FvnapIV0lNxBWva1B4W7EyuSzLr5tEd4SSvuVdh2TDDAPOafMP5x6xXLOAYI+zBQniHnFKf16KPenLS8V7ivmUKmfg6KhAPctQbns09MstY611nn+xWasgJOdJQ3sDei6n3tKFhSLpz23kc55nsss4nhdI0aJot0hWmh5DUc1G6zNC1dle0kWJvI61Zh0hv5UgPEJJUpwE92BrY8UxrbCokcHMmOC64PtparYbu38MiXuJ1OS5aLkw1FjzNjuDWN0FIqT7dsEr+z6zIbhnWlvNckT7hxSsziObBswuEOGVyrIRiSL53NMapHDZBrY/cLypprbpKXsb2jz0EXL1V6ft0Qqldso20CpxVihMWijrGsvVQl7qT964hvlsiv19BrvMboRe4Cw+d+XKNUymEzxc3Efbb5EvzZpDM6qvpX6W576zZzLLB7tm7tUSC2t8aF1iAUOpwjsjyDdbzvpSynWOuQqrJhlykr06DtcmLTOsoxfrFmnSYZpbIUEWGaXaecgd4P2urC0Qwj1LhLFVHKRXeJKy8GyjIrEMCuWZK/qpjBQG/mLM5hOnxJ/YPaq9wOlihe6VChfLC0znMxydHGsOFcUR8gQ4OTrUui+3F6i2S1lplPXSgikin1PZE7EAqvy343bkyB0UywEBAQEBAQEBAQEBAQFvIbHMh2SfGTwWVj6e5C+IaDHlmSmBxXFJ8UeymWQK1cr0wzQ1mRfAelrAfnoVpTgOeoPygZx+vCy85a0oXCN8oT7xK/rpSlt5RfLwMpJPvsFRPKinmbrOxGHzddVFuIY5/mbktTmoQoeCZiZu1Z9cEb3BG1q+zKZetKJn/pjmYWvp1o6EGyTh/CwJeMrOrINUqFAFokwRyP71JLYpUX2xs27wXC7psVzV2JQkGzZYLRvMpzMcHewjz1Iplo+OprhYrFTsbrmpVEBNKj+nBDRy3VLzd3bQigAMnrgkxpQmnxcoJnM0TYXNhqnYHaptg56zTHYBFhCwwIQRV1Lt0buVx5Gij+NOT9JUNifrzUYBhk1ZWVE5p+4mYZhlNertFuv1Rmnp8XyGPM1w9coJnnryNm7fuY2bt28hK6aY7B0gpklvn5mCVQSs+SqTVO7bSoXbZPYq5+iRU60r2uXnhQlDB039oF4eFJwar9foSwcLhJ1K2eaw2Wl4DWj0GsWoHdZcvHerzvsPO2X7zmxlUCvbpLTibZwfJII5zzXX/Z9b57/MOaTvWEFOWV44Dny0zGzY3Bu2BlhMztSnNj+cOnN8/V4hyzZK9dkhas1aIEm49rgvxFLMk7yratq90C+XGxQDEFQwt9ovLGPBfua5W8MyCQfS1vy75TXetSpY1rS2xmjRYMpds2yInQfu4KlOgi5L0OQJyizGjOTb0HauByMON12CisEreouLgDeNv3rbefl6tbIlRzgy2SmWB8dlt/69d7n6l/uG72xHiNpnXTt57NeZVq+da6N54wKA2jNsK3XX7dSzshCK0cSt1ripsncqYNsDqKhlP0awuqt23bLdUN+Yrl8BhVH2iF8jXjXs+8myIMw7nsVAW3pbFzmmk4nao4KPLeenZLfyWm462iWZXcVYEe0MjV0f7nzJBzW49+mXtYff00mmepW/vc9gHL/JICPXOoOPOy9nCy4w4EfVcu0+QwsQXV/ieswpxn1hVa+w3/noe9sja5MPXrLIK/dYBk8GBbOKCHLMLRtgvapFtG+2K4UbuOKLKYunktSeKJBjXvWuDgAV0zU9yGvNG9r+MAhphXV3BjwWDAukckBAQEBAQEBAQEBAQMBbrVhWqrv99GQxySJSW0rjVeG5VkWNyDT08lc21SEf2tfbGlkTo9zmIjhKeqTWzgdZ5C4JCFf8yJGsIp2dX6XSjumlKd/RSN7NZCxEKEdAnkaIJ44oVj02Iy/YsIakWlWqwBwJHxFO8kDt0a5YyAsoWTQp6UVMzmdzVHWN1XKNpq11/aYSttRl45aNXBBE3lhXkhSwz1nKsRX5Yoq/aQk98cDOlO9nLuoMdReJPDefY7bTiE2SXvKTJhUS0wuTfWRqV9pLUPVM/qUtS7uelkrwBtuyxIbeulQepwXqqsNqU2ocqrLCJE+xXD/CyfFMqebPPXkLhyScJxOUVYPFciNyh/60qnNG2wIqLeNI/UgyQipjRJhN5jjYO8J0OlGBNXZQ06ZYn5f47n/2L1HXFR4+OsN6vcXRyQzHV4+xrTY4W5yacjU1xV/XGgFYV1QRbkyxGlmaOEknKkDP1xustyWixYqySMyKDDeuHIss/8LP+yz5j77r/e/FE0/ewd7JVUwPropQjuLCsaSUobZShKOp0bUrdO1Gyl7agej3IpeNPLfx2qmHdwSzJ4GNLhsXmtwFHzyRFV0inR0NKYL7UrDGzjB8fqCRR8SuL6hmYlVHHouo9h7Hnk20IoeewOfLCD4GdFgIcqu5LbVqkkg9yb7f0jNZ1jPmI+BdM3gYjoMCIq54mxHiVniTRJhFBaw0m0hpNpOWMiy4+P9n7z/gblu3uy78N+uqb9v97FPvKbfn5t7chCQQQxWNKFFQRPkIChrFKKIoIhZslAiEqCBFQRRRRJSiokGKEDpIyJ+02+spu79l1dnW//P7jfHMOdd79s3dJ5wTyGWOc9d9917vKnM+83metddv/MZ3CEtt4jLXlmmGNcoixraIsd4Yl7fw/UANxSIgyxutywhs6saqggblKJUwSIHNEA0RdukOUzoycyItUsxmbG7G9Whok4JIgRoYpbkSIKqO2BnnOXPuLfEwPOd6skExHUkNjUo2awxbZYRtk4lVTCb5msmxJsLSz9eJC3oNCnYiRHv+K6XjWfiERuNMN6y5ha0iwhJgpviFJJIlnniVd4iFAgrWZDetC6FsLHqKlBmbtimx5gksY150CTuJ3MZ/FyeZvOuM7GS6kLkuEmFIguzYNptUMijS+iMHgnsM3cZCUvA8yLAWYoRJAnO7c58VcsSRG7qF2b/rnMzrLd+3xIMHZ1gvN3h4ukKa5njqqVsSebNRisnBFOM8wnx2ImH3bHyOzWaNssyQkmHsjSYZgWDdOrp71CLDtXDbsH4BFOnVlLA2rAVfggkOvharB6waxDFHnCdRqscxIWPJk1TnSOE5jJq490r0WWImYIPC/sHPgao04ZdIInMoG6ZmxwwcxWWam1klE0PuYl6F0diajAppU58pGXi+uK/PnTydIE0zXL96RQx6fkaOR8eI4wma3ciud8r1kCDPZvocaepcDQD5ObJTE8RBXB5iiCGGGGKIIYYYYoghhhjiHWYsBzE5fNnml2N++U2FSTD6qDQRspP5H7Vf56UWRSUhTc6wmPfRldhJAC0WlgKEeKDGKZWYKSarHYfEitY8ahK09D+WGFMsoDNOjt5O6KFbU+44HZeJHmpUxtJwCscR+Zs10niHxXSJxcVCLtntdmNOzKR738D3bJukeRsuiRFyQTtn1Pm/DIp04tdSYNBxuyjvoqmEG+orcgDy2PgXEw/5vllGQZnKhzWdinb29zTLkKYUPypEhUmRdp41CjZtY/M2vkxC0bDAeltiR6zE6bkEi2xUYrma4V3PP4sbV4+tgVoUYbMt8fDRuQRqXUc1JDPOLY87K6zbl643RY9shOlkpuMUzYGO8ibBalXiEx/7FFbLBbI817Gm2SFOTo5xvojx6PyBiWbOGC345CaSM2+7YSm5OWDpqMzpVORxLJZYFlvs1iWa5Rbp8SGOn76Jmzeu4X1f9X48++wzeObFF3H1xk3E+QTxeO6CqxV820BzkrCx3VYC867hjaIl2Q+c5J3QIjyvrlf9GH++O4P5O3ee9yXhPm+8c/D2oLjhz/3f9VTkVmsOb+lm93bVSEyjWNi1XZN8J4Btz8XqAvOutx4qiuoN8TY2p3VtuJ7pXpbjtv8+JhATC9G6XFsHvzlIeY0Y1rvMXNKtq94Z0Dokn99OkhUjmTmibekqbGrvquckFNaAlJ33NAcrva+4szIuU1h2HnpKmTnBKEqRsmnffKT3rEsTlrkHBQxDFFFsM3HWhG/DUtD9yfUxznNMRsSl0MVLb7K9NieDIXUibJgQoUDPPYViKlehu7zl+u3RS8SG155hIj/TcBnXtTIzzpuhZhsM6AE+EtzGDY85sH9tn/AchTODDcFD5IKwFMTVtPOk49X3fzpWWiIkj6Umz5x2ZDbSc6t66zx3RrTtl3b9thwbiq3qaBijjhphSKzKpOM6t+z1UD3iXHslACM61Dlna1wsVkoc7Ng4U2xlJgfmmB/MMD+aI8/YUG+GYlRhlG+QpVt3UTPp1rm+Kd63qO8egiYkeuzjydeyNxM03j/FXB6X4534maY1bYkg4lrMDW+se/KPbS0YIih8Jugzy4z8zvcP7PkgcHf9AlLHY9C1T9a9GsUGp7VjmNLM5jfdyDzH9bLAZrXFZsPkX6HP1AQTZNlIiQU63mfTKcajqZA/ZWUO/jw3N7iEeK3HzKo31ATUzmuIIYYYYoghhhhiiCGGGGKIId4xYTmoJEH4JFfVGhR1LkmVJ0uwMQGqtTvKnRlca9ZMyxpsGVKC/7Ec2MReiql08wGxIyTS1G5iH7uAQHYzS5IDuaJruEeRzBuFuZix3dKlTAmz+3LPc4jpEnZXqg63abBcbvDgwSMJngVdzdghcwdhpPLkUD6e6Dj557YhnY8Do9dnzkIimws0LOsvCmzWMRZCe9A1lukYG3JxW0HSDo5jEPi9Ns5py+MV4kAinrnexLTm2JTm0uQ5FBVFH6DRsZLpyUZZEVbbCtH5EtO795HudhhPZ3j3C89JVL9z75FciXT3NbtaYnOxKbFYbxDx2td0gBYS5OiAXS03iOMLNUuryhrzwyOJJNl4joTsAJZ673aYz+e4/dRtTM6mWK5XYvRy/OhyLMsIREVzflDolEveRfmSvGllNkoJZ/mYQuAUTz91FR/8wHtx49YNPPvii7jx1FOYHl1FnE/l1Ka44hfAXOZ0KjfEP6yxq5k42Og1AwIjNOwzZahzPAbdJcx3b8vVNflry+ADAiD8PpgnO5DFfgTxtuUldE0vw/Jp/xLQG+G+DpESWCzW6JHnycSOoV2CQzQgYZQMaakBVpqfZgkyJjuY+LHLpcd2x+TVBOH15ersdEfN7ZBwCnq7kieejFG1AZnKnJt2XKpIoBAqa6+JySZYk99tXNvRqEKamrubrnkKbKpaII7BhyHP+NoUlxNrApelYtQy8cE5y4ZmbOrIgxKvWdfaBV0tMRfl3Y3NeUuquRqspYayoAjOh9aswiDOg25miss7YJQkaJhrEYeborE5r60Robtk1UQQdj/3lFGClB0C1bvQEmLkH3PcYjWZ65zL3X+dPKw8E+diaJqo5BaTVFZVYDJ/EDS9uaVXcWi90a3ORzXmxKZzOUly7RkSxltxtEFJXIoaLZKdT+59g6I2Fn1eUsw0kZhICObEeM7eA6/DflB89makEp/5H5El5Ac3kHPZnmsIiQePzjCZ3sdJWeLo5FjnNcrnmE2ByfgCkzGF6MTEVO4zjqsgziagjTqHdGj86jThHiY9JBobcrO9kV+lOWGuZUso8bnGUw7NRoWwcKG/xW70aDbUjNUYUJ+RXt3i+BW+DpML+pyrdxiPDSNDNjfZ78JhpI534s8YxqEXiiTFZBxjvebnj2FktqxWQYXF8j52zRrb7QE2m7n28c3WcCJxvNA6n8/GqoBJkgoJERpMDCXWgHWIIYYYYoghhhhiiCGGGGKIId5BYdkUJLm7KFSypDshG5JCjDt1Aw+U1jvnhZqHk+IWS/EpDtmNX3qLQu28DFuRUMCwL+CJl4nbF/Id8qRBxlvqpeG7DrFhIoGJFxKlVP4eyqKt/J/MX4pZ4bEScVjO7JxLE9godOxwvliJOctjyMT8tFJ+lf57cz6WFNNRZy5mKQyXQNPunpOo0JVBi+tJVzWdjmQDNxVSNp9KUqSjqUSOgMEwB17XpNC4pIYfYPlzHJEVTMQGGcEUfvhnNrzaoSw4trWE4W3Jn+7WZtM74jfGLIuOcLHZYLFdy5lZnp/j5VdewYfe84r8eV94/Z4aAm7ZRHBXCQuyXK5xtlhKhCDDebMxtySbvl1crMXHpTA/mcxwdOU6kmyE0eQI6y0FsxWaXYGjoyO88MK7MH/wQGNdbDeotyu5UYttjc3GFCmKgzHt7jFF4BoNH1dViHYFkqjGfDLCyWiMF567ha/7uq/BjVu38NL734+jq9cQJWMgoqBMdYcl5JL2TISrtmjqLZpygabaOP5iq8aNrbAsvEXgXetCtu7P1hwcRD7vxmZsU4p2nLfGR5WDWU5hF577a6l1g+7jlSUm9pG5wZraojX8787TDqJo27jN+aqBpxySNR1n3B3IFMm8/J+OcQqygmqwiaUc84aK6NM1gii4iwMGxgQ1WyN+ZOG9KMpKn3OhLKYImBmqpWCSwLi7LZ1EYl2MLLakx3RiCaLRqECWsfyfjdVsvfExlqSyq5KStRxHGI9z4VjY5JPN3qiMrlK6lRM5l5XESOhYrQwVQnyN4OHG3dVYEF8TR9jK0MpmkQ0MMOGObK4v4bnJi7YZNhGCgmMhSjgKJlfI46VISRFdBAk2/wMKpSIaZOME2YRytTXm42M3gcvr15ICc5uY6/uPW5a1zQEmEbS7BHSLu9OZAOLFoDBMfI1Quy46MzFkDuMM0Y7JJiZzMqyLSgI3j4d7CA9nWzHZRXKMcX5FOyKmJIkxySgs8++WyON8Y7NVCsYtQ50JA1aFuKqr52utEK9BBBCw3FTYRrWqLLgXxEkmNNK2rHDz1m1keYLJ6FDi93z6AKv1GbYbMpe5F7KxYKa5z32pBh35hjwKn1u21yetq97WklVjiBntHGh+BnHOqg6GbmjuBY42EZLFnd5CHjlLOlSxhASjrSm7FmR9W3WPO6FZaSLnM3Eo5vDeVo3WT1wwseMVHGzoqCSAJUNGIyZVIiRTspIjVYPk6Q7bbYkHj5b6TD4/b7BaZliOjzAZHxoPunQ8lLA3bHg5w2QywsHBCPP5VElK4YsGFMYQQwwxxBBDDDHEEEMMMcQQ76iwvBfmXlTzLzpRg9jhZfPWN8wkZWtQRQG6+wKusnwvm2aTOJYmp3VwQAf3oH2J17tJpzZxTe4ycmLrzvFsTjwThiUCpyYWBzNi98Xfj70jcLbnE2699mfuQvPnXh6BwB91ke9y/6w+U1NiHwfbRRZDlVL8AIqCx8u66Motb73GWj2BWrxo5/GauBxL/DFHKt2pdPwF/rUJjRSMbJy9uZYcml35vfQoCR4sWU8lTpyeEQPS4NXX72K13SDJajnjyoLCnLmFZ7MxmmaEo4OZhJX5bI48zyUwr1bnEjSOTnZI8wRHJyfisC7Oycgt2qaFWZpjNj2kzIJHFxSYa2yLUskCqnVW0d64sLLDKDNh6LDOkScpbl65iqev38DzLzyPqzdvSlDO6FKOc1NNe2qt5oi7uinGG/LCb3tN+symadeUrtyebbgtr++5kEWjCL5InxPqkGbN+VpRti/b7InKl5rcSTz2ebo3n/pqdh+lodnRHl8rgoeERu+pXYuugCWwtaQmaY6FCc9rD75tfhYOuVsHljAJ/NyWBu2NJSkmd2K6CdL0CdNV33Nu98bAmn1GxqiVQE5XvF7R+LfkKcvnaWPMpJU4zMQE5InWPF3OGeeJRD9vaCix2CDtFH5lrJZwSDHRrnfVVKqykPucznJTxHVrASOGj0bFSgaKtQ2EwWiBMMRs+P4n4dFnIH/y3IjnkfzLBBsdtlmMJOMxmLjKaVSzcaYoOOY0tr0n7IOBL9E1OBVeo9uRbVzcRRuunc2Q7hh3zuwNTQONy841brOU91OYj1hBQiuxM5FtmrlI2/6xs+la88KAlaAgbXuP7c2d8Mo9KaQEwrxQ80etN7pn6T6mCF+rUd1qs8J6vUazy3S98ywXsoRs+jQtdJ2tkZ9hYEIiMAo2/bA2vIYgbIN2aDbOYU1YwtISjuSAmyDNOWFrTGPG+aaKD2N5k0+tSho63bnHc/8VryRUC6izpfZQ7WWOVdJ15UWnYJ1aA8GkNt5zcJVfWiL+eWTzgSLzZJIJlcF5zHWVJWMk4kDzuLnPeYWQrmGtvNlisRSfumkONEPz0Q7T1NAwQwwxxBBDDDHEEEMMMcQQQwzxDjKW+6JmIhFzWxcq/2Ypu1X+NnL20onFb+H8khtn5tSzL+H2BZzCopo3uXO4oJsqiYRUMBGLHGZr4GVl1Or8pS/JEmTrnUqnrfGTCUgstOb39NEowWgy8fJ7E7forKuSHYghNjayC1kSjkJjNStLl5jktjQ1KfNmfEb2cJGYIoGpXSoL77vhLruVKdrxmBkJOcNJKnQFXYSy/VVbpFmOiRrMUdChO5rNyWoTVynI1ZU1A5SOZuX6/Dt1HzYmJPZiWxSGwCjoLJQFFDEyNaUrqaHKwehKF/GaLPdfU7yjYD3CeDzHcl3hhz72OZwvV/jL3/f9WKxXePrWCAfzFJN8JPYsm0DdvnGiBmjXrlyVy3F10aDc7PC5z72Gz33+VTz9zG08/ezTmMxGePHd78Z6tcEnPvY3cO8uy7GtxHw8PsCt6y/g7OwcX/jMXZydrlFst+ZWZPl3aiIRBbgsBcbkNyPCtTnL0xN84IMfwoc/+nW4dvMW3vPBj2A8PcB4dhWIxz2xzZV4Ci8lERgcnLUcythtAGzFd6V4RNFZTtbAcXAHMrnYJqiGZn3+4t5Y0gRlc7zaPSF54nbg1mHqcytYGuXm9E6QffRFOyODSzaIhz3Hc09c3jWOdSAMOqifLjBLZG45s6ZSBTcrV41YudvCBX1ztUrga/2xdmx0EvdRN8K58BHBYUtJdZf69QrroJKQbNzvypoySth317PGifPdtiC9DF2yGyIWgJyLmTicSSKHsnEBnA9dFBqj8YgiWoLJLEfOhn5cs0GMpUOdGB01mKPSm7vJmw5UHg/xMI347Vm1QoPScA8uItPNzERVTcGZblvOW0RYRztcNBVIGNhsbZaJwR5FmJAHLWHT3K3stakmg6rUMFczARRKgI13XHbiNW/JaCbmwJNtXNcSJYXB8aZ7UtltHgqnEMRH5zLL3az5a8gRQ/Y4OodOb+fhy6EdRFaNO/cP7rU23nKLT0fGVGdlirr5Gbs5bsydzDDEiXHfuY8KgQMIu7OpSuP2an7a2BCZY7sg15Q8y5pjJdE/iFE13Ot3GEks5b5U4MHpXSQZcO/BXcznM9y4foJpPsbhfIqiGGO12CCNKADb3JfgSiZxwr2R+24AiRiUJqxj7rG2R3tiMFSx7Bqsd6xuqbyhZYHRiMIvcSscm1TJC/K7+Vk2HbN5oDdr5WcA51NmuJc652s22K2sUSU54TwXIp7y3Ju48nOu2SEfV4hSm49yTAfaidAw+01jzeMOjMcxxuOpXufm7QP9MolyJWAuzkosztfYlXSDW2Ndrld+Ft1/7UzojOvXbuDatRs4PJpgNDpuq3yGGGKIIYYYYoghhhhiiCGGGOKdcSwHB2ZrnHTHo8RWc1GZQSy4Zr3c3pt70eHFm74kq/zcRJzgdDOByt6EQoq5w0yU6Pi24bWtVJu3IB2IzdxQBIhb8Tg4lDvibRfBfdnRK8zJGxzVoeS5dSxfZt0KT0ERqkMkGC7A3rMV0ByFERx+wencOkl9HIW7cAdo7yi9EZpLfTx3Oipb1m14j+BS7S5WW6LdOnc7rIKhNZwtKiFrJ5fyelPi4dlCwjIxFcvNCqv1Tg5Buq1zXsN8jNmUYsRYYg+bQVWbDWo6mpsa200h4YKNDymYj8cjieBs3ieHKSWllv9Lnc9K7rdMElD3YgMxCb3eQMzL6ZkcoMCZT6bI4hGuXb2CGzdv4vjqNYnK+XiKSDXdoQOaS0oBFisrKMV6e20TmrzZXbD2tiXy+2Dl/cf0G/d1DcIkUYl5HBaJIzJatIXjCXru+CDchsZ64SIZgSIkM950OL33DxPOO1UGpq5O2UVlsX6Nv22N1Sw5Yk5l4+Uq+aHf993OgegbhEhvkNYzWXfn2t8g9tdYTwLvjsknqho3BuGbtlUzd+r1iVQwFrAJgF1DTLuf/0kvZoWCbo608eti3PGuwV3AIcjd6cdgTv5G99n9ncPTwT7WkFPOZ19L3pjPpFE/T58KTdiH3B/euocvDY1EWgmS9heiMKiN0umq5nMtf9gSGGEt7y65x4Mz2fjUvme1ztYe/5rjQ2HcsSdhngSUSdtUtN0GXSzlMbZsfMdohGvazoNuH7dmdZdN9d04tIkJnXbnnA/zTe9BDETGJonGdq/qEpvtRolJ+dZja7I4ynNhMLilhCoXCewqFnCus/ActohC/iY0qQsJoDCvwx7ffq4xYVhXSGubXwwbF0MtUSg3p3LvmmjuejJK1zYs2oA0CXuCTxruP/qsC7dQXeMMdN0X9vH+FLJz0HHxf8JmsCLGklXbnAkeZ5zHhjOxzyhLFLCRLpn5q/VWCZmiMBTHEEMMMcQQQwwxxBBDDDHEEEO8Y8KypBwKDBR05YilSNTIaVwRQ0HvJ5ENNMSm5kwcjWOMMn5BZrciul1jTGYsaWYZeCkVROXlauZk9dkUFsaRc3FDI7w6wq6K5SikS5kIjI0YwnVXmk+GaGmN74ptrvLicWrsZ0pDexQIORvNDS1xkWZIHhMd1VmCOI10S9JOtApsUoYJ2+4g9BJxlUy7mMbHiiWr4zEnq1zNZEen5iakE1JuaDnsWERvYo+Jf+5SdG1C4nHTYEsWcExXojVfKgs6fOmIcxFWwow57MbjBNNphostEQJWKB6ESjrDK7onXVB8eL6SiIL4HNGrD7CpKjxaswkWsClrbAryVBOMkhTHJxO89PyLyLIRsiyXiHlxSrfnWq9bFCVOH53iU5/6BA4ODvGuF98ngfm1V0+wXJ5KBVtcLHB+tsarX7iHs7MLPHx4hsVyiVs35zg8GGO1Wauxn65TFCONY0xHuTiqL73rPbhx4xZe+sCH8d6PfC3y8QyTg+sSlWNiMHS2jrhQEzuylemgXPlPcy7TqRwk6DDD25+tO9+EKp9hLvK7gBUe780jzQlqvwsisrcK21tDtkJocQ0ieHA1WyYlNNTr4Mtdw8Du/oCZCO9n7mHNFW+cRp5wUxYS/LerlRztm80G281WrOyL8wUWizXW60IiE38yucDmbUHUNdas8ZTlknVHcKeJS8JzrrQnDIKK5TzygOBQw8zAmw23okAVsVqBrsxUGAC9DhE31U7CsZqy1cRc8JZZ80lvaDbKuA7IbqbgaIkmHrsxpinzSn62Rpc+vhTV6MxXNQETUUmMUV6aRs8OfERw0HkqdE2MqfAC3B+sc9soB7IKIGmYjt6gXvNMCzb1YxVDvROXmfOWbmmKhqlzuDPhD4A4ixGnMQoKixQz4wi5mv+xwR8FcDYPJfLHHLfhrYJEabgDnlaMqOFzI4y9uSgy2+hsV22UEOL6Fb4hpxIboapsv+EYKNGUck6acB40xvEoQZ3GqArjRYcdMCTNtK+4CEtHN6cGufZEjthF5+eENVakMC+PMl3gEvJNYOeuqRaYYi+zqWGM6SzFNmXVCl+/wL17b2C7OcTtp64jTcY4mh8hickTrnF/9kgca7KGeT5s7sjaBrqEtVSdiGGCuSc0HE2UcM5p/lCstiaETDzxsQX3SDZndMwFT4hJS77+KCeP2rjU/B0FcLmUGzZMLbTnkwfP5EZRb1E1JWI6+pHoHE315/lW2LEyJzVWOqttiCWyigkeB9+PCQETyZUwcfGeqJkd144g93bM2djc6geHsRjjyyV502tEBd31TFTwvVklkGOxXKMo7mC1mqlaQ+7zIYYYYoghhhhiiCGGGGKIIYZ4Z4XloGe5n9GbgZmWZGXPEgvkgjTXrkRbul0zclDZlMi+KPO2x3zVH80GRyFC5cnh/UJDKjXiM+FGgohVFFtDPDob3cXMm76IUwtyc5gczD02qDlCe5Tl4PANDmVvBGgOQHPA9Tm3waVNHcAEYHOCirUZBGJ3aPbdymoORidZoG/4m3SmxvAe/n8tI9fEbArZNXnMKt9mUzN7Dx/KXjM1Ez+Sx7jdQjNAiTlszlVSbKP7vES126oknKIzXyyQFaw5GMWUHAezA4nFvN50wJlIZ+IKH8wGWhfn5yYMJxFGowyj0UguZ95HFjNdzcvFEqslRU+K45V4pZNphpLC77Zzd8uxnKYYj3JcuXKCW7du4dr1Gzi6chVxOkaUjXtM5e48hU8hKqIJPGX+tKRIcFL2xZpLNOT272aE9ckYbJnBkdt3Z7pDcfdmWnfv9fpuZZNme6yMnru5/5zec81ieuk8XahVAsLWARMt5G5TYObYVmUlsUsomZIJALufYrCwB/zpwrCtuf7iM2XOXKDBFmvvLTHOWb3Bqb9/3N049kVlRuDrhmJ9u8+HV6xautnpmDcHPrEokubpFg3JoCCwuiOT85DnT2GZ93JOBeGewihdyRRsgyBKgVPYBDnau+uvRyupEbXCsqRVFxrFx7XeeN6QzsRgvoIEPDmNezOAAnLPSWzuVyIydkrMKVHhe4dYw3JQ29ykoJ+SO+zjreMLexP5zHqqs4VtwzVXtYve0tXTIIZSvLTHybFdkVdt17fNX6jigegHS0mRy2xIE2uQ6rkyHXNATViaI6xXrxoJrneej/Y7bxzIi97OXn++zwvtXfycqCOJ2tzwyASmS5n7C3ezPBthOp5iPGJyi+fDTczeXUgioYk0nL4kO15xV5Fgf+fYhvFtN4OAMlJzPxPPdb4+9TXW7i6Wk5yfPZrffLxXBwit43/Xz9Dw0sT1rqrC+wn48XLM+xU3Ii7pM6Yth/ElGdan7XOm+VNgZuNbw4Ew0UfROg7Wej9pXg/uA9tmg80mUaJpEJaHGGKIIYYYYoghhhhiiCGG+NHEE9uUTBgNPk37M522dP6aypGIt0yvX9VELlLZl2aa/SbjBOMR3WzeRI3KC79Q021Ifi65zV6yL3GBTjw19OJ3fYo6OZpdCvZ/2xbmnqbeISdtmiJPMuRxBvrCdiXdy/y27zxlChn8ct14M8Ge8y6gcNlobjTKkbMplJxoFDsoCtTucg6M2f2Ga0GcFlagsuNn8CcFVonBLiIlKd3Q9jOhc5noBonKdGNbM72OxBCEwk44NCG5QlnRXbr1JlkuxvEcnbXK8meyQSdjCrIUrRz9QClIvFdjokpUo9xK9EBM/muDbb1VE7OjowmuXT3E07efwXPPvYT3ve9D+OjXfCNefvF9yNM56iKW4/gzn3oVn/nM5/CZz31OjOJnnrmJ46MDrJcbnJ+e4+6dV3Hv3qs4PjnCu9/9Xly5ck2CJ1EZ5+cPsVg+QpJWGI0jjCaQyz1labf0kh2qTYO4yfD8cy/hPe/9EF75qo/g5Q9/La49+y7E+RxRQlE5tEqz9mp0Ju/qAk21xq5aYFcvgHoF1OQr84ztmsjJyvkSkAnOD5cLNLiHhZbYvxY95kgnn7Z4ktAk0K4VhW1DQISnuQAui+y+5G+v07uvFav73bvCS4Smd3zP7n3Jpm2qLcpig81qhTVvSwr4SywXC93W6w0Kics1yoLCMkVDn0MmqZoYF85JGaPOfd0lZ7pD7P26FZhDssWczCb6BZco10Ng3ZrTc4vS3fdkqNd1gqrKsFoCZ6c1FosdVity2FPk2QT5aKI1RKGXTPH1qsByucX5BfEtdInaukoSVi+MEUeZ5klZRliugcWqwemywtmywmJTYrmthKSI8wzJJEd2kCOdEYibcJOx5n9pjBErAWYp5rMUR5MYB6NYPOhMWBUTTdm4rtrFYshviHoJCBs6sOl4zVKM0gx5lCOPU4ySWFz6jI3ceJPIaPsbL3UaATmfp70u1i3T79noLtLvKLCmTN5RfGeDN78RPM/9bDxOMRqlcsAycededMd6sHqBVSDG2TYxGFqHfLz3SNQw6DiJqzD9upsDQrh07HBLNlgCJ+nhIyicah9MbJ2FSo3gNjfGd4PxZIQrV44xnY6xXq+wXC6xWW9QbivMxke4fuVpXDm5jqOjQxzMZxiPM+TjFJNJjtlk5LdcHOTpKMU4Z7WI3SxB0gnFCA7jirxnSw6JY59wX25QFIVunJ/khasJn7gXHCviZErbj5vS3MsCrNj4mme8SzDYfm1JU8PShOQsPytjpFGqm7Wq5PiyCR/3s0pJRGsiGNZQQHgwNUO2MxOsjrTZUVQG5ocjzA5ScZzTUYE0L1QtdHQS4fZzIxxcqfD7/vCfxF//W9+Hv5fjhRdewHd913fhKyV+7+/9vTg+PsZXUryd1+grcXx+rOM//A//Q3z4wx/GV9p5/LP/7D+Lf/Qf/Ufx4zl+PJzDZz/7WX1+/c2/+Td/TN93WPtDDDHEEEMM8XdYWG5F1SAtB6GZQpwa3lEkiNGIYxpEUvOxUVhg+TBLh/ml3MnM5lWT4MrGdRQ7TchSI6S2wZKLbRGbANKFZWXO4ZjoNONNpd78ws1TYhk8b47XFTGAnFZnb+5xcF0oC2XiFB4C+qLFfbR80I7L2Q8hDpxl2zJqictwoTkIy+0t6Qsr1hiLQoOcjsH/F5y0QcRsebkU6sxtGhzRPXOgOyN3anjHsn2Z/oKw7KIly615U3MwZ8VKbiWugGIoGsymIxweTnHl6lVcv34Lzz7zAl5+6b146tazKkevqxgP7p/hjTfu486de7h7967EzWvXjjGbTVBsSqwWa5ydPcD5+QPMZjPcfvppzOcHus5lWWC1vsBms0QUN0h5rHmkn3JeWoc41AWPMsX1a0/hmWeex+0XXsTtl17G4bWbJio7/sIpx45uqLCrSzUu3DUb7OoNds3W/k53X5910fKuA8/ZrlF/ru/hG/Y4x+H6t8ySjn8dbrqv5yrvlOHHOJk9+mZnZwRbQiM83BuQOSvFXp9zj/PNBO26LFBsN7rRHa6fmw02a5bAu4DrTmVzvbv72t2s3XG4qN4iXfCm474sKnfDsbvk3O/WAOe+ObZNVAtz2ioOuH8w0ZSoQd5qtcN6TfcyBUiKpyOk2Ugitbkvib8hi7fAalNIIDXnM9+L6J0MsfAMfGyM7ZbNP3dYbmos2QTUsTryw3L95ymSSYZknGoB7Uzp1ZhTvKVIOyGyYRRjTJ4tcQYBtKxKC7tVbFhYG+pHLlUl2QynwL0qjdiYkDcKiiYkMoUmV6yLyiGpRQGZnHHeWhd16551XI8jfMwVbXxvuV4pEOcUn01UNmZwYBvbJaO7tqjI3ba9gnsIX1ePD05dicvcc7v72mkS3Pe8xm0C0jEh7vC1YzWHMEVbzQGfd6xAMSyHybB5luLgYIZRnqEojNlebku57Uf5FIcHV3EwPxTvfTIx5zJvfLwqJHKKySlGvC8zp7bt66HBozvwfbIy2aWKC1twxtYm5kLNU4mHCXsuIUKNifbWEk+JEc5fJZIkKvO//WWsjyHt6d4gUnNkf28xbriNX/t5w93ZufBKbPW486EEp58gsySj7Qcc3vEkw3iSIs1rJfDCbXoAnFxLMZ03+Jaf8TTe9S4mXv7eEjaGGGKIvzfi1//6X6+qo9/4G3/jm373b/6b/yb+1J/6U39HjmuIIYYYYoghhhji773mfaHxkb7gOg9ZLl5qLybOUVzQF1+JUUQ/NliXhZiXLMeVULIzgqc1TerJaSrFNglUjFhxLsl85ZdkFw4p0NRBfDajqhzQLpJQeKZYttpU7ozL5Lqr6NBtW3KFZmw9hoEOPzgzg5uUv6MTzMVG52JaKb5zN4O4Hh7uTb347d7KoL1HU9tFyxsEBsymFAyj8Qb4QdcAKzTtCyZWijkmpCUJhfgElcqtXZgzRcHF8L6IaQp1i+4NXaYC3oRuxarGal1jOhnj6VsnmEwmePqZ25hMxrh6NMU0y7G4KPDZz72B9abA6RldoRt88fU3sFptcOf0HIvlChXHv7FmfzdvXsNknOPK4VyvM52MkOc51hePcHrxQI+fzqZyga6qFaqaCI0Kq+VGWYlZniFL6To8wcnJFTz38ntx6/ZtHF65jSQ7RJxMegJtcNRWPj9LcZTFVG728RddJ7wwPu3A7ImjwhR4M0Ur7Q90W+Ozdh3suvnbNY/rJQM0o23u7MlMgZEcrpE7FvdIHGFSBdGyva4+V/2a26rxRI2nCaq6MLfyeoNHDx9hvdniwYNTLFdrPHxwptv5+UIYEl5/c2o6vsG6B7YGaVsfRKdYFYDd707vAP3w5nvW5M6wFDYvLckUmhGG8+oQIDa2EhSJpyAOQQzeWg3sRLetI+QlUBYR6pwHY3tIaDq3rWpxounCLksK5TXrJuTEZmKBTungNi/JDi8M9bKhsF4ZbqNIdijqCsmO7Hfy4cl8Nr637VWpziGNNshKjjbdrDESMsiVSQPiOkLMc3Vxt93kHEVBLAyTazlfk5Zf7qFc40lqLu2t7a08ppIiK7nGdYOkMdeyZowSc96L0pnLoXFeSHCEmR4qIFLumcIzxKjJOvaGnda0k8kkPs45wRRgJchT5Ex9j7VX5PO4H4sxr3nA8zZsxy4Jjer8QFiJ4Z8HEm+JQ3JBmjxwurn125BAdDSPEgvlDjH5/IRpNw3GY44bx4PCcokkomt7hNl4jpPDY13fs/NzicAUpNVYb56gGpn7ncfM691syPRuEFXcs0Ozzg5NIoyN77sUkLXnps6xFi87bDkO/vC9X9fa0cjWPNOd+o03x9Q+zqoevphSBy4a8+OXT3IgtCoPLu1XLTbGji3Rh4HjTHwthvHlfOHE8NaRbVKH7uqDeY7xKEXcJNiOGpwczXByNFelwGhkOKMhhhjinY1+BdsQP7pg9Qj/LflW4vf8nt+DX/krf6V+/lv/1r+197v5fK7b2xmsVszYQXWIIYYYYoghhhji75F48n/duuOuFWC9XNoabNFRZ7xTlmdLShDLlI7ALdbeZI7/qO67Ivvs4cCVtUZVRGkAm6LBalvJUbguTQja1myI1JC8AcOn8jVZgswv3uZovliWuFjSvVhjvaWoQCGIDfIoEOsrd9fZKYRUATb1csXGnaVs/pVELJ82kal18dG5J8YrS5KJtjAXMkeIzFY1JGxNjF72H7Q0Nu1z52Zgv7KBl3i1zlE1oTzwf02EYrO8LM8RZzmilF0Ro1ZUDrfWxdzyeINVjnqyidp2Lc2JzXMkb/diuUaaZnjxuWfx/pdfxjd+9UfwE7/6a/DM9WcwHx3h/NEGP/Sxz+F7v++H8Wf/wl/En/8rfxU/8MlP4+Nf/CK+eO8R3jhd4I2Hj/DavTso6gLPP/cUXnj+ady6eoxrRwc4nk9wMBlhs17iC1/4As7OznBwdICjkyOMZiMko1RO0/PzFXZVjcNRjhvHx3jlJSIw3of3fPXX4pWv/gk4ufkC0vE1xOlsX1imC17YCZaNb82lXG8BCoW8rmZb3xeYfQ7LHdi1geyVmFtSAY+5yaEZdzdLFASdOQj8l32LrXexdTMHu2+g1BoXOjid+3PTcRd0RYpbYbeof15yY7NUv0RdbnBxcaFxpqP8jdfv4s6d+7jzxgO7vf4Ajx6eY0M3syMAhE9wFnBgi4uFHnnTtcacm4HDHNar3O9iO/v8knPShWW6Pcl45o24GOEOOnyCuZY7gVSN4IiRqdmks8K22Al/U2wj3aqSLn8KdCmaOEUTJViVNc7XWyw3JbZbYlZqrNclVqsCZxcrnJ4tsThbYHW6wGaxQbWpUXJvWFVYrSs1BN0WDbZViYIJiaTBhE7P6QjZZIJsPEUymSOezpFPpxhPx5jMxpgdjDGZj5CME8REYmQUQImtsJt48ryxMRybgspRm2CSJ0Iz0O0s9/Mkw2icmVtVjed22FLAJf+YvPgdd68GWVQjSxrdbB+zaoOCDf/Y7E83/zNF4Bq6UawVIoHNOoWcCMJyI776puBYlUoerZclym2jnEwWZ8iTvHV7a08ua2/w6K+j6+sIoGBdt+2/TR6oAoWND5MYE2I7JNYrZWdrj2J3tEPFdnYUutXXNcJEY5NgMokxGnF6F6i2Gwnf42yCw9khrl+9juPDI2QjokBiTIi+GI9x5fAQt65dw/WTY1w5OsDBbGqoEFWKeJVNcNt7RQD/Myay4S1qzgXlUFMhjDgMUWKfgb6hQpkE523zFvZsE5WtEaZEbcH4uWdTTDZvOpuN8haxaSS7PipbEJJhXZIwNEs0R7fP/x3RFymB+4iRItklehoTK6ESgWuPiaU/9Ec/ju/67X8Dv/v3/S187ovn+HN/5fP469/7Oq5fPcLNGwf4vb//Y/jkZy/aj8Lv/M7vxFd91VepyuTZZ5/Fv/wv/8tYLBZvKiX+7u/+brzvfe+TKPMP/oP/IF5//fW2tPy/++/+O/zRP/pH28/3//f//X+fqCz6D/yBP4Cf+BN/IsbjMT74wQ/iz/7ZP9s+hufzS37JL8G73vUuJT/f85734D//z//zxzqlf9Nv+k146qmncPXqVXz7t3+7hJ4QrK75R/6Rf0Svwdf6/b//97/peL7cGDxJhGP5j/6j/wjXr1/H4eEh/qV/6V+SMBbi//6//2980zd9k8aTx/oP/8P/MD71qU+9aVz+t//tf8NP/ak/FdPpFF/91V+Nv/SX/tLee/GaPPfcc/r9P/aP/WN48ODB3u/5mt/6rd+Kmzdv6np93dd9Hf7kn/yTb+l8vhKv0R/5I38Er7zyis7lH/gH/gH926QfnMNf8zVfo9+/+OKLupb8LHura+WP/bE/hve///3qNfH5z3/+iY6NlUb/9r/9b+t1+byXX34Zv/t3/+729xz3n/ATfoJ+x3H8Vb/qV+0dG5//y37ZL8ONGzd0/Jxnf+2v/bX291yTvJ50637t136t5g6v68c+9rG3NIa/83f+Th0jn//zft7P0785+vHf/Df/jfYJHsN73/te/Ff/1X+193ue47vf/W49n2P87//7//7eXAioCr4O5wJf560Ex2m9XuM//o//Y5yfn+Mv/sW/+JaQHo9DsPDxfF4IjuNv/+2/HT/7Z/9szYVf+2t/re7nfS+99JKEcK6F3/f7ft/e64Tnfcu3fIvmOs//D/2hP7T3mL/1t/4WftpP+2n6PdfKt33bt+3NMa65f+Pf+DfaPYQC+uVE4ZfbZ55k3f/BP/gH8ff9fX+fjoP7x8c//nHNJ84d7ik8h3v37r2la9+Pt2vvePToEX7hL/yFODk50ZzicX3iE5/4Ec/xR1rnj0N2nJ6evulz7Qd+4Ac0rtznDw4ONFZhjDlOf//f//fj2rVrODo6wk/+yT8Zf+Nv/I0nGv8hhhhiiCGG+PESb8k20fUvM6GC/xAwprCVTtO5LPcyvwBLIGgk8hRyCELCRNucq8cn7retsy/l1kiMTkq6EENTMYnH/oxQVlw34UahxBom1byfpegufpAfyi/3El71Zb1zmLbI5Mt8C7cOBzd0OL6uoZy5zsTidKFCjZz0DyR3DYfGf17ebC/bNX/b9XEDZnXrYRRcvHMWZx+xYUgPcy5TDCZK5PJJOClXFziJiCMhe9nPt0UXUKCq5WQ7nE9xcnSIG9ev4/joGMW2Eq/2zt0H+OJrd/Dq63fw+ht38PDRI5Wm0zko8Y/CksbcmgCu2ZRvtcLZ+Zlu/AfwYnGBB/cf4s6du3j08BSLi5W4uPwHOP8BNp8f6jaZsCHXGEdHJ7hx8yncvP0Mbj/7PG7efhaT2RHSfIoo7rtAbMyCyGol472fvYL0/eFpu1B2pe9tb7xuTrbokz7feC/6jffai7nvOm7nj4E62t+/idN8yUXff722NZqjW3qP7eamCXvCYVRdaT7ZsfwHMrncxbYwLEZJXmyl9WBuTDu39ogeZ17slexLNJb43aEugqOyW1td8qg9yW7Cu5u7h2NwJEJ3660F4QZMdKT7dLu1mzXyNKyAZryjNviadGgvKTYvt1gstxKaOacp1oYEmGuBwufob2TU0jyqfcrOLbiqVWmgZFDH6mVDSQmVYS9RMzeuS/7e+M/hGum/gFjw59Nda1gMYyYrYWUd+fY4vcorkXOshJY14ctyu5F/TIHTGs11+1uYYoHBzlvdv2n/tn1WfHjfYyU8lzVK3iqKxgFj0s1/26d7jR6Dibeb8Z4Y8wahgVEeHuuCcocX8gnSQ9PYYx0PQTROTLb+FiuhXArNae6B49FEt1GWq1KC46fGrRy5hp8dZIlzvhd+jt55UAm9rlljt/7dTdzHmlvXvxZrE9bh3kLp7S0B72INAMOK3W9c2ds92saxNle898DeGvKqgJZZ7oilVhDvqgqskSrXhj3mj3/35/H5z1/gF/389+EX/4IP4NU3LnD3wcreS65x63BrTS4tOIf/i//iv9AXZQrEf/pP/2mJFf1YrVb6ck+h5M/9uT8nsYwl5Qz+pMAUxGbeKFg9SdBN+Ct+xa/A937v9+Ibv/EbJS4GkZTn88wzz+B/+V/+F/zgD/4g/oP/4D/Ar/7Vv1qCRz/+zJ/5M/pCz588fop7vPXFCQqI/D1FHIodFDL78SRj8CRB0e6HfuiHJED8T//T/ySBmKJFCLLDKQr99b/+1/VYvi+F4Xaeevy7/+6/q3GluEER7p/6p/6pVvj4K3/lr0iQ+Vf+lX9Fv6cA/Z/+p//p3vP5OfwP/UP/kN6DY8trw7F9UpHzK/EacQ5TAPzv//v/Hn/hL/wFiUU//+f//Pb33/M93yOB6l/71/41nQsFVB5jEA2f9Bj4Pt/xHd8hkY2Po9D7JMH35pzh63MO8f2Ds/bVV1/V9aTA933f930SJyk69687j+N//V//Vx0XRSwK0xTPHz58+Ka59Zt/82/WHOTn1i/+xb/4icfwk5/8pK7t//6//+8SLzknKK6HYEKAc4BjxnP4db/u10k45jGF4L8BOa4cYwqJ//V//V/jt/yW3/Km9+G5cP28VSYvx4XrhQ5i/uyL829nUGjm2qUQzDH8w3/4D2vucK18//d/P/7Ff/FfxD/3z/1zmtP94Hj83J/7c3Udf8Ev+AWagxyrsD/wmlEkpTjJdcWEENd6CF47jh/d2H/+z/95XV++dz+edJ/5keLX/Jpfg3/v3/v3NJc4T/7pf/qf1hzjNeNa4TXitX4r174fb+fewfNkMocJOH4mcq30xed+PMk6/3LB9fjN3/zNSvJwD/j//r//T3Mg7NE0efyiX/SLdH3+8l/+y0pm8Zh4/5cK/ludiZD+bYghhhhiiCH+bo5o94Q1sLMb5hJQNTAdaBIld5iPclyfz8ULTZFKYNisWbJco0oa3eg6OzlhY7wER0djREmM1+9tcLYoUbBUuWowGsU4OrTGSizpJfLi9HwtZmqej5CP2IDLSqn5LZsf2BK23Nyl9yvIvzSxiO6/g9lIZdHktBY0rjougMIPWZwq987oUAWm01RuQjJE2bCK//Ay4SjGwXQsAUjiDr+EZ2R45oalCM5huVApaFHIMP4xxXeKR0eHMzmas3HuYlNwdatA2htlmeXVyqYNRyAOJ9medYE4YuOukdxz+WQqbmyQJtarBc5O7wsBEDVbCQ/37i2wXG3xeYrCd+6iRoyKLk95Wk18q6sVml2Jp65dxa2rV/DSCy/im77xJ2ksP/GJzwqV8LFPfhz3HzzAdr3Cdr1Wg0Oyl+mVPWPJft1gsSmwrSpkuxrprsbJwRTP37omTvPTN69L8Hn0aC0H6WqxwGpxIdTG137DR3UKn3/jdaw2G2TNGvGuxDO3nsXzT7+Ak6u38K73fhjj2SGObz0v52jLGJB65uIxm1s5AkMl/RUb9RVWYr5j8ytv6idosz+2FZ4DQoDXi5PkssM4CEJBuHW3viMfWpqGO81bIVJW6MwaW7KGXtxmIkysiRyQG4+W11HNL8NrXWIw95Eb7SEFMX2HHdcBr2VRar6RobzdrnF+tsTdNx5icbHEZz73BSwWK7z62j0lC+4/PMOD03MlAwoJjawMoHhrwqIQLi6UBlFep+Ps6cA2Dmqc1htNmzGbxNGFSb4wmcUhOVR3c77nDSfuQhgMFz4tUWO/D2gYstMpFpKVOx4lmE4TXL9Bpm6kJnoUVpfrc2yLjY3/rsF6XeD+gwudy26XC51xlGeYpmwwGqGJY2yqCg/WS63vF25dwXwyEos2HyXCINR1Yc7YiI7d2KoEiNQQr5vN3BpsyGjebPGFO6cS8sZVIlRBPsqQ5Sk2VYllsZVgPJtYk73bBzPM8lzriF906RZfbyl81jhbcB3VeO1iiQU52BS4a2MT5xSwWemQe2M8rmF+YdlssdgW7v51UTRmdYa5VzkGsyldvJn/js7jBherraGKNkw60B0daUrlaYTpmE32EsxmY11TJucokq7ITSdKhFUj21KvN8oyVW4cjScYZ6khgBpzmxObw2BjVeIrwrzZNBU2ngSka7pt8BcDN4+mOJrmuHJ1huvXZ3JEr1dr7X9Hk+uY5DO88vJ78NStp7Apl1gVF7j/6B5+4ON/SwIy3dCcZSXnc9XgfLHB6flKnzHnhbvld3QNC+aiz6vU93xNZ0ddJCmTLhCnmXPw8CDGyUmC+TTH7VtHYjaTW83xLOh0ryjGV9hszJnPSgAK2Rutq0ZYoPEoU0NBVm/EdDlH3JOIYglrl+PKz1VLTlihiiUKNltjjwvJEiVqMkt3NreONLbXsWauxsrmNSXi5Td859/Ez/s5r+BD77+q7e7RwxK/5Xf9DXzdh5/Gz/vWr0WaN/g13/En8DN+6ov4A3/o+x/72U9hj07b+/fv6+/80k2BhEIC3XgMCn90BL7xxhvtl3sKdXSEPknQFUan2m/4Db9BDkYGP+N537/6r/6rX1IwpMDC9wwuP74vRVwKDxpDQCI35zGdtnTZ0Qn3V//qX5Uox/jhH/5hueooZv3yX/7Ln2gMniR4LBTcKJDSOcf4Hb/jd0iYpavzcTgEvj7dzRSn6AYO40JRkuIxg+LHBz7wAYk1dAFS4OHr/Z//5//Zvg7FKQp9vAZfKvj6PKe+SPX3yjUKc5giz9d//dfvvQeFejqBf8bP+Bn46T/9p+Pf+Xf+nfZ5/8P/8D/oPF977bUnOobwPhRD6TR/0ghj8P/8P/+PjuNyUAym0Mo5EJLhXIO8LpwLdOhSjOT7c34wKKzRfcvx4xzkNWASgkIlz5Pxx//4H8fP+lk/S8//cs5gCqkUsj/3uc/h6aef1n2cc3w+hbZbt25JzP5P/pP/RIJuCD6H73PZORyCCSvOA4qD4X0oSvI1uTbeSlCM43FQYOT48zrQScpkVxDp+frcp4JgfXnvCmPWn3d0LNM5G1zLvAb8fV8Q/0k/6Sdpnf6u3/W72vs4zynyhrXK53G+MDEQ4hu+4RvknuX1pMjOa8o9hEYMBseOyRzOQVYg3L59G//6v/6vt4iPsCY/+tGPfsn99/I+8yPF4/YgXh9eU4rUdFMzuC9wvnEdMb7ctQ+vy2TEl3KMv9W9g85kJt6YKArJTCa96KinCP1P/BP/hI6R1yrsjV9unT/uOPlcri+K2z/lp/wUCeB8f7r9nwSBwu+HdI//j//j/yiX8+OCc6ufhAzx7C//g4hH9nkyxBBDDDHEEI+Lz/6Gn4W3K/hvKVbb8N+XrMp5WxzLwUlofzY3WutUdJ/WnrGToqga2JFjagKHuY73CRT7jFtzLVJYDs2cVP4b2JVtCf1lB547xFq+sMlS/J1Kvlu3cud4bhvp9RryibvZdzG7u3TPadZ3kgmZEBxu3ckHZ7OQCnqMv17Phdx3f7av33dxBr5uS0LuuaTdmZkmdDzSpefwz94xBMSD0B1+s4Zf/jOO5FSeT3KcHM5x/eoJjg4PkWcmeF4sVnh0do6z8wXOLhZYrjdyurK8OzTgsuMN7khr2sjxpli22qzlLlytV3LsXFwscHZ+ISZzsTUnpJovpinmswMcHR7hULdjHB1fwfGV6zi6cg0HR1cxOzhGko68jDzIkh1uIriUhYbQrSXPdlrtm9zGfn37uOTe/A1wjMe72juH/Z4Dd7+LXc+FHK7t49zAPZtu62+8tPD23Mz7v7PXDG52E/XoUg7NyEwQD/gNm296D3fgBoRLvxnY5dhv2HepmWHPidm/7R1+/xjDc9qZ86WGo8PutO5XrmcmogoTzWp3LFNoo3tfTmE2xozJP7Z1v91W4i9vC0v4yIFL1/9up4Z5wttQyHMkDVnR3H+CY9duYf8KeBJb70JXCIVj7tsO2W1rNbhy7bTcdSqWvA1qeJ6tSa5Ra7YZGvJpiYXMk7+8XpO/T7l2KEBaYzwriOg7wLtppb1SDRG7m/Zh31v7VSRqpKcGiuZ65c2axnXv329f2l8S/camodLCKkvMuRz2uH3Hcve88Oz+/qjneCXIZrPBYrXEYrnQfkJBlmJ1llC0HWM8Gvl4u+Pb94eAZbk8NwOguJcj6V2//aWg5/le3XKkw4RtSeNh4Pc/L9v9orfO20/NvUTW/qC+uYhmv1ogjFtvtXneySqKHjzcaB08fXvefrZMJimunlAw4uczH+zJHgGkLYLQRLobyHQAAQAASURBVLGIbsJ/5p/5Z/TlnPt4CAqlQVRmsCz5sqP0RxN0wIbgZwPLrIN7j/Hbfttvk2BCUYTCEEWby65bijlBdLh8bHwtvi5fIwTFWX7J78eTjMGTBMWsICqH86N7OCAXKIRQeGH5Nf+xRhGLcfmcPvShD+2dD6N/TkEcfdw4MviedDxTOOW5cuz4vB+NY/kr5RrxPYJw3X+PcC50kDJZEhi8vP0L/8K/IFEyvMeTHAMxCP3r9yRBkZPjw5L5xwWPkdehq7AyIZPX+Ytf/KKENwrJvC8EBS8K5v1r9eXm1pcL4leCqMzgMXF/osBGAZXHQTGyP4YUF/sYhv/5f/6fdZwUgPl7umIvz5fnn3/+LYvKDDq+uU8FUZ/CIF+L7/l2B9dBPzjO/fFn8O+Xx//yWuXfw2P4k8ceROXwGmGM+SWP87G//sOa7MeT7jM/UvTnCQVtBjEw/fvCvHnSa3853q69oz8exGUwSXN53EM8yTr/chESFl9KVL5z545ek05lfjnnNeBa/ZHGn0I3r2+4Xcb0DDHEEEMMMcSP2+Z9chTyp1WFtwJxYHVa4y86hq1umzzRHZsmVYVETbrnJPjQRagmSfa6PfCAvuTyyzOFIwpAldjIidy2RFuILuGiQ6GSdZZts7y9ExV3UWrcyihGUe4QleSnWpOqJI3F4WQp+mgykngTxWymQhezST+8j+5k+we7lVSbaM3GXZn+QSMRKw0OQD8BU4sk8vLfFsF1KeHOxWUz2QYR0Jynu10l1ySdq2J89tDPKgOPUxdyyWSdSgQc52OJyjH5p3GGckv3dmKN3PhaEY8hlvtwkmc4mExRNXSnGh20JhojifHs08/i6GiK9778Hrzy4isaowcPz3HvwSN87w99HPcfnuJicaYGb0eTEQ6P55hPx7h+fIiirrG6fy5HtfcH1Psii1AiwgWF6IrnRsEsxsNHW6xWFZHYauQo/u2jB5gfHOCVF17EaDwxLm0MXLvxDG4+9TxG00PMrj2DmGOd5iYTe7M6a8zHG4UvspQbROSTuic7NBx7E3qh5aMGQcjmbhRbabmrf3ZJ/Tkh+s3GWmHIHZdOtZbAbuKW2s65AG6TxPjWrTU5TBx/jPNdXHS04nmfXKEzWzilvnglZ3wtNAkFZTqWl8uVRLiqLuUazjJzOBKdkOax+MCHEVCRg855QadltdQ6cpKFn5+t2Y45vS8uh78L/+Kl/3Sq8n5rMrczDEdZ9ZjlnStZCJ3e2BpixxY5qyGk5ZEFLmEyQdnEiIoIZ+c8J24zsaog8nyGCd1VTYWoKcR8zvM16qbCxcrcxXVeYZ1wXSTIR6ncv7MRG0qmmI3HmI6ZuGDJACsudjhfbHU9knql9TebzVWpkKU1UrlIiWhIkKcpZpMRyqRGvbDEy25n1QQUnfl7R6nLLU2nbsQmf3WCtLbkTOpJqinxFglwVGdyZC/ImOcF4j5LrrzvBzpvNmLjfqoFZck7JmwMcGxNJ9UEsNphHXEvrTTuFKMNj2Gu51aWFDCf141cZtvTd6vCRGwX6+Ug596ozwP76BCPu21zSg9wWAvdnm0NLGkN5l5gDG+6gGMdg82DhHuwelxacoRVE8QoidNNrFFZ4c79L6Iudlgu1zg5+iJu3ryO28/cwiib4tmnnsNmu8bdO1/Aer0VPoU2+ijdIR0lqMsIMZ3zws1b00wlIsjHF6/fVhxFcENRhGSgibjWMNKarPbXQKvcezPNNpnja1UIJG/UKsE3JMKUkDX2OyslKORzzLgvW/IwbBiegHQ+dtC/m5TNJO36tUkYTz5QdKCoVPNzkWuTjnTfQvTZ5OfW1PzM4JXjdbFmWHRn0UH1S3/pL1Up8JUrV1S+S4GAjvAgkl7+Ai0X/TvcAJCOMIqjLP2m+EIh7zf+xt8oh2k/Hndsb6Xk+0nH4O0IOg8pdNGZSOchj5MOwj6H+fI5hX33rZwTx43uV7pB6SQkw/Qf/8f/8Te9z99ufCVdIwo/dAz+nJ/zc970O7p5n/QYONb9z84nCT7nxyr+dufWl4rAAebcvpz4CMIgncTEP3CciXyg6MY5xPnTj76w+laC2AviRyg2huC5ERsR3LdfLgJeqx+Pwyr8aI/xxyKedJ95q/Pk8n1h3jzJtf87tXe81XUeKkv6c+Dy9f9y65UYDCaciA3hdSAyg+f4I40/H8PbEEMMMcQQQ3zFCcsSK0IzuMv0V37xjexLc0QxQV2OvCmXO5X1XdsdbNT8+v9Ms+/PHV+T5d+hJJ/fytVuzykDwc1rzYz4WHMjeysm1+VMwHLNr3U+q0GeO+Uo+EpYlohA96z7nFV+fFn465xuweksNm34ht43mzln1RzDhuQIZsbW8R34re6oM9HZvvkHZ2MruMv5aMJy35UpZybF5ZiO5eDkdYe0HGgmZlHUpXjC8wyc1x1Lv7MEV08Oce3qEZ66eQNPP3ULDx4tcP/BXZxdLHH/lLiEUxMnXbAej0eYjEcS0pKSAlssQSk4zW0QE4nwLMfmEa03Gz2OZfEsC092MVKKhGwCtt2gmUxwOD/A/OBIDmqWpJ9cu4XDkxtIRzOk45k32upE1daFq2ZXFGmch0IxKQjPOte+E3jfFahRbi+zZ016GOUgO0sQuuxyDK7D1k3M9wxOdf/Zvl8PTHspmRLWVXtnyJz480zqCg/rLbzea7ROd3dm1j2R2RqqWUl/cM/LJasES2Zrw/Q3zefQ0NEEqu69+6Jy360afhec+BpJsl8lLLs07QzmMLgaUTUy6wzS5hZ1p2nr3O/Gsu9+5ZqnW1lJrdKdvZo3iRjJPN8srYSMYTNC8tWJ4thK6LXXN/JIgjzOtDa4lrhOA4+YgixFzR3FvtId+ukImaaViZWEAdm6pBhMZZJ7oImWYeSsYoB7RVjbFOBNtBV72M9V8qDEVpuEozRGjQTUiVmMwOOwfa+zslKA5I3rhTcePJdfOyXDuLnLe8dk0s6Opav8aHdMu65aZ7ZXcu3QrWx7YSdA09ndUBzWm+1P5tb0G+ZCmEVy1zor3h3IPFUK7UFcVjPCdlLbXJa7XHPCGPpktxebEg8ePBTqYjqb6Px4refTAyE57utLYOAU+9wnQkTO9h5SJlSViJFv49zsJXd6VSsa0J6o3Hcr+54UPa56wSfx/noPa7nXxLNt9BlqU3qD2N9rLlUJdFvI/t4WWPxHh4Z7evW1Ba5eGemVmUR68GiDF5/nMFsy107IvuyTDcmx55f78IX6MuPySYIuTe5FbzWIJiCrksGkFI8noBpCeXOf4fqkDaj6rtTwusGtSudfHxnxdo1BcMMRKxCEB54fHXEszabQwPem+EK3G4PC5FuNgG/oB9+nHxw7lpGTqxrEFAqjP5r4SrlGfA/iFuji7b8Hx5NBHAHvoxD/uHg758nloBOUr83Gc49DYfAYicLoJ3g59hTjyKmlS5NrkPdRyApiGDm9Xwol8qMJOi6JC6BYGeYGx4IO0YBp+PSnPy3x+HFBJAKPj2iPEERrvB1BzAOvL9EJFP1DkEFMfAGRDZxrXy7ong2NSUNJ6Gc+85kv+zxeI44/RcUQ/DubOPaDY0bGb//vH/nIR9rXILqBDuAgXPM1whhTiKdjl+v/8prk/GW8XfvMW4knufaX4+3YOzhePH+ORx+FwfO/PO4hvtw6D055zoFwXS5zvunmJmqDa+xxrmWeG9Em5Coz6D5+K1ilIYYYYoghhviKE5YV7bdz9+TqS78JSmtxjylDUFKg8BlhPh2JsTwZWcMqiiQ0JcrpGcqxeR872m9NoKxKlmJLAXXhysQ9E6fNFYZii6huMEmBZBQhaiiUjV1MoSBBIYE8ylAaTTGEja/IUc0xnTATHaHYLr1Jl4vGvfJsuZNjuh9za7JFl5hzStUwT+Ph48PzbrUEimsm2/Sz3eYI9aZn+kJgYooEJz7Wy+ztCXbU5lDO9d7T8UwCcyhHl24qTZRCcW5l7Sj02nQUb8oS26pAVRZIsxQn87HEmGefv435bIbnXnjeGvUVDT7+sc/h1dfv4ft+6BM4PV/g0fm5GLE8siDC8lhn0wmeuX1bjfoutjucLda4KB5huSl01esoQtnssCnJ2AXm0ksi7AQD3aEpKfaVWG9WeMh/WPG6lg0ypDg8uIHZwQGmR9cwml+RIzskKMRIZkai7a5WmFNZf97uC8oUmDk4fRREYCO3DGR3L7tY0wrBZiveK+kPDRv1dzpSwwUKDwiNGrtVEbIgb1KC28Z7cc9RTcHQtVdTdXtrTGLivqjcVw5NRusUWArKZVGq8QfdyxT26V5ebzdCkyxWKyzXJZYbNl/kdQoN3Hw9XqJuBHFYhlYlKHoNLNtukCYS8g8UTOkwZiLCkjicrwIom/wW8jBiMtvTR3TqpqmOg9gKk+l83VDo5XynW5MsWWIvau4wO1ws2YyQYqmLrRFFYK6ZGDNybJMMkzGTDZ5oAt24QLPZISbiImYTOPKlC38NsxZvih1Wa6AqahQLK4U8u6h0jCdHE8znI4zySFz2bMc1MUJZpkjUsJQJIAq03vTO1ygbmXIKblJL1JCZnDtOJiQ2yO2lo/i42WGSNEjqEkldSWBeEm2ya1AWbMq2w2TEagg6bemKzhDtEjQ1nVVM7GU+3W2vEdajbqxBm18nuroDqTloxEELlSC849wwZn0qbI0lt8IeSQSP1g6PS67bGnXMhpCOSdK6caHa+eTkbmtuSMyMvVLEFkaqRJw5l/mzrkorQyUGI2XVCjCeTZQQWW3X2GwKHJ0cY7XcyL19cngFm2KMN0ZjbLcr0r01R+jSn+xyREktjnis47M9IE0dKUL8CIeGR8XPHQn3hiJJ2s8CPt4qVtr93Ka1udNb9ImNZwv3aPEbYdz52cYqGT6ZfHRHW0hc9soGCe0m5tdhW/LkjhI1Sh7ZOjeEu32WBGE6oJK4rj70gSv4E3/m85jPUhzMR/hTf/YL7WfidrtBmpkjSrq2czH55fi//C//S7nc+IWYXOC3Giy1/u7v/m59YafIRQHkSdiTLIVmuXDg6T569KhtJsb72WiNr0vmJRsHUijjn580KMawcR0baZFrys9VCm19x9nbNQYMOtLojGR5P4VcNsCiCMvrQ0Ynx4bl3hSHKNL9ql/1q97ye/yyX/bLVB5PN/K3fuu3anzIuu0Hx46Nz3g+vP5sovWjdfl9pVwjzkeyodkcj+/B60K+bRCa2TyMjmTiHuju5jVjooCN2FjS/3bOk8etHwqSHFceH3EIFFxZ8k+mLMW37/qu79Lx87i5zji32KCNx0kRkk5qcncpqvIc/rP/7D/TnvqkTt0nCTo6eZycexRcORd5fMRaMOgE5X1c/7ym/HcJxV7OGR4r5wvnPd2qTCKQPXy58dyPNuhW5rUMgms/+F78PR2xXy7IEKa4y2tMVArnxZdy3faDY8+xoBjJ5AB561yDxKf0g83qiK74pm/6JjW8I1s8NBikKMvryjEmc/fevXu65kSuBBwFm86Rb8yxpFD+nd/5nXtJmLdrn3mr8eWu/eV4O/YOvgb3QGIn2ISPiRaeK3EtvP9x8eXWOfcd7gscYx4L1yD3835wDXIfINueCAueMxMEnH/cz3hcPB9eZ64Tzo0fy6qEIYYYYoghhvixiCdmLCuC8qAw8Sc4lonEWBaVmk5RYC5Y8p1EmIxSOVHZfCtLgyOZgkfnrOrK+cmHreWilBm1ZT/aF2pr0NZIzABFz6pAnjSY5vY+k3yEkQSPzgFmLjx3W0pY5nGkKjHKc+IsurL+0JxMJfrkEktMZvl8qi8hFKz4Gi2fWW63/fvkKKYbWsgMc0WHoZP4UrExWK0x6PjUwZXZH1sXDiMKBDlG+UgcUf40gc/VIIl+JpjT0ayx2gElG2xxTPV+pQSb+STD1aM5PvDyu/DV73s33v/Ke/Dul96DNB7hi1+4g09/5vP4/h/8YXz8U5/GxWqJoq68NLybAhy361ev6nY8n+NwOlXDLKos5k41F+y2IgaFbnK6NE14VOM0ikYNy9w3uDg/x5JdkaudGj/OpuQr38JsfhXZ+BBJPvFraNK7xGXiLtikr71RYGbTvbJr1uf85cBh9slmblNxqvs4CncEfklLcWhcF0TR4Dbsntsqsns+/MeVhV9yUAfVEY851q4eoHte347aedT33JKcVxSdqqJEsS1025al+Nibgq7xrd8Kic4Um/l3S8YEUbk7F8nvWnthLLq5b4xXE5hNZA7O4o4Rbo/nfO0e3/dyC4uQxUo+jUdcP45XcE4xpU9iG8RMJiKBZf415zewXldYrSmiV7Z3uLPXmgiyGZ+tca5fHh/3Abp3N0WDLW8b8pfJXmbztUprhg3QKMJuC7rtG5xdbPDodIUHDxa4d/8c5xcF1mvym+l2TZSwGecZJqMMo0mGnM0L01gO4dbJKjHRHLh08rPRZammnBRimYAx4Z7OYzb4m+cZDvMMc2I2KBwL0WGYkJoN+UqbIyb20zGdCo2TJmykZy5s465zL+Cx0rXt65F7q6avo0nCNWznkfPpuYfweFlZID41mwlS3DaRU00VOQfaCgxzy3ec+54YqlSPCczhP8nXOv6eiCtnPcVl7iFshLcRhofXg8fD5qejyRjbcovT81OxlrdMRDTAfHaIg9khRhkbtrKSwwXiNMZonAqZwlwg34PYC97Ep+Z7hz/Hl36yf55jjIQECePa5pS88qQnLHcpoceSqP1zwBq7BkFZqKXmsoe8/7HbT+Siq0wgA5t7tBjS/f3KPssYP+2bb+Hpp2b4fX/w4/g9v/8H8cLzB7hxfaLzq6qtNcHt5Y0pYFGY+I7v+A6VSlPo+PW//tfjrQa/2PMLNb9I0/FF0e1Jgl/eeeNx0FX3x/7YH8O1a9f0OwqNLFf+J//Jf1Ll1XSi9d1tTxr/7X/738pNR34tX+/bvu3bcOPGjfb3TzoGHGcKTj9SkL9LUYHiFo/7Z//sn902/AoNp+gu5PuwAdeTCF2Xg6IH3Ygss+ax/4k/8SfeJHzwfCgw0cVHgYzYgeBoDMHjCuzVr5Rr9CMFURVsjMbmdhTm6STvs3c5Rv/H//F/aDwpRHKcKaQHB/DfzjHQRcv58yO5ximqU+ji+FEw5Jqic5VBoYxN0ChC8jjYAC4kMELwGv3cn/tzJULyWrPZJkU7zoO3Kyiu8/rQifkzf+bPlHOTzswQ//w//8+r6RuvJ13YvJ5cM0Es5HrgvKcwR/4xHcxMejxJ8HW+FGKECR02YOP5Py54P0XMxyEtLgeFQh43xUc2JmTTvj5f/ksFH8c1SdGdXGAKnRwHuqUvC7DcBzh2PCZyoYO7lnOU14wua85BzgfuKb/1t/7W9vm/4lf8Cl1jis8BIREqE97Ofeatxpe79pfj7dw7yGnm9eJ48HOaa+VLJTa/3DpnEJ3Cz0q+LpNcFJz7QeH+T//pP61KEJ4nH8c9ObwnEwUU1LkOea0ouPf3syGGGGKIIYb4Soho94RgxOmJMRhDcywzKu4wzjIcT1mSvMPFaivX4ShNJW4cTjIcTTKk2Q75hI7DCAnZy1GE00UlBxmbam2KGlkSYzpO9aWdgg4FkGLXoNoZG1noCvJi2TyOTrRdJafZU09dw/HRAc7PSzx6VGC93eLB2bm+KOdkEscJtoU1i5tOc8xmY0wmI1y7cqxzuLg4RVUVmI7N4UXxe5SbKExnM3+Sv8oyawlkFHIyCkijVmAz9K8JvRS/5EYOUAvphwEL0AnI/GnMZ2MwZ7m5QVWGTjfnjgzNRM7q44MDueXIWCZXk6gDitQgFzNKsVhd4MGje6jqLbbbM4mIr712H4uLJRYX51gtFjg5PsazT99Wg75XXnoF4/FUTkgKdB/75Gfxyc98AXcfPsInP/9FCV/ris7KnZivFEzIWObtqWtX8cGX36VzeO3eKZbrLT7+xTdw/+xC9j46VclRHaV0YMa4fjSTCLVcUOiskNQ7uQZvXT/BV33wRVy/fgM/6Zt+Jq5dfwpXbj2D6eEJktEE6XhqgocEOiJWyCLzhIIa07ljmSiMumjHO5Svm9HXGtip5JyJjmC41yPoFmT2gj8pRrtXM7gL5ZKnkGb3t83ExNDtYzCc7x2aNQZHNHM2RJTIBUuwKeeOcbENSM7ra+5kIxyEpov2OK2uvlO6ZWK4Js0/E/FSFKoC2KyX+pL06P5DnD46w/n5Anfv3Md6s8WdBw8lIn/h9Yc4O1/hbLHB+XKjxnTr0sRCssxVfeDoGhO0eV6ePGHywpv8td/lnCUdGgfaWTgygWzhiE30mFzw5nvhBIjLoNuJySayjmcjzA/GcqESxyIO+46r3J3cfN2Er2l8vpzJlYjrlHMNuHISYzZlZUGNUW4C6mrD5EWFL7x2itWqFNKCCQzNEnKwmTjKbd09fftIe0NgCZ+ebXD3wRLVtsDmghUNxlTmf1eOZziYT3B4OMLVq2Ro8tzoIm6wWVKoY4KM4r4fuJJHFEbpogcmY3PKnkxHOByRBW/Zvb6YKP5xE+F8VWCxKXGxKXH3fCPURk0hlK7mkzFGoxTrbS2hnMIxhXabHon2mrOLQntfSaxEY6zt0ThzRIdVXJTkT3MeCSnUa6zIa1oxUWPceUrISl7pulKY5XXneTKhs8MszeXCDkmDommwYtUCH+MXcSQsT2hoaqwgc0DvkJOFHAOH4xSTzLjEZLQ05Eh7Ii0RPxpYnq5QrAo8+8zzeOmFl3H16jHe/e4X0exKfPIz34/zi4e4/+geLlZnaHacgxHWmwr3eU0rJrYsycCEA685q1FGWY9tzN/GZQD668fBQYYrxyNMJxmuX58pCSA2PpNlvh+Unrjh/FuumNBj5Ygl23J9tvAzLsHhLHdMx9Zc6OwTwLGvMuxqWzc8Lrmbd/wd1CeAl4OfY0S18NrPZ2NHfdjcCtglJo0kuIvrbGOc5eTxR2L8/9rf8tfwLT/1RXzVu29jPJng6PhQ5/5Lf+UfwN+poMBGweF7v/d7JTD93R4shX/3u9+NH/zBH5Rw/LggeoLOwT/yR/4IfjwEhakfSSz/8XaN/m4Oil+/7tf9Os2fJ3HyD/HmoJOXqBCK9D9eg+uNDm2K0EMM8SMFnc50Qj/7y/8g4tHbx/kfYoghhhjiKy8++xt+1tv++cNGsmw++7agMCR+RMbsNZeuiQMJm85FqcSIbU1hmYKAedKSbITZnHgKuvM2EokcrCAhwxyuVtpLV1tT8cs05Mijy46iMqWSjivZIAX5qRGuHKSYjVJ88JXn8Mwzt/GZzz7Ax8p72O3OJaxSJEnTkQtjjrXw5l0SxwOqQNqBM4J79mFRJ+lAppDW41sas9lcoWwCSPHUmM4uFlMYdOauCKMUbRwR0hq+e4xPOgxbsVPMWxNjyC5OQknziC7lVK8dHN8UsKKEDbn4TDbjo6s7lEqzZJ4uwwr5iILbGLdvX8EH3vciJuM5rhxek1jx/T/0Q3j1jbv4+Oe+gE994VWJIGu9v4lldlAmJFKgpAuW55HF5pCM09yaChLTkaZgL7JK7j7HE7DxIp2dPO+qQVQ1EpbGeYKjwwM88/QLuHHzFp569l24cu0mxsfXkE7mPRQEB43iFmdN4UIxj4E4DP7d7gMoonTl5xKGNNaBDWxl92HMw2Cb3G/8aXuvwATtCWx9lrKE6yD09idMjxWsX+6LrH1HtBqZqZMbxaA+9qLDnxgWQ93eei/a/Qx9JYPjX+KtXPCVVQNw7RSGEuB122yYcChdZCImYSORnw5dzhO+F+dySBYZGbd/zu5svXTugflrzs22s6BdBVrVNQesiZ1cuy1P1pItnN90207HYxwdEPNCUXdp53TJAE7hmkkrOejhgl5DPEOEfElXM13KzmVno006lMkFZ7M6NpSsiWkxLjv3hkgOaOMPKyHGvcXFQmJd6GRuKnKTUzls+Xe5jhsK81ucbCZKehBfMZvmiCNWNdAtTeGPDvDa0Ahq/GNNATkpWYnBNaL9TWz40CvSzplzeEqec2ZN1UZiwcdq4lfuGqxDdsrHUmvVqwSINjFHMYXubo/hnWqIV+0QFRSKE7/etseIv86x6aFh7D0C/sWa+vH1KZCL68zj29n+pNengKzegaHqw9AWhijxlcCHKldgiQi+NwVqXqvJmC5oaG8YpWwqWclFzkqYNZNRTFTOZyZu5wniKsbF8gKf/fznUFYlXn7pJWT5BNeu3pTLa7FZ4mx5KnwOG7amQoFYgshc1WEtGeaC58OfTCryaIl0YWLJHN62GpiIKavY2NO+z9nlCJ9RhrCwOh5vqBeSiZy7SkDys8cE4PD+nCd2OO519sSUUEnCQRnyJmA21NBSCRJrEijUCNduaqJ9ux8hwuv3tmIqM3lC9/n3/MVX9fv3vnKMZrdGUdQ4P78Eyx7iywYdcHTRfilR+cdbcP5SoHun2atDdPOHwvIgKv/o4//6v/6vPefuEEMMMcQQQwwxxBB/5+LJhWXvbG8lwSZ4BbGJwoMhLbxZF91v7ZdzQyFQlOB310wgYSttNydgEOsC6NOdvcE95/gBPlb8TQpJWYyrJwc4nI7kJqasS4QGub3bgozhQInwhn06bitn1q3XxcmEXhNRTDCzZnmh07C5iO2Lu0SYIEbvjU7HJjAHoj1OXFcpC91jAjE36JLWaM+OKyA4NJ4U4uh+9jHaFzAdPyDNqHHRiA3bzMlM0WYyNuZ0ktCJ2WA6mcpJt1tvUKzv6s/3Hz7A6fkZVpuNXI3Ei7JWnMceUyUOOAReE7pPEWNTNTi9WEmczzJzSPJpoyx3qqkLk14mTkQBxR26lFPEOJrPcOXwQKy3W08/h6vXbmA8P5RDOaIlVSJLcAOb2Cu0RWv3DpgLY/nuLl+JPSRFxya2VEGAogQsck8NDlpaYB7vvV57edtXtrH3cvceX9leqtcF0JvxtY7q8EbeTNHE4zDO3evtMZxVYh9GOvCVAwCjE7+toSWTDpVckBSUeVuvN9iIu2y3mopq74htPnbCu1XRG4HXROVOkNd6v9Q4bK/ooXXnu3gXmni2SZVoL0HDdUvMwnaz1c830Xba1/TkjovoXCNKHsQRyoo3gCZTTSE+le+vhnYmMjfSts2x2WtLqOOhu7ig67Mn2ofGnPaTaAvDW4S9bEPsz3KDUU7Xq+2Jclr7VDWEvDeQ7LVlM7ewvbvY1f5nPopCMx9DoTBWoz1I7BYiI08ImxYTO+wflmwLlJQgqreqoifUYuG8Zfz3RFdo/hbc53LaczzDrGubJ/aBDnbdhGFQAsuSZyEkcPs16oRl/12YavGlZqdh3bTnSRexna+WPoVo7lmVJbs4Xyia85yIHOE7rtYrrNcrrNYbjHcZcuJzyO0ejVXloUkS7YS+oHis9opm0XfGd4dBYpKCQreEXDrmNQ9MuDe0hAvAPTJNmC9WmdIbkR6s3LYk7xMQmvcJc9QHX/je1CvtbhEY7edH91kRErs2nN06bp/dzmPgL/21O7j/8HOap888dYBv/yUfxsE8Q7Gp1MyRnxv9930ngiIab48LNpRi2f+Pp/j2b/92fCUF59Lb1TTt74b4lm/5FnzP93zPY3/3q3/1r9bt72SQq/t3exDf8KXmBLEOT9qU7Z0KYkCGeOf3Zgr4QwwxxBBDDDHEEG+bsJyryVKkcuc4YRm/Nbrif3R16Vbx5w4lagnH5IkWDZ1eOyw2LP0F5hQZKQaRa8pGVCpLtxJ0ffkOwjQdXsHh5uwIOrLGWYSjgzG+6n3vwvWTAzTNCNW2wfn5BV6787qaxvE9KYhJ5FYjvUROWZY858R0pBSP/Yu72MDWWIvu2kiOZmvSR4GOLmC14uLBBPdlGz0RU8360J5L6Mkll+3OnL77WF8XnBynQS6z9ylERH4m8R1FhaoaSTimsEsRIPCU+ZMCVA0+ZoP16kJIj2K7Ea7hxrU5kvgQ08lE5c4c6wcPF1ittvjC5+9KiLl/fo7VdoNHmxIVdWjQoU0cyY4cktaJSJcnrdG7aISzdYPF4q5E9olQITHGowlODg6wKkssK4qDZEnTRVxjWa7VmGse5ZhGGV5+5nl84IPvwbPvegk/4af8dEznR5id3ECSj4R+oGuS/0/XNwctlmPZxCVT7fynBtT/3MrGrqe0nOJ+2PxVUzHzFvYuhrXAC4pRK5TSee7CYIek6PFjw/1BYArzIPzlstm4FZV53GHemLXUGneZu7p7StvRr30fnWmPTR6EXYqt5NIW2y1WqzUuLhZ4+PBU1/ne/UcSlh+cbg0RIUwDnZmyFuudmMTgGRkXnEgDXu/g6PcjczeqMZcvifd+FUIihXNWV5NrWnuFAXVD0zO+6HZTKJnRNNwL1m0TwbbXmQa2Y07LyQx33vpalOuVQiQxNHQls2oADXIhBBrhF6KMKY0dypQ4AeOOC11TxRJU14sK5bbB7EDAB1sHqVUhNNoruPdQsd6BQ1cVQHNOTvVGbPeimGtvYQNNjlFRcn/z67erDduRUICOff+xRAGxOfwTxUzul+vSEjVseFiRNZxwv0tQRhkOvCHmalG1rOJ6Z/OZN5638gVsSiXuNdnV3Iv4u0rHLVc01xH31cShGUKveKWDs9LlgA3Xu8XChMapTBlSqG3bK2qcmES0a97haLjn6lLzZfl6HFOq6j5/JI4mfp4TsqoTzPIYoyRCjozkEqw2JZZsxsfPk6JEndQYT3LMD6ZYXdCN+0DIotffuIvDwzluPnVFDfvuPHgD5xenamZaYyuWynyeoywbRGubQ2GOcUvnWFHQpmOZnGfiazgPOFz8TMv4SRmSACHxKZHeq108fMTcmW8Ila66IBInmziLJLEmk+0ecvnWYtw74dkSj3YQnDNEefAIeC5tc1tPPHluCE/dnODbfuH7UTOzEkVKOPI8OS/5GnRhbzfrN2+Xb3OQActGVo8LNjEiM/YJqVw/buLL8Zd/vAX5yz9erhHZruv1+rG/YzO7IZ7MVf2lGMShcdwQf3vxd8N6+nJ78xBDDDHEEEMMMcTbKiyrCZ3cdnbry21Wmd2hIiSHyWFmjaNYzsw/dx5YdxT7F/M3m6Xch9mZmFuHF7/4sxngZDzS7fyiET9WbMuSDaa6IwvPtUM3t6A5lo2L3DmPg1AWSpE7915wmO0LacFRa3iOfhlxMOX1Xc3h+eZ+dBerN98KYxjcZ+E4e37q1hkebsE9J4cky9PV+LCQsCi3MQWEPJfwkOc50jjFpq6wXG2wWKxxenYuwXG1XYunTFExYBk0NmbRM9HdYMWtHZvvWZcsBafoaqJYno1b5yNZ2eL0+lBJ/IwijKcjzEdjnFw5Ef7i6vUbODi6gtF0LmQKucL2rFCOT4H5UjO7yz8D/qI34ParoOo/zoXX+VX3C8BdOL507fZE471Hhj/0vMoSofuisl93szO2zuVWBg+nE95IT798dP2f7WJo51XfPmkJErvxmvLG5l4lHcxEo5S1nMxMFNgI7fbdr0TdhFH1uRjcrf2hDGLYY78S9XAvwQltYJzO7R+eaILoTuJWUnR88W4c9ge+c4u665jJHrqItc8ANB0zOaS14eZ2JmAo4qWePwiip00zmztyemsddY0KlRyKvYLA54BTJdxtTTd+gzJuhJyxvcyEXAnkLp5rf9BrdS55c4FbIqPdTdzNqsabUt4NNyQUjzN3ySdvGyBSVKaQ2O5PXa7Cva3uPHekj59H+5iwp0lE7a73njvd98Lweho2b86oZJ/zobvXs5cLg+T98zq2ec/evyejMudA13fKmzmX6f7lLU2YkKSj2lzwek8m4iiaE/8D4ikKYV+43/GIUvKe2ch1NNa1IIudjfgkqJKtXNrxKenh2KNw7tZA0PjPdnaWxBBKX8lDdx63/UD3KxraP/RLI3oG5oCvYRPT1l1+qRqiq7DwOeoO8r3tIDjq219c2qN8robmtNp1/NrYPDS3vL19QAC9c0ExbxD0hvixCiYqhvjbi37zsiG+cmPYm4cYYoghhhhiiB9TYXk+ydyxnFrZt6Ma+CWZpejin/K/qDEROqGY2eDhig3WyH4tkYlnSceelUeXOwon9h1eom/4Iq8maDsJO3wc3YaxH8PzTx/jaDaR026zbvCpz7yO1944xesPz1BQ8NB70I1pnEpr1GXCTJ5lmORj8UVNwjRpTV/bA8aCZdleq244C1MbA0szSHIRm741EWoyailW8z1lg6aAxlLuHZqSpf6GJRAfekQmqZ+rBHo6BptWSBaJwoWnnZQQYj9SjNJMZd0UTsQEZQNEOhQXDaqCTcM2OLv/UOLHlaMJ8jzDycmxhJU79x7h9S++gdPzJV574yE2RYHT8wUqitE94drwJrEpK0KbWAOtprbmVBR+hDGROJPr/YnFMIvoCkVdIB+NcDyfotwWWC9KuYuzBuLEvv/978bzzz2LD3z4I/jQRz+KycExJkfXENN9rfGW37Vthge4U1l/D/zj4FAOdj668FywD9xZbzqn53dcCW+o1xN6goAdxGzZ/IwFbNOwE7BbgcafuXfTy4aUST/pYCKc18ib29wFb88r+PEEbIY3/JPSU+1hL8wEvOvNabtT4pCjQZqq0XyjI7dsWDPQ6FqRM0zhdlNWWG8LLInC4Lxloz41dQvish0QkxASB1s+tbm2jRde7/kyu+RHwCvwdNyhKWyMidcSll017Ts0+Woa/WqHYh1eOzQx9Pf0axiYuH0Rnhxwvh/d+HT5Btc395ECrKpoJFaSiDAWPoCir4ntHK96yzW8U4O7Rs5NYnR4xOSaZ6jpKKUTPGqwSeii53ib5p3sEqRItT7OFoWNVcLr5vZcPc6uTZpwL+P6iTGNUqRK4HDu8pgbMeXV8NMzaRdbuoIbzPIRJrsEhUzGrPSgOJrrdesmRVOS90sOcSNe9rYms9lEUG5FuVdBcL2yMoLXe8s1S7yP718UZ03YBXI2ugv7cF809alI8Z6cZyF9C3Nij+W6jVB6dQPrOVImt+RGtr2uAz74PucJMj6WY8lXyPMU4zEbpUaY5txv2YDQGgWeLUudHytgqNgfHQLjKRNYCaqGbPctPv25T+Po8BjXb1zFKOc+dAW7myXOFnfw4GyhapfkIJPgz/2xqHZYrdnkjpx6ji0dvIlc04bI4Hzcaf1YzpH7jJ0fE2s8bjrl1ZQzJGH8c4bnr88QJU1M/Xccsp7LJA/nZZyGfcGaIypJIPHcGn5qd+JnJ9d3ZDdDq/gk9IvjKGt9dqj2QkhvO94WAeSJU8O0sCmn7be8RqNJv+5giCGGGGKIIYYYYoghhhhiiCHeAWGZpd78ipo5Yznysmc22DO+Z+e+Cwotq423bHSEGomLFCwztgZD9kXa/+hG3sCYtTuks/gXaIpddJwdzEaYT3MTOko2Hlqq1H+x3raursAzbcNdvkFglhux5b4Gnuj+Te8ZXqPnsuzuMEesdABnvdrXelcNHVUQmNP8ii8529135kwOEpu7WtsjD2J4cBtaAy9jRJuQY2InS6spTBkCIU/ZCGsmQYnM43E+kvB8frbE2dkCj07Z2LCSuCzRU8KIH0uw5hokWudDcTmKKYSGJlNBwaDo4826XNSiUEFhmaX+EXEDeuGY2riSEVevXsHTzz6Np555GjefeRZxNhH+Igi+wZVuAkgPYtoKvK2/trsYrZU4cIqD4NyJz3u84/CUXoIgCMf+245N2hb5B/m0e+/2zxKbugZp/TmyV+HYKsn7E8n0KndBttiH/jF1zuQ913AQxHtjJWc455o3dNNru31eYlJwtrubVhgFl9VD47G2+qDPju1xh1ueSxDsg1N274yCx7PzZ3fMbH+ez/HgOa/I0aYbt8dyN1E54D+69wvXqf+aYnlTVK2Jr+lgInTfJ7k5d9OMImqMVC7uWEKyOPBsdFmy8aE3vawoNGbCVnDT4boTBjvc1JPQhECJw0Q0kHtMeoGcsIYHEmc9XBs2fuM5BpNqT2A13Iy5WEOJRiEcBd3/O2RK7Jh7VlAR7b22NgWNcSxD30Uc2NZWneHYBMKnJU5T/PY9OzST9ERBcDhToDeceVhPJvRryjlPuQ7vFXjRas5o61fs6Nam3K6sdgEGL7Dt+e4IpmNZHGS78fx425aJBFwbI2KF7LW5N6Z5hGzEV6pxvjhXQo64DF5LVlHMpgfYFKcmbjNnltt4sZEjBV06zUuJ8HxNc/AaL7tz5QertbAWuubuwg/u+nb2d9UkbdVJy0YP+xgTZYZo6hFeeiPUv8M+S3pkdt8j+ntQWAfdntU3S/fXZGd1tjUc1jyDY/nY4o4hhhhiiCGGGGKIIYYYYoghhni7hOUZbcb8EuqN7VIyjJMdSom0qZxkacmfxn5kEyaKMxQuaAQdZ8ZCpdjFL7TmlqQbklzfDlvRfleni4vONzaGSyPM0wiHkxwn8wMJp3fuPZAgdHqxkFuOjmE28qMWQtaqnh+aU7k+YNgGsmOtQVbLwAz4i7ZZVxBbO1FazfJcCd4vG3exxL+YW1Mriq7mHlWjMR9DK4E3d7OEJzJI3cpmjQKD6GYisnjKGRtSUejimFpzOz47aHRybYamfwDWm1JC14Oz1yT0fOG1+3jtzkNsC7JK6ZxmiTqdzz1xhEgNcmvFcDb1K+JFY4M5Nj6jA8+F7bZs3x16/N26alS+nyUFpnGCEd3lBwcYj0a4/cxTODw8xEe+/uvx4ssv4cbTzyGZzBHHbKxlF9tEUo1eT/ANopYJ1K0r2Js5hpEyp6wJsiaV+fN6jbQsN2H3d/eGxwY3dCdMt3iNICoFHEQPiRKaybU8ZRdrWmmnq/1vX9v08VDi7uqVzi2kQfxxcgabmKV7wykFF6KpfnLdbrdbcXPLsjAeIt205Phy3lDoL2r9OU1r43+TUbyz+anRDs2/XCxWo0pvRNaf43ZrUy2toBxc/C3rmc3WNAymVrXXpGVSd2PWpn9cjGtfp9dkMYxhtzf0Gdj2HIrBhhig45wVFbJeqyKAPFsms7TmeB+b2UmVbKSo8XjFQyY6hM7pukIm16rkTuTZCEncIGd3QJR2yWqogeiEDfW4jqgvg8/lUjIB0v7fzoNrY7EqlRhjw9EsTlU9Qa46Gqtm4PSmx5pCa1kUch9zrhRF5wlXS1Rn+9IpLQG4NgE2api4oqMfxoZXAjBTNQAFcLqlOc8rrmmJuSZ0M13IChPpxnxuaPDXQ8CwqoR7EpOKNR3O7TUxgZnJP/Ki7bmWlOJrJfL3dnkICv1qjNprYilOPIVyP18y3onT4d4i53gdYzpNhUphY1bOW16rqKkwG6eYjQ+xWdd4dO8MTVPh9ddfF79/PI5xfHgdZbXEcnUPdV2iLjfWUJafKRTs9TkRY0S+cuK31JoxRl41ERjJ3A/ZW6BN7r2J39TVM4TPDE6zOCYn3qtSVMiTIB+PLGmYWjWA1hxbo/rY2vKmG98Z/HxRcpTl+A9VJtq4XWR3zrzc53xSWEdeNbHroz64LihuB1RML5M6xBBDDDHEEEMMMcQQQwwxxBDvlLA8br9MW7m+fcXlfTEqNfril2k2yWoc22BNr8TKjelmZVM8/+LLMnx3vbFMvI0oiGguPLmjOY+BaRZjNspwMJnonT/3xgMsl2ss1oXK+fmFf5xn6u1GodmhA61oJ9yGN/CTu05Nwjq3maQTP8fWuCtxl+42O8991nEnG7buNmd2CgnCEnrn3XZnaGKxBBaNHZ155nY0YdRK0CVMuNhH8UnCIBmrFKLc72cPM4eqhDgXfbdljV1R4+6jcyzXW7xx7xR3H563oA0KHBSZJNNSXOT701oeSASmAeo4eQHYxo2Ppfiv57AZGMVRna6JIAWdnvUOs6QUwiPLc5zMphKU3/Oed+PKtat45f3vw/Mvvoh8doQkn7ZO2IB12BeUey67lkERRF68qQxcY+CCZOfD61x89ufOfxzczSZod037Om9swGf4+7WmwJ6o/Kbi8SAjhuf33fc98TzwrIMlf38SuftYhew+n1oGRzilHlPZ2NpVWemmxmxyINpc0bzR3KFYZhgVXWKtPWNf90XlII5zPIVl8DVAoa1DY/hxt0fTE9sD57lFKLiTvD3NTq1sTzskZswKip2aNdrc6r/L3vl3A2FieGWCa7iGFHAp3tox7yTMpe76FYomppy6k9BMNEGWlKhjE6gpApMfvsvNI8zxozjIpBQbebqajowN51JrcFjFJtJzDZiJ2tZ8u0cwicbrVMVqHlcrp0Lh0XAHxE+E66YGhXWjyoJdHaGIdkrIycWr5ESvUZuWDrNL3DcMg9GKvVoX3I/tuknDptgrPILjLHzGhkZvvAZswqqxRNjqA4DFChzYGLBl6nuzRmF5OKdClYD4RazPCG53S0jIQd6KmOZutvPh8fBmTVUNRWRHx0Z8FInVKFZJL08ENTXG4xGm0zFOscZrm0cS1O8/eKCHPPP0Tcynh1ivDzAdT9XgtMDGkmn8zGqA7TZGVZLpbGJrGtMdnRrtx8o2XOTd+XoK2Isvp8V6hYxY3UyQdnOZ93NOkfudEApOxEVl5yoEFC8nE31qUkqciB2HD2x/Q/OGseKyuAO+czT7tNPnLJ+Uch04P1oVO2Eud+DxIYYYYoghhhhiiCGGGGKIIYZ4B5v3dZBZiQEUOyh6JhW/pJstmK5WaxJEYZWiqJW2q+EfmZIULhpzswkRQTHBvzSbIGLSRCjzNaZpg3Ga42g+Ev9yuyn1/mRXUhygSJCk1sROcgSFCkflGqwDrQhnUrM1lVIzKzJGeyJxEG14XCybV5O8gOUITsvgZiZT2UuszbFM96KJbi0WxL/Um0jL9yQz1ArpTXikM9LQHBK73AWs8n53VIdy6k7o7gQuuqGLYiNxYjSdSEQajWfm2t6tsakKLMsdFrRSehk0z3tK552LijzujqFLUcPleDUXs0ZZKhHnOIeGgS6chRvdkjpWOcx3OJjP8NKL78LJ1St47we/CsfXruLk+lPIJgdI0nHnQt3DQwS8RBBO2+5YXkJON26NXUVxqDS2sErj+Tsv71cjMueUyiUYGtCRauDiZ0u7MF5zi5RoGzEafsWalfUczP6ngDJomc19XIXwAD10RHiInurj5m78VqiV6NTjb/g46Gr1xNP96MZLQlFioiPnuDVyNIawiXWGYzG3Y2gc5870ABWXBuxu5L6K27rxTUnrEBfdPhBwCa0kL1QDxWkTH4PYZX3yJDXKQWxYARPOW6e+9EhzuRqb19ZSGALTv/rl/2EeWTO9SJwDvnIKjOw1LbFD7EEj0VTu5biRC3eUOlI7iYyXW9FAXKNio8Nka0JjZmuFe1vDY+Pr0fCsZnCe/KHoLEGe4rIJeoYusT3DnMXkKUcoiwbbuEI5zky05PnQoUsR0JsH2pjHYgpvyCmvY4yRqjGpRGOJymTuWh6GaaI0SpAnJupL5m4V0G7KhPNQ4z+xs3uN+nqOVo5P2ziU9SjkBouZvEPsx1kxGRWep/Vl5yn0is9dQ2J0iYuQvwhpSa4zvpMaIXLcS6qqmfGG+UedHDnZ1pgxzSyhFj6LqLnmKRu6xpjMEmmqp2cPUJUljo8PlNhK0wmODq6h2C6wqDfiaxt/GPo84bUiasj2Vvv8CBgVuu6t4sSuZeBRh54AbWKstx7aygZHIYld3zZtbAsUumRO2JCiS01Htf9aA0OrULDrzmMwxEn4XOhRzvt5H//M4uekXtrxHHkWmvcxyZjsVUQMMcQQQwwxxBBDDDHEEEMMMcQ7Jyz7TzWZooM5STEZ5cjoYKVgs+OXfAoo/sW6rvTNX6W3ZJtmJlzRVEmBhDe5GyV+0d22A1uOmRPOQkJl02A6SnH9ZIpplmO9LOTo2xaNGjAJDUGVh6oCxTSKAi7yUJ3Qd2YJ3SZEmmhC6YkqSviSHhxhPHSWZtcSAMhapTBStwcURGWeUyqHGf8sqYQN90Lpf5857WKZNcjLOqFBh86y/NRfL3HRz1zJFUu3yZ3tkRUo8Oh4hCwgdoDNpxbW6OuAeIkU4/xYY1viDIvtGufbHR5tTHjlS1FYinL7mSZZW5rOo6+dW61TdeboOKOLL5GwbFeQTNrgwnZJQ8xQu1Zp0+Dq8RE+8pGvxvVbT+Grv/mbcXj1KtLRDHGa76ENOndysEuH8PtM6fafWzXya8oVmmqLpizRFGIYBHXNrO1UfVx4F6OX18hOqDU/mwBkDdJakTa8V6+c3X64mOwCWhPuFxzc5ou9XnBfm8GwI2R0tmnjxLpg5IKtIVGCw9dcu5102srYPUG1cyQqQSJRMvZ5W6GqSqEAhEPoM745//08Wv6r+MGdE1+v2TxGy9Yc6bm3/VxbIc3Xhunt5jplQ759NrQJmRbu1te6cVDvzlArmuO9s6bzX0mLnvgZGMC9nICOiRxjzn0TvyfmZnYsDpEYwg1EXD0V8jxBOierne8ZY5fSYbrDjg0Qd4Wuk5p9ZkRrsGkphUsKt3zbWqLymM5juwiGvogilA1QlLVwNKZNOg+5pIgZYbNpEDclZtPcBH9tril2kbHSWX2hWRHtsCy3WG0KjPIU9Y4NM40RbE7U4IqmG9W46EgNLVFQjPQJYpz6dtSFe1CKw13J7XX3OU2EhpA3zp/nb7mvswFntoswojheNVgpmdN9LsS6JrzetufSKatz8SRe4Ga3NyUbHIXBROG2RMn7mrHWhFY1mSNxjSw3AZ5VL6HRqO1lEUajCJNJjIPDTNiQO3dfw/3oAW7euInrV29hlM1x48rTWK9OUW/OUURbNCiEN6kndPSzAWOKurJKkrDslFBjEzxVD1CoJ0rF1g6F7W4uh3nuTGpPqHC98Sf3We6dAafDsQ0Jm94Cc0HYPztDQsybUbb4FxeoDWfB4/B9WND+FtFtr+O4p6IyPA6Ffz5vlHtjQXBeEx3FhEtAAQ0xxBBDDDHEEEMMMcQQQwwxxDskLFOwYQQhiV+WJ/zGT8ZuYg3B1GU+iHUSyEyQkgNXz6fAw1Jvd3fqBf3hoaFUW0puogadhakLTvyyvCnYaIuCmTl+O7epgxBCbXf7LTu8ibnFJOz2apl5HixFb5mZ3vSqj70QA7p3vyEDOoJCv3ebuUM7zMaefCBRz95fmI/gVu7jNdrDCM2fuuP337TuWBM1TcCiQBJFGZI0s0aBErF5nzX5ssMxJa6QWCL1RC5kYQco1ki46LVbk1GuG9u20Z67iOka5nNyJg6SBFeOj3H7xlU89fTTuH77Nq5cv4HRdI40nzhTud+oryfqtm69dhBdUCa0lpxT8m832DUVtqsFys0K9XaLar1pG25RwIsnZMpaUzUbvxg7Cnuh21e4Iu40tMZll8c3RFBXOyF37yo8hq0sZ3pwIgenZh910XuB9tf9CdIH27q4GiggwU3cuXV76A4Zj91h78ejpnRsTlYyQWFYltDwMczxVmR3h7AhKXq27t7x6wzbRnr747X7Ej+702q7Ifbmdz/FEDyXvWqF8LjegIV19qb3DSKo+OEmpvP8O050bM5TF+0suWPsYYqp40kjfnlJMTA4oP0dlAhwl6oJhPtJkLaSQIKfp0SEDzFhl+Opra5dO/Z3JYb4fgY6dt61CcRKNGm5xUpksHlf3cOihOtG1zLnfsAuUGzUeYVmjL4FtomKkORoG7d1jVd9p3ZRs3fdwr7FPdgAHjoH7slye/s+lTKxJHY+Hcj7z+83xzQBu79n2p7PzwTejI1PnjH3MFIeDPMTkngknBAPoUoLT4rQbT6d5kiTBhfbGmW9xWa7wWq1Qj7aYTyeirE8Gh8gKjOUWGkviZNKvQFaN7yjhcx93X3eGYveE4Y+nkrHhaRGbx33m+KFzweb9x3qwsY9zO5ujlpVQ39u93ni4b36C7NfGtHtBXotNl0N68WRKfp0romOCQI/EyX2GTLEEEMMMcQQQwwxxBBDDDHEEO+osLxiYzCGu6aupRmuzg+w3NSoyy3iXYEkqhCxwZULU9GOzcKMG0qnHsNK0q3Jl33HNncXwzCp1giKPzM2OKLLKksxykdyQT46XVu5d5P6a9PpV7vzjaIPVVZTalrt0pm8ZDyPRyZw7qIEZVVjsaST0NiroRScnE1zELtok2YSBVgybU4yL/PuOZnNobxTKT1dlnsABTV5owiTSOzla5sQ6K8fXNN956gbOa2Znjdncs5zXdGZyvchj7aWqBtPZxKW03SGqALGozmmoxJJcioZhJARNTk0iKqEIcIzOMaNjoGl+XQd18bg9YZ2lT8+SskFpahEN7rV+8d1JQHs6myGeT7Chz/4fnz9130Nnn7+eXz4G74R49kck4NriJmAaBnJfYZyn63sIp7czwVQr9XYrCk3aOoS280ZqnKL09fvYHV2hs1iifXZhQT16WyMhD+Pp0jzFKPJCGmeGeaBiAMyhsdjK9OPrAHijm5IuZbDlQp/6sui5sD0ae8X0wnKnmSQAVkCZphnQWDuP6d9autGlOG51Z9dUNpjE4fEQtcszbMujgYxoSsAKjgHsowOe2vMVRQFzs8vsFissV5vsN0aQkYueeIgPAnCuWqNKU1aDJzxwL5VuqjHiH1T9A/Negp2YjiF1pB8aXEWxmxukyZ9odhFPI1NcLn2ki7tGPa09eAUNe3a0CZlXGKz3shpz4aePKei2KrZJ53GsZpgZsjymfi6x0dzrbODgyVWqw0Wyw1Oz5ZIuKeUhd4/IwqGeSw25KS4rMyQicqq1OC6JItXLmQTix0sjCbmfsf9wsQ94oDo+t9sS+ElgmLfUEmNgGwUg9N0wfIOYk2iBEVlL2gCM69bpWQB9yTPuWkPkTYbKjXc+a/rKoe4jY/wE85m5/wWgzrw80NjvV5CRQJvYs37KPjysXVha4d7Ao8rsJeLusa2YlM54+gLLeECtJAdPmZMAOmciW8BmxTWyNy9TvZ0lEVI8whNUiIriK+IkI3JrWiQ5jmSLFdFCa/tZDLC7dsJ1usKq8V9rIsNTk/v44035rhx4xjXr15Dnk20Z27LNeqzO0CxRlwvgXrja9eqHuj611QVt9/Efz5PPG6uFZ9/JuBa4q4VycP8V/7OkCPcK4WOyhIk7DughqfOMA9oGn2uWEK1zVCG9o9aE4Gn7c0/xXoK18sXmxIRlshg31UJ8hLNua9SUE6FYym3bD7JB/OzWlZmq5gYMMtDDDHEEEMMMcQQQwwxxBBDvKOO5VYx4ndVlnebUGIIVRcQHCnQOpXRowxIgKMg4mXY/rzQCEoP9b8HmxXFspwl2MRFgCznGmsKZNK7jLcZXm9fDnT3X991J0HDRFw5AdXIa9cTdk38DY6vFg8QnMoS3zqha+/N3IHWYRZCMyVT1lp/mQtrct/tIQgCv7q1tzmf9ZLl2YUgK1821IG5Qf31KFDodf1cvVGgMYHN8dhyTd2NLEHcEQYs9dc1dXucCU90tfE9yTDu1DxeO0pKFJwOZ1OczOa4fv0abj71FK5ev4HZ0RGy0VSivKlswaXXO8/+OQdLnhp51YgoKlclqu0GdV1gu16hLDZYnp/j4tEZVucXWJ6dS0wtiykyCckNsnEmsSorKbJbwy2Kz8Sl0Olo86Ln8Ov5/gJyoj3Wx4gtl+9q50KwiLYPsmu498TgXOzdpbF2Zcpc4X0Ea39u2Jtw3G0pBdd6cJR3RxcEWpbxhwaSQXz1g94/ifYXb3Yi7xmq9//vMU0MOxd0/5wvv1VrzO6dl62Z3tiHNw9u0p5buhOWe+vOKxz2XP4+v7smcbyFVnTGKufcphOWuajJdOxCqTXZY5gL2F3dl86hc6x6ssGvnfaTcCHl7DVMRmjq2KJytKbC8YV9K1Qr+HPkLLfj7M41pD/CnmWPNfGYwnzfHm4COE9bVR4+LWV+9kSecZltD3zTTtrb10y8ZoLNbsEB3z83isnEDe3PJzsOO86ucqTm48LyEXM/iKgdqsX2tbBXe5NHg6e3z+X9xIXw+alY/jsUZYHlaomynGn/YdItTsYS/pN0hFRYpK2qRkT+9vnWVmtcus7t+myR3/3Pl85N336G9WshAs7cE0qPT9OEyf6mMgYfn8DuVprgUjlB9zz7zz9bwudwu+c5X7vpJx2C436IIYYYYoghhhhiiCGGGGKIId5BYXldmKuUDi5+cz5dLHGQplgXNc4uSqyrCnVFVyQbY9GZa82U2EhJzeydSWukTw9+CQ9N+FwIM0HZhIebJ4c4PphjPh4hasbYrJd49c6peKmGtLAv2hIGaEJMrUGWBGeVm5sIk9Ftl/GWIqcAGVGktoZbWUbHMl1/jQQRCbJqWMcv53YzQcAYleQlm2hraAXvXSYxU+9ZlXITpnEqN2RDQcfFIrItWbYdpyz1Dk36vGGU2KqClbbl0iaGm2jMKApiIYDVaikHJm8UDFvEhgts8iKKu0lnMpuUGfLC0AAUOOw9ailMNTK6tOmms5p9b7BoZeEVq8X5wnQY1hGaqkJU1XItzqYTzMZjfPM3fj1efte78N6v/hDe/+EPYzSdIZseI4pTQjkuIQ+siaJlB+x6kxptqgznSo2mWKHZXGCzWuLs3l1stxs8enQfm/Uan//kZ/Hg7gOcPTrDowePkI9ynJwcYjTOcevWNbkXR9MMaZ5gMs5VIj+eTnD1xlVkeY7RwYHczSZYWXE+xy7wYU1M5/G4APQlNCCJgj3LuuZ3z0lrjt2ezTY8z7EQ1tROM2fPHRoE6dBcTzI/ndXWDc2ax8lFWWktFhvylMm9rvVT1QDePE8JAaHHjWlO1EOp31uTNxNj7T+5MflenoCxEvlQXh8wIn5+3hBPj3GhtxU+W7G1K83vxtfOlTgcrQcXDoNIqkZ9VScOyzkt8dya1ZljlOvBcAwGdAmi7E5JqEme6udoRH44sN0Qe2DuT+Eq6OzOc73no1O6mnlvjvE4wVO3buHo8EBO4tVyi+VqhTfeuIPtZovzhxcotnTQU/wkZzjCsrCqig47YNeUjlFWRwQaD++kW9XEy8AnN7HPsCV2JqwqUMM1NmCsrPnkKM/EiSYCyNr02TpXwigJDT077ZDvkWo8IuPayxlsN7Kfd1vDFtWV8ZpZkbDjIhfr3RoKcn/oEDjmficzfzzivkWXLpvAZfYYA5gr+af9kUhf4khU9dDhjGwfs9fOs0S8ZO6BYlE7j3lb1lhtSyzXBfKY1yRFGvO6ZnLdpnlhDQuVHLNECdcscwGjjPtziqvXJphMMqzWp/j8FwtkeYSr169Zo9MqRdWMMJpcQZLPUfHjL07kSGb1iqFALNFm2muX5FDlAjfz3pYQ9n5rAkqcR9OuQbqK9UmkzwVL2Kl+J3w28uX8BWztaKQ6vdzfVzzteofNtsJ2WyNJRhhFo15yk9eq0LFylzc0ie25vB783Ntx395liHb87OVHfqxz5phwn+Qcu5xrGmKIIYYYYoghhhhiiCGGGGKId4CxTIHKBKeiKLHZbrFhEz2yXOWgdedhaLCk79UmfBrGsXNZdbhPohWsGRW/7PILLkuuk3inMvaj+RQ5wZo7NliKJDqUZWWvr8Z9FAlN7NXXaUdVtB5MLxGms1aMTmde6D9vVKXS/NaN3LnwAte1deVddhnrDS65R9WMsGnFJivn74xoJq64eOf8AKMb9FnDNkCO7G3xGxQu+Bxycyn2BNeyRFI9p09IJr/V+MtyF7IM3Z2DQVj2R+k9zJkZUTt28cRQH0FcathoiwKw2LUUiyJM8xwHkwmevv0UXn7pRTz7/PNiK4M85WS076pzpmdrYe+7lNs/87Vr7Ni4sNyiXK+wODvDdr3G6f1HYqbevfMA9964h4cPz/Dg/kMJyqvNFpPJWOMwm46RT1IlDCgqFwcTzIoSs9lE2Id0PG5Zp2qcF66zu7lb22grjLozM9zTY6R2rmKPDmrbhQtvnWAcfuGJBCYAemPU+WJ7L+1O7pYlThc5xWEKkOQna16YmNe/9Z2X4b5WVHbRLAhUYf6Z2zkIxx3K4E0VAT3hsftdaGLYjVOfrdzhAnz+7TGT7Q8aDR6Dr7N2HbujNYio4RgkRofxdBY73enWoMyRETXdqV0ygfsAX29bEClDLAUxIkQqTHD16olE+u1hjXO648/PNd8X0aJDkTQmylKkZ4g40J6LO3udt2zHaaKqzY2+K9jWq5IAEjO9moBipBJGOwnhcp/vuft7jR/3nK5dgzcTFm3fC/OPY5VUvWaULiBr7sShQdyb3ebiVeu4nD3s+5fuC/tHYomzpCHKwpRZXa+2WsREZd54XHxs7c1d1VzVEyBsDEgBnMzm4L5V5UWLDtpHD4U1SnE5yyKMVbEArC9Y5VBhtV4Js8F5X6vJJZOII1V3pGmum9rGeoPIvg+/f/1ClmAv2dQuS2dVt47rzlXe5Vg6tIvxzHtFGn59+hUTYfjDa1NcpmjNpEbXMpF/tiTY/vt0lTJysev9uPt3FRAhoREHF3ub+B1iiCGGGGKIIYYYYoghhhhiiHdAWBZyV1+qCfQE1mWFh+cX2FQ7nG6M30vhQd3myS+NWZpcSfyKE37xt7eiOU4sXzXDkkfLPMxNrOfTQXhyMMF0nOP2tSu4de0qNssKmwWts3QbZxJhiooszB3SxppmsWVd5k3b6Fzja7KJlok8BiENIqJEjJKiUiUpi3dT6Mgojoiv7CKUK7ttibMENxfHvBQ8iOmtWBfEFJUY25d9CikmUFLgiMyx7F/kjQLi4p3cgj33oQvfQcigoE7x/WJxjvV6hdVqLb50mmaYjKdodjE2dIDXJTbVGutyyRHHOCP6g+KVHZC9ZCjZ3iFLiRxhE0NqU42JXLvExeQgrEDvRQ7qbDTG0cEcH/2qD+DatWv40Ec/ipfe/W4cX7+JKJk4pqT19bniUe07lTWo3kRQ6gybAdJ5V2C7PMfF/bu4f+c+fuj7fhCLiyVef/0Olqs1Xr3zAI/Ol7hYrnB2sUKWFZgvauR5itcfrlQST44ux/zoYISTowmODmc4e3SB2XyCp59/BtODOabHB0hmEztEWT8DM9u4o3LXuuja14uDSzHc11dcrUS9AxXsUSwCNqPNRwSVNSALgujjOBA+SKxuc8ML6+CNIZmsqek4pDuyoru1xnZboNhasmdblCg4V+igpKAm4a4Tm+X+dbRN13wsiIY9prMea5ztVhBzLqya/+my1b110McH+HVvXdf7bGSOEVcwWbDm7rZhoBO+jxfpkAR8P1Nszd0ZXtveSwJtXaMsCzWuTNORzfPJyJoXBvG9qVBUWzGPt6ta82SUxSi3O5yfFJhPtpjNprh1c46T4yMczOdYLtf4ZPYpnD48w51H59guViirHZbrUnuGWtpxChG3wnHxrFnbVJDnlXtzTIflUjyuShOQeT34XxFVPf4xG/PFGGmf2KH0677eVFpGrMKg4JqQ70ysBxNzxkowuVljaXusXLha58A4S1Anxtqmsz3emfM9IRPY+DguSu6sqiLUGTRNK/zqIRLOnasRPhYk2u/UYE/7YG3HkhNFE0RhTyS1WA2J/DbP+PZ0LJ8v14iyGONJrpef5pnc7ZvJFtsUSJW8K501bNdfXPtdg+mUwnGE1WKj1zo7P8Xde/ftvZMKza7CcnWBqt5KK55kU0TNGhVdv708WJuoCQx1NUV0R3Hb4NDGxwoJrG+AHMtKWlgCVh5zzRFz+isppw6EXiHBxJKvdwnAPuElUrfCM8eWOJ8IxG6v16U3as2x21Uoq7V9jnl7RdtmY+0N3BOMhhKSWjwuuybsWxDTsS9c0hBDDDHEEEMMMcQQQwwxxBBDvJPCsmMlqFRQQNiWFc6qAtsauCiN1lhbbXbbpE5YA4li1hxPrxPK6/XmLRHSXltO2BiHkwkO5mNcPz7CjZNjPKrX2J4tETXWdIsiyZau1l2DjM3rKHJI7OKXdnMv80s6xYyW1xnsvxK5AkqAYqc5g+W4IzKCjkdv3BdYy+0YuDpmTs5Q/twJy331LIgPIls691jF+7Jk5ibCBMHN+dMhJLK5oBwc0mpqKORBKefuYnkhRARFMz6QIgGdkysKihTP6g22bIAXVcizcO4mjPjZSIiS4CRRXa0QEdUEU1CUT+RGFCaBD6cIpiRBitl4hCuHR/jA+96Lp2/fxivvfR+efvFFxOkYiPNLhkfDXZiY7BiMcHMXbvuTYlFToNgscPHoEe69fgef/OFP4+z0Ap//whtYrja4t9zgoqiwLAost6U1d1wUwnncvX+BnEkMATh2ODkc48bJFCfHc6riODqaYz6bqalVPhkBEpY7UTcgSHTWPRZycBPu82cvR3h0wF485rHBGWgWdRemA1vYF0jrWpfE2Iq0csT7jQKj3aycXaX87mLnz9KbOwqJIYHK0AHBERtcqt25uajs4le7WMOV2xkmIxIGIeAXKLZ1LeL2mpc5kEVrywDevTXUCcaB5R30SSVgJHS5q7K37uzwtND3GLb9ydbsDCHRZHTCmmhJAW7XJNhsCzRshLcj9mCHqgTWG5LWI4yzUo1Bl4sCq2WJg3mKq1dO9D5Xr17FarnG4sGFxNezVQEsNhLbKbKyp54QN0mEXGJvz9EaHPAUeany8o+lJwco1FbW2E0YjGiHIuKKc6HRUcJsiqdea7XtW0VFLAXnzEjva17ixMVfNk4NfQO5DrifGtpBt4RYCROU7fG2JCsldVg5EljXNvIBwWOcdReXfb/hfqtL2cf9epUF83i67M5lJpJD2JMwJ8LDvVKC8iub3PE9NgVRGJFEZTV4jSOMmexjs1HiM+IGzbbBrjJkjeEmgtAMjCd0rPPYKmwL2ycfPTpFlicYz+j+rrBcECW0wWyeYzQeoYrLbjb5HzhG9rpeweBs43DctleE5KKx/oWhobOYiJrSnMREdLRJRO8lQByGcB6OctKq7G0X3ZbYtTpVVQ5d3g3HqEI+YgLU3NlttUdbKGDiclMBVcHPOCA3Akbrx6ZTO8/Hxtx/E1t7iCGGGGKIIYYYYoghhhhiiCHeZmG5FdrctUhxqqTzt1f+Lf3S0Qti/sqWTKHBOMf8cl7yS3TTIPMSZ34JFvuRX7LTBOMsxeF8IgTGOMvNfRw6TnkZNUVgso75BZziEbEMfCtpYv3K8E4ydJ3P3X0Nmzs5RsDFVd4fzk1YAP11r8i/J5c6cMKdat1/3kCJSAmeq5fKSzOXU9oFabo85fTsyqODmBfF+w3czOXqghoprDz2osB2s9FrcDyyNEWejeQojCITAem6ztIMeU5WKW3i7jzWyXRiOA+RzyuYAJAqaCKK+Lyt3r2T85GIkusnR3jvC+/CzRs38NJ73osbt25idnSCOKVYzunUl2NNFgliWetMDo5lZ2pLwOG8qEvs6gLL8wUevPEQd954gM+/9gBn5wu8frrAalPgvCixpnBDkZ1iF0WvssKW48vxELvXGguWZOw2wNnGnKDHRzMcHh2i2myRTccYT+jYyxGTRRu0YLMH+jXvNYJry+79zz2+coe66CMtzLLarhl/TdejOtdt23gvvJ7NRzPkGlYmZChaakhoJMm5n9qY93EubneXtMhmdY3Px+aSc7kVlXvr3FAUoTdaV5wfxsHmpDmWaXFvBeU+f6YX+3eZeB/wFdbAznAtcY/RrOcFWkArKofj6YnO7Xx2nALxFLudhLfT85X2hjSMCRm34r7v0FB8rijQ2+usNxTogbt3z+QiJrc62UVq5nf1xhVkhyleeOk5XLl2gvRgioM7D3B+diEUSxQ1Eqq11Gvff3gFHS0hgTIcucRZOx/yh02i9VGS69bcy3JmcwuKzXXMfXVEwZ0oishFaDrRKStLfDUHelMHhrq5y1XxQAGYxxLtkHK8KISTqU5RWsUDOxS8fnIiM/lg3lfbi8JxUxQ2hy0ds3tTPmA3JK4nSMls5+OdQ962mhOSqLumSmwQzZFwlzDuNvc2XoskqjCZVNjwOmQRsjETlXRpcz6bg55iqBzVEs5DA1cD/vBuisvkWnOUydnOsxmuHl+TsLw4u49tXaEqPBHQRMjTkQnoKFuXuaGFOE99DRE75Dz70OS1XTd+XW2ri2xrE/t/hCSjIM7zovOfr2V7NG3eWuO7qm0n2mJKeIf4yrX2Z15rjlOs+cF5FLjQ/EwcaR3xfFgFsNk2um0LSzrxVMqoEfaCzUz5WWTbvKNqLlUpDDHEEEMMMcQQQwwxxBBDDDHE2y4sB8FWYrCLpBT4KJDofjE9dy6GeNM7CmDEEqSxXGd86HZX6UvwSKJQCnoGazU1gpqqzcY5rh8f4uRohvl4iizKHWVBlrA1xUrJ8WTXrIaCdSI8g2gaFHp1eB2zUlo3GacucezABk6NyqTrukRFgZff4MmgVNMjF5alWxn/ll/ETeAz95+pB86tDUzNVmMI+IxITdY4bjxmM2FaIydyX3e1/ITeTMnsiaYv13auLuSLHE33t0QpExnYxG61XGCUZVbSnBkKg2JrHG11rBSbRzmb17EbmgkvLI2WeFeZA7bkfWzoJvdkIaxITuamGhsmurYsHVdjtAiYpRGev3UdP/knfSNu3b6Nr/mJPxHHV68inRwgolv5seKElahTBJc9sr0ZCqOz59Woyw2aao2z+w/xxU+/is985ov4gU98AWfLFR4s19hSZBG+1SRrkVfosizNlXfuGAqKZhy/0brE6HSDSRrh1Vfv4srhFEfTHOunrmFyMMV8NkEybiBwShQj9aZXLZehFV2NDx24tObuDSqoXW+JTC3CIYjGveRAEJ7dvSh2ccBceHLDnsMTNAFK1s+2hN2EPi03zic2kWQTtVGEigJzliApTWw2FrjNL66xekenIxtW0s3uKINweS517Qpiurle+/caPqFDAQQnvWEYWgO0P77vWW8l+FaIs4RDcPmHQnxbe85Jl5M1kBk6rrBer0Vzh7G08eX6oLu6qGqsN0RiJDicTayZXx6rCVxFQXFbgMUKbIrHdXqxrBBHFCDv4PXX7uLhnVOsHy5x8/Z1PPPs05jMxjg6nosjf/0T1/Hqq2/gs5/6AjbnF3KHk9VM8bbOGq1o+oclaWp9OVbCm2fKSeoJOa4sunjZ4pJ31ltPdBEbxPlLjIQuhlV0KClFEVlYD86bME7mOCd/V/tNZvNM3PkgvlKWpJtfSQPeT4xCJFzFBsCGVSCO+2YSwpBCVlURZEf+uqhChYfNlIAyYpUHKz5ynk3EY4yVeGj3YEvd2XRzEZ0XkdiaChE2rG6pGlysYpTlDqN8i/lso6aKo8lYzxuRpx/FKOH4joznb1I4r70xhc2tnnmDQO5fq+W5kpXP3LwlYfn1L34KF8UWJT/HKqKJYiEx6PjesvlqAKgLa2Lz0BJwbIbqnH1P7th8ZFIrFGRw32dzRDvZJJ8im4yQ7jaIUWh9stKAjGdWf9gnRtEhkZQ3ssQpxWQmEQ1nE+o+mCSMhTtinwMlWuOp4alWW5RFjbWQOI7CYcPGyOYbX5eO9YzNL31O6Apfas45xBBDDDHEEEMMMcQQQwwxxBBvv7DcNiGz6NyFkbio+hLdIh8oXhn2wJo29crt/XlEXlA3U0MoCjwS0owrSoEi9S/d1hApOC27xmP2vZ9NoExc7jfy6pcVtz4wOs5Uouz82YrCciXnmzXecoehN0YzR5mXMZM/2rJoTdAKPsSuD5OhDYLgZq62novV0Rq6v8dvNsdoEB5NvAuOQzFIY8OKyP3aCpcuyMm9begO4kdCeXZ4CMUOilty+FG4chU8IBU4HkJ5SGAwp11cu6DsTRpNFGswn05xnbfr13Hz1i1cvX4d4+kMac5meCaEhfcOrOI9BnG4bu3166MwXHymg7yqsF5vcHp6gfOLJZbbAuuixJYOPV3/yErIQ7M0lYuHBnPuEA6sancx8rUvdg3SrMD9RwuM8gw3Hp7j4OQck4MdZrklL3ZCJ7hgeklw7YDJniRw7bTfPM2uqR+FxLk9tfXSy3XNwDT3e67ovkfYfh/k3m4OdGvJubrtXHEUjHjmhsoITfZ6/tieG/vN0ZqPfT618yowmp01Hkr7w7GFt4jax3bzP6ynrhOi399O6dAALQj4fWxLqJawv+017Ou9p8vf/hhZo+Xq5OvSwW8iXIRkxOttHFxbErZ3UNDkaaxWG5ydXWA6n2C9WltibDrCOB7h2rWres5mVeDeGw80Vx+en5qwyflW0U1roxEYue2J9gY8cIHb/bFFr1Ck9OuVuDi8s41aTmPifRLuZXahlBiKPEHDIAVBTOAOOtFef2/WGJqriunuHHeb8/4YH1tjNbNawfcyf51wrK07Ws5ob9An8doc5GoI2ne7C33h58851NtThQWim7dmKz26wCmKNqiZQPTj1vt4I0TDgYc1E5rdhealFGbp5jfBtF0HTJohwXg8xnTKRFyNzWaFLE9N5N1xH820bpX86y7e3r4a5l6Lj9G+FJjMoQmiVSlESlgS+kSoiSWidrtUAjQbVFqCJcjv3rDVt40whzjfTMy2ho36DGGyiHs3XeHcO3fEJFFM5nma55zjleaGSWGyskP9hEoJ48hb0nhwLA8xxBBDDDHEEEMMMcQQQwzxDgrLLHMOXM7AaOUXWyIu2FypbUwWEcdbotpVSJJGJdxEKMjXxqZVEjMbjJMMeZQQ3Wtl1nSeOQpjko8wznN9sy7ViKwWqqGoawnDdA7yORQppuMU01GKomywLa3JmDWxMpxDaCxGgXW92eLR2TnKqsRytZKDOM0aCdxiprK8moKTN5oiG9SElxQxcR3EerROVpaYe/PBHVSazi/35F5yrIJb0L6uW408my2xUV6WkI+ad16+ttyfY2WN/YKINM4nGI/Gcp4t6UYOF4RiE5v2TSbI2YSJ/FuVX1tzM5aWc0zW6wIXy3XH1Q28Vjqfy1q4DhZZs7w+RWKF4BLZKMg1GEVs2LfDu249hQ+++xW8533vx0e+4esxOzzG7Pga4nxknbveJFN6XfiegByA0o7dIARUbs0KEfm42w3K9UpN+z7xyc/i83ce4N75EquiwMaPsRV4O9pEK2pKj6EL0YVd6oRx0+i5222NZbXG3/z4F/HFOw9RZ5maH9567ik8P82RZrzR3ZmonJ7jGCAp5o92IUvn6g32XOwPZe0Glw32dclbndDsh9kX/+01AyLDXs+UJRNhO+e9iX5RcDa3Dcsa1GUlFm3G6z8iaztDxi5tuwbL5UIiqRIKLkqb69iuh3ThwGZlsmPPo9z5jilWqyLAr6FckK0gFhIwQeyz57evbcDw9k65/12IN0GSK9lQKBTn6IjlfaF5ptabnNreHFDc28DttrDHOr7FG7rJ1b4DLlZbo1fv6P9NMZ+NcXw0VRO8sllqb9kIhWGIgbTc4W5zjmK1xqYs8Pznb+PqtWO8+/0v4+jkEEfHV1AUNZ65/TRODo9w/+59/P++9/vEPV8VWyWmqlGMnbhAfnw859SREH5r2HSxKJDGKTLtBdxLEq3FbUn++Q6TLEGWjYU/SLzx22SUqaphteY+a9zz7ZbXiAk5F8+dcUznMW+B1xtkXiWkJLoSjRAjqWNELAUIOTAlgMLuZC7XNKb7NzB/fV1wz6CAy2RgFiHJiKuw5qQV3fO6XjvNW76mUBkpP0f4GSBLr9YIxeI0zvT5UBVsElhiPa2wWXM/tfXEKcTPmjC/hJZgmUtU+oS10ZVLvKmFtxlNMh3tar3Bar3FZl0jy2Jcv3YD09kIn/rMp3Hn7us4ODjBSXwDET+Tskzs5k19gYbcZo6AJ0z7bG9LZFq1C93DrH5hQ0iJ4RKwKSSzFmKEGBOJy/yzPpdq9gDgfmzNOfn5YsK542+YTI15bY0FXtUcpUxiOeeLVVZEKFasgKmw3S7bJpqW8Ip1ntwTJlmmdbUttp6MNN+zKoWEOKEI38f4DDHEEEMMMcQQQwwxxBBDDDHEO8hYprgnx2u4X47leA8VQccZtSxqF2raFJpxtV3uTeSiIBSLCxwcWSZSBI5lYJRKxO7dKBQHhDPF18wbUsUUTKVntWABF4FNLKV4rNJiOpVdFGhNlT2XoTkie8zXN7Fjw5fwIPp1DzCcrjNw3ZHWkp41TnRiOoy6vd+FH3cA0iEo9ylFGwrNcjm/GWBr4nUqt62E+9BQy1mudOlRfCoq4kdC00FjtRra2MeWErvevwFpxKabeqk6BZokwtHBAW7evImr165ifnSMyWyOOM0c1+Cl+u0geMupxznHW4SECc+GhDBRfseGiiXZqlsslmss11slE4jqqMmJvXwhAo+3d0n6193+bseRNDsk1Q5ni63mzYNH57jy4BTzkwOUVOYAZOORjXVwHPe42X20Q39NdGpcz+HbO7b2UOVg7gy7vavoeI2AS7k0XJf0Hrs7iMuGNzE0gt1sHvmj3MkY3MCdA7PDS4TXvKR3t49pE0a9+0LVQH+cbR731kE7bnsDtt+k8nGeaR1Md7z95pkd2/ryXLg0UEJ02PXT/iFnL5nqtj7M5c+EEh36dpXVKlEvz4RMhXVkiSgmH6brsdYPZ8ZkPMZ4HOPkygmu37iuDeP48EAon3pBUbiQM5aO4rCv7c9Nm0fBnW3nyD3PBXUei/Y5F9J9j0l8X6CQK7yGj7VVWdirM4nC19Jabtd3z93uSKO9ig4lG8IU77jg/Zs5l323Cy7d8Ps2n+gYh5AI0N7VVUNw36H4zZsJ1N3+YBUe7lj2RpWqtKgsCak140kLuaLFr748vbr1Lic2H8smfhRnq0pu3m1RCnmUpBlGo7EEdn4OaJ8sDYmTjSjy00lOgZkJT8fLeGXL/p69n9gK11X8dL+pmV/VyfHMhvI6qmEjE2yqRhBzpE3QsEGhuet5vcOeacmuDk3iuCVx/ntVNu1nkN+4PZsxeX8Pbi+/OaEfjzEaYoghhhhiiCGGGGKIIYYYYoi3SVimk5aiayG3Uye48cvvOM5MsBBCdIdNs0WFCtNkhINRLrGDDlrejNtr4isbj5GJGtNhRj7nOMN4NDKBtLES5qYpsN4WWK62WK4LYyM3NWZjc/SOMmBCEGnElnAm6JAvyYOh8MkvzSZG11iuV/TKmSidUGAyFid/yv0mEcFLjil6U4CSq5Jewticvq1EYC5KlVi7y0+caI1JD1cQ5De9PEViYiNU2L5HoA3CBYUjoj2Mn5zJuU0hTLxMlr1zTNyVludjzOYHyEdTud44aHSTUkChS/n0YqXbo4t1JyzJFW1c51KCJAV3CjnmYysjuod3yOtKTvOnjq/iyuEcX/Phr8ZP/uk/HcfXb2J6ch1JlmNH9nPrg8SXcCwbEkXn505f7Hh96Fou7DHixdbYXFxgeXqKhw/O8PqDczy4WIm9WnC8A986IBYeF63Q69dGopddrTW5pNUOn723wp2zDZro03j9tXv40GKF2WyM+dEBbuYJUo43JSaxVV0gD/bNlhjbCeSte1MNKDvmal/E6dFAnMpiImLguJod2F3cbcKBd3lNvLuYewQIb0JXY3mxQLEtcXG+wnZb4HyxRlmssdtVGMnJn8hNSZe+1h8d6j1lUagXuViDw7ZLiJgsSHHPHdRBamRSxueT9oKAPrkksu1J8T0xM8x7iojyrAr5YHxrp1y7KFq7qNcdjyWDHN3i52FCKmdvJ5qKpy6h3kTA5ZZjUGjdjyephESKhly05LILwaCGemyQtkNaN7jYrHHv7gON892b91BvK8xPTjA5OMCNp57C5OBIPPCj6QRnj07xAz/wg3jw4CHur1c4W2+MO5wlOlzbJUT70XQSD50VAXkmByrnQZ7mLCdAs1rJSb1KS+0ZTJyNleBJMB/lyOJazljOAe4JFF8rHntZSiivd6m7eSPkZYNE8832mDixBnARu/CpusOaCNJJzrkgsbLfQbGHXWn5Or08iyUavOmgYyp4Pqk3C9RzyVsGMBnn2m8l5jaV9hxeB+0SanBn6BLu4Ty/1aoAh6SutjrWNKnl6eacqSi6MknX0Bls7vhQScBZmaQ7ZDlQrkusNhfIzsf43Gt3MJmMMJ3FSEczzA/nuLJhYqnB/bv3tJcenzwnF/doNEJTF3Y8EoY5SqzCsZ4BOn1V27CBas4Mp9zjzY68ZBeVqwb37jzUHMziHdKowWw6xfHRodZNQdd2tMNsYuPCxBa5+OOM95EBvsPhgf3cbAt9Hq4XS2yIZ8kz5DxG2tNjcqGrtjlim4ylm7ooXOw2BnVIIpD8kY04TxJkeTY4locYYoghhhhiiCGGGGKIIYZ4p5v3BUHNxKTwDdUcy+Y4s95P9sWe4gHFjDyhkOAChtyDLjlRxA03ISislJvCqrmxnOVKRAIdZZXf5GAzcYEirpr5pXSkmjgsQcosfEJU8DDl1AXxDvzyvVM59oiis5fjC1khh1fbmczOzUvsVS7slmbpFoHG0Do3XTTtcUfD773qf8+1bPCP4Fy08Q10WHMuUvSmWy3wk22cTWzrFDoK88RgEIkRmM3CflCAdofepvDmYm7rlZDoYl1gy3rvPF3XKqJ4s0PO8nvEOJhMcPXwELdu3MSzzz+P0fwI6Xiq5lNsuthnsnbRd8aFsQku0yAuu1ijAzNhmSLIdrXBer3FYr0VW5kNzozo23vl4JrsmWk793BPgJUgrVSCM4GBMwpNW+CNe2doijVu3bqKi0dnmjssaY+bBEnUa6YX2Nntm3UOXdOPO47yvtzdPa4dpFawC8fu1kI9zFQhc0sHZ79Ptl4KohP9zFUuVMy2EIphQ4c3G9OJD0txjckROv87N3Pg9+qcAmojcJD7Duyeq5kjYKJUS8ttdcbgEO2uTh/+EZz/l932+65XO31vsBkES1/Dgeu7N6faseyJ2r1j6qyswUsbOON0FFfGWJc9OQj7lrgIgjtFWt7o9F+u1hL7lhdLTEYjjA8ONQ7T+QzTwyNMx2OsHzzE6XyO1774KrbrDc63hbql6eUl8JpQHpzmchNTb9X0N0+yNZ1LkHhzP63hikx420fjnEmuGLmuJ/cHLmPDLphDO3B+ucYp/LMJnlUfMAGXyglrVSW8U3gU7dm96hN3L1+C1IfLsOd6DziT/WvrrG9WnfCKsIkg9yxP8KihXpbqmCWwM6Hl1zmwl8Mc5LWikGrJRYrw9jkh5rMneWxPtifaYTh6hzzheCdxeQtv5lhscHaxUAUHmwGyUWxGEXeSoywqOdOZqEuTMfKUaIoCTUOxu0SV8Bgq5X5CJYo+IuiAF9+cWCQmozqUTVjHZHDzPXNW1hCdFKWoZpYgIaqIr1Pnhrhgu8Y0mbRTl0J9nidWjRMtseX+GG2sCWtKJAo91qwi4k7Nt/Y1wd9rOCjgG2O825ZsDANbm/sDbwMJY4ghhhhiiCGGGGKIIYYYYoh3FoXhfkWVlXtZOEVPiq/CJrjjjV+6yRHOYUxNOhEbCiQFBWIKH+5sdhYmRTA1HIp2KAvyjYHlaiP9I8v4Gnxu4WKQMSnl7BVzkw5fNuZyoSWx5k9WKt0V2rfl9Pqub/+1Aq6+7Buj1L5sO1MzMGdDXXGvOJxiCB13eqUgGElUcGqol3RLBBZ+195fJdegc5QvyXMJJeb8sl9KwivX56L7Hh+dIJ4foKlHPq7mqjTR3BtwJanYwEJSRJSDrUFh7eO95ZjTpWrKjot8EYrahPnArRWLmk0UhQOohC+hC/NgOsYHPvh+vPTC83j+lZcxPT5BMpq24rixa21YOiHRhWKQoRrcyfxZ2k9yS8lW5v1NaUmCqkJTljh98BB3X30Dd+89xMPzFc43lIWARozuUE7eNWgL72h/6PAAHYbDXYwSgY0RXnB0d8C9CwruNW69eoovfOoNFJsKN5+5pXL4elQ4i7RDOgS2tncm3HcjB5FTzNSOO6zmXf7YII4GbboVcV0IDQ7b4BJt1WdnufKvlgjgPLDyfQluzmIuilIs2dPzC5wtFnL4G/uWAtoIGedAQbGsaZvLGYKkE5U7x3FQd4NwbAI3nfh63B52wOeQJzyCiGU+1Ha1tUgFRnBHd2fq60hICOM5ExOg127ne9Dgg2uc79Xjyu69YIfKoGhm3HE7s9Wmxr2HK/29rI3DrLPTenfRNWLVRYPFtsDde48kPN6/90hrZjQ7UNPKbDrFeD5HlKR44YPvxYqiZQw8uHcPH//Ep/H5L74mUfrsfGGVABR7dXyuoGYRdmmEsVAGNjMMd0HchCEOuIa1bshbZtJMVREUidmwNELlInMr9XKOt/igGhsmZupEjGRrTLkzDntUY024PdcDcRN+7ThfrKGfXW+hKSQ+G2on7P1tIzmxkcl1jjHOxxiNMiVomBy0a8dGcvaTy2aUp0rqBTFTwv1mq8+FaL2VkF6XtndWuwbr7RbZOsJyUSMfJRjPErGhyVrm5wKaxJuB+izynyLg++cC9yEy+S8Wp/jM5z6N8XiCbXkD0+lIaJTZ/BBVucBquUKclijFIqYYbS7gMF9ZabLrf0Zw1HU/01Zchw2iCcVsril+JmY6r+2DR6jXtZzXeRohyalOk3nPfZqVJDusFyXKJNUeWG62hkyig19s7kzc7ovlOdabDRarcyw3S+R5ihGrgShI74h2Yi+CjVfQBFQUXd3+WcsPHW/Cq8+TmI+xPZMIpH3kzBBDDDHEEEMMMcQQQwwxxBBDvM3CMqN1Eoth2Tl6qYywlJqOO363Z6MnE33Nbcuv+mwGV7WOZb4GS7jrlpu8qxvEVYOkiORYpa5HlmmTsTyajEwTpoMbknqRHGDO3DQmZSQWL0XjlgEajv0SnkKuLQk5fA13bzn3uXOldTzL4AhmCB/BUvIWEurqlgusOmcvA4/6ojxducH5JzyEv1/EVkoUXiuU61OVYOdZgvE4k1NORkOOfc+xbMIyS+2J+yCSwlzQFJ5VGl81PmZ6p07ApKBGcSgcu4vK5hI00ZJOw9lkhJPDA7z8ykv44Afej6efew7jgyNEsYnYOqI9Y2NwJ7trUG5kt0Lzpr/3bkI/UFzmtS/QlAXOT89w7859PHx4hrPFGks2xPIy/fAWjzFTtu9uc7T3oFAeL36rFdFXSMSkfbSiS7HGG3cv8MYX7qvRVbXaomGzq5wsVrUVcwZyePHejOJ8CPZ1F1hpQTWmLkWhDmERDrxzBl96PXeTt6Jj62x1zm5IyAQPNcVhiXbG1OXcKAOb+uICj05PQU0yVAGw1D2tKNyWwka0YqQa7j3OndpziFsXvpYF3UvXtOfdNtYLTnKf32EV2poNjHFLvPTnjERzcdNtjFrMhtz0/ipyNAcsRueW5T7SnwTdJfIdgOKf/kqxj0gBctbXtp69aZoZxzthTZiC3Q6rosTDR+dK1Jw+PJOT9fjaGodFiWwaI59MkY8nmM0mKLYbJcfOHj60ygKJ0g9xdu9MuJ+CzdhksDX4bdhnaCoORQiUb4OwrOaTagRXSRyscks8BK683XZoKOTyDD2Xwfems1V7T2lJMPJ5xVbv8akrMeZ9EvjeYfgfHp4UWa0ZWztMdbkI77xeVqnwiiRRpv0+z8YYj3KJyqwisYsahGUmkqCEIz8TOB/zUY4tO9Mlie3vglJXKDPy8rmLNBLG822C9Zrc5RijcaIkHcXrKOYasP3Oki4hfcNJ0OgzgeIyX7Oqt9qvl0uiR8Z6/4ODOSbTGOPJTOsmGxPHVKISI9swFUzoZGxKqA2YY2KNVRP92ZIRWoMJXdlWIZCPPOkWZaqweXhxgU3BRqsQmoPCM+KtXMVNvTYXPVE1TArQob2tlCBgM1YyotOJoaeWayaLVlitl9hs2ZBzrjE0lzLxUDGKcmsNWn156jNCe5E12g1mfo6fTXdHFA2a8hBDDDHEEEMMMcQQQwwxxBDvtLCsUu5+37VO62v5oar6dkcX/8wv+9uK7i/ygemG9WZ5PW+rueBMmCbmIqlrbLalNapKybT0L+8sOZZobQIVBQqKGGKzyqnXNXqyRlBBtHU1xrEXJh6HJkn2WJcPW+REX0zm80PDQFNS7T3Ct/fQ9EquziCEuQpIoT2Mh+vvatm02VQoio2L18YWHWX8gl+Lx6r35fk5I1piAdEGEm2Cay40aAruzf3mbJ1oHURlVxU6koJfBOfVStTaIYtjHE4meOldL+Dmtat45l0v4tZzL2B+fNVEZRedwquYuOxj4wIynaRyWvYdy00nNOv3oXGfs4LZYGtJLvTZwpr2UYR3fIM5wp+kvVRIO3RCf3eyDmrR3RSYrZj+bFXgtbtnyGZTnD5YyL19MJoI4SK9ri8m73UM6/GWQ3NKnW9ohhV+112Px8feQfYE+n24iJjK4cZx8SSD5qYfF+cMBUAmFLZlg+W6wmpTSeAi2qFDsTivOjiP/TXtOLv51D9GE3x7uInesYSh0c/Wwd8RK9oz8zt2PYE6OL6V/tgBFbnKfmgm3IUXCM0098VnW732fmFdKL2hZIkdjwz7rqpZAsaEVIq4oXldf3Kp+R75v1WDBasnohgPT8+151y5cYbD4wMkeY6q2JiTnWsyz3F0/RryyRgvvudCfPTPf/41bEqy3de4c/ZAzlk6sIlnkYU1M4cwE24BlyPEj2NwjEFcoUgj4RxsDXPNWSKNzl3DX/QSXL0mc7rf9672Orggr70yTAFdq64tnHG3uzVl4r3NC2us2u2Z2qeJ6Ehz5NnI3cTe2U5rng3wMr2O0D5kSWf2WKIcuN7oDt5uWbEBlJmNUZaSOW3NSTnFqDvXlbG4E7qkfR9X0izxNUHn8I7n3SBPTRjejmPMZhHKIsJmWaIsY6zJKI4TjEY89inyfK3kQJamqCj27krt0zx2is41qyxCU0CxsIm96IBGoZklz6VWQtGSdGoWW5EQTzQH93NeW9vnk6RGntvzUu774l9vjSLecExi7MoITWHNGC8WF+o1UJaFkgLbcoOL5SMJy9uGx2zVPz2i+97/2+eE/WwRTy4uG+NjUJeHGGKIIYYYYoghhhhiiCGGeAeFZeMjG4/Xu2KZdsAyb3eV5nQAJkRUpHoIn7PYWtMwOtIkdLRinYvSu13roK1p2kWExXojoYBuXMNb2Jd8fpHPVY68k6NXjffEpkzVrEyN7WprKBh4wiYs85mmL/NXFDeMzRzKz+WFbJ3FDHN72RdwOuokcBIX4YKxMCByafv4uHJDUYQMTBMSTBi2cvwdmpKczhpnZ2ucn29drKFAHuFwxjJxYDIxZx8FFTJAKaQQGcKxE6OU5eR07YWSZpWqU9Bo7BjcqabGgwkdmcZB7imc9lPXz4Rbir28l87IaRLj+uEBPvo1H8Fzzz6ND37N1+KFl19BRCEsGe37wNtudO7eVul45U7FbetcNgE5IDH4GHMtR3x/4lBUAl7g0aNzvHH3IR6dL1SqX8hV2RcuLymAjwsX+FscRsuMNexDkJ6JxKiiRI38fugzd1Egwfu/8ADFtsb4cC62qUO4O0eyqXJ7omzL8Q3iuTvdOzm1h1EJ/O696MTd1k3fCvV7j3CR0NzKbOzHG4VlCfCcj2yG6QiF1brAw9OVhM1NSUcohTfOZ55XcJb3GuA5ZoZrS80z26Pr3Po6MqEbXPDn9esxlgMXXc/zdd4Jx/Y7Y5aH1W/BPYTCmJA6OkpLGLWs8n5SK7jSw9hKSLVXtGacJuqVdGgH4bvxpJfWCs9NrThbZ7mdc+e6lud+F2Nd1HhwtlDjtNdev4Nis8HR1UPM52O5SQ9ODhFnGdLpDEmW4cbzz2s+Hx8f4z3vfhk/+P0/LITP/QcP8Oj8IaqycuRMRM6POYUb7i9EKXR8dgqc7Ke22hRipBP7s8otoSTqjZrBOU6jRzCXAOyVGOYutnOpEyb1bMj5PvyzNQGlmMnklyE2Qg7OkBhddYkl/jhqxEuYBM0PDo4phWTu90RhzMZTw3RI3fdqBc1bT8N4Yi8b5WqOxz1rNB7LYV1stzrPpmAzV2CcpZgQ45Km2tNJcigLebqRp2Ok4GcNOwXyehLvYlzxusq8OSDF2wa7WYFdmWO5LLA4WwgJcUE8SQUcHBwjTeaYTktE0QpNQy71ufZkXk+OBwXd5fLcxHgm9+JYfH7hPzI2v3Pshz4DiNYmq92So1xTm8La3UZ0H+ds/McxL4UbmUxtvifEIEmh5t651nxYrWu52RdedaIEEd+AiZddg9X2Avcf2T6xqdgEdSe8Bz8jQtPRsGdqvxK+gyZ9ur69KkeJQX6u0SU9CMtDDDHEEEMMMcQQQwwxxBBDvIPCcis0iYO6X21O0YHCcutE9AfLaBsZl7MtwXeBLjiWw42hx/VuEtHcUWvYCjrM6AgNTGQTlEKJdhCezHHnDuQgbrXCid1njrOeAy+gL4Lo3Zb4u6DeNhEzcdKck3xZOniNIhvup6BDzaNFaTgOxBrr7VBsK2zWLFen45DCTIzYy67JIKVozUsTxxSmzdXo/m4XSO1mp+ViuEyrFC6Ix6D4zNexaxWONUiUfY9scJ/KOZemOBiNcHw4x40b13H9xk1MZ3PEFJVjTpWeMCoRO/B3vRnfZeSFH7O5ue1mAj3dqcHhTWwHUScl1ptCnGCKcRRxzInrRxrYw3tmyo6/G/7+WFKGHMX7jckC/XdT1Thdb/DwYoW79041wie3ryDLUuTTMMfs+a2Vtn2F8He7zj2AhK8NF4t9Dj6xdHNJVDa+cCwRiGbXMD+D21/rhEgZP3+uURNXDSHD8aTQKoEwiMntZdx3uQf9PMyPvT/1xvfy2t3jHD/mIuyde+9yPua3Pddl/7g61nPo8Neui16DzC5lFTogesLn8iXc79PZvkf/qRJlib3hmi1KbIutxOXteo1ys0G93bhwOrF1KJd7jNF0ht1xhZPr13D72dvIxhnu3LuDxWKJxXKja6KGp5UhgJggsIRb3zne2488eWRn57xp7bXOio8duxPoH9ro+FrdPiYR2SdTW43RSy6EPdUEdnPfBkZHSCsYhsPQIpJ4bXPpmnF6IzvuO+2At3sPhWhz8xuXOUOk6oUdmrhBGrFpZqKfbCDKx9AFPcoyjEdTpJkl6niMaTJClpLiHy6XJ8jkYs4QsQFnUiOlcJ3HmE2JAsmwnBpSQ6kMZTE5DhnSJEeeT4QPInedv0+j1DjHLdol8Igd2SLUBkV4X1dKftj+rsaLxPgQSeSoJ4n87ZTsroFEYEeUtK/L6ZACcb1DRHZS+HzVMVg1RPh7GAHL9fYayepYjYNtdJBehcDen3pJuCGGGGKIIYYYYoghhhhiiCGGeMcYyy5KULAUg9PFBEoTlZAMLHl2FcLRtmW0QxGZcBIa71kZtYti7g7kl2T+mU5fNovakrWc+JdyOlsjYEQ3Zp3Ixcb3yvMcWUYBLUZR7FCWfL67+WTr6wTHwAcVH5POMQoCzmemUEKHnbm4zLms79lNg7Iga9dQH6ZheUM/LzGnqEy+KL/AEzXgxArxPPm7LO9QGDxfNicstzXOzza4d2/hTjxzH49zNrdKkLxwHUkyAaIZ0myOOM5dSA4N8SrUdYGy2kp0j9PUxBI6qVOWfc90rtPJBJPRyB11FHopTO7Lb51blg7YCvODKV55+jbe9eIL+OjXfhTPPPccjq7dBJLxJbewi9wRx4eiStFxk1v0hbnogsBMJ6fOgxeJfGDeXxsCY71cY3GxwP2Hp3j93kOcLhYoxFcmvzQwdIPA25czw7H0+Cr937qI5lwFv8uPR40KGzzc1NhWayx3DY7/yvfj5vVjTOdjNE8XOLp5BbPjgFOhVhNep5M+7fUdxxBGNpSW6/Hm8hfIJbj1++7A9pgvS7ZhuC1ZILSBC7gV/2PTy4pIlRLbsrsxUREndIMWauC33la4uNhYk7Yd0QMBHWEibd+pGDAT/QiJkk7A3587/fE2prglmN60d3Qv+OatxZtn9q5W67YMb8m5oyaG2j+CSBZEsT53wxq6yTV/eapoGnFOWWOzOHP0Q8tA92ksnEIkvjd1xqRqcLFaq8Lh9NFDnM5HGOURNlemyCczZKOJOfqzkYTl8ZUrGB0f4uXpBEfXj3D/jbu4cjTDoweP8Mkf/rR4zWVdoBJ8N0K5JSuXx+ZVHUJdcD0nyLjM4hSVTR6MvETCmsbVelzO8aETWlUiQJ0S5RGh8OSM9la/tqy24C5niRvHZDARlsYYZaz82KFJua82KEoeT0DWwLjvWSr+vVaj3rfQPGyqAk2VIckmmOR5bxq7Qq9jtkRXlo0wSkd67awutEeNuc+hRhlxj9thnk1wOJtjNp/i5q0btgZ2dNYC8+kMkwmb75XaB1XRob2fjuqJ4c9BcbrBNItx9TDBalVgkl7onDYbVtGs1fwvwQwkdZCNvF2vsF6+YVUp6cSSfKm5e8XhTyiaR5iMY83X2dgwIOstkw6FBGSymZnE2ZSGf2LFAMXmbZGgoDjOz84m0xoRxgMNahTal0aZNb6Ns52c2EXFZrYEbFNsNgE6YnUPr2VKx7bNe4rwhp+pdV1V0cIbxWdWvHCr7TW6ZLIkkRDOKgFeSGuYOcQQQwwxxBBDDDHEEEMMMcQQ72jzPkZbnu681b1w15T+GASzYGrzh5j7y/if5gS2J1opd3Atm8tL4kdwzenLvTGGQ+M9lbZLBO7cYH2n35651kVGe154XNfsKwh+fYqByeatTi5B/LLj1V6k9772yzeHYwLYgFAu0qKWEM7+VTTFNnSU1im2ai5GF6/zmtvX2rXiWhCEAjNXKAA/BrqVMyFEyG4Ozc4eF3aiOmIKexTv8wzHR4c4OTrC4fEx5kdHSNlxqm0q1xOLghs5NOuTQmFuPxOy7fcqtQ4oiUvuRgnNPcdyURTCDtBp65CFN43hY08DX8I6+xiX9t6fIzYx22FVARebAvcfLST8nJ8ucHgwxeRojimVusTFtXZCda5Zez1zpPcRGDr29ljCeD3uKAL+IjCIu0fatQmv25/XnbBqDnBbM5xbHQu9YzAHx2TncAyH8yV81K1xd3/AL3OX33yfM4D3Hr9HU27XekhymMPaz/NNx/CY42tRJD4C7Xz0vcWF0tZ9vH9w3Yv7BuXLNxifu7HxqRT2KRvfWjgLNnjjjQ0nG4qozg8PZxszsYUEk/kMV65dlWJNcZTu20d3HiKqd1gsL7DeOBOYc8wrB0LLw/Z6x5evf9j/AtonImK49apqLYf+bf38hQ209gQ+kmKucgvaTxyB0TY4jVBHsTco7aZuEJPdEN2uAdP0bS23DTN7Yx56Ne6lJLyRqCpgwn90O/MmDIol+fSTf1cy0+FBMZ3LmRzBasoogbRLmumI3MFtDHuiRXYYjyngW7UIMRW2VigYZ0izqVjv5lxn0rBLptp1cCZ/YNwzOWEGdW86acKujsUTp4ZUcQyM5pCtSeO8Oy9mb3ra9W5fm48LLP03scAdHO57Rnetw7z2NeVu9n4VQluNw/NnQnhwLA8xxBBDDDHEEEMMMcQQQwzxTgvLsSSPrrkXO82roRzMMSd3lEpyWers/fKEmEhQJRGNhC2Dla+zRY06tjJzx6Cikntqh1VRyNF1sBlhRDcW+cI5G5NlmFaTVuqlsEAHJkUF9bbamZMsJ0fSG/jpSz0FCDq9XFhRBXdbNEz52Hir1vTPBD7qFNaMrSVWiu+pVwzKVkNhuzbHowvfapLk4k9dlxKO1ICvbrBebbBelbhYbHFxsZWTrKRBT00JiaIokKYP8PDBAvN8jHmW4uT4EEeTEUo2bCq22Gw3KJoKJa8HS8rpJqZrTcebYj6b62pRJM7ULMuPiYJSuHZyXhtHlqczS1lunuLF27fwDV//tbj93HM4vvkURodXENOFuRdBQO416GNzK4pr4X4em/5uLF8bM1pBHYcR/r6r0DQFVqsFFosLnF0scHqxxKrYoqHT3Y7W3rav0b7JPYvHipwUL8PjWhmzj2zQdY5RRglONw1+6HP3cffhCtevfxL33niID+12mIxypOMxsilnOueV+Vt3cvlRnHMBOfZyfBeRTPOx92uFp/4bt25cZ7gETdqsq3uieBBPoyBs0TmZpkjI7FWSocJ2vcV6scJiucLFco3NtpJbMUp24sVyPOls71y8nVi571zuYV/CMfuxUDjUmmlF206gbR/XE3q7C2YJjEBXFmvZRdDwvsE1bKPmZf6ehLI/diKzU2B6191uJUV04iVaJnZwrds7tA+WyO57kRzQhjyw/YEzls762vaBOEbNxA+TWKys4D6126mZHhu7sQkfsQqqeGhGNi/8ffPpDIe3EowPjzCezrFeLPHMc8/j9P4DfOqHP4YvfPZzKOoCq+UGTRyh5h4X0Wls78kj4QZNnnJON2tCV3GmY60y7r1kHUdi3GuMIyYRdkg2YV+15E1o4MmkyXTC6oMdsiqxRFfBfajGLI0xFe3G9i8lKSo64zkGto/Wft7mhLUGraM8N3TPjgmzjdjATbP1q01RN0Kae1UHE0Z1g8luIiay5oSz3ikmp0mGPKc4G+laPrq4wKYuEKUN8jzD4cFcTRHJGGd1S7ktsdls5Fwuthu75o5HonufzRJ5FNzX2SAxzenorbDZLrDdcB++wPxiiatX5rh17aaa4Z2fn8kFzcfXFasmOHDkl0dIMm8OydOl8Thu0DgSg272tLHPH5FNuFbpPKfj2dFOTJoRXcJKGznEJ1b9w/ljzTj5OWFbY8Q1opss9E5B6idLOM+dX6+9FojodvaqIu79nNv6HGOFThpehpVAQKQLa8duOI79NTXEEEMMMcQQQwwxxBBDDDHEEG+rsNw5cYNgExjDJgo51ticvX4/RQ9+uTXup7vx2oZQXorrzmfXeySGlE2NlJgIIjT4BZilySztTWM1czJmb6nHVkbYVOl6+E+uZAmprtXpe7MJy+bk6s5LZfsujnWNsOxvweEpKVz/4wsFofP/z95/AMuSZGmd+MmQKa547z5VXdVyumFgBjGDFgss6g+LBgMMrcwAQy4s7K4tBotcxKJZbMF2WUMZWhrC0FprLYcRPdPT3dX11JWZGSLzb993znH3yHtfV/V0FzSMn5qc+27eyAgPDw/Pjt/5/Du7ANlcwBpVYlaezM/PrDVwLl03SN+pYhkP+AqW99JDRTjM5Pz8Roaul8vza7m5vJHVolU4i8JUu0EGFLsz2K0AG17MMFqAfciM4KWpG8IFFvezV1R5xpN31WhTlnLUzOT+6bG88d7X5clrr8l8dSxlC4hvoCy59O71rEplJSFq1eHA2dW68UXlMiBX0MLSjIKFo/q+k67bcjk54Av6KdopRJkeIYo3JBjmTnnx9AwTJSrGa/TICD9wDaHO3Ax7eYq+H3byyU8+l3q/kw+9fIN2BShwpufrA8qxsKkGzQFXhZI+Vqy9QdV5AGEncUfrJ0pz92p27udKVisKud9TSQvvboyvbQeoZnYcGCNMLphCNCgWncHd9hZPj3no2qG2LL4MIEhhJ82/BfiT89P9uRY5qskP9MrWRu27uLtEEX6olMZqB79X7aKHWz2cm/3OmyeBz/GiJq999I/HfgGZzY6E9x8gNgAnALQplic2JlDCVjUtfJp2LvOmlX69kX3Xy8XZPbk+fyEXz5/J5fWVXK1vCCN7FCoF2KanslmEeEE9Lzpqvr81DX5thrL2EowDUKIYIMxXaMMTV1FA/Yv509tJwD8OBKB1MZNGKxyqJZAdE/Oqg3yF8TomtF9gDQHvY/WZ3+1gbbOX7WxQ9TF84uk3VEtRztT+YxjU4mMEIIddQ6mJGE/6oTBqpasY4A0O2e71VSFD23Au3FelgWUkEQYZOrOF6QYrDqtzb9fD7oL+GOHe5P5LXLdOOsw5W7x6KWaVHC3v8VrCpkNVxkgManIBHQHbCbVLwjHM09hedLlBUUNfFZMMdk8qEQPTr9tsoZg/goc+BhSvltqZmC+22o/Ely0AmWbYQpIu2gWpol2tOzRhavca5g0TxlNNzcKfuh/apGSynCNHjhw5cnxex7/+pd9LTk5O/ks3I0eOHDly5Pi6g2UosQhmXKGJAHyAiq6trWCU20E4ZHOkBq2vfg7AQpVTru7Ulz50K4gFYARI3W472ZbqIQwvZYBQKLHGPRRpeCCGJyf2BRmZPVRz2bQ+XJsgN8AsfTKPthHOCQk2wpJuhydaSImfNBsBLNP2z8RCfQqYWazPlMt4uS0IQXtRcskxvEBZUA1quwGA2PW86vcL3nC53kg/dPLxT74lbYX3Bnn86IygEEpBnD/AcdO0VIODKQB8Lppa+mEnm60ucwfoAAyBf/Qcx6f9iCEP4xBQKKKI2Bd9+APyDT/4PvlG3+gL5Rt9s28uJ/cfSDNfGlSOIF0BB1SynSI2+qF6IUEv3gePZYcdrm7WzxGEUIGHF3xZ4RPcy3azkfV6w+J9tMLAOJiYRKQuBgcwcBLBoTfATFcqa7EsXebvf+AWVgwScGWzE7noOvlPX/MJOb84l4fvuSeLRSn3nzyUJ01FUFhBPkrVsnreqirZi2epzYA3UfW2Y8LB49L32HLf2KtPamEugkBmaxQCaTIlAbEGSwG/hr6Xi4tLef78pbx4cSEvXl7KuhvoJdujABhUm8VeStx/uNdgdeCAOF06n/x06Oo2Mvo3/bfb09xlmKG15xJQFTi5OyjfFREYpxYbodDnxO4iwu3U5iUq0xWmGV9MIgL/kMiCYpMjFDOIF0Mzuw1TbcMfdyh2cn2zFqyZuLlZU+0KgI8khNpYJBCQvhJ+TE1uoXHlfCVStfL4C75ATp+8JvuylofveUM++YmPy5d92X+Um81aPvXyBT2zPQFAZokEkU4u2tc7tWsAsMZcAO/dAv0BeDnT2QQeuwCTVanJq/Ta1pVaTsBPGfPeBnYwfSGLeUvfYsydKJSHuaTfqUIZL/SDFlA1wF2qLRFWOsAXeN6UUteazKK1zzjSox5z4bCHL/6MVjdQGg/dKJubjZRFRasdzKFbwGGcBwYo2m9zOVZnfOrFCwLxbtzJvG3oHd8YIGcR0HDHA8xa6mpU64yQB+VKlorS3Xm1kLEqZHtzI8/felPun7Qi+9fYJw/PHuu1ePZU1psNofNqfiylFerDXFFV+j0GBTn6AisDpOilLHv692uhxJHJVE264VqU9NBH/3shREB1XJp+nGn/op4A7/WYEHPrE09w6leSJ2itIK193xBSh4SvzTaW6MF3BwcWrTg0Ecu5yr7jcuTIkSNHjhw5cuTIkSNHjncXLJtNBMFygpIAIuq64YMynHEd2lI1ZZvic/y3lZ8zbkZATDUg/60KQC6D7lRFB8jS10MAnHjQxgM6WB0KJAE0YBn4ONtLBcUdC+np0mfaT/DJOqE0VG/FomXqp+kewDIBWlSa2UM8PI35r4l601RtADtQ7tnDffpylSmX12Opu/m0AipzqTnw7ExhJP6GR/3rNZSnIp96+lwq6Qh71jdbgiMuK0fxwRqAExAIy9n3hBXzFr+PhBsE61DQlQo/Wm4XixgS4IjIvCpkURXy4fe9Id/hW39L+cCHPyIf/kbfmAXJimYxtQ8I3sgDC3ZRr1goNFaVsquVVbFM8My+NTVnUDUrcAYgQ8FAqA6RSNhutkwkbLYoyBXB8jQO37kLLh94uZoaMrznli2+B7fEwPJ5AJdukI9+4i158byUD7z/oZydzGVfzOTsyX2pUCyxhMoQ0EotCcyswTHggWeHSQRd7TyB3Yfn4dYP2DYtpsUKabIHcArl7WJihHYG/SDXV9dy/vJczs+v5OLymgUwtz18ynmXWuJHfVXNjtbArbZtz8TM7NPCXk/Q+P0bm59C33hObsEwFWE7KE7fU+B/Wx19+5rHYoLxesYWm+kF1a/62Vtey86ALWGko1H7Nty3yWnh3ofw9Wa9Qck7WW+2Oi9BIYt7mLJPS2TEjIJ1n+0QyS4k31qRB0jY7EZZLI/kPW+8V77iy75MbtY38uIlkgIvpO8HS95B0FpRoUvFeciZ4C5Wux8dAICRI+E451ZY6tCIGAmmUkoI193XGqtIoDDG+1Th45YdpAeobhtCW6iPm7YhTL7pe9qtyLaXYbZTGMokGuYhANbCYGklbVPxdx/LSKJhFQKOPqCgYjGjjzrm+O2+o5oZdi5Nu+AnUPSO3vFsu4JgXEMA/KuLCx4TxVyRVEMBQybMmkaWcxQoraSC17VZlejKF9yjwQBHin0h1ayVfTFKW81lKEW26428HJ7J+vpMZvuR53Hv9KE0zbV84pNv0aIExQMBslV5vONP2EqwLwCWaZeiKnO0sSxg6zFybtQVFgqWixLe91B3KyjmbEFobOMMCQwkNC2HQGsYKxaqYNkTe3rZvD5B9OFO6wXE+4pXxPyx/XbSYqIOqhNT9xw5cuTIkSNHjhw5cuTIkeNd81g2xXKEZRr26GzaVIUrUKLR05ceyLoxHrrxfKwqXf8kPuNqK18KHekO4CuWNAO9qlLSABvBEh69FZgCSpeoWGXKLlXN6bJlqD9xbHqWunUH3rMiTHgjFDsbtWgTHuRVpWwtTXiRKqDNTzkomHWZsyqXDTQGwJwU3TJLAgqsoTBkf5jy2TYfjKgBML+8WMuL8xt58eJSSqoBC6kBUUp4K0OFrMvP9/QyRXvU55dFtKw7oLIGLCK46HWJNcAV1IavPzqTB8cr+YIv+KB88MMfkUevvyFlu5JZ1RJoqL+tW1tEQEy4xcpf7p9sqJYeyuYJajCaSmX6MLtiWf1r4T/tFhib9Vpu1msCZi9ApxHtLlI7hYkLQmrFEF0zbqtjQ9HFW0TTfLaR9NBrtx4wfnbs+7feeinHD06l33bsk2qOc3dMkyQTtKyY8cTUVMW8l3mdoCSPSlw9vjl4m1WIn6lDJSOgVGaq9QruQwPESFL0WvjQkxVQmnZQm47ms5uc9MQy2VXEQaHsIzDxWw5FChPPEVNJOhA7dKQISuGQhUkK7E3sTdLLEqXEvD+8JmZxoFz2IZHA5FRZ7b7ufm4KsX1eiGr2yb08bX64vwn9QqJJzMZmRmU9ID5U9uoZ7PfHtJfCUgnzNQ/pB17SmTRHx3K0E3m83cpHrq7k/MVzHvfq6krefPpUrnE/7PeyBbgmLVWlq1Z+29GWgYX7VI9uY8KbEkr5JVZF7Ezp4V9fwXt9HuZCL4Sq/av2GZgfmlrHM85srHZSYfwb8KWncFXIajFXGN1gTjLLFfr8YkWC7pOrJ9AVdSH7EnPATPYj3oMNhrF3etMjGakAnKtIdlDywuRHLSLWANxQ4AMiw+8YxfZgZVHupMZqDFqIqEqc/tJQpPv50Qe6Yj8sl0dSFA39ma+vruTy8kLOz89ptTQH6Ma8WdRSwJx41LaiD3GeWjRRExelJWyoWsZKGq4mQc8P0lQKonXFwI79BhsTqMznKK7K+dzgPu1GSO0n45DKYhTe9PnVV974mA9JEFPX21wDn29cQ/1eMZUy51wMZJ2OUdAQ11O9nTG/HNwIOXLkyJEjR44cOXLkyJEjx+dWsYxiQEQMk/eBo3oqVF35pw/A9HgEQKDCTpWzBMIAy3u1fuDWLKCmvp5VhQd/Xf5L5XI/yGbWSTmDCkwBExRvCqtVnYzCR3h4L7mMGg/ICk3hN4zCW3xeHwc+s+M9FlUqAQngTaytdqgsM11mzodyLjs3zwguMVYgTjMNFK6qa/5JOTBgKUAIDkJnaf7NIYur71RJjH9YgUB4KxuYckCA9uOwL642Mmw7WR29lI997Kmsjuby8LV7UsIGo2qkrhp+Yhg7qfcKmqGW88JWgMcoHEiQ0TT0IN13nRZsgmKyKuWLP/h++dD73yPf5lt9S/mW3/7bS704kmZ1jwpLhZisfEhYbFpA4tcZXryu7itrRAvQmNu7Ryh+30awjBEyjHwNPVTKa1nf3Mj5+QXVtlCFbjrAI/MsTiP9NYGJkz9PFLfT7YMCOrBN8/8mqBlUPY9rvJ/JeYel+Tv5+Jsv5aSp5OjesWxu1vxMewIJaGJpoeTOYLNWqIQ1gaq29V6BUhvXlOdUYHwESXsA9wp2bYm7FZUDNtSbzKG+KewBqeG1DfXndivrrfpTb2F/0Q+y7kazMTC8GLpFoVQqrI9d5J7KrlKPCZ4AYN0vOlFTu420/yQgNEsMt5AJRQ3NnzdFsJ7Uof2DKSsV7Jog1/YR23Rg1xEK+9ENJ1hPOLRzP2gHzQqoXaWd8O80KYFVD7hO1teYGjabjurei4srKovv3T9msgFwWee1uDMzSLFxovNmHHaYA0SWDx7J8uyRLB88lMevv1cunz+T1x48kPPnz+Wf/rN/Jm9+6k15dn0j6/WaRTpnY2l8HP68KP6mq0gIKzGWOUSgjo15HfVoVvCpYHknm/WNKoxRBNC8lIu0P02Ji5/LeSNDNcq80TmJNjwl5k2Rslb189FySc9kzLmcF5H0qmquPtjvYGNEB3xN+NUKnuExj5cDYAXLCm0VKseVJFA+jwDQ4ygX12uds4+POUchObXdDlQDV4152/O6e8FH9A+ScV5tr2Z/3LtX8bp99Vd/VJ4/fyZP3zqRT37ik3Jyeiyvv+8Naeq5NOVcin0lkDbvO5GigXd9rUlLg/rYN+/lspGCa2e2PGZRDtKNIhUtRESG2cgitFhRs6hbOVosuQ/443M1jA1ctXHR7x4UhcXKEqxwgW+1336ThQycgYO5i06z3ADttLHPvMpevahh3YI5YcT1m8m+1oTjDir5V82dOXLkyJEjR44cOXLkyJEjx+cCLLuCKqoUFQiwgJ4pfeHxqRangLUFC82hsJJuqg/NLEbF5+JCFVm2/xR2qSJMtyIowQ5oBaD+kMo/XDmpEEofus27EnyBINiUhAa8AuhNXi7ATlV+6RJq/uTxYuFBP25QR7ItBuySgnkEjiye5UWx7LzYNijRDG4Z1FGookeGMm8728n1uqNnLvry5P5KikpVkAA4gEUJMg1F3Khg5U9ACSswBlU3fEHrSk5PTuRouZDXX39d3njf++T+g4dmfzE3yampRt2oJPVLBmSmpYUplOkjmhTrmxSdcjsM3YcCVrRHLTCGrqOCcrvdyGa7peo2qPPYsXFNtxbCSwTKUy+GaNcQ1OUOE++y1LDdp2PZvX0NbI27mWy6Qa5vtvR+hl9s1cJf2qW1PmoTWwZ36ubfzAfYFbsJ2VZFsjfUwetUQ6uf08FJQEtvVLVmAKxzReYuealNhFsB2P2a9sDEguLVIGlaLC9t1u2eVDuZaR9Od5bsLwDo/S37i0M/5dA+605XDt9uwEHGwIDbpAihK60ThXN6x3OlhCl2TRd+x2kblGbfawJHAb/6X/P+DZ8wP+87G5uqrmdSNXNZHJ0wsfXgtfdIPZ/Lkzc/KTPYSjx7LruXuNeFRfkY/JmsGnF7kcQmROdIg/Vh4YR2lI4j9W4HbAUgBWjm+YdrpSpfglMcmys4RJqq5guMGn7DWrgPKyhQcA9gGMAbBftqeg23zWAe9Jpgce/7AonCmYFlv78tnzGigCFsc0yp29cA2w2TZSj0yW2ozB+DlRBnQapuzb5IqwsGWEv1OYrCjkiQmZd2sqIERQIvrq6YLKRn/ayURbugshl9DaUviwX2OEYpFWxNbF7X3UDZDasQJPcqfv8g8TebjdKg/iDOF/O1GonwJxOsdq34OQJ7rKRRexsmZO07hryc32v2ZWF2VNgO3unhHkkSNfHe02vuNho6b2gfjkjEFraiIiuWc+TIkSNHjhw5cuTIkSPHuwmW4W+pSmMo+LCsP4Fha/OGxN64RBjCuh3B4QAgV2D5b0u1WGOP4gCntIHAAz+X+qsPM7cpATC0yBSOACCgimL1pWTD7cEe2+HDDVTIAKeFyJwF1ma6VB3WHFiKjaW/XL6tNhgAIOAP8DYmlICNBFWPICm2PBm2A8EoARYfDq9gaaD2E1g+rSBGwTKKURXctz7ta+EtqKsBkRqpWz0PtG3cF3wh1LJA+wRHhep1HGby8bfOpf63Xy5nZyfSLmpZLOcyk1rm7TGVgW7FgXNgETcs636Jpd1Xcn5+LZeX11zu3RSFLMpSzk6P5Nt98y+Vx48eynf+nt9dPvjhL5B7T16X6uiB2pcEX2VVKM9oZQGgjJ+AUYPIrjMVsgFml2RSpmoQgwX4zPoCQHnsZLcbZMBydhQ/wzL0iwu5PL+UZ0+fydOnL7j8H77ZsEuhj/FEVnsAPw7hslJDK3znkDsBqLwcWmAyvGGQBu3U0VSZJ6y6nD672MqyuJDXn17Kxcsr2c8KOUaCg0W8HHqjH3BN4SOr1gGAXVRnG8QLBQNBdAjHpjYXSlv3EysJQiBTw+5H/NyxEBuU1SgAqUvbVZ0O6AVVIxSJuK9gi6uKZfcTp+HMxGvaia5DqKBUnnRo8LWwLnUTm2nHpypk28TUytPkjFqcqCqaiRhQy+RvCpxVLQsYFxXEZvOQWqQcNpXD0FNEmhzAHxyuq8Jexwb2w+u3g0XCnp7CsJjpB6jlB1ofeLvopGGKdFdBe98jgYN7jip8QMuQb2AmK/ph34Hk3S+7mq+krI9kfvpQjs/uS7e+lvvveSwvnz2V//Dv/r185Zd/hVxeXsqzZ28FmIi272uMQ7g+w3hBLXUGeGpbok9vTZjk6Jxq+SzO37NhL5uba4LhJaws5i3HF+EtPJ1N4T1vagP6Og7mzZxe7m7dgvlv2cLjuKTvO/qH13RWSVXspK6WBt1hMmNe57QM12KsPB9cZ04Z+p2COQL/RkIMc+S8bohjYYkEywqMfYBgjPeqrAlwAd2rorfVIbrqxKEyrivNKfaD9PvekoCljmcUIJy3cnGzli//yo/Ja+tO3njf+6lMfu3J63J8dCpvvvlJefHyhYzbUcb1IIvlQu7dOyJIR/IOgFYHuhbLq0sAZnyfzHkt6qpnWxvZsfhjsW9kttP+7vot5x6cQ43CoKWuzNGxBVsQeEqrN3ozYJ9qaaF5lB3HH4ZkhW2wPgJtwXW0669zh750jrYxMuyk3M0EFvm4P1Cs8ZWrPXLkyJEjR44cOXLkyJEjR47PBVgGgKVu1ZSRsXidQmdiKDzMG6yDggvvA3JAU4bANvD8VR9afQd2Aa5GNK2fgpBQ2Mi8iAmeWD1Q4SGV0YpB00XujnBUMaiqT7e3IExK1JqTR+lbgsvkl3TJvAGhqCYDTAniN7O9sApMibJZt7W/uX+mKQtTvSrQLjV+exSrEtobXFzdSN1Ucn2zUbWzHYdNg6UIYUKEXYAwgB4E6+afyWXr84WcHh/Ja689ltceP5FHT16Ts0dPpD06kRnW1k+0mmpnoapZUyzbv2nx4FYXoXMS090D9TKVwVTKqb8ygCkUyw7moAZGQTQCkVStPLk4yWXh8V4hsUtkrfs74PLhLt0NIxGTqhISBcWGndxse1lvetmsO2mXPRMcUWF8IBxNRkg0nJiOJ1emHxhFWxsO9nYgNHZFalDN4rpTPZv821SQPraSntC70MflRNV7qxPvVB+nxfRUGHo3jErE2AfvuxIZbVOV/+FxVPlrkNhWG7yjmE3V4N7WVEzt40Hfc1/r/aQA2rQP0gtw0BAbA/E/b0a48cN+FCInIzbsViHnjCpfFAC9J/W8lbPHT3i/I+Fy+fKc6uHryytC12LWh706NFe1cvLTV3lEUa7NQX7t1Jvd38fc4IktTTBYoVRa4ui1wnt1VUoNCyCzMuI8bQk2zINQ3WrlPVbfgyFDAqZNWUtlNJIqtm9L2rnyG8AZyZhihnPVMbpAgb5iJpt1Fb4rOP4L9IEmiOjJbNfeVeY0PbHtOP8QzuJa25xl7cf31NUNfN63XPFRFqW0DeDvnopsV3n3UG1jUia0xjl5IT7/N17aHxWST1ZUVsqdzADMeapFSFj6NYPNFLodvu5QIDNPRU9+fGfA5glJTdhsYy5Xiyl+lYXFE3G+MEei0E/msBNe/nfOI2YvlSNHjhw5cuTIkSNHjhw5crzrYPlys+az6QZKZDWMVSuMfaGAF36+hnlZNIzAC0oxfTDGQz6g8ryqqUBbyyg9Id0g2wEQYS+N+YHSF5NLnPWBGWq8rXvPmmpRhcIA2wr6WOMKSmD4znYDH9ShbsPxq6amwozKRwdK8H/mPmydsSPqUHBLoYUypQiRUbSPPp8GLGEtQQhuH6d3phvEWpE/qKOpeLYl49gnVN/ox36HYlPqD6rOHApGoELE+V93ozy9WMt23Ev7Hz4qq9Vczh6dyWK1kLpeE0AAzKKvr643VDdeX19K328Joo7qWk6OTuT1157IF3/hN5Qnr70m3/m7fVc5e/hQnrz/w7K4d1+KqrajsmUGg3tT2HaqUqZiuY+KZf4N2zqdRZ/Y341eAM8G2E+gPEq/6ehXDK/Xq4tLubq8YgE/LM0nVqaQNFqkRBoXPBHC7wm7uw1qJ3kBRXsB6qaYzxW7dj3dtgQtv9h09KP++Cefy0f/08fk4fVG7j9+SDH3fNWyAJe3JRRN9LYGlKuJB/U2dnsQjq6k7fiHqg2hiuZYwKfhJ2BwVQtXQv04SNd1LHZ4fX3N/sMy/uvrjazXnWxgL2LjHsmYAPTwH8bmXXYVoSHeR2YXMFlafxuhp6DUEyS+/6DSPgDsKcxVuByudPL3RB1tP1miDqDT7GPS0Pta28F+Sov1mUrbzs6Fm3p2UCfvCyZiCEvNMsaTNSEpYcXkWPyMqnIt5Na2jfrulgCDJgmedtG0nQaTNRRGRuMNSE8XUhS1PPzAR+Tee96QxfGZfOjD31De/OQn5cv/43+S66sL+fjHvkK67YbXeBw66VHEE8cXeNTDd30nu26rNf5oK6R+9yVWamBerQxKGmWsMIfBp7nFfFKrL7MJWMtGE2RQuOL823Yu8zmK25n9kSnd1avcJ0D1DsZ+Z43bb1SakDIrDE0HqrrfYbf3wjD0wWpELS86mbeNdH0ndTGTLYp9btXmpyxrzqkYFxUKxvo0ZJ2tAnbMO9oZox0L31n0yq8LWZ0ccVXN8xcvpWka+dSn3pKjo5XMFy1fT58vpW4aWvbc3Kz5XXRyeUOVO1ar0C6khsc9FMOaUqIynwp2+Nlj9c5IGw2osJnkK/YcM+38SNXwxaAFGfGCsrmAcryifceKCVrsr5NuNhAu78pR6kqkrfC9QUGzfV/h/+N8ZzJgUQmTvIVO2WNqmaPjDyss6G1t1yVHjhw5cuTIkSNHjhw5cuR418DydlClHJfjB4SklIaWFgTLChjo4ek+sIkYEw/BeBAHYC6xbN9AEtRvDBaZUsdKqOEMUali2dSagD8OIYh/g/IQUEl9K315MttgFEqVwgqho7DWIIeBKFci86izKRxTSGNFsMISd8AafTCnUC/IAxOlMh/a8RCP5fcGpV19yvNSuwKixITdGeKVfreT9bbne289O5fr9UbaFZbP1wSMXafF4ABZtpsNIS1BMywX9jtpq4ow/9H9+/LhD3xQnrz+unzoG3xETs8eSHP8QMpmcaC9JY2wonL6kwX5zAaD/soBHnshR9OL81xMzUzFcyKX47JsXW6vKuVOOvgqdx0hu3qxKviLkSparaCiA+awyW2TgfSjU6icFJs7kAPzWpmKkbBvNpPNMMrVXouGvXh+Ie1iLt22l9q9lv3Y5s+rMDRpHPep4y6aCXvqAP2T0m/ziWXSQhtJ9SgBman8TYGPe4B92Peyhdp72wVfbaq+rVhgUJsaUD/EwxMUn6iDg4+1dVBUNVq7Jxv6ccLWt/XbdxgOO6xOujDZpdmp+LHNdsGtFG5fdlfjxn16oUBe++BBG69L2Ddhm/ap2pZEFW1omClnvR9on85CdVDpqs1NyK28atXD5J3pGot4HPgXV7I8vS+yP5F6Vsr94xOZL1ayhdf6i2fy4vmbbOd2MzCJhvUgenAF77qQIRZN9bZpm7XwHN+HYtmuJYAzAPlisbSxo7YSUOtqUmSQ2W5nnsq1KaWtj5L8gfazwXzYTFhCjVYMmGlTsMyVK+6PH+d6HMMV+LgfqgHJm1GaqpJuu5aqxLlteWgqpb1ooIF9Xe/hA9rHkl1XS0DAr56jtJgREPd9J+vNRpM1V9fc19HxUi1SLCm53WL1SC9Vpasr2G8lfJTNson7RjIU7dVVGrAz4v2PMQbbCRTyo2WNSEWffCQlfBzY3ImUlheQne0FKb9it5O6RD8CO4+y5/lqwkAV036fqDURkqYoGoj5xdX7wWs5CPBtfVBYUXNrqObIkSNHjhw5cuTIkSNHjhyfO7DcGfzVR1/DsPZwiudVYIK6qPlQvjHV2Qwepnix5ptBsdlAaDfAqxm/0yNUqRbeB5ygyi4s/YcKeScbeEPuTfU1E1lBPTyjKQN9Z1ESaTcrpd/P5KYfFeAS/gAAqNJY7QLgh6oA2gGwqhoNbrja2NV4LAjoD/C2jatTWXjJrS+suBwAg6ECVYQpTMbDvS8ldxilQl7tTTbHARD7Wf1fRylknJWyHUWen9+wmF+9eCrX6xvZ7c6oLtQl2w1Vv02JwlONvPHwTFZ1KW88fk1ee/RI3vu+98s3+9JvIcf37svxgydSL5ZUzenVNJ/fBChDzenWF+F9+p/Cg9bUzK4qdRuCCXQ2G41dr56pKNhnQLTbACoDimtRPMJQguVU4JqSw7ttMUISYFLR79PYZCTiamvxwTU3yIMl4rKXLYX5e3l2fiVf+7E3ZVaWcnNxQ2/XxWoUqa1tgaVGdbB6amA8m5fyFOMGlW9AalqZT/+cKN65N6gyZwB1NfuqqkYFaqXDzSooZhWAauFAJnYSsBu6yUCeCqmjbUTaPlo0uGWMWdLEa6H7DHAx+Xz0aj6o2heAcDxO+Ld772pHhGuTFhHUucBVzLfHhCvR3dokKqOtS90Kx0Gm2YVg4y2SGyEhZIrTYCXh1h2aCAowNQGqnlAKY8sTDWEP0/EVX24RhFksUf2KKo/r1T2Zla08quYymy/l8uULaZeNXF2cy9d89VfKxcVLznVrmGrDOseP5eeIc94BQFrh0pgTC3UxrUQqwe0C9g9M5HFgu+QXrhW0k9B5CzO9+2FjyCrMLLCRec1jp1x54GdqyTQtoKo+7ngFiw7bhgkyAnK3e4GyV8f2MCIBBbX+VqriUtZQ61oRPIz/tlmxbVjtoPlOK6jIOUiTOaroFRnsb/guQPJxbCoZ5jWTZk+ffkq6bi0PHt+Xpq35qvEaeimbEnpgefnyJRXriyWgs65WQfv7bSeb67Uqg5moEOnGUj3jgZyxWgdF/Up44xeyG3VFAr6GeS/ZShm0n9b8UM6P8IXeSYP+KQd6UrMQa4laBCrL1voA6qs8MKEI+44ieCyznsFgamX0Mr6PuGpGk7Bqi5LJco4cOXLkyJEjR44cOXLkeBfBMgrsmQ6VuMBNAIgj+cAOC4uKfp09i4gZWA6F3cyTeTSwjCJTAMVY6muF79TKU9W7KtxUNTRUu1ssJ96NtM0AUG5rVb8BKqvWC3imlNmukDWWSUONNq/MukIBL9ld4kVLqMHCftOiZsGd1IAyC/M5MLSTptdupSo8vPaFKW4dih36KwMWVCggiCXcqnLjsnWCZdspVIIkPVp0jMvvcW4A3ONeXl4CahZSt6Ws1zcyn1eyWqEcYiFti+J4I/t/0dTyngf35P6ikS/+4m8sX/iF30gev/Fe+cg3+RKpmrkU85X6oLpvsqcLCIYBkK3IXFAgu+0F1MsAy7DC8EX8qgRUI2W2NhbwIxjpNKEAewb6Kg+EyXhhCTqLnwHCeOE/6/rIhw+Ah9lKRLBnUHICl5PBechLHEYnu5+oSE15iau5xTVAv1/dyCc+8VTmi7msL9cyb+fJdfMDRMV3UNUSKltb2TexYRN3iXgTmfLQxppX8iPk3VM5WY1Ydj+EcUcrAE9UhH27z3J4J3qOh6TJYdek6mGDfSmsTVTCqQ+vlZ0MCnAHy0wy3BF3AyxTaFvb3VZjcinZN6ki3FYZTMXR8RoSQruKXrdzwOv2FpifcCwmN2DOUELRC4w3BeKpWjyeZ2KjHKCyK97TBqkCXmfMQ7DsmmVXsfusqrYO1aKSanEkzfGpnD5+LDcXsGso5PLlc+m7ra5yuLiSobth/yiuDTg3qpStfeG82VSb0zC/sehcKXMDyyg1B1AJWwr2H1TCBMO6LMNBsI4zvW+RIkR7AD4xX2qCzRXeNgdCLT3zVrrtkM35VhgRfsNMslgiYb+vpKwrGZnMgiofxe8GqZCt5AoV9YhetEst+Ok1MaG8ttUWDnlRKBX9VOFbB98944yweleXMrRY+THK8+fPZBg6LarXGFRuK6n6itY0+NzF+QVvzc22kqouWPwO9+bmZiNXFzd6L/PmQzXZJTxtOG8XVCirKh39sRs96QR7JPQF+le/Uwjx6WO94fiATRQsRvA9gzbQLsrAMlaToL0ocsuxtitlPwIs72TX7/ldCxsXtcbRJAO2I1gmYIe3cwbLOXLkyJEjR44cOXLkyJHjXQTLbqtM1gXoAcUUgKrRFWIBK9LmhcXcE5mgDiCZKlAghZn0BMsu0rQluqYEpjIZ0MA0kYDUbpfg4I6gdiZUclVcFq3Lh+m7bCpggKIKPpim6gRoCEAoLeblkNlUp75EHKiB0IPel7q8GSSb/qMAe0FhZpTQliVLsM3wlxph0m85fM4f5A0q2dJ1FqPiYVRtiOXu3aDloZzRX697gpkXz6+oNDw+GqQqACxKuX//HuHJvXv3CFQ+8OGPyOP3vl9OHzySopnLrIIaMSGFZncRrC926jcbVMkAOPTBtQJYiYGpwjsDY/yD78eL/unCdC/2GGilLd33omC4PqrmrqStK+nccsPsL6yxYbk8oFVwzlWKdPegjXmAoIh19phCay+25ypZ+IU7IETCYjPs5XzTy+W6pyVBv+1p62FS/VioL9JYA5i47oGb2rlEbGly6cQHIt5gYTm/yi9jgUuzV6GvbAVFZStt2xOCAXDp0vqUHEe9bLzTpgxePZzjvx1MYwl+2mepJnyyr2SfuOJB5Xtw7aLC3D8QVcGp//n0+pkNCC/1oZI4LRSIFQKpmt2V0/4ztUFxVbO1NNiEIKETx0JIQfh9nBTepNp1ALREUTYFmLSf4D2TZkYiwYdK+hbO92s08aROe1e9cKF2bRYrOXvymsyXS3nv02eyWKxk9dZbsnj+XNb9Vs43N7LtS1o19Jiz7PaEF3iBV1gRAfWxF0mNNhR7fAZ+zJwr4XiO87JEi6m49/BRt7mV929wvkESaaDimN8GUDCzmN2EwE8vbfSrCep9HSdqpxTvS4PPZSUloGpVqyUHj71PCt3NZEYvfZ1HWSQPRQJ5bZCE1IKBO4Bc+mjrdwIK7NHtGvY3242U61Kur29kuVzT6uh4dSS7fidX5Q3b1MwqsxXRMdn1ALiDdPB+HpMKeQiYHRMGF1LQ23knuxkSn4U0uzKuhiFQtu8wKsQBzXf0h0ZyC6BYi6BGyx74J+tndfVAuW94zjOeM17uT65QGsUB0Q/k9vh+xS447eu4z5EjR44cOXLkyJEjR44cOd41sExeCLsLWkvMZFEVMocfJjwrAS8AcgCVqUpG4aURa20JYvHeQD6502XIM3gHF6oyJlBQQkQ1MWw3DAxjv9hmO6CIEx2ZDTXAoxJ+zTOpTYmIB3EU7wOE6IZBKhazgpK30uJ9JR66RxlMocxDEpgbwCkVIAFcqGpaFXpBcwggQwWdCssqQA6sV+ayb7zp0NuXfSuIAkCoUHRM9lKXlTQV1GFRYUrFHxvjvytIAc9ldUEWbYNybSa7qpJ+B0uMrdTXPaHD9cW1PDw7JVw/PjmVD36DD8hieSQnp/elmS/k7NETOX3wWIqqkbKBUjkBVw6UAZGhUAZlGFWluAdgpp0HSavBGYMmVngt2GjQa9kANWCUKRnVEUN9rqNCWK0dAFABQUBRAJRRoGvZ1LJsGxYqLPeA51biLgH2yYj8NKM1BdK+rSmbozRdDrGr+3G7MQo0mMBbl/1ePnmxlZOLjVxfrmW12BAkqXerFXvkAFHVpYI29BtUh1Ywzy1GrHhfKCJI5bgVkNRqktNmuwVEsKzAeQAqt9I0o6xWKy51R6Gxdo6CYmXCuKf74/mlBQxNxZqC3gBqoX40aEULgwSQJicZtb3BRhnXOyZtFKSn6uJEZWyJHPc4T2G03oFTaw0v5Jde06hKDh49thLBVdX++cS1I8HTqhR17fWMSQVNgviY1fsasBXzCTYfMY+giGKv3tbj2Em5Q7E1JERQqC4WAJ0e8S54l1pg3PU37KpmIb1l3ch7F99Yhm4rbbOQy2fP5GNf+Z/kk1/zUTm/upBPPn1Tbrad7LpBOniZ22oOgOWyxv2m96gqlDEvIamjBVNxvlg/wmTFolG16wzJvJE1OXV+G/h7UVRSV62CW/PyhqmvAnazdYEyl1Y75jOfJhRCYkjftGFiQ9KSeVxhon1mM6uUdWO+yFgxMErfAer3wTsayn0kWLiChN8kDsNVqbunMhj/wdJH5y7svevmslnO6bH85tOnsh06ef7spVRlI3VRyaOzh3T4uXh5JcWslCWK7nHVzVbG/SA362tZb9dcxYB7kZ0AKyG0mh5HZkUC2yQkOzu1ZmrRVq5SUbukpka/VvadNOM1GLbqzyxDwdU4+A5lIpRJEbU3wfcxjtXsMdoHfndomnJv6neAbPOapppa+52rfQCxkbjMiuUcOXLkyJEjR44cOXLkyPFugmWNPR+CAVhqLMUtCn04HZNl/GYzESCWPd5iE8In8zamvzLUU/7gr1JJ8/jUZbsKRrw4oCOnCKkgIORDOR7yCQ4iPEMAMMALM1UJO/SLhYzsp0HHYOtg9hWzAyiVnlfA0wfMKDHWSNpsKmYoAoM9B7B5VNSmB9JiUArlcUz8mTWw0C8E73tZbzq5Kvb0VF5fr2U+X8p8Ppejo5UcnZ5IO1/JfLmSqm4IeqINhFcvVP/R+BNQOCnaR5V2qHQ4UQY7BgyeDsHb4TZMC16qdv7xpT7BgCl1XUrTlLQ4QfLCEwiTCJ1l6l0/9K1R+gpO58DXPLXD1XIP4sklsDE5g2/pnsmNTT+wzzebLcEi4JbbVDiOnsJRZ6feJ9GX2v8aLBQAlYN89qAKnNsCBMV46ik8LcA10YTaeU3GYzh/P+GptYQX1dPEx1Sf7J7qrir2AnhBkZ7Cw2QlgjbH1bvTq+U9d+sqptYz6TVNXFCSv9wZh0MgBZuuAvcj7yyZpCzToLsnveTQ69n8am1Vxl33BvcfDJ+tnXfYgHh64/B8JgkQVzoDTFYV8ejy+Jj37P3Lh9Jv11K2tWz6rVTrtVzebKWY9WYfhJUSOudYCVKuwKiKURXLVlCPCJ1JIF1tgt8VNhayR9LN+o3e+S6FtoJ8IZGSfAfQGgj7wSSdVIYMvuK8BuFCx6SGKain8733hBVCnfn8qf7CPEbahQdjzMdtsGlJkhLUKtvlg5UELSr2e9mgmN/NTfDFb5pG6rqRsqikAeBGO+EBvS+l6jspUdwWKmv60PuKA8BhSzoCBDNZYar3QucB9N9o32Hef+XObJI41wOe49poMs8L8am1B3N+ViAU72uyjqpk84fnrMTGqB81xjTV2SGtguSX/i1Hjhw5cuTIkSNHjhw5cuR4F8HyyIffZTuXZVtLO6ukRVE5KOPwUA0dGGTJsLlAkTYUBErgWkfgAEinwHKH5c62LZcl4zGXS6D3VCzjoRl2GYDKHQBJAF7W8FKkqWb0E4bSa93v5aa3gkT7HfAJVbCL+VxkVlOthvarVQeKGemyb1UMu6JNPZf9OFD0EejCixWAezB7h7o2wKGgSZcoOx32ImpWSA2gnecIAStAd0VIAQCM3c1mXYR2s8SzFupt213BInAgnLgGquIGFBq2W7l4vpftTSctYUsjZ2dn8vDxY1lZgb6ybLVYlHkyKAi09esoDEXPZFUswz9Zxq0pa1Vxx9NyCG0KP2JJ5123oLJpMHnhbfk2lXHUgNOuAarCpq/5s92NcnS04Ifun8zl8qqVTd9JRf9tLxIXuUc87oEf8dvBRofFekJUXTLZkBbbs3MJthOlApjNMMr5ei/PL9fy5pvP+N7jqxtpV3PorekBG+CNNkyvqY0t9OuOCnB3A0dXmJ9q1ZhaGXAHXtpg+q5i1xctF8x6ASsB4LftIFuRrvnlUhWZFotzaD31IiYQdYgdkjQO2A69kaM9hVtBOHymopIwzpTPyTW5K6x0WPzd9qv80ZX+pjC2Cx6tMRxupjuM4y3+tHMPTHwKvENRP+fLLOZmDse09tH7GvOHK+U154W5Q2TA6onOvML7nuDPof+rTjxZH5GmGyZ/fdsIBteAy7Xcf/11OX38SE4e3ZP3feSD8vytT8nZf/pyFpbb7/6DXF1fy8XlFQsTYqUECk7CI7iilzDauwlF+2BFw3kOiZIeg2HLPoDPLxJzA+ZlFBJFAT3MFRxWPefApm75+9B3Os/RyojlWEVov8Cyrmrf4j7PadLICikGH3B+P+hKEk9QqquOKvU5L9dzegbDlgNzM+5lElbAbBSHDSsWosKd1kYotGcZSS3Gqq92rsAW9+Jbz1/w30+fPpfNppf3vv6anJ7cl+3RIGf3N2xD2yzV33/X00oDNiV13cpmu5UbFO+TUfreE6RuE6SXmnYd8JrmYO8TW6UZv88qvGAz1dh9AL8KqpPxVa3fj/wuok0SmHUhSDXgb8OupgUGEr04L+wTyV+q1FlTVf3zMVbx/UHFOgyfdvNDFp8jR44cOXLkyJEjR44cOXK8C4plwlYoSiupZ3ighdpKH2AdKvkCeTzS89GYhdB89a0uv/VHfvc21XBEpjBZl+rCW9kKkTmSSX1Y7UEdijzlUVFJh9AiSKbiMrWbQidTXRJYmEoay4sprkvgqC8PdlGiNnMqPnSLgnR9f6LOg/JYVc7mlmDqQVUtJ1rNiUBVfakj5NTzp70EizOponlG3+lRNptO1uuNdNue+4W1BFR2ddPaJbaCZqm62FXJunZd1YXmFaueyaAWBs3Z7/CPncJju2wJXA5vBIsFt/xIFctaJE0LVOGF9jYN7DAqFiRsavSPeoGmetpwdFO3pyPnFlx2kWdaeY7jVJeRx/O4Dfai4tj8walm3EnXj3Kz3rKvUXQQIK7AWvJkzKSq52AXYRYYuqjfPXjtDqE6PrnB3KbirvOZnG+yuYtsUwPxNMI+D3YUujLROU+6NlGZehLF+zN4Rk/v4cnu46ZvG7H4on5iWoww3eEd+vRJAcbDfaa/u9rdKfO0I9zHOfodT44SrHMBNdVH3pXkZnGSNiF4Fx8e6PC8p0cJ98u0V9OGEi7X7Vz2TS2yOyWQhIL+3v0X3PZotWKbtpuN7IZB52orCAfVLZIbKKaJY/AeDKplU8OymJurg3HP6jhGobyYnIJyWVXAmsxQla2ej8JUHfdm7WMKblcjH96X085IjzM9dz8m7YRKtRViMsIzBclqifT6w9InJF7sFvGihvSxRp/WsEyqtEjstpOyWOs9jv6zvlMbGk0swjYFYBkKZvQ/vP3LEoVNcbyS1kn+jeT3BvyMg82KK/8Tm3VOw/B5Qr+6SjtZDeGgnc5T9p2FCYozin1PqU2GzcFeQBQH8e8uVzIHP+xXWbHkyJEjR44cOXLkyJEjR44cnyOwjGd3iE6XdS0nbSvVHkXzFNI2JRTFCuBMe8YHVdeLqWWCqpmx5Jx6KlOLUpBlIBquuthHDRXYzhTLVCBz0TXB6sBl2jsu8y5GbFtTawxV87YftA2mxITqsIBSGfLmopJ+qKUqGj5Oj70WTGKBKn3ih02olO1MCoBNWDQA3EAx3Q/0zoSuL1g4+NJ5tkeLI2mxKoAAfYiH5yZV2jJEZSg8NmtVpJUbwAAtkEYVKrmgqTiteBUgMjSutPsw5Sn/gkOQWlRysxnl2bNzOTp+IS+fPZW2hU/qidRUw2qBOQUHhjm8wN7Q06OUiuXR/702wAWwrNcvYYlReUlVuL6rQCMpXujWDIQm1heAOhgruBYtgFgr/cmS9hddv5a6KeX11+4JbFlh9fHy8kY2/V5ebs0LlImIBEKl/DSFT6n7goPS8DeHl76kP9lwYo5hm6OYGZSJ+0K2vcjl1Ua+5mNvSt938uGLL5Dj+yuOFar+aCcCWKeqS17DFNIzKeAWI1otj97ghMuDFAWsStR7e1ZV2lQDeQX8aotSdgL/a7UfIPALqlIkEhoq4VFsbNoVuHMSn+H0BAOnVYCnnCsWBEuDSaFEnezWAl7wcuJV8WkiQsVoQRLhcepRrn1Bywn37Gb/TZMKh+caVx+kNhaJfcdhgcBglRA9tr1ind422neAyf1+xwJtW7w2+NkxmdN3nRR1LZVWQpsmCiY/P5s4UETzOs+kXh1L1c7lrFpI2Szk4flL9t3Fy5fyFV/2ZfLi2XO7Y2dS13NZoBDdfifbrXpTL9ulNFAzY+xUDdW0KJLnxQqpirfhU5aYx1Q5S5MeU3HrIMD8h32qDzPUw1DFMgEBwAqw7TYsydhShbzde058LZmDxKLbPuh4UWVuWbVS21glVGVxQyuEZ/5FqT1LvO72HeSJoKoKhTBxzg1shI6Ppe96ublZy3azlcuHZ3J6esxkJ1a/4FjwN9d5HithkFhV5XZbYwVDxZUF22qrNQPMviYUtOV3nPqOo95ABOh76btRhtlO0M3D4Pe3XgOseNAChDZ34Xuo0nuwKqF8jkp9rg2xugdqo4F7F/UQfNLU+5V9PM6k3yZkO0eOHDly5MiRI0eOHDly5Hg3wDI5MIrllaXMASIAuHZQMMPDU4s4KdRNVIxR8Mvl5rb+V20dXOXsalp6CKsfJbYl8iVcULAQFIOmesYS42GHB21sp+pmQGeHY7p8WB/O1cvXvZZ12Tf9aq20GIDoCFsCHKc2WwF8DiTdVNhogwqYDRAH9wQaIge/Tn8pfDVorvRFoYoprLHcGZDaVZKhexKQ6+pvVywbUlY4x3OHx/VM+mFHFe36Zi3rm2vZrq9l7DvZQ5XIYm6u93bFmgIOQiAUmaINBuwwFC4rnaA5CeEma5FNwEO4cpPfva/UPsDO17SXPNcd4JT6XlOl3EI9icJzDdtzcrSQruvkeNXIoqmAYKTcKvgnqmVDkk6aQM1PM3YnSmLzvrUxFvokwE4TOyd+1D7mtt0g5xfXsljUst1uqVpWiOywPapN1RJClaxRKT7142VyBbCWwxUgFdfK1IqJG4YWMjM/VtcvHnhWs19haUDp/m1IFJSat8KuWArkQ8G8xO4k+Z1Ke7MQSYv/RdVxogieJCfiGzEnEOGyAsGobtfzcxsKv89cn36X8jpZVpAqYYMdRkSNUZV+ANCTey56cevY1cKksPzZ0fbHi5SOsMOg3U2ikL1LWx0XWhz84+3iYE8hEVJIASuKupXlTMFqO5/L9fPnsprP5fknPyn99bX0g/qyN2Ul86Zl8cFxqDk2axTrgxKXkFXHX+x/U81SIYw+KfVlJxOSiFasUK+B9kM6n6UnP7ks3tvhzRRwRuAaviA4tyKxgpUOtezKnZRoN9uXqJF93wdq8JBAmSjtUVgPYBlwupa2VWuP9dVGOqi+t1vpOtzne6logVRKBesbJhbx9QnQq97M2AZQegQZ3mlShApjFLiFJzvhMnyvkQDVYonsQyYn7fuPq0dsCuYXjp4/kp7cDz+PZJN/v8FSSpNHPodBSc0ilPieIsQG2GbP2TWx90yFDricI0eOHDly5MiRI0eOHDlyvKtgGVCWNhjwbbTCVnjSLWd7aUotpLQbdOlvZIDxQZ9wlQ+7BqJsOTAenKFsi6v4Z9QQ431XquIhnLDZCmchruBzSgUYfJ73sulFOtLHWQBwDrq5jLkBDKylnEENqg/43DuWSPsqagpJdzL0KAU40+JsgIqASIC01kh4q0K1CqVqPat1KXahgMOXLusJ+Rkq5GAxwXImdVNI26BonRYIpCIboNsKEnqnOVagAYb1WfTkADwA/y2oAi8BPcpCNluFy932RoZ+qdsoKU32aMpKKGypMIRSGUu48bvZNQQ1nfWLnYnD9KBJ92JcAfU5QDIldyjep6CkqAopGzgTV7Ic5lLVhXT9SuqmksePzqRtGrm8HuXleScvrrZys72QDuOK3q1mEeLE6NAPwlWvqb7zDnVrjNR+IWl9gJMKzdQrFQAHEAq+sTU9awHl6NNsisQAO6ncjvtkgTCWXLtDXc38g0Lu2d58yR3suuTafFY5RgxC4Z5RSF/RmkbbMfVXTsG/3jv2qxenS7biIUM26DCJEMOho+9GkyMK0XifW+E3/btbdqQWBw5zD4GuWlAEpXtiMaPHdDBvHtHpMRJZejTjsSKbXrAtSXo5xJbDcWJzVrKV+vH6uXJuqQg1oXCd+bXH/NAPTOQwmVNXcR82xNLxEBv7ToFeYq8TgKmfp+4bvsvt6oR+6u/50BfI/etLJmkePnooz56+kOcvzrlyg2gWq00KtNEtL4SAtmoa2WNCAqRl8ssU9Czap4rfnVm/BERvSScodctSQeVuryrg6XiPCQhaOCT3MYe4ZQ492QMQ7v7E3J7fQdoP9FTGscqR1hVcJ2H3kd4DvHCT5APf9XqctlrG28TjsBCnyHK5pOr/5nItXaeqdBTrBHxu563ZKNkQMriL7bFyoKlrmbetjLTKKFWpbduqhYUXRxyYNKMXtRWSVe9/hc++Okj9l/UcoGTWgpGdDPR2Npui2U6KXWergAwsM8lQEluj8CiO7xZUs5laoNDLesSxRtmykGCGyzly5MiRI8fnc3yTX/wX5Kt/4w/7L92MHDly5MiR47OwwgA8AFiGatk4DuAhdtACWOz20gEAmBKY4UpXA28OSSPw04djLNklODKouktUunzwN8UwVV62NP6mG1nkbyedtFDk7QoZ9gpCKip54a2sD+xQyDaNg8CKvphQGSrUjuJXVTJDjTjwfAmWragV/DOdzQ5DSbjM5cgEHVD6qVp0YqhgxFohq6nMZiJ1VdACAsWkCLBNFavyVNU7p4/5hAjaOO0fgzAlobEq06imK2ayBVjeXEnfrWXsN1EpbWpABTqo5AQSg58o2Kd2GED62ttoixXg4rVyUOzF8gxmTSSwUZ3sEE/Ft4q+HEgChkqD/oJispWqKWUYVtK0tQz9TpbzuVxc9fLixUbq6lLefHpNKtKZIp79mToMmNDx0CkgNP2AQSd/OdCVmv9y+jkrDKjL8nF8jNMELMOiwmwm1GdXOyd4c6sYPmnkbdgdir5R6emF9HQcH0qMg4KTdhiufC+lAuT0hIb3daIa1mbFJMFEOWyDX7vQEgSvYPEpVI6/R0Cu6k39JW367X/P3hbzO3xUCBjV73qvJPAa80eSjIn9ZF7WB97fh1zfNdCH2uUAhO24egiDiIDLvEZmk4HiilCkDriX4KeDRqUc2YvTRQXqZx5ekC59xYBdSrFspW4XVPv32zWtbs4fnkn9lV8t/Wie9TbKAUJp9+EQGasIAJYticSVFQTLft1wbTH2rVihW1hwFYb7xasqFveJJ5p0uEZle7wkNv/7rRJWhpjCN9hIKJDlHWT3BedbzH9ULI8sVjcOUREdvnfSsep+Hjy3qO5lKzxRgeK0KHgKKLx/xqK0AMvbTUcleLtsOS5hWeGWLrqypZRyhJK55nb0ZcYqA0u2pJfKk5pYg7Hb92qNYdMxvmOghFebG1v5QRulvfTwxd7tuDqFK0wIh/H3nYwoMmrXDXAZ8xOdUszX3e9rHUWaUEWiFOMWvxew/MlgOUeOHDly5MiRI0eOHDlyvJtgGWplgsSwPBngRgEXikNx7X5vijMroqRLdfXzBAuJstUxj1auL+mf68DQVYtGLyEpcw2iKp5N5YvD4Dm7hLqNQMMAmYElAGLAAeymxpJpQiB7sIdi2ZW0QTGtEBwP+GMx8vMInGOFYnPEJuoj7AXZwj4AK+CDmxbyslBoAYCu59POK1mtWmmvNgpkzM7Y1bgpNvJz4r8dhJmSU+0/9lJU8AZdytHxSlarhSwWc4KgceyosMWLYK6Eig5th1IOKj1YYABy7A6KzFlRqNB+g2P8WwSjO1sq74vi/V+q8DTCY0DHlbgzWGzsZ1JCuVyjXUIlIGxHVitAjlIenN2T119bsz+/9lPn0tx00l+uYbodkgw+VlwTrF0SQY7z0yDWJlxy/9V4XQ5BabwCcR+qQFeP7/VmE/xXu00n5bwJhbaCtcZdtgWstmXbOF0zuwEfRAlu1j+bWhaKZYA+WoJgbA4D1ahYok81PawYrJicqx9TreZkHDrHxr3rHsm8thF6qndyAuoOItrd2H/7u8BxVBVHJbRvFFXuETQmthMGHuknraQ+KZCYQrLDxiUzS8go7G5vmcwz4Q2TJaduK6EDePm1SJxbuXCVggXtDcwKIxb/nMLjV1nYxvcPFdzT9mqXJNvwDdxLvg0shkzhjSJ0spfjBw+5kuF6O1D1v96s5eLiXIYRuBEJNvX3oS1OkmS43caYQKOVkBUdVfsLSyTwb7g34aWu6nnLFE7OXRXt8QwL9023cYPP+0qTcGqzg4RLqYkLWErwHuE87AVS7R72RJ3N65y1k0lV7X1sTjMPeqzCwfUFUFclt87nuM+atmVikhZMhLuxYGaE67Bc0hTNiCQi/+nAPd5PsJHScW5rFWzyx71dVTvta+tb9g/AMkAw+s48odMkQ0kwrHZT7F16UatifOTOfW4yD3yx2gX4EuU6DP2ey5EjR44cOXLkyJEjR44cOd41sFwDqhCmeBEktZHAkvwFPC7x1hZLfQ1UEthqwTv3gsXjbxeK/Cl4asqC3p/4TXVYgUsbEFBG4ZhKi/+JdPCGhPKOSjWDmkYCAQDH3SibzVqqYidns73MWyxPnrHYFtrPT2PJ8QxLuBUMKOXa0ScTvgFbFoRSCxAoa2eUlgF+7tVT1SEJVNLYtqwI/AD7FOQaPDBFWIUV3LOZHB3N5cGDY7m83lJ1CphiPaphgCBieCvK5OpNg1o4RyzhrttTefTkgTx6ciZnD07l3r0jKQtYeqwDlGGBq1lDkDL2W4XIu60W8sN5me8yzmviIZ0qCnE8Wl84PIuenXrBFPAAWgb8F6S/Ciyh7qP9CCxB0P+jDkHA0qpseX1kVnFJ+enJW/L8+bm8OL+RzXors35A7oIYBEccUzqXio+TuMXJHFomv6bWDmETZ4oOPYtCiwpeXMi8mcnVxTWXyzerpRUVC/Q6vZKxfQYvua0PclcAu7+sHV0TFqoC5Ta0wTBLDhQG6zq5vr7hEv2u30o/9ASbuHYBMCcnFISytn/efxxgrlRMoXKiep4dwkAHtg7J3Ks7qkNT33AtXOifc9WqWtW47UQ4pv3dA5CNNjawiwmA29TVgKh8Q8fq7fAkiSUzgg3MoYf29Dyn2uZ4vlSBmoc1/HXhD65qXm3H2PcyVn2YE4JK3bNA4Ub6uiiVY/IkbZla1kRrHHVt0Ru3aJdSNXN59IEvkN3QSXN0IidnZ/L0U2/KV/zHL+OYgaIac7WWTEWxPSST7Bom4FOvEhJCmAeRxdPk064YbPwC8mqCEXMMvN9nO01kxWuawF4bT2odY8kggln1LYZnPK1VgkWyJicV6ivw3bFwKRTBUN1qsUq0Q3/qMXk/UPWOa2f3kHnhx7yfAmpMq1x0gi/EumLbsCKhLDoZ+lHW67WsViuZmxoZVkk7K+wZfFJ8jODzmPehDaddRbRisSUplkzSf/Pawf8YL1Nr81wNoiMhio26HhYYo8w3G9l0G7POMONlzsuwghqZaAQvJvtmclW/WXYzVT7ruMF8q2CZSUau3smRI0eOHDly5MiRI0eOHDneZSuMaPRgS5gT+Agl76RQGbcpwucAVAELwIHdf1k9mmGvAU9kXaqtUMb9Tk1NbKqr1GXC/5HqMSfL3r1Imqls0/bTZmMPpZafSwp8EthnTFetDUzJaTYEkVVFRVpULE6XZUdFp8I2qOKapiKMBqxn0Sb36E04X1qm7E5u6mrpckav4qq2petsOmBDfFGtq47VtnP1Rw4K5SCZDgcPAE67Ir3g5vl8YGsSrjvbZqpbUy6r6tf7V4G0qvvg78r0g9SNQo/FspXj44Wcnizl7N6K237q2RVhyo0SExks2aAWxEkBuMMhcudvrrOOXOiuuF3YDWAXKvhB+m0nPZIIgK4JuOTnUguGILtML9pEZ21vp+Mw1axHAB2tKBRCETRb8Ti8BsJlhXZU9ofjHlhYOOSbFEpLldsH/ZCoQBUO21j3MW/gOP18WvQvjn+HuGESuHWMtBigHzOqj6OXsh/D7UemnX3gp+zXMhX8ui2D3WNhS1fK+v7DCVjyinY8XpRR57twbaa9lkjnv25A+da+br2VjPf0FjBaWtAep5D56kiOT0+ZaDs6PpLttpLxZqfKZaxacIWwX8ukxaGYYkiM+ASTjB1TxIYkhTUmtblIY3JN/frw2PH6xnvJZo/0HrA/sC3IUbD43HR/vmpFfbDdQsZV+3oPhnQKp2VPTKrKGhAb/s24x2B7hPscwJcqekvsOYhXxbJ6oTP3B6DuKnuVEAcrmtQKJvh/h9wTALjuF/knfmfRbBnuKhX3XzW1NEzYpGDZioIOo5SYG3dqk4TkjFotIeFrq0pI5WOy1r8DcuTIkSNHjhw5cuTIkSNHjnfXCmMGCKpLovmgXOHhW9VfUCvDLRJwC0t23QMVWrKWStlCmqKkkgpFrmZ8SAcYKGVeVbKoG+mxnBzFr/hk7YBSC6bhIRmfSXkbVIvYr/5UNeceyltTxsLrWTWt8LsEvNT2AuRCydZtOxYcXK5Q0EptDnAIXQoND10s+YZCTn1sHUzDEIM2DvBHxouF+wouywZ8gIoumlZYoSuDaVSHzWYyn5dycrKUo9Vc2raWfT+Tbe9qYAvnu0GM62An6YMSFh8zqdtCmjnUlFAx9zIMW9pglGMhu7GUfYll8lAoKy2jdnwGhbhBZxaJQl/vZW/QxG0T3PGZXsmEMWZPYB6fYcl5GgbfHPzo5yNcI+DHcnaOH+wP7dTl9DWKn5UoolXKalXL2Hfy7PkVRdVPX1zLmxc3crHupdvvZItrBrjqis3PYOAT3jjMCdfHqaW12b2pDZxhfG+3G1nfNHL58lLOn5/L0cMzXXruYNSK0KGvvJBdtBhJhcl+LFW3K5NLfGC5ifrY6hiD4h3JCIxNLSQJdex2s5b1+kZurq/l5vpKNtsNwfcI1SbtWVRpS+9x+0k7Dbt2eq/6dY6wCvcBrg//dsu/2D3QMWLGOzp/PwGdE4gYaH7SBZa8UXDoauR4PehVy1UPhewLXHNKVieWJofhsMy3YRcGIwC/ibS/dwFoh+qGLEKn/uB4r6BafzfbS9dtZbMpOQ4AHKluRQHFGr625rucWmzcqYP+bLCyrgxQe53oeY5/BKWx26vg+hV7OXn0WJbHS1meHEkte7m8upSPfuxr5GZ9I+vNQAXuuB85d/uKCCZ+zPM7qL3ZX75qxXCyjWu7YjbrRwX/FCL7tYkQm1cg8tcAxT3/gGkDbDXYQ9h9qVYxcUz4mhZXKTOhxnZq4UzOSe79bEUgWUyU9h2YG/U+3NN7ek91MsY1+ub8/FxOTk/NHgPzd6NFXkdN7uC7EOULWXaWSVAUVkWRWC1oq/NkzHdqn6rVBucJ+i3jO1G9uTnuKnyxVjKr9Jgo+lrudyx8Oh+bkNTRo+n59LZiAVYdUCNjTqWX8m4nm77TAqgDxrZIJ1vZjlsrILj9LEZnjhw5cuTIkSNHjhw5cuT4+hxfJ8WyPv8bLEy4IovrcXl8fOhl0T88GNML0ywW3DNzFv8+uk9tkMvFgl1BBJdwP1fRubqYPHqiXE5Uy+YtGfyUCY4MIgX65WBQYYWfn8ILbaOCbHBEbXxQNhNO2NL/4DOsnp3RujYBq+bTSsVyCRuRRE2ZCitd8Zd8PBF06xJuZY7KJNEOwAPYCHB59JgoKwEAFUCYvjr0TyzQlfoMH4Zd+2BhYGpPihen/TxRqrqfhF0j9xJm/5i1igO5sgbc2Us7r0X2rRxv5nJ275jX5v7pisDkqhtlC1gzQpG3J0AleEw8WQ/9irVdqaI2/v1VWtKJCNYpl8N09jESIZoMCP0TaGoKYqOmMypB7ffkIABTUe19cK0dzpuViBYfcxiq7WFiJFHRatcDcpsi1AuNIYHAS+LgNQWffj97wuBuz91wbYO3dlS5R6mrn8iBitj7PlGzh4RJvHyJSnpyEYITc2x14g99p/Q8ts+1y/Evybg4VElPVNhRne1+yqpa1g+5t25U1Lpy1/Z/62ZKO2rSobEP7jyDgw3CJbN761ZQPitV00hVLGW5Wsnx6QnfZpG53ShdV8hg41YhsSptk91PjjlRlIdj6goI7S9TCPPtCJenp2rfDz6nh6SM/v+p77nbZUxV/xP7muT+dOit18uU+zYHBfX5rTyHjffouULFMhIG2Ae9zFnI1YsIWgKISRjvmGTBx6TIZboqIJ6zA3v2GRNISX+H8/Wf9l1E6qz+935f6J9tXkcicbejDRK+AwiWi5kmPAnddzLwuw/uRyPrCBC4T+bLHDly5MiRI0eOHDly5MiR410ByyZmg0cqiq9Bu0tVJOqzzfhwCv/bDtXtXbsGdVcJq4tSFnXFJbqbEQ/dRstIDUoZUD4I3sZ4GCaM9YJZugw4YjlV2tKv2GE19wNFLtS25qkK1dmslO24l3LYyZYqwzW9hOfwSpZCNm7xQBiOM8S7WpyrrrVYXw31clFIW+PnjN6qUH2irR3gInx3gycp1MiqbNWCSNYWthXdrGpngIC6GmU+31Ot3NYlvTMFdsd0VYg2IG6DEQCI+ZYqkFAlX83j72TXDTJsellfbaQua5nPV1ReF8UgRQFfaVWAByRMyAywgAJ+XGXNo+H8Vd1nF95UwVR/O693P1Mo6whkDLK52jwUd1SvUYebDO7YgRzWe4sUDRShCnZ2Y0U1eNVUMisb+dC+lEdXa5Gylpfn1/Lv/9PH5OOfei7PL9by1ssbeomuB3iz7mUrqpYPhRXvgGO4xnw/sTwIS+v9lB1MsU3qKU7F8X4mbV3Joq1lvmplfjSXsjaPV/5//FutKGKRS1e9wtfbW6QwWD/oVNX8gpmwgK7UvGDRn1Bhmv+rLBf0eq5RTGwYWaixH0ZZHS1kdbyQ+WUnVbPlZ6BKJkC2BIIq7vXemUHFbvYD7l8cVKeuAlfBtBbg9ISQpmhMHRoTKiE5kRJA+tp6z7tM2f2d3ULC4WCSDLL5RixpoMcwK4Pg2R3bmixkCMkoqN8naulkCAZbDB+ypl5PRoAqUbHCwsYlxhm0r0hmIMHB+w9JIRQFrRcyq5cya+Yya1r8ITnnpArmpCHRRf3V6Y04XhRqemlM32cxtVfxCzb1ElEFcN3I4t49efKhD8nR1aVsx0GuLy/lY1/Ty8VLjJFCxh6rO1AMFJcOSmA3uDCbHiSBzB7YQfZsp0kOvf+9ECjGFq6dljo1HX/ivGKDgPODFqnzv6mKWOfzMPmlF88H4z4xgvdXYrOkZWTjPMSVNsFT3MeSKZ95fEw+2rc4NFTJsBi6vrqSzeZGrq8v5fryQsrK/LUpdVaAq0kmTXbJDuPOVwGg+B52ay2xxJ7f9sG6fKZe/ly9YL7k9EmHRzK9nGMyVosk6jn4NdDFIEjQ0ZCKCS8mPnA43KO7vdQ9VNYYuzqvbNtGunknfY/vxoMxkyNHjhw5cuTIkSNHjhw5cnyuwXKqxqOHpHtv2jM+wDKtLhKFJhgEIFYNKEtF304qKKgcTgEgAFtYQT9CJP7HI5oJQ9yfvmuA25TOQX0ayVZYjg7YCAiHJd7D0PFBHUuKB192b76XLglWT2jAZbXNIOTgv2EnYGC5rmUHyw7adsA+IlWkeRE7g+MOg7WEoWE99e8EJHTFMqB1gGBRpBbO131BTQcdaCkOAfE0wC29dceRhQf7Dp67tgR7ZwX5cLZQqAWG5X7LCj1cRudAD6vLCWesQBt72ZV/DmP8XB0ke9G6VO7LjSNQ0bdNwexq1bIkxCpRBMw8RdU6o5SzEZ7LC1mve7l3cSPnl5cy9Bue7/XVRroRirwZvbsB/NEEPSOFS9aKAOqnrtxol7UhHezOInGOBly0Dyot5FiWHAt1U6m6fXLSCm5C4b8ULruycuJrap+nD4v6n3pLA1dzdS3GIcZ9XUmBscNCcjVfDZIU85o+2xhTuItgH+NtiIXYdOxj3GiBNC+2N00aEC7TzsSEr2F1QLyEwePY780gH7X+j8sPDro2USCb1238u43y9H4O95d9loeIx3XblnAsOrDcVltPVL+JwPkutKtj34wwOD/hdwPdZm3gVgyzstYXrAtKFK7z5MV0z6rQPrCOeQc8z/WkU+Wy3idxP6r2dxVr6uuul6WSer6Q6gw+vY3cP7vPOfnpm3O5KWoYBimMdDDtYJe/OgRP/aTtmuO7wBrnRVvVHgX7K601Whkv2mGn/RDvTh+fCk792Iqqw4VK58ake00AnCRIojjZV5Hc6nArlmdkVucGWKDsdUUJ+uBKoFjuaIHSbTdS7xupykWA6Zhf3erCEzRuuuLfWMHTPYG3LGYZvsu8mJ8l6PhdiLHlKvL4Xaf7s4RCUjsQbeWaB/pKx4KAmN8IsDEn7ESqQYXWnAM4l+FDKOaawXKOHDly5MiRI0eOHDly5Hg3PZax3Bsukvs91bo1lGpQ78GvcdhLN8BzcidDUGLaIzA+xyr3qk4DOMSzLNTNO6tkD6UyCo/1u0Efxt2P131pCSTiwz+Vlwa1qQIzX0tdIg1fWfW87Qc8ZA/y8uKS4GV700s/wGsSyjwDbGgfiyTBT7lUL0vCQgBLfdjflYXArcLVdyjqVFtxJxXfAfopICDT5inAW1qVyg6YFDjupSxKaduGL6hfcf4KFxVm7goATXcydvYC9ZtTRoUPVFGXpbRNTeUqfEHrpiFoTIuseaRINYZ6jgLfKIyL0CKyGAehRkwBjniS5q9qVhvBwdbhvlL+iMUSpqZjo5yAQWAoLF33ngYUWa729BV+8tooJycbub5Zy9FyLscnL6SpG7npRnl+NUg3jPLyBuNpkK7fSw+P0QO4rP9yJXPSl8bN3AfaUbF3PoBgNRNp6hm9sU9PV7I8Wsh8OZcaBRNDt7rvtBXV8kKJDpP9vJKCka64p99q2l+2L906Ka5lQBT3EZWdTB7sE6uMeI8VvBfjuQT7lwR5WU4lJAsw9r0eGAEqP+hiY/2820Dguvs+ZwcK5E9XD0yvtypbk26+NT69TYfezsHexbFhSotdFX2HCDiOhOilHD0kpng7jBaoUs17NwXr0wiS7k+jPv5swF16Mzq0PFDyTo5hBSU9qeCaYbxX1VLP53L/4UOZz1t5+fwZ9359s5Hr9da8jL3IqUJTVRDbRaXFjqfTDILyOJ5U0+Zy5YlbYzDPl2QkQsLJ4bwm2/g9YX72YeRTiI3vBcxT2BHOzQudqi7ZU0nKw+1s3enIiC/OA/O+Jkatl3jv6CoGHDeMj9meiSNOT1a87/r6Rp6/eCGr5VKWi4Xa0YRxKVOF/sRtwyxUmOhLxp0n9MLwsQScXSfYLtEeqFCbILdnCvOIq7CTa699kUi43a8aPtVYY4TvFSYYsRsoq/XV95CoZ7CcI0eOHDly5MiRI0eOHDne1eJ9CmCxWBdgeQZFFTxuR5Gu3xHs9YAwKMyXrFAmqDBlH/TJBLjYoRU+4lJiQOn9KB2kVAZl+cwMqEG7BX34DxYZ7u9sYFkVWHj4VqCsPrQAy+ozfH5+If12LaTDOxTaGxUe06NYQTKXtRtYdtsPqn73CtOhzHNQgqJ9Df2XXZmoiCycK8F3xYd5KrJHnLkCaOwCEAPFDAGW523Dwn2qRLaid+w79KMXMosKQacWM7smTTWLYHnRSgM/1bqa+AhPzKltHxGsuvWDLRlnwbbgDxGhsr0U16olCTZVSKj+zalGT31IfRm8qR6DglWLyjEBEBuhcNWAf+HXc19I07ZSlTXV2Fjmff/0iCpm7OnyppPq2Y2su152+0HWHZqi41D3nBxftZQTIBehpsJ8BbSRfbFdUFMXe1qiACyfnB7J8mhJsKx9HSW8CuMUdgXtZGJHEWyLLfGiBbzM0sU1+gartMkG8p0oo1/MAxZjXH1/Dep6EqcsCeSCF6ydqVp0GECfKOIdLuv4RWBbXbpvbQ6r5fV4alERfYb9Jwt63lVJMR17BhWDP7PtNxSqnChZ47EV2qNdarYTtKhJAsStZPScooo8udrxp6tvo0H7BMQH1bb34yuhsv88MIt+xzD5trL70x9HC3A6CL2dMEr7NYHSGGcEy3s5e/hAutVSXj59qjD22XPZbjvef55Y0aKJZjHE3ZrFDMGyjVG3I6FdjCWVCC7xF+23KQCdnrP62KtCWKGyAtTgkcxxpjYyaubi1wvHikX7PNnmSZkoW7bEH+x6LNHiw4s+xDsofbHywBNEqrqm8p/b7qTre7m5uZYXz5/Tu/61J4/VMihY2qQqaPd0dq6v9wnvJ/o0pwkIz2ihjzxhB9Kt/aBzoL0s4YeVQThGnGPTOTztZz+OgeWw0sQTnKXIDkC5ktq8pHPkyJEjR44cOXLkyJEjR453UbGsD95Q14544If6aVZJN+5ZTA3q4PR53rgwVb/6Uq9eKgDh9bjbSQ9o4B6V7uVrXr8qlIRvpYMjVywrXEhXvrtKNijtqLTUDfBzGPeEzFzmbEDMlXJQYgMs1670NLCsoXRgs+25LLmsG3aYqqpNQWZHURBgBf7Mr1lhcfSvDDDX1GZUHMNPGMf0pdgR9d0SkSm4wef0XBfzRo6PGzlaoTDXQuYLBZ3wAVV/at9X8vI2mwqcSkQW1gPo9YKJiqsCyjoo2OUgR01LIlwhoNJsQhwEwZfXlKKuzg0wNS7fdgWre54CtgDg4E+wnUA7Tu8d8ydgfDcMcnm1lbK5pNpyHAe5uoG1iiYUhsTymUAmqPqi7UgygiZQUZMZukFbFrJaFHK0auXkdCnHJ0tpYINBL28DSTzIyL4g9HNYnyhtg5o3jBoHm4nKMFHb+tvKfNwXVhWjqsLcyXbTyWa9kfXNRm5uNrLd9lT/R3Xy7Qgw+BZqdR9cV5F6YsOvj48hY9yHBefMHuBOjhiO4rYA0yJucVm/FUK7Jcg9OJb//yCE9f24XcBBIcXQ8ZZUSSA2LU8m+DkW4AwZBvb5TKqqkLaFJU6grVEFHbyD01kw7feDtk9g8aeL9I50gDhNNvG/UaGo3leh6mPsH1Pjcv7BXLbfyfG9Uyp5t30v65u12gcR4JtLcbC9sfuKBev8OPGyeKKBLYMNg/ehd3LqX2GZDJ+D3XLIVwzo9fc9pOPJkxhx3nHQHcbSwX3sPRV7+cCKxIeM/ztkRW0VBWyLoGbeC8E7Cvnh3kK+U4th6ndKuM99Tgy5plRp76Q7TUKEmdHGkCZt/R6N/sqe8Cpkh9Ur5guu35OJCjw5t8Miqz7LBfhuGwNQx0KVOXLkyJEjR44cOXLkyJEjx7sAludlxQfYq2ErXTfK8Q7FgGYEy9fbUTaAw6Himz4RAzQAchFIFyh6p57HeA8K545qZRQ+UgXiYMvpURxQwaqprKi60qJ7KBToPsDpInYFy4ZgAPj4ebUXIFSm6s3Ll6mykgX66pIF9AAKCWT5IB+Lmu3GvVyub3guJWwm2jYoQ7WtgMooiIZihnjBAkOVyq7c1KX7oBT2IE+wOpOmKmTeltJuS3rnDonO9TZ0cuUwvJ5LgvB7J0fy2uMTefzogZw9uC+np0fSQkXb1vTgdUji6r+ITxVqELYSWKgHNMG8+40S/uyl2KkHdqr0I9DBf0OvOkIv2udSXFNmq720q4ANq8OHtjBvYlqFJJ65dhxtxygz+E+YIhfb41q8Xj6WRw/P5P6DU3ny2pmcX97I4489lYuLG0EZyBcvL2FiIrP9QP/l7TALpfPQ6nHfKfwJauDYzwGmu9UImj+byVFbycPTVh6dHclr7zmTR4/uy3y1kKpB6USMESuE6DYBBt+C4teuPxXywT/Y+14BO/y6FWU7SXb1sHkgs1DeaOpHGKXuZOx7ubq4kouXF/Ly+YW8fHYh19dresK6Ulv73kmwjwVToya4lgryhP0BoCI3oXYXtpw/gdXBFnYC5lSF6dYvAcwehKvwVX1sSn+DZ67M1DpnKZSLx4qH1iM4d3fFdwT16UGd2zvwR3++YtugBbYCaSi+RqAv0rSVHB0vpJ03AYLrCItjgNOqi/E/bXwmMG+qWL7Vp7udDMNgNilqhaOC1sNifgqeq8VCiraWx+99Q07vn3KP3WYtm20n51c36s1OayLohOGBjpUXSFjofO7KdmeauDep7vZElI+3ide8n4IWCKW63pTK+AnVcPArdn/3SaLC4KytkGD7YJMxjvribeOJMitzaKtJ1AJD2zL9mkquP/ouLrRgO6Dmnbdztuvy8krapmEitKw02aL9YPd4sOWIenpaghgA13Nzxb4qtd2ySI+H762R3YNz5wof2mEUE7gM//t4y0WbnbDvSdY1gcqwWbKxnxY4VUund6quz5EjR44cOXLkyJEjR44cOb6uxftMgQwrDLzoizzAz1b9laeqR9PxcSkwgJ7+5IO2wSVVeCXKSVMPesTiX9PCRco0FF66Js8LfgURoVtT0MPWl9nH4mXYBxXKQTVntdNcamZ75QJlQLw9wLkuaXYlazgU4YXCMV/CbCK2yXmmCuDCln83TS0NVMYE5mbvkTpKHy7nTwo2AYYvlwq5ykrtJYLJhSndHCgquPCl1ynkUeVj6h9LVBL4VQImscw7KdAW8ButQqBiN5Ro+w/FBpPl31OFrh8tnpR/PoAUs1BBjgBghAXzZjNZLue0pcBn7t/bUF17dnpEeHS97Tg2153ZX+zh6a0qSiizqaRPPE7ZztS+wbxacU3Atud1JcfLORXLi5VajlSmCo+qa4V7wfbCFNxBzRxH6SvvsTCqJmLXA8W5gyP7uy+zB+wakKAJFiDYyCnZFEWm3tt6iROgnoyNcM8FP2iFYKkyeapONl9dL8g3ezvVcuoB/k4gaxzbU2WmA8W3+XS6oiLIwdPunMLbkLrym5ygDyAumTeSVQvh+tjm6Tz4GUfasWEenKacJpJhv9/DFoevA6Uu79WChfz241zmy4Usl0ude65uIiyl1ZAVD7VCdUEMnqhyw6knTbKhEBSyh+p89wmO7094aKLw1jk87Ne9zJOieekc64Jx91uOfai/exPT8R0UwQd5BkBv96yH13LXozCqrghwiB59x32lQnKPhd/jSoHkwlqfuY1HvB9iE16hfk866FAVffswnryb7nPqv58Vyzly5MiRI0eOHDly5MiR4121wqhoV7HdQbU8KIjr91QsX/Ujl0+rlYXCQC3W5z7HUH4OQj2W+bNyybV7yUpBzLej0jQ+RmuxIvNAtkJ4LEsGcFVA37sPx1K1cArusGy9orq3qtSrGYWQcB5UGBdCxTKgblOpArgqYZuxk76nabPMCihS96GwIKDCth/UNiNd7gyFJ+00VK0M5d0waLEo9dg0KGV+l3VV0rLieNXKowfHbPGyqbmUvSNYth4IClL30vSiTfA33su902N54/XX5OzsmJAamw3jwAKFIkdsByAYzhN9No5bwkb0g4M4h+1BzexknmTcQZGpWb1r6X0MyDjwJ5WDUNgBtJqHdljt7fYapsilOtoU2859qERP4MZsNpoyHapA3a6qrdAWCikOo9yrZ7I4auRs3cnJyZFcX22kLSt5eX4tp1/1Mfnk0+fy4nwtT5/fcGyuBy3GuNlpomA79tKh/SnYNulrCTUlrkk1k3lVyGv3V/INPvBEnjw6k9dffyT37p3I4mglVTOXoqrYH1BOjlBwox+432iv4C6sKYhypIOxEaGuLYoPRbocsNny+mR5PItOUlKs1i9Dv5O+G2UEBPREUERIB8fU4nzRasB6nsbiOvpg81FWWjDNC/l5GybKZT+f1CtFb78JJJygMlNKQqFJQGee6tPieFPQyNmBMM/7KaSjknsl9m2wnbFxNS3iqAUKocbXvnbbGE9eJb651lZf64BbCasdUNytaVot8kY2qMkVvdr/OfxqAx7V37BSosR8hXnL5dKHoNl7wH4vCpkfrZiYenjzRGbjTt56+kxevDiXcexl7Lecw0bvHgLemKzgXtyTe5Ig0Lk/JA8AXc1bXseaquHpt29F+26fl/8KiyL9mnLLHYwDjB1484895rUImMM5UpafWN/g+ygUlNWCeO5mr8Vei4Nu0n0s5gvZn87osXx+cUEP+/V6HZKDdV3IetPJOAwKl83yaQiQeQp/9VrZOPX7KCy1iYmjFLZrEicZj+49bd8v/oqe59FPmashghd5UvgSynq+9sG6JEeOHDly5MiRI0eOHDly5HjXwLJ4ySiqlkUGQcE9KJXN3sJBWNQFBz4ZFMthP7fqK8U/uLjOILGCCLMMCIpj294FuanCNQA8Cb7HWlAq2g5MPmM40wv2RYDnHqaxsNdkuXHSKyymZB7PrloGHI0K06mK2AsGokDUvK3p2apey4X1YwrVUtVyqkEEKAfYqOgB6udyqFiLsM6LvClYS/nf4VWO/7J+INyNasOoWnZ7Bb2g4fLZaU/FrFMQFo6k0kHra9sm+DJDhW29oS4ghFPYOT2XUesPEL/fS1VWcv/+CX9/9uJINtuNjP1Obq62gjwB7TAAY1AociYyDAoK78JhKEoI1FWXM5nXhSzntRyvFnJkauV23qptCgvoRa/t1O+Vd0vSAcldERTdoR9cKewesUG1nwDRoE5MGpzsPRT7c2Wxj7c7r+z0zUkhx2TpfGi5qzkn0G+6tynITbf1BMZB+8M+fMw6+E3dvQ8xo3sfJ+9M7Exsj6FI4oRsh5/h7k7aouNVzzN4X/t1tCMgoZP6AmvBT1z/dCY7vLESL+S7RKHvmOfdtWHaT1Y1L1EDhy18tUE4G++BmcwAbWci7Xwuy9VK2qsrLaI3QKEMBTySRj536TwTZgZaAd3VrEN1bUyq6DwV1csTBXNy/8dRFKoBhuvkVi5RrTwdI7fUvYT+UbmbrAeJ/TBp8/6WYrkwxTJegMhQLYOOR+uWVC18aEsR+8TPWRXHoYHhsP7522PX4XTiP32HUnlaTNP7Mw7N1Os5XU2RFcs5cuTIkSNHjhw5cuTIkeNdBctbFthTQIxHUQDQDeww8J6VcTtEOVrGTB+TCaUDqHTlW1w2HCxv4ekLZWoAtUVQ3qbcxhGRFnlLfFqhKjRfSn0/YgZAbkBwVZHuZFfOpIO4dw+vT9iiwmvYiagqwBBHR0tadCznc1UlAyalnWPHD8vjDQl6QUIosbVgn4KoChYL5UxWi0YePVgR4JysWnbAsAW0UKVuBNGu6HN7BdgewPPzXD71qTelLkfZvXaPfwcEgb0GQo9vvp+ufCVccQmiqflMAYhzhF+wHxPnQuOPnXXkTpWe2JYwnUpCUzNT1QzQqkUedcU8QG6fikljkUKVOx7S59ihlgAIZrW2KXu+rKXYj1LtdlI3e6maI1ltsEy9kOvLG7b9wemxvPmpZ/RH3nY7ubgZaNlyhe3GUZ5filyu3Z5lx+vXlOqpvKoLaUrsYy4ny0beeHwsrz08kYcPTuTk5FhWx0fSzhdSN3Mpq5r9QMhMAAZrCvQP7gtrOpIa7HL1sI4970W6dlb40AgQO0+9r/GTTgRBRO74W7MhJWxUoJpFkUMmT2IBxqhMTaWYycXwPqf63P7iSZSJAtThaQRf6UtVqXq2aEPcxtX2qXI2aYUlWbQQ2SHcOxgR5n87VUwHDbZ5rpvVSzILaULq9v5SAOn2DlwJEZTeantRQS0/Q4KhkkWjr3kLtXLNew0rBZzq08Zn0tVJAmaiYr5Dcf0ZR9qnvlLEE1JxvtCL4feZ+utqNsgTQ5AOV7I4OuF2sDg6e/CmXF0Vcn31VPrtRm0zeI0arcvJ66SrIGJROfNqt+KcvJamDPZ7g8Cacyd2ol7NzpFdXetNDwDW5jsFqZqsoyoYynpbVcHVM0GtngLpuIIlvTTqsW32GkHxnVoBMb3Ec6zLneybmVyXJY8JsAwfc5zj6hjFUhspi5skh4j5JK3fqAOX3191FZMXmOsHrPjQ71RFzfoNSb9qgGsbu9pPqWJZDxZsOJIVDcGWxBNltuIj+LSPXjBX5ykq91m8LyuWc+TIkSNHjhw5cuTIkSPHuwiWB1vi6w+ghLRULcOvljz44OHeC9elirKUusTH/MiqXPnpNgkKGxwa698Cjo6qTIN1WtjLVcNqGaGF5fyz+hBP4wUUWcM5DLDQAKwwAGaQQ4G3Eu62baQoZ+qFXMa2pELEeMzEY9qVZY47DKRR9QiVYF3K0bKV62VD5fKmG6ToRy5NdtDhRQL9SCyCt+vZNqhyr64uZbNZJYWYAOFVhRiWPidLoVmIjCAzwrewHYEdLC1i4biQCKBSGX9XuIzfqcxmI3fThIEzEOx3tGXc3q9cFq8u0g4CA1ZLx0nSj74FIQhOrQAIL1k4Eada1TNpmkHun3Uyb+Zyc31Db2QU9NreXMtmO0hVdtINem7bfpT1ZivbrlBvYhbq20tdzKQuRJYoqFjN5HTVyL3jOV+nR3M5Ws6pVm7aRsq6jlA5UcLTqtkLdflVwxg2oJRaybrSOwQ+42AN4C7h7pp4cRBty/oxnqCapdWJJQuCIjKNqM6Pb3lRt2QsJ0ppv37pPRvv5bjE30ZQonp24JUW3LOihAE2e5smjUz2l9xcCez2JrviMjm9ya+qVvb5I8XM02OE90NDXSlqhd9sJQMU7Cz0Ccsce6G4oSuW4yqHuxhxvHcPVzrMcD8FCHy4/at7JmwXfjVwzURE+je1oNBIqi0GpayfpUjVtrKczWRxeSTL1UKGYUOrm3Hs6EuOCZ4AkscwKMv+da9tnac1IVRMfuocpj+RlAvjKp3XUq9jzin6FyQR1IZFoasCUoXLQeF/0G3+vTAd28kmPtaDgtnaENqjMJ5pNdgnVTrXAeACyKI4Jl5HSCZWsFZyq6LEyjltgsFbwWoLf8vuk6g2NjicFNcskiqDt1YsGAyf2NyY97T/7isjZimPNisqTSbG+zHYLeXIkSNHjhw5cuTIkSNHjhzvClgOXsEz6nrxf3wUdpiWLF6WxHvYH3CpZCWLUC/SxLQiPNADvzmApFGB219MijsZojIpsiprAYehN3Y1nIFnel5Gb0+wCPhyAlGUaAM8ifuBYGDcAS0qLEEhPIIgAhHh7wDLgMrTgmWKjNEvXd+J7Adp6rlUVaP6Mxb8w4P8YB9SJfEO6u8ZFGtbGYa17HedtE1BJWS1GWQGiaodI/Wz9eJ+9I4uUbyvkQaAE3DTXGD9PzdeoA3JqOpABfCRwoWl0GbhYIYRVCjvCLehXq7VU5mZA/OxNi/gmQyJChLHcLCcLLXmUnovDgjC4prCZCn4oaLWswQTFauq9wB6vHgXrT32gKvYdpTVUSdl2cjZgwdSAIBxV+qNfXWzla4f5dn5jay3APMjBoMM40y6YZCmKuR42RD2P76/lGVby8OzFYHya4/vy72TIzlaLngNWKCO6YlglqLHSoEa4XeoXGbQyYufuZ2I3g+WBTCZP5IZDkdN4V+UOkah3LQVA/SJNahMtbJZOHifebKDbfRl8rx3UpW4KqythWHVvN/AvG4OAN0qwuxw0iX4vm38qMNGH2tBvpnAUFw/3Lvu++qqV29fAsXoixvnC4zNVBUdJ58kfRXsVQ4Q+6TgZzIW023s8xjrNZXshSzblisM5m0T7GdUaR19a+lZS6Vpqk52oJumCEIXK1x2Vf4rQ++3aQvTn+lx5I73Epdvn0Z5P8fkQlE1HLvLoyN5+PAB/eY//jW1dFuRvu8IVJumpEqfY5yQ1IG2dix7E57ruP7wtHe/fVMqE67a3I65AGMY96kXPI3JD/Q9Vj9oog5zZVTior24D+An3sswdLynolWLXdlgzaTjc4d5gnMauuUAbCeJQv3eCaOA45NMmH79On9fXFxQ9Xvv7BH96pFgqqoaWUr68evsYOOXqm1NaEYrCltBYm21u9WU1JaQSZNz+K6amNzHAaR8PCYQJ6N5AqPjOAzWUj7HxuxDjhw5cuTIkSNHjhw5cuTI8e6A5d4UywAHwVfS4XJgGlOnxnT5PpYOKwvGg3fyIG+2Cw4L9XnfoXLyU3dvgtjocQGoRpUuy/oldhmhSJGicCjl4K+Lwm/BO7UQ6QeomlHYrQJutiJYVtyOwG4mNRRrAMtBFBpBGQHyTuELWGVZVAaWobTT12hg2ZWWWII8DKMM/VbGfi2y62TeltL1UETrMcMif1grGKyA4g12GwBbbV1I68vxAc4dLQdIY7ARdiXDnvutCwBomRTVUpBBzXl46XJsvQ5l3RgcgVpRlXwEuYTROCdVHaPIW0wW+LprGo4EeJVC/+i5bOpcwtZk+bi/gmLbQHFluj60jx4RWnwNKubl8Sh1O5d+C1DcSoM+qnH+g2y7rXTdIJ94eiFXNxvZ3KylX6+lH2bS9Ttpm0oensxlMW/kA6/fl+PVXO7fX8rx0VzO7p3K2emxtMuFqW8d4ka4rGp0/I5+0vsDCRT0WQCw6BI/L46xxNeb52KD3CBrAE+Ax/Q2x1gw/+AEKjuw8zswWKcgmWLqRx4jcDf3sE7W7KfsM006TPxe7fruDrxaJ7YY6ayhAFib4+Mi/ZsCNoBiqjrRnyHxkCQXfJTcWrKfSJUnimEF6NF0JJ6Y87lokeCJkfDhcC8BjDaVSFuVslrM5WjREixj5QI80UPNNVrN6HjnjTNRVCf3ROzdBDYfgudXxSE0PgSNKcS/o3/0ZJM52hNAevyiqqWoK1keH8vDRw9ZOLNpKs55ayp0eylmc6lrHTteUC+Iam1Vgxd6xG+VzUWcnzgIHLZ6AT8tJOo+5UwsGvzmFg6rC012sYs52gCWB9mhwCAKleoNMU1senLTzWM4XalliheLTLOEmgxBO2yOpEe+gmX0ATzc60YThhfnF5qgHDGvYr7Hq7ZEV4JvuYJlepXwfQA7Js6ydkOmKzLi0LEaBSxumSYED8aBzeNauG+6CkL7wMeWJXHZx9HuKCRzcuTIkSNHjhw5cuTIkSNHjncTLKuBQgpl1EPW7TFM98Zt1es4qtjc31e5sWEbPNDi5SuyTcXpy3LVjxbA0Lxj8UoEilR8uqrXfJ6jj2Zxe3l18HPVHRD68rNQgvpDfVxCrv6WbiuAZdAK6RRE7ah41od/Ba49QMNsL02zk5o+sApYdNmxNgxAeRxnMqKa3H6U7WZLz068jwJ+i3Evi7ZmsTlYj+AVFKOOoeD32jaynFdydLyU09NjWS7nhL1qm6yAC/DcCyg6tNdl0gmYMAUtFNW7HbyQAdvdpkJ9gSMp1OsQKsSZupmAkDDNlOnmna3KUzg0uxraAGTCwxwohzdT34IAQm4DNF6TYHFi+sBSpG5bKYpKVico4lfT23rYbaXvOrm5uZK66+Vk0xHO3ztu5ea6lb4vZNvBP1v9rgGWz85Wslq2cnq6ktVqLkcnR7I6PSG0rppGrUaoiEeSwhTMQfmNPsT49vFvdix27lFcawpGH7QhWQK1syqfud+gkDa47jYioVDYlFy5B7HeH6ly2bmuAWYrTjeBwtyDKlEJ9DBmqHKOBQnDTXZgieEFHoOy1sYdkjAOtt2LNwWewVo6GZnBBicdqnbPajLKQZhVdPRtkrET3BYOlfCHZcpc2RwOrqwVQJQWGCiQWZbS1KW0bSlNW0szb6VuG6naWsqmkgJ2CJYo0KJsaJcn0EzRO1FuJ+rjA7uR5OaI53ILKGojZ68E14mC+wD0R+sIUzGH66q/jOkYAAEAAElEQVR9AWXufLmkcvnk9B6TWv0WSTCMdZtf0GccYwY2rZVuPeJWOmG+teMEmBvGm66CgaWNt/kuvE4bisLvK0ty2XwRECra40rcg/nCrYjcoknfMeBsVh5UxHuSx+8ASyLihdUqSOIhQdcPKOKH1SYd4TbOTdXrt3IUsaBs8JH27k4sa6J3y+QmcHsiJKPcPz3403MTTQL6OI92LFGJntQFDAmw1Hd8MuXmyJEjR44cOXLkyJEjR44c75oVBp9/41JzYsXdnq+B4EgVvV6kTmGRwjGAJSiFudWukJJWv7BXUDVYgQpZVCdDNQxWq6CjrLRoX1DhYivy3D1BA+AIHq273cgiW1C1USFnx1ZnAvc0hrpzL2VdKOAddjLAjsIwNBS3si9VFWeFpUrsiIAP789kIMcapRpnUlhxJYJlFjLsCODqZpSq2VOVBgsKVZkqXBoGQIidqmW3W9ncXMvNzQ2XmR8ftVI3tZyvoSYr5Xo7yE2H4k6jLi1He2ZCH+BTAOWjhbz25IG88d4ncrxasIgY4AbtP6BwThkugRAK6bklhimgzWdz7AC4N7rkm4Ba+5p9UsIyBDuDuXFtMBpwhx1FBfqwBxy3vmYBLvP9BXDhT/RFZwUCDfRwaXw1URhS1RiNdCdKu7iU34CpCy4JQDUJsTg+5mb1/EiGbidLQPfjhWw2N3Lx/C3ZbjeqwLzZyPb6UhrpaGHSbUpZHc3lve99KItFK48en/Ln0QqgeS6nZw/kwePXqNIsS/NWhhZ77GVWmQrQbD9wdlS8QyVvlghhmXqyDD4iMU1QYN8shghCDrsPIHkWS4xwnap/3mNWFNKSB3p8XHuFyiz8ZUXSVDWpfYx7lbCX/az3Ea5j9C3HbtQzOhbvQ2N599v9P4XLhrYShfyBmhnKT/ZDknTgva4gLLV6dtDorimE48ZgmaTxvjNVKYuWucevnYHnTlyxara2SaQQ189niiKpVkaRzRJF+0pZNKUcrSo5Pq7l6Hghy5OVLI5XMl8tpVnM6U9cNA0TAwos0V+sChrKVUa0Nx7YU6QNjCYdU9DsqwnuOo/D83k7hbMmJ/Sa4qcrrPXvSM6cnj3kffv6G++T5WIl2+teupueqvl9P+jlL3FdXQWtxjhMwVmRPn3ZagMm1tSvXlcfaJpSEzL6E5YjU8gZq0fSoxnju8D8o98T+EbiygAMfaj4MeZx3ZILHr+ttAgpkogKeXXcz2ajQWUFwzYp2qfsGrGvdlLVJb2nkUzabNc87+12LX23pgUSxkrP629X3ZM3DtutXiBt+9kWLSIbmpis3lH4PHJew/cm1fxQyJuyG38f8R1m38luF6S2LL5Pfc9yMSHB6/ehz/H6eieK+Rw5cuTIkSNHjhw5cuTIkeOzAMvpwy8fZJMl6ZNHUoAwU/wRBoRH11DCLnyG4Avskk/heNgFUMPztIJYwB0CHsBiLttNgJw9FYdigl6cKFGGpjpMHpMP7np01OrjeRgM34WH8PgJf9wOqi+zFQhFDA+KiOFXgGSAYgLlxIfWVar+O1WlwSsEMJLOG1IVM75iwUJT/rGI2IxFxLAUH9YN6JtQlcnUqlS4+b5TxSdtCpzS+Zu7qTLOFXZ+0gHQuYdp6MhEeerOu15w0JWyqlhWuwiFmcFfOFUiB5DsHZ6Oq4QSsmBdLILIj5JTmtJXyTS3QzE/XNF2PpfFcsW3u/WC/biZz3k8KJKPjlrpu5l01U5Wq5bvzRetLJZzmc/x77m07VyatjU1tFqlUJ3qI8TuiaCcTCW4YRT5CR0oYw9tC/zeoj+0FUJzZbfBW0CwYB0QFP6pD3mU/0YgHdvn41a9fZPLGVix6WMj14tQNBnvXlwy3iTp8DDMGy6fj58IsALI9iJmd6EtTxykQlcbO64rjSfrPtB+L6rSVFXf3sMxiRHvCh9yyT0Ur0hQmyJxAwsaKJbbpmESqGIxz5pJAYyNUEBRd3RgM5LOfukphoGftMz9oQ+VyGnLkutyqMK+tf/pp8MebPfhVraOpp9w08rq6ETGYZDlaiXrmzWws6m73Svcufl0ZYgOQ4fnar0QoHmimudVDMkWVRyHtvkcm6yQCeM/JBoPztoHblq0LlEgT3o0WFYk49pXENjn4t+QXAO0rojEeybKRtofbbfbYFfhyRFvhiZWfPVOVC1bN995XUIfpcmZZC6+fZ0TFB/uzzQpkYzHV33y1lyUI0eOHDly5MiRI0eOHDlyfI7BMqAmHpKhQuPjKwrh0fsXgKvUB19TExcz8wqGdyp9MWdavIiP6LZkGQC1KuhVimXmqqCqCBxGNUaWtiqkqmZcio4X4CeOCWC7NSZLD2ISYSzPnlGtTMWzgVQojf2ZGWCxbioqlmfwG6YqGkvZCxn2IuthUFUYIIIVQ6Pi2orpoeXKt9UGAuCK/qGukKYqeSebLQrFDQqX+cK20dMW7YAvZ58AwqrcSV3tpC33Mi/20lEBV7LPKysktqxm0taVnBwtqVqGQvbi4gWVscPpEa0zcNyiHKSFPQUMQhLWoPDXryg6T9WK8ICGAg4bBz9mK1SlS7FxnXcyg9Rce1J5buGq5MqYGtS20cYkwl78wHmrX3YofOdev4Hc23J7/USEKgFAG3SlKjphJsgS8HMKZMu2FIgbj4oTaeBdvbmRRVNLt1lLXRSyhkp82MhqOZOh76TfbgiSzx7ek6ady4OHD6SdL2S5WBAqH53eo/csFbfmSapFBP38DASHzrbkhvUvEiWeVgkmDlCx4x32l/YHERQKmlVQ7uP+sSX/alTOw8DyBGMKhRtxjVgAjdCrpM9rWUCJGlX+XEGAe4eK5V410rRLcQsA8zc3aw23iAkVBJO0EJT8CvjsCiVJJvp165v2aT1T9dhWXXO0vNDzUSDn75tti8F3tygIYJyF1zxJ5QkNu+/N1iFtre/b26nF/xJGHsbWtF1U1CcJFdrezKBYXsi90yO5f++ePHjwQO6d3pfl6kSqppWmWapHcdWKlLUlBhxyusfzISROI+hck99fQe0n278qJtmLg0gSOpPiftZ7SJ4gITPby/s//GHZXD+R3bCXk5MTOX9xKecvr3QFQt9z7BaNKtyL4KPsx4UPsqnxOffoqg9YCvnQUrUtxmRJOx61n4h2HV5UTgXySC5WsqNncy27aoiKaLcgSahrMBDx7wfOPTTyVo99859n8g/zEVegpH2tY53K4N3I8TOfw799Jutr9MFWnr94ziRiPV+yqCHagm3jlKUJS1tmEFbQYF/0FdfKe5MkiycnByxWAcqnHU1Ie0RFvie0mKSMVkchCejJN1tNoWp/L/Jnqmxb9hHu3Rw5cuTIkSNHjhw5cuTIkePdAssOkLyoUlSchppQwaNYt/Hl9vr5qe4uwimFYgrXUH5PxX260xqA018VrAHwCYNio6kPzRJAhb/mvRuO4wqyqLLUwntW9MyKoBEA7NW+AgC3TIqyRYUbHsBtUboXqAr+oQbFWLxK7UGCatgLmoWCf4nC1AqMKeLRPoFqmS8KcQ2+QrXNon0lC2qhaF/T1AayFWC76Ix9wWP78RUypCAtdpCpDxPFM/F54ls7hVS2nV/zoIhVWBz8OxP1o/6MqlvfLl2E7WPCZbITJWJS5+9QNR2UtrOpF7FD56quZbZr6fM8LBa8tovFgue8Wi2k6xYydDPpq720rRVmoyK15e+Ayg28dOGtWilA4+r4VJmaKILjMvRoexFVvgdmwq7kTQSG4Z+uxgxwXv+qXskYt2mRxqla1AvzBVDoSvJbel3fsf89Dom4oN/hfvxcVEWGC/FqgAlweKfSMmhFrZ9sv+6rbOcbC/9N1Z7Rkje936NqOT2XW20KgyrCOj1nLx6ZkLvgNg5LDSS5KqqWWTATEJ/WKLAwAVS15FgYG6ly+a423O6PuzWpB6sFwqW4W78az/32fTvtiihT15Ul1jabb2D3slwdcVXE8cmJbNdr2a57uSyurShrvFenivlknqDdQ4TXsNZJxMFRUWy+9Wr7oolLTzil564F/nw1R4TKQVY+mbdSTpyqf80SIgG6s4NVJb7iJb13cSwUGhxQBNbg73bbyWa7kaJupblTsTy97NpXBs9TUbF/hcQpMCZ7kmt4awVBvFVurVCZrF4I7bL/7lAv58iRI0eOHDly5MiRI0eOHO8qWIZ+FZDTnDkDPNYHbgUG9FrGc7mpMPWBX4GPa6KGGcrs7WWE3+5+J1VZybJtqHou6Z3p3qkip8tGllCfAj4XItthkMvNyOJ2W1NoUd01jnC0NWuAWOE+QjVEXIoOSD1vF3y4b6ACnWHfvWz6LVVpy6ZW/84EpBEas1phLMZUFnupCvWeRZErwBECF7JBwJmK50JcHjjTLoAbqEwd1AN+otzc0aKmQnKUQfpxYOEmiNvmTSWPzo5luWjk0cMzOTleybLdy7wVaZqWKj6osCGE7cc9lct7wOiqole1CjZTSAh1MryRreAcADv+26vCNUhJHd64QhiviYouKkjDtmEpvPvXorieArdQ4w7nxUxBAvJQCAsn4ONHXVuts+E/HEG0Hs/gEv6GXVBkCBqp47CgfB0Afi9HJycyLKAq3FG5LEUv8/lMNutLWV+J1HUjx6uGSuDVfC5Ns5D5fCHNQr2vPRlSw0sXx8Ux0Wf0MKlwMPq4ohF7qDndn5UNsWSCgS/tA9MwOzRPqRL7xf2IX7H43cYZFPO4h3xsaQfpxUGSg8UYAfRMLR/tOA4h3BS63hkuLL8rP5H+MskpHBZUU4/l6fF9DHhBOLPltWSBjwX3aI6gLaBh46QRFtJGJox5k3vfdT4G9bgaIfiBcIZypKz39ThQpesv/K5qV/XHhspcFbvpjZJ02sSeQKw9USk8hcufHhpP97N/xSv9W7wWvJ/4p0SxTJsOazObr/05PzmRaj6X19//fjk5OZaqbqQbeum6Ua7XfTwGB4B7Nfs8q/Mz1Pq7GfznscrBCkPuNGHGuaFEH5Sy2w/2+YrfBamTiGUbwxnr94EmAJEs0jnSLDJwLThOUJB0F+a9MA7QHlpzq/GxK/ZDb0/U+AUV0u59DDsUFPuE2hj7uLpaE5a3i+NQ2E9V9KqCZg/jew02KWFVxkxQXgDnU+5x3jssPLGxqwlOnDf81bECh/dw4p2symOfNzUpqNYgM2Gu1X5XNbInXrFyB6tX1MJE7ZJsn3fMBTly5MiRI0eOHDly5MiRI8fnHCzDjsGLyOny8wiMvNgX7C7IAFy1ZwQ2qLhsabwaA+iSXDwQE37OSi515u5hiTATOZ43cjRXMAgrhbLfy3owHGPPwmo1sZO9F34Lnpzpy8PsCIoZFb8sVgeLh/1eum4n3dBJCQsIb7+r3aJuTQXT+GHc3BWesDbgsuvQBwBOqnDdsWiVTDxlASpCMSZTLAOCtCDdUkizgWLbTTf20lSVHB8tZMWicks5Wi6lrkdaaEA9iUJ4BCFU00F9PUoxopAWPKoBir0LIpBwq46C9hzqhQoApIWs7nppf/iJEw4FOHo3nAg2B6EwX+yHqH41UEI4Bcqilhpuo+EWAYeHYM9ZsTvdwCTlPK4uewdUhzdsu1hQbbofe+nbWjYbFPpbS10OIuM1x+C8hY9uKU1V68vsJpAAcIipth+F7AnK8ULfwrKEV1ALFI6WfgHUNUVoVGibuv0Q3rrMkcp4p2r+h3Q7V0MrgAK0wsvBqp57VAMHNWMAbOkBY4cqxL3VoHh5J81J7EsOP2LJhKi+1sTJdPNDlXT00+U1dKmpt91TGGE/DtbiOd1qT7DYOGxgVLrfEm+bajqqaL1CJIrDAS7j2upLVwXYoWhHovdt9BhP5KSHDQuRgvIDxbL7KUw+GhWo03OKVga3Va2pjNXumTAe/RipulXHNWxtqnbO5Njp2RlXSbx8eU7/cSk6udkM09HEhIjO23GuUdsj2lyEpICBdxbjMy92m/P1lNGPnjBIfLNdSW+W6nFK8r6eTQvT8ZVarcQOCoz30MvYfJ+D13myXyb/LOEFOw8WbN32UhRbzqNM3pFiT2h4YkljANsLBXKeVfgLKK4w2r7bLLFS3OG1HH53uOwrPNhmt7hwH/L4+XQVS2p9c3Ab5ciRI0eOHDly5MiRI0eOHO+SYtmWANeFImGyMwJgfeDmgzBVmLGIXFBoAo6amjBCJDz0u2TQFFzmP8ziWFUpD++t5GTZSj900vUbYs+y3MuMx0lgtUEYKG35EE+QOy3a5SpjQmoiC/WHhuIN284bWB7AZgLWCQrMCfgQ5s8aVGPglyNAsJ4nFY8GC4dhlJ4GmfbwD3gLNTP6AP4WgI/uPY3P7kZafTw8u0dg0c73sun2cr09l5eXvbaP/rkl+wSg+ObmRnZDL8dHtVRLgPdCKhYUU9sGvgDr0R+EIQa6ndcBmMme2zhY1kJVNB5VhTnbjj5yNR9NQLQ7aECtykP2C8FJ0jkB5iQElQMk0mFfBh8jUezC39TEnL5vtp2/2+d5UlBBJvYHvPYGJ/GmFdnjR4eKP+v5nMB/uTqWcdgSNG9vztm3bQU1O1TAI72KZwa597uB4AugUX2JAcUUaEdA6tUE0S9IBYw2bHRw0LuXw9xAFUFkhIheVI/ewKIgDn6s3ncc59gvlIcYT746wGwBaoJwQPAx9q5BagXdntQIhzTf5ClUckV70IY78A22Ag7yignUmjJpJ64puPa34305/YhbBHiyxe7asO1BMTOfAcyPeTKOJgrs6f4j8LOx5DA84bjqmo35CGNBC4i6olpVqF440RMAuN5IYGCMxHsiWop4u16lDE3hctDlxnF1Kw7V0Om/71IuvyomFzZJfui/CYOLUprlkvPI2aNHsl7fyMX5lfTDW5zrkMDi+EL3uDI/EUVjj6MXSoVXu9n7EF5DSb+DnzgSc7jHkJrBSoCoLvfvE47JwL7dU7ziuB9gdTOg7/F/bu0TV0P4XKPJAvX311P15FAC2X3uipmBpIsVVCORlxadRMJqsWhls21l0TZsS7dFq9RffwKc3d87JG30+xP/ou865wJVKuPl3yOoLZCuNpn2z20bjVTl7N99nhBKczC6QEK/n3LkyJEjR44cOXLkyJEjR453FSzjeXS7q1ASTnqqrBRCqj8ygFsEZAplo3KKNhl87I/ew1BvEc7Yct+iQAG7Uk6PWpk3tTw+O5b7Rwu5urmUi+uONhzkeVR0OixwD2MtqBaWARuJJPqaCG9xLAAiGBdo4UGAuaqGEg9F9WraaqiIzJRzhCEy9TGGghHM0IqgEQQALI+j9H0fABSXHQN6ADUSyO6D5BkqSPwdytiHD89oZTFfDHKz3clbL7ZSyhUbDZgMxXJTVwRd11fXsgYYLk5k0WoxO/gB121LwAxADqUhYLT6V6sAmPYABBja/yiMqMvAkQRQj1MyWVOCYytgHwVFBtZQfApq3X2h5a1YhErhryriFL44E3Ulnp63jyaXe9t+A8RUr1VCazbVTHYD91KoqRYSkQAFYGn7VYBoZNqVjFB1o5/mC9nVtayOTpgY2fUbWQPKl4W0Na7/XupikKrAUv9BwReB7kDgBUjEMWYWGL7k3AsLKgAGaBuD6jr4n5oPN9oBb160l+p9L6DFbVHA0O8pI1zm1w2gDAgECOdEmIrIAsUWK2nbRsqqt2J3ugQ+9KsXj0wjGeOHYQvzI2zTCxBVvxzfWsRRGZmuQ4jX0q+Z7zEtHHnH0QIQ00QPZ5TEniMtcBbVm35+d6mTY0RvZi8OZyL6oOwOLbTzUsMIJLqqckZ/81lItBhYxntsoq2+YOJB55VppAXtXhX+mcMVFjam3xGUPvz9baByJP0T1bKn3fT+Up+gdnkkLYpaPlkTos/fei4Xlzey3Wzl6upKEx52f2N+mHg2s6MxJ6jNht6OSJoYvreVMONspBXJvki84b137PvBxyFXewD4l/C7bnifYsWJfh/g2wlzghWadMmzFYvUgqFqQRSVzT4PYq47HKdx3NFIBMVUawXLen8DLNeyWMwVLC8aei9v11uDw1Y8MLFrObzU6At3lnabHKiY1a5CX160VuefxCAngclpLuiwD33mDFZENmein1C8NUeOHDly5MiRI0eOHDly5HhXwTIeqI3nMlwpTF/SdJk1l/67WiwCxFiJXrcBeHS1JB+gC/WVBdRpm0Lm85I/6xogE5AMqkCFljz+pLiRNSEtijQpgOQqRhXDwZ8TEBwQgfAVqt6g7lWcR/Wogaug3zS1mSs3/WFeVWP6+zjuIlgGTB5HFr5iZ1c4LvyAVe2J4zW1Wi7M560uqb/CU78qU0cqTWeU/HVdLxcXN9LU1PRRyXzvFEpC9djFPgiXqVbWc0kVaixsl4AMLVilCkyCDFNP45heBI5LsenBbEW4uNQ6grKoVo265tj5+jPAeYJT90VO+WMCSbwAV1KATgWgftFTwWtMWgRKk/hr66C1vZPvmnp5V0ixVy9W+FxXVSNN3VKd6v6j+z3Uk4MqUB1Wjb2MUOZDgcjl7VEK6Fw54DyjN4DuAI3uux3khbjO5e52Ea8w0hyeot8NpRu4B1weh0GLNg5o1wGIo+oZ90oc9+yAxFr5LuQ4tVBIaxOaMpOw2xtq/tBB1ar9SqMb3jIR5k0jKjfpQjtTCw9VcE/tMeKQsOMQVt7Z8gCjoyI7UTrfBdlMqezK6KBs5bzgKnB9YW6gnYytGsC/g4WN216EsaDnpU1Ijunq/fhLoqaNKwGmEDlVIN8V0/dfXY9tqm4NfebX+JX9avcqm6Uewc18LkfHx7LdDnJ6eizrupLN5oaJFMwbnC8NzLpNjnqgGwylWjjei6FiXarmtXsk3NehO3QVBVt7cDvZDBOgP4412r55R4Wxnyjhkw5RW5lE8jutGBiumsNljAF8Zw19L1tTWzNRVxRcbcOEIe7VAJaTlQKH+TQ/iUnXx3lAV8TMZIe5zOxvHFBrMjD5rrP/tK1TBfPU+sI82NFGm2feTtueI0eOHDly5MiRI0eOHDlyfFZgWS0j9N/+sBo8lw0qU8jKB3os1VcoB5LAx1YDLbrKmeXcCF7x4AxApj6suiz45KiWo1UjR0elLBYzudoAzm0V9pF3mMrNC6Cl6rMEze0cWgd1pxbuA8Siv3Ip0ixqqcua/s4smjeOXMqMh24WGKRi1BXQBmJpIQCorkWmqNfcK8xA0Tzsg+rG3Uj4cHN5yTYDCtMNAwXl2BeVLJdLqt3u3TsmWHr6fEuFMwrr9bRqKGQ/FNJ3vQybNa/Dgj7AsM84JqgGUAbwadtWFquFwrBg36Fww31C0UdaNBBN1AKKw7CToesIJHFcgJla+afUBM5QKpuSll2ZJAlYlOtAaRmOiVBrDBeXQpXNQn76iwKdwHjUszh6dOsomYAeB+QETubRakXBws8U1hF0lVLQvwRK4opqQCi7m3Yu88VKjpanst91sh/WMhIIbcyupaWKfr/vpduaL3XbWUtR/NEUu3Y0Lig3r1m8VI2pRQoJl02ZjOsORbk6DVj1SFfPsoNtSbwVOQwofxx4nfptL5v1WrabDa8XQLNCaIPgsO7QDk5FxtFpZBLR49Yj2Jgk/rRueePFwDQxggSGLqXHddSiZmppoB+awiwFurjPbfxQTe9+424rkqw0UKMAvq+82iGyQdmJBYZalLg9waFqc/pvvx+0SKWOwT3nL7yqAis08JpJW0HJDv/tmopwKMNLzCFoOwCzQ+Yo0X+FuDgFxX5RfJx63x/C50P18l07TtXJn0a2Pdk2+cwrEgA6r6NgJixldB5fnZxyJUndzmXotnJ5cSFXly9oKdNtt9L1vSVZzKfY+gZzFe0rzIdYC+zBy14992ktEyyt3U4pXjO1LbEJPVmlEnz9bdUFty11RQALu9o566w9O0gGRNLrRQa5QCLF1D7HYJCbtzairRuuTNlucA/updtueL9jfj89WpkbD5I/mqjjfcI6AQ7T9bsuOVG2gat2mHTFH9Taw22UOHezObqSgX1AgK3uQ6k5lLn8xxHHe9NscTyZa3M07t3Rv6tz5MiRI0eOHDly5MiRI0eOdwssq+IsPnzywRQwy5S6ttUUbTiUmqAP/U0VgbYvPvDaVsVeqhoQVlXFUMnuZ1ryz200bhfqSgoVuQov4SaHykZjlsF7WNllAh7sg3zmDz7QUZOroutUJepHUm9MKnzxMigb/KZ9DXKitD48roPgUJRQ7Zj5sY7F2mbSAvCRdUNFCV/dSv2lDeSgY4PCOlGrhT4OKuvEh9d/uheyn1iyjcNaWJhMQUT0F3bW5+o+7cd0W5fRece5AjGOEMUirkJOgFqiXgyFANOLYvDItItJEyOw9KX6AIPoMxS+A/jZeZLAlNxqbeDw0jTZ5pnt1hRhDLj80xXMQUaoquV0SXvYP64vVeE2tqgMDnpDO13bPghwvXim2mIgIcMXEwS3Fa66t1SdeUgRo9I3vT5RtLk/2MbunUT9Pi10lh5j6uWq42A6E6Sqykmk7NX2GRS2yeWeuDnciilUjmpoPz8bD9Y2tUaI+yXgAwRFIgbK5Uo9fTlu/D47KBw3hb/mJXwwB6X9Yymq6GPt1iNBuT7phGln3oLEdxLiV/TLq7YL6YhwrXXXulqBqy1shcVytZBh6Ajbu24rNzfw3e/N2VjnEN5rTHKVukoAtyTsYpL5MlyVtKhc8Or329wKXk4Sm/Hax/lj2gcxUZFOJX7Mwz699YURFc6pV7h1R8H7Wos49j2SPVuuHoCVE+yKEot/+16KMPyua+ktdftnTS75ipPphyYq5TCXH+b2bA4JXwS+XbLSZlJANYPlHDly5MiRI0eOHDly5MjxLoJlDVVfYhk+XihSh+fSstIl+yjGxIdjAAXSyfCIHJapq7tlQdXoDEuk9zMZxp2g5hjVmeVeVstSjo8qKape+n0v/W4jW6hGoQIboQi2JcCuVDalG9pDH2Eocq0t0TNT4RLaRjsMKNsAYGGzQasNeC4bjDOfA6uxF6BqXUGVrcDbfxIkm4kx/tsNWxnHXqCdraCVG3fsZO6RhdQUWrJAEz1zVTWmSs+ZbLadrDedbLteOoIalnIj4NoVUFfPpFg0VM0dHR3L/fv3ZXV0pNYXUHIb0NRCUDjPfWgvYTfUfCYB3OH6UVVttgr0EsY2UNipwlltIKCAhaoNfV1K6aDObU0AnFjIUPtJVeg9/4Zl45FQqZo3qBJTBWI0LCHoLa2wFhTOASITthzwNofMbsvgINQUjmptkRQkI0zeS902stu10sE+pJpT+djtFLqgX9QPG4UNlfbgPHAqsMTgnrD8HeOM/rEGIauKCnPC4HEneglUoesFtWgBsxMWpFRbDoy3Usq6laKsNZmCcUklYq9d5ncflNYVRtaO/rab9Y1cX1/KzfW1jEMX/a1LE3c6s5wANRvldj/GAnkJEAfuhI0Hz8ttH9CeaCHhak7tfgP29pmprYYnTHjBI1DkdmqdwfsQhdxoW6vjIOY2XGFq/b/D+Ij9mSqBpwkKe/dOkDfltW6qghkDl51K5bqQFr7mVUlv89VyKcdHK64uwNgpay2aiRkNBR75UpMae8X5b6oo3k3VylwVgeucJD94Tfzc6F7/ChXz5ExfDUsnv/vL2nGYo5scw240V3aXOyllJ8ujY3nynseyWs3l5fOnVHOfv3wu11eXWtjV1ORY5YD7Zr9YEMqLLKTcN1IUsJ5pdXUC+myPeROqYXhY44vAilcahOaoIZjVRCMs4tl7ptbVpJ/PL5qMk4MChF7kMnqzx+yEg2N17dCxrC7beEP95FUNr/cvvjswdwEoD30nL58/k09+fCmro6Wc3jvldyPGTEg+ccUCrPV9jJrvva7ZCYX8OEdj3yzWh+KhaqWEj5RNo6sbuFrGveijN7n2zR3j332V+YtnS6MyXL8X08RUjhw5cuTIkSNHjhw5cuTI8c7jwJT27SOqX2MBLYV3rsBLBJyRPSZqOANWXOJv4BlLkflwrbALXsR4yQzLdHtaUtAPMlnS6+HHovaP6uADX+UD2OIerOFzXgQw+AH7A3ridpuoxlzlnPIJfT9RHbMNurRauaRCWpeWBU9ja2co4GZ/GwkiHBzoOavqVRVnUAAC1DR1Q/sL+CpHFWU8FweaE+Uy9o3Cg6kq+i4VuPcSP2vbAnIQPiuADsTCJJ5eCCuOkwgvUqBFK4Kg5nYlb9AFW/EwA2qHHR3R20RkOFXK360gDcvzzVO7tBeAjfq+6vZRoW1qZQPHClS9vQZ37O8OP12xHK9b7N8IeAx4h2ugliRTla/LMrW//O9uI43PjUmSR5MtWsDs1YXs0j9EKBt6avK5A8Xxgbo39Z49kHne9ko+vH4Hf05tD6LwOjl4qnBN6zaGkRDv3agMv33uqf3Bnb1j97ber2on4x7LdYVCcTUtTFS1PC0gGf9zCHwIldM+nwLxCej1e8SKIfI/AMnUw/vW7tJ9HfRt8l/cwTShc5diNZ6B3RN+zWcl+wCe8PPFXJbLhSyXc/YJbRVg19L3BK5QNOOFpFko8EfYylkx9pXZUfi8efuemf48TCKkc8x0rjbPf/fSTtX1BycbztfHVWL9FMCzb+uX3dradR2taXAfuhd3GCMTNXaqLE7uoWTVShTWH8zh6TVN/JS1T+LYSdXxkxoD6e0U5tok5fHKOSNHjhxfH+K//+//e/k5P+fnyH/t8fl+Hn/9r/91zvUvX758V/b/wQ9+UH7Tb/pN78q+c+TIkSNHjhw5PmvFcgdItt/LAKUrl3lDxQb9GhTH6sur/riKDSpXAIoI9K4DPVwdKsNbV3FMT2VnJ2WDYn0zmTd7aQCXZS83m7V041bW/Qb162SY7WU76gv+yloUDV7HCj97AGhYZwBMVKXsd1DD0f1ZNYlQivaq4kXzyxmWb8N3l6aTMlLlrNDDPSijihNtdiCi/qy0oqiDaUOAB2CSgL9QOEPxVvFk9yzYBnViNwzS9ZCADjAJlXLcywCfWsCsesaihcu2klXbCFxLaXM6K+itDL/Xhw9P5P7JSh49fiAPHz2Q4+MjaZpavacH8+R0fmH9rGJhBbnwROV54fj0OC2kqmuFoYCsZieBPoimHVrgiYWoBgNwUNBRVWz+uFQlFlKM0GobfDX7kqCEdf9g/zP6slJfZV0qr5672KdDfe3+uFycMVEwm8LTlOawn5hIZzEu68YgLfZXStkuiFba1VoWx8fSb0rZDWsr3KfeplRq70cpi1qqeq7+0PRGhkJ1FMG4r9sw9vezRqBPp1IZ/y8pHqgg0pMZBVWXULDPRvUnlq6nNTNUmpDv+zi8ywUBtgzz5VzWm6103SA3661cXW3k4nIjXWfKcsLwxFrD7pGQKLFCmCWLZqoyOQAnK3rmB2RSBJ8zj+URn0kBcLDJMD9au+f041qwTa/3Pnob+2oD37+BQipdTaEaFclJ0geW3UZZQ/FGB2jMR1hbgm0B2ux2HtoX6AUqxbkz3YseU20u4NMLVWpdl7I6buV4tZDV8UqWR0fSLOZUj5bwWsbqAUicA2RWlX24+wJEvA1uo1xaZ8loD+EfS2GzgmY9L+vXSdG/2Z15wnjvxp8TC59ku2kCzvaT+IJTWe3AFomtZi6L1V5ee+N1OTo9lreeP5d+GOX8/EI25y/VLsLnnlGTOdWsop+92s/oeFN/XyvqBwWv2dDQw95tMLDKIplP1G7IQLTNaRTsBsiu3z1+TTTZgnuimoDlOCZx++pSAtzz7kOPz7Norc1xVD6zSJ8WJGxbFE+tOd/gHjw63msh1bqRxXIps7KSsS/4XcVxZf7guiqGN1Qgukwcwlk5gPEkSYlz5rzuxSVN/e/XJHiCmN99sN6xpC3rEegqn/0O/aEvKutRkDVT5Rw5cuT4byL+0T/6R7JarT6n+/wlv+SXyJ/8k39S/vk//+eT97/qq75KPvShD8k/+2f/TL7kS75Evr4E4D2SF5/PCYwcOXLkyJHj8xYsD3gwDTgDYFmXGOOhF3CA4AbPqVJIzU9ocT5duG9eokGpZiADkJoWBAONFrDKvy73Qh67FxYmW2/X0o8D6+dhm2G3l2HUh2VdauzsBg/PWmxv5lDCKkJpUSYUdBIZB31oL/k8DTihS40JA7m8XgtN8fGfD/WmocUDOAA1oVkEYu70QBGutwVMsZxRvUYbBRaqUpgAaIL292S6CllxTjs7DouHVQWX4M/rivYVwx7vF1I3hTRNKUfHczk9XcnJyUqOjo64PB8qQgAVgnLlVVGpZ/3DPhtH+oGqyg1wcS81l1nb0u1QhCxFYanqEcBdIZwq9+hfENSMqvzF8WsDH91EWkyQY4zIagrqEnVAG/Nk1eXrWtgPSkfipUShGKDx/nBZP8CKeSsHlmcqxRJJBkAonCdgdiPFfiQgq+cL/g32Bno4V01ie4W8ZQXrg0p9JnC+PF4seqWF/HA/IKHh0FyvuQLXRBZovtMcibAe4ZtQgWM7zUwo3IrXL1wLWhJgHFRURqKIWN8PsmVBPyhEIyhKPxtU1BOVqimcHQrTXiAqKYPTbwBPPpZuq3HDFhwXDndNCoo+MMUpAmA52Ja44ng39e5OwbKPO/8xacfEb1e38eMHeJ147SqT1NkIli5B5ZmoRt3/HUpl2DzMF7W080baecsETFQt233mTQqK2CQFwqb4eUx6yrZyZf5dth3eH54c0D7TzXzsGV4npDz8bLqP+F569eKaAv9bVO1PDHzVc8JuKb0fmnYvp/fvSd3WcnJ6KkcvL+TmZk1bHc4jSHTxY+ZrPitYJJW2PQbcadlDCxIdJFQsw4KGt4Em2/ScbbwyOZIqmpOxYU2NGnLLSQXLC092xLGCY/HYNq4I0TWTF1fX+GoQG9s+79eYE9hXMyZ4sGqANj4l+qZl8ogWQ/g+Su5HtstWQai7B/Zr4DfkaUxRjO+gmSY5MVeoLQj24VYqU7DMH/aWQ2X8JMDnKfv8aDdmWB3yChl/jhw5crwisFoDnvs53t2+gCWSFzp+u3j06NHnzeX4TNr9+RJ933N1Wo4cOXLkyJHjM4t3/G2Px1g8KPuS6Jn7znoxK5WXJTDC9HV4CDcIG1f5KigkUqG3cS9NKXKynMtqPpdh2Mt6M8jldSfnV1u5uhlks93JtttL1++lN8VyWnxIFbQKfByYKBgwUDmBPh7AOiVfCnjC2waciNBTU4zwZyrQ6LNsxZzgv1zpsdE2KMRoTyBQPjbS1PAVrdRTFPARbqJ7eO2O0qMQG4GGyMm9pTx6fCqPn9yX156cyaOH9+Te8UpOVkuZNzVfR6uFHJ+sZLFcEHYBDCvsUsVwAAsOL61Q2m5QtZ8v+dfrB89hU2oGP93EfiJWikpeqpaLP+3frqAkqzDfXYd14RhQekMVqmOnCsc3pbKNLVd/Ug1Nn2VkLSoFx7fML1LA7CrPuJm2wy0waimqWgoojZu5VIultFCiHq2kXi6kXiykbLANjmk2IOYzPdqSfvybPszDKHv4VMNvGarEWWVqXFzfgtAMtgC0EDHLmEDAbFyqajd66EZQltxHBjupBuZ1g2K0lqppCLBw/bFnjCOA5oF+2e5DnmLF5F8Taxeza8E1KKHEBfxTVSmu1cT25mBZ/+ExFFDr5+L5af/rPiNU5jig5US8/v6attd/c/VycqwkeRKTKLfMUeKtHbiuJ7uSe9vsSXAA9EdlAB+rAQARYT+Dvvd5hCsm3MQ6rdaW/gyqa86gE8uLeHbpmE79lJNBPBnnyf54Te481VS3nTQrgdW35sODCPe02wP5ji0RVNbSLuAtfCxPnrwmH3j/++U9rz2RB2f35R5A8+JYlnOspphT8V9WDVW8ahej1kYHTQzWDiF/YH3oPvnp9Q/n5FY/rmSejFXMeeolz4KcyRzhFjOw7+BrGMwmyG054sAmMMYYKHUc0IqobWlFhE26rucDKewwcFy8v5jPmfyZqqRtPglWSAqY1ZbH7j2bx5madTF/MhXrpUGS0Dz+497DjzQnMLl8vorELYnC+Puv58E/R44c705g/vqZP/NnyunpqTx8+FB+0S/6RRMrHig1f/kv/+Xy437cj5OTkxP5KT/lp/D9P/bH/ph88Rd/Mec9bPPrf/2vD5/5rb/1t8o3+SbfJPwO5SvmxN/+2397eO97fI/vIb/wF/7CoI6F+vX3/t7fy32hLT/iR/wIuby8/Jydx3a7lZ//83++vPHGG1T3fttv+21pUeHx7Nkz+ZE/8kfy78vlUr7pN/2m8gf+wB+4ZbmBY0C1imN8r+/1vT6Dnhb5O3/n78g3+2bfTObzuXy7b/ft5F//638d/va7ftfvknv37smf+lN/Sr7oi76I/frVX/3Vd9p8/KAf9IPkJ/yEn/BKKwz09e/4Hb9DfvAP/sE8l2/wDb4B9/tuxKvajb79Nt/m27Cv8ffv+B2/o3z0ox/lZ9B2nEMaOEecq8cf/aN/lNdgsVjIgwcPOF6ur6/fkXr7e37P78nrg7HwXb7Ld5F/+k//6WQb9M9v+22/TX7AD/gBbN//8X/8H3z/T//pPy3f+lt/a14ffB79h0C70Paf+3N/7rTg+9vcB4j/+//+v9n/2OeTJ0/kh/7QH/p16uccOXLkyJHj8zHe8dMkHsmh7eWDrqmtHEyqyjQFkQo1uNjWgVUoBqeqQf4dKjYZpZBOmmonp0dLOV4spO/2cn09yvnlVl6cb+Tiupeb7U7W3V62vcLl0R6qp2BZFbPRL1kfzmP7VF3GJfGuottXUkpFRXOELb4k35bLO1xOgFYEsSpgBVQua0q2CZShHMVPfKZpWmnaOS0VFLNrcS/8eQtbDIBlSjb3cu/+kTx5zz15/T1n8t43Hsprj8/k7N6JnB6vZAH1ZFPL8fFSTk+PWCyqnaOYWEu1LVXHdo6u6A5KPKjnaO+gAJvXryrDK56PgXiL6NecQGWo36iAO4TKBrwoXlZlLaBuCpU9EaE+0Xip9QBf3DYW4XOwrCphqIUBpdLrpC2MhHMXkwFhJbxfILxKmUFlCLVyM5eiXUq9OpLF6anMT06kXa2kWS2kZHE2XTavxQvVO5Zwma9RIRSgcT8oXOZyd4w/qNM1eYAx4IUR1YfcoBXZnILiCFpl4oEdluj7f+bhTEBblYTJTase2/jJ4pX9yAKWgMsce0kEKOtqT/t94u9KP+EEcJnFiXsgRz/oZDzYJfAjEBYn4ygkc3jPYJxpv3oxtRQmp8kNJ4ZQa/ohHGj7NBP/R33yMqjsyubUu3Z2i926xYTe3yyqyXG957wFsNzinpvXUkMhTvsLne/cS5tw1NS5KQiNFhj+cwo1ExQfwF5QLx/S1slvPt7t3juE4weRuj/fBaxvKaynd/7Evzd8jpMeiqvWsliu5Oj4WF5/43X5gi/4AnnvG2/Io4cP5cH9MzleKVyGZQTgclk3/AzuYV8RMDm51O8+TLY2F1vxPh93U79mFVS77XsAzbzf9JrquI1e5g6W1Rda72fACCTfgr2G3avsJ1obNVSsAypD0TQHWJ7PucW2s4KrXc/9AirjIR7bajFbu4fS1RDWDs6H9GbG+IJVCO4Rt1WxtOaBAF2n4ThPxERgMgoSwBw+B1sN9oMnBt0mJIPlHDm+vsfv/t2/m/PQP/yH/1B+82/+zfIbfsNvIJRM49f9ul8n3/ybf3PaHwDY/pN/8k/kh//wH074+6/+1b8iGMb7gIwIwLx/+2//rbz11lv8/W/8jb9BUOcgFwm5v/f3/t4EJH75l385AfSf+TN/hi985lf/6l/9OTsPAGEc8w/+wT8o//Jf/kv5YT/sh8n3/t7fW77sy76Mf99sNvItv+W3lD/7Z/8sgS8A+o/9sT+W+zs8DlTKgMQpKH8n8T//z/8zwSPgJ1TG3//7f3/2hcfNzY38ml/za9juf/Nv/o08fvxYvq7xS3/pL+U1wrl+n+/zfeRH/+gfLc+fP5d3Iw7bfXZ2RnCMcYDjo9/Rn3fWO7gjPvGJTxDy/6Sf9JPk3/27f8dx80N+yA95RwVnkYz48T/+x8vf/tt/W/7+3//7hLo4/8MkBcYswDHGL46D647fsS3G+V/5K3+FYBzxx//4H5f3vve98st+2S9j2/BCvN198I//8T+Wn/2zfzY/9x/+w3+QP//n/7x85+/8nV/ZdiQ/Li4uJq8cOXLkyJHjvwkrDA9aMkCZu4MdhirPgs2CPcUaC1KoSD/l6BOhNgr6pIzH7aYqpJ6VctTWsmxaqcuKamUAuJv1KOtuJ9udyIbLjXe0kcBqfy7xNasNSTWAVENPNZoBSEDtacvc1e9SfS/RRn/M1xXtBp6T/+FD/GNwnEpbQCaD6yxcB6AN9SXgAJbLA6z68mmDZQq1dW9Ylr3fA0QMsqk6ubpe839U7mfwOobH8yh1XUi5FQO5KAao7pgonLU6WrF4Fv5HLXw91dcW7allP4PFxRA8ZMM5hGJR6DVXdCd8MPkfamFJvPmO6qprVXLuzcsYUIbQjl7NaJuBMUKLYKIcVIbqcz2pYGXWJXp+6mGaLs9GWInHQAbRA66Oj9YSsd3T5f5hFCQ+z65n57EBO2uA5EaqBn7JexnhbG3nS3Uyzg+qR7QO/twJZNNtcJ5QQ6vaED9no4HpsJ+dKaa9WFnkoOph7MXxkkQIxziuu33G+kOtNswKxsE9gD3gPDxjaHESFbST/wEepKC3e0nV0fq2e9QGtX6iUlYbGP8Y4G9sf7D4MKAb9LHhksdr74xyAn/DuIjXMQ5hs2gIh06U18m5uIjYt/FjcDsXGAd1dSzk6W3w+1qV1Al2NZjN5IjPET4W07HGcIuTYP4bCR+TSJrASu+46Knhd+fhNYr95uNqNrnWd3hyT653CrPjv7UpdzzkeXvCz8Q2gaplJJBKJufmi4Xs+pEqNqiVN9tO9hfXpsg1X3LaxaQJSLW4wJliPucdad8lvMsniur03k19iJNMgV9nV7F78oHJKh8TvmJgOn402RhdP7Rv8Xd8aUUVPL4zsJhht8N3wGh2H6OM/SBdt+XDIA7U1LDOsRUZ+I4IyuE4e6SlTaffNGl/+/dI7AW6Mu1fdc1i4dBpAsi9xM3vnwUhD5IFOXLk+Hod73vf++Q3/sbfyDnnC7/wCwnI8PtP/sk/OWzz3b7bd5Of9/N+XvgdkPK7f/fvToiG+Ibf8BsSJP/aX/trqUaFWhlwEXAYCk2AQXwewBcBWIv/7fsdvsN3CPtE8g1A7vj4mL8D6gLuuZr0szkPKGh/5+/8nfz5+uuvc3uolwH68P6v/JW/kkplvOfxs37Wz5K/8Bf+gvzhP/yHA2BEAFT+n//n//l16utf/It/MdW0DqgBK//En/gThJMI9AkUroD4n23gOgDOInB+v+W3/Bb2O2D65zoO2w2AfX5+Lt/v+30/+fCHP8z3vvE3/sbveH8At0j6AiZ/4AMf4HtQL7+TwFhN4//5f/4fKqYxFtEejx/1o36U/MSf+BPD74DDeAHIe/j5YCwjEYyx+dprr4W/I3nx6e4DjDcoonFcfBbn8qVf+qWvbPuv+lW/anL8HDly5MiR4/M9PiOZEv5HWtvUspw3smxbvhZQTqJgEVSOVABHwR41eIAz+BcehgcojUVYhwiWDPudHLelvH5/Lk9OV3K2uierZiXPXm7kY5+6kLdebuXl5U5eXg1yfglrDLXEIGCG1QRUX160zfx7qQgzOAGGycJFUBBD3QVgAWiEQlJQzPJV8Sd8Ns052cAZIKAqoMOSfQBl+B83tSwWUKu1tCSAeq2GJcW8lXa5kPlySRUxQSOUfbRIUPi7H7UAF9TLsPy4vL6Rl+dX8slPPZNPvPlUPvXWW/LWs7ekH7cyX1ZS42NdJ/u+k2Ic6D99eu9EHj15KPfu3+MydKihvWAalIFtu6C6DqrTmf2PdGVBpVodVPCJbcwP2VRxh9l/g2xQ8sGTeeg72Q29jHzBy7fje/raytBhm63shq3sxk5kP0yK9/HaYJk5VNO0lAC81Rc+s8dr7LWgYKrEtGX3AiUwXjZko0FJ4rccPBKmy+iVikfFonvswoRFqkbq+VKa5Ura1RFfAMzoI7Qf57XjeXUiu55j1pyH+YKFCZe/w+SkgY3GXAoU+qtbg+xqQTL0UD6jjaYod49p+qLq+FLlYmHKRag6tc+wf7wwjv09V0FjPFY1FMx46eeC6tcigibtHP1PFZET25R0tYHdJwpQDVKZVQSLPaKNpmhOFbFh6f3BEkFvR1jREFjZ1ColDL+JUtutZ6bbsC90+UOAhw7c4pBwZawWKwzi9aCkj0p5vIfCbEiczZtK2hrzg77vO8S/G1gheBLEfI/1uGauHRT8qUrZPXH9Z6o2jqpl/bfel5o289fUAkbV4/5SixYf1Xcpl28f4+Bvh8rlRMluRuFU/MMiBqpjVf7DMqaRsmnl3r378vjJE3njjdflgx98v7z+5IkczVeyaBZS1S2325el7JBsSG1DgnIYKzzUEoiKXvyJCTvdJlhXhMQUN7CCsWqXQ+9us1vxa+tzNn4G5bPZQNDWxiwp/H7kShiuhjHAbP/WMa3KYvhtzxeNzPFz3jAJeLO+lqurSzk/fynr9VqWi4Ucr1YybxupbUVIGKe0jpLJGNE0abwevpKB6viwAsCLsKJdWhRxP2KewHt2j3N+UAW2r5jRwoaaCAFQLmY7Fu2sSvy0VQbZCiNHjq/3AUuG9Hv723/7b08VL+Zmj2/1rb7VpJ+gIoW1QRr43T+H/UGZCaD88uVLwraf/tN/OpNw//7f/3tCPlgOYIWHB2wEHCoj3vOe98inPvWpz8l5ADLjJ8AfapT4C+2AUhqBv8PyAwATIBF/B1gGHEwDquava6BNHjgGADj60gOiEVhlfC4i3Q/gJhLAn0l/fiZx2G6cG8AqrEKgykZCwVW+7yQAdAFscS2gLP9//9//V168ePGOPvvmm28ymYAEAKwwcN5XV1e3ruPhmEaRQhzzM4m3uw+QRABMxsouJEp+3+/7fVR3vyr+t//tfyOQ99fXfM3XfEbtyZEjR44cOT5vwbIWLFIFFtRYKC7XlKXUWMLrYBaAKCyvnkgNFU6FJe2+T+F+lljST4KqD/brbSc3m610AHI7BdE94bAD4wjB4nJuB1dx2bxvdwjN0iXy6eLrqGp1BaZZD4RXurRe/ZW5zNlgXrR7cJAQFWdBJZmKb00pBwBws97K9c1Grq/XfOmSOLd30OJP6kmM5fktFYIA2t6mVE0al/5H1VtUvrk1wbQht+wNEvuB6FuaeJeSuniBO/xUewxaCaRr0u94hf0QVEabCF+afQsSJ+q9KQCztqfKRjtu+M/2easAXKJA9ASA2j8Axiew0EcZrVEiME3HgvatASB6z0KhicUArlC38e+DP4qA7ZXqFadq4tSKhOpxe7mtiVvBuO+5+1pPFI7JrRidKSYa3+n6+TtU33fDyvQ2T6wV0mJ/h3PAgTr+HQklX7nNgf3DRJ17V9uT7dNutn/7SgaHzGrL4kMwYPADtXNqgpscI8x3qR/BQRtcLRtGbELcrR+notRUmn+g1r912qk6OVX+pns7MMpILEOCBj2ZtPT3qFxOxz79h2sUOZzLarmkFyK85SvAZ1rEHFrYxAHqViuh+279TOeD9KY82E/6ueTU/X6Y6POTS5DaI6V/jz7FibI8sY2B9RLmemyF+xFz9nqzkR5JQFqpRP/66FQxVa/rMdLGxHM69MiO/XXb53xqERJPP/19shPzrWahy4PbKEeOHDleFQCTn2nA5gJg+W/9rb9FlSYAn8NmAF3YJKRxWDwNcy4FEp+DAFjEvA3rAgBEfwEMuooaKlP8+3/9X/9X+Wt/7a/x7wCjKND32fbFOw18hx4m5/FdcigCSe0zXhWfbX/iegFuHgYSBQhA20/XbijBYYEBVfof+kN/iFAf1hTv5Jxwrf7SX/pL8uf+3J+jb/P/9X/9X4TwX/mVX/m27YYNBq4druXf/bt/l/+GR/PbXUecw+c6kCiBvzO8upEo+d//9/+d0Nz78DDwnId+T185cuTIkSPHfxNgedHOZdnO5f7xSh6dHsvD4yM5W63k/mIpR81CVvVc2rKRmh6R7nVsteuhsrKfKGyHn1CRtU0l94+X8sbjB3KyXMrl1VqevbyUjz97IR97+lxe3nSyGQtZ91rMb9P16nULWwJ7iFbAC/VgVFaKKbf6fpDtVgsq4cEfYE5Vi1osCfYB476XYdzIsOtk2KFAG8AdQAP22Ug5a6Wu51okrUbRJqjWsA+on0XVonifxdRqqaFQo0dvPQG8BFalhM9C2k0gXTbSDzP5+Cdeykc/+lS+8qs+KV/5lR+XZ89eyma7kbHvBc7J82omp0dzuX9vJY8fP5An73ksq5UqPHCu6uULr19XRZo1BhTZSeEx/o84W93t2kYWGmThKqgeU0A3QTFBNeh+w3jtcS3GXvYjFM1dou510JwmAdA0HKvneVHpnLzUcsM9m72wl1pv6OJ4VSjSMiN0ZrrE3tTJpgTc7wfZ2cuLdzlo9uGvQLmlqrJu5lKjqB9UqKMyLPcJr6FkLWdUsWLcwuMYqvQK8AzesXVD72ap5yLNkUi9kj2KDaasPvgn76YqX1+Kf8jFzBrCQTKVy10n3WYrm82a46MfMH4H9hXGs9topNhYvZyj97EW1jtIKBjADj7KVI06BHMvVrOQmfSz9aQfw2Cbe7Z6osMhe1Q+Jz66oXuid3qA/m5fkHgJx2eWtH2u+LwbpGpCxK+/bx/3B6gMmFyjYB9esOihnY3CzGCV4X2aqEpdkZ8WR7z7pf7jh/fTbZXxq0h63EbnE72fo4d8eqzDOEzIvOoYh1YM2DfuMb/X/N/qec4xPoP9TyvNfCGn98/kPW+8QfXy/fv3+DBUW+FS3Gv6+UMfaRsb9qAbUiqWRGFRRawUsGswMaMxBTRWr3D1wE4LWGLlyvS002yeQuTdOBPcOlT/7lUZTvWz2/Ac5ACCpz+LXM64YgV2RNgvHoQvLi/lEx//uDx/9ozjaNE2spjjhfkX81dMxE28lm0lwPTh2qymzPpHLYlGFg31wqEDag2Yl/8YVoTo3MLVKhibbokROmI/Ue/PkkKAOXLk+Pod/+Af/IPJ7+5Lqwm0uwO2BvAYTgO/Ax7659xn+Y/8kT8SvJTx8y//5b/MbVN/5Xf7PAC28b+noNj9yEc+Mnm5tQHa9AN/4A+UH/NjfgzhH1Sm//E//sfPaRsdrCKgwMX+384iAl7MqdoX55EW/Xu3AiD3Yx/7GBXAaQCUohDd+9///rfdB/odKlwAXtij/P7f//vvPCcEAHAa+J6H+hfWEPA8hioatiFvF7iO8DWGV7IX1Xv69Onbfg6Ka1ivvCpw/FTF/07vAzyHofAg7FPgN/1VX/VV8lf/6l992/bkyJEjR44c/02BZS/Uh2XgbVVJU9X6s6zoiwxLCVf2ptDCFVn+Myz+pgIaXpSVLADoypLFjzbbnmrlm21HlTJW8+KheVKszwv2JQpSO5p5Nztv0s/oCw/uuhU/Z2pEL2akFhpqo+EqNoUMqjork+XVDnHc3xa+uZMCZLTQSNR/9kyvash4bC7BZ8G9QrbbQdabXm5uOr4AxR2OsJBYYYXEWiy/bvk/5lB0T88zKTZlimKe5y0P27uKdZmXr4ONoPh9RZjq1gvA3VYfJwX9QiWtgx2EwloOVPRntBFIANlk2Xzi8+oqYV+mf0sZm0DHW8pWA5e2TwVzvoTeLQ5cs3ngweuFKINKPcI9LzbIAn6lFikL6sOpRNKPHM4rtvBQspkWzlNwxCX8BPtxKX+43rdUwClkPoSJ0WB7er1ThXf8S9y33WNB+TkdX1PFfPQ3Tvd+oNM98E5OrvUdrfNznezzTmXmXdv4dTjYb6JW1sSTqZbtvahOj3JXP887Wa5v5/fJK6Gz78/2STEpbIL83/GeC5/XCTRRDEdta3IL23aeoJmM5qSRaYMDujf2GiS/kyKMd/7OeVATLovlknZAUC/jAUwTjXovhOOnUuzD80uvl313HPbv4RgJoynMTemddPs8Q1dymopFH4Mqe3JPWtJyMpclxWvNSghwGUtb6bNsfvRqxYHvkYPTC/OTr6Lw1Tb+JZX0r4P0ZLWIzvWarJ3salL087BLfbAeRHrcHDlyfL0NWAT8T//T/8TiYlBWQh36P/6P/+On/Qz8kgHhYB0BOAq/4N/6W3/rxKMYoO7+/fuEiSlYRoE+zJeHFgLv5nkA9MEX+sf9uB/HQmxQvsJvGJ62KNqGAISGShYQFErmn/pTf+otqPrZBoq4od8AhmEVgYKGKHL3dp7BaCNesBH5aT/tp71S8fq5DKi1AZfh04w++Yqv+Ar5o3/0j8ov/IW/kP366RIP6F8AZSiWP/rRj8pf/It/kfYQDtFxTihs93t+z+/h+/CeTmE5kgTwhcY2uK64ZigE+U58mnEdf+/v/b28htgPrvs7USOjDRg3+InPwj4FBQlTq5a/+Tf/pnzt135tANVvdx+gCCW8rQHN0Q84X/zvd/Rrjhw5cuTI8fWweJ/qp8rZXhqo+eAdCRVyOZqCD0XnUPwM/sf6AAwwDD/koFwGUJzt6fHbzmtViVZLud508qkXL+R628lV18kW0KwfpZyVBMz4AsZjMZSE1FyZqiz4wBojgxWHgmDzDp6iigh5zJsVSjddyozsPzoEoNAK/ZkSkIpk+K/WMykr/Cz5E6o1QGXARkAYbOtWIYTVXFuPZqBw20zqeSWVFDL2M5GBltNyzLpwgxzN5zxHLNHCz6OTI1mullJASdcJgc3DR2dyfHoky6OFlFVBlfLNeO0muQqDrC5YWarcdoRKkyq2qJgsa8CQCDfgAYxtHBBhIyYI8Ds7T/s+Luf2Qm1egY6fCKo7KL4xNiZ6Z7OS4D5LeIPuZWTxP2HfTggd4PRMl8LNUCyPimRTOjqsZTtxtWORwFlxCOuiEhTHTGHbDMrhFLUQAtWyr0aql6GSpCiTULGQYcBYHKRk2zTZMKsq+szu4VddtiJlK7NqL+X8hO2dlY16wJqdCZMJVPMajKavskM3809GX/JSqJpaT0WTBvC1ppfqMASoyWSN+4qzqGXqImLQLRTFMwidUlkDV1GNu5+AtGnyxu0XVPXuvReZcKRnqvqM42Tycd+Kyk31oFDvZG+LgTb0B/p/8ilXH0flcjy+tf2OWmQRgMfz9nPye4HJrlkhbV3LajmXo+O5nNw7kaPVgupUeFjjI7i3Y4osgvLb0DbCeRTU1HfTXJ6PUb1I0VLkUNHr29rfeS85oE19k+O+tSd8hPs7r1JF30XFJe6Pv+Lec78dgM1B9sXe7kuRWQUf/b20q2POi+N+Jo/fcyU3N2sZnj7jypGeyaQp+A4JIB99mEO4f4Pr5vfNXzAeWJg0bbf6JnP47M0zO4GqB7mRyelGcI0EjdYH8ERIdPzgmgV+p+17fHcZiEYSCUeywq1IEG6QEL2+5nR0+fJclssF57bFvNWkqS2/pb02j6+txL0b4W5A2ZZ8VQU1yst6W0f412O+dWW1zf3qIQ0ltm67u/PE7XuR/tLqWqTfzjly5Pj6HoCt8IhHgTrAQkDDn/JTfsqn/cy3+BbfgkXtsLQfUA3L/AFNAUs9MC99p+/0nQhE/7v/7r8LsBkrWgDWPteWEm93HrBm+BW/4lcQBgIOAurCl9kLugGYAp4CqML7GZ8F9L3LDuIwcN5QosLm49PFr/7Vv5rtAkz9ki/5EvnTf/pPMxH76eIn/aSfJP/iX/wLnh/Urz/35/5c+a7f9bvKZxsApWj3L/klv+TOv+NYAMK/4Bf8AsJlgN0PfehDbD8A/qcL9B8gOEDrs2fPOD5+xs/4GYT1CPQxCt79L//L/yKbzYbniPMDzEVgjADi/qbf9Jvk4uKCPsW//tf/evkf/of/4W3P6//7//4/XjuMURR0BKBOEx6vCiQ9oK7HeMZ1cusWD4xvtB/FCJEYwf+WeLv7AEUDAcXRxzhPQG/Aayipc+TIkSNHjv8WYra/VbXt7vj/fekXEjaiKFFdVVQQA8R2/SDX6062wyDPr6+kwzJdK5Smj+QAiDvpTDEMz2SA4NfuH8nJqpXXH96X9z4+kxeXV/JlH/uE3HSdvHV5Jd04sqgaLCUAIPFig+3/+SKk2sEyljR3yh9qLmsv5AgP9lUpR6tK5vNS2rZi4cG6LmW1qsko94VCItSbwyGqqpWmXqh9AApVFTNZYikz1GmVGFgupJ1XVK1WDYrgwV+3ZbGt7qaXfjvK2K1l2FwRLBzNodqbyQw2ZzORm15kO8zkZt3J+cVWaSCKu+129OjEEqvV8VwWq1Y2l71cPVtTpfzkjSdyfHwkX/ptv1Qev/YY1JiAliA90BaFx1Wjlh3+ttU35DnjfyQqTFNAN3ZbGfttKKIGUAK4SjhIiGFWALb/FKqp8jouxQdERZ/hGBMfWL64npzwaBh6WjpgExRDZOE6Fvmy5fL0Km5k1q70d4BbFjHTpfAKmcuJD3WQd/IHVIRud+Dg0lkzsiEYMGrjIbtOdt1a+sunMnYbuXn+KRm2m6AuL5BAqXW5f3t0X4qqlXr1QGawvahXMqsWIuVcpDpiP27Pn/LaX3/tv5H++pmUxciXFzpD/5bVXIFRDfhcxMKCBgsxpioA69lMtv1WrTBgd9LDEmMnXT/K+fmV/Kt/8+Xy7PmF/M1/+GXy0a99LufrXi62uGau6I5IEZ/T212RY7THUBgXwWsEmsSVJgr3C+4CUy2KOSaALgFjBpYBrai8d3ml7yb8jGp/HUPRpsPF6UXSJtwjUIZGn9n4mipqU7hsCvVEBOrFEtFKloSciZzMW1k1tTy+v5T3PT6Wk5OlfOj9jwkI33j9PbSeuX92yuTOfHUiR/cfyKxqpVwe01N71iyo2g2TFM/b7DcwQfEegpI9KkeDCpZgX8e3lT0NjZ1qyh2a+j3oCTRXtk6Juhd8i/3gRhJ3KPmnWvTk5y4q7g1uauHNHQtvMgljSmsZUYhzI+fPX8rHvvKr5fr6Rr4a8/p6I1sbu67f1mSIa2X1XFlYFMrnGvN0a6s69N4FUEVfDn1vRUU72a63nEu22xuOgZqFUnU0Y3++EkX7FSs8oh89VoXA1sdXK1R1JfMVjllI2fjqGyR3Ctr39D3mhJk0LeYrjNuabYeS7cXz57K5uZbrywt58viRfJfv+B1kdXQkPbpkL/Ls5blcXF6Fwnqcd7m6wzI/vJ9sBYLRd84BKAZKWD7wbtKifPi8tY2dhwSdKdMTm5Aozjdob0lCrlChpZQeD/cx2vNrf7eq9XLkyJEjx9ctYPsB2PsqSPv5FlhlA99heBh/ri1JcnzuAlAdXtbv+zl/WL76N/6w3LU5cuTIkeM/6/cPkutv5/f/jhXLppdSUMT/VF2JB9NhjD6PUCpDBQxoAKjMh9hkabI7F1AAVwh9k5+fr+XF1VperteygQUEC6opJFEbKy+cB2Cs7ZnxGHHJL9We446WERV9nlUReljY79AOAtAbOmyiSVOoEaSV8E9W1TIVy/RDVmWc7nu6Hl/RnBf527GhMyibsawe6maSXVV6Qm1cQQk97qWdQyq90yuxU7iKPp2vGmnmlczg/7mdyRz+paensjpeEXzQ32sAWIbCTzsmwFWA9wEQzuw2TBGM5f2u6FThY8HN0cUAHqqetQXzrqQl+HHFMv6Nz3ihrVhASndr0MrGhQodbYmcXXe2lTBaQY7u1+AYrimAL/yneS6QXPcyQ3/uFdhBpYdXYI3eDh4XynkfZLYBlah3yBa9aNVMFb9UVMOzFn7U7dzGlHozE37hGhWa4ABYg4GJw6uopMa4AIBeELZBlb+DJcZuIPyqUORsXxMcq3I4Qm/1FUab1Jdc7UlwZdRfNdg4mMd1t+2obsc4IOCF5yr+bZYZ0Z/YnQZc4RsLCoacEhW13p1+z5je1fyHfdhwcysqR2FkuASHqkdTQaeWB4llgx//EAynVjO+YUSP0bpj0v54SPvMoVI3iQDbIq4Ne5/YeOBex2qFmioivuCz3tS0oIHFgRfwTMci+zntrPSwds0T4wr7/xjfidKZ54/+NJV98MNIe++uE9f7YmJr8M7yhsl+0v3d1Y/WFtq+4PauZF/sqLLXxAHGbC31YiGnZ/foN39+cS1VWcnF5Q0TOYPBXrXv0XGj/ummvsW3S+njXYuoUj3vvt0Gm1UtHBcxTOaEpLhi9EyOY0p9s7H6wH2vzVpih3lmLwV89kv91vM+9YJ5BNU7U3DPZtK2czk+PuG8uLm54b18cXmh8/jxMceQrqBxWxsvxqn7RXpDx5IlaNBe3sOWXnBHIE/O8dsXczYSSHp/hoUINqccjr54q9v3ofVpOK9shZEjR44cn1XgoevLv/zLg6XGfw2B4oSwo8hQOUeOHDly5Mjx2cQ7BssjH8TV5qKkClEfVgGRWTCJhZPwMrBsy+51qX707VTVrBayw/PzzbaXXX8jL9Y38vTqhtChRCG0opId6ZUrK92T2dDGMEhhXqTBCoKkC0W3VGGsyjVAPAc+5qPq/rwEyzwbA2oqz6RAtppJ3SiMxRJ4gJGiIILlUu/AZ81mQOEAlMMFwTiW3+3wOdhj0GNzxvfRipIe1bANEekHeGHsCDLQ0Dlx/F7qtqJqrpZK6qGlb+n9s/uyPFoSyqJ4074fTbEMPoLjA44qPNnBSYKWGOrxuXfv58DUHfzp9VUY6UpRK0yGAoykKuNU7WigFoA1RGpgTUUcVHRe/CuFiwp+cf1raRLwpWpUqh6xDB4WGHDzYMEvSsV5XBYLg/UI4DOvqSuUFWoRQJm/sfJJt3XAvr0BkZ3pXwCYRGY11I57qecLXr/tzQ37Wc97pLq9HAcuT2c6ws6FoJvqafSZSNWu1OKjnsuOxRlvZOw7jo96PzeluMLEUNyxwPkHOsnxCljs18a9qDH+xhFe5BsuwUP7WLRsHE2BGRXEui+HYtHrOPWKDdYYdh0i2E1Uw6n6N7mWWgTQFKzJdZz6eCdur3bPcIWBFZ5x/+h4bAfiZn+QWJsE/OX2Eqn/c+rFnYDC+PfbQzXtE293nB+smF8LoNxIu2jNDqPRQphQsbO4X3q+EVanKafQz5PCg/43JGDQDwZXg88vxpMrbVOom+4g7efEz5t3xQEInsThPtL9pvDa/0afibg/fhdgXrXx60tR+DeAWlhiLOXs0ZnMl3O5eHlFH/5uO+pqDrNpcIsXDbXiYaKME9i0yGZIddgKCdqrECwbXGbSK66scHhPNbKBWM4fbstCRTCAuMJlTS5iLOqKCSiC8WKahys3DPpaEpPWJvi8FDJfYGVMzTnv4vwlx/PLF+dcYfDa0ZG0KOzKRIT1oIFlTSb5chK1IaLRU0j0qKmHF8CMfYB+0+/cNBmkdhhxfo9X2NaYWO0AVW3rNWUy71ZSKEeOHDk+/wIeu1/0RV/0yr+jQOA7KST3bgXEHyhy919TfN/v+335+q8xjo6OXvk3KLBhv5IjR44cOXLk+DwDyz39cOGdST2ZeimzqN5OekCtRC2pimWHkab5cg9XV0CG5cijDPstlcuOaP2BWp/ltUgerB7UzkE/DPgTPCSdaRpIraEoTJbXe5ErfykcUE9eVaF52xTwweoDqmLsBzBWC7VF6DUphHewzFx3o57LoehXUE8DchBBA0/rkmtAEvQnFMhgBgbEAZ/Rjn21l1m9U89ieluUBBMjlpSrabWxHoAW+F87ZzJ4SIm3t1dlwwkDjH1jxeeinNn/7nBDoa+qK6NqLqhNg0dvdJtVxasTXAeYEX/5cdIuJABmGw0yB6/faTu9P2ORwNRDOfmVfZPqOOPRlZeZr/Rei+7tUbStqnlswMNxNA9Z84j2NhEOs0/c+9SVnPCQ1n3NqkYKWIr0WLKu/U7gzs30p6sOYxE7h5AY72bkYcp8eszCgmbYESrrq5Nu09NigNYCB2rc6T8MBN6h9lV4FYsU+jhWRhXVtf4+kzLmdzsF0dN6oAHyGcxLwTPOD37rvN5EagZEkVC4BYKnyudwn6WDOfyIXsq3nX58kCgs1F07ntPj8uU1IS0hxZnJrGKg0uW9SAucw+KRNgBvd35yMtG2ILnRQgG5OB4OgK8B80nhu8l5Rbge7sCwbdqOTweaD8N3mABu7tOgq3svY6zSKtzmO/gO11bMb7VgYmR+ea0e8j2cfzRp4nOIt4vfGyzYipUCgL5qEOJTSMghHbYR97F3ffiesPve739X6yb3g/c7cbZ3eXh50Tz/voPlx0x2AMqYuwnW1QcZVvhVjSSE2qHcbDYcG0gctv4dVhayG5MEgPuQW/KNo9+gOfvReyhMZJZDhQWSk2J+oR5c0wOBe5iyQl8rhPZlQ1xF8pkI23PkyJHjv0C8/vrrLH726f6e4+tPfLqx8MYbb/xnbUuOHDly5Mjx9T3eMVi+3m4JVeZYDg6BGoByP8i27+mL3I+DbLHMGUqqsZNxN1C9SyBMgGrF7WAlYUvkYd972XVys11LvxukgOIzED+ANH0QBiNtzNIBcEeXD5ua061bafWgMHg1V59Mok8WTDPIaxBILTYAMACnHYDqCzB3sWgJFZu5ev/CXkOZpqp4qZ529SdFhVCYjSqOtaXVVDPCS5dero2qO4uG6lbA9wJWCUUhTV1SmdptOj79t82K6rd5s+DS+9k4yLDdSjkHqFSgNQx7qv9QLEtPcU/IByDWNKaMwzJukpLKQA88oWtFLTD9pEJXl5MDABVjYwDQ+8hAPbxjudZa1bL0jAUD38+CV6gzDHNYNkBvqmgfQKBObJJ7vDpUVpWhQg/smNWkDDGiyN+oxf6sSBb7ANCWY4T+IeZLDEsRr15oZMWA5C0FqTM8BkC/KZypkB9ktuhlX1f0dOUYIuzC+FSAX7Cd5tG8H2QGUjZz9TUKF7Ycs+XiVAvtdWvcQdpEFPACiCLQVkU4kxfoZ6jcregle84K1/ly+G4U6bu9bNadXJ5fysX5jVw8v5LL82vZbAfpCeQULnOUm7dsBOpRhay3gkN22D7YcnxT6wav4kT/6UUc1e4FgK2QsVC4uNt5VfAELFtRR9z7VGh6kUUWNNQijlCu03WXDFQTCgEOO9gL7dfCndwO73BcxoKFUd0ZmxBsNYKCWceHcm5TwprlAIqS1jPY6ezpjc4kjVE5zmFlydUUZdtqwqBqCREBA9mWpKVp+DUM4zQAWe8r67vwuyZuYqTF/Pzvh5Eqk327w/0d/ryrxel+/Jqmx0X/IaFjVjWwk8C/sVKCIBj3uRa2hPIfY/zh4/uyOZrLttvq+V9tpdvCOglKbZ1DjJSaanhUJTJWW0D5b8mLqXVKuL21hQGe2ioVjl0kDzzxg8+HbEMcCwbxAWs5D2EO4XdPNBbnmoYBKwK02CjzdfyuaXkM+u8XlSwXR3Jycsb2f+rZubTXG3n42ntkvlxxrl+0KAoqshXMF7jX9D6YrBKgRQiSRPBUxj1gNkR2wrthL2NvhUdtkg6AOjh+mBLcpnO1+DCjIvXZiQlPqqPdqz5Hjhw5Pn8Dq4U+8pGP/JduRo7Pk8hjIUeOHDly5Piv0QrDHlbVR1kLL7EAkSm5qFa2ZbX0z6R6UpcKE8yYxysLlrmSl8up1UID+wOw0jpGCVAyEBnQSlC9Ri9Wt850SwE89HNfVrwtbhtDwUJUyrFdDsEBkKhStsKAqWgwVSsnyE2P4wpbA2h2XPVjBURMF2knKtFkeXRazAwguyyhnK3YJoJtkgIAAj+vqZCRdhIUER4uu7cTcNToKrig9k4Vfumy90RR6cpA9u1UAUloY74aVhpuqsgM3e3quLdx9bzDG5Y+yJMCc4n482DMBDVfUG4eEj8v8maAjxwpei1Dgsj+huKby/N937vwAgifEYzhJ4AhXomiGiCfBfgUJCvgUWjsVhhUqrM5hRTmvewWMqq0ToAaC21pEb9u29NnGYrl7bbXwl+uwE278BVi1FBX0WSeKbZ09f7UP2LScwfK/Xig8NnpxTm4kHEnfh8EG4Molk9U0rE43eFuXnVeiXjX+tDb4+2NxzpUKvscwnkA9yCSScl9TzAelMrRk/rwhku1x/H3pOVx4kp77xUdFu+J6f5SeeqBwjn8TP+e7tUlvsn188lkcs/eOoGw+mCyX1ddM4mG8a8rPuBLjXHetrW0KJ66hZ87kgSqzNWW2D0LIBpWksS5KKrwoy9wsEM5VK3f2X8+f6Xv6/UMYyP4D6d83cewzc8J4E6Pye+eEr7xtcDdqOu3UvSD+p7bqgRYEmFcaZe55cmhnYoqwMP3ixHkW7dySKb4YpDphYprRg6u1+1hcOsy58iRI0eOHDly5MiRI0eOHJ9zsLwdVBGL4noAKyiUR2vN/V66HdSSUCoPpli2AmL0ZlZlri71LaRuSqqK26alKnc9bmSHfUolLfZLZSaKpqn/I0KL4jnwUVVnwcJqUBwrbMOyZAdCDYrCwbZjC9iGdkbww+JvVDp7QToFlVQoliU/C1W2F4eieE0luroEHlDJ4JKGPrCzvTOFCFAglzPAE1U6ovAe1GkbK3CoEmtVALvSDpYiKgIE0AJIVl9naUuR45oFotCutml4TFh5AFRUhKDob1W79QOK3e2loSe0wl0v4of+ohDW/DX9GqY+u2iHrqxWaw71JPVztf7ifg30ELK4ClIhCK1MAiQ0yOESu4SRBTgTIGfar4ZZXDVIJfOoSuyguDSYDWmpWz8nPsUqzS1VuUv4G/1M4dvMIllQZe5tG2wrgxTlnCCzaDqpAHO7jYxDJzN0JFTaAJX9mhBaIZrCMC0wWIvIUqSopJyfKLS6eEYvb2Rk4KfLU4Rfd2H+wijk2CykagCatRAfzmeYDRG87kW2m62sb9ZyeXklz5+/kJcvr+XNt57J+cVGbjadDKP6mYcKmYFEHQC1wByn9gn/f/b+LEmyLUnPxXT3ZuYezTl5sgpVxQuAV8AXPlzBEPDGVwwBQ8AQwJlgCBgCZkChCIUCPpC4wAUyKytPExHubs1uKf+vqmvpNvfILF6cKOKSS09aerib2W5WZ2Kf/utX0YbLsMwUwUlgbpYxbpGBp+m0YnNP+X0ElPfH9TGjfacWMTnZwuOZJUA+kvnH0q5gJzXf5x52PtCv1y+3svGkD1XXgMZYeeCjXov0bcXH8dDTu+8dH+/kdFIPXVrZoDAo/g0fb/h/Q3LthvHpvCEx4/fy5nM7UhuSN19TJEtQGoekSfJiduVp/RW4vL+mfSPG64nnismFDFn3kdJk2hae+MNGgKqR0+MDC6DCcxmQWapPcrnMMs6VzMuUisDGYpIN1cyavFm2yZIPVmyPCRwkdCZZ8X58FtgtRXiv+udwrclqxPrD1hqmLVKiMxbOhIWFJ8hCH7riOfL0DYUee3l4905u14u8nJ9km0Qul7O8vODzpJGHh5MsVEaf0zXmhCjGpO7YWRZ1Rba0kt67buSwFKv68Ks3OX7iMxlWHADXUELrDhUCeCTLuDzpGsLPqTAPPQn82jKmRIkSJUqUKFGiRIkSJUqU+DUVy/hCWm8szjfhC62BAKqVAXf5JfX1VmVVXroK0+AwrTHcu1ghGMESvW5BWQ067lSS4Uu4/YVxp9Ik0DDLjCTr2qlwE96KCM3sAbJi2aGyniJ/GY/qMldP5x3W6sVJqwvzOFbFMhSPjYLJtC0/3J9v4w5KbIV3Cuzarpa2g+8zYLFun3a1oBayUnUrj4wChkmRGRTDd2phVbqpCj0VMTPVtBeK0wJPXvgwkJSdFtO3lbt8TpXFuT/ueVty0g39ob+pMYHDpV2n7oBY8p/160pvC57NBsFeKUFTk6h3dHoRQay1XZ0VywBCsHzQ03mhOlh0zLItk1SgR9so2wbgDxBsUA/jlB7LBxF4N7OIpBXgg9rcOpyKZYBTFujL7c9bMIboelaqledZpmmS23WUqz0uN0BlJBfSqbOy25vz3kf1a+LOe+T5FYVjetoSM46D70Wt6Zi7Y9j4TJA6w+UdvLtXICeFbD6Oe8a+Brf5zX6erOTM/05Qz612MN9gc9B1+mg7g8q5X5NKOc6v4CGfri0kW+4vLWVCdmD/noxHUhoBsDdI+P3ORziY0NwdM1zGKzL/1nn2M/WrgyHsaLAtE2ob1Gxcu5DYORwGmadJhgEFEBtZkLzaLRC+7ufCr0mdvAPimsxKFhrp+bcGdVyLcvvqDpNw7fc7Nmytyq+7a487qKyKZRRKbWVGss+Sl0j0zfMkPZTb3HliSZ9d4cm4A0ATOtnX3K7J3YjYT27xofeh679n66ygIdche00cmOgXK+jpa2gByyVKlChRokSJEiVKlChR4tsX76NYbJPn601u0yLbjMcqV0CuaUpWGKnQVgQt5musPsgNv3y71YT6utYEshP8Ya0AIL2VDTrBo5UezQalAbH3Rb20kB8sNeZ1Vv/kBOsWVTbD39aU1IiliWoxQIG8DV5hd/7yrWLE7Q5AmbqNF4Lv+WqXgGJsyzJRCUnVMyECLBYA529UdOP6cC3wGKXFCKE64It/2cc1Q5GtoLJvOxm6jh6dfTconBBT8y1WWC8xE21nHA8K64YKSwBpVTDzZZS1WttZO8NnE0CY27a3WuZ5lBo+wiyI2Jp1gxVrs+KNCaKbr3G9aoFFeM5yRzsAHIoOmi+295b/pF2JKe888aC3oGNnXWGhokkM9TQGdNUhW1VQBzvR95J1AZbYdaYOUsqfB7RB8wwGoVi23+GdimvoRqkBgTHW6SO8yTyOUkMh35/NUxZvnUV6HH6QrV6lbk+aUDicqHjsTu+kO76TWiZptpt5C1POKThyGsfWZn3fc4ycX86q/LfbGG9QTk8yjZNcb5NcrpM8XfQxQa0cMZp5grO7zbbYYagqOUNSJaB+TxIkWxdLPpC323HRDloc0Um2Q9IA3HbDMRcoi8keOpAQgiVd5q67VElvRcs8HZKSFgGIuZ2MvyoR6fxwAbv/jpkAdTR2KnSNSN9Uchw6eXgY5Hg8yIC5BgA6dASjnNtRqcpLMwjq6lbOy9AIr5Ixb5D9nYWBg2ZX+P8J+v/qWDFZs08AvTrnn4x4naqmf+u9OmQNYLIKpb/G4bL2Ffq4gR91VcnDuweuXefzTZ6fX+TlLHK5VjKbGt3zU97WGGf0Cvb13nZ5pASNrXvJTijxWvcMx5xV73bt+q/D55w89OSV7nRB4gd/ARzG2k5lPtd1s24JSnN8tsHuY117/sQdnc8vVPX/8Nu/kOPpKJfrSGUxCnAu+BzlnMParzZQqXWRmLybo+qdjw8yKJZtB4p++NgcMR9lb39PctmuEldj70dMVj6XKFGiRIkSJUqUKFGiRIkS31CxjP/f5DrOMtYGllHAb4aP5Jw8mLOCMBTxCl/mqVYmVI7qLFVgAQwnWEQ/y+xz7F+TVbGcFbOOClw97WBat8+rmk0VuAqa1XogbB5PgkMHx2YBQEjhqsmsAM5Fz/T6HWQkhR1B9qIFvao2QzsqumdZVkB4B7NQv7piGNvpHegpxMUW57qyAoK06eipovRN0VRIr0t2d0iADvegBeG0MJyqwx3YZiCb+wfnAI9nG2HT+TzJtmFr9SJrulFV2hIeB1Uh7gN/T2wJal8odNVT4267fRgRBpyTUtcQiiqk0RfuO61WFkqcALsBpBVqpf5PI8LPYzAx2G1k8LiHjplOKTBFcUAqFZte6haFxDBN9Lq0YCGg0Kj30KJd4WvdydbegJcM9DRa4A0FLGlzcZRqraVmYTJV7OOcOh4dqCoQQ5/hXPBSxvn8OcAtjC38bRoXuY2zXMZZrtPCnQQJDwUFZOzuPCdtHASWFHBtst5wGb0nTRxy6bh7NbVfwamkwAx+zAqq1Vc6K/NdcJuVnMlbN+5ICFB5H+6hHUGzqsz3imC/YLU3oKe6FRSFHUbfwwank76HUrmlUplJsBZJjOwrH20Q9PTmFZxdXPaEPdL+v1fsPeEj8Ly/7fyPPyMt393/n/vzXhUe5OLhvH4+H1x3iRo+pX7jtBCSTYbDgWvH6eEgp4eeBV75OeAF89jXqrB1q5xsteMo1CGpF4IMU3mnXPadH/e3lcGq+0lHT3n/TwG2WS3xc0O91HUXRwT37pWsO24wXmD1BEslrM3j7SYvlcj3P6y0AeHOExQjtZ0++lnX7HvakinxwnOuwT6k8gdf6BtLonj7x77ceVHruVKL+taIEiVKlChRokSJEiVKlChR4luBZd/iP4EwQx05z/RGBtACVCZ0iUW7TCGoMFZhGWAalYtwu4C62LyUAYpZ+gyKYioAX0MU/EbbjXHk+VAUiYo2Xly2jyASYAE720qcCgbqFmQFVQoCvG4bVXUo1tVm6J2+cKcCWxls5a3Z+bu6A7QBUKpp5NB38kCQAhWw+oX6z8Re1lWmaaYqLhYY1NfCN7NWwEXvzGClYApdhfmqEsbua4Lytk07peHVS9WyXagC9k3WaeQLqNLeEUZrlPB6qPQcbAS0kwpZaVIgQx231lgA8msUHgQY123hVGYDwFiParE5h/YK191TWf9myQHcI0kmjmlcnTQ9Jy/02tQmJPM270Oofr1I3r31gKtiHdDYOCYQAyxezdJioIJ8na8K0+dZFpwfftrNRvV+xUJ8s0iLn+ZFjf4bDtIdH0SmSmQcdzQsM/ugwDVLFCQRoGCGUpkWGKMV66NieebjMi5yHQGWkRDwHfzRCzcAR5PtvoUg74tSKpzW+bJ/TZiTNldTAbXgm70vlpkBtzquRDgXVaRvKXqz9Yy5PbzxWufKuL8IISOfcwjoKQlNS7Bwn7kT02cbCvxKzAbDEjNWlK3re7OiCSdNLFPny46vph5xOJsZuL8/iaj/ZLylGvZ/3xfJfG2C8aePFVTif1/u/aYSerVWjCp2L4yp3s81itctjRxPB/n48R379dPnZ45pjG+O/7uCkSoKt8qKlhRA8kaTTbMqkrlMu/2SWe2YQjfaPHiSyf8fa1IE9lo7kIbPmNb8HFjTuga4jHXa1tHK1nJdjFK780jcAaPJiHGaTP28Sd/BK7+nIn6sR9qCuO0Td36ERMu+dww4792hQuvnRKgmVaH/Vn9p91DmLh9fs33uhl0oxWK5RIkSJUqUKFGiRIkSJUp8U7AMwIW4wjMSqknYX+CLsWmf3tLUqc0FwDLAqCq1/GszgOVMuArFrSqsAK3pL8vCeijKpsXb0tZoqjbnDJaTp6hBRpU0W+HADFjcugLXg2OkHdHJXxVAFr6qsOlQiEslKcG1Xour29JWaSpoMwJwiN4femnrRo7DQd6dHmgp8eXLL3rdUCt74TJey0pYSNZsbeXbsAnd0QRNr1YWpmLVolJa2ImF3lDoCiB8U29oqJoViIgs0ypNu0prcAXvJcy+XQkVengA0PPTd00bKDboTXiN6zD/ax0DGVZ40oDM1iEu4D8gTAVw3Uq7dFQ+U53t0Na8tAk3UiEt86Kmv7ErT3FsABywqSX7QHO85C3fClZUIc5+ibYYHE8Llc7+OnJNB18OtrH1HHBQBzuRY9X1tEtopps0/SALONBNq2hh7FPg3Pa0v6hQjI/qYvQbkgKVVPTqbqQ9PEj/8EHWa0UwnQagKT8V7BjoQiE8a1sUuJzrmQXARofKt5tcLzc5nyc5X2Y5X2c5YxcBkzLZ33rvia2hfudBxBuAks+P6PEdZJ36es1C5KRA6vNgEv5ay5nWgmRFYbsa9mRWCz5moW9QwsZkg1mHZGuLoG7e+XfHObsHshzu3PngpdBEGqjFOacXzieqlvtO1622ka6HNQZsaJq9SDpCe4L2DOXTvSV6nHcKpDp1O/XsvvX28bXfM1i+h+j/H4Hlr8bXzpvTTNoGVjgwJU3wY6HFjgCy4kpRULRt5eHxRBsZrLOffnmSl7qW5+dLsljRCWFJQAPLPne18RQu0z5i1cJ2WUlu83jTxGOyeeD/NLnlKmVVILvyWe8H9kTYobEtGA85OcnXLjOTIsuiyQXuTDEwy/vnmqJraANf6XmT23gTfEwSLPc9faYfHo4cY5jXVC0TPOM1+fMxJnV0bTDFvI+nkADS1+v9sw1pEaK7c3A/fD9Tt7pjKNlEpSRuiRIlSpQoUaJEiRIlSpQo8Y3BcmN+uMlywq0P7llRCv1yrIX6KmnNS5m+u7bV3v2OfUs6v6xz13U+kHpGavFAfAVPdhmhcFciUr51n1/UFZBGJWbELb5p+B4MJ6Wn+YPmynz+rnRhOyUir79WxfLQ9QS8gFDrqipcnh9wFvfNrd+mkCOUwxf+UAyQQBt3oSozfX9WiHKbdlD6RsW2W4sE3a6CBSqPce6FMBs0pmsG9Qs1ug7rCapw31KO8p8GypJKO0P9DUrRVIDQ35YLb2mf4boVBKV+Sq+N5wkq1QTMstfq7nXe/36gCDCDolx9lwFW9Tp3Sl4CZfR3JRVAFdqPcm5AMSgtO6lhQQIlOWwRDGJlz10vJAhltNt2GBA1wFx3vWxjq3Mm1G2Lw4nzYSOPSrYtbguC9gAkGseJiuXbDT+xY0DV4Vk1avB1pwiO7evgOY+Qvbo4AtMMD+PF5qKNGZxGaLxTjO9G0h5KxvP6+fQwd+rkPUPL3b1jYns7gFikL1yOCd3z2kQbDO5YcGuMNHK0oCf+jt0M9tCdELbjYXdv97+80oTnawtAPDPC19f7ut1eneT1e2JbxfEf291XP/4a5ju74P4A/nPfd/pPG3MBwPoxdFeI+S8DLmPhQ/LCiqMC3GOtPB4HjnMk9GCrtIWdCFwHky8+3pv7eL+a39+z2ijpJTqg9XGfx6+3fU5O3o2Xr+H5oILO1itpViSPeyTXxulCAA1LDHijY17D1kh98vUsPub9npFg8rVg39z5+uK48bmjlku5NqAuTdlmyOdsFT57YrHdEiVKlChRokSJEiVKlChR4puB5YEF2TZ5XjbaNxBypi+2Zo/qCj37kgpwg+3kXd1ID/UqlMu0O4ASFdtzF6rCeGyoih1Qs8gcjqYeujPsNwiyUeDPtvjuLA1yuDIXEALFtFQ5qoWc3LaBW6JdWehqSivKR+UX1F0GlRBaAMoAQDqpH0D9nF3p/P79g7x/eKR9Q7XAA2I1xfYm7dajBBy/0EOphuNCyQyg1feDWXAA/mb1I0AD4TyBLJTA6vWJJ6GmpMUIi8GpjzL9lHEfDiMAKlhcEYo7qMxnuT6/sD37vpFusOJ+bSfQ9E4VVOhuIRGKG945SNBeRLSoH9TIbHtyzXWn4MRz4EGuDlSttb3eioMpCNe/VHELfRS0shMqqdoMYqg3TbDF4BNe7G9n8S8DNO7RTOqi6vYsODfFLNSotGuBytrGTLNIs87SzzdC5XW5medwJcuySeuF7dAfG6SJrcgyGnA1j1kUXDw+yjJeZLaCW7BISfW13KIFCnQoJperNjM8Xs0CBNcCoPzl5Vk+P73IL59e5MvzTW5LJRMU4WxLvxcF3e41/haAc/yqou/8msiyVB2ci+1lWL+3EMDb1aoFdjZmOWLPJz7GF2lSKRY8u78u3xmxunr8jt6mXAZ7PCdYImBNqs8dfMwwlImuppZDj7WpkkOP4n0ifQvQjOfxMqjQN3riQrncw3sZc8UV/m8tPpFPvr6zlIRJMDC91ts2+tK/1WPhPr4Glv9sxHO4dchXEerdc66Ovr82LjSpiJwOKrPA4O+drcOdVM3MooiNYJwv8tvfvJfD0MrPn76oWliQfNHJOaO43Ya1bZG6BnzGOu2fL9q/WIXS/hRPfFlSEyp8T17qZfraZAkYrGvWHrQ6fwPwq6e/rix86yvv59dElp8H3UCV8dPzRW7Xs3z69EWevnzh/Dj0Pa10aFcBxTOTeo36LlNhjAKw2PWQdxJwXiC5VLk/ue7Eyf3iSdUM0H0NUEsM/bvnSlXNjR0pq6yzK7lLlChRokSJEiVKlChRokSJb2WFkf4VlLBeDM5RFLfi6xd8hUS5QBYsBZJ/blJTKR7QbenhSHYcfw2LyZli2bcDZ0VjVjruuYDRJyu2lwuZ2Wt5jXtfWVdBe4E/3168jwhW/BhZ8Qw1WofCTTMKvblfrr3Toaar/LJkVFWRAMJJDXoHqAJXcrC6U2MnUBeUn+meVAHrD9ppJD9cs7QgpAbAabJ36JuK0CTzM8Wwgg/1RjV7Cmvf/M7gyxxFkAYvk7LV7CC0TW2TvbcDX+yQU4FOah82ZW6sXbPtEhABklo7uspQQZ/BKVJpkxUjQQFY37ayrWijlv6qOtYzQFdwjesCGIL1BqC0QWGAoxZQDb6rud9d3Zinj9mDGAxjwcmkjlSvbiR1xmkmZMZPLwR5N0Ry20dxaXqV/3GvhvW2Zj/d1/pK88b7OaqWffyZCnnnZbsfQs729+rleB5/oxcsy37eb9zBG5GPm+8nXEBYZ/L6hB0ZCiIbT8yYfU7yXWfSxhJVd0mtvcY1eWSEZ+8VsfsUld3t2+2RFLZvj++vxdsK1DhC4tr9xpHD2rp/7jVYTsmIsC74Hz2JwF0AlkSifQ4L3TUyDJ2MU08Fc4+xvSCBZmr6AHC1kB52FASP+wBS/8yg0JfkhSStH7HA5H1P7mrivfrnbpblHJivi1Rn07mb146Cm1AtSxh3fmKH1clexqxl3hi2ocVfW8H4Sqte4+EKtz89NtybukSJEiVKlChRokSJEiVKlPh2VhiEkVbsqgI+U79iahDxJRpfhAHhAIGgTKQ/cC0dHlAumyVG8sZVkSgVzE3d0l6jufuCPUPVxnPAIRJqNhFzkUhfth2CtqJQlwretlXfYYcFLqpzeNSIwF4YiuQWNgdUW0KJPcnY3uQ2Ai7VLJrHbfNJ2mo+sHZYvxBXbOLB7d2HXq4vKLB25THd0gA+nIpz1JZCIYte9/Fw4PvneTLfTi0vpnq5AI0r4eu5lZ9K7GbvIW3nUmsReIWq9y7AJ2waoAFvDMwDMEMx3R8f5HA8yVg39GWm/zV+Yts2fK/x2sTGFIIThmOLO4616DCieQcFww5LjM4kJaEW8VPw4fYeqsTGL7h3AFmoE3FfOCHOQVxivtnwSpZ1soJgs6mCrWgYaajafchibWhSa26x52VYQcbYh16UEeOX165t7rv/4bXcHU8sPrYtY1J+qwJ7ojc46vXVK9TK8KY4i2y9SNNxsjRdL9XxUebLsxUBRB9D3ezSQ/V0hWe2K4Qx7pdZ+wDWJSgCdrlc5en5zIJnf/fjZ3m6znK5LfRX1vbZwyGdG2hn8wW+C09Q+GsV5vlN15yPqq52BSjGnfYhX0fxpCrFVW2fi4j5fWj7q5oZY0P9yk21m6xMPEkQfac92WPq9gCrVUVacUzqUIqqXV1f/P2uWE1A2JX4cOnl3Kuk72qqleGLfjocOBeHYZDDcZDjw4HF1vpDx4QR/JZpU+tjkspcA4RUh64JTO9hss9jHdd7Of5b0Haffnn9e3xdVBJ/DTvfK5uj2vYNgvrm+SIY37/WZk7IbelaL8ESo2lgKQNV7iwVivgdB/n++w8yHAb59HSWw9NFfvlykfNlUiUt23eRBXMO6wQvYUtJHHjWcyeE3YuPDS3imj9H9gpfXZv09bkNdK3yaw3JT65Zuk7RxikkWFwT7yMzfXbhs6g/ilTwk34vbdvSI/3Hv/s7OR6P8vjuPdduFITEFSBBtI2zrvbmb491W4dOvh7NZ3gSbKN1UQLv9jcEffyh/ub6lPs2JhJc740+ws6LApZLlChRokSJEiVKlChRosQ3BcvqIWu+lOmPrlQNICHU/IqKZQed/MIOiwLYTCSrBXgvA+CaUtIVYKkYUVYp81d7rV6WqYqtKBFBMB/qCR0uNYuYg4oaNhP0GHaoTSsCgF8ATddR2wHCz/vt7l6TzkFxXamfpvolx2MY8E1KUrWaALDCtfD8FExnj97QBbv3EEwHxbW3B386sDMFnCqD7RlrvwQACYu9wKJ5CLt2Ovlv2vvre9Wp+xpn4LTXz4YhFBAZz78r2GbF67SDk7oxq5rt3pJCONoxRBW9EUzAlibbLrxSGkZeZmPOuLTJ3M3qxBMnUCyjsBvsMGwrOZgnYTphPqA4QPdsP1FczJTVpliG2pkqRiowtfClj4c1eoRbn7H4FopUQmW+LCxYOc3qs3y9opDfooplt55OKvIkZXxT9m4jMP3lrgXTe5OFchZX2tgJCDs872NTExwAfmHMO3D284dCgX6gvdJ874+dL9SKqcU7SocI6tmvCDCTqtQSZFwHmPSqOW8BAb3QKOEffre5iYSK+zMnQBl2JGTf3QBX/WZ2pQxj6I3t7353tft+efXe+FyUqb5+nf//fn7ej4/YsvHmvnbuoN5Ob83S9PSJEVS8UCsj0cNE3KFjUhHwfpo3eTpPUlVqA5FtJ5Cw4qSzy1KwnL2T9ztL7j3Dd2pkU6hr8iHvenCv+vv7U7W8/yUUt7xrDd+Vkc4B5Tvsj+CtjqJ/6yrX64UqbTPx4ZpbEaDPTCppAVJPgHnm1Js477JI5/Zp/kanq8o7J9DiCMif1fkei165RIkSJUqUKFGiRIkSJUp8U7AMvbBuuTebCyj+stsEw4GL4wT9Xuy6MCicRa7jaNCgkbWFWhDgppZmVkUzmSD9l1erf6bQjTyBylODvVYFDa/lt+IGSmn1pfXt1vRZhpqY/scNfVVb+Kg2lXQtQBG8jdWbOHNKQDE9AVSpuUASojNGYn7N0aPZvTjDTxQ8bOua/tILpZ6V4HIdklK/CNV0V8nQa1GruunN+9K8lnlTgJq1TJOpWqGOrUW6vpIalcesOB+3V9etCRgVxLt6DYgWym/cUH888G4ASdmKON+8JLiBY8BnWA8EawxANFMau3UwFXMBsqMtqFSHdtyAq1SEof4+hd2u5tQ2UihkEMT7FO/EDd6DNZr4jvq82L2ycCCu1fsBquJJBMkBzSLoxdW9vX/Wnybi3QEiNGy6PrQrkgsAqTj+wOttuwfZGvijojDjDNdv/oS6eJkmqTbYZuiWdwfMvFbAyr6X9niSdR5lHbO/Ksf4BlXmwi3zt9tIEAXVOKDbhcr3USZAZs7FSqBvnDFGqUzEfAmGCaYI1gloKuJ0k95nXuwR/WsgOcIxL3IYilfq+NCdCjom9PiYCwCvfBmUy9xtoAU341kT8LuzhFAhuSWurJhctK+IuJN/STzx7pndr3v4Z6Jic22341qCpusaGTq1sOm6TmEyveABlBt6w6MfmQjB3GgH2ppk0GhJLVOMZt9kVXPrz/lOsZzvbw9z33oOY8WbzudFLv62b4B4nOrPAGpfvfephX3C5i018/7atI/u33OfALLknz+wRmM3y9DKQTb5/jcfZDge5DpqAmUcMc7U7x07GTRx5xY4ejx+HuminxNUJjFmEkBXvJQ8S3eCOYm+33Sniit/vYgn5zyF6MGrnv2m67/bcqBgqRaKVVTsSTh8BtE+park8fG9LMOB8+N2O3NHCzzv8an2cBykaxuZb88yQ8WNHQHoj1i4Lzk8O9DmJ2/w6cj+5mpfhATdPpGw6y+33zDoj1oFutehoOUSJUqUKFGiRIkSJUqUKPENwTLUpK40pfVs8LVNOMOgRxR/6dZgfOU1E4gFHrQoSAcYAJgDCIAv4qogxsZ3B42p+FhSjFlEfqFvSC9Qi1xTLUOxGmwq3MfYf1dFImBuRQaoDARb2o1PAIqaqlGVoA6OFMg5KEtf/Q0qO3inEhlKSNo2wK43WwS4IA1PES4TeCtwAZBTWwkDJ4QWtW5ZNnkqXtq2UbVtbcX+ATjw5wCWc5Ez+kD3vd2DAjH6Ls+LentaATy3omBvOnBHv2mncgAQbaXuUdsDehCz/7wNcT25SN9uu7l3phXTSopP3L/ZZzgudfDDjsJfAfgAkHGddh8ZWAaolcYG4LB5R5u9wm5cKd3MnWJ+2LyMCgAbauNN2hZWFuZDvdYE8iqg1DascWxAIiqW8W+9d0DKuumk6Qc99HzVJIgBQm9DFO2axiuTC0gkwBKAUHlGAUbYkigI4pnYJ1oMUrsoKGOTItH7MCYBfDu/A6t7xbL3g3tl63t994DCOh0TCrqCZ6zNEjJob2dfH3zuJEW6n9LV8xh966v36LqTGfd+NbgDsQYQs0g1+697jsEdcX1uY94h0YS1gFA5FExDcgh9lwoC1lrocgfvkyo+F5BzHXIq8rm71mBd4Z7hCSIGwJd2HwTAG5Smu/5y4JwayVeh1GNfgYdfA8Z3sPnuZVp4TnZQmWMlXWYs+rjf1qGJOS2q2nS1dNLK47sTrSF+/PmL9C9I2qxST7rrhAUhYaWBoqzh9nUNjTYg+boT9DfLh92VRCsWQmnAZS8FqQk0vUa7cVf+uy1POhuKyOp7co1LHQ9ISOA8x8NJVljhCJJDV1mQVEJhPhE59J3t1rHPvFXtXXw3jlowRTsk/L//nvslrfv0ps7ad34WpjFlY8w/T+3eYkqgRIkSJUqUKFGiRIkSJUqU+HZgOWyZ9U3dVCzb1mIqGgkwM4B1n8t5XWSEf67DQyiXp1nmBS7KIvO2yWWEV+2sCluLpDBzUhSL8PH/3D4DYEB9Xp0D+O753Tb+dFR9PRTCh+FIRfN0g6+wF69zj2BTPRqoVmajQAowyi0FcN7DAA/Whoq0vm9lbFTJygd8QDdYGqh1wbxMBIVoi67rpe96GfqjbsG3LfbXq24Jh/+0WoL4XbvKF0o+By3wwbViirNbj6hPtErA9b1uZQBwSgjcApIpmE3boVlAEPBZAbEmABSc4li8Fvh/rtBgAnEqFMnbsw0uWjcqgFWrjQQWeQ7rC/fI3fLI8b7J5NeL2lnyAuejqs+gHWFztLxQJSGBtr/GQLlen3q+BlngPhxGW7tFhSyViXi4ah0K2808kgGZZZYKCksAIKoQMcLt9VQtq1pcblBWq8UFrnVGIb550t9hgTGvcrmMBMu3SefG7Qaf5Ulu0yITVOZWuE+9VM3z3ECVXr8XBvQb0H52+xeOkwbjXRMWGTKZejjR3AyMvXl8TPp8TnMyKqIxrvRFd3Yp7n0b4F06dlA5h7meLsO7JlnXaGJK4bgpqMPx1ILHhLIEyTpXsHa15rF8PHRyOnTy+HiUh3dHORwPBMw4PzzIJ/iQM6niC4nOfU9gbPSft4PeleEz9hqKWXo2LN6QeVYHiJ9JX4S8/po/B4ZdJex7PTyhEl8fFc8RzN6rpuNro8I5X1uCmVFBnsZELuyZHlh/YA2zUvfP+TT08CwXOR4GOR0PXLZuo6nxYatiP3395V3BzhzJNoDnN7xPPOHBddHrpqYkYVTp5mRAUtVbMiyOS1/nck7M+t/tm2y50YSdHhFzFHN4nC6yzs8yDEdbY3WdA9BWW6OQF0ufeTineiWrkl93jqinsqmOfSeIr/FpCUaiL+TQ0qIQhd3WAqHbSpQoUaJEiRIlSpQoUaJEiW+kWDb7W9/VbEXH1PJB/wxArGBIv7TTeQDYjLBsNLipKrp5hER4kes8y3nENv9VxhkQdu83qkpV+8KfFKb2RZw+lV4oMPu2uquCch5XprlCUpW9gGtd28rx8CBd38tFLiLbjV/2HSoTGCbLAAdiChc6B6VmZ/FwGgiVT8eBSrRLfaUyDYWqUHoQj3nV+5zmUaZpJlAY+kGG4SCH4YGKvc3UsnV9laq6SdN2cLJIW/h1y7e2CbaIo4CVFiGzAmoo0Id7w+sgtE1+skHxhmt3319YadDewFTN5t3cwWaC4Ez9TFeq4XIxNB7TgVkCy95HuFoDiqzYBpsSLfyW+sJfb1AtKXfJ56AsV5W3uynwvFQXKhAlwUHDJNCWLTZUIW2ADdAZ6mbz0d5B9LDNP4NCs2XY0UHb+m4KQuq2bQs+aRWUhqgXCOsUqJznm44LFBncOqlqKFwB03pph5MdF9rcijB5W2eqkvFvgEzaYGBevFzoqTwxIbGyqNnLeZTLdZZp2WQCWAOENrCsgEl3FKS+cG9qbcWg6FVFLnyF/Xmde6oSVRDqKnhXPFtiIdtBJ6iWS5kZNMOY8m33ZvuSCvK5ItMUxvld5rMdx2s2uA2/52tJYNvHVKZ+Otoxzt2CHT1ncBnDvqsrFu07nQZ5PPXy/sODfPzwTh4ej/TGRZJpRFG1eszrkPkE76AiFKhpF0NU44dkVvK7jiDXf78HulGS64rkMCZ3R49/jNrTt87x1nveoorxvfH1UR38tsf5/j25D+kqbH2HXQ3S9lKtSCyo/c7hgCSMyONpkNvjkUAW4xw7XGATQbA867YSh6eAylyObI2K95LSC2Z3IgEs7/zgk4e8jy99Sj+/3NdYX8fETfpMMXBPf3X1gNZrMVsK62vcxzSvcn45y/X8i5xOj5rXwhq3VExupOKr+QPOWD7OOVOJrIVS0Y6aiCJY5rl8rOT2ZpLRHFiSPYtb1di4cqisS8UbybUSJUqUKFGiRIkSJUqUKFHiV/VYNjuFLB42SEwQaF+8DVrlrehZ5QgPWNc9BkYsM4ugAR5kL2Z7a0YhAUbqcQ1eu4gvvI4CVgAA0q89dEmwQFQRC1Vi3w1UDUNhXNdQmmZVIBXLbuVxBx8IqQBwYafR1HLoOvWLNgANlTIA8jiPchtvVNvdDKADlgCgtTXUzVArD3I4AGZ1hKCAMMtUyULwYNulTZGt6jYDsoH5KPQ1hRstRPA7DBNcZFlTmW3UJCnVvM2S9jBZSCiqTADJt/irNNHax72id5RxJxnPBc2gujMAEoqsvdJbpoJdd4AsAENukydA9e341hbcR26gKVkQqAd1llDa3+iP6qrzDGGJkUxmmnSc3LaPtoM60syZl2yFoOMkMzZVd6tKXbMCbQKPddtLDUUzIaQd04v1UcFs8DYU77veRhmnWV5ervL8cpPrDQkL85j9Shsmdbr9wQFSbHPtKi+UaRYtaSQYENxxw/zOu1J7GRYHAAxwDasAf+f+eu6UrWGtSH9Ox7QxG4Ft+Lf6YN89HywwVCFvRSsNLDvIVuWoJoeaYIeBpBNtcqwQKNe6CP88ycSpYd7kyQ8hNFdcpLY3ILOB+dy8hv2dknvxuD/J/nyMe3MqnE/j2p8J8PHt8NdH+B1vJkP9dHtvHk/tdNL1JOiPon2NyIrEln70VGtr7Y92X2UYOjkee7lcJ/YBYl7Ue9+TPuqx7stSLoR6r6TO7Z/MTzLoT8kJswnxLS7haa5tqX9sLeRct/Ho/RnTZBwTeT7Q/7htOFcxh9UKStdxjDN1FfL+vitS6kX70owLa2FI6qh1kxYETOOaC0AcV6/7MtsyxfW2RIkSJUqUKFGiRIkSJUqU+AZg+bagUJntrDWrBXgHg6O2BpYJx6CmCoAZQXZsAFKFWOYaST/hlVue8WUbnrEKVZJ+Ub9QEzzosZJ3cPrC7VW5MiDFl3f8W40lqqTwwrnmGfCilcPhwMfj6b30w0GmcZaRhQV127gqwqBCYwU8Bcyu/qVmdZWhaeTD41H6rpF37060wMC/8U0ffppP589ym0Z5vmpxpvM4yQx4uHWybgApR/nw/oOcjif5zW9+kA7+vVC3Cgr+PUnXvBjhUAUvCw5CcWvQlAWmTKWrijpYKKi6b7yqspc+0oAbfS89vGFhmcBCaeqZHYxuzW4CKltV+uI8KnTT16sfqY8ItSHgIQxCu/pP4bRDC4W+dIzgPwxaU9mpKuJUsM3oES0iUOzM1aEuUYQHt4GlpkJhPDwwXsySwxIXG4t+4bmMrwF7FW6p96mZc5vfs0NBg996ewnMsLhjMxAqrxgfKAA437IFgr9e6b5sUCHjfTPGUyNbo568VYfifQ96zO6g42nM/sywYsG2eYisAZihWh7HSX768bM8n6/yX//wRf72xyf5+fMk47TKzGGANorC0QyTYrE8bUKzNAjKYSrV3cYGvycop+Mr5RYIq4Lq0V4DPTCBmvmhs4ifQWW8nFYGFSxm9DkqTxOYvLtWc4Tw5IR2gUKyLaqf+TMntvBCswfnsRLyt/GvdhjqnYskWAelNl+38tGax3k/tDIcezmeBnl4fJDDAcXVOqq6AQFbFOyLanmMA9z/pgpmKEtT+7iy1S0IvG88Q+VjegfuPSNgvzpc9kTSK6WxrkZ7Wwu0G8a/eeKH12S/3TTA31Arx2PH3996UH8cvKT9tTGRY+fHT5sHek81PdKrdZNma6QfFlqafPz4SKCP/r1eZ02qYK5ZwobjFN7q2JnBIq1IxOEzx4D8TkWt6w92jWhfqI0NxoKGJ6BMfeyJLx3QBpcdPK+puGc6D9Yi9qe2LAtosjor1hvYFNXS9zqGfvpR5On5IpfbRCyO9Xxoe2maxRIYDs1tHV4WHbfYFsLEAex03Mcen5sYfj53FYAzgXI3RGyJDSpqh8o6jKH6hg1PAcslSpQoUaJEiRIlSpQoUeIbW2GE0k+msNpvftev7ASXfI2/T+6KsBlQMmjilhnKbLPyK/zf3pYgKOGSAPJOEZkVmPEO9gBFgWujvsYoyOUezVEVxkdWZ0ZNIKEtt9I30lPh2PB4eBoqU/VmVS9leEnDZxqeyDPAA7d/qxoSUEHVka20HSAJugTKSXg296pCBs8wcO6XRshgRbCi6iw9XIVHe4ygSttxl2wnkJTKtrWbYMSRPPmXA5bIom1M+PuzQHDX5H4tSeFnL8qHycrN3bUCmBCs7KSooY/XV7BMj2HAiNdE9+E8StKhAnSNA8yhUj5LKjpGOm42CBWL+S0JVHshPH9fdVdA0FEqlJpQP8PbePUigamNDC6ZdzAgM5IkV3orj/yJJAy21qsdRTZmZeJjN0/CXI2/ptmq1xgV4gqXstpT7WD8fjK6z/Mv+xlHCxwtGmbYOAJuf0/gmbk7vOCkKi39Umx457+9EamIHGHzG1PY/zMFqOvUdyp7ex13BbCwp/aRwrpg1J61rjZn1H7GVe5Zqe/L1WtFr1uD7GP/l9TSwT4j5IDu3hPfe6f7DsmGPAr9bXcTdbfCWZv6O/4Ef/bXZAXvW/fkSRysSQ090jckhDQ7Zr7qtXRdS9VyD6/6ruNaGm1VQiOGn26HcneRXLtDn8Rkho39NN7jwpaYe3h/el1OBmRgGy9Lzcd9Lulnivrd084CCdRl4fgakIQ0uxZV7MfF2cdYLq6Z6iHeJwTiRpHU7Pu28FvNI8Z7zf2mS5QoUaJEiRIlSpQoUaJEiW8Ilievxkb/Wytc5jDGpFEArPgyPEMlRpWyQTW32jUH16hKTmAKfMGUd4YwrdCc+0Ryv+8OjuSiUApgFAT5Me3bOIp7md8y/VVrkbZr5Hg8yPFwkKFvWXTPH/H42DJv/8jb8aGwqyo5dK0ch17eP5wIQw4D1Ge1nC83Fln78vQkn5+/EAzCI3deV7neNlph9F2l52tbOR07OR6gdAZkBpPAtm8UsxrYKIDR4w2F/NyYwi0qcvDvkLAZ5QCIITRHX0Gx2Wi7XqdRYTZAB1Rv8AQW49wUYS6yQr0GT98ZntOr1K0rV5VqeEGs3BMAoPAFXgyemIexq483AOqgWHZvDvYflIkQRwNku6+vIU/A4RrwSYkKYaBtmQeY1SKC8LZFgsAaAspE+kHDE3emQtj9VYVF9AzIkFGpilLtAgxOO6SnjYUr11epWOBQbUrq/kjrFlitbPOgnMyAqfqOa/E8Wi9skB5Pel1UMIoWLesO0g4PfM94fZYNHr6cGarIvo1Qak7yfD7Ly/kqf/zxs3z6fJZfvkzyfNnkikPSVgD9BlU+1MGwPTFVcQChClQd/r6RLOLWfO0vtQOwfg0CWbUD8B4PBb9s3mV4G4GfPZ1gcsZX2fPcgJoXnoSqmdxebUG0+GcG0Z54SH7gr+KO8vl1cdhpCgJzgv211bLOmywNdhfM3LGg1iJIGGFN6Dg/MZ7zuDAbn7rWwnEYI7P1XA0VdCUVJ7FdCuX+ZiNjqnh6heN9qYCltoNefYClnO/eZnuovw8O7vBvT6QEVX7WcIdWuj+e9+0ebu7b9a3n3gp/TV6b+TvsL1w1bztYqqZj32MNbZqNBRTR9uO4UrHcvdTycn2RetpkRTLGtdhOW9XYWFW7HEevrUN2QJ2fRbrTQBNQnhSzAc9J5X3gPtoOinPCyZOXKTGD1qaVDeY6joG1tpa2P9Bi5cPDe1m/+wvpm0E+//KZivjvPn7Ho+Hetfjq/RxVKK3FQStZIOpfVK2siV6zp+J8SM72BsTT3oOcaGW75SKqWEvVgqQriuUSJUqUKFGiRIkSJUqUKPGNFcteoMs8alXxp8+5HygUvFo4r05QKKs1A2wyQ9S3CnO5AtYFy2H3blBU7lFHUpreqSOD+DKzNrOP0O3tDSG5btNXtaI7QetR99vwXdnmW/5xDMAnAPXWVHcAYrcRVgWmVoZPrnEKtd6tBJX1CCKt6J5u17eiS4BTKEZI6ws95UwoqjeUFHKpLbJi2dtNWbheD4v0cQu4wme8oiHlBSiz4lPwG17z74BfgMwEYLyefO+Ap7viiuzn7BO8QWF8BxdjIaxcECsRD/OGNtCk5McAes1jE5TxfaY6dasVWnpENaHdzwoP60WaVNQrqxb3EVWQ7oMbVNiOsVCQzwc6FMcsgAgYA/FlKxuBWR7Dfh/03qWa1dXT9NXQPsH7ms6gots8+DhRtTKAJ3xZr9dJzlAsjyj8uNECI7eWq/ptN0BkeWFu3f0zKRiz0jz3WfQ33/PFUPDLD3h/wiQAjkrgexG1X3OcaQ6Bd74edqQsAX4NlO/Um7s7tKMHz2m9ZF2/sp24A1/3YTYVqVb8S2rS/bnd412Lk7oXL4tlGnnPc3KfiIlt+VY/xfuKcPk+MZD107mVFZ9nu5z9q+9P8rV2y3P4jRTW60vdvSj0nS/i6RKMnppKWSGueVdj/RVVLOM9w9AzWTdOI9fHZbF1DP7qSKok+W5W2qdCdm/F/TiOuzfC3/PlRllyzMoEFfNe8B5U/2plwfXBdrfADuMwHFkQ8nYbZRgGfv74Z4mOu69Ae/OkT+fTHGsi0VnfnCd2SAG93RT2OcPPCbZnLspaokSJEiVKlChRokSJEiVKfAOwjIAfZoa89hVaFoA3KIIdkBnuop7q1Zd3Ax5vVaJP25r1yztgr3I3fS3facdiUS0VE+9+qmjUgDGhre3ANrWywtxK+q6mZ3ELBXO1ERBDJQyAAbsKfMnXIl+qPFQYpYaV+CLe9T2VdQ1UjVDSQcG4riyu9unTs5yvM+EjxMLwAYWA97gt0i8bt3nzfFBFUlkM6DuZfzVggxa0ItwCJLLv/MtsiskK6lezqQAUg0rXtoy7ahjXRMVc10jDLdfagmgXFBkEFB/6XuGGFZdDC9LDlFu5s2ivcigN4OxtjXMRjLh6/I0d86/Ao6uD7adtjWdhvASEKWE2BWstLcAc+hpQuampCHUtZyrWSP9kI/dm47HBE5yq3lGkRl/25vXqADQrJxMod4sOPrVXwycATNX1JlV3UAXzPBpUn+m5qyJKXM8iFYr04ZwNruFm59BjV01Lz1kqmKEYrEaOESjaYXcB1fJt0sdlXOR83eTlPMvTyyzXeZURdivwK6aCN8ItnSQ7y5Id44v+5wG+mhewb8tHP7AvUg2wXLjTAZZbYejcMvgeRLe7OR0AoxfEdJjva8mO46UXm7LXJjjgY4a8d2sIIaUXX7RrpprTYDHWKFgucH2waw5WMa4eRv9dr1dVwi4Hzq15mnh8qM2RO0H/taeTjp9Fi2Vq54c2jQ2ChIKNH4WanmwIFjc+DN+kzW9BYZdG+zn9oXNE92rEdvoT4HUX5r+eeil6ON9f051tTMzm7WxqckG73fvxkhW7TVquIW2n5zidBvn48cStFKdfBq5t0whfc+yGsSKkWDvQ3txhgd0LmpRhv+sWGWtXU+8n3l0Fb3ZNz2AN1eQO/mb+9fB0t1tjsgc7K2jZoX1FNb3NZ08Mcf209RJjCyp2qILxZIvPgXmRl5cn7ljBLg9YMXV9Rz/v6/VG26Q8GJD8gCe5wnQE8n3Y1UDvct8NZKptb86UDMFrzNbH5wtTjEl8HsdQiRIlSpQoUaJEiRIlSpQo8Q3BMovvsViXgiH/usqyWLBMSPYWFa0iEApiVW0bGQtVfG9wIarh+Bo9OoEVqKsfgVuAM7Shyhf4JDFAs2qgSlcLC+Zt925XoHDZi9o1BMhCwAtVGSHDCiCoumt9nxaR83ugyqxtpYHimfYbqqxVIDXK88uVYBDvIRBoeqm3Sga+BufC+Svp6O+s10QQSq48EHi4h7MrUXlq2+6tFh2Av/AezWrlLSiV+YASGteI7flG7QDWcF4UK3NvZwerRAwG2YzPJKUkVZ3Yg21aSHrRgsobGEtGB3axmY0Ff9oIZ4McXX2UAUgWK86n1+Cwkkpws3pQ6K2v18KCdkUqYU4+0QTVaNMZEBdEUQs6KslU92gWxErj2CPYn/i4Cyp9wGoqCzFWcH3dIDWA+KwFAl0Jrv01S7XUIig2BsDN+3cohnEDWw94LkPxrO2H+5+WRUZYoNjP27TyAcB8uc1yg5oZkN/6KTG0oPj3Ipf5nlzRajA+TT9T1BrABXTVQnpqA5JpsrbRDg6a5DbZgFj376FyfHEOw2Y7+fNXC4gFhXN+td1PUMWmp4Pk2qG1w2/38mViyoajMjYrRlgD+EEtPtLCgOp9Fv2ENcskHdsdliaY/6y4JuuIYnmqqPcdBVn4muTSKaGi921+vXdtkv/0FlSOyD1ifuub9DqMMVdKByXqPfmP579X6u4U49Flfp+xSJ75VuAz6WRTv2Sf7vs+VQGzLuRawBPrqY65w6GTh8cDd35AuYw59WKFId3GhoG1Y33tFe4PncpmNZIuPZkl5VsOnsuWYgjtmuGyQ2uKhTndrW+tLfPuDST/KpnqRrYV9igbITKOcblc5Xi8clyBOcOfH8lG3Ouu/bwwHyG2XhV+5/LGnUF3Kmn8zW1+EjuOtRFcbZ3nZtoF9LW5V6JEiRIlSpQoUaJEiRIlSvw6xfu0UBV8O6EadNtdQIGkVnTQ43bMVsQrfTk3sMGvweabm7lDKODmSjl4BHObrl/FRo6gX/j3RZUIhwFKodSlj6rL/zJ2ogrZlGiq7vXieVDvHuhtfLmeZRwnWTf4BsPbGIpmAGdV1VEFjGJ9fS9N25mibZNpnFiUCYXVYFUAFXA3HA23tAo16pG2CFRU0/7CbCZg1TpBZVtJM8wG0x2iZPWZ8kooatUUeSNRV1sLp4sErwZ+Ab5RELDtW1UDAww725UoHAZMVNCJhgTgrW5IGMAfWdWYaDvfqp38WR2iuV1IAFXJ0iAxIBwPY0LfFwSFBOC4F/Y7zw+YpyBY73uVdZ5lw3XCdqBZFOoGYOTwy9WCC0Ag+ngBPF6oauSrrGCeaffuchtBbRkpn/tr278VVGIgbaqgXnpViDI5gHOZLccM314AzJu2bwWVshbao0q87qRuVLGMjsDbAJZhfwG47KphHAOPGc+hAKT5DCtUzo7bqYCcjRcHtbR0yFg59b2+J0C2ZBVhtjCkVKrC3ZKaU0Git0VK2NDaQD3Ad2PAvJM9QZG26QeQpce2dcGgvDZ/HE95HmRFcIbZPhj5utCF6eEAvBFpWk3atL16nQMQ88FETngtPco1eYS1BfPdVc4cf+g3qNM7A+8pcRKU8Q74nW4niX+0jDFgGWHsLtKA3IPinfb8a9YXzZ2y2c+XVuUAf932Q+F3fjbfV1YvBziZxoRfn6cddQ69dTfJosYNvG2h4G6MeuO6dTz2crv1cjoe+L72+YWe9brcKdn1cZOTb5mb5ibOpF/hsgNv/5zRvtbL8ESkefzHUaRm4LoOW1FX2zvB8WDlMDk34Tufj93Q0mPC3/CYJjleDwq+K5G+7+n3f7kqbOZagyXe/ew5MKLyWOcZ2sK9MXhmT/Al72cH37YLIyVZg5UQmyEn60qUKFGiRIkSJUqUKFGiRIlvpFi2L7xkhao4VnRYS+NwyretI7zQW1KnOoRY9EsteGogDmrjG4tHGQBW2XESjlK5RXAE6JnVVoBCPa0lWsJo9YB2FZp9oSYQ0i3KULECEtEjuevkdABEOpI/vqxXbm/H9ndafNS9VC3Uv6qEbrte+gOKMqGYVCXrssn5OhIuX0cU6sN5ejmcFBZOqOcEb9BpMiVuJa37OuPGF5HpusjSiAztbPcLewpTjxosgO+uwnVYWugWf9oluxcCgCOoNvgwlMp9K+3QSde3siyVbKOq6PSQtm8a14HCWW3Ptmo3bP2fcdP8ucHqgUXyAOy9vbUgn/WcqXsVLNqOcm3jpJ4LikJ/3m02CPus4F2NdqrNekCP7aAcxSDRd3XX0ToDBdYA4Hlc26/uoIeqvlmtQbB1vZJZNqiG1wY3aFJnV7uG+wC8ST6nSQYcwFtWWkO5zH+2vW0/xzlHhbCLjuNqGqXClnhpNSHSQo2uUw72KFuD++mlWQfaaqB353WR6zhSvQh/bvxt2SrazUyLyA2JC5tFqlgM2tU71akrNhMsTjTZW8v6yBWcqciezj31wQWMN8Afi2cGVe4OLBsYcxU0fcy7Kql+dRmxwmIGVInFvOijA65UFNAjw2QHyvTcNW9vJipiwiTddwblfGCXQidqh3Owwp095rQne0yNb8pmJpGwrvSdtEwmqbocY7BqUWBz1eSXm6hrVieseXfOzCDbDjVNnc/3UcTuqnJfAVPq5Q4ab19RF3uBymanXL5XIOf3vqWEfivp4okJ9W/eX8N+zDlYjl71Kd/x6rXe12wY/QeSR9sm/QCID/uJRd6/fyTc/+Xzk0wzdjXQO0hLa1qSSn24rZafFfPT9coGQbxl8+3Jd6BgWRd7nwv6CkXs7l2tRUOR4MFOF8DthkDcvDHMQmZZkWCERQ480bGQVbwPzIFlvsk0n+VwOtKmAmPzcDjwuS/Pz1qYMjtbaCIucX+F1zwPz6H2Ohg2Kvw2OJ4U2/n+NIGooNrhMwXYbhdTwHKJEiVKlChRokSJEiVKlPiWYDmq3Fjwy77Q8zs+i5tVamPA38O2ZHt7wiVUWEYwYtA4ATKHd/5lOZgnE/JmeZqqLX0LPyCQeitnKpLxzN5j1XyIk+JSVJW4qf+qAlKAQqjONpmbib8ruIYvs39Bh+flGgrQ2dlolwAMAO9gVfom5M52ARwBcEZxtkmaZpMNUAuAfJlMOak2FSPUybkT2PYK7wCR9r62O3RDWOB+yV6UT2mF+27OIN6WHEjwyhqk7TpekwNMbvFedet/Uqre6SujftH/ua+l6ApTLWwHFXWCtF7Yz4GjD4jYl2YzQT9pFDSEvYWPLBsjUZWrlhoAOrVUsKWozXuVnuBQHKvKM9l43DVjVly64bSpQtM1BQWqeUKzHd2jw4tyWb+xiB/u2YCSFoeDIr+x8QD4isfMwn3TtMqIgn0LCvZZkcRXmPBNPegeGXq7ugozweSYCNrPl1f6Rfepvb/1xKpdMekqXfdf1uJkekBVocdidkG3vLszt9WJp48jKcLzWIgtQXL/6dY52GVh8zKNJfiG0+A8NVQapWqV4YXlzFrHbEKy4jVDSld8B+lvsCPJCbB8S24N437GajGU7Ap2/fm2wliXG4e5ppRmBPXyV1XQub9f/ynC7fsRZdf5qphhvN676o+pGd5+vcNSXS+8eKLvuqhk6FuZppZ2RdiRMs+WnowF59Kx1RIi7WZJOyri9YVLC0rtNI93zREgdLBNUcsLBbJpBwffrnPexyUgs3+e0C8ZVjcTrFYWmcZRph4FPKGix/jKrU4LDL8u5u6y+VSoYKCfw5bodf/7ZHi/75Gd37nuhIBftSV53uiZEiVKlChRokSJEiVKlChR4lcEy/p/WpjMtsdTcddIM8MD1wWwVYKRC//lKjpXr+3VY6Yd1S/MubqSbocGWKjCl3VsYwcM5Bd1Vf3BmxPvPvStPJ564YWAlyrdzhDa4LMW7UNRPS1uR4ViJTJ0KMbXyvnlRSqzXrhdzoQH802LLD0eDvKAbctQOeOa1kpuZ1UhAwgCGgAWQdUICwm0DY5VL5NCN1Mpwr8VcKSuV4G9btt0cuiPPMfWbdJukxyPBzkMB5kmAAr31sQ28EWmdWTrArLAJ9p9lbNSWBVt4OIUSEL5BxWwFbiqWQwKiYBJVdvDJN0wE65DlQmQdnx4ZJtdzy+yYBv37UqITlCYOWPs0QD8jXI4a9sV9dNiW7RWMcAqi/Vvskmw7fQG7phAIBQ1SwzjzsRo6EN6SBtwWc1bmtYa6kXMIoxLKxXHhsFfuE/wbQp3w9nzmAcTIqPDdc7Jn5j3Y1WyIpiCStwVuQjAIrUXwYugmkbLoyicenxTUVp1tEq5TSLn8yQv51HOL1c5X2/y5cso58soL7dZzssqI2YTVN1hm7uDJ/eAjjhQwZNaqLhqN6mS3R7DIJxCsv18J4S3TlRRqSpzrc5fSq6wffGIYNReo97lmA8LYbkezlTsYUVQ8J6xmR87LRdp2chAGWNcEzp+4UjOGMzG3AuQDXOkxW4GGrUjqaGwHrmJFXNsqmRjZbRR6uqoFjKAfVAmw16n7aXpevXDtgQC16DQPmrnoGvAq7QLEzJYM30XAj15mCDS8S7JekbHhYNiV8zGvtnbzkiFdg0WFTskXL2hXubZXimV9z/f/tdbaY3XIFwTXq6eTVYfyZ9BQT8SKlygLBu5AZ5awVL8126VHIZWPnw4UmH+4/Eo27zJOt1shq1a5NPhOhNPrhRXFXjKl3mxT64RNm9j4srWTt0ho3k7zK+0KiQ4vcm8zVITEs+02sH5GuwwwbzGTgSbNxibt3HkuonnkWicxlUu14Xz/PPTF+3zWj2l4XuvQwVWGgqkFwPjqL/KHRhmfUMrHFsI8G+dK7qTBGuLJq103HPNtzG2ma88lN8TdkVAAR59mEuUKFGiRIkSJUqUKFGiRIlvApbvVI1ZhKbqLRax8i+9yW3TtuMHQrTTKudqShkCuUHqTnGXYy/Ac/BcGWRtCADgeqCKV4fKrkLUre2qDssKROeFXvLI1b6q0l1l5Rd2hSUAHAq9DG5B9ejbjoPnLFFG8PDU+wtFpSr1yIRqGS8jkADEXlAMTiGQ+7mm5g/v328ltzZ2mmPXTSCdikllZScVz2gn2GPgvqEA5nXAixMwTb2DtZhZLasVKORDSavBYivCmPpGLyBZoiTZbAaeCS4BfrvC1QpDmv1yKLy110Ir7FQrlRWQdkFBRahhPT2R/bgdVOe3G+n2P1mhv6Q0DWP7lbjT3q8b4121HHxb7Rjuu6z2G4G+p+PpuPIBnO7OEgcz7VdmenyPNzxmuY2LTPg7fKN3ACiqebOq8i0jhddxp0HdyYH1KGmcJYVxPFeYnmn++ljfNbO9NHqF62u1dzVpwjETLD3uGj4pgXfHTBa5eySWlM67AXgX/j4bi8lnOiYJrPilqpU18RE9at+4yf0pkoVNPmluxyD1fn0LdwplXzfvCkzGueGZllBkL40ANcJ/Aw7/aRHz18fOnQo7XFccJ2l999/Ty20dsNtLlhluPZ3sWLJ9EQA/rErghY/1u4ZXfYL5aUKnop8ci1kEfXdDkSZ/Ratr8uR0+WF7xi5BFu4/261oUsvfrwlB+OqbBz4LPOp8566EeWZCU5OccU2IhS1tB0rwlI9LW76TXGAzphliesH96DXHGXZBFCuMEiVKlChR4r/r+L/9n/9P/9++hBIlSpQoUeJXAMsBHPFLO5hepcotQGVYOVAt6/YJgLIG0vyLLrxyGbYFF994gQepMqRiC8DRtzHblnUDUHCchcILAcEnvogfu1b6ppF3p4M8Hk8yj4u8XG8q69zh7E2Gvpf379/Ju8dHOQwPLJqEImrqWTlz2/I8X2WcLjKbfyWtCrjNfyXcGIZOmgaembDHEJlGBZ30NCaMAoSduUvb4SWFiQANpkRTFfci9VTJ5aYFCsfxRkVl3cGXeZXTcZUWfs6mrF6psoNC072lVUHadHD59CJPIvMyUUndzKrkW2vYWUBdqUXl8M5lmmWlT6z2Cuw4oGDj1vPLhcrl7XhS5W9TS9+jSBmMomc9zuwKRBNPmkKV4D74fEZVKsaC2ngo9N5x8WQdoZ7RdBIFfATwg191Ah8bC+IRik83mccrC6r1xwdVtaMvbSwhabDVGIfqqe1wEH7Y7JcFHqgb1c4O73ldDsiSipED2dqdmkSzvWhtC79NBhy/aVVBX70B2N2n2DxXIb/HWFjGlY/xOsnlcpMvTy/y44+/yMtllD/88VnO11m+PN3k5TLLbcIeANOa2uTZ4TXCIp196hHsXrtvR0oREfZqH7B/BD7huk0++SdHqOywyhMymLsYS5YMQL9Bkc8+XzPAR+IHc3vicHHrmAzQIhR3b20mM+ycTGCxXVXxrHBZbRO4xtwVdEwLVSoCqQUReb/0OIeKWR9QVrNQX93QU3k4DjIcBumHnip+KGt1V4UVYfTilDGzEovwsfihXnNWy94toMyjAShiTPlzGQNm45E7ALrLrrm6PwLS6EETCgqGtXAPiP9cRHx5f4wIav26QmHVBJPvEpK741tShl7AavHjVBh9dDx2nLOPj0ce6nobtcCqrbHo10UlxjqXufZ7ss1WS16D756x9uHuF1+nPOlm90A1uir7wYQBhrkbxA6l1koKkwGG6cEN/+66kduI3Qfa3lArQxXc9vjcULumtu05/5+frlJtjXz3/Ud+FrXYSYM5gjXK1mpPUKKoIVTHblsU7V90x8Veee1rgyc1aUeD68HSR7U1dthosgqPIlkuUaJEiRIlSpQoUaJEiRL/AFYYIYIQEzALW9BXsxSg9zJ/5i/sdxvfA5LIPphRPbwHENmv2ZVV6r8s0rW19G0jPYpswUqC35+paQ1uq3pMfGnv+0E6FuGCHYZCZd4Doa9BcoLyvU8pADq+mxOgVb4NH1/SFXRjmzPhyCv1Z2ZGqhyG8kw13YuDZvxOeKn+ui3+ZpA+P7LaN3vZZnWfB705ueUexwNIhY+sbv3HgyiK1g5ZlQpgjb3fgM0sDQcf00Y9nusGRbTMNgTAGf36SrDpqlV7jd3nPUXSy3ff52yXYTSHVajY7G5TTGqtYF+V0jy7+YxqG7LgYDdrQT+ztUiDI+KuKBx2VTp8lzlgorLZxuRuvFtVwni/eJ9DzKRWVi9hbONPUC3IBbWtVUlOR4ZltYdaleABFeP1epPLZdTHbZGRykaD8sEXdi8y1HGhQCrA2r058a5N/BgJShkEVHCdLSZ8LIaeTreVRNn2Cz2NrYBaVJH6tSg/NCXtG3zTdMxRc2tqbF9H4sXm4nz7ew2e2QkuqmWAdZj2U4CD+aHJFEBwWNpgPAPKJRuFNyHtHTb1c4Z5omrceyG1rU+xj3bQ+B4sJ9RoliPhGmy98vX1dfgqoXMnq5m/9vr7uAPXNgvfVv26sj88s+vfty4vjCyzPlKwC8Uy7Iuwdrd8IAGpti6e4LK+RZ+Gz5zY1Dt38ix3N37tY9Xslewe0hWHnSi+ySb/Tb2gMV4cTAMuu4N52tFjoFu9/RWi02eZ9jAYdy0TDO77n+aeXWQqtpc8uIOnv/3I9kT+WeqfD3n+8pih/oHvtimK5RIlSpQoUaJEiRIlSpQo8U3BsrM/WlAm0Z0BJvsC6yBFlcW5Kr0WP3M/TLwumSvrAd0HE6yOlhWq4LOabkll5XiAhZ1afJmv5PFhkHfHQY7tIH3VyiqLrKZGrSrA44r+lY+Pg7x79yAfP3yU4+EkXXMkDLzC3FZWeX55kcvtKk/PX2TeJloadIdjun9c93Wa5PPzs9yo8J0NBHRauHCa2DBQ0gEOow0yfHFrDYXHdQOFp4Jw+CuTATjQ0JfLSg/MURb4IhOaZFDXmCLX+4GqyEBw8NJ5VHUxoeWkoJpqY/JbQPJKGtuCnZIEwbNzvl1knWEtYdYc20LAA/UzLDR2AkqDPG4dQHRlikNVnmrBPP2DK1udbcGfWm04tklfiz4zAaz6HNMiZOSxYNVBFSHV3Z20Xc/LBgQchqOCbVO7sgAj3oOx16j9B3sE29GXGW6ppkaPwPAe6CnMIdDhQV31aD7B7DxYsBgRN8CO/qQynfNmlc3UkeTp8ya360zl+PnpLNfLhUp7P+5WdzzetNS0wVg2tXghTDcvZPqrqhbcoJL52mo9uD8ZqTYiVbyqJnZdcFZZq92Htpe+FuMuASyq5B2qqW85Yl4nK0fnUNntP4RqTbf8wLGhiGbiw5JLfn6T8vrosmtQlSYBMb3GbQAlBalBsiyD12fNV1rfB6GowmQt9gm1aaXAcoDHLdTrmBMtleyeUFFFvdmfuF8LkyMoRuhF5GwOcoFSpbgnG/TlOv6ofP1qz0R1sf90wKn+1K8L9aVmCpXovgaKo6rYj3E/WN6G0q9f4x7NXxlsySNa4TUhcPLsxnrikBz/xgKvaxMHFXzSsSOgXQWr6+Gg8/43P3yUw/Eol+uN6+yywr8YCT7uS9DZYKdA0m6d4dlsRRgtAZSvP4NhAmGq6j0x6IUq3U7HZgjV6trJ3BHQYsx0moCokYjQDy2Mz67rbShcuevCFljWDug63N8mk1ne4LzYrQKff/jqI3N3rdS7n3PQLEN0HDsI188/vzzupcDniqmq+eCuBe8D3VnzZirgrdxAiRIlSpQoUaJEiRIlSpQo8WuC5aSSTJ6rQYUYVZl3asf0ft8eDfiYmEm2HNDv8qqsAqB0lbArvsAI4hZvwNm2q+V07OT9w0Fa6aXdYAsBiANQiTfgC78QHB2PvTycjvLw8CgDC+UNPO0NlgrrKJ+fv8jzyxdaUqwozlTV0vVD2s6MGGfYbFzojQxwASh8GhQ8waIBX/4BSVl4yavLmcLMIR4ARlOpCpJb792uYNHCS0k1imJ70yTrMu9ViaYe9UJsviXad5t7cy7zolvxDSzHF9FuojY1c2Z6Cv+pil1lHrWNseufhacI47CnupLJt5cb1VDAqDYEvg2b9h+msktwhl62rqrMIwP3SuU2CpmZIhp19gh8FgWJ43gxKK8qUv6cOul6FL9SdTXgTCpE5+p2eq+sshEyAdqaIpZwUsHwtpi1QZI1v1EIzTonez97dULNdGy84Lz9nE8lP1MtWAg4ijbA/Y6Xq6qTz1e5XW/sL1driiVI5rXiY6U5iBbecqVpTrMoZFJfWSiGX8mzQ+SJmuqo2Zb76A2eXsvMTtK3ZiBn1iterBMKTW/7BX7kUMTHmU9LDS8ohuJ/mlzCuTGP8qWZf7nJezU34cbbem0O1jNYzuuMqpKzOlOhvyaukirTxiCBOApnAi73tXSdjSnzXnflshY/u2tM9iuuw+YmizPq9bzyo6aKHXYDyJAoWH7VJeleIvD1Yox233Zi9jNvXkHo3nn670sH7/r5fxVVDMX/dvewP0fuQ7zOgDvvw4sJWgLD/V1caQtQ2yBZUUk/YN7X8v7DI61JfvzpF+mfnuV6g4WRJjFhxcSr4nXYWsukiSa69PMlqqVVQRzEy2l0J+W9q46TatkhrYFl+PrDs79GQgKWRvqahYUfW/NWtjWRnwOYBxsBMM4LP3WolrXoH47VydANMo86X9RKSXIbmWLaHcrxGZKTt/q7+jtruAJfgbvVArjXnhe1cokSJUqUKFGiRIkSJUqU+IcBy4Y2ACsNDt1Zgqqf7d1W54RKfKs85cYZzDmcUuGyfdm3ivaEN4CM5lVKKGVHV3/USvoOStVO1qvI7XKjAozKXEBqOx6+yGMrNbxT+27gl3jAYnzZhlJ5nK9yvZ1lnkees+sVMOG1uHq3IBjnVb48naVrR7leR6rSlhNAeGPKxMqK8WUQTHuLCbDZYTfAAkCW+mkSEgMKdAoqhmHgIyvNgp8noDz5qPWFwWq+xqEyQZg/ZwpZAnpAY0A/qFoVrN1eLlTNQonMgoYogkjwqn7W9DheJlkn+D2b+hLKVbe7MKLkhdsitEjaUU84GKhR2GTKXnggk1oquFkmWIPMVAUTMsMqZFbFK/tmW6VD28D2hG0Cr1KA6VmAZ1aoyHEOXiqU1arY5khaoFyOknv0D2Af6bteWuPyP4PDVOyaqtOAVfIGN9jpMJ1Fuipcux7OqmOZqhkw372MzbfVinXhHr1g3+2GxyzXcZXrtMoNyuYZvtmq2ncH6EjD/P+Z8HnTFsI9iQPkDnPaQXkqphZsVmJEDMm5askNKiSDBQe3+VMluklFsJevQ5MN4bymHN5xVT9Puo86AOV9IcoE1dL7Av52b2VTKtMaO6isMb+GoZfh0NEjF/7pLX92nGdu50Il67rRBz61Ow+IsQO/77jamcKY6uRsf8BndtYMd9rR1JdrSBL4/Iot7/NIz+UFRR015q0HqYdfnSO/5h4Cf0W2mv60SxUEqOy2E34M7zd/sX8iBNif7iF+gHj7WkKD8wXKf9+hgLVRP64eHg/y7v2D3MYbxyF9xy3bsLKhrRgmfOM5p+1SbJKkYpduS2OWO+pd7K/N9+iWO37NcZ69rnmX1+xkyQOLISQX6PGtnthY3+Dl342Y3zfZtiM/o47Ho4yj2izByigeS6019ufSzw9NfuxtkWyHi1sIeaE+5iZtveI4dq/y111fokSJEiVKlChRokSJEiVK/MrF+1SBSu5hXqkZlPhW9LSnPckiCXer7JcJUAABKb2Zw4ZvOPgSMNiXcWJkwFmaCmCruYK9xrbfw1v5MPRyOgzyfJnl/HKR2wzF8WKHANja+IUdUOLQ93LoYW9Rs1jdNE/yy+dPcrm+yLJeuLUa2/UPx55bsftOi8IR6m2bnF+e5Ha9EIZ1AFP9INNHqB5bVSYSNpjXshXugzoVlhaAHIAJuCacA8o7AFqAOBxvMJUbwMLhcEwKULZC8PFcCe30vggKAEwdzgMYQw1nkFW9jN06Aargia/rYAdAxfKswObQE7IBrjWHnt0341pRMHDapFoq2XA9uE8rwobQnqG3Q4Ywd0XE9EoDtjIQm315Udxw4iuXcZb5dpNpQgHFK2EwoKuq3We+/vD4IB3aDiCcwnYoslFYUGSFwhui+L6TqlGXVxwDr6IiGLCz9cSGbeOHuh0FDwGkcI9sS4PoSUVpdhOuvtWbVxyGNnF/Vfax7epH39HTGrDLVbkKgBRsNrLUC9XkM4DydZTz+SbnyyQvl4lF+y7jJtdxk3HeZKJyOypafd7ZhfBu0baYn1bMcQf63L7EPF9d7R621u9Zo8/jTM/87LRiMQU7xxvgHsehKU1ZLGwl5NPLyIUM9VdNeqTCgPZEUplH5WjIXmlyxBMm25sg022DXZ3qvsqEyqa0xvVjTTidBu5k6A+9dEMn/TBINxyoXPakzWJWKp60YQFMjAeM+W3MkDa1r45pPOIamZoT16MVIe2CM/jDvWlWBIXi7vsieIHve+NOVb/fDZB2OuzU0PInFMv3hPFriuZX2te7c8vdNbtfibeVt4HtYPD2s7ZhcnFr9bMBVjyVyPE0SNs38vHjI3eNXC7npND13Re2YYLP41wNdhKYJH+z46hPvirmFbTaJxBtJwzeWpFYXkv6SMtjLar705JnRSlVIZ1VziwaiuSYrccYG+sG//QvUteTzNNFluXEIqnv372X6w0+/wp/9f5yv1W7NIL+dA9wnfNZ367DS3erwHbGc10Olf1RokSJEiVKlChRokSJEiVK/AOBZY+gQAt/2YdLIAFJtBhW/okv7+6T6nW8AGx02zKBaG0qr1DIi0c1tTRA7AHKYnpaAiSs9D6GKgyKZNo0E1ypYhnnVV/iml+u52WWeZ6okAWIRlBF6l/S7d9JoWb3RPCgOi/Derb9Pl1f/IdBj/AF3os9cdtyKM7HrdXc/o2HdgsVt1TIZuMDAkwCKMAu3VrtKl34R6TCfHuzhNSKON4CYAzo0jgkdasOVfmhf1AIj/cNuFFpkcJqDjDQuzgmFcwiIFsCmCUGrtQK20UYmgCNPRwssqghLUVWGUeoyHEEVVt380yonArKJYFehiW1+6WmJEVWIKrft//77jbcKoSe1bYx3mEYrh+QMRHMtI88HQVQlaQKP9ndarehp8zFzgiDUCCu1f6GQlbvdaLn9+U6yxVF+wwoY2wDSulYiL0aFOFhXhIwBVXl/czVbohzKr9OIWgyAAh9Fhw23LHDi5jtMKWNvViM01S9qat3/847H3R83p0rDZRw9fE5n1PxwG+ocf087k6hD03qIKEDj+iu76TvkWBpUyFOzMGq1l0IbqmQriHWeHSG6wmw5CdsT4ZdDA6TX3HY3YHC82+x3nuJ9/1zfvqgvLU/vHHYr8Flv8mdlDzZkry6uMCWXylgI3iPbzHParc2yVDU/fhdjQvVsoJi7FA5nQ5yOAzc+aE2OlaAlfAWS2FY+6gghw2OHpnJUYf5bgljiZW8nNwrsnMxTi+QFxrc2sTtJu7X3rC82HvwbyQ2x6mR63iT2+3KZzj+6PWtdhzbG/PUSyP6HM2plzBP+DYda+ny0/3mtTZ2T4kSJUqUKFGiRIkSJUqUKPEPAJYzSLR6VPolGH+077UJXdj3VsCb46FXiNOrb+aZVgdQlaHgF1TItUBMyi/cgL3Y/m8KafXPtAJxskrfNPL946O8exjoczxOm5yvN/n88kWatpbDcUiFudpmk6Hv5HA40Lqign/tMsnl8izjNMo436jkhc9q2/YZ8jawh1Al2zopKEKhKXiDaqGmVqq24y3iOqHSVCVuUI5Vtcywx4ASF0X4WJjOQFbXZRuMWm068BgGAJODTFSrXunrzIKEBi2gPFvmiQpkqJ9rFuKDitT8Nelnq0DXATMhQkVjEd0eDaUy+0VSsTv2Azw/cZ2w6hh6gpzet5ZPk4ywwaByWsEHwKgq/wB2cHS1VFB7CMA5tIdBcjwHFSte69YTBgI3eh+rJQIskedpodXINE1yvlzMoxd+pSJ119Cq9bCekq+z8jpXLm9StzUfFB4DEhLEAPitUgEU4mdSsoZxDXUzOddMOIx7ICQmSMe5zJPCaRTbTccGr6XvZVsUegvGMCEXlOOmgqTvtip7oYoHIB+OA+H5OC3y6dOT/PzLk/zxp2c532Z5Os9ym1c5j7NMSDK4V2qAko7iXhXYc8XvPXwM/sxU75r6XItHuu8sS6olqEaRpYMsY/UJzrobsCtRg9LY4bPOBVVf4twYQ26Nop7Qbufs/tWqL9b3WXLGoTwTS2b9Ya+hUj9ndLLdwg5i6z1hSmPYYn1Asglz8XQ6yePjg7x//14+fPxAwMxdCFLJ7XoluDycoMrHWMLYtR0JBvaSftSLaFIlr2tbUnxT4Yz/NQp7eeneOGZpkdirJV9ewdu36N+9gv2tn66YjwB5r2jPv/u52bJxUd/1r2yulDZIHk/56jLvIGxS+mLtUPWyW3ykPS8caDAgQWISyuVFmhZr3iYfv3snw+Egl/NVfvrxk3ypa/n5l89MilHlC0ufpmM/o+grd3VgLOKSmWTS+1Pl7t11m4WJ+hubJQmthDJMTx735oOd9+1gfuq8V1W9/sSaZhbofA/GLwD4y/lZpvkmP//yE1/7cPog796/ly/PF1qyTDLJPI48OnbI8L3BWoMrZrIa8WQPdlpwpbAdK9jxATskvKdie+gDRUFf10IoUaJEiRIlSpQoUaJEiRIlvg1Y/pp6jn/LUq+kLzZQkDbymhVG3q5rUDGArqzsUnBEHmPndGALcJY8kwlFNxnXhapO/EdQi6JbVBwrzCIkpmLZvGZRDJCK2Mm2JttW/bT1/tUNZpVigmaqpAN0cPjgRseuaHMRpUM/V4OqbYXCJC/2pj7TXpTPPaajWtt/c6WdtVqVlcLcPm3gMt1HZDneRwAqnhzgwyAElMl2nTVUgIDq9HQ2WEGQqhilBgwKBQTzI4COYGPAre0GJ1VtnEyhk6d2Ui7DV9uuTRWIAMtQb6tymkpu/n2vx4btiIJQGpTq+WwreRZdukz+DVllampVzdOS4GvzwGFaUC3r9n1XR+KeDMwl8KSDOvU5rSQaqVskLBR0wVMaiYhxgkoZ95/bIo4H//+smFUkx3EYlLy7S/Y/3t33m2LFO5ubvUjTxvoraG12talfYrfu530q+ml/3NkNuLj2/vrdz9hsSbJuOc5bu+qds4fbUdgawqJ8+Xqw66G2nQIsxma7K3Zj2xI7SfUeM2n3ZvPWHwzz2U4KZu7GwE4A8xGONx9GSpojUW0c+vteY/zVnkz9tB85/y3hSQv/7fVx9Zz6/2ER3F1jbLO7v1lCQxNjBm5duYyEVtdKv0DdC4U5ALJa0bidCN0vrOhk7SpdWq+453W2gtE+vVN4W4HEVHDW/OZ1bTPvdV9G7lTruzUpPuf3l+bCJiNtfkQul5schqschndaMJKfX7a23imLeRRX+vvRfWyFW4ljVX/3+w3i+VfXV6JEiRIlSpQoUaJEiRIlSnwjsOzuB2kXt1siWPEjfkmtVE1JFZrpCelXayox3QUOxV/23QUGVRdIqBcVI0BdpdYYxkW4LVgIk98NrZyGTo6AC00jnz5d5Xwe6WvbnQ4KnHsomUX6bpO+r+hb/Pj4KG3byTxvVIieL1C4jQSWcJ4gb6Bw14HCIksLlbHBQKp14Y8s0vWV9AMUzTDshJdsVh02bgURdiWTx662nZuwG2BLQYmCcgfhDliFsLFDAToWbRtVqYziT6b2VbUuTqVtjud5PrOrwIVCcezngmzOt1YDHuPCruskDQqQTfQLsevWLd3TbVIP43cHkUNrN7HIOm8ywxsbsHkYtKCgQWJFuepV3FRQZOt1qpoTLaNgTyGzar15qfBErmuZhgPVdAJ/5bYh8ObPFbYgE/vh5Qov7EWa/iD96Sp1hftUpR78imnhQbntSv/jpodfLjxY3UJBdYaEtbS2uAPDVPyprzKV2Bz8aFNVAfpozWJQ92xOKRSpoOjGvxskLkKyZKdsVIdqKJZxP4fDSfruJJs8y8t1lvO4yjiJTABlfgybMfk4r+GQehdDoRmKefEFwXbF7EleuUYkMIZ2wQ4CLWYWi3y9VrnaRRg8VPsWLZbI9hXA2ibNIxwfYwBzHG3vvrjKu63IpKk7sx7bdiu46j4pauNlZG22FiPU+QGlq6rz4YtecffCw+kgfVfTdxvqeK5VlcFkwv6ciKI1jSUMoP4kFfbr69RbvcJi40AUncW10pI13JWhi0CG8/Ast3HDMafKf1qpJCjo7eAQ39tebXt0pcTrPXnhTvXxtf5vvUd9RL/lrymgw4Dgj/vX6Uq3Oxd8paN9RzhCUiC7F4X3oe1s4C4LfhTpzge+3BXCeA9egx0aWMswZ4ZO2rqR9+9O8t3HD0zGoJ/mWRXCOPSE3QtUqeORdzbQ7x+ydYxB+jK7Gtk+pzyBaO23rFAgz7KstcyLjg/tL7Uqgp1Q9CvWB3Zl5PoCwVCeu1MAwsfxKj/9fOZ5+uYP8vnTWer/8SjvP3wnNXbbHFtZq0VuV02iLQtU2zo21SvcEx5qYaSFRq2vuIOE5UzNNEOV2wsUyyjcZzs4SpQoUaJEiRIlSpQoUaJEiX8wsLwTOmYJoan69LeoPs50yBSjbyguXWFJD17CvvgWL+bkv2qxMIJjeNNS8VvJOM7ycr5J03bcPqyetVrYDV/6ad2A93QdLSz0S/pKqDwvI6EhYadDc4sNHp1Un7rK1JS8AE9WtAwPXL9DR1XGOoALYUoxVaHFIlAO+gwS+8tNaatAxNWGGSTciRxT2yMALFXpBmhiKmGSRPdhVlsMQkCzr4AVCYpdAb4AzBI+83eR5bTICqBC+AOIgd9nQjf0A9MH5mNMuEyFc1YHumrZfVId7ipAVICJfoGXCABdDYsRKNFhl4AHrqdaZWYbLlT5ob/gRwx4hLe2gNgGq4h+Z1ipqHexJj8M+EXVtqkYtZ993N0pToMvLvvVt/1rBsJGsnnmujpZvWFM1QsbDQNMu3GfExiww4D1AsavemujsCTG52ogKKr5TWEZ5MMKk/Y7CrhNnmprV3vmOez/iDrqPMn3/3R/6d1Y4xNZ8qgKT/PJ5RxyWJfVxaoOzr7f6TDh/FEIfqc9zhiU7Pa1MjZ1XVqf9tesqmRNcnAt6FtpacdhymDrt2Trc+d9nuYg71cTLButUrQQqSZlVI2cGzsNsvRTL10Ta+p1bKPCdi/sGiH9jBA9j704kvR8Pr6iT/H9sfYJkFed/hUddGjY3W+v3/v19/j9ZXWvq5Pdr93WCfOx3xX6S4plrAWbNI36ykOtDJ9l/PREHec0gfAqNa1tYvFErGmapNTu0eQS2sJbx/3yPYmy2ecFknF8Dj72qyYI9zs18vocFftp7fOkHda7CnY5tdxYpG+Vp+cz7/VGP3k9OeA3P19CYUesVZ7gSdMunJtjk9YibtOSPaKNc++WtdRzb+UXSpQoUaJEiRIlSpQoUaJEiV8VLNNfGFhDC1nFLe+qTXbIE7fBW+E0Cmk3qQGdFlWSTYCZ/E6vBfbappYeUBHnqO2LMZSq9JLMBfXwxXxaRf745YXveTrf5DLNcmp7+hO3BNAAy3g9lLVQAkIR29GndoK38niWZZuoslPgDBiJazOQjNgAEBQUwbuTX+yhmkQBqbaWprPiey6+I7elg2qCs/TZpCsDjSp4f1A5q4rXQLUpnaGoReG9CZ6a7pdMAAHIq8ehdzGvzbfl4/pQ1M08oI0X4Tpg9cH7WhepF70XWEUk71DbQb1ZIcNxuqnHc3+wa1dl7gKF8gSlNpTVUCF70cFcaI9XZRAZd4brBRjF+dtWFdQAqISC8GZ1a4EVlhHaxujzbsC5K2nHq3T9QaZpk8t5lvE2ypenJ97Th8dRTodellX7tB8GeXwASIWKsJWtauR2vfE+21Wkh9pRB6MlODRFAGBEn+plUg1zA2A1qGqUx6plAzA0GK5O2gGqujmrWYZY1sGAmf/eUnnp9h/oN9qOuKIXikmTUUKhDuU2flqvmheqe8La9vxX9h0+7+4gn+0k8H7Zl9izt9rtcFOAQ8PEIbOHtI+7DNwcoqvvqwI5Bae5RKFzT4eY1vbmP+uKcbVjyapYjntPuqRZYlY6Bg+zWjPei9NCjqbUIkw/QRVf19I3NcfOxw+POle31YqlddI1mnjyHQMOmX0nAa8b/uMsDKpjeAdpKYXF2LEM1f3aaAmC/JytnwtgKRIpSKr0BlCxheIVzdfj0dvY2yuqmb3Ipzf826v4Xql8/7q3VMx+n3c0cgeOMdZUSe2+xPtzRMrv4yEru/VQptamitheykPpWq7FXpeUyMJ4gzf2Dz98z/X5/eMDk43P56up5Ss+4Let89Z81eG/XKn/O+1lbKB60qg26xIkdBDcMbJi94SuaxXWq2Xi/MS6Wc/w6Ub/oSjsLJutsV4kUn2eg67cfO37rpeP77+nInocF/m8PMunz1/k02fspJnl/btHjtmnT595zAqJTHzuuFULdwDpmJ+wk4TnWPTz1D5X0tQzTym/1zhuFHh7kqFEiRIlSpQoUaJEiRIlSpT4ZmA5Ktzcg9JBbPCiTda5DiJUdaWevr6tHgX7AI3tSzC/bNfSoTAbVLJuF8DzKCLS7b+qYAWk/uXlyiOP14Xw8wgFW9ertQQwhSklFQwB1LRUno3TWab5SkCOL9dtO/CLOoD3jK3ursImfFSA5Z6+VBmaUhn2FbyyVJgJ0EGVdNzqn0WdqfAegAJtAZIK2oAZpb4KG+cJFhQACD2VeVmRRumxQUq0YPZ8VYDthdVMAc7t0KzatN/cHsCkszFVK0M9q/YetO0w8AzgDaCC7dsO3tTH10GZq/zUC4NID5YFJKJQIWdv250gE5A1eWwDpq3SdrCFgDXGIG0HyDbK9bLK9TLLLz9dCL+3cZXphAKDnXSwRDke5dCjDxUWAabAxoPHRZHFvldlKf1OTFnLa1bFtPoyT0wIKFzCaxX0ZY/kaEEQ0CmVhAvHpPlL2BZ+U2bXsOFwIGjAzLbfE/K7DzTaB5YLKDTYmP0Eob3aSmS9cgbMGejpVWV4nGdsmrkOjt7wVM3uBXt/cbYfQJoVHHNrl+RD7irMdNzoKRsBpquVzS7H7FiY6PAECbbn8zrUQzeU7ks4U8cQ+gRFDO/AaIKrQTBMIJ7XDkA6FEA7DL08Pp7YH+s40q8XYxvJKKqPKW3OcFkTUwbEYX0QwHDS4rJ9oVx2UbGpRnlvdfZZ5iVjHJkHuBe8JCxFUgQKaKj3zVojdUZQ7qb5kydTtqV+Cyjfg+KvKYujYPleuRzh8v37stVKvkddFHMy4y2JrAFygk+dM+4V7wUm9bME6yp2TOD5Oa0vOMTxeJCPH9/L5XKV0+nIJNf5MspMCG1FX82nXz+PfDeHKs3DSqjHTOCfbhIhUYL11JXDSECaKhqfC7DC4E8tSgr4rQVUffxowUhvF9+hgkKyj4/vZFlmeX76wnt4fj7L09MzPy8eH056LXTwQBvYGmEP/VzUMYTPQ+5SoCOJVtV0X+oEltNui/BIG03UJqREiRIlSpQoUaJEiRIlSpT4hmA54aekTXZglbSFd+wgQihHz/yCbrvGaQ9B6AM7A/sCbv6v7r2Kf6RidxUUuABSQkVsogamdCQgwpd5AxuAq3gdFK+325UAFeq2ZR0zTLAHC/DhQA4A0xbs3ALc1m/F9VSdq//pv/N2Z1WAuSWAwnQtxuVeyurf6o9UpAmFCG83vr4fOlSnUlWrU3lXPhoUVSWrFawzEgN4vkJZTU/nrBrnHTjkTFvTc2Erqkh5nfok7BkQAOjp9gnYHCqqJ7F7MkcxXNp2bUrWmmDMYIdFRnN567lf13ib5cuXi/z807P81//6s5zPV/nxp08yTpNczpM8PPSyVK0cjidZl0pOh7N0LdqrlqZuZYYVLgo5tq0s4yg1PZ/NP9kRnPkJr4RQen0tFMykdO6TDdVfKDrIe1XfZAJCFjwEUPIxYTcHmsR2AqzU9uFIXysdm7g6S7wwmTBbMcIESq2NTJXoFrcZ3eV2vNcgh5H4luVtOsLeGMGKRqaBEmFg3MqfbVzIpawwZkJVNi/Xe2sAfz3byK0dIixX5XO+PJ/8u0tOa41fgydJ9sX7ggO1c05Xi2Jcd40cTz37bq037iLQdUatXnyl8h0SsNfhHDX/cwXhOsi9YGdFKs4GSTYXaiGiwFktDBR2Kzj147jaXRNLmmgwmwwoY0NiI2PcrMx2rXnuUb/nnFTQd94pkauoFA5/TF7K9+D+rdHmbzZ1vltbmN5+dx3xsGGhUCDu60tIXGKcBPsY3RLiKmD4U+u5uq6V47GX0+kg7949sFu+PJ1pkeOFMFV1buPCPJRpF2RFMTlHsrW9jcf8+QQ4reuZWge5upk5In6ezLK0rcLllCjRhCHWHCQ6F7ciSk2h14KxpZ7Nqp6+XK/y5cuTHI69PLzD2rbI8XBIBWnzWHfP72yFgXNzzGG9RctR1a1jTvNXr1XqqYjofdeWKFGiRIkSJUqUKFGiRIkSvzZYzsq0/IXU/XMZEVqm7eS+pzxvSXYw7CI+fK+G32mHYnu4GnI9/VLc2BfjxorQcduvQdp5ngjdhqqRrlLP5QGKT1zLgqJ7ageB11yuZ3l5eSZExNZj9TlFMSioEVVp7HBZPYQX+hsn6OEtYBYQtNkAJJZaGlqEKGCtXsFlgxMGHQCnuJW5deWZFgfj6+Dlua1yuZylulVyWgF2B5knsx4x0VkucGbFxKzIFyHSisJ6N93SbX7Ojv01EaDXxmJkwTSbvtPzJqsVv8I1HU699Z/BM1epksKp/QGU4Lql3K7DxYrJVtaBR7RjyOPJzQpqvMFh9bbJ9XyTn378Ir/73S/yH/4fv5Pn56v89OmzTPMsv/n+II+PvUxLLcfjSaZxlWPf0dKgeoSHbkc41EyqQG06qMQ7wmV0CVWM9GJVSERlN7evi3TY2s5Gs2JnVlBN0Qv6CYrJCfheBMkJwmAF+7AvSOPdvLaV7FixP/aJQbgFkMzUhvMi0zhRnewQ1ucV+wXn8PfdtWEucpeeNqys/5/sIv6UGtE8r92zPOsZo/bYLShsTNlBMQcIvCxxxKQRoZ6NT7MA0YKCBt9NganXZjA910/8Kt/y82vRP4N0DuvcEzm+wbIb6uNuThV1JYehlXfvjuy7ucO606jKf8WakS0M1Dq55riiFYYV2HOVup+Xfd3oGhLBchZ/6y4M99omEDSITCsWXhtOqCpmTXJYYiN5l3sHKsx+7TPt1jpBRbyjuVF5n4+1/+mvi73w95GxvqWAzqpua6jXimfvQ6yz5NEGTpn8McWv22rAwx7A1B1IbBfLcOhEtpNcLg/ym998x10Lf/zpk1yukyXkZtpHeGFVz1WwQKhZzTDhxx0zlsT0pBNvodJEnSupk9f9Ig2SgJPa6GAMQRG/0i5HP58aA+D8O+8p+zdrkrGSYRhkWfTzA3Pk+elFfvzxR/ntb38jf/3Xf8kChe8eT7R8ajYt3pf+o62MAnAWdrXiuJiT61ZzuLlbj+cl05rgCRlLiAJ+lyhRokSJEiVKlChRokSJEt9esRwLTiWwgj9rsbKkhfXCbQ4P+SVXlXbJGsD0dACgXQsgCFWW/rUx0lQD+BlbUdVZLmXlCkq3WMj14wIwMXUhwO4yT4ScgKCEeo0V2yOg1WJ3vl04/c2Uq35GXCshd1ISBv1n2gVuKuXgIZu25EOhXAFkmVI5qDTZPobaAA0I4E3F+grdpKJTVmCQ1+pqay+y5yo8bRHaiCSvW+0HB88oJEWIAlgyTix6qIpNsyaI+QPvvSQzduMC99N2hXoQJyZVLLxqQwFHh+S4T4C+aZLpNnJL+E8/P8kvn17k6TzJ+TrLda5kXmp5ucHiYpZPT1c+j3b78O5EqNL3o+3CB0XcZJlmWQh/hP64Sgz1ut2KIdXes6KE9Dt25bJ5fScVJZWGsNRAk6sq3UGRjgFV37o6m0DYWs19luEZqxa0plRMSvEMewizayQ/Qp/7jgGe1+1QXseb42U3ig0ucS77vPb5vH9TVruGoySYbckj0jo7UjpmuJoEsjJs9R0ItEvZXa/D4Oz2kZImQQH9VniBPUZS+Psd+zlz4U22I4uAqn+FFh101bIV8UsP7dOUJAuY9LVi3NdAb4nssewLGc166JUcxosp6p3c67zGa3V+R7icYHO6gFjQz/+mT+r6403hXvj+/H1bBs/mP5OMyK/J/ff6LVFnHbIHXzumX75OquRn7kmg+PA1H/C/bVsZBnjsd0yQZLH729kKn2u6zvv6q6/f/b6zeLG1ikkItdYIm0jsmMEihl70uEaMMYXY/NzZ5dasmF/T8gEF9OVykXG8sQApjnwYBp5gwbpnn6GepOHnSJhz1dca1OdQWss80RM+o0uUKFGiRIkSJUqUKFGiRIlvCZZVyallsRCpHBghqcKZlaozU6yFLdj4Io7tyUkOaFviUdjo0NfyeGxoYzH0rSofZ1g5YKe4+tBC/YXnfZsxjgPlL47eo9ibeSpXLO5kX+ATPBCZx5vcLi+qBMSWaqgRTRW4rpOWSqMKEaABHq4z1XLLosCH99hAEd1L16sy1QFS9BlGACJUlVpujPMi87qoQllQnLCTvuuooKUnsCuaV1NXO9ipKul6+L9OCogBt5ywAzKssyzLpFvp6dnqW63xu96P+oXmomfYGk3/VkGhQgUg6DfaW0hPZSWKSP0yf6EP9LvHI3/2Q0/lXHb6zQpE38YP709AFN8KTmuPyJGgPoffcdOyQKH71eI+qNQeJxbCev78xMf//B9/L/+X/+v/U375dJX/8ncXGWFnIj3PfX0apX6eZFx+kctlkr/4zXs5tK08nFD4T6hiPPRH6dvBPFU3bjnXxEAjTYsChJUsKELJ9kKzmSHD7SzN3EoPD23paGWhckb8u1MoiGKEuO4FymW9Oc4KjB886ICh51NltBcGw2vg31zDL0LqEYUDoVie5XYb6WVdAUI1Lf1XW8wB9M+6h8oAo64I1m7Y0Wczl0Chsvj3IG1OL9Rxw2PSw1VfowUWHSrrXOb5mXzZH8cTLTGdo9YyGH+W+KAafrHEkM5L3CPalWmTJcLDrJROoM89bqmANuPbFFn1qs7q+hrNsWgBP+yEwPqhawiKSaqXOy8HtfZYaBP9MHK9QR9S3dy16nkNT3QW7MO8DwklV++bzUIe8Jb68tfa8b24HccGQWIl1eJbNFQyy+J9vFdVRWsOJPg07Bbku993nDgmvfzF/jBv45Sii6+5R+VR9fy1CHYV8SISHNaElyfxfCyq5YwfIsB4UyiralvtZrBTQNtX+18tRaAKrqXvajkdB3ot4yX9gOJ2nmjD3aKg6l1ijGPJrXDMTskhMZMdPt5zO+ruAcxvt8TRS6NVBlXDntzQnRIdEpCyytAO3NcyzRO9n90+Sj38OfOkH45s/+s4yvTjjzL0jbw8/4Zn/+H77/jZ+dMfP8n1cjNoDVhd6Q4fmaho1qu9L3RpZ/JMDdtbx2WyfXljaJUoUaJEiRIlSpQoUaJEiRK/vmLZ/UX53T/7rWZ/U/x0gGDqzPBlXoVwCoqSwhg2GI0qlqkEJhBwJZV96a1iUTrDIclawY7DrdGqKEtivSBapDqVwChbXmTw7OpAV0TmcN9iwlM/pz8CuIs8J3leBiWqqzpV/XvnNWxf+h10+H9UV7vn8d2xowrUD5IVgeE5swLQl2jbJoiD31kUy/vFjj/rMaGeAwxUCMN/hMJ/bm0S+j9ZEUSV975YoN1Gbi8CNAB8KItnuV1vcj5f5OX5Il+eLvJ8HuU2rzKhgCDgNgo7Lnqe822Wp+ebnA54z0gF7DRNLD641AvB8TK3VP7RrgMKQFgW0I4B2+/Nuzi2PSxLCIMBvHHPVogrqRE1yUBbBCsoZgbYSWVK9Xby58a5XNKIP63h726XkuFVLn7nLRXBcbCAeAMFJeWx+65mLX3qh1icb3eaO4Vn8puN4tb4GrfE2M2AXPAuz5E4x7Ia2cGijx3dxZDHYVaBu/J9e3N+7iJdYx5hQVRv9hZezFPVpEhJwQbAfW59zuj7bEynrrS+/5MX4Ery132XfvpxoViu8fo8VlzBygqW6XjZKDs3+et+9ARXesZvx57lurdTjt+3ZW6zP48ad6tkvrWghk71Jd86XhzntoMlXbKrsX3SxEKhIbyP1LKkoWUJfOm7FsUYa1lWTWj+vaipJ1KSRdP2GpsnH2Nls1rML+9QiTsO9PrcOiYnJ/3zMTN1vXdNeLVco8Z14jo2j6PeF2x8aDmj5UV3n0FWCDBZxHjbxPmVlMp33WbXkfuuRIkSJUqUKFGiRIkSJUqU+JaKZcC4DW9oaQPgxb7wd91OrlXqc5Ej9VPWL85QGDu80YDSrGs2eXdq5bsPAwEL1F8zpF839fql2rSV9KA3MwS59EtV39fOboLgeZmo1aUqbNtkAiStYLGwyTpt9NFs4LnpqkzAh0XPC5ajl4v3dwrqABdx7W5dEanovaer/pEAYl4BQ2e5jSNVuVSUVVBMQk3Xso2gQiMwIKzKPrP0yhT172waFPLD8w0VjPMyqicwlNlUhSpZd89WVVzrT4UgVpwK17UABsNHFvdnqmvuJYcnp9o4OGOkhzUU5vQ1VZsTVadCwatjwZsCp41+zrq1X/tAQZwCdjxwzQ4HcU5sD8d5nj4/y/Vykf/8n34vf/jbH+U//uc/yO/+8EUu0ybXqpalhU+pM6hetq2Tz1eR6qerLGst33/4o3x4d5Th1OnrllqmFnBZTby7ZaVq2eEMFNZ0tHDiaABmGm9SL5O0HQ4CRThI9iZ1h/eoqnCjehnK5UGzGqZUVqX9rMXCoGomJNT22lYt+Ke2MDoPWqhTt0qOpyMtIfDc9XqT23WU2zjRT9o2/McabjuI5fYnOWmRlZoRIrm3q3ufuw9z4qkJ5Km6V4tk4t7tUOarUkFly/Hj1jToQ+wSMFU11cg2T9AOLLxpSnYyOSv0Z2Mgbel3Ow2V7+9mlCJyLZem41T/GkFu5oA297EWmC01dbrVJl3byKHv5DD0chgOXD86qISrSo4V/g3ldsv1RYsxwvtW1wa2lyUkHNy7BQWLjMK3PYJh1YBnWwwuK6YIxnHxbNNJhflkCR8mGajkNiDvJvQOM1NxvAirfRb6wNBxyOv0AULbE8x9h5sxW5DbzVb58PvdIEp/u1/14vHu3+MD1+eYFfmMmb+YOdsBb/MjTm0X7tcWYtwLhg68in/zw0dpu0Z++8N3fPrp6SKXy2iC8pCuST7x1l1exI+WOG6JEpI8llnAZ5kWjLWh2lRcwxt4/HeLdFCfs3BshsY6BrzEq7r2O2TW3T+qbO/6AxXx55dRxtskLy9n+eWXn+V4PMr33/2Wr/08PMl0w2eCjllfYzmOoYC33RH4D77O3OHDCzV/b9r8ZP/o1BbBoqZEiRIlSpQoUaJEiRIlSpT4hoplBSPcOu6iTH0mFarz12T+6urarMz0gEK5ayvpO1hgNOQo4xRrXCk4cPbHLfewC3C8wQJ6uawadVlUJWcfYS3SpOoyPAiw3IPXYI2qeLMlRQYYWdnr3qyvUXL2W9Xf8nHwAHRToOrHMHWZFxYjAHMPTw1CDgK7VWZszedzQBK54FdShAdxrLV2Vjm6EtuuKZfwo/zboK8VC6P3tGMQ93CGVYVCYaiXHQ6zAJ12thUaM1Wp+TzzulyZaRFV3OlvxkRxzPE2cpv3l8/P8vPPX+Tzl7O8nG9yW2tZpCcspoCY76MDt4zzLOdlkedhpLoZp7tNsxyXRaYKULyStptpNQHQpp6lIi3AERXooT1YXMsUy/gPxR/JB+HP3MrWWgE+vsUAIZMpCgnxu3apGTazbzOUcqITASiuCYAb/rCdWXWgUNi8oL31oe2bFcIpAZFHW1bVhzH5KoIVxk7V6NeSeiQXqEwgOqg5dW7n3Qp8BjDLgaDXrzP4rPPMkyam4N9Pn6TwtKmpTRUuO95CHuP3z5sSNyi3/f78XhqDbXjQbiSpQStpNy9gpkXUtLn0aAmVpv7MM17HeiZ1btHgAuwsonZoam2W1kyl3xSyUlLt6uXgGc32836I/XUvM7XXqoG3/SVCev932AURm8+TC6/G0+5FMbMW/Jjfgsqp0QKw1qRDMNXZKZ5Tw+2eiTA9+3+o+lfnJPrzeByo9D2dDnI6H+R6meRWqb96Ut/7HEje8/qZQDuWO9bt1kZ6HZrUSH/j+ND1UL3wrUAli4Pu7z16bsfdKjHpAKi8B9jYvXFlIhJ2RPBp1qSkQWW7iaxath0kNne8wCy98318pkaIkN2vwxTtJUqUKFGiRIkSJUqUKFGixLcr3qdqN9gnRJiVvx5r8S79Z6BD9uWfQBJw0AoHHbpKTn0lD8dGHo4d1cWAakLv3YVgrW3xxV9k6CsZhkbmWVWuWTEI0K0wFKrgvmspeJxHHGOV6w1K0U1u11nGG9S6DbdK4/xUVtZZmcjCd+5UQLqFQnP6mo6eyKpY9C3yfvNQrDlsc3iEW6cad1HP567vkopMVbOSVKoOiKAMXteKfsO49qaGdPJGdTF9dbeZjy39hFJWL4XQlyByb2ysimL1KcXz66TnbOteFeiyqpdsa9u2AeO1VamUrnE9sypSaS/RLAThyqZrE1gGAGfwpUpwzpmm+gLj6ASqgLIVoIyw319ebvL8dJWffn6RP/zxST4/jXKF6hgOotWc7E8UYgPo6nPXVeQ8i3y+zFJ1i1xuqxxGVTEC8txw7Teoxlf2H7ab43rws91aafqGQBHQEapa7S+RZbyh2p/UWysV1MtNJ9t8M2qKa3dY1poXS6uMjeNHgb27XvD/SJ2ZNVDv3wB4vD28QBwTIrOCZfWs9gkY/CkMwDmmZPP4VvwEsXw7/OuZrEmZzCbjmNap5QkdtROgy3ICo+rhqkXSdNy4TYAnGBzbUTkc1oqkVod3ukHMVEAsWWjAS91ey6bL59TRqwDNIS6tK4JtBfMeBLU6nrWAWiVD39GH+3Q8ysPxyHVg6AedD/OiySn4PtvYVUqn0JGJnsV0yLZ7gWU2qeR2ywZ6bSQ7CNqlRJiZEgHe97ZmmB0M34fsiV0vc3EOBd2amE8iCRHG1i7NYPA4rmM2bjI2vPNWdisK93D3NTv9jMm2AJXDWfevDe+JYze9xz3Hc0Iitw/6FuszZbZqRUPhf/awRr/rmMjFVtuulYeHE1/y3Xcf2Z5Y76Eo9jZw6BqTGZijTD7aONfndPy5XzwTfZ5TYgJKFfwYJ/BN5k6UUX3KsYbjQdhryU2sF5zatDfxtlcndPUM13mG12EXQ9sduJ6M0ybHtZK+PzD5dHp4EAimsTbgczJZDPEzSMfN1ujffTfPzjLJ5kvKj6QlxReuApZLlChRokSJEiVKlChRosQ3BcsGjOgx6+rMPWhIX9jvaBa3LLO6nu3MrTcWHjoOtRyHRo6HVqpxleqqQMLVvh2sMBqhBzNUzXjfbdIv9Y7UWGiLNhM1YdE8rXJFoaRlldtNi/+No0KGts3XWdftzuc2MVlX4bqaGZuYGwBpVcruRJJ0kgCYzqDHQRjaB0pfVWa3pjhTGGeyzqT2VfiB+8O9obbbKmMNeHwjXO8AzFFUMEBlLTKoqlAHHWmPc9aLq0p0nWk5MV0mA8sA17CDqNjG2GpOC4T0VoDWjqB7BeSEVUijBQmxex/wBIA5bit3dnqv9vOxodYcujFc6V9t97pRXXh+ucmnzxfC5afzJNNayVRtMlewhLAjBTsAOqZslVznTZ6vq7T9IrcRFiJoL8XPgDnVOPLcDpZRiK0FQOZ9oAMt6UG2osXklmnSIm71SGXyxsJukxY/hJWKqZzTg1AweFp7PyQJrtsZZK9g/zfhbEpaGFhmcsVsQ5LafN+eWTsbnrvnfm/NY4PBLjz3+Vrt4Kf+C/OKSvHsAMDwLfjW9da/e39ht+VISmsfkRym0IWbyjgsFn5tiMVUpDk5ofDO4Ziqzs2iw61YbA4my4/kzy4ydK0cDwd9DAeqxIfhoCYN48RkiIJlt8RhFkyhLrtwo3IeVjocBwTQbpuAXQncVqF2FqkQnoL2pFpONiZe1M53TmCOAoQnIqz34AUMCZd1nlbq4fJqPOgpLM2QCjP6Xo6sYdZrDZLqOKZCwm4fd+ri3Xnv1ccOPD0Bku1f7t+nyYcwT5LdiF53etLF0Q6X6WaEJJACZtgzHY8o3rnJhw+PTIZ9/vwsLy+XcOeOdbMRNa/ABikcf9zfX9tQxyqKavrnHNYJfjZZkgjrYVXBE1kTVPVaM2FV+1iwMZ08lvepoGxdQp9krE+dtN3AzwJ8BoBF911P/+jD4SjTvMrlcpNpGa2Qps5Z7iJyGG9e4j7k7FbyrgOzv8iJCQPehSuXKFGiRIkSJUqUKFGiRIlv7bGsoV/6AV8cqFDZSD6y/3aalVEZ/Lg9A2AtoDG+0FMHBt/hBd7EAJizFo+Dh3ILYAw1F+wyoMbCu9WqIdFdKhXxmo6F4PClHB61uvVeITW8jtWvOMMgggRXJibQpSppt64AgPWCXwrPsoJPC7XlbckqdzUvaP8ybz7UgJhaaMnNX/f0T7/iA0xnS446ndt8WNU+9o1t5+avy6fMB9bQnaqZs5UF/jiPV72OFoo9mE1AJOeQC0dr2P5EEBghhNfgIACCatcJ5Tq9ls3SwRVyav2h3srezgDU6sNtFiZ4fQ0JqKmPAdLHRc6XSV4uo1zHRWCPbbpGg6s6/rSYIOAO1IurjOsmVwCXaZXnyyTDMErX9wI+0xI6KqydZrQB+vYqMwy7a6gRF9k6jME2t09Sk2+0xKihSmyhHlf1O6wxTEZpzMwajd1q0NO8xdUhxkCOK3apNgzKfG6nz/2Tt+xHlWfEvtr/SbBogyIV7zMA7/AwJ0KisjmM3WhFYM8kEGp/1qmfYVXagk/fbUrmpTKbi2gl4IDOj8w1xO8v+t6m84YChTuRrL8+zzM6hLDYpyk0zUvDob0XDqSdTjrQqspSJoRgwwMIqG0AsKzQOBb5a2hTwPFM/1pYwYS5a+daqKzVBBSbir7k9ZsFF32eJDVxKiwKqx+zqaGi/85qYoOVC64Dz5nvdkpi6G4AB4Q+hlQFj3MBPOb+3Hth7NMWd+VL71YohPVhGG8JlMaHW8PwUryopR8r+sG/oXROnszml405j/vI0uHsisHPIPOAr2uDsAMTB/jM8DvEkoo1wyaIDSn1n/aEJ/3ZDXgDFCs/Nvhsn29MMNBYHn83VbOvHbZ7hBZHdnlqJ5SLubJ//VkbB7ra6FxuGiT0RJ5frtL3V65bbSfSwRscSuxpkfV6TUVneb34nGLrqU1SnDppM4K/3sYHd6gw/6HK/j+VjCpRokSJEiVKlChRokSJEiX+m8EyFJ8I25hOjIEv6g5gFADeQwJjAPbAN1jAPDwFBW/fQWlshfcACScUvJtlHFVx3NSDHAYoC/GAUm9VQSBtAvAl2mvFoThgy23D83zlMQCW6VEMXgPF7jLJDGsHqOG2hpDZy0mpck3BgPsWA1RgCzKVrizyBbCMu3ZVM96vSjQtXqctMzscNl9OXCKAFW0weJyGRbugvo4w0+kGeDkLGAawTHjnBQprQLzo12vADQpk2HcAXCbmpapp+AtDxUev5GWV6wvIMK51kX7q5VRv0gzGm8h6Wplm/KGSrUPRNhQPBFBecvFAqJa5/buRthoI1Rze0KaAwMWgEtWlprADBARYJrBraW0x3qDEm+XL81V+/nyRl+skM1TS5l3Mewf14eVBwbwKtNdA39d1lefbJnW7yC9frlRpDg8nOaAfSAkb3tLtNvI6MCaYJFgm6cde5AgRcq/FBLHXHONlAgDGmB9VVb70sq44NhB8pzCXYx44x3yVCVlNjTuNvP8ayIjySgM9aJMFivOVhf7QHzouMT7RV+7j+jZY1gSGDnpaoSAFYGMjWUsEmOmQKelJzTbCSvDl9yV/4nhWBW40egnqZj7jSZemlhYKbvhas7CZ9bvZO0RepRAaliGbVPC7Turm/CrNt1g/ExQbgOVLrOimc90AwrRgIDyqdYyjiRQp+v1s6YE5hDWnbys5Hjq1mWGFQajekSTR42HuIeHQNijWWMmM+8HcghWMUjyuHpgP0w02KbhE89xmwgTjO+/iSEpkV/EmCxm01UzAWWHew4IGc84TUHzraokN211B9faqSRMUDvR1yZIczLmhh1skcFylmkjsn7E9iOrjaJsRR1IcELZ7JEBi3pUnFrSc6t1xDIRT1Zzfp0mBUECQ9485OSVPet3yEs6NpAbWrwbrdS2n01GmcWaRxp5gWUc71yb3+LbkpINiwv7a1y1tS63piQQqkmpWPBLzM9kO4ZgKx/ETySHCbyQHeIsK+ddZ70sTSHqfqih2uKwtxiKR8P7GZ9gy0Wseu2qQcOsHkcPxJE03yPly1eK4ebHm58rmNi42l70faCFD8Xu01smJoaya/xNDokSJEiVKlChRokSJEiVKlPhvBsumWE5Mx71Bd0XBAlBO73QgiC/u/idXDepxAG0AU/ngrm/1bNXt7NkrVsGTenNq0SwvJpbxUbRkcN9VhziutIwEzWG4viwUW7Lf0/np2er78qOodO8/oCrhPeRTxbLCr2SVa29JGN6kZeqk4ddrns5h27YzBd6ebTFXj2iXnBp1MRsPhTq6NR/cCp7Gz+ebbumGL/W0SNO3cjj2dl9mOEClsfrIukqPMJ5ezJvU5nVNlXpSh7pK0ny499wwqSldAUjlMqC122Ism4zwEMW5ku1rLhbI/6jSy20NWHOb4K88y+cvF3bH47ujDH3LV8ECQQWmeiS3UcE29rqBRcok0zza+NGBoz7eqlyG2hsgFMpltmci+laIzOwvsqY4g8S8fd9tD4IClY3hdhemuLTkQ+73HRY2qw1X26vNg2/fTy9L3rpvz2NXkufryOrk3XvSOAxw1NWQNi/VUqMmEM3uHqESZVJDhzmXBNIGVe36HXdnyJwRWVpv7J4jiE6neeOW0xT1eZ3uWxNDSEZxbuL6UZgyFFfTwwbaFncKBJ9kz5ylEoh2//EKoyVB6p94lekFrjpGhsIaKxRVU2CrOzVo00N4CbsfKJ7vZKp+B2wzN/wOBPFV36+v7jspa98YTPk1uz9mQLxTIHvRxrsMRVgH8z/eON8b0NN3aPjSpw5DgMuNdD0KYjbSdpoQ088pXQF0+HjiDklDuzJT9NZW8C5tLLEkiibIzDfcxgDXUK4Rqyy136MmM/0cvkbqOhGKz+YPKbOV0mvTpBfWwlWmCclWWG3MXBNg5aRJlGzP4gndvEPEPPBdZe1TL7WhK/szVL6r6VeiRIkSJUqUKFGiRIkSJUr8+mD50KFoFBCgflEdqSC0AnjcFuy2B+7XmCkS1aAwjOQXc1WwomBc19VUuD5dJ3m5rXIdAQk3mXAcqj3NttSKdeErfwtl2brKTF/SjXYHHWXMldyWTaZ0+kpaqIO5RRrKxI4gs0ahOthSoCAUWY6DNvUZjlCXUIHevLrNfUken2ZLkYofOZhRi5B50Z9ABW2lRQXhs8wd8lSsGW6x3e5457xBg4hDNKpK5jb8xopJqRUFixfCA9iLeTUAY1YkjH/AsfH6GhdBtR2U0UgKTNsot3mS5/Ms/6/ffaIq/N0DAGwn/8O88hqbrpHu0FJjOFFZW8kI3wtADVwHgYVCOcJQtt0qLYVzaGcUAtRxgsQAhZYJyLsie5V5gvqykv5wkm6E/riTaW3lDDuL2yTXCVYYSpZrQDZX1/H8qhRVr15hn//05UJYvtwmeRhaGXGM54v88NuP0qNwYtNID/sLU57WKOx4vck8jarkhgczFO/DkWMTbaNb7FUl63i+6aEafDSrAvhMV6KZECNcjWuBtX3YUTo4DOChcJ8V3jIzX4L7ZWWByIfTI71UoZIF8Kqo7s+qTlxnLNKooN0zJHYmt4rwhAMHuQNpB7OpsqBdXca60b83a479bxgZOh8xltXOpsNQkwW7AeiNnqF3AtwpK2T/9HnmPsqYowZ19X5cEax2JworobYHYMPftQBmoqmhOKKvT/4Tx8UcxDPLVsuy1rLMeKaV4fhAULdi/UACwXMygIayaD+mNvKMjnp+q8ctPLtr1HbMbcZCfN6qMXICa7tXf2NsQvE6X/WVhNU4ts5pgmHrt229WVfomseCfYDLth6rTYaqrC31YuPHVc1+LemCLRnFEZ76eR97/XceG1mhnIrv0eVC1btcAAiAXTGNhlL7G/zktTvxNI/z7E+dC85l+umXoesJ4bIn66ieb+Th8cCDIbn0+HzgXL5dR85hrJe7LBd2PVQNxyz97fEqKJTNEgmv6DAKqJSGxZJ7iWPXhq4TUDFjPHLNrdWXnYkh7sjIiSX64+MMuG47hydLuAbQf16TbdepYqLsfL3J5y/PbL93H96z2GQ/9LzPBZ755q2PhByTsTYPsA8kzbWQ8KJFDq097MOHXvp4n1tBlShRokSJEiVKlChRokSJEt9MsaxfTh1BzFaQiGFKqfSfbe/1JxVx+DZ2/ZqrakfFsqpUVkWtbufNbDoW9+L7gSMAUB2PGUCjWtmUYXwr36YFxpKyMhXNs4OFQ+9AVVAa7mpOvbqe/XZ//0usz0WIbpDaoU48i1qpOp4OCuQdGMzetdHjOV9Y3L6+/6mqabQabCBExnmT83WS6xXwtJaxX+R8Hglfuq2VZoB6OCtpfas3wIn+7oUN93yKPZCKb+V+2De0KeSUoJsliEJahdVQnBM3ZUV2UhkaMI1wy0TA8ObGP16gxJ4XeXq6yNPDIA+PR24lh48yrFd0RDiHN8UhtqgTWFo78vrcykNhJohStQAigSx5giSOSuuTMHAzVHI/1ztrW1cOW0OyeBfbQws87gus5QGgLtf+NwdU4U97NvxmJNuaoP6Nt+Fq4p3i9k6FvFP3M/ERlJluuRGE9PnMMXKb58sIJ9udMyhgU1XANwqF3s0MV1ZnBbhjX+wiUKubqoXKNPoV52vdzbbUJVHta2M+/f5a/bm3rfFklAPaAEhTP6u9SkoEJIU7iaAdBkYxvgr6363tzB4onzYXyot/835UG4vg6ezXulMRh/u/h9Pp+egDHl8TLBruF9T77vfXOFS+O036VPH+TOvClhXLSORRtdzSGiWtG75zxe4LSTcX2LvCl97zSV2sO2uSIjqODPusWqrsoZxAO1+vCVRNjDiQ1t02ye/81SCJSR1NTtLSadZkhlsu+eeDr7Pp2P457MpxX8/S4hNHc0wS3PVziRIlSpQoUaJEiRIlSpQo8asrlgeFJ1BcAfXBrkAmhRIjvvhuWniPHrOOpQJcgtqSCt7GC/K1LLaH4ms3KExHU15RtQVYo/7Nvp0YimduV4e6kepYLZSFY3So1FZVcqOtARS2qgqG5ybA0QC/TSq9HC5bmbVEvcxc2VRwWkSpMa9V+DKbvyobwApKOXgyNSi9kwlH1dsZ7UHfz6aWrutZgFBtLPQ4zgagJiZQpRoOKtBOOsG5TZFIn0548mrxORZeMsmzMgq0C7xlVWWrXqxQtk5qJwKAVjdU2336cuPj7z5Ncj7fpJUL1Xj0+r1d5f3HB/nL/933VMYNaH+KQbVomfqIuu+rtmPbdamNEFSw9r2OAcBas1BRCGlKRVMvo73qtpeun6lc7g8XFt6CsridAWUmVXO6LYRBr6pmNUFtQ/W2kNH8ib9sMwv//cf/5Y/y6ZdP8nS+sMDb6XSQ337/kZ6rTQ+7CVVXc3zQegJHRh+79YA+pttF1uks7XKDplqadZHu8EHqdhbpD5TTw+nZlby0MGDfKr5Gv6h0PRfKS0pFglj0nYP26F2MazRbBofPHDcBfFqBRqgNkzbWx4ZexB11zTBcPV/Vozl5LttPehzzcoPi1vbueyJEYa3+DbMdx9LiiH6N8fzaV/SOtsKF9PrmzoYMpx154di6TR9ris8yv4j9v+7V2QpJMwTEsaBWx9zrO/i095ZgUJ/vfjhI27WyYVsE5tA06rhlwsssBbgWofid9imUsdoISIiop3KDgnoqvU4Emn2rF6F9hwkOxWoClIni81jZ+iODZSra+XOWjVYs6hHNdzMBZxY9tPHOoJXJGls7XqnVPfHj3u5OzINP/G7M6Or2Kn32mvgGL2yorfEpkGj8HUwOfs+htqOOR7NHcl7+po7WgL7euCae0J+t1PLw+E7appePHz7K9TLKl8/PcruqhzW3vxhY5oMgGp87q1SLWmC0niDljg3sjoFn9yaLFaqkKNiKbHIXDlzf11UaL9jJceAJJZZNzF7S9J/PbZlbcX+X6L+uH6RqW7mNo5wvFzle4fGuxQGZJEVqweYRrl+5th4RxSSTzz2LhOJ66Vif7J04Dg2458K8JUqUKFGiRIkSJUqUKFGixDcCywCkESyrAlhVVfTfNRsMhZnuGczy8/bFWkGU+lqa7zCVtChghi/He4/I5IXsnMG+zCOiVzKBLuFs/qKtr9EiXISHbUMlmwPQ+03CGYpkQJWtBBRsK2AN15SKTbnHbgZbULp64SSCcBbhw5ZrBc4uLFS1o26N1iNqET4iRVoD6OtUMZy3TmtxQVcGK3wCxFSvTisOh5+gDXYfsCm43ha5Xhc5Xxd5ucJb40ag9su7Z3n/AL/jSr6/vZcGRfWSV25QoSfQoyAe/Yf7Sj7bpgznawChDbI6rI9bsjEuWJwQqlHYhLBQIoA6ChxG4BmLmGUYQmRu4lW2K+xZZrRtJZ+fzry1dx9O8uXpha/57t1MCweCKLc5wbWHQlpJuWk3C1gPMIM31h1gZCfbMhqca+1td8XHfDxZ3+lx93A0jp3cqLkNOSbuvMuz6nCv5I3KcNO8vgEI/Vkfo0BMOmqhROd4crLrbD1OFT9ngrh+Dht/BNzmyW2KzXjepP5NntKmxt0VGvQ8j64X1qLWPIaX79XJd+tEvC6/hwY7BrjmKLBXdbElj5jcwlhHZcpaK2fSAiWfKK0FvnZExSgB7t4HOQXmXr23O+CdGJf2dlElLcC1q329xw0uww6FCQhdU5TJBpWxAdY8ioOSNym7FWpqMi0NyjQWU2LFX78fZDnSGDGA+lbwWlzvfq/sjzsOckIhqfz99W8o5ONa9ApUe3dsSCAgsSVyOBzkeDjK5XyzNYmLpr7LiC52knhiUEeuasB1vdfxnLzEraDrbmdD8Ddm4oTg2C7WagSk9nrDyNjKHKaxm4XwsGDSdRVqZS1Eq8U+fTeAdqFaYejnbk78MBW4m2uekDJrlbBE+Tq732FSokSJEiVKlChRokSJEiVK/OpWGPoduSP4q6WdoMDTL8tpK2784m/fXuFVqSqpjYqvI1SpnRpqjHAZADdxr1VXGpMAKojE990ZBd1mKCOxNVj5zziq2nPuVlngl7xOVP3C1gB+mQABh4MVcWqrdJ1Q7tK61JSqNRSwBsj1GgAc8bxvac5qwIgHdVtz9vhc3TvTYBC2QrPYEtSPgO9QY5sCkV6XtlVdnaPzdmXFMmaHoLo/tWTYFrPT0IJTaCsCBcIGQCds+dZ/67Z4fVAYuolcryOL2315vsltEhkXXKsqbD+fZ/m7H5+ou/3w/iQPD4u8P32Utm5TAb+qgye2+VTXKIhnPqmmwNPuNngJYNzqsV2hCaUxoJ7+RNu77/UsQw8IdJDH01E+PBxl3DZ5XmayOSjdXdqXNLQBVvEp9B9SFNZhz9NKr+2//flFjv/p7+T7j4/y2PfycBzkUD9IS/rTqvqUomf0i/mlcqjMphi98e/zfJVlHaWHF/Xpi2zryG30FTw2TIGdhZehsiPGltsauPewXb56m/pP7eMFylSOHSseRn9qVchrgTiRWut8UV3tINXbJSmPHeIbLPJky2aAzYETdeQOqXZWDH4zDrvzeTAuNTFkqkeOGi0+yGNzvEEpaecgkfMESd6yv4vEGdWn2pNQeU2xtcTAoEI0JdH0XmdjZQsePd2a15Rqk6Gr5Ti0chhUuTwM8KsdOKYFYxVjHG8cRRoq8dWWJNnaeJsyGdEopIT6mSd1hbonGXxkZshPz+EEGL1tIxGHZzf80oO1hT2gdl5nePdy/4H5KJvy2opl6r+xlrn5OxIGyYBIC88l4OxZmVcdEJIcke769XgyzrF/VBXHceNJpHv1fFTGRiVz+FNUMN8D+5D88OPCg96fxJqNXRRo4r4/yDAcpO8H7gLR9Ec4XlhKmPyza/OUxrLBU14tXhAzx7uOJfYphNmmYs7AObSD8+dQyNDnPZJ8Pue8T2JigJ7jGJOVUK2MZx7fPXIXC5Klp+NRd4XYzgNAZ/xbd3foWFXl/z7PoJ7luEPdraA7SBROf0UbXqJEiRIlSpQoUaJEiRIlSvyKYLkSgloocPGFOirdHBoxoiiNsEshGCDWoW9l6ABkapkmfIEHkLGt2uRDKDaX1Zr4Yq9euAri5rUmtJomLd4H4Lx2OP8ss6n7ABhogXFoZOjxZdxUkxAmkodv6loAWNCphQX9a12tmNig+5LeQ5D9BmaFdQ5mtFAWbqEFe2JhQ712V4+ttn3b4VL8Sq/wLkBUfvFXNbI6KhjIcFUaaDruu8L252CX4GCZhdWgVp7k+fkmL+dRxkVkQjsaWH66zNJtM/v1y3cvIksl818ssvVeWEotSKDupMqYcM2LfdlWdIMjrjIGtHXtpT6vQBn2AXif/o6+6WiBAdAH8Pt4HORlnmUYb/ROntHPppTPjZRdgHM/oBCXvua8bCzq1386S1dvMl4n+avvPjCL8d1jL1tnqlv0TZW3tDMzQayvYHmdz4TIyzKqAnzbZN2hKwABAABJREFU5Hh7BmaSqkNxrtWmEICgwzbAHQA+TV5sa1aMEiwmFfweLgMQaQJBx6/yQVV2K6hDsgQ0S71dM5O8B8v7rfav/a7tfOhTU5cSixJCulo7SyiTWYVBTXQ77CFQADP1r4PwIIZlQTQb57TXsNZRiJUG+17cbY40yozVgoVI020FGrSne37rWoG/414ymtOR4UpUB8uw3+Ha07cs6IgHLWr6TmpYWXAdAqCbCfW4y4CKVj2sKvPV5oFKZdgq0JLApNZ8oY6RcHuWY8ttlQGs+zEE9a0X4uSfssqUNgbLZEmsNowNBckOvNWew0GzAUsaBOd5s7uwu5Vn/zMu5GHd41pj/Z5E9Hey9viepCiPYDkowncj8w24fB/hI4Yj17ctWCFAqNDxa9f3tEhqW1ggdRyP6zbtTuSn4trG3TNZDY0CsZi6LXZe2GeSz4OUN7KdN67iT7kcUynr5e7V3cS6sN7QrOa+eXxYYDdPownHyxUFHVcZx5HrJMYh1NjjbbTPFbOWwfg1+yiuK8lC3TN72l+6ZvlkNZ12USuXKFGiRIkSJUqUKFGiRIlvDZY7wNmw/T0Cqx1HTtup1V9SHyIti9hlxEJfS6ooofpqpSZ5hS8vvsQb8CI00WJyAMlQBRMQABAv9tOgM75JQ8UJAD7Q+qKWAQBpaA3u6hdq38aev/+nTffpXtNWdwfn+iMJ+hzgKTDLoNB9N+dplLoFHACccuaYtzyn7e9B/ZzsLu6K+zH87/fFwYynRMWzdg3xsymxN5nGRaZplnFWL1yAOBRAVHsPkXnbZNoqmQBkx0XGCZ7NsypoQxGzLBS07dP32+/di5Wq6k4vkV6+kqAy1M6wlFDFpYIwtBOgNVTL7x5O8rKu8m6Z5TYB1AHqBuUnFZwOHB3k6LUQL7LrFOzeFpGn8yR9e5Y//OFnubwc5HSoZJoOsiLJgTFGxaHeF0Co4ioowM3Xelmkahb997zQ63aDUlWJsIL8oMZUBwL3D/axhL5QxacrwJkwQdFKbHWf/DGprYerDHFtmEEOfcOo0PHs1jGYK83eFzkh1gBv084C77PsEZ7gsT20G7N1gIuJ1dpF7SXcrFYBmyZX1iVDbleZO+xS9WeeZ7FIYS4WubcU8etTX+hAHHfQMUBRt92JTM2sfJAU6zsFyyjspup7BYdIBrkqWu1b7EE1qiqCNWmApIrucqCi1ZX0TgeT1QTGRW0KUTc9CKpmNi2xZvZm9nEEeMv3uy1QtiDxRB53PagG1VZYu+cErWM72brnbZTaN0LFXIxVk2Rh1qdmNyi5W4TuIPJecrzvpt1b7vrSved3rNtp7R19vfdeToQbwNa9k2sW8HMLpJScSp9ZYcVPsNo+31IBvlz4NVs46dhHS+CnqpbVy9/c6PMhveCm7VTQNZwzOiUPYlJOi1+aIr/C594m58uZvt9YH9BXWCdRnwBrxeF4lKqZ5DKZBRR3zBhrT8UbobpWyxurYOCfYDq+2cVfNTYpUaJEiRIlSpQoUaJEiRIlfh2w3HeqbuW2fbNfSKq6JIryjdcaDnXwzr6upcOXc/tSPwJgyiLDsZfh0Mu8ovjcyC/fHRSRtM9QdSsg9OWmKkWofXENgJ+0wpihAIPYFMW3YKvQSH868OfpYZABYBm02a/NYDW/TBNMA0YokPUbiR63ysdMfWaUTq09HIzpF3UUUcPz+MI/jldpq17qAWA5eiMrOFJgpw9Vj2qhN6o5DVQljBgFniotVZjsu9PdvcChhN1HJR3fv8w3uV4nPi7XUS4jlN2rLPC2rrTQ4m2r5LqIXKZVLmco40aZZih1sYUcxecMMeEc5GDqTey8XP20DXJt6Dt4Jh91S/k8KqhperPI6KVuek05wMsZr6VydJD3797JD99/lLltZGoqOV9QTO1Zi0+ZZciIAmsb2srAsqN4QHSHfaaqfh5nWacLi3e1yyTvHwbpmlm++/gg07bIe9iLNNh2rkUCW8JZjDgHy6Os66SFCptW1naSdRrVsoDeyw6VncQa6GUnBK9uFn9Tj+0VCRErNogkBApXXi6j3G43PqZ5Up2uK4NRXJH3bwUnDYlRxcsJCesM3TofQe1epGpYC3M3eb9m1TQvHUJc8wTncWhbkkEek0NWqBEF8ei5jnlTrQbnre9rTfgk5Wey5sD4zIpebxdCZwftVJWqX7YDVVdUoq0AlhUceoLotQI2pX6SB4sqqrF+nQ69nI4DH4fDwMJ9ULgi0bJBmWoVzfT4qlDWQpk2X7EeUbHcZaPeKoJRB6GeYLOCbTtIqz65Sb1O9bMmYVj8E6/BtgpYoqCNbS3yvlDfdE8F+DpiOyY0nxHSDwZ3Ca6bV22k1+Ag1Y8V+LBXldzF/m8Oo3MXRNmu949lv8J5/fXuS59ArLtMB5sIVYS/0c9JPmxewmjvZqPVUQfbkyOKturuGh/vmkNwOxY/h1sQ2XHt0O6rjGKxGO+A1FirWbCUhWNrKw6JYrSNreV2x0w6ZfuX5MGMv/mchK8NC/xpO3hxVl/ssV58+vzMZMhfj3/N+XE8HuUwDLzG8/ki58tVPr9cZZ1194N9wtg+Ek3v6s4IS0DY55+Oby/yGC1cSpQoUaJEiRIlSpQoUaJEiW8AllXZ53qut7+EvrmJml+maxZK8q3lBMJ2jO5+Gy4VVwq37O1WIDCrufRLu4E8h918Xo+lCjJVqhHe+HtMuQUAvhP0+QUnpWNQB/qW76RcBlzNQDmB5eSnmQt+qeele9oaGmTxJC0cBwcGZSleYMkLqmUBY4Y8oYn40rwNmlua07Z08xJwz1CoSLdgJ5LsMvyOAOygVq5Y/O42Qd0Mv99VgZt54ioc8XpkVmgwFjn0LfiAFVR1WmEzQGjd351Uy/ocYCssMxppsW19UNiHx3Ec5XjTIlynoWNiYaRHtYLFrDTMysyw+zxtPsddzmvF955vizTNLE8vV3pfH9+dpBs6qhqHCqBQZEPBrKRRd6Wsg3sD61AUe4FI+3/NN2QIdn99WYnr7Zm9hqlwn5f0wO/pvElRbwmFoAJ9yyUgFdbzsW1y+QjtXMXoKkm/WO9jPsvxbfdkQNALM7qnLD1ceSF1GutrpVv8UwIkrgppHIUCe8m72a9/D8TzNe7XiKjdDu9+JWV11XQM3ANUn1gj4g4AT8y4VUmafKGIZwwvKphHW7y6hPLj4AmWKK42D8Xg4iSn/QwU1LAYUZsOB7A7LwZSZAfsEbR7k+V17m6k5Ea+h9B+3X6c/f9lj+Y3jxu7Ioy3XfFKn7PxIPftF88V295/mhXK2yc2ewsFwVoQ1Hpo9gF4h6hzNdV8pND3u2KO/pxZaLgdBnfZsH8V5Prx3K7irjGDb4y5V4f54Tsw/HPL1/llni2pq+Qb/uDNpPYevGsfPrHOJ62X/LNTrzn14a4Bv9aZJUqUKFGiRIkSJUqUKFGixK+hWKZ3Jb6dqt8vw8V5CbDsDRIUoKiXa28wZ6GyCnpZdVBthlYGU3R5oTnfbu+KZNhg3FCszxWWALMNPG71bAByG4AcwWMlAwoE9r7NHapkfCm3a2GBp9a8Sq2wmO+gJ//MICmBZJMIJyEiLRxM5UrLh0pG8351dSu2K/codkefXfXOpfgRlgLYztwNMsKeYoIyGLASKkXcIJ1sqc/NdcwUbvtW5wW2DOsi4zjJBNsNtes0MORdqtvrl+0mMzypl43gGL7F0CvjoR6bq1xx6hntUssvT2epukau4yyHCYWhVBGOfpsbqGNVMZ6KRAHQme8sVcntQEUyfuLqYX9C52GoMgFum0GfIzQEVK7l/Xff02v5+998Ly9PT7JhrNWVXI6j9HUjt3GWn1/OcoMaHFYL6q6hQ9ABSQC67pgAew/q5edKfnye5DytcvwvP8nPv3Qyrovcxqu8e3eSj9+9Z5/hP6iesw+y/VwrLQyJIpLTxO3nPc/rlglq6UFPZQe0Do8Jhmb2r9qLAORD1bvKvCz0vr5cb/LycpXnpyuV5Z4EIGx2KBX4I2wQOLYdPBM8Zsao4wxzyGYiFZqZNqlRSlT0eoE/VXtj4GFsE0IZdMV87GEtYGplHpsF59Ct8LadTRW6SkWylQGtOdUkwOYJn+TTzt0DVU4o8W8A1VuyzEnYiwX79qYf2Zol23p4+TgXFvvtQ9l9xI6GYUjJK8wnQDtcZLIm4Hi2In30YA9XwWNhTs/K6NgXwYLDf7L6JHZiqEdysrjYQVVKxfMFAirPowgeLSxXDCzTOsaFu1qEMc1zW49yMslWrwQyoeSNON6hrbdWfPjFe3uHscPEUESQb8FIv79ouXGvfL4DyWhvKmoztLWBt8+sWdu5snjn75ySKPrafujk4eEopxOSVZ3cxk1u8GunX7J+dvDzxhJGEfpqAVLt9wU7aUwJ7qfzGgAoZAt7FdQMwNyAmpnzm7sSJs4FtWyC575qiN0URf9nEmMXoZuimEkOa3v1S97ker3Ky/mc7hefJ8fTSWZ7LT7jWDgAa4Mpp+lHzsJ+WGssJWxDGqpp6uOLv3KJEiVKlChRokSJEiVKlPgHUSzjC/Fe5qU/M2tJgC/8JUXykTQfYIJlAC1Xd4XD7UGS85asuuL1GDxzkJdRTVaRJShiauadh2tSzWV1nKuEo3TSwa7+O/yRD1cs5z/rubX4FyBc3ILu0EqVn41UhN1ZnexbsPftl386+4jwUv8eWi35WGcS4q9K6tvoKw1w4T7LS1YsUykXlW933e5t5qptFrEisNEifaldXLFs3rVa/S1vx66blYW2hsOBnqHc6n27yaHraBkxtJ1xMbNoSI+M1vYqUWsrG1L05d6EquUGVh/XSdoaoGaU8XaT+dirBQP9ZnctvRuJue2yunw3vu8hG2v2OWDW9/jvUVhLRSEgM/yv4bGMJIkr5E1puGNr1qeuhN+d8l7xawXHdo4GxvxS6bJQWO/+QJq4UfDMnfvmKx5tYqI6H22oVhK2Drgvw87b2R0n9kpU//1evZyVy69LaObL9Z0JcTSEf+1uDUmpvc1H9HVW79mgWLZ1b7/0+eCCTcxrtau9KCQDFIoSuoexo4pqp5t+oZq8YwKPC6S1ifVlapzUIn6fDpZNQW3Hy5A5qIDjcdKl+zHv1NP3bf12D+ze86qA4U6lfL+u5XPv+s0zRg5crZDj7nbjH5TwJ2smrK8o2qeqZUuSvLo2XzfuQbclRuwa4kqT3mW2MVyX/PPGKubRLsjXKb/bBPGzajupjL1epr7QLIXsbKYwBrCGH/u8zJpIXbdURDN5ZceEFpXKurYkK467+HN9XaJEiRIlSpQoUaJEiRIlSvyqYJlcDArW2SAZwgrPOeTMKub8/Z9aYXs9lVTgJa0CGwJN+MriizIvxwv+ZUAKP9qRajM7DrYBt1bQa4FXsHpeAuY61MRrL+db8lelJ2yCsUSpASkAjAGIqv9xsrC4K4cWVZNanA2+vNqEE4uYbXI4HGXoenk8DvL+8chrhFLOFW6AEIfhKP1whLsxVcsK6wBOcEbcmHp6EngaFKe6Duo3qvigHF6k7SqpoWoMYM4hIovAAWbVm9StGoYCsmaVnqvkKtUub5W83Fb5+fNFmraX21TJvFA3TUVkbYX3vOAZoQaACoqi4X56eNaepO1P9nynSuUsFVVfTyqXeyva10rXtPLhhx/k9PhO/od/+k9k6Drp/vN/kdv5Ku0qcq5etL9rkbmGz28lM2R3GIPwEonWAM47cXNQDsLjmWpckXqEqnyTT2cULxT5LXyNrzcZxwNhbrXVMlVQFsKrGSpZVemyfQFNqXwH3Jmkov/0LDWk3GYBUsEv1Qr50RvXFIOcH9jGDt9c+nmrXy5Uk/Ajhioa4/vl+SI//fhJPn95kRuV6CjilwtIpqGX7Fbc8sQKlumzqVAclbmHnmNgGkfOV1wumsbth9P/R9DrUBne0yhI5t7GlswhoMO45Is0QeDjgQYni/oo68wxf1c7l9vS5G36eQ1RdbFajajYVWGrq4+hjk46XXsb7QcCwHNTCvRhA/sd3qj2T113qqSHqtxaSouqqVk31MWw64FynYpWs21R2xsDgtHKAh7IHGcG6KJlBr2VDTzzfq3oI9+HdQcKZEuyUG6aPH0UKuO106jqZSi2W4yVnh2H+adaU/3p9hcJiDus5g+3bXH1L+XlO2sHGzYJTGsSBB7jQeGerG7snDvNtQN6HYu6u8ML2HmRVEezAYTvKLGqsPf+DApMLbWR1mxNdOh6xp80LPbMhxd+XeXhtElXD/L05YlrMj826heq3VNiwFToqt9Po8LAbk5k6kcXzqee+apYxppUc03v8Whb9VjGOCLMVYU+2r9aUDB0IhBOIxrzq7fxa58pmkzKe37weXY4nujhPI03eXr6kj7DvA1U9zzLKvgMRUICf0O/43ME14H5GHYLEL57k2FMYP3QZFaJEiVKlChRokSJEiVKlCjxzcCybkF3xbHDoZ0bbYJgrsnTcmsu2ssKWBZY033PZoEBYEq8azBBt7+nU9OL1osg4W0Os3NxJH7xD6AJMWF7+7YSBBAGBMmvXqUDC7WNoMLT/Cru4UvYGb7b5U2knQSOAN6dVG0nw9CzyJJaH6CAnaroAD6gpOvaXqZmURjOE5gnsTS5uGDw4s1KNG1rtAHYl+I8V1MTk5k6FHeo9FjVmbl4VbJ5zYJSnhl2GefbLJcRFhhaKBHyO7V7yF662s65rWFj0XY9H4DSXpSMQIkgw7yWzX85+S0DctW1DKcTfW/ff/wo0+Uqv/z8SY59L9NtlK4C8LVCWm6/UUPt7sUQrXdMhbu5pwnUevb8vFVyw9+XSm7zJv2EZIYqAAF9qc5G0bkgz1aQ6kAt9wXgDR4ARoR2Ne7R1N/mn53IbSg+h/dQ+b5TXWe1+zihiN9VbreRFhle/Gs3Bb2jdMDsxKc7+AygX8FJAdYclayL+h9zyiS1chzYUa2/VyK7Ulb9wfWcPp93hSV3frSKlZMOFUrddKqgK9/RLFe7BrWvp3asOeu7y3b1f5yfNgosOWRrSlLVm5I+sVDDnjYpiMlZ+NGSBfdrQFDd6i6G7D/u4DQXS/Tnzd6Ci1+AtnyJQWAAQb95/52A3ZJbgOFeBNH9yTmvcmIg7R6wP8S1gkAbSwGXijzmMtS9UzKbF3g+thmLuLo9S6d3/eefCXv9e9jy4L4k+1Y1YfBbenRPF9g82qnS86BQD/xGLYOwLtJ/eGCiEkVBsbZgLuw8tXfCabdkCWe7U7J7YsRBMKfGK4/lbI7ftrDB0OdqtzAyuxk9Cz45ckKEt+Y5ABZUrJKPctcg+QTro5FJogneylBi04ZJEtD33UBgy0DwVbQPSc3ufZCvpaiWS5QoUaJEiRIlSpQoUaLEtwfLpukCkLvOi4zzTHBLZSwLyBn4qV4DZ0DgG3wmDdwohNUvvvr1Wt8Lz1U9lWJfqL/wBRiqZlWOQXHrZGFRwNjq7/iifehaGfrOCvdBtwUA4/7Na/IKrupNtkXVllRtQcgHONCqMrcx/2b3gqWnMi0VoOzSe1UP21aOR4DUWpZqpYK1FoCjVY6HTh4fTzw31LwKKBQlnE4nORwfqVhEey7LJLfRQQboDzWneWu3QUtX8NWt+qpS4OgSZUIqsy1ge2qlPfwH70+iB/r9qldzVukpiHB13jTjobYY+lj4OABUdg0L3h0fHthW/TDoz+OjtP1B6lYf2j6mRNUT5qJ+XmiMwFFBUQVIUvXy/T/6SxmOR0IRKHz/+Mef5OkLFMsiw3PHv/cUKULtCT/g6Q5jqf+ubzFHP7vFiaY5KhmnTa6wwritfEAJ723YtAaIUtuYCpC0ZqZd7jzeCMOX6SoNkwhVUJ6q9YeryKGyZR+kbe+qpFYlM2DrKm3fSn/spOngK12zr663UaZUbPKNqbgDujp/HNjvXmPwF0pcespyFNzDPdvyH0FqUApzyz/mBeaHQenqFaPCfAKk13vDmHdL35yIygkgBbemzOT/ub2ItjyhmwHgpnE3dr361M9maUG1v0NpNi+8zCtpWlWUHo6tnPpODsdOhiMSIKpA50xCP/BmNAGDeYVED46L/taabp4mc6IfTdmtEbAGui9vIvP+OlW1MwGVVOxOeRUSI+GibQMf5tU2RNAnSP2ZK0zKRfNfXlQugd7QI6n/vPBcUBezcbDQhc7zRAVDleL5Ph2EuxWK3889/PVz+N8j6LdirbgHqrRDocMEiR30+7W/4cdsSnEH8lm1rmNlR4rRrpjDDT5bRPq+l+PxRK9uqNY12WSA3CxnOLqSQtmU8nZtulMDYxYJyoYJKNr+6FTHRgyONfqFGyTHRwoKovKTo8E6rbtZaNXtimPr/cj0dbxpMgWrA3z64QE/9I2cDkhGYmzCK/8iXYWCp1Dh2zyoK/qx4/ONzcEdFJpq5NlsTOw27bC5zSajKJZLlChRokSJEiVKlChRosQ3B8sAjwusHWaZFi9G5tutAQj1lYaTVQtcAWyuMhEso9gdNqkbUomAmcrXyDr0fTg0ADMxHLcaG/QAKDCIhWuDdQSgct+hwJiqFUd+rdZic6vbZsyrgeVWEYopTgmjCKTN9sHUkFVAACuhsm/Rx+uhSu75unGdpJ78flY5HFo5PRz52r5XT0wKqDfh9mYUXiLoHGeZppbXxkJPtPGIAMtlxaZHplJZ74+XAu65Y4VegMy2+VP9powKFhDk8m7lYTI9F0+iv9C/eADE8GHF5hTy1dIB0p1OWoAQBdBQwGp4kLY7sngfbC5UqaxAttpgPaDb6NVewO4rieYqbvUHXP74F7+Vdx8/qop4vBEC/pf/9L+wbfq6lXFbpEPtPqhvMa5qFKZ6XYCKCuIkxDNls6k/x0WkAVy+rXK7oTCkQSUWVVSwmNTpKMBlhfZwsctUyTKOHCPrfJN1QXFG2CZ0qihn46rVBYFhrYmNnGxx9Wr2aW77RrpDJ3UHSmVgGdvmMWYJi/W6ExBKvewIMyiFXY2blPkKyJi8QSG8VETtDdWoebmGvyRVMPrGC+tl1bD34WZFwnQdAJhzYBcVqEl5b7YRUSnsCnt9nULnhkkIzOuor8wF29xSRMGyHletLQzacWdALQckQgaFyv0AVb0XujMPbNoBKPREQgBjALCOymLMESjNvTgf29h3FmQPZV//VHkaZd8hIYT30fsH88JBbVoJFQBasVNvXs2mASzXssGuBfCSnNqV5EHlGxW/SjCTnU/03sU9pTGwGwZm32Hqcj7lNhrJO+XeruKNBEVcuyxBk5JfsBxBsie9RXdE5Bv2NnU1bRztOWm2Vy17gkrHnXva1/DRqTC/OjkMBxY6bapWliSzz77fnP8ce7Ae2e7Gv499LULrD25MINTVhCj+ne58qWS2zIoq4MG68Rln6RF+jmgSRQtI+p6A3KL4O5TW7x4fOYYPg34+onjoNF/1fJbTomUPrZPwmWz2Kwa/U9NC0e2qaCTlrD+8YG6JEiVKlChRokSJEiVKlCjxTcHy9TbxCzg8gacZgNYsAdI2cFfK+bb5iE30i3IGYLn4Gr6kQ/kLeKz+noC8+3PHzboOEf3r/+4F4au5gl8oj9VHEnC0cf9Kg1x4EB2bei4VQLLv+gksE3YpoKPGmj7M+oUe/s2ABy29eQ1ONi3/3kGFqiLVVNiMEBzwm76pag/Ba0zQxm7JtzUn6ObWD3vFaNJwGnWh4weLOSnmmKdFzuebXK8AErC4cN2q9ZvZGxCImIoU50T/wipCYU3edo8gj0uXal675k9NewuqfP26dIs41bvJj9ig2r3VANpEajl9/CA//PXfyLRW8td/83s5nr7Ip+eJzy3Xm2zbRO9beGu7Ij5dSgBaho3u2lUfaIfbDUrxUcbxZsW9HAJqvy8ztqAzNUHwRxuMJT/oB2zKYwLRAHS9gdLYcqBsoBq+zmhjTygoL1UlrSdC1K/W1a93HW59l2/cJcL6Az7otytSK3ZuqvbvCFLwpk7F9QiKDZKTwVKvb+r9MCd5T9q+Ct5tjLomNSiW45h21WQQRyfVefbkNdsHe5GrOrOX8n7qJzTn8JsKU4NzDQAdvJPxcHUnlJ+qTKaSG9eU/Q2s+9xGILVwIJpZvawqWoP4O4uBvGMjq4M9yZCPlm0s8pYPQMHVrRXYVlnlbatrSDpFdXQYJ3yj2T+g/93KJNHwJBcPYyjD8uTha2OZ0NvXqbzopB5Q++xY5dB2D/CcmmB5JXZOc9SSZoygELdimbnAoRUjNDiu/7ZERloD6nQt6B1YYKA4aAernrYzVo+9D3b+ZBei15J2O4Sxl5IqtFTS9d93j+Rxn/vRLYOwDqcZaskL7mWgf3kjQ4+kJPyXZ51Dlt2D97wmc/C5e5VlGeV6nS2xou0A+yHYLcE2B4ps8OTbGW7LY5prcT45wPY6BVzrc86hRIkSJUqUKFGiRIkSJUqU+LZg+el8Iex6ud7kAhg3TYRiO8Qbdjm7zhePlvBYtxQ7C3G9Hp6DLzGUz+o1akWGDFokb1tT20JQxmOHHdsZxPqWaf1i33adweIbjw+187LW0hhk0+vU16riFEpYtZ8A8FXlcy0btpAbYAWUaky5iedhhQGF88tLhd3zcoACdWjldOz5xV+9crHtHphQ1ZLDMBjsmAgIVlY0fA2WAcMd3uMZ95OFbDGy9KhaJaQm5GzYlAASv3x6kS/PF7mMM9XmDqsBPtRGWCFS7W2MYonjJCOTCXhe20GvC91kdiLpvLqdv9J94aa2nGxcuAd1y7ZV1TJeY/q+ZEVQSdUP0gyVfPfXfyPvPn4nDx+/l9vzKD/BEuPzVYb6Z1m3zwrKtpYWHYBvUPluO7Cjar1NFlO0523/tFjZNFHycq7k5dzLy8uLwSgURVTbB7QpdpVjazn18tUsy9zKPM5SQSmNgpPzJE2LYmzuCQ6YZvCHKmFXVLpS2UDyvNIvlUkaFvXydoaXbs0kC/qeftFU5mof3yuzbTAkS4BYSA3F/+ZR+wDjVY/jPrP3hE/HJ7kWbS1W2Zpsg6DzwPs6WAhgTNOr2mXzPp/uLk8valesz7nmLmnEJIEzUiuimUwVFI6pXcAd6LV387wAyYR/ak/Qt40MXWOJHqiYW2noBQ6f2k6HKtSeGIum2ubRkDgIcDFCZRXJukWOkzlXXGtf5+sywGxqZFWz+66B2IWqhua8ahrZ1sYU/3hPTTsSVeM67FVo69esENl9LjzzY50FWwSoZF01bVNu1/q7PrNWT9fqc8grfoYEQrrPkHjgE5ZEMhUxweq9x3LwZNA9KfhdbTN8cFDdSwsiW5zcT9/HkSVNksrZ7Ei04GMl3dDL6eFBxnlhwVRc0ziioOacPbMtoZkSi/umUDuZumKhPjwPWJ37OyeuHC6r57K2dVUB9OLeV2mqVVpYtLQNIffD6YHve1qemWjSzy/cQit937L5v6Bgnywy3V5kmSfp4bnctjIMB3k8HbmL5MOHjyyW+vI8ye1q9kB2K7opBAVFdb2FnRTzA5aHYJPeJTVLlChRokSJEiVKlChRokSJXx0s36he3WQmoFULilzgKWyPdjBkf4HxBb6UO5xyWMP32RdgWkDg4Qq9vFvdimQZRHBVsT/nX5oNeGVFp8M2hx+GhnwLPG0tMghQAOin2QOGpEg0WKG+m2EzOK9HoQidketGugaqZRRXUqDSoDAcGZDdWFJymt2GqeCgvNVTRdxm9xrUn8k6YAe1rGCX2SwAWAJiQrV6uYxyu0EVZwUU75BQbqH4S/BpteJ9qfiZwS21QMi2DnG7ukJv90dVzp0VskFhG/rFlZeAJO1wlMPpUT785nuC1o8f38s0jfLldpXrBA/ihbYqgHi8pqBkd3sESzGEYowZESZ9Ir23DfhC1U6fbPTT3oI21/AK9iR5EGSW5/2SiqaZ2nLdKyGhNIRS8YZEzfnGon2A/oS0qW+CEpaAy/yjY1ftigF60Ta/DFVMkiOpjfK+yJv3cbqv3eWHyZznmAqwrX25ld6sMNxL2e0qnMilpSGPsJwTikXS8lzMRfNMmxug9VsjVl8T+yd2T54nSX3K+QZFaPT3zXYiqgzNthsJo2Yh7as+yJEVsAnYuoe0gdmd6UHgy3neAXJrYoi7D5LNSVZWx4J9WthvP5/3F2htl/o4/0xjzNadqLwNi2LuxiA3T+N93x27dth12e6Yb7w2KsK/diMOQVMm04rTmVLc54oOJ6wlUCx39Fru+47K4Hke6V+uTXIP1fdtpgUBcwFJZfmuBH79Nv93ni+uoIfaWa2hdCxhxwISP9ihoGAbkBgqeuxogaUTPlf0ORTum2mDgaTU1GDdGBVGw1rpMHAHBoA3d16w8KvPL0/Q6r2mzzfvf/9MLVGiRIkSJUqUKFGiRIkSJb4lWP7l+ZlfQC9QWNJ71woypW3AVngOv5hPJWwhhraWHuqsYeCX8TML/rkXp8g4LvKyXdXXdwIkFBk6or2MkBIbytBMEYOqE6EkIw4yv0iFiVAau5IVYFRVsoAwUCfOXhOvgYIRPrkOEwEedNuxSGdqXLP2wP3xkbdmU20Kz0p4SG+rDF0rj8eDnA5HKpMJ3jYtYlVB8WYeyoCKuJ+OqmqFArigeYPK1qmN+78aAOApAXOhOlZ1N20YCPjUZgEqVdg0PL1c5TZO8oc/fJLf/ddPtJK4jei3DG9cV0kF2wb1uG/7BnjrpK6wVRs+wih6Nkjfn6RtVaENyLncbiygtxxGWagAbTM0A0yDdcQ0qb0JtndDzWx2GRkSGegyjMhb7w7Stkf5+Bet/B/+p0a+/PSLnL88yR//9g8shldtkzSXmoriEYCWYka0Aba4Q/KH/sQYMm9n3q4DfQV9DqWWdaZ6+Hbr5Hq5ytItcuhxmY3as6BIYkKLAD8tkwaq4o5b/4Mymp7Uq2y4d6j6p1mqedFt9D22r9/k5XKV88tF/vD7H+Xnn7/I3/7+J/njj5/k6eVintS4TC0gx8JvplxmhiKdT8G/7gQwdf22UvPp+NiFnhjD6B11OM8wWdXu3u/50FroUaGmjocMfQHAcAyuA9PCApRQSNNCxZTPqC9JRW+yvlElOQ8fEyQJ4PKC1LeW9BrQzFs9u/YqXHeNdgbIrrl04GelOZG1Epqb8+dC9TaKaeLRdr2qu5EwQVKngj0N7Gu0SGdKLETuhp0O8B1wUBttKZL6Frsv5gDqF6rbzSmec6/BPGIjaWHCaK/RdJhzUFGPekTMF1s7vJNgBcMxwt0GbjgUFNYpgeTz3dZEjiWMDxzPEh/aGUkNHOFqBOnpeMHKhEpof7FnCuN8iJw6hd7n3qom/1S1dJhSVhRPNv24YhtyQYUa2H3v3YJER4vC/oZr8PsP7zmO33/4IH1/kXUdZcQ6ajYafEdQ+HJ0mfNIvUDhC79zfNZgPgEQV/l9sK6g7YZaqNi7k51Lx0KrqrzH+irbzLExLZWcX646U22OPB5P8nA8SdN30h4GmeZJPj+d+Vnx/DIyCYX1dJln6duTfPnHTxwrv/3ht/L+/SQ///hJzi9nwucZi+MuiYP5qclL9h/ButUteFUwsUSJEiVKlChRokSJEiVKlPiVwfIExTK+xlsxN7UYcPWcKWbvlIm6fV3VuFDvZrWlf91VhSWsF2Cv4FzDDhzEbq5azoK5CMTU99KgilszOOJMRaz8ENEZ09SGppRNCmcCh+xX6mpoVSz7sfR6aeHhZpXuGc0igG5bESxU6c2JIkvgOWb5YWrlpNRMJMtUqvuaanpmg6MZ9DjIVegMBThgKXyV/QGAn6wHokrSWVISpzpA1GJ7Diz9d0LlVPzMles4tm1JT21jD1PjpbZOcnP7tylVTe9sz2kioB0O8vDuvazzIh+++yjT7SqPDyc5DYNc5pVJBcBftTLRYpH5ppJ0MOE2PaVjyaCuJpRW5TIL89EmJSpj43UHUhZ+5G4zSGZjIqnfw3zhPAKUnRe5Xm4EzNfryEQAikuaLDipSSMyjapkV0T6332cR1Cb505Wa7ufduqGqGCOyuJwzd5+ezGrFv9yCE24b9ecCvUFlfh+DOeCcj6X0rl2sN7GpCmOd0tEWguyl/WrLopWKPas7xJQNb56au/acDfhzB4jqabjTeRrzoUFfW56kc8wV+3vSfkbV6KwE0PHPz0RbLzczZmk1n3dvq62TkTYjrsXDgflctKOW1vdjQc+4/4nd2d61av3i8q9mjlw7rzovPH2nXB4f/+57fw9d2rnqJSmz37DAn7wWEbxRhSc5frMQoi5gGl6z25e5PV790gXuF+Y96PG4XL2/KZDERk4Tffpi0wvaCuO2cAPuusIi3HNOpvUQ17nWcUkznybZcQOh9soQ11LfzhwbYbSGfYvy3K/s8EV57md4lpdokSJEiVKlChRokSJEiVKfPvifQu8ZEUmADiq+yCUsyJ29iVb/YlBzaCMUoVm37TS4UsvFXi2bTl9qTXvYpUVh23w+lC1r36h3qkHqYhUz8vj0MvD6cBraKpNPSj7nkd4ud1kIkhYpanxBT0DTIWm6mnrKjeq+bYIRReqm1XMHBWBqoYDDH95fiE4ALDgEQgRGqrGXl7UlxqKM7zrcKqlUdkaH/h3P7TQKdNblUUGl5lKUAIHU8ZlrpIBn7IogF0o4MyOgNu8Z5nGWT79Al/lq/zy+SJP51ku8BV1SBVsI/x+vMBTRz9aeHwO0neDNI2qlqumkxr/buFPq0rPZdTicAt8m/tZaveDXlZZALLxO60doDLOyYAc/osj+6g/raTuehk+fifVMMg/+5/+j/KX//iv5LYucjgO8j//7u/k5fZ7qatZbrCxICxRBR6L7MEiAypb9p0lOJiIQB9pgUg8YNGCLfIYq7APQaPAw1Q9qKGgbRVq0qO2UssW9I8pxHm1NcYqMauqfaGQZR/BVsOlkOpp7WMLAtpxXOXHP36W3//+R/np58/y5eUqV/hGE3pqm6X5Yt7hDoR5fzZfYmIiQX/jhgplHTibqt9zHgFKa3HNbNmrHsqY6xiD7j1ux+ZBfH7q61jM09oMB0DChWCMan30S048KMDMlcPUQ32l8j8rT3Xgq1K6YXKBauGoSsZ1p5mJPtUbQyFN7B7AlCKwC2psjIlkhdG2PD/6WH92+jcKnOG3jXVCfcGTqPPOxkV5pI9jtbyAOpRz2kc17XBwb41NPVPnmmKWKtYAi7U4He4ManvcsyajOA5tU0iCx0xQQbGMsQhrhZhsUu9r57f0b46F/nggHbPpsa/TmWCuLtH7hMr+33eWEgnm38uVY/LC53oEw6E4oK9OlBNH6w1LZvlYxL/Y9fjLbANcH/DZP51OvJQffvhOLpeDzNNFzmf0x0v6XEvFK1MB0MCZza8bSn2oftmVVCnba80WCE/oXDJ7pGrVdQa+/oLEaqsJmAX+/a189+GRi9DMHSirHA6dHA6tdMdBjo8Pcr5U8nc/rfw8OBwfZOgf5Pz0LJflhdZGv//d76nC/qff/0bk2NBz+ctxkG0bZZqyqv7eLseTU9zlkhT3d91UokSJEiVKlChRokSJEiVK/JpgeTYIBKjs6lMFXVY0i9BStwM7zHLFMhXF/DKbVZzpC60BL4lf6pP3oxYVS9/4gz+wAyNAQMBklt/iFvOGijQcgNt8CYhUgZrxhUvP4mb6rN7Ta1KPXOU3rq5zUKL3DpZwm0ZeR7Uu+RkAyBWqYRRlg12GQqrsSZ3VpoAVaB+3vsD7FBpGFeWe22SFrIO5DKoAPaGGvVxHeTlfWbxvnDaZoQi3d2fF8p4mqPrbVH60fMCWfAXwLAyGf9vvhGg8phdNzAWwErSAJC8VMrtTLFuoQ4XDJO37pBrEdRyPLDz1/T/6Czk+DPLDX/4gz58/y89fXqRvGpnXlQUgkz9s8u3VNsQ29ahCTD0dvE9VwYo+g88yVNAYS3bdNs690WKRyKicTBpBU3ArHIyQL/aX8jwo9c/nqzx9eZHz5aaKZe0IblXPQtEI3u60p698if1qdr8mW4y3AJIDtPtjKQz2Lf536lIbcxy1ZkXC8a1ZJhsDDtvVCmZ3hOidHryWs9hVBzwhHq0NFOJHJaaXLHQuppBfPZLJ9gCLHZYnxbXDYvcPV8U7gKuOA4xvzFdvFPNE4HyMHhD3Lehj39TKd+113z9v+ZCnxBUTAXZ3VONrQoD1BV25bu22U/AyGaHrSCpoF20OYC2SVMm+IyMXmFTVsl7Zax3rPVR+Q+n65viK43Z/uATkPTEZhrndpQFzg+AO3/3zYqf69q7R+0/JqaZWj+UZRfyOarU09PRZvtLSxc/nK+KdFUgYN24Xk1XOvjbYzpig2E4pzLDWq10GKuap9/PheOT6eh0vTCi2LUB4zUKTuMZphhWKfo517UBl83QdZaxv3Nnw9PQkbd9zrYbCGQX/8LjekKh43d5RkR0136ntS5QoUaJEiRIlSpQoUaJEiW8FlhP75Z7eANViQT343waVIV9bwXJ5k9s8UcfHQlPuTysZrrKulcFf+C2D66mnMGC2flVXs2VVS7YNlLWNnI69PD4cVF28zNJxO7DabjhowN8BA9tmk61vd8WKqKYExIsCPsBdguAsCiScJrAC6ACtAizb5HKd+L4OKthK5Hy5yAZlIPwsVwBYhW64x27oDGgBftcyjld5enqW8/lFfvrpD1Q2t92JNhD0911VyQYfYBaZM+9X+jqvUCdmgOXcErASarXLBb6ck1xuC9uT264NjsNn1+G+wyQglr5t5f27k3x4/yAf3r/n43g4Ss9t5IP0h6P6KLe9bGhT7u1Wj1WHtfT5BZyD0pfAys2sI+gw1eGORN0D/hw41vHj99IejvI3//v/kcD7utXy06cn+fz8ItfrTW7rJktVybzVMsE72RWhULuzAJreLxX2hMlmd0EVsYI4TZbo9eG/ZR7VHmObpdlGEkr4VwPeo5BWM07StJPUNcaA+XgbaE/WAg723KYC1zkvcjELjM9fXuSXL8/yckVBwpU614ziTLkc/vtaGykAowlzLh7oBc0MjurY86PkQm25xCXGMIpOZg2uJwnUDsVBr+Eog2YsTEY1p1pA84qTMtLP9DqV4UkUqswJZT1RpfMbPzEPtNBmTsz4sPE7bH0d4XVgGdqkhlocuxparBOdDMdBDqcj7RBYJI27AcyOBp7ZuEq3r3FIbOsdYLMWFk2mDCmBke8pPOdC3ERJrXIiwa+OEUqPbYeEJl/ibg2DovgBywbMJfooV5qsoTcyjoG5h+eDihpjmvY8+pMJlt0cm7NXM99m18EFNxbbc2Ww7/GIN3c3l3dU8q2xGTotWU7cPx+KSO4yIHn8JPccJhvtvnj9HE3h/Pk6AJabvpVh6+X9+0d+Zjx9fmTbXs6Xu+vOyRUkG9gfoegm5w9232BHA8aY26nYvPbmp/rd/K+xUweC8q2Bdzc+izDeAZAH+cu/+C3B8OX6wjXdd/Ys8yRfPn/iZ8k8zrLOm9R9LW3dynA4qHK/ruXT5yfp+gPnycC1+0Eu13cyjps8Pd3sfuzuAlSOPaK+0oUqlyhRokSJEiVKlChRokSJbwyWEVb/LEMAU18CuihkcpUhYIwCJ3xlhcoZW46TDUVSO/u2dNVyedV67OKNBftYCzD5xypo62Bh0LVyPKgVBr6Yz5OwCB6++CsM40VSiYrt5rAfWFctcKbKXQNfuDGDy3iKW/xdKe3cJG1h90Jpartwu6odxNYrsDxfr7IQ/mrNMLYWLSY6eXz3SB9MFsDDFuh5lufnL/L8/CQ///xHKps/fvwr6YeOABPKY1iKuIcxVcAAzdg6vcA+w4SUFlRoz1BIQ7E8y8t5kuu4yrQo3FdVr26t17vxIlR6g1B/v388yYd3AMvv5P27d3I4HBQs94N0w2BF+FBw0HR5LLAWoJjBZbyOtgiu5Hbl527ffdj+7qrgpFrO9wVV3/HDRxkeH+Uf/ZN/KofDUT4/X+R3/+V3bPM//vgj+fZk2QGMNwXL6EdXaJu1glliJBGqKQJtZCY1IpAlimfN06hgWWb1XwZYhrJ5ngmXWwDmBnC5VUsFiiytoBnHkO3hd5ElEgTurXy+ypens3yCYplgGXAct69jnUXsdAAFqGzzz9o753hcL++CXwf3uXmj1YUHz2RwWRW/VjyPE9DAMiBesjIxP2G7PoVtav8SEzbBKCBrQe/BFoG+K911PXH+BfUmrwNFyBafyzuhs74vFAeEb27bmA3GgkJ9QrscrBOHA8DygWAZyk+AOU+IsRihFXzUERE0uyTVtaBjUiImvegtLfd9DiX7/WonuA+J75bAgqOAWEmydpj2i/ZHVWNcWTlTFAXk+63f6I1hvi6uqjalMhJTSYnMucgbskySmf6mBcR7KtgVRQnxjpqHewvO5rvGiI0QlLx5DbiP14kS/7m9eklOqDGZYyD6TZ113UgDFXC1yfv3D9L3jbx798j3//Lzp3RgXxOTutiSn5xFHH/qy103G+01MlgOcxCfG7DtsOOgx7gbBU81ap0EgI/PIIDlH377A8fky/lZpmmU+XqTGd7J8yQvL89ywe8jElmqqkfC9DDA9qmWZRnl85dnOZ5OmhTsWnn37iS36Z18+pyBuc6RMPRsLcyppLdarUSJEiVKlChRokSJEiVKlPi1FctBEbbf1J3/UyDlMMy2yat/hXnM5i3Q6uGqX2np1UzFoAImAFYVe5r9hhf2Cz69AIptUiS2ih+2hVBZi+Sljcp7+4L9TRHeiqhdhsnh+FoCAVdK836SY2pS+zlcBZi8jpPacUBhSCEiscLOFxdemvTxpeHBLC8vT/KHP/yeUKHrG25jhj80wPO6jgTEKmgEzFRATt9ibfid8k+BuMg4zXIbZxnn1ZTKpgY1cJ96kA1TU5kN/2CoPAHlj4eBD2zFVu9hLdYHVSeACnxo4bGM941QUSZrBwcqBqvMkkDhvaE6Sr/rtI0/AyZre15TgMq77tICgsd37+i3+8Nf/IX8zV//FaHg737/e46f8abqVvXdzYBYrVZ0HHIzOsGRQjtPPECZjP5xG5OktnZIB7hvClAtkghAZJYf+J2DE31TsW3YFlqdS8cfFbJapHKeZqoloVi+3EaOnRnA2iBgKh4XgCQTAGbfEP2V81gOat6dUlZ/34tEk+Y4+1ubmtnnblZoeirJ/YnVWzkeJ1lMmHLTYVYsmulF8hxes0dt90O2+8iWA7RiIfwFHA2IOniw7weJvkLtCmppq1b9wqG0H5AU0eJtTWcmxYSnaq/Be0SCxHwzsiVIXuWy2NaLJOb1JBaB3C0uvpsjrJopwWIgmPcXIveBqczxE4pq7MJI1jxM11kihumh/WRRY2CD0RH8hsSOJzw848DXezJNQXdUK7vo2puCdxjVzf4zAfOoOrb2tjEZC+a5avzVPbwVrqqO4zcVpbSf6Xj+El33qKqH1QQK5NEyAuuYFVllssjGfODoft9cwyzjQbiMccl1RJMfnmS0dEBKgCYfbFt78fmGJbPpsJ7iI1F3nzx9+SyX64XrRb1uMtInH+u/+nSzOKndZd0Ckncy3mbu0uBc5LlFTsdB3r17kIPVGPChEMenfo7XwapmlyMpUaJEiRIlSpQoUaJEiRIlvpFiOXl1xupOBojowevbh10VpY+JUFeBGqCNb3EnmyLUhYIXRfRWaejLu8kIqwHAZADmdZMJakFnIYbCOnzBhhKx76hahlKxbUSGvqMdBkAhi3ZJpWpdA4QsgJXAA2wjVF3sX65plQC7CEIf83tW7G1gz+CPQVB48gJMXm4X2daRlgvAjG09SNcMydcXrx1R7A4cY22kWyv58cc/yn/4D/93GQ6d/ON/8ldUo3Xtg1TVgcX/RCZCSwAGVSpP2U85KtDMkxoA/uUyysvLTc5XFOwDXE4lE+lTiwtCISkAi2UD5DVgvm1y6DoWlPr44VHePRypBAecw323bSfDcNTiff1BZtxLq8o9wtQV9gO2NZ0ApTXIqvYM6kGKPgCcNSqfZcM53tpFzz+qx/PHH/5S3n/8Xm7Xm6zXs7z/jw/yX//Tf5bPdS2XWW1DZsBXFjZ0cGXKPBbegqc1ivZpQS60DEAv/o7+AbSFJyvfQcCG651lW2A/okpygn4oyjE+WcgPr4VdiQKuxfwiKkjkqbhfCKbQRvO0yO1yk0+/fJaff/4kn59e5Au8sOdZrWISHFMAzqHjo54+wGqJkKYlEw/Z3zfrg+/BfQDN/OnAXekTFNcsMGcJIVVnat+7mpnJnKaWier57ZXCc032LSaI3bFOXLu+FsdQtghluSat0J4ObZHMQGIDa4UqQaEWj5BaaZl7hSdVpiW3YCdwaDs59q08PD7Iw7ujnB4fWBANgJk7I6gc1USCWn802R7CVNs7lbhBxTyhDXSmNdFbOSt9NaHlf7V0R9fZHDHiZ0UdU3/ycLqeeiE7WGFI10uFtl1vthb7uTFWOyvKaW0DhbMp8JnMcRWztbkOBbdocS93n4cYswZwee2ueP6a0PhO2bx70R38Twpv88vm2rHs/hazZuqt7J7Kro5GEU2307B1wRXLlgBKMNcvwwTdSN7J1srpdOB4g0WK9qkmlwigcS6v/RcQNhNiyElgvHSqkkbiIxvLaBLVLVtSgoXXbFtwsC62Iu3QSN1hvR5ZhO9v//B7+fTpF3kcjnIaDjLCKmdEooulY9Ocxf1hPW6HnvMMzyP5iP7CcPr44b0Mh1Z+/7ufc7HKBJbVQgrFOJkWscK4et13a3CJEiVKlChRokSJEiVKlCjx94z/dd8ok/I1CAj3m5YTfPWv+FmUuN+Ln4uG6e8EpLAEoKewfwHO53PWQ/sMqoNV+aeKy90+9Z2q08GxQ1hVmhoAcO9Lf7jK1M8di6+FUFipqlXA32lW+4qsXHS1aS4kCPuE8XaT6/Uit/HKIlL4an86neT0AKsMeMDq9vaozdttWE4FpLJKT1XNCjypbE6F9OJ1v1YF+gZ/gAgAv8PQyWHoCVdh20FVH0Fg8P61QmcNVKXwXAbsdE/StHXeLAroUatA7g73hGvyrf2vldj7a1U/3KbruAX8w3cf5cPHD/Lh/aM8PpzoEd0SnGuhLGIfKyKJJENrQLlrFSaasF0Vx95+9sD423lxJxWivxZJEPvpY8nhLv+t1iXoc0LoFQkM2Geoah2FHcdRx8zE580S4q6fkyJ9N1/28yapXIPaMl64YzpPpuzn3l0bp10F9y3/tsdz4qzpTxmGpxROLJpoa0O0wcl+63lN2c3dUBSS7ROvO6qyKThWKIzkAUF419DiRcey72a4H2D731XRn4va7ed9UIbvKhLeU9fYkNZ2lozKz/kaaf+lp9yiIniYpEKC5o8crtmVwHG18EvQdVhtf+4VxG++534E+Wu87V1VH1/7qqDh/Xpz335RzRteF861f28G03fD7O46/1SEwpQ7Nfo9E/c1Ox4ymYnn+gB7Pxnr6rf/lnzNQ19jXbhcLnxgrUiWGVRO19K1nbQonkpVtJ6T6x9rEpi6ngnaVc7nMx++00bnVZwjaZC8+RkQ41/8i38h//pf/+s/05Yl3op/9a/+lfzLf/kv/7ttnH/7b/+tfPz4Uf57j3//7/89x/unT25V87/9+Bbz6t/8m38j//yf//Nf9Zgl/v+3H+7v47/39czj3/27fyf/7J/9M35v+vvOsf+t3FuJEiVKlCjx/5NgGR/aeKiSLwvu9rvxraiSC1GxZRjeq/9v9t4DatYsK+vf9caq+r4bOk1ihhmyiOEv5rwwR1wqitllArNLUTGhy6yomDBgXOaIOQdcgDmhqGAAhXFST3ff8IWqemP91/Psvc85b92vp+9t+goznH2n5ktVbzipun77Oc8uaEBgqj37x8JUKq6DAAwgiEBuVm/bDvBtHKU3eEfE4tuaS6jPKvNLNf9RKspMlTsNIpOqUEucAGpK2zQ+Q7UFH2IowlAYaZiojJ5gR9B1Mg2D+ufi3LiGYSBAhFow2iJouSNAgGHS5+12nVxfHmToXf6sijuA7HGeeE/73U6ur67k3ssvyQfe+165ePCKyKqX9bqRj/7oj5V3vesT5PbtZ6SuNwSobNNQogzbudEPAAvqI61capZ5HKUDqD6g2FMnE5S3UIPzZQ5rTIVoKjrVMasit1xNUheTnG8qecvzd+VNz99loavzW2ey2a6lXZsnNNpqOkoF8FHWst6ey/rsjjTtmVT1Rr1gk2J9KxTOgsIZakuoLoNXbBX8YJfK2hvlyotYAQ7WrTzz5jfLx33Kt5RP+JRvIZ/8SR8vn/ixHy3PbTZytiqlPVZSHiupjrXU0nDr+Lau5Kyp5O55Jc+c17JtVwIB40oU9kOB3XUHwv6uV+gPpTEFkWZHO60m6Qf08V6m7iBzf5B56GScUORvMIiM8TeLwIpk6KXvO3ql7vYHub7eycXDC3n48FLuP9zJ/YudPNzt5XJ/kG4wxbL51qrG2rH/khDpONACkAqRFOjH7ffaljoCV1JCqU3JtM4x9SO2ZEzIxZiq3ArxqT2E29lqAxyPKEBWyXwsZTIZKK+BxeOsECKUxcFPW+c5dhKwWKLtGICiuC4qgrOqhDJZga/aCpg/rTpAEOLT65q+4gpyuTuBSYEwq7ku4DVIGjR4FCJtteL43Z5vZb09k3ZzJlWzlpXAOocV1WhZg+KdC6ccwMNK1ayc9/Ro9x6w5IHK08N88jkal1X4OCMpow+FwtgLgbY0NTJ/1nvk69WUl+eWqg1/52vgJVzUUlStrMpWCnyPc7BvfGjYdXmb8D7inFPDBIP4RyjER1kdYc0zmP1FmhQyv2ZfO6h4xXWa53vi03tC+qOyODxctTuf/N1/9nXAr99V02o94/cTDh8ge6Jkhj0M3ifMCsfc/e36/RrdfkQTluxze4onPBza42+64QB9h+QZbIAa9Zi3hIBeqcrz1XveLJjM6YTzFHME48DmGu4H7xeXV1fy9e/+enn3u79e6rKS5559Ttr1hvUE2mYjb3ruTfLc3WelqVqppOSai6QZkn94YMcOkibD2MvXfd3Xydd8zdfK9fWeYxt/bxo8B4k126ljqmmvgxCTwZqIwyNHjhw5vrnDWqzvgJUfifFbf+tv5X87/o7f8Tse+dsv/aW/VP7pP/2n8uEWn/3Zny2f8RmfIf/3//5f+Y2/8Td+Y19Ojhw5cuTI8c02Hhssp76jwQoiVe+lqjZ9RfhTqlRNBbQKimPBMIeMUPfqQ32D1Q7UjuAFAvmBXVV95gIawLZbAgRVpBVeonLPxcfuj5uqIc2LOSiMk4J5AcraAVw96ZYLAIpLtbJHVLKi4B6BdXeQ/X5Hf02AOfh9bjdnsgH4AtByhfCNyr2IGV2BqseGbYipZxOVddyoHfsp9f8ktlxFxXJLxTLUyvBULqlYxkOLf+lx3fezpN8yrldVy/GaEzms+QJHz9hYvPGkmR4R0kVdpis+I1BCIcHtrVtyfvu23KVq+bZs2lZaFG/E2CAswVcFMrROoX0KvKzhteqK5agujhYXKPQISwdNIFieIPhdQ32sz7XxcdM4srHkx4VtxkD/60G6DkUV1VsZiuUxVck/oiq8WQ26VCxHRTWf4b9Pivydtnecz8nOAi/CWbxK33jBxlNZ+Q0C4MUd8BTJdQf1qIJxHxdRKZ0UO/R/C7/apSPFYrhxLVE+6GMa9gFMitHag34GCQxdPoIuOyimLVmmRHOxgEXkv/w5Iv2oblV/6ajmDvd82nDeXwCgtl4tbpDq1UTJvGj8RPG7ECDHLKCNkCTZ5DA3keg+okA+GQQ3jMfT53G9vlFFnYLk1Ps7STKla+0i+XSDqjrZPbO8pOgp70rksIacKq4XYUVdF7+xdvcCibZ+6bhMDnPTLovEUiXOU/jtq9IYXuuaLIEVDdZQLdKH9wQUfMVapu9hYTnV783vHG10vbuWq+srvrfgPLqbxBNO8W7j5cVRoNeY7Hr4JhywKvrmHkxe0t4mR45vnMjz8MO3H/7En/gT8st/+S/n19M4Pz+X5557Tt7I0JoyTy+urq7kgx/8oPzAH/gD5W1ve5vcunXrqZ4vR44cOXLkyPEGgOWmbKSBwhBKwxVwnX5K5nZ/eKXCDoJb/wFkFQpDfeXfH13ISTXwQAVVVR6lqVeyblVhxSJ+N0HHBRpT0MCifVDtQnM6CmHdbndNi4nj1MvqOKoHc9vKZrORzWYrTdMSLhEQ8FArtTOgyXJi0kBF28QCSlCcDj1UqgeZRvjs2r0GkKyKuqKy4nbm2amwEp6o8EfWBy0Quo5bPD/4gfeTmr/rnR8tb/+ot8n27Ja07ZYWFOprjI/9KN40SbGChyYAEHx8oaCG56/6+2rjukovFlELCsCAAlVVGf+hWJwqEMsV2lNk3ZZy5/ZWbsOTdttSSd3aA5A51B2jsLKgAnR9dkvqdi1l02q7pqCKu/pRrQpqvoqF/6hqNoW5c7rHUSqfRrVey+aZZ+T5t3+UfKtP/f/kW3/qt5WP/9h3yLve/mZ5/vZaziuR8+Yot9qjPHNWylufXctbn9/KW54/kzc9dy7n21Ya2Hys8GFdxyQU3wD+l5cXcnHxQK6vr2V/OFA5j8/yKLB3OBzksN8zOTBAHQ7/ayhO3YIksVjx4n49QPK+k4uHV/LBF1+WD7z4krz3Ay/J+196WdXK8ErF6wzu4niE2gbX3OLhFP5oYsbtIjAUEiMNAirAKi2C58pmVSQDtFaaGKBaGEpzF5qrwlJVlgoxPU+hYx4et1rMEd1G/2juUrBdCAbq/eE2K+7PrH7Ges08vs3HheuA35AlhxyIBmscA88pnMbDC3niK0TUONdms5bt2VbW6zXnPwpjlqU/0AZQoJqdjt2D2h1YIopqZcxDt961bFeQOLvfsXp5c6XCMasGVdpUpY854I90yQ2AMCaSCFNN3azG3FYQkQ+0CPzokzlkvtAKc1W1a6VRrdCkqX/dgsKSYdPYyzR2Mk+9yNjTQ3wxd2mXjnvxfnWI64pirJlQc+Mx6DFg64PHjIft8AhtlRQ7XKiUl8kyqwRqUmL7yoetdbbe2ZOjYtme5+MuKrc1ERbgcNLXtJmwtlcIHdIY6SQ7fSMKWQ2MRSSZmIA033E+8P4RjmNJBp4T6w3eI1qZp0IePNjJvftXcn3dS99NHJ/P3L0jZ2drKQr4KR/lzp213L27lrbB3ImwHP7j282G1//yK6/IB158UQ6Hjn9rm0Zuo4jfukl8vq1eQFBrx2Skj+2bous6Ktk+6qM+Ss7OzuQ7f+fvTKsCj6//+q+XH/7Df7g888wz/PunfMqnyN/7e3+Pf7t//778xJ/4E+WFF17ge/AnfMInyJ/8k38yvBYKsx/7Y38s7RmeffZZ+RE/4kdQfX26Xfk3/+bfTGjwSZ/0SfJa8a53vYuqtR//4388rwfX/Qf+wB9YPOcLvuAL5Ft/62/Nv7/jHe+Qn/tzfy4BxallxD/8h/9QPvmTP5nA5Qf9oB8k73//+8NzsA7+kl/yS/g8wBjAmtOk8j/4B/9Avsf3+B7hOT/sh/0w+dqv/Vp5kvBr+Vt/62/Jt/yW31LatpV3v/vdr9kv/tqP/uiPps3Wj/yRP1JeeeWVJ1aMftEXfRHbCMdAXz18+DA859/9u38n3//7f395/vnn5c6dO/K9v/f3lv/4H//j4jiYW3/sj/0xnh/HwBjAvaSB8fKJn/iJHCOf9mmfthgDCFw3+hP3imOg7/7CX/gLT9SOfi1/6A/9IfnBP/gH81wf+7EfK3/1r/7VxXM+93M/l9eC8+Dvn/d5n7eAY94uf+bP/BmONdz3j/txP04uLy/Dc/DfDT/lp/wUjpu3vvWt8rt+1+965Frw+u/wHb4DQdhb3vIW+Qk/4ScQkL2e+FB9hED7YxzjPfBbfItvIX/wD/7B13XPOM7HfMzH8DivFVhfPv/zP58WBRizGIeYx49zTozbX//rf7385//8n0OCF7973Phv/+2/ca7dvn2b7fs9v+f3DPPutcYs+hSB8Yrz+s8frv2Qxpd+6ZfSeuk3/IbfIBcXF/Iv/+W/fCKVONri9/ye37P4HZ6P153OsU//9E/nuoQ+v8mCB4rw9L9n0deY++gv9Nu3//bfXv79v//3H/J+sN45SP4+3+f78Hj43U33gev+UH2J+Yv3Klwz5uzv/t2/+xH7mtdacx/nfQPP/07f6Tvx9Xjud//u353voTly5MiRI8c3K7BcATat8HgU/kaFsT4URJn/8OJ3EnyMYVnBImpeSC0tjCepQC31KtXQrfK6ZV+3LR9Z7G6AnQX+Q4zbvAGuUXQOwLeSpmnosXqq8gsw57iEdEeHBqmCdeHLrDDBVaZa7AygRyLAcbWzKS5xDDwAJq/wQeQ4yzN378qd23ekrlsDfWpxoJdiG+EN6BBg0j9ZVbJU9cHaIAAb2+YeKY0pkp0ruLGGYg8HYlQsox8q9Vj2AohoN1hu4KGFqmI3aBErAOeWnsfYcq/2CQ6z7bkGM9VrWRV5N3r1Rl3fiUI52iykxqewQKk2Wzm7fUfe/FFvlbd+1FvluefuyrN3zuUM91CKrPGoRDbNSm6d1XwAKJ9vG2nhuevK8Nn8r+GTPcASo5MOQHnoZWDfA/QqkAGAhk8yFMgs3gdwZ57K0Wd5+cB4h0VKd+jl6grg+lourq7l8mon/TASKitWTscjxo61dVAex/G/5F2JFUaAruY3bIpxAjXvC/vZlfz4Go/tf48qYp8Trs5XAGuJAR//wWM2uY+Q2LCimL4zwU1yeZz03hIFtLPlZNA5sEsbIvg1GygkvGaCSu8fY5fKT/dYJshWmA3Qd5OaPiqMo72E88SoWDYQuVCZ+24Ota6ICtf0Ee/dVaxLv/RkzUt+F0XIbqORVO9Mdi8EO4nF7o00d+Nj1cduBLdq8ROL9S28hJP1KMJce3BdxLHG+JWQN1Uf23p04ve+VH2fPJLXp1Y+Szm8rQ2L+04U2ak2Oaj3bVaFgrPxHhcO0ot7j3NNDxMTAcuH92GKp1O4rOMOFkmHPeyRkLxEYgvvhyVhAcYrEovI3azbSjbrWvDWRTsZ98RHIoXPWxFkAYxircJxAZ3pk19VJ77t7h0dd+ao4Y4mqW6Kn//zf778q3/1r+Qv/sW/KF/5lV8pP+bH/Bh+WP5f/+t/8e8/7+f9PK6XX/ZlXyb/5b/8F/ntv/238wM1AmDkq77qq+Tv//2/L1/91V9N4ACggwAwgcoMYODLv/zL5V/8i38RPoinSjxszf4f/+N/yD/+x/9Y/s7f+TvyOIFt3t/2235b+Yqv+Ar5Fb/iV8gv+kW/iK/3QNv9vt/3+wig/tSf+lPyJV/yJQTDaUBN/jt/5+8kAMS9AeYCKngAFgIkQPn3z//5P5d79+7JX//rf31xDPQL4DMACe4D5wWwelLFMa4F7QqghGt+05ve9Jr98m/+zb+Rn/Ezfgaf95/+038itPlNv+k3PdF5v+Zrvkb+8l/+y/K3//bfJiRHewLCpzDmp/7Un8r7/9f/+l8TGv+QH/JDFpAVAUgI0IbrxN8BcNBenlz4UT/qRzE5gev8mT/zZ7LP0kAyF6Dp7/7dvyv/9b/+V/msz/os+ck/+SfLv/23/1aeNDAmf/SP/tEEWbgOQGGMTQ+MR/Qrxu3v/b2/V/7oH/2jhExpAFICjGE84gFg99t+228Lf/9lv+yX8Xd/82/+TflH/+gfESadAneMfyRAcB04FmA6EilPGq/VR3/uz/05+bW/9tcS8OE+f8tv+S1sA4z7J7lnnOeLv/iL5a/9tb/Gfnqt+JW/8leyTXwN+PN//s/Lm9/85sc652d+5mfK53zO5zBJBSiHB373OPHe975Xvtf3+l6E2ZjX/+E//Af56T/9p9tujtceswDPCCTAcF7/+cO1H9L443/8jzNBg/cNfMXPTyMAdrHO4f0Abf84gbn49re/ne2NPsMagOv8UPHdvtt343sDAm2C/sLvXk9gncZ7EJJeeK/Ae9LpnH2tNfe13jcwBpEoRTIDr8exsJZ9OOwWypEjR44cOR4nqsdWLFMld5QBHsX8EK+KLfeEDbgAn/Pp46rWCYQUtn+X1ejtQz88Vtf0QwVcBjwFtFIVrYKosHk+QCOcX4sZ6dZ2bBOGDHoCnEMhtG6SqVRfT5y+rtV3WaaS8G6xXR4A1zxQ8WHfrQjwNxZHQmEks7aAEpAi5+DljL+rB7IDR9UDmxoQ5yr0gQCM50d4U6LiQ0XXdyKrW3J+fls2m3OCBwTUyq3B3aauCTE7eEFPuL9D4u+qQAWtBCgB72NcrxaHGwnF41Ut4ZeDb9+8j/6pAeHplWteqFAcQgpuKsvgIetKRvSDAzIC9agCjxRM+x2/Ctv3l3j6ht+9VqTbuAup2lZuv/Ac7HLlnR/3DllvShmLTupmFFmNclyNsmkreeHZVtq6lLMN7DDg/YwxgH53yK4qG/AVaiUn1X2ir5pyJXMFRWwlddksAZ+DKoOkTCBA4Ux7Ejw06TENk+x3B3nw8FIeXlzxcXG1kw6A2hBbmhBwO4TUsiItcqmnVLWrXvvSKkD/tnRODVp2E92qb3n6t8j0/Hzuv8pCmXy5jTt2syuTXV2M+bmSI9T2nBewZo5JggKer/SJtXlEuxmFZIt7DeB09UjSCmsLABqtWEynP9PHPRZcwzHqumEySXcAOFCONix6PvfV9eKTkwjm7RyBr/5Hv49ttvTCG9hZujNwrSUY91bA0137Q49x9J85T/TYBN3olwCi03EQbSp0nOk6KHO0OQDEXYQRcI4JvJyHjNYR/CsTNck0VE14cpClylnXHE0cRTSLp+luD8JpdLi3B57KtcNAPZ7jzRiA7w0WHieJsVAQ0/nxTcuBkfClyYXbYAQnZPOGnplg4LjAGts03Ing14VdIJo4iQDZvcj1nvF+4I4eBmgxz7Frhe8rnnuwZE9IkEWorAbMlRTVmue8vNpzxw2hvMxU3W82Dd+/zm61Zn8w8PpVoA6PcbxPrLmLBv7tOFd36OgTD+uiW7fOZRh0F4WrlTl2QtLX+oFtdzNUxodiABZ8hWIYgQ/JACf4PeAI/gZQBxUpAkq79PXf7tt9OyozEali7C/9pb/E9Raw1D9Y45hQcQHE/YAf8AP4Oyi78Bz00+MGVGAOJ6EABDAAoIFSEZGq0HBNAK4/+2f/7IWCEOv+H/7Df1g+7uM+LkAFKP1SBRzgGaAoAs+FUi0NtEsagNBQbwMcfatv9a0e+35wLbg2wPLH7ReAKUAPB+ZoBygU8ZzHDQDdP/2n/zQVeojf//t/v/zQH/pDCdWhsoVKMI0/8kf+CPsPUBWKUQ8AU4AsBK4NUB9QGNeHZAPa2FW9UKV7gsID50+h/i/4Bb+AbQ2QB/XfkwRgEOA1AmAXEAn35X3/a37Nr1mMDZwXIClNPGDcAgC6WhKQG4kDQEMkeADs/uyf/bPyfb/v9+XfAQ8BzdJIgRvmDNrkO37H78jXe2LmjeijX/frfh2/93EKpSvGH9S1AKyPe89I9uA8GL+vFYC0GH9f+IVfGM6BPoZ63+NDnRNqcrQBksK4hycJ7E6AEhnHcjCJse/xWmPW7w+/e5Jzf1PshzSgUIY6HzAT8ZN+0k+ikhv99CTj7XEC6vuf9tN+2hO9BmsZEjJQciMA/F8r8J6AJBsCO16edKyk4xVzFMkPn7NYR31tfdw197XeN9AHULFjnPnfoWx+taDApdNdSP76HDly5MiR4yMELCtIdUgBaAEl7QwgHNR7CmPAMwCAUlZFP0h+ztdP5VAqb+qCBdSqYk7AMuCEwx/9gE63UquItJpVYUkf3bLgB3uCZYPLU10GsExrChS3GktZwTbC9LrqQakFwwAZ+MEWyuQp+jLjMzgBMT6E4/dUyaJ4VsVrxfXj2oIq2w0mzCqjtEJVQd1JSDnSToCWGH1P0HB+dkc22zMFT0d4AVfStjM9NgEYjmMh8wi7joFgmaAgATBoF1pyzAaWqboF2DQ7BScyqU0CEwPRRxT3C7WyAnUt9OTKwyiPNQUqjsVEAfrBFOC+w3+hMrf+D6agj0LlR5XLrwWXHXR5rKRs13LrheelbCt558dhi1op+8M9KY5XclwNfMBq5e6dVuqyoHqZQndeF5piFdqUYBnKAiYh0K4j+26qCjnWpRybVrbt0fpdjRGCrYMpIDEWoGpmssH6hermYaIdxsOHVwqXTbVMm40Acu0uefxEbew+4faE+NU8ikMBuBQs2zZ/bzYbMGq5YB7j7Lq0PzwxYGA6UQKr3YQVxwy9kPh5h7HkrgW2BR8nsXuh6j2dc2aiTgsch8sJVFZFtFlgGAfFmIb4CNdxFMzFQubS20zva4UkAHcp6BzCB0xXpwfFdvB5RqLEYa4BclNWn6qCHwHLBuSCytihcqDMUY8bVckG98NRAbyxDPNCFEAG5XeivNXFNckE+HHUgmMhBLa1GusWk2Xk0TpOAlTmD6Fqna0RBiF9jQjrup5fC8ElEDcFyqlPs3lxM4mYTH26+XCdsCKoj8zttJ0iGFYrh1PR9HLcRbH70h9Z78MvTe8Du1gwJph4wNigQtaSIbwPjFWF4uHa4xSK927WHADLQRecWL0oXLYx4GPJkgda2FGhMcAy0rUocgll8tl2LWdnNRX+9RrPmc2TX21lcFR4MmNd62TF9xOsVSw+2h2o2r91fia7a7xfuMUO4H5MvgS4HCxJHl17AfhwfSmYQeDDrntx/sJf+Avl5/ycn0Nl5vf7ft+PMPXbfJtvw7/h9/gZyi+AYqi1XFEGpSbUd6eemAA0qV0EgPWTQGXEd/2u3/WRn9Mt3P/kn/wTFrH67//9v/PDOt4zcV6ozbANHYGv/uEfge3RblUAOAB1HLZCewCCAaCndhhQskGlCPXwyy+/HJTKgBNPApZx/96mj9svUEVCNXjaDk8ClmFf4KDMX497gEoQEOfFF18kDEMiAG2Da0Ib4v7SSK8diQJsdfe2xHWm7ejnSQPHBbgBSIYiFf/thHv1vvqGjo1U+YmEByAvxiAgL8YGrjcNQL903KZjA6/D9aX3BOh1auMCVSaUnZgHsIxJxwYsT96IPsI14nqgXP9ZP+tnhefgngBfn+Se3/nOdz42zESfon8c0t0Uj3PO1xPoSwDTV1O7Pu6YfdL4ptgPacA6BuuZJ6dgFYFj4Zy4rjcyPJH4pIphJHyg9MX7CBJA6fr7NON//+//zffQNEmFfknn7OOsua/1voF1AEk27NRBkhP3iZ0ceM5Ngfco7PbIkSNHjhw5PuLAclvX+gGV9YHUIgLuX765OFUsu7LQP2SrMlKBiz8PkAkfgPF6fHAGjA2+qonqL4rX3K7CGI7ZQbDY2koLqfkHZkA2WmVAlboqZJiiMg7QlFdrikSou4LPJgA37QJStqXK1fTeaH8crsv4KSGebrN3JasyGMDimtd16Dr+x6L65ppqugL0qsO2fPV9BfyA0rKVEcpmAIHZ4NHJLlp+kCXEVDgKMN008K4GODlKNYgUvcElL+5mqknfsg3WUVPBbQD5CJuHXoYBYB4+qxAvA14MhKZQ5FLk6UX9HAOF60ukkCkkfBWA8aShPRKBDUBi2VRy6865TP1tee65O3LY3ZXpeJBxPkgNH+817k+4vZy3GeqfGURMFJPuZ+yWDQEW6l/Dz6lfKZSupRdztPEMoKwqci3chwesL4YRSRQziQiDLb0/r8gXNu3r10RFe1r4TdkgYJ5vcZ9NARvVlrxeexWUtMSPyW7/1ODGlf2xQKDeG8eDJXrCNVjugKLIoDj1IpcOseJrOM5MxRotNFxNrW0KVfI8F1bYUOfd6ZVGhXPqF6HHpAUOLFysaJ9/Vc9oG6iA3ilYdpcF97YOmRcbMIDyhLWqMdeFwBTPPkZS9XJyvcH6YnGvcUxZg5+4E0VbB0JdnktVp0v7DMBDB9pYb/V+aG8QFmZN2IWcT+i7kM24we4iwl4HZsmIsyuM9+PKff0CeLoco/TcNsP9UGgyGduJ5DuZDNZHC0G+l2tNY+GMfMMqc6rm1ySDJhGXd6UJTQXqwUrI7TwsARULv8Yxk76/RVulkytxXn5yQjwPa8P1DrtSYIOhOw7g3Y+xP/TwN/dr1rtXq5t4rhEFQnutA6Dvab6rII61qNj3dnx0J4QH4AbmDSAYvqbhKjfAAHxQhk0B4DI+DEOZB1UpvGzhHwkPXShDAZpgnYGtwjg27A2wRfw0UnACEPlGBiwHoBgD9IbCFB/2sS0ecAVA0GHlKZhaFKZ9zIC9A+ANtrRD5YZ+AVB+0qJbUHCm26Ufp1/+XwSUlvA/huoR9wn7AQC10/u7qS2fxA4E1iY4B5ID7o0N1fkbXUQOak5syQfMwZh25eupR/I39H5gkYLj44Hxj/EOsImf38h7ct9wjL9TeO/j5nHv+UnmIcbrh4rHPefridc69+OO2TcyvrH6IQ2o6GGjg/8m8sCYxS6KxwXLeE85XQNvKs53eo2P8zokWaB0xvsIrJOg8Mb9nybH3sjrfJJ43DX3td43oG5GMhYJPkB9JDnw3vhdvst3eeSc2BED4O6BJCg8vHPkyJEjR44Pe7B8vtkobOlWUsIreJykkIGYVqvTa3k4Ii2og7EtXhTmIRTeiUxWKAmF0zbw5j1q4T3AN/+gHrYay4kvKNTG9iEYMBofpMdikPHYU9WrdhBQ9/YyF1D9blgcDP/NiKJWlMxhqzvtHki3pW4bbs3HsaIaWG0ztIBerBelthezzA6U7NrYkPwPNoDhGhuV+T0AGZSS6/aM1//g4UPZ77vwewBnbGdu6rWURc2iX3WzphITr9m0exkOvRznUo5zobWyoMg0DhXEeLRawB+Psl6jiF4l2+0kw7GQw1DIrmOlKRkJp83UgsxKARkL8TUoegYYAZAHBddOlXFrKLVnBXMGBPtuL2UFD88mFFYL2/VNXejqWCMfSVulbfe4SuVHwxk2i9E1lTTbVt78thfk9lkph6tXZF1N0g3Xsu8veE9VhYyIbl1PIZqKqtEumqSISQKlWfjnhSphkaDK4VJ9vQGQWYxypN3JqlSQrApCtSWBFQaKau33B/b9bt/JAeMdSmkAQNgglNWiANyjnruPeh+7UjmGPZ+gW+8jwECzYaB2kb9TGwr8XB4VQAV1sgFC92fmfDPf8AnFKKmqV2/XoACmIlmVkcq83IBFFcbTZOcwH3JarhS4f+9Da3NTfKPtDgco+gHbJhkxrw2apvdLO4ygyo5F3/B72sm0sDtQOwx4gdco3gnbDK4FljDCPdS2C8N2XQRLDRsLwR8ZmHE8cveADg/AZcA/fJjQOWY3rWsNo1xC5YXimCsj/ZjVcgFf3S8iKUAXitkZeNYLs7/HuaeYMC1Qh/NYIs19hx0iPjJ+4nFSqBzG1omxt+tc7WISmxR/rtqPsM0s6eXF9VzdrEJvU1L7G4X3pXtdc0zZdghXTBvEdW7OWZp8eIuAe7nuODTG2osdGlj3y0qLWvrrQqKCavy4Rujfdf3AfKfNjbd/SNhpgUV2kY1LFLD1hKQX9gswPijG9f52OxSK7eWqqeX6asexWrctDwg7HSQXYSej1kKwbNJ70f4+Stf1stvt9f3FwDLmrBbtw/3Q1IbteVxhV4/a0NDm5obRABsLJJChuIIK8NUCH3ZhJYEHPgwDoAAsIwDNAHPwwDGw3Rlg+VM/9VP5wRpbmd8IpWIa8E49/dm3HAMM4D0cwMbXUChhnyQAfaAygxIZfq4IjAkcG/eFALyCWhFt4W0HgP1GxOP0C+4X1/eh2uW1ArDzfe97X9j6jdejzVzJB4sRWEjAo9b9kqHMfpLAdZ4W8zu9TpwHhR2xfR+B/vuf//N/PpGyNz02CuulP6M9EbAKAWz81b/6V4e/P2lhLagVAZfQ9lCxIqBIxvXCWxUBpTzGBzyIHRS9VqGy19NH8DTG76HIBLS8Kd6Iez4N2BgA8MIexG1HnvScUOmzUPATBtTxsDXAf4PdpFp+nDGL1z3pub8p9kOqtsX4gkobiTQP+JyjQB3Go1tQfKjAWp4WogPo/D//5/881utgN4GEikPnm/yhoQbG4xf/4l9M6xxA2NcDlnG+D3zgA0HA82rnS61o0Ofwd/Y5i10pmLO+vj/ue+HjBI6FB94rkdSABcdNYBlJDzxy5MiRI0eOjziwXAL6GDhSEOVOBywlFMGi+Wu6J2j6WNg5+gdu+jADJrkizFWkziw84+uETb8Q7EEVCrgHxbJ9yHd44Go0hYQKhItS1bn4IO1F8lS5BtUaAI6p4+x+AgMy2whnGouwbfDqAe2Azrx3gxhQoQJUEdhyi59ZUIxF8WDJofBpZVuwwaihWAa81YJjlazmSUFNAosUGPk1Om9Sj+i6wsOKmRHirYx1OURx1Z/2Hvo3shgF9PgPqRX+A7sotYghoP2kBboIpd33dYFfTrbsh99GFJV0YxKPAZdvekpi88At7k0t7bpV5Uo5yiRNsEYBLNIkQSzqldpLOGAKhzYGrh6ziSbSC3WZcpFjK3gcWxIEQBWJDxSVpFp5kIFeqQ6Xotoz2FMYGkzPH1TT4TrjNS+u89WC8zGB+cm3rjn143vyJzglLMZE/NlHUCiGRt9e0+KaAjkeOyqAQ1eZylLBsILD9I608KV6Nrv3uZ0xqMxPXSHiQ+0CMLc1GaIgza02XBUdhtRC+pqoiZOZEpIkppUPr1+Ia+Pf+f+2Lvi8ThtfN2/4Qpmc158bFhp/GIwNT/eb1wSBzznOZLdcsESEvkbHo95Dqjw+ga+e1VvMUwOZiz0kvlvA71G/6i3HHQq+vixzcMs1QAs/hncHmz/pGD8Z3OEWTgf9oiOXZ1gUwjSPfbNSwhrsSQ23deFoTrvn5Ih69liUMBTc9PEZzEySuWw7G6J1jO+a8fax9xhz2HAWL4PazyiLV5spd5lK114cv+962e+Q9INSX+8LljAuptQxANX+cvdDsMw5CXzIBwgBiAOIxYfhl156icAIAAceolCOQpmM5wKg/bN/9s8CxIUNBFTJKMKFLcModOZ/w3GhRAUwhAclPGgBUlCUCp6ip560TxKAR5//+Z9P6w2owf7KX/krVMIhPv7jP55rMTxQoSjGc+GJ+aSBgoAAg4BogDJf8AVfIA8ePAh/f+aZZ7hFGh6ugNAAT6dF6V5vPE6/QBUHr2lAfLQxPImfxAYDgUKSSAjgGIBIOCa2bruXKe4dW9ex9R1/R9LgtRSjp4FkBO4BrwWEBJyHf3EaOA/8YQHf0K5oa1gavB6wjLGA64XfL9TC8Hr2ImY4D/oJSkn4HWPMnBZkfK2AehEKUNwP+h+JE8DCNBEMeAVwijGI+0dBQvg9v554rT6CAha/QzIEntaYh4CMmKtQQ74R93zTNX3u534u5zHuE+MQ4xOKWbTN45wTdiOAlgCCWAtgJ/E4kA2etmhXFGUEuMN9A/LC5gCQ93HGLM6NuYTrxjkx5j4c+8ED4xv375A0DZwLf8da/FoBf2rMTayb8KDG+n6q3r0poNLGTpBf9at+FdsASZd0ju/3e/bDZ3zGZ9B7+j3veQ8h76lH/eMGYDnGG94DcEyse1BBv1oCE2MLfYdrAHjHnIVi2v+74HHX3NcKjGe8H3z6p386Ew1IPMIuKU105ciRI0eOHB/OcZNs7cbAfz/AA5IWE2ZlAfsE+NZCfYzv/eOpWy4A8DjYRPk3PLj9HR+Ox6NMA+whjnK1n2XfwatS99MDclaFFvmiqsw8isO2ZBaxGKTb93LoUPgI/pPYyoaCefBWVn9VVy3iA/VwOFAx1zSlrNeVnJ2vZXu2lqqBb7Juk/ct8+rBqdvoqQgL26X1g32ARglqaapK1igYVtZ63QRaqjzs+l6ud3v54Esvy/ve9yJh+N07t+XW+bms2400dSurFXxnK2nXZ7Ld3pYNHme3Zb29Le3mXNabc2m3t6TdnkmzXrNoHdTYfgFoorI4Sl3CDuMo5+tS7mwqft02lbS4J8AUXpf7zMLWQ/2uGyiorZ0B8wjBu076XgtIdGznaxn6vUzTQeaplxUKLcIegSYgelwjhTZqvJDfCSR7RBH55EEIbwNYkwpHqlPb9Vpu370jz77wvNy5+6ycnWlxxAqq8ErV3MG6IBSfdAWg9jl+M0HBjGKLsC5xn1JLUCChwSKJ/ah+j4eDjMOgXquzW7SMVA9eXV5R/XDvwQO5ur6Wfpro4RxKBgai46AqqmZP1dw3bVdnYmVStX9o85DMwLVOydd4ztRGgtYRBLGmPLU6a5iDrlKnMryq6Fu+LESnvsjoe8xb95p2hTXnjKn94QOL66Bavy6lwbiEotjhnvuw01d5VqWy+SvrbgBLxLCAp642WFNwnRy/SBgVRyqiocBfU7Fc2xyvTBluti6jFkPTgm7eLlCP16GY3xGw0dXK9uDYcEBo3abGwdEuRZMOqngqCsxrGrJbIsa818OcUHUvirlB3YzzMVkBskjlORSnGIPYuQCltGUbOOH1OhVWq61OgV0ERR0fuJ/QX4kc+8ZwOyIPvWct0xnHIq0VdPXQAn2YM/hq7RbHLK5/4s4P9ed2oGpQtfDlwtoFWzKozJ40eTXDNsioaLr7IWQUrPnM1iKOOV8ZFOgiwaPJMD0vknrrtpXt2VbOzs+4y0MV+JqwZPvf0FRp2gEPjGe878xUBasCmN7/oUikWHE/WAuhDkCv4447Y2DXNJs9jyZBsNsFBf1WxVqOxYZJsUMvfG/sh5WMI6xhMOdR9BLXiOTqUcYJyn74h96Td7/7vXJ5ccVr2Kxbee65u3L3LvxgkeTqQ3uqrz7ms3mRtzd7kkI1hg++n/M5n0MwA1ibKruw9sDeAsAYwAQfwL0QGqAS4A4+eANq4B4BThAADV/2ZV/G46CgFV4P6ITE6zdUwYxrBbTBh38U5gOIxPZyBDxG8TOKw8GWAnAR9h2v5xwo2gYgAdUZ4ESqrkP74l4BSnEeqPBugjeAIPDdfNJ4rX6BAg5qaWz5xz3DpiQtDvY4AQiPvoG6Ex7Z6Me0wCGAFMAYVNpoC0AjL6b1uIHr/eIv/mL5G3/jb/A6Afm9EJYHrhvnQB+ivQDrcL9pAFSdJlxvCgA+9AvuBUXQ4D3rgBrAB/0EOAkPWoDsz/u8z5MnDfQzVI0AcPBRBcRGgiVVVOJ6AblxbiQoACRPA4ATFgHfkD4CrEfxS4wX2IhANY1zA+C9kfd8GjgGxibgI+b2Z37mZwav2cc5J6Ai1pNP+7RPY3uhnxCYKxgDrxaA+V/yJV9C6wLcK9od88DVy48zZgEOkZCCmtzV7K8V35j98KHGPv77FIUkXw3S4veYB49jFYG1HNcNKyGAVMzBx/FBBqzFNcASCfeOvkzHNd4XoODHeob3DwB5JCtfr78wxhvaHoUcsaYgeZQW/7wp8J6AdRz3hjmLpAKOg4TB4665rxV4z4M6HG2O+/ysz/osvnd+9md/9uu6zxw5cuTIkeObWqyOj2kc+P0/9f/jB9krVJ4fBjn0vex7fFDFh1wFQPi9FoybqJKs7QMrC4kBPuC5Q8/id297diN3z2oZ4R95nKWfjnJ90G3vhwNgGc6qBe0Ao/AhnlhiJdLUhbzj2dty1tZyq93KumrkOA2EnXdut/Kx77wr7bqR9fldwtf3v+9FeXD/oay3lWxvAzat5dad51QJXOn10bZgtO3BBrCpFEMxNoI7fFBupapR8KklEMZrrq72/LCOz+XgPLe2W9m2rakx1dO1KBr+B95X//evkvsP7stb3vyCPPvsXXnb294hn/BJnyJ1s5HN2XOEnrDBALh5+aWX5P79e/LwwSvy4vvfI8N4kMP+gdoRTGj3UeZ+kAmemoCVfa9FnrqeytiX7h/kej/Ki/c6+eCDXg7DJA8Pvfp1ejE0wKrjLG9/5kzeemcrH/OuN8t3+y6fLOfnZ/LC82+VpllLszmXssaWrFrappbN9lzuPvMCQe3m1nNUVVvGQYtsuRepDzDWJVOKpKglQh8NfA/7EPvdDf+BHLWujnf5RFVTHgHcLmXuD7J/8d0yXF/Ke7/+6+T+yy/J5dUDeXhxT6YJftHXLGg4DFoAEePlOI1m92Bqc3J1tWIgrLaiZFB+t3UpbbOWu3ee4zg4v/MmWq2c3bkj6/Nzaddb2Zzdoifw9dWBKuUHD66kO/Ty8isP5MHDC3nlwZW85wP35JUH1/KV/+O9crE7yCtXO9n3gJyAaABOUcWs6kIosQFeXYEfARoBbIDK2m5UlUPdTuW6vt4/dHgBQCSBGiRSANm4HX4lLdXz+Bu+FtK0FR+AgK5uxGt0PKsy83rfy9Wu03mNNpuPsu8Gwq4B8AvOM3bdCns1UXTr1hmTN4Rks7B42cXlNSEb5hHv2SD39WGUrp9lgMXIpBY1FXYe4D4AU1crOWtK9s9Zu5Jbm0LOt2t5+5ue5zj+pE/6BLn7zB15x8d/nLzw1jdzTlLNifuARzhA4Lqx3QB6XzMAYHegPUO1BuiLBfaG/YWM3bWCdsBwtGOz1tcObrNiJh+wxSmhmEdHwYolKtATU+Hof1xpUgPWNmyYaScy7gmdjzgWIHVrhT4JnHWMIqGhvgjq1U6gnFhQHKdeZIYfkMvRbYw53Q2ENrwrLD+oOpS16w7AnVNSC5wyqUc47MfwoqWJxhdtYMA9rA+JMjnMc/7KX5v4TjucdzsbA9G0NmH7usw4Qn7dUcDCALwePBeqXlhOvPvr3i8ffPGevPTyyywKxl0ZvfZTUTd8PtZ7vE/owUSGsZfr60srlgeoPEtTb6Wu1lJjHHIA4zI0CTl0WhR2j8Rc33EMdwN+N0s/YD1G/+maiQQL5kddN7RIwi+RhMK9lCtN3CIxSssicnzYX3TygQ++QiuG820rTV3K2972FnnzW16QBxeX8t73f4BWU5eXB763VuXG3md0F0ZdF7Ld1pxPf/jP/W35cA/AOKio8fhwCGx/B0B5PXD5aQbAD2Dvh9pC/k0poDD80i/9Um73f7XAWgMV6CmU/qYYKCgHSAql5YcCqd/cAmATsPm1gPs3p3icsZ/jyQK2HSjGiCTDG13c8PUGlPBQvEOo8kbbR+XIkSNHjhxvxPvPY1thbLdr9Ym0wljwSsUD6j1ANFUYAmQ4vIB6FirCWJwsqCVlJX0/yXXBnb7Sk6McWfWeqkqAErKDQAcWHpn4qh62k4zVIBMZxMQP21VdyvbsXFp4Jzeq1oM6DdvdmvWKhdzwYRqciAWSAEVN/TgTPMDuYQhAj5dsUJwgSisiEULgflHMjkXz1lB+QRUMMK1qOxZpA4yYCiqsARJwnPVmTdh1zu19W/q/qoJYt+wD2jSwczg7k27YS3u2kVUv0k2lrOC1CYQzF6Z9BMDT6wAkrkrAtyMhG1DEfltKN1RSHI6yG+CzrIkAvtI8D6C23q7XVI8SvgCgAZLJJGWrqmYocEe8vq9k6K7ZTlW9kbkaFcIAvqhs+pFd6equ4Vu+vXycx5Mol6MliVstEGpRBe+e3pVs1hsZzm4FWxIo9Q4dfPMG2R/ghwlbipHQRgvpqXc3/FbZv4FNQX080iMYphqAkoBU04zt5Wq/gjGAgllF0UtZd4RIgDx44PfYnr7b7+Ty+lqudteyO+xlfzjInirw/kRtbC1iPzpEdhUzz0lYC3Wq2TKcqL/1NXofHEumFmbrmWWFQmdVVjZU+a6khUKf6lRXb+sjFGNz62D3kaXPeVRXq2+rtmO0EMB8dkX5SqpJDwf1Ju1pbLs/jqlK+sRCx60nzIZG3YNVZQ+VPHcJ1IVU8KHFnK7057YtqVQG4Fu3Def9ZnOmSaGqCTY7ylPdjkB3JVCx7MXxeFOQ05rvsYFlKJ+rGbDY2bAmStymw5ysY+BkZpPi9inKlu18HG9WUTKxO6EIGnMdyQY2nwJd9jo7w6olhg0AdpwEFBMyU37OPSbmb2zjBX7rblPhqt9gSeEVCNMEkHV6agvsnuWxlKbazPA+j3LE4s7LUvsF91aOi4LZyvCLDYaw/VTNimO5yeQlwQc5FizUe9R20VuMqvJ0GzqaSX2FMW6QLGvpGc6dCngPStetFK57s7BQKtYETQSpAlgf+JYJUYfqdh3aHmojxMKe/rC1x73yae+E9ppmKbAucP5osddZKt0bMLm99szEIgr+8b1gLuTyYsdk6DPP3GX73Do/k4/9mHfJ/tDJe9/3QfqWdz3mHxKo+r7M3SuvvfDmeAoBawD8h1reCv0NDwDYL/zCL5SPlIClDKwHMlSOgQ80X/u1XxtsbXJ8ZI79b4z4iq/4CqqJYRmCcQZ7JgRshHLkyJEjR44cjxePDZYbeDUej9KMlBQCnRKhsOgdQLBBLWVIgAOARQqt8HyFW6qEA8+BeopKRIJlhQGEUAaU9emufk3dKs3JlFuW8XcAPvNkJv+B2rLhg0AYQKsC9MWWexTNw8M+UBtk0D3ZChLwYV9MVKceWwoycBnunUxYh3MTMqjfsao9Ae3wSgVC3CoNBTa2MkNdbMpnXA/gMvzbWAgNymnfbG6Ah1uUWxQea6SsaymPgxa4IkwyT95gx+qqPCg51ccAYlMAi3VTyLoppWOBObxUt6Snfse0NTG1nMMtgA8WZHMoxu35ACKwEUBhtYLqcwdQCgOrR5THS2iRWmL4Xx/HFsMxshe+8mPZ2CCMtLG1gpUJFH8tQSzafBihqu1lHEsZhj1hz8j8gHpJAw7T6sRAHxImBHmEXgaF6GOqthLc3m4+ygp71Y5CH8fF7wiYh4FqxY4gSB9Qu2vByHj/WrzOEypx432Akon38RHkMXjxJkmYMG5TILyk/K5iRr8DUOMrLWfU/Fd1oq/mOe2acbP9jZYd3v4KP5URR1sOJH+wfb9A0UKOJRvDtA7GuFVoHBMQUdXr51DBrRVuo+uK2nfQToC+ypjrVpitMksbFu9DgU4UHMT4jFYMYfy4hYLZORSlAd3Erzh4GkM5inmSWo+4+vi0YFziCBMNdNUERcFsYrVxYnsS2tlBsBW9i0mV5BI4/vVa2fJHg4WJf7gXiIuuNNHTOBwxmML7/6U3EP+WejX7/PWUj+Jqu1ZQVvMRxhqaTGWj56ZRNuAe13kHwctrXNy7Lszxmh3kpm1jJuWq8PZzuZ86kngYI2p/E+aMM/XUVztMUT0BVcb2iB7O6uOsdSET9bfvsgjrRWKXw0dsFn2tWhEhUYvXTQ7czXoEvydIPsJSR9cQJl+OhfQdbDZ6ri04LaD5+Z1WdruD3H9wqQnDsZdRVzO7zrTPv2nHl3/5l3OL9KsFtr9/OAW8p7/yK7/yG+3cr1Yc7Iu+6Ivkwy2w3f0jKWA18Di+rd8YAS/gD+Vv/VVf9VWPbQ/wJIEkDPx3v6nNlVcryPf/Kj7Sxn4aWO+x7t8U8GzG440K2NHA9xifFWGhgvM+//zzb9jxc+TIkSNHjo/0eGywfOf2LQIi+EB2wyhNfeD2836YoBOl/QLsLFBk7oht2qtZ/YpLBdL6OTkWURrnFZXK+AqVmH7QVvWmFzkCMHbFIoIF+cB+3POYhf0UY+lOcigXK7OsqKWf8UFcoQbEe1CkVjVgVK1QCnYVVCuWdAOdzQsVAJWqTvP0LALQUXhB3078IyhWoHs8VmQ/XTdLTy9m9fXE+bteATwA1xbbmdvGvJsBlhwaKsSEOhjgC2BPPaYV+E4AXv4gU1HFsfKwklu2FcSP3CqO+8RVb9aFnA+lTMdSLnalVKMC0olsUOEOFKDwb8U1I9xHtjD/XijciqN6rY7dKLura6mqUVarNdXWLeB9qp1Fm0KleAKDYpyC0/R3rx5LpOzACX0Em4BZyvWWsGR7944UdSn1Dl7UjXSHnczHToa+kK6rqBx0xR7H1By36isAMyUpb0UV2IT4GAsAURUKL8ZHaUr2UKzvoEpCbLcfRqikO3psX17t5P7FTi6uD9IBOpOqmtWEbf3HZalA9dXbwotB8vqs8J9CR4VFmBMRJqbuIpaJMMsSTYhgjGmRL8wF2MGoF6+O89RKIypYFZyxABpV3BgngGAYNyuZOG+tTR2tIWkE6wvMfSRvOG9tPhGEelE7LTTJYpOugF14/7o6Vv8xpWHry0TfdlWONlCjtkgwqHoZOxmQMOCRg8cv5o8VvjOVt0ls6bUcxhk7Xy0uqOinetnApSmdeUX4ls/VBNICLiNhpIPI2tfvybyJk5EdigXyuGYdE+yJte1Cfxh4pBKaNhO4iPnRpI3fm8NWs6lwFBygflJ4UF+SQFKeW0Em78Wv2QtPehE7r0Dn0vbERiNCeL9fXLPbh6QgWS0wQp+crhU8JuxgEh93g8Y8gqvfvcCo3Qvum7tt5oIe33jA7xtr7IzdHlaHNvqdYyTiPAqCaUXE9yoFwKEYpyV+sE/E/aiTywoe3WwC940OSSPf6KPXi6Reaf7SI+2K8L6Adyju1ZHjUesGEC7jPcjeCzCt5nElu0MvF5eXcvvObXnh1rOyaWvZX9/lGvS+6b7M3N1iiTNWA0xB+DfdQMGt17Jm+Lqv+7r/Z9fz4RzwO301X9U3v/nN9Iz+SLMbeEzXtxyvESg89qHmIf7+zWmu5Hh6AV9qFPd7Ne/mNyrgpQ0//Bw5cuTIkSPH/wOwfNfAclHUcugG+qW2/YogVaZK+nGkj+8K24RtuzahFaAt7RcMiKCyPYrd02kB/slQMtoWYwPK+iEerzHA5NXrqbLUAoKqODb1sG2HBliGAq3G9ua6lv1+pKcltz+D31DFCBhYSWVesSi25FvhJ4O8UCxyF7wrlnEd5EAKDLiF37Y2E+S6qhW2F/DPxCd83GUBSH2krQTuB0WScF34WpQKonwbNe7dbQ9AGtSXVqEy7B38cVxNugve2hjfA5TA91mvSWEg7D4AazbrUm6x4FMp27qUbqVgDzCf7Q5IxGsHwEDPGRcCVF4pVJ5LVeyy3+DtPF1JVfdS0ld0UNWfFV7j/8HDFtYYp9usEzgY9+A/GdFINZRGBw0si1RrWHMUsj1OUq9baTZrqdtG9rta+v6KQL4s4Cc6MJngsI9iR3rRKvCh5QMVjQaWTZ3PQmYssNXQWzqMJVimOFgeJtpfsLDf0POx7w6EOhcAyw+v5XLfS4/nGhz2sexNFDiuqeoXTWQczfvZn6Nwiz4ZCSB1sGywLlG4YrACmFHlywKcWkBvsMQKnwY1Pn2Eo6JWwbvpSgmldSz19Fj2XQeeDAoSZ4PhK84xWN4Q+rnfb1DBYiaYLY1hNIXKU+qqHSErwaPZ19BaA1YtumuC/u4o5mhwGTsAFCzrBelt2DUYhAfo5llZhM7V0Q4BXU2NtUILAXl7HwGiyQ4dqrpFQ9IGpniGlwEYarB+sK6JfexeE5rgUJNybQZeq8paDSC7kNZ+bx7UzJSEPne47OrlOLbYn3Z7AQgHsJzwXy+6x2sedMcExkUKe72goduuOGT2CYW2sd0hcXwiCaCWDw6WLVUQ2syMUQLwD23E9U9tgcKkYNP5ypDYxNj7A99DTNHvBSRbeIl7oVb0Tdip4LetY4C7H/gYFSgnUDnC5Zgg8m7XwpWcLUE5zau0hCnfv0LXe/+qpQaLzhIC4y7xPoNr050jOoZN2cxdKuppPowr2e97uby8krPzrdy5dc6dM3O/k+vrVi4eXsp+h7oB+t6kViofBlRZhLY2KJKV443xds6R4/UE/nvnm9M8zHPlGy/gc5wjR44cOXLk+AgDywBQ+NzrYFfVwQUhK1TC5ImAs0coZfEBWlW+uu1doTG2EPO3/GCu9anih3P1ndQP8eZT6tudLdyOglv3sc3dtjGrP7E/ST/o65ZhBbrcNm+e0ACKLNwFfgR9mUHl4BVbzFQA01Ii2UpeABJzG3/cBk5/TChFsX2ZJQsVIsDDmFYC9J6FRQZUjqaCg7oX1haEr7h2vX71WI1WBQo/VLEMu4xpRMExe27Y5m1NZO3sUF59XHENQnsO2GFs2lJubSppBpXEosjciG9nFB6EF7ZC+6gJ1C3tbidBXEL4j7ZSSALvYih+p3GQqbL7oRA4tdqI7qvB+3WhQEwVjDdHvKYIjeLfTNlZzLJCoS0A5s2aDIoeyGMv0zxIy+rOs6zXG4IYFEMcp16TIKHwV8ICCcP0DAr5YRuh96gF6VxVGq1eYEcyjOq5DMUyH0i4dFAtAzDD6xm2GKMWvLQTOuxcwj/1QVblqI5hV7X7OPct+0Eda+3hoJp+x96AHIs6HkP/ulQydcpwyG3JHp7FVL20CLFBx3OjDfyszvfC4RyaJuPAEktpD8bXRQ91BZ6qwlyIXQ2mefsTqAG6qTkzx6QqsAELGyZwsEuB8946NnriWlsEVa6BeM0ynMA2A37W7iEIoM1qgTTcAOkCKscWiaPXx71BxqSA3QJEM7tVJ0phU7uq/bCNfTOIwRgJlDiB4Q7ACYWj77BbcWgberv7uZNbNR/2R65v8d3pzgM/dpwjcdWMa1cKouMTmDUJFiQBgIdDB9odVMCP+Dj4vTi4D9koS4hg/NquF53Hy3uP1tJxroWvCav2wevWF8FSxUC9cmJ9L9KHzrmYm1nMFtulM3O6HWl3YXZPSLOYqn9ezXzo6c3L2trR/8HDHcVCUbyyOxxU2Yx1EEUc8e5cqqJcAXlMaubIkSNHjhw5cuTIkSNHjhxP0WNZ2UpdrQR2vloETO0RprGWsh+k2tcEZti+TWAHBVWvdhD0k1VUqd6RhDCmjKVidrLtZmYNQAZF34ugPcNHdoBgqKBZmAv+qfDUJSgGOFHPWoBQHIDKXKouV1I1gNEA0VALw09WC+XVAIVQ1zpvqtLt9qqUVK9ZhYcByFEdrEX6cB+wNQB4ZXG8WaQC4LRzaBHAlYytKj436zUVxvACpr+yKSQVaiikAFQGHINP5qbdsPGrVS0zbEZMAarQT6Hc5N6dUGjPEyETNtJvm5VUq4ob6o/TkdYl2xpemzMLKALm396gPVHIzRAOgQPsMgr1iR7VImEw2Ikihei3dr+jj2rftgS5DewnyjZcE1syQP8ULhfJ43Sb+5OGKpZBxYv1GUaqlE0pxxHF9ADjj/QRHrrb0q5RwK+T9tDIdARUHmXojnKEKtDsMBg2NulUQQFrRTDdtGuFy1T3mtduUci8WlEZ30Ol3A1yvd/x++v9XnZdF5TKDy528uBqL4d+4vM5V5y1mme3+0T7uFFwrHPGkwcOkdhP5kcbQa3OS0Izh1xMkhzpQ+wt7j7Rp2A5FjqbZYay2LzHobCHQht/1zl+lGIY7Tyxf8MWe7+e5PDukKBaaU8w+FhBgiaCORRW1O3+QZ5sUBmKKVjYWCE7OzoTTrSEUaXydruV7XYjTVtL3SAhg/bVgmtYJ9RWR73Ng5KWmRX4JztMtMJrNt/ZBw4X0cZezNOtLAzSUrHsN0y19VJnLymUBNjk+dUvPn0e7DgwFjCW555pK5EJxUJ9vdD78kJ6bouxzJDotavHr5tae4LAfPEj8Q2A1EaAJttihcB4/TZHoh9yCnh1/cW6prw+URV7cUMeLh43wFauqcxkLmsHLnY4KMjXNVpBO3aPeDosnQ+B8ScJBC/CqslGe68xr/MVJkmqIE8U2At0zueoLRHGjfcj+4vjRN+L8L7HtWHUJBP7IMBrT9/5MXEjujNkKqbg068w2qxfWGDTLDWwS2JOPZ4VFD94cCHd4Zq7Xd71zreyzfe7S+m6To4ySFFpVpfnOWLtgjVQBss5cuTIkSNHjhw5cuTIkeMpgmV+cA170KN3aNyuH9WL4TXuH5mApng8BTY8qimYA7NJGKRbUSh8cCVvtIngluLgcanwDQXTiByCQs+BnbOFZKu+30f6CAXI4oX4fShYsiJcDh14zAgXKDSkbYeqRmG3ocXgFCwHb14AbcCXojQA6+fzxtBzQ4UJkKnQLRYUw89ojyO8YlPPhCDENK/mQqSpVrJtYU0CL9qK0L0r1OMatiYN2zRRs3q/mGJZYYleH7g9rkf71aGGwqtw3Um7RrqYwqGTVg++sa8+BlOlcvxlooKmmts8a5kwKAmRAINhQQJlMmDtWMK6wwp20T81SAZtANqhA6NToKPgKCoUl2rDWMAP428ctUgfCvZBqbw7dHLox5DsiKXKUhAWVctxbkXoroUvXb1p9CtW1Eva6eZmjIrgCMd8bqYtq4plnCeOcZu0UXHpFjUBRqYqzNAqWl8t+U04VADK9lMyN+Nf4jjhfCy80KTuOsD36oWuuyi8CKEX7kO/KzxMVKl+D4kiWf2dT5IcwYbBrzC1lUgbOALy9PXpK5fDVscp8GEs3LdsF21eBd5sDVs/ggIcSSzzYg/2FcGH2mH4iSo2vfbFmVxtnhZNjIUXgzI+WZDDzoHwK/WKj1LnqOBdDL6Ttli06cmmhZuXAb+WeJ1xUU/7w+aWrfmuug5A2FZ3T0ZA0Q6fdCQ44w6ApEcXQzrxxPe54wU0vUifv79xzfSCfQ6A/XnLvnCRfPg7319OeP6ijZI1wd4LWMgSu2NgTdNP0nWD7HZ7zXsQpJfSrlvZTrqbA8lF3QBzOjZy5MiRI0eOHDly5MiRI0eONxgsd8PBQPBAD09+8FUNn32+t+25zj68JlC0v9S/GwQgSKYvqxbe4sM9jLk192QH8xHF+UrZtK2sm1q2m61smlr6/V7Goec1QaXb951cXFwqSCxbswTwXe6qEAQpmFGuz4H1CVimlhZAlr6rRxls2zD8NbXwmCq+AFMJjwG7mkaadSNVMcs8HKUpVrKholuVy7j/jkWWVrLdnsmt28/KenNbinKt9hxQvyZKUmxbHgdohGdp21amaaAaDp6rRdlQjVzIoP6uKCCI2n2AmvBKdjU4rT2OAsF0XRSyrRsq3A63KyqVATqhpkNxJ7Tltq2skB/UylDKiXR9R9Ud22qGdzOuswlQEcpRPn8aZIa1xNTLalVLCRUxgZN3oo+kRSt/A5TKHqk2NlVCqxId/SLzWjbbLe1c1tdrgjkUeGygXJZOjqP6K7vHMhWNVoitwEPcQsUAP61T3HcVSnX1re6HTg5dL1e7KzkcBnnl/oVc7Tp534v35d3ve0WuDpNcHQYtsmUD+2acE7fYx+31DqMSla8X30rmSXButUnoCYYkh7IAvSi6B1V3cCQwGBfEpInXOMEbjkkLl6g3ZgIDPN/PFZISc7R5MQUsbV0JRk2JbSANw19VpHrvI+AX/wgPZU3S4DkotnbrbKPzDpYsR5EWaupCZLtp5db5Rs7PN3J2tpHt2UbW65aqf1hi8GrpIT0lANRUyv6gRDs1uo42BdRak2675NQLhy6L73nRSjPD1vPMBpC98J+/2iWzQRyLc+M5UFDrjgT+FfOK7TrIih7Oq0SljF0MSE7ZUd27mwrs2Jl8hauR2Qe4ZZw/URsng8QF2JH/cjWzW7VdJOkizReYfUs4Vgq1HXwnv6eCOsl4hHb0AodapDWM+ZjqS65Z1dG8R1twVlVh1tx2zsUbkc4rvEdgbYWyHbYRu33JBJC2l9+HaevpRW8WEkwgYe3U9wV433NtwLhdDbr22i4SJJhgjzNOgxba4/udFd+kvUVMaHFscp6plZH+Up/DxBZ3CWnSxOcqir3ORxSjnWV1dpT1upb+MEh36OX+/Wv5uq9/v2zPWnnhTXdkfVbK29e3pR9m+cCLD+WDH3yoCUH6XOfIkSNHjhw5cuTIkSNHjhxPESyPtiWXthXcoq5QOCqvkg/6p59SXWyXgGU+zZR10boygt6Tl8fiS1AsU21qHrIGsFjIidcHJRbsKfB3VQSn0GSh7PXjB4VyIoILv9cbimhE1WRazQ8AIl6XQkdssT9KRcWygmXCVyKrWWYCMhR8a2jL4R7LrtRUlTV8qW0b9KyFBJcF3GzrPv4wwwZEH/DeDJ6mdg++C50iXtg2zIDBhUyAdeVRhmklbYXibbAY8V5S0xL3WCbcoJyukBKFu0wJHrsyKkGPrqY86Xdr6GR4mNLzdRSOelT/Hr17F3rd0C+qTNQCXapUjgXzFMTNbkWhA9Pkg1FprYroZOyaLQLvmcNBPcLRZ1Aq930vh0Mvu33PYlr4ehhUrYzt58nAvrFa/c07061gpSskHdjZVvlkYkW59cnxIvCLCSD3Wg7ceqGQjJ62Qf0fKgxG25rFtLIimHx+AXC99MH1rtKiZNqa6c6HOD4S4T6hmnql1yi2hiQRHBc4v9T7XYt3qv86FKgsfGmKVD/+cjdA0jjByzdpqKCyXyqVwzUmr1uMRR+PXgAwjJhUEe3fLhXLi+MkHtCaXPA1x6GrtWvS36kPt24cUKCqtiHJ9YbXpNedruXJeEnngSXbTtX14ZBpAskzgmapERXLJypuH3+nVutBlb/45fL7+MayfAphsyv6kwGfBNcFqtuRgKqkGsza5DQFc4NS2esCKGuOvwuKZe8L87xf/C0tbBgOr3MgfU9MG1dV9fHB68eab8kOzgMo+YuVjD3eq1csJHt1vTNrIliA1LLZIAEqsr3spF3vZJ5gZfXoipojR44cOXLkyJEjR44cOXK8oWD55VceqOp2ADyDdyQKlQF8ooBfJdPoPp6qBsQHYQBMeBjjQzUVzfz0i78eZRhGgrhggUHapB+Q4xZqpQYEADxWQWALWLPf7WUsOxm6jsXjQCUAcHFM+AhD/TwP1wQMdT1LVZktBD98K2ykj22hfsu0LYW4kgUE9UO7W0PAHZoetcpxqU5FgTaoiUv3r5VCaqiJKwUaKGdYGbhRWBwBhNpmQNEL1KyqS5kHQtzd1SUV2Dt48e46GYeR3piHbi/DNFA1rQUPI5Ql3DClcwSHqqrU7d4AygqJYZtRFnodTQ27ELSBPqepocaD8lghFrfrE4bQ9NS8dkupmlbqppWqbqTEA9678KsNhQVdrem2KTemCl6nWvlkb3j6lwBaEw5oKnUv0oW+IAAeZ5mGiRCmP2AcT7TKoDJ30oJwVQ1ACTW6SncxRgGGQTS7HkUBMR4aGaeSyu4eIPnqIC9/8IFcXe/l3e99RR5c7uXFlx/Kg+tehkmkgxUsxjBsUMI1n4LJDx0uvuT9WHInFOULnsCqJFZVqZVLdDUy1JRH9THGHMbfqJKM4tBFqO+yFdTDbEBSxSAYQC6LF2LGT7PUkGOPaCesB2onQw9aS3IA9qIYJYohos1xLPzVtagco170MxQmw+sKWTeVrNtazratKpZnrAUCI3cqg+mnXcKDeSVNW0ndVlLVsEFRSwzdvVCo7YwnZxYqd1PQuxqYjWaoWSXsZoeir+VugeAxvASQy/5MHmaHE5XcUa2s7Wt1PNMsF4m6FVSkzY8mh3guqK8Tv3UCZCaDcP0md06KzKnfse6MoGoZY91tIxbAOyYV3AVY161lccPg+e3AO2TmtKimPkf9qbWTTYWclOLU55iFBw+jyTt63LNbTBGeZj5S4OxtlCa06F9uknsrfKgq/5g0rNe1bI9bOb91Lnfu3KHq+f7DSzkescZqGyEJhMvF+xve8/R9zjatsD4A3gtnqZhAmWQqbfxKUkR2YZuRjgv35I7fxtuK8BjhCTFPlvgah/GH9Zz1NbGPZSXSthspZM17ePGDr8j1bi937pzL2flR7t59Vtbwiy/W8uztZzQR1sN7OaPlHDly5MiRI0eOHDly5MjxFMHy5fWeX/tRC++plyw+/AOpqhLWoZZhjKCkVa5hxMo+E2vhOMNJpgD0Lf8EzEsxnG6FT5S9Q99T/TsNKLw2yQQ17hEPCfCPBcCs+F4ASwEqK2wEdCN8BTQywK2/0+fy/GQ2x6BKFhl4/7hqWE0o0ymk4od8VcUCtRUAI1RRq/43FIIysKRb6K0EGW0VRun6nfSHg+x2O9lf7xVOAGQPHaGbqsUjs3UlqSvdgkjQe4HqXJwraq4BIvBk7Kh2r1FcExghASsLlSntcOVdYEVU+VkBuxIPtYegWjv6BkTIu9jhfqISDb37WhG3wMefHcykKmjfdZ+2dYR0VAQSGAGgG1zmY4oqdIJ6BcuhiCQbW6Ej+nJFdTL8WEeRYaQ9zAD/5EF9Ta+udnJ5uZd79y7l/sVOLq4OsmcRSy2WReBv1g/BJve1WsA3AyT3FdTGi+cZlGNRNlcyu+rV2KY9l6pLtAPsWzBXWCjzBpsBg1thji9UlkhKHAWuDZxrpdqHhD4Jilu1pHGPdMAxJGw87xCdCnAtJ2DZABuSQoDEsMOA+r40IDcxKaNjlP7mnO+mWEYiiXPfjbMT5ecCqOLeT9XHqZe83owW60uAtHmnhxsO/XQiCF6M5AheU5/paIXxqGqZ/UkLDFUts0CnnWgxO1zmLWnhQIXY6TpMl2+OnylRIvs2h1Spnu44UMDOtYQcWJXToSggvjfbFOs1H2nxAKm3tZHURzzOwxqiEPqRFcJl1InIeuHnnOwy0BUsgbjJMaBUPrZHadeNbDZr6fuBa/dY0IRYc2OmDud4pHWTFR9kMgbvg/o7AOSS7z9xDUnfE5de5jdcT7h09zdPxhyb3d63+J4a55cmbPhOpAkeFv1EshLX0snl5TXPvd93UtcNkzpnm62Uq0bO1me06jj0eG/PYDlHjhw5cuTIkSNHjhw5cjxFsHx90AJuHQAalcYGMo8AqLrFH1BIWax6U0Kt3NA7WFVfeMZAMJoABAM7+Dcb/FwVjxYIBCziZ25C2lmhsX0AxwOMGl7I3TjJvhukHAFhsVf+SC9NQAT9oK8vcHisV6row7fVA45o0TvQP5Ea/pn00YVyTeEzfDOJ4eBtiQ/2NI8VqatKmrKReQZ8huJtlo42IrhqVQceFwW3dKt0Px5kGqF4vZDDfifdvpdh6NW7twNwxvfW9vR3pkaUsGE2KETgzL/RXDkUFSS8gYWFK8ANGrk6FaAIIBxerfBJpmKZBs2mRkS/1rAgaKVp19JuzuhNWjZrKatGCjxKVS2zTSnnozRXaEDsUtgTJePri3S7fiRLcSQ47ML9ABjjnrSY3oAHvVEnbhPHA+2r6sQjFcwETkHJqkdWL2SMcoBllYhDFQ87kUp6KedCrncHeXhxLffvX8oHXnwgF1d7efnBtVxcH2TXjeqr7IX3TDHKr6fF98gt43Z3hKrRzfIlAcnptvhYQMxVphG2EZRD1UzlsB9XrT+AxQuAZQfcRQLgDYjRW9kKWCqM9ITD/EjiQcGXqVtxbCuQprwy2qeYGPgRaOjnVWB3fES5jOtggT4on+1uqOQ/TlLXCtWQhHK7HoWmugYhmESBBQ39hxUOU0EcQGRim0BFsKn2TQVc1jirUz/3R/aEWKpO9qxGUnDTqa2mqaLtgr00+ASbQjz6l5i6eqUq8OitDVUuSbq+Bq9HPwOMLvi0e0lzNBjAdk/89BqiCvroay77yeFxtRinBMhMDupcC/ZAsOjgrd7sob46KSaoSQVvoyTT4skQT6acAtnUvP+GUE5v1kXJ84ITDXdguE2O9gfGOtT2SDhh3Km3MewidIcO1g54zus6rPUBYCUxFhjrs0ylgmdVKluTcN7HjI7fDZJy8G3mSp4Aa21bH1M6ONJilUmuKCQvtIl0xwiexzeyuZBhWEnfz/LKKxfSHUZ55s4zsm3W0lSVbJq17Lu9jJMmMHPkyJEjR44cOXLkyJEjR46nBpavYBdwPMqu63T7vCluV8dJCgBJs0wAw1ENMwrGldJYsTOAHKq4hkGBD0AnmCNVw4A7sfBXAMtGR0LhPYJlukoqyKaHpgIHguVJgd/u0BMUHAUF747StitpYNlBQKUgJaiWo0iPkIxAuTSwbFAcKjBc0jAAICi8U7A8mxhwZv0sCJTrppZ1s5ZhxDVBSSkEy7hnV12n/rzUDs6jDN1B+v4g11cPZb+7khGWI4DK/Sj7/UDrELUPgdJbQQ/YGBR2uD6HGKMRQipHCdQUNrgq3IlQtD1wLxIUZhvlOA1ynEqvpmgASuFL3TbSbNay3p6x+F3VrAmVy6pVBTNVy+YrC3sSEhBD9rhYSlY5QF5nJPA4RAIC2TOudATowgM2FwPBMixG0G949CO8kLUYI291OvL3REAlig8mnrRufSCFjCYEXckoxQTo2glcLa4ud/LKK1cEOO/9wH2C5ZfuX8tVN8ium/R14TasgFwCjyOve9RLVQGr2ka4CtnV9Pi6UEX6dafHDarWZaB7UbyRYNm6GwuCX1ZIf9CWQOeNilMNdDNVoscODgi0Hini+sDClGZxgyTUDV7H7kMc/GUNKBMuT14IUG0ScB1I3tSVri3slbGU8QgrFyiacT71AXfLDl6PFWCj7zak+jQuV6gci0iaejnsIlCv8xQsY5wboj6BxW7PkCiYU6BMaOty4PS10cMa/RDBMuaLH0YtLHBcHd9I3NhXbDMoa1Mkc8uBrODXY8VUA+DlMeP8YSIpPCFeLxTMSIrxJyt4eGKEbxkE23qC83A9H03xb4UKsX6jnVMf9kVS6GQOn3gpp97euj0j2lE7XI8y/mQGpTsWgpA6geh2Hs4SU7b7Az8DCGMHA9YChffarrBbQp0BwmWsG1QpK1jGe1GB94PVSsaqTMAy1nkrVBtuPVJhvB9oTVTsfkAyxAteJv7vrtg3n3hfEzxBE6zgeU/W9lhr0cMjvPQL6bpZXn7poeyuO/mot+zkztm53L69llu3bkl1dZTrHQxuMljOkSNHjhw5cuTIkSNHjhxPESx3g3p7Uu1J0KSSQyCbmoAmUSPah2Iop+Df69t4J0IkqLNkAXIIlheIyiMqFe1js0E1wCOATABshYn0mYR38KrgtZbTSkr4KhcKcflB34AVgLTq0GxrMa6N9g8RWOB3NYghVWz66R2gFvBHX6PWD2C3VBoC3NpDwZB5puJ85n+M40DMBzhML1nCKoVWPYB938k84menEZEcKJxYyTyvCLsAo0sANAAEa08FJXB3BuSgtDZCxRPf11hoK4GYBPdqC8FrSF4CqNFCrdxuZL05k7qBvzIUy5V66JJS6/mURzs0Uzi/gr6USkYbO4/j/+Cj4PSpC5uUm5XKCpUVCmpxO91OXnmxLqjY61ptPYrSPF2jYjItgqYnMuUy+sAU8qpyHkDz5eLqSl6+90Du3b+Uh1edXO162cMaAzYmhLDR5zhYKpxsjY+Adlmkb2ECYnNBPW39svS6tE91TIRjuC2Dv9ZhXTytDXkvMoZp7fAqBdTmf5v83tX3bqedAm4/QVQoJ+3q6szgz5sqt0+gY5I3IGT2sc51Q13ZCegw9pMijbpGeGHMWP1SobcpjRNletRx2/8ntjWLAReAZuLh4WrUk6KM6esWSQrv/2iknOyZ8Hmj3iDqE+yT1cB+Im1VeA+/du+lmBXw+e3WEOn6FueM+VSzn1SxrMUZrVMRXGBCy1i72Yi1woCLs+P5nErY3uHNoBmZ4APvg9KuV5G5z+ckwuA9Gcj+MOV27Jv0tctD+XO0mzxtAr9x909WD2VVLGON9UZU6xvaP9kc4alp2XJMlPWaDGETJONa+8uTCt43ulcG18IVEuPCEyE2NrxFvL1wfNQR8HHptQl1N5D2H5OzIUmrfYXnHPpRVgUKiXay2x9ku90aSJ9kv9/b+1+OHDly5MiRI0eOHDly5MjxtDyWDx0/EPeweDDvSwrmABVrVfHhszHLWQG6rla0wWgMAAPoUR15HAgdCXXNk5iqy7DFOfFw9ZPb9moC5BKFuEqpm7Vue3fF8UotOfDqq+uOgrrtFgBR1c1QpgIWwO4AVgflqpSqqNSHtSgJFwCfqfgz5fK6bQmijgOOcZS+nGUs4E0Ma42Gqj+I8sq6lCMU2a7kA6iEijp4+Oq2fIAxtEvfw+biIE2zZqGoaexkd3FBxfLUQ/lnYNZUmvOMQmeFTBMeeD0K7HWyquGlWZrvp567wjWjHaaDgkKXouK+QsEth5EmnqR6ThkK4EkBGwv6VMOrV8WOgMrnt27L2e3bcvu559kGdb0hVC9LPM+sNAhkcK9qhQGLDECT4rihElhWtag5bvRwff2RjhRAelUo04YDyusZ6msUdpyoYMcYxH2gTTfbMxatmodRDtWOhxgLPNcKLYaEgOIy3CeKtVGnfkR/UBNIT2yoG9/7gZflq//X++ThxUHe8+JD2fWTPJxm6WA3AQBNmxWFn65KRhiHMvikykUW5/I7TNXICaxCYoIFyDD3AlhO/GjZcdhRgFtzZXGC7aygn+tXaduyQkLCxLJIvphC0qGvKqbVm5YeySx4qP2AZEf0nFXLFyZgrCCng0FVAY9MMoQic1a07bTA2cI2FwVDB/iN003WrGrg8E4DdikqVSy3TS1t09BPFnMU6wX6LhTrK+ALXivCJIAzb/BHhpUlsqzAntsIw7IAnjhH9FEovmcKXULcRJXs9iGJclmTUKbeD2DV1cu69viN65roNB+vM9U6VeYKJldFI6sK6wjaHeN36VXh2DYuqNrmmh5JElimbIaqnEXz8Huo6kMS4yjHAkmkivPgeEQb+nH0urWffacD7IGsYCubxSw2gm0KPYesYGuEywlFjkbVIelg64Yli4IcODw/VSyHVlDbkSQ75RYu7lmP3Qxdt5eOu0Y66Tusz1oDwHMFCo1HvkdgPmMt59o+IampNhhotwJriLW9Fvhzjq/9HFk61gJLupU67gm0LWmr9kSaAMX3tI+asOulUBBttQNcqcxmYgIUllKjCB9ov5LXC5/3QzfIK/cf0qP87PyM6/ow9vLKg/sE6zly5MiRI0eOHDly5MiRI8dTA8vcis7v1Js4FbS61jLWZTrVnp1sWzbppKq2HCgocOBTUhGa60YN9NBOAr6YZqGhlg6qnOaHcDIaswuwwnxBjRtkiQnMCcrBeD/8EJ+gbSoig32GKiB1R7iez7cx6/Zn2zpPkBuP4R65PD4Biz7oeUpfT9gyGHhJt4H7tvwUKhJqoGAhdj6nRdrSrfypItRATMKcFmLYhSBT1cpBaWceB/QirRUoAtYROrItohKSsBOexgA2Bpb1NkDfJzkSdNrW+eAhGi/D7/jmSOSHKSRaKJZPvk/UmTpOXNUa1a1uKaFNropzHz9p4TS3E3CvVPfvhW8z1ICwX7nad3J96OUwwGoD9hep5Ukc86dq1jg3TlSuJ1Jt11imOueg/HefgCR8q3zast6nju/0T6YEjWJX/XVS5E6nbarC9GKR2scYH7SGSXolfJdOp4QbhvGRKHEfUYovej+qiJ3LArqha+CLzrEJ6I2CfcGPNirF2Y9BNbw8ejqOdL75FDQAynXkxBc7tGzyOges7m1u9hj+zNTiJHS5fRN9dZe7NuymI6C2ORn7LFVGp9eUqoWjp7sOAv/qg4J+MGpntFj/lnMrjGUH3iEhmPRtCCsQGNZAP+fq0SsNxsHJvYXz6lw7oh/SeW3vCXF+pZ1iBQk5l03d79eSLq9hTNkVAxoT0psFiyuCrTbAooBlGORxhwP9zHkiR+0Of6MqO+4WieuOrlE+XpM24nJj6nXvq6C6Nt14OH98A/Xx7AUMkfAZi5Ucup6K5Q47ZLAL6Si0NVLInSNHjhw5cuTIkSNHjhw5cjwlsDyaQrIhtKkIQalA5gdf3druH6RdyQZl8wAPZihuzRdZLSl02265qnQru7/Wtu6rX7MXK5vVDQJsr0Yl+0aqsqJNhar2VE0HxS0UjLSwqAGfV7JpoSwupGlqqRrdGg+rDD6OOKvaIEDtKPzZtkUfRxbrg5oLlhtQ5qKAn8JCkdWuoLoMosVVVdPn9DD0bIW2KqWm8rcj0KAmz/2PoTSeZ6qVxxGPjgX7qJjbH2ToelOd4rpm9U/lOfXBdp6PMnQDVXWrqdAH2hDSacAMKHUJn63IH20zAI2cYkXIT7Ep+0+hRAF53TRQ1QnlMlTIKMCIgmXNupXNdi3tGirQQn1zbcs1bDmgvD5OnSqlce5xUKgBZTfBMvrbigrCqBrb910mnRbKe9Xw7fjHm9XKQbGsBfugnlZ4r1AJY+PIsQClLBTvlVRVzf53aEd/YLQl7w9tWvBR1LUUjaqz4eWrpEmP/eDyWi6v9vLiyxfyvpeu5Go/yMMORQJn6UvzLlWipSDeC5qZ4l9Bvt+7/84TBY9wtICUI/hWgMZZxCSDJTWcQbqQ1hImhuj4iyp4OQfL4dAFuE5VTnNfgSnuTXkfjq8vwtiD7zGU2/Rix/Mw/wEpp5WsRsyDo5SWNOERV6XMmI+8r4HrgloHa3tA7Op+6KFv0CclvNtx7UcpV4NwbwCK9hUNx+fZrTNZn22kamsmQgKsK1f0DNYid04RXcVtKl3zE9expGvaagU7HNhDOJn0hyt1AxmmQp6KeTRk1cah62PcElHBP56PtMCdwz3zWg5IFn/X6yjKVguGwoImjN1UpawzGyMv5vlo3mJzA4paL9aZslv0ls5pd2wI1+lTT31SdJ4lib/IsNG27o+PfhxPEkJ+IL+ydN7HBE/YhuLwOyiV6QAerjuyWbS9Wxn5cUrtb+72UJUznqNn0h0yPvarciUNbXF094cXXeWIMK9ktQdSVX6ciwahrZ/Q6linw5igbU5SEBLnNaU1l5ASZ9D5y6s2FboCbS/1GtuuIPA2b2vMERZiROFcrX+AHS0oEKq7LSJ4x1cvuPvSy/e4M2a9XtMOA4mPj3nXx3+IdTdHjhw5cuTIkSNHjhw5cuR4A8Cyh2/nPxZQ6LrK2BTKXhTJ/EFZzs/UZPBXnhYKTvWVTLe7u1aLR7xh+796KRehuJ5tLGaoqFT9VvVRxOfSAziBenGTsRW9iipf1x5q4SX9MI/jrWiZ4YplUxEb8FIfZnzAVx9Oglbf0hxE2laQjEXYzGOZENDtAdTHM5DARNXHNkuAqm/DnitnLqrOZEFEe04QFibSuqVW0NqV4DvRXDrMT3xk6ZENMBza0oGkHYmvUZUyHvRpBqCGkhvgG97XhM8Viz3iufRbJjiKymDt90fh8tJHOaoOo7lDBH5e2MvVkfpzhNLcYm5b8R3O8i/pOHQYyUJ0VvyNkFUtTvSQgDwTCwDuu172KNJ3UPXyAPjvcMi9TpO7WaqVb/5bKopd+urGa/RicT4BAoZaNqn+7KrJ050GiYD6JjG1K4ojLFPYRsjpymGsCYByySvCGTi+lyrkoFM+8Rl2pbCO9wSkp+pxQ5BF2KmgRdgAEqu60gd8swGSTdZs4uhoPxEmQYTCDpUVvaeNbfduamJV/Z726XKM6TkccKbz+aRtkiSPPtctG+I6qKpz62daeaCwZqJ6tnkQzquy1qDGDUkKPuVEsey3bv7thK8YRwHOntzmclVaTppElRufFH1ewgqUqotVbhy+anIx8fZ2lbQXXnTAHDo1OTkhq34frGZChVJs7cCPQL+JLN/fu4JSOFXN32BD42c7qR3ovbm4z/B7f69Y3rr3mo98n2G6G8jWuUXz6nqyWAOT3QOaNPI1Pdlxklwz1NaHrpN6t5Ldfs9Hu0Yy5twsX3LkyJEjR44cOXLkyJEjR46nBJbPNvDuXVGRC8B6EHjY+jZ6/SA7EJZSPyzlquDPqprCP/NgNV9Lhx/Ki9RGwzEC1GJeDAkAb8KHbe6CPi6AEjGxHaMuoVRGkTmoz9TuoK4Bmwpp162sN+q7ChUmwFPTNlJXtdT0YS1FekBRLcjGok3lMfjMrtetlFUt+w7F2I4s/NbUjalaa4Lnvp9k7Gfp60HGCoXq1A8aSjGoi3v6eHYKl0dtFwe4gAKA0niuVTYMam2012i+1urGoIo0gGU0CtrZwbKyMgUOAOCzAOTGQm+IiK0S+GBeupocsL4xegdrASh71TYCIBHqTYBhiDP1+sdhL/PU0/cZSm3Ks+FtzH6D8hnm24CyM1WyWmAQ8lTzm4UPM0/o7qHp1XqcqB3Nf1W37kMZ6f7K0Wd56A6yv76WoT/I/vJSxqGX3dUVCyXCTxU2FvBNVRBjYB+6TTJGhZXwhdYHKH4hxwmwfJbDHj6sg9y7fy0v37uSVx7s5HJvNhjwYkXbUr2pxcwUKnk741fmv5sUwnPLBo6/JA0QbXOtGJ4V/sL3XiSMWDexeEhdEpAMSoFX6rwRkL3Z1eK57r+NMYZmKEt9gY4xqPXVa7Y6YnxQ483xgdcz8ZL4XYR2ZV8ZE+S4V5m0A0BN5LidDOag1Ywz2KhJI1WSa4JJZLtupSZAVg9l+MZCidm2jTRtTVU6khxQzx9nUxAHwOqLh0HWVE1rimxtSGtTKOwBALnG4X7wBMzjCKfVHxnPrdRrPQBnU/NaEiaOZ/QZlOY8gY5/nM/hMH+2UUAbGSx0a1UXq/k6r0fV1rGQoxZsg6cxlPvAjQpUsRbQf52qXPVjfsSahUp6V1fb2GRSxSXwaOtkkAUFsSVxzBqII5oDLNqKUEWb2mlwijnQd9Dq0DiB0OZ7rjsTvOBiOqatLe2iFlDeFcNsF1P7sk1xHWoF0XeDXF3tZL/D+hyyGZqs45JqYNgSJM7nYbPERZkqY8wfhdghccdmUeWxKqFTCw+fyw6Wk+HoycSQkDlKgV00guRlJdUKa6q2giZvzW7adgPR5xzJR7P14Eiy8Xh1uefumLPti1IXhTz7/HNy6/Zt1hvIkSNHjhw5cuTIkSNHjhw5njQe+9NkC2orwoJ8gE9TOctQqD0DVbRh27CqrgCAWAzPihDhkzg+tCs40qJOCijiVvAIuewDcfD6xffuQ7n8MB6VyrqlmUC5UksMqJW1aFslddtIicJd2LqPInclCvdVwS9YoarBvmCKqXBBATUgNF4zJj7DCklxHeOA7fyArBOBHK4HIAwBEAc/SxRI0g//ZtGwKMambeMA1e87Pt/uHXCPdqhmrxDUeSdMlmpwqDmhgoRFhrVw8DxN5HMGUgxjJopyVysrEI2n0e3wVCYDBg49LT3mCVYfqliGFYV2Lwr7FVLSVgRbzQcpuUUeBxsMrDkQPfX5DFe0vF6/aANQCpddjQmwbNvWx5EQeQBIPuzZR/2hY4EufD/CD9pV48EDW6FngGuh+JuqlZ139QDI3SDXu14uLg9yDajcj2aXou0SL1MhEZIV6lywVAc6bI6qyRQqJ2pJS8K477ErLB0+u4I8+m2HE9iFRHOF2O/xW3t1LDxm8L4s4nHRVTOK6E2Ar160UYEykw5xYutxOYZ1TVConHpOO00zPJjsaLCd/LFdgp+5+SpjLaorFuwry5qWNW3bcq6zuCASYJiEtPDQ5IGrotWLx9qIX+L40iYMBstL+TfBqdtiGKBOVM/e/lqcrzJLkhN1Otc07GaISSRVIyfjf0Fu3ZfCgDULYMIipzMfdy84mrzEFM5cP1nYMxboU5EurGoAnZE8c3VysRwrbHNLhHhR0qRvY4on/ryYtkl6wce0rhu2nnPQxCTRIxsVwjqMNlIArE+hqcpJwcVUZR7nTjzU6tG2tGtEggdjuWPBPiQgdJ7oppvoe8xxa7tkUn9x3TWD42jhWl8jQhLH4TITlrGfUpWyt2CyLyDMEy2Ua/B8pe8HeJ/gecPOGR9alpSx3UGpypr+3UcolnvpukkePryQVzYti5myxOYNHu05cuTIkSNHjhw5cuTIkSPHGwaWGygDVyvZNFD6VixMJvgg7kqpULDO1XmAB9BYAXCoggrw1GGeqrt867Mp7BLsFQGVecgGRa1/qE8gJ84EkFwBGK+kAliGbYMBZvooY6t8oTAZwBk+zADQFYvQQYVd8sP6SDg8UglMT9nSvIqhvDWVdFNWsmlbguW2qQnNpqLhB/eqauyacQ1Qe5ZS11Domn0CYAXUxATGvp3Z8Wn0bo0WILFgGDc7m4oZStt+WMnhcKSXbF3U2nYERdqOrmZ1oEZwTZ7jXqle/ErBPb6qtYGdy7xmCX6okh4VznYHVZnb7/vDjnAZdheA4AqdAY0pVNa+QYGospC5KGUiTK5UrUgl5qTq1WJtylBVAS4jBcuphYGBQj7UR3YcOpmgSD4cpD8cZOwPMnaA3wrTCHDM5kKSrwq6gtmCQu8V+g7epWgH9VfGWL682Mv19V5evnctLz24lodXezmMA1Xn9HbGkSCzp0WJ9jMLgxFe+zlMBezjeeFfEQWbwcqADDICtOA2kFJFO04oWukFLTm0DNqqzlG9XQ0YElZNK+4OULWpJW7QVS70PEZjbgA5v59g62I2ODy3jQ+3y1C1vSZ66FsNSxW3qAkq5mRbv80Ld68ICmh4lPeDVGhWg4AAy9iN0DZraWs8GnrBYy2Y5pFe6RXVsmaeE9prWeAxtSlYgO8ULnM6OphOleWputkLg3rSLHo0P1Js0qA2LTBC3/o4jApztS3RHRQUAsNP12w5Sk88WcJHJj2f1Qa1Q8Zk2RFrMQqFAjq7rY1bBTE3MxKirgQQ24oKEn67p7PB4NT6IiQvzCCbJ9adDTESrOrJOxv84S+uGOZc1nF6XKkXsW4lcGSdQGLvsoXy2vvX58UN4JSWHwrQMX50DGEdZw5IL8N2zug8BoKdpCtnGeEhbu9pvGa7ZKyjOsZ0zfUUT+r37j07T9p2jptVdawJnZCAcWsf9LaBY685YHkKUyZrskcfsEuK679aNtkiQI/mlXTdKA8vrmV7fikP7j/ke1SOHDly5MiRI0eOHDly5Mjx9MAyYFBRyHbT0gZiP4yyOnSECLBx0KJEuiXbVcj6YVg/+A6Dbt8e6LWLD8oRuqhdp6meE8CjRckSP1wHrFZEKrEZVWUxiq0BLNcGlgmvVgaWKyqWUfQPYLxxsFxBlVezmBhLKE0osjdKBbBs3sfuiUx7C+xGryrZrDc857rVAmFHWAIc8TcteKYQSKicRPHAoHojLDNobRYJ7n+pFghWOCzsAo/gS1+vUBnFE7sOrVhKjetqtSsJgdCWQYEbIQegg5Jl9yY2DIPrmFZyhKLWfukgwzeTowjgBNX1qiesVcGmgu5ud0Ww7P0G1fI89wTUVYWCYytphlbmciXQKpORQT1eQfEK0DEqvawdhqJQmW8p9/BrjjQpqHsDVB4hp5Wxh0r5QJVyv4NSuVOwTJuMRFVrBeEcLlMN7ukKgjRchypEg3czFI7jLA8f7OX+wyv54CuX8uIrV/Lg6qAFHNGPgQIbqAv13lzp6wmTCH6jojf6vKYNoE9dyjr9OfziTNN8ywkKgz+wq4K1MF70LjbrilVh2/SRWFAA5UkUjAmyZxbw9HyEto9vu9fCe66ctXMmbDbsKjAVse4i0Lnp9h5edJD/rEtp8WEq6aC0naYFWMZzYNXSNC3B8qbZEC47WIaKfpCjrKkWDQdaKlwfMbFObDFCu9t6hYvhjgsriBeHZ7RbMHua+AdTnKae33y9F9HTHRz6NfrAR2uY5HqpoC/kOCKRY8UaMa5ZUE/XKIBlPh3gMtgVWYIKa4EliRws6w4T62SueaMm+2wcHZF8U9+KOAadI/uMxC1zAT8pEGh/TBOF2uS+eyLuPtHzm68D1ipaqADqqvUIi4BSPW+ANskSLL3ZDVz7GucK7qWQOSSQsOukrluOoXXb8IiTXwZsX1DIFH1hY6gsVUE9WAKQu0dYbBbrpBuBLwG37oiw4oHWhnTeSZJCXhjQ2XpQmpvvuOuTYwvoT77W4z0hgOVgBa5jlwkJX2BklMNhlAcPL2Wz2cq9+w/4npYjR44cOXLkyJEjR44cOXI8aTz2p8m4jTnZyhzgpH5id+sHL5ynhc7UONILcwVVokstjUyoCmtZ4su3ZKcRixMBzALmml0EBW4Kv3Vbfiy85lYTsKIApyhXsxwOhUxVLUWxoaoW2/oBpVFojpDgiMJsHWHc4bCjnUXX63Zp+AjDn5OQocAx1GtTlbCAEij6tzK+A69ltVZojg3vD8A7QkEF5a78xAUuDAvSQnP2W4cIPW0IJsKhhsDSoGmqulWKu2hLBxbqMwoomajjUkVldCjVbdnwgYbVwwBwBA6mcGqgtYSBZdIbLWBIawBYKcCXd+i1X0xcDLBc8prxu0FV42De1ShStrIqN9bh4f9OgKCmIdD/tORgwUB9AJoBPuIrbRAI3lJ4mFh8WGHCeS6lQFLgCKV7ZZ7SqmZ3NSYTI+MsfT/I9aGX610nu8Moh36WgXDJ6NoCUnqBPdN3nthUBAFogGv+/z4KzK4hGQsROuvcCqpRb6IUoJGbK+zV7jGlOqAtEi58xKSNJ4notRwKetmVmB2FesYaEPdxbGpoV16rH3cspAlgrepl87MIx9PnsGhZsMpIu12v3T2WWUAS12z9ow+D1WZRw6+2BsQWDStIFBgn3RSsS9JhZtYjEfQmDRsK4AFYKrD2oneK5DXBszheWl8vWVN990AoQBeU8z6f0uuOSQm3TlH7F4zzXlXiwY7BrTZURZ/aW6gHuyrKw7GsEKD7rEO5PB8LWaFKqHk34x6XCR+7NGs/38Wgv9QdEktTjNj3/J6+8ji2fU2muN863kPCFAgAP7aFNmMyScLx9bl69tjv4ZoxNyQZr+GhcB9vBfQYx3sIFy1NkiIZiSgGTW7ozhQo5/V9T0G+zTcvhJp43ceUnl27z7HTWo5Rsx6el+abtIjsyYqdHEDXCis262sAldclVdEoPNr1g+z3BxkzWM6RI0eOHDly5MiRI0eOHE8VLFONCGgANapu7V2CTxSSU0BRNzWLvkE5BWsJfIhVN2bd2hwKDnHr80pGsk9+gg8fnhfhKq6E6bAIIJkJ7DhmqYpZ2kbVyVAm44O+Q2ZCZaht51kOKKbXFFIcd1SoFSWUahuR1SRVLVL02GqshfSudhcEjAN8gctKDodJ+n6W7tDLPB5Eyoavpedy0ahaGOpYFsZS2dt0nKRqFFrWa1Uz19hu7ZDACpYNgKHjKCU9ic3nM1hqlIQc7vsJIACvzH6YCcnPj7O020ZV2wQq6gkcPVxdtZiADDs3C1OhIB2L8kX4H8CtAQpAZRQhVJ9QV+hBNTnJ4eqaKuGwR9sUkmiOhkUMtS/gw7wq91KgrdhPNcFuBUV3UUpzdlsqtGdzW5MC5qe6VJGazM8sQxSoDXzM8HkeeyqVD7uddPu9DH0vR6gzAbDMOgHtwCKOTU3vbFiJMJo1z4HxS8V73YoOCi2whqKLu30n19edvHzvQl565VLuXRzkYjfKYUSxPi28toIdhglQVSQY7RRisbQTzhW8h604WerZsNJEDvoX7UJVNOwOzHoiAMaTSeMesRUV+6a2ZXHMgskPANimVisXVz52PTTlR3qqF3VlAlI9Dj2LuQasBK4iZttq16PWLnwe/c1LOVboe91FQChcKQRW6xG1xHAgzOdbAsiVyIR7dtlQH8NPucFug6bhdcN2Bg9s44cPOn63bjGf8XtVRysbD34zcc0Ka5dPkKgOTSGoA7mojE0g4WyNgF0PPJiNI8w9T+qYvtTnWkzW+DmtOJ3t3AjqXlx4UJzq13Ad9P5GukYLaTK5NB5k6i71WWVrQNUKXNKjoUuKCuLbkmuvgmVNdqlFD5Iz8KS2AWxJwsUOB7OWcO98bRibn7TBMU9rKKitD9Migfz/Av2NzQr6PnJkQs3HKDMW5vUN1b8WwPO5GxNniSeEt1GwI4lttdgKEBI46oltdUVZhBIPjBvuJAnHVw99tQHq+bt1W8tQFtJhh8JxZtHYulZPf4xRTUKhIOUs/XEKu14U8FtSle0OYG+WFaF1Y2JJ33cjcA6bM1is09qK13oCl8NhrD0sKcCfUTi1QCJskunQy8X1Xu7df0hf8hw5cuTIkSNHjhw5cuTIkeNJ47E/TRIuLIqLaaRbewOkYFEjoQIx0d7F1yQivKDNPRGUeahtRtw+HcBO+neDUdgGPE/pVnJXwNlzxknmYaDlRQeALEcWcgMIAjj17fi+xRgqZfzcQ20LL+EBRfiw1Ri/N4BigBMKZ4UA6l1K9TCFslF1ql7RgHxuFRKVnrFIWyzKl4IYa6LFbnF6pYqql6E0reAFTM/idBt08kikr86ao9VpSpyd7kf4FVTVVC0rfHTFMreLUzJuxQ9NrQjuMcHuANvGey2MtYJXNR7091W19zw3BL2EHugHQaFEwH5TStrdJ72uRcvMczoWQvSHFYKzLf9QLeO61bsXCmdX3GtCA8kIwBb1IQVkU0Us4B1/B6UikgRISvSq8kOhvlCsz7aee9E0ExJHWMQfHGel4uLEAmChytbXOfR0tXA658JESuYi+zCA0Phcr7/mv1FM6FDRizLqq0IxQ5tT7jGr/s9IWPio8HuK9+XXEtYAJnaS8wTF9vJWlXlHQM6HtZcfC+tPDQANiwtAY6jKE7UykzsA1fybQlJXnqZQeKH+TpXeKVhOIGik0BE6p4U3gx1DsL5IX5u8zue2/yWB3ItuD3Mu3dFhIDoVL1un8le0ItH1iHMb/uUBbPMJwYPZLVX4eiZvHo3lmNTrJhQPxfRMlWu+DQ6hdf2L2T//XlXyyxOEeonmT8zkif9+Iec/yTKeZh3TeRZeFtXKnGHJz+kxolI7KpW1MGusARCU2L4TpsC6ps9f2G+HIrJewE/9aVSJH8+zHGfLdSH8X1IwMA4pmz9MzOguA98hxHGeFBqMxSVjLwbVtrdLsPrQ9zW1YcqRI0eOHDly5MiRI0eOHDmeEljebtSagNYSBsR0CzsUbQAZugVbPwMbZHGz1HkWOjzaB2T+WulRskXYPWH9U7AqMwleuW/flMrTpFoyqBGhJDOL0n6YZL+DogzPMaUW/C5hSTHCO3aUbr+T/fUVvZj7HsrHWuZjKe16Y+qxUsYJoBmqzlmmHpYPKzmYPLPbDzL0ADhWtGnVyDC2MmFrsdaa0nKFsMKAo0O08uVxAMK4Fd18bcXuB9DTiyzptm8tRqjb/A3k2DZ5wlEUGcSZhpGKv7IaZLPr6O+8qlu25AS7Ye7eNqBhINdVc+pmbbhBd6wHJax7J0PpCKWdWmAM0nWwAikIiRWfqA69HwbC5nlUT2pXaeK0eD75x9WefUP1rikPUVARIOfWrXOqTovyFQLd7d03y9lzvZT1RurNXW0r96M2ewGFe6aupF+s+XyzqSnBJDC5vLricwSetLB6oJAcnt/mnbrS/h/LSqZB1ZdVuyG4rJq1lHUj01zIbt/LfjfIy/cu5eLqIC89uJKXLq7lYj/JHskGKM+x7ZwJFqjFca2qwNceVQqF9nGFv25lT+GyD36Hyg4ozbU5WE64b2wE6SjI6JCYzixKHHlEXAV1oCZenI7Y5o/xiGml85ZFLOnDrf09cGgAKrOkJfup3WwNtNOxWzir6bfsgvKoDG3qyiwDzK/ZwBzum+NypR7edEBAm5eqFEXRTCRleOXwTy4K+iVv143cvr2V8+1abt26Jeumls12zeKZ2+1Gttu1bM82sj5bS7tpQrFIzW5ZcVBazagSPiae4DkceW7033YLlMQLnuNuDB7Ceok6jlW+7WpTbdWoslf46ydK8zfaz6n3s5xW3QtkMHqMazKEf4ItDa5n6mQeduqNzEtWK6J4zZMcqUrWkwLAs++w9piVkJ8eY58/uK00rTZSSxYtZMcmYxFAv2ZXAesuEa7f5s8e0x0nxDPawBsIRoJnZcU9KxbApGpZFygDxVBEa3FN7swICmrb+aLfBBjrBkEc8ACwNDe2OUJv/kratpX1euAYwv3D8iYU5nO1+VzyvasqYRGi3vG0gKYVhqqWoWbme9Cx4nsIbq8crfArfcEnmbG7gWu9NkKw/DA7DL4TJsJsJFmKWn36sVbBF1kV/HiP1HWWidEZvs9QVyMpmijJcQIXYFsXHq1oZzcd5eHVte1GyJEjR44cOXLkyJEjR44cOZ4SWAb4WZS1cgsMBzH48Jvqa0+8IlP9rSsxE8FWspH3VSKIZxMlXCJ6i4plPWhUObtiGcraiR7JACp1r3fSUbEMmNISvLLgHbePq1qT+HRQJWDf99J38A7W+4AvL7c4iwQQQSZCGGOf540s027AFI6pcpMAJlFgxzZTaH+zyYHeH16HewZEgGqZ12QF+IK40CWrti1aWYYDM3+CA99EaZgWDKTthRcbBHB2Zbp6F8PKg+pgQmgtzkgFIy9Ei8nhb/paVfc6XK3giVus6PEZ1IP1hrYYuLZ6DaiPwllG4sg/7OZoA+CexieSSFOo8npI/RU8E9kl3tZq8VAq8HWbACiWcT1ULCtMmgfYoIyyPwx8QK0M24hhxP07CI0A+ZHeDGLJqOqNkt2bhv6juDkU87LjRSS9vP+gpLTfUz1PkO++rDoQ1XtZf4ck0UIdb4pv9SZXoBmU06kY1xMF+kO4Fio6k0s7VSunfutsc+978431Z6ryGeC7YOIEEBBJiKquVLVsKmZYYfB3tfYd/bGtAN1ii8Tp+uQS7NPWDnMvfb5PKow775XVDfYlXsBQvw9q7lRdHr56Mi1iV6tmaq9NvJXDy1NLCLPSsEQQFh61KdLJolM6Adypgn1RRE8lrHSxCTYSiYo+VH00y4qwxuic1r/PVpDRIfHp3VprLHZGuGLXWmJhp5Mqy+PrVJmvCbfwXhRuLP1F2qWJkt+SMlGtrEplVcNXnM8T1d9GvpHsW3hRJ0r/pCtoq2EqYlf0I5l4hJOOK+gx5g2C+5Lmyvm428MTR5aUMljvtkjqI67FZ7GuYm6EOZfsQIgj3lTnbO5kTNv7Ad4/8F6SI0eOHDly5MiRI0eOHDlyPDWwXNe1gjlTb2mBPP1b3HrrKkD7vanCqIzEh16CBBSAUkWwFwAM//BJ2zyC03ClrttKEBLYh2ih9QVAwErAf6FAruo1IZPDKm7XnqDxPRJMVVUhFUByCYXyBFktbStWK4BRqG/1Y/gslZ57VMi2u4Z/b2dwCz7SBnSLIz2PAVjbupWavry6ZRmvx3PUH7Njm9x+5q7UNTyRC7WRoP2DAg8FEPjW1XhFEH57q1IJuypkmET6fpJyP8j11U7appHNqmWjQ7UGQEJ/W4Avetqqgk0ZkarpCF0tJQAuBTUyldv9YKAd1z7R/1k6FC1USE+VKUfPUYax4/P6/iBD11G93B/gvzwb+51lv9+pHUbCQNnTZSHn2w2BzLCH6nKUd3ziJ8q7PnmUs2eelTetz6SsavpAB8aG7eYOsuhXakkKFmqEZlyV7JASNm0r84RCfqrMpH0IIbgmGgCUnZkpMltJAe/eupGyWktRrmW/O8jDy2u5vDrIe9/3QC6vD3L/opPL/SQdrVEUFrnqXm1n4bl8XG5jT4queQvg/t07W7ejx4SIQ9xY/CxaTijQtW3zqZ1JApEJZO3BrfNeWI/+r5oEAaDC3HSPVfQDmhrHxnNg87Ea1BO9hvKfdSpXVGey6J/tLKACnJDRYH9gfNH3lr7QFJ/bvfhuBFwLvLhdiYuxaU4OsL9oSwBlPACYJXjjtutGNutGzs+3cna2ke35RtbbtTRtLQULc2LXgrazgrkUbKa+v/7F/X3V9iP+jmbAZvdAzbsVdYRnMXY6oO1S6wlXOScWETEzFuwO9Bw4H+45taxIlM1+PK4FNgJs7GvF0gkSVVr1VP4czFn6cpuKmsMQyRkkSOD/PssI73GAZPunc8BUymb/osU3VWWLdQJ9VKDCptui+NrPtTtU5dQxJgXnUByrrp7Vc0RFvillLUkU4LKvS2bzoY2BtqdBdGxjKpN1rXV1f2zjNPlinsahFfVn3M9meybPPPsck4sPL/YsZvfw4pJFOvl6WvhYEcoCv1HfZbx9al95SlWLVOLmqkLnTwXPZbO18J0Xc401VXfuaG9rYorGRiyciV1A6jkOL2S9ZfVJb9u1PupSzjcV7ZrQPwN2lOx7HYVMxPh48VSFjibiciYP7D2haaXZbnm+HDly5MiRI0eOHDly5MiR46mBZfUrVUhIEGTKKITCAP0ugDH7SM/tv/ZV8ViqzEtiAY6TreqmfPOCXqeqMVfzgVnQ6xYqsbIJH6xVQ6pbwQkyWYwKHrpQyJaEYfSN4Od33Z7vSjKAGZxzgDfzBD/mUQ4Apg2AqCvbtJgdYCp9jstGARmhEawnFPCi3QDfXLVLT2EWxVO1L5uAIjIHy1Hjndode4sDZAKeAoLD8qHrABhXMq91Ez5r1WntreinS79jQGlTOOJJxYnCd5ykKK3gFO7bEgCEF/TiVN9iXF5NyxGRcR5ZpLAbB+mHA4v87S/3qqY+ANbPcnlxIYf9wfo69j8KPm7ahtBz//BaxkMnZbORZ154k4L5qbdBZMDJipq5GjdYSbjKD1YLsD8xNTKAKQoTEjivHLoZWLeCcyomVRsSJgTKSlWvJYoK1jKMB7m66uTi4iD371/L5a6X6z0KKE4yGIC1zvOhHMa0g+FTpW+cV/o339KecM/XCPNePplMQUNru99d8au2NfGprg5G+7uHMa/HfuZsUNouK1MsKxBWSxvdOh/nJi00CAE5WsI5FLfF5zn3c0W9++/G69axDWsPjg8rBEgAjiKQsB9gNwNuayG/tm1ks2mlbWupWqiWzUYhUX1qO0QwH/ywg07b5acKKuMalFivcHeCJjOY9CFcxjrjUDPIoB9Z26L/cGrXkqx14eXWfqEYHkBspfM2rBIJXIYdDMGrWpoEhTLV80myj37BeLr6ogMw43pgYxR8hnznhLWxZg7VyoL9OiVt5FnFoJqOnsv0buaYgyWIJhQWanEvJJoUlOT6a3CdIxf3AKjqxQ3TtqOC2J+bwny7lCCVdxmwX7das+jxNXkHhXddt3J2ds4E0Xa75XOvrmErgiQblh5L1jCZ6q836yYviudJKY4dHZ/oc1cszzP6EO9T6BtNWHJzDdMUvmtF/+m81PceKIx5Rit8SYseFrCsZL1uaVG0P2jS16ZwXOvjUqGqZcsb0PPZrZaQJKVtTAbLOXLkyJEjR44cOXLkyJHjKYLl4wT1VqxKjx/44ddhc7qR3j/HW8EofrR3P9igItPtwAtlXrAJcLUZEItvko5Qg/60Zi8BaAYLTjKCQn0uoYCFopgWDQB/045gFK+rm4Zbnht4EROOVTKTOCJcO0aNI5WI9OWFh/AAcKy6L3wor+qaNg4K2Atp2kYq3C9sHVDIb15JZaARx9K20yJ/9Ja1eyekch9aQgWAb/N+tWJqWniO/hGh8JJbdUDxPE8VGQzACCA4YfigQJsq6NJhHc5gx+F2c/PL9Q4A84A6u4rwi1YIgIZHVXVDfd0delNiau+Mc8et41Al73d7qrovH1zJOM5y2EMtDMXynjYk3HIOiBFgOYoqGgU/dCLjIG96eC390Mk0wjt2EiHoNgBpRsFUxtJgG2pKFAacZOj3fE23u5LD1ZUcrq+l2x84nqoKftpQxaqn9XwsCOVpI4I2PE5S1qriW1WU2BPqYDT0wyyXV/BrHuRiN8gOXttIYrB/U7hlo3VRLCwWD/OCdj5XqBQ18AcYHgB1AMaJFUForgjC9RQGDU0p6jOKXR/GnDsLuIWJ+qQzyQKv8gLb64vl8UzlrMpnjGvzKfaiX5zbCis9EbOYzamHsIt5Lf+kiRS95xWAP9XzSXG7pN4aEwBUQat3O9oPOwKauqZCf90CLLcKmDcbqk8B3qByTx12XdGq9hiYDynoN5NxVxDzgs1nOJBwTcZ4Mbe4yN2QLQhtAEWxjs0wTgJ4tAUrJOP8Pt3oJgHRlhRThmtq5rHHVgqdH9YBMaHhX6KlBv6pojwWufQOovWF21LAYsKKW7pfOopfUnHrQD4AdawtBkBtbVcbjWWikPYrGJ/ulcx2gW2HX6sBZh9cwd7GvcV9u4YbR5gnf7DKiMmCoI5m254UcrV+U59mfQ48nLGWt+u1rPtB1puGawFtVYZSRkuuIBmH5Bn+Rs9qknB7H7PdJfq+pOvrIgmQtIdaZNj7iF0T3bDtfRHXj/lZFZXuVOFf9e84Kngxk0Qctqb2Nysb7KSZGowjE4Cb+luLL+pX3aGBnTs13wvP2rXcOtskCuccOXLkyJEjR44cOXLkyJHjKYDlmcXPAJQKFv/CJ+OqgOJXFVeuDNSIykZ6DdtWd2cZrKdFKAXIY36ZrnI2mMKP5lTWxg/nrhDTbbymXATjMLUWHk1dyu3zM6mbWg69qij3e0BKBWnYQqyKr7UqRb2ulkOkRPmKLc+46KGfpescLB+pBsPxWXDOtj+v1zVhVT+O3L6M26oNnGmRKYVxtL1I7lHBMh4A3zMVhLQjmK3IIRXPKKIXq1ypUBGQAPcHRXCh1hcFLAm0iBNgKHgtoHKFBqflAeAOuIb5wTrMceiBEwIs44UGJmFZMtJCooPpBT2mry4PPLdup59kmPcyzYNcX+/52F0f5N4rD3kt11ewxpjZJgCI66aVddsG0SVA9cXDS4L7zXGWRo7yjvsX0vedDEOvVh30STabACsARxUl1Zeg3YNacRx2Mg5QS1/K7uIhAfN+t5O6qaRdQ4mI4laTzCv0gYJlgPjBwFlZQa1cUB0IzwVXCx66WS4uej4eXg2y7wZBnT+W5FpQUAVi+F0YQwsbDIe/EUYGsIzjWSE7fUoovWb/72AqxcymB3a/ZPNPpq0G59eK6kSMIfS7K5HdKkPBsikhXeHvXq78nfvIasJEEx2qrtRj6hjWgpImKk+F9bQnMUsQB360TcBYAOhUzqqv14dHaKFQ3FPtdzA/kDRq60Y27Vo2AIJrqJVb2Ww2cnZ+rgDarRpWmsTRHQIKld2DF2uPqsR97nuyBcsiHryrYPXg3sW+8yIWAUwj6SEHy64uxqtQlA7ngC+NK2i5S0ITSIoSbRCEOa9FMq0HtH8TsKxuIgu6H8h+THQUPD7GeqoQD7ZCQe0M5qvzQYuKAkSqDYhardiivoKa3xT+XEe9IqunB9y/B8+puU76PCH0pBrZzu+gnRds59LFT+9HzcHDuFCLIHuPSNubxQL9eWojFLvF5+ri5gnJYdmx3mAZGeXsbM22xtreYczg3lmsEsmPUWYWrUT7oE/UtgmWGnzwfcnfq+LICEDZ6L+O5+i3b8NcETPVyvCdx0qoiQj1P4aJBYrQalFYjjzscsDuF/cgbyq9TrwX8JrVGki95bVPsXsE52/KUjbNWs43G7l7C8mYDJZz5MiRI0eOHDly5MiRI8fTVCz7rmovKmVF27it2rw6fds/1bX2IVrhEyKFDhpeqMv/HI6RfCqPqjcFE0FR6ccwE2dsUS4K9YAFEOMHbfjAzoDC6nMZxWuxaJMeV5XBEW2lME9VxgAJ3AYNn1f6vWqxJyo+y0JmKDqDbQcgBKDeFFTIIAIKt8AsUXQMBcYMPCY83tWE6gPsUFnBsjN2tTYAZMN9mwMvLSom6bpOhqKQ3mwsAEjmY0UAQcAXQJO1qUrfzN9aYb7+PcJm9DPh8HgkNH7l5QtCmO5ggHk+yHQc5bDvZX/oaRdyedURYlDZS8AJ7+OVoGbi3MNSA/YhUFzPct1r0UVAf4LwpiG4R7tCgayWIdondgumZFV1pXrAapE+PmaouNV2xO0l2H6mwDY3l3B/tKGgYBV9qWMF/wCdJ4DqfSdX1we5PvTST6MMpoR3OEVFusOjxNXALRZCkUuHxuzrOMgjijLR6GJO+LeJnYR7MLNBfJ6kc9Wek6rbeW0KqoILQ1KAzK9CPb4Bm1WxrM9VoO9H90KZwVrDVPEwRvaxpeprLYjov/OCfj7r0x0OLlxNHwqEC1MgGwBPxj5UlvBgZgHIquaccnDLB5vHla2msjVlaAr9Qg8EiXXsj2Dx4IvS7EU4U7/k2PjR8EGhtELlxPqBc87UtPg+hYypTXNQK1u2wY9NP3HfCWJKVnux+nynlhPhRTamCynhG4/XYT4FGbT/3XrFlPjcDQJFud0z13rY4FQrQk0+GwrdcKIIT9VWw86RXFLsaLOQWNyvA/IEwEbpemIZYokA3zrDcW1q8OT4sVhpMq4Cf/c5qffO8W52K1QEW7E9HUs+7q0/7ZzLlELS5ovGd0uYm8LML7zZbIz5e0BUlnvy1d5T/bW+w8fPkYjo/b2VvcRcgMJ/fxvE/Glgg2HvY1mxnCNHjhw5cuTIkSNHjhw5njJYdn2tfliFihXWBu5Viw/g7rs6z9CFjtSMqW4wUWapvjl4JKsNRQLdHLRaQTNX4mk5KFdTRkiMc4IFwuKgWmGL71HaZiVNW0rZ1PSx7Por6YaSajIq/6ywGdWaVUMgATAKHOwqYPXc1WJkKP42Tz1FhoBY600t6y2K9KnHK+0vCE4VELAg3xH2EaMci4qvAehYr9d8ArZdYxv/iqo/vSdFDApL0STwi+7pn9xTuatAVLeds6BZXclQlTIChqJP+kn6cZYHA/yPVQNN3V6Jc2kRLy3epqCbMAqOFGgTKrEhbwa1RmEyBaW4FhY3HEe5utoTrr780kP5mq95DwH25RXsLvQ+AbHVjkMhLqAxwMx6s1E1LPymVyLXw8gifYC2OxTrw32N6g97+2wtzbaV9e1bcnb7nEXY+sOBns5abAsJg4YAkfAYalC0dQ/APco4dFQsAwbPQ6cWAta28KHGMQDHAUXpe23AGT6luMYWBd/oraw2BfurXnb7UV6+91A+8NI92R0GueoAl2cZvfhkyIL4zKC21pIkcSt+WrRPeZ55FJuSV7fnW5E76z1YzSiUsgKVZk0QYJTPrCQzEdTJBukw26C0p8bTfFd9Djk4dnCGQ6IdMKZhf0HHAwfLAbChj+G3PUmJRA4tQ2AnAEsZ/FxyGKH4IyxQAOtpq4F/UEVifoxakMwBcFAQ2z2wBe26YHVxtoXlRStVWetuA1jZVJWs25q/X7drPpp6TX91+seaYtmPg2J+qkS2HuC6YsXgzBomTQbEhc+VxnhOqepdKmMB1lHx0+Fq6oFsYwHSXCiEubPCQHKB6/BifGhgKHT1OCoaV/NdZadQJGPcJhCZa6yvxJoIwM4KJL9WpRZ841HSRJ2tMehb9e49MrmjCmRdUxRO690jQefqWRb6wxyjVzIUz7MUVMm2CrgTa5oFFPYtCexj0+uqaTB3caAwKMccAbe9hrY8qX+0JSK568PuB/2E9xvuFhgV0rvrhamZdceAJhJ054tbK50o6n12IhlZr/ioa4x/Ba1sL6yFllTh7hCuKbq7hLA3VdknSVHuJGDBvhOo7IkAFpw1AxhzYvG5jHaY7H1Sd6pYwg/rQijcqKAf71Hui+6W3ewSZ+zh3dfaHcMNhS/rSs42G9lC8d9o4iZHjhw5cuTIkSNHjhw5cuR4amAZ7gj6YfhIKwyqtwLg0k/t7utKywf6COuH3dMP8/oKAxnJ9vwTl4Dls1N/4UREqC83HGQ/0x6BaskVlVpqm2Eb2INIzaGLAW4UYyJZivjB1coKl1TJ6zCOrUH4CkMEWCiU0TWaSlgFhNTpYte7b8e3r1CFnqq4g1UIj6tgV1Vy9pykvRyCObigKpVOFoBUvkXdAbGCOof0tB/x4xFCJXI+V8bS7sBAN2wmhkF2u4Nc7/ZyebWTruvl6rpT6MR2QnupfQfrbs0zQVkFCEfFq9KucT5KByXzOBHQ4rr9PgAnm3XLr/S/JmAHWFOABoCpsD8+tPheVHmzHZAcMOuNCHtcDWivCa81gMU21vGFMY5kydCjKCIeg3T4nj7bgEupR62NJYLJOJZT6f2i2KS1u82aqKhMLQQMSJ3OgdDxJ5MpKGjjy+PTnED7a3SQh8sL48mep8BZt/RTHRvmVdyyH9TeC1Xmzc8JF5J4Owd9a5I4ShXXfrecb1AsG+BzJanOQz2eXy/nFq/bvW7NCsNcX0DJg4I1NFxsQI7CpE9T5W0icz2J1HsifUJie+D6aE/mhLl2Is2OAyFRK5sXMY6yUMKm5/R1lE69y0MuCp1awT2sO5hzgMemAtZyn1HBq8dC/1PjulC9a1tOUvC61I5BF92kvdJCezY49CnJmE7v2cfKTQPhhhaPo05/8wi4DYP79NXePqk83K0xzLrJEy5c531FP1VS+yP2WVhjwptQ2lfxe3f9PvUi955Q2J+A+OTGI5xf/n75QyxmGpv5BKWHpeSmd+YcOXLkyJEjR44cOXLkyJHjKYBlFC1D9ONAkMgiceFDOZRyK+nh73sEOOzpuQtVYVtVhIlQOFNNaWpEFbrpnvwjVI3cnqsfcyFiU+Wl/x7FjPT7YYDk6mi2EmZpwUJkuo0ZzrcoIgeo2J6pWriGJzL8c20rOcEBRKnlSgps68Zrp8rqdJnSGJYHoxaDWx07qaTnNnLgYygJqYwdV9J3B6o0t+fPaMEwbBsfUNjtCGcAme28uL623khZVbLdbKVtN1TPdj0UgVDQ4r7hO6soCl7NgJlUDPM3sGnQQoPDcZaOQBfFnACKRKYeVg/md70SadctFda0o4Bad1rJalIrDkhKaclRKZSDanAkYzC4xOJ/UFsPUgwH+vTeu3df3vO+V+Slly/l/37gARXAtIRAssGgKpXQBcaCQlsoV6d+pLK1LKGsK2XXTXK9H3ld46BjYVvD+7qSZ56/K295/o7cvnOLbQpQvLu6ZJu1LdSopUwdVI54PeAx1KOu3sM1zzJ0KK53kKvdTs8L9Spev27ZXvNBFeB4jD28t1WdPtvrZVpJN/dU1t+/t5OLy17u3d/Jg8uDHMZBrrvObDC00Ju2WMk2MixnBfqiB29MmqResf7VUCuBkRcESxlbBNaxiJmhtVObAWOSxvyW+JLPcYUlUT1Vnb7LwH1fq0rtQBzgB7gMF3AqaT3hgTmI5A3grj/HbGBYVBKPI1XL+EquyQmO5JTey9gfZRpWHOP0xzVfXfVMN+/yupLN2or01VhPYH8B5brOaewagN0LlMx1jV0EG1XRwhYjqK1XssLOBHr9qqI7FJALSQVLRvGrNyYsVbDuRfNrKmJZzNAsUzSrY9YIZvKhJsLWX9Z6fH5aUA5fkVzy59hxWJQS/smzzFDew0vZ5pUbydN+2UFjSCBNcsRzrIAoNbNMzkTlfPAPZ3JQkzTqfe3+yjZWsT7SU1nvBf0z9IN5H89SQTFcNKae1qKowezF7RqsQKJbDql62F0sov+zwlN4uKsCWPMjsbCoF7VzVb8qlW1uueqfferqcbM7oWfOsuifz7sVVeNHFu5jP3DHBdTPMy14mqGRGvYqsBTCeB2wY2UOiTMW/Rxx7ep9DYW3bnJAX1TRzoL3qe2r/6x4Ki/JAXNMLOjsnHSXD8YY2tH4v9pLWTFan8nWxp7MM2FzAMuaHDD/cvf1h2+6lDLOk+wPe9l3Ld+3y8yXc+TIkSNHjhw5cuTIkSPH0wTLw6j+w/DunWBDsPggqmACKlUWmzOrgMp+9jpMQTSZPHwbb4RfiXQ3buJdKHMJvY5WbMzVzqY0w4fpcRqkmgyyJCIyfy5f54rlRNGqv9Nt/woH1FsZsAnKY32szM9Yt9Gz6BuoAoujQfUc9vSbx3S8T99erR626mOa+ienSk/62AbFsrUFVY/apgADATxaG/PeaWmxkvqoNiSqfkbJqYIFnRwwqhOGQRgDVcoqofJWkDH71u95kq7vqVje7TrZ7XtuiZ/sVlnEDrjCYCAuzGEHoLleGwCkei7Td5lwS/tBfXQLAsTt2ZYWGARTuM8RdiuwOKnZhrgmsD+okvEIHt5mawEwDxDmQF65KcZHQZ9pB5/a7lqczEN3nZuvtRzlcIBv9EDFMsY95oCPccDXBFUFywqHsF4wTcGyXuCjTqtRfeqQ7VSYmiKxiG99i/1N8n5TmJ6IYJdPiUrFU8UyLW0KFHojal1Q66BUXopRk4jq08V4tq0M0YfdObolYBaq6nhdXBvMT7lMdgvEh1sAmCcuFcyqXHYfDxZTY1eYrwf7QwvNxWl6qjx24OpF91SVay0f7TvS+3Y6GJTayU6EhUpZHvk+NKdbSDBZol7hSKJQbR3a7iSTwHtzgByP6b7LDpa9cKR6ZZuFBdvgdAy5p4RfmxULxbVxh4HOO8wd7oKgYjmumVGxnCj3E4W8HzuOIR9P7kUdx7vx5JN2W7ZZqt9W2wvzP1nodOOz0zmr1DbOKd69vY+421LYCRE8oGNR1VPFr+H8xXlPBcbLsHt0yO7P9x0z6fUnIH6hmE4vITkP+8bf9Ai2/ZQ6h/gewt0Xtp69hlI8R44cOXLkyJEjR44cOXLk+AaB5b5Xz8fe/IMJR6EqBeQzENoTOAAqKzQc4eurlqFBK8kP/2Q3Kq2iOtE37FKt6zupVQHrxeT8Qz6UqmAiYZM5i3kJixC1tQKB3e6acLBan0ktgIHwEzVFm9lf0FuYai5VKELFVhbwKoXs11SEprDeQOW3UqgIZXDXX0t31YfjNU0rzz7zVtm2Gzk0Hf1n6+Io6wIK0EqaWr2UR1Ny7/fXcnV9KR3Uu7tRDkEFqzYiaBq0JVTLCkg9ADVWcrXv5P7VXo77To6HUarVSmqq10RGANTjSioUUoO6edK+g3Jzosc07hsqQ5HtsZGmLmWYVjIC1AOgoA8LU++uJqknVW0qUDLLCynV5RWAApdl9gQOstCjAP8YMVBk49wD7ECKlY4RegCvqAhsy0KeO9/I+aaVt7/tTfLOj36L3Ll7W8HzOMogB5mrWjbrVn1foW6Frm8Y6D+N49RFhP0OovU6AOCPchxG2e337L9Dt5dD11PtrKMNYwgFzQoZpoJdf7G7ln6Y5ZV7e7m86OTh5UF6AHGKo7WdY48EnXJMckCdGcAyQsGUAhxPbCjIoc8q4ZMqXB0es8Ag1YXRQkELpul8YHIl2c4eyuJZciDIFh9JoCjoxRwGrCU6xFc1+NW+s+SJg+7UbgZgza0nWOgM14hxy7kfgVeE6xF+sYXoJaztg50PbkUSky+V/sz8CGDmIGO/krE+CmyzubEA8xlq6sJ3HcAjFsUekZDAQANcNv8Z+CCfWCeorh/trbsf1Noa14vnmtKYSQeFuxziuC62k/qlK1x2FfJNgN/bPx0LKVC2q3GFLrM7quY9woMaQFkNy7kOcjeG9btVUY3nMEU0faRtzeKoM0ueArsmTMEdTkcYjd42Sx5LjBzRJlDs++4JtnPJ9sU8P84A/WW06EDzFAWTSzy0Q0yH3exjV4E78NZBpdYYCR19xA+lOGGwN7Shc+vQxMnYX+QrXMGcJGTMz1kTlabGN5DMXQ3Y0UC1sq19UOyTozO9GAu7BmU096ZogoAe8LE4qM+N0AH8H53Pbd7HhAzeuwRKeTaJQ3K9Rrwf9H0vcoSiWi2pkMhlUU9TtkPRj1REWVdSty3XmMu+kwEJtwG7NPye9X2gqtW7PEeOHDly5MiRI0eOHDly5Hh6Hsu2DV4tLRQ2eEEv96sNCih6RKqySy0eDECaIlG3j7u/rVWwdxVYYjPqtgIUoNnrVUmrz0e43yqsJipYaqxW0vWdQtpxUKsHKyz1qL2pFTtbKbwBEISqNgAMbo0/8oM67hV+u/OoNiAdKBctS1cyr6FkLqSpGqnt0RRHaVBU0FTKuK5+hMXCLP3QyeEAwIkCdgCk0BkrDFXlsKqVg5+v3mkAh4d+lN2hE+lGAe0EVMZebCjURoPtANC4KtgQwHIDbQjFMhpxPGqBv7rWAnvTrMX+2P6QIRdqoVEQWBhgM/jiYARIyrdjq5LUQWLiNwzFMqwNoGik+4LaKWjPHnndTVnIrXUjd87W8uyzt+WFF56V7dkmFg8c1YIF59ekhAIeFOsD9CGgqmt9PuGPFYd0WxW23yRH2C7QFqTnMQlvrM1pb0CQVlCVDO9oFOq7uDzIxWUn+8PA8e99Eb2UfXb4z65QLpZgORnUrt5Vvpx4Q1vqRZWRCqjcmsD9eb3YnaqsTWkc0zUL9SaLT7qNQCLZDz6urqy1P0EVTKBtY8wPtFCGGjQk6KTqPiqH3RPcVZTuV+u7DMjQHMF7MTTz2PbjYx3gdbidghWMY1+PhT4mFG5UJbEnibhG0GPZCrYpcQ5K8ZDSSn5me9N6JBYmZEE2t8Lw5JclwHiNXIhSO4tI65fgOlHIOnh+ZPEJC2L8ncFlKoNh0QLLiWAf5OpWV8rGAqh2cerb7gDVxpqPJxazoz2EvVolrZYEibYS+lIkZ0ytS15px/YEh0qVjcuqnYjfh9tV8KpocaTrPHdyxAp7Jx7ISy9rL9Sp150mBZIqfUmEHEr6g31ZsmpLeoTTmTo8kf46XGYB1nHQ9c997kPX+lrsu2e8T6LdRVT3u9o5PtJ7c5V6VPnrGofz22yzpUNfi+sZxpHdN7pljUFxfS/EvFwx2dg0tWy3Gxmw44SLMf6O9dSSAlxT8dZR6a6bHDly5MiRI0eOHDly5MiR42mBZaijiDhWNf05qXSjZPUUqUQQ6to5/bCu4Iwfkw1opS8M2+ZdaRe2xR8TGwkDWvgQTEWqQW18V5cEy4BeAIDwCL66upSq2ssAhZeBGC/yRQWZWyU4CDOvUJyc54RKb15JAyOJGWAL/prwDAZEhjoMBedK+v/C6zVYQRg8q2oHcAoHYF8BoLnf7+Xq+kr6/ijdAbYN+gEf5/Z7VzsHWI5MESoQqKpH7Tiv6A88HXqpARhgG2oqQFwCQR8AFZSOVmxRnR6OMlCpdpSuA5jAPcHnFu2J50xSTSLrDuctZA1/3ONRNm0jz949l66bCdoBvmGHoQkDVX1C8Q3VKHW3vh3e1ejmbQscXa8A3Vdyq66pVP7ot71Jnrl9Li8894zcun3Otuy7gR6+Va3b99EWqxFtrpAV/QDoiMNTIT8O0nd76Q97gkiH8LjCcYRfLRSE+B5fcf/woC1lmrTN0Z6HAdB/kqvdINf7Xq72vexQtI82MKks0hXRj9pRxMKOUel7Gg6T9XvT8of2crAWn+uKy3QLfDy+zpNg7WIAT8dz9L1V9XFS+A6KRtiyWAE8T/qEY7stgANig2GaiHHrGv298jn1UnaYCt9lzl3ARSPQWjRSvX8xd7Too/nqWoIJc9jnI4buuq31AZ/l9Vo2eGzWst6uZbPdyHq7oSoTcxWAjHMtEG+F4Y/6gRhdTNWyaT8xeWE2GKnvdYDCxwVgtlUtHseUtlbT0QCiJrJOfQsUvEZgGS1OTEVO715VtrpvL8eH+yenQSWx+TRTda2Ju/jzUY4cvqkBhi+8mpCBfzyhvYF7Am2OExtnBiN9HvrXmETxQyKrqGtXUSp89mRk3G9i7xa8L21v9U+2You2bsaJpesz2xuqYYfcZn0RiTGSAdo/3BnATvBxnSqkR3vAUscSl0xOqeezFghNkkILwTGSIKXuvkgKQ3KddSUz5c3LnELYfeNNj59h6RMUy2aflMz/0PcEy/o9i+gSeKv9jxdPdV6Or3XdcPcHzrG+fYc7YGR6Wa5kJ2dnWzk/28rZ2Tppkxw5cuTIkSNHjhw5cuTIkeMpgeWaBBYfeAGWq+id6vX7XO0V4IpDYhoBECKFckUBtKSfZ6MKTKGvIRuqKbVQV1VhS7aB2nB8hXTY/Au4RAUu1dWjXFw8IJzgawwo86vBZ14L7QmiX2vqwYwP5sWxlBYq5rmUvjOjAADUFUBWIe26lHa9liYBWwB4gKxNrT8D4LmnJVTLV/sdiyjhc37fuXoN8GZFdRnVyq5YTuAtLQxGAIWVgtB+ku7QsUAirUMAAo8lYTnboNSt28psDEwfZ+lnKH3RrqrEHadCxlnVbyhcWI9HWa8BYAs5J2A7ytmmlReeK2S/H6Wu1Dv5yGKNZmeARsRueQOvId8QtqUrGAFYLldHOasree68lbu3z+Xj3/k2ef7Zu/LWt7wgzzxzh8/t9j2L9rmtBKDIkeBf+wYgGf6zaNdBBhmHQQ67a+kOey3qZxvfqaQfR9kfrnnvtGywAmfYKs42xjGmo+wPM9v04VUvV7uDXOx6uT4MchjguWxAKyhUDao6CEuK9wVMGRhiquTWX6oK3BWlafhx/ACmjHRglDoGcNwa3jZFsxYjiwUDqSYG/LIXUc1ovsXY/q7F+yIMVHClv3C/WS2K6fegHI8F3rxA4RFjGep237qvykktlIl553JSXSyQjPFEC+xaFBnqnPayh7DQwbzfrhs52+Cxlu1mwwegGB7b8zN6ciOxU9UoVlmHIne0eljqrWOjRUmyY/wFXGPfUClPzXu0b1gAZm1rhcopWKbbuEuwE0KYnH+hpE4hr8NTHWdMoAW4qh2v+QgUIzTwmtBO3jueScsGrFmmJp6h8tYkTDrK0sSEQus5gFa0DXeSYAxwR4DCZd01YLtNDDC7op5/pZ2OHQ+JnyN2PiT3kYiWlfmrvUuqfLZSpcm48fllsNjsWqKHdOJ8ETzTNbkSpcyprtxsZ46wm8ADYFeTpCwKakkr7n5wRXNQKtvpuH6qhQ6KknrRUwfLTHjS4sI9qD15qe8znmdgkVN7H/PJrXPIesYtOrQEn0Fr281htkD+u5CqMMDcNI089+yztIqZyppJs26H+xK5e+e23L17i/PHldM5cuTIkSNHjhw5cuTIkSPH0wPLlW3JhYKVillAOv0grV6Trj7zbf6mSwuwSh96mAhbHD6fFgVLg1vkqVaOKrWAtQjJVD1G9SUepmwjcMY5qbhTz2YvnBfsCvAhPNm2bJWqAkzgFwPO8FmljanAixmervidAiCoMFmEDs8hEIXtA9oG16Yf9PG88libnYF7jDpQtwKIKIrnxaJcEQiZIYyPvSgboHjwv018SwkrTBlsKrxyVUptCk4VmK/MpxjqPEDZiT7MANXsR4Nd8IkuilHGAb7FI7lQ21QssHf7fENYfpj2kAtHRmeeom47wvGQsiHAjhJqb5GzbSPPPHMud26dyfbWWtqzhnCa1hfBpiH6sS5+57DF1JnDMPB1Q3+ghyj+juuj6tDGIQvyAdq7T7QdC2MTW8u7HmB5VFg/oE1gGQLB+k2eykkhviiINIuAExVpSJ48qlw+faKqQrWvQgFJnkNVwmk5r+UJ3BLG2slgratfIw6Lnr6unPRDhgKSJnLVY6ZX7dv6I6tzaw6H2ct5E/cvBCsAg2vRysHAJO0GHJDr7zFnAcDhTw5A1rYtFctI4tRNLTV9YfWBBATm5kLhGsTKZqthdiw+5TQpoFYNDogdXHJNoAUGVLtoEAW4xxvhcqLWXSihrUieS1PNU3sJlV2tbgUC6UUMyOmO9NEewj3YgzLdh1aS4AuqdranPvg7+ImPUHLTXFkvwxII6vVtQJiWLPBYNpsYvw4/r3cas1WL1EfiNa1FSbWXsf6qFdEjqnCCWPfZj/Nq2b460xTCnhYnTLytQ9Zlpnd+fFVyuBTsWztTpW3Xpmp+fd8I9+qrR7JbQJM5vub62u2uGtHWJhb2c/sLTdTwH75al8FHPF0jF3MopqFi2mPZPJbE80csBIjNAyg6+vDigmBZqpa2RHjv26w3VP3jgfem66ur5J5z5MiRI0eOHDly5MiRI0eOpwCWz9c1P9iWvRBEHvuBKlL3VsbWXy1kFPEXVVWJCowfw8uakHdOP4z79yn4MmUzXlNWhTRNJXUNxbICXPqqQslcl/QyVtDUUF2paklAWlMdNo1UBRTNFf2PAVgAgr2AFBjDhIJ2Ak9N2x5N5Z2+fHLlHpXJACJQqgEQA2COUja1NC3gVyPtupX1gC3Hg3RdL2V5lBLnBByv1yIV/KkbmVEsDypj0e3UADUAd1CVDdOgKlMAasBxnB8qZTBTADiaV7tdh1qEUKgJaw8UZYOAeBIpp0naVSvnTR2UflCO7wYtGgeIejiMMkwo4FfJERBtNVOR3Na9jP0sZ7f2rFFWV63cudVI9+yZfPTbnqP/cAev5yOKNqrHNm0n0H6ullXsoSpVKxN23pRyuy3luWdvycd/7Nvk1q2tPPe2u1SgzqtJrq+vFKQDSB1Xcn7LfIVN1cfvCIQBiUcW2Npf3Kd38vXlAxmHnlv3YY9wHNRPGpClHwYbh1rcb8ZNCZTQk1zuDrLbT/LSvQPb5OK6p0q5QzNXpcyDAmqJuDCxRbAwsOtwdqFMNWbmiZc03EpCvcS9kJqqdz1uep2f01XJqowHxDMQ50kJlRdHTa552AbFLfMoqupmAogF6dKtBJYkcDWmKf3VcQLXjHlgns70P7fnOZRkrTmMSxTZs0MbxKY/clmoTQ6hr8JljOtN27Kw5Pn5Vm7f2sidO+dUs283rZydQakMOLaV9WYrdbuRsmlZoE6TGcysJHDU4DCPzwyAfl9osTuFyWgPzBOoc3sWDdTn6LHUukHXJluV7OHf88ZOOsg9DVT5vLTCcHgMdf3AgnkydWpFAaf0laldg5exwlsW6DPlvp6C1Fmhvxff7HsFplCEmxIZayGtKFDgkAlCs2Lx1TYobLGG644G93UOXg6Jh7qPS7eyKOVEQW1bJbCmzKP5WJdoT7SDe9h7f5jvtqmxfWz7eHU1snqOJ8kMt4rgbhhcN56o7QzFevB8tyKV3ifzEQspCkD2bG+WusMulKIKCUqdutp20WNZ+0/XWoftsFWCnRDsoqAiRv+hr0yJ7Ts6knqG7tzh09pRPO2DbAhhLkQsnfzzA1jCS98K1BaDcNnqIWCe3n9wIdeHayZd6s0Zv5YFCs0+K88+i/l0LldXV/KB9703WPPkyJEjR44cOXLkyJEjR44cTwUsE1oRgB2l8A+/QSnlHo/+Qd/+Hlw0H6VisVife1e6j6Q/ARBBoZYX9VKlcSwUF/65d6V5X6qaLF5H+hqFcan67fTa/JrmoHimWJigTsH28QjrCQA4AIRkq7z50aKtoOYGcFd+AjjsijgDWwmcDBDDVLgKeEyl6j62ZDUKOUAy4HWKo6haG/YGZingVgWmhKRSztqH0GReST0reC8Bfayd4pZrVdfB2gAex/gKL+eqVMVdXRe0JoDyd7OuFQoOsJKI/qGulkUA+aiJhGpCoViOvrm1tC1AmXkhA8oQHpdSCOxWfHO3tiEfdtxQ3G8Y6FkNoAwvbSQwSrSzVXz0MeZqQiYxjDO6YhkghgUZ4ac8aHFGLVLpKkGDxlACO2VN1Kv0kU2MMOKIcj1nChRN7eyvN3X3ovih36NPlBO1sm/xX+g/k+vx43lnhD+ZOn8x5O2P6TXHnQOJ2atB8kXhuORpcSd/UoAzVXsmetSowgxPNA9oncOaMCnotY1kEJIn8avuOPCCfWq1kexAuGFWB1Xtot/cCxm/d9CZ3E/iEx4PdPL6cAxXLGN3gd9l2m7p9ZgVUCCIUM16Msserqg1MLvoM/M4VnBr68aj2Yq0N/WaXGVMYO02FT5ivSgqzjfxwSPQJsXB7LJtY7IjjnG99+RabD2OjZrci8Neg+36fXxhKDx5qmD2NSHAZk98LEzHY0E+G/cL5XLyPuVd/IhgNzl1HNduSXNTlieOdVWBJypvv17fzeB2S1TBJ/Lz0K6nFjlRff3I7gNPWCVPdSCPRFrXTbIqJxmxppaVtA3mEKxrtAhg3/eyP3TqV58jR44cOXLkyJEjR44cOXI8LbAMAMhCb75tu4d9wijjEcpO2AxArWrFmkJRL5SC0w+6gJpqT6HQgIXfqLTS4msM/xKYzIoFvODv3FKRbAX6UPwrgDgvsKQeoATfBg1Xdnt1VVLZrL8/Lcal0BdKZxwfgtd5HKmG7fuOfx9rVUHX9VrqslWePEN1PWrRO6jLZnj+dmQxuNZdv5erw7U0VcW2g/+y+jhXiwJ/qwqWGZP0Q0e7hmnqeW4FqHGDOIAyFNCHQycy9VLNo1TFUVa1FhLctmad4DYIAgW5Kh6xtR5+z2frluq2stdzoggfAMP1XkT2UB4f5QCgPMH2An03y27Xq8UCNIkAE+tK3v72u3K960SKUa73nbzy4IpfhwHHMzWs+WDLCgq+Wcp5luI4y+3tmbz1+Tty586Z3DprpGlWcnl1IbvDTtZ1wz7YNmdytj6zQmFqETBNB1oXlFTIVjKMvRx2O7l4+FDe8+73EDBrImIl57dLWVc1faD7oZd+HKjoU7tohYDzBKXsSrpeZHeY5HI/yr3LTrp+kstulAGQDckUAsxC2k2ryvzO1ItpYiLYG+iYOqb2BYlqNqBnKmfdbFYtTjC+4NmqRzJf5dSXwo0XTBXKoneJehr3wsOFoW1F9izZoepT7ZuU/3nCCAp69aDWseaqSy3YafYQhLe6nR72v2xHWgkot2RBMfp0m9KTCRLNdsCvm0XRHMwDfHG8A+LD53oinK+OsMAQuX2OQn2t3Lq1ke35mkkI9VnXRAotbapGinrNr6saBt8YLeb9axRO+wl9DaW/eeky01JaYUsM0YTGzVCy6gPPo4IbvcU2wVYAJJTUV1iXT6Z37IyuQk6tLNyOJY4O7QCDxxO2gOzkOPYyHS6EFULpMQRVPVS4sKTB3DNVNK6NlTJxvwW/V79vLZSI611VNX8/j1BeJ0AS/TH17Idhd2AfqSc8Cimqt7uqfScd821rfvZmB8PLMgWuKZ5dfjuhOCaBdLRhKcsmLrM8vaqo1eIn+MdEhTm9z61t0T/wMA52QXEnAPzzTVLNIoaY0y4MTx3Ko2+L7STgWiKWGDR//iDjx5qP9tY20JyB2gdhx4MWyFObIqi/8bMmDNPNBTZnbZ1HzxfFFHbmJCkOsyZC31lRQ07uOH4cYrtXvSdNg1ocBVlReNStpJJRRqjMyXmUuVdLkml3Ye85VxxPDx9u5OzltewPB3n44EFWLOfIkSNHjhw5cuTIkSNHjqcLlgEYaEUQgK5+YOWH7aBYjvrkIBgOCrO4fdvVjV7NXrfXJ2quwANcwag+yl5MLNTESnfs87XmSxzgmsO26N2q0Mc+0EcpZ1SNmiJNVa5q7bFiscJCamyXRoEwelUCVKEIoB7bFbXEWPRUhgp2IF9VwIbzlieq6Qh9UDCM56MXsAOstEnUUmQGeAIQssKGOFdTrgiJFScnmjYjiFpoSm1DShSqm0uB5eo8A8YfWZQQA2GgF7HabMDCpMTXEapgnFdhBwD8dtsQhNy+3QrsOw9dx2t3NkLw5XYOBvYAlXHNgOAbqJVhbVJBeYpt5APhIvWjx0qmSpXLLilU5asmLQBAV8eJ6mZ4K/ddJ9dXOx6jblCMrlToY/1IT2Wz6tB+jT7ELCxJv3BVLUOtDH/l3hTLqfKXCtngdezja/FlMRZTb1R9jpspJGM97eDg7ZpIJ313up03PfapwNL52HLon/pUm2LZlZ5BVanFDJEscajlthHpvaqCPha61JMpOGVbBmWyzR/6fS93JKQXq88xBW3STitLBiEZUtfwNa+saKeqmcMOBFp9aGE4BZIhDZNYXSeSYbsW+vXyhsoEBIcMTiJrTQ7B19nvvPJc6C9XLKd+1ppAcFX6SU+ZKnm0hBSsKwbCZcLXSXdK0JbCFMYKcJE4MpsFnhvw1TwW0n7H3KMViBeXjEp6eqDTEkYBs1o/YC7YTgku3OaRHNorjpnj0rMoaQ9TI7tHOMZTsHSxPgk+MdFf3sdfMg1i0UPSVjumgWF/51nMuyjsjx2Wep8bWNYr8ZOUofCie+svVMJhKiaO16lyORQJPBliNjs5PlM/5rAiJOHAm+2EtlcLjHBDJ5Hqxv16AnT35yQKZp+TWPtgV6RthB02WIkByaFo7uRw6F9FhZ0jR44cOXLkyJEjR44cOXK8QWD56nrHD6b9CBAH9jGq4hDF00Z4AusHVQWZZkJpW6WhzkQxrgBtE31VBNIKkPUVKlsGX4DFA1TKNbe9R+9Ll3KVTSlNUyiAInxS9Z4/Dd8BQkI4C4sEgEdaWhjC4jkJVtSnkrDbClupWk/o3bzC60qYxELBCodlgBiAKSip0YyqZBvHngXvAD2x1RiKUPeSraDCW+m9YJv/NM5aeA5+zP2O1wmQq0pjVbK5byoKcE1DzweUym0FoFxJJTgW7CUqXvv1ZL7X9MtVyASvZlxjAzNq3HOt6tSuLqk0rI4zyhHK6tDLQ2yLRgkvWiUf5fraAFTVyHqjsHW9XrFg2tveckf6fpDNpqKC+XCAZzOg70zlL8EuYSUgfMUW325rqRqQ2pmKcBQzRIE2+As7KCQkpFpUlbiwXoHi9Tit5LC7psr66sFDuXzlvlxfXsru6pr9zMJudUNYjPMfoPDe72mR0R06UxhCOY6EAMawqpaPq0qm4yjdOMke3sroP7T70IeCbfBhVXWhAuagJAwP+y5xS4iKZcdNyf87BbuB5zg8AvDz5+LpqkpeZFICik1qftk8cy6nFiB4FtxkAdSJ05A/KB08r+Q4jgHEo/+QQOC4hTq5QmJFz099Ll7nADlRSRp91WuNE1nbhsJzbZxgQ+CNyPmvCR+1vcED8w59WsnmrKXHOuwx8ABwxqOsao5tWGLgWlNbAe0DU1sbJFfO6yDYFeGmlDVYGOw4oBBGu/OrAc4FeIw9n/SYnLjmUuYLNAlITNW6Gbcfx4PIeJBpOMi4vyZYnrsDzwE/eIf32npHmSd4g0OwbOCyOXI9KhotKipJ0knb8tSSBWLhUcZ+oF3M/vqK46KB6htrq9mN0AzH2gs7NwC2kYzi3EywpoqBoUq2QqVMiGkCiHY29PONibRl+5kSH6pbGyMoNqqpl6jk51pgENgV1wrcvZkVPvv4D5YaCzsO3T2hr4dVyazveubzPVPtPHJs4P2DXv01CkJqQda0YKB2u/ULxwQSA1oYFMeA07Ras2AMqVUSfnZAr4X90FUxARPedxIbqACXLc+hba1zR4ts6vxBn2Gt04KN2qaeLLV3VB0rhXpg8/2WXvhH2v6M897sT/C+9ug6lCNHjhw5cuTIkSNHjhw5crxhYBlbZqmIPSqMw4dpVSwrXNZtxVYQKvVZNrEdYAmC8NmFfA6YDTg4N3MhlvoCw/6gkMr9lR1U2TnwdwARAGNXNJ8qSPHhGfYSUAFD7IfrwUdtQgxjB/yQ7wWQvKiSgWVAEhbSIyzAdcy8JsA5wGVACCre5pFelZM/2C4o1qfX5Vvoca0AzlTSsv1GGYdOBkJhU9WZp6+TBS/ABTgU7EFWK8Fmc4DlBjYepiqGHQFhENmW+dEalAMgKs2/GDB5Klcyj5Mcx1kOoMlU+wIw6uUeOqiHZznHVvdZgWNVW9HAYiPj1FL1fL2rZX8YZb8fpO9Hud4dCFCmCYrACEqaFkAQx4ZKWO0GfAs7Ve2uPE1YlBaGUzh52HfS7fdy+eChPHjlHv2Vu31HCAQ4DdiI8QnVMa6j73oW9oNVBhXXKL7Iwo7qr6ywtqAWElYgUPb1sGdg/43af2xDB9+w4oC611WTqkxMx1uodhcglxd+03HmvrbRaznea7qlP3hWu/I+PcWrkKBTmOjnVHU2lK7w/wbmNIVp6pVscwUPjn0bh7RYAFSGk0HiBx2U9GG+2/2mReqCtBJj2MFy9GJOloogVlUlvlr8QmUPKxkomN2/XOePQjzvF4eHUWGdyv2THQv2a1ehLpTICSgP89XhaJRA3xCnvRFTDbHP6TFiUBVZjV6OYydzf5Cx2+lOBdi5YGWqNbHiPvJEgVY4cRy4JYBjnX7qtY5Hql1pNXGivPVO4BoycZ1B4qvvDlyrVu1RqiMAvRV/ZLJPr5x/X5ldhivbXSGLNmViTkGqrv1mf4KsDdZubxvzO1ehdrITIXoeqRrZba5NYa+KZxqk2FkVWge4bDYXqoCOuxEsvxP/brtldFqqrQuuFTYVzFqxvfQ+Ma6C3VIAy8nOF7a9WmToe5XCdCSpPPGkc8PKfJq3vovKfZeOplLc197V/nHLgQ5H9NlS66wKZLyJ2S6gVZIc8WeFnSJxFOLtV0G6JvyGCUVikfTAMbxgZ44cOXLkyJEjR44cOXLkyPGUwDKgGz/+G+hT5ZRXvPeKaIkC0dSNxAKrUmoot+xYBFe0tzALCPwybJk2NkDbjcqghr8GH+wV9hrzkaJcxS3yBB1Q0yrSoMPzMarC6Is8Aw7ah3LCbi0rRxivHWoAAQAASURBVAtVwuUILRwSlAYa6NUMlfDqKA0K31mhPSjzvLBYWcITFbYTtWw2W1k3jcIvwyy+fdmBgyvuAIUHFMqzrde+o923UwOgQ6U5N5Ndh4JtvU/f9qw/r1wVNx9lIFwdpSlLgiIqTqtaX0egMNPTtq4UVqNP4FPb9VBcTwRH8E0+P6gCG+0F2EekohUKeWc1OqHWsnsD2mFV6xZseHyCMwFqr1Zytq2pnMYxqBCkB3VDlSAAeQl1XTlLUcIGYJRh6lRNTMh7lPsvvizXF5dycf++PHj5ZY4/9E3T1LLZrmW93VCtTQUgVLmw8QA4N+g+CKTYgPTo61KtMqz9VLHu/ZoWklRbFIxJ2iikalU+zwv6mRLXRvHS8iLaWSx9WeN2fWV3iSIznD5acKgK+qQwWPIIZSGTApcpNI0F9VTluQq2MK7iVVUjABjgOWGwsUCCeKwDJ8JdzmU2i/k6BxWz43MoRa34JpNTiY3Aog1iMUoHyJhHjT1Q8LFtWqrSK6iVT+wvFtkIQvSoHg/Li4PG40TP5Wh9gb8h48FFxaxo0kKf2nYL+4UTgJz+JT7MX5uWHwC/mqiBQnweB7bvqqh0FwVB5lHtKeCffIQ5DfpAPdHJHQ34M9lFtbCrehXgKsNXEA2QzAftFpBAGmXqe1Ui42HFQAmtoVQfoaY2n2Gq2jXTMRdqTaNrZoS3nhz05lMl70qOwdPb7t8WJr0WXZP1lQqy8TSueT6GvO1cAB6HcFxDPWNp/t/+puDWH7bImzWHW2rcbPfgFjRhTeb7gdno2HoQQHUyn9Pf+IXRLsr3JSzsZm5yt/Ckkxcg1GGb/tmEyKpi5vtgtJAhJAYQVyN1ex9WfbwXNQwWNf6+Zt2i49qM0nPkyJEjR44cOXLkyJEjR46nDZYPUMox9JM+FcvmTakF4rwolVccAhiG/QO8iQvaNgTxIsADPoCTDahfsaqck831ZANmXQG4waJPpjzDlnmKFKPvsqridAsyYBSC3s+EAg7KoDTV7dTcPk+oCzBRsECT2nkYPuKWZ1NJGjQGvGyqWmp4BfOYqiIEEIJqGl6wVYXCV5M0bSPn87nUVS1lAasHnCOqyBwqY5s3FbYDLEVGOVLOC4AewURJ4lbKuqlJGtq6pEJ5RXgMYApYZzDBQY1trwZU7g49vY2h2kZbQpGnwAjdD9/jlYyNSFPDqxmqXSjUe9VZzvC6XcmtO710fa9WEmwzKBOhoJ4EzggN+rmFf7KC6E0DMHOUw14VkFUL1Xch23Ut6zVsOdTOAL7Im7aVigAefexgGQAYauM9t5131ygMOMv7/+8H5OFLL8vFw/vy4P4rst2u5YU3P084fX7rTDbbjVztO5l6XBvAOvxko0c0LS6wdZ9geaZqT0sEwj5ElcrwV3ZbF4WtCscBpFeJBYDbWritAsWPCXyKHt7R4zXA3oWoNlFoJwr+NFLYrRxsYezKBA6xrR1KEy3RgTUUBLTt95ogcsWyQzUFzmSyUDVDmU8Vq0FtKvHjOPObgHpf7UUAuVDELbkwvlDhMuaTX7Tb3wTWaxYDUMTjnEhyUInf1Cwit2kbPtZtw77G713RvWgI8+P1InGqvE7CQTHzAMOyOhoXlUpWKIw317YGRYuT2JKvhihPbBgCWMZYgt0ClMo4J3YI9FQoEyyXuBeRqob6X1NdrvDWBB7shzTdpOshlPkAy4DrkKIqAMdahadxbYZ9DiC2QVGCZdhgdJ2MCdRWC4Uj7TFmK4a6KtTLmHXscFz0f2mF9aIOOYBO7D4IvsIA4qb41sKbbjuhOzpCItLGeZoScGMXHgGAmQp3fYRNAT5uzSuba1FSmDKqlD37qKA4DhTzeA6z1MGvKt81kam7SXzdIMzFfXqiwc+fuHyEa8NzzVc/GXRJ0iPF03E3g0+ZcH9JfiS8KwYrDE8awH7DEq7cfRPnpueDwuYBrh+uaHbLIbX38ffeHDly5MiRI0eOHDly5MiR4+mBZapVDYpSYQsANwVvyFC0yJSZ/ohKYi3+Jycf6GMxwNRCw1FDAiC8EJ3DM9sy78W+4nn1+A6x8WGamj6TAAdVmivh/DM1PsFTqelbyV0RF7f9B/jhxcvgy2nb8mnZYYBbIQ9UzqpCdFgR4YEpyEJBOYXgtKCl1axtxwdIAgg2y1YtXKhKcLpoQrVGJTnU2GjfqBqEtzGODb9c+CAPfSl9r0WaKkC5YPMBdSa2hAMQGRAyxSquqSR7W8mhGwib6ZO8qQL0JrYzeuH34lYOOFbbqo1Cs1bf0rZGskH121QRw1OaUKuQudCt9z3sTYpCBsD2XiHobg8/6kkewlv58lK6rteORrvXFX2jXUEJDXKwGBkBlqHGVAjWmeoXNhxIVgDoO2iNW+v9+0QFawMm/MbAUmp7HP2GE9uAxFPYDxtHt0FKHxPuXWtPSOdFVB8vt7jfpIROx2mqbF5eY5wj8X51zKhlBnioKvx9vqaFMVXMqj7oAaGl6shQ5CwZ76Zg1kNFqwquAehzKubVX5kJHar0FSwDJrfrhj+rDQbmVbQLWfj3cn77fapv8iMtoPJvUywnxe8cOnrlt8RuIfZz9DBmAsgHw8I9Pn04CPU0hLUllcmuIEWyS78CGqrdQ1TFKqjVdTSownHttL+Iljl+z2qdof7sqh5G37pCOfZHmiAIX21Nx1rorYLim0imQJluGluRGf7bvpCaZc/iq9pbRJoaCzhGKKtrsaPeOF7jQCP8DmNc/YNVbR9He0jUmVI5Ju48sQPP+YCnY397YT8kFbFrpIc/vq0bloyK1+XvP+mZY/LGx0RINoU5kzzTJ7zfW3hvC5Ms7ByI60K8Yp+fx0K9tAni3fPZvcsNznOnTrBECi21uCYftdkKI0eOHDly5MiRI0eOHDlyPFWwfH+3V8BBL1sobGGNMBgQjYWDgD3pe4wHt8jj59KK5EVQHKCkAU6qRw1eR7VjBJbcYmyeyoROBMumpSQ8wjnMk9Q+rE+Ai7SqqG3rOFRaWoBP1WXGlvDcXreDDz0K8I3Bz1XhGe5ZFZ304oQSuyzliGthcTyoerWIX1kO9lr9HawzAGToy2m2CwAlWtwPhdLgdwnlrH6Fhgz3QMUmLDbKFdXG42qUHqpjWH8AQol6AQO2loDK6Eq0EQEU1MnwaBbpD71cSS/FcZBtc5S6cT9V237fVLKCehFQF0UJ60JGFPbbw4MTFigTYd/m4Y6K5Du31rLdAjBZkTN2IrZkq0IY96PQUZXqZxsFgZt1S4UyzZvh7clKkAMVsMPVXuZ6oHUJioWNdSf9fkdmBrCMdrp/by9dN8jl/Xuyv7qSZl2xsFu1bqQ9X0uzaWUujjIcYZ/RSz91chg62R+g0JykO6ga/XpAQgTH7PnzMK9kMMU6fXrtEcahGx8HuOgG3PoMghmDZsGmwv6onq8+Hn2bvz5XgaIqpRnorwQGq+WGncOgWgFcxP61JAvngCc+Uu2n2adYwcww7+yuFEdpYUQq+e3+OG7owzvICt1En2PbT++FLgtVtyIRQ+9wS9K4hcjk/tQGxKi0HlcyFRMV5wo5kyJomNNIwNDeRcd6+jg/O5Nn7t6VzbqR7abl9TRYW1i8DwXWLAFAqwnrEfNydmWyF2AkYDTf8QBllVAG5bb2sQHJKA+PxzNrC4LTFc5JPW+iTvav8/LnsLvDvOGhNrZklxJlGuLwuXPfcfeEK4px2oYFQlUdrh4HIxpdwTnV9cg86T1BWQ4vdqqQUUCUPvUrVeDSTx1q1yBZN8itiTBaE9l5UfQS11xVg67pbSM192pg3GryDMVD0Y/a/+rfi7GBsar+KDETYyMiGaOwzNGComqrxDtMVOdRYc/xYutxfIfQoppM2iFBxnPq+q7K6mBKYZDWiqL64m99irYYh0nXyqudXF1dq5f7AYVLdefNkWpgm/r+nuD1A29IIgQgHeZvBLlUlhsw1x0nDrCjN7PuUEDC0Y5lCSy+h6B/OXigysct2zy3W9LEiCZJ+0ELMGK3Dg5IhXsY49aG3LmTI0eOHDly5MiRI0eOHDlyPEWwjK3YCoP1Q34s2Bc0kKZldquH5T7h6O0aq/NF5Vf6XIfK8bVRgbn8ORw5VYSmL0p+CP9823DqP0v4Z1ueH9kS7GrnRFVtcEElv1BWqmVGujU/VYymW/3dHkDbwuwNgo1tcs/2fEJsbC8nxFaFqAmpFYjbdvEiURdGMaur8cw3F5B4AnSe1GbEixGGg0aOFjybWWxKpO8n2R8GWbeVDAMUc1Ax2j26OJCQUq9BEwDqx0yrEFOjkoehuBdUwwNAB65lTzgCUwoAtQpWHzX8j48BLF9e7HkNgMu4Fxy7Wa9pOaK2AEUCNhWqASj3+DrM0g3qq4yvCpYB9WcZ4QkNGwxCukRNav2RWlucjr4AcRLVso/1OH6jatdhZVTVu/1Lst09zIc4DtIxmRbyi/PFvgsK1AiXU7CcPNFgm3qKqwWGnf9kG745JYTjB7uVhWo6bhFI5/jirGR5Om7c+oNgmeMa8M6A3anHclXR2sYfRIV8/rKwYVRDJ14DCZAj/rtBsRr1pTcAwiScTTu0i7fuyS//PgHKoU0BMw0eRk+UqKplIsOAK+ZsUJSHgRJ2CMT1wzoGOxPAldkYXhQvLYqaKJHTu0s8xMM6hevgic1v25XnWAO8SFxyDYmONloqOCzl+uQXlvRHMnZcBe+gM75dLNdaV+GmyvJHOyh6K2PHhum8k/m6vFY9y+l7lK5N9Fb23ReJqjndgRM9zP2KY5u/ep1Hv47oQx6bZjn3bn6dD4U4Tr2ugOYo0jVH4TVvMQ65MJBNpB1+ftVLzpEjR44cOXLkyJEjR44cOd4IsAxvWgQLtlnxIFeZ4fv0w7ZSIvVmpVMtAOoAGwbUetMP7bqVXRWSR6hco8BQ1XTc4gurCfgXwzcZ8DDZJsyL0edSwUXFokFWnlvVu1Bl+odmQNoa6mUogVGoyxw98RzAEHgQ46BUKvP4CmrgDevqT8AS/AyPY76+WlGxDAAGFfOp9QDCARjuQy8bkDcBn2gDeF5WaDYtWFiVFf2ZodpctyiJB0sMVYOrJUZBQLrvJ3oSo3/wO3ovmxqVNiBlJeumkqpYyQSV9EpkQF/gztdQU6+kbEqpxlqqGupkqJFhjaEwGucAyLh/CYUvwO4kbQ2v21LunLeqzF4XsqpXMvSwrzBvbai0ixVBNL5OHVTQsxyuD3K47qik211DmQlxobblYVTlNqGiFylzmATwJCu5fWctZ3fuyHNvuiMvvPU5qZtSNtuGUOV6t+e1Xl5fy6Hr5P7DS3nl4U66fpLLqxF1BqWDc8AsVMdTnUv3W5HdYaDSFvfsEO/EwSIWCZOlovhmuV80TA3QzRGOg8EEKFGtDSsTKKY5znW8+nMIugLFcq/XBCSaXYAXL+POAY5HF8pGyKhWILABGeRw2KvymIXf1Egk3IElJFjUDepm2IYA6mMXQg2Ap97TTDChXflQb+0A5ZB0gVK9WFGlr/YojSn94bsO6AoF+yQ11MhQKteVnG3Xcn6+kVu3zuT2rXPO/xYWLtgxMGPMO4g0BapfMBvT1Zyp2tjVudab7sFrySFVrk6yGnorsOdQOq20qKkGtTBxgIw/QmksNpLcU9m+HlEkb5TZvJXpG+yQ2NW6VhiVqmuzt6Gglh1eiswD/ZHxOswP3gl3XqykmGwnRokxWUbVMH3k0e7HsDYxcVbs1E/c106DsSXA/brlnFhBYY5dDF3HwoHB5sgKmXIsMjkkdn+F+v3CXgYPFP2kj7v5YAcP8aSgpBXVg2pZB7vn+LRN2Ke4l5Awofk0X8OVgAVRoyKdxzH1twJ7tE9c433t9nZWGq/Hg8c1V9iigwhcplELJrLApXW+v1/Re7myXReJBY1uZsBcgULcIL2vIfNy7XDzGHqO88V6LbSugDd22AWxkhXGebDqSeCyrSWAyjWsYawmgNcjoJ0N9pFgTtrGIH25thPnagqcc+TIkSNHjhw5cuTIkSNHjqcJll2ZrNt37UO5bgpfaP4cHnio6NWUbKkKLdURpwXN+CJAtOjRHIr0ueLXnqNbqF0JmigNHR6FAmrJ+fyYBlRSdZwXlkqVxaowtS3/iVqMXsoGSFydi2t8tU/pqbqSIM68nF31Tbh1TDyaaSHiD4Bt9UrVK9b70qJe7uUJyw9i8XhPrnymx7MVljLLAgcPynd0mznvIWzHVsChW7FFumGS3WGWth1ktwfkEjnfguMAfmiBQVe4EoxXXvhQzw1LjbGfpetG2e8HOXS9XF7s1MvZ7G73A+xQFFw6POe902u3odXI7btbadatrM/O5Oz2LTJp7u5GocK+J5jpuo4P+kL3I2H49UHB8jCpF/UAD9V55F2iBFbfAzTfpFj3oXCDqi+IkBNlaPhDorRMj5kO11TWm9i+uKdy+pJwjrRgnWdiEhVqqlgPBfl87MUhHcYgoDHV9pUmY04OG5I3qno3uwKbixxBrk52XrcoPhZtL3hOQEvzGlcbCJsvHDhWvM+8lgFDaSXjj0rtMvg8evcu1auWclqsC8s9D66mT9or9RU2Ak/4yETZDQkDHvp4g7LZv5rdBU1t54UFBtdMUnjzWlZz8tjGSf/HazZPbVd7Y65gEBuTZqE8PoWdYs+Nr6etC9qBSSsUJTT4av0Rbi9RLNObl4XdtI312BGgcn2yxJ/enxXY80QjrxP2C6Uc4aWChEXYuaElJhe2I1FGG3YJaP+5rtbfD4IZvtll2CuSMU+lcgDZpgRnYlEBcTjEiXJcd23Q5yl6Gz8yunRdQ/+oNVI6vnwYuLrbfdGXpzpVJAfVf9hZcjLgkp0TcU1PCg+GnTC6m8Xfh7Ce8e/JLpSkPmiyVsUbvSkhmiNHjhw5cuTIkSNHjhw5crxhYDl82sbW+WQruOpI4+ZmVcsaPAkFyLSIEL5Sq+W2psnjdMeyeidH0KIex1EIGirzJQRM/SpVfQzGQlUj+YJepZJSFyka7OLWZ0C2CJSCrcXqyA/rCrEK6fuOSr4ahZ5qqAArFhaDFYP6KAOmmudn4tvhdhgEL2gL96M1X1L8Vj1ES/VpLkqqM6EMbmooluGfPKsvrxX3guIZbUq/XIr7jlIdj9IGDuSqVoUchbgPdBmUbQCrIMdoi3XTSlvP0ha19KtRVkf1V6Ulhoh0UPHtcexBmupStuuK3p1QLjf10ZTORyka6yMoYvujXNy/on/pvXs7ub4epDvA7xhWFUcqiY+Jz/YAtSDA9Ar3Cqg8S1ut2AbP3NrIZrOWt739rfL8C3fl/O5Wbt0+k3HsZXd1IWM/yIOX79Mn9WJ3kF03yL2H1/LSvWvpx6NcHVT1N82AR6tQnGuYR+mngUX8AKPV29RtMWKyJGIgHUvOgFRJvFQkB/Qb6v85mEuAv/0cxgkU2YBgPsbdfuAkGcOfw+ETAO3K6ATUQT26hKl6HLWxwflnOfSzzZPGkhCqSvXprtvrVbVKsEyRLpIJsZAgLVkmT1boPamtDM6l6mioOOEXrr7TqnYN/tDW/21TUaUMtfLt2+dUKp9tt7Jdb6SozPMcc3RCwkKXLgJbU2ZHhwYrIOcevGyzU3n5ylStkPUbdQO0xY6MSNs8c2BgGC8xY3HzW7eLcCOFBVw+zr3Mw56qXpkOZomBe8B2gGUSSvvR1riq0bnHCYGdDSL9HnMFkFjBrCrbC1lhPqMtbI3S6WQWN3Y8egSzz2qp2q0U5ShS1Oy3smlkZpJAPelp7dM0PCZTANgJgeRNIfS1hj/7cTSXZa61UHdDda2KZaqyTW2MpkRRURYIdXjLpo5zKSRLeM3WfN4/1l/uTX4cAaw1+aaZiQQ8W6JmkXQJ4FrlujpfTPkfDo/n6PUVZS1ts5au7bkGB9ukUBjWgbInRD0vop7VM1TkNsTCbh5LtmE+8IH3jqPuTlCbGR20DtxjEUi/d4yXmLQJRXELf6/wGgbqqa6+5VCQW3IkjvZEeO8WGlaY89QqJ0eOHDly5MiRI0eOHDly5HjjwbJ9IE18LxER2kTArFvqVdEY8YCpHU9FfonXsh/wUfFUBHBUgrkvsauHE8UxtjDPKSBmcakIll19Ga6VAESVdrw/fmh3f1oFbaq0w/bwUfqukwF+wNyOjw/3EQZz2z/VtgAVY/g0HwSRScEmXKNCTFOoAhRZcSUAa7XCqGhtAbXmRHBt7Y3X0f4DkAQKXIXjwSYkcFBrMyvQBmsNFjwz0D4DcEz4W03o3JSjFgwkNPLiWaoRHMDGCJTwnL0cNpVszyoZpkpubRpZN4AuCqAAUKAIHodRHj5AIaxB3vfBC3l4cZBxED7UA9QgGZSotC/RsaSK1pkFDKGchD6x3bZyfmsjzz7/jLz5LW+SelNKu63ksDvSYqHb7+XhvQeyv97Lw+tBdt0oD64P8vCqI7DewYnAiayDZViRTIN0Y8frxXECxA2KWtuj7+MwAHtXfSb+xAZg1bjE2j74iOvrT7XNQflO+wZ7jl2Dgto4L06nY1DBJ56rYbu8KZLdA3sB4YyDrpBUgF/2sZQjijgSRCq0DGAZGtagSFY4xu30IbdkiQ0byzGZpG2jCRQ7DvsYCDMCaLp+mIoZdhfb9Vq2m40+thvZtC2TN+acQLCMMYa5p8mkeM8RLls2pDhGYJe2vIE1tV3wjQt44aBgmRA0FjTj1znOiSB9NfCvlhf4I+a8K5VxDTheR2uN1dTbeWuDqzquVNSdKKfRB9x6gO0IKGhZ0LUBvuaIsrEdBWgEwmWAZS+Ap/3B9Yd2Q1VQi2sCp5KybvX3DkFtcPKq2bYoiFjxFmtXPKP4m/FX7QcH6RgTuGddR12VHZKCdFBROEpDdlfoJo/gzG/AnOt5GPapKlgTanqQMlEtezIhmSZJsTxXM7PL4gwzOxMvmKnJLbRLVTdcJ3WXR/TH9/cDTyylZ9CdJ2aj49dqF5WC5ThPDDrbsPHpy+MWS/VyuJVkfWAfczcH3nd0vqpi2a5z9vdLVWmniVyH8G6j8WrrS44cOXLkyJEjR44cOXLkyPGGgmVVNynI9OJQGv5hG5F8dE9Ulm75YM4LBuf0wzDhUygCB6VVqgQEnFBY4YXFFBpFRWhQgxpkoYrP9v3S+iFVH9pDt+UbFOQ3U1Czqe2GQwXbTm7XIikYNisLVykDxiiwiSpBqiVZCEuCepM+mrM/VLGsSmunfQqb4H2sfrOFbJpK5kFtAAAjRhS1g5et+duqeFuB1zDNUro36KqQHlC1n6Vu4a8KIgdlHrbFFzL0I4GINTaL3sESg7AcdgQTvFijpzXYdTfOcrEfpZtmqV/ZSduUcrnueJ0trpdq11km+CUPk1xewvZilAOK8NGvVdWGBCMrBV9VAyivqkmAqIpqPC3gd74FaGzlhbc8K7fht3v3TDZnDUEVQD9UxlcXe9lf7+T+g53srvfyYDfKrpvk8tDL1WGScRY5QFiJHiGQVAsRDDW2JQr6ma9tsI0IA8wk4EEmHKGpj8eZ6N2gWFph0hDUMpeSqgMtacExaXQvfV5KodWDYmEboK+3awqKzaRgWbCT0edivHlyBw8MfSQkMO5ZxK9w5XyRFDFL/C0ICW1OeaLE7yTYHXhROh/P7nGs/hcpdCYQw3hUp1/t87ahB/N6vZZ1u1aVLDMWBpaB7Wndov60AezbfKRSU8mztasqPkNDhjXD7Gce6ROjcIC7OPaE0o64HW07md3DF37KTrLduxeKXVMsozOZRcHAg6GtgVRCTm9DaxjzE/eif7Dj4GxAW+NUTSX1dss5UrWqLF+VarVAOGj37WBT7X0DNo1e8Sim2bTRP5rq9ZHztWyQdFK/ZnzllU362uAnzT6kETitUzxJmLa/5gncisI9hFWp7gpd2uc4pdaDJuMXBR29owywm4yZaxVOb8kEHsLU0AsrEYzl0t9HzCLkZPbFn9JxkSaJ4lzXeY3btuP//+y9BbhsW3bVP6v23lV17n3vdXcMCRACwR0CQf+E4Bbc3d3dXYK7SwjkCxr4kASCBXe+YME1kKTz5Mo5p2Rb/b8xxpxrrV3n3Pdup9/tJN1rJvXOPaeqtiyrrt8ca8x1vDUyfUq+Rd6BXet9ETs+uO5zh4mOz90TnlTUZ0kUdyzzQ+Fvna84+hxzAL7ZUCwzAVckZ1XUVXN3is/O+KTm4EB7YdZpreZnR40aNWrUqFGjRo0aNWrUqPEiwXIrquNfWJNWU4rTrM/0bbVZUcxt8aGeLAFUfNFNX5TXtj6DFMzcmh2whN7M/iU5qfLwBX8BBVTkCEXCWMQKqi+/bllwOlSm36fwhb7U43kUn4IiD8XvrADLUC0KYoVCjdCbqk1cj7amAyo3KP7UdNo+HluoWWRKKtCwVcCXfPraslCg7APgCSy440WbCFcFlrfdyq42jT3cbmyGZBhwBurF8Ux7B4JCQFIv6oXjn0bYDvg2+DVAsJSEm6vO5hb72Vtbdx3h22nfW38a5Ke6HgmCm66xdjrbpoFK2ov3eRE/AFr6JLNg4Mr2/WBtu7aHHYDyyh5uO3tpB9sOUctxnO3JdU+bif3xbKdRxfjk5wwP3Y38k3dqZ5TSg+My1dObzq52W3vnO1+2hw+v7MO/wpe2d7zjZfugD32HvfTKlR2PRz32R3v8xrXd3uzt3a9e2+3t0R7tR7vtZzuOk+2hTD7D41njYLNRcUSB5TMBOe5ndtiPwL0lYhSxUOUHVPUiloRuguUZ+5bAsJRVplJoDoalcp+5RV/9GfOHuHnBuLJJQGzwT0AJlM3hMzf9J8ArBT4C/cEaf2fh5wDrAIewUuH5UmFIKOKReBB80nkzlS7wmxi2J2A0bwACXYUdzDssL8IXmJa2a4FjYrnJNtuOQPnq6soeXj2gDQbGCAvBpTp7UGZKURoSfqpliz7BCUKtm5s+q9WjQZlooad6xoyy/8AgkGUHio5GZkVrQFgWuD+5ZqBnjvBT9hVMho29GYru8c0+NmLbgZ9Nal6pr+lr7p7MLOYIiNusbL2FinbDe95e7QTMAbyxZqDIHovmzfxJFTiKJWJMiWgmHwX063YrQN3BXme1stNwoK0O/Zdhq8BCqYDXUA77mHL4HAVbudZCJV2o36nITQp5v6cGvjixEGPt8s8H9iGpsMugI3Hn48Tl0UgK6PyylsD7vaqhTD+4o0FF/NJGGleqUznNjiyKUdImwpOevrUjM+JImvgukiKPpJyIJ154/15Y0RNCBLRs6tW9YJlJRELeANfawQNLjLDICH9mTbFy3GYrnDJH06AQLXe2wI/crY3cjiaU0bK8ke1NKKJlm9EoQey2VUyqVbhco0aNGjVq1KhRo0aNGjVeJFgGAOSXZ//iH196C9FwVi1fqJXDRoOvKQtBlTvUk1gre2EsCwqFz2uo8PyvAL+xjdqVmWlzeWwzL7YzZ8Wyq4fxZXxCUbHs+RpF+EINKhuPcsu9F9nDgyA7ChTqmLLlKFqk9OtNCj5txs535z/9JWBpYD3wpe06FTQL/2YCAVfcls4kKuYniAQ1W2znBzjA3/FoqBjUczzWeWVDP9g091QVU1VtKAQoT1sUSmNdQQfjUluqaOOxN2tGADRADvk8z6Peu2kBLQC5GuvwO15X2ELgvrYdoPzarq6glBSMgkoTitXtZkNP5Xe96xWC5YcvP7SrB1cEX7iCcUQRwKPt90e7uT3YzS18lUfb95MdBjxmFgyEghs8anLIiTZgacKwE/Et6kuFvQ9Jv+/SMzk1dn51keLIW+8vYwFho6NDnh6+qtpXn4/ndhsEXmEjk/b0F37KaNcFlMse37GVP43bKGLHomYxAgXmSsFkUn+GKjk1io91AvXStyZAaUzFrIqNVgqsPheWOjEvGyr/806BPDHycWL+cuyG93nMp2j+MpGUFphyll0mDdyKJvsRLFTiYSRRdmm0C1ckzglXKc+Avfg5FcplnTcr2ZdtJpd0Vyqny1TagKreptNOCuBLLzDI+3XriXkcbQ4LF/wNaxKUxitARrzO/e65bmYf5ugNKXFjLVPxt5QL8X7OVis5tcE1N9mQSA1tZ+084L0ygZCV//pXWXAxEiuxRrrimaeXh3US4vs6TCW02+fk97vOucwkZu8WzxIUZ40if0U3l+MgCpsm2+xUADY+D2Lulbt2ZO2Tdo7wtBks57Ef63TeZZIsnNI8y+2TPsMIxDWW5D3v8zt3SylzvichpgQoWzasbhy5p8/vu/5TNWrUqFGjRo0aNWrUqFGjxtsHlruu45d1FptzNSU3ACfmlgFMQjd8fQAqqTq1HT7Lr/iF3F1XpbQUMAt1c8BewRTZFgC4ns9QJPoX7Q5WFAIOKuzkSsQZOMbtCZLMU1faU8k72+k40G8XhcOg/GIxpPAs5bUChEHVTLSqHd4oJNbBr7SxtoMCd8UCZbiJYRhtAOxxX1mozwgqEgRsEpjNBafiIWVcsz5b155tuzG72q3tpQcbm4aRfdC0nU2Awc6ueATfig1F3AmQGNukW8Bo7MaHUniy/TDZLVTPuO7thh7O6/WJ2/Rvbm7s8dNrm1edTesrHhttAVVbfx6krLPZTtOgAnOuqAPAxfV3BF9nu2rXtoOCedfZB7/zgdTMr1yxn7pda/0AZeVIGIbj43Vd19jLL8PyoLGW2/Eb28IGgR67D+xDPvRDqGL90A/9EALn9QoK78lubm/tC77gDXvttcf2fz73Nbu9PdmrT452OI325DTw2qSydgDqXgrTuZeinWr4SV7LDpcjSgAU6twSBWrAhz3EJWj2gl9hfeyWEXyGzNeBGAv1hT9x9gsPcBcYL+aRRPzy+Q01s0+A5M0q+FXYuDiciyJdGNuhXNTbMb98V4EPJpWSzNArfgYew/yEmhnWEFCtUgHJp92jnLceprZRTFMK6oBumCtUoPpdwp6kazBmkWAoIKyvHSXc13yXYhnnj0J1C8DN7nbvjGQxkwvFuRac9yKlcYBgtzlw0Mk1JIHVUF/nwmp6D9oCqmZ41Byl7sWYg0+NulIAvIFCuMiJcYdCWvVyr7PDocJ1r+UG69hoZ87VM1XQtDo4nahQHvsTC1cGIMb6BgUyzzMpcUSczFvAnPb+mdw3eY31Rj7n67YrLCGQeBkIr10Dn+EnEl9tyzWR50Ifwr+9wfvVhgKtrpgPu5eYZ1TBa04Kcp5tGrQer8+wSPIxzIU+tPlnFSnkfeBjy/s9FdELGg7feFcrJ64dBfhCoS61OryIdXwvpoqCowMKeY5e0FX34dL0NLuT+N17MGrIslDlOHHohV1MWO6wp73vwwYD5wult+xE3Lu99HWOz0u/AKyX8CJv6YPuUPpy5Qo1vY8HtCn8mLlfZAWv/jX08CmTtPbP4Bo1atSoUaNGjRo1atSoUeMFeiyHY2cUL1votfjfpHlyxWK84g57W0TehrtUbDn4CnWWH+lSa5aLKl3GUg2cVGcBiByaJXB2oULNqj1Btij4FDYFC7UZEJnbUhBYpkMkAhFy01KbnBV7fg+pxRwk+u70xMgCa2n79EJIKOgUvrd+vS685n8APWBJ0Y1e9K9ov9mByIzt8w1gR1lwLg7vStO4lfJeuRkcLENeqG0HdTPQz0rgfb2y7SxoSDHnaLZpWyqVAZh3+Algsu0IF7e7re2udvYA/soPdvTdjaKD49gTeB0OJ7u9Pdj+cLTDcbBjjyJ8UClDmQ1vaG09D3sLwkn4/oZfsgPDS0V96q6LEXu/nq9UMt//ijS2vU8XryzU+s8b4WObriBlcqKwYEApnakUIy5GXzE2dNziWpPKvrzQkHs6U3VYmUXTDrJ9JJeq44DDJWhPWmlPHkG5ToU8i6gV6tDi4iNBtCimVqqT3Xon/60ojJalmUUfZN/0ANC6RU+SlRLuRaOnG/MHBjXA6OA/g/7n62BCLYB7pAzKBTLZn1x0brDmUIk7BD8jeQUwOTrEhke5hMr6TwlyC19rWW14IdMYfCnvlseqPJv9XElrXo6FZTPE2BN0LVTmhSo92kIezEWyLy38hY9+qMgj+cbXRgIhK695jhhX8b74fEoL5IWSt7ye2DUQyYaFCjgN7stZmD8N0tiOegFIWAEIF23pbZ3SM7HjoDhSufMkPhfS7pli0N0zgov/5n5aWNXwlmWPVFo7qchiJMkubrFGjRo1atSoUaNGjRo1atR4O8FygCJ6ECeV5SVMLoGkQAPUcFIuh4WEKGlsEQ5VKZVdrtwKsArIuNm4upZcYWVrqJNh0+A+wikSDcuAIuwupB1GsThXyNEaQZ6dYicqLMVrpHJtFHuJbfi8bik1YStxGga7PewJXh4+2FozdXaGAtoaKZed5UBFHaBK8ANXUSjwoL5sOqqNm6aXem0a6K/aNg9st4WiF7h2tOk82DD3Nkw9zw+QOrHKlatQCWLglQxV68o269m2zUyLgbGVp/Dnv/bUXjr09o5XXrbdbsN72m5ba+F13GxpDXs49jaeVzaiTaisE6AlhMeNpW3i2YcD189d6tPKJsCQyWyPdujW9qEPtizw99JLHdtr5X66XdvZg6sHssJ4sGHhNhQWg/Jxu7uy3dUD22629vAlXFtj43i0YTjaG288kp/yu1+zz/3cL7Cn13t77cnBjv1oT4+D9eNsJ/ShF+gb6HcbykOMX/mRRh8lX9KAbglZZXKWHcEvcXu8Om0qd+tW9EVOFATspzVIANwE/GAfEIrYXOjxDjyOQR7b+8NywCFythzO847uIvRAz4pY8WaHiRI0uie5W6ukk3mRLxb909yF6hF/HlU1Ms17+mRvWs5j2I+cQx3tNyOP9QzZpNim2QOvDwrMBzt5au92WyrTMT4w//XYCJauvfBbqDH9UrWrITIwEObCa913HTD54dD2outgBa0CoaNUrsErKT+Vd3PYPcDLV42sn2w7Fv8cbT7eEiqf+1v5HEO5i2uwbO/iTgSytwhHCwLIwrEj4HZaw7TIzigCeNwTJEOhjGuFwje8eelfTIWwOpUKdR7TFctuHaL2h1pc1gpSrutcTDa50lnXN9l4PPB10gXHzo9of3g4w0cd1wH7Gy+W6Ou62keeyyxuGpmysPegujnAttZCiYe9kCTBuXs2e6HG9WbnKvQtP7ro18xr5mKbLJrE4FXoT+3raUH8zZNf7D0qo6VqF3Tt6N+Ncd76I8ZDEkS713MUt6R1EJXbsCqZbYIFCVxCvB35uZZyT/FZqEKLYcODJFxi2P5ZC0U43+IJRQ3jwr6JCSbfreD+0PJTVhFW7JzB51VAfyZsVg3viSp/zFPOayVCKleuUaNGjRo1atSoUaNGjRovFCwzLoRml1qqgA9JHFio5bJ3amgaXQXrCtyk7Cq+4RIgww+48MSUovGumjbUakvlaVZDh4ozMZxQHS/UmaFGFmAklEYNMhZryq/HFuZhGOh7PKB4Ft/auRA2tsCHejSuJZ9bCsQA1mtadhASUG7ovq0E6yrAJ020CgbClgJWEIB7eRN9nCEAs1tq0JJUwBJFCvfHniC392tnG6M4GEFZw/OOg6Ds2YtqJTuE4iwBSuUz69u8cQe+C5vFBKkaPrMQ3qaT/QfBOiDSaqatx9UO0Hhtu11HCN+5RcdmhyJugOqAilKsT9PA7em3Nwd78uTGnjy5tcdPbuxm39uhH+00TFQr01PZr0l2I/5gt2QLiPAiTt7Kl2QlybKfJUYOu4esTC/hZYbK2eZCrNBfHaLWxnWjochdqFgvRk74L4dEs5iIpYVvMNe7+visc0x7DAIkegImazyzH3mAblkJ6FXqT1c/AvgicVMkoGK3QWzzT0rKhSLU5/kKCZaGvuB8QJ0e6mX3XpYjgZS4ydrmskcSaL/wMb4HKqvfHHRz/ZlshUQNOyJD3aQQLxW36fShIu6VEIJNxTzmQqReyDT1UYLhd6+xvK7U12EhBEAIWwrYtvQ4x2zTqDEtJaoXkQt7lawbd69mL3JH0Bi2FlFo0ter6Cdea1YsB5jX2htruDyQdZmRNBHgD32zciX580D3r+vI2zAu+k6oXn3myR6uTU0UI3UwfW7tjI+u8yg/a7c+yurmWPeRBMO5lqA+Z4wK5bMPkvR/yU7G7WViwoY/fqHCz59deQcEmw2fHck+xAvSph0McdqsOY5xFj7VhcB+IexOt7EczkkxrUROUZDP7VgisVveW57zFS3XqFGjRo0aNWrUqFGjRo0XCJapfiIEzh4MeVu9fwl2chSwI38bzsBJNquqXI8XAJCGrQTxgAMYbo+H0o1SP2OxNqgg4cULwJEAJxW07u8aF+vnpFcnoBdhdH5A8cdt5PS4HKXic/pDJdkY29f9S3gCDVKnHg6jnfaP7MHV1tp1Q1X17oFbPRTgIkBNwD75m7pKknQgAyvcc3Ne2xbK3XZtV9sr222v+LrTsdejH1hoD6rp8PUUOBNEicJsPB8VoWfrXHnXU4080L/55vpk5/Fs27ZVm1IZvqXS7Xw+CQylNpZ3aAnhCd4L7122s62sB1Qez2ans7VPZjudOnvXg876q42945WNXe06qpe3G4BEnRPngV2GoIdUm9jmPwKgsZ8EaW73J+v7wT7n/36Bvf7Gtb3x+Km99voTQuXHtz3H0WGA9YXZiHFKBZ8KDlKkCBVk4eGN8/vIzD8SW1Efkp0tvQkKiqvjhUVMvDfGZfpTsc0/WzgEx8qeyFR0em0y5SaKZEt6vy2tCthcM/1iab9QXF6Gy5hXy230eeO8xhH+A6W52J5La8MrGmpYJDEwTgjworilQz+o7umxiySL/GXDcYRe3An2uqWN+zljxmEsN6uV7TZStCLJAOXzZtOlR7fpqKg/Q1m5mnxOu9LUkzDnAqiG/4u4J3YkYH2YuJakQm5ebC/ALdc0vM7nusCv4Glq/6Jgn8jhYPN8snk42rh/6jYYo9TGkKUjceMqZPJlKmiLYnI8J4qGau6rEB+gMSDyiYB6HHomVKioHuSxDLUqCy9SFY31cGtth6QWxjYgqjzgtd6iZwRFed3jYOv+oLUxjQqpr2Nya5eG2oc+voCjtM9xiwcmDNzKhLeJxF/jvtc4nx9jxn1jh0JZ6LWwLUkJGLU3CoqyWeDyXeyKUBfRCNiV+gDLSAgx4+bKW51LwLrzPiyKTgagL4lsorZr+r5PQ2/D8WSn496Op731/cEGeGZjN0mb/YzZDikpoM+DOAX6ZrBBYNgXeyX2XBnuoDm8lXkZTJ7oPuWFHgpz34ngxQkJyO9ZQ8IiHH/Az3FSodLZkyTYpaPimDFvCmidokLlGjVq1KhRo0aNGjVq1KjxgsEyAF1SahXFyJYKvmWEUi1rOrPHZADK5HNc4Cx9EdYXbCgZCVTG0dYdtvECYDjbSduJcW1epCuf2aGftoZLlRt+zYJbskAA4hL0EUCCpUOhXqPHJdyCsyXCqT/Z6faaoPPhQwDgydbtYN2mTcCwVIAlxV76S9aIhScuQM8ZSl5YQbSNbbqtdd2WcOk4HFlQamRhQG13ls4wlHhL7XLgMdmHqHAYtqz3w8rak9nh0PO57iEK/KGwE2D2xtZNgLToUykQaRviiYBQgSc1JZMBfrdQUuPeexThQgGsya5vdgSX73h5Y9uuo+XBgwfYbg5lahduEAzlBwQCUaxwHgEfsbV7sjce3dCm4/M+7zV79bUn9nR/tMfXe1pf3Byh4jYWHZTa2qFyYa2irqUGMsH9IjOyTExcjuELz5fAyvr/+D0rk4PxxngvwU1yCU/zYqlOpHdxKn4mW4nUz8VRcH/rKPrlQDDrVEOdn4vepfFbKLHDh1UFyJCYwNNe5C7gKotPqjhjKiqGJIjbn2iOoJCeK+AdeoWjBOdCAmU54QI7ASSLBngDT5tUfA5qZdqiRJEyFIVDobi4GxzHAVl41S7bJxl+uLoz+5/nAm+R2ImiabL2SF7BBYBPUPmiyB52AJzHg83DwcYTrDBGrTPJh7pUPuN3jEx2rI4Cm4cka5U3xjygWN5oUw/bi0FwczjmLRV8o8A/7YGazta7K2thEUGi7+1D5wg0KJZ3JNJc2Yt3DidvobBRiKZy6MpGUJKKCYSzrHgAd6MMJtdDLgjqC4wFWPqgAGAKJOxGXx+8n5IqPyT7hZ4e9iX8CXiNtnLVtOZBGkxZZRx/p09LNDfWeRRJLMeDIHfux3ImuskHEitcX9HmJxvwGJGs6jn+8NmhIVJqe5NmPV0jrUEMyQUdu0wcZSG6bDCUdCyKaIYHd4DlNNd0phl9wVvNbRaJpxhGaArmHejDIcifCiESYF/UK0je2zVq1KhRo0aNGjVq1KhRo8aLViyPo/7h/sfpC26u50XQlHYkp6/eDuACqCSulfYCpzfx6zcsIOgF6SorB1GAhGt+YXagEcpm9ygNNbL8YnXepGrEF/5h0Bf8GVuo3QOUStEgsK5eDLYCtbRvwV4DWNjadtvWrjatzVdbG+kdDDuHl6zrNm4lke87FwFzQJDgjfxNAzCAYcPaQp6nZh3Uw7CEAFBrOqo6hwHeqqNNgMrwco6LXBoXpPZO4OMcyu+V9Svfnj+vrD+N1kI5+mAr4AC43Kp4mkSP7oVdlkV0kF9u944nYNGhyJ68UEjj3I+uj3bqJ+s2suB46eHWXu63VEoDotM7FZ4dq5WNo6tHMSxX8FCF/3BHmP7a60/tcDzZo6d7uz2Ndnsc+RPqbcBlNAk4lphekSTIe84X467EWveiFX9BxozZ4kF/LVTJxTb2/MNtFFy1m+ZMen9+fUCzZNVSTpRLIXXoLFNRS/e1DVCcXuM+5A6wcx00JQeQKJIvsyYwbEZgxcIeZNd7UqXYHSBfV/foDbWlKkZmO5qAsoUzDXF+JCziOSosJxuxOwDu5G6DAZgMqKwuizvHfWRYGN67snF2b2FKy91mIVTBeRHyORkwbuklgHVCSl2HlklVW3rZhC2FfIhhfwG1MiwqdDqp8AUSpRIW0IZ/M8x2objO42jtxemUIpLtxNgfpUSnLzj6A22yZR+gv9kOVAhjrYPav6Fied3BgzrGE2mrj73IcPh5kGOCtRCVrwHnQ3HrgH0d1h/hB322psX6puviGkSfZreS4GtiZ0ekEzwx1/ia6EpcieOl4JaiXMpjXYcnWLDu8rU+mZmQcLV5kczibhImHTtXDWtHjXavYCgoGZjXRyUXignqnzno0tH648FOh73t97d22N9yJ4uO68VZ+e/COqmwbsJnE9ahJqxxzvDz19xJwzbtICjX1OUuG32eyfopEjHyKscODJoh+a6h2GkDj+cVcxb4+zBCDY1EXPD2bHkRHtixA4efqVB7e7HDS2uNGjVq1KhRo0aNGjVq1KhR420Fy/ASjjfIAiArc0NXWcpyQ4OcPDYlw8yWCoWiLDu76tiwaMA2bBXoE/wA2KaCzqEFjkd1JIGy4KnAKCCR40C8/xxfxCep66B+hNpT1aVUT4r34tYAVEFrq3gDLR7gJv+9tocPXrbt5gEhy8pgzbGiCldf2FsVnQpA4vfGe4rd0L7FWwW0BFHhCTzHdu71irYaV7srKpbbZmN2PtrpONjpNMoGY4B60OFK1pv6OZNpQkLN6K/NygSW8duEOmAD1ZLzOx9KfdrKYgR+xvBEJqSjB7N0jRIzFn7Fl+CTZANWAjRFIOA9Ylv2ONjnv763bbe2cRrs8ZPO3vnKzt71jisCrq5BAcGVq11XvD+AK4CRgaCzsXW7I1h/96tv2PHY2/VptOMw2/W+t6f7gTYdg7OkaJWwjZBtS4bKulYHwulvSxVj/GUJgQJURrJC/w54nfxZL7aY09PWExg+CcJMOYM9T0C4flRn4pb50vt7mbBRgTb3HidQzeA4oD/BcosCcu6fW6D08GCl8jlgto+pdtUq2RLzJwr7MfeCfpZaWsmX1hrdgAPKnOBIHLb0Vi5RPmH2ZIO/j0rlrrXtVhYYEpEHDFxluEy1MywkkHtA4iG3gU6Ov+WKeGnXREHw8iwJHhuKaiQ0UGzPE00Ba/n/SgrBAgOTCFAZymIojKWibghgBQYdggL+zb2r56MDdSFMnEGtPcM7/KC1oB98KDk8XnXWABy3jTXbztcv/eTMDu9kt6CIRsfuDkLL8Bn2By4LbRzjIxIKpYqWHu2sL5jtKfCec3u20/7AHRRokwkAGvYoW4wX90AudPpsA9h1hEmPJzSYQHMFvCB+l4ubMqGlNRe2J8mHmmu6+6GzQGkk61TAj/Yh2K5AAIv34do8QZKKcuaxUKQ/+C/YYJxuru1w89Sunz6225sbG8c+A+tIBkaRQ/+MQJeywOmEz6eZ9juy1lbbcBXOttWeYFXCNAB7ntMah5Ewis82JFuYKoBViydFJoyZuWFCEq2A5CSapu8nB8voTzUdPkPDw7y0o+JxUCAx7WQolsEaNWrUqFGjRo0aNWrUqFHjhRTv87hTCOsiEnBLSt1LhXJ6VbYq8L/h6zh9fxv4rkqxLHWiF/EKKOTQJ56/fMS5BBN8a7GfD2AgAKl8QYWYwvcy7hHAGtwKIAAPwK8N7TiwTRw/V9bRuzXJnuVPGvgqQdise5Vvq2CygIQgs+CH/KMFyDMmjntJftIFJ9PhXTW6QPTenmuzrsF1rgTrG8D0JWgIb9vkX01FqXxCBTnQDwFpyxOn1lpckZSJegwA5uPK9if3k21659aNbZrRFdO4FsETqFEB26G+E1iGD/dk++Nox36yfjwTXHPbNyFn2dYFMqIUMSjK8oI17lQM8f6iVSVWDtgcthWXUDkX7UvjzhX6MT4FjfIW9OUkcusMVzeXl1p6xS5Vy7nwVppSBaTKxSwLri4pt79C3rABbKPXaH0Se+tprVE+I0Ul0xccCl60jd7V7qfsPrGEq1TNRiFLfz3Gb6H4TkUVYweBF+vkToGkFi5+xmoSvscE5oDzkwA5VNhlxTNv59z2i0ZM/bFQLJOqhr1CmNz6uCIp90J/s/yekUw5A/aSaUtZq3na83WAfwnkFvxb60DY+Og61u77DWAfymcm0JqYI6GGjnbQnOQOBr8+tjnUtmwjzS9JWmWHAQjrs97Pla2f2Rfu7as+079zXsbtSIoNEwsFbgzG8GoOVf/CU9mV2rxmDJiwzygKsMbZsBbdUY4XvF/S/axiRqB/AJXZNmE/UiTi/HMprsNiLR4GG5G4i+RdSgIUOwJcqRw1AcqifZwH9J7OxR4Xn0Xsn6zqT3PYryfWiIDKGsdxDbEa+FxN50a/SsmeNmYUYzwfK6ZCaR/j11euUzVq1KhRo0aNGjVq1KhRo8YLA8t3PILv/juBLfc0pW9kYQmRAKBviY5/6/D6Qg7B7K7tbNN21rrHK4rL7VDwLdSTUHu5hQItMPwLNBXO/iB3ouOuEQgTpMIrExBhhr9rPLeT6muCctCVkNySjPMCuq5sAzVl09jVVv7AgMuwqxBk07biAZ7AgEgAxWwDFdoSINCXdxViAu+YbBhOdhpHO/RSglPz1+Cca+s2uNbsQS0bAG1Fh/cqjgNv5OSIwfZUB7l+Wf21Ntt2Zi/R5nhlD45eGHC3oa2HttevbLvbWNNtbJxn22xWBLcj2grqOhT/g8MA7AcCSZRKO2ljBaoSOJHCdTyvbD+c7QSA/mS0zc1s2ycn27ZnWWHAJxbQvlUSAYAb19TTDxoK89bW3YnQ5poF+iaboPZbrex0XvP4kZwg4BkBhLzAFcmYCvcVpDbBq2keZeFxhytn79Vkf5H8SX1sMQnh/ZykwgX8Sv3tVi6ElEUlwAI4sbggLjJ8V6Pj/HplXZGBarpGqsM19pIdRAB9/zl5YgWwnvMhHcshsSs/g45z94EnPDAWwUn5Hk1mzkOOZwwQ+t8a1eQpEXLGfIJKdW3HE8YrqrsFyIOCfaJFAMYgC5XRt1fXhuNijkGxDIuZZPVRQuWzzzeHknC0JQNGgmDG+IG1ived94sKDer62DPhMV0salIq42UoiCewfF6PhMYrFE0jbHdQDjUw4exs7aozazdm6yvvdiUqxvFgEwvxaa0CaIdfutrcx0DyPVblNdpaQKmN+3ZVsq5b5ycoTQstjit4DB9yHwG29qJ7Ng26ZnosowPwWlhywEpndDU77DOQ4QEQl5cy7TnYLH5EQPKUREF743dpjXEq+WlLMS3HjbPOgXtLuTado7z2CfMOYB7JQlfvnxMg9UKuTELFjhVvC9+XgLbnv9k2XiTRbU5CSnxewesY/R3WKK6sxzvhy8I1Gp4Ro02nIz3zT7c3drg92HGP4olSRLO4H+cSLH6kVIb1hewvZs4VPfTZQf9pVwpr3SgAbniPs7CqEpZqJynsmUCIwrVhyeLzmvMlbEq4CQHnHGnTxEKOaTz7wT1PguGsz+EM5nk8373DIovYMXO5DNaoUaNGjRo1atSoUaNGjRpvJ1jOYkxXL6ZCT/q7hFABkHMhs4AS2XrAX3HPN9nQMi6KDBEiS61cuBAsjxX/cu/OcptzXDOVuIXaK4poBTDEl3QxOt+GnBTLei+3FNOeQ4piwGoCxkmga7USIIrryKrYICyFmpYKN1lhyFNYX/RTYS+qO90DM2wukr8A0I7bK4TutGzbQuQpwAFgB7Uy7kECQd1Ltm7g/QXgJcwJu4j7+yibKiy7M1m8pp7XNm38DdYWYf9BwEiwrPtoGtwjPJeRAABYngmXeW1zQ7CMv40gPCgw6Aq9UOOm/lRFSIkU/brSJd0RrQYEz/CmfNGdf13YKdwZ00mJmt+TXp72wV/UAVw2oRf0KiXGFy3v3KjQv+cxlxTOWbG8VJoulZJpHrtqUq/PxwJkFCzPHRxKTF1KPldWcIbqMisrpXqPJiuv4uLu0pi9u/Ng0WZu7xBzhMkUjJPST7tsveT9W/71on1T8USHkGzkoiqpMkUOLV3169YFKpLnawutVzDn3d4mFSp1X/pCwe49lBMJUHp78oswOcByAf7T0Au1dllMlVY+7h1NxXLgZvzE2uRFDAFT3Vs5PHiLgZqUrnnclC3nNi7hcZ+8lS9U4sluJsZKqdBPg937yy1GvE9i9wahfjk4lm9eqKHTtRWe0vn4+ltK2pThanH0FayWYA0RoDgUy/mnF8orVcte0C+sYcrET9JypzZS35RXwZW2kF+X/5fU0hfzpPwMm9PnQdnXcQmeYEvrfPS1Mk/lXIsrrlGjRo0aNWrUqFGjRo0aNV4YWA40Q2Wtgwn6EQPsODedlhW73KNS7wvoQwEhvsZi27C/THzIfZv1Aj4Ae4ez2W6ztbZzNZerLFnwDrAR1ggNfnbuF4rn3Y6C0Lb8xh92DtmXWEX/GhuGNbkR4av7DkOZ3DatXW13VNd2UJTxC7tUrCyO1zT01pztRPVp7PAGFO6nEyHoNIWDrm4Mr+sBWKGscyEiAQkAbH+y43Fte3ger+CDOVjTdfRWDfVmt1rZFjAc6j96XOOYUrO1LH644rVCab1FwcEHG+vn3rbdLBU2HytrYH2B80Afd25ss2ttd9XCv8L2t6OdaDtxaRdRKHn9L2n79x2knxWpp2ltw7yyk48BFLpqCZTZIrw32iDQ7iQ8gyc7Hw8EKCjgh5/tdsPrx/0GmIdKMluquFdsAVoWthElKruH3V7q8glZYxu5qxwFmBziBZx3QF+O9SWucRCVoE7ejh5QSDr/u4gnSgUmKEkwGZYoAfkCXMtfFUW9oEDnyHNlJNTCSIdgnkDETKjvNhB4bZoaPlEJZFnky1XekwBuTPiwcYAiejplmJh9mScqL/F/LJ6GJqNsXz7o4WmQEjh+Pc2q8SQHrhZzalAL8HzYGQAvYh9buNRRinfD+PbEQiKaLPCWYSmTMDFJo+ikZJ2WJofDZaiEVfQNnsonqYzdr5fFJd26IvoXKlZCykF+6EjSYP2IRA4hOK0xYueGezPTJ9mtMKLAIMcZ7lPjbE2vXyiCHUn7cdwA3UeQK/CnSFZhXuHe0Cbwph5tGPrkjR+ezVj/3OUjQ/akxoc9ivtnI0HlanIo09lv8PEO5S0nQgPbZMfmnhQLdTxVuOtUbE7vQ1viHtC2K3qsA+xiLaBHONtZu0PyxgPuh+DaihHOlR3mxrHGLuxQPLFXWHqEaJ1tjHW6P9rtfm97FPAbBuuhuuay5Qp5nI2e8ag1gLUIVj0TbX5QtA/NHdY/7K6YqNyFgSKsuGf4baM/MPwF9VO6g17rntBMeybcl9p3IMRnJcYY2q6dzkzWNWcUvnQFP3Zh8MIj+ajxh37uqMTHujnxeLSc8rnPz/M7q06NGjVq1KhRo0aNGjVq1KjxdiqWAyb6Fv8FblwUDivUjYWq8y7Dc7VWUpVdKKdKH0t8+Y1t5AmguVfyDFsCL/TlxcJ8p/NSNX2pYgvX3IUPZsh9w9dZ9gzYok/PZ7f2yC93RRolYAEtA7BhuzSUuXGtWRkoCFGqPNMta5s1AdBgw6CtylJrZ3sGQoHUWPLFDV20oK0AOdXKKMjXAdJAsYx2lO+yBJHabs7t1ECAhDlrWxOWqBBU8utNgDZD5CVGXirES8UsgYiDF6nddf4ozRbHbwhdBZzkB+qwzotXIdYskHiPUrn06k4erwsZcTGO87UlpBLj7y2ke8k3O8b5JcdKKuBixIfqsTxGOleAPC+Qd9GWi2PegfmuQExzM+k+s8KSTgT+PGGx4C8esU0/S7qXCYSlmjnUt5f35O2P/uHxcLasQF2oS+MYi3OUCvtSDRvzXXPJU1CFL3EuFghrFALZ8EhnkxT9mkyDvUhcqH6LvheYd0m/o7dcAFC+xQCziRn6P/Ia6EXd3G8XgC/sPMLSI8ZF7FII+KhifaH0xnwWRGbRUfg5k1YKGALiaz3xAn2u3taYdrDshRlRSS6PZ6mcWcR0hTUTuB+vQ4HHsqDhpXo/e5TLfkj+7CjoJ7V23l0SntSrc5P6LBxisge52rX0XdY5NPqRDET78Fysouoe3mFHnDo4ZovOkQBy8Vl1MYMufo81AxB5ZKHRAYrlSFYV8ySAMQXr0cfFQ7mKi/kRxTjdc5z+/VzblEwrJnbeRVCq2stplnYglMlY2XJIhKyW0TV4wimO5zttIulFpO0VBRdj854Wq1GjRo0aNWrUqFGjRo0aNd4+xbJ/AU6wBDDSAdmyvJgrjwMQFVt3+bfkdZrfkuAzgMf5bKcBfrprW21CpbyyzQaK3VIN6uCiKCKGB2FJooauEjxDWTb6ufSlPG9GVwGwgn6krdXwF+YDFg2FTzP9UKlWhHLa3+bXWW53VyE/wBKoxASUBEulXHMaretiYTOj2rGHz3A32Cm8h+GVCv9XPlqpxH0/M84NNAQvWMDkbaNifTuolaFG7loqntvNRPWyCvfBlxMP3A/e3diZYBlgTT60aKMEWQDxSgCbvD9zcTa/i6JIY7A8wY41rELgXZuKd0mRXBKNke/Ba+VNymtwNbL6zGw1uIYU8tlLAEtPWET53H3QJJN8FiksrnehcI97jYQG+6mEPGf3SS0SDjG+iq33yVrZj3mppxSId0i2APcxhAu1tUM+JRP8/AUQD/UvxyZ8sQnxNH+luBdYgvIe/d11gF4qZIj5g2alxy0KU8KrxFW47G33rA6gKzguVaXAP0e5ABcVscLl8KKVchrvzQU5F4Cf0M13A3CdCBjGtELRLlIQp0QREhLrztZQtvKES7W4VLJe4M6LtZUWCeKRDiUx/oNgUvaNXRG9nScUdBuobkXyhQpfen274paQd6bHN/6NHQPrdqN7wIoEa+ZBUDr6UF3kWnT4H5/NxpP81mO+yfcdfsmy3ygyePRKJkh2V3VZ9mRGG/Y/nN8szlcCbUfi4YlcJLm4mQSZqWJuyaXH1ws+r/UjLfBpAOJ3KMTpTOy6WB+5vIaA9e7VjHaAXzfbK6x7cA6ts1EQtUw2skge1hLfSpB8uLEe+7hlu3JtUz9BXa42j5VdimHsNAnP/oHe+9qFo4RlfFb4UMDcIFRWDkU5jGVxvjJhIgCvpGTTtNphgYPCu94LQio5knctaI6oaGocMdlecI5qPnOd444Vh91cOHOhWIxPqP7hid40G7cbksWRbJvgKKT70biRt3WNGjVq1KhRo0aNGjVq1KjxAj2WvWI9VY/+BRiQOflpxgvzVt5FOBC7BH0LFafDxB72D6u1bdvWgScgV5sUveHNmVTS8QWcRZsyIAmpGyGcq9AQZExuR5C8TsOGw5VqvE6HyFTyAijhy3mh2gvwkDwsF36W2uJNTaGfH3A7AEEpWwv1Lf4+DSML50Gx3Pcotoet1IDqgsBQVCaY721NiwP6JJ9ttzbbNCvbtig6KLAByAIwjjZEW1KlnAodqvAeFcsEP95GSZknlePCToBgKODcUpW6yDEU/rKCYO696zcNsJHVvrJXiaOH3Qq8T0tkM46NUCNBm46WMFhS35Yj6xlavFDplv7QabBfPrI/cgmV03vTaRxWunL2ErKHQjfAUQasy0vN3rvx59UdJS4VsRcq6+ybqrfn9o3sh+ZS7DoAbAI8hm3Gegg/WQWLV2ILP+9HKksmP5LPedrDkFS4TJokRbaDZZ+XvPxCuXshJE+MMqx28g6FUKfG0Mrq5eJqbc0ESfjsAs4VB+WbywKJcVInvO5nnJMm8XpYJUipDJBLwI55T2Wz4J7UzLhHWbLALgP2H5qzeXzNg5IoqWBispyQ5QY91QFZ03qK308sAshCkWdB52j1aUBxwAIsu30P5xGnhtYo3poXcpOvcjFUsf7x2IUfN4HuYtTlJsHYoW2HQ/jFuNWqn8TvQureV5rDskBmmUGORe0y8QQZn3Tv91gnUrLP28TBuO5bSYdV4wkHrmOuyic5hfWDdmLofN7WSZWdE5SxM4KfLRwjeQOE1sFQA3sCM9xUvDUXS44vAUzw+OcErFBCDY910O30VdyS1yOVv6wwirUmqaADQntHMoG25rzlnXhhVZ0bnwOC2Ui4MOniRfw0lXW167XGDO5tsQbVqFGjRo0aNWrUqFGjRo0aLwIsi8EVitRQf9EuOZ7Rl+xUfC9BX1sAn1B2XhZRk3rM5PdK6CtVMDE1lZmu9lt422ZvymSdAfBABbEUuLHlP8Cg3w6PG8+1ABOtgDDUfdw+nAovSQZIJBJesPCR5VZpV2KHAhAWEuNACJX3RmcwQkgEBTNUjsDm7kcrkKX7BBjQ9nhBkSiCdxpHOw6jHYbB9gM8XOXPipZRcT7AZAFlQENARCiWN5uNdf3kiuts00FvXEIJtQZARyj1Qikeija0SQgFs8VIkq16X9qlUUMCsKpfGCrzPISoGHblZ4F60rb4BHG9z1TQENcqCJQvwo954UKRL+wCnmQavUx0UG2ZqXJ4u0qpXL49W6ZkO4nCXrwsOOfnEggW5GJRu+Kq0hSJm4iDWbwlKxtT8qLw7Q6bAfxsAABLVbarphPI9kf0MZ4BROZl0orF+9CVs5m5BgjLfRwIPUKJk4BXa6qVqWbG/GukFSWYM7MOHrAYs11n2+2WuxIEBlOLpGMHNJ/hz+znivYM/+uFTw2B5XIIEFSW1xtLWdFJMdcBiDGHAZUBjJmMoRK2tRW8mAFY6d+sxs0MO1TWDgWjnxzoh2o2lNMC0jjXbP1pkKKUsBlJFIFlqsBnAEn0LQA6RM6jFKzcuaAioUrecDuDjwwozrGCBADNSnquQRgUAXahNMZ6x2OKQmabnrCm8fWOfeSwtpgUSnbJT/7ir2pbV5XnRIjU9GUxSVlrCJwm6wxfb9LpFoJzH40BnxfzM+yDXK2b4LasRbBOA9D3p6OdDgfrT6fkaR2WJswVuGKZXtPp8+OyxF5+pHZ2CxN+znD9Uxty80wm8HlnQ5F8zQtIdl+PnRvrwoZj7T7MXGJTkkyqZVlg6Nz8XCtV+qkTlvO3Ro0aNWrUqFGjRo0aNWrUeGFgGQELBkdQCRBDK0ZVnn9ZDfiqokPukUmQeFfVWmDXEHoK6LjSWHATsHXtW/HPC2Va+EeGKpZQizBlbSsWmHLbB1fFEewmL1dBDNwPilGJEwFcjlSbJfB2LsGyexdDzelqZcAHfT+fbXZYMWMLOy0v1EhSv5WwClvjtUWZSuxRhdBwnyz4he3bBDy4b1hunG1/Guz22Nv1qbenp55ADg4hG6i7WaBpZVddI7jcrVnwcLPtbHd1Zf1wtq49qBCZ+ycPw2RNK5sD3scI2wPRCd4jinRBdTfDcsGLNlIRXiQYkmDxUvLrr3BjZSUDQvEbEMXHi/dlqLDDg5ietV6MKo0VWApMsugoPbxzhEI6/36HMMY1ljavCexkWJmAjyvh8925l2wkURKfWRYE1JlL19pCrXvxXHjDRJrCT1y8T6BWCZaAfkkCSzAo8Iikwqoo9Oa94Wa3sdMA5BKJB0LoFZTLLVWMHJs+1sMnN4Fl+t2WUDZDsPDExvhmEU9eMxSaZxY4Q9E+jCWAMHYd3QugyG9st9vag4cPbLvbacu+W2CEX6yGlyeKHG6KYSL5Eobhhcm4Jlvhv42VTkXM1MyOA6MOX9gn+ArHQnITrC96wkdAWForoK26TsdiBcxGbjeTjoc5wnWQ551ScTmqv1mIjqkR90pWsTXCzf7EnQr7mz3Vp0wl0V5E60jMNySSAOBxvVIxKwEHaBi7C2L9TQmRtXZ9yDLFRyrWk0nzJOoYoogfC4QmdW/pAa7CeLKwKFTead4UySItzj6GlQyUAllgOeaW2mLO7DfOSUW01j0qw9mgqaJkLnbJBGPYRsg3vvx4Ce9q9/ZZzn8CXxRZPFp/2Nvh9sZur5/acY/2V7siGYI1l8VhXcWspEi5DpwvHrKC0RhCP0eSAepgrRWwFJH2PLWW2y/5jgHf4SHlto/HIqVLeyACZRWKZXIthPZMVimhgIQjkhAs3scaBdrhkZa15Hld2kvVqFGjRo0aNWrUqFGjRo0aL8oKgxyh8KNdsMSAYaE19binWFd6Yyj2Cm/P4hVJgSj/43yOUuMFHiT4KxVcqAbvw4iCEfLQ9L3KGR6WX7CTp60rwdJWZPn6jvA9HTv3+BWc4hbtdD/ypJXaNhRrZXGmUHHnbc4BwQBfs51BVsJKOTfbME42jgB3Djm8KN54PtvAolArGyhgXFP1ja3dCMJ3AMfWnU+jcF/qQAeRwUZD2LYATM8cGUlJvmjwgjSV4Clvu89q72QrkTr/fLcwX9mjxd9CxSh+nwH0W5WjKmE4ldPx90IJXLodZ/iSNML5WGn8LG8/28FkD+bo0/uv7p6RW9h/L9o7msIvXu26tLMoj3TZjuUYUzFHqSnpFe4uI2FdkM7Lhi0n9IVkNfHme/4WRSuj8BhV8W4VAADWtda2zcW4TH4E7uWwvKbUU6mB4npKL4MM18Ny4ULkXrS8w2gkf7DrAD7GrooXfBb0zpYxARFjrLrmmevWelGYT1NE90LfaQBlT0KNp6NN42jD6Si7DV8LZtpgQN6qhY07FNxyJ6nJucMiW+kEymSiCDYQayRFvP1kNKLEGo6hZVjoN7zUz7mfOf4vGytU80mne5G4YRs5zI15fqedw2sfr8sgezlaXNF8uRuh0IWXCcpyiPKfecEhfC5SU0pKzQDLg51OJzudjnY8HqlY1k6ISGYVfurFHMorhO/owCOdbrnDIlTLWQmtz4mw1kiQvFgnge+lzC9sd4qQD3S2Gcrq51D3h1d5/nfZRkldH0X8Cl15jRo1atSoUaNGjRo1atSo8ULAMrx6EbGFH8plfAnndnQqw/CF3V01qabygkNJbbYsjnf5SBFgt1CMhlqRCub4m8MwXBd8YrmVPxUACxuGYEz69gwFGlS6AB/tZuNqT6naAHAAbdeN1G/ivw6KoW4DeDnN1g8nwiZuSXdwxPsETAaQGgCLTjbPUDqOsuQoLDuk6g4ADaVy8pcQPGKxvoAVUCZCMTfYqR/scDjZ/niyaUIbdjaivSeA5dmGebLO4T8Vy/3KVu2Zhf7aDewwOtvtNryXVZf9P9cstpXk6HaG+hLenYR+3haBhy6K0glWFZrzOxAke7eq6Fa8N2CyjwUWfnNyxMEli4AEgcLuBGOHvxfgxv+Tz/tWgCQrbPMBSj/ZvIU+AeOClaaia67KzLYNfi8lv03KeL0mpUb4/0t1cxkBeUK1mJMSBcAO1fgsL1sogrPrils0BPPH+MTY8/mIwn0qlLYmzFXSRImJdg0fZTwv7+0MagWeeX5XcEb2IZI8Eo8WCQq/a7wwLDDCVgDXgPG33Wzs6mpnDx8+5PiEiplFJGPrPxTQUFHzWqKfvQ3DroZk3O0b4AdNKwspg/PrO/kCe8Wz9Pek4sU4VBG+8Xhr480TV52GN7KUtJj3ALwCryqsp7mOn27VAaW2r09QYGsc+7FhbTFPdjzcWt8fbTqdqJoFWD7e7KVEpkpa8BMPnHfN865t2G7cmgZzVIX0pOR27/YC5a7WA8epCnhK3S07ihV/V7/4yHTveHpon12h7OrwKKDJwoW0TBEU1lQNz+soItnIKsQTbslHGSOgKHQJRW0aRGkHhNs+YI0q7DEieRJwnGt2qKi9mGbaXRDjzcdAslFKHwbwsYcSvbeb66d2/eiRvfH66/b6q6/Z8STQjPGZP2moE/axKzBM52bauKC9zvw5oU/cqoc2RK5i5rigOl1JTXjGwx5GcDlmu3uYe3ICa/aKO2sgho+5pjWZdkUTdpDEZ2sU6/QCsky2wmIJ8wtKaQfLTr5zgqtcQ5+9FtWoUaNGjRo1atSoUaNGjRpvC1iOL6+F0Clpx2LDc6hGM/gKVVX2HCh410K1G8BtIYBLstl7VKKlCjReW0Yp00pqToGmrCC+2M5dqGnzE/k6AQ3wXV6F9Xr6WIIxZfVZtE8gtfu/rS8KMwVTTQLVC7WnQz8qlQnZZJlRghT6UkOteDbrZbkqfuZ9EIUFaTOAdxAQZRsRCa6LAoj0eS26oADHhZA2t3PYE1zc46IrCtVy2cnZV7QU3cX5ioMV/Rn2Km+q9b0QUi+veXkNl2r5gLnRvvcdJQNft4FY3sA94SO48GNOnV6ooUPNXjDvZxwrFyErR1pSVqb5Ura6j9M7EyZsTsqCgaUyuGjbi75e+KXnE2XLk/SiUkTs9+h+0LLvUIFMwc9isSgLJcYBPfmUElDJWiFOHnYH5QCInoy5mq8h3R9BKFTK8kFXFbyiaGXhfS2bDrw21KhRbI6eB3pt+EWn4obw9VUxQKhlh9PJxh4evz3V0SMtdKCWHnUthNEOlpHMQhLBFdMswhgQmB4+2ROb947TMyfT2EQeC5iPRJJsLM6kl0UBvmiXMtOXhmdZ0O6uhcxi5sTYWt1NMEW7p9OW6314mBd+ystTFPOnkPDHWJRVzfJzJmwlwvOEavXwtUYxVfTB0NuAPugHQl+qgdMcuVgT0pyPApjZ3//uro6cNs2fLctdOjFL41rLz830GVkkmkIlnvpq8UrfY1GA41AsX0z14h7e5POzRo0aNWrUqFGjRo0aNWrUeDvB8nYN/00UlAJcBcRkSaIFgJK/JEtNUdlF1bJ/Z3UxbYKd8Vx80Q9ASbWZq7PoFkkfVcHPeC8PxQJ7lLv61mBXzCWP0eCHZ6qCpzOKXeGLNyAWFLw7gWW3mcADCmx+4W71bVtaOCjD9OUd4IfKwmNv+/2BvsxXVy/xOFR8nqWghgoTdgIosEVvU/fo1T0YFWv9ACUbrFx1rQ05FK5l5Bb48wTVX0PocbM/2s1TeIAerD+ebDWhrKG2tau94Aot9WEDT1tb2WAASq1cP88D1cdXD1o7r1prNvCdXtnVbsd7OB17OxKWn3hvPT2MXaEekMdhTPL25Zb/0gNXrw6v3awIVVABeK8NQQmZL0FrVsSm5wvpcJzrkmxm1nkfLQmoIxiYFI33ZydS8TIppovXJNjoxcgcpqYESVLtSvtYtk95Haq0pvdR0Y4xgvnl8DTAu96Hn1J+ot9EneYL1wgkIVRkEqA2OKO60IvAxc4BesnK9wLK38QDU1FET0SkRg2oKIsCKJ6hYJ1ZPM6tE9y+gPYtkahZ5wKUab5BtdqsbQuf8A0eG2s7+Jx70TF4n+N9bgVDn1iKVFWYTC0nWNg1DduDalp4G0Mlip0FTKL4zoJAcmiDES63Pm5WOMpgK/hL9webppPNp73Z1Ot50NlkjHKmmhhgeJqxFpzc51drz7rbyFaCfsxrWVn0B75+OOnncX9LoHl7/cQO+xt6K489dkFASQsFNG4YdhgqLkevdkmTedzdOEjJ7r7XXbehf3MeU3gfrh1W0Dtbr1trRkF7qJ67Tspj9N2KOxK8aB8bNEPX5PtbZDdSAoPrWBT9CxVyqvDonw95TiqRIAU2xy/bSCrrsJ1AMdEYZljbAb5jfZNyPnbGuGoZynCObVeUYwHlGor7iYJ2KvCZlMdhS+GJg9PxaDc3N3a7P9rh1NsIn/sExvUD19kRxOeCfhjN07yyoZcSuo+dBe5Hjevm5x+97N19mX7xUD3754nvTknXmoBzkWTzD4ak3Pb1uNRnC9Rj7cj2F1Apq1ulttZapM9mFaBcWztjl46uP0ob1qhRo0aNGjVq1KhRo0aNGi/OCgPbu/lVHdui/Y8J4IZfblZMRtV6YBt+vw0LjMIHtywCd6lGDNlpWUQtVItJjZXUXL4du/C2TKJUB2oEJsmaQNANAIDb7EOpu1COJjKY1JcjPJZHQCXBakCdptlqa7q/E7AAcBlb4kvvzXSlrtgUpHD7WLc9ECgJr1M9oGCEshGAmWpHqBnBu+IuWdwPlhcAytycrwJpKBRFBOJQYXVmMT/0Q7PRtnrYIHCrPrdhQ60HD2kV9iuLXi3Arqv2skZuCZCXKrrln+/qEPNrk5/rxUtKUJvfdh8GWaqKg4Pez0tirAD4l+rfe45WFr27z6y0kDZyR3kI74vzLiF3ocwvn784XEDpaGvNL41D8OIAwdgWn46bVMG6LzGy4mSFF3N+rXuAO9heCE39/Yu7ztM0JXF4rvLi/dg5EZQTB7TedYAm+A3rFwef4XsekHIhPy1zCuFZ7GzYYbPWDQFEWL5Qseq2FfmGBRfDekbg2+0xpsHmAUkd2Va4H06+gzg25jHnIgAurn/jRQShCF7bGl7R3Mow2uQgE9BYPsqay6f9wY63B/4NCavkJU1PYBVP5FqD+c42nmltAIW3il26ErXIq+Bv6E+Aa7bLubUzgKt7wbdYF0TnNZ4K8Xr2Or+vIOZy/CY1MH56n0mhHT7YGqclkg7ozZUrEiLF+VKhRh/nCxsTX+dkg+Gw2SsP5mRUzPpyHPp5oypsXIMDbbQvlMoD4P6oYn18RyHjlS2SyvLR+xqJRvpDu3KcoNzHRmRP/fe0htMiJQB7HlFpx0wBllMNgLDOKQoPhmJ5OdXKHTjZTzmLuvNE0XoS51axyfGe9apGjRo1atSoUaNGjRo1atR428Fy42ri6YwCcJkASo3Mr9sLX2UoswIch2o5MGQolhN28m/DCc866A0eRG9K2lBAeSxcqmJ0YhmCB1Ic8ywubFu58o4FsuCf3La2htIzwvcwZ1uMuFKH47gPvH9e0cM2ClMF7JP6eLK1QzLBubV1TWvTPEqpSnqluxeYnugzSxwAH2NAqDgrITFUpO7PTMUd4NJA6IX7hpi6W5+tgxcuoDDVciJM5B30KF3ZMJztdJrtxG3e8G4FkJRa8erhlXxiE5maBOJw/LblNQxT7oMokggIKPV2AT6LAlfx97JPc0NnqFG+sMSiaqbLgmEB9yOyAjF7/grupDPeJ1QuAXBgfl5DqK5DuRxQKQjqfQeL88Lb+P5zCfItBNlez6y0BVgqvaVilMdqamOHmRxb9EstrmdhSZFSPEURS7/Tgs3HMTU1oG7V6xr32o7khubxyD6HR6/gk26IPq+uMp0GTzSlgo+yXaGY3cfIHO9jU/s8Wq1s07V2tdtQsQw1bbMOdfGK55hWgMOaUwFDc/aJ2uCLxBN8zqUUnc+YPxi3MxuV84nOEppLgtGtw7so2jdIqTzLO1nv1+ExJ9FYAMRM7lDxiwRNa+vNVdphwYAnMWwWjns73d7QbmF//ZTA9+bxExaJu37y2G6vrwW3o7+8P6k+ZeNNSQWP+YffekBjjBHcx3m2Tbe1TScVM1TIWjt17wDbvP7NylabhmsnjhLjiMUTsbuCyva1rTv4xctTGf1A1Sv9kFvp7nnPUAqHB3qRiEiFUC9zS3gf2hltXKBmZUukHk5rg3u6BxQlnY6kE+XDPk+lskZCgH3nr8UnjZZpgFyfvxgLVKj7EJlnKpWH/mhPnz61N954ZDc3tzaM/vkVn0WExvIVJzN2VTcLJMKqe5qtbXpaeUsM7fOZntpS6CshEN7w7qlML2v/DPJzYe1d2GNE0gtHctBMJTYXByV3meCY8BmD6+wIiWk17gvpNMrXGjUDyq4Jq4+Y41KoKwlTo0aNGjVq1KhRo0aNGjVqvDiw7IpS4pNQsvqXWNdvJZUioDIeKZKtRKFojr8EWMtCLQ/f5utASf7GXrAI37HXAKWu3EvF8GQTEIXhQhEnVeCgL99tt9C1hV61VEeXXrAAF1AfB2RkoTBsdXY4rsJWrtyLrcZQYEZxt0Dq9Gh2QK43EJAAMgjOyIaD8Bwg2wuHqfjTyMKA0IxDgMjCTGuHLg5SkubWrx12CP0wyw96OAoUrRtru5XtrnbWtK1U0CQjUTDLrF23Nq0BxtTe0TphrZCAafTSm/KIQqW7RL/+T1ct+mBKTPlSaZyUr6Utg34PEJP5Vqnw03/ToUqwnQ8u1S45EO7vze6neH8SwZYey8t7vuOX7WMmv909fiP3oNwCFeeCP8v7XxX1ErMPc1l4MMOq3A6FlUg6nJSReAZYjypbL86ZVPoJLLfWtl5MDxDQIZqKWoLbBeh16xT3nV14YhcK1KlYO9quse12Yx2Ky2FscQ6J5EYxQV4/1cg6b1ZkBi3LdyerY/kks0ilaK2PF0BnL9AHkEZw6K2TJvwIbxr9TGNCynDOQRT/PB5tgr/6Zmvt7qHZqrV1t9W6EIAaqmfMrdPRToeDDf3Jbq9vON+ePnpMsPn0scAylM3YQRDt5tbMsv5gAUDnqb5G9oDaNlP5zLWim23qJhUx5fzEa1G0TmAZSYo1rhHN2uCZyWaaL2vd6GANhCYC/A3LHrQ9vZwjUYAF1ylzgteuSj4/Y26nAS0omnYkFEpzDQvPQBTWyWH7onVpfVex7P+mQpeXE8UYi9nn/8A4oJ2IPxs7QE6no93e3BIuw9YIimVeiefbmgsPfoJmruOq7Kpkm4NcV0XnxFlYwpRzIB7yx85CayVeqVQuX10k6iJhEylcfQYrgXBet8k2RtcjGA6wzLV6XkxPqbXh0+/qeG2bURKlRo0aNWrUqFGjRo0aNWrUeGFgOQq04UttW35pTnBmtjV8GwlEGrpihu4qFVzicfyd5Rbmsr5QUdwsKeP47TgeIhOEzPCvxHWRXqxt1Yo48zu6qtfx/GCnUNq6RWs6bz6lvFbTJuVUuK18mYMfej4DSgGGAdho+7tb9ibAIeisa4ZKmcppFtUDsIA6EwitsbV1WW3N+xQUgtKs66Bu05WEtjauMu2QXhQdVLvAv7YfZz6Op9EOR9h9QMU3Wjecrbs6WdepSBhVa2wgeO+ebbMRSGMRwBFKVkGUlippXN2Z8EWKVV1NhsuFdUDKEuRib7G1vRgKy99LlXImT+WLL7jwBYH2BEZQ2ORPHK8rpbsJSgpaYWSnKw5A9CaUOc6zVGYHxi59pnF+h3LFe1VsrBiKUewytvcvVNvZxXzhxXqPS0hG1gFUM5RW0sfVqsH5ME2iDwkY5b8dttN5JuTkkJIo4JaTDRzbAlvJ4mVRcC8JMH13g2AchJ2bTUuwHB7BC+5fWiQ4CJYfbV43wnskrHek1I4x5t6/mGNMNuXxh0mV4SXmAVTKZ5vHk5TJmBOh/vY1aDih6N5M+No0nTXt1toOyn8cm+a57o0+2bjf23Q82vFwa4ebGzudTvRIR8HP65sb+qTf3h7o7QuF67qHDciK1jRox+0W68M6tTc82Qf4OXOouicxdjFQqY/zr2xqXbXv845KYyqAlUTQ+qL2AjSm4hvKXoem7TRZhzUAftXnja1hEu+ewcxhIVEGv2qmGHHcvP4nPEroWshjo6wrEjbwRC7mN6FnAsb6PQ2A6JqYA4uM40VCAer4tG2ggNyhIPY5GrZCsL+43e/tsL+165u93dwc7NSPyTqG90lIKx9qDWNq4/kc1ss1YLV7nS8/UNymCQ8mXJPJjDvg+317LdKY9/pcyWuGlsVsvxMFLen+QeX/bGe3NMk2GEUxQXpwhzI8e7Dr92hYHim3bxUs16hRo0aNGjVq1KhRo0aNL0Tcu5H/3vDvo/g6C7ywWa9sg8JG9EhdWUsFIJS62tJO6BoqRII7/8LvhDTZl5bq1/hHkhJHMSlWwiMEkiIREGeiz7HqjFG+bOtO3hhSBs/W9/DQhG8wOAMgAaBBsc06fckOuh0XdgmgBR1wf1QGdhvb7HbWbbbWoGgYtoqDf0H1NkmNPQJGjRMVhqcBj8n6yWyY0R6dbdqdbbqdbTdb6zoUYnPlLPbr22BtM9tmYwZWHlAMrJy8nLWrVg5zpGymIvM82jjjnLMd+8n2p9Fu9oM9vT7Zo6cH+4I3bu21Rzd2e7O3/e3BhiMKlfUsSIaiZ7DYeLhr7OGutQebzq42HVWQsvdobNvqsWnhI+1AuPChfdOBk01Qkvfo3fdd4PPCG/dCE3mhhnarhIxfL4r63QOq432pIFhZqCz245f78gv/4thOvgC/hZI5ja0C3BTwPL0P08EFmWl7eqheCyl4lDArfVuzStOBann68n6Lv4MtYoc8xpDU0/ibwGMCaLSYEegsj5gsT7htgZX3bDybHYeRY3wAtGMRTCRUZHuRLI2zKJ1NCqiM8bPbbezBS1cEzHGP2djF79ULrWkNiLZxO2JYcrj6Mo2puG2odFdQAmO1aqWsdZU2ASk91vH60ab+xsbTExuHvU3jkTY2si2BwlVrzREw8uaGU41zd/PAtruXre12bldzsn5/bafbx7Z/+obdPH7dbh69YU8fPbInb7xhr7/6Gh9vvPbI3nj9sT16dG1Pntza48fX9ujRY3v0+Ik9vbmxGxT3cwsfrAboq37APN7b9e0toejt4WCH08mOfW97QGv8fjyldQa2DsOo4omxqyJ2bYzwFD711qNI5+Fke5zz+qnd3jy1/e1TOx5urD8drD8dbfDHeDrYiCKEw9HOtAqBnUZ4PStfE4UhU8LDpdeEtFRMw64BfeKJuCh0Rz+jNiuRFwbCxZz1nQGExLFDAMdHP3K8YveHz+Oco2O/U32MBOPY896ePH5ir7Mfntgbbzy1/eGUCiTys2vdWNdCqY9ih51tuo4/USASf6MXOB+aLwnmU02Mh/tXEy6Hujo7QMuXOa9zMR9Uf1ApLgFjJWxYJDN2JrBPNfcFnPUcdsjg85aP5HuNuYI+URHDtG4lT2v/PLUPjPjYj/1Y+5k/82e+353rizK+4lf8ivY7fsfveKHn+JE/8kfa9/ye3/OFnqPGF/04+ZW/8lfa1//6X//9oiuw/v6lv/SX3qtj1HH/vuuL2tY1atSoUaPG+xAshydyqKmWiqrYwi3FFEAzvgzroS+6+rIbxbkCIRXMLbYeJ3VafqRifX4OKrL8GPQEdTCW2LVv12eRPdpyOOgqvkJndXEBkhNHLEwsCksBAnSA5RYPhwxxT348bKiG9yy//BOi6qu7fs/YIAHIwJnRfu5lHPC7RKNUHvKeAAfPF2Cw8PQlRBRcB2Tqh8lOp8n2x94Ox55bvw/7fSoImLZFe2FEKJeXfagHkwf0dL4sIJVDysOAwhfbuxePQpGo5k0/CkeH9McFbnIwuiwGWGprC3VweJeW5/JkwRJOx/uLsb3ofz2bgfKzcMwSapfwNymqF/dbWrCkG1tAtQUoLsHxs3xISnZ+8WfOY4wht8MQcAov55gL+XqSF26hGI+XSilcFqQMX/XiNYXKunijlJoYVwsQJoVl8selgltKU6o0C5ieqgguZkj5Mw2AfP3p/rK1hywuoEaGr7LsNwAL4Z0MeBhqaH8jFcaEmYSbuiaoh+nj3p9oswBF8gkP/Ps06NHrMQDwAvQyieDtxuRZHqfY6YCkFIrK6b2AwvJnH/wnXsM5Pvrfp9lGJp/O4Kl8cH3g3+Z0Tp4XP1mkM6B9hvP0p/bnaD8EX2k8uLsBbSVQWbZznpfF+Ckfsd4VkYdmJKhykcqFKj8lmy7WjsvdBIXcvbQgKseB7ksF+/CQjZFeF59dWM9pY+RQN6mFHSYv7imW8CSizvYhBO+LRTKK/eX1MwpmJqVxrhmYbyuSd4s1NcP7mDf8/Aj1sl9QSoQtvNhz/cPLten9PT71Uz/Vfs2v+TVf1JdR4/0s/tf/+l+cd5/1WZ/1JSrBENeNz+D/9//+3+K5z/u8z+P/vsXzeN17G//iX/wL+/E//se/Kez7uT/359rf/tt/2744B/pz+b1n+cDzH4gQ/ItzYCx/5+/8nd90rv7O3/k77RM/8RO/iK6wRo0aNWrU+ACzwhion1vuRI4vwtxOTYsMgUf6PHrRrthmP/uWdagciVjDqmBdwgX5d7KiWFIRO1hez7Ky6ARd8T+GYQsLUHLqB+s2nW95DvXcmSpKgRvBsxakhV+q3WIjDGu9aBT9i6HY5LEFUqkW883M2661TYdiZjtrNw+0BZ11vlBO6UR12DgNNs6DDfMgj+IVCg6iQtqaSjsqZFlIbPZ7n+nj3BLkmm03G9ttpRgFBBFUk9oUkPgE9eQ422EEosa2bBxXKmY0YZO8WAWAoFjetY0d+8Ge7E+8fhtOdrVt7INfedkebDdUfkMaietoVthmbXwP2uOKBb0yxFiPuA4pxREp0bCAnJfAs9hrvfTBcEuDPAbuoFLu2i+2x5dwmH+73Medx5IUrCUEvqQnLhl2A9IoBphVy8WRC0jKV6hqpCtblwE4t5Tjx72ni85/CkX/NLCPF+2S3h4wqclzLlTcyVO5qDvocC3moWB39nFFP6sgmdSfGHezF/mSCrPhWIb6HsXvMMcCFuL4hJUYj66Ahl0EC5B5cbMA1QSnnviQrYQnV/Dm+cwkzWbXWbeR9QOPDdCLMe2+zuMICHhk4c3mDLXoyhqMYbW0rB2gAI157eA9GBx2EZzh5ZC6XMpOWlb0RzvPUPHeEixjl0XTqSDfquloi3Hc95zHK5wfEHy7tdV2xx0DM9TN42Cn/VP+PKYCfTd2vD3Qx3d/e8ufT69vWcQPSuNxGG04T7SsCb09RwYWtNWayZ/TCbB6INhN6Qi3wgirY3aJ73LgcrLZKSHEgm5rwmrYYUzDaM00s427Vu18HnScdqu1zma0yZoQup8FUsOznjB9vbbNeWsbFk7tkiVQgq1QjnNQ8UbyhwSLIkbC5J6EIf2Gda3sL3r+IrmlQo4cTCjiyLeH73YkWAo4mkZ9FPjzc1M/LNV23x/sdDzY9dNre/L42oYetkNQKaMiqidFPaEQtiSR4xBgRmvmyZbGGt+3sqbVz7ZVAhLjJbg4YDXamZ72YU8jnyPdX5HsUkJAdxjNy3QB3yTvZHlvYyxrx5AU5N7vkdBEsmCcXWGtZIj8yle2QkIxjasYge/f8UEf9EH2/h6w29lgq9OXkFChUI3dD8S2RA0M7lj7IowP//APt0/6pE+yX/SLflH625/4E3+Cf/8//+f/vC1t+KEf+qFv+dqXXnqJjy/uySncE+JzPudz7Jt8k29if+tv/S37Wl/ra/FvX9zm3heH8fVFFTH2vvSX/tJv+dp3vOMd75NrqlGjRo0aNd6f47n/13wo9xY+v4uv+NohH9tyk9KVCjAVs+MW6KBfSRXqis1SzXyPYisZFYRq2R/JLqPwGk46t8U2/xIullut4/x6LimWHQSkbcj0hQUwWLsdRsufcS06A1SIAMmAwVAIesFBqjBzMUNAlChECHBMyO2QUOpuAYm49rjm8v7IYeCbfMFb43pSIUUoGamAPGuLPLbKD4MN/ZCVoK5cFctXO0uxDIsT3LODEwcYue3eIkKVe5GIIJ68o5BbwuFnHj0BqpIXXx6vBMR3/7nwYk0mFqUVRqmn9FMt1MT5tG8W+aX5jXfuzGF29gB+C/Fg4vKlFUdcT1Y/373XrDwugW/8flcFXqq+76r95aOcryV7KzuAv7y2grEnSwBX50v176rVGMxp90Eoen1Lf+GtkTTqaWwVNxJ/TH+7SHzwuFLnJpsNAltfW4Lo+YXIasGrhkbizJXOgMqjz6n+1Osn1MnDmBTFUAxjHsZ6ILCY75H3SRXybAMtfHSM0xE/cQx4tRdWI243wh0M+Jvv0BgvflKlTCWy1MilajkrpcN2RI9ybWIbxb/DdqS0Hrl3gOb5eMfEplCQ32ubXirkLx935k6el8vdC9H/y/kyucey+geFS33dDf/8wloi+xoXnzX+OXA3LhTBhZr4cjplPXH+v0U6qNwtcXF8QeyL9bS4Ta2rhS3PZZukgxZq6/LC38/jUj2KHQW/4Bf8AvvyX/7L23a7tY/6qI+yP/pH/2h6/t//+39PpRtg15f6Ul/KftgP+2H22muvPff5MNZ+6k/9qYQWH/IhH2K/7Jf9soV6HOeHShMA7+HDh/YxH/Mx9pmf+Znp+ddff91+0A/6QXz+wYMH9nW+ztexT/mUT7lzTzgH7gvn+I7f8Ts+17XhPBhnf+2v/TX7ul/369put7Nv+k2/Ke+5jL/wF/4CoRnaB3YGv/W3/tY3Pe5v+22/jdeJ+0G7/uSf/JPt5uYmPQ9V4Dvf+U77y3/5L9vX/Jpfk8d9Xnj5q37VryKgfOWVV+wn/sSfmAAfAmvUb/gNv8E+8iM/0q6uruzrfb2vZ3/+z//59Dzm/o/5MT8mPf/VvtpXo0rxvu3wv+7X/Tr7sl/2y/I1b3egzX//7//99vEf//FsI5zr0aNH9kN+yA/hveHavspX+Sr2x//4H1/00+PHj9MxoLYslcT/+3//b/vu3/2727ve9S4eE/31aZ/2ac99TT/iR/yIdL4I/I6/l/HetGFphYF/I77X9/pevI/4/dIKI47163/9r+f8w7j51b/6V3Ne/byf9/OYKPpyX+7LLa79RbcXzglQiUfA8g/+4A9OfyuTV1grcI+Yu+hTjPn3pC0v46//9b9u3/Jbfku2A8753b7bd7P//t//e3o+1Lh/5s/8GfvW3/pbc05/8id/Mp/7Y3/sj6V5/GW+zJfhmvFmfYHAOP3KX/krE87i+v7kn/yT6TmsY+ivr/AVvgKPib7+6T/9pz9XG+Icv/bX/lr74T/8h3Nt/YiP+Ai2zauvvmrf43t8D/4Na9K//Jf/8m1ZC0tFNtob8Q2+wTdYKMwvrTDw95/2034aj4VxgvH3h//wH7bb21v7UT/qR9nLL7/Mz4tP//RPv7O2lYHzlrse/82/+Tf2bb7Nt+H7sY59o2/0jRb3WaNGjRo1anxAKJYBSREozqeyS1LGUqnMZ1b0W2bQLmGVwIYsG+ToOK6glHIVFr8Hy6syY15ABSj3rKh0X2xJVu0wqizTtmK8F9vBT71vZYaiTspRwBwptmShged5XNpNOBAmUOXmcXqAbprWtpvOXn7wkPCYxcPsbA92LdXE3WZjm6udjePZ9lPPLehQKvfDyfrxZMMEy4m93dxe05eT51pD+buhtBi+pQcU7iKw0gPHBYCHJ6zZxsZpbacex8X2e/k44xjyJ5Uf6dlVdCGLg5foroWXKeC+vD4hkDv12jbfrlv20QYe0fC1pVpW/yP3dBwJnUNsi8KBsD99OEPFubL1aWJ/0f810GEq0KZ/C3RgRBSKwiT7TbpMH1EldAnQ4TiyhLgBNed7COUdgKq/J6/dUlmca/rdte4I6EvVbAnBYnCFJUWmMSoKdgHJoz4gBmfyQw5oN0qBCcVgFiy7QhNDTABvfbHF3cmjHgnW+T8d5j4rOxQMQ8pLwEadD+MVYJLWJjDxxjiBqnjBXyWLpXMsFLss+hYT9+65ZIPhuwW8sSNBwb7H8TGWbW3tqrNuLe9azLNN10ghqiPxHA0Uwm3DRM3Qn2w+t1QeU43rBQizDy1U1vDWlecy/4Z5mxJThaUAPMzHs82wrjhBsTzhiPQAlkq28Z0Iup+m63DFDsChJl1ZP6NI32jYrgD18eF2T+/ix49v7ISiffujAPMwyGMdCZ1J6m7u+1hh1wOU0WuqiXtYaIyjHU4CMAS4VPqisCaKeLa2vdryPrtO13GWnXxaE/v+bDeHwe109JrdpmMibNO1NmNb8/lszTxIqcz7UXPAf56wHMpaWPxsoNjGUoV+9ywWjzvauhlUwJC30XDt1xqkhFj8O4nyU4FLqPJluwMveBkg590ItI1A+3qSEN7X2tkCmO8q27QgaL2LpIMK1/mOEDccSuvEKn9+oWDf7fWNPXr9dXv0xhM7HQ4abrQ50rZv2RB5wsMLUubkXl6UAsTHeE+7AfxatK5jPVD2QEX3IqmT7XY0T1WE1iXpWgP4eYV2ETLn8se6k/i8RdIT5wkbi7JYpYpTojAs38+DleBYv7OV3sRN5wMhADb+yT/5J/a7ftfvIoj8n//zfyZwDDD1cR/3cfZjf+yPtd/+23+7HQ4HQujv//2/v/2dv/N3nuv4UH0CHv3zf/7PCQ9gBQAQ8+N+3I/j84Agn/3Zn21/+k//aYKZv/gX/6J9p+/0nezf/bt/RxB1PB4JHnBeQAhAYMBtwB4oJcvz/KSf9JPsH/2jf/QetwEAHYAWoNgv/sW/mMDtv/yX/0KV47/6V/+K9wuA9AN+wA+wf/yP/zFBMaAWQMx9gTmD9gTA+R//43/w9T//5/98+32/7/el1+z3e/uET/gE+yN/5I/wWB/2YR/2ltcJiwSAMoBDADTAHbwXABMBqPyn/tSfsj/wB/4A2+7v//2/bz/0h/5Qwj9ANsxTQMg/9+f+HN+He0F/ALLhHsvzoK3/5t/8m/aiAu35G3/jbyRohd0EEg4YB4BUAGL/7b/9N463542f8lN+CiE77hmgFMd6T5S/gNxot3/4D/8hwSV+AnZjLJTWMW9XG8IWA30OIIzxzp0dzwjMNZwT94bxjfmE8/5//9//Z//sn/0zQtSf8BN+gn37b//t+br3RXs9byAR8pt+02+y3/ybf7P97t/9u5k8ANQGfH7etiwDUPNn/+yfTeiKZM0v/+W/nEAY4LxU/P/CX/gLmQACPMWcASDG+zDmkCh78uRJWiue1RdYi37Gz/gZHKPf7tt9O/urf/Wvcs7hmgFGkXDCuoi1C8D68z//8wlNnzfwXiQMMPbxb6xr3/ybf3P70T/6R7O9sOZhff4P/+E/8HPy7VoLsRaXKvM3U5jjWFi78B6MMxwX7YI2x1oZ143EGGD38wTGAPoFfYK2Rt89S1GOxCMeEU+fPn2uc9SoUaNGjRpf7MFyqBDpzRpg2a0X+LWfcEBfykM1CywFVEdhGL/XattjkGUaTLj/cigcCSgJ+EofyfwTkTwlkwrMlY3YVokv5KgFlYCnTr48TqgSkyzSFcHZHxNAGbYUgG9wrsBxBGkAhfBTxb+gaMaWdkAGbOMf/TFgC/848L6wnb9pcC0tySOsJ+BvHOprHNe2OwEFbuMGhF7ZODqEcJCc/Xj9rh3KC4yrR2BHApWxVN5iQlJKxnbu7HsdRacAPADJo+ga3hjgftPBLqGx03S2BvYbbLPCAzf9p4zQABevWfzlWarlSxnd5QEcZAfjeZYgOROt5QsAcBbCvXyFSel7qdy7Y0lROFuE7cbiHBpDuoRCuRnb+H2sJZW0A6lQaMrjOS4jCnbFfMl3qiJpsd2/aIOAYMnWo/CJTqAbG+vdJOBC0CtFncbP4lxxNQvltlsCJNXoUomKIDB0h5vQaRK8YRcDLAfcYzkVJ0sKYQHSUJoC1iNRxIJkPKBWn+jGBOnK/ow1o3wQPruvMu1iVIiuVHuHxQKbmMkcEwher1hID/MX1iVzf7QJnsdQJ1NdfLLjoac1D9XK02RD4fPuDiDenYKX0wiYKksfJHaYAECCxxMOeGy2SFZsaP2jnRO+9kVyw+8L75fiFgkKtSGSdLRMwPhzD3gl66S+5nnQTyPGJ2BvYw1BsieO+G+1nbyZkbCCdc5kcBg5N7AhkcVClmAXC3MxUWMXRarqyKWt3LWS13KOaiT78Dxf64kpHm7pwRzD8VLhu5jCSKh44cLj4cAHkgK5DQuFsicg9Rnmdjd+7RrjpVp6SWfTfSTgm7NBWXQdbbTcFaDPqXycpMIv7kQ2KNmTXPsslneuuVT4Nhfr3H1xV+39/h+Ap3/2z/5Zwi+AE8RX+kpfKT3/e37P7yEAAPyIgOoPKly896t+1a/6lufAawEf0P5Q/AEY43eAZcAIwBz8BFRGQL0MVSL+jvNCnYe/RUBB9zf+xt/gdZcwBSAVAOsLE7/iV/wKQrkAKQBHgCeAW1Aff9tv+20JfxC4Z0A4gJ9ngeVSER7KRKiLS7CMhBt+B8x/3gAAQvsD4AAIQbkKKA7wieOhvQCLvtk3+2apLwFI/+Af/IMEy4A3AH0RAN9IKqAtS5AH0Ajg/SItDX7wD/7BhHQRGAMYax/90R/N30vV6PME3v99vs/3oYrzchw/T6BtAOHRvgDL+InfL4HX29WGofSFuvOtbAoAYZGowJqHOYRxjsQEwB4C9h0ApujrH/gDf+D7pL2eNzBHoLJFYHziPgApAXCfty3LwDWXgX5CW2JOfu2v/bUXc/B7f+/vnX7HHPw5P+fnEBRHfONv/I3ftC9+y2/5Lbx+JIYQANP/9J/+U/4dYBltiNdj7cS9IGFWrklvFd/lu3wXJgQQAOQArbim7/f9vh//BoCMufzud7+b53m71sJLlfmbBdanX/pLf+linCHxE4nBuO5/+2//LXd7PE+g3bBuffWv/tXT9T4rkCwrx0iNGjVq1KjxfgOWoYKl/TG3cReSS1dtph22VK5KzSzgjC+4WaC6WavgXVKnOlwI9SXhtX8zD69j/Wz1pZ6Ut7guB6Rht0HG4iAntmtDSde28BdtresAaFr+ToUcwUCGSYKpre02G3t4tSP0HU7aMt42UJ2OROnnc+eWFvBSBkjurR8Geqiexp62E03TEYicTj19UKV+a6hQBGymarDFtTR2tdvZFl6zeA+PD7Xc2iYA5snBr4MdAQb32ST5kB0I7r9bQ4ka1iIAKivrodD0Yn7tCl/s5OnM7fVrWGW4ZQa32wMuyQIDRBBezJvt2vr5bLc91J3ZtiGBlRgDmWwWW379Ou4n0HfBcfwoQIvAR8iBC9J7CaHdr7sE35fnSAilhD6XkmP/WwJXF+cpizXmayisIi7+Fmpu/QxIKwiXBcKx/T7O7VC2vLw70Dt2D5SKbL2GuwSg3i2LA0YCxvsLCs8wmQnYSb9n9y5OZQO9TWid4YUhoWxnAbmwEyDsW1pmxKUzGQHbFhSFXKnwZaiVMebp+xt+ww5EA/JxjuH/oOgelVyBCl/Kbu2YwFyhIjbROylmeZeiqFlhikJ7/YFKZampsUr5ugNjW9wYICkUoVAq81izTUfsRJjscDhSRUL7i9ORcPL25sjtuU+f3qogJordzVIiQ3mG+X8Nj+V+tOPegaYnfmiRAc9ytq8uvz+qXaFWpmL5ZHYcpFLddvJf326wPrr90Nqs2862pV/wmm3KpNkoNS388c/jIGDarmxNhbw888PqZ4P3Y42bof7uCJdnJqDOXKNCGZ1AKrc25LkSSSsB4GLQyghac4ESXt9x4f1Fv238TAkY9x/mez0j4btkCKOpdNb7Y7xyfDugzlNZx4VH9ThA+XOk4guPY/QhMoK8RCnc0T4NxqI8nZz9evImLFnO+HTKKDbmTSTqwlZI/vqu8vcE0qVCOCyQdKtKuHJeJmieIX2a6Wy/c5GcUR+t1vrkZPIIhvtu2UIFmid3NIeLpEnA8w8Mi+VFQCmGtgF0vC+gvvu7f/fv3qtkxPbz5wHLgA3lNmiAEigJkSgDZMbPy+NgXAJ6IPA8gBTgCYqrQWWJ5y/VcVDyfWEjQGxAPMC7//gf/yN/x09sTS/jW3yLb0EVI67tPqUp4C6AyH/6T/+JCjusi1AbAgbGdQM4QnX5ngQgT3nfuG6oNuFzi584fgDyCLQXgG3E7/29v5cwDnAH6zKeL60XEICNL9onNwByBJSQgIb/+l//a/sO3+E7cEs+1JvPG7AgwDE+4zM+g6APx3pP2xdKUZwT4w0qWkBO9N1lvK/bEEmEUo0LS4ISomIMYr58wRd8wfu0vZ4nymMCtkNpW17n87RlGf/1v/5XwkwotbGzIpLCeH/ZJuX4wvk+93M/lwmi9yQw98tiizH3w64DABjrAKA8QDlAMRTuUOC/p22DPkUE6C//husHAH5frIVvdo0xzp51jc8bAPTYBQNbEYw9tCNU1/cFYDZeH4H1FMnKGjVq1KhR40u8x/La4XLAhfBVpZgs1dlzR0389DpOsmXA9nc8Gj66tewm+IA9A4AIviRro3vynqRyGEXkqGoMG4j4iu1wlQpdFf+iyoxb9+UNGt7LOs6aW+thTYH/8SP/5LyFOBk3oI5TiyJ9rT3Ybe1qt7XtBgplbB8HVICvHwr16QGwPKEAGLw7h4FqRRTfGqDsawMsD3ZEQa7jiVvlCZ/GMVkfADDvtijat9WWbJbcapZgGapjPKh8FTcg9HGGQJXiGsDY/IF20bZ+gOUh4HEA5AGwWr8TEEI16a8h1II61GbbbRp76QoAUEWucG4VQ8yKwaWSOoOLZ2+zvquQS0DyHvVzVhNfvv/81sdPnrtLPpuVrNlfO5TOWQF4Fyrfue4L79d4/f3vK4+3VCtKqZ4B8FKO7eN08Z5szVEeM23WJ/tzwFWeKF1JmAY4JrvYDiBbgQydaT7gnr44LsaItvb7owBWi753VT4fXlQOcxAql5hntB8olMfcseBWADwmPX7j4eDOPcb5IIxz6MZHWAtIcUw7BfwN7xtHAuF5GJIqOtqdNiFe8FO7AQDBN9Y2GxtPk/X7kx2e3tr+8bXdPHpqTx89tSfpcW23NweBZySOpsFOw8n2x73tD3s+d3u9t8ePbuzRG3jc2qNHe3vy9GQ3+9H2h9GO/dkOvdn+ONvtfuLfr/ejPb0Z7Mn1ia/F4+n1ya6ve7u5Gez2trf9bW+HfW/9EcppWHR4XUNNapvo/Xzil7FhGlnYE+sUHscj4DfWpd76EyDsaDNU2fRUlq8ygu2Uxlv2WV72mywrctLRkwMxntIjA2idY9J6iIJyyBoUHs5e/tXnDsCpit+lXSscKxku+wdQNACTfvS/hlJ5v7fDfs+2wO84b1wnjsPPmbbxIorFtPWEGc5/n680gb2D7eSXnfyqs/d4ynssVPEFXC79nNPzMbax8Lp82osUAizzM48fYVJ4w2aE+xFo7YRCi0gURCJG18c+ZdFLFWAMdfYHUsDX9M0CsBKgBAC6fADuYBv+exs4PsYb7CbK4wPoBLyBMhj/hnoPkBvPwze09BYOaPXFIWBRAd9XABlslce9AZ4hymtG27+d4y08nLE9vmxLKDnDZxlb9qF4hJUCgCKeh2r47WhLAEMELAYuA5Yql4XBLs8BewJYJPysn/WzEgQMdWZA1fJ/T+B/Z5YBUAXbEWzLR8ICYBHWC+9JAJhBRQmF7df4Gl9jASojXmQbPisuVdMYN/f9Le28eR+11xf22uM6n7cty8B69MYbb9DrF3AZD8Sbtf9brXNf2ADg/M//+T9z5wHOAWUz1sXLtn6etom14L6/RXt9UayFbzX2Lq8xdtiVcdkesMGBvcd3/a7flTYv8JnHDpH7At7VWFvKR40aNWrUqPF+oVjW9n6pkYkYktwYKmaA5DMtIYhoCYGBezK3DsVb6xYYhdjNwdDZRoAAAB6+5qzif/QnJqLmy+Uom7c/B0QLmIRnAHWlrvStw6n43PrO/xiIImbJZ9chAr6sQ70G6Du1a2wgtxU8R1dQcQAKQ5loNqAw14Av8wIr8EJeN1HYCqpDwKqeAH4celpbYCs9bSsA2OHX3DnsJkQHaMO2e//S38hnVGAxWCiA+pnqYyiLcWy0GXxPBVRkS5LuDyJvL9TVrLBlHj9nO/UCK1BXa4t+3p4dm8xd/Mz+kGJP7ZbUsey/8Fq+BKqXdCYAfh4TPFayUghwFe94qy+fz3r+wgbjQqm8eB0jF4/LFgp5PGUgfbHZvlAil4fLXM37LEyVi23xyT+5fMeiaFcU4ioE2Jdb7ov9/9nnurTn0AtC/BjtT3BK5bEKuYXNAd4IL1r0cRSZQ9tgzAocCubST9lVmVRF085F7ZXmT3G5qX3cmHq1wnsCeuXimCz8iTkQnutO43C9q3ldJDPCPkEJJ82Pxpmbj0+SdVfmUvksYJl2VKTmCPAX8z92WeA9sx2PKs538+Ta+uPRbm9u6c2LNqB9BSDtaUiF4WSFI9UyALOSSkg2Ae4Otj/ABkfWHgCSXCV8DOF9mIfDEWA3W2FgkT4PmvNIMlGMPKrNuxY/teZuHJQigYU2hEsx2pnJOXgzw8sa/vPAkuFwwZ9n7pToOtlxNH1vTXhVt2db4QRok/CtTu+9KK6H9SHtMOAHAX+XfzJ+RdLMISlB6Vo2F1gnOe6w7hZzwBNZyrwIYpdzt0wApXkFfw5XumNhlGIZ1iRQOPUsiBj3Qi//c+G57zL4WAsW51kUzJNdhoC2wKzqATic5udVFKLMq14uALrMdCkHEsrorIIOJThfg/mYVlH9HZ8fsWNHOwbiczEcqUNh7fA9WdNkmJznwgcWWAZIw/r19/7e30tWGGV8w2/4DQlHYUvwvCq8ywjwE4Gt5Nj6jDECJS3WDKjdvtW3+lb3vh8+oVAMw5YAgeuFDQdgxNsVuCZsY0fAVxfHB1hE4OelVyl+h8r6PrUyQDKuEarsAHxQGL4dAQU5VJ0BynDdUJMDcEFpHUUAn6VAx3VDkRtb+xFl4bP3JnB+bJHH/Zfnh8IQfsnPo27HFn0Uy8MD4wHb5WE7EFv3P+/zPo9FxBCAapeBdoDlCB5QOgI+wi7gPVUto32wvf9FtyEAHcb/2x3vy/Z6b+I9bUsUrwPIxXXGegH7j7cKFInDGgbfa1hYPG9fxNwvCzji93LtwVwE7MYDvtVITADUY+18u+PtWgtDSf+ixt719TV3RQXgvm/sYT3AA4kkJHJgfQTf5ho1atSoUeNLerzH35ikWs6qAAJlh7j66gu7C0FhgBDaXlC22NCTsz2vrQFMcHgUHsuENKnIn75BE+x2gq0oaid+IbDsm/wTNJOyS9v0ocSjipLHkZcpfJEbbANPqlu4zJpNvl0cICIhVUCM9dm6zZrF7s7T2mZ+gccRAYnXNk8N7SlOx9n6Hl/gceyOICa2zwPQTWcU+MJWeRSPwvbrKEoI5WZrD66ubLPtbLt1FSd8nVGAD3AZlh092ugg1aV/+QdABlTeYAs8+QwAlXtdS9Ym3Te21Y+z9Shw5UAIrXc6CVrdNCgwBlAIhTT6KeCq2oFQFdYa/A4JkCjvV0EbWRxIORoq3YKZJEmwd+Yilo6oKXy7eH5VAYgTk17d//7057fe0+1X64C38CFNsMUTEOjvxf2Urh0OsBJcvhBQuwrSN7cTnrFQ2UVzlCrqMlmzbL9sexHwmMdOFiOFNUl6n8Y/5tyiyV1xjvmG12MsYcRCqYvjAohy3HpRMgJPL/SGnQfsJiqWZeEAxTvmOox/k/i7TCPEdblvr6AcoPJZ8xtF+7zoJ2CrvIzj3jyRAWWs+wGjcJzUyJhDkYDB3MBPWCq4zU5YdHjjCly62jU83L3D2BbYPdDq3OH5DGgcQPmNz3+VNhYo1AfFK17bbLf0UofiF1AZxTuhkIVC+dQfqVY59lDKwgrjyAKZ19eD9bCUQVHPFkk0s4nAeLbD8agdA6PaDLs7uNtjPltPQIj+VFHObSuvZRT93HZrm1eN7R5OLLh3hvUO1r/hROug3dXOtg+21qCYnxc+dG0xLS3wWthfYO2RNdCcwBGKKLIYJd4TXh0x/tF+TEigsKf8ScI+RXYkfpaYBy1ej+KV8JrHjhKoefBTymL2M/tymeRh5qyYO2XfxR819rPNM497hh82lNgHJgNub/deNHVNSyTsBkHSBJAZCQ15g+h4C99iWiS5GB6vL4rAqsifEhxoK6zhUi7lJuIPT4pc2uREhiN4s26jqAMQn7Ux9/26YIGxaWHrJKupwrimWOek7E/K5lg3uHajyK0nYy4Sfh8IAdgCaAKYFsX7oBoF6IXHKUAJIA6++KOAE+AhICGUhvCPfbOCYxEAndjKDC9R2BxAFQnoigBYQCEnFKiKQluvvvoqARAUv1C0AUJDcYvCXoBk8DyG5+jbCZbhVYwt3tjW/Ut+yS8hIIUVAwK+rPA9hY8xivfBHgHe06Vfchkf9VEfxTUP9wnYBBiEonBvR0CZCHUnPE+hjIY3NIofYhwDnkH9CVCDzwj4BEeBMqj80M9oy0/6pE+iLyv8bLEVHYXL8O+3I9DP2KqPdoQFCkAg2g2wqfS7vS9gb4At/LB9wPZ+FEoLuI82BQSF0hGFCgHTYgyVnrpQPWNMITkARWe8/z0JeMdiaz78du+Lt7MNA3bCXgFJgYDA7228L9vrvYn3tC3RPpinf+gP/SEW+MPagiJ9zxNoCwB0FOnDfQN+Ym4ESL+vL5DYwDqIdQmJt7/yV/6KfeqnfiqtbhCf+ImfyP8N+DEf8zG0o0DhTIDmj/iIj7AXEW/XWog2wHXCyx5+8ihueLmj4Asb0Rbw/4bdChKLaKcIJMbQrt/3+35f9vP//b//l31+6Z1do0aNGjVqvN9bYUBNTH/lUoXoX4SBaAEkxvSYbZynpIiciy3Ny2JCJVjLBbvkHym4I59Zr2APbCohYtqCHx6w5RfzsrBTrslWbi/OX+6zAiyrvqDalDo3bDcKZZdvXw7AmrY5x33E9vrCF07vdbzI4xZ/5z1m1RuhJhV0AhXhL42/ifnommkDCpgM8KSN0VJZOsMThCjEhFlUmKw1AM76YaJ3rAr8ud+pA4nsC1pYnxSqukVRtwLFJFXf6gKWJNx4FyovFbyXo6M88BIGl0XXyvcsFO33YZOyz5I68OKcb8ZaCkWjQHs8ltcfhbRyabpi7F+MzfstQpbXmP6+GMeXl+aAvHzt4jjRyzF/Cr/zkEjH2Em8O4/V8nrvb2PB+WwzESHfHG7Nd/uA8jUc1wF2L+ZQ3EcUfksWJIU6m/+38rXGobSSIYDlAZaz8lXJMF8DfJcE5zdUrrCNgOL4cKRH6MEfSFqhgJ/mzCTA7gU7AQUAmQ+Hno+9K5UPJ9hczHYaZsMmgWE266ezHcfZjuNkx2GyExJA4zk9nx6TXgv7CjxGeKX738d44He8f4Cv80g7DDzwb9hb4Ceg9ukEux73dPadGrDDwd9QbBBQiEUH8RilmhZ0z1A0ew0vi+cli4wiw5J73u2GAP/V+bLNSI9sR3J3svoaHv++b8yHpYQnCPRTdkhQLENxDsUybD+gWub1p8+cUCqnhSvPg3tyV+VULYv9lRYWOTkVn2P5HiOxqbHvhfZc/Xyn0uTFZ1kkxNIr0tqRL1ZTRJg5bHLCtkRt4wk1L6IZn21udPIBFVBm4ks+VINQ2wGsQW2GQEE9wBfMb/jeQuEMIAXoVnq+vlkAGgMmoLgUQDWKZ5W+pVCq4TUAuPA2BtAFaAgFMSAq1H/Y8v2xH/ux9BoN6PtWgcJbeM9bBQpS4boANj//8z+fAClUfTg3FMeA6bBGAAAFiH5W4T7AeQCfT/iET+DrP/mTP5l+y29HwB4CcAnb7QG5P/7jP57ALAIQF0UGcT5AQvi+whojQB3gPgAv3gsABPBbqkXfLHCetyqoh+QDYDfuHYkBwCKoFgEt38qOAO0N1Szeh/vD/+ZDmyOQqPqUT/kUelbjeRwfxdjKwBjF+Ir7BjB9Fvx/s4DQAYmFZyn035s2vAzAXhTOBAQufbDf23g72gt9XY6tFxHvaVtizcGYgCoecwtJFNhDPE8gsQI/ZNwjkhewq4Glz5v1BdYZWE9ANY/3oAgm1qtYU7AOIvEGGI12BnDG2hH+8G93vDdrYRkY20gk4n6wxl96yL83geQjAPunfdqn8fMC47AcR5jX6Ges+RhzAPcA/bVAX40aNWrUeH+J1fmtTGQ9vvyX/mCpiKF1863GfGgjM7/ADr6deUNbBjwabtfllmb8xBfYVAjM/Y+h+HL1D6Hx+WzH/kh18Jf7Ui/bu96xswfbh/Zw+7JNZ8AVqHfPttli+/bZrujVDA9kqX4BtA/9SdvnG31pfvkdV3Z1tbHt9sp2Vy+5WljnhK8wC2YNKqgHBeDDq87e+crL9lU/8iPpATv3R8KmeX5s5/ONnddbO68f2Om4sjdeg4/W2fanI/1LT/NMdSFgBpSNZ8Cgw4HK6MdP4HM6W0u7CyhtHtoHf/A76a/8QR/8Cv9HMRTM8J7d7h7YZrO1x0/29u53P7bP/bwvsM/4jH9gjx4/sdvjkSDrqtnYbg2V4cy2F4mA8tTsAYoCrtd2gBJyCu9rs65Z2TsetPz5YLeyDs3gYGOhOKYSb227hzuqSP/Pq7f2Oa/e2PVhss9/2ith4Org7B/qCtPCjkGOI6G8kzIuK+uWfqUBQIrhuRyEPN2zn09Ka08IXAz1pbWGA6RQG8rvA1+osp/x7J6qCTj7T7alK+7DyzvBpLgu+LG1G7ahrCMcaPIpqRzLWwDkxxeeDI1KLFfC8gz4AKU6qvlX/HKK/+GKsdWhAKSDVHmRZ8sa9BXG6TCP9D3f8D1rFqqE8hiKZdwT5gUgJrzJr7ZQ0Dd2td3ynNc3Aqz7U2+3p55zuIM3bWl7QSUxCkTKIoJQd+rtatPaV/6yH2LvevmBfdNv8vXsK33kl7PdprXdbmObTWcPX3rIudF1AnET5jIKXcKLueus6Vrb4Fqb1q62D2jhsHn4krXbnc2T/M45jkBcUUDUrRW4D8ILj4ZXrooTQbWt0RlwEnAVIBYg8o1X3+DP177gVfon0yJkPlu37ah2QdHOp9dPCC5ffw2vPdrhAIg7WT+NdhiwDoz29NHexgH2MyoSiIJ6sKqYzyvu0GBCi2tRTt7Q1YHqUhTkhH/62h5strzezu1wrnaNVMvbtb30UDYiV/Q6PtuE9WcarcVuCDyuNnb1ygOufehv9OXpuKe/8UsvPbCXHmrtuYKyuevspVdesXbT2YNXXrbNdmstdlNsNupv/mxs9+Al7qw4rzv1+br1vse6v9WccK/5aH/aXzDL1dt53FNVPR1RUBHjRKOc6vGu9V0EGOMoqPjAVjzP1myVvQ5jTqhYnvd2f+C4uXn82J688Zq99urr9m//zWfbfo9Egby6G8xPL9oHWxYCXkj4cQi/RCnmfcxwXsx2OMBTf7ZHT29pdXJzi6RDr7mHMdq0/Kxh0diu49qKZCVtZgi9VRQrloCRhVqxW0D+5bIY0W4J+UyfbWTCJMbqyl7a7exDXnpFSZhOoFOJX9UUkPe+dutQoc6TrbJHuTYCcC04Qtl+Pttv/xPyo63xJT9gyYCt788CZJ/5mZ/J56HYfJZCtUYGc5hDpfqwxvtnoAgk4Oinf/qnP1dipkaN91XAWgfqauzGqH7LNWrUqFHji+Pnz3NbYeBLa1gHoAhXAiCO+mgr4WSSVhbulznBdzUKDp0vtvYWkJG64OStrK3uSQW2UG5FZHVowpSuyA1lbVIiF+rQvEV56cObXkPAI4AARSKvpbAoCMuNUPO6qWhWvwGKoOheUlnf5fYCIcUfHEqGSjvsQaBYFvhQW0h9llXWOFs4L9NyBBYGDpDTPV4oAKValh82lMvaXo8jhXdyVsupKFs85kXbSqWYd6hHAxe65buDqPxTeIk+7wBcHGTpUfqMs90TyzMmeOuq4mSkypeG5268M4+npV1Fceg0jh0/hwByIWTORQIX13Kf97SfU+rKuwrq6NpSmXzvXT8zd3Svltuvxudbce3yQi7mZaqCeN8R/L9ByuOeQ63pnrbJ09ouvGvdjoJ+64U6nJCPXr65qFuB3LNSGQkBV28qAPYd2CVlZ1Yrs1tgiwPojsJ2LLSpBwvfQc2LAnMOfVdUEEMJCwUzvHt7OxxPdjyc6KMssDzZAa/pJ6qVxyiiSWuLDJZdRMqfmndaqNw9xNbJA17vZSIICpiwDoJqGWrnYaLnOryXYZdBUD2dKWmmVUUzWQupMxNqs52xPrNYp+A/gDrON3q/4O/rqaGtiArRoX0mt1ORPUkqsld4PyxGRM6mFCkSt5lhklKF/AiE4cnsRPdybC/HbPFz8ZEgSyL9M6uWJ783Fi8cBtpfqGBluevhGYP4GVNH11ZYVpQ+MBfPYbyrCCbuD/2qA18WMEtjkW2Zk3DLcZ5X2SywVnFEAOmkFb/wvOd8iM+kvH1hobKu8f4R+B9+8GuFYrfGexeYP4Dwz+NnW+NLfkBh/nEf93EVKteoUaNGjRo1aryH8dxgeRxh9it/VIAAbbcVjgQ8Bu+QIlIlhvj3GfADhbdWtvLiYMnb07/YE1pGUUAAhrPZlh6WK9tAWUwYcaZXMV7ZrFVYD97HtP70XalQ0GIrOBTL2NqN785QMtOL1W0lqI7G+wGlCGsy8FLxMN0XIM3x2Nsbj5/YFirKTUeF7/kMldzO5nFFpdc44As8/CtXtqM/KYpsHamUGwJQDxO3os8j/JbBC+RlDD/qEmw1q9aadUt/5a7tqBDcbnfWNvBljqJgAx/lVnR3kmY7ARn1pPxnm+BhG36c6RzwbD3bGYAMkLhpbKJ6Vp7N6MfQzAYnwX2ipCK30BOKlbQng1mRlnKLcGGvwJ8FxAxFcQLUoXwuKOYl57gE8RdnuoRQBf+9+/7FdQX04gBI16eXRALAaQyVgG2CRXxmRrvHJvUMhxKsumDPgYyjeGQG1SpLmWxl47SF3UMqzVXs2M9JFV0jxgkgEstdFt3Dn27ZQaGkg1tthlfbh7WM1NrgkTOTDqW9SKi1NZcwT6WijIJ7suG9SADRAkNFAjHWtlsolKGKDbVozE8VrqTyusO8QvtiiZJqnKAQwmfMO8yB7ZWdQVcd2xEswiMaAJRqZAw7KacFxZHogo2FKonTloAif82n26c3dtjvbb8/2PXNLQHzzc2BqmvsloCalfN4nux4e7T+MfySe3v85AmLwr3xxjXXjZub3vZQtZ7NTlwHZiqgqQIHyEVbr5CAQ1/KFiISOTHUlIhTe2OXCOwURlvZYcQ4mQwiVc7lXnMW8x7nwzo1b+GfizUXvdsam2IebaDPcW9ju7Y1vd7NRijCZwBpqZ2nBv7xUt5C5YznoCRf45px75tZPtgOQAf4TTfwjKYhs8YTlP/uWY6rP0NtjGdiwaafsjyVqW7mDokNvZpXZ3l84+I87UWjH61mWs+T71E5tzxhIAiPPp04BlAw8XYva5LjCQUVUcTRxzQGEyG9bCvQ4FD0Zn/nPDc1vPQ86yoyz+q+7KV1S0xOfE7CS9wf6kfZRmEM6LCeTkpWOvl+VGATc9LtiVy5DDsROYtg3qgwrXY/aBdELt+nYzIBEHY9BM5eWJAwWvyaYvKL5bHGswP+pm/m7fnZn/3Zyc7iiyKgKoB/55eUQBG+ZwWUo88qbvi+CMwzeG9/SQz46mJr/n2BImhvl//1+1PA2xyPGl+4+Af/4B/Q3uFZcXNzU5u2Ro0aNWrU+EAHy/yCy+++KBRVbM0HVHaQnJV3AsXxZVhK5VBo6Yu3wISrrWYV+SPocWUkGITsNPTFHrYEsMAQABY0i23uUoFJrowv0lT1MbIHJhWSIR8tVJy6h/At1jds8AZsdYYCEcfbtYAu+vJu51bgCuBiEjjBYdt1y2M1EyAQEJBUfbweKqAzaI1zlgw1eXLSX3ldeCt7O7rvs/yq49pDxa0CTmElknyQQ/HoGjd5VmdpJAohRjvgTlBI7FLcKj9SKSVT7a6S2i6Eg4AVaKNCVZjgT3HgKLDngCjB5WXlviLK916OzHysBc5MAkK/YEC4lNLQOAyV/EKpfOechVoQBRoXvs6lz2nxrsILOv3d50R6TVn0j3Mkc+24zXw/RY/TmmN5zWk8BRxmhUj3TV0IKUOFLU/uBXdPuwJK/W/5N1dBFuCcPt/3CZcx9+OgfjPscp+/bRTtozo5PGoznAuf8vBhRnHKaEMW4ZsCOpYKWQfYbjtCP1mOae2wSG1FlS9QLNYYAFDRcLy+P/Z2uD1wO+zN9TWBMkAxwDDXFzYBwDAsLXq7YSE/WGHcEizDIuR4HOz6+kSYOaEAoAFmG1XLsgvyOYz56urtSACUClOsmXl90P0S9mOXAdoR43cNH+azrVDsDyszmO0MexPxQ9w3W8nXoVUz6Trmsw3NwHNPo+wZ8BPeyoCua6qJ4b88qrDqMNgEaImkmO+qAHBeTQDtshyBjYwK9cUAynMs1mEVWMSFSXUu8KxirMkXIgYgQW9W+mYrnfsqZRbjO5YDrtG6L3lIQ5UtyC/47arheKR12edmyso4AI4loPCsz8M9xnAxeb2gHvqKOmWvnCqnoLwWpVsOh59iHUveyuX9ukI5uH3p5xxriSC7O4gzEVgW/vN2TUry/KPG8wW8OT/rsz7rTZ//4h7Y5v+cLmgvPN6sLT/8wz/8fXot708BP2wUNbwv6lb6Gi8iPvqjP/pN53ONGjVq1KhR4/03nhsslwXIYpO+oxxXbxXYLvbpFliOimWHl4ScjZS+ocqjhhdCtZVZ59u54QFLpAAWEcpEqDEBlZu10SY5FZWLL/wALO4PDMXvprWu3brSVIpavAaAgciGnpq+HT99UQfAWdnN7YHAaNs0NkO1TAjU2vk8EKgAXMiCEwpB+fSymN5KCsNQ2M3nBjgqgWCCJb5xpqcpvDmxhR0et1Bt4loBglRw6mA3N0/tcLil0hownWpE7wFgX9WeCgWpoAEV5E4/gpEEvWjQtgRPNNm29TzZel4TNuL6k0cy1d0uwIM1B1WNDve9iKLUrRl8RmRsHkBuITl2Nu2Aw8fGm8by7Y6LM/wsjpohdwEdM4jMClw+45Ao6++1pbykSOVd5eqVxb24TcnlxcouAO2Ti5qFGpXb491CpbitBKUXLRnK5USgArCBA+J611T+B1zlGIbdCcYDNxjk8orJmSLsJ3gKwTrZCMT1MNWgViGk008BOBXh89nv7ej3XBT4ZP+6UlLwC0pXUM+iSBsZWIaPUF1ibquIZ/RjdtTAL1M/Uv4J9SgEzcmGwS8nLHeE0TCq8VoBXSao3C4B0BHz+PrxUxsAiB8/tv3tLS0vYGMB+4TbPfzM8W/4ik92e3uQ5QWVyfCaHu3x01u+9tHjPb2nTycU4tPljFDBAgj7OgXfau7ScLXomqBZ6moCzeTxvRxNQuFS5mJ8Esw3K+44GHG/84qF/HAEFArEOtE1WhtC/U7LjEFt1UNlvMpgeb0a+Hd4W5+t5dpxc3tr3annhfab3jYPHtjmSv0PX9+pwXjwHSHnta1b7CbppF6ODo5xUsyXVEWUf5P3/mqNjmtsfcai7j7tYYnEjympndnVk/uHIwnBXSl5fiTzmJQ4FJjGsbEbBAk+fqowMemfQZ7hieRbHmxhk6Lr0FqPeev9UyQDkQiU6rnItPBzSOMZOxt4AAx/X1O4g4Vrj5tJRaINc9cTepEhw94d/OTOFn6+IJGpZGnbxhrk8yUdQ7tmdKmFNU9aZ0KyfM96W+NNC0B91Ed9VG2htylqW76Y+LAP+zA+atR4XwUKVdb5XKNGjRo1anxgxnsGlgvolR6hkk2vWYpZ44v+wlKAsJN02GGxvkBjezi+AEP928EGw0GtfIcdZVPdDIUwwEAU38MX6Xw+fJHHF33YSWAbNwsouVouvnSjkBaus/MiY9zGDMsOvwFwj9v9yfp2tJd2W91bh93eDYsvQbF8JmzLAHPtUJmANrZtJ7Wh+5FyWzK2Y6uYGI6H4ml4rxTaunbAJ/iBno4HO+yvWWSLyAcqw8ITE5Ya1J96v1DhScYG+5CkUVN/ONwknHYbA14hxHTuucp+KQSxoaKlipzQGxAKQLg0WI7eLvvd/xU0MKcYMtjg7vNQt77J4LsUKJZi6ULNl1WAWe+7vKoLb4hCJeroKNyFHbTevbd03BAQehIl7CRKhE5f02QNIaCYboHjURDoPq4TuRnNr5CK+/W5bJz9DWC1EjRVRkEWB8TPIubuoR3t7rsELpu4UI+nFk0gDyhRBc/Q/3qx4LKOlKG5Eie8e5/uDksJ7TQzeCWFqrWEjoR9YS1wsZbE7gZYMsC7ZWEJ49slCI0pDtWoTxYYsAiBTQbni9odcxh2F48fPbHD7d5unl7b8XCQW0CDgp6D7Y8o6gl/XlnSPH781G5v97Y/jHZ9C3/l0Z5cH6iIffQUiSg4J+sAmIMAyATMvvRBYRxAUePA/aCTF3wkjvKo0G4Q9RrGEBSvMyAmbCtgYeMK59G9g4+wDMLY2midjLYFeIZqF+vOAAsFvHeQWnttA2HvtG3otgAoud7vCdHou7zp1XIo1Af7E9hhNADYsNUA5GytgaK/jap3sTOheBS5pmz9APsG3DwSf2gH/0xIcxRAFjMzLByw5o8ck2vuJNFOkXz8u/OWxwfgbTtrCZY1lnjtpae676SJZS0SVJrfgMPqzIXLvRfiXMNbxIt+lnBZCRKptOPAsbOAfVkC5WQJJFV8OhP6nhBcCUyBZRXjo72UF6lN42WWNznP4Z7YVIurTmuhIL/w1KlRo0aNGjVq1KhRo0aNGjVeGFgOL0wvMgc4gi/GJbiLL7elCsyfKl4VX56lygp2RMhGNWP+Qh9b4hGEcviyjfcQWgByChTH1uJz4ZVMm4yklkuIgF/IAY5aqOuK4k1hNhn36HfN48KjdGxcHSlESOUjYHG8NymRE0zN9wW/UgA52Uyj2hbeItVhBz9ZqA8dTuIYACfD8WTDOElB2ff0C/Xd1G4PkoVmLFbmsAwwntviC8gq2OzAxdXIq1CGQ9nooN4FgmJ9bldB8EK1MkAMoLQtVNF32bKrFIs2z6bB5VXHvz314HAjgdbSW+Gyct0zSHS0Sqn0u3w2y5391al4XPHiy1NdWGgsJI3lpRZF6KJfqAK+U3gyW1dkIOTX4URZ9+Dj887dhtq5nE9+zJh2d9ov7qA8WobUMQ9oI1H2Hee5wCcUu0qSFApLf41UmAHcsqWGpq0/I5E+oWCA4FAzh/VFdvjIYFHq4wz7A7GHYjQpotNWfySo3GN2ghr4bNM4OFjW/wG4oTAflMrHw9GOXqCvH0e14XS2w6mX1/IAb15YYkz25OZotzdH2x8nu7lFQb/JDkf4NsOCB9pTgFKBUCTK1B+FTcnFGFBCIrJCvpZQ0awEjgp3ZqFvvAc+zQDF5VAEb6cNEdZWXwdnCILTSjbb6TTY1CJ5JbB8xjGQ1MO1c3nFq9TmSExh+ewphe6t3Q22QUFTHNQ7HkpwHAS7HqAiRtKND+Y7cHIf4+63LDsI97hnY+CFeBofRTlhUWYUotCexoZ2lqRKh5500Q4QpQ30Z96YJz+0ZkGxzM8swNZQ8tMXOa9jMe/CCoY3EgUWI6F2jwUG5045bMPAmAmBZnEtSaGsMrfFmPciqSkJe2c5StYXec6GOLqYG/g8YH8oUVuuWTHXlotR7P2oUaNGjRo1atSoUaNGjRo1XhRYpsISwidXWsI3mDpGfeHma4QxXaUlQBLF/MIOIaDVSDWne8viD1Qoy6M3viDHFmOqb0f5DLM4F7f+qnARwQBtHwBkBWPg4QrVMiEuKW6GGLSooOUEABCe9S/g4flcKrlcLQfrUah7t50Z6vcBKk/zYOezCrmBdIS3bYJuDgagwLzaXUldCbuJtZTOeGw2G9ttN7YBeHYgAfg12Exl5PX1rT15/NT6/cFGFMli4TW3COG9iSiifSBExCG23I4txbL6zQGEv1Z/E8TbNmvb0RYDQHqmrQLsRYR3XIXuYBlFwQCfAaLxSGphB/s6sODGAh2nynEOW5NCL73C+xDtl1XGd1S8pbIw3pteVBSETD8DLl+C7OJwySs5K/6yCro8UlAtL9B1cSy+olQ/hiqRkDYhncXtRPE+bb0v7GPSXNLfOF+KrfpllN6rbwqGinuUWnd5j3yJJ3Ewh5L61yEdio3RH3gabE3IhnGeSyZyvHiSKaBwJDJoVzGnWmYEmRTrk/sJYIaveMDlsvihEi4OC92uowTLJVxWYTN6DTjcnm3oj1w3AJZpocBjre10OtmTxyi6B4/kG0JmAGb8TnW2zbTCefW1R7K3GFB8b7bHT25pgXE4SbGMdWcY5ME8zipUp3JpgpBnzCsHwWyTchzHcHYVbHhiy6Neilyucdn0xdXJgMqR4hKP5nu4vuE1sgtCP2KusoAcEgMTfJQnro84EXdKOLofR/nar6EER8HDc2MDla9G1Xa/GqzZ7Ww7jtaMaZGQnQiO3422hmJ67m2eT7Zyn28WJmThRlfYcp6rwKIiq359ZqQinkn/v3KC7rZFqbuxdoefNvsWCQufxw6Ww+cYnwfb7YZtiUQd3sNkKWwq8pncXKjw4MdaHSr8lHAUEKY/diQXL4NQGcVocV8OltNKgp84hl9juFIwqYjPo7y+JO/l2ACT5sY6PbDxB30MlTmTEufRJsB9/wwVyhYgh8IbO09kCxO7Cp69dNSoUaNGjRo1atSoUaNGjRpvC1gOwENrCgAoqhClOC7Q21LPSSqXzAUWAuYllstetgmQOeTFufglOFWOW7479JeFaJRqPCryxJ3SlmwiqRC/0ZtWW8zTNZVK4/gX1ciCceUX8LRNX5LDi2Jn+cFt+W63AWUyz+U0HWAniqBpR7JUzxQ1AwKNstzQc24nENYWqbCVgt7SZC9SOGLLvQqOJWlraql4D7Eezy+wTm4jTs/CY+VG9thGLVFplNJaFqBLvXmnn5KML6m7Qw1bvvPNojjC8g+XGujkd3wX9lz+JYHa+8BQ+aoEcrLKOd/y/Vcv9XkM+lDT3z1P0S0+lguAHLw+XaOru+9rhCj85krtRZHFpCwPdXaGmmW7Xfo9pyJhxS6DXORved8Cb3HS1cU4KB9aS7SrIRI56SJC5p+uO1SWRKilPD5e496y+Xr8/nxespgfHiygifkkpS/A8uFwpHcy5hnU2CMe08if/TTa/nCyw7Gnj/IwqijoaZjtNM7WD3P6G/yTo+lTOThXnWdzlXIEXw7evJjEcS4E4bmb/ZdYf8OnnAX1fL0RYmUdPhVaZdPpvLRJYXIAlhz5WgjGwYLpzStozD4FrPREEttunKxhW7rql8URXYXsymFCXcwX7jBBhwdUpjnI4o5S/8eCEH2cFnZPVt1pMt2rEnawaXGf6mz8kxTAyQbFx0LMp6TST+1a2lHoJ5MBTJD6+u7rX5qwRWjzSsDvsHC6SDgl64myk5ftUSZg81DPn2E6hYrz5XkZ614xtzkNUHAxClhG8mZ9sWOgXFNq1KhRo0aNGjVq1KhRo0aNFwCWN+4Ju2k3KlSE76VQq/HLsRRl9EUtyqMlMOKKYv7bv5NjmzbVVA6M9WUaX/xlb4F3t/BJ3nQ2oXBWj3NokzaPs0Ce7k1KxS2UaVBlwYICkHnmY7VGsSfcg4qboTggdWPwGI2LXEhhVRSNHqC+hT+2PkNdSa/lVD5PKmZufw+oM882DgMVy91u54rotTwv0V7n2a6udryOpoDLk2/FB/CCr+vZC/yxyJ97IAMUYKs71d1uTTBhO/pqbdtWIGPsoaoWlJYS0rflE0joTqEQ33iRtAbtgvvqBGbkRmvWQw09a4P/hoplPOQXHQ4VAWYC3ERIyXw5GmJMSK3O9yb14lvTjQR10+szvA4Lk6zwvE+tfEfuvIykFgzauyqKc5Xw1S0VALVoFeHj2SFf8l0urjKK9S2tHqL9AkmGatnfuVBP+/0t7CICjikpAUWxkjlJVJoSHBROhvK6gFISpWfFsIpoqrojkwks2OjKTYDUSTY26jYlXqSyzTYy+aqjwJ8MIprV2TYd5nXLQnFNK+Vn2tLv0Dnds+8uoJd5go5oSO08IHQGDKVdA6xamuQLLRgNVetgwwD/45Mdbnu7vUbRvd6ePH3KwnwEj+fZDqejHY4H2++P9vRmT4/1d7/2hAD5zGJpK7veT3Z7mO14mu22l7pUIllX3juEpYvEDG9nV26HNYjPRUK+SCFEgiUVnAsO61YTWdLqvWZUUA+uOm5xz6E6RzvDR5j2NipaJ9nrZF27YtE99GWP3RTz2boAw+PMfp1xPNVGtNUg65tug50W2Lkx2WF/YF93my3nRLfBWttYu+nsPK1thuXI0HN3BYGytwmzZVGojwr/RQYvz4WijmeyGUpAN0a0FmRYb9CeBUm4vud7mm6bExGmdRTFT+GL3R+PBOPYfcPPqsg8+i4T1X90Sw3CZLcGSdekGRpezOhBJuV4rGzNgjG48aKsHT30UfDVz4H+mfE5Mdt5jEKumq24JhTiww4BjMlYpvSZKS91W2OdxnjErgEkRBoq5THvU+KHTeRFC93WpNnAjmltXbdhovM8yNpGycVQbdeoUaNGjRo1atSoUaNGjRovUrHshaUAGQRDQyVVqJazkDCBw4jFl9elZPleBaoUy7kwWlIiBtBMMuLV4vXpkVwtwq7BVV/8PRRdhWzUPSmziDJXKrz0os1wL6B48TPABB6uBtaWZamJWYzJJv3O05RqX4dQBVAkhA87i6J9/aYTfIl+KsFcvsLc1GWzJT9sh9utq+5K7Sk5GI8dha8Ki4tnRAk+03UUateEmpOMXT/TqxfV8S5Hxl0Ikl+Wlb3le0o8G9YdZWG5srXS+1yln/s8Q9y7l1DIr+9tmgyB73tngOeSNJanuLTBuHvfoYxOdchSK+Tfix5ZzN3ldQa8173H+M0uJFmRr9fr7wHicwNkP9rF6Es7EZICfNk7xe14/4SFR7ze7TKCTkbxvnSc0nOXYFnqfyR6AJePR4Fl/AQgZ3KIit2Rr4P1BdTK+0NvxxN8lLUu4PyAzAOUu1T2xrxaqlfDMnl598tdA6kYaqnGTc2Q+3+xGyA9rzbnlfOn1gwpluMB2I/XedHQsCJJxQCV2IFqOc5G33fslgDLnqNQYN6tgbYCmG1HWIvIDqKdG/pBJ1Uv1z23p/DzUrnMe3UofO9QLtbAcv4W45t/T40S55lsZjHVUWuYsiLZS9v7XqplQOiZFiPJWz2ph+M05TouuE91r4/F+2ZirIdp5eDQDP/lSCQV8zsN8dxfdxTKaSxFskkDJT7P9FkhD3deI+45Df9Cgeywm4UzPRGT7ZHymlajRo0aNWrUqFGjRo0aNWq8WMWyg6BN09CncVo3NkDpRHqs7cUAoGHtoJpBlMvqC7QXW4ov7Wuod9PGcX3XJrR2hWS2sQjwqqJOAQUBDKAG1BdzQVtYTUDlDBVfssOAR2kCN+7TjIM2DqKSyalUxOtVS0VlOJviMU6DrVZQul1R8TXPfbZGAFThtnKp28ZxsLEfbBx6eXm2XVKFdpvOlZtQuJ6t3cCPdWKRLHlirqXMpBKxtQdXD3hZh+ZIVRx9WimfywWpqJZzFEPFNhTLpWcwbUQKM4LzygYqmaFZhVxUvtZzJAwgUcRb/T3UtLHfzbrGi/wVsFkK9dj671Ct4JVRmCzAe9DJjEHLiG3iJWh8nshWAvq1UBznC+FfAMfTWKAqPSC5w5iLgo+Xx8/FGtVAsgXJVhklepJ6dqmqvlRvl8X34t5DHSw49zz3nn9etheBIIs7Sv3LoZ+SQnheHuPydnXwC2V8ul+pX/H3wV8jtbInT3z0sbimg79sRxIDJStNV2vsIIDnLd4l1SU8faX8l6oeIf/X7ONM5TK7Vn68TddZ17XWdqioqeJ8XIc4xwGID4SNx/0t5+LjNx6zEOZ+f7Lrm1AsX/O+B1phzPwbCvU9erK3V1+7tsNxsEdPTlJon6Hdh1JYyl5azcQaV/qaJ7Wrg3efFKEKFcd0WB9JDYfHBJiLBIGrZSOZlMa2Xl92ON46AoBj3OF8trIea8lE+2GuGxsobLuztRQTzyoCCiA7Q0W+pqIZKuVtjzZcUU2Oa+ZxMS6PgxS8Z7PNdsuipFDk4l5GGNEzHGTiRevWVg0eUWA1tn7kIRvJB3kO56xXJJ30032P6WuPOoK9nWkThH4buMsDauToi3XTpZ0ZeG7sT+mBzwzsNlFCLWCt7o/j09cqKMWlGEcf4PMIE8evibYf2rGREnORTPExy5djrfQH/LnxmcUHzUp8x0AkBVg4sfRl12crz4VXUbkPAbpslPB+jFu8Fh8rgsTrBPBVKBYJHBW5bVlANgpkYv7h9S3tT87WVcVyjRo1atSoUaNGjRo1atR4sWAZ0JI/obxN24DxX3wZ9y/o7sNbimkvFaEqRpYLj5WAMn9Rz6qsJCu+IGZQ3BFmhZLZwTbgaHzZLwvXZbaMbcTh85mkz0nlSI9mUXGxMFeGya8YX9QBUsqCTQ66+ZoVVXGAVQAT+Mltxw7ZYKHB47sVAb0u0/EFEWcoDAkmGtt0Gzs1gtihMFXRtBDUCgAC6o5hJcDt8FEMLmsfk7oUOI/qw9KmwVvQFd9qy3BPEOhuApSEeC71RDIHTtA1KfKS2ruwY8gjwUWp4T4b3sC5WF0+RaFiLa/3vlgolxdddAF0L98kwJTsFmIPevKnzq9d1l/LQPdSyb0ETuU1PVt5rddmheqlqviuktPtNd4EwquwW/RHSGSX9yEwLIAXYvPwQBYoDpjsP5PS1QuyFcrzmOO5Ycq5DLgcddAw7gXZSn9YjduchEj37cUTG9gMtK3mExIjHMysFKjEERSs42DThCTPiXYIx8PBbm9u6Zl8OJ4IkWF7EQplgL/RH4dDb9e3JzueRjscMZ9nG7yAG5JAmOc8ZfLcju4qditcyFN1/UV/BZBMy5AnqRbjQC725ZDm+xamw9lTmwrjYoyEiwZ9oCHobbzo38qouJ5XZ5sHtJd7z69hq+AFF/Ecjgc7D4DV1VmWIDaw7VEEMGxI0DBc71ZIGkw2N/BYbvg3FRktkkpKcZR3yfFBb/lS4y2SvJgiCdjDLxlgGck4KJGRzBt7jo8OSTO3BklWNViP3a9enxleTNATJpHcjLVSeRDckyd3UMAveSvnnTJJdV74FIfAPFm7BOuNYp5uC7PYRVLs5IjdKqVSnOBb4uN0zeneqBbXWOTnko+JZLvjBWvxOYQka9o1wAShig8imXj/7oUaNWrUqFGjRo0aNWrUqFHjbQLLAkErFraiOg6KYXyxxRfi8JM944v7TEVzqPj41RYKyPjyGkWQHGSFX6WAdUBhfRnXNmao5KBmCy/YADL60k+HUFdmQU0Mj0qolmWh4ZXorMVXaMekeKCwVN6GD5Q7nUcCKfgqN6tOX+b9XFDuNeAQXhcKoKXvoZZzxTAuB2AF7TLBz1UPgWVAZlw7VKNSdgMGANTAHzO8mflFnxYeOhFumf6s8DtmYbGZRcKk63NQXTgYr0qwVwh4Aw5SoepoBOfEWQeoNKHM06Z6O68bu6KvNOCSVK5nvz8kEzbN2rpG6j16qMa1+Am5BTsGTAKpZTGt0ku5VBVHsbW3GoTpPxnNFsW2wn4jvSzx2fxcUPHwMs7gtGDZUVirOIeU4lIeln+Pol4snhgJBNCkEhDfGxfPJci7hLTFybxg47IFwuaAfsjeBwkcx6EdQjI5EqpYv2dAwtI/+XyREcgexzmZo3+7PborLlnMjSeRIlUetvJeDpQYXBvrBmEt1cJYRwJRa50I7/Vg/Lw9+ovjmA1VylCdBlyN3RFYK/rjLdXL/emG9gf7m6eci0cU6usn64fJTsNkx36020PvSSBc59r2h4GQ+eYWVhiTnfrZ+hH3trLRVbPcJJGKFcbY9r4g39balMxEXKmsonE54RPzks+x3UvUeKlwz/+Kwp3hC5wKIBbDhde3WsmHmcpbrNNm68l4P+TITBLBYx5g3kko1rvTZPubg7XYntCctQMCa1uzts1GV8L+ck6egPKI+5tshAfwejLsxWg2sREkkoM5eVAWwoxr1/UrScD1rEiCqH0FwadhsLnvCZahWKbHOa0nNFZhdQHvbPwdu0YO+72djkcCaKqQsc6DfZ/l051AMcXAvpMG85lWI3Fu2IAoqXI5/1MCxhN4kaxLa2JR/FXnXz5yUcxIEOBzxBX8PD8K2OIzQc7o2LmC6xswJtvZxs533LQOpIvPyphPYf0R4yYVteSOhUjs1KhRo0aNGjVq1KhRo0aNGi8ILAcfpAUFisXhizuhUQAvV3d5UbmkVI4vuPhiHFuMkx2CK6d8i27jj/A5BgCF7ynOGd+LEwikr2RWZcF2oG07guU4RtZLZ6iccay+X7NwFaXCIC7YTo9t0lATu7/zChYbDpZj2/k0EywLCLfuXSoVZ3i50tdzBFh2GBEkzqFyAssOdcOXE8AMxaWGBtcgJEclJWCcQ4WsBnclWyF9S3AqfD6TN7NU5lQHAlafz9bPs/XsT/TlTIAugwjBP8EzvBrta9a1a2t9C30qosh7Lzw9Qz7JJwvCnfTR9my4fE9kkOrHvtDlhhqwfEX5svK6qIJ3BWFSKbtCO4mkszlpvgZaReRid/fB44DPpXdwfu7+u1ledH5F0KnVHbi8hOt6Oq5Lliahmk14q1DJ8hrRWxxWGIMAT1JuBiDna90eJqwpWPTSFaP0FXawHIkEvDKKc6bo4OmL60kaSU/jyFqAfr1UCAsARjtEEcG8RmQAiWsFOARYxs/wONeagLVlsNP+hurV4+GaoBkqZczH06GnynYYJutHwOXRDieAZcxiWNSs7XicCZdv96Ptj4DQs/WDpq1Kh+bCl2xdtncxztL4yWtPqFljkoqTFvY8nP2eHEjTIPue3xkql2L3ws+e8JjTTvMR98krD/XstBJYxjLka+xEB5GVBLxcHiY72GhdB7uRtc1dQ1sLFPULxTjmQqQtmDSDnU8jBTO9l0ltZ0HlYJlYJGPoxj9jbYr7iWRjGsNh2Z0TiRgrAMuTg+X5PHhfyLcbUBljA/4fGGNKKhxolUGwzMKOIz2J04mLS/N0pNY3X7+Y1sR53Yojp22yMj3tVEgWMJHYLKBysr1QMmEJlv1Y/nxOYLkqH7YmaxSD9PkJFblXlY3+SLbjLASYE7or2F7ELpyUjJI1zWXB1Ro1atSoUaNGjRo1atSoUeOFgGWogBHzDDWh4AW/jhbFf7idmUWi5NGJfcC0cQhlXXG8XAAtA5gQVALQQlELZFEWrsNPeSi7IjIpGrW9Hio72jVgyy/gihcwC/WlfHHxZV8wKiIV4XNf11IvCHQUNrFUy03wMtUDx8d2fJW0E1nC9WPLOL7sEx6MUF331hpAjQoowZcUPtDbzda6FirrTvDOCy2t/D7nDkBb27oBGwSYz8Tj2kWfiwbGfazgT+12Hg18Ul3pSfyRlKiCgmAXA5TRVKCfbTOZDW6BCkiJAEgM8AVQgeuGBzUUketpTFBLLZhaNP1X2CU2pufeTz8T39EAoe6vtFK4Yx9xUdxqwaOjMKADlHtenxIaDviefT3F8zF4F4y3PE9si49rKqwKfGwlJ4BLuB7D84Ihpna9U2Qv2taLCF4AcD5o6eEQFP0WVgnCuqnAo+CTA/JQky/8bTGflcRJimU+im3/AePc8iWgNt9Dqws9Qtsb7YW5Bi/y/nCyabtL2vtQYIeKMinum9aabsP5tnYbDJ0HimOaCNs09ITL8FQ+HQ+EijfXN9aferu5PtDiop+UUIF6mYXt5tmOAJXz2W73R7u9PbBgn9w1wixeBTd9JUt+4jHwmWCSscdCNVzI6pd9G2tf4bKSkx2L0ZG1r55IWjxXAMsYCcwZxDj0Lg1rIK7ZtPMAdPbdE7C3wJrCdUKpt4kbTmaCeA5lrMfwrG9b22zV97THYIKtF7zEOOK1oJWgBFZyjSk9WmIAcPrnBZMHbkXB8a8dHYS9BK0OSr1houX5HJTuSNydTp70munRg/GhcdFyLUXCYDhN1vuDNh4J7s62micVOQ23jaS+LwtTxk+fD6Es9rkXCZNUtLDY9qDXLe0u0Ga4Lqzj+ClLGfdWTrsB4hixpl8klnA+fLbSIzllyfQU5jP6yceA7C8EmxM8Zo0BAXpZg3jisyqWa9SoUaNGjRo1atSoUaPGCy3e16HAD6CpQ198J3d7BSrEoJxiAS1ZSOB7KopIhYVGbG9fyLdSSB2GwlAIbPH1/e8s/gRQ3awn65qV7TaAxwKcYhmUoFFtu92iUJGDZWn1EhyIYmjzaqJ6GqAjlHyBZkLFxYJl5gCXYjuoPAGJzzb2Zxt6syO4xnptmx3ArysX4UM6mR2hikSRL1Da82jDaW/nqbHN5iVaD6Ao326zs+2ms+1269uVN15cTbBgs0Fxss6ePm0dBp2tn6S23OL1RIQs/Vdss0dfiIWhffA+tBXZIVVysi7BZeG5fsLzuGYVI2vasx3GeI/unuprL3jYNitC5asr0PPZ2mGk2jlv/16G/ua+vYk+l7jlMhwuJ9HzPere+96T+jnSCKXquACvfn6NrVDxZhuRpH5esOzSNHxhLh2IdHkZYfjhFjEFmVrYXGSF4112njyTU04kqyDLnwGzc4LBii38UFgWHrDJ51ZJopVvkUfBy6UvcNyV/I9RoQ4WClJRyreXHGoB7uHT2hRFAGUZAouFKMpHiOeqSfk1T3Y6HG3/9NZeunqg3QTuUc4CfF5UkEAVBUM3W2u3V5zbsGeQHQ58c422CGfYPpxubeyhTj3YzfW1nY4ne+0LXrfj4WRPHt0QGq+6ja22OyZSzueGkPX6dm+n02CPHl3bzc2BRe9G3ANujnMS6lAfV9FtBRQmoMYvvN/CDoFrYR6nMQfSfx1eZi/uIoGxGP+5KGQkv+LkXhIvJe/ib85r+ejcJgJ/5H3NgK4OGgeBZRwE6mP0965V0rA5DvRMhn0x19xuY7sHKyWiRi/kOAxMUOCyum6yFoUORwHTdXeyppttvdlyt0SLdcWL663bzmGt2xyxmX2XhKEQYKRtlN6jgh5wGIX6Dgc73t5Ygx0UsEXx4nRNqweOP/Sz3d70dnvb280t1OpQN/v8xBjzuY1lTgri9CmQPi9inKuPNfBDoe+fFMnrmOM+5ZH0WRJJPw2NDJUH2Cj573PxU0UBYx2LtEIB2ZPdUOy2UCKSv4dim4Vr0Ziwf5LFknbTIBOKe8VP7axBYUNeO4y1a9SoUaNGjRo1atSoUaNGjRfqsRwl00oBVWzvLtR1lnY0A3QIyCz2Gt935LIQVdgKuMVGKHjTNmVXW+qFvqU7+RPr97IqmTwrIcN1yMMv5lG8LyBaiTqT58ZdWBpewvyr22qsZLGxXgs+h80GFWwsggVvzkGezQ58N21HtTIgLYuQFT6i6RJwX34BgBayDHAA4QrXBEJC0ejsLh7JPoTHRLtJWTvE7VAECFAkWBSPBLWSg4dOKJuCmXAP2+qzLUMA3TRKCjuMDFaJYXAN93HlRYG+oiDghc49qwjLgVjA24Xi+L7xVjTyQp5cHjKD6fJfC2Dt954U1uUho11KuHyhCFy8p3jqci6kcRoqxgSd/fh8Uur8AFpM0KCvEiiO9rx7n3FPWcwe6v7ltSZl7AVUDruN2JlQWhks5lWyucg+wyx0SbjldhaF0jypdXkfrliGHzTf7wVAaaEBpfKJ6ufT6UilMvx0jyjQdzzZ4YDH0faHI/+9ovWFq/UHFKSb7HTqqVLmTgNXkWpuax7i/JFvSMpuT5bk/Q2pNRbjrkxxSMh+Rz6fXhPrWthiJLCarDMwd8p55yPzIlEX/ywdzeO4i6lTvC0Us0oerCgGBvCkBtnXk5h3WE9hUcSW9M5kXwJ0AtrSVz68ewGuJztPgO+ybljOhZwkiaQadpVkwb+SjrwSLxIYIFfFUOW7TWsU2qNI+U6FPUA/21zWTMVt37GfCKCMhEw5xhdjvlQjFwX3cq7UYbQD4ih2qevB2PJHPJcAdmmJsSxkKZuX5dKSEzoBk5v8YEFYtTM/n9M81HouZbX7bocdx51Eb40aNWrUqFGjRo0aNWrUqPE2g2VgVG4NJyTQ9vYFjCWIkccrt/xS10yEI39UQIdQUKZCaw5a3BxSyj+8V5Ya8BqGTQRUVtymveauZwJSQVsoqVtrN2vroGTuAipLQYj38Ys7lNTjbF3bWrPayK/StwZrq74DMIcUOK+uxxF0Eozi752tbGtmW1utOmuaB5BIW4v2IHSFX6uKnQ39QPXeeo3t/p19yIe8yx7stvbKyy/Zyw/fQbUcbD14Hrap1KYE6L7VGdd5SzB2slM/UrHcuT0I7p+qSoAf+CCvUFwviiGqzwKPt2uzTdOo8OAUPstr6+c11crj1Fo/NXYYZZ9BfTpU0oNgyLltrCHcm63tztbQMoPl4hamDOXPJdrSI167jPssAO7+nc+VYOoeD1q9JihpTjIsVKElA06HkK/w4rxuaUEI4+M4gR4mFBL9zseJYmquhl7ebIm2/F8ho8x+B8n3Oe4xH6ZUMgp86v1RAE7XDDUwPMfDpobOMA5C1f5ykQ21pcZfKIQ96eJZDaotUbQsshn+PtmJhIe3ADCB46g5t7h3gklYqHR8qLAm/H1RbO/E9YRgEMcgDFyl9m3aLQtzbnc72z24UoHQ6eTrkKwvbp48suPtrd08eWJPH71O9fGT6xsW4nv3q28QKD95fGv726OtN62tthuuNdO5IVh+7Y1rOx6hzD/b5ApQKLrpK845GEklgFcUjHMLGO7MgL2Eg2f2G0CsACnXMoLFUqm+HIGyJXBltvtzhx0D2oDgkH7Y6n/ak/j1BYbHnE7Dr/AcCq/lplDg4h6VoFMCLHzHoWQ2esKv7TCsrGVBvMlaFOtcaZdIgFJYUez3eybJVruttVhH+lNKEqDfaVUCqTO85Hv012zTuuU4bbFacZ0MoBx3gnHQ0hIj2/Y4WEZBu2FkAgG7KNAPm9XGNtjx0XXWbqBmx3YN2fQgcTDyphrrNhg3K24z0ZhXAmS1mnT4VHw0j9/5rJ0o+mwImyQpqrFeJssj9x3Plhazjes1x1V4kKNNj0NvwzjyARsSnos7SOQ1LhgNGp6TnrR+8nHWuvd4LB/8W7emF/bu6oqfbfgdnxvDcOJ5tPx6UU1PeMIDm7ZKtGhS4UxcR40aNWrUqFGjRo0aNWrUqPHCFctL9aQiMJe2ia8WsKzc+u8vy4LgOOr9Nar89aFYloItQz03PHAF4FIJGYriOGFcQ1moKKiXq7iSRLB4BJINiZ/DBynkcD2w3YBKrHMgPVHZWBbdivuXCk5wBv7KgABd17oPpm9/Blwo2ihZQgAG0JMzK90yoo27KH1YXa28wKiCgDh/KSQGfAlrA9ptAMZ438lR22EJVW5+cQuAWvRT9EeBjr310zUkpHwPGI7icndEw0XcAbVfKF/QbCexVI8uzlS0XIF1C2/u9E73hX7GJReXeg9Uvrz+0pv5GZeVEP6lkPvimAEn0wuSZDbeEMrki+Jh/lp1d9EGl9LbO5LwYrS5zFO2ItmiJArzBfxcAu240mhNgXFA5abprEGRTPg4O4wDYIb1DsDycII6+WDHw94ON3s7sWAbEjEDlchHqJnh5zyMPj9WCRqOIwr0jXworSPFa9xTGiO871APZ8uWtDJ613GN8qKfhMoXivR4R5rrhbn2cmyVBf7ueY2fKxIt5YqcFcrZCzsra9W/0tEu+1hrAJJMeg0Uy/Q+BmQNOE2lK6Ak2rJJRVQJRun/K0Ux1eRJoey/xyOSQ5fDx9XJUh1f3pfOgXMnu4iioCPGBpKCSQ6f+kQ2LfJ0Lq1pSjVyaZfku0KKFSs9Z6XCuPQczyrm8Cym+huJHK7X+hsLoroyWcfwpKb/Hm2WEl7+eRe7TzhuOQbDe9wV9fSYxj1efP6U8zYMU9g3uXhg9pOuUaNGjRo1atSoUaNGjRo1XiRYhgLNv1wnO4YELwSZ1isdblrHl3+AShULiy/q9O+V8SMjikZxu7Vvu4Y6q4N/pvtFNlDLzijMJO9fgRwpsQANUABPxcPcZ5j/xnO4BikpdROCSTg5jkubAvqD+nZqfJFvOj0ASeDfSsdPd4rFtWxa215d2cuvTAReu6uX9JrpYPMMv1F5PzftSirq1uzqwdq2u609fOmhvfTSS9ZtOheQ4Z2TFJ3eEwRv2M5MJWSGv4S+9EJG+7sDdFLkCjOHBTAeUC4DkhFG4zxeNAsNM517+irvT2c7oUCXg5X1erL9cbQO99mhDaC6hl80lNcrmzdrO54mwjoomdWsmZwFtwklK+8QJrihdl+AjjSyspr4jmXEc2zRdmsIP1RKBoT9wyU8zkkKqXbTmZJ36f3nEPCK+71bYDAKdylxgT7NEB8NdQciF5BQyRCfIQ6cOI6L60n2H3geglj3I5f60q1TEjhzuOxzb3aFcCQcohAkhkPrc1Fzer5Q1Oqn1PHuk+xjhXa8qcCelMyhfo7rlcpZYA3vvdptbbfbcN5ABUuvX1i8QD0KWS1UsWekNFrbdg+okN9ud0zCQCm/Nig+T3a4eUKofLx9yoJ9j1573fZPr+3R62/Ya5//KlX9t8fBhmm2p7cn/n48r6wHnIb6ljsasBasbGARy7WNs37nJaNJvP+ScDYNRVdTJ79tQUbqwKOgqINl7d7QeJHFsda5NIx8aAkcaowFTEzD9nwXvGp0a9eCV2Zj302Ervk9XPcIO9F/DrOnyJYJYBKje79K4Ho2GK1j/eA6Mq6sY529NS1GDnvYXnTWrs82da3tNp1b/EAtjMuTAzE9smGXgfUeqxXV8tiJAiXtaOdpUELOi8Ly35Q6635kKRG0Fusf7IBUEBAAdbPdWLfdWLvpWNQR/s88FlTR2KEBdfx2w3V392AnpfcNQHj0X04eLpIrPr+pIo7xkGwv1DxUKvtPzCUkJY7w+U6Jv8lOw2gtm1/JFIzBHopu2Bqh4OQ8MwEiwBuWIfqpIrVSyutaoZ7HjpFQz2vXTuwWWQW8d3W3ckNe3DLZrehGBLeltGb70pdf6vEaNWrUqFGjRo0aNWrUqFHjhYHlpRLsQpPq8rs1VYZZS0uUkCSygrulAjCpbJPC1pWN3L4PGOOgOFTJoWAOSwTArrUUaQI5rt4NpRehsl5DWFdcO0Fc+mpOcu7vc+WbE0piglA0exEs+CIDeAF8bTYbL5LXA0cIkNPvGdv48VoUtWpp2YHXAnYAlEc7JP1xME6/R4G5AHSCdIJ58kUG/ZJ/b7zPkYjzLl6Db88OBbnUbKGygyXGbCs8HASzsBQLDprNXgSQBacAUlZ4aPu0Hpfe1OrUS8FtKJFTp1++o6Rs3iLlP6NPL59b8LZ8IFfcLp2ZLxWfJVSSR/Ld16UxWsDyMrJRxP1q4SgIuFDvXpylVBfme8oKwqw8vXueu4YiWcl6d54WKmcHjAGv4/rCZ1VQvqwTmAG9D590KPm1FkX/LvxmS69YvBmFAvGQDUPY1YStRiQG8BxUqFsCr26zs80Gy9SJ7uDnebChP9rY93bc39rQ97a/ubXbm1u7eXpjTx49JVDeo3jmrMQJrV4AkJGwIYjVyTCPAF01r5TI4bVjCoRldAGVI4HGpISrRgXs5MsbSYtkaeI3xbUo1qVCC8/XJIWu91t4KF/Wibzoax3L36fadk6N3SM6QXH1MZ7mepC6JCTWsSvDkxuwyhiwZqwITnEW7JZAfghQF+2NdWEYtOMikgnlapC8lR14srd5Hl9t3XdZTBxrYQbnWqOLnSQByR2c44GDwW8b635SLCODR9WyEgRILvJ5jDkmJvQZoQKK0eb3tbC7qbt6O7Wj312snbK/cNunCRYaKt7HBB0KGIY/NZKbRYE/WUh5cU33i47EU/JXFltXMsttS6hMTgmoSEZEZ8bnbd7FoURFrCNZfy0P9lzs883bokaNGjVq1KhRo0aNGjVq1HibwDKUx2nLbrCJBE5C5xXgabUoHKRtvg4dXD0cRa5C6am/yZWZKmao/xaWFkGeYiuw+8c2rXVtR9hbAuZ4Hy6Rdg9QvZG6CgTQg5KACIqy1hXQ+j38ZLPgVDYUx/5kh9ORMPrqwQNXkrW0sWCBPWynJ7ATTH5wBf/PtT142FE5h/fsrh5Yu4bCunVwOgke0D8T4Fr3ENutASiGAWo9GmsSfBPYAw54k7BIojseB4hHG6JzR6r9tC07gWkvsAg/Ww6CBiB/TRXz9f5km2ZNtR3UrIDKhEzDbP002AmQGR7NXtAvYGMGmPdYPJTQYgFoQ2Ks3+5zX748SoEq39o+IraTL8ZQcDWpNon83sx+IsHypXo4AFUJc7Lcd0mGLy1h4pbxN/rCJqhWXMfz1tO6c+mFx27MVSoTpQrFNVKx7OpMWB2omQRLAbzogR4qW9i3dC3fgzkj7uXtiMRLskmRd2u6fE9ohKhWNjDwWW5sB79kKEmvHtA3udts2A6Yy22LxAuAMuZ0Z02LxNBo/elg/fEJ7S6uHz2yYehtf31N0HlzfWu3twe72fd220tNOlqrMU+dbPhgyz95JDSdbH97YDKFPuLF3EugNOBi0e+LcZCU2ZeuLDk5wGRYAsEZLGeYupbi9J6OjzUusF/yB6a9iDy2vSPyGhLQmD3ilhFMvmE1kHocTzNZBjDq45WOwquztWHlgZ0cbg0Szu9cE1Hcb5jsuD/ZPM52uuq9n90b2tfUBsX8kNTDHGF1VUDeztbdBts5pAjHqX29vJsI8oGTskt5/tKKyD3GlcSRwlxjWApo+EAfYY9yPNjheLC+7xfHBrwGsFXSMXYaxFwtLEQKDh+Ansp2V/0zGefFVWM9WLsyeLWOhE3+3NRuASmWS9sZnS+SQ7C1gPWL+o3wvChiu9yJgc8z+b8DrrPYJT9LWhaORR+hEmM4QKkYZrHjyD8LnmutqVGjRo0aNWrUqFGjRo0aNb6wYHnyIneAFQkN+xf/UEdRP8eicVJW4Ystt0S792QJkalCdmmW1MdRCM5UqI+/F4rcpGkOBCeA2hJG6YEv4OlLegEcA/74fwQD4BEavpsASlTJ6TrEkjIknQmdoX7sCZa33ZVdsZDYyuYRRQmxuxueryr8xIZtG7t6sGNBpasHG9td7ewKIG33wAzb/VFAkFvGpVol+AkKR4NTnZv2FyMUcQLLoH0JZhUolkpptxTBYVhsS3cgz09Cw7AwcEDifs2ASOgzKZZ769Zr29raWhwDENm3fR/PgMvwxXWbg0XcRRPJuzqj1Bg5C/h/6aX6nkTBpZ8dbtWQwXJpy/EMuFyoicvnsvrbz15ee7I0KGl5RuH3wmW3Nbir5c4KxLsuGlKBL/x/i2soj5DmQvF0WNng3pu12254YbAZSnQA70KpDcUnxsFqdeJrVEiOxgbpvpJydaFedrBsBVhuG9tuYYlxxaJjuwcPrNugQJ+DZdgatLAxULKobQDNkLTZ2+31Yzvsb+3pk0dULEOpPPYjwfJ+f6T9BcCySHZL/S6uUSpRXAhNYXj/sCF4fH3LeTWfYZ8QPrWtrVgYcEzq1Gi5xQhJOw6yUrsE+nwJnR3cCsP7OCU5vGGadYDMQv1dqE4FpvNYE/hPmb3icrCjAxZAC7lyfg5ro18nWoNq2eI1gMoo8gjnDlzrGtYgwM0A7b4u8dKwg8FgLQIV82zHoywgMH50SfLEhz1FeLdHQmwNsNxu7OzJMY1vFbmTJUcxZVxdLO/+Ionj61XsIEF78BkH2lgfsR5jfT8ej/44EDRHm2ot0DVJxSy4nXyIk0K9+HfqaynbVQgRnyWCyiqAp+tqvJgf4D0eUTjx7J+jeH3pJ18Is32ceXuhKCySKxfe5HgDkz7+7wDL7Rkf524X5YmM7EAkyyPN9VgDVyzaeHctr1GjRo0aNWrUqFGjRo0aNd5msEwPSSjBxtGr2ef6S1S4wQ4iFJohq3L1FyKEf/FTfq1LJVr4ANNPMn2ZFhBMX/LBVWiB4UwYylx4MHsBvwSRC+VoKP8WNgipUFMADh2famYU+Uoet164b3W2oR/scNjT8uNqF3jStxdzi7HgGuEuCirB/7MV1Fk5JOgAXOjdDPUflGthKxrXFwA0gzpuoXYrD91WwPgAydmCItpJACHfL+6T27f5Nmyt9uOR2+Rt2FB3AzLBTkB7sh2mQGnHAlTlTnrdK2CULCVKdeddmJrtPxwtL5IGlxGWFveZWmQYnZ4t4UiyqxXAkzK3hMpxpPy4VKPGTcql4J7ri+JpeZilyygZdXpNCajLa+VNeLok+TL7QRbt4+2R7BSW0LuoAXePbUcJe8u/F4A/qSGl5kwwLFGvgFJ+dgdn8HqWeru4zrKZCmuYltYwMQ8Aj7d84Pcoiil7A8w8DDQU4dvbPJ/s+vEb9gRK5ePR+mNv4zDYCT7Kw2iHfrT9abATLAjgr+5wE8AMCRGok+mLS6X2zMeph6K2BJ9ab1AITckx3O95USxOwNMRI9ch97gu/cMd9pWF3/gnb6M8oh2WLvIueU3JYyyQapHQ8KSSoKrmflmoLfVZFElMCbK8RiWVqv/H3Sd0T4UaW+NFIBai3L6ftKbBI9ttctbNbJ3b6gCqxnwJMJ/hrAC1lMVq53K0qK0udxgoARfJD/gor7H2IRlGNa+KOhbbHnxtB/ie/CEInqC8H7/YZ5OuMa431Rosih5qfSxU+m4pUTS8q5IBjwGyzzZy7KPtRq7j8d6FvU6a69muJorWLq8rVwmcYbTuO2yWftHLRE9OUHnbFMVY2c/++VLRco0aNWrUqFGjRo0aNWrUeKFg+enhwJ9U+FEB6+qw1YoKV+rAwqM1QId21bM4FxVTAX0TZCg0fA6V8SW8a1vbQK3Ignra9iy/Zgc5a2zddsVyu7aua2g5ASClXf8lQIDdgPu2FmCZHphQMJ4Lr1f4DsPO4gwlJ9SV8hENkHc8HGw4HQlVXnnlHe4TC+g7yQoD9hruF4viUVDtoeAg3A5YiGx3RdVyf5ysZ1G7tc3Yiw6Q17qZBbZyE7qgZKBABuDZOI2yCiFTddAMAE+lYm5K3ylP2N+6P2ds/R5GwGH8Ki9S6rsdPswAH1BIO8g4EsTBDkP3DjXzCYCOir1AP3LExZWrRJ/am/9dqODO9xfrKywD/JkLO4wLSnteAtB48+JcBchNiCs5bpRQOQD+BTReXuriSAKL5THs4l7S4Jag0i1wYSWSHE0DdkdBSbddybYyJXfOgFG2I9lSYUHVsy3v8o+FolbXgZkYysdIwqgoXyr+N2kXAlWMsHhpckG+BgXdaHkgdSrHC9We7i1c6i594tBr3NXK201ju21rVzuo96/swcOHLGi53e2oFEZRPyiVMY/XK/gpz3b95N22v3lsj994w9549TXazmCwYb5dXx85N57cHO3mcLL9OFuPQmRQzK5aJsAOx8lOp0HqUcylebIjkmOEzVwdWHiOcwmQkqplrSVIMikZc7aZ60JYami9oFvuAlSGJ29QvOjFQsmNuRcdlsZ4/lWq0ujJy3KSeTyElcIiAeYvCb9sJp5oaYGuBFSGuhbrSl6HF6PGkzBefpK7G+SWg3WwoR3ONAy0M9FdzNb3OB48j/UTSYJY53FOqXfd9znmXfjY5xEqdXu+EIfP/jytOWYqnhsUpOQdz9ZuULwPRfuUmPCGoar6PE42DxNtO5CEwG4S+jPbyjof72XCMv6NfpVvspJokXyQN/FcPOfwlgA3T2Um4Gyyvj/Z1Aw2zSg2GGMpCixGP8acFthPSJ1e0drBEnUKaLHkYBjvZyG/aEt/BFRWIrI4TxqXSpxy7GIdRznMCfNguLe4aI0aNWrUqFGjRo0aNWrUqPH2eSw7rEgKLvp8ZlxHdOtflFWe6F5st9y+v4B0YXAhGEqP5VBIuhdzqaKL5wSLdZRic3FSAC7OWyiC736PzlvR0yNdJkBXofokTBBQcCybFL9pGzx9jgGqAAoAgKI4YOzYXtmcqHWGhlTfEbALMBG2unwudGmpOKHbX2T2mpWVAjPLQnWpOJm3uSwNiMcFox1MEyyR37t384V4NsBYoZNLvfg8eCIK+pXYLAk3F8X+Ft1TaD3fg/OVRSddK53+7CA3K4WLk8W/kuVEFBYLmX5WGmYYXMDj+PVNlMTpHW4ZsXhd6WV+77sKuBzAPRSuhQKy5JcJMgecvqcgYdlmoR5VEbZC35mEll5QLLWh2qqcA7mNAabXxaMh0JWFjVSnUthrTMG64DyPdri9tdvrGzvc7u10OEpxOcrKATYM2EHRwyoGxeUAAn2RiKRMKHNjXMsGB4kgQHa/owTWl3YMpZVIgsNLhhsNdk9P5femXkjLnXecm5MTLGbWeJFYWI6CcqxdXE4g7MW1pZ+lcjjm3vnNSlGWhergq451fdaa4D71KlIHeJzhaiSO8owLc2JBWG058eqIaWpCAX9x9kLRq6HsinI8fB4vis75DxUrnW1CgdHBVcKLdTnW8VA35wRLVib72l480tpZwN18xyrWGZ9PuD+AXfVrAN+wgrozTJbzdDlQUpIASdtY+fTwtEMY7UeLJ8hdJDkWVkM5oUeLmPS4uK4aNWrUqFGjRo0aNWrUqFHj7QTLKNqWRYkqXgQVMjxCpUpc27ZtCZpGKHh9Oy422RIaUKCpAllhI6DDBZoMsCql8Abb5R06IVAgCd+hJwfOAFIdVcoBVqWwTcI4kgk/Mgkq4LC2cANIQHUGFVwCj+kbe2Gw6QCD14HznkcV2sJ2+gF+s1BIQ/eL7dajzdPALc9QBXZQKG93Eh+3M5V8LLplE2Ezrp3eoqNsPWARANVmt92poNuMoli4Xvg4D3w0UOmtz9ZBjbyG0lJqaFljAKhAWYwWx3OhvsT96p4CvIU9SQsrWkAgV3mfqM5bU604ADZwu7/8l1HcDL6167N+x520aD+oVpPyWZHgZqkmjoSAq3gX1hnhZXxhp6F+mO9apqRxuEweFMRzAdXL12ucpUGSkxpRSPLecJuKRGeXp4yx6xhWwy5u2L3HoRAvIV+i+gu4XKqgs59zVhx7UkGly5ZXeJ+SulBSRtIBrUm1qXvJBFSLImihKg37Cvwf1J4ZZ2WYyXnfNhxUkewJuMwHC6TpVjFWt9vGtjvMi862m41t3WMZhfy2m842HTyVAQdHe/rG63Y6Huzdn/N/7cnrr9ErFzY02BVwPJwIiPeHkQUnT2d4xa7sOJideoeCVG3SRIHjFlAUgBTr2M3t0ftUyn0VDmxshbVkjaKWKAA40Vsd1yIVK34/2/qsOVmCW9lZ5D4JKwnu1YjidCwqJ59nd6NP405JE/mgRxvrWbSrbHkypA/rERXii7Kd/J3rn3vgFzbLLCTHvpEdBEF7cY60rgfv9h0nA3dJmO33Z5uGNQv7he86Vz20/QljAz7y5uPGld/s/8lWM9au3ub1yqbhqONDYdxAadxYs9nkdiztHmKORnLxjDVzm48LyL1ukpWQu/pToYxip/Dfvn7y1I63B/5NOzxijoVKWg2Bz43RITnWbvV32Afh+VAzexHYZKuhsoaLtYivm2zsjzZjfUYS0VrNZXzccD1A64XNysWSkj57NJewnmPIwP+ayQ/+H2D1mtc1cfBx/46Pg5wkk/WJg3VerhdDhK3JPKnoYKionyslWKNGjRo1atSoUaNGjRo1anwhwXLiDoVyqtypz6+3DoZRoAgwWT66+Yt3sjBYyjMvvtJmwKzHpWJZ1gEsVJcrGPn1RAGl5VXHecO3MsCOf3/Pr0wK7Phpi2uxWQUG5WGpgn74sk7FV4CKApbSExRgocnF0XIBJRVRyorOgHmFMs9VqFL6CZSvAui5WjQgVyrMRUCoa8j+y3GTS30j3gtgEerUqIklhaegHFt0lbfcBzBciF0DYhV2CKGsjJ+L8XPhrXoZpWqzPKwKeS179743X+r/Lo5UwNf7X/Ws50qrjhKwXr4m/hrAOANEfzfFqhc34mM81OgB/y9PEJYBi+OkQ11ctIti83vv3hfHo3sqJ/FnUsQHfC5sO4oLzr7lS6XynXZLBSW1EyHGN1XLXnQTdjXJ5mUa7XQ60l/5uD/wcTqdrD/18lQ+nqRYPs30/R5XrU0rQMbsgQtVs7b8awVg0T6378B7VdDOQd3qrspWEDa8kkvsJrVvHgUXnRilIUtS6B7t+ddCxXyZFymHaaiSi0RNTooEGo6D3tv0F31dQsw7cudirU7GzMnSgsVDmVxbG9x7aO3AQp6ln2+pdPeDS+KrXR4AmFg3uSjKH/ruRZa2GIXHj3twc00N9fw9kxTXAl/lsVAsu2tIcbZCtey/aqxEoT3vY68RkL3JNf/LZT5d5uL6BZfpkYz7RgJOr1x2WeFhXs6fcpbJRxuKZRW6zbsvPCnriuu4nzjDcvdN2ULxeVkU3KxQuUaNGjVq1KhRo0aNGjVqvGiwjC/0gk/lNmChAajY2kbeyIBR2GoO4EA/Sv/CHnvoAwY7L85gFyrgaXQ/4pV13cq6TWPdprXmAIX0RNkW4VSzsnYDxTK8P+W9vJrieFmVtqI6EIphKLpC1dVK+eXGxNJq4Z7wLzkF07t5bbbxQmLbdkP4NU6AVyj6tbbDodf2fXhdzmfaXtBXmdv6BdAEzda22axss+kEzggcWhbAwvHop0loon9TJVqqWXNTO/xd2bZtbNs0tqFyWX7R9EOdz4RtTXO2zRY+nFD5rWyzRaHAlc29ICFUkOhDtN+2wd/kMw14tIKnAK1PhB8AnnHODvYFbWvHCQUcfbt4YTEQCrsAwHcg4wLg5wqAGb4utZ7pPYmZLVFeLnpWwqX4dyldjz3yDp2peM9b4hcWziHxLMZ2BoCFXUFcjlsY4H7CUzi2wodKOI1xwLQAZmlr/gX15ZiN4mb5uYUFyHnpx7pkmyX+dBCO/vN/S92v+1fxNwloz6vZ1jPGYW6xy6KQcf+EtgX/xLzgWSGH9535KgCH16ld4APeYby2G+5EWGOOYAxvN/bgaksFMzyY53mwJ2+8Sn/a1979eXbc7+36+taO/WS3+5Nd39wSGO5PPcf6adC4P0MMvDbrARPhpTvN1p+gPMXfoNo06wFHvfgk/GvRDhjPsuCQRy12G8wD1qHJ+mFwG4jgeWi3IPPLxYvwMKnIZRkR2u6s3I/xhfkv5XKC2UXxQynNPWmVelK7NHRqL+jmEyCgeb6mSN7AW17XEEkseX2Xo+4e+w9/Adf1prGukR0DlgV6dPO1a9tRySvLEoixJ7QzIDN+Yv1vZ1uhwBxUtSPuebDV0HMHBN2BMcCm1mxyJbcrgDEGWN3U10W8XmsLFMcrJSCkjU6qexV51S6aceit7wc+Tj28tGPchuK/nDa62YDRVKpPUiljjYu+p6d8YSehBFwkC/GZNHHNhXXLOMO4+0ynDyYzpkm7PdzjH0lXjM8Au+Ui5BtL+H6oiMOaiMsB6xUqtcURlorPyqpkNWGun/k5iM8oJga9WGveaAGPffhsYy5gN1HsPLjc/1CjRo0aNWrUqFGjRo0aNWq8zWA5FGTYrk1FmYONUNAKxgoGEK74l9oEHhdbud2jtyxaRaUYtu1ClSa4HH6s4sTaQh4exQLM+VxkDHxJbBWPCxeso3exbxe+LJ4l1VbgMnkb057CLT4AWLjtOlS8Z+N2a9z3xuEii+EBWLnSWKrMxosANoLufj6qQR2GJDVeFJRKAKmwQCj6QN7NAnUorEevZirPpFzD9mZp3FToD3C5bb3gnre9FKiw1QDsgy2ILDDWooyCLSFo9GtBWwP+Qa3YwF/VrSEEOwv4eqnWLEKQqMRa3v4lUI7qj6V6s+Svl+D5QoWcTpsoTfHcpdKxHCf3XO+Cht+jUU03S8ieldKl37LY94UtB/+qQmKLg5Xgmi/PtgiZQ6t9StVrsr0oACH9XZN+NKteyx0A1L1SiCoot3QSuYCOfuBSTZnns2w+0sPBdVyB7DDWHG/0HaeSH4mXhgkX/MR8H/rRDvtrOx2Pdnt9TbB8OvU2oihfPzGZM0Cp3KvQZM9xS7THPlBhMvyEHQKKtcHOwQuqAY46WCYJ5TXICkOwXesIla6ueF0o7cv+K61bLhSm0YMCvwGKszKcsN0LMYY3dRqTi2RJ0aJFPyTbC57Q52Bh+61z5p6PPo++jIRMmZMpPb3zjQSQjvUgkkYzbU2ikCqLAsJ+h39w2OkV7/gZwenh4BawFOAVOz/mhusBE4blwOVnAF4TG0uEUrnyesYDUDnl3gq1OR4AsrBuYR968cUSmC+SPaWvfhRphAUG7S6UtPBX3VkVysSLlNSweArldvZrBmyHJzVtOGL6hh//cpXx/nc1sSuV2YdszuIz1+cvV01XWNOaA8lVwmjNg+SrXXyOqIhlqPFjHFSwXKNGjRo1atSoUaNGjRo1XjBYRsjaIkrzxXb+DMQWm8Sj2BLgb/pWDyiaN0onf89wyAQcoiIYqsPGxnGgihlfxKlyW0GFK+Ul7SWgtvUiXTpKKACl+AyAKFsJ6dwCzAhi64t2eOeG+hoAFTB420mp3MITdN3YwNuHfzTUxmdrCA4AR+ShKpAsX00ofLdQVbeNPbja2Kbr7Gp3ZbvtFQtXCaDLKxNKQkAxXnozctsz4IiKZgESCIiJ3XjZJsKKDKfleU0NcfIlhVp5O68J1QaISvuAuIAWUFnrfvFegBSCV6dUBHb4CfhDRTqOt6afMt4zuG81dcsJ/l54WMSfLuBbUgvSISCbDWTAUsBqP8bzjtAsJ75QI6efLl9M1gfFu+/zv0hK5ovzBBgq9qHfxU/lv5eKUkL22NpeXHMJsEsvX/aFQ7JUaTGI8eJaC9/nfGP331fAbxx3XVx/8fpIHAGUzvBTpxJSJ+J8PWfP2fxY7lCQ8hXJFcyDrW22O/qP77YbeivD+uL66VPrj0d7+viRDUNvI4r3uQp/mM/WQ0U6wb8dyY2O4y683FFEjt7nWF/ajnMTalW2F6AzB7SumbsJOi179Bumv+5k64DKBMuYcxkaLxTzRduooNtSJX4xJIpkgYrV5TXqcqxEQcV4oyukXfWdVKXM/ZztzOKDoaIXbMSejKXiP195FHuTXjgD1VDdyuomF7IMWx1N62wdJM9hQFc9aBmC/ul7OzSeLOgaKtKhHsd7ubatsV7qXnMSLRPfKBAJtTI9lBPIx3X7Z4g3XKhvtZ47ePYCe0gq4DMDSurTMLK4I9ZR2QnpHZcJnWh9/Ff3lxNaaVZ4kgc1BVYd1u2znceVTUxCIqGady2wm+A7jvW7wVor9THgMqEuFe7leCkGTFFAkKOGFjG+06BQwHP+4fMR82DsWSSQn4sOlrkDJxWe1A4C/JufKW5xkvr9GetDjRo1atSoUaNGjRo1atSo8TaC5aguH9Xt8WX/goeFEi7+D0CX+9TTfnKp4AoAmbf5A/Kq2BC2pAMssxAeCtdxW7AUuOsWtbhWtm4FrfkFGZ6pALVxxPBOJWQNRXV4K8uyAQCCP/06WNCMamDZPgAI7zYbqZBZeA97m1FgS+peeo5KKKkifADLbp1BsNyubdu1ttm09nC3s24jsAygBoUwHizoBWDnBaEIV8eRha4AR3CNBAEsGhUeoIIeSTFHqbbDPkB2+vRiGzyuYWXb85oq89V8toMrg0NrDLuPDtvYHSq7IW3qD+JngmXYBsASRJAaBdYaQju0cFafJ6Wmw65CYHsxkkqIHK8MfWvJxkKRW7z3kvEuX3oH5xaj198f404EVtciSJ6hzVKhu/RTlSI0FcRbyKkL2XBhhRCK2DsXdukHHSrr8Mgm0M2HLr2n78qbL6IQfudGXEK1nAOQYpOvvyNedEUv/sXt/qpiFkAxbat3oBzqSR1L812FJpFk2bIA23a7te0OcHlru01nw3Cy68PBjoeDPXn9dfniYj64yhgF+qA87ickMpDsaGhhM4YNAAGpYB3sLawf7dRj3TgTdIYiFxeFf8KihhYCKIx2xnqD485UOdNKw5M4bJ5UZ2+pjE/Q7mJAJi/4C7AcytqAeAKH8aTWpTyIlfxi8igSYwt7HLfwCbWt0Ouiv1PirlScp3MpsVZeVy4+6eMh+Q3reqTo9j537+oA67gXgGW8G+smijCu25a7OuDNAJsMeCozcRXS6bhd9+9GcgH/XqNYYVLzx5pdjHVPmuRZvGJBUbW3+hRJuX4c7TQM/InEgQrtZQi93AqhXS1sF19r9ZL4zPMrid0ygPNMSqqtCJbhGY1En39OYk0l5IVq2s/ForHJw1ntynXWPfQDcKudUZwwdrJkVX1qNir0VVwRiZgZCvCzbKKkVg47ntgfEfNJYDmldN22qUaNGjVq1KhRo0aNGjVq1HifWGG4zHQBIDNIKRXLWVjJr9sOdagZdqCXAWTa1O/FwsiuqB4E7ME2aqm1Mo+IL/mpfl9YR6R9z2GLEYAikMsSBjVUm8G+eS0v2FYK5XYNT2RBKX7RB7T1bfLAAYBbUoYV6j48AKynga/ZbQCWO3rJdoDUXqgsoBRgt7apAx5kywRaC7ANLrc0632Aef0aaEm3BagXfr95G7zbZqxWtlmvrQfgOU82AqQBxiU4erEVOxTGrpadAka4SjGXoYrt9RcqSe97CQxzsbsAWCLxfp3p/GUBx4CzFybcdzeOX0RWZ8bRMozN772jUuZ95PERRfPKY9w508X1J9OJwk930Q6FUjldo7dvVmZ6YuFNxIN5fISdzD13E+dK6sklSI4CgdG84cdLUIz3n916pmjrVGDMbwrwCgOW2/VLSwA/L+Y4rFdijMN3HDCXc2GzsaurK7u62jEJg/Hdn052e3PDAn2H4yBf2hHj9MzfD6fRjngM8I6ViYjzbVm8zD43MVYJObHLQdcfljRQ3jc+l9LuAL8/XAPHYOGj7ZsgFqLynAKLPJuDx3uzHXktW1rvRB/pfPJNd4ufsLlIYxBLGLx5pVrFsQgrC41zDLJoewTTAKtyDGdLhLgWrHuRQcHIox6Wu0n85lJR0NglEWtcWBKFrY4X+qQ/8VQ8kBgbbQ1oP0+2nqTy1a4In3f4TEizJ3aO0Aw57WjQkIsdKLAwcWV1TDBVR2Vq0Hs8zYsMx+O4akCMGbff9x0vXoyQQFmu+3nZykkeJEjyPBLwh8O4+s5380Q2weEx1tqcYCpU4GVa6hnZsgz453RfOWGQlduCyHHPGcSnNT5sORbnyZ8TtJu69wpq1KhRo0aNGjVq1KhRo0aNtwssu5SRRfzcgDd7kMovNLyKQym3drAp5V8U+hMcka/vUhlKcGAokAerCbPjEQXyDoTLhB5evEiFqeTDTHUyq+IFIII6V38DlBHQDKVy+FsKXuB49D5eY6t+S3Cw61oCYfwOIIYznhxyQZXX41q6hg0H4MNt1n5uwonhZPPpYO1qtJcfbGyz3dpLL79kbYdt/1tr202CJ8lD2j2lwS3kPysIo0Jk8C91wEw1ndlpBDibaNcxnhsdzilhuCPI1xYFCNfWnNd2WE12Og82oHgTFOAFMFT/BqB32OZ9OjhUb+bG2nPrRf1cNe0lEYW4L+GyF/MLCBYZAL5UcDn7jLpCk/9OXg5LlBek77LamL9C5wrsHcikNGdJTxRAXOM6xnbBny9vJ4XAbOFjnKqdXW6xL6By+IrkW0pF9ATN/G8BIZPFQflvv6tQLj5LqVwA6GiHpBwnpHJ99sVbkXTA+dfYEuCQjm/DmAsv9SR6xm4BB5IS7tO+BfcJuxq9OlTxLYv27TY7e/DggT186aG9813vsFdefkg1MeDj9dMbe+3dr9kwDnY8nASb4aM8zfb05mD7Y2/Xtyd7uu+Tkj5rctc28rWT9eNgJ9gDoHgc3U6gtIeHc2tnqDsBz6GixQ6HYozivbF+SM3rdggxDJh4CjX9MqMmf95chC37pmfwK8VrkRCIPqI1UChGlx7dWEeZ7IK1DtZRP1ZA5SmSc94G5XHDe15+xbo+qnZ5foFqWuDQTgGJu5kLMwvihXc3Aba0w1CMc45THa4kFs7BHSME81CUo6rnikUPT8PJmn5tQ39kI8HWhMkySHzlxONrdPDjvGOChRC5PkEF7PATu0b42aOSkJGM4zvdcomvQbFAfoaEAjfbKzHJQSsM71sCZvUfCvUN48y1cTqPCf4GvEZ7YQcKPis4V2iVYtYB96LO4Br3iXNADSwluNZXYzFU7BiR+li7AiJpJ0hfLk15UlKxDKshelDHXhxP3Pj8z+Ac3tVQW4clR7G7wWG37Goyr9dzUlXjc+RZNkA1atSoUaNGjRo1atSoUaPG2wSW8xfP8LW8VJZGIali93fahh5fihdbtV3VXJzljsJS24QFqWMbeel+mhSTfENc3f3HK9XKAfbkv7xyiKNie9ranP07M5QLkJS3ipebzyXexfO+9drhOiwyAGHCEzk8U8ODOqs+E91N17uwW/DnwGewHR0wgYSjEGln5iqFWsArKR6zgi0U3rk9lh0S9hVy2JDlAOAUbTwIuLPXdqg2Q3Er9VzA1GcQ2kRv9TPZFYRkuhgrS0Ff0Z/l8S9Oc/esi5EXPXbn+XwZeaxdCgqzCvnN7u9y6LnVRCbLi9ck9WnJLu9RMt4BQK44jvGXjl/cw3JOuSXE5TWWtQ7DR6NU68b9JipVXI8rNtPwTWLTKOrp3rvbjW02G+u61pq2oTIZRfn6U2/H44mJFBTrAwTrB4y12foRCRv5qGsLf+qVdD9pbLqynv2DOZwwvc+vwpo669sLQHwx1qIIoZq3SFBc7CBIzX3RZ8/SySc98oWX9aJPaK+T1cxxr2n+Xg6NYilajs9l5FRNeVVF5xfjSffnPwsLnejgciZlNiyYG7stAEYBrlngr9hNUszuxf0E8PRtEwtQT0/6S8V+sjjiPhKd09XTPLcXqqOq/Vx48sMDuVgnY8dJ3kVRjJvCxiSt+SmJVybEysHj48fbTgVmPelyubBdJIL0p1LlHsOl3CVSPHm+b0eDvNOZQkze3ucMpj3fFZ95FSzXqFGjRo0aNWrUqFGjRo0XCpYBXvW12NVT+JItMpvUqLSGKO0PXOWYC0epWBuVkQ5NEoQKoMBvzoBCUuLOo94pRTHep6PNM4pz4boaa9sVxHm6rvBTJgjyL+ALkDrxNYBbVDSuoeCDcrmzzbqjvzJUljgGt9m7z7DuL9SMKvS1XrWuloZKEP7KnSAsFMVU3QEmoKBYY03XETCvmpYF0CCf61C8bLflzQM66Eu/vuQniO6+IPiJe0UBtROKk0FkeJ6thZcnAMEMeHy27UrgGcrq02ll23ZrW9p7aDs5FeUOFgDuBqo7HRTxEcWe9JrB7U3WKMR1OtthPNvtYbaePtHhF3of7QpEIh/Rcos4+UuZBAiBMqOE9XGkDElDfbt8cqm2zsLlJUxeJjZibCxtCopTLU50yXgzF85+r/FEhjQF5F14KgBQFZLYpDCHWl8F8BaOLrRr0Xukqo3nwjqjgEkh3C58r+UvnkEoxlp6j6v9qcR2rsa3ex/F+GWzugKW/sRuKxFtkZM+q7Re4IlN27FI38OHD+xdH/xOe8crL9uDq51tu85ub55Yf9rbG6+9YW+88cT6frD9/qjdAYN7NrNAJBIpZqOfgbsmoDSFUplWGr1UyrwurVHbzS7ZBcAmIyOtOLfjAAEAAElEQVTpWHdUeDD5aC8e3iEsVFjoo9GmkrnqXO5Vm4bdwozbf7oaOato87iBepjtldoxxp9gMuYsimZSdUow60XhaLkTvvWRfBKIDZVq3EK6mOSh7AXiqIRFP7stCBdbqJodkBa267GzRIrfSOjFwy0x/Epmh/zwiB/7k9aQ00nt48rwpMzXG3LRUYJh3BuApwoNBjSmEnvjdhcBcGmN4dZCWIvmyQ77o93e3Nphv7fDYW99f+L1pNvjybETBr79GqOy7igK2hVrhuyKsLZiDuozUCOxsNeIDzHcEi7XPbpxKnobU3XtliJQevuOgLuJB+8nTzpyl45/vtIShcmSvM7mdVVtFZYzUnTL3xmfd/SuXmmXCrzys5E6irI2LKBZwXKNGjVq1KhRo0aNGjVq1HjximVy3/DWLAs+6YtxeHGGerUEc+lVia9lj9lAhvDcLCScC/WYrDDiN72+VAcmEVcCjBdqvUwzhXlcpRVF8Ai6F0WS3D7jUtAX6tx0j/5+fvmPQmfLomyXnriE3w7I5LnMveEOJQvAFMAp2EVSLEcRP5ZStCakccUdUuEJsNBAOV164OYmIbgj2InrW3r8lj04u38tj0slYqEuvChit1Atl/rQND7KTitBclHpzwty6W7uVK4rXl8eK8Sepfw2w+t4XbRjth6Idi+v+646+H4vXZe2BuAtriXOn6xe8l+KO8iK2DeLbL+RLUYun18+lxMq6V6edf0hVA+nEldXZ7h94atejP/kR1solTPCPVMNCQsBPFCsD6pl2r+sVrRIgPUFwPDQD7S8OQUsJljGXGkJZ6PUXZraTPBI1UylMuBhAfcF4WIXBe4nQ/A0H5OncpbL8v8uJkFCt8ViUOK9ohRl9tuOZlsoSx3ix3oW1xMnyqmNpD7Pu0P86N72RTPkwyd7lnguQ+/liy6mzkVi6HLTRHnvi8MtJ1a6CyYfaM3gimVXEcdOj+Xl3D8mFztiIjHkFhgpSQU/FvQzwbL6H+eB8h2Jhih+mpIhmOdngO+Za2b42vO60vokCyB3qi6scvL6Uv4MUr4YMhf5JC0psSsjf5YuPx9Lm5RgzEtP9bJ43zIiceYQPLWTCo3qPuJDxBPBUayWxWarYrlGjRo1atSoUaNGjRo1arxoxbLjh5ZfTAU7ZvhkhlKZwlpBMhTRUpGtUIFFZKuHBZT1L7nQg0GppmJVpW8oMSwhMAR8AqXxRfwSYJ/lzVlAsjDzZJ2n5Efq2AYKsnLbsbtLQJ6XbSAEKMaAWLz/EN3qOJtuSwXzyy+9w9bnwba7l+gT3QB0QL08A5jhp0pl6Qt+a+umkz/zuc8AmbBRCj78hLMq2kaKwZWNZ3mfjpBx2mydbweHWg6eowAZx36g7yd6jlvIzWy7QXev7HDsCVrgnXpay/9TD6gkBSJGVy6vWgcRNJmNioyh0BOAXPSByG7u8oJzqjkv4FT5WvdavgwxmBLlLZFahj2FavYO9LoofBclwBYU7T64mwfGEjQHjExsKR0bykG8TbYM3kaF33R5rQFooZgMn3IfVumVqcAj+1ly5KU/a7ahyPw4VMmeREkH1fFxPqqRWShS9i66HinaMbfR7VSqjjnhgWOPkwpIDnx4Qb25LO5GaS3/Db/yd77ysr3rne+wd37QO+3hgx3BH7x9nzx5atdPnxIu27ol/EVRTGgr4SHM49FzFz7fZk3XpmuFCnW/3xNGp2QBdyBo7cBcREA5S2fegMa+W+IcRQZ9jQoVNpMxoR4vilNS1V/srojejLFZjtpQz0cxvnKcF7bgqU3RD3xf0fdLr233KMf6+v+z9x/Qsm3ZeRc+a8eqOumm1+91t7rVsiQjy4ANyGQBxmQDJpqcwWQGORtMtkkmDWCQTTQYMNFgopH/gG3CsE022JJsgW1J3S/ce86pqr2r6j++75tzrbXr3vf6daufpJbW7K53UtUOa6+96tZvfvObjQzVmRBzL2a6oAfsTk365Nse+wwYmxMMntAo546vv7AtaWlP717L8KQ+4W+YM4KUWvO1brBaA41J3asXMBlrHeDuqhHgbcIaA8kpNGOdp6TODcIuj/2sUoYiGesZVfWhro9ETYxLrO2455pJ15BNXycq4DE/VJUB5Xo0S1TlB9S9TOxFQsCrZbRFrdGcS2w0iOqZdEf6uMnPmB7YeN/iueCm0XtHsrcpmrQGyOXh+5zKc+e8xPxu/xF+1ty+q6RzMlTru9TKAytvVLWDB56rPgCn8yRFvUPyeA5tai4N12vUqFGjRo0aNWrUqFGjRo2vNVjGh2tCJi8HPudOXsmfNEqdiTtLpR/iolS+jKTEotVFqIhLzidIE4Ya6X8FrMzPDPVYSANDjulAMym/AjaGGrpwEC2UjOdC2atGUAIqS5CEkmI0mGpsXG/sdESzvpGwLfyIqQxOvsdxHIDcgDHgBQ6S0laJ7h0uh5mIM7ukWnaQtzLrAX0CXqH8fAb0ONk4o1kf8BmgcWuz92WjNy0fwWCj7Nu9kuNIACUBmeKiAGzhe0DTD+MRoYhcyCkLuLxAuAtp5Gt63vxleY2X+PnCV/sSbl8cWtb9lfYECz+OC9VyPuJSF142Jyy0pg6YAEbdVzZ0raWnbqm8DLgcvL5ULpZH+JoS/EMi/amAyxf2HNkj2Ec0DTWAoyCbxK0rO7aCzYBnmBm0gvEkC5WpYZeQfGp51xCCAVLC+mK7Xdv2amPjOHCdABh+fNzZ/f2DHZEgoXwUyuQwWxCcC79e/IxxjfsQX6fpQHiIBn0Nmm1688tITHGWcEm6rDBYphIWVQjJ2iA3wgxLiASW8yvTKOeEQHFdCuuVQhSd7uNoaFqC5DQ7070S2t3slSvwrTsUx8V0Q+n7XuwrX+zyQPIjVzJk8AwgTSuKdGtGwqDwqPbLHAmP8IsniAZ0pt/1iYm1PLaxBqqhHl8XY8aGg0lSnqyMspdy4UF+kX4KSxdaVbhVUQBmNdpTYgDJFDwXvyMehr3QxUIVjTxhxUQgHJUNAeeLa5DeARcJq4WWOSW0pKL3ihZP9JRzMs4p1rZ4LYB09touFNzF91EpQ79kvA+lqiI0OtS9iMaHWmfiNXqfja81atSoUaNGjRo1atSoUaPGJ2uF4f9zrKwmSl52yyjKw48Ji3pZNj7A+gdsKtBcGUm/1wJWYNsJ9LJcN0ObUBuK+UjBzKMhRDgSqmaRang9BzTCB+8M4wKGBFCW97DjLEAr2DBQ4iYwKA/OY1HOXdQr++iwvN8aW6+v+GG+G0abz601UAXv9wQd3XhIqsqARS1K/fFKyuGyihAeqlB1AvLB97mHpUALsAalMTWKDt/gvSw1Yd+uqBClAm0FoAw/5qM108Tvx6638wnnvSOgpBczaaZU3oJ3ART9shKqQ/jm8IYKRSimz9ZMx+QB/Ro5LgEpFaj5OW6CUiLlNwLmBNtKFwd/WhIBL6CIw6nXydMbGiMWMzsTtHTOAdmW9isXSQwvaU92Fss+ZGnfi4RFHE/AZVc/JqXvhaqVf3MbB6pm0zGnIypPJSuW06lmhTiB9wKkenNMB5SlfBL7TM3rJgeOsB5w/21lG7LFSqkoj33ib+uht5vrra3XA+fZ8Tjb/Qcv7XiQDQag8mE+0hMcylKsHSeMC4/1ZPMBtgYTqyBgxcKGf7uJath51tzMSn95sGut0HlBkQpPW1QNHGm0K1UyqyscnicA7CeTElMXoDiu4yL1kdamPHfKcQ4LESBKuj8nkBvJNDT39HmkEomshH7DdGWFBTfpd0esR7E/h8gx35LSnRBRDUXLWZNSLMmGx9fh16aWFOyw28b9H+fBagu3IsEvW/YSbbU+uddwbCYSdA2TYUnm7wPuxwjlulugoAEj1suyUWM+ek+AcP4GbPXjdxgb9yerAThOrvanwju7o8sZAqDZ3wvChuciM5ZXCd1kmEdQ3wNg5+Rh+FL7FSqhcvgmp3nnsyklYxxYO3yHNVRaMwo/9QDDMWXhmyyVuapOUhPBFfX4nqLJ9RzL5cF9xCtcrlGjRo0aNWrUqFGjRo0an7RiGZGhiEqVWW3tUElAWeAp1HRRlh0qToGCACJZ9Zo1pIKbAawJNhP09WZr7idJhSMlauzKxJLkAAdLIpPhQ4bKUgrzNJyiBKgFKGGzpePM/c8JLKMJlJcixxH7p3SV4K9svb1mUz78csbf5rPtHnfWHWcbNxPByarrvdQcIEBg+bTq1FiPXAAgBPsDWD4JLLct4TJKuFFWHuyA9iGrs3UEy2YbguWG7aUA0vbzyc7Twc6n1tbdwKaHvBZQndLWWdtrTg4coPpEqX0ozx3oc3cOwVB1jn3S4hTP8hLxEi5ngWwoAHOSgdfhjYriS6icIymHPwyAlNBYVC1PrHhKOq6l7jTgYYC5rM7MSsRkJ7Ag3K5gTKaqGexdqpw5t4pDCvgY8FzjnMFZyWkFfqTGXI5HbL2AVAtxZ4A7L513L9WAYNp2Vn/r8EkG7eQQkkJHJG6QEGlVSs/7m0p796llwzy/x3wMAmKuR4Hl7WbkccynyT54/wM7PO7YaA12LoDKu2mWCtrnXVRETMeZDdji9/M083VK9ggQhpoYCtRIEJ1oqIG5qvvLe5ixmZqami3Vy4tckdtCEOARsiMxki9rho5ZbV4qhZH2YYNBAsxChew7LBXJlw0UY/tkosUSEzNHiZ9sl0DLHP5cqMUXmQ2/m7g+6boUsyYMYRIojuektF5Y87gNEZoJyq5IL+BaCQU6OuMddS3YcPGi8V/4cavBa5mEWCpwCZT9ayQvIvEXnD4tA6kjn1sLFT7M4Y3MBNpcqtH9dcLi3BhuRdbD0A9ZX9kbMd2LhXo9rTHY7pHzMZKOod4PuJwSaH4sYTm0bO75Bt9pQnk03JMdDJu9lmuVJ2cDNiN50nVK+GFNzusUEqNKjmaMXqzNvk6ykWMlyzVq1KhRo0aNGjVq1KhR45NVLEePMkHHTLD0aX+hHvUPwlIPRrltoUAtsCF9kwEDUm21K85cIeuf5h0GBJz2D9cJPDicc5qtUn4HooRkXvqcQGHAEt9v7IPHrhPNkNHLqgkNoimVW1fwea4gc8Vy26/tbINKweFju4KqDX7HAnRJmcpTaKzrezsdYS8A9TWeVZSOu2c1gHnHMn+BvMtS9ij/pkqOABGQTWgIQA5wmbCYwChDcfnKyoJBXrOZOARwiisWFeOhZg7QUoLZhM+Swjg3rkugtWTDvr8LRJqBaJznh3lqpLlZqp0d1hQbvnx2kraHWrVselc0liwlo1IXZ59hKHqTqtMHIyvgl/dNwkiFynhpo3EZmqMl8M267kJVm4WOfl2KJEBqFuZqTdwPDZCwA20/59d7p6UrrqSNM2s1cFMdgGP3hdI82U24BywSLd3KaH0BC4y+73QfQHEPdfIklfJhP9lhnuwwwR/3TAuXsGnJMNDv76M/3IIjzx+Hikh0IeuR1L9FcgB+8ItZVED/UIZSxVo2PVwmSsKTubwwGZHmcSvHLuHJyARlPXeaD0ktnuZr/p53awFFY/1KSYoEqMvLd3GPRCPA7CaxWIfzQ+dGNTenSaiIGx4DLTI6h9spUeTN7RZJRLdwcG/gqBYo7TDSsS4qCTxh2ML7ntJoH/iY7KiTKMoCyJP9OrJJ4JGNIFEhgjkmZXoQ4rxeMnnm66nyO7pPuDLHPeGvidUyA/4wCkqtK9P7QplcCHuTODUq092mhR7LfI0nCEpbHL9A8f6X3geTP1Rx/dJivFyUY5zD8zzMjfjnaNxXzNzX14AaNWrUqFGjRo0aNWrUqFHjawyWw3NUPr6p21Pmvg6Rkz9vsr8QWGCTsARRXJlrAU2jIZN8gaO0t2nb1LTqfIJ6GBJZNGmC4ri3VQtvVXgAr2gbgbJk/tkVWFmNya52LCMPtZ0Ueo54yoZyhExuEesfzlGiz0Zls5fRo7T6CBAMQKBGUz0bQTXWr6+sHQebD5Ptd3s7TnvbP9zb3EDhBrA2J+qF5lFNM9jxONnpdKCamBYYLPk/86lgBmO7soNDnZZezIAzUGTywnBMD8cj1YQqV8d4ohy9tf0J8G62sTW76qDE9gZW5xUVy1A1YwxRwg5sI5R1JhREoImaSr9XbEYIThNNFKFaPoEDNSs7FZ7bZaMynEt8X0ym5dQqvs9AOFtDLCSYkmm+rmxewKByZ+FaWlDmROPKHn5xnu6zGgmHyKgUftyEb6EqDqUnpiYhWvZtDiDIlyxsKOL8CiVrqUj038QIhFoa90pz+YJg75gvBI75bOlJzGNFQzoHzd6Q72LE8zGnppcBi4O/YxtlyXyQNllWJEUq7j3Yt/S9bfrObm6u7OmLJ7bZjrZ/2NlpPtru8WD7x4Pd3+9st9sLLM8HwmQ0lCzBMtXaAHG492YkaeCtLLBMixsmp1o26+va3oYeqvyzTWhe6SptPtAwLpCaN3wLoCxQHdfZm68FmaZtjcY9j8vrsywNXkDlsFMB/Ewq+FB6B6nW3KTq95KzxjpKRbYksPJy1zxUIslYwUCuiPvSp21A8fCkT3ZDRXIvwGKZQAomHUrwHmswmvm1qoroh8aGQcAYySgxWVV/aCppfnSt1m/MgY7N5PQekBXLqoXIiUIcuH9PmNzYqhts1cM+xRcbvG7COLpvOccDLWUFaWGxMh329urVK/vggw/s8fHR9ocDLVOytY2U/7Dz6DqNJd57CIY9eUlA7Q4bZNeobOG9m0FypFd4Jiu9R/B6hFyZDSJ9TOP9BhZC3ugQfvxqTDkne5BFIqyA8ngP5EMk/LVEmyxrdG9qmdKchoJ8Ps02n/G+yRIA7gB2NmzmiiQnE50XLjY1atSoUaNGjRo1atSoUaPGJwGWQ6eoUu4MlvW318t5U6O9i0ZZ8aE44HNWxeYa51C7pYZIpSwvNZHTV0EL0ht3dy7Vb5kPXp5NgJZUe56UhelIi/JtV6VdKAdDYcfGSYCzgClDb/048tznCR/eJypGCSEAELz5WJQ3s2T+LPhFY4lUgq3jCPAt5alDZx6Tg3735YxjDI9oKpdd6E3oUTQwFKCMdox6HUu/CSniuHwMArgWtgmhmiQwuSxpf416foWSuItS/TcSjyTTjednxXJ69pLWLuZRgqev+TwXhxG2KoVlx2Lb0b2r2OfCnuLCOmOhivZTWPh+x4EtDyLPxI8zhIlOefO94ioIGhYK4Etf6XK3cR8UXr1p2IutCkjl80pqUiQm2taGvre+V5k+YjrAjxbWFlIo75F8Afw7wgYDkM2hZpEnSKkFb8gmoBynGOpOKJUFzAQb0Rgukht+37mSOtSjpPAXvrV6QZxbuXCUd+RiluXfpumex0rLSyiDL6+tQ8A8sJwPcc6Bf5NncrJ/KDx38zRZ5ErcKtjKA7lcApOVTrEGlwentbsA0ZeC2YVSO/+B605hu8KluXjx0k7mYjxiRXNrCz3KMQ4/9EieeLUB5oTbFWFu7fcZKKfKCX+u3jZyg1hVraRn5ENZnGW+R+J+4BEU7xV51PJFSMU1vlZma53CXvq1ocj7K8c7p9h03XVay/eiQneeYXHKDKW31+UcjbXq9StSo0aNGjVq1KhRo0aNGjVqfO3A8n7e6wN5O1JBle081VAroFJ8cGVbshXcNl3BCXLpfwsoIpSaNWCAmuCr49jbZj1QcctnOTyG4piKNqqZG2vgK0nVL4wlV3Y8ZoBI72DuX+AJikaqFalmbmW/4YRRSmp8MxO4UtFFeDsTZM2nkz+kAsPzsV+o8aA6Rpn/uBms61u7vrux9dW1Pd4/8ph2D2puBbg1H/Y29VAS93Y8tdb1ow3DmiriceyoUt67ZyxVal0v2AwP03m2x8PeHqbZ9mh2djxJAbdqBb3hs0zbCxgen2wYRsK9026y3e5ozbCyUycX7BFK8HNj8xmOs4LKu/kojsPGXCsbOoGhyRXRk6tJ6WXdNrxWXXOyU5S6R3BeZKpbgs0S0Zb2DikyDVt8Kbf9mop58ZziOC4Yd4I64fXtYDoDzAyLQmWfZvOCui7hNAESnlB4p6r83q0jXj/5xeHpcBx0JpCrP8rH1hX03FcA0tJ+JBTSMe8z5/QswGK7tFfh3I9t4F4WtC1PVFdOCklYuuj5ArYcH8wLV2lC9a6khh+3rez6amtPrq/s6mpr3dBxTN5/9wMqib/4xXepVH758hVV/fTWYWEC/Mxp/G1dP/Ba4UccGn2YH3dKkuAA0SjOfWvHcaRSGfcMmmFibh+hpj+daIsQDT7ZEM3w/F4A+wD7BG8WigPA/E66fQ0Y1ca8qIXfcJpqMT9i1Vuq85MncDQkdQ/gch5x7QmL29R0LSeNuGbJBYL3+BGG5gVHDP9jwsskq47rr2ZymmQY4MJtnlYpsizBWoikVTT7g6JWDfCkrpZ/thzoY0VHwinAOE/RbSWgVmZCgWu0K8bp/YsEWOh83bd4cV/FWi/FskoAypvGmyIyiaZ5KSUxEnNHmw5QwT/ae+9+YF/84gf28AC4LMsiJtG4bmmMoFLuoFzGA6p+v+pqMqo1EbugT3KRPA2kzeRe+LBjHtLjGHMX7zkuG4eS3Ktx6EsNdbQrwKPqQLkFPy8q0P3+pv2Fr0qeKEyqdm/CqKofNUGkfzbWZN6Tmk96hEUHXqP7Ck1dtV2cse5b5mMrWq5Ro0aNGjVq1KhRo0aNGp8kWBaUyRgLkUvlswqNH2KTEjA/8/L7BWsrG2kBALBkGB+2UcofIAsl1/GQYjlgIX0ok+dyxpm5pLxAmqks2plbagjlnq5xIH5ul2rl8E4VgM6l9jhegAOB5t6mbuLv1GzLAYE3AFz5J/+WIEf7A9zhkZfe0pL7ySeTPs+y4oBNAEvRkyOE4JOUxypjX9g2BGAIxTL8bwGF3eNTZdkat7ZRQ8VQA0oRzdp793DOonGN/4cplosolHmXSuSF3+qHxpfT05UazyzBjJn2ul6zOOQ3KFCzUjOe9LqqN7ackVN5OKXfagFrFwLYC+KcCFYxUA4jXW7sf3Zv8MXIFCccwHgBxZZq16xlLgH1m1XbOQHgnrppv7nRYShDg2Pj3sA9sB5H3hNErvC/3e0Jlnf0wT0k5TJuW4BHzl8Hy7S0cdX3cm5nlSaTS+5ZC7jHteFiDmruF97mZJZSor/Rn/hNA1FKp/PykO67WEcW8yUPcZorcd2Smtpe9yJP2y6uVM6lqHoiAGPMwLTtpFqOC5+TeTKiWM4x3JNaStE8VOtUuu+jGV4+vCKZWMikC0ukpFb2+yc/JZTGsfsyOVTeH0Fcwwc4A/40tNxuTO6ceEHiQIrl2fb7iZYpsb7FeSyOryi6CUUx1rl0b72x0iXP/cX1LhTa3I4PbnnuUnGXqmX/O329/X3gDctcOv/i3i3vw6xcjwaQhao6qibyllJjUjWWjO+rXrlGjRo1atSoUaNGjRo1anzCYHl2BgCP1BIiBEArP8Rm8HVJbLIIjSrPEOn5B+XGFWRjP9jYj7aHl+oRTaM669utDWNn683Wuh7AFi8+JkUc/Ie7U+dNpwIyukkxVX/RbCo+dAtGx0dqFeAfrTmHLyW8leWdCbXwNE8s14eXJ54n4TSUlVAeCyZDuUyF5wSVJLW+lBoe9jv6Jz88PHIvm2a0sRttdT6qwSFUa/RObWwNpfaxs5f9vY7IlcKHoxqbQdGphntngml4n6ayaj8XIqRoUEVfbLP2DIVeRyHgdm3c1umAZmgOrKGIO62MrqWE91KPn6k0vEgNJAh/QdyKa17aJsjXlrK4BYJdQI+AogvYGjAU34TW8k1RzrPytwF5MggLC5Eoq0/ANVM+jWeaRwEBF/LKAirn16dHeKaW0CfOMw4l/c4TAWksHAGm48t+LmKj0aARSk2BIt2DMQHkOxxjmc9PzRnpiJzUj9oOX+VfZRuxhH3xN2pXC/9YwTM1hIRUEhYr42ZjY9fai+dP7K1nd0yYfOkH37MZiuP7BzZV++D9D2xms74TqxComgcApm2FD9cEGS/mO4ChvsoZAQkcJXN6VylDYUufbyieZ9yjR+4njU1SWYf3AxSp4bkrJW4p9RUUBSiUej2uU1JllzMhWeIEx9U4Y6t4JcYkbwvqVffLLSo1ItGlBqCSMCclcMwf3oMrJqbwfUc1sfytI6ES85oA0/3No6FehtF+HEiMFYrlWM/7rmG1gjyBHVafcYVxXrDqCSivZNh67G0Y8Aj/5xUTBk0DlbrGS4peqaDjLkhriGf2CHujGR/nPTNintCA9FgqXZdg62SQfOw6O02THXaz7R4me3w48DFPSMZhb1BO91yL8yrhicq4R73hKUYK1wCTHFZFWisC3ef7OWAwkxqeSMGxYcxwnkr+qDIAiZWwScLX6DfAa+Be4vQTxz7xfsQmq3F85TVVAsUXgXRUfBbV3z42uH/ZAyAeOu7jCu9bAvDxfiif6mVjyRo1atSoUaNGjRo1atSoUeMTActyQdAH14AQors5SrFlBnIF7vMy+fzB2F/kMrxQynUNQHJn0zzTLxUNh6IZVN+P1vaAEPiA7KXAsLwIlXOJ8lD+e2m7QJjRLvxzBXHwwRyF2iiZVwMniCcJcd2/k5YfDmvUS8k/8AMcoAy/RRk+YO1Myw2py2BjMfF3h8OBgGFUjXYGSwQr4Dcr65vOzi1Ac0sFHY6DTfZYzu8Azr2Ol6gzcHqorpPgWcpEqJHpX2oCR6uVPSJJwEZd2j4hwxEQ2tWhlHdectvcxGqJW19XK2d/01ASLwHGm+0wLih2OYuSr8TlmV98LZpKLsFuwbBf853NRqMZJWXAnQ91+bpSuRjnE6r4dD7eOO9Ntr3J53QxBi5YLlTFr3mYJ1V9vhbJlqNUxxb+qem6OfzGnI1tCVS/GTBlv+JCxRxD0/ji4Oc2Qqk89HZ1vbHb2yvOs/tXD1QqP766l6XL4yPvJzbnRAn/UXNbS4ErL125jMaZvPd83Snh7DAOUirHwfi9Ec8PRWZSfaeLi/vN4Z9ntnj/c7DU4FArW4bqOL4FVS42WY6Trmee89lyxM1FvMJBTej8+AofXo5zMcXTLeQJBjWZyyg5pThKz+lQEdNepUwS6HvZA5UqWp9r3tyUlRYOW/McR7NSWC2oOacSgboWaFzaI7GHd5PVimphNjlVp9Z0zpeq5XIVWyib3eKDzQCRdmDD07Jdor/3+IMpxOlET/vD4WjTHnNGyUNdQ28uSAgbiYC4t3KyBvYXhLtIOnD9Lnydl8JfH4Pif55A0L7ymp6aR/p7RVaEK5EQ1jQxpjmhlC784vdJPh7prYX4O2e3kpV1UjEfvX/A6WIbMTfeeOvXqFGjRo0aNWrUqFGjRo0aXyMrDPzHbS8a94HUL0/uQxoKNH3e18fXorlbyQX5cy7/TmDMv4sP4RJgCbG1HdRgvY3rDVW3ZlAFC14IImWVppRx2fM1KWcdoElh5vYRFHrlD+RS78p7dZoBlABJ8DyiFMIZAO5x3Ng4rqmaJFDmQa1st3u008PRjvPR5sNkx2ni8dFjGcrnaSZoPh/R1G8yO+MhdSZVcNgOrAG63lZdb8dVY4+wD5gAu6H8xTGckhqSjfgIebIhA842rDmGrrMR4AeNzdwfFSo7+LUG0IRH7gGgDwo7eO7C3frYULUHa4zU4CrK8KGuO7mlBiA7CfZxwXszXH2NyKTnvKk51SKKZolpEr3xmRdQOQGZXN7+GvSWuW56dWw5Pasovw+g9freQ9Hpan5X+8a5+zQsINay0V15vUIbGTYkoSrMwA1q0dNr+wdYzXY0rsT2/0nh6IkbP6fksRvK5QvYt/DLXqjKy+Z/Ae6xNXk+Y06h2mA9DPRH325G22xGJlQe7h/pgQvFPhTyuCaCyrCvyCX86ZqhQoIJD1ljZKar+xcNMsMHO1uphCXBUmGd1oES+qd1Jc61JKixcvlYXkJ9X7eyRUm2FygOQ1t1SJtBf3rBMvlQJDIEEdtFRoBiVCfKYYVBEO9NPBcK8+Jr3N/hC12qbt1NKP0tmtrhOkrmim8cyPr445bRWpiV1qjgmGcB5pP7DVP1i6QYPOk9KdecTjb49QuAnk/fobHk/v4rQWAAeCm4i3mcYK+SgUji4Tjw4PmQNyvhF77OKruId6NIKi7n/imt/YW1S5myi/c37sAVxq6yzg36/HWekJO/dCQes10G599pWf3Da+nHlyxeYrxwTKwOiOejbwDOtVWFBMaI46Q+AnqUnWBdJR/ZRk71fC/XqFGjRo0aNWrUqFGjRo0an7DHcvrOvwQpdsgQZezFcwNaEKIF9ErgK0MkfKAm2glLDFd5cVveQYnl1PRt3dKmAcDidJoKsBwKREE4Hcey3T0+oPODuuqeuTMp2/J5ACTD8gIgmWCZZfhelU1FGqqvAc82bJDXd4OXeQMGoVnfgz0+3qcyZdhiALAAWAAoT4fWjlQwCy6jYaC6J6nEHQAdYKnte2u6AeYc9rCfbTfNqeid0AQXD2CXZdMBh7JqlpCYZe2drbueCnBollmybWfi4+AJtNs4wQYE47Li1waqaSgGAYqchzWFHzWxk4OVBk3dfFuhtizBWahCl1j20oc7ZsyHQeWPiFB0FkpE33Fq5lg8sVAWZxWsnp79VXOy48sh7fB51Tx0YpNenbxwLz1mAcTeCMbL441EjUNUJnPiHGKuC/5lWwHBMylOs41D+nsxpjreJfiOr6WtTQQb3i3U2Bg/QTvMhb5pbDMOtl2Ptt2ubXu1tlfvv7JXL+9tmiZ7uH/gy8ZxsAbey2HWjUnDQ8nG59N00H2dma3AMRSyw7BQkSaoHuAyxitllrLtQgBMriuhEtXdI9/Zwo/3clzKeRDe6IS9MY5JYa7nlirVmAOcJ1CPRsVCShyEMjXbbwAm6/nZLzfdE2gU90Y47fd1kQDQtuN+zz7pcqbQ2IXVAn+HxY7e8IVa173aWSeCLqGu9MVaNjdnO2K9OmEuQk0uFfk8Q5U+c61uaeGhdVzbLNaEmJewJYkLHvMBe3QrDCjbNbB+8O5VL4Atq6JY+8J+gmp459UBk+NrwHctaecEqNWnsUgI8Lq5FQ3HVOPUNLqH2UQyeVxHUkJK5HwsSvSlNQb3ztGrGTBHcPpFA0V6+DtcjmoCJGVo++TvQxoGqZ6ZgIFdDO9rQWUmI6KRaiQ2mLTweyJ8npdLWo0aNWrUqFGjRo0aNWrUqPG1B8ulDrS0DVAs6u0z5CtfHSI17zwUJeNqQLd0sE3N7k5SoQlwAFjIsxLc5XTqXFtYeMJGSbl7eYJllHBJx1EQnELRqIZTgkDYXzTsCzh9PkspXJYkAzBArYxHHDfUefCPle+lZUDmYACPw34v+AxP2nFNUB5Undv3Rn7ABFAmToTbghi0FnUk361WNoRf61ketwJroTSUd2mP0naoQOk7i+NwGMZjjMZQ2eIiqc8vTBpixKLkXzAljttHOGcgFtf99e5UDsFS07AAdo564mQ/Ei5fIpGisVVh2UCuHebeC03nEiC+BrKDySzElTGPaPgrYJPUyssXZtCIuRNWEzEfXz8eqqvLs4nmjsljplA5B2Nzf9wETt9Eicpmf28Yz6USeQlHy9MOcFZC6lBREqA1Kxv6zsa+I+TaH/b0Jkeyht7MTL74tYlzSPt2H18+NyuLAxLj/sd9EvYdCN2bGCNZvPB3JZxroPKO880VErwPXYGrZpxFQzXe59k2Io1ZMW4Z2Kc026Kpou6FAMq64qWGWre6N9OE9U+hatZ5udK0sH/Id6bvz9XKaSm7vFiORlOSIXyLwwrDrS5avzVogcFmkbIFwXpXwumYawFKccwtmyBKLXs6Nui1amZ90dAUKl3ZlhCypsZ1xblCZUzrkWXSSRfKFfFF7ihuIVznBsppegojaXew43xIftnya/b1yNeo0nwjN2QVwA+VcnhxXzaBjEUu1pbkdb4q4LFbXoTqWElSnT99k1+rnCj8KtLEicRAvo5RzxCqaHkwe5XK4uL7c/2Y+D7m1suxOfpDp/ev/NwaNWrUqFGjRo0aNWrUqFHjEwXLjT+VDenKiNL+hXhOCsAE5vIn5QJKZWpXKpj1WdtVaPPBDlT8YhudQeSI5nb4MAxwpUZ6e9lJhAoP8LmXjQYUx3qtfwCn73B8MHflWlIre8kxG1QJSssWA8e79JIOCARoANVyi4ZNNlE1hpL//eMDPWIJ1/Y7Al3AIthgrJqD3b/6wA6HR5ung/Xe1KkfoRIWmICC7bRq6K18mM/2eJhtD9m0w4zBIcS2bWyD59LC40y7i77rWZaO88Pvu2Zl49DZaT7b7nEv+H2GghpARtQhVJG0NPHGhuRxAG+wvYgq6vBVhSIR9gcYJ45chOB3VnQuxyzjNfdzTpXtYfsgKEOWml77YXB5Mbk0t9yvJQFqlpALZoayOOwrAq4k79cE10vhcDRN8wnuTdnybnWt+DxcnsKqQgmC8FQNo+rXzyMAdVJRp4Z7/jMhGnTrhY1G0Shzqc6VrUnZGC35obP8fqlQ1rb8+juQje3RZzc1CCyBePZvFZg7MXHRY571rV1vBttuB5umvX3wwYlN1Q5oxHc26/pBO3X4Ru9ygjSBd5ynmu6pmRkCtjO4vwA9k1VuWBe4P3hYDOAcpumYzkPHiZ/z/KPnuYNl2jrwb1Kgcs41J4FOqmdxX2l7Ao9nekLjuiZv64WyXvSdVz3mTmoWGY7Z6eIykAwaul53j+8DHsUEp7BRiESLz06BcNnRrBokky5uM0JENaoLFW3nfe7i7wTKfCgxJecKt3GgMvlEq5zTyZvMOXzEdR77RsmDAesxxnyixc88Y67hAo2s4EB1CSs6+oHVEuGLHUpczdmTHaeDNe3J2mGz8B+ncpkWQTjPtFBonKI6YDVz3qAqZPf4ynb7B3nb8zWwofAqC1SlNO75jFsVlQr0lV9RNazrq3OHn7/uvwyIY9+wn9BvlMABSOb4wnLIt9HPsCzSfMGdjMaSWJOV7MsNA7EBvCcs6hYSNM7NZoMK63plBbPe66TOjnGJhCyu76rXunfENfRkoqdQFo0BmQBAlUzVLdeoUaNGjRo1atSoUaNGjU8ULCfJsVRtiRu+kZeV5fx8UQbLRdk6/xL+sPEB2b0m1UAPAEiQmeW9/NDsZcIEqrnMPO8nb5swhyrjoilSAj5+Xk6U+KHdbQQCOAdQKz12A4CHMpqqvKaxI47FgRgUdIRm02xHNOrzVwcww99sMkLow2FPMNENg2CGK7npyQylMuG2HqFdDOVf5/YD/Jt7LSfI4Go4go9mZfPqbHOp0iv055eXiqDFoQb2iutQasoDUcTx6HIvmyRmeBkS4o8WHwdUTltIvf6i7L7QaxbHG7jzMvKhlB3z0t7SHF3q2ZezKJSCSfJ38fqyAWDcC2oOlp9XwnKNQdJmX2wvsI/+CyVwOSDhyZygfVJb57nNv7khcToyV55f3o0SQgeovoCjxbN1XZcwPcvT89kT5tGGQOkDzPFpQlM1VB0A9glnxWulGs1N3NToMh9LjK0sGgAlBUQzBI85vFQWX55L/na53iQRMhWcca65V2U0R7yMaL4W+8tJimwrEOtDPCcP53KbXEez+8ci8RZg8fxaw8x0Zmky5MRIsYs4tnKeuio7FMyRiGMyK702X/+YOQLtAPuucnbIqdfmc1iMU6HmlQdxXqfTmRTnnN8jinPkPC7V4stslaot3GMZQHf2pESRsFoM+8W6kR5lQrS4hjoGtwQp1hEdU1iMxD1YHJ5fiLDhCXsmqaGTeXNxJd+UcCqsNS7evwTm8/VdzIqi8qBM6C7GtZjjeXuvHUKNGjVq1KhRo0aNGjVq1KjxtQPLUJ0x3Ec2FHGCkIK3UVKbFZqIsFWAO2fxod8/hUczMSrAOjR+WtlhPthuv7L9hFL6g+0PJ9vtZttue+s7KRT3B7gE69N98EN5Sp69JBuwuPMGR24hQQgs5VjAq3EY1BCsGfj8wzTZbn9wReRMJXACQE1WwrWdexgPUlTOx5WAMZRwBzTpO9h+t6O/KBosUfRLleTJTvudrQ6AY3QstXE9Wtc/tdUwJPi820/28tWDvXp4tMf9wfYTLAU08gNU2a3ZVdfaVdfbAfB6khUGhH4s8Ye36Eq/6/rWdiez3Ur+oVCxARXjfAc0AmSDLAd+vL46FzQLtBn/F3SHwlH9Gk82n00PgGf47CaIcaFov4ywVKXCWNBN1gJ0uH4Nn71WOf4ahCl+dm/aEnjmzSyhVAlXoZotnz7TUFtq0kioRGoB43AJuDTn1cAs1K15g/IPLsTxC9aTPX1zI0GqNAm040WnbJvQwSs7A8Wc5HD/1iN+o3sNE4abSPeIQ6zcDc9NVcoxjfMKv+gl5BJ3RFM0PYXWub4V2dLIn/zxYbb9Dj67J5unE8dmPaylnnSnB85RgkAkjaQMLscSwGscejbtdIMGJpeQbHEvFt7LJz8/3KtQz2aLDYheA0LHeHuyJDIqtPJwv2UOdyKL3rAzA2RZJeTklOxuGltxaZT1zYIvJsjtgLUA0Isjoso8K5a5RsjUPSWbmNzB13JOS67N9ZBDInNgv080dzWHw/4CTUExrlAfu1q6x+uxbkTlgqu4C/DZ950NWOta06PXdcHruk5rIlSveKjqw3Wxoa5N9jxKCqCRH8+5nIGwQMFa5GtBSkS6Ml3zQv7AtCAKBTgTeBPX2t1uz/km+2rNDVYNtA52fX5BjS1Hbd8LToDn7NZAAaWpWs5JSqx/mlda71gXQgV4yMZjTbclTGYitNU8xnkz0RI2FmpCqPHPa0bpsRwN/NQjoPU1AAkCrzDw1/DuZhPY2MzRzngslsJIAOWGgl3fVbBco0aNGjVq1KhRo0aNGjU+acWyQzuHIlnpurQw0Af/rOJbKEcLfrVgF8n2OJfsA+5JmSjIgsZ3+KoP5N6ACl7LhUo21LalYi01r3KlcyjAwjMWAAsfsPu2t67p2cyPcuJUhkwvjcJtWJ6daZvurXk6tQQWSfHrfsoBaksbS1kSyIsZVhkAE+EtGxSAthoTSsyPrlrO+I+eqF6i3mO8cBz50uRDZ1PEaFiYG8qFKjtUpKGATpco7A/C7uCEJn6uOE2YU18DI79ufZEhcPYaLkq8y4lwKSUsv5Y/XKpIC6VmgikENW+awYE/3+BhXMzDdO48+mahztZzL21RdHw6/wvYvDivj5BrF+rx+E860sIzJHyD89OyD0JWH0Lt7ErLlMjJTr2L8bwE9QsVZQBn30ch/OZ5IVkDKwGfFek+8xlGxT7sZGbAXhimaI5zDpALZ3uJ7AGbzzOV/LsyNhT05fPimsbwZk/0sEUpbEOC05czKIbPoSK54aXGczHdsqJV88FBcfLREJhc5EK8qehyKizhcul7neZe0dNu6Yvt1ySbXqdpkoTkCU/G/7KyulQrh78u1hIk42QblMF5nKxUyrgmATOzbzNtF6JJ4YV0OW9r+ft8DQvvb12CnC8olczF9ctTMi6WLH2gWGZzyQRoi1EIL/iLUdRjmTh6/bvy1ojrHx7L8b6V1788h8tqmoDMfhKp0mFZPbO8lnHcy7HMDSEx11KaodSB5xUrNU2N34fNTiT0CsV6VSzXqFGjRo0aNWrUqFGjRo1PEiwPrjijb62rKAmQ0n/Dbza8JyH5i8ZwoSCLCODhkDQUWq7u3O+npOSjSnY+2e7xYLuHR9s9vKLvKvw3m5Ur38pP5gDRaAjlakTtTko3737kH/Zlq7Hdbm3oB+uatbUNlJEPtj9ArSsBLJV+VApC1TbRN/l8gvIXx6jGU/BdtSMUo6Gqa+kni73Qo3XSueO4qWolTRB4mKaJYHu/PxBarrcAaZ0dDlAsv7T7x50djmeb4XVML1rB5IFqRJ0JjpEetLw0aKQFb9jGmlMo3qDcduUmPFwdvqj5k9FOY8RxJQ9ib+jktiS4vPScTVyE2lS/kkuQFirbRfU6n+veoYtZdQk0SsxUfluA12hJFiSnIEAB61KuISBtqOkd6AQ2TntNUN1Rk9s/MKmRvMA9UeLq/OQsnfYhy4BQnsafSkAVwIjQ12W7UvkHuS2HAkmLVnCft5pPyFApe6aCDdw4HA5tg6i+gdXj/syXpICaBbxMDeGYD1kCxnhx4xOh6QQb10NjN1e9jYNUmVBnBuRDogabPiBJMqlhZfhdh/dzXEuuAe4zDm/e8OTFdtAEENUECZana5rHB4mc2G94RicvavowN6/NCd0Tvob4acp/OZoxCp4L2OeWfG3YlUSTSBwCklIxj1xlzPuHiSGsVxk1poaVvkbpaXHMJfSMeRukPM9ZPiKLxEZ2APj5smVonpXSWDN43/u6JjsJ+KY31nTYTON+vNniAs+3MwzrXbmNa7HfUcncjWtWmuB6QfnadYPeB7i9lg/OO2Xikgo7AGusIFSW27HwX0b1x2EBodO5+/0bPvbzfrJpN9u8n5XQS400I/GBtzlQ8/Cp9iqC0wrLcKokUBWFr2x+j+eVDudQ+mUjUYLxl7c07j94MDdtlxObrCLgO0iR7Ys1LNoV5uSnkkZSNyd7G/d6zlAZ9kuaK8fTzHOAFz/mY6tWjK6czmtwVBXxZy4Bek76umLZQY0aNWrUqFGjRo0aNWrUqPHJgWWAB33A9/L68Cwt/puVVm6TkaBLVn3yeQkUCs/FIwDfNAvYYn8ApviwP+2PNh0mmw87W0W5MR9qxFTqtUq4pmMq1XQBCgVJUNK9BhxpNtY2g+0PMz/YC2oBpsqCYBXqXZPvM2grm2QFsCRAEPxjqXbTWjv0Nh8bm/x40cgpFKdx9FTazfCjPbKJlUryW6o8UdoNmAYQDNhFKOQWJCxxDyyT1KK5gVNZWh3l2Hx28lmWQhN4o/OGfNgHbAWSoo65gQwgCR19R1K/lRq50rU5z4mISyuKmC/2EXrBpW3Kkm+WW16wXw82IyyA7WuK4oCOCx/W8Dt16Jf8fstr5uYtya5i6Xl8eRxpRBKY1VFxRrzhuJfjIxAKZeICVRcNuwjI3GIgNfHLktqFHLL0hM63QXh/p2HJiuVFcsCPPYHKla1oLYOESWvr9ei2CrKwgMVFwFlB06Md0ICSSRCBM5T04/np2Ng0UPce/kao7MpmKfinBMgXiSO318Dz4ittacrkgV+zsBAJWwsNYZgipJtoYWMR1yjSYdKLxh2QKzZSwqKA9WEvEV8XJkEO+hJMzi9N17q0YlhoqUMNzFxa+Go7Ar24T/Jqi5f5ehGn5tOFimRYYmBTTb4P8IBxTtovgSzWLKyTvrYicYWKj9SkTirmaHRY+ucHXE7T1BMNqWFjG0lLVXVkOuoEuMjU0D5jlo/9cUYjwIDK5drDmo3F2hRjnBxh8jTKFh0cW7fgiBNPkD9XvvD958L2YqkyjrW3UDMXSubYQalSzrYpRQNBr5CRX7Wej/chXPUWgJuNBX2+FEA5rdB+XoX5VEqiAYrXqFGjRo0aNWrUqFGjRo0anyhYDrEfnSGg6KM3rdtVLA0Gkmou1Hj8bSHbzPrVDJX50duVh4AEM60bAltL4Tsfjvby/Q+sHzprh86h7oIH+bHoA7csHgSiwrsT0CmUezgnACwolpumt2YF31BsV+X9AlWAtCUOyqAQAIBwrOvVXBAf8gFXAMUcQLVH+RlLCQo4kgF3eG8CCkjFKpCLZnwzPZofCdPYdM89f7NKMVR2fpwBsxzWRNl/gg3e8BAM4Ugf3LP1zigxCaCvnldm+wKykMPFdaOKMcC3xgfJBnzTNq4Q96vFq5vkkvLDLQwrFlre/LyyWWLxtKX8+M2RYKh/TZJDbz6WoM6pSIiUit3Y/5sAS4aImr4BpyKZUpy3w8RoRFfCqMVrg+a5YjlAZ9noLTX+ipvPlZOcRyeVsQfc4leHreGVu1BSF8kBceGAkFkRS/XuIv1zed8Wv2MGwhMT9OttbD0ONkCd6tAwXYPFBvI4xD2p49Z9oPs1J0GYeAE85D0YkM+bjflahGsclsvncttUbwZIdigXsM59i+P8ssL50lKjuG6+tpTQNquOC8XrxUqhsXbgr40t55bv//whiZWY/rpNls0il3dEbqjq2NPXWFUyABrDI14+JEWSgY0/A0zrvo0kDCC/9jsxmUf7HXrLd54UwFrZajKGDwfHHyrl1tq+14NrZLdQ4pbXmEfr6unzcaIK93Q62Pl4yDC6XE+8agFWQqhiecDj/tEeH+HJP8tGKewzqEJeVjXkpoLLuc2/Wbvwso4x5hg6IGfCwyEtEjO0Q3JLJVSIxPvQ4v3J5yELQKBkd890+faH43PMIW9eSS99bZuWTWwwW4LubOQRlQd8j0mLdlbvh2I5Ndl0i6S8NtaoUaNGjRo1atSoUaNGjRqfpMeyfwgnguNn4OIDa/6LnuTkk4Ci8ZLj+Ige8KFsKOUvEViAJytUYEf3gOQG7XxubDoc7YP33rN+aG197SXYw8iv5TEECGSJMr1B8QHegQGaR+HDOXjICo2oehvYwG8gXgUAKb2e6ZNcKL8CmvOBxkfDwPJvWl6sjgmkQAHcN+776UB79m2prFq2GQkso6HeCd7KZza3OjhY3k8H9M9zIFqCmPArFgQDgEtA2RXFPAd/LbERgY+SArhuvcM/TIL1qrG9me0Kj1eCz5ADYyzo64mxFBzqoS6lpQYAhcrwlwpp7GcqVIGlijpGs4SeZZKiVBJ/xMQMRks2JkWsLD38xRyuDF4y4PNJR6hXKjTtdRD6Ybt2YB7KZYJpp5OOBjOITxTrwmGiUJ6Wofnr+wd48oZ3CWbK+ltNzBKsC6PdAuDGDsOWJS5GSFZZxu9zKgY09pHUsxmX8tVuN4FGm2x82TW23axlpxBNGdkMUguAti7IXiZ6YPURjd10LzhsLtSaLPlHI0k20cy+4JqSrgaOfAALBhprCu9seIMvJMJuL5DJsjfUJMB23+dizMJbPmP4sJxxxetCgVwCvrjWcd0v5nUkOHxNSkmAop5D17ucjwDbWX8crJVzIixMIgGVjjbsE9xO4zJTwGyWK+9jPXbwOoxQjZvNB1hMIJGke77H+kZo3DlUzpBYsBTXtrO2H/To8JCaOal6fX0q6xy4/s+zEnEAy2eAZWyr92MWyI1umNPhYPev7u0eTU5fPdj9/QObnE6YLwV2jUUx7juOBytS8hKRvrKxa/D3nCSlDUXhq8/0JZOBZ4HfDklKNYdlYoLVLIVliQ8T1+qwSMFc55wFxg4A7tfH/Y8B7wGXAZbxlUDcHdDjnkrJMX+PSI1AV4175GdrGNgiZYsTjUkkW2vUqFGjRo0aNWrUqFGjRo1PDCzL/qHwYnUVrosXF8qnKLsN0BLYjn9LyCQ3iYoP71K7hVqvUFRmDR8hU4MKaX5QdshAlVmognE8oYaWbi+ZIrBk/uiQABAgQxepbAtA56rGUDDG0aeS+PDrpRIMIBwPV3ESjkHBrBBMiZJ5DJZbaCR1Z1EHH23TCLrg5xxgG+AY/s15HON6SMen7cE2IwtlXSV3AQ3Cc/OI8nJXVHYASN4YsZRDRsOvAI7JQsJ9fmVhkrFksZfXr3qCykGLS6VyGUvvZIq881+K73w8C/i2aDjncyrSDQmyhgy0UK2WY/rhkKUArMk7In7WNVueRQF4C2VzNE68jMvfleNS/u2116a542r2tM+wT9B+o7Ff1sVmFW8+r6T7FiDm+EWDPned9XnVto31PRSVYYsiL2F8K5Wx3/uL6xvANNt3AGzRA7xICszznL4SLIe3dTH/S5VxPvbcqI7TwDNigtOhfM4+3EsbirA+yTvK8yvWstDN6vjTQfG2ifs7GPby+ifFenkSr80YV6vGOSQv+rAnydeT6583DM1gNy2mgpZJpRvwNm+7nEZ6nlS09LkubFN4/ztQLcF/gtas/IB9ScMkH2AoLTEKa4/F+4GvmdyrW2fQz5ycG17O+PvMRBFV1p784Pc0vD+luQFv+sPhYBOanMJj26s3aIkStj9h++AP94/xs8Y11HhoPGPhA6wv7rlzqQF3j2qq/AWcY7yUB80K9FSR4TYs+aKUqVgfT0BjgubcuK90EeE+YuLi3vL5R+uSdJ6RQPEKjPh9rL0X93wxnWrUqFGjRo0aNWrUqFGjRo1PDizvT3MCEwhZYXijK+crUArigzUVgAl4ZDGVIIMQ6DnKm0M1BvVj1/BrRo4KQlgv1UepszVQEkMBJ8DMD8/wVQ0ZJxvzoahZjfSi9B2Nnk5oANa31re9DR3AGNRm0fTIwYnDaJyXQIWaEzaEHKG4FOwAsFqtAMEEl8FJAFggP0sl7igZDz9OHCLKnwkfToQhLRucuVIUAA1g5YhmZwc7HidaVxAPSB6cjFrD7uLkemFci90RFhdSL0NRmiCuAw3B+ZPNUG5DveymAWsCw7N1jRdlOxWBChUPAENYnxBoHKXAazsp6KBT1YWK2RKwroTLl8DWFdfxipANpknz5h9fj+JJgXIXfq6OlMIGYuFlUGLXy22W4SiRUM5L1tmgD9dEanqCsQIsLxTHF/YHUkCGqvDDpYKl3cYCTBagNiV6HM4icD2kOhVYjkQPQDDmAeY1oS6f6WBPG0ugCSDXvSZ4rwI49rC6oEWCulqux9Y2687GAfMd9ipmpxlWBlJfinVnBaUaFeqePeL3YW8BOwsciSuYcU6wgMlq5RPv0W4ICCll8TxJyZy805HY4b2Fc+98vmvdQDNMbvMiiRGAu/TmZZVDhwoC2PIgQVSsZXG/8dgExmMMXc+c5gzHKtkn5ATZYt6lBE7MDbegiHnmYFeNBrHGRXJHf80wNCwjioSP3/bIm6Uj9GMJBW7qqxfwsoXVhWxJ2LTPYj3rEjgOWyE8H0tn169szXkw2ABLlGGk3zIbAZ6VOMA6ikoMSp/jwFypjvNAIzr8ap72bL6K+cRmfk1r3bllhUnfeXKTFg8r2+139sHLV/by5YM97g62QxM/rm1o5Bhe8lhw9eD+WXWBrUi9z3cszG9vlMl7ubxbI1GDex+2F1SEy+ap43tCI3/naIzoWulwMiagjosR3tXxfjhjzW4WVkX0usY5t5o7+GvrKnQp0UnM+V6GuanrGUlfeVPLsiga2zrET0k9rA1sl6uqB8rvc8KyRo0aNWrUqFGjRo0aNWrU+ETAMkBKVgELYuYGV0syF5AqflioCtPflo6kSUCWXkOksqCRUTYOywh5RmbrgWUUiqykVs6qxCj1F/wrVJCJUZVqvtQCLinb+G0BUeNR+tMGsNI2UR4OrgCQEtghypjL85L1BqGBA7XYRqgss/4uoGwePzykWPZS6/B9dQXdoqqeMFrHSWgBz03Ci0C10ewrQ01w7WQU+iYiS2BSSBQ/UgaXjz41nHJVd0iU47wXEuriGpeXPalFy2xFcaFKMB2exlnt6a8vPI6LYUrXM9t35OPPStd0CsttfIwa8xIYf5jf6XLe6ViXZ1aYiCRF4pvPJQwXlrh7qYKPcwylqXyjBaMtGkgi6eDl+UFL4/6Kq5sVqq7wJbjN3q6lyjgrWlWZUALfMkGRvbAXs7roDRf+ueE5Hfdj0sh/LIW67tXlGpdsNF5f9gqFa24SmK7ba6O8rCRIa0/K6VxOxPTMNyRZYruFQtpXv9TgMz0zzyNBSXgHh9VDXg8Tqy7mS7IqoWo5NuJKYIfUsA4RGAeMLtf/rF6POQFbnTyQUeWSVeXpHlucr5IK4NOYI9M0y/oCoN8VyuWj3K9flqTuL69j0iFHHq64V0KNvqxAUI4vYC0rVorEQVRx8P2yeHMr3k2KeaY1TzDfk3WeGAoP9rh/9ZanOY3zy6kznxvJdyjP23itRNNlgupD3j5r1KhRo0aNGjVq1KhRo0aNrzVY3h3gwGtU+gIoQR0MxXL+bH6JqrwU2VW7WWEnddUJ6i12tj+xOR37SoUiED+77yPKrwFh8T+o3qZZH9qP05EfvgG8m+h47427qBCG8pFNllTiTB3z8ciGTy1gCBR0hKoq/aby0YFV2AhA9Rj2DfwSDQVLZS59a4lopYxjQ0PsZ7Zpv+d2OvhAQ4XW91RQT9Pe5kLZCeUjFJXYx+PDgzelerDdI5r3HRIYgWoYqOK4OtkpTK+pOGvstIK354nKPSjdnmwHGzF2rRpECboIQcCzFOrWCcd6PBEOogHbuDLbnBsC55nqOiiWpVqGevG06nTN/NwJ/lwtGw2rLolzAmoBc8u/BSSFxyt/dq/ssBZw2Ji3w5mTN1Go8EoIlcGyVOgJjhdl4No24I2urmBmhqJvYo0fBn6zt268duHvUsz9AvZ8BHAOmBpWEaHkzarWAjS/BvFLaOR/LJp5xfCG6rQYtMW1SfPcz4v36WHiua17+L62NvaNbdc91ZVRop8Qp6tnYVcwTcfifKCcdh9zjE171utdkQtbg/hKP3SHa7i3zgZ7hACGWQ0qsCeFftND36l7HvMSFQ4AkPBtDxiHtetSUQ4VtovQnTkuIZ0SWRgLqT3pBh5LQZKnRmSISIReWMWkLXpSQ/7osqDgUkfw7iJbn7dZeSrgjuD1S/Bae5MNB9TWrjiGohprKaGnXpfO122AhhHVG1AiQ5nsam0IWL1ZnRKJUvkOA1TJaMTXevM+r8w4e8UGVfGt9f3INUbVLH7M2C5sjPgSaJGhUTdrT9pZWC1BqQw1saB4buxIv3w/y+kw2eFwZNO+Vw87rnkTwDIfGs/Zfbkt1mdCWLc44lyTOlmpUlctxy0bywkLRAp1N+87+R1r7JWsmJuVFMtsKCBfenmfN9ad86NFrwHcA7i2eIG/L7VY9+ijDCgPS5GW3v/8ftD7Lfz88WS+73riEeeJTKAqaHJ29FRcE9wTnBNsQOvnEx7c/pbHtrNVsVyjRo0aNWrUqFGjRo0aNT5JsAzrBH4IRqMhQJuF9/Cl9tE/lxfPWQqbSxDn/sH8kO/P+RAVVbnN5GkbAGgBD4sGUVly6pYZKiV+XUV2CUZzAyznoq+faiiHL1RiAdFmBykYZIEvAZJ5zs2u4tCkVD4RKnO8AcTQyIrN//KRlT6vC5U3vWqNpeDHsiEZgXAcV4Z+bNp0PhLqUxsudwOCZEJ4FxETuwRtITcpRikdS7ZQeF2lfKkCXP49XhcuvSk5QaCWlZblVXpd85chblJAvwkNLxTHhSrd7VPSkwqyLLBXqmUvjj8kmW9UkpZjkpst5kO+0J2+YV94hD2ErDguTqdQ3Cbva/7SwWeRBImv0Tgt+UNfDtHFDzq9sBUIz1+BR9ojlAd1YehNiHg8LvyLaRmTzin7EWeVf7ankGIzdJnRKNSvVfI7D7WyICDV+OSU2g4baJYJonTdluPu+Yw3guX4Pl/vN1zi4mIkSFnO6by118aX3uaF4j4pitHwDRUaTN7k+UMTB/eRXpRURIPDAIhFU784vtQE0ZvYwSM7LDGkWs7bkMVRHEupWg5LmaJ2Ink3C/QrqRf+49mfXZUmAuZJ3V6o5ENhHet5WkOKtRJrI5MW+EoPbq8DKRIPSR3Nxo8+thyrEvaXj8VypXP3uR73qgCzBiv6qcKKgg1MaTORvb2ZLEAlCNZQ2mjITkNA2c/Rd6SmkN7UMr6mR77vmXj1ipR4Pyw9/y9nWbomoaIv172LdbxGjRo1atSoUaNGjRo1atT4xMCyPqquCC3p8egKuvOboJh/opUPcwbLwUGSYqqAZxk8SAF7Wp1snkGH4K8qlRgUXG3fWYNOc/xgn2EdoRJVXYIeUmWGOivUZ9gO1JZQgfXWQoELSB4exw5JAn6wHB8eyvQqzh/SpfYTPBEAkzouFJL0c+VjtqY9W9vDyXhlPRuQ6bVQbSKkQJTyejXN9urlKwKJh1cPdtgf6KOJ46bCMWGXBUdyFTFUgGc7AGafzPaH2Q5tw4Z8+5VUo7uDq9x0wWw+nu0AP9vzyaYVGmWtbNu30jbinJhIcPWiQ5NAX/jbgeXnH11KHU3UFs8JsJFKs9s3gl8CHVmwJlGuvspbNLafmiLGkxxqLSFyJBpePz5tCvMGPqtZ4UsUVIC2XMaf7RgC4uqX7vEdqlcZaicY+5XYYpT3VPb/lbI14GGOUgmtc9GvwwrAvWX9JXFfxNwvATW9z1P6p2i4eM5JhqFv6E8OBeuA+9H9X3msvNa6dwF091SXHjjnocRM/uS4Rx2+xb20PNeLsn0kaqay8V5u/gnFM+79gJ24p3AP0mfX7Teyv7OSYogSYpeJl8yO3VPdaSPXpqQSLxux6YKHj3KaGK4eLaldSqSUDfTcqoYK24t5qc1oLVokJXItQN6ezwuspUgmBVSmZQn86wO+wzvdPXwj2RMe9Q2u6ahr2rmKl2szIL/PihiXWMewNtNn+oRkQbZN4nrtthixVhrsk2OtPh/tuJr9GkOl7CppKIIdwp4beORTVsv3BPx+93iwh8e9vXr1aC9f3tvD484O88Gmozy0NWUxv1o1P/U5wOau+H2vCgC8D+D9AAtNWOmsGiUhWrzHuK9yBv0x4roHCZMbeD/jGsmnH5Up/JtbWND+iGuLQ2lvOjufT3ZABUrKb5ytW7XWr+Cv3NmA9yiqmKU+j8qH00KVXSZBpbx2RM2DjfdZJgxbVIPkCg2B5jKZUaNGjRo1atSoUaNGjRo1anyCYBlNmKRyFUCTxUXhpVpU5CcVVFm27h9foYgtwWIh+Mw/E4wAegE8oDTby30BKfBB28t7i01k1ZZvP6v23OszNdMSVGADI5QgU8Bc+nAS8RBGCBifuH8ev8Pq2LZO05XB3hCNZffFA8ANoAW0N0ZM5eYCJgEKozx897jjMR12eztOUiwna4WA8dG470J5FmX1gEoo/Q8Qh/JyQIhphtINvxEwApeAJ/MMmxH3UB1NYBmwGV9D+RgSzKycFuw7X5Blqf3CI6C41m+SGSclc1YbX0rDw0ogWW0Q6F02vvNEBQFdBs5pYAroVm67kHHm11wwlssGfK8pR93eIZ1cqTiOdEwoJN8AlrOKNyda3gSX429hD1GOYSiK9aMDI553oRANxWKR3Ck4pu9rqQuX5YN7ydKH2xKU7NH40sv2A8uRKusVdjwqmRGqUjbfW/XpHo2qAkHlDPAvwXJW9aNp5DHEpg7rm6LhnJqu6T6W8l/JHsDOZRIi4Gvp6RyNNktlcwDK3MHNFbhxvovr5PMwVM9xn/j5JF/vWK2SD3IxWxZq9hDh6wJjLcp/LuaTz71UreFrUlpT8WyuNwXs5JoaStzYt3vWA3BChc7mp169MGe4qucExMc2AT7xNuLrWFGlwsRBU1xbkdHirBtbsQEiU2tqvugWECqPUPIDaxQSa0gUYAeYV/vdZLs9GvbtbH8AVD7ajPXa32/SPRTXzI/JkTDXf1gdtdHkkX+OxrJewkGwfHHDhtLf5y9ejMPF/YB9AuDq5b5eoACFXNvv4dURhi6pGqRYsZi865qOgLnjuLqKOam1S79wb4hI4/t8D+WkkxJEca+Ekj+uuRTySnSW6v0aNWrUqFGjRo0aNWrUqFHjk1EsB/RlMzh9GF021MKHXvlAls3Qojw/wTtXUAVWSQ2OQnpbiJ5fKxt3lZwgc0fVm0qjXcF6bggC6ElJf1BqrN3vtbF+HKwdNzYOvV1tb20YBuu6gZAB4Bx7ktrxYPNxcgiWj4cKvW5lHTxm+54+oxHyAZUvM4B11/Vm45nHSM9THpMUuMK6ja1OGU7A35dwp7APYAk/y7g1WoAOVPS5CjaNvSvQqIhz9R22r0ZjodQMcJ6hltCCUdU8NScbVo1tmo7A+QDgnJpEhTp3JUYSkJXl5xkEJluSBOgC+eeEQ9ngUMeWVbgAhFJUOpCKVxfM9nTGeQmyBTBMo+DXUL/W94CnJVCUcrlISDigy835ioZyC2sBWTeAiwnKhCFukSUpFPtZzS/l6IepAj+OijnfX7oemBucTwm8X86FsPWIa7BsDpf5ZDTNy8cR6uYMeF25SW9kAMeV9X2bHrgHaPtygLr/bPv97EplJDNmAjABRiSE3CogwWslnvA8ANhQGjP1UVQESCmMay6ojvsvlP/6invP1yQ/uQBpGKFTWMMgWXLCfQWv5nydfHIv1p40ORJgBmgsrnsxqK/nIwJSexPNWL6KJEDpvxyK53Le4Vt6RkflRWGH8mHzJuxsFndFAZhDSavGi9F8Tkm0gNgYy95V6PQT4foTc9XXl1A68ylaRY7z2Y40ypdfdXpe6frjPu+yg4BgGasHvIAxD6Fc1loU6l/4WeO9Rn7WJzsedP/dv3y0V6/gQb/ndcU1JXjnw1XoR1/zeHw5CbhoIBhzvGwOGZmBJAZ2JbDbrCC5OYyD5jOhNOZXYTvjb4Eh/D/uYXGk9ygkCQ/TyTZ7NRwcD3seO+c//P6blu9dfdfZ0PW8nkpq4jx0rKhGaU5I8rTWNz0hNN8PfL1PqvuLCOisc5A/NLz3+Q5QoXKNGjVq1KhRo0aNGjVq1PjhAMsq1xZYy5aY+VM4oF+US5ci5FxuK0AHFW8YE+AVCTRLDJehz6Uy18ub8eG+BMuACyhhZsM/eBrzg7MwqHiJYOtmXNt2c2VDN9jN9oZwuO9GAgKq4WAncZxtv9/ZBK9jLx0P2Rr2B+Vl3/eC0gBcfmhQIFN1huMidHD1mpd9E7I4L5Gf5sqaBJZdCUv1tOAI1MYnNBxjk0Gp7AAcCM29BD9UygJVoUT1kncOpMaHYlL3afVhdLWj7EQgSNyvTja0Ztuut+PJbHfEfk+G1oHaB66Dni/lsmBPbmpWpAJKdkxoU4JPR76p2V+o6KQ65bh1XYb2Yafrk+B4VpNCKcRjH17qfhKgAZGJsVm5QjIUfRq/gM7ZHoLXOQDUAqrGmArwSCEYze+K5EfyjnX7Az8upjcK0/CPC5LfFLJ1KJTNPO52ObrRQDCpF0uLhSWAjrmQr4yPZfI49qQG4ZWAIRqMRRO3fuh5D8DyBWp42F887ibatxzmyeYj1P64b/yedQsM7qfwMYayfzoqmYL7Dsc0DiOfS7CM+5JqVNniCHyq0Rl9gv0+w4TQ8UstyyuEawdlPiG3oDJU1BqNAhKX4arlJXBLUk+fwQGBY33xlzqkxFGHLcNCiR/zIIF+qV/5Kr9OSojpu1B4l57UOdlRKKbZBC5WWXnIh/9uJFiIHx2QRoM+jBnGVdBf1hW4rniVALwa6WUf63hobh25Rq3sOJ3ZwI5rIBJpvJfz89NwutpWQnCcMxpC4kCx3gp0ngiW5YHDkUDiD/7xO1y7k33wwb29//69PTzsOHdoWXTS9eW8BbCeBZf9iNAtz+9LKKexDymow6c/y7t9HYgffd5GghDr/7ge/b0A7x2tzpfjhr/1C5A+PRw5RvN+ZkPXwwENVlXN8nDYsXfBfrdzqxNdqM7BMsDvqlWiB81ecT6AyvBrpnMzPbJxb8puo1RmXwLySGzwnomJkpKCZSKqRo0aNWrUqFGjRo0aNWrU+ATAcoqy2dKFcW6oNEtf3bLUf9GRPhTKZRl3AiPukexlwGp6FJamrnYLcBLNumDzcEL5+8na3pseJXqCAHyG2rK3AY9RimOpt8BXAMOONs0TgQpAViol9hAsVpl4lPAXer4MzumzyUEqfCwF5JvViT628BkVq4zx1H+Wnq8CsFHknordw2YEYGcFb1SBIyjWYsxCAU5bkaLBoY6xsAZINhJe3u3N+wigoIxzsO4W2GraCJBEL1V9rzFwOLaYK5ffOshNqrqsCA7o20ERDngYFgsFIMEX2G4fzw0TATh3jZH2LY26TytsM6kty/8WMzbGoZjDoaD2qZmOOTXvYom8j29IsQs1fqHR9jNOZD95Laf7J8uFX1d9OzhebL3wqqZ6sXhdgOQA/lLEXlqG+DnmCbf4e6lcFoj3vfuNmWwU6K0rSeYJXsrwFZ8AlE+0vcBXlOnHNU4K1yRbl9qadyUSE96oD69o3VKAoLNpCIHDJ7mEuzoWVypT/RoNM7WtuIAx/2mL4arP2NbluISFxEKwGvA3JkRckNcm+jJ5svQdL/ZTLptxM0Sixa97pA8AENMG3zBTWcFR8m7+LpPEEnbz/P06JK/l4n6VxQkUsyVud4/psE/xJJnU4FldresFwNpZ22FdhZ0F1Lyak0psYP77ofGgI+FTWI2kY/XmfrTHUPUChibsfGCFcZgAmR0klw+3X9IDm8Y11VruPQOTHbrE+QWkDyuloq4G54MkJqFyr6+brcDyMKyZlASIh/c/vKnXm0H2IVjn4bm/VoJw2k02ewJm3AmGj4eRXx8fB95DrFA5QpUPn2W8d+H0ZZzRcL2T57MsNFSXgPuw5zEqMaBxi8uTk2V5jvi6krzkc6VMjRo1atSoUaNGjRo1atSo8QmCZX1AzQrZjFSlKpNnK4vFXZFchiwc4CMp9aNYgmM4fqDXB+Nkg+CAtgewcABLUAg+UUBO+aw2tt8fbL878MM91Mtd39qw2XIbKEMGSQCwvLlaUw15e3MlxbNbHEClfP+4s4fHV7abHqUgc7AS5w7YuR4FptH8j96iUL6dBJvVvKm3sR/sfJ7ctxVF/GrSxPNDQyn36ZRn9UlNqhyesbSb5d3hhyxAGKAQsIkqbfggAzhAyeeNnDCa/QrHhudKHg1gDpgh4C8LDarWcPGPjbVQwFFKDQUcGiNKMt5DfghyTb/VlR0IdGCPsbID4a7ZYW7of0rA7CYK+Aog2YYlRvjVFkwuN8Hyc2IJ+MrWQ2sDLRbQFA7noWaLsgoRFNxNnU0zFH9QyMrSIBTLMSMl6hTQItSC9zQOkn90bFc0kNTRBGB3NWZSLrvHNZ+EZIUrlaNpZPikus0GFbihnI77g9DUfWabwiLGQTMRbbIQAf1y65HCGzqgtVC1w7dQkDv0k4+tFKiclw7ANTdLGlpgpIL061icQCUDXv+bzxuAX1hgQLEMWL7f7W23O9CaAGpSeN9izAMM4+R0bFJcy9/cITzHxXgdAdYA7cb14AkcLU3TLEUzldluM8OHN47Dc9EoMzzNAbbj+kKVGnYcUItOU7jbRiIrEKqSBwKgoTiOuVp4IRfDt+TKsU5cDq+/kjAz5uYyXaBFNYCqVKuyTlGFRAbd+fqpMkHXCmsjfOuTZ67bMORWbmqmd9qrOd5qrTGTFQbmYsN1CmvbMKCCAyYLOtJolchlgw3t3IaE94dU0YKgnQ3rtY2brW02W9tutpJPu5KcCmKvCNGbAdatJldzpHkIy4hIMuDem23llSk4lsc9qkmO9uphb6/uD1THYy2gvzI8vaFGP802o+oj1h2omOdJam3aOMFy/mznFu8leB/SGtMPg89xT3Lx2HG+guTdMNiwHvm+ArDMuYr3l66zzdWVrTdrG8fetlejFMZYAwCW91jPzwksA4YDjEPVv4OPPpTLrx6o1D/sJ5v2sILx+ULrFryPnOxx1dE3/NjpvSESpxxDrNnw8qdyW80Bo6IDfvrxX64CBVXPAN6l3TVq1KhRo0aNGjVq1KhRo8YnZoWx+EFqp4xIlmq9y+dnjfLSj7RwG4gnLirOsy+ugxRXCoedQfkAFEJJPlgWFMeyRBUciU3jA7VUiSqdJnhFy7yz2TTt7XDYJ4Xkwiu4aHgGABqKZamW5a0sFWZukEWRKgEHNuCgw8uOAxFmWFV6E8tKIamMl74SCcBIbCbVWvRmi+J+aH1VHO0Av/R5JYzS6+K5ANuALKXCjfAJ4+Pe0MQ/MrulOlQq0VD9FY/UvWw5A5ZSzzQbVJ7tYwqVHuB/h2tcKNy5NXdKeX2GZZ/gxe4KAHfJ/UJpGd62ZciiYAmcF38v+Wx4M8f8jv2WSth0XIUyvbgmpTJd86AoZy9Hyncs4BSIuQTSC6vliwjFcHFYPrEXCvLi3MoEQID6Sy9hbENKy6PgJXyS3QpCNhT5DJb7VSSQ617dTKD4/ZPOq9hfeA2ne8zPO9TIgJcEmGFhkyobSnuWuFYXKm6fl9kZuzx+b6hYAOUsso1t53mdfncxlxbjn7Ny6TwKi+40VstjX45mAt8Bv90iJ6vp/TlZpJ2rQrKrR2ruxgZ/xc2U1uDYj3vXxzVKvsOEnAE6cxO5mOayVMlndl6dmLNCs7qLq7LwyV7OG0+6QSHPa51VykoA5THM91Cslb6m+T7RGI9N8jqAdNi5yNqI7y1uW8RqFnztB8Lj3sFyS7C8drC84Ws3242Nm5FgfhwHLootm0aSrXN9nrrJZgJlWcYg2dmPPc8H+0RyBEmaaQBIDmskQGQlRtmQssP3OmcmVmAX5Iscxg3uTacjkimoinGv+hUSlxdzpkyMXBYe1ahRo0aNGjVq1KhRo0aNGp8EWGaDPEKBVODuJbnOFhMbk8ItkAHh5cqsp+epg87wV4Wyiz6XUDj6J97mbOiJhJ/xoRzKurFvbDO2tl4Ptt2MhI/Dus2KY/gEz2d7uFczpGG7sv7YWzNu2NyopSel2WHa28uX79pj19nhcO8KTwGA914+2KtHqJb33hzOJdGuOMWHd+x3MwoeDMPG2m5t/bChkjLsL+hD263ssJ/t8PiK4wVBMMu50cRpBV9m9ylOnrxnNj9DzCj5Rkk/IZ0AcygYKROG8jrAsUxNqSxezQInPew+zo1trLOtdfYI+CJLWTUEa8yGFqppswc0x8L2Jqe22Pyoa7vuBqqPJ4fQp+PBDvsTpdLRRMsvFx84e9dEZvgVpqwBpT0iUUB43LS2WQ+Eytfb3taDmsFBaUj/V38NAAxgCxor7qE8Lbxa49ySlQEsMhK6VjctAeqcaHgNdAfY5c6yl3F6qoO72F5GZa4nddBG64XGrUyoWF8mBtRMLMAXjkeviVJ8J3F5oLwpG7avl9L/JQF3jYF7bkPNzhsOvsKhhQ51dvbFLRXKEdFUL8aQwD/90f172exPc/Ds6mAqLaFmh1oS50enEMho4Smuh/bolhWY31CBY8oBmGGMaA8TmQCN6e6w01ges29v74paKEPj9Uwo4RgOsK/BcWgs0WRTilWBSkHDPlmFlCythN0JiCaY6nAyLkkxJeJ8jrixmXRSU8GwWQgY61mmvIGyKeBChRw+uUhE6SW85Tnkrm73a6nkke4hermLYnslSXiuh/IZz+1o4xNqZVWQaHLh2UPb2RpNSfHzjIan2jYWYax79PPtW/cc1vUA2ARQJZjFPbzubdU1XNMxXrB1KG+gpLkHdIWFzrnTNh1Ma45rIZJCHQtna6tzy2uLipSHh8kO8FqejjYdjlxnp8PJToeVnWccO77mVCaqSjAHcS6jJ6/urq9ZEXF1fW3jes1mfMPGVcjjwGMB9KXNhPuIA0IDPkuhLLDcr9e0yeBzOlz78GA/23CG1zLGaWTyUY31fH1gAgTHv+fXx4dHzuHd/aPt2ZAQqmYlOI+T7J0O+4P7MHulBN7XAJrxt4P+9nB/bxPWxwOSpFIw836lf5FWZ/oz8yYNSyXcgzmhUqNGjRo1atSoUaNGjRo1anxCVhiF4rJQZkYzqdee7Wq/KONnKW6o5Rwqy19YCl1qHAu1MqvoCQTxYT03DgtVK0qVAUz847J0wWwAJiuJBsou99l0p1n+/jAdqFZuVvjQ3SSwvNs/2n4vxbKO6nVlbDQwk2K51TEAJhGCyAYhwTuUMR/VAFDQUAo0QmoCyqzWjIZvoWcOtXIhNCxIVoxllLrrK5R4Kwf4erR8TIAp4UPrVge9ewj3q5X1MM0AbDhKuSxuCVglKwHsHKinM5TIzwRd3k/KOiiawRAJlU9qMvgmPlF4C2fFZFZJ4roCGMYDSQJ49GqKCHRGM7mkUiwV5XGFfMxkM5F2nb6WdgJvnrX5igcES5M+5nm5vYVNcmGlUUDnhfp2cbD+lWAVNtECfQu1cjoPh5vl7ec7TofnQP1yzEvlcbhbaBcBz0NxXWoaM/Isx0/sOVS03mxvnmnDwnutoNVlg7cA91mRK9hLwAbg5U0tiyFRs74jkipeDeDezuHxi83SlZmNHOPhDTCxdhReyaFexZoivhvq42WlgFh/Mnq/UC3HfV1C4Tje8K0tZ9abFPuhlI/t56abGS8Xk+piqpRq63RtS7V80UQwXpnuteRLHSrn4jVcb119HCcS3u5RKeLWJkm57D/DHkPN69wmJiUNc0+4sKqJ45Ri2ZtapunqamhPBLCtK4CsJ6awCVitAChLtRsq9bjubt1TVF7ofLFeC5Kj+SmslWCFBIC83mxss1lbPw42btYEyADlbGC4FliWFz/ec6RsBkCGQrnB70Z4Lfv7AHbChKT8nGP/9F9mI1KoiwHUpWAWNB4IlrFfNOcb+s52rmKGLQY9wydZYUxrPDe895WoxH7xu8fHR1XauOWHYLGSAzNuBiRuPLlauOr7xfH38CShr1GjRo0aNWrUqFGjRo0aNT6h5n1edZs+sEcjs+AD9F9mRFMnvQoeoICdCb5FeTpFuFEKj6NpbNU31g1Q/QJU4BP0MdkgADKs2hOBctfLn1Mexyt6s97cXvO1+jDfkjBRIcqDbqR6g6ILat3dg86nHUg7dvC/hB8rVHYXTJlKtPLc0cwJ3psokx6gSAMkkcoWijtAMao5DwLLtJs4NdagydIZMPVE2EHI4GAd9h0cgna0oWlt3fa2bQfbNp1tXDu6caXyFlCjbWzsRxu6wc5QbvfwKW5sC/Vc29ozQJO+p2/0o3t5HnA82EcHBd3ZjoAk05ykl1dDZ2s0pAIk6qG4A5Gh3Nruro62PwAfO6zA9eyhajb7vi990d5/eLD37h/tvYcHHzMH/t7EC3MAAaUkvsfXkVCmsSuUlNO/Gv7K3hSOZeCyNwE4YnNGqpRzY7k075I3btg2xHXLlikX6K74z9KqQa8TgEl2JCnSHha/i+PNr3cQnOwaXqft6W8JWmcYxfuqAMhLNwVXpi4aD/pxeWJAallcOvmZ923vyQmmLRzIliQ0pzB02aSsFChzGAhV46ljAgjqytX5aLtptr2D5QN9zEM1q+uL5AQ9tF2JCzCYbT8yLg97AiQ29o9wDjc7TkooqFmmfHzZXC6petHMTWARx4v7nQr6VurLSD7It3h5Dfk/Jno0DkhIlVYUOH4kY7hUudK4tNagP3roggurCs3AEsznBNJrzfXi7H0tFTCUn3cwPoJcTxyxMVtag1y5XKjWNS+Q/MIyKt/itG/L8wFKXU0v3fdaGgWUI7EhpWs07sR6PLDpqUuktZ0RKt7O+lGKZXoRo3kfk23wdj/ZvEeSDttVE1VeI78nOceQ2Av331Zqb54tjovq747JBaiSd7vZXr33YPev9vZ4v7dpN9txf7TzdDSbztbMjbVHs8EGa1dnV2c3NmCdBCjuW7vaotKkt2dPn9g4jnZ9c01I3K8FmbGmj1dbwuJuUGNXzmEmSLwHABuMemNCqqwxlkc7w+N5Othh98jfKfmJBn/uWe0DfOZ62nGcu17XvVuPgsc3kx2hVIa1jPslH2d5lgMwh38yBlVQvyWEfnj1kvfke18cbff4qMdOKujH/SPvEXg3E8az0SmJMyszUoVGFSzXqFGjRo0aNWrUqFGjRo1PEixLPRkqsOwljM/LZGqFSnRR4u3KWvoih/+o/pIVpqGf9GZOahKlhlT64K5GTjhcNOaTWktNi1Aqj9cDfAASoASeKjqS2+jVFs28znZgwzcowfY83r5HYzDAMjQJE7wsrRFKcBfKPpZuw4OTjaXUUAzqZDWccvUmwApKyiGdnAHEtX1Aj+Qmms4fCl2dyNCpmV7fdDaiKVbT2uANmAZAaROQ7bvGtuNgm/XGWsDk7Zrg7WozUpV3u97w9RsHy4DKULalqwNbCYNSWB65gOIohQeoBrDoWOaNazGw/J1qcHaDA1RWSfyw3TAxMNtkzQpq8MlePsQeGqpw5cuqBoD0LPVHTzAuWLh1K4yeyQKzI31Ooe4DoJfdAaCkRMzud+sQNPv05quVlMUFfMO+k9p3oUJ+PcJmQl6kPmcDHl5CyuTn6jrnhQrTPZgL1XL+Y/nD8jXh5xs/L55X/BjAfKFTDcYIGxmH1LRKwDEc1VAs2znkMdC4yhpBQDs3EUxVCQ7aoY6EMnQ6He3ABMBJDfbcCiYqDrD/oj1hKuXXNXM7Bu4PwA6wDQ0iZXMSliEAmagOkPK0TGZpO9g3qwKQlIKa/6zKBCRqQlmt/UW2Ia59vveoeC3UxuEtHBC6GPCLS+hphiIBkZ4fa6Hbfiyvm/sh84fsSSw7C1wHH8PS3zh9n3/2/EMGywhy2mid6s38EqiWdzDXVR+fS0gdSRI29uO44zW9dUMvmxPwVDTYHAdCZABY+hHDGoPzzNd5LHtIWvEaepIHymb6usQAuT0E71dPVrFBqawe8EIC1x1shWYC5cf7XWqEB6iKRnzwD2qOSCKiAgNT5sxqDSarho7rJED4zd1GYPn5M1uvRyYiN1cbwvE13jv6wYara6mRe1XE4OhQNZMSIq6m5uh6o0w2CERS8bC3/eM9x9nQ3BIg/bTRufCB964OA+glH94o1UbNoS22lStbCJaP8X4F4BzvnFpXMd7wOH/1wUBbjdV5sod72Hm0Nj62tMhoH+RNvWMT25PhkmAdV6IIVSooHche1zVq1KhRo0aNGjVq1KhRo8Ynp1jmB375u5ashNjIwW2Ud8fv5bEsIIrgh1qqdKVSTuDDm0JlhWZZru5epg4YVw38XKXuw4d2emIO+JAM4Enxm0AyWC3g5GllJ4AyV0TLY1WgAApj/EnKaZXKByw8QWFpZ2s7wSF6bno5dPYADcip0m0BEcATHFfn6jU9H36fqebf7RP6Hp7CwBcaDYAPeLU+e3pnn3nnU3a13tq233LsRigCm5VdXUN519nd7Z1dX9+olPtKqt8NGky1ja2hpm5a2+3huYnGUbPt9zuq4Q57eFGf7NX9wQ7TMYEReFlfbwSW+y3AckvYgvMgAOM1aq3poUxs2LwKXtDDzWj/32/7AbPv/j579/2XbFU4JXWnYIkSC4KDQw8FakuwTGsRwC0oaV04J7WegLLgoSxNCArDjSI1biwsMRaAT561gC/xd/4J2yqtHgLQJhuKPP/CMmBpHRERDcMCSmajioB/cSw4n4gAzSUwzk3aknlD/EEl7K72TlrX1HQwA9JFw8HECDWh6Vft/sn+DFd2L7d7ypkYfq9XBfKEbYzuSYhEcb/Az1j3mLxbqY6G9ytgpDX0Ao/z08WVXziVsKkhIdYNwS+V/musokEnmqsN/UCLlGgsh7ksYKzzgVIWwJLJD4BGXmZBSZwHxgvbD1/3PFB+nRdq5Gwv4X9NADq/Jv7ilhwJlHpyYZGAkJI896mLhJLgMa7JMfndekomLEQKxXs5r+VW4er9ZDdSZOgw1GF14XMDY85zQeO38MlPDh0AyZ4UIIzO8FcCbdygWO/cVwheyEwkSJHe8pq7mpxVBq6u9aQe10ZaB7mqmdUoWrdXbe82GlBSY79SkuOWwT72u9k+ePlo+92kcUJCC0mvYbTu0NMLeZpO1ncTweuG7y9mo/txw5P/egsf5d5un17z65OnT1nhAhuMYd2zMR/gMpKF/Rqe+Tgev/dhUE9rCTUOTEmzsOnBb/AcWrcgcYNE3dke9zu9XzQrOx5Ga8e1tf1oqxbbS+986YaVittzYkyerKw9N6x0wbVGtUuo5+O9gxYlWPNPV9bPA9fMzfXW9o872+92BMtXj/e8Xx5ePTL5s98fmeCc9vBx1vfjwRMANWrUqFGjRo0aNWrUqFGjxicJlhFJrcyfBJEBFIh4sqyT6kFpFcMDWBQDjbtUkiuIS6bjEsdQty49IAMqq2T9AMnVCj/D0xfABgirsW6EjUMr9S/RJiCTSvkBiKnAO6m8WlDQEQ027KprwktXLwKYwX4BB9C2Ok8A4sFLvwFK1JhLzQMBlAV5AGHxPWCyYEkPoEBoEmay2OGZYwJAjPJkmsJCjbwerWt6+9Rbz605t7b/1GSfe+ez3D+UyIAV10/WNmw6e/Gpt+zZs+d8DRpRAQgDlLAUHefLpoYAIlK+oVQb5dEfvP+eQPP9LMsBQJHzTPg0OPRhSTh9RAWYAa/DYqRfq+R9vRk5xtur0b73u/9fNj787u/5f6liPXjygAOJ8SF4PFvfrgiUAX3Ww0DVZUeum31yocILr2xc53RtYm5RZSlYFdCX1xNwMql6UU4v1aKgqluP0N5haVtRArjlV1luBETOWE8qUf6vqCIP+5fL7SDCXiAA4VKJrBAoFQgND3I+1+0/lirnfA6x9/J8AiLC8gHcq1SkZqicPaMxzhmICvY6Juf9hfkEsAfQvZ91Dby/oiAYACPmLwAjkgdsFwcYLWUnm7IFCI6EDMbD9wl1Pz2bD5P2yUZpUOb3tgb0c891jNEEewAmGXStcQ/hnkQzN0BACb5blf27qj3Adbwm+xE79HWojMPkn32exQBhReHRp2GDvY3mlRolBgAu/b/1vSzWgyy7ChxzykFsWjpDpE37kLzGLpTTnmyjSpkskp34BJPDDQWqdAfAcmBwX2ouO2wXmZJ+q/A/putNuOHn4HFx8PEEX8J8bQ2bE9iY4IHrygab8yH1oNSaCKUyVMADlc+Ey71bX3h1wqodBHT5foB1frJHNOu739uX3nvFa4ucwarrbRxaWzetzcfJ5vmKdjkPw15j6ZB2sxlp+3F9tba72ysb1qPdvbgjRL59cscEHlTJWE+lUEbHWKx96F6KCyav5CMaA6JqAqrhOC+u4UUm4jRzHOFxjAfWr8P9TrD+NBNir6+vmPwzNPND9Q0VzLIbUoIyOt7mvgSIuBpchopEVgjvu/PJ2kHWI1qPJ5tQobJHlQqaHb7imv/yPdhlzPbwsONXAPvDHvcc7r3qg1GjRo0aNWrUqFGjRo0aNT5pK4wow05avVDOLUvq+dxoCEWFqlsLBCQrlXLl1uPv/sGZzZj8yWf6QgJSCUABVgCMsqQbatoAjKkZoH8lJBZsIGQ9Hq2ZqTFMzY0Ah7Cj+QgPZvg1h4dqHK3Uk4Au8KrtoQSmJDrOIfuicnu0j5AKbtxsCXQAo3HMZ9g7uBIbIIDN9uCPvILiGiC6sc24sa7p7Pb5HaXX0+Fkdw+T1JtQC7cr29x01g+t3T19Yte31wS+gApSbmM70h1iX1Bbq9FVaz3OccbfjwRD0/ZopxljBHByJFgGvMY59Ou1zgWKZUAjNrCSkg+l8NwXhh7ez+verrZrGwH4mgburSCpBWgNa4SsokyKU1Sy81Kf7cjjgK1CbtBH1AXQmlSupapTytqF7UpMrtSoLDtIcEp56b9zucX8LgSEaV+5QWXxhHIO86sr7ZPKWY9LW4w3KZbj57zfzHsvGxOmQ5BEOTV+zH/XfSSVrTcFLOZaunMdOl7eewmuLlC5e8tGBQHVsEgceUVC2DLQr1fexFQse5VCtrNBg8IYi1AVu8gWIJSJH80HQU5ZKxD8qaOcWw+4k7F7Ay9VvZg38qblz/z6hpVMHifOuZ3Gpvs5vJHjtbF91+/GhHa6J/H4otPfItJ1ST/7HtL188v5Ia8vZ15W6sc8jjXZ1zg/PDXho0t4mvthr8LEWaPkhdSxskOI5n0aTrd6SPuTqpsP+mbTpMJ6+l7rIQ9onRMtc7hxVytDnUylsnvTd7DP6JM6N+CzbEHCbQeVImcbRrOra0Ddkw1bzRlWi8Az/yhffLwXPF4VYNnM1kiE9a1tt2u7vt5wDb66uaF1x7AZZd+RCk4wt45qHLgK9a6SkwTF04H7ADDmQTNX5Vr2uJ5ueULlvJlNBzSJBWw+8L3j+smtbR6vrFtvbLjCe1hnTb91W6VBicVI9ryB82anqWgimRt7chu4JoT1vt6F73grSxIcLYAy4D5A8359IljGPcCCisqWa9SoUaNGjRo1atSoUaPGJwqWE8TQw6t2XaksCEO/Rre+AGBEKTweBEwOWQggHEyVzcMkDHXgc1wZsCSb81HVqAZ902T28HCwfpqtH9AMDB6fgp6AXcC2AE/wfhVkynASwQZ2MxSRUE8KIKwMDe2oB7YzPoWf8ehIwai1box+xrBv2Gy2dr29thG0w1WHaGKm8ZGksh9GG7dbgt/1Wio8qC4BZ6DaxPHtH+7tMO2opqXacujt6u5WJf0jgG5rd5/+FD/wc/TO8piFyldV6bM1zUk+z7DZKNSypLeuJRewUlMz/VLH/eJ8k2wW4jx4LrCOoIpOjalSuT8bJKrMW88F1D9Syd2cjnZ3O9ru+Y3dXW1s3XWE1oA0ODA10ZKnMh5UjruPK1SJPAJX6k7zwY6U1woy4fhkmwBVpI6HMK3gIIBYYX+ABEA5VwnL+IvQ3p7tREW8GtNRUZ39MBav18AgYSDoRkuKhYWFZK4BnkuVsrxm3UO8tDEgwGs+RB2tK5H/6+J2L7lfwEmfuS735U+Cx7pWeB1hvOE+kkqbiJj2Erqnlsa/xbH6PahzxXxorF+5qQXH+exgzaxfqTkaphxV55w28veFjzYVxjgHCGpPZpPzWNwH9Nal1UnJajVrMddH+OL2SOSoiRruXTQwC4dyzA+AQ9mEIHEC9SpU+bMANBXYOh/njQ6i/d5gkzpYagTxzokArh3hL116MwfU9cabwaQ1p4u18o2crrAF8cQXkiUx7tGkL8+b0srCEympB5zuJ3pUJxfw/DfYzLBAAsD+whOctixnrGnZX1k+yhpP7cQV0kf3e3ZVMRW+WG9RaWFYE3vbrrHWIakF32DAXlgsOKhG5gkweVzzOg5YF1Hx0QEu4/nO9DmmumewL/wCTfjgpz+uzbY3L3gcrJ5AEtGTZqpwOCV7n2RNwSaoWjugTMY6jGakw3rLtVXvK6hIOfi6drLjNHmTxIPG099DDnsofA9c0whoYZexgW+yYC7GkM0fT0fec1jP58PBXr68ZxO993/g+23aPdqLTz23p8+e2PWTO3v69lvWDWtb3zwjaF/huDAesO2h13jhe5yyMmW9RL6ebCzLZqyy3cC5DGvcX/JpDpujw26nRn/3j1x3DwdYlug94AzbmNfma40aNWrUqFGjRo0aNWrUqPG1Bsv8psRRqX1V8fUNj1J+WVKXaD5W2pcWMC4AgraevVgBTdCICABa9gZRBi5VI1V5F4pP2l2QWuvJBC6E4Q7tVipNL2SLRaNClXLD+7iDFyjJUhyrGmTFJ3+qLekHC69kgUco+nQAR+6PcA3qN7zWlW4CalL2sonh2NFOAEo+wF5tR2B5tTq4P4I3vFLPr4VtZ2AIwQ9XssZ4Z8PXZItAOw/Tvqh09efQ0oTHq/GnYjSsJOQdQpgIcER1aVINa9BZDu4q1mhAhmPhtYzt+EUMixQpMXUybMLIUn3/WkwhsetodpcVfC5uLpqkhVfyylaAZcWW+Pdke3DRn60Up75mXVFsM732Ej6//rsSDl025iNU9NmebhcHoW8mPx+Og8Jiwq94tsDw8b9EVEr0+MmGytWTQq0r+ONuVFNANZFco2kblfe5mSchKZv4wdv8aOcGtiZh4UCmL+/jN5xDWIlEY7/SCiJEvfLizZYBuuf9EVYUCQzH64uBKZh8XiMC6mttiFdksXqsdSnrkBu6veGSCBAvkwEfrkfOR5jhdp5b5djHmiublnxdCmnywo5lsY9oNFmst3HNYt/RQDWNuKuVmZRwRbPuccHmaKyohIUqGmCBAlActheoqIBVCRXLBLtaT9PwEGZ7wsMTemw62Ql8Amhj+/1GYJnVGASi3iz1eGQir/TQVqUMmrPCvkhNVsPuIvlIuNKYEHkukhaejGHygYpleEZ7AsCtROJ65O34+5+PVcz/w36y3cOOVhSw6kE1RjeMNmwma7qNdajsAJQPexV/H8nVDunipLu43G9c9/TeHElBZjyQnIPdCpKlZzuxegXvT7geZjN7BuJa4+8fMTlr1KhRo0aNGjVq1KhRo0aNHypYlpouWQG/Bo8RtN+M0uxScldEfPhPYISfiwFcKT91JRmgKxR5aFaHD77aMNRpu/1My4p5Arht7NRBHScdbWoK6GCZfqAOhrE/qrfwQZ5kV5/J6aNMKC0/VioD4akMVSA9gKUUHtqWNhXbzZZKPYIrAInjXs2s/LxgPbHe4tgG65prV0KqBP1A5dts700He/+996iYmw6Pttlu7HSe6AO6vkW5NprzbagEpNcnIbLJYgLjdRSUPq0AZ+Rt20F1GYrlQgVLOOWggmpFWhOIRCe/4Va+Bg1U29Sgy56DqMOVmThWeYyiJHyfQDfGumvMxnVnQ8BlwJcj4JB5gz5ZbMS1nOh9fbbzMVtYUDMOCAU1I61HVKcuhXueP2zCFs38fF4ywUA7klKxLAuNgNSYSwGhsFv6eielaTTAy8pZn63F9nxy+1yM77N9QLa1KO0tFnO/+Jrhu28xQX9tf2mlofOKWR7PT6CQx+H3U2RykGTx7cnfPJ8rHRKgZg2d7PlsHXxCSOvVmEyWFnjt2XqqMXG9dT02a9gMdHZ9NdLDFkpX/E4VDLr3w5P4eDrY8Qhv79kedphDJ3t5D3Ux5q+gmACjrDRWsFxp4XnbWt/j2sMH1tsttg2TNOO44b6mGfYHJ5umPQHgTM/naMIp2KdGkLju3hQOxxR2O6ECTaJQr5igVY5sD9SMs7CqcLsOqa0Lg+10gYukQrgZ+/NzrUdcb3zvNxjHSj7cWN8kxI57VwkSXVrNLdxTqAagp7Rv/+i2IyFKL5gkz/+EdZJWJp4wc5U+qwLcQxj2FsNmzRdP8NKGbzaag45IOHlWICXbOjsDECPh1Y9UBW+HNdctAuURtkEtrSfUsE9NFqO5aZlgoW2D+89z/uF1bIYqqwd+paJaY5qgfTR4tSEnzTgR1cWQ6wAvoeYSbX9cjTwfHm2eDnY6HOy42y/uw1CVH9D4FH7StGVBgo9Xyu2U/HlQvqu8hOrhYbu2Z++8sMf7R/vSb/+SHd/f22/6Df+fvXr/pT198cQ+/flP283drX3jt3wTm+3dvvWW1Nzra2vXWx8Dn3fNqUiWSRmN45fNiRuCN/D1N+NTizVLQwQLJozjlvdDN2xVqcHXQDUOlXT/ZdIeNWrUqFGjRo0aNWrUqFGjxg9ZsZwb6oVyzv+Q/ZHjV6GgDU/SRRQ1vVGiXigi+UEd5c3p833W+smPFaAJzYq8bD9K0gkCgIRKMJgVo2xo52rh1CiLjggOmfyggbBUjS14AlAKKBlgAWXchCNU2QpAAdRw216a3HZSKQO0JqiDZn1uzUEbiMPBDvsHe7x/yZL83eMDYUU34kO+/EtpyYGDdEmhzhFPE1jGMUNtxoL4kpABMLmNAsY27BQAJYhhmB3AcRwEis7ZQmNlUM/Jn5f7FJ8RjHEwj2ZWagTmqmEASADlpFrNqklAK8Cr8ESO8ycIK8EyQH4CVlIsyvc11NV+TeGDnZlSgmTEL6HYLJSYtApwhS1gVENleobPCSpfqJVfh8tJF5puguSp7MLaBWO8NPd93ez39Vsj7rGwhSiU34K2odWO58mqo1TJvokPLYXscWPFa7zRZuHhKpsLPKQ0h+UFvcJ5jRu7hk3F2NvN1cZub5Fo6Wy7lc2KbC5ONrFh39nmqTHkiQ5uZxGCyqN7BAdUjmSV1J66Xrr0oUJ1v14qWXVfzFTRA+zJDkH2FJmmZisSB9NxD3EOvlmkGXvLzfIukgQ+dGFXka9rVjLHT4tajuRdnYyZ8yTAPVrsT8A7rlPWnGdPm+yxHHvSnPakUVpYL6ZDrHnJ5zvbZ2g9QDrArSKwqlCdq2vO9ZOJuzgAKZQFfVsmBBp4y49rW2+3BMvDOMiLfexd7RyrfAGGY947RKa3PpY8VogI5AJK855wCwsl62JWO5wPhxe3zOG6TJ+VfNcq1ZETarDBgIUKkxLTfvF2FdcjGoniD5h2yxSNKi1UwRGe1fDYb23cbngdW/gnr1q7f7WzHwRkPp5pBzIfjvbiredcYNfX11r3urU1viYu38Py91RXnyavRvDGqKm0Is+NNDNiHeZKCJ7s794trg28nXFSgPI1atSoUaNGjRo1atSoUaPGJwiWO1eDqux5CVqSyrIssU6gyhWEAHvw+HSNXtbpXfg3J+gHaAmvU9ldhP6PoLcVyIXCDyo0fPAOJZrAh8PI9HlfUKO1Xo33qGrEB3Qcj9SzK0AMKKQBj71hHRo+wXv0+mZNxd5mc2Vjv7Zm1bKp0wka4ubggNOLx88HW9GjFI2lKA8kxAJYxvEeTyfb7yd7eNgRJj988JLese/+4A9wfx+8/wGhzO3tnV1d3xDUbK6uHCI2Xp4taAKQA3V1eBJTUOnwmSo9/gwZZ5Ttu00AVb9H299/YPNhb13f0Du16zbWrjEOAEXeECypeVGyfRIM9NLpk++v7RsbRilMISrsz42NODb3OAVYxnnPPH+3vKA3qhSFQz/yukod6ujGm6ulRm8Ah8kXGipEL/WGohugG/7FLn+lFyzLwDEf5uQ7y+30ZsdolAXP4MRjvNmbdO9ShHvCIfns+nEF3MVfAu7pVYk6FveHthvPilvH+Vf5NB9ngHgRIUKk9AzZmpTWDDhfzOeAzDoXrwiIKgNX0dLWwhM4sCqgsF3YkNAQ3rD0zu1b+llDgbwde6rzr7a4D1q7uoKKvrMnz+9sc7WxKyiWr0dCYKiZcV33hz2bZN7v9vQzf3h4ZQ/3L+3hYa/kymGyh3ZyBW/mYUyRQKU6SOXK3I3BZkWKU3j4AtLJVVi2F0jI0CfZkwcC8pqToYrHvEtN6AA3PamxiNS80G1R/AKFonvxtAWsdfzomarSYiYtjhda0PBYxlEpWZGtZUoFb/IIRyNOb8Co1/ncjeMsbBB4DFgXjxjfaF6Jig5tC7dAhwcTXz2vmSxoLD3m6Wi7h0fZmNCKo7HDwdW8SBzxOjd2PGOeDDZc3dl6s7Wruzs2/BzGtY3r0YdUatuwIAKglRc4lMt+r0fDPnO19qRmo2ZoWLrX3OqUYGoHrCk4GSS/dH3ihoKamrt0D2kqiHHtueZp21jX9GS/vw97O+53dtzv7XjYFUkFzSEONRKKHfyl5YUvT++QgefKHDU2bKzBWoqqC9hdjGt76zPvWN+vbb+Hp/zK1tcbm623+/3Rfutv/QHbvPfKpmNvVzeT3b2ztpvhlscJNbjWJEHw41lJwGn/yCoXKpQTyB6SLYx+Hw0v440YCQBUv8iiJI63MJO/mLc1atSoUaNGjRo1atSoUaPG1xgso4kXgh9aC4dHAZVSXSeVLFWrAdlcgckP7F5InBR6F0B5GeG17PCl8aZQVJ8CLIdKGJgGQAIgrlAChn0yWASbIfkHaqrzWjvR50HH0hB29AJYhFid9b1K/qHAQwOocVxbBwUaQeckiwiHylT5+nHieIhKVHcu1ww8vHnfNM+2Pxxst9vbw8Mjj/vl+x8QwibbjsPMx/bq2gbAQwIVNVk6HuD5ebIOLhbsmaWGbfKYdqCcSqelIF3AM6iOj0fbvXqw6fHB+hE+qK2t1pCpXi2K+bPX6omAj2dGaOsQE+pCqBIBJL1hFhq3wTqEzQxbWWEAGAG0n0qwDDju3q06b4GsUIQmMTxV3w7FXdlI8EfbDZ9/dPMQXOldNQiIDWlslPxTjw2b1RXmjtkJUN5Bthq+eZIAKMwV6Bo2n0SaLGluZiuNnCiRnUIhXy5gZQaHvrmCUaURTx7VAmVpM5G8KbfE8wp7gNKb3C1QYl8B40/yTB6gQk73MhT5rW36gdYKVxs0zevs5npDm4vNZrQnd9f0yb17ek0V6ot3XtjVzZVttr1tt7IggC0MFMn3Dw+c3++9urfH3d5eftDa+52u8/2rR5XjYz5AbezJpjhPJBeYiGCSwWFtk+cXLBYAhWFNoKSBkklpUONaADrPx5TAinEtJNLZdiSUxKk6IoB3eBhrbn04eltcpdf+mvYWt58/B+vAsrFe4aThjQZD5ZvV5l55ELYp0aJTt396EtfBgIfcjhIZeOC2ZBUB4T07z+Xk30lWM9P+4OOl6gEkvqgkR0pi1dlpBfUykjqd9esrG7dXtr29s/XVll7KmCu5oZ6nXLDuTUgEHGldJKeb7NuMtRhnhWtLeyIfD631Sn7gXFA5wllMH+iA6lF5oHlIj30kophA0/sCEzZ9zAavjIEa+XBQkhCN/DzxqYSEnjk0a4JlPuj1XMwHTxqlRAQV/Vqrhw7zdbC758/Ywva9L72ylx/suY2jtbY/nOxLX3zfxnFnw/rOjlNj69vJro6oEpB3Nd44jmclEc+wITpNdti9sv3u3kG51s2+nxe+40yqhUVPaozY+c/+XrIsz/jQ2V2jxtd7/Nl/9p9t7733nv17/96/90Pazh/wB/wB9lN/6k+1f+Qf+Ue+qtfj3vylv/SX2h/zx/wx9uMlfjye88eJOi4/uuNX/IpfYT/9p/90e/fdd+3Jkyc/0odTo0aNGjVq/NjzWE7Ngfg5PpBjhlmI1eJ//joHKZdiTvwQPrgqgZc6rMHDgQjLux1EAFDyA757vVI9LXmZ+8bGJ+kASA5daGeh5nt4yLPSFaw3IgjwNIZSue97Ks1wLOM48rWbbUufZXzAf/XyFT+4Q8EpBXaUQIeFQW6QFUCPZd4EPgBes83zbNN0tMNhtt0ejfjMHh93VDSvN1dU1wLOPTw+Eu70643UuQDLp7PtHuENOtObedyMdu5P1q0HtwpBWTxsN4KFCnqHLYX8iRsqN1GWfZiONs+T7R7PdtgD1uEa9NZfaQwCuqWS8ig3p3+tQHmoKzFG49DZvDpZz+OQhQYuDxSkEz2pBZY5V8qGfgbYy1HyhoMZsoZbAq+bJxYIzPzv8AM+e7k35hsU0x1UlafW5hOUzI21NKgGH6c5hjyikQAgdMG4HmnfQCU7FeBIDLhClNQ+e32XbDn7C+S/xfVPX0sLC1dgh8ey7oucdBm7xq7GQUpSNBSDnQSTErCZgJd3thNIVhyx3yDx8L2GZzWBo64LoPGTzZWt+96uN6NdrUfOdfgVw1ri+mpLeHtzNdo4tLbdru1qC6X+YDc3WwKxzXbg15vba/qB9z3U4Y3m9OFAT95mp0SPEYghATJrDhxnAkoAX6qnGxbyp6aOGCLYzOC+ojoVcJlq1ZaKeMxlbQtg2W0OPKnFaoAj7i3AS1fn+7zROEs1m60xEuXNl9DXI6nGiz8Gxy8tdIpY8rmEsdN/5VK8mDBlrcby+e6RLJYtr3POcf5RSu044Nws1beP+wL3NaCxA9De7URQbaJmiLLKKK0VqMxl8gkzBpkjAFd5rkvVHud2dogJ9TrUuFsbxo21qO7oXUnsSRnZFWGdC2W623sA9DK5JNuSsBrC2kL/90ggsVShdSC6ShUZ+hm/L6pmoFQ+Y15gXyfa9OjemXVeBVimb3Eh0gX85leshWGX4gkFqrm53ghqp/2FbD3ZTmQbE1m16PgA7s/d2TZXa875559+y85o/jr0Nm4HrjfH/SPn524328oe7MkB1xzXARAYfvc4PiQwYSkDAL5nhcnxsJc9SdepUuKoJn3hL43vW74XemoFY4ampXj/G5AgxPbjfQHvg6m8qEaNH/fx4xUq/Uv/0r9kf86f8+fYt33bt9n/8X/8H4u//ZJf8kvsZ//sn23f+I3faN/zPd/zI3aMP1bit/7W32pPnz61H03xYYkTzIu/8q/8K5mc+fESv/fv/XvzGt3d3f1IH0qNGjVq1KjxYxAsJwiYP5kHZwllpxCIVKcqspeSjiLZwucUAbVeOGRKQQcg0tEDEh7FLH93sEyo3KGMG5BMTasAirkH1YmrERw/TMv7k8EmfmfaPPSr3vputH5Y0z4A6mMAgG6Qp+h2e2XrEY3IoEyGNYMUZxKoTfyQf9zt7L0vvWfj0Nt6I6UmIKqFxzDV0OHx6Z/p0fyvF5QEgJvmyQ7TzMcetgCPeyqbX766t/U8WD9urW9a2x0mm46v2GgOZdUESyupcF+9/wFVhbdPruzmtDXbnG29vtI+WDoNiCB/3PAnUUl5J8UnfKVh5DHD63m26bCzeXq0cXy0/f7IplvXdqfrEb7JBP1+XZs+KxChOqZvLmwUOlooQH34cDzaDMXdPFNNfsC5uw0BHphPUGJzzGhlcrLZoWiMIo/b4RqUz0o+6BEdyqDuo00LwKSMHWwkWAYYbq0FcMn9D5kUoB0ElJhnQBap1AHC9hy7DIpmt/ogVEZzxDiyc9FIjwpSqBYFuQX35SVNobhsoe3Ee8cTDt5QD3/DOasRIZ1nbdMN9vR6w/vjeEBzxJPtJ9mqHE6NkT1RYV0cTzKKgQpWKm4okgc/z+3Q8tp8/u2n9uT62p4/v7Pnz26YxLh98pT2E0+e3lAtfHcz2nrd2xoN2wCPh842V7K7YOdOv9ZqMofrP9lhf7CHV7C5ONn9q5M1hHoHNrY8nw4JNMJXHN6yuudx3TQ9UR0gD3OAy7jOmnPYP6xNpKadCJABLFNyAoAaDfxmNPHTI6TahMqu6AbYVsVA+OymZSwlAQI4J/OTN4g6y9eF4jnZMbjVTGrMF2DWYXBeMdNW/cq5dzAnaZHAiWZ89MTVmkIVfoKa6dWevMK6iTUNlQ+yVMFlA35Eigg2GIDLAZg5j/x+hO0FwDJB7AQFrIYu7FRYmdCgigN2KPBRvrZxs6FfcIt10pttqmEiPIsnrlHyAnarGDZKhRoYCZOJlSHN2PnaKIDK+829gKmyZUWDvmL/agDo1j+uMca2jpMaOR6nUPiioZ5Xa8hTBx44ySoiwDKPF+CWlSbhD79K+8CaGVUVKdMQlzEmgKuY8zHKYx/nfXV3ZQ1Afz/Y3VufsmE9sAIGNki/7Tf/Fo7Rw/3Bdvcne/4IlX1PVbg1g3vp73h9pv3O5sODTfsHOx4e7Yw3RJZpNFpbsOb4mohDpVU11PzYBhccWDR1NsKCpR+0ZqGZIdTfjZI9NWrU+LEX0zQxifxx4urqyr7/+7/f/vv//r+33+v3+r3S7//5f/6ft89//vP2wx3qy1D22vjRH/h3Dv5N9VHxzjvv2I+V+Djn+/V2T+B8fixdoxo1atSoUeOHIz72v9ZCsZUUdcneIquVqb4MWV9R0l2qxAAp0L2eYJm+uIJM2RJBnpuxr6weDEsB7Xx1bvkwKFCP1EBaswJQAUBGOf9o43oj/8+ra7u5vbXbu1u7e3KXHk+e4Hc3dnt7bdfXG9terVnav9kALneC2ADdbvEApeR+t7P9fm8TIBkhhgAy4Yx7iBKeepM6/A6vw0NKPfdTdduHhfoOZ9EBpHmDwHD8heKT/sQCdNg3ysz5gKctQButAVxxi++PKCmfE8RNvZ0oAGxsheZasPi4vrH11ZWNUErDaoCc/mzH6agHStMBibE9gj15fi6aKhYNrKJRXyg+ZTGh57/WVLFs6OfjReWpPwi9SnuHpDqVR+oK53Y2G9vO1l1nm6G37bq32+utPb29tqd3N/bs7obfP7m5srubrT25veLj2d2VPX+C51zZ09ut/3zjj2u+7gUft/biya299eTO3uLXW3vxNH6Hrzf2Vnrod2891dfndzf2/Pbant1c2dOrK7vdbux6PdjVONqm6/kYAXtWrcE9uMf9gLE6zGbT0WBzjcZ5G/gcQzk8jna7Xtv1MNpV19kWqvqVGdrmDfjKx8rGVWNj09i6a7m/t996Yp/99HP73De8sG/8/Fv2uc+9sM98w1v26c8+t3c+88ze/vQTe/bi2p4829r13Wjba9i/oPkaGq9BlbyypsM1guL4YI/39/bq/fdt93DPuUif46DoSdip6xSWJ+FZG4p3zfusBpUyNbmyp0RG2I1Ilez3kK8xaduuWF5UTpTbu1SQX3TwzK+KuZV/n+1Q8jqU5v9HWggU/sruff36olq4mqRHDGCxncJ3voSfrz1iLFOjvLxuxrFqvLMPi8Y5Eke5wV4uw9Dz8UBCDqpy+G5DyS7rHjfMcSsUAnx8dQsegmr+Xj/TK7pQesdaGSfJJnawz4F6HfYs/UAPYfkIC56HX7ya7x04J2FpMcdjwu90HLICimsWDVez1Y1UuzqgqDZQUkhJLn4t/JRjDcs9AYrmeekqerKH9huwGerpk3/z5M5u7u7s+smtXd/hcWPXtze2voadCCpmZIWE17PhK9dBrcFhH8TtFu8ZUSGztHX3BqkcI429rgsAvMYIIB4PqKDxOM7y0a5R4+sx/u1/+9+23+l3+p1ss9nY8+fP7Q/6g/4gu7+/f+Nz8e+3v+Kv+CvsU5/6lK3Xa/t9f9/f1/6H/+F/4N+gxoVaGQFFKe4zWGlE4L766//6v96ePXtG8PTzft7P+4qOEyrIP/wP/8N5nD/hJ/wEHncZv+W3/BYqg6GUxj5+1s/6WQuFMI7zD/6D/2B78eIF1ZS//+//+9v//D//z4tt4Jj/qX/qn7I/+o/+owmK/56/5+/52MeHRNqf+qf+qfYv/Av/Qvrd933f91HFjd9fxr//7//79rv+rr8rxxHn83f8HX8H/637psA2cGyl8vXX/tpfy9/FOUIdi3P/D/6D/8C+/du/nSKP3/ybfzNf+7v/7r87zwd//31+n9/Hvvd7v5evwfW5tNqAwhYK3K9mfpSBbWBbZWBf5Zz4whe+YH/X3/V32Z/5Z/6Zdnt7az/n5/wcwta/7C/7y+zTn/40xwZK77/v7/v70mtwzmHNgnPHz//uv/vvcu5tt1v7KT/lpxDuR2CeQU1cBpTF2Hc5vh82Rl/LiPHGvPrMZz5jv8Pv8Dvw9//kP/lP2rd+67fyfN9++237E/6EP2ExRpdKaJxP3D94z8L3SF7gmmO7uEc/TpRjGYHzx1wqx/ff/Df/Td4vOL5/7V/71zg2f9Qf9UfxPseY/eSf/JPtl/2yX/bGuRrz8pf/8l9uP+kn/SS7vr62P+wP+8N4P38lc+WjxqhGjRo1atT4caNYjlJnfWaOEv+lByxdTVmuL4VcNAkC/AP8lQcmPvDmrvdoioSGYWwahqZrVCoeaX8BgEJw4WXU3DnUp/jfCcqu3mxGybCUbPgQD/gBL1Yo98Y14ERrL148t5ubW1tvNrbZbqkmQ0Za0BuKLgeiAVDoIazybFoRzPiQfyZIe//dd22zXttpvuL+xu2WUAaKN4yGW4EKNJNRHemnTIXvQbCDGNy9nOHZDD9gwoK2tc0aTdGuqPhjyT+UvEfYZaBhViMf2/uX9ogGW/B3aA6sWb863sq+ARYEUIgfdIHGrVSnVG2GNcg42Gowu/v0O3ZzPNr+4aUdHu/teJjt8AAoc+bXpp2tGVFyDTWcIEzbDLIFSeX+AO5HKgCwQ/ythWrU/WBnNFhzb+myfDysT/D97MBECkJ5lkpdjCaCuC4Z/6EZH6ZZC8sOPABR12tes80aSYXO3nr21G6urlzMLh9UQEnON4x10ZSQpgLMkrjilKX8eylbqTyWGpQqQoKwOI3YgtpRln68oQIVdD3Zfj7aAZ6zh4M97uCpLdU4gM/j7mATwf3e5vPRuv1k0/SKqvztWk0NUU6PeSw19cqmw9F2uwPHHT7GqgVwx100V0Pjyc7saljZ06e39h3f8W327Omtfcs3fcaeP721qyfXdnWLRMJA1bI4oyt5T1COwo4EkFG+2qtBFhYPrz6w6TDZ+9//Jdu9erTtk2vbPrnR/ZzAlyDbya89LF3wQGLCXTxkdYFR48Igqxo8KCqFAt7CFkdWLJgfTM4cUDmQvSmQUKEnMNT9SPIk/2SBa4R82LNNBaGbG+gGRA3gHAkQehInhbLOKXstCzyXfsypUdobF85wDsr2L/qqORQoHf+nDzhV3G51EZ71cmhIPttMxrFCAiOl5JwU80rkhaL8PKvBKdXLtEWRApeqfbcKwRjzf/gKSTPmGM6P9jJQzrbWdAC8I21TUNkBC567J9ecP3jJ6jxLMYyEFpurwh7FYTLvPdzX+c1CFhi4D2WVw4SaN0uktQarRdbWjxupgGGMHvccrxUUz6iGABCFLdBku/uHlNiTxZIAPCF8vHfFXEiJPPhx0I9Fns2cGw6QCevVkFGWM7nZqApi/Gs0h6WnOdtpagwdzCO52bWD3dxubLUarRsH67cjFctIWh7Q5PIeCmuz69s7vifgFCfYEuH+wf19RCJz5roMCyfYkWjdSsOyUMcrUQeorm2oNALKPwBrWWGwUoMNa2URhfh4msYaNX50BeDOn/Kn/Cn29//9f7/9sX/sH2svX760X/krf+UikV0GwPC/8+/8O/aLftEvIvDD6/7QP/QPtf/n//l/7HOf+xz/9sf/8X+8/V//1/9FUAgYGYHX/NV/9V9tv/pX/2qCP0AjADzA3o8TP/fn/lz7+T//59s/+o/+o/av/Cv/iv3Jf/KfbP/L//K/EFbh31A4DiiFcfxYq//uv/vvJsD69b/+1/PfrDi3P+vP+rPsH//H/3Ge3z/0D/1D9kf8EX+E/d//9/9tNzc3aT+AdNgPYB6285XEn/vn/rmEZDhGQE5ANRwDQFgZOEbA1H/sH/vH7Du/8zvtN/7G30ioivjb//a/3b7aeHh4sF/wC36B/XP/3D9HCAzADgj5F/wFf4H9G//Gv8F/R/2aX/Nrlk3Ev4bz46uJf/Af/Aftb/vb/rZ03hgTwPF/69/6twhLkTDA46Pib/lb/hZuB+AR3+OYMSc/zvXDv48AMb/aMfpK47/8L/9L3hv/+X/+n/Pn//F//B8JgjGnYSPxpS99iWP8cQP33C/8hb/QfvEv/sUEvL/tt/02+3W/7td9TY/5b/wb/0beL7/L7/K7EOxirDBO3/Vd38XPXP/7//6/Exh/1LzE9cE54t9ef/qf/qfbX/vX/rWE1B8nfqhjVKNGjRo1avxoj6/sX5wpCnVX6P6SKC4ThMKedNEQrnCpTT6oCxAQm/EP7BKWhVI5bDbktyk429MuAkAZil95JDc2bgSWt1eCIQCQ6zXK+vHhvHcVr9xtpfrDsahMPIMjVzC6l3CoAxf2uqk0Oo9QHp/S17Vo5FY0X1LDugwKkvhtie4FPQFUAGsJ1QD8ZOMA9Z6UiNoHwA5gXz/L/oJ03v1D5U8sCwLA7P44gtzbbJOdDirlD1uTZBMQ6s5CHSfFdaggBYfUFEwNwaK0PXlOp4teqCwLK5V4btb+lTrP7Iur+SDrCPgrr6EK7Fq7vYHva2/Pnt3a3a0DTweGgJCc8A4wAy7p7xn2qSmkyuoDOBMsNyVYLq9LtIhzKwyHrALLUnjvCVeVYHh4GOVtDVB0PNmrbmcH2AYcG5uP8oiGpQd8tjfjyOPdrEfO41COzsPR1l1PWIsmYYRpDiUxKkDzA5INo9mTuy1V2fh6e7uxm5uRXskb3htSJcd9wGTKLBAfjcnOKLWnFQjK5vU7APPdbmfdYbAR6qQ0X4r5X/gZZ+VouAO/Ub+broEa12lORTPENAeTgLRoAulzMjVaLI+jOJYlEC7KKBZGy8VsK8hdrA/p6elZH9YEzasrogHl4hyX7svEkIXNUHkG+bfFuhH3x4W9SlJ+v7YOl1LyAj4mi6OsIpd90Cp7GoeCmJYx+CooDB9nNVmkr5HmIJNmusL82UF9eoNIXtJxfX098soNJhxoiwJLJDw84Rb+zYTQcS2h4nV1NNdDrIHuz+zJClBj3MLxfpPWrUKrnn6BOZeSRtlDOa1/fj/wPisHsrwmrgSPqhoEFci0dxqsbdfWDr3145rHCksMrpdIvE1oxCr//4XS3xX5F++Y5WKaLre7mxTK+mjeGs/F+xzAt2eCkRSA/zRKiGrU+DoNgEOAtT/uj/vjCIoRUKe+KaBShZoXsBTKYcQ/+8/+swRksHv46/66v44gEwFF86XH8u/8O//OCR4CAP4T/8Q/Qcj2ccHyn/gn/on25//5fz6/h8oV+wUkhpIRikrc7wCqAQT/xX/xX+QxQEH5h/whf4j9gX/gH7jY3j/zz/wz/Pt/89/8N/ZH/pF/ZPo91MXwS/5qAuAt1NR/xp/xZ3Cs/uF/+B+23/SbftPieVAnA9YBdCPwGpwTwP0PBSwDsGM8oNpFAMC9//77PL9v/uZv5u8A4j+J+fHVBq7LX/PX/DXpZ6isMT+ghse1jP1+VABS/syf+TPT2AKwAizD8/rLxQcffPBDGqOvNABiMU/DAgNqa/wO+0eCA+eLefRxA+OFCgAoyfHZDDAe6uuvZUBNjDlQ7hMJpJgLmL9fbl7+0//0P53GF4r0v/Pv/Ds/9v6xv69kjFBZgUd5jWvUqFGjRo0fE2AZqlN+pGWToKI6GBAhgZ1Q94XTsHuoUqHbypUWH5TxyRaqQTaEQhM1+eTCW5jqPwcRzck9I6lMbfU8G6wzqJsBkge7e/bMrm5u7PbuiT199oyl0+urjdTHDT5Imw1QiSHr7yXP2PZ+2vkH+MOFSlYl3xlgoHHSRKXXzc2VXW3WUqZx+6F4VDPBgKx4UL13lj8cICER9Sw/Xj2npVoZkAEqbTZnO55st99LqRml27CXgJoOe6KaDurgne13D3bYtzZNnR32O3t8fOU2GgMH//7VS9vDv3m+5YWiIrGXVykgpeAuzuNsw2ZDIHvanmx75VYD0XRrdbAzvDgB2wGNvDmVvJZxnGbbq5ljdn1zbTdXW9sdofZ+6WpljEM0U4uGew09hukBzeZ0gie0zbgsKKdS0+0QcH3chhYKTNhHXI29vf38xm6uN/atP+nzdvfkxj792bft7umdq14B3gF196mEXoMpZS2FyUcBNl53KtQxJ3TOBHqYE2jwJ9GjR1l7Lh1pUqM6VKIy1ucPQP/hMNnjI9SHk33w3j1B8/d//7ts3LiDvck8WYO55PeYlKdm3SCfccz3YVADSjRQw4dQKJcxT+Q7DeUlkgV4zsn69mhXm8E+//nntHjZbnAuOztC8fz+I5Uwp3Es7Gpku0KlOdSu89GaobfhCoqthhYIUIjjeuLYR1fVM1GDJAX+BluKsDChqni2/V62MTFWTI64ol8zXA8cj/zN4Usu/202l6TaXHMYx4V9Y2zxN1oGuO90+L0ETObuHDyHpYvWKs2yrHDOwDPocXq9e+8GtcsCoEiweDPKpRdB2o7mfFHxkfZYqJFXMf/CL/vsc9xtZXjMgKry5lYlhyuUi2SUXg+Qr/GCezHXV3hVY25ACQtfdCjKGymZR6iOoVp1Cx4070QzOCScOvi9w4bl+o6WQvBUbrqeCuLOva/phw+f6+lg0yGU1G1W/5LZam5G4gj3JJJ/tLZopZqlnz0971FtgnUqrFC48GkMjlIqn9DIjrYOe5unvU0TrIEeOa/alDyD77ODfb+Pkw0RgDhHWfZBPEZ6RMsQnZeysDTinEJVgVda6HQ86UQ1Od6bcMzwVXaPfdg0cc6qaWLbb6zttgTLzbCxYdXb3QvYDU02XaPp4NnGzZZzPRpdonoAP6uZobymo8FkXHM2r+RxaF3TbRzVB3ivEHA/wt8azUHHY4b3UI2jAsLn4odrtWrU+NEbAJA/42f8DAIiKH4BYFFi/qbmaFDVAhBBZRyBNQ4Q67Jh3ZsCYLkMWB3Ak/jjRulbHD/DDgIBhSZAYqk8RiCJi+NG/Pbf/tvtb/1b/1aCZuwX72lQUgJalfEd3/Ed9kMJqJYBtQH4AOOhigZELwPH+9/+t//twmoDx4PjxTFB7fzVBGBlOc4A/VCG49oC4AM+wi4EY/+1nh9fbVyON44XxwqbCKi9AROx34+K8pzj3HCNPw5Y/qGO0VcaGMvSVxn7BCgFnMX54gF1+MedA0i4QF0fr8d8g03FV6q2/0quEdTDf/Ff/Bfbf/af/WccL0Dmy/u7DJxLQOWv5t7/SscI1ilIMNSoUaNGjRpfL/Gx37X5ATsUhKFYVn+87POYCvKz92RCN0FkAhyXPsqlapkfygVIosSY+uRV53C5dxACxWVv6y28ka/s9u7Onj5/TkCxudq6d6+sJOgfig/g7lOMI6X/sKt/BZZ13LARSGXrOnFXXJ4Fu9D4iGI8160mJW4Bpy5CoDkacRX+ygRHKgmnWtTLxgHjBA0cbCaxoatK+ZyJ5yIPTahfJ24cY4IAmIAHM1WzgHqE+A63wqvapXtqnAYbBdkTcB+0FgCUCzBceENHg0b6gWqfgBdDP/BDGkF6eCt7m7uAcrnZWKiVsx/vpYo1/ewCwNKHlkXnbADY2M12sNubjb3z9lN79vzO3vmGF/bk+Z3myDD4hx2pkFl6z/Nx/170tzqixBw2ATF2AZahUJTKT/YoGSynaxKq5gSW87wX9EMTvslOUC0fZoLgw36yd7/0ge13Bxu6zl7dPxKW7qfJm++puZhsNs7WtFIVrjejDWt43I62QbPGkwAvfbS92d2q6/loVkfrVrMNQ2M3N2sbB1xj1sQTpM9s7IV7yj183c4hFL2wRcHxtrh/BjVM69u1q2OjYqG0eAjLB1dshz845oF77pZrSfJD9g1pfRGYpIqT90fenrixA3vex+GxHHtKphKlFnlRMRDzcTm7MjQuJ17pq1xWElw8Myty89aS2jn5HHtCZSl3jvvBrSxCcez7zz8Xun2/b5dNVLMaOOZksvqgLU+uDkgQfAWk6qpVv55xP6syRAkk2iXgHuqiGkQewNFcMTyZZWmha8FjgSjWczil9zPVs26rhCQO94eEBNZz+jUj+TVQqRzrXU4pyaaGCTus4W57If9gWX5wHgEm8+TVwDQWDXkl+xyMxSRk7gH+sSt2vov5lCstEmx2T/g8x7QWcW1N1yWlVfV2x4aJgrmR+IHVyjBu7dQerQXghs0FEj1JpQyg7Em+SMxRUY7jVIPK7LPs77xlh8nzpc8/LG6UvMnzMsbjNcF+jRpfN4H3Cyh//7v/7r8jJIICGFYCsKv4pm/6pq/pvi4bfunflV8bxf+rV6/sd/vdfrc3lta/9dZb/Ap18Be/+EXaVABSIRELOI2S/jKgjPyhxJ/2p/1pVB7DUgOq5TcBPhwvwFepAo2A1cBlxL8/SwsKWagtA9Yjl/+OBuQGCPxP/9P/lMpuwHVc89/z9/w907+vyyi3+0OZH19u2x823vCd/u7v/m77T/6T/8T+i//ivyDkBby89NT+sLmV/61z+tjH8VFj9HEC1hZQPV8GfIbh5/1R54tkCLy+kfDAGMMWBHMHnuBQ1H+544cFDaxnMFY45r/kL/lL7B/4B/4BKvG/XOPJ6NPyYdv+sGNG5QBA/H/8H//HPGaAXFhl/OV/+V/+se/9cr9f7hy/3Bhdxt/0N/1NtN0pFcsYpxo1atSoUePrHixT2cXPr1A5eSl6gsuuTPMP0lTf4Y0Y4M8BHt6E4b08UZkKRZ4+wDenlYF3oRqXD/qjQvkFi4OB/6C9vtqwIds49rQ76MfObl/c2rAe7MWn3uIbNsqKt1fy5FwRoJ0SHJ29sVeUk/PEuy79DsAjWS+4FzLPx8lWNHCij+wKgMPVrg55CcQAXPzDPT9sEHjIVmOGgaYrNdmYDPYT42iPu50dDkA8arQHNerzZ1BRjDy+gABQgADa7Cb59L56+YqP7dVox+uNg2ZspyGcxtgB8q7WZqdptg/ee9/6sbfNCX7QLdVrhDsARVAYeh01FYMjPD5xQfSBCUpE+GOriZdjlLANIbhZ2Qgv4K61q+ut3d5c2cvHPWXAUPkBqkAxmN1vPZHglBZMD03qTgVoTupxihWPtL3oqRpeWU+riJVdr0d7st3YW289sW//HX+C3d1d2zd+4W27vd3a9npE9bl17cp6QG68ljzmxMZVhMuh0F2d6G3MOctaeJy7mh5KLa3qd4JlvMJPwMXqRRM3gK1lGX3YunSt4OT63Nn1uefceP7WFa/DZz73grYY02mmcptpFL/HpEJEwgGez0c2AWMzMKjyxy2vM2EeYZ2DylDdIqlyhFJ4b6/e/wGbDzs7ziubMe/tJD9f7G+e5MuLfSLRAfAFyHg42xlNIVcnm+4PUrRuoSZt7fk7b9MPtoen9djR2mM+nHgeh/2BCmXMU/yjOh5qXikrlwQhHcEjEUFwSZWyVKphXaJ70328Z9mH4J7B9XFBuCuPvQGkw9ezq2UJMx1CZ0uXMuSj69+m50jJ6okC9zbmOpZ6EPpsTZCutAKKiSyv3/Tb5LqRrS+o4PdjkK+1rBBivWGlACFhhsPRwE62D9GQTkkPd3XgV/qd04e5aGwY/r+wN+FckaVQJFrkJ+7+51QT93ZeodpEoBnXCZUWeB3XG/jPc/sOUh2gCnTKHgLKJl5XXtuG2xjGUf7Nbc8DxlqUmzX6OZ2RQIB/86OuIda4k9471PS1tVU/ch9IrFENTXsObA+VFX7N3OoF81Trvd8rDohRJXA6xY2thAa93B2w45jDE57KawLybNsEZXQ0RYTSW00I/a21wVx1cM15BuouL+amRdvNk3Xj2c7d2c7R9NRBOp6OscLbNMDz+dTznoaSWUppzSWMERN58IDmXFIyEusc7j3cM7O/p/XDydo2bH4CQH/cfwXUqPGjM3DPQ4WMB4ANoOsv/aW/dAFmEFAcYj2C0jbsCXCPAO5E461QYvLfdl/j+FW/6lfRl7j8OcrhASMBBGHBAcj3psBxwyYCik4EfHt/8Ad/8Gt+nFDAovkfPIJR/v+mwPECBn7Lt3zLx9pmwHFYU4RaONTaHycwTngAuAGm/+v/+r9OaIrt/q//6/+6eC62ewlqP878eNMxlw3aMCewr2jw+FGBa/gn/Ul/Eh9QSEOhCluPsFr5SgLHAd/h+HwR5/hxx+jjBNTVAJ6XARj6E3/iT/yyr8fnFcBzPGCFAlj6X/1X/xUTD5fjCEgK8H6ZUIBKGY+/9C/9S6nUhv845tlHxeW24TcOxfzHCYDav+gv+ov4wJjBFufDwPKXi48zVz5qjC4DSSM8atSoUaNGjR+DimVvsMYPxQsLyvS9UIMruPjRWTYXAbuSmtlfJEWyxLD6qg/rLdBh09g4bGwcRru5vranT1GOPdjTpzdUbd69fWfjZrDnbMx3I7DmpeQJXkHddkQJ/MzfSRUtSACYxeM/CbQF/gHgnSd4hcqCoWyyleAMwv9xB5BL5gSfXreWIEQPkMXP7vqAEq8HoIDKFyBimqBKgzr00ea5U+kzVdayA8B+AOug0NwdoEKe2Ljv4f7BpsONnx8eABGCaFQuAxD2+NB0sP1uZ8ejoCAtSQBn2tbW9EUGpvXEABWvAu5kHPhgNbsNR4JsWWoeYJnAs+uoUtlu1rRMUCm5wHLMlIBx+bpHQiJbAxQ60gyffW51CSw3thmQZNjYi+d39oVv+jTB8jvvPLUNjIVZlm6yAPA5gUZZmBMHwCoHTEw+YF43sxqHAeh46b0A8ipB6UWDNofNAZYTUIzv/STUJC2fFeYcVNGyAsGHHjTik3UEdPRofknVJo6fzdZ03x0OD5zDUjvCH3zgvcHtRAl/klDqAE7zZPPh0V69/MB+48vvp7L5eBqsPQEkpjIDWx1nKaSRYAAIxXygEv1AVk5zktUk1rXG0LZ29/SJnZ9oXrM55W6y/SPm79EO08zyezbsAwiOr8mGIvtuUzhKCwh5ngMGxjjTIzz5Kq8WvtSwwwj1eHD9/HCf7qT4KfyV3wiVl9c1edMmSOwzs/C0TXUbIakOFWv4Kacpv3TGTYmIIugE7Ik3WkbkA3H7h1gwXUksg/AlWGZCDGueJ/Uiucd1Tr7JAtDh/SsbEFSGpMoD93mP+5pzDXAUDfwcIsu+B01HlXyI66KjDc/k8CP39cLtLejLzEQWVMnwxIf1BeC1K2i5HkPVLEUyEyPwUJ5hsXEvhbIcPqxvBbaRgGk6vR/Nw+Bg2deopNIOz/o8poLW/p7j/v1MCvLGlgcxmsdqvH118veNNKf0JK2ToSKPpohsBOvvL65Yju3oONzyZAXgfuI6fW5lQxPJR7wpct2jUhD3LNZTJMmQNPGj5ZTAazBWfj96A1kdJ5Kqen/APchUVXguhyVMlB7UqPF1GlCewucYVgOAsvj5B37gB97oMQvVIsrfw0sZVg9o6gYQ9ef9eX8enwPoiPv4P/qP/iMCXACvj2rq9ZXEL/klv4Ql+fDehTIZDdbg7RwqYag0f9bP+ln0bv2Gb/gG+97v/V7610I9jJ/h24vmX9gG4BzOo2wu+LUMeCsDYqOB3psCgBYWDxhDgFOsj7DHAExD08HLAIAGyINKE/YZv+E3/AYqRL9cAEDCSxqg+zOf+QxhNuBhAHr4G2Pc/uV/+V8mTP1X/9V/lccQwP4rmR+XgW0DPkPViqQEvKah4P1ygefBKgHHgHHBdYeH8JuUqR8n0EwRx4y5irGGKhlq6EhAfLkx+jiB+wJ2J1A9Q80LqInzRjPA//A//A8/8rW4V+DB/fv9fr8fkwa/7Jf9Mr7XAFbHOGI+ARpjDDB3+G9XD/wN71O/x+/xe9AaAtcQ8/rjeFNj2zhuXHts42/4G/6GL6tyRiCRBJ91QPN3333X/uv/+r/+IflSf7m58uXGqEaNGjVq1Phx5rEccNgLlFMZuBRkgBiCyfo54HJRN289en2d8XuAAbNtN9h2WNv1ZmNP7p5QWXt1fUNLi1sokdej3d4BLN9S0Xt9u7a2a2zcDmwghQ8Ah+mQLCakNhSUooUFgZI+pAtq6cP02RsW8VxS0zudE/5xIquIDJalefNq6aajdyaZW2mdWvi7rorSKPklZ8sBgNhxvbEB6r8WHsDy2IQHs0Ac1JoAwXg9lN5gvGc77KEIndJ5EShDCUq4o9Em4IF1CMA1wU3DEmuMK2A2wQ7LzgGSVCJNGO4WGPQSJnuR166+B3DxMSgamgVBW5WAEKXz/Afjh9dXp6aGANw+TwB4MJTUAp6zRUB3grr2bF13IlSGh/N2HOydt5/ZZ7/hLXv77Wf29K07u7rapOZXoR4USAmYKZUrwbuLUaXczAYbCbj4awDWdGxZ3cdEAq1YdT2Tbr88VU+a4LrqRyDjo03wxZ0mAsShG11xLEAHDgUGiKEDBE+sGPCvHe107OUN66piWSojAZCvh+5JNco7TDt7fHhljw/3UmnDZsNhN8eHVgYAV/IK53D5g+X+SI5gqwDCD5O1/Uxf3I5+zkr8AP7CNmSaJ96DUH4RjnkDtaCpxdksElJq9NgKuIf9RVQ/+PVBtQHm/hFey5P8khPlDf/b2GpBgF1DvJiDZUJM8zbb10jhGfcsGtFlX/A3eWCUr01VCuHhW8BG/SLr8HPD0HJbbi4U55Ya3sVB52ROPpdIZLFFnXypvZkemsEtrDW8WaUsgGQRAR1zUihHA1HR8tS4Dz72WDPYuI++3kNqkMoUDxMk8hXWnGitY3M6JU3iXoQdjWwuNG9gr0HLlrhW2HHcK6DHmMNQIB8PdjrsbXrYufJc2+zH1ppelkW8VrgPqO9HhYXgaijC2RCQcF3VEaW1CtZN3ev4wRXITKp4Cux0mXhwoB/+2373a86j6kGg2VMa2c6G7wdIbGJtg7RcvsmwMJI9SJvH3ZuFrphcksWNJjmU4+EZTW8Pt3qKdolpcVXi0NcFXb/OumbgPIJCnM/3JEC2LKpR4+szANe+67u+ix6tgK2AUQCW0ZzvMn7+z//5nPOweHj58iUh7S//5b88qWg/+9nPpsZ0aIAHOAfw9bUIbPcX/+JfzFJ/gEdAu2//9m/n3wDUcB4AY1Aw4thwLPAHDoAICP1zfs7PoYoTkPbv/Xv/XjZ9+zjxhS98gT68ALsfJwD2Pgpaw0YAsAwQ/Bf8gl9AmAeVaTQnvAz8HecLgAkv25/2034aATT8dT8qMC7/5//5f9ov+kW/iDYgGDcoWv/Cv/AvTMfxc3/uzyV8R3Uf/KFxzaB2/WrmRxnYFmA5tge16V/1V/1VH0utDLELIDDgLv6dg3MFSEx2dF9hAHgC8uN6o0Ei/IBx3QGTP84Yfc/3fA9tPwBPAanfFPD+xTjBJgSKWlSd4XoCikNt/VEBWIwECOYWrgESILjWaECIgBoY8BuJCNhq4BxKxTJej/sSYBafa+DhDJj9YUmNMnAtcZ9+53d+J6E6bGL+p//pf/qyr8N+MEbf933fxzmCc/yFv/AX2lcbX26ufLkxqlGjRo0aNb7eY3W+lNB9SDx9IsUG4AL+cUTloJcK0lZztbItAFHT2HbVWr9qbFy1NroCDB/saedwkM0Amn0BJl3fbOzm9srunlzb577pM7Sz+NRn3rHNdmNPnz+lvcLN7ZZ/B8wd1sAoZ5ug4nXvVgKdUMi6eo/WDq5ck1+rGrXRj7ZRIzI1lIrydKn/4Eu8e3wQhMPrUsm74DrAFv6BDMiG/QCqWZRlh390glUNwQFUw7KX0N/e/eKDvfxgZz/w2367fe9v/G7aFOwfv0Sg+M3f+o129/SG/wi5vrkhb4FFMpSg7733ijYDP/D/fp89vnppb3/mhb341DO7e/rC3vmGLxAi9JstzzdsPqA6xjFADarGWA1tOHB8sEkA3AiwzFL3Xh8koD4loHx8tCPO0WXlgExDv3HwJEU1FHWAFb/6//fr7Vd916+z3/hb/j/7Fb/m19njNNke4mXaNMgqhOpADoVsUgBzqGr3iURI59YjuNLQkwOGXq3XVHB/9tMv7OmTG/vCt3yD/cRv/4I9eX5r3/Qtn7EBqmlXKbbtIDsJgvVWtgkO46GATKpANjorfJf9gZL7aDIGKB8Qj9dyVsl9NFYM2WP4Tse1JwvyyTWdJzua5JZ4HT5orNG0EZ7hfqwAZU2PxoodGynqxuLWCL0CLqb/ujeuj5z7z56peJ9Psz3cv7L3v/RFe3z1yr7/u7+XTcLe+szbtr25tjMSFrAJ6Rpbj34NHcY37cas6ezxg3t7/ODB5sPedg/3nFvPPvXC+nGwYQuI31MJv3tEI8nJPvgAzSQn+9J773GOvvfyPXvYPdiXfvB9+8Hvf9d2+9ne+2BPtfGrB/iDn20c1lTAIsEyDOuQqS8A7uMOwPpou/1kD3v5SMY4A2yr6aNXBmTbYIfbSiKkJAzXh+Qsm7zXIwkRimckDybMCel0+Rw2Fk1Jpmjo55fJl1CoQgMyZ1/jgNDx5OyJqzUGyvrw6EZztugQeebvJeqJ9AueL0Xqdt1Z30KVrN8hqTfCm5gqZ9kejGz0SOMit8tAAgvNFhtbD0jQtfbkiZJ3bb+ypvN5gHWkR3XICyb4rq+vbBwHu727tdsnt3wd5gHWmZ7QWWsL7VLWg/VrtNXMTfuousXccrBMVTBODOMAe4sEa3G9ZiZizvPeTtOjHXZ7e/ne+w5LlXy4vr1hYi4apCK5Nh12TLihizq+Blhm0gJrfYDYwrYEc1ENBrF+ab1oOQ9XdvY1IBTh8H1uB1VzoHqAc8OvMVXc4UUNixoC6lDeqwKlHW6s7a+QyTHzZOLuFYD52VbdIDsSn5xQIB+PezXuRJJphTVpz3VqdZ5txaoLVCTAIgmJyEd1ID3h/lBCEMfNHFokftlMEHjbK1rwfoDqFbyPuw/k577jZ3+cfwrUqFHj6yygyAakg8r1w8BijR+7AaCMZAUUs1/LpoU1fvgCSRFAefhgf5hVTo0aNWrUqPEj+f7zVbbczc27qApbZXUy7I0B+IDjqF5mOTRYFRRqZpv1hpDharslDLy5E1S+gZXBZ9629XZtb7393Mb1mkrl9WZtm+1gHdTJ/OwdCmCBo7CCYFm4w6BoNEfqklCCgyNvukfwRLAc9gW+PbewQMivFdB1KXVUUyRXQ4fKsIRHqeGXoFcWVEo1CLiLD/YALlCTnfE9GjpRWFc0OysgtRgxFM1ufRElzaV6MykZ3abCcjMq7JOKV/cxTUJyP96s1JWcF9DkRDAUUlY9TwAsfEjdO5nj4+XiVKGqERbHhipDHIbga2iqwbB6KCfhww0LBPevxetwyCy+d+gGAHIH9fo42Kc//ZY9f35nn3rnORv03dxspaBM+/Rry/MOlaqU6mG/sPCOSKrj4vp6czpCGCgdXfkZQDqgdG6qlm0OFk3VEht2MJw8mOHBjROV4hhNAuFjLUHmUdeKlgaunnQrAiHRLJ1NxxhzDppNNLaD6v0wEfwiAcPxpypZ6uTwl6bnNJTMmC+Yh4S6gFcOrmWYnpSaVLdjO4eW82+e5JtcNrCTxYLsAnCOAvdSUSY76tAwF/B1oWX2yoEEhGFlE6C4aFiWepU5XHaZbgEq83bzGuCt/hb3a3xfKIuz+UtOehQNKMv7cmFxkZrhhb3Fcr5dHleuivBVlVDS7wN4D2NauypVf1/OrdTMrxxLJmrkRRzN5Kio1kxIthawwUHFBMFojyoHeQQ3SJwNIxuj4iuaodLvb7MhaMb9hq+06yFYblI1BNc1V2Vld5YgplpzZQgNYfKJPt6hspWhBqAw5mZRYeDzMRIBMfMJlTk3ZHnENXkxH+GhDusINcFcKtvj2OLeyYUHvFrwmOaa6J7Xca1LCxR/QdiSaI2R171qdmRlQxsmVtHMWg+ZbJUfMjy0iyUkrxVaVPM95T/rq9uapBlaLC9+mywSnIvZ7G9XxSAsLJ5q1KjxYxIsolS/QuUfnwG19N/8N//NFSrXqFGjRo0aNT6x+IrAMtGGmICXJevDM4EyVH3nxvrzytBuYFiZja3Zmg2Iettst4QRT58JGr/9zttUwN09u7UnL57Y5nptT99+QrXmze2Gyk34ZQq2AhrMdjwfbTepLPqEkmfaQ+yluHJ1mSCZABOasQHSnFzxKQCAp2Z1McCEvEr14ZoKMwCPAMtUhQLI6UM/Xwal8rQXQHOQJOikxmOCtLCZcCVsglP6Dmpt8PX1BvBmpM+uHdfWNFLL7h/3drzWcdHWAspbm+2wO9jucW87eNruoWoD9xN8BOhrHILSYoCAN4BTsG8HYxf2HVK1CSo3M8AOwDdKqI2NDNmfkApwjKcsNrA9NJvD19Z/ByAFNTQbcnEwM0uFqhIexmjC17eCyhvArLa17XotKwR6k65sxHwBwGrRhLC1YRzsM59+266utvat3/YFe/udF3b37IYWGExasFkjQJVADxS7Z/gCE4pGu0CfwQGnvAsfoE1ALZHoSAacbWYDwjl5UxPdriJ5kf4jJWCuRU+q4tTk0NWCBMgOyo6TN4LsUEIvGxS8susamw5HemRv1lKDGjxlG+wfNgFSS4YFRCRYZFwB2Lu36fHRHl9+YO//4Jdog0HrgxaqSjRslNUD/ccPJ3t4AFhubBw3goqwxmih7D1RYdogWUPblLPtHqBK3tlq98CkjUr2vWEkkw2NnWjdsrLD4WT73cxGe7gPpA4VZE5N61yxbinRo/sDIO7gTcd2+z09m+fIsyQPZLc4oRpd1RNQleJ+oD1BkjAHqHar4cJxorgxFkkhqejbCziq/+j+Sbh54a9BIJ/MaNRsEnOY5xfrSzRbLPI+WjaEDEPpK2shKG4F6t0pQg3nfI1xk2rdN7jGBJZQ/Pt5wEYFFg9MKsjr43jmkbHqAOvyenNl26u1DZvRhs1AhT6sK3DPPX/rLU/w3diIv0OtPiDJ19uwHhNYlv+wV4bAYoUWPlJHc42cPP2BKoRonInDgdJ4j0afuWqAKmssPADOKNXgPR3jgTUJgwVV+YGJjnk+MNk2HfaCyt7RMV3nU5Eciv58bhmiUccc0vWhyzQTW1JZ43wCVhNkY9tcy3zOUAEviw6Y3eBizrNbi2ANhIr9LOsim/F+sbKmV+NUjA/XU80OX4ui/kDqdSqVmSzF+bpVCd+r8L17YheWK7mXo+Y87/uYx/4cVOHEe4H2fGZT3Y9ZuFSjRo03BPySw3bgMmC98L/9b//bj+i4/cyf+TP5qLGMj/LOhrob9go/FgIe1F+P8St/5a/8SMuSV69e/bAeT40aNWrUqFHjawGWHWIkX1qo4dj0yD+In+WtDCuMse9tg3L/vrc1/ITHwa7vbmgf8eKt57bebOxT77yw27s7u3l6Q7gMz+Sbuxsqk9cbQTBp7Fw97GrapEpzH9ykGnbAJkAqxR5ZldfHJ4GZWwqcTvrAndTOoVJ2pXDpWamIRl2hUCuVowIUmVf5kYfn6sWHdkI1qhEFgU9Qx+EBj1A0WiJUU7OztGuHUvRjBiAiQEnaz9w4zn2QxT6kUA31sppCuYIw+dOWCtuwEQlVa0ArNY8SmI1mjHp+okRuKUAPVYAsd4kIhk2gvGpsM/Z89ADKA3xbW6rXAZbluYsy/cE2I5SRjQ1Dy/nz4i3YolzZ0+dP7MmzW7u63dq4Hhy6zIuma9IVgxaBOruhcqGSDSBTPpKDdvJPjUcxz+J8XWWY+F5S62b/8VCfxjbSPRQqxxhvqG2RUHCVLoYWPrG4t1T+n1Xo5bxcqkF13eDfCmsTADckSGB/Qe9XjC3hshIq7hRDgHucAbrgvXuUjTYUy4RcAGV5iHBs02GSPzMBHLYJSOdN5dIMCtU45mpWKsetxF24Arz0AY4mfaGcpzrf53o0HIt7l1c4lKIBbctrVNw3AeC1/7jOr0dcy/DhLbXM+cyKm3H1YRvJjePi/HLTvdwE1Cm6Vo8Yj2R64e4kvk6E0DfZn4SvdunFnKxSfN/wsWeiBlUE/iDIVOPSYYT9CNZmAGM81oTFsrnA3wau0+MaSuU1wTIsgNRksWPij/7FXBt8nzxub46XmvDh+vj6AzgaanlAU6iN54Ov5wLLWJY1pXIFBuAoXsdzcDV0VK1IoZy7OMb9ulTk5rUqJ318fUzXqNDMu9cy/86XYZ+qUHl9Piy7R/KYYuLFufK9CAm7I62J2i7U93m+ZXWyT6XS2zmdR658KETJi8NJcyvup8UT3jRpa9So8UMNNE1D47E3xcdpJFbjRyZ+7a/9tR/6N/hb1/iRDfiff9Q1qlGjRo0aNWp8XYJl+EvCZcK9Po9m7cmsX61sDd/eprEn49o2Q2ef//Q79gx+nPBOvru2q9sbe/Hpt2lr8extqOBGu7m75c9QSAIgynhXqtHd4dHOezTQghKZcjKVxZ+hTPTSfi82hiIQ30mp5WpSqDEBp/bwNhY8A5xi07NOpdxS1Zntd49UCS99UAvM4/6z2Jk8XWfBGsJa/xvgpoPgjl7LsmWI5lCU/JJPOEzCscBDlU0Ie7PmaPtdx8ZVH7y6pzft5gbjdsv9AjrPpzN9lveHmX61+ArQTNUoIWSGig38gaH6w2F7c66wSZDPqhTeYhuBfqXi5Ti5ahRgBZCpG9ZU7kJ51xR2GlArqnQfNiew9oCfKbw7G+vtZAMeK5TIN/b2szu72W7sM++8ZZ/9zFs2Qpl+vbG+b+3m+prNLjr4tnYdExMjS+4b6wGf+9aub6+psLy+2xJyQbkIUKqScvkoTye3ZgiiDRhOZbJf1oCHBKWubCz8lalIjLL9GJOSrtJaIP8FQeWvJybEFb3FpTcHC8uJaOAYDdLOVE7rGGWzIcgPP2FcHzRBHNrBzh38c+Fn65YPRz2f1y15PZ/sPE2cP7v7V/bq/fdt9+qeDdDofc6ETUdrA1QNhMoTL4U6GjFPO1kajPCAhi0GVOlQIOOeA3ie7f33X3KfN0/vaFmzGnANFoYqVA4f5qMdDlDYI0GiOcpGke6zjiZuaIuI68r71weTzdfQxPJ0ok8ztgW1M+/lExTknvQhII5kU1iUIPGCuR3w1WHmBUuTlYsrWANeFgkgwWWHlaVPcmLVYWmRrREiuE/ec+EgEueW95XSFc4dk5UQk3L+Owf/bWEJEyCeqmD8jlUYAqysXMBKTh/olQ2wuUASZz2yOgKex2ioRxDsUPjuZkvP8usnV0zQQLHcb+R5DZUyPJRRSQL4vLmSuhm+2rhPoypEp+sQ0xvSzXtcc6wVSpLRPMeV7Xbq7NzCRgVr5Sxv5N1j8n/mOZ43dmbjUaife2uH1tbXODl4nofTyaxGjl5tgrHFPIJ1D/bJe8r3GX7HWsbdE9m9nqGuxnvYNJ3seDiqEoNAGRUfUGTLnkj3GdTXeO/B3DjIw5gXAMetNRPex7RjiqoIPrQOI1k474/W9ifrQM+p1pfXuZTGcR8F9C6YtUN5rVdZlSy7ENmAxPscqjawNvIeZtLNy0YcOIfNTqj0qTqnd/TH/pdAjRo13tCsDY8aX1/xLd/yLT/Sh1DjIwLNI+s1qlGjRo0aNX7MgeXSNxPNgFaEG12zssEaNunbAkKMgz17cmtvvXhmT57e2lP44N7d2duf/yxB8tO3n1MZt3HIRZ9MV7dNMxownQmlCA2OE4EmUR3gKeGoO4W6co3AzT8ol1pTftB2xSbh1PFkbc/6YIKGsDuQR7P2IW9ShxHJjHLpj8oP9LAwIPG9UI4mVWt8lncVJeBaUrvqrwQdABv0WpYSDxYCE5SmRzSJO7JEGdCErppe9i/IdhJE41AEDPfzAWihSve8bCLIHwEgZC2iXzieD5AeauzU3A7HKEADr1KpId0XljRMCljBabc3ANRopF6nh7L7Sd9t1/b09so+/amn9oXPfdrWa1iebG0A3Lq5VjNHlNejzB5gmZ6tANsdv45bqSnR4A5AyKWqvt/sPxxKR3JBtxjQtfBjD4Ugxsphb4yfFOaXhMUVzenX4VPqCvYEbRwLBTF06Bdqen1fWLMGTGWyBnPKj5R2BbjmapBZXstQ6Sblcxx3eH5T7X6wab+z43RQAgMK1KG3DmCwU9IjkihSUuq4pKiESrkV9ATEcxuHUMvvdwc+f3OjBpw4+IWwMo6FlQWCyIRgxbFz1kWjwFDT83Lm19K/2f3EOe/9EWA+KdOTb3P2yOWAAu4mlX5xJZM/bz5o3Y46j6TkL9Xr0ViyuNeXc+Ny+7mpW/rr8nLFREk+McH5mlLJndyQ1XBNw6RmjbJXcfWs+8Yrl6K/dw2a9qFyZOC9REUykjFoDOmw+fYWSZ3Otrcb2g8BKg+wpEGVyVogeXsF+6Lexg3sL5CUgxUFJoaSLzGLOI/pOY91Fk3l0GhOqnN5R8guA2xUlhKwNplohXGkpVAAWnDZ1o5sSDhobYHaHt45AMugwIClsCoulcppnYvrmNeospFeeh6rKqCeBuyGjZBgNF+TLI1wr4hkS1UMpX94KEfVw+VaEdY02bM97g+8FpUIeMttWkBs3GPeHDblZvybNOGW0uQsjI4qneLhPi+pMSTX/EjE5HmtP2f/ZgH30hupRo0aNWrUqFGjRo0aNWrU+ATA8oZt+MwGNJyDJ+u6sbFt7Wo92vO7a9tuRvvc5z5l19db++Zv+UZ7/uKZba+3tr3Z0puTqkmo5646quuONtl5lqoNyjURKADPE+00TgaV7kSYnHwjV2p2pnAfZCq2pKbFdqj8Q40+FI67g02Hgx0OsgZYX13b0I+yaXAVtAAEPoRneCwq6QDnBJsPlETLN5W4x4E0AQVUs4Q+gi30cyZQAyiRrzMAC9BX69tkifrQ2ma9ploXau/9/SubZ4BmbBO2uvDVbWwFr89+sHY4WzuO1s4nwh8okgHEHx7uCYRUNu+l3QEhHETLZxrHIQln08AD2aGcZQhKrgiAD7h30Lj0/ZrHC8jIkvTk7evN9ghwBc7XfWM3287eenZl3/qFd3htt7dXNqwH++Zv/py99eKJvfPOW/bpz3yKSmXAZapkvXkfzjOa/4UqnJ6/bA6G4cD5QJXs0JZ+vd7mKhrzFX3yonEZDprwyBtKCrpkWBsN6DImCpLjYkyW6ocfgd84tJWAwvtSEhugu9hOMjmIp0RJv/t+AxL69ug4wkZfZyp2cfF6P08pDaHORuJh8jJ7KCmPNu8Ek+f93o6zEiqCbPA6fqRiEuM5ENYe8+3m0/1Ij+aVnQ4H+uCiEgEAsnWrkrlHokeK6R6ezydv1kj1Oo4HCmMc04kqT5T/wmbh1f2D7ag+Psm3FtcASYMmN3qDGh9qSwBlJJUSWPamfKU/cukpezzPhdXGEspmiwi9MNTEMebpni3sONg0jlBQa03y5eX24EutezcpmaNJHe0vshUN1kfMX3wfuYpge2F9wD6JWAtOXv2BJpVhC1JAe94D3ryP0JiWFmY9bSGMCRjcS/Amvt6OnJdXIzzLOza2pN0FPJS3a3kqQ33ctXZ9veH4b65H/r5fw2d5TXDMex7VFPBcRqKoh4LXIagb+WLOqZ2hBhoJDSbxkNiYoegNgJ4bz8U9cDwebL9Hpchs0+OOf0KSDdcoGkRS9evWLf2AdQfXTcrc1I6VymLc3yc7Y83wxqOYM7LVYbfXfC/i9aucTKHdyiksZ7SWd1yDOmt7KJY1T7Gf0xEq/84Mnsn0jIYyWQ1j5c4c7NabFBJWS6mOaz1Nk02PAOnYL7zMe+s2W19SItmzmLjpJ42xPNTDS53vfz7Rc/PCSOgcbfb5i8SS7vGA8ADkZ6Ptuzf1S41ea9SoUaNGjRo1atSoUaNGjU8KLI/Sytmm6QgmtmNvV+NgdzdX9tl3XtjNzZV967d9o90+ubYvfOsX7Nlbz61fA1gMIccTqOn0IZZK5SM+aB/tOEExrNJvYggCUsAKV8+GpzHAchyxK9PwIZql17MDakJQ2EOgX9LBpt3e9mzwN1sPWED1c6hdA9aFB2iAVtEf/tptLWSD4DYHrqzEn1FwnHTcDmjxF0JPh0sAZYRZrtAMGwd4B283G8F0+PCtztYNAu9Q7gKWsKEaQHKPMurB2n4moBNYPtput3MY40rrUl0dYjZXjobCmY3JaK0QTQtDAQdeKPUjn9k0NgC0kGdKoYxtAfqVYr1QLvbdyrZjZ09vNva5Tz/j356+eELbhG/7yT/B3v7MC3v27M6ev3jKcYAVCIIl8w5XSfJCMOobB1zCn9i0iwrz7K/K8/VrxPNIDgVSGbt9alJPYl4BUOfLXPh0l1EqBXH9ij9xzlDJXTrjLp6RnpeAdVLLBnzW66KpW3iVh2pS4EtgOcFOHn/r3sho4offC7ahSd/xoEaWLI0n2NJ5ocFl00xUqMozNs4rj0E0laOVBJo4AljSfqC1dtPa3OE+lYq4bZBkcoW6ey4DKhMsc3yN0BLAEuiRjfgohgaUbTTHAZa9wSTWgBPQNoAY1gQCvwDLS6icFepZyZzrKIqr4edDiBamJjE3Co7G+RB2Fw7s4nIxYeJzJRq30eokrWdSM+vSOrgmxGa7z2TXsfBCdpsL2T5o6gIu09aHCntvJhnwmbv1doBuc8EGl7AQahvbbAZbj3j0dg2A3Hd2c6UKgOu7K64x4/Xaxi0UyQMTfVQqbwWP0UAUPwMqAy7j2rbtqHPC7RgH6/cBFhECXkyU8DLH/IGvN5TIR31F6B1D/sQ6eSVT8Hc028O6P+0PHCnYomC9w9+ao9Y9zAkks3CeTLTQwSev22qYGntCliQnawDIqUqOezw8n11VrPtLTRGRHIrKDr6uQ7VE8cB8P8LneWWnY2PnI3yhMUEAl/N7UVbMu+e5W6Cw9eV05PvRasb1Bbg+2WpYyzf6IvFU+jcnV/DIBOHdJQzscX8TlAcs1zrB7/E+inu3VUIYPT8X61zKeRWNKGvUqFGjRo0aNWrUqFGjRo1PCiz/1G//Vn74hiJ5HHq7uhrtarvm48XzO9tsRvvMZ5/bZru27e3aIGqEout0hvcrPlhL8dme3G8zQVoovsJ/1u0c8JrwOY2P3FRVqeETf+fbiA/K9M+E/ycABBXKUHGphJkfrgENOjfpBJQI6wzWJPuH80Ta4oO2PIlnlk7LpkBQ2S0nqESb1SCvI4VJfqKyJxBtIOBgObtACC0Y3H8XUIfKUHhinqHqlJeoIBPAwolNtwTroCKVwhfqbzx/gkIQ/p2TlH4dz6EAKpnO+jjIvgTHMc/yJAbQE7h3+Aa/Yiiycb2aLkFyQc1C6R3N01wNDDh0db22ebqxz3/+bZ7vk+fw0h7t+Ytbu0HZ/bpLitnjIcqzdQ1bQG1aMbADX1x4gd3Zm+EFqAvlJOFsVkfK/kKQXRXicR1yRPKg1AaWlgrhhys1bJGA0Dc8LNhVpOaMCTALYYZdS+w3moRpTmgbPtyyWogGfT7bqU/258DyIMZcqsqVnWGR4efHEnso11/e2/7x3qYD7GSgSs7WIPOM+bOywwFKUoeigHgzxlD7EriXSpQJkGQHo+QKxJ/0poWlAZTF+6M1fUs1Ne4x2C8cG8xxaMJlPxM+ycnLtZO9AXy0pWQNaxcBZSQs8Hz4MsMKhuOfEiX54tFqI/xmvaFdWB+UZf3RrDJlGHzVSa3d/F7X1JG6vTBVoSo7AT/3uNV4uQeyKz25R7cjwHh3hOvarlqEhueumtPhefQ7pk/5ytaDmuDBV5xKZ583AO+4z/E7At9mRW9yqKY3a/gdN3Z9teUaDDuLqyt4JHf8Ha7Vze01YbL8kwd66SLJg6QU7kkAW/yOnrzelI/rqsP08l6JKg35GuP+lWKekJPcE00fZdmha+AJKi9r0H0mKx4CaCQbJvgsYztuV0FfeiUvrJnNWqjstc5qnZGnclQQhGI+2cz4+wSrS5C0aOW9zvuRl8297rlIZOufuO90Hl6dEFZAAPmcr74+EQzvCZbhE63j8KSozzVuA2O1wvsDlM6z3d+/svv37q0dtjasV9avz9Zfe+IyPJl9IVKiz5sZ0mIEx6RxFyHG+yPWbvd3L5JYuDdwzKwOoNeyJwNR/cNjjTRgUVlRuXKNGjVq1KhRo0aNGjVq1PikwfLP+M6fxg/xz57f2fZqY1d3G7u63dIDFzARcGK7hkpsZateH5bx4Zjl6vORymF8qEVTKH5YhwqMVgLxQRrem/FhWk378OE8ytUBuvgBGU3K2LSPJIyAUco/NQKDMnn3+Eg4AUCFD9AobyZ0pSo4+zFnsAyWA6WkWw644ld2AlEurKZfsOIg3KACFn64+rDenmXhADEwqXAAhmikh+0Wqlu8Hn+DohBwBepO/G3oAyyvElQGqOoa/K2x49CyvP049oQIu/2Oj/1hzzEfMJ7NUokrcAaFKNR3K6oadUA7Qg80i5M7gIAurhtUrqFinnuoXVEyPxb2GdlN+ujWDcPY2t3TrW03nT27Ebh68vSa0Ov2U09tvN5qrA3+qiibd2Wjq63hXd2yEV8GxbpWq5xQCEhX+HgI4DrwI2wSjAPcBDzXGDhQCXW6z+sA7xIbhk+2/kB1ND2rAfYcRoUMHI0MqdbFFELJfR5vvCYCx0uFevh6w5+bZf2C4lBEJtAVSQAmIgSo5aWdm0+yvB+wjCBpJaX+4WAfvPsum/a1HfaH8XLPW9wzAHUc546qY/pmB8B31ScgpuSzuncD0HL+e8MvePUe26M9PjwwmdEOvQ1oGAaw3PVSJbOkAFAZMG1nh4OsAgDn+vXGr7PGm3CR3uozHwDgByqW1agSx04rArcN0WWRYj4SPLJBDlsOHGsBokuL2iTVDhVtMYd5PZSQ0H9DIr3KYltvlhcJD6nok9lK+hteBiG+7vjw/gWod69mXNlGawLm1ADV8Cj1NhrV0HbGm87hdwDOtBAaB/68hd0FrCygTh46e8ImqWpsCcsLPBfJPYDl6zuB5XaEx7YAMlXr/Ip1OBIxwTRDzVqolMMrBfMakHSCWhdAeO/WREqMaOiUkADwlre8K4RdDa7bAvPxYIfdXu8LhymryQFCsSZwjiA/5tc/EjWoTPEmm1pXNU9zf1ScjyoJMN8CLDMhmPyR/X7l2lJWAmSwzXvUb0V6PHeyx2jOvZ1mJPQO9IJGDwCMDW4dn3beiBDHo/cp2NAgofP+u+/aez/wvg3ra9ten219dbLtU6z12Qs6mvNpLVFSk9vB1xnrMd4PMV64N+b0t0isxDVkNUiv8ycU5/2F98OjEhXqLLmw9qlRo0aNGjVq1KhRo0aNGjU+UbD8zmff4gfg27srW68HGze9DZuOKlV8dg0XA/n5uoer2xUAIJxRuo8Pv21jp9YtDBhhD+BWBtGgz0uXqYQjlHEFqstTKfjFR2yq3lAyvbJTg5+PVBFT+Ub45IAGcM9hA0FVKmOXVylf73AhqaAd8OgovYGWl0wDeAiI+3ED9uErlc0iWqGRDOVkAITUgDA8VAkeBaYBDGyGEnmyw37H5/XdSEgDhaPAkw9Z4asc8DH5wxajK5/l+J03oksC2SgDD9jgKtYC5IUnKQuzi3Jq+Q77fgAnh5YNvyaAyCPUhkZwApVt7COrS/FF281qQWgafW4UpdrxjHB0zU2q9LsAzPmsC+6cbCmy2jhZVCzsLzJwTmNXcJfYbzqHAsQtjTD0dyqMi9f6IRZe3g6D2LhPPthZt+wQPJXxuxUMx6ywOvHrTbuLw0RbFDUwCx9aT9j4fIPFBBIVGGH49sZrY05HYzOAR0LHmPnRhNCbo8V9HuPMuewN0fL465gF1gWyY+xiLAjO6bPuD4I0JG0ur0tOJCybl/l0x/nQ19jtRFypnhSgBNQC6KFOD4Wxmnb61I9rE37NxdWXlllf8dd+1SRbClUX4PVqMtc7NI7mhExqBfD0dQF/x5o0wGt87LhGreGVzt9ju6owIFiGMhxwGBZEUBo7WIba+O72xm5uHCzT3qK1cQPA3NoGthdDT2U5ml5qnwKLasrp0mvdfD6fowFdoVp2y5zs5SulLK7kCetK2Ckg6VbcdWEBk++DqApJjtjFOhP2xALHMJ5eoYsf1gRvHJsqSop95AdgdK42iTlG2M1GffgFoHDxypjvp9y0VdmB4xseb9ivW6PEvRFDpiUu2+HEvYEx5zVFE0QCXlW3RFNUjX80+tR7YTRT1e8Bxv1RHBPV3rAMQXNOquWbvHbjfsJ7Q5FkLFaqqlauUaNGjRo1atSoUaNGjRo/PGD5O376T9ELCBLONu0fbQL4hM8kS4KhHIOidWXTHmXtgMl60LMVcBLNmU6zFJBoRtf1Xr6Ocl38XY2HjtPeAcaBH6abFQBLLwXpJCVXeJfudvJPjmZc+CA9Heak1MQH6WEYbWRJOHxERylCvbETgThAHKuds8qVasK+99J5ATypB5uFMvVAtS88RvEB/2gN2sPBNJUq6gCYGkN88NfvANyPUhDCi7brqFzGi/b7HZtbwYr6tH+07fbKWlfkDbBB6AXQ2AywQym7GuDhmDuqer2hnasHQ3lLQSFsLbAtB40JCnrTKsEKt8R2RSvHGarm3l8DqC95oJTl4UPdrNSose9td39v76O/G8vj0RYOTRjdS7jprOsGqdNPsvNQuTeuF3aOBoze/Cy4B2GmQBhUi1SJF97GJdSlyNRtRtRYD8em8ZC6OHtRLzBL+KSmwn23qXApIsBMBvmuPnZ4HGArLC/oiyuvhTTOPKYkCHW7C1ipUE18UiM7KCxdUauEwMkMle+4l2BBEBYSHCv3UXbw9PKDl/bFH3zXDvsHOxweCfifPn/KY4D6V9d18gaR8F+VNQs9v12Rj3mApntKvgTwU4NNgimcWIvXrtjKk3bYrvQlMGs1Rnyde3rj9+O48Tni4Jl88GzT4Uggvj/MVDhjP7x3Hf5xPF3BSzgGSMbKBiUkOM5QCGP/UOTDLgb3CH3ApdZGkgYHQwC8Wsm3mZURAsmr1FTT9coJ9AWFdvW5H3qY+Ixtb9tRlRrwN0aFAdYG+h8P8qGXzzSgMZ6DZnANATMTRA6fUfEBn2Qkl3Dv4DlQhuM+EQR2sLwGkGyY1KO3MsAyVMlokHq1of86PJKpRh6wrsrLOim434DL+Vv35uXYugIX1RvJkgFPQ+VHqVjmPSulsRJEWdUe6w+vYMih3XZCazbue4B1GMmbzSzjwG7lcM0KAwB8qu2x9jd2XM1+73lyKhIfLhXn3RRJQb+/mJygmli2EIx+8CQMFMeYE5AaxzodVQOTJu4R721IijZMdiWF7xnHoiayvEV4slIDZ8COe0yezP0wcBzW261tr0+23t7Y9e2d9WvYlQxuPxJJHpwzKixwTWSzwa98v8Bx4fsDRkl2JK7A5vsorVNwfLKe4W12PNrMZpjoP6A1Vtel+irXqFGjRo0aNWrUqFGjRo0fZrC8uV4LlLH8Fh9Ysw8kP6/CTxmKqXPDEneCL0AzgGWXFgJ2hMo1eShTjZaVWYQW/CpVHJsgsRmSKzcJlwSHQpEVoFOqOjUii6ZdKo8WNAt1b1aPlmeYtaL5N7nJWuq4xT94G7+yExq3mVWml1tLVcdJySb1pxSKUikmP0wolg8H2zcAjj3tKgBe1DMse9+WoHPZhC4rB6WmKw4x9S0sIENS4YIUyqJhdakOdi/b8vcZImkcqEwdzI5zz9L9E4A07EUCLiX46teECu2sXA5P7UBaoVJfoDCXQmZ9cvhsO4VeelxkxWThkRyXBmAx/IxD3Rhq1qW27yMijiXUuA6gF08JdTWPL46z0HGmRovla30uuVIz/MfT+BS/45yHP/EMQLu33cMjwVZ4X8e+UzNHV+tSAR3q4vjj4rT9d97QMVTK4MOoDsiOJHnOxfdKGLmSOKw9kie37tFQKFOl7MmipCYG9HXPWjdEXjSkLASkssLhV8HhsOdOD2+Eh3uH0NIRq5INUBMLCLdMaGSFad5RnoN87cpsPQy2XUMZDL9jwd6+BywGWO75PR6wrcF9sWFjvIZqZu6LHsAN7SzWm3i9QHI/jgSvtF9AUm1obFjr+fRYhroZlhg9rDE2tL5Aw7luPQooA1bSFgRg3hMoyebC57hXW8StEmuv/JMFd9N08CReWPuU96nuoeU4vaaLDZsG96BWM1G3fcgzTeLphUof6nr8QdA2tkVWnq5PpKzKq1SGewnHbe5e5UmlHpmrdBS+lhDo4v0Ka29YZ2h9XEy+1Pg0eaT4br2yBdf5BH9y2JEMDpMLP2tvRJn3X35d3l/h/p2/5nUnmhjGnC17AUQjy9Q0NIB/Gvg3DFuNGjVq1KhRo0aNGjVq1KjxtQTL+8dXBFwP779vh92OHe6P6HLPxlKCtv1GDeT2Dzs20GMTNDSz6lsbt4NEfw5HAZVgjxHAAKD4eNgLMNHHU03JjvhgD5LQtA5cHThTdYjnQs3pzffC2xYNxPjZWRAboGWAzzKsMw7wvlzR+xmQBYBLjakgsRSc6vB6MnT46K6sHQEIoMiGQlnWHABmUMZR74nmTSwtF4SZTpM33BJkkEUswJrK/APMsFy+hzXGaFdXN1QovvelL9n9q4NNDzt72Z7s6dOzbTfXtloBMoWaVMAI8FAKUzWwCsi+wrg2B+6v6zeEVqHURYTnbpRssyS7690fFJJfV4/TLJTGo0ltTAV5o7EP/+tkD0IReGOrzcbsyVNev93DK0IOKPcAV5hciHnA45J/rWxOCkicKtkDLEaTKuQipErV9XYgHmAqkIsDXEHN8o9Q3oYa1hEOm+FJGYv9wD4lg8wlKE42JgXQ1+BE07IcC9Dv5xWWIuHrrMZkGdIKFumYA8axMSN+D0jV97o36ME62/Ew2QGJHDafXNnD/c7e++K7vL+211cc836zldVKC/DY0QqDfsgOno6N4BOOH0kMKpfpywv/Z1jJdHaGUjnm2RmqfBwPjlFNwab5QLANNeWqOVGFe7y5IpPcoYkfG/5NVG7vWFGAn+GrDNWyHhzGFVTb3hyRcwL3d1jrXFhgiDpqjuKsCJhXtsJOJ72+OWN7rfWwCoASnF7SqD5obHDP8RGgb9WweZ6aabb0Cw8AXHI3VBeEbcWVN8m72sqegq+hZ/JAYIzf4e/0Pd5sOH6A2NFkED9zX6g6aOWLK+9meMIDbuP4OoO4t+sdqiJ5B3gMe4vGm+6xqaI3rtOd4PeBW6gEV7xc1F0pyzl/nGw+PCb/5GR940kmAfoTbT900xaK/4DFF1YRoVSGaho/Q50Oq6LHh709vpLNjxxhYh3VVxpTcBIo4TBDSc3KCt1zajiI+Xy0U/Jg0vorxbk33YOK3Bsg8ngdLGOpw2snvHdQ0Q6hr9azFd4XTrBUOtvq2NrpeLBjJ2DP9Y/zEe9TAM5oholEK95LZr5HpUoFT571ULUPZ7ue0Tx2a22/sW7E19HObDzoFTH+nsCkjScRA37njBneoFABhPcgNOnz/dCLJWC47hl6r2PsJiUKAbU5dpgvKtNhQkeXqVpi1KhRo0aNGjVq1KhRo0aNTxgsz9OBH1gf7+9t9/BgJ8ChgxqYda1KufFBGR9SD497geWE/M7Wn1Wum7RX+gScoAXgERSu4cucHg4tqEtjGbvbM7AcPlSb4QUsKCcP5UJlyMZ+8kSW0lmNAEO9K2vL8IR10BGqrxMa58kagPsGXHG31aSGK/yYUUYOwEtomTx9RRC1r1IBDB6C8fOGfFQmY6wB3SfD/7bbA60+6FOMcQ6FWjQsixL0pPQTJCQAda/drEjWFeDYE7dls1ACWDJcWGDE+em5ShCU6uY8L4BF4pjOAKINys87NXqb2zRvwl5Bh+oAip4eGGsvoV8ofS9UvzzBUnWcFZfp2YViOSkTHdrGk0Ksnb11Q2lcbqQQeBbl428CzeURlGrjUhFO1X3aR3kwfgwXW4pDkXBXcxYK4VDwh0I07FgAiMI/Ff7i+93eDrsDfbpxL4yY/+7pCgDH+eHKYIJ7bMe9dKmKP+Ea0leCMFvAUh4fUHEKMGPeuYczmnSisRoAGwAvmk12UOMKYsW5oxwfgBnHJYV1eCq7r3Ka137FQ7ma7ZKTR24aYQmqZT2DR4BU97OGNQ2GmWDZKxZkG9DaGurRpvGvUBWPUhtDRQzrCYBBb3QZwBbnBBi83mzs6vqagPcKVhQ43022qxAsbq2lB31n2y3gPsZO54f9AbjiONhQDxB0cM9dNDblve1gGUmL5GGv+z3ZldDIHIQ/fHijAWcsbheTiokan3GhXE5e3LJciWaeJ7dkUePBVWH/4P7EcY38hikrNdK1iWyiK9WRSEDCDoCZgB3+JQlFuw1JyOt9TYYPt3bhjVITQPc7iVwZCYCsapYVkFdHRHNMP+ncmNKzExySaFKIBMfKzseDPLOZ5PEmgm2ougGflVwrxx3gXe94ngxjE8GGQw4V+ukIa52RSTyo0fM4aewiLZVcN9JJpjRTtBl1+48YNk9URZ8C98NmwpXNCFdKPnjlTjQbpU92mgtVtlyjRo0aNWrUqFGjRo0aNT5BsPzub/0BgqeX731g+92OykKo/NicjU2DAuIBvqwJm8NSAMCkGUZ5YxJwyRID7rsBhvgh1xs7AbrAA5lglv6oUOs2dlpBhQZ31xMBlSw43PN0GLltgCx4KuMDNyAbNgjQgw/UahYGL9eT7Q97fkzHB3wBTn1GB4vhh3FX6Ako68N5go0AIAANaOIFNTX4jv8Fos3jdLLT6mQz1ZtSGSKALIAFWwIYRzAObOGLCmj95PkzeqUedi9t2j8QLhGke1M2HATgF0rroZrEufa9wL4sNaCU1rHIH3cym/B7ADMdB8CMwJvgvaNxAWKC6tbGAa9t2EBwhnIccBOqcjxwkgAhriamyrDpqFiF1zTgEeAh/JAJVwD1CUHhmYrXRaOw3HzusokcPZYLlBINBcsGfaXpQSgr1RCvaNCmUc4Nrf7/7b0JsG3bVdY/9t5r7eacc+997yWBNERTRgRKIlFpQ2OhKAKGQqWkkAIUkAJCtBJC0UVDUyiCWpQVhKJRbCpiYQGmMAVFF0tsaRVKQmv+wZRCAknee/ees/dee+1/fd83xpxz7Xuf5ubd+xLeG188nPvO2c1ac801j/s3vvkNzBuCIL+W7hSMjOTY0q7ig+IZFFvSwDgeH4BrAGqH0ePUwR274uEkLoBa9Kc4qiMXNiJO4j7SW8k9jfFnDjNyljGGgKHnF4T2dAkzh3dl/XLD/OoASOGuVVM12SHptMdoeSRNnJfiNmB61LkCaCMPHa7l1XzD11w6ZEURaGBeuvJqkYt88/IWQbFytQGW4eAFaMXd6m05o0jEL+VOM2q7TU4IeIy8bb9POEzRwNDnAn6G88P5oKEd74lVZ2erTtDMYWRkc6/Xq8ZFjDzmua0BvgGUV1ivFCuBLGTkHi83yjvGz9u4GOQgcwfEemXrsw1dxmeIokCBqFdRC/cZwDqvA3PV8TNdl3D50/HO66GMXzn4oxCFLG130Edud8QVs2gmoIxCTimGMG6iFjdUL2jiTWKMJ0/RNYFzPZrDMU7l8srvSd1bOEeBcEDs6f0WTUG1hnqRr4leaKz45Tteb3N+5mBZaxKLJ8gORx6w8je4zmvd0JxceiZ/xKnoPlbWMjK2dV/6yZWcYazDO288qkPA/cSYDTrhfS2IeIv4Ho+hmxm7ZhY2YyY85q7nH/t333twx2IY53oT8zOL6xxrtLuVI2KEsL7syIE7mSHrzLRmpj2+mPPcrpV1veAXjwuFHk+S4eM8eoPlRBSG8Jp+PMmUU6lUKpVKpVKpVCp1v8HyI295G0HbzUfRHGxvy5WgJj6VonmWIIc+7NJ5x2QJd/hi2zbhpzvtuA138A/P1YpIZ6zDIPyLmadj/dCu7NOODknAXW7y9TzUng360NCqt83ZGV8H0Asf1ucBUYoD7kBHJ165X+NjdnW7lWZ7DsVn4f49xNbvCOcUGII7l2IzJXcLFze1wPTSlgLdzAxlVz828JIEg+CCNFvatRs3uH36kXcIFMwWXQHLZKNwv/U9AdiSW/cFuiqokuszaJIc2jvHQIKm7PEE0O1Nucp2dEZ8KPOYERvzme227pjll4AHtllDML7xuvB5+A/AaERtyJVK6KIX1FxwF3U0hqtQ2ZuxeR5vwJKWhzEbtoW7zOY+VY2Y4MPcTaksXX+8X6Oay1rnlyCx4h64Rb04zPW4yl9mNiJOwcmxjKAC0YQ4Dt+iASGOHXO2WKYjkrUxYdMVHDmyxVqt92eGOIsEcMMLLC89WuHy1i1eiEW35D0QzuSFQ9ES0RJwHOPrzs/ieAxnOq4go1507+4Pg63sSIgKJAqQiycMHXYWoEHeQDc97rOr7dYGzA/mOsuxjIxhupv99cOBjGmgXGjMFf13gfNRbzicmGtpJo2oBa0F666zfr6wa5sN4fL5prfr56sabYNj9oZ45xdnbJC5XPX8QkM9wGiM0/nmjPfQ+fk5AXS/RnyInheOa649s3m535Ah3m9WBOjrzdrvP917XOfcXczr6qDfZ7yvcXG9A+qGIzWW5RoHU2adCG5Zf/RoPUL3hiPOKJR5k8Pi6o44mDYKmGuLu27ZmHOw/Ra7DNoIDbhssf50sc/B78e4t/Aa2scR91a5v1po6cfCRpVYPBgFhHiJkTCbzea8wKj8fcWj8PbgmrH0nRW+64JRGTJso3Ek3wsRLe1b4r4ZNA6xO0XNKBWz0WaXK69ZkSoaSF/HZ3Aoz+2ADq98mPL9FYHRBEU3J6kM+FpMKVn3HHtVIgWVfdKz4OXHwRtEYFlfDopRJOR7wu3tc5PrjCfCeANAfjEr223PfFnM43BgYz3CV0TxNIbuVCqVSqVSqVQqlUql7gdYhusXH0ABk+lyW8D95w7AlTJC4QgEAEADv/i0SrDXATbDMRUfoL1JFB1aDg2YNal8TEJq91aRH0bkhUc74GfIjj2au409C1X5o4oeiG3IalYlEKAP9YA8clTSjQdHIVy6Svaki5pP5+HLLVi2VDs2ABSJnMrSC26h42fTtEHZmcjKJJwCkHOYyvHAebf7+x30LBYz23hTrnG8Yn4toF5AQA4Tzs9dw3Jfy0UcgBYAczZiuzigg1O6gDQgeZMYiYqttPW9NukaizvWxygcxMFhaEQUHItAE7q9CW6ONuwAHw9ySHqMA8Err0/d/u1W2sJc5RAOg3KQT71Hga0a8ZpF7A5WuTsFdOPYAuAxBoKQB/BaDvnqVo44gJqFXLbmO4xrdtJXly/dioAzKhZgHrXb4OO9tU1ex68Ykjpvwm1YwU44oHEd1CCSMDacjA6F8V506AN09ks6iwE98fI7z11mYYUJFshKNmaYIwNZ7ujqluXrhW89oigclOL10EjywEaL8qmiQSCiDPaIthiQgSun/n442uUlIjgO9ugjt2x7ubOrW1eMPoCDGG5owOEj/huQGI0dPcJCY4u1Q/+EO5qN9dzxGxMP8xp56YDmF2dnfNyNG+d2tlnb2dnSzs/XBcnq+WiCN6cbGWAZOwOWazxfjfMYhQHHssdbrNA4b9lxTYuiVTStjLxp5vvC9Yy4Cxa1lEUdjSLDVKw54Ne04M5aHKkxCJ55Xg3xzWSLIpePDwYrdmpwSjW294hVKfFAEdHQGnIdBLPhqdZj7j6g633Ja6VYCTl6dRrVsV0c9eGkjoJFANaS8FIb5N1Gl5v7g5eWUQ3u8OfxeRPRMWzaWoNZNOGOiihixs6AaIB3MsZxO/mBzJHhzviWugNC95nnyfvBKz4oGDEJvZ7nsJ1NMT1KQjs3YodF877RRM8jKQTKVU2S39vhe+xewZpVXNQRCwVXMb7Dfezuaqfj/OuFYeDC7fMHRVQvMiCWis5k7q4Jt7LWaCbJ+1od3vNJrlAqlUqlUqlUKpVKpVL3GixfXD/XlvLNNQIIxkrQkWt0sxF8Auw6kJPDGNBWza/oghwPbAJIt7KDZQIyuvkQYSC4RVBJiCKogziHcdjZYTbawW2MHaIv2MRKcCjAIBsDElgDVshtuN/B7ammbXgvIKtlLw7QrQAr8H7aIhx5sHy+56EqRcEzewmxAfP6EiEAMUsT57hfmB26kg2KMcGxYhv50onTAFC3R0OqyGyQ6w6vf/3GubjRYmSMQM8Gd964i05YOVgB+/Z7fS0Pq5JJjbHF8OssAWgEHQApkRsasF+qsFiPEeTCdToMe7nN6SoUQCMcLo5uJ2gE94LocMEdjnM6V5HzCxxK9zrGko0FvcGau4ELD6PbUbnXkf8c0Cgc4SXeIlSc2X7cPJ3YDi9nntyUAZbnjWMWIMdzhhvAE85CwTqMe83C5lvq4T5mU8ikQoPyvIOoMYYB8Cxc2UFz/EUJlf0rgH0AyZC4kTvhAZtw7wCwoQkXAOwK1/5oq9Wajn3MxltXW9tcXdnV1SXwoM3PkG+OxpeIJ1FsDJ5Hrz6iYDi35uU74LJcoQKTcEUL2ANMA1zrWiKtYjjObLdDJIbZdjfaIw9f2na7s3e87WE2aXv00a3trwbOCcHl0Q5bgebI18YMKS3yMD9mM4+nmNsGMLj3nOEZjn1p186vMcP4wQevExY/9NA15hyvNivbXGzUcBPnwWZ3+g5XMQAwIinwOPweWcp4LCM18N94TzSYxJgAFjPqwYtd7jCPe1HwPaBl3AfuvJ9MVl1URZ9UwCzQKDgJN6qicn398fkPmCk3c7BiXxfblyqpE80PHHiq6aGAKXGyg2XlMOM93RlryNTGXN1wrmF9g0uW882LH3Ju1zWyFN1YFYhz5MWbwOh6vm3UBCJ33Nnd6bi6I3ZmqOBWHMNRdONtjUgOrUPxmoL+tSlprBtRLCmRI16pYgY8x9Hfh7sn4GhWHktdXiIaqTZa5ZXwWA/tSNE6QlCPIqXvTgi/sgqoGEYU/+K+d5DrJ6gCaED4uR0QdzGg6DKwgS09yEc0UlQhUWsN4jHwt8wXTxw3Ip8A+AmWVeSyrrPZiONSrIrWeqzhnttectaLz/2d/f8KpFKpVCqVSqVSqVQqdfdgGYBGH9r1gXzexQdiWQ5rSkElhpGBqV+4e0v74E8ooTfocmgZjfXmR29QxQ/g7nQrzYm0jXicyzVK4AnwiWN0wIwP1ZFlWd/Htyk7lJRzOCIw3FEWebfu3hQIaB2w9chbd2GYXeO16aAmsGsZ0O3OsHBEypEsF67ykuEUrg3kCC8J7KsTNhzL3N7dbO2G8028QBm7I8DFQdmcTsqr668kPvOHFZC09slKe6pTuyFc1UUsQKZcZX89/33AGG5w9+vMR3izxZg97iM8QR3hGL+Tsy5AUjgpi2X05Kuec3xvnds1iiIIb2RI+2Ma/hLjJrgeY1Lzb9v3CJNmmFf5bg1g1js1luhyPOGsjFtIc7zcO+E89m3xgoOCXMilhXMZgdv94UBwC6jM8YcjEnMGnvgh5hPyimWLV9R5db0yioMN53QCpekeI3DCSauoCDiRUVDZ7w622+7tsJNjE+7kuR86vcjqXFkAOwE8ACay2+FG3iAvurPNpqcrOfLDCZYvrtGdHWD5xgMAy2vlHnt2bxSF9H0mt3LfsRAFuEyQ7IUn3mdsntdrdwHGkbsYamZxuIsDWFY3br0/J/d360CfZE80lzmcqwEhY51BFBBzJupcqPJ3QfEMBZLgr4izmKnIEZkPZb2Ne5Drkpz7sWNA64XeX43xFoz5KO577k7BvKoFrnr/N+fbnFxZP5t4k1KMas8k5jN3gSAbXo71ycuWdbgWP04HMoInyn81f1uYo19iQxTXEc0h497m2PlrxnUuLnH/v0rxqX8nPL1Cjxgb0FsaTuLvHO6lOtYaTj9nQG3fdQFgz1WHvQXmep+TcStnFIWBpm6hSI3mupd1t1mj44L5eU0gekLlVCqVSqVSqVQqlUrdb7C82eiho+1sOOwIZ9hwDr4qb4CkzF48yp3HMstquzR3N8MdhvzOkSBHUcCe4Tseub2e2a3Dnt/hpmRe7ELOQ7qV6eI62vYKEAUuNvweQEtu3RITMJ/b6mwpGKEIWxuO2LaP5kcCt1Bx+hbHqN5L27SrO5Cv4a8b2aja9g64BvcX3H1wwB25rTtcgoQw5ME4T2XT6rsaNSkfE0BHjtqAvv2iIzybwT182DO+g47U45FwDMeCcUTkQbfbMZ6Ah+Wu2N0WjmGct2DjAbm4HHfEksj5Jw6KcwLABszCVmyc18JbZ/Hq8f8BPR2YeXp0Zzq4YEBcjqTHeSyUc73e0GE3ICy3oSAAngcATjheB10vRnXjfHy8ClwmhNVWbYwPfopjCFezuLQ7g3kcclTqxKL5X+AhnQfjMxx6tZnNBKPu6iyFDvr/xgYGRYMtFQ0iHsRf0F3KgOqCtAHM+SsWO2AkVKTE/oA57igrttIH8I+Rd/enehwKFMJJziZqDsgAX48dcoPlQsZugtmit/0w2tvf8YitdjtbnMOJC8cw3ltzmnMV77vV2J4hAgLFgIMgbmQtIO7i8vJScxKNAWdz2+0OjL+YdT2dkRjrZb8WuN7PbLg62uXDW7v5MPLYRxu3cIjOrMf9etT4HOdHvh4iLbCWLP3r4to5vz/00HVbrpa2OettucRjelutAYeXdn5xQUh87do1/vz82pmt1ivrlmvr15sKfQGWAfCc/xcnK+/h06ZlMY/kPI+E71J3CHjsuyGC4+lnXjxoikweAuSOUZ8LWnC8UWW4VtHtczdpUqf4YV+aeTi3F+FsiPzhuGcAF7W20S2LdSZyzgP4MpJnqXuHBQpAf0UuREY0oPL1Bx/gmrzdXun5bQxG+Xec+wmm5HpSAbPWc7htmYjvGeZy4YObA+oCKCMzPvL3Iz5DS6Nc9Sqi4Xzq+6FAcmDmMBmuN9/0+xtjGY3r5ksWQRGDo+PS35DZzJvvoTA693gLxLXgGrEwVws8vOYHj7PxoibWsSiIKuZHhQzEzYx7xADh9RT7omjkmY2zgw3zHdea5QoFDOy4QZSKYnfwXnvsM9hqBwieo/klKK5dLkGdPRqKTmY1zlQcksc0l94HWK8VD1R2ycTaGdb7VCqVSqVSqVQqlUql7idY5md2fJ5l4yOkG2ubrbYU12iAIDEFyzhE4fbd0pQJpEwQgY92Rxe33gMQYwu250nO8YEf7QEX+kDPCA4+pokOcDgNZ6VgxmhHwmY1mqITrnkvun6LE7HQodpADmCtcSdPPV318ep0Vx224ZQrIMZzRuVSrY7f+HfdMh4+Wffd+fMFLn0PvoNtGsQBgh0EVU29vGVMcQ6AiRwfQKSFLcaF4ItnURdLbZxaQbvFrtuMXQUt4d489cCJHxHTEox7imf5JbfjV1Y3eR6BTw2LnfxO32Kk/L/LtvwY79YpHBbyU5dj/XdALIGm9mfxGOW9RlxFHQCP4mgiDsKlGdewGbgG5gguui1YtwIpZ1wCvacfSXHZa1u+u1vDtcy5UJvDLeDI7fSFMRmGgy0ItkY7du7r9KgFNRpsIH7jhsZ9Fy598auIOgiI5q5Mh3eRWcscaZ5abNPXmXQeD6Lsa7POgTwiORBDgdx2FE3gKr5+A0C5t6c97QEHy3AYL/hvgeXezs4v1Gzv4oLxCBu4lVdLWyxX1q025X7Q9MSa4GvQyRRoPa8FDEfIc70ZJpexnVbFSeoNRt1vXJ42dddO7oDqTI+MXB93zQUAQt3r1bEexxNZwHhMCwZjToQ7tvl3XGVv0FjB5GQvR42NwfyBg/ugeIp4YAXMZUROwPLUqVwfo3W7uoXjPokboDbsPF3XmK2MBpAeO3OEKzvOpcmoudM+hnoEEdoeADWOuc1mnu6PiPuhNEJsevydTAsvcgGUN/cIHeUC6ATLSNuA8X+ODHy/h2M9c7f+6A1YBfA9Z7qxx0eD2dtP0ItfpUHm7WtVmYblwFuAf4fXTKVSqVQqlUqlUqlU6l6C5T1yHO1oVzs4lgcb4Rz2xk10fk4+actNrO70gslHd9MpR1NwVMAA0EE/H/ZqSHfzkUcFlz1XdH22ss35xrZXe3v725DRjA/qAgJn6xXzVuXU8maABwGKq+KgXcrZScea3KTK4cX/E/BA3ideEVvh4fqs+RVyGOMg0UhpGND4D65Oh0jKAlAu5gj33YIgrWyp5n58HxZmEQcqhnt4YctuwcZn20uMkZrlMWGh5ALLtUbnHs7hOCOEW5+d2dnZ2jbnKzs7O7f15ozjELRy4bEQYkiA8XsbLw9yBh8HgsgeLlU2MVReNMD8biv3JDETAXZse0e8QjTocqDuW9MBe45oewiH8jjw9Zl7wCaNcvrO5z3zotl26ojmhHIM0v3LHOloqFezrDWE3ogRrjtvGAhvcc1Bnlu30PUCjCvxHSXjtcI+vl4NE5h8VViuzGc5UiM6RW5ANtLzRoFqhiWgq3/7O9FYXYEwhLkBpz6KJnKIHtkkjW5Obps3W/RLm/dwt0aWbUU+KABgdiKbe+tO13G1lPMe0QXL3s6uXXCsH3z60+29nvksOxy2th9u2hwFhfFoi6MRyuLxbewG38MzxenmdkOkYNfc+nln43LjoLTz5mrKByfA2qPIM9p+O9j+am+2H21xmNnFcmP9hRyZcDWj6R3Ase49fV1cXDDaYrVZs3Eeoiyu35AL+caD1+lK7ldoEKosdWYtd/hZ5KsvtY4gKxlxFshGnvs94E0rOeaE4SiI6Wd0vOPfB89Hj8KE75AggveoggB6xbmMx+K/eaF1nUoySUDl8pxoROfro68ngtFYCPZ2xH3JKpnuFT6HBbs4LrnmA+5WuFkjNuquAYFSNSstpauKUR2Q6pxRXMK1PzTZ44o84SEjmgI7G9z+GgW52K0hIOzH7PdsbdQXedPKBUYOfIBl1OwYPQIXMtaL/U5FRzSpY6VE2fC8gjFWbE5aC1tBejGHlsiDLpsGsINFRZWAq5j7aBapS+lZ/h7hBEcwDusQRc2A6A7D5QL3d4TTvizkHvmD5rCTXQ6e3457DjE0NrNbN7f824mNMnAt96vBlrawBdb1QY9hhj2iXpCJbD3/zmA9mI0qzmGcUCAR3Fc+NHf88HJjJ4kgOLKisQZjjVXTPrjf8Se+iQbiGOD5UTdMsJxKpVKpVCqVSqVSqScALCv3WHETiKoYhl7RE3RZecOy1uUXmce+PZfgoJjnqjO0AAN3NANeosnYsN+XiA265/qembFb5LYS0OmtekQSKFOjQYbAxch6FZDoFu7iojPOYyoKCPEM5Yik6OT8rI5EB2gwYxIc4Dxmk6HTdmSPSZgBmnijuGrkK+dbM5sjjlfNxUTdy+b24vwrjrU4ZuTHuisVYwJY2HW9w27BV712tOYxCwAAbUpJREFUhUriQALvI6MIFHvB5obktnBHKlqDebqRGRxuOsJuwaJwMbdZsyUugjElHHnPxA6S5c0deZ0CyAluiH9XJ2OMZ4F0xR5az8txaHHBKg9WjamKLbV1HxceNfWNFxtimLZ5Lr7lHwDbs2+1zT6ub7xzjQcQjKzXuFVskwfgFGAGYVIDMMwnxL9Epjh20WtbfjjYg/NpFqipIrhVR4d+mQ9z5AX3yhg+29jZ+bltt2a7/U2PMw93s7KYyTTdvVo9tQ1gim3yPEcUgGIu+3mWCBAYZ1Xk4ZZ/NOSDS/lotuyWNl/NbQEn8WpFKHy2QWO9OQsMOObr16/ZBk7lzdpWaL63Wtk1B8sX1+FG7myxVAGKGdJs4gfgiUgOzE2PmWCVp3F6hhU32DHBocCpYg585wQGE79GBAYDYDQuhI4eFSKwW2/iZl+BAOWpXbbJZI7IjNaAXwpOpZmjGugJ7sqNq/gJrX2VaEd+eriYp7sJps789l5qiyqR9+2PhSvaISvem4U5L2xZE/PDUyiZw/XVBLpbh/epU1m571gz2biT668KSV23YIGQlyCc+MUOXndcyA3cFoHitQVXdbw+HlyrNA8AjYPEq5Cgx6sW6M5ih/8l3/+20YviYLkpGqdvzIkoIrn5PK4d/3bN+HcMf7dG3PZ7vNbCFiv83fQGtp4XpfmMfGX9PcWcZktaj8iI+Y11GjsDVPCb7hVhQ77Ie45dLvF3oHHP1wJa3ReRSqVSqVQqlUqlUqnUfQXLcirOuW0dOZyAmvygy83ucnKNxf3mQM3jF+pWXY+5ILvB4+DAxTP14R4uOhhoV6uNLRY9m3/RHX0w210NNuzheOuJjQ/N1m98mAdgBrwihGaUBYBgOLUi8xgQ13M8I2fVt12Hg4tgGe7lozKZyxb5yLstDuBBzcY8q7lD0zPETISL111wPjAF1jFnF6CMLcyOtgdAB2zEVmo2GesJ0G5dXZXGbwTkPTKnAdEXdn5+5mBltN1uR7CMHFOcCwE3Qa3c5MATyBFhWcC3jwcc3++QdTzYuEB0iDdkA7B0qCT40DSiaziXHI6CLnB3Ci57U6p5zRIWfFNzrhbGEv07/CkuygYA6zjg6u7u0CCtgWXRyLEBfjUnN0D0kQB0GsVSpnUpdmguR6O1OecPxizgjMZUx0Ow485lnU+NNdGUQGZuwCO9HzOndztBfm+2eHmlQsl8t6VzEe761RoN6ARTdVzR5A33ArKSOzvs9ypKwKU7M1uulzaOG7u4ds2uP/CAXV4u7DBumUGO81FfLwEwOYaVKU1Qy90DcuvGtabbMx7PZnKM9fWGZd5U0wV3PEDhbDWzhx56gPPVPI92fe3M1tfPFVmxWeuemTeOZeRCr3rruKZ0ttrAjQw3vXLAI3e3xr+4C5qwzCNdOMc8JobzRA1CBZE9YoIh6mFk9qIXdmFg/theELiDo9/vlZlidZSN3ADLcCPPOg9kUfM1zTsHtWVu1nCTNngjQKTPegFyXyPrQ2v0SfUchz3a/80UCY90YLaw71AoTVW1xriN3s+lHh/mM53LMAUzX93Xu3iAZwnXg6qQuTap07yK9aJ1/GKNRKEKhUjm6nPt1S4VHDx3K3i2PFy9+OrQaJLXIG7LQP71Pq+MdMFYlXKdStNAFWdQEWKckWdnc8cH/i5w7fKiJzPH69re9NQsHvRYB2+Po9BzYyNHjMk407njteDIBlze3tzxb1h3dWWXV1vrlhvb747MBb84jrZcYVdAxDHNuA4gxmjETgcvaNUixvEkkqbmztcrVY8Jc6Gu4Y3rPNbaZj1MpVKpVCqVSqVSqVTq/oBlCmC5427aNj9UjYHiw66cePyADXhWmiDpZ3AjKzIiPty609C39cONBXgNsIyngnUQLG+xrRpQrLfZXNEacpkJYgBAIgoCYLfvBSZ2gxoZCVwHWG6ycL2pl3NEgWU26EPjJMHIcKwFYIB0boh7AJTztAs2X2tBrDuZlf/hIzWvjegAJADe9ztlQ7tLGG5NgnvPySX0AUxE3vQMsQ+dbTYAy2bb7U3b7bcOpwc7eLQIIFMYCAPGaaTDFS6npu3MRkAc5O8StESerq51BfAVRES+cWwzlysZ26/VSCrGcrbQc9UADz93ABvNvRDtEbkL7tYmnnWHoSC1HJM1VqLJCg13ujtRS9asOuU5LGqBUTg+HTg2DsQavhvg2x34dHI2nnGfLz4LmkiMcBDWOQVwhozjKKzwJ2hYuHew7A3NtpeIrDjYDM77rrPVeqPmjEc09arQS2ZOwT9CKxQjAgbPcc/0HMPN+ZldXL9G5/PV1U2bdxHHUJkkoTKvg9zWuO/iezhf+VDP/OWOBFwiuJLxOhxjuHyrExRN+HAPLB+4Rlq/XKw4V88fumYXT7/Bc1qvV7r/HBienSG+BREgC5vhC8cT8QLlQjdufdy/nM8Cy4dyrF4QIDtWnns0sQRQljM4jL/6OR3kyKCJvGqe70jArpifsWaF8wYX6HaLfUlK5zjMTzK+T1TLK80OjZh/hObIMobLP7qdRsFDx83IoWZUJnA5nL101Pu18VKENWtSzR2PKe/N4DDX+XLeGK/stAg/7+0K13XN9Y3flLuNbmWBZXxpzcdboghRCjyMyNFOif1uJ1c5dlKgoOJjo/FUg76yBnHMsKMgmvGp8FB2UcRa7YXDKFDMl5rfKHzgsYeBARMFtsawBtytewamc7EC5nivursi/gYWsIzdPfudXV3esqubW5strmxxeWU94mVGFFAQAaOCCwqA2gGEmJeetZkSC6TlqMD9Uqz1nSjVcV3LGOpr4EVGriV1x04pynoxOJVKpVKpVCqVSqVSqfucsbwXBAH0ckdoxBy025RnJx+wAcJk+NWHcsAmOGoJmfeDQy24vgCd8WEcecdyWcJhNsfz3QWMD8d9587gTnAHDmeB4NG2uy2dsvP5kh/O+6W2yi/mAnD8wE+oUEx8ooCEywGc/b0iziMiA5xxiC1FQ6qjDdj+3zSyIzR1N2dAwYJnfFt+M1ClwWE8ShDEM2WXS1L1EVEXhC0B8rCVHOe2sq6f2Xq9YRM0AGtAaGRcgwUCLMCxvGCeqOeINhu+WQTg+8rxqMMSICPUpCvXncDuBGXWM18c4ybwPWPG8dEGbH1HNQB+8igu+BiqISPGZCidIDXEgnx4LGGMN8yaFdc58CEerOcJmDVFjTmOyXNtCeAUYxHwWxESNeJDbkkHaH4dauxGjcWm2xuwFfNt0myrAd3uZg7Hsi57XE8VT3A8uCSKdVETMzqX4eJkwUCvDdY/wNW56OTcXCDLG05hzG3485t5OF8wt3bBxyvGAd+7JRy/yN0+s+Gwo/NdqR46alyTA3AsgXLsQggnpMBaOCN1P3hhgcAZBtiVYmZGXYcYyxZqLWYdM7CXHYo8na2vn9nm+jl3AmA+lyuHfFw02+t63d9wPMd9GQWq0qgOh8dqBcdBgNedysXl3kY+xIUVAGdjNHfvo0Eaiyw4ZhRcMDaRe835p+x4NNALl29xtB7rbgflMOOHnn0LFzdd07ovisM/xrQFngUTq4hxhLOYsN/jCxwwHifVrPY+jJcLh6vDQQfHyAPWU2oTtwKV45jCcc7HuuvZXx/3Y7iOowgQ41riWVpo7a+n89Tcx1qjworWAhXGtARqHdbch5u3wukWfsa5F4Rf4x3imkTaBeMfEKXhbS8JiePLc7Wxe2DY8T0iHka/q401q5s3GozKWV8ji2IHQb0m7XHV6+W7JCKrmTt19ra7ujKb71ARtH65s36xoiN5d3lDO12Qq4z7wZtscg1CJj7XEM959y6Acpu3QRieiz6JPOHZTjzMdI1jbvn9hccPXLNTqVQqlUqlUqlUKpW6j2D5CtEMALurDbcfCyBEvmTNw5RzTqCQ7spBEQ+KiFjYchlbn3d063bd0vpubvv9wS4vt/wwvd6srZvPbQcwUbZmH22Brc8Oi+F+k4sTUAgAY2/bW5d0VHfd0bpZb2frjS06NHcSoMMO/YJH6GpVAzrCAzhsvWFeOHTLNu/4UI4Q3IAvdIqpcVKAPcIWgG33JxKmh3WODf6qO7J82C/hnLSS0rEJqLxcrdjQjB67/Z4ZutqCPhI4wqG6pmP1aH235OPxfsjsBVxedEfmdLJZ3sxBD5tAxXkgHxeBvR4PcDi6axWFAzi3fZu3jz0hzHigAxaxI4xXwLhizOCuplt6Z+NhZ4cjwKk3lApX+RZxAygc4GdwJQKO9xwCNH7D+++2W0KO1XrJvGA0ANS8mtvg8y2a5gUUJmQkiEEhgim5yt8GTPfc62hs1UKYaIgWOdKNsdxdujo+NUyMbf9y4Qt8CuY5wuOYRWQFIDi+dxjnxcLQ005O4+rixLkqekTzDI5NgGW8Sr9c2YjM5CXiFjBvAd8EMMNRCmiL5nWY37hudDkv5nZ+7cJuPPQAc67f/naAYI0fMmdxL3LMUKQAkCYQVGYrM2mR3eqFFc6FTu5JvAe/92vew2iQRyCKMXQ3JCNk5gs2kYTjlGAZx4Xme6u+ZiIHwBeGFkku2ciYiHJ4j+x2FjnEctYCwDK/HHOT/9P50IHsbuGSJcvvKFoJTS7gQIb7HE1EsUbheFC0Adz0rFuMswobOKfj1InO8Gh3VSOPnLEYe19bonmdO2dN4+VTYwKWvQriTQD1mkKaWCc1l+eMWxE8j+JL5ATXMlXM5XCcRra0nPuCqvjuBaOAzF5tUfNB3QeRn82dJ2PTSC+iJ3xNlFM9QKsge0fnOn7nLUkR1+KFHEQZMQpjr/s44pCxNihvXPcB/xB5fjbAJ9agtmDBzPYyfO5W5vvXLHs8Jhpw4o3UvA/fBxtRzDqoQKMLAZe2CkbF8R1w2XfNoAleFAzj/dCAtL2UapJYHdsBg9mIdKZ4CjWwPdh+u7Wrm4+qh8BwsOVybfPD0VabMzu7do2geLE6s/lqrV03KKqywLiwcY4IHTVgRQVqNiKGKSpZKnAcm0xsnRsOCP/Hd6dE+QA7C+Yo7grsH/B7WKPvnDCdSqVSqVQqlUqlUqnUvQHLES0QoKP+ombfVodqdXPVKILqyKXbyrc7KwNX8EeNjxRDAUdhyUBuYxBqeOTko7CYq0cFDHgdueVqlIPn4HqQbwBkfHiPeIYSkTFpZVZdf0EzC9gpERHhgPUcUM9njWZRahKnHFO4AUuObXgrvSEfHWf+mjx3B/LYFg2QqGMJ4OLj2TRwK5EVxzttV48t835mzbHLrSlwidcrDuvigKzPXxRQ5s7W9rrwS7/jFnfGEmDc5SQmgA8IE7DKXyUyoCN7ONBj2brdOD3jX7FlvxxAITz6d/gGdSoxKJXpF47cfm8ZYPs7P9aIiSDE8feQUdELFZFjQMg0dbPTlY/IkmEgXMJcYK2BfeT0s8ViZ9vLSzsCQqOIQmAf9F8HHVCO14oO2+qQx3xB0QF5433XExz2yGEGGAZoLbndo4ogHGYvjAAEcx4KaDE3GSAq4PVyQ/BG8En3phdoMG/ZYG/OzORw1GNOctdByb91N2dxiXt8gbYMNPEmtZhUwF9bwKLzPoo8fk/qoggYl7nYxDS01764mb0xWsz1spZ4tjGu2cEdypiXjm/p8yecdkjrnUSZdRuO42YHwETRPM+jKjg/CQjluNeTHPI23ldBQgHi9sXqeIRija6OXUUVxWvF7aERLGs5/tud/uX/NNkQMffDKVxAeskeVmRHWXsaQ6/ulTYywsewvGaNrIjXn/yJaf5PyRpuXssXV7/tUCjBWhLnD/ir3TbY6VKvCN5E91RERnDdbZaRkyOo4N5d9e180ZpVzzt2vNSdLrrHMD9xFCrceZtZZqfvCYzpSC6LW6wpca21WLDhH/5e+VyYHm/dcRBrVRyjxvWkEOx3UmLlVCqVSqVSqVQqlUrdV7AMFyI/mHr8Q2zpJgxgPIWiLQiHHQwRrXgju2E/ENCs3VlJJ6jnzh4OgorMmsWH8D0a0Qk0sZEe3GcKSy4fsAnUAoE6tGMMxHGwy9mV3KtwsSLbFS5rOH7ncpjitZFlTKC86LV1PbZF+3vM3Y1Hhy2bljmgdEDCplOenlqzpXHuR5sdDmpEhq3N3kiKebqHPTOmtdUdrjm5AecOAwEf6OzGFmXGWJjN+6UtL9BcTRZtAJlhBITU6/B/OJ8ezmzwFYdc4k3uqtaxwRVMcIM8X9/KDmHL+jBsGVNS85QjtkAIiHEM48EWR3dTCo8IZIFJMwxbr8fma8uVRz/I/RgwjWC/FCN0XcdezlOeG8bOM6URJQFASnc05oBvUY+Oe+HTZBGCBj1mbTAeI4A+ZxHec/ScUic9g7vHowlWZH0r1iLgJK3cQuJszgiQJofqkfmuHg8SDvhwQXOMcG0VX0JH97Dg/H74HbcEhneKAIjQ5lu3btnlrSu7uoks1pu2Xq2tHw+Mc+l7FT7kJkVTSbMtXMarla3Wa1v0cKfLyb9cLe3i4syG7bldP7/ge6jRpu5DOsR3g233yHldsWEgoysQp9L3er2ut36Ja4h7tbOesBjRLOdl7nKcNPhOvnxusXOhygi6/ZtoAb+9Ro8UYGPG6Jrptku42uOexo8YAlIiMTTHAHkrDKNl2ecXxqDXeXpeMBuTxc1QywwC2bynaOP02hXiFPR4Zqi7FI2jCA42OQTsn8/sQBAq9yvGY/D5EVCbRRE6mAV6a0mEuSLuWva4Ak64aDrpADQAMCMLHCryZotoFu2CkMPYp6sDe818rAnu8IZ7uRRxPD+bufde2WhocGlw6dngei+PhsH5s4Dhbl6PRQo8iZgUHzUv2GGYe/0dCBdxeX81S9VOGGVY83VJ79WULoqRvrD49dDaU5g4C4qIysE6Kzitfqkz39Gy43PVCBKuaBwPFmXtfYBvl0555WqcFEXbmBxfV2J+erYx3hvrlgp2XnDxIquKPdhhsrTNZqOCBHYzdMgcx04DZO1f2bBFEQfwGbOhtyPuIzbtC/c8/t5hZ8eekR6YfxG5gpiXgPu63VSo8f0TFZ5Hw0L+sN4XjNRIspxKpVKpVCqVSqVSqfsJluW+jfzIanGqrlF3Qnmzs4h4mLrPwuUrKFV+B1gSmZkOGxVVECAw8kMFVoubLn4e0DIwDT7/AzQD9AK8RK6uf/CORkaE4tF9jw3v6pboxldYIjH44XzibHW8VSI3wz14FBhoHGHT7euVTxTHnjcNrPmy4bpV8zQAPQEcz3cmtK1bwac+xjrUkwZbpblZ6wBun1fd0KWh2eQR4fT0r+JwDKdpeJ0deoQD2h3n/J07wwmhfDAJTelYlDs0Mqrb8SlORT92gboA1pMjdMhZ8hY8SzheLo7QJ1/5FhO6ugUj8iTmbHNI9Rgah3lxzobr3eF2zDUVWZQwwMaH7ngmEDscmGMe0Q+A5Lurc7MRmauId4E7WSG13MbOGI2ICmnvU7mFEeOBiBTMlSgYxNTG+7OXmj+ug8N5uSSABmzWd8XOMHJjtS7FgshoFTT1Zo3FVdtepxjP1jUZt2p17PoLnEzeCvXifydG3RMnphzMiopx12ycbOvOjLld8oZb57IXJXzHhNazsJ9iTYn1SJEqjJkok8FdplidCFAd/HpG+fSY3bHNsagxLrph46x8TeJaGrsHVMQCZK6u5GZMPKpFpx3rUgDquluhjmFZEOo937r/T69JO/6l+WldR2LXQXnp4i5WnIqmQmR7+DqE/3kUkRzLdW2Idas2iZ28+dRF3PwdUt6yQ2W/hix24t8RHQGn/MSZXP9+teNS3jEytuuCMhnDiNup17MWccq5ISoGxT93+eOeBHAGSOZfNTqrVTSoUD/+bqmAFdntMTfj3te41uJde49FQUPjHRnzsbOlPedUKpVKpVKpVCqVSqXuI1heIfMWWlTHaAACOLZKkzs0+oMb83DQh2p3t+FDteIEvJ2Qg+pZj4Z+chefsS+R8jmRwxlxEIS/bPA32riXk5AGTQIuOf/QmKvnVn+zBR2eaDKHbf1w7zlcAHRzZyePA1m1AZoDHQDWuSOy4Au+wIyOaLgD0dCtOpbjo7+AMrJ8mZuLbc3I2p3PjDHPM7PVcmXLflXiHeAeYwQF4QB8c0fbX93kB/3Lm7cYiQBXZ7fwhk7O4OiiHpFPCueroNJmhUxiz0T2RofsbUZHnfJlo4/TuPfXInfC8fV0ugq8jfR/whXNMAu/ToKQuJZo1OdjyGxejNlexQE0k+NzELMgJ6cgnRzIAosF1zuvPVq3VJO8+eLCzs7hvo0YETStkhOQRQCMgmzwHIcSk+HzMFR3ywfwGu1QYgYqGqo1keqCl/tVoPBwUMxIgdHuPiwv4mnP8Xp6XzVki4Zf+L5cKRrl7Gxv5xdXtt/u7NHdw3SBa3jl6N9eXtnNHZp8bW29Wtn+5pWt10u78eB15k4jjxzuzw6waHGgUxngF19qbgeH99LW67WdnZ3btesPKHrFIV+/XhJw0XG8UDO9s/Mzuvo35+furFwphxXAj+5ROHRxzXEuuoc1YeXwrZCtHYTK5zDHcT/IMaoLwxgc/rdiN7j7wJuS4fEqzniDvn5Bc6+SzgEh3SXszQj5XmxCid0CaG4muFegWzj4mdNb4Skd11gLxrnNO7hdce9orsnd60UeNrLTezDanZEhciIbY0Mwlg6m6ebFDMI54PUxP9SEb4YMmYDpBfSFs74ZRx6cz2vP/hZM1neOCzN0lfVNbzh2VjBDGvnUGEa44fX8KMVpanrhpkHB2m7gbmgeD228/r4nUQtlpwheVZE+dJ3znN3p7Y36OBdXKmyUXHLueij7DJjXzvnGwppWYOYg+2vI1TxDspDuK/4dwIhi90lPpy2+jngATdcVCNNZzbXfC1ikvH4vMHcYueMqHB3HfZkfkbscBcIoainKwseDfQSwEyTibJSdrhxmzGf/m3JEHr2aWR7O0UR1o5rBQu53FH5wD7CA0+Nv5WizcYuFxw57B8jjzsZR6+scfyHw3l7MQeG01nAUR6Nijf6eRgNCzBklYet+VBNWn4XjYMN+m47lVCqVSqVSqVQqlUrd7ygMAYmxcOUAEw7l+GHcXVLMOT4wH1YZxvGhd5r72DrfFkc1IAMgAWQDCCbcCvenRxEoMEHboeVwA/xbgBkYIjRppHQIUJyV7vALwKYP3WGArU7psgXfn6EMzDivmqU5d6dw6+Jz75oaIg0Ahkfbq3dbyRMmSCu5sgDV4SgOHzbOHQ3E0PBqxxxd9pnypmfFGW21MR7eb9ljvADFsEX/xAEaW9mLyRtAW1miop/eFLGL2A9sq5cjWiZUIQllOcNhpxiDaGkoLisYH1vJNY6etVziK2KUmosf4AZxJ/O5LQGC+pq3rC3zcIGqHaK2moeTMTKl/XX9PXg9K9/Ub09Mj3EUypRt3Nzu+qzwGcWI+rwAxXH9ivG1dU43ju72ui+6g5osLtd2RG4vc3R9jnnEDObN7mpnj77jEduvdna+WtmwXdtq5bEky7nNlgsvXHizsQILdT2Uk4yM5aWtVmuCUbj3MV6r9YYQq1utbLFcMg5mA7AMoLxZu9sZ8FpgXO5y0FGvjET8AMco4idORjUGl3xOxynHqCJJSo45xoaAVk3syq6CWEs4ljWeJDKD9Z5xUWL9cbjGyJPIup1mM5eIh2pdd9e11hI17FMjQ5qXi+NVr4HpHTsAuCZ1gqd8S1OBgASUYBGYVc34VHCTWzzWmTps3gS15EuXmVmP0Y8jinIl67u5j4rjuaxR4dYGiMSlwvX0vHlftqqXutllUTdr1OPwXSc+or7Wx1qsoo1c6R4r46+FpyLmQjFG0eTOr7PHNUShRktWVCIiVshz+bkOuOuWsNafi6Ih5qDvTJCT3K+975TAayvbP3bAhOs3HMBqYBdNKEuGsRfYyig0mzSKO3+yc8SfU1zW9VKyCLFAwQfRNb2gMv8moDiJog0KJ57xr5wX/pFlY1XeF2pOq5Khx+7UI/C11XdeNE7/+Jru6QjHcjHiq1h7iAaQqVQqlUqlUqlUKpVK3SewPNDpBAjYZljEtlxlv3JLuG/prlua8cFbRFqgTcAqGo7JHCZXmRr6zQQ5YZyDO600tpLBtousULyGu7T0OOQpw8k2s27pTabgzvTmdn40/gEaH9jp8bK5O/DiIzk+ZB8Gz78lWDxavxTYOswHupYLCAGE8uPreE5w2G7tsB9sd3Vl+2HHczm/QG4tHKJy0gH8yfnqH+zH0Qa45saD7bZbupiZMYzt03SPqtETneGMsgDcQkboUHIzD4e9zdmcysEvKgDAEYTjyhPFz4hR6Fye2XFQA7Lx0NtxROa1ck0BJ8eFriXhGIFg5OZG8z+5yfGYgde42cbNYwyQAyiCZowa/2GAQ25wuKjXcVbt4Mwdgg7/YXYNQEIXcWRZo3hA96YiKXRoaha4IBzV8wso5nwJw2gUE9ytTPBe81/9FxNMEz+BoxLnv/AsZkEzd5Ty9RVFobkkoEpwczgS9j5w40Hbrq4QbM2xsLmA2AMPwdl8Zle3ruzs7BrnzebiGuHvoj+zWbeyfnNu67MNncars40t6VJX5jWco3T/rmRLPxvn9iAMrMjq9uLN6hx5zB0zn+VKRtwF7hm/d8LF67m5fsOXSAVlCbcW1pN8BELWKGIEoHP3Zok6cUjHAdVjMa/H/Z5QWMAc8809lqdxGU0sCdNWwnUa8Jjw7VQeAxJxCSWyxXE5HfJyAvPqwhKKrF4uW/HcAN64Z2Y2g8mZMBs/X9Clf5z3fPycGdtwy0aLOsHoSii91FUSOvwfhTk3jeBKrEWdgzpwX0sdPiNn/MAccd3DHNqZnLhRdpjFGhzjyLUXkyQAZgXMXtXxMSoXjJATmfgDdj3w8ug6xXmiMIWviSK6GPEtg/Ln4eTFfQLHcjyPhS26oj0uxhNGogiGPHnsUJl3a+v6tXYWoDEf2bxfNx+T8nfBQTDXEkJkrHdY3/H6KKRpzQpgi/s7IPjJiJe/Vz5FPfM8mqh65rTPb6VuzGxEMY6DqHUCR6Q1fWH9es37cLnaMOufBRT8LRjwN+CW5sEM12Ww/dWljdiegt0h/JslzMyrE1Ej+FvAaxI5+4p44jzk9a2FVs4yHwPkzSdYTqVSqVQqlUqlUqnUEwKWEYcgSBKfruUSbPMq3RTsH9TdHQsohegJ7ntuoHI4vchK1OhqAQfXeJrjybROWzj844dq36IMiNghAxYutQ6N+QTGACWLKy/ATMkOFsRCEzY1upLQbE4NjtwhNjM2T+MYwCUNJ3bYf+URFXAtvdyu7LA/sBnbw4+8g035xvE6G6NtLuDKljMSYCEc0NhKDacygPJ2u+N3QJISScAkkCNdbEqTFcggwnWoq2ZjoPMC5srrFHjVlmdAUMEGtT3E2QlU0yWH7eUT2Cq3I6FyYzEXZKoucgJjb9JXmqQBfDOaIIoPZVe/jQMiPK6Ki53XD9Eann+qxyp7VOBD78vsWge7hLs4nyYKI7KX4WLsZ2qEqK3uiprwZIVpxrJPjeJ6bbGyJm2bqKp/c6wqliwxDj6vuBXeQaogFK4bHnO0vutteWNpOzqRt7rm457nc2O+tM3Zdbu6vLLV6ozHuUQBgtmsAE9L69fntjo/J1gGYAaYmxWwLECs2BV3dOLLG6RhPJabiMJwt3dccIV8+8+iOWOw9ROwe2JsDOcr55wXJuKeDjivQoEXe0pebVC60Y6Ij0HETRBFRp+oKFRiIQrQFlgu5l/eQ3VHASEp368eoZ1AXP4vliHPN9cccF88wfKCZTJCOxZvHJzS/a21DE5wWwDmAbwrJ1cxGiqyAJTyHoi6lR/UpFxRbMLx+9a5HPne7vJtTOFx/Vg08uMszmLej+H6V6azzhdu7Ih2cbjoYDmKgnLtNvm7XhioYBnjjfVK8FjXqc4HgVmHsrFDxcedmeCH0dDLb0kXPwBzz/Mh9PX7kMk9zCGvSzdjlQjrEf2CPPC1Iiz8cOPvRNkB40XJuN+ZBKQFk4WeaPaq7OxwGdfdN7eBZd8pEby9LHWegx1xGdqBUNdJrAfuEfb3WVi3WLKQs1yfc41HdjmunxoxYtfKzraXDpbnci6Pu60dAZbVKdUjLvSHFn8bFHOhoifWetUN4j5W7JIc3rq/659ugOyTKJZUKpVKpVKpVCqVSqXuNViOHmYltxGf1AFZY0s5VLa5q8lYqOb84jMucofdRexNjtSwzPkJ38cpYLAnz7FUbq2cvszPZPYv4FvnnNsb9gWPKa7H+MAvoBYQkSyB29arZzm2a8e2aB5D/H6GTFvf6h7AiMZlZGa62xo5yvjs71AdoLvrl4R7jODwbFicD6G7u3PlmPTsV88vZdM3QHO4TAFgO8FSZFIfFgvb7Xd0vSIyAzAar61GbYGAT8IKGpjAcXdiXPpEBcgHHPPM0ogvKPm5BfYrQ7SCes4Sh4MHG+CuBrTYXfE7tnvjfPY45v1eL3WILGwBmAC7yjXWtcPvwq3cbtOPjd56bx1gNKaq17AeWQuI40U00v4TOqDdRR4QyLfOR7QGXdi0BDZsGWCOMSTxBgDJiDNpXdEog2jOHAblY3e98sERvQtY1i3N1kez8/Nzu3HjBseKYHmxsPVmwzmwuTijU5kwCnOCMQKe/+Lvj+chh9n6ma03+nE0RcPj5W50l28cM939DqkYIYP8lcrsy2vXsN1yruUOKY7h5gI5cC7RFzF3YjIWUjptmqmFROfFYyk3mpzlk8XB7xPcm8fTBpLt3G8Yrk6tacQZc6RpJMkYkGjO6Gsanh+wshYm5OJl4YZQlWRP93ATo1BMv6d26qhRRa2uTNQaQVEvQDi2PccA9wzfSwCe8UB8P898LpdL13ssuwzC7SrHchTR5PZuiitlgpQ7bOKw1toc7tgKw/k6zERGFAaigbBGHWzY7QmkEQkTdFYZ7E00CXdcxN+SY7kvBWA3tsBa2sPpu/TCGv4uaO0r958X7OKacW5EA7+CZJsvd5Bjd0cpYvo9Vf6e+TmW9RPgnpE0UdDwe8qvTwR4RMGHuf/9gbE/ai7r4HeGrOdIcUEsBcZqb8PA8Hx3H2PHA9aZ5u8S1kfGU/nf0JIp7jtqmFldnfL1tnNntudY4+8G/3ZkFEYqlUqlUqlUKpVKpe43WOZ2ZAchAGTDdk9g2LGBk4MBz48UXI6YDAfLgKiwEaNhHLfje9ZxGFuRPcuGethaPpej7DjabjfYbnfF55ydbwTHHNQC5AKkHeH+3O/4OnA/MhKD2ZUBtKpbFRIEkHuNAMkNg3DwHd3JG42/IsOSzaIABpwUAYoowxMQRc2i1mdz5dtu1rbanOl55Ae+ZXk2I1S2fXW4ikV0dpiNzN1kfIXDK5zzai0HLqHXcbRV3/M8t7du0eEMaL3bIXaj53hE00RB4mkmK67HsBdIiMZ0boITxAF8BJIAx5grD9T3cjvtcdhH96OOswBIApaDjQAjV1c2DHt79B3v4DZ3uGwR7wDwT5egP4fbwqMpnkNswHvAIJxDB/CiTe9UFA3oQtYZebEAL4jzkjuPzkk9w3NsNbfUlK06G+W61rVBkQIS9IfbtAHLNC1WGApnLo+DbmVv9ojXY1EB88ebffF9cS5w0O9ZDMDx9SvAZIzjUufcLdmkEe7F9eacc2nZLf245Cae9Z1g3QQMBwHWUc4ZFQAnKEg1yHLMY/hoPeoAznfEAbAGgHmwMLQYI0zl2Dvc8wJMgDOORbFsBj5rdgGUuIkAoF4c4ZyZT2J0qlOeN52ydPlcdykTyKnRoMokAdXkEg0grnkMoOrXnTsr5GafllViukQRwjMWIkfd51GN7VChQwUw7pXgfGAmcCmIaY5wPAkz94p2OS799SJTuulZeXo4MUebYlfAXPd7O3H2J7AuoOIPiaQCugkakeWteB40Do1NJSgCYnGVwzjqInGh2CzPYXDAbPWoC/Qe11PHVYs23sBuUAEnnM5hMsf15jiNo+13OzZ03V5uCZeXyAQ6E4SXY9hd6J75rtgkebGZ+73oGRWxubhhXb+yrpODX+vbwPtV93WNdOAai7kdqR8OZFugjPW8xJI0RYZovKlIDd17mvYOYwdvRuoO45IPzr8p7m7m30QA5TWbXvbrvQ1cVL1hK3dcoDja2chrhDotmvbtbb/f2n57q4BlFinGPZtTxtgz+py7FeS81iX1Apc3l+XfwIDP8WfW407UJFdNQ3dotptgOZVKpVKpVCqVSqVS9xMs48M0P2CDvaHZmvvY1MhJj4lt7toW7FvjCZSrXZBQyJuV0TEbMMozOPXht4KV6lBUUmgBaN78awGohG3EEW7pzsXb7a2xZb9xVjfRDzymYqCMTnfF3qifx8G228RPslPF+AJUAi4wZcDHJBr3tRCpkUN4fEVGaI0sqPELeJ3FYiR0JqCHe9tjFwjwCboFAUtjugIgAw45j5oJlAyLyI4GaFWusnKH67bwGk0iR3mJlvCfE6+4m7s4RN09CGd1QL8YVx0pHI0ef+CXG+egQoS7zhsYVpzK1W7bOCt1XMWxWKIXFI1QQbM7ugkMAcP0WMIkhy8MWyjN4k7gMsz6ESHA6AfkUgvmCJQ6sGemqWOrgG+AjQs42wEBveDAuAq4kztmKjNXGbC37x3K+hyIOR4u44jvmNiKK06t5ldlDwOGw58PeHXYb6MOQFiNeAHB22ZeRg5v9bfWYb8DiCrN6Yo99+T+KB56b0rXHnpxMzvIhbs/4LAXSMpYtu7LEwd9G50zHZvmuMJ9XxzvLfNtc5A9yoARA/V6CmhWp61PjAKE3d7dzJtmvk66St42hPVnp+fQOrzjGDmpsWbKoaw8+diZ0S5JJ+7tOiH9+Q6LCVvVJK/Ad0+Kjoz0WmOIZnXNehnFAXe3l8ewwFjnTNvoTnOzTKoy5spN11rH5peYo2h2x2KDZwVjPjTrD/9dF3EB9ZLl37x3qWQ2rvDGNR/rRI1tiXUzMr3DFd3Or3ZONi50NuhDdMfclmPHotvBdmWtiXiN22+mdtbG+9fxnx525Ib71/xg81EFmaNX4tRsVXEvUQya/A1LpVKpVCqVSqVSqVTqfoLl7SVyceWeJMgctW0ZAGA4CAb0vbu7OjV9o+PY4wQihxkN4wQrPJfWQQRA4rCVs9L7zsk5iY3Q8956ZDsjN3dEJu/M+m6lY+nh6lQeamwVPtggcMVwYjk8pxAJzxf0Lb7LABINLKgf/t3pya3KyssFVFBUhccLEGxw0zKdoYYIixWyM2cEhfjsLnegmu9JAUFqIyg4yODum68XytCFQ7VAWz0WTdsQr8GsYbrScB0w3ke7vLUllF4bmsh57EFhU74ZHK5ixnF45vJBrnA4ipcrAU5s4cY4c9v3wWzRMwdATsR5hRaKq+js2LlTGQ3pAJALTxRERFM6bIdfrpb8ousdWbBwRV/tCcoX/YIOeGwHH3YHNppbrpW1S7zVwkDmyOJ6+3Z3HIc3a0QUxWF/tA5N/PCaAI4LnLfgoC6uBjRgJOYQzjvgIyEw403VfLB4/ui+P9ijjzxsw27HnFfEYSBDG+dVABfuDsRVMBJA44b50S31bziK6Sxnc68ApnCoa3yiMCGIKeAL4D0OU1crjqeUTAg+A14q8xfnMRy2HCs0k4Rbcnd5i18aR7xnb4sbANlw/y/dMQxLZBP30NxDAUxDldUK+FVO5cfC+7lYWfXaLBK5y3W2oGObrlXer4CmuhYs0vjYlBxtde3zposA9RVAC+irsFAOgYUFdYnkvCk3XQs4w3mr6kYUGnAN0BhPERf1OW20Rzm1Ns+HMF5fk6pVjFhAvYr/b1cBry209lgfb3TKMXD3ulzKALECp9HoUznLBUWWfHzcU9yhwXOXS5/jHY1Fue7phbFLI+JetP5p7DV24WL3alGJo1FzuNLsrilP0FE8AtqjC6J2tzAcgrso4HbG3xnkhGMdXNKxrEihpc0Q9YL5cfB5VBqBqqgT1z+iS1QUiAKZMqJVeNJcjbxlNUL19QUvG+HUPic053RO5T1QBMR7THbD6G8YG/X1Z9Ytz6xbzez8OLPd5aXdeuRRzpeIbuKYR+ET62uHexHzPwo6cBfrOMv9VpreojCIol2NyVj6Jea95Nej3k+4vxShonvPbNn36VhOpVKpVCqVSqVSqdT9BctyfzUghio+zsZpqMxgxjnomeWxFaHEh/X6Fdup+T0im5tmTJGPHJbFthFYzTwNJ23jLJy4scJlWJs9FTDSHE49t/r4lj23WNjbiNU4iPKb2ryJMQ9kfYIb0SiqsfWWL0BLNP87gg40Ds9yrGyaB2SBrOWOQDPc2wI6yv6k6xcdEBsHXuvSrEcakPxgh1EgGQ7eiEAIJ1xzQWqmdsCscpwVOBH1lccKXgOaYyzGzhuo+ctqK7c7MKPJn+O0Drmk+If253sRwBsvHgKOuSM8mrgR+Ai0RGJGPeZwvzoAxTxFj6yYX5MxjxiEyFuuV18QCtCdxIdjHS58+i4B5XFejATwyAWPgwBM7pZwI6NI4tEjPh4qVHgkQzjgG4dpzQfXhRU4bDe7N1eWDeQUD8EvAKiDGkUinkAPc+czzmOhpnOTyI/bZvV0/uiWbFzSxXkc90DdMVAa8U081eEedhcqG3gGCG1XkHAuN9sOijM6nMuRZzM5xPIe4T2XW7o5j+YW1I4IX608EqE4a2PtC2jdnG/NUW4XtrhHYpGL465ZxbXIdOpWfgzg3J57rd404N4BP8GoCg8nYc0nL+tu9tjCQOhYi3D87vNPBZzqPi5Zy/GgcLuXJol+LzbOWK1TU9dyGTNfS6J4IHeydqSwQHPivG6s9bXwUJzzrbu6/e+TAkmZk7EcT/3rdcdHUxto3dixJk+ul+/k4VqN4wcIn7NQBwjMxq0RMRO5zJxndU5FUVMFmHrZioM63q39exwFKC+AoqB2RDYPKbZf21JI9fOf3LOpVCqVSqVSqVQqlUrdJ7BstucH2nHY2kBgifxJNZebL5dyLMOZi4+r0dSIWY7Hkv1atm+zN5dnuvq23hENmPawiMJpGqzEt8SDFDEiVA4rupIHbzpFqAgnJ8AZmuUtbL2WUxhbqOkEQ1wBHWp+CA7++IrhRCO4AOipkGAKJALYyEFKVzL7F8KhO7OeY4H8VbZkottPzQrlfMP7LZCN7JQC0AcZxGhmFyAUQPnRdzxit27CTTpTFAJjCnSZCDLl9yMX6OFs9cgMPAZgc+/5yVeXO7NLbLkWRIUDbrle0wk4WwseofeiTHCePQo4OsN7AEQOdEZ2Ja85AIQDLHIoQd3oJzUMo+32g3VwacNxB5MonNX43f6WXV7ubD+Mtt3u6QTtV2te/5FOdQDtvTjKsGcGMObT+bUrjsPm2gWzg4ftle237tDDmLnLGwcIFzrGa322YaYzrsv86M3tZpibkR+MRloLO/aeVezZzGxaGOgRj4E70ucb5hDBzWHgeFxcuy5Iu9vRqc3zQbxFQFGPRIHgmoYbHA7mbqV7Be5LzkHA8NIQDHMUmavRTLG6S1l04DzWuB8O+5rAwKaNANVwt4us8RoekZ062H5/RRc4YDLum6vLm3Z581FPa5jZYbWyi4vrtkDEDRykCEovBRzP2Qak8ixXQm93v0/yf73oUSItSGa9KODN5fTzCssqHPbHhGvccb4gbkNCeWsGQPTHxK8I7+VILkUGr1QUXBzkM1yb/O5uZB1suWe0fkW6N8ZkGu+ghqLIB6+FLn9T72Pp4BCv57EM/JkWHmep1UnLhnsFjELeQA6PASBswWqMgQLcVYphNjHWQBRZ5JpmcSkyggsgjyzqiCOSy1qQW/e0bl6Bat7jmqTujo7CU7ODg+uxojQ4re1o+9iFMegxK6w/WBtWveKUsG4fUOCoBUBdeUBYvH28vwoktt/Z7HBkNrny+T0nvq17MVMYDmQ1D6Vv2ONMsJsiQDCOQ8WGpgDpBY469nheQHCtldH0jmuCO5cjmLoUP9kjAG7r3ubd2ubLtReVZrYYBpth7UEBcdhpvRmxoyUqSHi9hfWbM12nAc1PBzqNo2ISUxjnyHsfay2cy3sVu3AttsNBaxKfELtqMG3wPWC/Z0zXOmEqlUqlUqlUKpVKpVL3ByyrfRoa1h1sPHj+r+ddEqh63jE+qA8zQeji8vVtwQGW2yZZxSNanHVqAKV0CXfnDgE2cSR6wWgA52S55niy4Z1cvAFU4jHVPNm436IPWMPHtJ1ev5xkiBaXYDSMq1m+rUsxdv/X5/l2/HChOlgGGOB283DgjqPttju7uryy8/1eIC+2NPsW6/Iu3N28sCWbrWHsO4+32NMBDGhB1xqvj64X3dUcdz827/+lXdFqRhfJxx7owSkyMX47YC6+zBLV6k0a6cCbN/MCEEhbstXoS1AYyRpMP/BGkI7POLaAtYf9lS3pcJ7bcjUQFuM1UUCIrd/awn6w3RZN0xQLQDdwr0zWcY65qriQmHfhjFVEqxq+Achpi3uNToDQfE9QSmMB+AxAiOctlyue92I2t8HPkZnajGCpII+5tZ4Ri3mJ5nzMyGaRgR39fLs6igWYp+Fc9MaQxa1c3fU8HkAld5CytZ27VZWPq/sUTQRRIACkx3MOyHcdBja5hGs5LqDO0TOw+TyPcHB3va43IGSYg8PJ7VEJzSpRMrgDIJPPnfw8zN8FxOp3nDeNq5nn7vdLDQFovLxeDNJj/fgcHsc93Do8T19BUNrduQHSY9eDu4+rN7TJuWVCykGHXaKvG3d2gX/uyG0dy00OcfyDj6HTPTJ/250gTfWgjnID2Ju8Zf4nCjQeo4CsXRRLfDyieBFAsXqrm395oUnXV+/WwnTOVzyWocsRB+HufW9eh8NBzM2e4LVG+BB0Ih4GET2+lmPOaT2O667jiHW6rPVz7HpQo9X5AfPSYbw4aR2bxnlbs7G90WBbGJk452//m1CKGz70zI1vc5abjOh4ToXUKlCwkALAvOgJlUtzVEwYv4ejkBXjHQUKRH+o8SIWa//bGUfdNBNkv9zYrXKYl2uhv48zW/h6VYpAUcn1ucc5nmblVCqVSqVSqVQqlUrdf7AMdxQ+5MLdqQ+yxQHsn3vZ7T6wAAEfwJzcWMzlpWPKP9hG7AFtYHD5LujmxLPnc0FOxRM4GAWkhOMLYAKN0pA/ywxUfGiHA+zIbGeASHygl+NUVixgQeVi+vsVyOheXbovwxkpaEzIMjZAgm48h8MOaQtsiS36vhUcj3NDoVzN+4HP6wAgO2QSKy6kG+VkBSPcAwOOo11tt3br8tIeABQG2PRzZ4YwM2LhTBNsATQC8IBTFY7lxRHjs/ZIDbhVazYvYjMwNgJKDk49Bdbg6iWIAqAGhPb3LdEUgicC4+H0q+A/4k8IdecdXyOKDHRJA6quVzbf7wiWkOm5XK3t4voDOje4VGcz2+33NgAeYcv4frBhONqjj9yy5ba39fk1z9te2Go5pwtyvwOYwTXRNWPu9gINDQFvl7bAF5rSAbQB7hAoC+iW4y4RE0dbYA65o5P/7U0XK3xE/IZngzt6olOe1wrX1h25fu2ZI2t4HVw7zWPlQDc2QedpmP9omCgTcTgplXVLjMTHMPjDZqNctry+sJ3z14BYuEZybdNVf6hwLJrytfCqOG8PB9vv93weAPNibLfPV9hG+EdDaAOOUWRpqx2tWzlOsEQ1RDZ3jSLQw9yZ21Ywpqkejj8jdENrT3lIaWIYIDAiWATnGYlSELGiH9S0LvKL43j1HDnbfSWDdZbj5+BfE0FrQHDfcL9y3dHaUw6NVZvIiW7OKUAs1xkVPxYFuE5OfFq0KgDWIz3a4g7/O5o94l8AjTV2pgDGON/m5RUtgzXAz6O9HpXkF9CsIpFnEpeo/BbjezNFGmZ1rzETP5qShoOW16F2u4vrxL8lMzmMDzz4Odc73Muc63w8CpjVNazzj6KIiiTKJda6GcfOecK/PVrX5UKGK9qvb0Rp8G+RjpcOY8QLsfiiEy5xFb5G8k+Bx+wADPOLDVtp+efTsGbdfPRRFs+ubj3Cn6/OzrmDAQWxeY9CntNyn8+ckdh945nRmJIcNTrNy5RsHOkO+4fBhtlOcRzsfzCvnufm2nK+pFKpVCqVSqVSqVQqdX+jMNDEDB9LD5OnyV0cMEAf8vU5Hx9mO8FlQkff1oxtuA6X5fzDduiRMLZfqemTwBA+TgtOAhwv0BCQEQJ9jbngf3dyinKXtjIzCS3wARzQkk69QYft7kA8rjTmA6prKHDrlAueg+MjmOM2a89OxlZuMopKwOggc9cao149FxmRFwQC5Obu7l7MbWTMhQAhmzCNiInY0rEMN3NpIObO8JqRfLQZ3c4CLHQFz+TSZsM/noXiJbD9G+5UxVk04A6gB07oyId1oB5uTwCg1v1Xs2Qr1NMYqlEjHLN4vQDYAipwG68cLC9tsVsyomK1XNp6fWbniF/oOlus1TDu1q1L2yJaYo9zv6S7dnt1afvVYNcRn7Ho6czuCPv2doDLkznFAn1q6Ij3B1gG2EHshJo70jXIqAV9VzVEADF2ocd1ZuwFQXsUTgTr9BiH7GXrvgaD5xtZ2k7sEFeBa+tTv3Hc17lYODSeGm7YEuaKMWxY41zNvEqeNiNGVMxZoBiDCAQ6rQEU52zOqLeqdvw4Dp7PWKHaAVDfwTLnDuMhfL67M1bb9Rt4HLTU7f4l57w4+9t4AYFl8clmGz6dxnWASo5sOLUbqFoyaH29KcbL+H0TbaOYHbg4BScDSgsO+3WetCQ8dQmrEWKMfckBLwehwleBcxOwXNcTFoDazOl69QsABLR2wlcbJrYXrGDbCpRLZEPZWeH+amX9CEICLCNuwrNV1NdR9+wEKpcdB5hf4eavR1t2cvi6oRgGFV+UG+/ub18L6x4NvQKai6oRIu4RRULUeoUX58pztO4Q4A4ot6HghvGZ23wUWGacDy/RwOgTvkoh91Ec8iiQOL5iLvZdLDx0b8hKqFzvxzgHjrI3/xODVjPMON6agxxzD7EW8fdODVBZ1PJ1Esez3w928+Yt22+v7PLht9cIljV6S0ZvgOl8FFjW2sLCHy4xCkdsvBhjV5splmgl/g1AwWjkuoc4DdUBVOQSwJ66vFOpVCqVSqVSqVQqlbovYPnq8hY/iG4A65AD65mp/GDuLlOAKKKBcG5CZDyCjoqQCOohCAKXHJ1g+I5MWM8/hQSWAaPwAVofhtH0jGApwKe7xZR1qg/KsSu8NI0iePOIAQeFzCQO4zQdvg4G+MEe1MFdjwSDyujktuxgD/65P4A54SOeC7g3RzwCHMVw3SGaIxqyBUDzQY0caRwbadCcDt/12d6W6xUzhgGK6ep1F3Rp8ObZzcJ4Nb+2ojKBFToXMS4FJNWt22o0Vw6lOahwzHqEhjfQwjWM89VD3UkLcGydDXtFXwgM+dgTKHV2fu1CGcNoONh1tlyuea6A/ws41bHVHWMN1yaA0h6RHsgFRsayIA1B6wL/nttinFmHSJaFGg9ifLvlhuPVr84IsJHvvOiRt72wGZztnuFbmsQ5GCqgnJDPG9m5g52ebo8TIeAZ3UU8iTyI66jfxXMxR+WWDQ4bkQglf6S+eQtKOQ8909thGd3RPhcxTzpCUweSKGjIXu2AS0QZ9x2gbQGHgIEHOKwdaPLeVT45ChBzjDm/dnTWA+BFTIVO04+F/43XjrkWObcOWd113TqYG790Bcucnw6kG7Yl8Ofj0zqXm4Z+etzkl1OgHfRQN4oPe92VMPmKbRfuLq541LN44zr5FQzwz3UnFpFyAePhEYOBOeUN8byxY6xZcuvWBnD8v7H8lcxdOYPlsPZGiwXZNpErxWnvjvLSDLVpNDgZ4erUr8NYxzfmYfmP0rhQeck11sNXHIe3WtYj1zoaINZmlFqy22Zx7fup0KcBUhEzChHaW4EvxGEg1x7xNzV/m2sooapGRdetFjk4kt50sRQPTxqPluna3Nht+YGFBj+fSfNH9WVtssXdWU13tRc4+BZa6wHaD2zkp3twv91pB8tyYf2I3RsA5cp1VzNURfxEg8oap6T/jrU2gLpg+YHNAssVb+JWSvG3TvMTO30qlUqlUqlUKpVKpVL3GCy/4+1vUzYsXKHM8PSt+oejzZiBDNcsGpLNrV+tmCd7ALA6HAh1ARTVhEqABfCZMCAaPO0H27GRVoXO+OCMD+Pz7sAvbBUHcMVn4AO3pusDP/JkETPRAUoW1iV4pyhhb6xUnLTYkr0UM0O2LNxfe2T9OjxwmMjdyIQIypqFAqQEHIloA2ZoEsiqOdxsj5gGOJXh2oWbOLZptznJej7dchjLBZrVXbd5h+/XbL1ZayyRxzszG7i9G7yyNzt2HklSgYqwjAcFu2OZsIQOX89HLU3DHP4RqDVb9Qv4gTNOAB7XDteTbmjC+4A5woXdEo0DAUqQ3YtjU3YqhIZ7OMpnPPMZlReNcOchb/hcoGUDV9/MulVPwLnqOlt1PZ3WV5eXHiGx8mZYSznUZ4AtAC/KaMbrrDZnfL/1+YUaAxLK18gUOc0RGeEQuLVmts5XAFvOy71iClDEYLEE13rl5K/YgYMsO4AVfCOIAvzyHGJNSd++zsZfAGS+Lb84wgVw5eutTckil3fBxoqjzcejLZYq6CwWaO4VW/8FKiOCA3Nd8BjRBZ5NzS8AZi/mDNiJMLPt1RXn+Xp7ixEgi3C3soAkQDaGG7lxOxYqOW8gXURtTKItQMDc0V34JXKim4DcGM8WvPt9EjCwXqom39ZFR7I3lixz1I8hHKwF+jdQOqCkMc9Wx0PnM5m6xy7QHXp7vq4c/3LQInZY0NhdtB4bFI0flRntbvP4X8Buj62JYpNOFNswHC6rg552RWjaeOHLR40QEseo3G45cvW+JcYC6y7ifTkM7hQukLi6zQNz8z09m5nnGg0sAZdRnPCoEd7vkRfdJFPHfIicdxbsPNOXbu/mKmi3QMQ5HMiE54wcMRbrFFsz6GfjVifhBSLNPw+Md3DKNcOdxgfmmEfTPW4OaAz3dYcJRzrSLbz4qUaFGB7tYFg080X5+JHVrTkz5/2M8/UYnoizYU6yMSoIRS9EYey3yjzvulu8L/t+brbq/G+kGiyyGHQcbdmjaNeXiA/+XWQkCK4N/obp+kTxiF++I8NH+KTo0JRlpmb6VCqVSqVSqVQqlUql7j1YvrwJwLewi2tbNi4TcHGHnwMIZMROGk1hGzFcWvzgHnY1PXb6Sba6Y8VeHPgRLMsdhi3PcoU68AkYGs66cHGVrcThCgznXo25iMaBeutw5wqAKJ+z/q518AZQutPWYYLr9rxKAHU9HkWGVLdeNLibAc6H49kzPWtuqwNHvpQ312KTLAAmAI6wTofzsDoAW+4WDsTWiSgnpAO5AoNaJ2jTgLFxOk+uWzhZI3s5nJ/IASlwSq5qZTYDwsdWcXdrOxCKHGs2uoMzHa6+MRoQChCrGWDn17rnWzGieY4olTXB8gIN8volXbzxvHD0gpqW5nMeKVFOPHJ8CawwVrgmDrC86Z+abAWEa8bpJGYk5gmd6PWqFDCK141x07xswHRxxYfzNrJfHTj7PALEmvcYA9Azh6fBVRsncGytJx8lXPOIgBITAMf+oX5hzH3bv56KY43zbe+N5r9P4dSk42Mdpuk/GudqMceWwN6pkzvu6ylLviNcFlycPrZ9auxjKGMcLtk2UzjGr1zeaDo6V4Fg8jrT9UxLkTf1JDhmt8ziRC/nzXkmyFtB/em9duLa9mIZnLvVLezvU9zHfge3jmV/jO55B7yTEanXosaL1GiF+PdtA1sudzT+a32+J4+La9g0hIxtGBFRpGWgzjFG+JSr5ZCXDUCxHuJvUOxMieOfVC58TNr51k7Utins6QmdfK8Tp45Nc45s7DmpzsX93Dry1Wug73vbM7JHhR9cFzbkpBsc6w1ObpxcrzuOvdcLw6VcHOSNkzz+hpTvMZpeVItCUCqVSqVSqVQqlUqlUvcVLL/5/3uzYgbgHLWjnZ2P1vdomnewAU5cNJDrmepZtnizIRWdzAKo4pS1qdY4VIBCp/MSbmRlKpc816PZbrfjpmi6ZdmoDc4v/zDMIGNBE0QndIu5reD6Opptd4fiysQH7w6xFEuBOWz1J2IYADYV5wBwK5dYEC1BREQt1A/tTe5uwGhGfei7MmujuVM44QSi4P4tQBuN9ugIhqMN+cE7myM+gxENctMKDDh85nvpcgH+1Xxchzo64LJNethHrINnggJeROhuuDJ5KIJqbEbojk/iEIL7mQ0DIkoG6w6jLT0HRFDfW4eFsw9cA9cHzkQATzojkRktR+3yABus3Oy8zgTLuk67q62yiMF9kbWNBlYX5zzPng71GZtbYRwWq411KGwYrpdiLTiu+M787bnN8PuuRqZM9rgXh7Ln/sYEbBURGeEijYxgB+31OS35LESxAUCB3asrOZz4iknADx1483sASAG0ajYMV6tfNN+OD1sx3NtqKrhTruq4sxFRAaWOUwspzFKmc1nxM8jJxXyHQXu/3/Ja4ns3IDpkYQu64t2xSxAl13eogFcfDwJBFmcCobeFlnKjV7hF/2fNnxX5jjG8/R1uG+vmvet7xJi1MK6CYjWK85z0OFZ3DdOBSrdojViIXF7lmOv+OM6RRa38ZN0q4XoNuDi6G5uTyWZwmLqL3DhnUTzBWGJ85ThV4cjnZrneTZHBKeLxuPcdAToPuMhxjUpcQ1M8YoQQ87J1LlxzuXZo50TMkShEWJsnXc6jKcyVSA6s9777wseMdwynrxzybAqJddWLhbE+CRZ7ERHH480U0YwzXMtcO7EeH0fc5Yy00Sq64+/2iMoZsCasbGGr5trXgo94bOTn628Go5JOzrVOJv0QcyDW9VjjtRZEYU7HW1zzUcRogC2jd7xwVQqhfu8vV0u7fuM6c+Jvvu3MdlutfcNuywiaw35bjNhHONDR2dXHeGTBtQX2Wkf2yKVH5j7z0TWP0QhWeetwOmvnBl50PvZoJ6q1lH9P455NuJxKpVKpVCqVSqVSqfuasby1rhtst72y/XZrw2qlxljciqwP0NimXHJDWxdc24F+YjJsmoq5k5XPia32EUMbbt6GVRV3WOPGZG4sX6JsTi7vc5qJ2fIuva5/wHao1xzh5N/hjrvthSa5oc02/DjNxqlX39zBNdzY7niOr8mYhfM5mpYB8AKwl9xMjkB9zxib4tRrXIB6kH+/3XnYwrjJlv/Ijg1zcXPurY+uQConVqU5FgCnXwflrUZ2rztncU6ME/HMZuRaA0Z7dEBxLAMYo0EgYjHwBaDs+cmKyEDVoTdDbEdEfbRXhYflUQ0OnTwA9eSalok5jXOoyNFBYXPSZY6d5taGa9HnTsQOcJwijqNxLJ9OLV2MJkQlwKyPGQEdzveguApAx3DZerZ3O+/iWkb0AECUmmgeGuekHMv1vMJtG/nP7dya3hORr1xb353eMSd20Nal3D6jZBs3Zt7m8pRpHsN6+3Run1WuId2lTbquxt0dxu0uhfaWUXKvonK8GVsc9u2gsvkqA62dEJO9AZG/4DEJ5X4sgcP1XmQxovzMYzHinr5t/ravE8Wk9rnTOJBwgpefTM4r1k4vssSujeY+YKTLfLqro1l5KwZtAGtx6pdaVy3I6XCnYD3yyfl/sV5gb8kIUKp7Pgpqk1uvNX636/Hk3Or9eZvDulz7O8+pNh87Hqe4lpMdDO1rIZYHWfMdGtt2BML7vbKUCYcZkePNNaMJYfxdrTx5cknrPevraLsYl2OJBpr4m6O/r2VnEItadzzJVCqVSqVSqVQqlUql7hFYvnXF3Mq3/vZv2/bypj3wtKfZYdgz87ZbnhHuwTkqw1g0P4uM1umHbH74Jtxw93AA03CNnQAkfP5FvvJ8URv0MbJCL0awcogsWTvYfufN/AAi+UF6YSNdaPoQzkOES5DAUvmXI7NQK6LFh3qcn7aCC7IFnAYEYI4nYxmabfJ8Dx20tnbrHPCeBBfjQu4+j4Bw8urOSf0PcBSO3eWyp7sN+coA7QIp3nwKma4EacrbpEuxOCwjU9VdaISr7jD0KApkj+JXwJB8XmSt6gX4PmzmtlhYjwzpaH4VUGfwpnHh4gOU9EZwGsPIUFVcB6yLl3QuH+x8trBuvfGUkFlxI+P9cX0xhjjn+XKpfFQ3T4Y7GU5kxFkg53u+6HWO7uIVvPfsb3rnmfDqsKluWY95c8oDlY3bmB8LmI6mjbVgcTuHmULYqU5AKsNa498Ow/2YC/zj7wLdeVZucYzrtXTt/JgxEbFDAN/YYBBjj3m7sH45MP98fznYuHcHviH/2XNnbbTt7soO496urtbKC0eeOh3gyNV2d/H8JEqmaQKmYX1soCYYifcNgF6rRMo3bhrE+TxsR+4Ustc39liFphhSr0G9FhEXIDApcMfmoGUt0U4JNrQMuurjM0XbAf8VX8E7o83l5dzA/W7uVI6Om95YDySVa5g7WqNYwfVFbmTvJHpb5I7uQq2tBIm875G7605yj4YoALv0LQxQLXipNcJzn/1xXAm8QWA4cHWZHLl6hAqzex1w67r5sUVUtr90HHGJgcCU77W2+URyKIrr4texGNpndPQuOszn6gTnfcBroiiXgxcQI+NeA+m53Z6p3d6OddNCLbDU+dE4vRFHj10u8yN3lETjQRwHXd8Eue5w59rku3I8v5/X1ddHHCty/PV3UGuMCmQLW23WPOnd7pKu48ubj3Lw+hVy4rEmqrEm12Q/ZjwP6zKuI3cdDAfbXe2Z0XzYw7WMtR6FNzi5exsPagyruY51ofMmqIoJutNqlUqlUqlUKpVKpVKp1D0Hy2iuNz/M7Naj+PA7EHpuztbW9Qdv7hVQ0yFngB9o9lh72SMTuckSdnA7cX/61v8CteJ1CmBzUIAP+wCmB8/l5Qd9B6kBcJhP7CDKYYFYjL9quC0BFwg9aiZphVcOsBrHb9ky3eaxuhUQMEIRyK0bNR7vrrGTRlftV2R1Ft5YcqHVVLC4DJsxCxgc27PjeeCkHRvRYaDVuFBAtgFq3iRsRLbucTomeFyFbfF2nsMajsY4PsYAaOwxf4Zhb6sB2/P9aR5PwGaQHqnB+JTlivCZMJ0WdKAQNYw84gR4bQWW41gF2QTJ+P4ed1Bdna2JuHGN+zzUnK3H1Viv62nG2HrDPad29TUby2zNr21GKpyghHZBI+dqOEfYVnPLJ7ZDdpGs1/GOYkFjbnNEVzCaRM9hUWYBSAVAB2f3UMF7zBk0zxy0jR7XCQ3FIiZDc1LQ3DnjZFhU1GhY8p3Ass8bweWTHPb4fZNaEY7Y8tSTnQDvlOplb1y0NbZArvVobhbk0TO/Jw5Rn1cni1nNco91rnEHl39HxrI3AB193WFRKsbKCzDFsav7SIW3WBvqGKmcIHqrIo7AobekK+U5ubKbIY6IlcaxXO6XBrAW+3cz/csdwmiJ6iSuj7/DtSzPr0UfNeFU/E1xNZeKlR99U7hglBDjhKKaEjEUXugbBVj52JJnXOM6po0h28aSpxPlxDnv2fC8Ah6XEscWYxmFxlIALWuy77YpuwNUdEOjU2Vhe3EIY9EvGDujjPOBcFnzYmmrtX4e6yojOsJd7fcYrgX+RjHahhE3mhOq76IgiMgWrQElmobngGOM3SB3voypVCqVSqVSqVQqlUrdU7DcMW5AkRjIuFyu1rbebGxzZnZ2dqHGcuOeBsrjuKabNpqc6UN6uy/5SLcVIS/cuIy+iMZvFSyz4Rvem5nIco4JBB9t2Hmjo1EN7JjRSwgAp/Fo43y0RSeXLhrA4QO4XJ9AMMjvxGvPCpABkCg5pQ6ykSGNbeoAbXwPdypHIyTBDJHdAxrwedRBwFg2Z0L2M7Iym/yBgDzFswn33X7P90EEhPU9ncJwaTN215138f6R2awGYR5nwAxX1wQy12EnREYkRTgDAWh55p6h6m5fNpha9nS0rdYrbd3uAXLj+rj7MlzONrNu3tm46GzslsJcfC1cA+SjLmzf7RXdscDryC3H38Hp1zWN/3BMdGkv1dCRsBXH7HA9GgDSKi5nrpyMp/v53bl5anmdfpu4WhtKXL5jDha4GXkmfGihxPrGeJLmNYINV053ogA9ymMtcLDEYThx8/tA71Pfs7xGON/L7+YNhI6c10Gu48XKZgvM80vdp3E+BFiAW8h6HWx3tbNFv7Ou36kxIrKVm2aFJW/XXcYFot1xDGNoGfitppMFqLbN8iq0m0QllFiNmmsbDemKszyiOhzylpnfwtQosxSwVl3OAYYjXqeMS33pujbcYdZ4GnqFkpExj8vD3HK4yAfl5I77cpwEpm2ru4C+Di2xpgrn1/lZ4mmKZx2PHXys6jWHABU5KqVAhjVHTl457B1SN7EOddVo4yw8CRnr1FCPr0ZBROEPj1UsEtdHuGvnyM7HbpbI2ff5Zk0MC4JcemWF43GCyoDQKgyWtZNFRawJutc4jnADl0KPu8abORHrnoCvIh/mZXzqVaQTmYC37gAo67PnZ5csbY+QoHu7WWfjGmJnxjjuOMTjsC2xUHg+dsFwdw3GAEUMFjK08WDYo0i74xjg74mud0QKua+eOxJip4D+QmI3B+YWcpURpRFHw8Nkcc+d7O2609yjt8VEpVKpVCqVSqVSqVQqdT/AMj6dXl1d2uHWYJvNmZ2f37QOgHB25JcBLKNxE2Mmlg505NjD511+/HfoCUekQCk+WDuMcHgYH+7R+A8ACrucF2xq52kB49F2yKXEi3pebWxVxkfkAZCXmb1yeR3oGtMWbscIZkcQs2jSFtEcglMyf7qT1gEFM2jxr0ObhxxNjxRTUZ2Vvm3ZoQEgQYCUAE/arh7AEgBi4Be2gAMUwFUcLu2ASYTHTLrwrfsTzhnuvNjKzp+W66fnw02MeGacY3XN0klJ9zOaEKKpE8DykmAHzvTFsucWdhwP4yngjMMxefPFrlNzsx4O4g4FBWFSgOWuQ/O9web7JYEInMaANgTLPeAx3g9b+efcni0aoutOuDPHvONe/eKCFLQLqF9dnqGIMo486AKT7+igv32sqk6e6NA2Hl4clw30VXHh//KS7WsXwBpN7MLF7JA03qBxY1ZzaUQMRCRJkGwHkZwbC95T88WSYMrmS0Uy0MlYo3b5dIBlG22PnQnbvXXLne37HcHWAg076c6O47qDqzqOZ3LSrVvb81/jfHj5ajO12DEQ41viBCavVYs/jzmkJ/+t6zJ1rga4bgF081+NQz/WqziviKI5cXiW38fj6/shwoDTBu5vPI9RFMoZZ7PLU7DcXBhFGNQ3Kc7rJt9BcTgBmhEXo3s7HLqAyyXmwsFyheTxeg6KJ6A9Tmv6vlwH3YV7ep0j0ohxO93CumXnYHk5BcsOflkS8tyMgMh83CQXvK6bMR8CDHNKMmYi1gQsRuHcjvFx93KJ03Ew7I8pDRjdHYwiJF4Q7RBjLpS1fobipTd99MigUtwoUwljeVADTayVh50XAAT50WCWLmMH5VjvIn4EDQxx7QHZUXjljOfhad3lORzQxC/WmjguQHAcPzKaY44eDYk4+ptaG+FGUTT+TiRSTqVSqVQqlUqlUqnUEwKW8cGfLrpRuYz73WCPPPwoP9heu/aI9auVrQETmLAw2DDf83NzNJwrH/Rja683G6KLjADXnVn8gOzgkm60eHb4IQU19tsdP0TjuBR5UZ16cGiVpmUAAXDBcce/AxHmJ6PxoGAVAUyA7nD3NY5RvqY7lGObNnMyYytxCy/LduPqBIPzWFIONFy7ctm553A82ADH8h5gHuceGatNsgBNegDADojp0HbU4m5jGft8HJmv24CjCSiZ5l4L6rtLkG5iwF79W7Ebvn3bYz5oLMfQ+Gkpa1oN9pjLy3iLiNPorB9HuheHYWebs3Pr4GpewBGNDF/kfQoss3li0PLiiI4ogbptO2B8GZj4eZPPG82+mqk35Z8xIiL+kzk6UYlLiONyYh10MX4XbsmT1ymstTRmq4+fPkgFEh37ya8mbmB3zJ8ebnlg4xv23xPoHY/Mrsa8RxwB5q4iYfxp3hANjmXMD+QyD0s4YTs7LlGQqFEndKxGdnAMqt97ZZ4F9G+Og9eHl8zdznxKLZxouEtgcVQRJtfxdNjrOUdDuyl4Lg0cWxIc17zEeJxcNx5jG2fTuLvLcTSO3uKgDpip4+D64EBe96BylxUPgzFgta30cWQhiktSEyUUUez+382hN05TwdDIjWdcBWEkIDbWUOUyAypr3DwzmW7nyI5vzr2cqkPX4vytUSIRSRQAWNFDcY/67wPgO4RdwDkb098jL+I5HdcarKnTqAr9HYhs8XDiKz//CFfv5HrXazbiHAFgPSanFPRs7jBZxcBatJNjWX+TTm+pcCIrQ17TVwUSZVRHAVBuZTsMNh92HLXFqPupNCFlHnRnR2TpL9d8X/Qm4N8wFjD99SOf3wstKE7See1/fwivvTFqnD3W4CiyajcPHjMgYr8UPSLXvriw8drpWE6lUqlUKpVKpVKp1H0Hy0vfqmwrs6GzW7e2tr16i22vrmzZz219dmYPLQGfN7bYrhW5MAf8dPedwx+BUzWNEhCFm43WvpJ9PA77si1/RKxFbNUHUMYHds+FxWfk9dnKliu4jz0+gR+s4YQF6JRbkAhkIZiMLwCuYXelD/qxLdr7owVAEKyAu7PC4jDJMiuUB13hTkkV9sZ8ghYHQoCerlvlaA5wpLlbWPmbiMHY2f7qihEE42zn7tCDDG3+5VHFnv0sqKAMXcFVgkIcMQ2PDu7dpahjbBsMCrwHvJsDOMF1Pp9bB+jYRRRGNCf0SeAGWwIMbg3X6xJUOoTGv8N5iN/himh8cd0OnhuN81/avF/jxcy6pRyrQc3YTJCjS4c43ysiMApDqvSp/Ou2pm2eUztxuZ5i23CatiA6IFvENdQc7AKXvSAxpVCP5f9r3qMA8fBEtscydfhO/n0KAMvhN+dbOHu1IgJqdWiE2HW23Ow4zv2juOc6zkc5HBW5gNffXu5t2GMHQm+LeW/Hw9yWK7jQ5dIsedr+3io6NI7uCaQNR7k/tkB8uUx5w7mrnzi9cZMWoNvEbZSiA9+mXquISjitD0QDT5yz3P7uKC7Pq+7Omo8cXLnGLkSyL3dW8AE6l3BxA/j5Gfhp6VjG+cHmKCAR9Akwzo+ds2bFyoRjXdER3iiSMUFxj/lB4eE4BW4CcQcrnw8ndBNLwd0XWtbHETEMEdegexxZvxgHOGoLmCwNPwPYRw524+blddf7Mwe9E+TFToYoskVxq3XT42fMqsdZHju1FMV1VC9IxRXNZrZaLvkaU9grkFt2gyy8gae/POImcI6+XaWJ8sEra21FPrliruXIj2vWFjTowmYDUs+Pn8sdHXnb0ZgU3+k2xvVhU1kshtEsVH874FCe4w33C5uPvYponBgYp07NCBcrrtvriwubAzCvNra73NvxsCeYBkQeuO4hRkr3A6Ocitu8uqLx91Q5+SgcYS2d628mnPEjYpyu+LfRxpV2ivB4sQtIhQBFmjSxMKlUKpVKpVKpVCqVSt0PsFya05Wt4IizUFayAEDN3Wy6cE0ayhWXXZMTXOx3lQ6WrwIX/HkBQqpl79Rt6s62NrvVoa+2StfN9YpyjeMLl2+7N/jEbVlf7GSL/js/dvWV6xiVsy1jBzZ0J+fsnV87nIw1i3T6u8fiBTFGBYvHhY1hK8zUr3vz/m7obGza8ZjG5VmgtYA3mkl5t776nBbWVip68u+TzInihi2/vcMwtT89fUR9nWmzvcfQaSO6xnTcPmt2R4txcx7/t0O609ve8U3/Hw8+Bc/l923jsub7ibO58m+/Vx2Q3va48l6T0XyMsb7TodZnTWNb6jy680AIVN/pWj0mF3OLs65Ee3E09x9rvkwymJuXKsfBuJLWSR7H0d6ArVn9hHr7+haRwO2xljF4rHv+Dv8quS/cCVEB+qSpZRmH6bjXekfkx7TTuK7F5dyj3OZrRj3WO2Sal/eY2oDrX5N2gO887vVMtU7c+ZqdvmI9zju9WMmUPsmVjjFsNxVMHjM5tHoftb9q8/TbI582Rry90V/riH9MtY0PTx5629jFPXryeu1ulcnPUqlUKpVKvccp/rY//PDD7+5DSaVSqdRTSA/73513xoQ0O6ZVKZVKpVKpVCqVSqVSqfco/cZv/IY9//nPf3cfRiqVSqWeovrN3/xNe5/3eZ9741hOpVKpVCqVSqVSqVQq9cTooYce4vc3velNduPGjRz2++jMe+5zn0uAcv369RznHOff08r5nON8LwQP8iOPPGLPfvaz/5+PTbCcSqVSqVQqlUqlUqnUe5ii/wigcgLP+y+McY5zjvOTRTmfc5wfr97Zgqb+UqVSqVQqlUqlUqlUKpVKpVKpVCr1TirBciqVSqVSqVQqlUqlUqlUKpVKpe5KCZZTqVQqlUqlUqlUKpV6D9NqtbJXvepV/J7Kcf69rpzPOc5PJuV8rpodkcicSqVSqVQqlUqlUqlUKpVKpVKp1DupdCynUqlUKpVKpVKpVCqVSqVSqVTqrpRgOZVKpVKpVCqVSqVSqVQqlUqlUnelBMupVCqVSqVSqVQqlUqlUqlUKpW6KyVYTqVSqVQqlUqlUqlU6t2gb/mWb7HnPe95tl6v7cM+7MPsv/7X//p/ffz3fu/32vu///vz8S94wQvsda973RN2rE+Vcf6O7/gO++iP/mh78MEH+fVxH/dx/8/rkrr7cW71Pd/zPTabzexTPuVTcijv4VyG3v72t9tLXvISe9aznsWGc3/oD/2hXDfuwzh/8zd/s73f+72fbTYbe+5zn2sve9nL7Orq6ikxnxMsp1KpVCqVSqVSqVQq9QTrX/2rf2Uvf/nL7VWvepX97M/+rH3QB32QffzHf7z99m//9h0f/x//43+0T//0T7fP/dzPtZ/7uZ8jhMPXL/7iL+a1u4fj/PrXv57j/BM/8RP2n/7TfyIk+jN/5s/Ym9/85hznezjOoTe+8Y32ile8gjA/dW/HeLfb2Z/+03+aY/yv//W/tl/+5V9m4eQ5z3lODvU9HOfXvOY19uVf/uV8/C/90i/Zd33Xd/E1vvIrv/IpMc6z4/F4fHcfRCqVSqVSqVQqlUqlUk8lwQX3IR/yIfbqV7+a/z2OIyHmS1/6UkKKU33ap32a3bx5037wB3+w/OzDP/zD7YUvfKF927d92xN67E/mcT7V4XCgcxnP/6zP+qwn4Ih/b+pdGWeM7cd8zMfY53zO59i///f/nu7aH/iBH3iCj/zJO8ZYF77pm77J3vCGN1jf9++GI35qjPMXf/EXEyj/2I/9WPnZl3zJl9h/+S//xX7yJ3/SnuxKx3IqlUqlUqlUKpVKpVJPoOAk/Jmf+RnGLITm8zn/Gy7ZOwk/bx8PwUX3WI9PvWvjfKpbt27Zfr+3hx56KIf0Hs5n6Gu/9mvtvd7rvejCT937MX7ta19rH/ERH8EojPd+7/e2D/zAD7S//bf/NoF+6t6N84te9CI+J+IyfuM3foNxI5/4iZ/4lBjm7t19AKlUKpVKpVKpVCqVSj2V9Na3vpVwB7CnFf4b7sI76f/8n/9zx8fj56l7N86n+rIv+zJ79rOffRvUTz2+cYaTE5EBP//zP59DeZ/GGIDzx3/8x+0zPuMzCDp/7dd+zb7oi76IhRLENqTuzTj/5b/8l/m8j/qojzKEQgzDYF/wBV/wlInCSMdyKpVKpVKpVCqVSqVSqdSJvuEbvoGN5b7/+7+fTbxS90aPPPKIfeZnfibzfp/+9KfnsN4nIcIBjvBv//Zvtz/+x/8443S+6qu+KqNz7rFe//rX0wn+j/7RP2Im8/d93/fZv/23/9a+7uu+zp4KSsdyKpVKpVKpVCqVSqVST6AA0xaLhf3Wb/3W5Of472c+85l3fA5+fjePT71r4xz6e3/v7xEs/+iP/qj9kT/yR3I47+F8/vVf/3U2lHvxi188gaBQ13VsMvf85z8/x/xxjDH0rGc9i9nKeF7oAz7gA7jLAZEPy+Uyx/hxzmXob/7Nv8lCyed93ufxv1/wghcwD//zP//zCfIRpfFk1pP77FKpVCqVSqVSqVQqlXoPE4AOHIRtsyeANfw3MlHvJPy8fTz0Iz/yI4/5+NS7Ns7QN37jN9Jt+EM/9EP2wR/8wTmU93g+v//7v7/9wi/8AmMw4uuTP/mT7WM/9mP5bzRKSz2+MYY+8iM/kvEXAe2hX/mVXyFwTqh8b+Zy5LCfwuOA+YjGeLIrHcupVCqVSqVSqVQqlUo9wXr5y19un/3Zn01w+aEf+qH2zd/8zXS5/dW/+lf5+8/6rM+y5zznOfZ3/s7f4X//jb/xN+xP/Ik/YX//7/99+6RP+iRGNPz0T/80t7mn7t04/92/+3ftb/2tv2Wvec1r7HnPe17JsL64uOBX6vGPM2JF0Eiu1QMPPMDvpz9Pvetz+Qu/8Avt1a9+NdeOl770pfarv/qrjGz463/9r+ew3sM148UvfrH9g3/wD+yP/tE/ah/2YR9GmA8XM37eusWfrEqwnEqlUqlUKpVKpVKp1BMs5J2+5S1vIcQEvHzhC19Ih2w0jXrTm940ccG96EUvIux85StfyaZQ7/u+72s/8AM/kCDuHo/zt37rtzIm4FM/9VMnr4NmZ1/91V99L6fAU3qcU/d/jOH8/uEf/mF72ctexjgXwFBAZjSkTN27cX7lK19ps9mM39/85jfbM57xDELlr//6r39KDPPs+FTwZadSqVQqlUqlUqlUKpVKpVKpVOqeKctFqVQqlUqlUqlUKpVKpVKpVCqVuislWE6lUqlUKpVKpVKpVCqVSqVSqdRdKcFyKpVKpVKpVCqVSqVSqVQqlUql7koJllOpVCqVSqVSqVQqlUqlUqlUKnVXSrCcSqVSqVQqlUqlUqlUKpVKpVKpu1KC5VQqlUqlUqlUKpVKpVKpVCqVSt2VEiynUqlUKpVKpVKpVCqVSqVSqVTqrpRgOZVKpVKpVCqVSqVSqVQqlUqlUnelBMupVCqVSqVSqVQqlUqlUqn3OP2Vv/JX7FM+5VMe12u88Y1vtNlsZj//8z//mI95/etfz8e8/e1v539/93d/tz3wwAPl91/91V9tL3zhCx/XcaRST0YlWE6lUqlUKpVKpVKpVCqVSj1uCAw4i6/lcml/8A/+Qfvar/1aG4bhPX5kX/SiF9n//t//227cuHHH37/iFa+wH/uxH7unwDuVejKoe3cfQCqVSqVSqVQqlUqlUqlU6ve+/uyf/bP2T/7JP7Htdmuve93r7CUveYn1fW9f8RVfMXncbrcjfH5PEY7lmc985mP+/uLigl+pVGqqdCynUqlUKpVKpVKpVCqVSqUet1arFQHt7//9v9++8Au/0D7u4z7OXvva1xaH79d//dfbs5/9bHu/93s/Pv4XfuEX7E/+yT9pm83Gnva0p9nnf/7n26OPPnrb637N13yNPeMZz7Dr16/bF3zBFxBMh37oh37IPuqjPorRFXiNP/fn/pz9+q//+m2v8YY3vIHO5PV6bR/4gR9o/+7f/bvHjMI4VRuFgX//03/6T+3f/Jt/UxzaeD7O44u/+Isnz3vLW95CaN26nVOpJ5MSLKdSqVQqlUqlUqlUKpVKpe65AIwDAgOu/vIv/7L9yI/8iP3gD/6g3bx50z7+4z/eHnzwQfupn/op+97v/V770R/90dvgLJ73S7/0S4S3//Jf/kv7vu/7PoLmEF7n5S9/uf30T/80Hzufz+3P//k/b+M4Tl7nS7/0S+1LvuRL7Od+7ufsIz7iI+zFL36x/c7v/M5dnxNiMf7SX/pLdGcjPgNfANaf93mfZ695zWvo1g79i3/xL+w5z3kOoXMq9WRUguVUKpVKpVKpVCqVSqVSqdQ90/F4JCT+4R/+4QJVz8/P7Tu/8zvtD//hP8wvQNirqyv7Z//sn9FBjMe9+tWvtn/+z/+5/dZv/VZ5LTh+//E//sd8zid90icxt/kf/sN/WMDxX/yLf9H+wl/4C8x0hqsYj4UT+n/8j/8xOSYAazz2Az7gA+xbv/Vbmaf8Xd/1XXd9bojEADAPdza+cIw4BghO5hCaAEb2dCr1ZFSC5VQqlUqlUqlUKpVKpVKp1OMWnMgAr4ib+IRP+AT7tE/7NEZHQC94wQsmucpwIX/QB30QgXPoIz/yIwmM4WwO4TFnZ2flv+E2RlzGb/7mb/K/f/VXf9U+/dM/3f7AH/gDjMp43vOex5+/6U1vmhwbnhfqus4++IM/mMdwr4Rz/szP/EyCbehnf/Zn7Rd/8RcJllOpJ6uyeV8qlUqlUqlUKpVKpVKpVOpx62M/9mPpBgZARpYyAG6oBcj3Uoi0QKbzd3zHd/A9AabhgG5zmJ8oIQ4Drun/9b/+F5sYwoWNY0ulnqxKx3IqlUqlUqlUKpVKpVKpVOpxC/AYkRS/7/f9vglUvpMQSfHf/tt/Y0Zy6D/8h//AjORo7gfhMZeXl+W///N//s90RT/3uc9lRjLcza985SvtT/2pP8XXfNvb3nbH98PzQsMw2M/8zM/w8e+KAM4Ph8NtP4crG05oQG5EfXzO53zOu/T6qdTvFSVYTqVSqVQqlUqlUqlUKpVKPaH6jM/4DMZHfPZnfzYjI37iJ37CXvrSlzJO4r3f+73L4+A8/tzP/VxmJr/uda+zV73qVcxLBoBG47+nPe1p9u3f/u32a7/2a/bjP/7jbOR3J33Lt3yLff/3f7+94Q1vsJe85CUE0O8q+EXcxn//7/+dUPutb32r7ff7iWv5G77hG5gzjSaCqdSTWQmWU6lUKpVKpVKpVCqVSqVST6iQm4zmfr/7u79rH/IhH2Kf+qmfStcxGvi1ws/e933f1z7mYz6Gmc2f/MmfXHKbAZe/53u+h+5jxF+87GUvs2/6pm+64/sB9uILmc0/+ZM/aa997Wvt6U9/+rt07H/tr/01uqrhTn7GM55Bp3UIec9wa+M7wHkq9WTW7IgSSiqVSqVSqVQqlUqlUqlUKpV6XHrjG99oz3/+8+2nfuqn7I/9sT+Wo5l6UivBciqVSqVSqVQqlUqlUqlUKvU4hDgMZD6/4hWvsP/5P//nxMWcSj1ZlVEYqVQqlUqlUqlUKpVKpVKp1OMQQPKznvUsOpW/7du+Lccy9ZRQOpZTqVQqlUqlUqlUKpVKpVKpVCp1V0rHciqVSqVSqVQqlUqlUqlUKpVKpe5KCZZTqVQqlUqlUqlUKpVKpVKpVCp1V0qwnEqlUqlUKpVKpVKpVCqVSqVSqbtSguVUKpVKpVKpVCqVSqVSqVQqlUrdlRIsp1KpVCqVSqVSqVQqlUqlUqlU6q6UYDmVSqVSqVQqlUqlUqlUKpVKpVJ3pQTLqVQqlUqlUqlUKpVKpVKpVCqVuislWE6lUqlUKpVKpVKpVCqVSqVSqdRdKcFyKpVKpVKpVCqVSqVSqVQqlUql7G70/wMFtEMmG8GS9AAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "def show_image_with_predictions(image, predictions):\n", - " \"\"\"Display image with top-k predictions.\"\"\"\n", - " _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), gridspec_kw={\"width_ratios\": [3, 2]})\n", - "\n", - " # Show image\n", - " ax1.imshow(image)\n", - " ax1.set_title(\"Input Image\")\n", - " ax1.axis(\"off\")\n", - "\n", - " # Show predictions\n", - " class_names = [p[0] for p in predictions]\n", - " scores = [p[1] for p in predictions]\n", - " y_pos = np.arange(len(class_names))\n", - "\n", - " ax2.barh(y_pos, scores, align=\"center\")\n", - " ax2.set_yticks(y_pos, labels=class_names)\n", - " ax2.invert_yaxis() # labels read top-to-bottom\n", - " ax2.set_xlabel(\"Probability\")\n", - " ax2.set_title(\"Top-5 Predictions\")\n", - "\n", - " plt.tight_layout()\n", - " plt.show()\n", - "\n", - "\n", - "predictions = [(imagenet_classes[i], p) for i, p in zip(top_indices, top_probs)]\n", - "show_image_with_predictions(original_image, predictions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## **Conclusion**\n", - "\n", - "This notebook demonstrates how to set up and run the Bonsai EfficientNet-B0 model. You can successfully:\n", - "\n", - "1. **Instantiate the EfficientNet model** with the B0 configuration.\n", - "2. **Preprocess a real image** from the web.\n", - "3. **Perform a forward pass** to get classification logits.\n", - "4. **Visualize the output**, confirming the model's end-to-end pipeline is functional." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv (3.12.8)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/bonsai/models/efficientnet/tests/__init__.py b/bonsai/models/efficientnet/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bonsai/models/efficientnet/tests/run_model.py b/bonsai/models/efficientnet/tests/run_model.py deleted file mode 100644 index f6b00e12..00000000 --- a/bonsai/models/efficientnet/tests/run_model.py +++ /dev/null @@ -1,47 +0,0 @@ -import time - -import jax -import jax.numpy as jnp -from flax import nnx - -from bonsai.models.efficientnet import modeling, params - - -def run_model(): - # 1. Create model and PRNG keys - config = modeling.ModelConfig.b0() - model = params.create_efficientnet_from_pretrained(0) - graphdef, state = nnx.split(model) - - # 2. Prepare dummy input - batch_size = 4 - image_size = config.resolution - dummy_input = jnp.ones((batch_size, image_size, image_size, 3), dtype=jnp.float32) - - # 3. Warmup (triggers JIT compilation) - modeling.forward(graphdef, state, dummy_input).block_until_ready() - - # Profile a few steps - jax.profiler.start_trace("/tmp/profile-efficientnet") - for _ in range(5): - logits = modeling.forward(graphdef, state, dummy_input) - jax.block_until_ready(logits) - jax.profiler.stop_trace() - - # 4. Timed execution for inference - num_runs = 10 - t0 = time.perf_counter() - for _ in range(num_runs): - logits = modeling.forward(graphdef, state, dummy_input) - jax.block_until_ready(logits) - t1 = time.perf_counter() - print(f"{num_runs} inference runs took {t1 - t0:.4f} s") - print(f"Average inference time: {(t1 - t0) / num_runs * 1000:.2f} ms") - - # 5. Show output shape - print(f"\nInput shape: {dummy_input.shape}") - print(f"Output logits shape: {logits.shape}") - - -if __name__ == "__main__": - run_model() diff --git a/bonsai/models/efficientnet/tests/test_outputs_efficientnet.py b/bonsai/models/efficientnet/tests/test_outputs_efficientnet.py deleted file mode 100644 index 3fa75241..00000000 --- a/bonsai/models/efficientnet/tests/test_outputs_efficientnet.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax.numpy as jnp -import numpy as np -import timm -import torch -from absl.testing import absltest, parameterized - -from bonsai.models.efficientnet import params - - -class TestModuleForwardPasses(parameterized.TestCase): - def _get_models_and_input_size(version: int): - nnx_name = f"efficientnet_b{version}" - timm_name = nnx_name if version < 5 else "tf_" + nnx_name + "_ap" - - nnx_model = params.create_efficientnet_from_pretrained(version=version) - - timm_model = timm.create_model(timm_name, pretrained=True) - timm_model.eval() - return nnx_model, timm_model, nnx_model.cfg.resolution - - @parameterized.parameters([0, 1, 2, 3, 4, 5, 6, 7]) - def test_full(self, version: int): - nnx_model, timm_model, img_size = TestModuleForwardPasses._get_models_and_input_size(version) - b = 32 - tx = torch.rand((b, 3, img_size, img_size), dtype=torch.float32) - jx = jnp.permute_dims(tx.detach().cpu().numpy(), (0, 2, 3, 1)) - jy = nnx_model(jx, training=False) - with torch.no_grad(): - ty = timm_model(tx) - np.testing.assert_allclose(jy, ty.cpu().detach().numpy(), atol=1e-3) - - -if __name__ == "__main__": - absltest.main()