From e4f79912b331556a9d78b536a5a1fc08c2887077 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:02:23 +0800 Subject: [PATCH 01/11] add improved debugging Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 339 ++++++++++++++++++ .../poisoned_code.png | Bin 0 -> 150458 bytes 2 files changed, 339 insertions(+) create mode 100644 _posts/2025-11-27-improved-cuda-debugging.md create mode 100644 assets/figures/2025-improved-cuda-debugging/poisoned_code.png diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md new file mode 100644 index 0000000..d811633 --- /dev/null +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -0,0 +1,339 @@ +--- +layout: post +title: "Blaming Hanging and Complicated GPU Kernels Down To The Source Code" +author: "Kaichao You (vLLM)" +image: /assets/logos/vllm-logo-text-light.png +--- + +Several months ago, we wrote a blog post about [CUDA Core Dump: An Effective Tool to Debug Memory Access Issues and Beyond](https://blog.vllm.ai/2025/08/11/cuda-debugging.html), which introduced a powerful tool to debug illegal memory access issues in CUDA kernels. That blog post itself is a huge milestone for debugging GPU kernels, as it can faithfully trace down the exact GPU kernel that caused the issue. Prior to this, due to the asynchronous nature of GPU kernels, people often have no idea which kernel caused the issue, and the error message is often misleading. + +As more and more people are trying out the CUDA core dump technique, people also want to get fine-grained information about the GPU kernel, such as the exact line of code that caused the issue, so that they can fix the issue quickly. In this blog post, we will fill in a missing piece of how to find hanging kernels first, and then proceed to explain how to blame the problematic kernel down to the source code. + +## How to find hanging kernels + +GPUs are becoming more and more powerful, the computation power is increasing exponentially. However, the memory bandwidth is not increasing as fast. As a result, the memory access patterns are becoming more and more complicated. In more recent years, flagship datacenter GPUs start to introduce asynchronous memory access patterns, with complicated synchronization required when implementing high-performance kernels. Such synchronization is easily prone to race conditions and deadlocks, especially in a complicated codebase. + +When a GPU kernel hangs, the program will typically freeze or become unresponsive (even hitting Ctrl-C cannot stop it). One solution is to just kill the process. However, this is not a very effective way to debug the issue, as it does not provide any information about the root cause of the issue. People have to blindly guess the root cause of the issue, bisecting code changes and running tests until they find the root cause. + +Can we do better? It turns out we can. There is a feature inside cuda driver called `user induced GPU core dump generation`: the cuda driver will open some pipes in the operating system, and we as users can trigger a core dump by writing to these pipes. When the core dump is triggered, the cuda driver will dump the GPU state to core dump files, so that we can inspect the core dump to know what's happening inside the GPU, and most importantly, which GPU kernel is hanging. + +Here is a simple example of a conditional hanging kernel: + +```python +# save as conditional_hang.py + +import triton +import triton.language as tl +import torch + + +@triton.jit +def conditional_hang_kernel(x_ptr, + flag, # int32 scalar + n_elements, # int32 scalar + BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + + # Load values + x = tl.load(x_ptr + offs, mask=mask, other=0) + + # If flag == 1: do a normal "+1" update + if flag == 1: + x = x + 1 + tl.store(x_ptr + offs, x, mask=mask) + else: + # Else: non-terminating loop, no break. + # The loop condition depends on `flag`, which is invariant, + # so this is effectively an infinite loop when flag == 0. + while flag == 0: + # do something trivial so the loop isn't optimized away + x = x + 1 + tl.store(x_ptr + offs, x, mask=mask) + + +x = torch.ones(16, dtype=torch.float32, device="cuda") +n_elements = x.numel() +BLOCK_SIZE = 16 + + +# 1) Normal behavior: increment by 1 +conditional_hang_kernel[(1,)]( + x, + flag=1, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, +) +print("After flag=1:", x) # should be all 2s + + +# 2) Hanging behavior: this will spin forever +conditional_hang_kernel[(1,)]( + x, + flag=0, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, +) + +# this print will hang, because printing x will synchronize the device, +# and the kernel will never finish. + +print("After flag=0:", x) + +# the following line will never be reached + +x = x + 2 + +torch.cuda.synchronize() +``` + +Directly executing the code will hang forever. We can enable the `user induced GPU core dump generation` to debug the issue: + +```bash +CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ +CUDA_COREDUMP_SHOW_PROGRESS=1 \ +CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \ +CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" +``` + +Then, we can run the code and trigger the core dump: + +```bash +CUDA_ENABLE_USER_TRIGGERED_COREDUMP=1 \ +CUDA_COREDUMP_PIPE="/tmp/cuda_coredump_pipe_%h.%p.%t" \ +CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ +CUDA_COREDUMP_SHOW_PROGRESS=1 \ +CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \ +CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ +python conditional_hang.py +``` + +While the code is running forever, and we suspect it is hanging in the `conditional_hang_kernel`, we can trigger the core dump by writing to the pipe: + +```bash +dd if=/dev/zero bs=1M count=1 > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276 +``` + +Here we write 1MB of zeros to the pipe, which will trigger the core dump. Simple `echo aaa > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276` might not work due to the buffering of the pipe. + +After we trigger the core dump, in the original terminal where we run the `python conditional_hang.py`, we will see the progress of the core dump: + +```text +[01:39:15.256278] coredump: Writing ELF file to /tmp/cuda_coredump_hostname.3000837.1764236276 +[01:39:15.256350] coredump: Writing out global memory (0 bytes) +[01:39:15.256354] coredump: Writing out device table +[01:39:15.292027] coredump: Writing out metadata +[01:39:15.292039] coredump: Finalizing +[01:39:15.292124] coredump: Writing done +[01:39:15.292128] coredump: All done (took 00s) +``` + +Then we can use `cuda-gdb` to open the core dump file, and see exactly where the kernel is hanging: + +```text +Opening GPU coredump: /tmp/cuda_coredump_hostname.3000837.1764236276 +[Current focus set to CUDA kernel 0, grid 53, block (0,0,0), thread (0,0,0), device 0, sm 124, warp 0, lane 0] +#0 0x00007f2e6fbff300 in conditional_hang_kernel<<<(1,1,1),(128,1,1)>>> () at conditional_hang.py:31 +31 tl.store(x_ptr + offs, x, mask=mask) +``` + +Excitingly, we can not only exactly locate the kernel `conditional_hang_kernel`, but also the exact line of code that the kernel is hanging at. This is a huge improvement over the previous situation where we have no idea which kernel is hanging, not to mention the exact line of code that caused the hanging. + +One slightly annoying thing is that the core dump pipe's path is dynamically generated by the cuda driver, and it is not easy to find out. We can properly use `CUDA_COREDUMP_PIPE` environment variable to specify the path of the core dump pipe, so that we can find it easily by looking at the file descriptors of the process: + +```bash +$ ls /proc/3037675/fd/ -alth | grep /tmp/cuda_coredump_pipe_ +lr-x------ 1 user user 64 Nov 27 01:50 98 -> /tmp/cuda_coredump_pipe_hostname.3037675.1764237014 +``` + +## How to trace down the source code of a complicated kernel + +In the previous [blogpost](https://blog.vllm.ai/2025/08/11/cuda-debugging.html), we mentioned that compiling with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable will embed line information into the compiled binary, so that we can trace down the exact line of code that caused the issue. After discussing and debugging several real-world issues, we find that the default way of showing line information in `cuda-gdb` is imperfect: + +1. For some complicated kernels, `cuda-gdb` will fail to find the correct line of code that caused the issue, even if the line information is embedded into the compiled binary. +2. Even if `cuda-gdb` can find the correct line of code, it will only show the last line of code after compiler inlining the code, which might not be the actual line of code that caused the issue. C++ code heavily relies on inlining to remove runtime function calling overhead, and we need the full inline stack of the code to understand the issue. + +Let's take a concrete example to illustrate the issue. Here is a simple Python script that can cause an illegal memory access issue: + +```python +# save as illegal_memory_access.py + +from dataclasses import dataclass +import torch + +@dataclass +class TensorWrapper: + data_ptr: int + size_in_bytes: int + + @property + def __cuda_array_interface__(self): + return { + "shape": (self.size_in_bytes,), + "typestr": '|u1', + "data": (self.data_ptr, False), + "version": 3, + } + + +def from_buffer(data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype) -> torch.Tensor: + return torch.as_tensor(TensorWrapper(data_ptr, size_in_bytes), device=device).view(dtype) + +data = from_buffer(123456, 1024, device="cuda:0", dtype=torch.uint8) + +index = torch.ones(10, device="cuda", dtype=torch.int32) + 100 +print(data[index]) +``` + +Run the code with PyTorch >= 2.9.0 (to be specific, make sure it includes [this commit](https://github.com/pytorch/pytorch/commit/dae7710bf2561e9e8a8dc76fd30c68e25bd755b8), otherwise you will see an error like `RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device.`), and you will hit an illegal memory access issue. + +First, let's run with CUDA core dump enabled: + +```bash +CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ +CUDA_COREDUMP_SHOW_PROGRESS=1 \ +CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \ +CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ +python illegal_memory_access.py +``` + +The core dump progress will explicitly show the kernel that caused the issue: + +```text +_ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ +``` + +From the kernel name, we can see that the issue is caused by the `index_elementwise_kernel` in PyTorch. To locate the exact line of code that caused the issue, we need to build PyTorch from source with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, and then run the code again. + +When the compiled GPU kernel has line information embedded, we can use `cuda-gdb` to open the core dump file, and see exactly which line of code caused the issue: + +```text +(cuda-gdb) target cudacore /tmp/cuda_coredump_flow-matic.3756036.1764250282 +Opening GPU coredump: /tmp/cuda_coredump_flow-matic.3756036.1764250282 +[Current focus set to CUDA kernel 0, grid 4, block (0,0,0), thread (0,0,0), device 0, sm 124, warp 3, lane 0] + +CUDA Exception: Warp Illegal Address +The exception was triggered at PC 0x7ff533bb91d0 void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, lon +g)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorI +teratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native: +:gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef): +:{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{l +ambda(int)#1}) (IndexKernel.cu:118 in _ZZN2at6native16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIterator +BaseEN3c108ArrayRefIlEES9_EUlPcPKclE_EEvS6_S9_S9_RKT_bENKUliE_clEi inlined from IndexKernel.cu:37) +#0 void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at +::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef< +long>, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayR +ef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::Ten +sorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, +c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1})<<<(1,1,1),(128,1,1)>>> () + at /data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:203 in _ZZN2at6native17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS +_18TensorIteratorBaseEN3c108ArrayRefIlEES8_ENKUlPcPKclE_clES9_SB_l inlined from IndexKernel.cu:118 +203 *reinterpret_cast(out_data) = *reinterpret_cast(in_data + offset); +``` + +Next, inside `cuda-gdb`, we can use `info symbol $errorpc` to get more information about the location of the error: + +```text +(cuda-gdb) info symbol $errorpc +void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}) + 11472 in section .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ of /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn +``` + +This gives us more information about the location of the error. `cuda-gdb` will unpack the compiled library, and `/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn` is a cubin file that contains the `index_elementwise_kernel`. The error is happening at the `0x7ff533bb91d0` location in the cubin file. We can use `nvdisasm` to disassemble the cubin file, and see exactly which line of code is causing the issue: + +```bash +$ nvdisasm -ndf -c -gi /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt +$ grep -C20 7ff533bb91d0 output.txt +... + /*7ff533bb9190*/ IMAD.IADD R19, R23, 0x1, R3 ; +.L_x_27840: + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 203 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37 + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37 + /*7ff533bb91a0*/ ULDC.64 UR4, c[0x0][0x480] ; + /*7ff533bb91b0*/ IADD3 R2, P0, P1, R22, UR4, R2 ; + /*7ff533bb91c0*/ IADD3.X R3, R19, UR5, RZ, P0, P1 ; + /*7ff533bb91d0*/ LDG.E.U8 R3, desc[UR36][R2.64] ; +... +``` + +Now we can see the full inline stack of the code that caused the issue. What `cuda-gdb` shows by default, is only the last inline expansion. + +A bit explanation about the command: + +- `-ndf`: Disable dataflow analyzer after disassembly. +- `-c`: Only print code sections. +- `-gi`: Annotate disassembly with source line information obtained from .debug_line section along with function inlining info, if present. +- `-C20`: a `grep` argument showing the 20 lines of context around the founded Program Counter number `7ff533bb91d0` . + +In case the cubin file contains multiple kernels with the same Program Counter number, i.e. `grep` shows multiple matches, then we need to further filter the information: + +```bash +$ cuobjdump -elf /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > elf.txt +$ cat elf.txt | grep ".text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_" | grep PROGBITS + + 1ac 1b83f80 b200 0 80 PROGBITS 6 3 26a .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ + +$ nvdisasm -ndf -c -gi -fun 0x26a /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt +$ grep -C20 7ff533bb91d0 output.txt +... + /*7ff533bb9190*/ IMAD.IADD R19, R23, 0x1, R3 ; +.L_x_27840: + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 203 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37 + //## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37 + /*7ff533bb91a0*/ ULDC.64 UR4, c[0x0][0x480] ; + /*7ff533bb91b0*/ IADD3 R2, P0, P1, R22, UR4, R2 ; + /*7ff533bb91c0*/ IADD3.X R3, R19, UR5, RZ, P0, P1 ; + /*7ff533bb91d0*/ LDG.E.U8 R3, desc[UR36][R2.64] ; +... +``` + +The main difference is to get the cuda function index (the `-fun` argument) from `cuobjdump`, by searching the function's elf section, which is `26a` in this case. + +Note that this is a simplified example to showcase the usage. Real-world kernels can be much more complicated. For example, here is a complicated inline case: + +```text + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265 + //## File "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454 + //## File "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41 + //## File "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122 + //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122 + /*7eebf5e9eb80*/ STSM.16.M88.4 [R13], R4 ; + /*7eebf5e9eb90*/ MOV R34, R26 ; +``` + +In this case, the code to blame is: + +

+ + +
+A line of poisoned code in the attention kernel. +

+ +The faulty source code calls some cutlass functions, and the function it lives in also gets inlined by upper-level caller. In this case, we find that `cuda-gdb` cannot correctly associate the line. In fact, it does not show any line information around the error location. But even if it shows the correct line, it will only show the last inline frame, which is `File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158`, an internal inline expansion of the cutlass function, still useless to debug the underlying issue. + +With the approach outlined above, we can uncover the full inline chain of the source code, and carefully check them one by one to see which line is guilty of the error. + +Warning: to get the max benefit out of CUDA core dump, line information is crucial. It is recommended to compile with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, as this will transparently apply to all the compiled kernels, without having to dive deep into the compilation script to find the right place to add the flag. However, the flag is so transparent, that if you use some compilation caching mechanism such as `ccache`, the `ccache` will directly ignore the flag and reuse previous compiled results without actual compilation. When compiling from source, please make sure to disable the compilation caching mechanism. + +## Conclusion + +This blog post introduced two advanced debugging techniques for CUDA kernels. The first one is to find hanging kernels using user-triggered core dump, and the second one is to trace down the source code of a complicated kernel via tracing down the line information embedded in the compiled binary. These techniques are powerful tools to debug complicated issues in CUDA kernels, and are especially useful for debugging illegal memory access issues. + +The vLLM project aims to provide easy, fast, and cheap LLM serving for everyone, and easy debugging is also an important aspect. We will continue to share more debugging tips and techniques in the future, to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). + +# Acknowledgement + +We would like to thank Ze Long and Sandarbh Jain from NVIDIA for their helpful discussions. Chao Hong from Moonshot AI helped providing the motivating example. diff --git a/assets/figures/2025-improved-cuda-debugging/poisoned_code.png b/assets/figures/2025-improved-cuda-debugging/poisoned_code.png new file mode 100644 index 0000000000000000000000000000000000000000..e9606b08662cc98774005f4d079a536ce2a2a788 GIT binary patch literal 150458 zcmdqJ2UJsC*EI@)6qO<>2#5p}P^kif^sXpPML>|=0!kA?=tKl5f`SDQks>0f^xlbx z3eqE;&>lJ@5PBdX$=&hsdEf8*@BQz1`R*8Z+;JIV0!a?%oW1v2bIm!|j=5oYmE{o6 zAsQMQmTS7&w`gdNM9|RC?l3Wc--P`f^Z-BHbk)+jaZO7La>Lux+11U7hUR3VZ^A`g zO}pcd)!I2aKE8YRu1O=iNwEB*37=pzyWYpj`Ci$c`DxdPa?VSK3^cUP$;z8JXz6Jk zKhARW*T|`wSM@8s*To44iTP=-YD-qwyM@3Xtd13rj;68HB%7IzJ>w$GC_U)O`xu~d zGNwZ>!gD#eg3&E${Hb|e#~8U_dZ~XopwioFQP9461^e8U$6RH@> z`XGJg+p0|G0op9ej4*GFlmhI@$^4gs2DI`T4`9+~AJ6yP;BFDYLN%S+mhgU_;*w^O zDe?;v74_9H|4O>%c=2u@dOEaM;Cs=?_1>HU=kB6|dj{~wnYs7g6j^&=wHh4X;GBB% z%^WGmEL}3r9&_}MxVGR4?(q|+JJ*~I3}`NZ=S(#8v^+Ei!82O$txC)Lujf~2&(YBR z{`~f%|!0e@K?n6VvDMbCIy>{#DDmeb2>+L(RI|llS4xa83c8;F* zP7?m^UcW!4tmf}!=WyQ%2C;X7x_YPx;_8rs5LZVPL9vs)r!T-j2?Sx3sVP`7roNMbHHX^HP+Q^z-wR@RO17^oB}aR8UZml#-T| zmKF!^5cdi2fZ6$rd-$B89^&s~Xgm2hc)NPRTs=J?)MMJ&d-}pu1O+9je?;=nUje5p z`LFW@Km6;8ojky0dV}k6bymFL>hI)cuI=jXiBVtxHqr)&eBzm1fJ$Lk3qym>^mtb5 zInk(mIdFwf`uzFxy-oK93&oqG*zj#?M63+Af_HY6yP2UH3C67!T0iz`3p3dMke04~ z*~~iN``wAgkh;yK{&u&z**f_0i*%Ksz%AJKF-p56Qn%$9pG-xpOTzabPfmzB4i(!S z4eD0j&<~glV-##_+?hIDGoKzL*yioyu+;bIR+ct1E`e9x<5eh$5P*iyKIXDc(0qI- zrBlIYt~FM&(Zx^a^-fyDbz~vEN(SbXp4x(IjzOyMHa9+F0vUpKuliZ(j@DmmW>XV0 z`FwlL>}Pq%OzTlsbY;PL6;Utzji-SLhC+b2S@PY&*Be>+cm0CR>c@OL*1^?Ad*g{VeI$brS_PD5N z2YLXfK~fbnFU?jaZT8wD!+U#sXOR?Ak=t;ojzsN3Rx~c|dn{M`nRsRUXMb5k^Kk`! zN4Z4_19e@!wi~x;W6$tgXbJ(KO|efn3-`jK5hhYY>s64g+w6mJ-0sH65^Kku+$!| z^~KCI5?5l+wHL7P&E(U)TVCH9v}sTC8+VmIqiPen>*VR>|LvKWoa*k(F>V=W{cq!$ z^Rc=7kw>@%eK7-1FBf?%JlqDeGBs=*o=}&BiOmZiUCu&tS^L|4YgierG5+=IS9b93 zdd?d(j;tQ7CiXsj&G$y5)#hQ_%>@)(rm#CRJ&=PSN!hXyOT&fQ9CU>uZNOOxC)c)q+ts60a!Gv;-$vSHg><78g9Z0JM5SNWjx(Vy66hFNkCW_t(`*GtvUv#MQ&r=E zd5gbV_z{)^g;uGbf_&o^Ki+)b;lD}5A#LQ`M_zq;hJ2{5E78Tyah5%cZ<4isyw0y^ z;o&{M2l%=s8u@csY&&gubpI*lfcZ2(!KEK{ z?~n_KAKXy?sNwp6ve5@EeB+e}p#k-#7eac1&Pp*`!G%6zq)FD>gF!vSUyvR$iP$Ob zI(fzBUfm9zK&53fzoOowkd(2;T{01;V{Bgwemi0`Vgs+imbPbKjXw=Hp6kw#G|tgq zt86a(wseg`+V+@@BF9!tG>0NB17;I{ebD2=!B*T(TlkZ<2HPjd+wz{st`{3?E&QqG z$@ls{JGG5*o4Dz)JWE56cj3l`jf8t;r?(rnST`|kB}XC_gp>&j?{AgZ>)=YUq=19} zgN0H|9fd`7Rguf2R?7nGn8lL0&Er(!jUV&Vdw=C=)=K}dc;Ghn ze9$&FrdDl0(?GaC-@Nz!cBlMA>jZ`D=czPuT{-NL)cj#I+sZ>&0~=p;p-jfmcm9(< zFR^c~lpJ}+_e)6K|UB8{xzp{*(hMh*rUDxMQJx2_1#^G0~ z5wKFk;LTZdrTa$ZSfUgNZR+t-k_01Y7^*MBDm`dL0X2Nt+u~8>VInE3J6$y#zNVJJ zL-Qb`&blEmrVD!!ws1p%h%IxMsadW>3#PZ^A2?<=aTlkJEIfr#k8aEjWPuB?+Y9Yz z)6mkhtMB7K?PtmA(F_{0?CT&>8m%#-KwHGPmY(=ePfAmNekRs3LrKT zmPEOw9izE~QG6ftl9PBn>u+``;gnqm3)9T-qdwhw*iXg2@+1to^cZ{5EkaB8zbPz7n-krHTwHx|C>jt6mK9i2tErNhrjFb)mYFcIF556IL<5!kWY;?~{qRPqbhW-uA10%)7vp(k)jb8|#Th5yI4ZjU? z52D_+@@qM-55>hXGLHJWmEv{Z-*%vH#dbd&b~BKeGYEsn;Wqd zXE4fXC)Cl{@dosflY{V;EP<4um6{gOC!9>Ew7?@(Miu%QohxH?n0q-wE|9bN(41q6 z2LaeV`f!X^C7$`n$T=%k<`E5jUXGR2d`f%Q_`=CIVRHfDL4I1|6ga? zC|gAcor>s>-x&4>9+)6r2bhQ0zdajMyIylkLLyFJf075zI z*E=65wXkQ7a1plc`*_p2gJ5i=BBYlr(YaKt2|0Ua4^`*q={k^ao}^3WQ|vKs=Cq>| z2&@M=KAPe3UMxbYoRHCHe?n}j5-jAM_tz6ISf>@-t-79}8q!gTm}3ul zuu{Kah)Z5;xNZ-@{Dv#!KmFiF-6fJ@Lm>hfjrzD9>cIarpmS9Qw;SEyOd~Z)m7ri}& zcGD)@1Ec1p61d48@?+VX3N14zd(b-WE#q(*K1Epc5Z7EoKV1YfJi?tzmc0~!0+~GH zZ!D!@`-EpqZzm7zs<3CNJyf;-#&q7S@U1ckPo*xXEX(DRSczKQLXfm~SL$#b*(X-$ zxmIGBA{^;8A04?4M}_XKU0wY3A@=si8ycsBuopcCL9TaN=y^9@sU-SHNh`WQDq!95 zyj5-am@Himf56O9nGA1ZP9;62+0?>+kpHt?!g&osQ^JzuvP`f9nN#hgW!}5kr}$tyf(B!k z-%{Us`zxFckhGtf>R|`6m?i3}$KzEAA8+U1L9IO5=~LyEcGT5Q@mR3k0U$kamS?CZ zRNSVq-Ymz)KVFSoT@OkX>lDOyz9U898|W%eT#^PzE0wJkQMio8qYqP6|G2elE}S&rqJRwNibdHx~rAyo)^5IU~~8p7c(N!5rk=^ z+dgF;^0^>gD_Lq;5sp14Zc!f1v)Ajeq}IkCy7SJ(E8^fjQy9BC+>e>@?7Rt4X%Fc-F9BuXHH*cMb_f}0Rir?1c#beV&=DCNvtd``PL1Jb9xFO zZ+x(*xaaIu!dg3C>3;Vcmg6G-D!n^nl zYjmjM-!JU%neIaxaDZ0o3DO7M`uO;Na+%=Psvn>S2hb8`2LXNPrKb4)0N=k)nFFE@ zIeSt%jQxN>Q~L3H0{bnN{(y}H8d`9G1m*~KCg$zz8;AG3djB4hFnT6%fJl8vgt~6A zh)|>jz>&{YLkRxjWD*Y1UWaSMxR-gm9R#(%Sbv^bGJty5-)TPX_>LVa(aJGQyEj0@ z<0P+{*K@tU8n0T(g?p=lRE8=9C?>zn%ab32y?k}Sr0rDiGSc7B8`ppj?4;IVhdISo zD~^5#&8Jf zzt+!a@oe0^^5ZFQ^8E>}7D7n;VDk2#!TnF?1bZrVYFq!A+F5;{ZIF#ePF?Q7DtJvk zgsqGY=gWA>A2L17>D13gHfH6K)dja=UQ!M%cf(7W<)PZ|tNS=qb6Xd3HW^Z;(6P3u z>gdC(utLSTqX3**e`%et#jlKc6PK%EL%WH|*zxdiMz%H;{HVjDM`U=32R{a{GJyG@ zy1&6=`X&Na!E64%SHY`u6~g+jA-QSrH^Up)vhKFg8lO%@I=}a}vvxO4Nt0@`ptjV3 z(iXRaqCjW^DHD%Z~K1tql-0}&h_~$%lyGV9X0gk?nD$bg^E0D8(DrJfzYwN0x zz9fw{=-=gPN97Vo+u6|W^rOBbRi2W|BNegdHav%lQ;^C7?qsVNlQ_cu2H@1h0f0re z{#?d!n*vbWw7tIQp#O<{C^(9hub$_~=+?RQV4*rN@kUwLPbtE5VN!mE?R0F)5NLWp zj@eW7=&J&D%82`Xls7@-(UUNn#^8VyS=SeScK_l5^P^|x@d9oXO9W|ay$VU{yx-`H zG3CB!AHn>Hg-`MJljG;7mNC7#@e84K`l)gTR6bMQX>ICzsmGY2%-vVKQrwsBy%Gae za~_xTV}#k)k0^$Hg^RFhGwNDvJj3pN;GiaXHA?L@H$QB$ign7s&@f3QXuH*KB1FU! zNvf2;tDSJQh3{m6O!&eP0eFm@m*3R481G@H^eoZQB0;iMOO63-8vxbV7ja}yS@QZui= zpoznS2_`pd6S9@ae9CR2aZE0jkYc6g5-wp?>%UxBzkc1*icVl-z8haT#;sZZHSPp{ z$Kv!^c8+eJ9EK(H&$rL_;y}ogFH07)X&gNwH&Nx8;TKS|_BMd=a5YC>?p?%>IPT%} z4>#Xm$?9&9AmpQIH$jE>!nD95C4`anRDw<{=g&t64+|bR51?$h3rcgdu2}+=o}%5r zR+wk8h}ROEkhCi_#wm3WZ5zG@GlcNNI5YAJYunyx(s6?j^&IfPcQ!);{(8*S;kWlx ze%z%Ix~X&>S+u2eg^Qvx<~wLL-Uzrc-CV;oeH9XwKxzksuv_Jr)EiDLAPGGhcb&_+ z`K3`Q(F{jU+HOB4`G}3Hn>$DNxVc(2mv`$)xor^zp2 z!vmIyN(}Rg`2~A@-Ps~Kug&ULnV(G?^vsj;;GIkfV{r4xklAY)Q8qF$? zvHE5I3>lm{t=U)q#`QMQaTE@eR=WgDQG+|>Nmtjq`JUPL@v1M@rd!%~=QAWjh}$P9 zWqW|9GL!XpnD~@W=texlR+8}ZUF=f`v(3wZAb!RCV%R}=DI~gTqS-ok98lfi6){^f z19**$gK7O*T%-(sd;n^7&`6DRe9dDK-^(ID3gMA@qv-qj)w!Er%eT5u-L7;q4_N;x z(y2;%Px}>=%c<@ZnRf!0RgxrlG|%7um|#ArREpL%wnG zF#Uk<^}msJdIE5{bj#sE^CMZ0PmV_Nw0<-zNn@6Km%Yf+?8#rkZwo{4YehQ2+z{Hwn7?EwMn@8?9x z^J3Q%PsM)q?|O4d?+~Y$%QF>F$eDxkCux3duz&a@5Is0AR;%UxFyFE<-Flx$#C0=h z!r&ZfR$r~CtS&0WAN!&eFOZH)_6JMp(g41R@QuKYAx<)AK1$b{Fi8_?W$~Oj&h#Kd z2I1vNE!j+3=&`F$?i>YpHNS`{;;niZJL$wuw;IJV zNz7a@RnA?St}|8ERgjQj1!H&yBF#{zeg+5xiHp6tsg;Bmu2!VP06!2d&Eo6qni+Pd z`6=gq{`_gZWu&^R4dRJIrKnkl&FasJHVlU-7LQ6SWfta{7Q8qtQT-)#Rs&y0_tBl?62`9b zgBrU*y(y0zumMtqOsaM{_wCKg8Uzi@7H+~uts^TEj-nlDPZBRnE%li<#0`=fzQ!fq zVf=!DY)EydD7}o_e$F6Jj`NLtr0Bb(8_Ox)A;vh|$(}XHH>2Q#lN{K(qrwl$^G`S4 zJyLke5H1UO)j!2E3P-y;52Vg@re3{0IB%T+O?%K{m8Ak*Aww{z0|K7)YmuX&#O0%% zmuPZm&(@tKghX$XN)eRMje3_Gon3FT2IoVO--6Nt>+ei3x*v0XkntFT3s+;$BE6re z@L=)+a>$&D;$rs84k2-?+S^y+PrEw@o;S{k^GhM5y|F$x z0W{DM0`5gs_J&omf+{I`Da^ z>rnAKX0e@SUc&2zj)Q=dl{1QKyzGvy>|9@bfFOnJuFoEAVnUxA`7+j(rj$CV3o@qT zmxr{KpNGtu%57&_pWn(e75AI~IZ*PLh@mW6vDmz*m*wS<10kF*kPT2Z2_O}}u3V9v zydTDJ1D5Fz2tfJuxNa}V3TWZ;YXy3xi{Xc`b1bhGe9*PN_l=#yPhGm#Cl$U?(bE{-txC*4e*V_CzFxK3 zf^dTG8d7xovNN<*jU2iILhOed?=+9fyxPfdQZdPzh{6&`Z+*I+6IbfGk9i?&rWy=Y zTtGne6S$baLlkHwPO<@tOp}5cDtWI^7TN+dNsZC@2c2&&WtAOydh()Xd4pBs?i??- zBJ3k`n(Jz32t&m)7WhKhB*RO$~rPRY7Vl0OF?e^(~x zGO?7;lMzbz#OY-5RTY}Tdb|`8J@1rk<^1(p;1iG#e**Mzj;a;j&NJ;)%1gTBKGK(( z(!xM9I$bA&cF|&KaanDZ2sp}A(o&g07?rqF{KX?MBPNvp>vmB6;3xUF(dyCZA|N=k zL=2fLOVrG8)XXN`-Sl#vJC)XXE!>-+zm)aYg*29IAcsqPO{yDn&cu_fw;44#B`g<0 zy0;g5uM)PGd5^yBO>qV+5Xa2%o0y1Y{71(3?EAvEwDju@)XZq*ctAk0=N6(Q6ok)p zz;}vd%>q0rMkNcmGYMez_!EJQFYG>2{v52^pJ-3yaw0krlzOQqeDrA<$=Pv0>R8Fl z(BM)kRzfX5vbHG??9DMah896>*Ev;i{3M zf+Ve3Sst7s9uG5+`PH` z5asyq1-y=9YBGsj5LD%K$e;1Vzxb%Gn~&2fSC1y6LofwZZzP702;`_& zGs~}nyYe-A>+J@^KwuYb2*S%@$mUP=?Ov(Dw|T}A`!C(AHqi0_K%OHc2WuaG7OphV zT}(QS3HTVlPmHsxMJ-iAbwA%JnL5gsX_=``+dDNLgge3ar6`tlvd{cD3`L6p%;F^s z;VRVU9Plj+4}_u0oIBUI7oV{3f0OP}nWAAHxqS4106FE7SOMz|wcwp%X1nUJ*ZHA8 zt%%jOO^HRtiaV{M2r#OP5}vSpv5YucY!|+Xo0!O( zVvG2uctg=e&0l2CS4j6^(}Zvm?dqB;l$;{7H@Lphtrj$KwN?`AGpFV;w49l~;!}Vd z*jbxCWB5si#a7_U#whKs^NRX<*qLmBtMJPS|!|%HEmHl(*&S|fw~sr zl6CTf=^xiVTu-de0YE~LSD#u4QDu}X9o2tb0=E+T#R{*;c3SxyRmZqo-j$8|Y$95u zk9cg2csdROBjBkzih9vcfX@61AaJxniW;I(KSf#yrsM}{Y-9siZ1WkPyAw#*&JU4e zTt<+yNukB_Wf?2Jb#ztBWS5LHSD(EC6=^^n2v4Gfy{mUe$-N}IbeRx1nJ|c?RW^TUPN`4SLIpK_l>ssMK50^sB>Cc`Kse5|UB z$3r`+;sj^JzZoN~sE1S;iebpynzB+k1KSOQ4mwKln1pD+79d*yHlmc&vstK(3oQ^6 zc*EF(3J9lMhI|0f$#DUo(i@Z+sarTC#oD8GQMhcS>$W3$1Zq}gg>#vhuQcePF2K#5 zKCS(6w_$riwmyfI;FNQ-*ZKI`Uq_3UD@4C~^-BA##EQ?^G$0II6*hc(j>`3LV?o9; zD$xi7fv&yhomMASKrcz`V%dcunWdue3}ADZ8*KJXzykLVe_ea9Mf9{?WJ2^|3ZF+c zUCY+?URlr%$Lb~e1zt?c| z`Fx8$nnkzmiwav)Y=6ce3g#%(+3wfm<4BC#m#(gp(_QrM-Ie@S!~iqUW8w-hsrETh ziyO_iH`gR zNJRBu9;uMz$(6=;1sa3;0OW_c-C__=bo)P>ICZ?f2v25KS**BvcQdG7JBDMt5IE1e zQa^NvcHCh`<5n)-xw_Ahp1%5<6#`Ut!DqTr#Liz)`UQuMpR;{Lqek9w$O$3j{iRr} zEG{nIt9n2nMfaI`6Br6Ci`-!BNoD_HAOnq9b_*q23RaDJ>wT-(r4s<)?d5nwP^u!i zSVz0rM*b`7C^dG@PQ^-Yvo5A%o`@DE00u9R$jYa9b{SnIj|MZp{6_e(qK54^vMxQ4 zUhE|~7r+ZKa8P+ptJug}z&ka?GB8N>Kyap^X6iS(UV?lU;Oi^o201$r8hzRI$CVzM zaI~Qo6NnJoPhdHRSP~YGaFT`6=k*y=Fm95Hb+L8-D2zoez=C3pi2P@Gdp$qq_7i)| z?~Q6~SgdqqPqL}k*u92kgbbwK)JD%bLA2x5g0|fN=h1Ic27x3 z#-0Nc$-@n73?efbcP66z=>!xpE#Sk#JbUG_k&3eVW>k~JS)QdHZg=byoA}~JzKpP| ztuB&oIjg(2tK+!)n%Z3giJFk~O`5lZa_DaUr^N9;@B70D%pBBwWWr`0tFG&nS6`?= zbqd(n4rng{eFPy7byQor{d;iCbH9&B(u?}#V0|!*z zqq+rkM;6EehxpLDS;Y-aZ|~B=z;=L6vxD#}&Pr}dEJmL`1-UgCvfN%&xO8roHQf&i z?>QF567g+>ZQOuT%OU$^vNF%IrZkREX+T^W6Nn^aL^ioeiKNeWSf{F{DozFij&E)^ z|JH}=+7n1HH-1=43Z~DC^|CxZS3H<+P?q{Y$l{!9*IR{kFlnwtz&0=6t@7A%CP#A) z{PNSt%_uF4apARES7i6@yfoNaF=!pM>Bx=nofBbEUVO3ePuwi}9XGM|83ma*t@<$b zv`g10odX-L>_xiI1?NaUMHE27B~a}R;vA%^ED&yvI+0mOhMQ4Sql#o)@Bx8BR?PlV z;@j#2il*kte07Ld-X9DNPeycVEW93@wi94~@d~EAX>2(N%%fT$8rV!aREH%1BE0=e z7@YwcSms{aw01A5+cCOf9sto&LeV%H@ue=P3-GZ6sU_O&q$bvo-LiW{NK2&un!QhT z@p5?y;s=M%7cue-V9NE?b|hO^r3crs2ab^8|DbsrtUI7c`<+$5oo|rCeQ7+wqP5^A z$Fqb23!}qt;YRAprTLlNHqj#Ox6O?XWOMb+;XY`egp6~C6zv&c)FIjB*0V+|1Y`uv zr9>xu0gS#5^a4ugIZjKNbh_L34R7bpvCWX;t>)m`Zd23cF}y(40Ri*0`xG)MOGb2g zv_{7G1cJT8<}YC3cmY5o8l2HFrnPf-H1-#zXhD#Q4u9TxAuIvZ;#s63x_3&xOUEzz zJWeXQQ=$pt6{O<+;DA9H)xOqoJCpaAuzo_Xp;G57%%mqDeR}V$cbW%LDcaal@7e3r zVsh-bpI3bs11ir<(leuQA)g)SA#yTo)k>Qm(g9A8*NPaMXM=aVRWOFZ9|PM8`hxQu zYeO<%djg#7(8F|ZAN=5!w*5kL{%&QyIPe8%5U8{apo?{08Ko+ z1crufnpK*OXQiZNVYn>U(yQs9^Bj%5+xw7TKrz27zUd{51SCxgqqh?D0+NlOP%AXHlb z&y?M8eE1LaM8mbNS$qtc0G31)^jb1trq{NEhA^`7yO@kjTsT@dX$!Ily$r};$LSwW*8wj9y~kT=l|$!r{4JQu&YUcKaRma&p#0MzT_U--`&>{OASLu} z>7e=Beda19fq{klb>Uz1Y0lTgkz~UCjr!Ux!*bsc#x51S<*`1ChCts&B&j}_-+NTU3DE`_VzoEERh5M3vPqbixanL|h>2cB=hh)BI=E)>j6lua)gp?3Q;b4^XAr5{DM&yCQjMn8BiC zS~>>D3MHFYlyRJ-4TQ0pi6B73Ve)VF{y%#bc8dPBz*)vj8^2@Jzly!~v!&vX;F1tqC;3ydh4{&F-3*I>N?Yy|B<39Mf} z<~!>jz#obuy|fv}Pt?iaqxMC5sE7U*9z=jM?&7M-$%(a_m|9HXOR8pAHw3Pr9f+zx z?J_+bxFA3^*eQIQYzlYUnCtWg?%8(07&J5&S~pn0`1Qt_wK0s8VJgQ(wWJ-FupC0_ zh3?LBEdrL68ZC45Q(m|t#TP)uL)A__PWN^@klr3#>(I*_E`#cUW(|eM!0XXv_K%7k z{ghY!Ix>_5F7m;we<1(VGJi<=rB<^VZ>RIt_0?2C9Jt|x?++GsZYybYqe04X{goX< zmBS)_ZU*cDH@D+x^@E!2kO%n_c>j=nFa8o{u=?k8{`gz27yJQEo#_BL>CShh{db=M zuv2G)Y)ZB-bEiAl(o3`R`|9GJr0G4T*YswVPHjGAgwv9%dOlsqz{1ly<~Ol?Fm8_p zm_Ws;og)kQdc@BamCK+<23V|BgwHK_a$&YhLMI#=te4LG1G@h+V5lBU^%m^@c5Pv= z#oW68HGOA<090N&mrbTb5%gQ1_hRqk@r z>e6AEh&#;BGt-e3tz(xC2-ILw`cco4@d9}mwp&1q7dwBa*tXKd50HI8MTq;@+ZYlA z%;HwsDTwCC&b`?Zj}6@QeEYur6^WBOzq2!iDMFh$Vj*)E5IZL|_|73g1A+RR7!R(6 zEa?kYeq+Ag{U&As{k{<5@2M&!N_GETbolW)hakENSejMV<}X(9g574NV0!)>qiF)p z8f`|O`Eq5D&%EbSpn?5N5&p+Her!|?tQFPFcR$s@|6|hOs{B($S9Lzo4W)9=-Jt76 z5a6jG%%umrPRqep#oehFA+?Z3+S`0H#AXBDOL;DiI zw4Z=Hs}Hm@-SHEO<8lAb^1(lm`*P?(D($6oK;Y&9fn+Oq3TVq|X1WK-2Imnbdu&Xi zN5|Wag-`__s`nKXT|yvb>}5v9o%~dN7>z1YKCbdvM}XE2x=H)GTDh);oYISm_oZMqP(XKlBycHEX1hgM1nX zZ4hI5??vs_a$SXuK`(uK^hA$G!3%(dwl6z?c2C=B0O@Y&u* zdj^%gSXkSO<@pHxbnAmY5;uXI=&&vg-u0Fl3I^?V)-M6^QXhWMGN!+`Kn=KD-MW8$ zymW-jA;Aq$@ZF)_MauP-dUBzzMr?M$CyhnFV5I+0IZ;27@%n(`UGBf=d>N>t7E`zQ*l-vWB<#Y`POQ|3Um46{q-J3wh4 zHA+QH7P2ZA-YLlfZnZ0Yz21DmSQ|1KkGx8+JW2wz%zMC3>B@PIV^0w2B>96^!OJYjq;$$z=`fhX6`JbPa zZkZtKqJa4(2F&K-Y#0eo(g2OQY9>$p>|dV7!kLZQx~m(F@z$@Q(uC+ZAzL~ zy9!5>i9SfSJ=~Mi^Pn7*rsWUn_@^l8BgSoo6w8%_{vfC>lx3C$E@V+QrbF?rJDA*z z&?o*?o5aOPUJrJblLl#uCf8q`vl}5mSucR%OL=$ZrE>-um2ejr+&U{MDLw3OPwR29 zygPp3c56N+W7J=Ddf5+Wzu$^8Y+94bi+KpLPaW1dcQ*@4UvusEPXrYNG8h>R=Fe!#*X`4gnw`U^bU#Vy;1LPQl`n==ZbhXwcOUWW9aksHIX zpbgpl0SDY}4oy4p0PDj0+cXXc&ZBjEVF}S`T~`dl*azDhB!ieJG>E&Wmu2#3(Iu=8 zGyZryM$QYi!77i)i*5YlxVYJZd~OW+W-~Fitp;6b%*5C!|8dI6=Dt=AC2WTCM{c;UPF>-%0?{a8^3 zieY7`<20|FTY?&SE&j#Zw^$PjXfJi7)`@k%^e+v^-T}pZ{@p4=s`>?p(n9%S>-Mb? zMz+)Wc~p9fE;$)EI{;Z%d0f(Z6e!v&q{r^L8ti;`Jh*&!_BOhm0O{)8bdbxxOnD*H zkEWu4Uz|OwWrrHf!5p6(B#4@Pc65FjEf@aew1}>nO9KFpm?-nU9tlPDpy5&{xG{WX zbdgkxG<#jbfuZr|b1W=l`4~m;94Y^)ENbvZ_9&@af}p%m=D_7A%(8rjMbl*hbL1eP z?w)gsTf8=B1vU}q)QHfodb)@^G5s4QEvH{nt#C4~Qo3}S=Hi7-?p0&LW>&t5fVx)B zs#V1klreEiD;&!TK6pm+cuuB(>8J8DAQs2F6oy@7feZ_wfV788CK2(e^E%Cr@PXqE zs(7LhkdLOn6I5MNrDBO<^a1>PfFc^=iC@1)26W1)c6)`5ydS5_m*ji9wQ+BTwq-+j^pAV&kQY>jXRDwT8B z1S>{wA>}=z`^h9PsR{?2P|2YC=#4wyP`->g!4NHvzs1@*P! zW!dtqzfROGbg)Vy_e#fPLw|QH{#R>;%jK&EScqG}$LEP?Ac4HG?Fge|#0u0KEh9lkX{vh&127=pA0TPaQ}iL*H2)jN z2CkUl`eu#$SZ&4mG7VN1=f%ljtNc#TR>*sF;FmjJUr{W@s-)$jk`{RlA0CQ7apCqk z`$!H-S$^*IO=5x-GQ+1!!FwVEA8=p=uLqQU#;02$s}g4;PaxF(o)IiI3kW_XxGZ!4 z_ZpD*(M07!^Rgcu+p>VZLT$r=uZ7k)u~x=G`PYCV zkwiHFD`gvf2S{j?7zQ*pi@1zbs1!@;#u8NO5N7rYjK2G1C?J&IF@#Gb7?KSf4OCW` zzL!vu#0sZMLdQl78J|a4?KMS#j{mC>(jE<4rZqcLu?@`t2B9Mu>kmM`^U+h^nHq-d z@+eyZ_6_TmtvY0A5nY6DX0e2EuHhvOrn2X2Ev(9wXsyU2&%v@BG2oVSd!2u$M9{dn znH5%f0yM54z3U}C50lnGRxW$7=CqNCk zR1gRE%j|{#e{=#`xafdDpkZ2n7C_F<@JC-CAJ)5m{YlMn`HuGoJt(!^5i7msN6S5Z zO`ihu;q6#yal4E%ea1uP?B6y@+a->%h%8cDxN6pZ@Xo%AI{k#{rK>rJ`3>C0ZZuz1 z8kNePp4Mx=$TM212h8T!lI2|0Dq!L-P928=rw-5walT6|zs2R4-yt8)8LuZ%qwP^M z%c=5qh&y1@ybpSdi^Xiw4FYatGMJ)oN+Aer5QZ%RB$*y!`ZB1qH^^rfj~9`-p5&I&0-1 zZ5zA=vhRB?u9@VNB=yA_ClEc7mSfDTg5zt`gP+k$Zu4_qQdBR=1|&`-ma z?y1An9Qet+%(*X6@k;-aCHVhn#x3Sw@GCj1Kd|DEJA4B+qAEMr%*lFad01frOztWaCxkJBk@7-QAdPe|6p%>#;HXCt;gU0CV5B zINqZtWaNNI>kBpPVK)T*4uY%&GOp$K`rAZ%ZQczAtxZe-q6-?DHf{>m^amJje9R*d z&sCYG#BPIhF)g%QeXkb&R5*)Ux#lk zqs@S&ea5{`)Pv%>OLcm%>q13(oB%aM-Cv`u5CImQcrSFm2tmrDmb9RDJ9hE16&{lr zsuC|@mm$-8nTt%n1%n!(R-@9GwHaqJ0+&8z*~NqE@cMH>oeq;R1O;pY5XWm9y{$5{ z;>WlY)o_V&M1W`J=6&Kg=E3{47GK?O|12%aopHtjejLYUNGs4!8cai2uP2F}2Q1bl z9iV_EfWGMA{QD=@b5cE|F)R8yl?5SiU-Z_q?GO(%J?RJ>HalmAWgB?gd__l$YlV~MS}yQd5j1i~@%38lWE=rv7aI(g7cu?@+q)R#ZdfYTQWy zMOzy#zLP~|SeIh^@%n(QvSyc6&2U$_x6aVG2hc~TO2soMkV0T4p!JmJXxkIcKB*H9 zJJW0tZll#V0ly~#s{U8tG3e?q^5FaziGLXm@x1MzF7e)w$-@iC(+R9>-9z4|pmm^8 zD06!l)t(!midB^tyH#<|1hi(yjX7x=1ummFNv9=mW<^wCz*apkLHY_p!-d?C=kE)vl58A;(e3f?iV!(10E;=$QS<0uPi^84wtd^o}_{f#%jNC`-b;)?Hof-O%(s7Y;U>G}}x>jBy2SKQM~T865$NOtNz z5GP#K?}4TrXNJQ?tAmLkq$*4)@cc1sU>pAI>k#8I1swMYHC@|TWp&khj}M)g0Tsr^ z6If>Ta(8ieieK?%vqymhmaXzYe#OJ_N-oyzV=2e?XwS(Xp>}u$$~y zR}V1&P4FlktdxteZ^CZs$w}vi|s+!(zfa>$vNo8CYHj|4( zjy)^Z(Noe^iyb1peW9%u2R**#_$ad{8NxquhFX zxw}+g#H|O*ppZ9byFWLOy-DEOIt7|5dw-C*^*U(Fc3$%S_F(AclO~a;lv*v@%MJtZ zpt9Mcb*sSQ5~8FEm?z9qx4?Y@EG zXP}2?)FI!RTAdTsbziCL%2~I}7J?E^mm2x+4{D4iTL*r6C22D*+RHTu_|IgpjX}3) zZzA(C;a1T6?%YEVpf3ogGTTS^6#G1$C|$OOhOaVuz)Yh7b(5OHSbv3E zH1H|4^UgWCn7G&UEJX(Tq^odpJcBiJFo@c%;GT3|m!mvQZ{tKYV@wWEv%+Hq}0J67EVY%&p@TTe4&2Wm;3S=YmuhT(sz=_m$kjv86nboUV zZbW3I>;v7(Z!?J2lhY1Bn5J^#pRDRCRkBla6oPc}pRzu)c^ssvU6K+vOcB^m> zAQI!BOalWL-DAs`@5a#blCseSbu#D|PA$9lMJL0gPCJ#aiBj(pRMB%IM#b5KUUO3+ z6s*$a53OJ}WFsi#@k-w7zciPLpbscoRj(Eh(SJ9L)%>)f(1-Ak%9Z4|5LBrf7hHLE z%GJ;gI@9ye9#k&|DP7im75B|#^7H_?t)If19|Sh__yAJzP05?cW?*&W@XZj98?c7T zqY#X$z=L;Jd?^>S#B>uu>X+$%S~VfVmvCUexML6)lT5grnSM5=0+R&*kol)+)4Fni z$rjvxPIF*r$hr5AUgOO-n%DF~TY}m}PHhIFc2xs!2x!Ni`C(Tm(Or)LKTEX=g?)b@Z$}jD6?*(;dQ0b%?WF@jQ zDTL7ppiVPYVZ#_tOKC8Q_oisTA`|wi9Zz6e7cH7A+ZU^A8YupYQKOBx`pfB&abLyY zoN-d_$A;(dp@^9BxOvM>E#<)SeVRRJ>9wib8ysQ+`vLr1zXmGH9O~wL*m$R;C^5Fz zSC#%YJ-kP^t{1oHsW}GF?K(zvJ(Thot-fJa2Kt0+aE(OG|L2|q;KkYGhEVrNSSWkp ziFwQd_J>LK6sRtCJ9y#m7PihxLsKPSW0M36xb>c2<2ed&m^7-hdL7-=qB;9{pY|V` z0}4!1TDq+2$dx}kGrWkuU-74T!s;(jwm6dacE{nSK-Y5eRc_m;QU|>qvxRjlhkT!d z4SEu5eC8wXrGC7E#1r*iIn>X;(QT`5HDnHlc!23J4CR1ult>8D4WojnbW0=M zNHpmBo6Mv)v&?dgOuLA=+31?qTKchf zU)Hhfv9|}gmR#gJ@L-|8a|aFWFteEd#S5Mzg9W;B(Ef7`7}mR&pQjSyQ*+ZeLoXZo zi)1qHFxMGIX4naEXGzW3_TPm&5pZny;A=hbv2c9T(Yf?e-Pz2lbKGHxy7U&ylFUB_ zESIBa%~MCs+-TgYrjGd7A}yJT$z}xvCx=SfA4tX$j^(D2z!@$*VAa7vuM=sf-geAj=N(u zr*8ma+D5V##@_Z61R$v=WdqG!xXtYdG^?@+)oxsoq=P#T$ z*X0O?EZ7#Wtq1d|R6Gey-j&;&a=Wy%D%m^YBll!AekC(KQHP$E@b;iJqt|W&o!RNf z;kQ5V+1uO69LSvuQ)9U{9&-y{YUxvLo%2RMrHxZsy5VDqelb%r9GN0z^u;~*pATyT zLuEIYBj_zA_rfO@?D@lu(&c*NeB_oo&jk;*G5Rbvrd_*tY$lXk=X0wu>g7=qfk(q@ zTekNXK7iF|j712M70}ALfbH&31;Y_^*j%_f#7tS?vEeOg1R!{t%LQxgxjnwK{7i9fFt9Nu<}rXyja+til~ogVsA z(ixymr*vU2$8V1mV-gHf^V50j)WgZ@Ioht?yL@c_cKuTtSB2>c(~lnos`vMteWG4h z5;%7)Y;hekN%dZoc|u@FG`ta%Ub$`+9HD=GM!8nlKHZy3t>~SvV8*+OqB;la6<+Oh z@vEhYw-4QAjxc{#u4@&2_L1_qF0yU7FC$^>&9l&G(&zK9r(}Y~Ufw?BX)XR2RrlfN z2eC`T)4m|e6u5HcH&3t7qF)N7izjIs&!R3)?4(QQ*fU%Sz3rAl{I!Gzb+caGjo+~) zEnNLNI@~akj#+N<&H)fOrcy808J+vWSg@}Z-Z06s*fkQz+FBqgtyt+$bAAU9MIIZIerGJhz!`% zn1V#91TZ`K1CUne$jHcqLDAI;qSCe1N*ECN$^$&CQ(?{Qb7&%YbG`MFcCB*(B2O&w z^5}jd;sI3LP{ax)8S#-MUukbV+Tk3oxnshROBoHU+;+pGO)`XIk^mSS_&heY0(puR zk$mSHHVXR>uKcE7)Dta)RzpNf7!KHqURH6sOPy{o<)MTsb+tnH-s-Ddm#Qyj;ofGF zf8PJ-L+7;r!?!ZPT?Lm#K2Ke)*M{8eBv%f>cH}m`4!sazxU(XcjMZq%nB9|k>hP0! zB5gah1ebA&N!>SR-1_KdkGr-CKVF=vI3r_N7~aqP3|xW^%7BXX23BRV*=nH4fKTyR z)I#H=!8NO}>?6TTM_#VrWr#RtB64K@*+@0IJ9lbxNTn8jJ#qC&c3qx6OmAPG88Dm?CQOk+8Og5)jrbUI)Nl97 z`$NMNTnFQ+*!J^Zmi`|uM4Pz;2z2sy~22;xggP+qw*8`Kjuy$xE zT^&Y;6RZSOg^G>)cZB6S-Dn>RnlW3KX=iQ38rRiwj=1aOY=%1zO-s4V8nxa4{?+5@ z4ti|ScvpG4kb5niPo;KmKveigY}V_NnR!XM{p8VC>OyKWG0ExQ-Y%z(tA1#&Ye+o~ zbn}L%e3E9{a3(Lj^{906oIq*t!L>t=14PS>95T@A&b@J29_+$bMb3rWFfy|>gkJd2 z=oN6L6boyXSR8^0gnx>H@zb}&p2o{^*FQbd1^@F{FhgN(=O9Gbx4^eI@9exTAor?U zCslYYwbVbsdBoy2c!5;DcKXdW_Jtbqc;6w9H}ly+(F#qQ+z~T#(Oskdx)Wco>b6m1 z#Y~qa`8WNa+o!MjJ!rgmk>kf47w)6^i;@=y%d6E`>#fi5d$xQzQGJGiAo8SP(XDgN zK_z&;>E@?zgb#VL$kbMti86GAaap#9Svu9@%NcCPrL(zqk9OP4-k%KoAsyQ%9~BvR zl{cIVMdMcd^~5*#mA~qtSVf1bCa3U;ZA-PJO;KK%BcqV7C#Ga5j%aiG{5#%s@3n>L zT#!$EdMfy$O}AspY2wd5n9}2?$SGqo>f=vQF2Atn`PD~r3X44bF_2q>PsY&;TSkhy z&%*)&E^8~R9YLbf0Op(ZBn7_BDE&H3CFH6MX5QacUh0((mAmg)UKMhEJkxDm0lyU~ z%P}2>N5;Y1a1OPCn=2bHz244(MLg8L>mt>-IJ323ed040-$hU;viai|U z3NQKTq_p|izm^WWvek8-=FAbc$G0ND{9RG^APpln*@c?FfbNkFBeN}K;_B45N$N}d z8R2vVn9B5S&v#aFpqQ?oVGggG7c6SfhuTWv;<%7(I9_^O>PZuusFNNQqvI+f1m>O1 zZ@kh2MwL`pTePn*5%%Whr5paxbja-g!Yk`8`I&yT#-r0AY+qD~OSk<^Fm?tGolHeg zeC>0(X3e!634Zm+mu^1Kdr9uihsH+_I&TmaTOaA+DT)Y)omw{pIPB!4$ieuu_xj<) zk=JK)aG^n{z+lDk!j;m=1lXOw=t#%KmZLl-yJ?1G6k}R&Qz8~^q#nuSsh?;6MHTUh zztnz~e|K%1Rp8W1AH><~@H>Ie%_gYtxIO2K(nfdrK*sL|KK&U8Yivz$q+I#=|F_Zl z|LvWtS&Wh* zoq6ZYi11r|fgJueTv_n6-uEbjK+dF$)%(^3;L^1s&^Mn6$ZG)4n^O;9y!gZb>qxcP1zu zj|SmY0>}v;8nj{2h&#s?qIYHkCtAfTcU7$YLcuvgA7(A=limjf2=Afne1h*Er?3BU z0!Q?ZlZ1c%um76DcUl=|V|rwP1T}cRR>gFBw`Aher4oanaDRPes)|sbaV4F$5&X5= zcr*&pKB~R0GQ_CC+&cjey0CO8G4_KH^+rsIr{KPK+QRkYy(j9}dgm;c_}0B+tYE3K zPDtL7XMAgWPln}}L*=~dJi%D~p_;_}yDs7U`EeWS&EzlZZG48-#?%IPFwgj{OCASv zCpx~VS1&WvJ&wDFE2uD5tD0$a>V0di{#)rwvy3G+Z)Jky3p8oh zg?wVH$51|QAb`}LCAhw6XM^%Hl zYN~#rE)#B2l7Uyk=#Sk~WUm9&tcbs;oe8)s*+h~`#JV&{KtlxZ!wtdu+=2hQFA=Ka zoy_mOlz+XzS7Lwo5sDNt`6k@dFgtYlbIdIBa?fKAuCGIFg(h}9hx!Zl{M&V>+ZH<~>xTi30OVpzU&XDBc#5Ntf^v(_U)$E>vS zOvBH@5>XC(nn~NQl7k5@9t^R$^HiZrjo7Lb*KsAm_7N_tD7{>Rqf&JdHz@{d=DYod zRj7OX{mGByHko~ueRd3etOyn~d!2>XJm?n>-n01T>y*m|^Uu8&=;}vYcj6yaXkTl3 z(H6ld2Zmd8@?oZ6(H#Y?)L43y9L)@zFHsF%2ZL?4V1U|NY$Ej~SPv!$%X*cJ`Da4( zUeWXD{TR}unMn=c2v2k#X%ly0PdnW=edRaKo`0WHPquMW2&YU`GEI$eCA-ah9XyzW@h6heNc?!`=VsNpZ+=r zwGAbO=BudjrH-@G1MNNav_&~8yFVKw7sbZiT?}q4*Kg@CtkonxqP0@0E!q9qF{|pe zN|>`h*U{-j9fDpcHi{2AtluF$+)=^#BCe6_?kC;yNWpzTGl>ap$kvnoh%9XR<-h*dOK>s^B158(en%TtHt$tR(nje>gRgQ)PPk5AAW}Row;onq zYdElRI0OcLdmCeNLN)2Od;*^e5Bv{LU(vpIZNQ@eBd^8;$6gr*cMeu2_hOi#Y|h4f znNWLw+wS~tUIfU?5aog0P~tWtV7ZM)vn*gkJDIpE^)YL|cRFKfc|;=7`Mon~m6FJ1 zi^}S3l=MzA2a`ozI=*-58m-?mPyF7YHMbY}qdPw7xy_qW1BblnmQAV=(=`f$tuE&h znI>(=3o1_E5sD9@h-B_^p?Jh#)lRN`*N4kn0_Np*^kXP8 zUavckDKsn#*6n(JJGvd#@lFgoV2d5zrBC(?^)yCdqcV@m9z8t2eom7{yucamSao%b z*m$~3glSq3&+Q}dkkNV2`P%4ke;r7aAg~#3q7;D)3;R9=-m5-E2KDKNjo!^g$FRp9 zdc&n143f4yzP&6FtPIi(P8lE$tDyC`;`6l;hLUo?^6a87SG9+bm7^}mkSUHc0^~qZ38oK)OftHLvO^+@gurnm?YYRJLiM$=_mi|MgH&21h4%%KXRyD zK@K(W3T|C`!~4Q|%pk6FKabmMW$t|Sdc|<{V7NrU1@Vbw+JT~aCM%!nqSg^|ztEv0 zl}*>!r%_L`zCGPObuiDK!7{*|c$2%sy?NxCOOZ+mAET|1NAXfdOU6EDd2Q>c!!B9(Fb4oE2A5|drN==rNj7d zPDj{}kb?iUApca|}7Gf;dzvhOW_dg1*W zLe`s5|ObIel6uFB)f4$AmzwAQX;95 zTDz)2wQ|dD%b|mEz?K9B%hSnXAUDdod_Y5629W@2ooW-ne0&d9U~DkTr1*>o_aacn8+Nx5BXNZIt`PLX~UAs!sF15y{5xb zsymBP66_Nr;{=gT$CK~msRm4xqH;;(xLTSvJu8Q!Z`|3_EuXn-QDSLUT*7Qo({Cta zmz8vNCx)u`*AoozXwjgV0l$w3ur%za>jUE;y50M)UHd=&EBOUvbLW!#5t?gm z-H|#UwqCo6mMfIBdt-UDd?R2YjDgzcxBaK@w%;^u>cIS88`LG#GRcXQPwqoV@-Q(R83+?ydfnno*K@}h>Lc=-Aw$tLv)aa zf)*;!G{V%@hEY3SZI2o4Ueo$ae%wM{yY&3M_hb3X8No48{O7KooHW?HpShM>LiHu#GoniD`u#kH$!y@7zDR(!4Iif(a4EqdDP67_eH) zUKz12c&OJfG@0y7Z1F|4p$K=>bX7-tJ&Dz{K7@26MY|FCY;Lq7T1cnDf50GSRDXP2 zq>7!_rQi>H?sU*^_%XNWz6FKaUyti;u>G|(!tj9e^Z%p(&E`DGSDS-_Gp8mh87qsj9b{=a{2iKg7Zf$B;#h{+dzFf#i-{Rh9}@3k5E z=O1`9-k~Zu_h>#Tsp6dZ2e0Jszt406S1jZyt|Qtj=8hMaXiXs&&`gB`@5>L|Knh+(d(3Q12k;^_dhJ86{(fuy$I01y(7`pji~i>$u_U}h^x-lJ9zQy{ zdRh-LTlMQOOsVLF!(`aLgFOR0UKl3e{I%@eP^3e>1M6rxuq#3g$RdH-N!ClaMy>fLK1=2!af=(k`9| zIF?#pGA3!wyB8}iaDg_LPK}2*1e`0{fi^#`_Vb^t-iLI^X3u*_jDjB##Hf2IPjNx^ zENN4_%H!Y<;B1Nz@YxJpAOzw7Nm9}=r$O}Drr<|8a7&j?|LiuY*1!D<; z_u%s3g5<1qVGvvd@|)TOj9%FxIS!IJvV9u#H@wc0vFoS@=1UX(gGKh_EDCv%uTX!# z$k(D{XZ8q$e*6hK>HbiaJUmEs;yD6!79#srW!(L%AY}ZjKh3?rZ@GG}KsX!GrXhZ; zXQr!8cvgeuEKO;O;WV}LS0hQ8%wksG^rogts<9gUeJ^4ZcX$2X1S^aOVxX{ZyVXKT zB%jIDf(+iQYga1}Yksh1P*T~vuJTWAQgTvc^DW`~5yB5TKV`z~$0w07MfFE1YA<#5 zus5LxO`P5|NM-sy38rj;eqpGMDFaP8>gkpc-YMnnAVbr zZ7w8CWz&oCl}m7cytRwQ;fF3K!)1MfA(zE6+g`0Wk%c`c-y76U_AA9Lq=JII|L8LZ+Aaetvsq242T=V+cDLIDkDhNxSbsY{6N(QW7Au>r12<3U=Cc_em|7 ze2F{h=tM6!`G?N1%&B|XZ=i8hz9LIc9gS`Sy#XUQNOB7J02l;j9SGyI z0t5}qri6UYiPlYU!tSs=(9EEBC%ln^s=$yjrc58DHy(qB?jA|5)F&w3P)y{MA<`P5F{ zNo9^KJT_<<6|c?4-T%p?H#PVXmKVpUP3@zbR&J0;kpUhKbDlJQq2MQB8kr$$gl!(Y zK#E#Yr(paCrC3lja#N#2P&n`-9L0pAe&q>Jskusj8nR0{F7>yK843G^LXX-ZN%6?^ zgA=%4BOUBv?jNb7wE6K6S0l9XiD>uc#=7vq3N0w%YhEo z7f7`#7ct=Py{sjy5^(9B6_Z`nY5eKDpP%7yXKT7zZVX9mH-#0vO0| zUSR~P$4~n&Om<7}(KDYnv(cf+hck@Lo3cdn`kKuo}86J4t^Ffnortv%l?7NfRSWedI*SQcGO8H3^8i#09gDk&Ap#s z)Qhj<|9Dz*Cr<*uOKBV|N$5fTxFu_h+`WkHgKzy`N58>_bqb#zcVW@gvx$zg6xP@G zKQw@;hACJM@>mT+a4h1qIj2P;@tlAOlQQPXW#&=j9EI|x5t&-&hdv36!Kx^ zrVu1HsQ*JNbYlvtfchosxi4lIN1^bTeX|Eh%g?8-0q;H%rmWpR12)!Eo?Q$lQsjG2 zhXrYr90&M#7@La4{qG}&zl%rin@EXx4 zSOB}RAMg3HUNZzpM0_ge{ZnPti@zQy*2oUGtieeAn0{F=%0zw<<E8ig#NuJ!R4-{MH(C=B``A&Xb)tc zzbLqJ|7kr$@gf;I)hZ>2P3kx*!caiOw<)p5oleaAQw_Gxv<^KPT3D4ny(}Tm+PDw4 zE3X+0A9q6y|H-$%exHX}^SqqMfh^DX!)M`_EI=3DmRD0#W9r7xhJ#cb;jILAG|HDw zx9ig&h*~d5vp?>G*ex6wLK6WV{H(ruc1M+q{^}`VENzC+0`) zFEB~wevcIHhzY@l$WFf+5;JKgq1i-?eubpQ3}U{$N0>=5diF9u9OtTDjlDY z!k~J;zzZE*73A0;(St>8Or#O@fP0mD9Bh1oHiLA|70YWD=+GlmH;AAkvv?q(^N*Ii z@<}bIkoo?;pq}DJ)}mSwyan2gn|FS|(b)m;W1VFhEOncz?`KUc#B9Z*ZsJulCVQ#^ z2Y(v%9?yUpa%2(E=C7(R1h+^-(;151X==X5DUc&1#G3e5bzsw-Lr;We(5|qKMGSHi z)H-OF5KI}-qiy(OSu;_s2>h=zx#m@Ewg9eNfz%=s2`uS^sqIp^zUdtD_nP&j7bk&8 z&+M^U9wYMz{0;3vVUihnyX~LUK%YYqbYIisKmDiZ86qYZzfkbIK|(U*FCeO2@I;

4(@q=BHB=i& z(|khrDlcERq)*o>ghtFBuKu_g^mp^mDfgP%UpfAxrbzTxO~Us7dgU)tD+lR6M3UO0 zyI`GLMp9WH!m-V#`T6sfUDZPrL|iB5gWY3us#M{bBZ$ICpoBa+21E2?6^R&C{tkNx zoBM?#=esXpuBRbq0+bxQYGpSi0t^qQ$PjM~hx$;|3xb%NXN7lC5`OXLNC4{-ki zt1_5x3WtBQmA0!Ig-0{^>ebncwe$JDYJ5lmlmXfBse_vPd+W)lS9B--?y3LhiJGVw z`zcYXD^%dQ&>L#WLp&h>`ximR&{$-J77r zWgF6S2yzj+aK%5MIiAlh)w*t@IhE6RmUymW!rcn7*XznaC~d$DA|=gdn&+rlxG@Y0 zA$bgd>FmQs{@3J%joJwi>vXu43~g(re4@Y|t@O_(`iGOC-At2CZO6}L`?fsw_)8gh#jt^u8uu`p zLETra?^-3MFZm6??;kVZ11H)`B)p-e^DMz7k(oi~r;8$OLLqh*`Q>6S2X7_wlE8I!+Ez~}qub3kM ze)*pl&wl`+3sJu=P5-|yO~NOL(id_6TU&s@gfOU#`*C$G{Rr9rEFb#;RZ8ohCr&Lz zGOzeffS37eHT-K(LJkcJ28k9;}dn0(?{^{L<- z`h!Xa(Q>4jlT$f2Zrwc0gR*bPK=Deen_BGFj42!ehXC`9uXCxz|s6);>@E4K>*?kGI7#zX$&}SCE?<~?u z{|eSVOH3RKO|SymKms!Z^-9;Pi0yafam9c;s7O}pAO*%=pya*VrweE%pB#iagZ z79J9>hSJ#0dyND!bUP`Oz@Z3l(jyV;7YHmryalv7H|YDHN^m*dK?+!JM5p+QJmYch z;{outi5LJgU^CEbXS7BO{e$I0$Yk=XzP$LK`tpw_CD-p4povvmnF2mYviY)vSi^ov zm1PCACc1LMoz)nx=|ThC8tuUPU`O^0(*HJfu^BDpe(D9GT$<+=bgmv5&M=LWt+W{z zCy@>oSYj&Tvp78hF?m78K?-stk0LkVB10VFcDWn>QIT-t*XQQ>KR)+ozyyI)_m)`> zmjAz?%R*IRWJ?y8SjYz&pV+7Ag*+MSNW1Ml0E?x=)n}@H)zC0^xP=*Q1rbaXOk_sH zy#LW9CjRRZv;NP?1yy{u)Je<4={T#xm*}T|9h@dj^r{xP5AWKRv@sy~MeHpT$uw10 z7Iq8!XaCvO|MMe>rR6M z7bB9;3079^3H%PSc&Q+mL;S7~00^1ss_N^rBY@j?HOx>e7hE)#4qthTnG(vTz}GDSJ?&>JKg9~LF%eq!`>?nJ{l5y&s5Mu_5Fqdi zMD$TNN6sR~8A5_KPygDdHCg2OB5Mb4`gtk$yk@ox5ahH=a1FV6ps?q4u-ZRT^Skae z;>-cXzC(eYkX?189ELT6Qn=>}+zZ#idaA&o1F;TaOma=_c2gwwe+P&=%l_CyiX(`8 zWmjPk-Z5Z0wJ+y~6E){v-G6FRJ8s|RVfFRo?{_pf^ghT9xpwmmy1y7T;Eg;=iQWno zyLjuV5FFDyRvVt_9kXsPSd)z^wL6b&;%3jYdSoQAlJT?NndBbr;hw1>-b1UakLj>CHIRtvvHG*|x0g&n;(>@xRK5iw;SlX>w3!A|*F4xtM%3hz>G zeei#31kn^@oP~f|I2Sj;8m}F?#3e67yJ^gW6#NQXm_ra9La!C+Pvzdqx&+ib))5O0 ziOWNPsw;sudNE11Y4YDw2%Z32eER%$+9lz`A0JWoW)aL1T;34nJ3E=)<9qLaL=lUU z2_wvnx#-A`(1(Us*CQ$ko=MNGPCJes2sWE5kdNeS)LV3DIKr_8=qFW{RL6En{RE06 z>%$nQ_Xa*c{BDV9q|%SQUK45%)l}H9ztmZsBGo6doL?%hSHAm3@f;~=nyZoYSysoRN4>uAx_p$n~J|s!z!AOGU*SIK)^UMAIt>t{u=Q(xjCQ zz2Ba02=Sa`9bp)8No=olauzLYu_Sr0yE42R4j8jGl+UT1#XFwRc}UtJju?{FreLzS ztGd)^SDkJUDw&kB?tplvA(0W+bS!rsf^NjA+aZ`2)z3xO{%5)Q{}{# znWzslbz-kRYnHpHu$<=l1$xFAuDZAJ{c304tq!?R z3G$X|R|9bT?F9#+FK+K3#Y!Y7A20KgIaTTXR0$ly%0#7QH_i|2X7pDp8$az;u?6lp z|LOTWO-lGu$V=2b;X%=vlMMRPZ*_O$2IgJN&Do2cf~(cSl}DfaDju*-ZqEv_|% z+xnacor`-g1t(r9 z-!(3UN$1xCMWla#>u

2+sFcMfMNCoeL0qN+=Q4@~xNH$E();v(E&N7@X@1IS}6TD*A*8&23i^4rm*=9TS~) zg@Yr;I5Oi_*5U!Bg2MHaLtO9gWEcAvr6Ystr9eJ zeGQ7q!gp18X8Ua=go)5*^^%gp*o}h*-mMHR#hynE_R$^X&DWDN9$c%ew_(D*ufnI4 z(!{lsKNhimcxI#`MP<=T!O~IRdiuQw|Ubnae0CfhI{ZxH78jl z@!TZbKMEuZ45l?#wTrZ#M?p%vRJMD}Yba9h!M>$IBl3C?gO$mR6h zNdEpQtmiRDly~&}8zD$&~Y)LLuEgK=Jf`P1S~HU*&g>x*w=%FBk?$kWrD^ zuc3w@-I&G1C|$X(Y6#8c3O5bIJ}4uugK4tY`}9<2nBw7T;(YPj8AP>$-Rwhoebz&m zF336-r#rFaH3*LL4iYFvqs4w+10ZJwpz}8pz*Iz{#%6#?Tgj&NR+MMhl!zY&u3WDq z48aCX@?Q6Y6jDodj(#tI3{zLzVP47!w1SkwJ)dJ2Nfq&qon{K#uMf2w*QBCT%Ae=| z$xNk%e+dq=AWDywV&+UeVHpW8kx$HS{ktPwv}*@U?H=^q{_2%_2>wzB!@mc1jk{+Y zjm-ksD_X~zAxdyl=UvBtv@?(xRa2h~w_V^re`J8ksj8?mQJ=u*VH;|Ji`3tN6E|)-Q)h+S>&*oVo zUIU%ga)qq1{by@mYs1-mxufUbx1+1!%wRJgGq1j~K^~Am@=QOHyWsGXg-F`ASa_$o zM0N7pw<_&dNe#_z#WekR`y=Xcp>UE2HYs&5e_-=%w-W&cPyPbabf*bvy)wkhVsuF8WLw)wHe-V>VLInj45eTNxgjekBYD=1DTG`xk80cI z6@Rx(X~OKi2EL592~*oeY5mmGi(@VL||j~Q5;nghGUPxULSwt>_Mb>Aap-T{59iOK$9i#HtjbI;I-`^Ta^^RMZm z*4%4=R{%d_bXqgroN4P#mB4rKhMCwez5Hv)xV_@@IwJ@y(%Xx?cUTcv=_}G|kUq;r zwL`DRwMf=sLV^OMT5fT&(Wkg1=ci*@#Slkm)p{D!dK4eEs=i;~&FZ^rRnvi~krN%@ z9Wo29OU|~e#UdOrpRPuTTAmJJD>E6L&PhD|qmPnW*Ds#YK)N#u5Plb8K z4{CNIvKrwiX%jc9e7z_SpUQ;xt4DO71V^=I*r*b>jkPh%&Ihrk!mRbRK{2W(E1UyE zTUi0A2;|6O`4?8^8Z~#7wOc6Dyf^1;X^%zkUYJ7G*B+1lc%?MtI?3+u!?&Hq!*q5L z!#|l{HA8zo_=fd-2xOr@-6Q7KZ>qBs+B`l(F(0YM-&ZwcoMFUIQtQ*@D&TV=+I?K@ zsQ1PHeIaO!7U~Nk;y+5KewB$mBHkpWv-&lTVc@RUkGl^AX$O$TmJ)`Dz)t4!ovlno zvYMqSTc?`OXfhpbUY9k!Sbm4hyYjO`apJlVw4LBO^^(7D44PCk!1}p{%JUI4cLQ*x z)weY^a4CTTbh9%j1=R{=4rkm1V3I7ym>!=QaNsG5kYunAg4wL_W<>JgCoPH3XSQKE zWvl>?<3Wib`^SYTmx8{OzVRspZyE-QbI|30Xqk_e-Zk0moy-RSNhvw{dtB@HghABi zcYkW~i0K}oEqSfTPJZ@?2~B`W6!hV@k8AD!&|V-lKS?j^eY_=dY;q-A^&QMdBA82; zhapVV6wp|nEU>wC;PSPDqgEb9Sr1H|fzgwiuPBT$Rm<1Iqw#C^ zgoCbw#-Gyo>Ay$6R~v;(s7c3H-U~xlb zy-2b7l9G5AXWJIpJU`xqVn1$W)lSy>;<)o@m)0rWI@jWdjo+Z0y^SRwe}DTcuh2ru zGkG871CNnO+!`o8Rde(56m)+;D{`7Mrfy=HIOQE~hrXHNsOxGPcE|@C6T9FJ&%bLc zB|86uUbuTmuOVK?=)4~Np2y08R+JKwsQv8Y9VIUqmK|Kb?X_2hH7z=S`E0H!M!m`8 zQC4{);Zdul*Zt)6`qVrTOs_0hB;~N!P_D)%k z9i`0WILJkZ=)4?oV@Sk5*h)T5D35*ioW1Ow9gM}=w8W7SOwh@?W>%M>D@o@HPO)mv zPE0DOcH2jnH{UR?xahrGnjBl4!Lj`iCQZeBSSXca?dA~|O$xTHcvH$Gkjt3Nwq#)g z1%l;nC(lOT$`!HXAKy+3>>E!ab5)_^u?iETN& zHPnk)aUh@fnu+esoZhOls(apVIOU;NvhGuuAikqpjS`ly- zSG<=M5Wq9?kapO0I%t{h&1~y{Mad}~{VJ@_@t)6|nETBW=nBk7>Z4G8EQmRKU*m#v zxW!z;&cah$_GyHqrW88~<$wG=pFx>A7^~(n81jtOuCK`s947!YQ6QlG3aNy$_KDbp^tM;sY z9g5vj8%>cCP?x}86(-PdqhtL{aE?|ou$8Gc>Vx|iz$hP(N|JS;UJM)#NsJ&BHvLzP zhk|b&J>}@U=wLUJLAUPDv}#Vi>h$$)?Yw<`NNSxXcRgHB<9n{J{Xzjy*5;`Cc2aE& zkrDS_4ql$Bqh9goRAB6+Zk8A|ctrGZ40sQg6ZM6cGblKM6Pl_;czd#s_%QTtM^i#l z?>$QNyo4XN6g29sTcB44Bsi9yVo$P~=TaJ(ePypjJMKd^I0x_8<7XZ+Q~hEZC+Yd# zjyau}0o{^n);QWvY!*yQxh|}>zi6V<`p5r~H!ZQljB8G1n-_)zVMw?uIr#Ey|ss2>?)0Ep< ztxf18&Pbc$Lf*5OGOw8HV%I)`S*H9Tpq!ob<&JI;4*MW@+C98%$R>#=(RlaV9g%X5 zYzn^Y5Bjgw{JsPA(41pd4d+V`o#T;cqPMHiVtT~wjrQ@r|H}r{IcMWRsKF?ZS+m_M zSIR+u)xQ>38?ymLSnxCyVbx}VOCdUn@!tJ*7pEoVI1YJPl;Yc|kAF&)o(P~?9&;(= z=b(_MhSdahZQl-70=iFM4W}GKa zx*PLM=?~Y^-Wj#gx6f~KnjLVw=ZV3@P`~V#@|kKG17*m+ww_CJdO}kgS<>xki(6K+ zy(2Gb)a|3^Tm_bD7_kj=-Z4X6-bffY$NgK~umR6q_&@gFeUZ%>-LIFOpV>|OWso>4 zlHI0s{q|%o-1~41{O333)o^}wPn;g;WT; z-KQX66ks`izJLEujuP|`SjPt1zuumX$@)UU(QFB#a`{xjKXtPxdzEos#!n&{sz{)9 zQ}w?&BN9aD8_I42P2wye)y}WG#DDyVWN8Fzj{!`pcD8E`>j4;W#At2g?*yUTei|qN zL8iTMF6WhshVzb=i}}F3_&5Yo%z!qT*Bi?VAkr3PuAxYc->P)cP=5{yL%BwZA{l*N zCvo4fUylVlC0_J^H%{ez`Q0xPUu$T6*NpQxCQ(D#tKpes^H{L9tM|;t+P1_>kRmI|k+JRtFvX2@9z*nQpIV(^}D_XsJ@A`ha3Tr>ZC4Kv#DtuOLy$t+KQ@vVmX4Rh8ZCdq=m-cPCkR&Gj zOgx-=v>_lfef@MIuU(X=qrp9r=0}ry1--dWBXk^XTi@oa`j|C*l($ruZE9#6o!X_@ zKj*ULI_a1Zjdcmdpk4$^ZELG!>7BbkZVJqj4WUo;F(OG(tQ5I=2emTo7VBN}Wc*ui zi0^hP3a-+T={ttK%=U78r0KY(JKqU@&FBV#(eq4gd}HhxQOT&)tmc>heYM`fR!vhj z2O0CUNjr>X+uO2-Ev7OgGIGw02?svk!5WYNTBEG}%5GdZi)WQ-ot8q%=aGP4(q7e; zd7PBjPE0D>Lyo<9eg0fl^nq%;?%}aUp*3D>7X8MwZAP0p^L&PP`jsAkPf{vdjX2NT z^hiqsmL&DTX8Qw<&c=SUWq#3-c(}k$P<*;R%Pb$G48NCl`la9vNA{*~?hSO-a?Iuy zJ1%L39w-{Rm%iy_WT2+luyEIs!3C0-D2DK}gb749a-G?qJjNO#HDetf8E=y26wMR3 z0cmSG>T5jlXl&HS%LiB_f37{8X%3lk5#gSMVjH1e7A+Jq+a&7?jNe;A{P)QywX+`7 z6`Z*+e#RQ=$6eLAp=+WGh1i0?bFcT?9)t3mf2E2bHX%!5)fIa4h#;51?T^i?7d;d80* zy_KbDr^!J{_Z~7RC0Uc2ESwY57k%rvsH!H3Hc)W^dQS%+Cgz70hpZCbg+lw$+ELDk z-p95B#iJjTr9s>4C_{o$Z<9&Jbz*teLgittv9X=@6*{xdRHFJ6{Kt;BvN%M)__e_% zD4dL7P-uTeCiHoHe&F^C0e!;UcBdwLm^&&r%XT5oJ*D_=BW~5yAIk3*uRfhi!;i{0 zHZ0UHNvn{1zj@F=9bS>Vw)MTwpmsKJOZe4+f!!ASj0*|M-cKJODt7 zKPGzlVHN~5Q9%ERpOE#J@fg6U1U}F=oi%K%g{s7%uEn@vxyM z?1HTi`w)t{Y`Q=BAW4H>1ucxupmrKrTe{(=hdp^vCxym&sY?(omE+Cs(#2WzbK|v1 za2S9pdznmrk&#|Y;|2b7*wSq4_Ug%q^y5@727RqL0kpHK*$M=|#+{+u|K7=Qa?m@I zMKVVlBKFv%(NO1m9vR6F1{v#-|%qPfGL-OD% z183m}6X7!>Z>x;vLwjzrv~>?R1UwSd!;h+wc~U-Q8HMe8ZTogLHd;u==R>AiAYk(f z2RA^4o^{p1zstnpioQ^eY~%VTPeiz5~jM!VA z2QWpKzTNz1;;Zz#oKdXQ$y*ic9}eYKAH=Nf_3y7J{p{E8y=hS{_b;3_y$-L*TuA?g za>;U2G#9%*Id(jIIY}8uhJnnY3DJLqmYKeE1Yqtq{kOP~;8gBQyasybg|8*uA5kHV zJ3B3TcThqP>CP0m(Vd^$E>M&0JPW**zC_npe)Im8oy+6i?osGIA8?v^kDm}Plqlu- z0ZT4%n0W(wV_y8Ht9K6a1e=%dE}&dzSC?0AoOQWXmv@(P*KP?SRqjNixCv%M)oJJ~ zymN^&0z#dPlH!y!C7t)&vC7%A3Vl&GM6N9uLJn`(D=dIseM}9XUYe7H+%S91_*m7C zP;lW3oBiFEc4?77?SRuZc27x($E*q`^~Z?RQyzb8+FAZXk1TxY=@qLtVr214#$;$c z*-_Y8ADur`tOm2V=Cnch*wBvW8)xP#njv131Tv#N*{kKcgM900RqG_YlPGU$KSF)S z9mqe?2XCm?G8_ea&p64-bInFyza8fvufbg_vBke8(cN1k67TY6k`6=fb!0v4nM_MD zaimj-LF!$Ry}Q>2i&P}g<4+{@-x}Ue-eba}ky(1b$uv+jTZr-7nOw!rmy&1M{k3(t zx;S64&U-8M9tEtGWKh4~*2hBANbYBaPEwa(Tvq?CzNKO&>$HO=s1Q=|X4Mx)x0ygE zz$?9S(xjuMfbkmR!qe<34jZs3>r|@io#k3tL*uUQFBkC+-YYZAGC4C;W3H8eg6)Iy zO*O1sw^?;rc2!G6-&D&>KsZJG=4M(2@*{2}4m8ao@zIhBy~%GEiPJ=~%^Y}1zIg1G zzg=YUMLdNAB~}?&-YY0@Cszf}({NNDZshA(WYcYGuN>Lv%yufJIln2oMnr~ZRkZAR z4SQ0K_HbzY65_ng6W`$L-zgkF$T!xm)UeAXI6p#5UF|$w?Z9AUBEV}M=bda{gPng{ zLb6ZwuId)nE=|yXYpC&P!1{EQ5{ET4!`fc1r$x-iI|m7t+TnE})&iI$ml&^pOwaxb zCx+vp4Ngxjt3$>^!?DAup9Ojrdaz^e!;ZkB&T`14HsVzJi!#bUkM zSgV$%a|Ktx~j=BF=7=u%|M%_bT3&rZJB&wIn+ruPjl+k5T`1 z-SN=@1*=M2H&7eRR=yT9v~f=m#D#K_RjziJose_6U;`3dPq6p?Mdh`xIZkcW*4dk% zOYTXgYfSG+W{+8SF#R}c*l}m_S7936;ymu88G+VnybmIe_LU{d6|%4McVD5T!`~ZhQ@ib;P7Os4)YFaTL7Lwf<1%OjvP}K2F#cq(3l=Aq6 z38x)x0vmo$ta>Rdd!-xmsEg52#=owls@rO8K02$a#W1ao zD%GNVR1=uR{&MP&RJYqYr*C0hQhzUfU1S+rrrCOa|DcF0x!kQX^0mp!-UyOKSxe94Aup~KkGWQRNRjGX)e%gh%OS7(OV>3g0Ar3a&>*wpJ zzWa52Ej7#a&e8qjj(rN7;8Ng0PxR{oYk(2qZR1Zb1l+!MN7v>(p=iB$-Bb9_QHe2| zL(d~b$Y|icNKYhABGLZ=1&oGGiAyd|#azA&oN2Aj4?>txSMw--RgF5B>OuRmm73|? zP4;M3C-0`3|A)1+j;lKD*1my?Ac%nol2X!2NJvPhQqpCBbeFV(fQp2GAR-{$Akr-= zEo{0&>D&mL+Q5c){hXO+<~{E@&wQRU?|;r8GtS<7-S=AeTG#qsq4!fR0k(j{bnf$g z=kgrFIlkRGEmz_0P;T^n*3p7q-N_~sRoQNpbOKeRW3jr}-pLwdE;QPfj7ZakH_3bI zdO)hlwDW|)+O&LX$fUTVG5*T3n)cv-o1?Vo7@{vnqB}Aw$Fd4<1T(scm6@zfa|5ze zHT9;m&-NFZ1@_23Ohv>;R+@lixs3i~PfwWo>? zt4hDBy)FF}_WnDCCAjh&Noh+N^$5kXIR&2Xh>ksw0pYD4{e%zuavJnCL zt|zvmjm#XLzTDCWbqEbeAfQ8)P$z#uMyWp zoOm_A!TVgAUeD{-?RCoFsq1E4sSNh+Sopn1yNOzi=23o{0k5+iR#wt9+8xe+g7IiY z=2sAsza-@-B`)&n*nig^__VyJ;fJR8T5-ZX?C#6Skw@|~g(FT$J6<--W6NPj-6)ou43lNZ!;Yf#V4q9QfCHlsbsc+(jZQoYBZ z5r2O0R|_UhGjnZ_WCjk)=gBsO*UDSZ*&F}wbK>wno|6$Z=J}Ni9-|R@w`s&cnHwqf zmyJGNfg=^)h}KxvRAbR^`jOuL)0Ecf9jbHh#9Z(_eeNbN!3vPwcZU&SRLq2jg3k zU7I`>9?MaD?Ymdd9`$XdVdJuNQ-8ctm{@Qf?=NJ8SJ8}cV;xkFH6eyf8HfyMF3(il zRE%{2`+=(Z*9jqGe~>r%+YyQR?TGBy1?&0147wV&_C!r<--AFaCJyU|^A$jO{42ZY z+9d|PL3*?5lo~CBx^}Md4PNfl9_-dWv@c;jTe)T4 z-S=D(XNGh+<-Pj-I3iHC55knUsJup<`Iwx?a5%%2jA9NQ{q9k=0e3TknG}N%{E2-rR-QXFq(!cP3!1>V!$o zz_Q}i+AEEkd5~CoST#fy`M6%kTK!;aGX1goFtsK9b;x~`;-F{J{!k4ehAWOpXv@Pw z@{P}P-Eo`kqM_9Lp^jaw=^iWfYdrCs7j$YjxTZU5Ij1X_X@+2%N=1V1D|_Uor1TX; zoBhnd@*8r& zH|lIokmOyK`Jl9^&EW}Z@HWZ{KXDI{{uS{PR1BN+UU||4ec{jbh+XgVv5OTyOzt!9 zVF?6Qo%z)4?Ft?0@^3uDi}KzYmwelfHzUNE3cgzyEYSEJNj-(-b?4C`o5L-<%r8&kGOWysyPZSMLBqWxg9PEqBbh z+zU)fXg-sHp|-oNxK;+ob2uU0zn_q9FHxsxIfFOT(8dp)sL}FlP5A+QrXKf@T&?U+ zcQS5!wSI4vPPmmV_vrWW`l9hY*Lt;jNrOF7C)FwT++D&bx_+G)99VXtH@wG>Me5{%4!V43Ln121O4VyY zBzQl(jcdWSeA8oM(`a+hlqUIZ`bsD?0j~Q=x*Ab~#g?F^rN@x08r$We^bbL{3x(9& z$2Qh^FJOz7S3s*_8!2v88L?Qt{x$mK%;pVlX2M6|>-n!xx@E7synzyFfS7yOC~)8r zU%dMM@w!yyOHY}}>2pIRbQ=Yd>*Za~z?DO?q|;Vq<-23ft5<~{4sYoXvZ+Y}FfIgB}mX;}KI?$~5`BLjBG zEm0^q9$5K%kFNZx!s761)B8+)E?1RwiTvwXqZDoUqvxVB*~9aUQU%lBq8QD6N+TEJ ztzs~^B$DvZ_R#(hBwZ=Zp}ap4foTTq5tk&|@sRC%rI+{(wRKh7v5;Ekv7Ed_4M84O zH)d7e*i_7?IY21Kvh&rfwar3qLE}|&BxgT+#U|G8AaOo5Q6lL0$gw6HLH&pU3)Q>8 z@zs&OG4YIk3M`U&XCoy67?^_1tPi)$D{r_Gn#KKubE&X$ZZQOT&PthIk7ocgU)OAzmQ`2_Fho* zMpXp@Z_w1UHrc~|jyc4#SIA`8tE&v_>gsaWZk*gTY+`J5arONWI`>vB(uDEEiHmAKz@AeqL-=~cE0pZQ z9Z5=TkCeX24@7XxlK$>~S3IwIct(9w%J0pCnrlS~SxRYM6=}C(Uu@PotA{F!Vj0BV zJkxrGS|492;nJfaJR82r{9=&iPN&oBryFq@xZ5s@B-`Ipvj%Eq&;4DG!yb#3>}{yG>-zaI?Ige`%|{gh?GHSFY`@LJ6N)_pH@ z#5YB{geKE#{X7hD(X1t&zp_Oe!?sAV=qYA^WgCyk#_mycd6eql^g{dy33ok z6Szv=nahW7y7C1JEXKcJNx1UVEcd(%A`4g3R=|gyFDSTRI(u&Phisf~@grFHt3uP+ z2 zT83_^xAn(%gq^brE+nJ>*bv?JI{;w08X`U6Nq6M<)P$F@ z%Fv9ul^$lq{4+Z%i{MzR{sCh21_F_@{Dw@4=Y?x7NFWQ7^s#SKIs0%h)>g556SHHT_7>5>Ns=ycKFecytshN+7|Tk&0C zh05~{4O_UdB3w@uC~a+YKxqpZXQT>)1#U;qo|#bATpEEba)v%`J0Q#N(;(@)+WBRw z$t!IXS_Vev1r-x8|APc zucQe&Kf?xJ-?rfzvJC;`G?5>373lh4p&^wUl=0_jDvA3r`yz4MbXo^<5p;s{%7c4F zwnMs_l;tZ_epw#6XzI4^hL~ZPT}s1(Hs2LX%8HzFuG)EBuV`vaSDXNqt?3q z-fq-gUeuAL%O^6$p8wtp*bpTYoMQJ9Y;7Fbe8T&3vyqHlJM3gUlx69-Z^jFcJg0_E z+Elqve#hYGF7l%bpgp}Ay;ODc9Tt0Kt9Kw?thE70o$(;yDqDZa>2q5iuiMz#^IuzD zDt$L=aN)bpP!ygqwvL0l-R77{(0gAnbdJjraR+7q&DF&U_OUspK%XJ z!Kl?kW`ZsNqeomw3ERe?Az9_v#(oz?Et$u=2pK`eI`xdf8mI5l$!~^k0Ru*8Rs>i|goV z?*@5rb@)ep=fkg973VJJ>$eB|R;ra>_;wuQ?AqQb?0oG5=YA>r@loQdn%9tFr>djf zg40B`?RZT?{tE_Jj@Ed1q7k&FDu~W4Z@%`fpjEnh{(N3q#e-_0-m_ycVF@-mU;!i| z0(H1x@F))r{F=-RzBA5#3^$VAn%_ZV4U<#PDSZXLtaUeM=y)<&~+aH5mX16p%L7xet-E$SLe zj+VO+g&-rP1jp=S)%F<0oxMR{tf=huDqQc@yO!rd#SMwp$=2%8oCX%{i?wANro&vf zCVRDv_~tDQENe6mtLGovJ%%hBBUZWY%~wpb!qbE_p7k9g>VpZkJEx7v!1{Xs^0_(t z0x|}(%51_&c30O~s#300laoT@?xDj*AF-V)1q`2kav7PtN;9O%mua%6xtDp&5+sKe zoaGnAyKjnitK+N20|rKKMkKQlT7+*^AFKJK3V{BeskIpA%N1X5L7? zQA|q*hjRcY{`y|=LT8D!9(}t7U;C)DK_iDny(7RUbtZp;obK(j8;$3SI!w6DA0M#J zyQ%EC?4|7^&G1d3T!#g(>v&q`4rwnUEJjy>$r}#4W)HJ&&c--@p&z~4xEwls$UOLq z*RXPIK5hffouSXN2NN<%lDSiBD>xbAII`*iXEkF>ef)DG!ej9L@(S|+#g_z&z* zVFqT${F@@Fp6Id{&%xye?s8nb?J(d(`mZbFI9#eNssEB4n;+?>W+c`>dS#?=r0(MPeBynNcG{Zf|S`&j5oE=7l+JKrMRQPB-UW*utcZ4c($?C(C z>US6|(!);Tna=iWjwU=-jcW*k_sv|N@E_1=E0>t@E9b75PenJE9TBNJESaZQMChQQ zGY56)H&XURPZiP8)BF@FC(g4bYnE*2+jUmOJH~OgTLS|FdzYj>JPi4vaA9{ctW>*; ze=G-zd*sxtlA3HC&vip|SGB1%H+%JIMA;EWVp!qFd==-B&3`W|v{!b(Kt`y{oMDgTs^%8GmgV z%4lTvO3q%C_0PAQJ>ADV7AE=M{gg>&!dk zTyuBqH(mo2F-l0as(=^Ze2m^2Ig-ulM1V}kP-Q4 zXwH||mW;_+sF$Y+yYA*`yzm?3O+Tu>gG%sNL3@_jJWg5B>mL+Vf@Xz{t}1R?xQ3{bEMZmzy%#g1)v|wf ze3Fw}zx-$XEV0Ye~hofF@G^>s7u_>~AP5S5 z%PIQZ#mUmCkrkbL>$JIh3k>?=%sJm?D;bx!hewM?kIi8x=8CUkdt9VR6-NB8T&FT% z&u<|sAEI!5@==$)^Fyi~h$FA|sq=D;{!rl%Pnll|w3|*1(kUIX$j+I19fa;nXqC2( z9w`ygdb?S*>n%6rJ;duZj9kv6+_JwVOziC1%7-t+4HYs2&eoGXJKbO|Cppf1{dUim zUvPS@YPl8Ys2e^j{~S$1m7(V{eaiLk2o}yWl#9j6M2=@*>%I2%H~m{?MP={Zb1r@8 zHEVfkwj?-NCO8~8i^fS`bH%*KEOxS#(h2vAl!50?^F~_GqFgBs{W$MNVkH{csshrX z>4rg%S~xednyu8Kk4w$ixQjvNP|~|TeVM)IEydg76({>!3RyW9rqlrRQW);cd~}ph zenCXRr@2iXx9;I9wLrai@#gE3>wZC^vdB)0dABcNNd)XY#)x=-DE`itaNYA~|E6@1 z$#af3MqZTk34>ly63!%^woGm|$t|XKt{wU&)v(*e-AC-`Z+>N#OrgNrs`@LHH3290#QN_{mV7+$k3(3CH z*$jf+*bBwp*8H3fwuX&yS*MZ+9*C2R%;i$T!80=P4 zbslB(A)Zg!peS=iu{cEpTn}?-e-*+Sn)lBTw|IG8a$yriRW<}cBgVeCMxhf)ulp-L zP3Z~`3<5o0FklV!qs^9MG9bH}3dm3KP+^sXxHJfUJMRuY${f9zy9o_yJ(V-0u6u@A z_94?}O5a`g&V0ljpT|h8k1}1PTd)OrbMC6sh_iBCn?m<~NCSo= zhT0V{dF*yQ^(i|Ai352(SW<}R+A_R#9hLeM8jn@5PGlmpEjnTiNGFs1p^IxoJE>?c z1_9=&e*HL~c|S@t0yJI`#|SSBbQkz62ij?rsHL(w?~JL`OKN!5F*0|MjIpVLfFfhg zIOVq4eBd3g(}+}FJ{;%H%7NR${Mw~2En#JgzZM%@(0*shC3;`2M)%5SvB1V- z&om=3;LxSN1Jnl7J=Z#y9YMaF^RG}pcV2v96lQE^q9aMCUMR#VIpo_E&CO7fnNvK(WLj?&G;Z^w95x$&=q$Ks&0Z)n0Fh z-ZD?l)0mXkc*#$*D`4U3rvR4M5}++(a=MU^kUMT|cpBN9pCTpUl=#)WMn|PnagY9B zZMi>Hwn}j!Ze3xKwIf5&1$Q$T!3x|JySSK-Myn{eHjtY>yM18}z0(&=b)%0xtfbN2 zXW0u1Yy0lYh60Dj*na2>uywsj5W6v$;$R2AUbMeJ#`Kv{18DwvZJH2P5>+9WVRr1y{gYYH=W-=yJ=tCG)@=$&pG_|kA z-;Ye^s5Snm(qf&&vNRZ2L8q|6q9QG>^{m{rwYKPo<)!DaE$ziE%#S@ay{V#GsdROM zxVcOw=R~d^sB!Yi3$j0Pc&)tCToUjCqnEc~_^y(}J8QZ)qjo7Oo!PutkB)*q2h^A z;yn^_f@Q#JlzO{&HSoRWQH6WeX_V+#Zr)QznoNZ@QZBJ0fJdtj9l*zkFe_(&p)gCl z$V@s_=v@|(P4mPdQkkmXz_E?bo1X-q?>9u9=Xdr%lArV{At*ni=99f)_>q}w8B9dfps_*G@jw8TM-S=}?Y!W6;#96SuB+4|4?3aN0ScF8{hbT{9t z>RoxiL#Hy1$7(r@l?Nk!)>CVc^}{~QQgX7)$BuwIlg@Fk4_aH*3hJ9w|&XH~Fvm0Gj0E zy%f8rz>JVOFY%UlmHxLO&>zqSUWp?X$3H@q%I&dfRStByOLT{A$d7EuzVAw?4WYWx z0!fSwS)SmD;*Tt0^{s-AYFb-?rsP6CtcyUKLJ4mv(A7E+Uhcm@YsijO@X5nl1c&h> zPd)+CQX^|K2-%U{@q&r7h=TSIo)oJyWxL0*`GDUbAdzB+XHr{gL$2JLKJz|7-_5y& zs+Z<)L-*iPOnTeQt93xaiTVD|h>G*iZ-QNU3MkZJWibI?fmYgA>Ejhx(#SUal zN>T-USmTyR7H{5@{|tEk_~tLt&kspD23Uy(92YmNMo9hyd`}p#;08jO693`vAcq9I znCJg|l>YKj|MrjnqP?)ZC*gKpxKj=g3Y^*5e}-v4^yC*==u4o4+%t)1NwM4}FX+m^ zT~rYBjpzU8f1KDq><@xIs~i9)Hvd1yp5S>2FwtW~mf6k&MT{Q72ea#yC?OheNE#c- zivKjsFaHX2G6FjItIiuib8!l=?h()&H8BJJi~ez9$~S@yNB`~Z^0)WJO?Tujd+=OB z4!$c8W6HvOo7Y^;(ZGKm@e~9&^uV4d5CgrP{&CbZAbhJIuxt@1(FFf(uC}*)Jm-5A z5CA^z**^}b0pxQ?LO{XiKz7%C|G-23zvuP)L4hlio%^k=gM)`9N@o>~3!lo7oT0n? z{J4fRhlT|k#rfW>KIP7#EwB0U3n@LfEm^X7UnE8LX-c!dYLd`=)pR-gjWH9?0Y(J# zj@c9~jQUnRtMiZvBbpywK0Jutwc2{VGVA%D0=K554vEVNd*VxQ<4d>$FA>o)GH%1- zLo-y#mqM8e{^70A^!$)-*rRQdr@Zv|Za}1%Mo8AeZ@>QUNLv58y=j?`;7d&WkR%Bg zC;2$|Ft!=#Zk)V~|GfXzvwwC^Hk~+(N4I+AGJZ2HJ~HE*hTeH09ylZ2c7sLW$xem~ z{~NsEzy8M$9T7ZpljQ7a8sD%}zF~#!LW(YfCS-~~CxW=5PM%}>nk>)$zvF=Y>p(hO zeoV_iEs96SBHjexe)@47+SMe^x$hO!Te|L0%!DKTMQZu~2z2$|T-aee$QMO&8DIXS zL|)aJ0y19j?m=v#$pJiG1&jFe9`F8}fA2p*fT?IBKl7?IN%+~P>nr0%NQ8n~ClZU9 zZ3T;o$us}>*KngpzAXjh+maL3|EFR3J`9iK>DS~ST0C*Ns(6Wi_EX+GgHMb9`~|NMM>8WM-UqRK^z#Fv)AP?-)n!cXw< zZiA?G{_adaAV|g2lw!iKO@l+q)Le3?&H55Bz{5@B-O@U}Y;MvT;UR=dubuT*@Y-!w z_Sxz7G*J|_WH~%jSX^{3)-4w-mdE8UF5^*~mUN5aMe5B_B_jBE!GR=lRAkP+pgNXYAON@a1{f}S22HZT1 zUsS)M+@N|`=!g#VC;t8Kj>Jjap}&U9{j+mJ#n~f3al5_=Yeh_uSRXY+0+00rB~TH9 zlqfbJs9wX1C^@n!Vcn|0&GzA5rXrXMbNTqML#wGQXr6h+0`=7cWa)! zno!&pMHquZhP8w4i@7{CAKJ-adYhR+LT8jXFg=|;R{w3bQ#oh`~pof+?s1i~|Lisi*s~#PHvMKW``ZoJ!39&U;?a-m2EoR>^ ziOnl(zdy$R$8@k9i~L^8p?|$NkHKe5`SdupS(GL~vF&{4L@1zuCCG)GGIASH-kUs# ze@RgPbAp^Wyf-M_5(7YUg$l1dLtcLtmX^~!>dK^~er0xulMJNDSTw#7^=bKxHShoPX+q~pDe44>i&ksFufzk+$luOSj;eHlJF?I12 z@@afs*$DpbLA)IGEBgNOY2LS;nd?doTB<$R=LQ(HYH}p(CNzU&e#+KoHx>%`oX%ZJ zuoLRaRAUh$wp$2( zX8vy9Q?ob(sWT532IBPzQ^{&4QS?-!*nUCe>(|Z5RRwlQE~VX#Da&EEFFr!bu0hw% z>i~7wFu|@(uYvMlaBB~i`$lQ|0hZhZI)(=;^Ebdr6ur3zrhb-FgN%B+fhB%x;44#z zv~r>rXJTz_T<2&&dkz+{Et|bak6heq<1$%1;)%2T3znMznB?HKP1`AMYH(#@2CJpe z4B4}MA-d%bso&lSs*7T$&bdJ+YVSTpT%eu2?ife{Sb>-mzg(z`f_q%IsBNY8z`J!8 zu-l~8p$I+}0V-#NxYNsvwZ0Mm94kL86219A)KUba6K-=>D1qpY%dZExVqce=mRmMO z224U6_9i~>WpQWn3!6v?Vz&x;1P-%)=Py6E~s~Q`kfYi<4ty83=kkvleBvU# zlmg4R1lcqUkgrDarfqPj^UTfFDL9^~O3wcM@F}k~Qht``4B9DxeALWsFZ5A|XfMyk{MHv%?MDdWfTlyFq=f+JRF|e<=heszrMCT;xxRTabHrc0~ zJWID6rdY|06Pa1KG=%2#q6;lAF;evO30Pxq5Fm(imFzCBd8JtE;`UV}V9$_T&>@BN zTDHn2)Z5>*BljJ<%A|r zF(*!nH$J1_HYH^JCIq3(v#ES5B2rP#El*i?F+-x@#R~#UCgQPSM z+o*>>@D?|i~405W#wQ4(gpBy8I9hRjEtW`#=_g=f$ONSl_MNH!z+iVB3R}6 z-disp5P#r5dca^E*n$0~!PQ*mY(8?${KixldT=N)bNqoWs>8)~MSJc+-m=!x3Q8q@ zaENPs@`}?K(XDhGeY7`KSFMD1K2l<63TI!sg3!yzgL;EcJ2{>v+PTf@PX9mVo8JnE zCnZWqH1%Z&!v|raI3afR0Qj^y2MPi9u6PpRijjE&KAw+4ZyLs5}rH9n`Eo zX4drZR3%gmslmiTE=yMcgIl~xNQW=@^V893@cjCKUGy-c(#p}kS52AGf0QWZB}*^9 zk%*s}%29&R7+nz$#63$B{{hV#F;&T{rgy@>+Nod&tWRi#rZghnNWkB+h%mpADNA6z z3p2S!ce{D0vFPc0{D_6*Tt@6=V`Zwpl*Kg{FDJ(v^r#Hgk_{1~tKa>IalhIrWYIzO zz1YgS5voi*Wj@49ez{QJ&3naTJOY+v4hr4sujvv|oo0Q9c>h_`blcCq$t=isx9)sT zEg2%ZW9C7P`#mfz2Z8MSVWZ1_@iRQRa$U)n&r$SbG%equ=|bt(&^30roVFIrrN`(z z5SKH%@wCXz>}>~jDO{2B6_^FqUC3UPN;8OD0d;Gq;GHj_-s|)p>jwVAkeB)!|6`c&=*R*L~3)ZiFQhv&vcS(*fTqb#+LeAl`E&W8ei= zSp~C%P5lF5=n-kYOkXT@=x`G-qLV-9;f^Rtc!77vDbrcAPF-ERz7#P^4n}4v{AZc@ zUtiy0FY>75c?J}bZ&x`GJC)*Ly;>`12Ft|M{Nfor9p)ARCbgE-h0I_|W%oZOA4?WVe0DU*HG`8Sv#hFwnML@o@W(YP1( z-%@c!-t~L?r>1qj*`F6$2}tX{j=$Xh_BzAUJwsLHkoH~K66c=W)?2ulQ<+@mf74{S zsfhFpLWO-vR($pwC1rlSD|2H%As8deyZMBP1b7XhGzobhUeIvX?m2km-al*_na=O& zhqp2GePW)+Ct7}#e85e&!?Y8boOQ;r@tv0-Sp%OFMH7?Mx!rl<9~BV=ga|HbeX$>9 zh+u}-R(uYI<6XfqE;5M1GxOfl7x(V)CW%iX>E8UohYY$mJgoBf1v0P!C4Ol{CoOJP zKu3Ha1yNfdD_!E(S(iFT<9ik5(%#?qM@nnao$iEvh6fEVo6mE2aivgoJ#Tx?Bi!w= z;3eZ4+v`2+OF>IaBPwhVs_(WyoQ2p*wHadEcjnU#RLMi()(o!iC2T?hSq7Vx9*k^o zkA+ekt!$Dne;PTvvku;eQ#pn02&?GkouA2dTP+hV^8HjJ&djbv<(oXQ^1fn;%hW^O zomtlsde5D-4!nZ8Z@?UbvZ-2?K|CpP-=b&ol8VfJP9pfiVBtN8QE$Vm9^&L1e1*WeR}0Z9yLRu zSS5yNm!_ezW8PuLqKE@$w;{_%@(d&3&bpAf<(A}$S5R-)1(R0ykCe5^x_J-!3!f%m z!D`d2$Gu?0KN)i(_5pgPRS8q|BDAN&a8{2mso$1wc0=h{>`?z2RXC*7&0%<@D!j-jt>XOJE>C3oxOe?U*+h(uS%uOv zNp*Clk?bTxIWew!VqZj-U!%ii)$bs73n1~i8x#v-vy^|+;wX|v8iK_Xf3~rnTo?Z| z?xrL1I3wfB@}tF&1r`aBBW40nS|MJin0oHs6atwqw(vohRv}`I)7*x8y-<&$?Xf38CJ;U=#f2hu$ee+tg zXBu^GQ5>JKF3oz4`&EjocOS=JXTg8YlQ5kutoQ2j-pStrt@pP;quOD>ivLeZyr-sBuZ`|BWYNdUFgTHQ(FdFl;UAc7dTSQmB@E)zYI?h}wFyW?RB zw_1sFJ1=VY_}U*4b~N^{5f)4xQRFeVQqbYDwWza%VOWmNwDJFySC@XvtC(Rr`P$8H zMv>FFYKIQ-6R8JrlKr|H3sb%rsk8&)lvq!azsatjhL8qimF8c%gB?ux0-Z7=En~1g zvj!`~gLCb);{FU4AI=fAkv}QhhMP%nI#PYfW<$3RK?}eu+@(G(bYb;o@J}37NUw@S z@VR6|RsGAP6<2N${WV~m72NNfd5IP_GBg~cod>M$6g-a7+`i zl$xAzk4y0hC`rg;Tp*f)=V~NkN^o0vRcg@Z0B7Bc_}Ix(IY%Ov%KOlN8Sq$Ok*dAr zaFPzE-s-q$5WXT5&6H9*wH88W?l{d2EnFf+c8|FggHtYx4TB)MUO#X&VX`CXH&2XsEl zfooQIxHXXwiP_N(ry*iGoRIo%eRXu}3hC^qeOHP+VnZ}=*~|8J&Esn}UeEr{=v|i$ znR@6R|Fg?_GnV%IR}x(vz9VLw!f@@Ff6+7z`@S-@lzz}=>rFq7ZFgPDRXSe4X+uY( z-E-4=`{$3641s`38?w86!Xw5Jw9;_UX#Z&iBzK2q#^~)(inv+7S-8I5jCtV@-_3~_nPYDJ z5{-W6RLwLic(N(x=WK>8NV3IUoqCb`s2S{EZBRHd`7*PLur>VzP07n5{om2S|}}SFsg#_UO`DXh7^-`X+$vP1dF&!g z<|l$_y3mRLB~og|B;77}Q2Mu5Ce9a8N$fbMZ3cO8MUS&Bm{>v;VD)q5($2~uL?**w z=pO@$k{5cU6My3s^F;F&a zJ%%&oa4R%260T>6!ORg+egrSLL9TRyTgJ8++C2))0y^P@qwN zNF)Wv-TNG?l+zg7;mmlh`~C=UaWWHkz{GGWO|r& zF@yg(l3ia!I_nct3Ej!c%0k!^&AZtR^9i~cdPgHX`^$N;+WeUPi8kXmk&0&%l|&nm zSlBW6PWXWvN8I5(U$a`Ai($lHjb*aAQN}CZSI}X<*wXbN+y&oyt7H5@;xf9{r~nm# z)GVY(&0<>6_^T%7EHW~+s4d%az*c{R$_>u5OtmdcYZT`dnsK!3Kf0AXph&8T7(NSxqBRJ1t#oV*D^^5sT`RC{l}0@Ujh+YYd@U zCLUaP!k|6?O^#`3(3gZF4JFmn%$?s~-gJ^x8YNfvqmarFPYOWD_SXF_eAlD9q<2L+AybUruc4M9^1 zfo4jk(*qPH7D`Ft!Jmu7j0qE$zt2X0{mD)vTx_NFkdi(oY}F2s0NG~jdxobzO9DS`k}u=#3M|roob}(1 zFc!rS9U6C=rNWM`c=x|ofD>Lb=blU#+gTH<_c7Q~?2Bh)I+ipd-N?r^-ul-?!YzTc zFGZZFPsO=X@oz}I_p0^75>c588GQkKOLO1t_3$& zC7T-TL%ke=%Ui%hPz`m2k}A~8Xv14gN5{POOrcv6X7SykTgiSVHo&vre=~Z{qo~h5 zd&6nvjTA}v;-e0Sh;>xQ$`PDF0d&pH?USfao3_0*fv?-BBZZQH4&UBnz94gdeu6Ou zCDEA6rWgv1)}m2KpHTmj$^&nA!*8q6DY&GKo<#|7q`SH);NYm%HO`heFR}aEWz#>- zW;3#2y9?qW+GQ`b75)u9!$qxct0d7U{x-<2LtXswmH8tYBryeu__BdU_tv>N0J{#o zaRzxB(q3)wrZEkBe-iPoRQ0<`=IN)+x6sNo(6OC(oE?EW!efDck$CCt7F-AKE*KbR zHG*5|{bQ#o`uV?&!V_}}Xt906YXCTJ2Hf@o_~fO#m4>oJb;#~A^rq7J?ussCBhL!>2(G$c@(I3H>1+kav?w7xP2QSHrH>02xwoG}LT*uhJ zDK?lXxL9(yj;n|K!}ju&*cJ*?r;e!v6Kce$S4~fkO_tLBsgYDN20{?>HX^{YGK_i3 zUpnFtgs7l5-AcA@e`YMIQjyG*3A@Zcv0li46N+{@ZC=~7IH34jZEU(QL*~jAVb&tj z+SvKJf5bD*q^CGJ59sU@W$w7}*Nge^3uEQ#u~g_M)-uo}kwOuts2jxZxn(G~@mCZl?K! z@}dMb+3{P)xH_p?4dU*Et-_qDyw#Voa6_4c9s2u4wNggy8_KS=#V_@8Lp>Uyo2}pI z7qz%^t7PSJ@TA0?dl4vy(|pcV!#teZqUTCm4FHFf3c@q6S^5^&&~XZh@qI^t5x zvFy=s(J6KjL*rw|gS}{Uw6Z*Taz2kxWMAK{tvDe&)Q5d)%M=y&EahKoB)ab_w7D_{ z6~pV$9g52XO5KXBYqBBY7U8`!KfhZe?n_ahm6-H`lXp8n&En1h?DD5}ODjtdY2Oyv zxPw?SSh*|&OS+c71AF-Dcz#>Mii@T_`TU0PqFWE?ftB52#I{+@(9mWw%dEcF$X4Jq zv%9Fvl6GuuL$5;|U7BbRg|ES*AO74PzhWJ z8g`njcH6%?UyM^DC5#2yRBuK5*LXwkUtIGK+0ySn;Z>_5E;r)hx*Fc@2U)fzLM)_* z?d^C#p6HO#M&;K7@6IBM@b-j_%y9A%%eNhy80pZrp_eAx?`o`=Kl4uL9ly2DU*rRK zVELx}kWrHDbZNP$Z$31cq1N%;Ma5PsMDYWJ#1euE z#eJjMC_?gS!s2@J>o zYnwpwZuLfA;SD9%g&r1z*@XoQ5A@fw?vOT&w38*tn)&7Iwt$$&2-On0i{gq*wV#y# znfWI2+W;|UU(le}vulv+oV5Vkyb%Epz4}WQX8A`>5LJBd@99bLr*&68b_S}nxcIa$ z7xb*|1rK^1NB;m1cjdM1voS!TjJ9j#H(x{6(&Q zM~$|oX5aT4IteT%4qP7xC9Gcrrv-R)6P_SgW;HmF(;?R=y_?(;g&3&E5ON z<8S07w4BJBPmT_M{zHSxb!zTkdMC3>cb*&DOKR18)=isOE;mHXSzYUxJ*K|e4C25^ zYJMSrp%nCybe|;N!{3Q#J>F?x5KD-pSxUQ4harNai{@myrH;|La}`nX=0d}MKW&AA z|LK|&mMtBD*BL^7o7Q~uBOcq>vD%SuBsSdWWNIXqacKITV|ziRn&o)==Gd1c7xg^RtMRQn?>+Yl zG9>&!YB9g8O?MlKU#2+MjbTY5UXAyem+KHAUTZ464lI{p05l0g0jcV@(tbSD zA>J!YRHr*VC3n(cy`=Ic06CM!_h$HsydTU_=8swPiV^KBV<77fV)XZ?x*(w8mgn?t zg5hUG3G&?D@laZ%=t95E)1FjrU6UNy#Pkb`d8J`Iaq)dsl8-J}LQJ#Fue|PXYJnMO z?tDZ*$2gm}U`;gr22+YxUU`SwWMaAd>fg0ejS=PJrI>)NGf!#p%?IeqK$6=hbc8kU zqJZ6dKC~9SKHqxFd3c^8)@>Dp#m7gF-KnOig63&Vc+s{9UxtR?ankTPkG6M$JT7rg zrg!+d`mf&qif@^dKg;T_x~iCV)pOf!eyqB3;?4wnx^U6XP{edA;~Z#g+rnwZJZ86Fe(pGAZ1u zwU*X&lM+}ylucgg-)yy<{Q3>*oSvasp79O?`*$!HQk!#QVcLAz89;Q22?5+AtSXRu z^E5Qiwxvvy#+q_YibQ+IDL42xWe;F%=*AMug@+|}r@O=)v3Zsu2J3Em6s4UNjpa6( zt0Y^MpL|a6?TkSX<__pfEBK!(hi8C0spH7Z5SnhBtA)|LlFVbzzr;W>Rk1~V5$1h& ztT@Xz7t1!RGsoFp29|M13ICU=NkfEFAr4-(Q9IS2AIS|xJpcE3_om`GNW1yRNFbn|_p2j9?3@ie z-4@zJ?Yq4ZP4npqoG7t_DjbUZd8 zYQu;#ZP73>Dt%rhn7i!7H--C7E)mR|`}{Gw#v6>(k7>*!0w&dDnPaKe%~CIiqf!!b zmNo~(BF+$mj;f7wy{!Fto$yw9NocZXm60OO@SgC~O^dF7-?8aXL-f1!e-Uoo&1LRT zJAShM$PEZ>c=b4T)%W)GDagyq?^IM&=pz1KIz?uxJJig}(VvDysrxpg42FmEIoeD? z4)FStO4Fnw) zF0N>}+?4ack+T?WjdaEkO|2F9(a*;%Y%przBM-VkR~JLczKZBs|Fh@dr;xl)Yt8Bw+sQaxA7*<3@1pyH$=@O8Rp<6n1Xep7-kvyhmoolV;$yx6nbK3$bus)DX;(st|4dAUM zT5AW|#7uxq>Go-l?qW1!uBSAw$18Q7C`P%B)dpIh^COPXW5;tW*7|{~AOABoV!?S` z%t@8~A9n1m8GwUkI z%j;BQ`!5ODKl^WtuFZKV;9pgh*YPHv{@r{48RCwqDRQvX6bS&;-gd(&nu`oq$i^Fk z&6V2^mV;wh9$b2~u@%p=OW$Obd~dINYCn}(uI{1@5t%421q+g<+v89sxFgYKt6c+5 ztV0p6byDxzmPZsos(s$@hgk=d86Nd2k&(Y0l4+d^ChETG0j3C9h@7m?&Ss0^QASdU&|;`-9f=zx9V^vScYmj5nPzqdl57##JjfhUlS5 zJIh>Ac6J!l@|DQam437dW+*2{3UzysVe8XjHjr}Ca0CLcxG z#&j6s7qWkLDdIuLY*z6Jis+!6;(<-|EREgNFlwG{|K;7?my|gjB=Dr@HNA1%^9aun z#`nceV;`M0yu#dy!yd#c?5BOp44p4+;%po`1ft7I{ZYtKD|hK3`a)=g7~Xfh$elK+X6C zFY^;C=2NJ2!(AngE6bO! zMK-`r0zw2`%0(=o#x`C#`}|*x&c+S9^MjvEIHC_=`vs*V#PALrdu;Wk)fe^cm}$P@ zo@RZ~9cw-EvWKqzuArpO7JZ^6{-17Nf7;DBlReI)iD^|s+q8csM1Ec?k)a(M8|_Lr z^fChOXYMfFqk0F#YJ#qe)jTxL(&q@Pv~IuW-H&#%sdL5NqXn5~(X2Lhu+}|3nMG1N zu37P9kb%Ua|A3P{vQH~9V@s+L5t5(^WO4CL-FGGAqLt2TL(68*40OGD?GlxE9ocXG zm&*H}K`HzG$O(sN7RN3D zSyZgs+d@N<)J`L&=3bo?FB(c#{mnr`$1~O?%f-B#x3+P5jqSK+rS%X4PHa%lQ2Ft| zwdz&oJ5{Lhd3XRP^WNc~q|ky8Q6JHlrFB6hiUopJ&UUDpj7sE}(V zB!Bj~_h5FIh)**yC=^h6-@MK1b|^LS&1?N^tDED%P%LWh@e+vUx7xLSxcv`Eso!{XSh40yOQ$IEYx>Bv^{nC32AGBWHcK6qmC z*=p~ID`0kx%l(8~=oDg6@2WmJhA@ee*+O|gS+a|RdzcKIY(lh^dG_aY>$Tg-`0TM| zqyR1*F*P_+nC(t^bR%L1UB!8{8)da*L00Pu^L*~&Zod^1=B9#|W?QFYnpt3ewzS`# z9)uB{?6J>r+U7`>`prSdeWH)N zAWJWLZtp}1n%(Kn}tjp(2w^mmTP~((7)8D^N z+>=aB0I7}yso`{|FSN)>s>^hOnok5cc%jtXBV4w7V;h-L-6=*ON6ZooC zEYm+S)e%7kvDvdmmnXv|!4h&qNOfUHQv|I911{Ypn-3}14^8FU98DBrqH4Uxb1W|I zB`mRlL5)-<^_=UXLx*rN(pCq;cR3|TA-C=#uzAP$Y|#Lmub>x{2qJqhXr;u9THoC# zBTI)8&GNQ45q+^=kYqy-ui(>Q8BG_FR0YOYiVE;vC5qXn)th5)rZ-pEkNldN32P5C z5VKzG(#=v@sTR7)*R3YWSqG)3~^ zg;U6Yi_lDu&oJ$lRJj3L3cqhS6lQcap=(Fd+hX8j;U7>-VftW`aJ2=vT@R$P_&1dm z)|`mqANil6B|avao9XoH8OVpqZ{@lNI^0vvHVyDnrPGI0BwFxSxoOJ7VZ82Z8-U>F?=|dzb+rU{3}5IhY7qi?k9(p)CD+)q^1!BdcA?^ zjCux8CUdms2E*1{KITZ2k*00-E!R%fcqk40ioy4Qd0j!{GMWcT-l>3A4tXVUab9Qb zX3TYCk(3)iNS08FYV8M?Y?q{34PTVF?Lpi^`9gy-6%$6#<)0;`J;YbF?_%*A)>Y0P zDC^Ji9worJux1Cbv;(rGpC9y=8Pgmc;)oof4%)x!l341j(9^SdvS@b&Qqh&ytHF#; zo7=_pTH)?+{&DpY3O5sX~BhUUwMj1`h@G6GtYN_+6?x_rj;04BsA}9OZAl1J zr<(V~z*4TN*^qKgXv~%9IGU2Vk&2P!4EtpfzTx`1Luv6_IEp*&s9Ez{kWXTGD`77B zb%RsRT8Idkg2V~RKcD2P7Y&n*Qw8FJsiTGhEE^?iK1NGOziT6mnfpsdiVr{8W2(B- zE%u$n8g@uNO)PbLch6WB`IVV)(zM+w8Bysy$%mkpA5q}s1}A9UcdU3Mq2wkBf}SVI zb39{?Aa$#drJuZMs-ceXc>Sm%#hbbKx!yX3sW8x@9P}~a??sdF>!x4>t-l5n@<-bJ zLFCDDJ9^Mh_(nPSo@&iR{=#pBnS|d8GgZG!<9z8%!K1>N{331fQOFuL|B)V!rA7}& z>r{h&gEclsKuHmWE3osAFf?a>rx=)MrK6z{KN29?eydeP=HlzZ3*#@It$+wm@5*b6 zc#e%%63g(}0tJN$tR~4kGUv4uN1 z;EmZkXk%Kw$vG@v;K>?rLH7iQPil#mF0un~t5(hp>SY3<~^R(;qrJ6uzeMKFN9E=wT!w$=9+SMX7>s)xc?c;drJvzD5w87bQ z^HrNzt*Axy)FJDU+nBmI-q2(6Oiqo)ze>6 z@SS@@dY*LnF4a69N|C!)FyiI#Xi${RBB2;Wq`%_=?gR}QOC1=uv#}I_PYx7nZ89k| zT$(&!OUc5dhKArti(pxHE{C*%&b0O@wcI6f4E>2a(*5(z)$tsH_OTbG<6hxsC`NU2 zW+bQTNe-j-;+8#W#*fC>$BC{)j8-q}gp1j{gmx(o zkO;!}*E0vYOx{u34Uj8CKmjOPhZu>7YWsF}?v+u+sA4_13p>(Ts~8sF2c6>UH?K&rqt5$T{a+WsgdBxupwxW2e~&RrzRz6A2-x=TV7F*UKx?9=xCW^PrU zSdRIVqXBTxOBnJq3Qa=jZ(}A+b&^%z_pE6!i}Gq_`{!IV)p8af{FdA6aH=AY7ODH1`RmGB$1Fz508k0I5MrHDn4B!8$Mg3)>S^bs80H1pE z;BrtL-ub|JR2YhT<7}SLFj@^X%fM#HmK{Plm!o7T7iFp9dQVSO`JUuO4ST}f0wLIV z5G(8?Xd}y31-UxU4Lb?nBio~v{}>EZho_(bUY}xaG8J~$w-WYP5%M45j#jh~yh@7l zQe8G~IO=wp6MWUrnAH*uPGoP2J3E(y` zD{otvmE>?6YoOboLB9O?hm1Q-DtO35&i73d#4WOZ+>wV345#*qH_e~=jm;fyF^a8^ zcB{@Vqc+-aoJ2)OIDu#{`62m*CPDjLL+Za=J;bJncXo&U2kjuVPpZY}pL7D(wBZWM z0DENql(>wbk+7fVXi-3z#mOCLUxD41d>l;hZXJcc*TARro2x?ktQrrp_+XWfi z_8{BCXq2PN@Uenk0D~X^dijEx8Ic8y1>l-}p`%4@4C zVz}tutUDW5J@X<=aj4;#2q);fjnB?blA$@W54A2A+SA+Rp=Uq4Tb0jQHC`D7%p)xo z@X}~nqSP?jD8tBEg7fmpR-9_^*LL1Yd%RK=ZldIotK@6>!7nK%re=t@IA>^u5S<>O zo}>QL&$ip86Zh_cV#_``EEes(Ze)MzH4pT^Mog5Ut?VEE-CXbYpcFkUNf^~Khw*nX z{JT`=+ME}yl{P^0kxU$thjtSMef zwcGA#0LsPwPE#M2VTv*@O$3ODHS%LSv37_AFE9>~lJt_xJak&bF-t^Cm8{m$uRJ8QstsC$$7RUBh>2w?dXX+<&^=+qA65 ztWhS?tLY_A_YJ=y>8C0ZyBtq6j5KGbrar=6c@*DgzKTh$T#*o*QhcO&KxVY+&!U!G zb2zn3$%Ol==y~X${_tM{9uG2CowCI{D5gvuVQTPM~r&=TMJ}C=FG(@`^b0M`^R6YDJdk8uad8dcF_!^ z#W%A_R3X4NTef1nz4fN$aTOvU9c@thM6r zbt8)Zf@E86>Ytf{uqU|g+I%#*TYnIG^$eP<{&V3G&?$Df*(=_V@U(OxV5k=hhvzxS zx2hFJgiQn}1v1BtfrqWC3v*-yI*CI0+k^N+Au>NwkDO3E^v{2eyM;MOwk(CL+w~!0 zLJan6Kin$R$}0w@GeoqjMLih*a%54h6Bp@m{OdOLW`yEL_BCS8rP(YtQoDFTjpP)u zIl`!eKd&eD4NPia!lqY2RnazW&-8xq$h3VPT!_K+i;$k)-#&p=V1|ih=u5QYUuK0& zQs11LuVRYulTuTX=$sYt=oHYkvkbmO1yUjk@|d+k7{YoY1)hUFTXYjLpq?A!l}UqV4C3&iA`f*Z)gpZ z0WS2i=ECxLyCBhPNjrpR$}Lc6Aw8_Lef(}s^teeo>*+Uui30QppK~kjOHLxHXS#^j z4U-;)qPt*mfm|>m#;4W_CWOq}Sjt2@cZ+bh&?&_DEyA=i19+1U`Rq4Uug@djH{m`v z$TBB4KT5Xgw8bYe4mh7LkXaP5>iY$*rk_3I%!#l1Uru0(kdC{zySa8fs(Az8f%<0q zUmjasI5NLd7igB?QVq`1e!9R&j!LCIAnQhJx7}i<8F*KXzw_g0e))l<7dv0k)80MI zGQPgh8a}IwEJ-<$Yue#pE9F|D=&=>&nT5Zy9rtL(8+F}%Zr=EzJa`-Ce>6$tZ!?@3JY75E!Z#_xZH`!ST&<=% zc3Y~PLL6@uS@#3N{;mq)InNKf+D+fFwEMVy!y91XGOcic_gji(t5_Qp%Rb&=5#yP? zt@p(k`*Rp5piTLs9>SV=F4k)mw$VX$Yj@7$&h((*%CMB2#T9Jq_n{6|lXH7jJ2A7y zD52+0-ux2cKDp@bk$_9)e52*jFMsGgCp;EJ)2MV05aUqCHzf=&72tZmj_aEVw&vek z><|E`iFWiy{PJFo7Cy0RK`)Ed=CWR|FRT}K3AQExG5|3>-FU6Z5Qa-;R@O=Pl;q7= z*kJBatDzQ8Ge^`(MHV*B#F)xC2a+*?%H+XP#p=VQ1|TowRHDn*j{76v`9#po5!4h_ zdDr+%8{M7w>9+CdzRv~!pyN2v3|bO)Hx4BG_c%!l-RqcV&Hh(^W4g(RJT*JA66h!H z^?qBg4B%EfzI09V!!NiJL1WD)CK0(xyl*u8*4t#T?N_3iDktaogjx*ak;|!)k#l7B zpQj?w(q(FOyCXZ2Bk#sPQ5>R30@^y5BtB*hma_rmE;M93!a1blA=;?ET%BVSNR%EbaeRp!-=*HdyNp<4=E&C;tS5}R6 zeQpQDZ|glQeyurZQC_)K!{<~LdiRRqHHQ*QobB#G4%UI~VjXS)KZRCET8MD+PFGK| z$j1~5Z_1LgQw_SGFwj1mW}|<0k#E@{)LXmUa))7N1AW$2AH%ZZvggew5tMaaf-iD8 z9Yo113<8Qf!nW}u{YR?*a9GD{NpYH`du1y5qkDHk&eV(fc-!d_Onkb)`YMC@>+=xjXzXk=`%Cp*kO?=Fhxkzm$ncGAq z0Aol1!gVVb=MUcfU0mq^@dH;?{rrV*o=*UTZjK!U;R^V{_8ipSMy{xkF@3QXm_xm; zZ)SelXcl<7+QQz;c063Kzb%EBsa&EIHmYC6fUVGlT%51-@% zLQQELEpkEgL$`#96)za4UC15t@2Kki_*9{S4xG)~%fxD8b2lb4$?Hr$2_Lr+A-xd^ zfwOz2n{Jnyx?NvKLbut9KI=Ngor^srq1ht#+f+IlApc&>JA3o6<^i*0%>eOqC|?UM z4zuVh{Lc?uznts-!1GT827L{S04A*pPX2s@0$5h=@G_$N)n3zH-oU)uw|$RypZC4L zF3jqbJ*O1&-4hVs>*sR@)>pHVA~g(n=EfX$>ao1)ZQzNVMv`MzAw1AM^W2%!B%?Y4 zS$$LC@(ae5)V6yfMd#bx?p7TQ^YD4lu?_cyYU5uoDJ{+Sxm!2F2nJL90~vaC?~pbE>^t{*s>fYznG=obVP0Y@ra!_!#%=d&CVKFf zdi9olKfh0`%8fAY$ZCkKUfT@Wq%LFFs{PU>1q&u}Py(y=id`pX@&wyGu z^IYqX@tG^Fc>QTck7O0Dwvbu1babzfYwQNY%tf?Tq(!0^uwDvaU6Sfsuz--6!%Git9R&GE5nzu@9FFK_?qylu3a1%cf#xgVVJOlmmJsS71f(3s)OD+8IFV%++iT8lW z=zli=y9D)Hi6Ck$)8`&3aex_p-ueNkTATMqb~K)*p(KOwrQcN9&Dhko$zUP(hcYGH z!?FHe-FVF4_kG3se?5d77`EZxpCk$09rb9=(i^Uj!S0u1w7$oPlu z?7t+2|LJ|7VO+}|sA&JEUF*-p@xRUA`!|2|CFb?VmXP`1MZEt%-srdZ0P7ms`N94w z=lQ1*|KnWCsWCK~e(Ge@`$dZh7$~f<*-G(d=l#>%{=*;tH^eWnoKf-XS)Hze{U-q5eG@4UhVJE%?+qv&gZgfMP(= z3Zy5SVGOtZ5ZU!0V>^nSws7zl7mt3c-sGCbQ8GHyuHilt+@dj#q3z1xb;Fs!AiE%<+xv?#oB9#=W;RV4Bme*Phk4EK@F9G0X3N{N}myLjT%TFrk_#kADw z(C9b@arrAYx&g$3_FxGkXLHtdQZzId%B|kPVbZ} zfi0BZTM27o&*Xjrg4_3g4aqPIW0kI)c$uMA)7kR^j9hq{OMJk&F?l17fh_knWzpFJ zRqjkyT%knZ?e5Ga^d3*1@=>xbVsinyxmTPQfzUOuew;l3%4@2sD3~pTl)l~E%s_tj zndB*8fsYg3RsVQ9-*tz!0?HR2IrJ#vg;`LG&KCNrT!3n@_4E5sIZ-)??6%L%@il78&+3A2TKzS ztu+=?E0t3w|8*O}zqi7#zpgFZ_%9YNE=2coDpOK(#vurLP&x>an+PmR$aueCXVyBC z*L^~;(cm4meP}e1$u2jsP2bx-jAf5~!`M_hLmSh|BSC)xmrSpzwK6Vkdu%*&JDt5= z2uIc0g-@o0ag3~u&+BbbmJfZ5R>R18;~3COaL!HUl} z>tFJo4S3*#Hw{)kg?Xw5JJuq<7=g|GndJTQS+*2bonN@WE?IC5`T1P!H0w2@*sDYb zqJeW3v`;axFkEb7mr1ayFj2jIty?n9v!dRmM|(3#lgA^t)$zb}2ALyvbw+|)8jRZ} zZNgZRLvoWd^3#Bg=bsQwywi_hkCR{dcxB<5=w|falfZI~-;ws|Un1>JIKcfYV7qdY z?dIA0*oLuH-W@F*ezkeLYdDtxJ566s;!T!D7@*V@8Vv~Zv)g!Yz+y-u;_1F2JA@Qw zM05Fqf7mGCm3kcDWSlKRH`R!jwo3y$CTGsAkGb0STlw;V(S~ba_o)ON2l}Mu!TZ%b z7SWn+_=Nt^!ut{G!eX5I!=z=jUdmRoZmz?ho=-b-hVtI>khi^}!1${cCT{)8LtpXs-&gZsjK0Xr) zxfeTM&JM7Bv>>I8>V`&`!3@!&zlx%pM0uqO+N%JeSYPft*GTWC$3+Zq$A^Ue(vFI2 zeu3f7;$0lcPdo=sCc!Rk!C<~J%RRH5J5=odUH$Vx^IG^*e@7B;$e&93s>7iR#Jex& zQ6f3)uJcdvkT?rf_{i5stus9Vb>3>5yR-Rrf}OL>ob6MD?Li>AAG|lPyXp!vb0_SK zbGhpb;x1i1Y$TK;T?@{3c4_?hA}2>Hz{k^!U@r5Tkwl&rv%zSpjLo&VM-!P4XnoMW z>FAZv%M4HR^Mb|l>3fHWnl0{$H&1<7=>!q$>vCcr&n`izA64U$j!p3yGiP=Hin%i)Qs^3d?%y?9`&NWX(6;z^sInt`ro^zLbig zxU2=CZ(Qps%Tq^=IZc(hEg5=ZBKwKD^9JEKuIQfY7G^{~Npr&61bs6pU=b5> zB-@x2{C<~U3NqvB*_OMr!p}T>!7l=7JuscWlkLA0zBccYzhFT?(53Bup#nZh@-CUJ zyvXz6;~;kS4S#r*%UL7wbrO*uF6KW@;U3ti^#3K~rLz^EqQS!=z zL`abt;==;A-Jml{hbq1J>pxJp8*6x z>$^l##Dy1OTPMs|Oq_R)D$Z?bGiB;iF&^?(772K&;`46$ ztU2mHe^H(Dso*WHZpsu}_`88EUBXVz*G_A|0-ItRIW;{gY*Ba;SX+pA=S+{iLery8 z>L=!K%LJ}rnZfx>%F_Kt75e9xM%&s9U zz~-SFS+rN<)x^VG{~eTgtqCjD=$Dz}n6x7tuQzElEDg4(_{^5FD!gMcy{UhW*i3;c zs*>HMmCQYbuI^EU6c`w3l5l=FVjJYIh$e=Z6Rrw%egxD2IA>DQ92-M{6j5Vp-S+@9 zQt8G<%_^UX&N&@ZYJah%a2vgztFrR{$&d zRYt_Y#u!Gx>$F=f^YkrjOsL0V9h!R4)&(xj$P+C>O?AE*AwG-o;_BaJU^P3jCR4gV zaHptb@2e($NX&Vk?%ELd2oU;gJA-IBIA!{f3pTk=(E-Y~T@Hp3Rt&r$eFH!LRKLp)u9HTyk;JI2@1G7N6zJj3i$Yphv|GyiHMiI1L*NwD#?;@ zUF1vMzCyD!-10tZyk&9}ArypU*QuzT#d6!_%B^3gZfdXjj%ihX`fXq$H>cG8K0@*Fd2n(L>0Plq%t}8 zCnBk_6TL=5o7=&s9rpx=URX6}g);5kn^HoLI570xXIMs}F6G3|v46~gu7Xw!AY=6NArhZ3c34X#%y^Kw)%2T&T>iyzuHNQs270(*T&P0e`=op`A_R1gG z>bRUtfsZTqsYay%krC!oia#&t|64izRqh()-wOHPdF=l}`CPOu&$!>hvV7UGm{3z4 z;Qbzdr$dcEvRh-(cSDdYhS5VLW{hWYa%7zw(50F5i!-fE9{RBXU8^seQ^g6nft;Fr zshbYQCRYBm%xCLvxm}LC@k!-ckoshGLrd%o-g*Ij=Dl-6bPYJLR+ z!42kD<=5J@w{1y*;}67yZ+RosXf$gS{RDLgZ>H`n3H@%5Mi%fH)bK>>ACyKb>#v)e{$MiG2hEYbQYB8^s_PcsI0;tN zDo7`AuyChm_p$o%*m(HG_;x)2Tlrf6T)QI0eUHGMortSgQl&H_`^jo?TdY zifhIA_!^cvSxPvK?+0Tl7AspfMf|>0$u-|todm}mr(hTHN>ClN7^Fpe}|s0 ze@m!N|H2?y1Y$;#ssq)(&v?M4Z*uExMZ5Mp0Cx|}vefkkDb7Nkv+d(7bCJ-Wk&I!l z)AO3|89`!Z&&8O0tZPM1uGG=(+u?S?-Z?i-9P4aAa&cU~#~~}E+j)r+Plv^Aj9^c@ zVI{^{^JL?ulY}#wT^04CWKNMWO+sG{GTypSuzQ-l^>Fl@i3bbJ^-he%iBN@>m_h27YN#T z=Au8~dwqFn2SX}`m?xIs^s&-9yk<}aXC{>237tmmiVcxg|CF&@@fkuh(PPPPU0?9zBGwWVhIxq=YYh{8<~&*~al#81O01UnDL4cfHe0J@fkVz!Y^`Jt18ZUvi= zD0711Q}-=*(Xo?^Irc%p$7hovtgMV`WVL%tMnxcU}VP9X>h$Nf5y)|&0k96|7kZO7z(uamhglhMc*k zpEv}xWa2iCzhSf_;ojp~oOeabOxG4?D3wX(1b$QB zHmo>x5tf}fcGwoL)@a)%Z#Au3>XfODo@1kt0+re8rO*H|=`9)kyy~cqXArU9RCTHe z0%M6}9$h_1_5&Dfp$Rdmr_ql=el`IX-rB6{(0;Z8;~~G-fiKV06Njsiyw7xXvdc&ox1=(j>(^6senFiqzL& zJTO{+DRuRy(a|Wkt@h=3Sb=%A1r)vIRj?=DWO#uLw%=_}P~yI-bU5k``F!}1XJS&D zpvEiyu|M&GQOJEf<<8Aif<}3Gv7R_9`Qnk+mF@%_ zxs$3o#>k9>@8N;_Lr81DBVC)HCDtPr&++a_lu%)$y?^}T=El2e z@-?5v*50M**@Nkw-Gk}Y85DfH+w*uL37!OnK;JbA;kf)X3Z%@q10K#&wk-sfiWYsC z?$RA7nNQ9b_!D?r9-(o%Z&&>!&-!VlP0e>IUBDB{U((7d;;Ld(ZL| z5>$5B#V$$}HFTJtcat4tE%)jgt;O)NL@$;d<4;sLe-v(J7bxzF_01M5DGRRD73xeGxHT(AmSuJ7s+lpC)ijp2qrhG>_`0ERWN)YYnK6@F1+pZjtMK_wiC^@? ztIR^Ky~?i0ldkcYJa*BHx@hvIj#y#i+3>}MTXfk*$;+arjmxU>=4(Xyrz{crujMDQ zz!PSj+%6O=(pG+ondGDn?Ey`=(TUTE;Wf&Yx91cWMj`(Fdn?atzdse5(f!n;%l6Uu z(l+k#%)1P^0~wdv`Qm=XzDZ@U!?Il|YJFzAzNIXAIdi{lPgu^ zobET%WP{g(G#wj2xy@vbnF~vl9*Qg~4u-j2hD@aO z#d1v#pCERIKU=K|GdG!UT=*XJmV7f);|Pw97gVTV>*YQq>6*}f&7@jN@~WG##;%XnC_@*tpzN+6#_?ABu(bWJD44)b)F4LXj3Zm8yCXQ?yTZIGoh zl*32CXG~t5l z6@bU0kF9i5X_I_ot+G@3rY0WBxL=*6rR~RuM1st9wjBsrmVojd)%zP5SR$wW zJ7yBA8l!#hQL0ZCe7i|K>Wnw`S%2+C(wNHeXy)`rFabGpFhWRRiYr)CzE4&$qd6~x zhTha=Y?WL2SD-wOK>{mI3NtPiNo~`{^BUvKzDT}wRM`Y(OSA=}brjblyvWoe!jzv{ z2=AHEO2u2a_hsYXhd(A)w9VqVGO>ozH*GT0Gj2!JAE1f4>8_e_=}r=OE$x(%i$)ld z3MZUNM*Ylt*!fA2-hJu^S_%sqR|@K{okpo?gyN}?gk%}NAnX&UXzn!^)_3yT89@Z} zO&e@tJUux2%xhE5ksw>6^$YKIlbh7^cby`7JsN#XLwN~)rsg8vhu0~KcV4C4_Wx1i zeYP{U!GDu_TFV=o=did8MQaeP*Why-JowEp=#a1JKv{v1=MnF>)uPO zAL}bs!kg4>t8*?0r8i%))H{412x6@=o%t}Vm#*+d_>D8)_U}1muop&j7$$t3>#7ui zpAlpy+3PRH9#oCq_V@Rk?d>_pE^2Z*ANdX%c3CfSvCn%L+J%UXe^{D1`D%>rs_?|yfT()$uKy|dyQsqAa(y+Xh9Re4ASQyF z)Nt*TJLuTtDpOs8DUi>{#|8&xi`EQ`r6rAz4rI=Iv$x^IGeyl!tDU%YyHrD|@MWb7 zX@pXsn`H$(X=C%%6NSqtO8LG#1Vl>JyiV%2zgv0`Fm%iCuymR2yhLp;z`nzF9H9X72WO;_s^ojM|j|{U6!`hSuL~O5AnzzENSV(=aOAXEKKNQ_w6^&k+34C+}BmQPO z!b9*~+d>B4K6kzNZqW9a4BcWUg}1oeu9TTcf$N%l#2xwgH}c9gxJpqMT?aE(Lzl4O zzG^o6ZWI|DI4eheB!*`c`2(XAF=zDUGZWN!J$20ii=ohk3s(r>!2kA`uRZ(v2fBvu z--o|DfBzm2L$ZaOX<71jJeg;F_A*s!MPIBjNto~j-z`anV_x#7SJkF;VHb*qG1Hj~ zGJZKlQ7dNLZ>2tkzI+o&i3fIg37YVdq!em;^W#B`FSf$DYxt1kaYZ*bK3?SR86U`+ zF5z-+L04u#yYw~kQb36F?Tfx(?#y`4zF(l9C`Y#R#;wUs7lt z<;7XLXw#`F@4U=?gzOm;Ym~2+87i2YYX?K%0zd8IGYG!+iL%G>g~m;hcri|{_H4z> zU$L%e=10SsGrEc!w{LBq)ax?JCx?#V%k83+>uy)f?bF;-bwsiI^Vde$UKC6lSN!i}heJZ~E(w`u{ZqAiSEo%iJ=l+ zASP@S70da(R(cog!eDy};^u103jsSFi_&oJr~jBo{WsC=HHfZ`lR^7*b;*VUU>G+|rBdl-^VA6)LR3An3vCDpz}iWWWn z7V+$SrwO3DyZILz|+t;Bx)wYvAdA1RHZga=~ zM}zkr9BXTFIsMvVIvHXG2k5aZM`xXD-ZJ_mf1d5z+|nJFAFM<+{*(pkzpCQs%-dq$ zhgY9dYd3kWVjWy8#yPJ{*MB2l)4$|M`SczyO2B8tI`aHz^M2bSTaJ?rNyjFYn9qGz znR5N)dDu^xtlwfhhKO1z63S!9D!6j>gW_!58~0N@%f)j^$TQSs?{VffH#LO^gsD^w zkD}!X@xYTM@aRxMH%0rhrU$*={nDTR2o+N#^i*^+Yt8q!8ji0jVQEjWF7^*Fvk5xh zqI~1|DSfw=h5k#Zz`A%QF5VgTL_--LPf}b?qvLKj-fv`l+V!BCDy&pF6Z~#iC=vvD zp+FD^BD+aFiRD8VY_e5Ojv2O1{Z??1C>~#hCuVm-Md9kgQK;|(9f}}}an?;=_?;`$ zFvu@1`%Q-&(c#zNgCfvP!W788oPjFFICv@BZ6N+k@#0TZF2uGelge+TD7^{D3X7AQ z-AaML`F#9tlG=}Jc`YV`kGJ$~8r)qkWszCz@&&Csv@~OHS3%Jx}+83*?B z@6+818Rgj<-+eA)A6|@xuq_w{-wzO_6(CbY&qO;2D%2=5e1A#0mf$(_v`haR*Q@MV zqG084QEtJJ6jNb!L&S14*27_psLlrLG&W0f?o>MGn0)yYv%>BkVrDw8 zB--NJmWyuktNVREjYEzfD1S|UC##WXDNBEyRGPi))%5B<#W_bwqxwzi_|#L5gPy}F z)_x_NSvVsOrMK8CwxmyggrKR$XE&)|GwdFG)lKoXc450n7nKnrgT1u-|nAKZ|fE?`l zWjVYUmkI5DAqYxz?Xq_TGv}0ZqtcVaRJdMtsS*rtQS-Iakr9n;rK91CvDmx((XMvo z4@C8ze_WbXTn6xyY*6_$!@jh)c%Nxr86dR zn)L&Qoy)g${lS>PpmOWYC)pQ0br-x{y2vHnmZgGXr!ewNr}azSvHe*Lrur)%!m+9e z7%4a{J<#IyuZG74==NC4JxKKXaHWENUFQP%an44DTawcyy85^LAbo74r73cA$MV6J z<~zRm2B32<<;<1nd~&uJ7bVvCh2ODeeYK;(`^>3Afa-^xF?n@yG%q4*QPGFC)X(+m zAfnfagNoyub{eM3scCxjlv7_LQtFB|-EKm#N2Ibq$LQm^*{WCG)mZ;(48+ zeDoSolin{6%KD;n_azlAEG+w&^=Dw^2Q#3Ei8=X&5XjZ3NSpdlt&)MyazyVciFuSj z5$_Gzs97HezGylO2`7WT1K`GPgA3|VPvGnFgn#!b;<9CB=}{bsJ6N3_01Ux>Ff*&M zk6~aaqkc+xKJQ+*GN{ecTraMAWtvY02kX;;@cP(_=#1mXJD3peWfD-`WIu3Ijq>?o z{~tHhp8)2<|2Hl@xQ8W=kM}4vOz)51|D}faElDhFo!YdpO4;dhc&F$wo>e)jJ?N{&&JS^f32MkAJ zG6J!jv!i9brsmsPxF`8bWy zxpjjNz8~*S9a_9{IE~a{!_hE|E-{quym85C=cMp_pT{XYO=IkuD5n+u&^5C-Nihn8 zD8I0~#qe>=T`wJtNIjfhaeIaEdVcuDaVo(a6|JSEX0LbBO=#C-oLFR3_p{oP3s!yf z;_n8-g?rOn<<8P~e7;2Q>c*7mRg;t%n)0wKS9>dDW#pCom)m>UC3@(Vn8X0&w^vsF zG^xHTW3m40I<4Zk`?|yMEfer)5Cx!s>W^CgP7@{W;c1rG$onDesh-fcCm-@I3fwe^ zmcLZ{m3F&6YJO~P`SxyTk9v`0@tqOZT6Jc|6VsmJZKU+#_ZkSU_+(`H8`sor9XgY( z4$~CZT6xB@nSxiO^OltN{!3PMfC9uvbkM&g)dvFa3#LQH+J;50mee?M(Cax0LC$0K zhGsfDSz8&PwMs7eQ=Qw5ZS~DhhA(Z=2oK*5x>(4iuXrZ^YYcI<|NG1iSD_%x8rjJm z){zGtSEMs~WwiYMA~3Hq{Lza_#;m>($nTHcR>`M^IN~&2Px}c<5bU{y1W>yeHM3_;Nt5M-3$qRw+=Bkp3O%E&R_7YY!1t|j1rd_CMjjb z##oPy)YR;JLBH~RYQSGpm~rF%tkPLa=(R9UMBAn#YHIc~ccsv;8~j`8N6&M7doZU< zlCPgVXBjF6Wzj2r!(ZxxG=)8fRE8I^*8+dr(EqtCd8=VWE=Xt&96V1I94vYF^OOJa z{i}NEvn5s^Vn^E)jS|ac={lcFa8b?Iz6FKe(DA+HMX^lzU8nXu&4*Q&uAYdsN`1G< z_$gHCzrc&dee$2ppbBPWJh7;Bp80>|LVbpT^#1I;5;YQt?PbOKGoCAAv2BT#_B;pI zNTj^LuaWK--tlM5i>F9V7LQT`33fIA5>YUFqKttsjTop3ySpC6W=0*XDsu4rJ{M36 z)@*H?ju{r17;<^02^Qw}YYP64Tlu%E(K#w|gw{Uu#|RZ^|Tl-~B&c8f5$M{8>eQ0`SS?8R}DcRk0&KN-4~7a0Jk-xWfCd7kvce(OQSp|bxMI$ulfA@b=fkp z0TGq<*fw_HO?AJH-R8_cf8(8WU@Kwbf8R>!sdN^?99b@tpIm{sq8xti#)8tBEBeJT zz?<(gSY$k`R9eO_Ss8m?-J~HuSYmCG80<1(peP!cRGW~8%>&;b^|tqr148J}&;C25e5(JRr|0aeXLc5M)`#r{S)_f?Q+T|7K~q+i zLUGfcSE979SGMI|C@5xt4irTm@hxcbaEom1Ranwk`%Y>%>IFitKx8Tu%yidC61BGq zGCy;ki;H#WAA-k$Krc{6E1^*;f7-Z*xYbi^svUoP`*_V^ur$=kw#jdI>RX7Qdfzs1 ztZK`+6_N$RRrNG@dq|o(-46nz9WC4AE-oM9c}E?*Kz8T@Ztu8nmX?;~&R^xb-GLQB z6_{bpY};QD>$3u;QhKe-(iMKzk-VbA`ZWZ{VCRgHz$@@7>-|&YI}o_VMZWk3(s~2Cj8*{1wNNBp{D~Sfvms|MD=y1eu;_ebgcgH5^Kd(-~#%gd?MIBH;!G@ zPRH5h+ct0qwR{_JaBX|A%-nK$s`Z5>v>PmtA-`?nEycq>v~z~%uq=oz;OZ7H+$Gg( zwZAbLK3z~)OlKJ5^UrtXswz3gOUC@MXwMxv>$}qJI&7Oe#Lsu>QY6_sp6r51!gs|3 zdn)N7ei`6uZ239jd=(Om4^}_txA9eFHSE`|+J7i3n4Y#fQieR89xH6~g~OoC>{hlr zwY5b#s<~F%7yC9A4C8J?N-pxEpzh^EFGE6zy!>RRo&vLu|5g_NPteQ5|Ae*XKf#(& zIytO>A@VBiar{h%wPc?e!s|CD#wS&GHoN@eO!v32Re8t*+^Thi>(?NyGgyzik+KUtjLP3MM*O zXSZp0TnAQq*}3kPZ`eYk_{;|b82Jkip60wA8UYO4 zZ@Id%fX+P#HifS{df;Xy^VnAvjjMxpA-pUhk82p0dI3fPv8?XVZ z0t^c$oGKCw%EqmMo#rC`9oXu`=cpyO4OiMvGx%@Z9I@kl_P>S5|233qw*1V(DG&di zh0o^)urp?GI!mqgX|s&>6dTvIfwyB#Fan;ZM}cWGV%L9vd)yK@<}f!j)1SRAq<4dr zs(h(&d$^q0$T^Vgi>z*4jw)ldqWcvyZptbd$+rEh>fSO5abRCy~+gY+_{Ns1u|hrkl@> z#U7-j9vli`7K{V7Te85;#H`7GebnQ{0&q1f<9s2tITd})eZKF*PNVO38!(IElaP>5 z4pO<+kNH5ysg4eaa3>73g3h6CxYstj;7m#i^s9HMh!b zd$}+1@XJmRj~_p}f+X`P-z~BKLmcEL z!`TfL+2>d}=vkH;M%i@f!=Y!c%CC)jnARzi_n8lHixaK#tStuK47u_To*E78Xc7_U8EP1Ubnwzs zAbaD)8DO`xvM?lEc8itI1#CL9g&GqnVGv3|9_>#*(EWfVo}A0Ag}t9(`lk92|M=mMf8z6F@;Zq41C#&aVgA)$_Ck}VH53!0n@`3ul%ObY2s`!Xt{g+g>F6r&p9jQ0!^eTmM zco^6`xxb$X_MIS88$?v2&Fe(q>%R;$qbj-L({+RkpOTM^)ZAQ_k`W&dtf~<#g(yAbt`C0RRqDZVWFOc z(Zd6~d6-Q$@RJ?1Z=$Z$jeYDc4Lm7NJyyIMJ6d5IodWMs&Oa^u57-vFz@H#Z@jFN# zpjefcrDTD_86ndXV9^|HPy^jUhJl3ALMIz`5gW~Nz89X9)JIF@jzTAaFM5>t*dxv| zJ|-q|yetAjRe~Y=!=Z{b{O#OipFZJM z;=fxsdW=3GLFl1;9_*!|fbKN3?J2_V80jg(|3e6%uG$|TI&xRnaFS&2{fx-wRIh~? zw&fuRS#_Mj1rEPLbDZ5Rmb=ax8lA6$Gydk7E8L5mhL*pi4no601+G$J(y)L{qSbLF z4Bbzwyura`D6scgiP_}Qs}7W8^(Q@J)2>X=-3GG8=n%oi!TTC%(#B->oUFyB?a`&F zR-Uxir5&hh)(_e13@84<1~UG+fg~yZAaCdE-;R2W^Mid^k*TP_($C>oqw`=3u`lhr z^?WRk1`r9ClTrR;qf_ba1@qh>r^hmYF|tgCRYhvZ?DkR1z~}o~k+bRU1X^CVnU3SI zy0lvO+y!8mwQfB$x4(mzDYs~iQX#GKuhcEp7NW)h`$(>@7FK|pyn(GkB)Dlgy<#JS z{9}+vvU_gyxf@=CvS&IoVcv^1cnjiiW$bbb>MC%THe3NVVDWFdq$Rb&#%u0dH*U@( zy8<4QJFX@1bMPIIy-n*kZN)B>w+r08caIn9)f&noJOwx#WN(&T`>YMS7atf6pa;?? zzGxneIDo7#&MDG&yb16>x?z>(yJldfKvx}wgh;jl284)JMb>57+owrXQtG zx{r2#)d3(F*nNstZ_%YGvinxdmgN&u#9&Ehy=Gq_R)PnL3=Oq!U!XT)UhAj7 zAUpMqT%L?RC6`Vanebp`mv&v<_59q_%PnW!LOaQhy>0XTg|PIGpjugwY6@x~sJ=dU zjwkhz*%LdmLuqSQS3KEUs&RMC5Et1fPffZNB(=93-^69LKd0dz6~rKY{l}=9)Q=BL zv11}7?Eq(>-EW@=DdV7^qM;<*{^GxDzqE9WxhWI{E#b% zoWq-}XMS|~@Aay0`-}L?$py~RUSQRUa_k{#cU9{sQ@*Hh^AVc;TscVswhX8F?|%k1 z@22@L!3vYHO;@`d4=UZy)G*syjd!wNz7gS^_aqD#ZSxd_kX}3XisZG2nQ0l1>9iBJ zJ+pAc;;&H1(fMA_)*j%*$kVj(QE(wcr^T_{720LI1c=F4M$+Fd%km_D@=E%D@Jh6O zwt*ly&>gOJVz<*d=6`cCjzpkUUjm4F4H$w+ZiDbz&FQlN3D`ZhHq2tGsB*6so^QkB zD^v4JWe92rd*(T|9nti3u;vjWF;0z}w=p-?M-<8c!DPC(w7=Ko=QLRQ3cffpq@GJh z+g;taPQ+nnKpv!Z;kp9$_*pFl+TqhEE84ex4Y~=h{`t?zO^#FUVu*cWK zskzWvrtq`|P~lY-yt?_UE#M%f!aeE|0lBdDJl6 zu?eK~58qSI4HPAhI@QVs1&iF{Wol)&yiFMbS4ga{mzwJ@jGvmC;?ny{BOS^rssh{_ z8S41@X%7M3CLXw@4oN+ftmV;*vv}v8f)$&0`G&mc4ycpB z)+^&W#M-^yj#i2S;6I)5!A9bK36d^@b8%Kl3~KiBPGd-NJh3YeUsNSn9Fi)`dM{I% z-d*RS8tZ`!bpbM&O15~cDJhdCzjfY4^uSyx{zdI9@A;&aB)`2)ygprU`H|=ymxJeP zOZEWp@wAp%4ixAX-W4!E_zy7p*iVe6p^&dv5C8i4r{>zne~{h_J(rX0kU8$Z?@MMB zZo*_oT>Ev$YTfN?2*q^{V&Yqmk@vnq8weDT{+p5?>eNTU{wy;Z~avRS9co6A08_5-S9KX*w%UD@ZRvN-&0Ju-)PcYx4FHpdjxyv430 z^4)T#BhfLE{x*8K?F!L@{v?ZptA&|C%3=D;60_++4q6afdC7H;g+Z-PNQgb&*1Iyg z((%pxnu*Bk1iH8v^Vc4}qq@Y$SJA8tefJnSG0_tyKE^9atcc$A!ojLrC70?aPk-D3 zb!^$J-t;MK5ikFs#13WsSYg|{aAEuw zWiTK%RtZ>~B{?dQQ*mWuqXt8g@Pa2>%SE&?a0*uqRgHP!d-nDCz6m?P;YNVNUOw}c zL+l31ipZWx55R+7^WXDHD}GmTgf?GDxHY~tnun)2T=}%{KM+O#NZWFvc@L8l+Aa!K z-mkZ_p7-QSIMxWofhb!+blwBL=WSriatfet2A!&4Dm{WG(}_hOM5abRWA{VL_ZNp^ zLYbr(tITT*ZDWq%0G0H>Od;9=hJ_k4hDw*(w2g^?Z}@;Q?qcJ%qhxI$ zoAJjY4J^}|+U~>QlOJy>8g^INfn+q?0=TiC())@97OZ)U>OhNY$;SV1#q; zqlE0f)ezuNQ1rV#mleHp>s8Ks1fNIjW;2-0&|T-VwTYM6H}KmfnBoke6rTYV%ET!! zYjY|I3ZLZwByM%@@0?`IW@hH&mBmrq>HcnAzqL1Mt(S6B%oM*~2iD8Cg5*VPZY!#F z$6Ws0ED*ZE^RoR{u{Oy#NXG0*QyZBuB#XNS)s85#i}_y0=U(V?8z|&p?H6$~l*g{i z0@cGRemjr>3JpBRo)i_^yXp5<2Xw!|gOMP~Q^HljU~D5l%5jc$b9R=Xd~@v-#SYr; zPExP~^~F(fONoy+%{h>3IM4)u^W2Xd$fD>)oU}}Ot4G-(vz_i$Zo*(%K>;XtLXeJ8 zd{1sTO`@r=`i7M9^KG74sVbiu)yijN<4 zkx#QZH7~Mlwv>d}{(a77u=c0A%JVNT)#kuwph(!_7H=bI-<)@8f0A!V8u1}@lB@b| zhAM^=KK%p1#o76%j6SyvcEpAhdY&WD}Y za3)q_RlvkKqi#DHC&2M0irC(*+`Mb4DNOJ91&GMvFBSCgh zGOUhE16leS6bmeBKU~;n%@sD5RnpRmOaYd}FML%0zDOfG3K$Qr5oyjpy%EgeS=zIp zFiVFdntgruL`;dJ<2LBX&T)*l{_ylKU>sL zQ-rw3?CwlOYfOO}wxV~d<&ZSFQh;YgJ#fWGvon4hPQ5&PTXG~EA~~QpT}MSYw(UFn zO3cCIjgY3$C?L8JF-PfxuY+?oZ9D} zUjfeaCfTdC2a4W`D++_Y-T&DyaIdHGc?8#A95nlNAWZaPkifAp#IeTl6oFnJ_tlf=^Glge)>6>GS z0TG~hY&Z58wM>3R(96}C1NPjp=17T80>~4g(FKk3Gu8H3b#ga>f*Ael?gMP`Uh9ke zyzbW=)^Ne1KI9nm1!@?Srrr|m3fSvn-7(P2t0Us#B&0RUusgN3O>qm zP2rf-ciB2Rh^?IB)c4*czjevY-6ek~>rDZ_)FM%1(6}={>6z7{q;@FSCS4DYL0#-U zc)od?A@HIhLDjn9Z$HrIME0y>a<)FLi z&LJQmiAaEi1xC0V<{a^Q4$xeMg$15K<~t2}*nAOkQ2|>tH)tMSBy>SZ2;CgpijaL? zj1BAg4v5m2iUT#O<9zo@WuG#zn_SY!u^s_{gvdHoq_2p+hjxBwc4|+F zMf7<3iGge&FT1utJO&Z`#n|uwe-g3JM|@CApA3qE&%yCJM^qX zxxAKm+cV5jYH9S_JV5uhD)XRpUZwb_N|8E*L3SKW*PiKD1E{rZ(Kzel;t^TbUQ{E3#098tVvO^8QxujhhZ1)e( z&o_4q+;vy-fIJR@Ohp%683IRe+IPz_xq!{9h;BsWXvFS_ciLO;+LtuwiG2T<{r;Gw zDrnodK2-WJ=qQ_5M<;T3E_^;cQ_sk?^Oc8Vf)cloLUq zjv8@o=SS26wieV%`pSNOjl0Fo7l!xXy=a%oEbtzxlID7I+bJ2O9$?80Z<6l<3S}Op z%ssyVz4nZoe!g2vo=h!1JLi>0c6CSiIm0O=g+H-FPJK9pnyqyN4%Q#4Oc6ttJ-Hfb z3PsJ6qeF^^X}jBY9wgwLnvBV;PTetj?qR?w5-au&*K(b9$bJEk$B5bg4QSHg1w={) z${pb03*>$6=+;V!j@9hq6h&~xjGA>$loyDq{E(tgVO$lUWkm8sFAcj-#C3MB)p<+G zv6bszRsas->K zA6s_x{N6j>I8i>a{#&{-;Xs(bZE{$kFCY+r&P6M6+a85Q&}v2|8q87=cL$~(pM!4^ zN}5j7tAAGLLGFV5+q|b>rCCX_UCAo3LN+*Y?>S4enOr;uC5a1o2GXu%U7YUy`}dnI zK*MAPq*XYQ%sv7<%-3ux^ZV~@eMt>2x*T@sLuT}E>8)c z7heJ;7Xvbs# z*oKf*@|6nT31A!-8>(Pq`FJRJ!*d0CTuLnZ5c*ITXA^Rw!fRX3s}^j!^b_NMoBkQ} zB4@R4csRJe{J46S4M^Ds0*9}B2MN(22Nh9a@iv6EcNjhWt`Ef*PE7|V4)+VtA*o+w zI#oE0j|tovu+?V<{Q!WxQi6*h=X=Up&Rxk(aaD02T37}ly#4uc&ZwFG|9K@Y=y0oJ z0xHCXFKg)Pow}#9?{PqivfXzhUHdUqpnIierRNT;G;~pz-fsf{ib7CS-<@cr$#?Lv z-aXW!Iq}Ql?K3V%DtCaOaK)+q$J6B4wOnLU2JC2(&6l>%k({@Jd_^L7AJx}^99Mc< z6)qC_!1dr2a^6m_xg1Y$@dVhVOB>e>>m_wT_PvZB0bPIB%m9V-9?O~hCicbDL+mzf z_v%13i5Ij3AkiX-g?qN6wUrgQqxw!;13B|v-Tv&87NQ>Nv zLirT*2WZ0luJs87*22%-itT;C*MOMMMb5m!J_mc_R?*X)zm4}CnlwL~b9(=zljku} zK2;+A4B5@IKF5PaOi{F#RtF zoc~HH{#9X4H~HBRApWn9nVbMUdGe$jtk^^uf4^C=7Jg>0Llz1Tu7I+Gi>yxSoBH}>U)*G+T_$&>CR{cD5XO3!Mm(IGGvd@iOy+B$Hz)r7D*g@*2mBB!NZ{i5 zKsbB<2~)^#&mDQhp?Vb2Q*L<&G+nUBak%@K_Usga_gAsHU?BXrAD+y`pSdyYU+?Fi zHzE8#J2mi&YM5?vL<9&LuRw5t)!TvIfrcW7v;T6Fj3?2~cvLnAkK$gLO(fst%X9)W ze|=V0HjcbmX#VK61U+xZsgSVVPcy0IuGxH}6_;#Vk3D(fs_H2%^S)2FaLRv1255Q~ z`V-|q(P9SDMExpZIz>aKbFt095XCeBHN_(grl{G=-XH6c5(J$9b$Ze!Z@67IYfUC% z5X7v&q37c9g7GEVQtcM5Y;PJGRjW|7cHsx`u+FsVFgG{LaBj7%hE4w7FK|}uN`8&H z-@Q)vYmPO;Zg5yb+*wOMVKLD~b{0HK|Ml&n*H@avD^Zm9t18s|AGMzFXG$nO7&yOJ zt?#}huq|TG{RUeU-PDm{K*^u3@|Wu^ke3|UdN2IJ4jc%ir6^fysP6QwWU-tWU24*} z@9NasYF07la9p)Z!1%I<@2ak|*!UZc(gv3E^1aJZQGBn%bo;mi1#k|Wliur9YPCL? zvx)MWda&UemF4_Kc>8I`{3=tw{`iNVC~yh?zg$&b013P*qA2VNcEnvmnC|r&R%MU?$-9|x(nY@<5eG#U@j z+|x4R>(wub+wM?0F#hb(UcYR_)tU8>xYgeFabKQlOxjXcfe@tH=!uHm#+{h$oAlz@ zcULu*YH~%cHfeJwjlb{@|1McpyffDKAR@$L=%eGha&Iol@O97eh#{8#QK!0O?;T=; z*HEm9*&=1{hH=agL>Bf}I4-L|-iLud&y4@h;`Wa%8k2T?hXib&t7`Kd5zjjc~UWq#) zt72;>Q%;nXPe*q-{5a&sgfJo z%w1FvwwNPBo-X+9s;Bz~Co@M3<_?ztI1 z!cw%GRY!#9%khl1a9}}jlJ4xWdrBEqQHV@^b(oiY>G<@ zgct`4!<4K>#Hgv9f9o!NNHxw%u>sRtCgs`U>hCd!!>94G8{Y{!BJ!ETk*-%x)NwS1 z=`eM$k2yu01~{QWrB)s(Vb^}crEbJ1aj>ZgX73Iel5mPnt7h6-i`6@HsN98QB|!vQ zM#LUSQ=dFZ--?m0f3+JAP}q`B?i`Ewy_Fb* z_)De!A*XwMfi*B5R&|`XA!a`c^)%mH&cZDpiPo2t@^M_AWpi7%A2V9J21%=!?BlV= z4kH*Unoqt8u72;b!n4OvLsdifzC-#a%#l?sN)dQiW>uIX4)@jMBtM}obC8F_dcJtdJ))jX-Vo2t~98d4ZpZE-ckRyFKwYSohHH6 zQ$k`nG5?BgODXU6dD~3R4u^-gy|1Hnuh^R2 z$Do{70wKjjk%!OVTaw;a;kHfvR}WZCGtOGiEyxs;D)CVEX=i2d@4~%C>)xE6uAS=W z8mC#$?WZ+@I-+#N4-|G6)nCH=j|3>yy-;vYWg)D?R_YsE~^x? zx+)5@w=9l_QYaMQtBuv_B;&^0wMPqgiA8o$9HFp)Khb1{LVksHyjhA<^eT3{S1otX zsZDaV4%y^o#BPzADFW1d_!7h1*|cFKoI?J()2-SBlDW^)n98NDWb1a9Bc}&#*K?oL zw8)s88(d9gna02DlJN+!U?(}Gx4g``(eatT>5_~XT%odZZ0XQ0>U&L%ANpi<%9S(> zNnwy+4PB?fVjO&H$6v5AF3*1ZSYEOT=^i&?ztdUi^3+pcJKXaegYbpu!rUBxCn)nS z3LWB!STIZVU9HjgC)Hc}h13y{?fCH9E>SGm+#5LgG-X^|xaw;6eE%|P6BgfkvI^yh za`E>7^#8i&g|^)RSn|H70k$}!3$wL!Ua@Jvv59wjMb7Hv)%cZ!zVQjJzs@FB$j4dn z|I*HK;){A1Hc5Q2_L!>QBsKvjt!cduW_3;);3KznH_|2QdJB&f*6yqtz=I|PS^Uj_LPEL;k{vXamB1$Z`th1 z8j-kDsG7{<4LgQ-PJbE!xxKVCf2A99%kSkI_MK&|+;E}ouY$^JL4|l1(%SNNLv?)G zcy`Fy`(tID)5@!!i@9lfv$Hvy-xg{3w?pS1Zc|&wx-v|MIT(nU2W(K&13sg)s6CbX%^v74p6$Nxf8DpR|TUV_(Q5?r!6jtDSh< z`X?xd{NxZN0h?wJ#mOWj^63VCkaBU>KQ*|@IVN9qz}vxZ_^PJ8Uz zZlf_FeNPEKedZTYgnCIHlg&lRZuOMmGN}ZCt9ZEr*$&VVF#y4%k#3E3M5R@u^i1Db zUs>(%HE(14yYdd6*TUw?MP;@taw&;XyjYBXwzX2G))O%1R(N=xgiCb*tNYHt)1)Lk zO>l<@c6B__J4$-rHXsdA)tIlyu&E+xpKKj26C=Y^8^GuxK!rO!vloH+F*m>rVgzZe z1;1(D*?7)qFU7NaVQ4vUPLo~Lk)1};uWqw>r+G;*nJ`sE!t+?Q@aLZT;G6o~YgZM& z7cuySPYr}B*3#Z~s(V!ax$}N(R>l5DD4LvtI^9--!PpA?=6zFM=T?EbL43;twv{?0 zcQzItUqDCuBbAqPLwpHkp-4DE(fbN}roE2w9eIAjExARb-&;Q0yC z(;lV{oFX;}BP>vZe7Cr*$XzNgscMVA;A@|k^ML14eD$+dH)qfETb#o8KLt)ZeR>TB z7!lE~?w)w1QKWvynr&&D^i0^^u%22g^cHvvaCCb3lG=fH_A|uSo-^m1gKeu3YlusO zKYHC~y}IkQHrC#E;Kp}XD&{Jt!q-lKpN1N@q#m-mfofk*F4i6|9T|o@ z_aoGI$8dF{=@6oZUv?N*+E%K#qZ4hx_z zpR?g8>d503*u|a=Cw7(%jiuNp87V$QA>HAkXDY3g&{sp#I<(-(KSIJhny5&kRcGqB?4v3aFGHzgiN*Xax5 znoT4>rU;uq1r@eRMMI|N``v~dQhE+njRW`u+8ppx1b&=6v30;C;#m?V5}({HDl$h) z0gaN?Z9Syiry%=TMwJFg38Zkxadw%_Ecx8x$ef$wa?_k9JN@-H+;ismrh3}FiI$R& zsK@(;hT+Rw&#ty8oESmuo?D0haO;(D44cba@UTFvvRfjJ#pP8?-@IvO>c5ED z8mZZPTGMFkRO}X}mOEj+ycVK=MZ98rr;!(mF(T*BX=w*?LAQYf*T zpVs_NPWvTy2(ytgOz@);uVCJ9L$ki;u_yyAOs+L7{pyDix5sk9NrLm9Z0lrPx8ELd zYc_LRlIWGvKv!fQ-sw5euG4i$AkFVKz)EZ1_2e}hp(JX*6gy}IVO2-bsyEVGWE0CN z$=NA0WMZ1sNv}J?K-+Dt;Nin?ZQC!!6EU#6>&?XHo8UR4c8c)?evff|V@3S3qbx0+ z7^g)iD3VE6`(Smvao_Jrz$3g*S?eUl5x}Ka8_BGTcU{YDlibQ}o5XwcYPEd?pzgg3 zhk<*)kw<#iD3xM%Kzpmy*1AA1BI-&B@zHcTGJ3&kD z=iln^?9Lu(LwN>TF|s4JY-*Lx`3SjoTMa(lv-#`r#u{SCmQ*BWcc4gHX4R6(%l`{J z{BR03lD|-eDp)7`{k6Zsb&f_dF}U!rZ0H_tuVjDuj_3Q($H!I85u^Q$kun>nb+MNe zNIZn8DX5u`;yW_L3FI^vgf+0yhoWrJ$HUVPuEuq`uO80!`{~Dk1W#T zW=TJ3)EH}Ov)HT$F=F#Unlw0rOPtDfGGA95YE_WTL|d52oV74E>5-DrU)t8s2jGK~NJ#IrtX4mC6i<6EqR~piyB1Yl@pUvVM1^b9Gu~2AX+k zl+>knVp#*)V!*wgCforBh&Z{G>x^RXe4f>DdKy|1i*^G&7Y`&89<)8%UaKUiAWZuY zX<=}1q&@EW+JORH~OgR=2t(5 z#jT;|(!HAI@{(G?5FX2NV*WfGb#O}etDqFnX`K=H4yN{YMT$wu8+q6f+8W9mf0=^| zvwL3Gdt}F=CD(99dXg&Gc?@%kex?tu-Idyd3++h1G~{#^wUkE6PkjcaciEr>Ze3!= zdtP9|cIBI|u_ z?~pOI>9>itEE3CalV1hfND8&Xk(*N5UblE`bS z9t?gxc%IC*j8j06stauRB@rh~xS_y;*JtvXT7^^BhvAFo?vB+MN`}9_h8oD0`d-w` zATv1cm7vA!73=X5+6lKYt+kCtEkWtJgdp_KlB*M`dp)g!Bld-z)Q(X{QIY06ur>U> zkSFU`opc%syglmLbey=@oyNWmeXQ?v3!EDi{|sN~ooae0YQU2*8rpKN*FaUfbTXIl z;G|`61bi(BrVuKlm{1+O|FNFmsCT8m@Rnq$|G6_4KM1J);89U51=wg>Nf>0x$?O>T zzJi<2S*0)A%DJZoWqm8fbapydwLATQ}Kf2I+Hn1^kA132ykUP9Sa-j{8~{x-GeVSVC+=ZcizMX??hu<;uiNv$)DY&_jU(7I%eQ8g!-y>6qA=m zjq}=qikf@NTiA>maW5HT{D~J2gSJ!}n#PbCYVj0JLGB>yEz*#APaSMw=DR04UMZ(& z)s26Zq%?HTRLoz4hEg7yz!9nY__GOyhfLo4`n;J1VcF-%y zVw$!pdT|?)PO`rvtCWI&p=u2?O70=K%vGj-3|-7Y*z@-`sVP0<9;*3^O06w!TB*lcpzYU-0}?r$SWJGD}IuDLX= zVu`8lk2!lcJa=q;1)TTy?%LfJSv5U3fj4LUOA?-=p1cfSgJlTG9>t}-k9^_dxW}_n zGvhB0v=+pTshdE&$)H20hm}r0O;itl(-uT+r1(|G>3mVK+mj}Btjh(^`>|(3-2@J> z=*7is6jt`?Cj~_Ms!crKx0;H~al^>z!ZF>LIfg`WCwn@?4jf{ZNfn&+hTVbFY|jN{ zx%17dtLGTszRI=xp=pA@uBJ45`vj!AVSBGr$+*_{{lspMG1YI0FA%p~T0i(*A53B9 z?AUx|yYbrKnt3Oy|C=jnEe+}wLe1>Vi{&9Jwmn&QQG1~locpuOyaGuw&)gl&Z7OzRlzoQ z6Wk=`g}!P)A3wL|<%E%|6X8h{_!q8YIc`QR661nG*oJYPA4Xp=R7Z9)rS5S6baMiO z2qILAx|*16pG`)qair1jt+K^7&n21sx`JX_jVIl=HW;KM`ne-&3j!jpuRpuql3UfH z;X*F^{g2Ytytdch1YP)XBD$pkl8U0eOC4nq3-vFr^*n3Z&xmuwsy87EQhC5I%vyJa z1{t&~K(WBHd$bXIwKPcw@N&}!F_fKqhM9u5l%kAofEAiW^6_nPtWy?e+BnNSG3t~G zI@Y>O&HV_+M4CK>QQsCjMV~ zHv{>|$DEHP1u!x(x~6zuaewIOc>QX${f|#!8T*`2$C@u!E^q>=g;DG6p_?{kWy(+| z>cZ{tnz)){!o``+1=XfooQJdKd8g++Pzu5jMpwH!gsUCsm=t{oh1W>w!wKBmU{>sT zr4N<5G9E%s+x~;~tENA3ijfUWTSe_YU^`9N~P-z$&OV4A9W{9(b z$rG;%?v47h z38X3FsVfWf#?}x0MHuxxguvC2y2c0Qb9yB(-xixJVBerk7dJFv8Dmb-o>t~2WPoZx zG36I~FP_O00CUAkNQ)_bV$T{HXKl>423%zx>9p7Ny{rGa{ljl}0XyLq*q)gNCW&0D zl-}px7+%xvRB~FBz&PpWoqJt?f|!OQ24%Z^MCH=!?>9~2&x*4NNo&xWU#@Fa{!7j-PDte;4n*wXf8pbp;Y3O)J_*|MKBr^RyaBA2@4 zWbFrjyy2})NJRd2QL(?1o7x%UJxS%X)jb30$LZr@|7E_uoe6mu!zZfXgaAhJK{!oQ zD$KR!1FYy*mvpUjba3-fI*n^`7c%jfq8K$wdO!AZN362QFPM~@tR6A$Fszf0B1c%o z^Baxmd+I_?m;sSIZ5|ez$6!>r+mL56Y^r#yfw=sjMFV?YN~NXtXT+o4o5g9Y!albE z8MZXQfpLGw06k2th4va1c`jfd(g2(ylPi z2%{$XFnf3UxvRU-7@nq*ELU6-DbomFiCRo~?z&H^7Hs>O)q6K?Lq0^`U|z7eeQ>Le zgH%*0$`-*mj`Npoyh@6b^dBecq$jWnkC= zT3Tgy8HDsFO)&GnN+!Dc(a*SNJ60Uw~j= zsX>lpe+*uKhqRcR&!0~eN8t%qxgVuB9lr1!TKSxo4wgHYWZ7heGgnA_hgLg^f3Di!gU+8OecSUW+dXukQJ`+8g4fCreS3isSaND z<}pNONgN`NOGdtGrLY_pbKqrIn=^0UbW+l=2K34%$^dRsu={OUe{~KjXP2t28qwjw zeUCd~2l+rQ=)o9xjwL)QEvc%Hwg|^0iy%Vo`;PWwCrerysLvOPI$3DMUhBs23M39P z)bcx`*ozD}IjuWYDm@W?DFg*~)Cx3_L`YpmV)LQtwPHE<7sa8ATO*a6Ob5^VdB02X z*TSsBk#DYVca5L+Pa^p2Y!@x%L1K@%x-Pu#VfTw@Y1$h$P3fNF#d;c-N%E9~X>y=Y z^loIgf2=^0{6I^MC*e3{BSYvJ=YYd`^%eH|b@c__W+1S^vDFQS$9lK58YT2jCmhSy z;4@A-Q!^j=F=YDH#XAm$(Yr5PeO;z|4bloTt9)zif2fjtHn7Y1_ZF}Y3v%TErvLxa zp^j@%rP;M_8(-d!kT5-`?AnO-zOO5%_jCIDk#jVQX9hO+NIoOG6Ik}uy@~kFPre<8 zT=|5X9;yp4Y^>v)6KjCQqz`%0dlR?6lie<>RJt5h5_J;CPATX0O^}rEiY-dNUPV3h z!wJ6ljx%Y`BGVe{9^ep&A1-s}n9{l*p~M5dxgL14qlw8z<+Rj1-o1C|I{QfNhQeR9 ztoU6fvwXPTKcC<`94N*u`}NiGfhW~KnZIi4$5K05oLPEWY2fPz_n3mq58fAbOouJc zY!DEmx@!)@6>UhVPSutj9-zGzDi0tR!`Rx%LhrO+%MVCcJLR-MNn&;3n8jH6JkPq` z0~GVZ#@twX?K`+E^y-&`bA1TnRNopArVESOky6I8>hD~Gxguil)24Qg6;|CX#fMtD>H zxXrZS3J&G&`>!ohN+Z8H>}+IS>0xmwy_)l3@J!I99NY?{*Wthk{}a)gC+>1^1;?Kd zI>=OOdzkSh<(c+W1s73ge%$r<;F4SVyXhW|u zba#f-TwN_$CrLuVcxHy?N|9Q{OwSnDxMlV}i32eRa)yX>;E^Vc-f$wil2|@noZoZE zty7^1rjpoym$Ec{)>>=03r6EHY$6D_ZXna0cB>#gMqIDheD61mAS^(gs_eU7*ii*5 z+$hdWU1B4U-o?VYj$nN^UXA>i%k_+jN1o$PW;ip^c?FPu!qoJ%mpuH}BtF9bKid8~ zs;PEqAIGhTA_yoVU__-$l`4d)^eRP>E=Y$ErG!qzf+9+lUZwXAp+^w~q)81B2ns@I z2|@?~Liz6Sv~$k;`260_S?});U5m9z*!$jd&pmU^Tyw!BKZMeVKb|$BJ%qV+G+0MQ z=U`C4#UpMhnsa5aZvZUX59!}e;=W%3yyTdyRJzDAo9c)SSK9heMOiiRt9S)Yu zMZSBL19(t<6u-c&0&HkL?sna-2Baro{*o_UWYLKe|&zD+Yi(a8K)C6iYc}Jk8~=F2^RMl$e)Z z+Q>Xkpi|MR7WzsJffd{W(~mIpe(S3CQUvt8 zeF<$8F@PvcvQ?egENc0HYNf%G6owF$LQee1t)O67 z@0jY6+W=&nYQvw18$X#Uf0N_{HcGO44pM-ZrjDI)TIvoOKEd80Zqb(wT8DhJA3B#j z@t9sr*cm*EdyaeWw3bZw)Z1wb<2CN^EWG96^M>_7yVPb06l_l-_;8-HLm?ssH|jpt zmyf%TIx|?&UTj#JHwpc!N9w6J1~?vrKuzUq;cT2(u~t3q{kJa++ z{E)FqT#ELpZn2nr^z{-}sgxiIt#CMLXlb_5N=`4)8eKAZJ0r8W<@Uw#>80-VP`^I^ zX2FbB^O)|M&fHe)5f3ZS*NtN9BR7}Tul6n05QKQjvqcm+7WcyY$>h4)L;W7>^*EM5 z)ymG=C?~zMywmceQCaIU;aPki9+4LaQ&UuzZqHjn86Lf<&EntXOXdeL%%6}EjntSk zv8|LxC+bgWi!_{M6e*~}{s7F=dISjh7;->On&AC9R7`Ksi5o>*q$A)me*6PB#g`{v zZwk{6Sx#(#S0ZHme%zILm92rVnTTuOU0$3@wzq;0JZ|k7KXNK>ai!8_y&jdT*2@qW*su) z*8RD{5}O}6*#xX}_iKDAQg-o;_5)gOYSv(A6s*pqqy-ZcYm?>fS{BR@lk%Qg!6PcP zSEiyaHE?2hQtAU;f4Dem0WG8R_^C!NA#QAS|jn4$FKtxk> zy4g&0q#+I3VESDhFk`GHo-LY&a4{=b@pEVHfxV&udo>O}uvc%-$nX(qi=(`0!qgeO zc`4QwUmz#U5N5R)(~x;-oAJ?VuYz+#8EXU-;i5XhG692JZ!a7}Xv6{uK035g0nH!^ z6g29wR-=T;q3d5Bov+!g>&4>pVq1EXVa2fa4mNYu2bnWrNa)P43NIlut}SAfMs_lv83;$q{>)2TflaC_(Pb+2Wp3< zMX`?kMX_|Con^OXC!UIMK34=4uen_Ia9?iL-y4YN8`R|FEf8rJ>21ES%m$qs=KA7m zE4Z)M$3&iR!`7!iKWvN_6$5dXf3%S5_M4T}Q+7ImNU_GB!TSx@MKF=y+}i{b?PwRt z7ALZvj*2p)<4o>`0=5W#F3cxF&^GCT8k5PxO$nDB6jW$rDv1aF*Y||FfHqk-PCtg4Cin%b5*R^r;Q5#n3@E)Q063;-#wUCc%L$uSTO^2 zGC^OFZcysgd6=s>Q{w6;X7|SM=v4zWy1Yjgt-nbZ*hnm7RWDzOykcLQuIWWB0pikr z5qf!n0oT3R{-pQy^QF{=y={#8CAV*%8OoZjvM=OwNhYo(2D*)JY}^4)9eZKEj`7D( zbHLKbPBKHOBJSCUz`-l+$G+@nfeBNEA2_5SHL>wt(`aZ{*y2nsPJI5;>xbV=d;7x) zJWhF81!&h9J1?~}k5IH|I+z7X2Wg*nQhuu*CL|cUF|e^6T1y%)FlrRYO+y82zT3Gq z^>V@lKplf7YC&kH&SJw0^fq%UkP}s+c^4xCaN{t>P$1)NSx2N*4!1TE? zEzlWji_GOdSdh1xWF_eCem*KzZM>Hq{e4X{uyOM!grohf|K4-CDTx0wx!%lLcGOmB zeeW66Qq9SS)t#@yXTvA#u*TbB^Im(BwCBWv&7OmdA-=A8?FJ;kE0;{Vjn2q?*>WS< zVGaeo3#YqZsvw!@Ker+=3az9HO0a$6phI7xv!6v!qoSsk)?69$0FvaYpoHB-hd4gg zNar*;ZTiV8!p*i}ktAtdHMtwAY6>C?+Y$0>X)6<_mCA%;$I}~E(!?QZ3_iO>Wk1@a zrUXo?pAuQ-nIX^_1QaHW-MbAGz0VJ%L7#Kb)%kXdVtkE`~AG3Y3jdL_S(U4X(T3)yik@)fVHvEq4~EUZitd)&n;2OoI0z#J(#0j1CcC%! zOE`<>TP75fGFig9Pmj~7+@+Os$&T~z!nHM;AGMuxv6zd^*PEyHF2Bh&{_bfH!pumg z4+cJSJ!^bcAP?j=6u!Wnk)QRx#J%in(GszvDYa3fj2~+Jy}6X;H>t-l zPk$|Q9B~hKF^f#|N%aP=CA(iLC8ax?mz8Mf?Vt8DNBh0?Xi3&3FD*6?9`o;w*c7nt zV)#)tPNp&7Uh^zblx}xG#T8eYtq5XE>W!1>7OU1%_BGQr{A{iPwFg(QatYiU)7~9+ zSajPgj}>Xy z;+-&7cFj9uI3)tz!STxvRY9p;OAp_VcdHCF9y7oyS+(d6wd#M3N0;$fVueQ?ImT4g zGDH{HWuAm8Ne2tJ&J8@nFjO!IYcwT0e`%8N{XWRVHY!}? z=07(*@=|o&?uQ&Cgv$qd218AXCEn*S4`*X%`;sAZh?%vx^9#%oodXttPkWwz3^RA43<+Ma)#t1rMDd5e;e50ekmmf4w{RHQ6 z($f6SJNQL@?R=ZEkZ%sAb^qo?4jqx;zK7fV%Jp%75hI%*M3bt#tN^0s$QuZ?Tx&$>T9);n`!tDbH%UH0FeCa z@V|oM!^dt?XOqTnd!RISW+S+2?>qHNn2ngKi2A#Jbu-x+*5g{;&cn@#0dKA0ZpH!% zthprlQCZ&fYJtZk{CH{w(Ng%mR|`tYI*&B(&22Igi7y+bcVrpuL+b=|t~oD<2d-ZZ?*17ZE$zWx zrW{(77S*QZ;eUI9S5xoYH<_ukW;q1W=k>i(mVOz)V3TwHJ8}ic*-yaFx1=j~83$yy z$d+#y_e=Ujg-dy-tC6iu0c9zQU6HX|9`7BUEG)pf zE}EFt=r)CW1K;UHlS{c0Cskj$Hsz`#vpZgdacz9b8`s$I8d*9n1d=q65^~+c=SENZ z(3TzZqG^c>+4p~E;WW@#CM<|P-PZ%{RbPCSH$J6wXaFw2B>#R!uJX;h-AvZf*)6ZHDT+L;B%A z@DnvaTFqaQZk%816zLPTuuOOY61SApwN-pNv^xWGs@siyl9Pknh(o@xtu-?iRfDYD z;gtR!^c{LU2VFhErIdEUY{X1w06Eo#zB9d+=aZ;eKVKa2=ZpJ~H~N+TU^@d$b-=cN zbGDc1Z(dec`@VdQuQPn4*Ar5(g+=?N0*YG|SBH1IOU;FQjCT~Ed7+)G^gmA>K0FKw z#!s5fJe#ccD1s@Qr^O9bkUXK*qyA;|zzgJ0;wy`^E+cqO2x1-~R#O61LQ0by%$BvL zi(hbFCxzLBjrY(D94?>K?0zW_$x@g#6bKKSV5?0n;Vca>Wh@Jt?(R(QUd~9hiRkuv z)nk=8H75`~{{Of zj+KXTcO5Dh$}0v6gZZpu1o;nM%gn-YHhkd+`%lsG={!7g6AO-aeA{1!U%D;j5{0PQ zOeS!ynl3Qh$SO51&bPQOyiCwCe)}#90?2?Z2*S@?xtyIfQnp-~(f9NXc7>;M?od=; zOZ3pn{@t;gmi;m)X9Vl>e|I6`5hz_)(zIS*)Mk_L(UJx4?EB9m6)j_XhX%pp zMhz{l(C3eRXWX`FfzL+wph~zt_N&PTY`0(IUhZjJH|09G?yTxF=LgV%q zOD)6(G;*X_b=H(J_C~n)f(Ga6ra-75gw9j>81B#o#|O5Hw=)*h%(Rv;5Xjxgs;tM6 zab71hwqG}UxHw{$+XTg$5G03o9HaFNb*|?>5!mr9^+_%vZysE091U{rX27{XMRE>=f6hfQB_x*Lh zJx&twa^1lRT?O3ISX?pZr2s9P^dP~+EcdJ+&vWzxK5ek8;P9WYJEL2bt7lHUJZ;%& zb{9S&$7mL}5!U7>_wf@n@{)}H1^TN5WZY2OiQR$oY>8QUvQ>7*x{JI|W&A>OdsY1o z*50P<%2@8}Y~56$>$H32owPPu<~z?gFHY0*@6FupjJ<1FdXmPe#FeqK0VuPK?T!{5mk1Z>FrV z3ouj@=oJ&cR29=PvZTJ;Z(;yXZrK&P|Ej_AeNPN65a&>HMfOXbcThY1cDY=MF&SyJ zh_T}BbikgQR2^@cth*xQ#V!yHI8moDzkFLMGcD6(GtRvYEF%g*jyo0I<;+tT$M6&l zk&<_5X=`zGl|qi5MsldUm$FGhcpx0NDwk;EMuHxN2JYmvVXQbqQ3`^!zHm$X2Z2?% zc?mACxsJW?+ff$+=R2X#g?BtYIXEXG&GqNH=Ul>y&4n|ed%6Kbdp%)^kQYmuUIrx0 zkL>JI-W+UfIyZJIg7=wsjx(O?GCv zu3ot z2^*XDGu_!I2{C)V<#2uZ#X0}wSUt{C=ug48%(sJJI*vMJ$O%LIX9DSLe=i~vduwq= zE$q&xbmLyHOuL%cCP_C7wK9#OGET4Yg%Y=LOQ`(U2j5Eib^^~Ozno@-!C9K;YBsbM zy=>oAyKd2Ty4_H=SrL5OHY27IeX?^IXz)@ix)+(r_EJQw`_Wxp z@}txhY1Ed;3k~4^*6_x}eix0dTo?7FufdGuhZSx6M5Py5k5#0hkjd7IDEriMY~y6a z);h)qSr91n)ny==u%2iWDy}Dld4#f%0^qfTT{}q=L(>xr&F1X4HRX_Y>kUHS6JLumT-cMgYNh++V%y|6t*jZFbk~J zu~@IOdiHy~vG2cA8_%^h-?&Fo9U%xbpbs{Wb%ejKRr>IYd9g(Kc$D6d?k)r=xD$@t z-&0gp;}h+qdiIvQW6(<+e47QJ8hC!cGw!8VKJD#|Ryu1A#Lpc$L!*Z7bS>yBWsQ7b z(8cYtspMkcrm*Fwa$7@N%`LRgWO-I>U)iG1Eh7axnL#4Cr&+m57ZDfsLAQ?Hj6Ty; zfaDn83y+4xZh{G1^0)TG8Sl4=z`+(fB>JW3G1dLM+3i-aQNmlN?cSMGU#W=IR_byhu@Jo<4G}ZtS3-O*ReO0Tuti|{ z+1_;@i!Sl(p2TNH4=&WczZ>&Tw~XQBE*k@7u%6=ICKbV$-*CEvru^S=I^&bo#o*fm zJdRjNoc55A(__VH#=ypWsvC8kD~{{aYFF8uUqrWhc{TrdJ}zCB@$~VSsvn7%Gh1GP zpKDrZ+n41J-Jjn$k}U$dJOF+Z8E55c98fgAkh#~V$3$xh1aO^GV`pJHydK z$*G%A!K&+yTF*f?xB#*fc>7Zxa44!_Dyo|McyW=4;7F^1Mknx2T_AY0r)^zll<1AL zM`{l0A!0|MC0O{1<5gdb!kP?YVv|1lLoxb>$rNXC>dMk3KeMU(jph;QB|XbK;xil? z7Q4j?Wl3p0Qb^ zBbQ^(jIwERaTliU9KQ0Le43Fynx&?}R(I^9nS{K`QTa=cj>rq`=*-;hMos1aAm`| z!VBNFa#EJ^CanO`B0N8T5HuG^l}za{UuLxDtL$#3U#E$I%(RA2!Gf&!T35XET6k#T zf&J+q06h!N?`yPPS(%kZZZee&Jlh3_oE(j!4XGxBUpXPRX)g66kpW}Yq02BCHm7#i z?35|Ey-|WaE%#5=cwU3dgmBp1vOXnAuXFU#-m>S8X^#`Lf2Csn^`8|! z#OtJ+{#kMXJI})}^iQ1$>8Q2k@|qS8G_)Hpi=WaT1u$c?>$u2)2K;5GVT4+e(dfxf z_C2pJ3J6YJ{2mC=SuoQ?p!CJWy;KxmE?kI|{%(fZ(2--*YZ9(3)~MW8h4-P z8Yj!s?UlpU>&bpaiLh5o9ZVIZ^BO?r*y)y$g>Y?%I|t*I$V#tvEl$0wZCYTh^6+JC zGYj8m9U#)|VZ(6&bwS8@_*wfOOe{%kIm+vA2pH0E^-KjUD8_hB(|Jt{b8UE|dw!mN ztYT?&dM?mj;&SJ)s)V;+FH@0vl7K>BwN@M91Ao)rjJN2}!e%P#(E9D5+c7;33R#vs ztU8C{ADK0oFJgAI7po1kG)w};8t2s&EKTdQ4f13olZ7^=inZSEj(JQjH|QQ6Gxb`( z(%&DklZ|~V75$1OP%Q@c(kT~^>a%ao=Ub%IV6>&{dl6@-L&xd#(Q;q>X3epa9<68u z3vsEWuNk2c*0MXU<)@&N>S@)V3Cv&2@>7+b(g~gn_sm44k0R}tn#(OkS+*#y$~Vgv5n>vWvX!Yg^L9d$jrVbUW9(;4(>x+T$DvBV`6r6Q^AQfdK*9r6bjD8NT!j8}po9IwADu?oO_gUXwb z!PIA?zObx%Kax(%b+%6T&XPZKUetbl^vlIqm#(&bd7ewEWA!5KsjqY!xr0g{P0nZqU|vb-*A_Rw%I0x8!4b13p=KZ@ zo1Q!ZpA2vHI)mh(^=_IL)>@G@Fb%4GLz3HzXXY#bt;uM@+JZ}|shwou3T1}1#tOKH z%<0I?LG76+bl7rvsj)cXV_IUxa?)x&UE1;e8MLRU>i}JBEy2OMT@QUl5@^>Sf&sI% zv+Qyo=!J#1-Zz+jqQZ(v`|@dieJdeI#&Z)mfQvQ zUPg52^`uR2f`Aj6azom5K>`t8i1>+QA`N`nUaMv9}wXUFzRUlk42D=3jsH zM%M5qN=nWi68J@K-&O8ukSe8CWL4q5MQ0(xXf(8>Rbqt=tGSin`>u=-e~Gu!`3qWQ z8H38%3zpiQlvtwHgdKbHgEMesQ&DYzPu;o?o!{>DmN0p4Siv%dZpS{Ti1 zo;g8QvOaZ5hK8EY0saQNP!&OyG!#-AR@)#0+c7A(8BZ z%JQRaaT!xlJ1QV?Q~en>J|6`%H##d)lVT@czI}rnqML*W8|EwEFNF%4F?3j`CRYz4 zXa({WHt?Ao)(wXpbJj+2_34zJfF?HLo?92o#@44&;}Si#4(knOf6q87!|tEnsr^VD z+wLh|aj)|MP{g1KU?o^vL#lnmV@y#rvwR$hj24jkz)-?UpDwPX0_ogsPDpu(Y>O9& zZ;FS^&lmE6>}sOFMws3O?%N+bnoOle9s9@!-JX|0V(W2(t|Gbgu0^aDBI-_5S= z?k|)kPD(|Ld+qM`m!xhOj7Z4i-!0v^RL8%yAA9PdlzZ zmfeKlvCR2ECp>!sEWVIzVn&)A?J@&eLEOhb@GC!qC!HBHow_Xl#jMK6Av0VOP%YLO zf;}XyudIZ-Obc_L_e0Z6Suf69~73 zX7R}RZ|0Zv2aeryYM&1`;3_+d4_cPUGIEF$5yHc?bePfUxPXQ5jR)Z6HhhrIB6~m= zIt5o8TW4CvCFYKb20Bs$`fE z7sJlG?me|UBf<4*!nkKg1PhUFuNgB|>%3HiCvvU$@G%}sWOjX*dhzV|Sj9beFp9}v zsAB1DBQj2DH+vy-*F)*r_S|aQQ7Abp=hOPQQ5h;EaoTJxnujfBXX~PS^><% zZ1nDfu-BPp!4hyuJe$|8_@ws}ZE|JjaRRAk(&I>_BR7+f$a-63)?r|b@i?DCDaXtv zXcLWjg`ohw_SG^v8`4R{JwCY)pIqKv+KYh8)&LEy?cuOdVsAu5ocdzQ3$9%Y&}z9d zK*-0&_uhMDp+JKQWu$KnwEGm8vpO`XBH8g~*}4q${4-vTZ_7W>5_P$u_1y0r2Ks5< z{Ey%OU@$#W#yYrx2p%)8kgV=Aw9jq}f5M7`$pAgn+bOPJ;+UVY-&a0a6|MnbN`n4x z)}B70u@a3o%W8sWy%j42h7nA#hWKH;)QG$@TZ79}uaCJ{S>E2f$&3hqjZ*?P26`K( zU$+>}A3uN_IFH;Eznr})5STTx$d3&pY5~D_syrdO#Y@gB5$5)?&Jt6*Nht&w-M-w6 z4H1km9TtJ z4#>W(YKc&y;l|IFJ6A7fr@$U|G1<#B7qwj9;#^dZ1t?}RCYaq z5D01yJp^KC_&6w7liXKT4tO=2T%lvU_8jipcJA|0S7iV&s;Zqo%HJ z&LBCa=hB9lx}tha`1_GUj*3Jh_;g_Ou){+4^{Mr=P#r90G+t+5Ga01v`7~i|^ua5l zUq~2=mtoH-r9Am1td2IxMvQsBhUrIm5_iWsy67nkh14lukkMXNqr?0p=`+t(OdhkOwj zRN*$fU5#Y6t=%ix+)+d0MrgD7hCkbLHKx)Xz4>Jmgjpv*a)`-G$G82%%}z>PZcAwt z_3N3i0qKk9DAC61>VNqyTc415;fImC{d9X>A?6i`Mk!qZw}-3p`Q9N8^ZlOZ>sAA2 zf9t9B>8UFXsr?^D@Ll5qYSkhdM{a5>cGzu8ZXKrm(578liYK@f~(8Dvm$D4+h*peP>hs?DO zIm&1`=V_37s@|x%30Hn>-~uaDpEfZWV8*XN3IWm^j-G8$_a1SarATC~p#eH8mH`*OZobI?j8o$8`mIOPoRq=| zgeHxn$)&%W_wO)4d=ctXgw}NG7H)C2_oZKI2T(vGM1fzrHib0S!v=F?j9xst&i1BK z!M~IzMQTp+{K-8>AZCk_IN^xxW>UIppB}{0!<>914KC>)+2Pq5Wv}+|;J7P7p0z=T4z=cuAE8EL)>h=UrnR2=QY{Rx=C!p|9`haK5Ku*}1u(Y^nINQL#*e zuUYbr4A34%NxJL3Hu^#j^)+so3tcQhxs`E!Bjk0iP|?Ar%*wmgL zOiFL&Iyls|S=Tc1rHNY%OG(DonG-%%G~l3dx(e$fH@`TYFw5mlHwg@Tw0*)%vy1Ly zbNDM?^jA8z(ETmrPP%<-nQ;JpMab69D62sgYqlD@2Js46ACy_n8f6#gK6PPrOSNX6 z4ZNF~Vr_n{pKi{C+oKiHtkwJE&t66%vJ7`oLKHr3Du5lI?ZN9>h;Nf#pPNWz{N=KkKd_#%0lVnMwi`t z1oR|6IrBzmlnlAv=iE&LnDcA}Ryl#3ZWi%dY$C_g9C)IsMLenh^SU6~4I7A97A(r; z3|t7t6a&4pQdZv{N1np<1$Nc(`D?1dt_@cV2|5fd03k^;nOB3oDW`3^Vnby(YAIT( zg?CjC>542q;sFHUb$y3zo%pNYYEzUzlUHZmI6wz49 z7b&4#eP7frPc~%~4!CdwnOt#CQQa>-YM?SNJYyp`{MgX>;;DY>EBa|xndSUMEU{obU3VZqdU>!aXhq1IC74PQhGxls^khGlNF ztdZYU3}-xCy&$Km(FG(1pB||>3JRnnQ|W-Wy0wdP+{3(slOiKOxZsZ5v?LEHT$+v5 z17t1oXo?I*NW8PJ7a6_e5Kq1Xkqu$TX(N5?rdpqzE+Vg!_BPS}=b2yk3rTQOnj5LO z0EM=juei{@bw=&;HcJVS5su1%{&Z?8C@ z?OgUg%tcQ^V-(T&LSVXyBNg#%o|fz00BqJZrQFQz+!0)P{D8i)w)6J)xPgW~RgCBL zc%YuAxebK(I&y~($w0GOn)KaOQ9lK$fi3fTRBL?0LjC>3Vk=tMPQCyX(J9!nE+yErZKUuv&bU1$0LnxlVt_aJ+j)PXj?)%*M6r%1 z;5yIyj$^uO;ZLZZ8QH+y4y9X#jL~p;?PcwuGdS-MMJsc-|>4;@ztnZ^6 z<{t@{H9{Iimz;P;uKAmi({@Y*E@$8KU=@>lrEVvw3*M16tM*`hi%_4KZffOF z5;GGcIrvVFDw7ERGuq{HaA-rCi9!d#Z3r;G^sUoC99(Ml#)+*%|JseK^Xf$~Fld)M z3}4@taNG zpoSaZ%D|1-?4W%Fj93F5#PjCAFFjX~RUY_Yc8S30wWWZ|T|21lCdPOZ=D13IG#Te` zZ;yAv?la}xvh+g`Yd(IzC+x7+{ep9^liAlL$J9%2a-uB$OWYJHr4K4%8lfVx7o59b3$E0m1-dc_7HmV#4XSST<{c`Wg`e&FHhu z+53Q)_{5Y-2RKb|Emm3L$j!`4DE&Utx+R-t*9SA-7UEgF3$tzHn0XWv%u>jZ-Yt5m&ORE+gcxSWjm4if3cg8#V75W=WC*Y;ZtSl z&z~<99KO;#+cHfx^$9d*bk=LxF3|`GJJ6T`vs5c=#Rg+W6K!s#7 z2=JcXlg_dZV~TX%&v+51b+%>u3;UQr$JY~f_Bt+D{3Ov=?dCqMjT{N|5I$3&u!nii z@w)!-6+ExxhIGFxxB#voO8Tjp4Y<(l#v6-%ZgM#2f&ucuWYettnFh4;_7-elyN3Z& zZ1bVV%@SKviyGXBA1 zmg+CV=gA;_nfZ+S-u8qLKaZWz?YY^07ucwNE6#^c&&Bw;0eleLH;DGWle}xOt+&-w z5L4oQO*XfD-L(HpJ_O<3E+6Yqd3m7X%dr#)Vsx_1Ym`sg8C<7AH~1)%F%94HR97mk zw0+>^aMNqVd6U_ZrgJ@qpQ@})zweZr#d#)?1S zX$f^b>#{63NIUMbz0`zHHLPI-(U$xp4FRj)Iz7Ya1+O(U%sG1VNebM$mjuMLyf|#_ zQ!u2A*&Iu(qxD4!&N*U_EB%Rb4WcvHbaYvAO~8VXML(q!0^N|#+yb-FK*H`I9i zgzMYUfQ&H!vgi(B)kEL1{jrFVjFCBFVXl^0)ro5|y%5>sSKr${iL*f$<^_$M`4}V`y%`iU<1Z$S^+B2~s;Jb_<*}Wc4$@ zt;KomJl5+qQ1RC2wzM_o-*(zp^GFD9T>x^3xTm$<5Iryz_c|`Wf~&q4R*}=Q7vzfd z661iO_{T#DAgDOd z(U;1|u1mL_&Yy?vxbE%EMem0DIlY?NsN`f1L?!o9Is+~?R!QLmlhq_WuCC_vcW+L4hlxSOM=IlcRQ z4|a+tv-atvZUTUBr@G7Gf}RF7q=IMQyAhn}0Sf;{CDYc>9zkSze;t$j;&oNbKHoNo zpwrjqnWBdg{RroP5H*{n!a;sh`A+}z~ zL~RK@460gwbZK4&s;c}ois6-InNW#i$f+Dn+V&MAWGy=z{UjlMme2gEXbY#%Mv zZWfhMQ!-7y2vo0k@eg8I>|VVCtw~P0ZBR}y3wN`q@ydsKhhu~q5yg*GCO%y^m&>*y zuSZXM9IZ;?EDcg+6r1xKoMT7G8J^9<%8^F3zXL2YTpwO)F%NfPh3c8|IV8!10F25BH z;1qw~ORUI6>0!(Tz#srL?ymiqWB227C~~Ww(~>NnkUIx-6nz6lGEh2-E`RmG1YLEZ zy%$v_SpJbp%ZJ8t`<{XEZdYFn3MkP^Kov%MGlH&3p89Z}^}?fxR{(nrN@1^^{Na`R zaKwA6dQ3DwUEFC`n?hdm&~`N$KXdFl!DbK>xlNuDi3$2bXQ`2-Hw6|>NB+D;n{h$D zgKWB~Fw`8*kZb5_CQ26LJC6!~HvU0pq+6PKiJmP7G!8qv{RX3h#ORwMI^R@&g8-Y{ z_Jp&VrD1RFIi&tkwzO-Y(QhrmBc%@%Vqz}EZJyv)?RW|7GJk8ZZUXVkYBia5dm)@| zpkF4GKh2qxa$8sWW5ZiCU9H<2_zgwnK8AjeD_(eoOqj*SIj_ zIp_sj#F&QdlGiJHslFFP?4qNM2htUh0-$+XjUfw4Vh6@fZ@3FoWa3MCPCWDRT0TyS zbww)Uy`9l3nkbYmGu5&2tqn2l`!6io142y7Qh=By&1}*=#5d>>ioXfNY>Ap|($5Rq zqs{uc!V8596=tRw)eP00nbp z&rMf*Np-V&DWT6fs8jY(4Ie-onqrKy1y7NB|Aq+oCr~xH# zvb;;yn|x4}{e6UGqi8=7sF&=)weh(9y>5fEK;M!9O3v81zz5*&H@)eYmmtXSaq8_i zwqZ-=@5c|*o$hI0lvgWFLkJGC2>aqZqK6n%UvZUv9Qck_hgH>hfG&yO&czSwajw2k zj&6AfhsL(q?Z0hfBdp!+a}26dkgB^`9pkXe&K@(zzfL=!G|_)+BKBkLQwvmJOrca+ z2RasFNk|-f^vss-qM&Ft5)76qA-kb>I zv$sG$YP|?0dN(`UAJ^2EI|`ZT9NHwysf@5Kck{_7)|nr@X$fV&!CI4|Z%zSZ_*XWi zZNMu<=;p9;r|hen5KYOm|K983ik8ostpDf$*+Bu&Q~9eOrejTcs7>qYwxr(ST|Ai zf~p^=@FV9-2mD`0$K66~#C-x1THzg?&U}L#wd^HmNu%`4)K^I8A1nS=G-?;ieV2JiOt_b2eEJWq zRs&NQ1aYz+dTXka0ay?C-HMYWlb%a#pVpXqX{s#8`Wmn7%lDTSKX{?R{cXjl%B?1? z%9UO%HGIef;O}Jljjefa7@SV-3xE-)ZY{SY1vQjZ2f4_Ry*^=itkkE=gf(sEj=_zef zb6guf&|;AYHBmCKq^~e}r-=kVdHY887sm*2zIak_8i$+pv;Wc#zixV%pzylNr6lu~ z>5#_cwB5Lgj#>?G^iFu|nHDafwQO!y1S(vtg-{RJ$h_*Z{@O&4gGyuR3#~!kVb)bH znMK8CwzW}`SuY$l(dB)P}saJvu{|@Y$9Ti+N#i^*M0omYU|36Bd=trC$5oX`hi5V zI;lS_Ni5n2xm3nKjIr`gg)2o=xubg~4?|znH)m*9-rXeAizh zGp;&0!Nn-Walz|R;l^G=olKxRA@SX|YqO>TNDiW(|DI`tKd8yX5;A|Q@nx2)_+E1T4d*(`02&kWT9ym;PR6+0a+3DpS5QEL15N=X zyI(}oe!&!kgB6(R3xw0Z@zs-|wIM(4&%Y z=dA!Yqpr(2>uQ;ziRrW6YV=~irgW}fq{R6(#}EJ2iy|R>f4|+mn}tHho2=Zp7QMG{ zPeXR6l+(T&7`HX%R%aVN{()>s8WnJ=^9n(}4#n`@$GE4o{Ai3i#ooAdr(mCqjvB$y z@DBrYyb*Ny2)YbmfsMBiv$R=F@E+yjnPB&QTV$%>R>1y8+(du(F}AqeW)_R}_Sp)t zsCnK*_?}JX1bXnFE);H8S^W95|B_>kNz@rKaW^{#g^0%6n(UG|_4a zhZ+bgCCteTd-emkYpl;SK<&Zhm`Zo&pZMdBzm^=4b?`c)fc)6lF4tix2Cxw-1BZD- zy|E`up8vx>UcW$DZ;5|j@2>zSXq{t<)E)9;19*gwQp<=O3EO`)91{QhY}w9`r8GpuCW#|4J%PhyZnebw&d794jB2jD0p4|;9FV# z^;`cwM*g>t{f}Sy1&QvTKlC?OZy%@J&y~MeVMCGc+d|LUbthbp*3GjUtF&8dr4jDv z0r=cn@j@l#hCLL;c%@U<{=sTfXsE6Lr3Ra4KwELsw>jvkVq*0ZWtxv>zo*4GaIbw9J2@bt$e; zn3eay2dXs(sKQYGdqBzVKn>H`d#MERWvIZ=r8?K>|3P>CkNXeCT%p`L?B8zv-K`g2 zVvQ{qhTQ-KTz5cFe#B9y;QSw(_+M`AZ%pRjt~TGI5LR)@f08+GX8~dHkmE-K!T;xX z8EnZC&0zrSF|4~wMgb^%JMwjZ7x7X^gf6D1+0w5XKAznt4I@Ap41d?g$(3Az6Yog{=r z+<&M`6Ne1#et!E9q{A$1DWG=AH+w(zY)YV z`w3Z!z%*g5hagYUlVvTHzVmD3jpA+pb|Evcl);=o+=|G6pgfVT8oKqRz9)kU-B z?;rY?t3kOi4?5_e;yLmM+8)fx^bk1GwG*#;V|*RnjJ$RJZ_EXJU@p4K#ohYFQrve^*k5OZ>T$`b01s*Ma#i!Vg-4IX+Ru+?Q2%F}KkLX@ zZtX}w7QMe{lEoqA3Lw2{|9`h4oDWB>Uv!o>;;-$JN%0{eRAu%(Jw5InS6&R=7A!dR zUzlLC;eiRN{$+xK^5ME7T@ktIx90_wHXWKDowB@MrH_P=mqm%qwUjO91Ms)JB5 z|4OiamTKqMcem^SMSD^BUH~w;*phZQyzoCA$iJNgfB(0?Hc|u-MuPq_n!o?O|DRv= z|KH&Qf4KDj^%MS&xaoiW$$x)RBPGh)?fAc5p5eS>+`1{efJ|_dr;K11Fc1fjpjVzA8!Xz{CVP^{yb3g{y?^dsNed0lT2T-lwP>>Ttbs?z(HXA zx32o1yw){J$|9owhtDeNJ9+wo2t}$~h@v5V`r?hRMf)dtQZBOxuIr~)WwicR41uNIN=4SEjRQ-!9&QTTfxvGrcBb`oH<|uSCj568m!=)^_oqG29~j_ZW^` ze4I(`pH|@)-afdL8Hf-QfLZy!Q6q*Oj)%~Jt{GeU<{0nu8RmARdSk{M1oLnF7Y>Hn zvzg+7U;hz6m&#MZuylyuwO<4HzuYg12#SvF>(oEI`C`=(H^q)>r@KOm0RcW9MLB0 z0_2J>S2^^>0W#r4C$|5uy*H1idfnf~3z;&NB4yl}qR1RFMkOMpY;$DDTr6Z*B$|*! zLShM#gfg?RlFCfxd8o{5ndh}wp7*ES-sil2-@VWCbzaY(zw=M$>^l1Fd$``$`?{|C zzD54=x`+R?OQVrj&0~fBS`_2EPKR?INl)~7!@zyAyXg0_Tkko&N1l+q{#l3lhxPVz zzyjTn(GmRlcnxw<3yhA^YX7BhW#aXWw#Cn;rsMz4-jtnB{daSa`*W=Ld4pk!{}?Oh zq5;9><>kac4gqUmi0Iy+x_`dcQcO&NSivI65%UYwc`IB)7vH{yaLE(=7@EfIi2>h=63BT0aqClqc zlCJ#cdYPNwH4y$*D7$#y_jR7}cAj zHyR|W&n6^dn3a|y8eYifO$^Enpf1TTeLdB8%y*^7ZZwdF`??&4gLX+K&;fG>QN0*= zluaw(eaQ8mX8%sU)ysxv0*+Y)&yE`9YJ{Ko{(xyQm*qb!tpD*7pSqul}27TFM_ElG9~U*_iXJZ@5lrm*{PaU2g!vftBHE>wO;o z?qx;1rtbt(41DLc>hS0`UByPnXn+i(nUe8DCX?%cBTwmA1#fw~6n~?RQ%3vTUpM-X z7YP2%)S=HH`@;GBpY(z6f@WY6SI>Q5X4vV%c#GUw^sMjT>))t?*Z40d98Xb-)#uSE zolJ{IB(2K>0eS;QUL^*klcg)Yj)>p_itFfb5HnLdXHmafXzIly_D)VutC$pzepRsk z&lm9h54IiHiy+SLhyKE5S_7luFsY`wm0PO=UZHezw2_D$P;1><-%QJ=_DLi@9ZSI@ zViH>Cnx$y=#)CU<&jm3oL6Uj)p-^H(ZIlqros%OGr>r;D=FJCpixsPy`^|q=#XtGy zdnf*W+8zE7RQ_UId|!8isxUctdaCKe{h2K9YTsFBG5w~Lp)K;Nq;10i=I`vdwY%_l zv2&%I(4H+@)P4BbYGKfKwszKkRnd+vYMb=ztKrXnbm31dGsNus-Fr0J2DiBqbzZhd zLnnRZb0~Yg;%wRB)j{7JzvT|YX&wKyDxPyV|ILvOLyKr=5ezC86J~41;$6iiQN4!> zepHZV77JTkG!rC1J?pWJckVsrGZtrOJ=0$@J;oltr|>Z8wSI~<415fuY`5ygjHiAJ zwL%NMjybhUlWn2SgP5W&bKhDfMFOv5PI;l<#(ed{Qk!1=B|Rkp8ZJ~s`=rdPcMYUj z91hx`3xoco>Ey$_8xv<#A~k;F6EB;)Q%0J>lKlx9-Jb98nwy-mPO*mUZDX0)rEmN; zm%DPWxBQlqsHePcJ!o4r&K&V@3e3%AvaXcSz(nuoN)2YW=dhuyt)0W=#k|v zi&fJuGcPj<=e&+9CR*<9d0ubLeWwe&M`bMDd^5_f_rYfg`NF&Qu}|cv84vQGKYxA! zq{%uH*Kz@Xjtf|Y9*%*V+;Qk3xB6p)^7+=ad*#1BIP7$j0&RZt^Zwi>1^ijr_~mdc zujuB8kji{s+_k(s_s-@9s%;jQC3=K^$iR~izq+Dvs=@(J6Zy&1DE>H|!b~jP+}aPG zEnz9!h8h#wp(WQl{GF!mjh^`D%T3B8C-LLIxoU1MWLj*V18~E9>~SMa7u6%mkcrD0k~~f{f!cubHPJzVFvrg>;vKDs`2X z4dUe8ZQnbQPqIE3O%xLReeCNg=H(cv!2iHn{%&uT|Getc`Cm=Jk-+!Ys_X326UMrF z!dB*5&hU62w(7o0>;6QEIFVvQ-dOinhPxE(Vn5JJB?ofc{1`TP^jbxPSZ%PD8AwsQJSlnWh!t)<(VZBxt zGk703B%U!>Tf!jilK*1-&@Btht?zUQ%QM?Iu?P^qa)w%>Qom@>$MHUX4@)Z_^X~2S z3a{yYuKdG)t^9N+>|lz?kM;()T< zeu7nd&TxLC@zcR55Af)Ju;__9aN~=(OMzQQW{7PHWFeUK8?o{=FcAwqdO>3JAzyL9$IcQ=KJb2G7D)8 z`mOJ)xkZsuSTyygx-w|1bG3OYGe-=Bd8=9~QB#kO%Y(`BE~Q#&=i{#&jaLyrr?@7H z4~B1DY?zR;U1D?Z?~z7LTyPzsGYNO_?R5+GSg&OEAAUrh@4oi9=lmK5M_$g^d~zYX zg6QDA)N*Fi<}Rxln!9I{3uISG|BYL-nGzMV6(mBkD#o^*O-k~C7OXIp(Gbz>i5p53 z>Ix&4eA@#anV*iD?Y~;qun*k61iorc_b|gD07gSHu?q+AGr8odP!hj8Yns+Y60-Fw zq@L!H%$$lklP=F;i&9hur~leF@A`gmvz{jdVYW8`G<>n5oYNm0P>QnwbCoo1-yaJ3 zcn|0^Ai^Ec7T>b<6&}_T*G_jI&)B<@Z!|yX*rjegsfFr^mXz|2aD^#h*VISH`)*9b zwG1!;F5ZNB>#>qbMHgHN(Ob!)&4-KpY^k2X6c5CE{OYJC z#b!`GUz60VPCu5&I3Dq~yN)V&-z7`#?F~@DA8zTt50I<>;-Mb*W~FlC^F69M7w&WH zui_`!pAI>mP<&oteI#mhb9HDn%~`c3Su=Db2DkA_u6c7DtmmhUa{cejcpt`B>v-9m zSpgety;7Ln4E3aJ?&I#eSoo8p;*wWC>{ut_b4JruHZzL)!-8`+SNeS0Y^m;l%&!;2 z)_})1+n5b8J^ASizZWJAooA_ml8v{qcy~2WE~Fm5iXF9FcJ+A^oE0eyGC6=;{o;hs z-T}>NxT>Iy?KcQRd^zNL)C4YXwOmakr_e} zZZnGjfIj#Z*kui&CvV_$V48D8D4n2$nm~ep6fyMtqeqWk2E`Cc{eJsr?c(oHS+)7+ zJx18S?lD52{y^0^=pY4KA_!- z<&E9i9NLPumW(S}VTT8h^-j(Q$MQZ7zZm=BK63hOh&tF0mao<6j0s_n`{~E@;Iv~n zU9qIv9ze{(cwwTEQ}7|VYlewuX8aqjs|bFmPKAwN%ne$u*Uj`c!4IST?#;^!O5pie zHOV0FoNZ@lmYfkZEKvCi4eh9$jt6JV1u-jk+n1;7VZ;!RaiQmS#6{P2m2bhUBPK({ zXE9w=RNc$;UaiqwJwhOs_c`S~DylymenK7EmxZl9?CLE{rr+Zq7@eA`;`Eey>pZiR zdoS)im7Ka+H*xke2S2_2ISKwB-wqbo5H38lA8~t~s`Xg@WNECRyTa+H)3K%w*Mcv& zY1KAm&7V*fbu9278nOup7~0I69`Hx=F#F@N6^>}}=Gj%!kH^d8tl@8sw^57%6IMZ@ z;7&&)9D4&^X?*W8P#PH@@AyV?5ABf#^jx#(K<+&j)U*3O%Tro$k~UVQJSzUCiQZ!z ztV|DkZa4@e44-Nz4fAwStIIld^FqR6MecKS^V9UivkGFp7-x@J{zmk5hvEj9an-cw9Z5NXvd4FMO2AWf^POdtHB8~nae3HpQi?%LA7(qV;(L^#K`0^Q*6YYW z3B9KyJ(8rI3m3|mCx#u6_?SPx*f6z4p4|$eW=GLE+8o5B;lxTt_Sz3sl+%5E zDr(rosdGT+Xv)?Dyx6*%BQJiVvFhe*EKx^1s`~fB(;3 znIJ<$()-uY42LVuMCc|c#Nj)P*ipB=x<6B^hjYsRhA?>3B9M*adiGwtOP_zKJ&>?C z-W;C!l8R7lv+W0$IamnJ16^!a(eo6EL2Lq-8-bS z4y^r|2fp(&s!eS}A=V=1D(}epI3|XEzsT9-B%P{zj5g5krD1rGm3LfG5gx>TeDb%0 zQ}48MNmfPa#C$#u=MM&~DFvV9;9+-)L5mf0!Vu14$xO>G?!?VxYq;b{ z>|wTiO9gZo)`K1cX_Z;fg$g;G=GrDWJiV}!g72FPXCMpwlDN+}`d}b&8D9@tV#Bga zU*UOfzq!hO`*Zs*4s|_t$BCNe@g@ChaE1=YZ6J$3!V5*qfYP6L#RNEE&KlB_h!+sY0s_B3=H@v z>|zXj{YdPf@!Hyt$E==3d4^C~6pzqMo`JztbEd7t>5)BgFX;3F0w$kTl_Fw{asG8n zPmJ)b#gQnzI!$&n>q2(aK>DH*y*aLlpr5YGC4cjKFlw?@pjdvilrd>-=i8@Tx-R*N zg=XEn51pHZwmG`{T#?b8Vg74$MRo?*(lL6RTS6RtlshgEaZw+$H2l6{Ae*i_x#9|xxN@NGm*A*J-Jat3-VbZOcK`vS+c@>cnC~wjKLE+0t!`i zJmfhV@7eNIf=jCR;#jiUT}hj|REECmG>X^|YD(#{cOcmoVXpE`q2H0TgjkZ%^mzNn zheoB|hEgALx_*YzsHCK1hJ#ba;iPnauivvZePzAoh+XeQf}O{O}tL=QfnVHJ9YpLD)JBlTN#6qq=G|Nwlg-KLvv|x z(z|yQkSfL~&$g$dSdX!IVpAu7qCra3FdG%0JV=A3^v&5#f4Y0+?s7iq;X+(^Xo5r) zcXhyh>Fcjva+%%fGHaiGqNGByZ#Ega3q4!?mGq4lU<;bwsN7}Wd{mv6UqJNG_S!1p z`JW@IPQf>*I)1}CUp;AwmVrZ3!??gn%CyON><^;Y|Tl~fjop0?|a)P>-jb()M&?HaR$ggn*sFt1! zz)SYFZGw>+3Hp@Omj#Guj2R=PzpWFa^lA3R4Eh~4J3^a>M}xd-p+NY{XIU)+nWQ`q!D@+S~{;;+F=EkjDS8{8WeG;UES_f%eyb%s_xqW z{47AN#Go*yb4f!1oAoUH=W3j(Yy4ukl5iWiu zB<0X9Y!8%%?y|>puguc=^6ZdBZ-L`@uL|@N>6_~)nw0HOTZ9K|{-(;y&nfjM#PY#y z|N6ZA8bv%Btct~E;$x6?BdM8Qvd7UfW@s2KEH0*?NNP(hJGv>EPhQP_W!Ib-S@+T3 zVj{h1Wv^$P#I<{Lm+eW%E2G|#UTwFS?&J9h_q+b>TRt$&t=Q4hwPbygZi&;NX5tjq zes)wMfN<3-VIT=I!($34Tb2aA;8)hWCkUKTLVDq-J@xQLC>ZS;Q`+fwcCm?>3fm>=QfKoJgQ;^^Yzf|P63rVZ7>yvu5%AEMz79A zr-DD2>osf6N41wCOmKVc*3aH7cpopkb{+yxh(z4bqtE)g`t-l~XTlW7f{LK~wb$wY zzS8|)e^TIYEXe1?sh*{#+PYU2rgM zFwj4xi{(BZ*9z^Lx#K`%T@GD$f&`Qful4d&2k<`DHo*=gSA7KjXVh~U$8+loqvzWU zs`E%8z0E7*aAT7H2@0UqQF)PmIZ8m0w8W(yaH)z%zccPaGe3KZ&;0#;GQ^%fOm1$@ z2Cy@-=WeSHM*i?Ct)E-rJTHIk2YZYwq=YXPIvHK^SenqE94yD;dZJUP)kO@mj{*ro zh<-7;NBzwY&1t*UM!RpjP;;J-4$Ab01AMX?is|0{?v@BO|0$5p=C5JgO2&80KaZZ= zUF!9D`DEthOhGBTCLJFll5`xvdR03=?iF`AW0&+>!|GM7(hAJRh|t0(R=pFojE7_& zREU1|en?c-m^OTI?1bFk@|l00<(jY^u*|}#)I-#GtpU1QApjlhb@VbXzXL9qCRsA_ z!*FHL=$p)JH?)vWEHA*R!YE{WZxEH&wU!iNbEkLau>+=QvApP&jF6^2dnb+u!*B z>b&r^^zLH~dp4teu}=NP`b&?+Ed1AB9W~CcaS2%u>lL>Zdb{sjQYyoBF{W{Ok12za z;XAv-;GwU@y~=S9pLcEmp{5aka%T`td~ALUj85$;*t3*B$#AV|4imkIFE`L@a{*hPm=^4a8%f2 zk!01IOw?;JsI=z1U#f)eWcT}Ptyv~=TT7me!|&HJZ*zu=qNSaC+p|b_3mm)ZU@vZn7|wKL#)tw1Zi$v*SpYq+czKwC zPoGr6^(;vpBDEtROS3VuLO){0#C=j2gI`Z4#omnQ!2J{8{R(y|#)CsQ172=b8*CBJnQtl~o#gs6IvWt5< zOZ&ar66_Y}+1iOJe4eMfIdp0(=RKse`tJURt7QHbCQb+dN9s>eIYV&4q;?QB zqoi6)Y~3YMgl8b8;D*|lQmrZ;1>rKc(YlAFaSDvg%ad)@&f*Q;r{xnBZzK(?6#1Zm zhd;E}k6mo%il%Cg-OW@2=1NZSv-`|E&$yL*iam~PbOmh2Y%Ij!P#yv6^%mbN43nvN zQ1>lKlU_eSiLq)@+iF?HsmS!FjZdrO`D`KfIBT)HP@-*RKqCFrR6qdHxyl?foiaNJ zvW}}EtG_?^BY=lBvCjH~DOcfgGigLhC&|@ttK#o7a9A1cUgj7jCRY^cllB+XL)e0w zway1zQOzJFvQ+j28GRjcL=93&JCztU*0q5nBax6Bm$6>6vTDZ+)pl;PeE)=XW{SXw zs6ZC3cPD*U2*rVk4{3ua!{~zfFN&BZ+_OP_^?EZ3oc}2(E1#twB4EjPC7J1Zr(l=q~J_aDzU$c#M*oAp+pEWO<0ZQgkuLT&XB~YxNc&LBGPTZPuDKT?Dg*0#$2jRnKmb zH($Zem9RP_>l7h!`Aw{a?W!E~bVSD?*w2#Z-x%hI_4M>)d*^bf=U0M~YKj%jF8&z3 z%)VrpWjxvmLX&8oZKE3=0m6NYAHl~K0wl3$F!G1gMY2@E#L;645m#t}X($0JKXVF? zgfKOzKv?ZO9;CV0@lT!|7NJA1r_G{|nH1L4i{qR+WoVNFS+El^?a+={ZxWZ?IxZW2 z4!(M)IZnkSxT!2D=)vN`LMo`pt@-ws5Gk_3iGd7I+7T>7M4 zlza-u&Fxm^t5`$Xj!8DBbMbGHNwf9W9xJh(v^K`y-Q?Ek(xha)oOwlN7GD=FdVcP- zUC4F7?e%MVWTdg?D8j=Y(Ws;um>p}3f0R`|?h^(&-dPuyY$el?X5lXOXk>2K+RfKJ)lOD@;4hBf#4WXIU#nPMRxEMs%Kl=JlLI})jQN`7o~(AS?k35| z>tFc0o=|MFRR5IEX>P-`tCd4piVE2;svu$3rPfe7GBR>D<@BTe!eLc((KqXtpi3hZ z*f3gz<@>>eqt~x~U|LY}TYI-O(NZ<^BA>IsqLqs+i18rmD{r{jUPf+LD3{hYHlKr` zR_9eeBxdAEUXQI3Q}}91!yse~Sw-unK9}80bIeU^bBM4A{fy294ZF~H+BXDQOmgxx zuk#ADL^gR-+g*L&cG8y#32ZrLmwb7fZImY{$a1Z34Kw+|PHTQCM?4|7f`(|?r#U+f zR(ylQ9Ci#0$o**lJKCj{G<_;gI1r(sqh;M}EBE!v%LIWztH@=*8p^60gmL0aug2DM zOHJ85ovnP=W!IbmwO_C}8iR4o>>m%~kV5(zTeGbnEhX-Dm%@qM0`$sDgq}94j3un= z+T6F9?t>=|DWG1PJT`vdh9$5J(H?bMAC0L8iLgIg>f2t1;ED2b2-FN5RjQ2B zm5xo>$(3+Z=hO|)`6^?W>+*hZcs--M(@@VO*tWTGAFxB0;u`t-9i{v zrnW*>#kcOGZ3%25=VS}{+}?n@l)UO7GQ{toGGN$lq4&7AS!GLj6h*C$13a;^)B1X? zF+tG~^`XoRmBz6->Ww{WUe=x8;t`Ox)2wS81^g%%njb{N%oiDjd6^2Q_dV$=NyaWC zX&2O&cmr*PaihqSm(1MXI4igZB{)}Yxs<{V868*ID#0s&KGkCM=+7xpQGxu1;v7wR zABPl*ck_|l;C0Vg8WX}i{o?cboNTJZhEKo3`lwF{upxoB)@iXL~XS!*sLL zRT0*W=qB%H{LNG`2j9HW!zeF=T5NnQ=rCegSXdZ=_)+v!M`pZ_;_my7Bm1Gu=lgd1 zKcwJ;99BZV*ARN_Xa$+*SOSgMXIwZELF2wPymR*GZo{J?s%LLdCmYzkwM+mty+Qin z0{5;U=7aomLtI-W%?9w2{j$-0)SyY<8I`t4!*(&SAs}SNCUvd3pvznos?v>^#T9ri z9kNP|KKRb8Q8;{v>LchFl5rX22EJ^*?N+B7prLYV)~RPO2Vaj>rK#!_XxU*B@IfMw zMRfm`eGnHpscY@ZyiB3RK&gj-oLrYocLf^DahxC$xi5J|I~SY(rj6q(fq8_Nl=Ei@F871;>16L2x_rkG;~*DK6Sf06R7xg;d9cI2_W-05jM(wR*i zJJmdu8)mp9XbfT&-Z^%@eF8zq>iib9maTQnmOb@rjPFdzcV>;QdsD{)$Q!ZBrGP8M zgcR;sBrgPnHP#GZM|3P^X>*gN9zpvh26NZgSC>Mn}7zd-vq3lQZE0l^m?6=+|?=fu8%e&v*H(~HSY z54IPD&V8jG3ANSKBl#Ce+$U~)aNf2m^d83AN+~jH&rci)s5N;B{?U>3_4O|Vh;lyS zXfS_H!B@-uyv>NGjGeff0LuTdxd>p=#Xl$xP7Ke=I`kmz3njNW(` zmrY}brK1+vyGuPEgG=>!9HOum@er>)UcSuLIvI0>eN#@;jyews#|RygbIH$~kvc%T z{ykKDWClCy+4rYeY0?$<1+t+sMdrXEQ>AWOTiLRE@zVcGK>OePw++>wz5RPCbzN$8 zpp1gxmAF}vlZc?+Xn{*HAWOj){yl#msxXK^a4MRuy3hQxTEC^1acGZ#W78pn_u)7! z))S$y8b?w1&SQez?;%Ry4dcj4UcgezfB-NSo~jljaqUa-+!L3tc>DmiXoFy<63Q{5 z0vR!HfC-I&SA!N}WpnD@QyJ-|(M(R<%6`QrMR!q6mq`1@ZV@xr4qAN`IPsv@AZ<-F zVezO|;t{cu59dJ&i0<$251xh#pZk8Vat2!<&uo9f&dJGi3|j(Puw!qbg%)FQaJjOK zaq$hy2Az|{@;u8b>MAeD<|1nJp;MDI>d9qi;^>t>y)X!g{%DFBa0U#SPQgqr)jzGj zvuk&%Q%~-UbRv4i$SBwPA$$|`Zg`oAm1-5TK6r_HR+TdOQecC3d&*%dw@kVpRn9%g z+BTR}06keK*L<@VX~-fpf_xr>1Z0Ondp4+Ey}Lxo2W{39IBn3fp6KCbL9Dh|V&9l{ zLsA7@IT$k!j5uy1_vw*KU*0&XlHt-f(k=bV0 zFV4(;IjPj@AI&izClv9`!Y^lsfp0gEx4MZF(N^%&D+(P zXHzHLBF&VZmA1vbbJl@0Tab-U-YQ(3?pwv~pPk)Wd%P6^C|aGlo*nSU2yJyRZ+kfr z)gwQd%os2kJa$KQe~EkJ?G0X&8*}MAv@J#~DB0k5+W!6(6TDl)Ry2%`XnW@1*J{ZY zwp~5}bkyo7lXicyf67zj*xo_iv-yy>P!Xg;L6EK1%xfU}wY{lXvg%pW!#V1wbQDYi znRIpYVLZuFan~)3VYZsbr zjJSbLUmQFQ$P@EAkOe(-M@`e0>r$G|lhOK^zK+=D;dzuNM}TK$?3cIO6Ao|RX60r1 zE_{)VHuwd7@axuJO#+h+&D4~t9@ns=o#h-Q-|A2RZUa)GUemipI2BN5fxYeTYT)t1 zH1QH&Z+_QJI&ekWzV%Z>*~>gzD)t2ZYb%7F1&ATY6l*0T+R(xM3ck0=3VsL*`mHZW z;JCpnt|+#6fMjBAp!2v0zY~0>#m))ZVI-7VwE5%=_Y?SQO2)rBpSawO^~&+SBInY7 zpns&ZO}_bcb^I?Ydn^`d^)y4-N`Blw2J=zc968Lf(i>Ia*qfgb#-%80y9l*y1Xx>G zhIV1H6ss8|YA?RG`W&k^z3c8d`l+th`iH?|mbl+B_oB%!iBF!felmIX4Rw;w>dY~S zI34%E_!BewTtw*C4J|8Zc9$yqIvhON8W&Gs&@A+an!b?2vsvRe+g-YolDY=AMuWrC z)d?{C_azuriveSdLeAS{gLYN+y-va80tgZVKJDu5_{GODGeGX|O~agCRXw#eu+@ek zTkBP$sVVP4KT{TWUdJvaH&l6%m{8=@BWbdKm4ku%5r@fMq`xv*w|iLynZIk7NI8B(&7^O=WAty$9Il1nSWqBK!p>1RMk4dF z1+z)8gFc$A2Q-kt@kyEsknXaw%@uja%D62G$fBwW%D68r7<}MBbYF)Hh|~fheENV| z64er?hPoHgS}_sEsn3>hWa56=aAXfQ)XrAP!?hNoSC3Pr9!?OBA<<8KE}YZwhJz3T z+t#Dbih=lA@(6bq+(4Es@1!tQ>%e=Aht7@GMW5e7+3G#;`(QQr-gCNF(tBa#tX&+W zBo3brd(d7~?DbI~&x^LX!Kb%elQB4qN5veK7M(}2RrJ*;37CwP*Y*QRysEsZJCZKN z(*;veJKwvOvL8g22xDGRQEiYho|fuIW&xdf3KqXmk8y71&04}{l~i7$rz>Q<5V#2y zujronf(9HjWD_?hbRO9z+Ne5!50rr>TtQMfguD+BsItD0bS_#%(z+(%tgqkJ=BiCg zvgQ*a+p2D{^$OFsI3qo55mIm(CM4X)bi98&k|T_zVi4O6mSvvV*5T*iUBTgpWH?mE zZ_gvU-47}Q@9b8Th74i>^ZSUs{x7cB`58bO4nw@{YTds~_HBO1pRM`6K*Y?eK|Uwg z_ad@f|2G`*fBa>tdyvn$ufxl)#>)@03jvXFDkR+n={yhozHM%ye?_c@myn-3eh&G$ zKnnyV0wW+zrmd^1d!mIQzZB&CR6uQ8JrJv(KnBI;YR+dncKsZ22OKK93{VFMQ6Q@H zX=$N8VkJ)QaNlZeq3>D&czg=~Kdr{+GOwUOq(yC&rw+w{K!8#-DZcU2X|Sv(i(T{O z>seoYo*ZEYdFt(s%7qsIfWg>9V3Jz$BDX^!zOOsYV{Ro8FlU*)tLXZ2z z&;s3&Wmg(8Ng9Oksr^7{fnp<)s2nSG*0cGPD9gS;^FtpeT4wA-vUM<2!q;ESC6?Rj3XGyz$Ov1Is_1XGva zkQvh3>(Xp((@tV3uz)NwSsv(gSm5a(L29)L+Nc2um*>5e; z>>*L+=#c_LCwBZ?XVx+MJ(G%IO*1AJ_I%=1 zwB2voYsJoZZi2^b$?SK?9Ufs{X5fn}n?-J)K>Pyyy8ysmeq!Pr#v8jeEEE`#M2p|< zmc0El2eLCTf(rIFDm^D!UB&-=EEWPc071%WQK~&_muTe>3>5)b2f-A&{;eF z4SZPXZ=py!2dl0YrUC`_|Or#iVUe zixU&i#dqwWUDMJ~Gx!5igd2ZO7n*`Eoq`WCIuS_2YG8u(%!0N>@S7}`ffGC9K5&Gp zsOV7v2gklvu8Ws1MLO?YGy(x*^6SeWX{YWBkkUdfS43*=Jk1`V-NDk=zpBEi;KN|m zjIqzG5_%5*JfwsKBtcC7#96I!W?-(<#I+Vmg=xHrONoH3;BA90yuw{!jq-}un}j1W z_N^Mo8K?}CqU?Y7u<1|;xX4B!c?_`r_VfVN%*?wB4RgI-0Pzx5WEQtof#h{v2H09l zmlJdc2c-0PI?k|xYGS9*S5FSG%`<0J3&;~ zrT!(1<1idoH&QS7h2{OST*4a=PAMj4nG_*8X?|!$Jv@xR68TQu`_x|GGGI*2%*`ps zyAGlu3nN1}8hRFngVHrr+*Y2q97aKnUVu1I?f1{217i?;J%{K9MHQfJ1A5_%Tp{Ov z5(0lN@oYeZ1pROPW_ztB8i8MoQr6Xos?NSubbjqq+5VFHH<(f^)S6>fR$zN{E-!I7gL2`f4*mj?ZSv zcHrCb9iz}--$#ipI=3pd-MKNg&ab^yN;MCuEK(3l- zI8V7S1g%OcLTzxZZynuNzMMeKPgfKknCz?Lt&WJ?`#Xm$|2 zY6e%2B^tc_y=8g1iaJ}e5xR=d-sZ$am72cJT`G`PKE)y9VDL`C#mtVf#b`G{5{gYM z2UBuC%|OgdIxJw%A}tlj^x@T&{&A{5vQqe?d_W4EULr4z#)fBQbj;%+)r!YSC}>7#8Hau z>D;`$1e0zsa;re{IOgr`J;pvheU!%-XZ4xBwz07zIW8du%-}SqtzL*1U%zq6?!Y{| z3QMvY3zw=hJGw^BxtAZRSPUEuuiN}y?y^dnBtkcGjHtcg2}v5cpeIPspjQOp0899; z=BzWvCJ2Z1v&9`$l$zlrkS>US6|tAdpaW{ynYNHFfRFuyRgk3>8kyKQhF_5(50Y2| z`m5?$C}xf5ww>R=TbHX$0{Qa+`F%Ys!v3UXkl}$_ohJxNrPL$mH&dk&`aS!cvz?#* zM0X6##VkqjayMypGqF{ClA!8Y@;yY-tdG)DJ_6Co(0#OY?kB;?Vmh#h*o_~j zQA93(fs~?4A9NeSyV=E1cm@0!$W!;1auJ+FkxQJw`6$Yiz$oMeQ;VRIye#1XYru}g zoTlRBz=PN_aO#qAgR|mLe~$Z;opz8tk9EB@YHPSUeBb*{UeLZVZpy{5Je zv3wwt9`7LT>O@S@PSOO;UFlo%HHRO^4o42=K$!hBgz~37)BAQyeK?8i1U Date: Fri, 28 Nov 2025 22:06:33 +0800 Subject: [PATCH 02/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index d811633..7166de6 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -1,13 +1,13 @@ --- layout: post -title: "Blaming Hanging and Complicated GPU Kernels Down To The Source Code" +title: "Tracing Hanging and Complicated GPU Kernels Down To The Source Code" author: "Kaichao You (vLLM)" image: /assets/logos/vllm-logo-text-light.png --- -Several months ago, we wrote a blog post about [CUDA Core Dump: An Effective Tool to Debug Memory Access Issues and Beyond](https://blog.vllm.ai/2025/08/11/cuda-debugging.html), which introduced a powerful tool to debug illegal memory access issues in CUDA kernels. That blog post itself is a huge milestone for debugging GPU kernels, as it can faithfully trace down the exact GPU kernel that caused the issue. Prior to this, due to the asynchronous nature of GPU kernels, people often have no idea which kernel caused the issue, and the error message is often misleading. +Several months ago, we published a blog post about [CUDA Core Dump: An Effective Tool to Debug Memory Access Issues and Beyond](https://blog.vllm.ai/2025/08/11/cuda-debugging.html), introducing a powerful technique for debugging illegal memory access issues in CUDA kernels. This represented a significant milestone in GPU kernel debugging, as it enables developers to pinpoint the exact kernel responsible for a failure. Previously, due to the asynchronous nature of GPU execution, identifying the problematic kernel was nearly impossible, and error messages were often misleading. -As more and more people are trying out the CUDA core dump technique, people also want to get fine-grained information about the GPU kernel, such as the exact line of code that caused the issue, so that they can fix the issue quickly. In this blog post, we will fill in a missing piece of how to find hanging kernels first, and then proceed to explain how to blame the problematic kernel down to the source code. +As adoption of the CUDA core dump technique has grown, developers have expressed a need for more granular information—specifically, the exact line of source code that triggered the issue. In this blog post, we address this gap by first covering how to identify hanging kernels, then demonstrating how to trace problematic kernels back to their source code. ## How to find hanging kernels From ca51b3c44b34dae960ea94e12e406da0b2ca0eab Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:09:29 +0800 Subject: [PATCH 03/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 7166de6..8046a49 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -11,11 +11,11 @@ As adoption of the CUDA core dump technique has grown, developers have expressed ## How to find hanging kernels -GPUs are becoming more and more powerful, the computation power is increasing exponentially. However, the memory bandwidth is not increasing as fast. As a result, the memory access patterns are becoming more and more complicated. In more recent years, flagship datacenter GPUs start to introduce asynchronous memory access patterns, with complicated synchronization required when implementing high-performance kernels. Such synchronization is easily prone to race conditions and deadlocks, especially in a complicated codebase. +GPU computational power has been increasing exponentially, but memory bandwidth has not kept pace. This imbalance has led to increasingly complex memory access patterns. In recent years, flagship datacenter GPUs have introduced asynchronous memory access patterns that require sophisticated synchronization when implementing high-performance kernels. These synchronization mechanisms are prone to race conditions and deadlocks, particularly in complex codebases. -When a GPU kernel hangs, the program will typically freeze or become unresponsive (even hitting Ctrl-C cannot stop it). One solution is to just kill the process. However, this is not a very effective way to debug the issue, as it does not provide any information about the root cause of the issue. People have to blindly guess the root cause of the issue, bisecting code changes and running tests until they find the root cause. +When a GPU kernel hangs, the program typically freezes or becomes unresponsive—even pressing Ctrl-C cannot stop it. The most straightforward solution is to kill the process, but this approach provides no information about the root cause. Developers are left to guess blindly, bisecting code changes and running tests iteratively until they identify the issue. -Can we do better? It turns out we can. There is a feature inside cuda driver called `user induced GPU core dump generation`: the cuda driver will open some pipes in the operating system, and we as users can trigger a core dump by writing to these pipes. When the core dump is triggered, the cuda driver will dump the GPU state to core dump files, so that we can inspect the core dump to know what's happening inside the GPU, and most importantly, which GPU kernel is hanging. +Fortunately, there is a better way. The CUDA driver includes a feature called `user induced GPU core dump generation`: the driver opens pipes in the operating system that allow users to trigger a core dump by writing to them. When triggered, the CUDA driver dumps the GPU state to core dump files, enabling inspection of what's happening inside the GPU and, most importantly, identifying which GPU kernel is hanging. Here is a simple example of a conditional hanging kernel: @@ -88,7 +88,7 @@ x = x + 2 torch.cuda.synchronize() ``` -Directly executing the code will hang forever. We can enable the `user induced GPU core dump generation` to debug the issue: +Directly executing the code will hang forever. We can enable the user induced GPU core dump generation to debug the issue: ```bash CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ From 8bbd332a1aca6c1198bc0a25aaae7a28ea380814 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:12:07 +0800 Subject: [PATCH 04/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 8046a49..ead3701 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -90,15 +90,6 @@ torch.cuda.synchronize() Directly executing the code will hang forever. We can enable the user induced GPU core dump generation to debug the issue: -```bash -CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ -CUDA_COREDUMP_SHOW_PROGRESS=1 \ -CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \ -CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" -``` - -Then, we can run the code and trigger the core dump: - ```bash CUDA_ENABLE_USER_TRIGGERED_COREDUMP=1 \ CUDA_COREDUMP_PIPE="/tmp/cuda_coredump_pipe_%h.%p.%t" \ @@ -109,13 +100,13 @@ CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ python conditional_hang.py ``` -While the code is running forever, and we suspect it is hanging in the `conditional_hang_kernel`, we can trigger the core dump by writing to the pipe: +While the code is running forever, and we suspect it is hanging somewhere, we can trigger the CUDA core dump by writing to the pipe: ```bash dd if=/dev/zero bs=1M count=1 > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276 ``` -Here we write 1MB of zeros to the pipe, which will trigger the core dump. Simple `echo aaa > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276` might not work due to the buffering of the pipe. +Here we write 1MB of zeros to the pipe, which will trigger the CUDA core dump. Simple `echo aaa > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276` might not work due to the buffering of the pipe. After we trigger the core dump, in the original terminal where we run the `python conditional_hang.py`, we will see the progress of the core dump: @@ -140,7 +131,7 @@ Opening GPU coredump: /tmp/cuda_coredump_hostname.3000837.1764236276 Excitingly, we can not only exactly locate the kernel `conditional_hang_kernel`, but also the exact line of code that the kernel is hanging at. This is a huge improvement over the previous situation where we have no idea which kernel is hanging, not to mention the exact line of code that caused the hanging. -One slightly annoying thing is that the core dump pipe's path is dynamically generated by the cuda driver, and it is not easy to find out. We can properly use `CUDA_COREDUMP_PIPE` environment variable to specify the path of the core dump pipe, so that we can find it easily by looking at the file descriptors of the process: +One slightly annoying thing is that the core dump pipe's path is dynamically generated by the cuda driver, and it is not easy to find out. We can properly use `CUDA_COREDUMP_PIPE` environment variable to specify the template path of the core dump pipe, so that we can find it easily by looking at the file descriptors of the process: ```bash $ ls /proc/3037675/fd/ -alth | grep /tmp/cuda_coredump_pipe_ From 312ae9a3c58d9b4eea1ddc30e5b49df00e95054c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:15:37 +0800 Subject: [PATCH 05/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index ead3701..893e44b 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -17,7 +17,7 @@ When a GPU kernel hangs, the program typically freezes or becomes unresponsive Fortunately, there is a better way. The CUDA driver includes a feature called `user induced GPU core dump generation`: the driver opens pipes in the operating system that allow users to trigger a core dump by writing to them. When triggered, the CUDA driver dumps the GPU state to core dump files, enabling inspection of what's happening inside the GPU and, most importantly, identifying which GPU kernel is hanging. -Here is a simple example of a conditional hanging kernel: +Consider a simple example of a conditional hanging kernel: ```python # save as conditional_hang.py @@ -88,7 +88,7 @@ x = x + 2 torch.cuda.synchronize() ``` -Directly executing the code will hang forever. We can enable the user induced GPU core dump generation to debug the issue: +Executing this code will hang indefinitely. To debug the issue, we can enable user-induced GPU core dump generation: ```bash CUDA_ENABLE_USER_TRIGGERED_COREDUMP=1 \ @@ -100,15 +100,15 @@ CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ python conditional_hang.py ``` -While the code is running forever, and we suspect it is hanging somewhere, we can trigger the CUDA core dump by writing to the pipe: +While the code is running indefinitely, we can trigger a CUDA core dump by writing to the pipe: ```bash dd if=/dev/zero bs=1M count=1 > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276 ``` -Here we write 1MB of zeros to the pipe, which will trigger the CUDA core dump. Simple `echo aaa > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276` might not work due to the buffering of the pipe. +We write 1MB of zeros to the pipe to trigger the CUDA core dump. Note that a simple `echo` command might not work due to pipe buffering. -After we trigger the core dump, in the original terminal where we run the `python conditional_hang.py`, we will see the progress of the core dump: +After triggering the core dump, the original terminal running `python conditional_hang.py` will display the core dump progress: ```text [01:39:15.256278] coredump: Writing ELF file to /tmp/cuda_coredump_hostname.3000837.1764236276 @@ -120,7 +120,7 @@ After we trigger the core dump, in the original terminal where we run the `pytho [01:39:15.292128] coredump: All done (took 00s) ``` -Then we can use `cuda-gdb` to open the core dump file, and see exactly where the kernel is hanging: +We can then use `cuda-gdb` to open the core dump file and see exactly where the kernel is hanging: ```text Opening GPU coredump: /tmp/cuda_coredump_hostname.3000837.1764236276 @@ -129,9 +129,9 @@ Opening GPU coredump: /tmp/cuda_coredump_hostname.3000837.1764236276 31 tl.store(x_ptr + offs, x, mask=mask) ``` -Excitingly, we can not only exactly locate the kernel `conditional_hang_kernel`, but also the exact line of code that the kernel is hanging at. This is a huge improvement over the previous situation where we have no idea which kernel is hanging, not to mention the exact line of code that caused the hanging. +This approach allows us to not only identify the hanging kernel (`conditional_hang_kernel`) but also pinpoint the exact line of code where it hangs. This represents a significant improvement over the previous situation, where identifying the problematic kernel was impossible, let alone the specific line causing the hang. -One slightly annoying thing is that the core dump pipe's path is dynamically generated by the cuda driver, and it is not easy to find out. We can properly use `CUDA_COREDUMP_PIPE` environment variable to specify the template path of the core dump pipe, so that we can find it easily by looking at the file descriptors of the process: +One minor inconvenience is that the core dump pipe's path is dynamically generated by the CUDA driver, making it difficult to locate. We can address this by using the `CUDA_COREDUMP_PIPE` environment variable to specify a template path for the core dump pipe, allowing us to find it easily by inspecting the process's file descriptors: ```bash $ ls /proc/3037675/fd/ -alth | grep /tmp/cuda_coredump_pipe_ From 43d5b8541b46b8f2c5e499c18b936ae67bc3f925 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:22:58 +0800 Subject: [PATCH 06/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 893e44b..7c439d6 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -140,7 +140,7 @@ lr-x------ 1 user user 64 Nov 27 01:50 98 -> /tmp/cuda_coredump_pipe_hostname.30 ## How to trace down the source code of a complicated kernel -In the previous [blogpost](https://blog.vllm.ai/2025/08/11/cuda-debugging.html), we mentioned that compiling with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable will embed line information into the compiled binary, so that we can trace down the exact line of code that caused the issue. After discussing and debugging several real-world issues, we find that the default way of showing line information in `cuda-gdb` is imperfect: +In the previous blog post, we mentioned that compiling with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable will embed line information into the compiled binary, so that we can trace down the exact line of code that caused the issue. After some discussion and debugging several real-world issues, we find that the default way of showing line information in `cuda-gdb` is imperfect: 1. For some complicated kernels, `cuda-gdb` will fail to find the correct line of code that caused the issue, even if the line information is embedded into the compiled binary. 2. Even if `cuda-gdb` can find the correct line of code, it will only show the last line of code after compiler inlining the code, which might not be the actual line of code that caused the issue. C++ code heavily relies on inlining to remove runtime function calling overhead, and we need the full inline stack of the code to understand the issue. @@ -177,7 +177,7 @@ index = torch.ones(10, device="cuda", dtype=torch.int32) + 100 print(data[index]) ``` -Run the code with PyTorch >= 2.9.0 (to be specific, make sure it includes [this commit](https://github.com/pytorch/pytorch/commit/dae7710bf2561e9e8a8dc76fd30c68e25bd755b8), otherwise you will see an error like `RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device.`), and you will hit an illegal memory access issue. +Run the code with PyTorch >= 2.9.0 (to be specific, make sure it includes [this commit](https://github.com/pytorch/pytorch/commit/dae7710bf2561e9e8a8dc76fd30c68e25bd755b8), otherwise you will see an error like `RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device.`), and you will hit an illegal memory access error. First, let's run with CUDA core dump enabled: @@ -205,15 +205,7 @@ Opening GPU coredump: /tmp/cuda_coredump_flow-matic.3756036.1764250282 [Current focus set to CUDA kernel 0, grid 4, block (0,0,0), thread (0,0,0), device 0, sm 124, warp 3, lane 0] CUDA Exception: Warp Illegal Address -The exception was triggered at PC 0x7ff533bb91d0 void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, lon -g)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorI -teratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native: -:gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef): -:{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{l -ambda(int)#1}) (IndexKernel.cu:118 in _ZZN2at6native16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIterator -BaseEN3c108ArrayRefIlEES9_EUlPcPKclE_EEvS6_S9_S9_RKT_bENKUliE_clEi inlined from IndexKernel.cu:37) +The exception was triggered at PC 0x7ff533bb91d0 ... #0 void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at ::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef< long>, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayR @@ -233,7 +225,7 @@ Next, inside `cuda-gdb`, we can use `info symbol $errorpc` to get more informati void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}) + 11472 in section .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ of /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn ``` -This gives us more information about the location of the error. `cuda-gdb` will unpack the compiled library, and `/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn` is a cubin file that contains the `index_elementwise_kernel`. The error is happening at the `0x7ff533bb91d0` location in the cubin file. We can use `nvdisasm` to disassemble the cubin file, and see exactly which line of code is causing the issue: +This gives us more information about the location of the error. `cuda-gdb` will unpack the compiled binary file, and `/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn` is a cubin file that contains the `index_elementwise_kernel`. The error is happening at the `0x7ff533bb91d0` location in the cubin file. We can use `nvdisasm` to disassemble the cubin file, and see exactly which line of code is causing the issue: ```bash $ nvdisasm -ndf -c -gi /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt From 21f28a2f6bb3f77f85fc3987892266491fcfe0ff Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:34:14 +0800 Subject: [PATCH 07/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 43 ++++++++++---------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 7c439d6..82a6ba1 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -140,12 +140,13 @@ lr-x------ 1 user user 64 Nov 27 01:50 98 -> /tmp/cuda_coredump_pipe_hostname.30 ## How to trace down the source code of a complicated kernel -In the previous blog post, we mentioned that compiling with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable will embed line information into the compiled binary, so that we can trace down the exact line of code that caused the issue. After some discussion and debugging several real-world issues, we find that the default way of showing line information in `cuda-gdb` is imperfect: +In the previous blog post, we mentioned that compiling with the `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable embeds line information into the compiled binary, enabling us to trace down the exact line of code that caused the issue. After discussing and debugging several real-world issues, we found that the default way `cuda-gdb` displays line information is imperfect: -1. For some complicated kernels, `cuda-gdb` will fail to find the correct line of code that caused the issue, even if the line information is embedded into the compiled binary. -2. Even if `cuda-gdb` can find the correct line of code, it will only show the last line of code after compiler inlining the code, which might not be the actual line of code that caused the issue. C++ code heavily relies on inlining to remove runtime function calling overhead, and we need the full inline stack of the code to understand the issue. +1. For some complex kernels, `cuda-gdb` fails to find the correct line of code that caused the issue, even when line information is embedded in the compiled binary. -Let's take a concrete example to illustrate the issue. Here is a simple Python script that can cause an illegal memory access issue: +2. Even when `cuda-gdb` can find the correct line of code, it only shows the last line after compiler inlining, which may not be the actual line that caused the issue. Since C++ code heavily relies on inlining to remove runtime function call overhead, we need the full inline stack to understand the issue. + +Let's illustrate this with a concrete example. The following Python script demonstrates an illegal memory access issue: ```python # save as illegal_memory_access.py @@ -177,9 +178,9 @@ index = torch.ones(10, device="cuda", dtype=torch.int32) + 100 print(data[index]) ``` -Run the code with PyTorch >= 2.9.0 (to be specific, make sure it includes [this commit](https://github.com/pytorch/pytorch/commit/dae7710bf2561e9e8a8dc76fd30c68e25bd755b8), otherwise you will see an error like `RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device.`), and you will hit an illegal memory access error. +Run this code with PyTorch >= 2.9.0 (specifically, ensure it includes [this commit](https://github.com/pytorch/pytorch/commit/dae7710bf2561e9e8a8dc76fd30c68e25bd755b8); otherwise you will see an error like `RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device.`). This will trigger an illegal memory access error. -First, let's run with CUDA core dump enabled: +First, let's run the code with CUDA core dump enabled: ```bash CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ @@ -189,15 +190,15 @@ CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ python illegal_memory_access.py ``` -The core dump progress will explicitly show the kernel that caused the issue: +The core dump progress will explicitly identify the kernel that caused the issue: ```text _ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ ``` -From the kernel name, we can see that the issue is caused by the `index_elementwise_kernel` in PyTorch. To locate the exact line of code that caused the issue, we need to build PyTorch from source with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, and then run the code again. +From the kernel name, we can see that the issue is caused by PyTorch's `index_elementwise_kernel`. To locate the exact line of code that caused the issue, we need to build PyTorch from source with the `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, then run the code again. -When the compiled GPU kernel has line information embedded, we can use `cuda-gdb` to open the core dump file, and see exactly which line of code caused the issue: +When the compiled GPU kernel has line information embedded, we can use `cuda-gdb` to open the core dump file and see exactly which line of code caused the issue: ```text (cuda-gdb) target cudacore /tmp/cuda_coredump_flow-matic.3756036.1764250282 @@ -218,14 +219,14 @@ _18TensorIteratorBaseEN3c108ArrayRefIlEES8_ENKUlPcPKclE_clES9_SB_l inlined from 203 *reinterpret_cast(out_data) = *reinterpret_cast(in_data + offset); ``` -Next, inside `cuda-gdb`, we can use `info symbol $errorpc` to get more information about the location of the error: +Next, within `cuda-gdb`, we can use `info symbol $errorpc` to get more information about the error location: ```text (cuda-gdb) info symbol $errorpc void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef, at::native::index_kernel_impl >(at::TensorIteratorBase&, c10::ArrayRef, c10::ArrayRef)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}) + 11472 in section .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ of /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn ``` -This gives us more information about the location of the error. `cuda-gdb` will unpack the compiled binary file, and `/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn` is a cubin file that contains the `index_elementwise_kernel`. The error is happening at the `0x7ff533bb91d0` location in the cubin file. We can use `nvdisasm` to disassemble the cubin file, and see exactly which line of code is causing the issue: +This provides more information about the error location. `cuda-gdb` unpacks the compiled binary file, and `/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn` is a cubin file containing the `index_elementwise_kernel`. The error occurs at location `0x7ff533bb91d0` in the cubin file. We can use `nvdisasm` to disassemble the cubin file and see exactly which line of code is causing the issue: ```bash $ nvdisasm -ndf -c -gi /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt @@ -243,16 +244,16 @@ $ grep -C20 7ff533bb91d0 output.txt ... ``` -Now we can see the full inline stack of the code that caused the issue. What `cuda-gdb` shows by default, is only the last inline expansion. +Now we can see the full inline stack of the code that caused the issue. By default, `cuda-gdb` only shows the last inline expansion. -A bit explanation about the command: +A brief explanation of the command: - `-ndf`: Disable dataflow analyzer after disassembly. - `-c`: Only print code sections. - `-gi`: Annotate disassembly with source line information obtained from .debug_line section along with function inlining info, if present. -- `-C20`: a `grep` argument showing the 20 lines of context around the founded Program Counter number `7ff533bb91d0` . +- `-C20`: a `grep` argument showing 20 lines of context around the found Program Counter address `7ff533bb91d0`. -In case the cubin file contains multiple kernels with the same Program Counter number, i.e. `grep` shows multiple matches, then we need to further filter the information: +If the cubin file contains multiple kernels with the same Program Counter address (i.e., `grep` shows multiple matches), we need to further filter the information: ```bash $ cuobjdump -elf /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > elf.txt @@ -275,9 +276,9 @@ $ grep -C20 7ff533bb91d0 output.txt ... ``` -The main difference is to get the cuda function index (the `-fun` argument) from `cuobjdump`, by searching the function's elf section, which is `26a` in this case. +The main difference is obtaining the CUDA function index (the `-fun` argument) from `cuobjdump` by searching the function's ELF section, which is `26a` in this case. -Note that this is a simplified example to showcase the usage. Real-world kernels can be much more complicated. For example, here is a complicated inline case: +Note that this is a simplified example to demonstrate the technique. Real-world kernels can be much more complex. For example, here is a complex inline case: ```text //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 @@ -296,7 +297,7 @@ Note that this is a simplified example to showcase the usage. Real-world kernels /*7eebf5e9eb90*/ MOV R34, R26 ; ``` -In this case, the code to blame is: +In this case, the problematic code is:

@@ -305,11 +306,11 @@ In this case, the code to blame is: A line of poisoned code in the attention kernel.

-The faulty source code calls some cutlass functions, and the function it lives in also gets inlined by upper-level caller. In this case, we find that `cuda-gdb` cannot correctly associate the line. In fact, it does not show any line information around the error location. But even if it shows the correct line, it will only show the last inline frame, which is `File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158`, an internal inline expansion of the cutlass function, still useless to debug the underlying issue. +The faulty source code calls some CUTLASS functions, and the function containing it also gets inlined by an upper-level caller. In this case, `cuda-gdb` cannot correctly associate the line. In fact, it does not show any line information around the error location. Even when it shows the correct line, it only displays the last inline frame, which is `File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158`—an internal inline expansion of the CUTLASS function that is still unhelpful for debugging the underlying issue. -With the approach outlined above, we can uncover the full inline chain of the source code, and carefully check them one by one to see which line is guilty of the error. +With the approach outlined above, we can uncover the full inline chain of the source code and carefully examine each frame to identify which line is responsible for the error. -Warning: to get the max benefit out of CUDA core dump, line information is crucial. It is recommended to compile with `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, as this will transparently apply to all the compiled kernels, without having to dive deep into the compilation script to find the right place to add the flag. However, the flag is so transparent, that if you use some compilation caching mechanism such as `ccache`, the `ccache` will directly ignore the flag and reuse previous compiled results without actual compilation. When compiling from source, please make sure to disable the compilation caching mechanism. +**Warning:** To maximize the benefit of CUDA core dumps, line information is crucial. It is recommended to compile with the `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, as this transparently applies to all compiled kernels without needing to modify compilation scripts. However, this transparency means that if you use a compilation caching mechanism such as `ccache`, it may ignore the flag and reuse previously compiled results without actual compilation. When compiling from source, ensure that the compilation caching mechanism is disabled. If you use Just-In-Time compilation, please consult the documentation of your Just-In-Time compilation tool to see how to add line information. ## Conclusion From ec41bcc51beb59c14d0d38b15cef5a8e72609059 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 28 Nov 2025 22:39:06 +0800 Subject: [PATCH 08/11] update Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 82a6ba1..bc2b9e4 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -314,10 +314,10 @@ With the approach outlined above, we can uncover the full inline chain of the so ## Conclusion -This blog post introduced two advanced debugging techniques for CUDA kernels. The first one is to find hanging kernels using user-triggered core dump, and the second one is to trace down the source code of a complicated kernel via tracing down the line information embedded in the compiled binary. These techniques are powerful tools to debug complicated issues in CUDA kernels, and are especially useful for debugging illegal memory access issues. +This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. -The vLLM project aims to provide easy, fast, and cheap LLM serving for everyone, and easy debugging is also an important aspect. We will continue to share more debugging tips and techniques in the future, to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). +The vLLM project aims to provide easy, fast, and affordable LLM serving for everyone, and accessible debugging is an important aspect of this mission. We will continue to share more debugging tips and techniques in the future to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). # Acknowledgement -We would like to thank Ze Long and Sandarbh Jain from NVIDIA for their helpful discussions. Chao Hong from Moonshot AI helped providing the motivating example. +We would like to thank Ze Long and Sandarbh Jain from NVIDIA for their helpful discussions. Chao Hong from Moonshot AI helped provide the motivating example. From d06e19fdc4b95d3e51ece06a4aa1abeca3afa934 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 2 Dec 2025 20:59:18 +0800 Subject: [PATCH 09/11] add explanation about ctrl-c Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index bc2b9e4..7126985 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -15,6 +15,8 @@ GPU computational power has been increasing exponentially, but memory bandwidth When a GPU kernel hangs, the program typically freezes or becomes unresponsive—even pressing Ctrl-C cannot stop it. The most straightforward solution is to kill the process, but this approach provides no information about the root cause. Developers are left to guess blindly, bisecting code changes and running tests iteratively until they identify the issue. +> Side note on why pressing Ctrl-C doesn't work: pressing Ctrl-C sends a SIGINT signal to the process. If the process is running Python code, the SIGINT signal is caught by the Python interpreter, which turns it into a KeyboardInterrupt exception and queues the exception to be handled after the process returns to run Python code. However, if the process is running a CUDA kernel and waiting for the GPU to finish, it is waiting for the low-level CUDA API to return, while no Python code is running, so the KeyboardInterrupt exception cannot be raised. In the following `conditional_hang.py` example, if you want to terminate the process via Ctrl-C, you need to add `import signal; signal.signal(signal.SIGINT, signal.SIG_DFL)` at the beginning of the script so that Python interpreter does not catch the SIGINT signal, then Ctrl-C can successfully terminate the process. The downside is Python interpreter will not be able to show the error stack when it is stopped by Ctrl-C. + Fortunately, there is a better way. The CUDA driver includes a feature called `user induced GPU core dump generation`: the driver opens pipes in the operating system that allow users to trigger a core dump by writing to them. When triggered, the CUDA driver dumps the GPU state to core dump files, enabling inspection of what's happening inside the GPU and, most importantly, identifying which GPU kernel is hanging. Consider a simple example of a conditional hanging kernel: From a9d9404962e8edd6e3712163df8486d47ec5443d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 2 Dec 2025 21:01:59 +0800 Subject: [PATCH 10/11] add explanation about ctrl-c Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 7126985..b902ac6 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -15,7 +15,8 @@ GPU computational power has been increasing exponentially, but memory bandwidth When a GPU kernel hangs, the program typically freezes or becomes unresponsive—even pressing Ctrl-C cannot stop it. The most straightforward solution is to kill the process, but this approach provides no information about the root cause. Developers are left to guess blindly, bisecting code changes and running tests iteratively until they identify the issue. -> Side note on why pressing Ctrl-C doesn't work: pressing Ctrl-C sends a SIGINT signal to the process. If the process is running Python code, the SIGINT signal is caught by the Python interpreter, which turns it into a KeyboardInterrupt exception and queues the exception to be handled after the process returns to run Python code. However, if the process is running a CUDA kernel and waiting for the GPU to finish, it is waiting for the low-level CUDA API to return, while no Python code is running, so the KeyboardInterrupt exception cannot be raised. In the following `conditional_hang.py` example, if you want to terminate the process via Ctrl-C, you need to add `import signal; signal.signal(signal.SIGINT, signal.SIG_DFL)` at the beginning of the script so that Python interpreter does not catch the SIGINT signal, then Ctrl-C can successfully terminate the process. The downside is Python interpreter will not be able to show the error stack when it is stopped by Ctrl-C. +> [!NOTE] +> Why pressing Ctrl-C doesn't stop the process when a CUDA kernel is hanging? Pressing Ctrl-C sends a SIGINT signal to the process. If the process is running Python code, the SIGINT signal is caught by the Python interpreter, which turns it into a KeyboardInterrupt exception and queues the exception to be handled after the process returns to run Python code. However, if the process is running a CUDA kernel and waiting for the GPU to finish, it is waiting for the low-level CUDA API to return, while no Python code is running, so the KeyboardInterrupt exception cannot be raised. In the following `conditional_hang.py` example, if you want to terminate the process via Ctrl-C, you need to add `import signal; signal.signal(signal.SIGINT, signal.SIG_DFL)` at the beginning of the script so that Python interpreter does not catch the SIGINT signal, then Ctrl-C can successfully terminate the process. The downside is Python interpreter will not be able to show the error stack when it is stopped by Ctrl-C. Fortunately, there is a better way. The CUDA driver includes a feature called `user induced GPU core dump generation`: the driver opens pipes in the operating system that allow users to trigger a core dump by writing to them. When triggered, the CUDA driver dumps the GPU state to core dump files, enabling inspection of what's happening inside the GPU and, most importantly, identifying which GPU kernel is hanging. From 5f38a1e307c542717c84f0779f692b978772e8e6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 2 Dec 2025 21:25:51 +0800 Subject: [PATCH 11/11] add lucas Signed-off-by: youkaichao --- _posts/2025-11-27-improved-cuda-debugging.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index b902ac6..b5333e9 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -317,10 +317,10 @@ With the approach outlined above, we can uncover the full inline chain of the so ## Conclusion -This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. +This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. Using both in tandem we were able to recently debug [a hard-to-reproduce and tricky hang in the CUTLASS MLA attention backend](https://github.com/vllm-project/vllm/pull/26026), which actually stemmed from the upstream CUTLASS code example and has since been fixed in [v4.3.0](https://github.com/NVIDIA/cutlass/commit/b1d6e2c9b334dfa811e4183dfbd02419249e4b52). The vLLM project aims to provide easy, fast, and affordable LLM serving for everyone, and accessible debugging is an important aspect of this mission. We will continue to share more debugging tips and techniques in the future to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). # Acknowledgement -We would like to thank Ze Long and Sandarbh Jain from NVIDIA for their helpful discussions. Chao Hong from Moonshot AI helped provide the motivating example. +We would like to thank Ze Long and Sandarbh Jain from NVIDIA for their helpful discussions. Chao Hong from Moonshot AI helped provide the motivating example. Lucas Wilkinson from Red Hat helped polishing the draft.