Repository: google/CommonLoopUtils
Branch: main
Commit: 5150136b245e
Files: 52
Total size: 1.2 MB
Directory structure:
gitextract_g9bj2g2j/
├── .github/
│ └── workflows/
│ ├── build.yml
│ └── python-publish.yml
├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── clu/
│ ├── __init__.py
│ ├── asynclib.py
│ ├── asynclib_test.py
│ ├── checkpoint.py
│ ├── checkpoint_test.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── dataset_iterator.py
│ │ └── dataset_iterator_test.py
│ ├── deterministic_data.py
│ ├── deterministic_data_test.py
│ ├── internal/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── metric_writers/
│ │ ├── __init__.py
│ │ ├── async_writer.py
│ │ ├── async_writer_test.py
│ │ ├── interface.py
│ │ ├── logging_writer.py
│ │ ├── logging_writer_test.py
│ │ ├── multi_writer.py
│ │ ├── multi_writer_test.py
│ │ ├── summary_writer.py
│ │ ├── tf/
│ │ │ ├── __init__.py
│ │ │ ├── summary_writer.py
│ │ │ └── summary_writer_test.py
│ │ ├── torch_tensorboard_writer.py
│ │ ├── torch_tensorboard_writer_test.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── metrics.py
│ ├── metrics_test.py
│ ├── parameter_overview.py
│ ├── parameter_overview_test.py
│ ├── periodic_actions.py
│ ├── periodic_actions_test.py
│ ├── platform/
│ │ ├── __init__.py
│ │ ├── interface.py
│ │ └── local.py
│ ├── preprocess_spec.py
│ ├── preprocess_spec_test.py
│ ├── profiler.py
│ ├── run_pytest.google.sh
│ └── values.py
├── clu_synopsis.ipynb
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/build.yml
================================================
# This workflow will install Python dependencies, run tests and lint.
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Build
on:
push:
branches:
- main
- 'test_*'
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .
pip install .[test]
- name: Test with pytest and generate coverage report
run: |
pytest .
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
file: ./coverage.xml
================================================
FILE: .github/workflows/python-publish.yml
================================================
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
================================================
FILE: AUTHORS
================================================
# This is the list of Common Loop Utils significant contributors.
#
# This does not necessarily list everyone who has contributed code,
# especially since many employees of one corporation may be contributing.
# To see the full list of contributors, see the revision history in
# source control.
Google LLC
================================================
FILE: CHANGELOG.md
================================================
# Changelog
## v0.0.1-alpha.1
Initial PyPi Release
Current list of modules:
- `clu.checkpoint`
- `clu.deterministic_training`
- `clu.metric_writers`
- `clu.periodic_actions`
- `clu.platform`
- `clu.profiler`
## v0.0.1-alpha.2
- Adds `metrics` module and some minor changes.
## v0.0.1a3
- Added `metric_writers.TorchTensorboardWriter`
## v0.0.2
- Added preprocess_spec.
- Improvements to periodic_actions.
## v0.0.3
- `metric_writers`: Lets `SummaryWriter` write nested dictionaries.
- `internal`: Adds `async.Pool`.
- `preprocess_spec`: Support nested dictionaries.
- `profile`: Use JAX profiler APIs instead of TF profiler APIs.
## v0.0.4
`deterministic_data`
- Support non-positive input value for pad_up_to_batches.
- Support padding dataset when data dimension is unknown.
- Support TFDS specs in get_read_instruction_for_host.
- Allow changing drop_remainder for batching.
- Add RemainderOptions in deterministic_data.
`metric_writers`
- Support multiple writers in metric_writers.ensure_flushes.
`metrics`
- Makes internal.flatten_dict() work with ConfigDicts.
- Forwards mask model output to metrics created via `Metric.from_output()`.
- Forwards mask model output to metrics created via `Metric.from_fun()`.
- Added `Collections.unreplicate()`, `Collections.create()`.
`periodic_actions`
- Formats long time strings in '{days}d{hours}h{mins}m' format.
`preprocess_spec`
- Make feature description of features in PreprocessFn more compact.
- Better type check in `preprocess_spec.get_all_ops()`.
Documentation:
- Added `clu_synopsis.ipynb` Colab
## v0.0.5
- Log error instead of failing when `profiler.start()` raises an exception.
- Makes `periodic_actions.ProgressUpdate` show total number of steps.
- Makes `AsyncWriter` non-blocking wrt JAX async computations.
- Adds `clu_synopsis.ipynb` Colab as initial documentation.
- Restore Checkpoint without providing the state
- Makes `PreprocessFn` addable.
- Allow n-dimensional arrays (and masks) to be passed to Metrics.Average().
- Support slicing `PreprocessFn`.
## v0.0.6
- Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`.
This will be the last release supporting Python 3.6.
## v0.0.7
- Moves `clu.internal.asynclib` to `clu.asynclib`.
- Adds methods for writing raw tensors and audio to `MetricWriter`.
- Adds `clu.values` to annotate arrays with a modality.
- Adds `clu.data.DatasetIterator` - a generic interface between input
pipelines and training loops.
- Fixes various issues with `clu.metrics`.
This will be the last release supporting Python 3.7.
## v0.0.9
- Fix pytype failures related to teaching pytype about NumPy scalar types.
- Fix a couple of docstring typos.
- Updates README and clu_synposis.ipynb
Last release before dropping support for Python 3.8 and 3.9
## v0.0.10
- `clu.parameter_overview` now supports JAX global arrays.
- Various small fixes in `clu.metrics` module.
- Removed some tensorflow dependencies.
## v0.0.11
- Removes numpy version pin
- Adds sharding annotations, dtype, total bytes to `parameter_overview`
- Makes `clu.metrics.Std` support same shapes as `clu.metrics.Average`
## v0.0.12
- Switch from `jax.tree_map` (deprecated since JAX 0.4.26) to
`jax.tree_util.tree_map`.
- Improvements to parameter overview.
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
At this time we are focused on supporting research done by Google Research and
are not accepting patches.
You are however free to start of fork of the project for your purposes as
permitted by the license.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# CLU - Common Loop Utils
This repository contains common functionality for writing ML training loops. The
goal is to make trainings loops short and readable (but moving common tasks to
small libraries) without removing the flexibility required for research.
To get started, check out this Colab:
https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb
If you're looking for usage examples, see:
https://github.com/google/flax/tree/main/examples
You can also find answers to common questions about CLU on Flax Github
discussions page:
https://github.com/google/flax/discussions
Note: As this point we are not accepting contributions. Please fork the
repository if you want to extend the libraries for your use case.
================================================
FILE: clu/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: clu/asynclib.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for async function calls."""
import collections
import concurrent.futures
import functools
import sys
import threading
from typing import Callable, List, Optional
from absl import logging
class AsyncError(Exception):
"""An exception that wraps another exception that ocurred asynchronously."""
class Pool:
"""Pool for wrapping functions to be executed asynchronously.
Synopsis:
from clu.internal import asynclib
pool = asynclib.Pool()
@pool
def fn():
time.sleep(1)
future = fn()
print(future.result())
fn() # This could re-raise an exception from the first execution.
print(len(pool)) # Would print "1" because there is one function in flight.
pool.flush() # This could re-raise an exception from the second execution.
"""
def __init__(self, thread_name_prefix: str = "",
max_workers: Optional[int] = None):
"""Creates a new pool that decorates functions for async execution.
Args:
thread_name_prefix: See documentation of `ThreadPoolExecutor`.
max_workers: See documentation of `ThreadPoolExecutor`. The default `None`
optimizes for parallelizability using the number of CPU cores. If you
specify `max_workers=1` you the async calls are executed in the same
order they have been scheduled.
"""
self._pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=thread_name_prefix)
self._max_workers = max_workers
self._thread_name_prefix = thread_name_prefix
self._errors = collections.deque()
self._errors_mutex = threading.Lock()
self._queue_length = 0
def _reraise(self) -> None:
if self._errors:
with self._errors_mutex:
exc_info = self._errors.popleft()
exc = exc_info[1].with_traceback(exc_info[2])
raise AsyncError(f"Error '{exc}' occurred ASYNCHRONOUSLY.") from exc
def close(self) -> None:
"""Closes this pool & raise a pending exception (if needed)."""
self._pool.shutdown(wait=True)
self._reraise()
def join(self) -> None:
"""Blocks until all functions are processed.
The pool can be used to schedule more functions after calling this function,
but there might be more exceptions
Side-effect:
If any of the functions raised an exception, then the first of these
exceptions is reraised.
"""
self._pool.shutdown(wait=True)
self._pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_workers,
thread_name_prefix=self._thread_name_prefix)
self._reraise()
@property
def queue_length(self) -> int:
"""Returns the number of functions that have not returned yet."""
return self._queue_length
@property
def has_errors(self) -> bool:
"""Returns True if there are any pending errors."""
return bool(self._errors)
def clear_errors(self) -> List[Exception]:
"""Clears all pending errors and returns them as a (possibly empty) list."""
with self._errors_mutex:
errors, self._errors = self._errors, collections.deque()
return list(errors)
def __call__(self, fn: Callable): # pylint: disable=g-bare-generic
"""Returns an async version of fn.
The function will be executed by this class's ThreadPoolExecutor. Any errors
will be stored and re-raised next time any function is called that is
executed through this pool.
Note that even if there was a previous error, the function is still
scheduled upon re-execution of the wrapper returned by this function.
Args:
fn: Function to be wrapped.
Returns:
An async version of `fn`. The return value of that async version will be
a future (unless an exception was re-raised).
"""
def inner(*args, **kwargs):
def trap_errors(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
with self._errors_mutex:
self._errors.append(sys.exc_info())
logging.exception("Error in producer thread for %s",
self._thread_name_prefix)
raise e
finally:
self._queue_length -= 1
self._queue_length += 1
if not self.has_errors:
return self._pool.submit(trap_errors, *args, **kwargs)
self._pool.submit(trap_errors, *args, **kwargs)
self._reraise()
if isinstance(fn.__name__, str):
# Regular function.
return functools.wraps(fn)(inner)
# Mock or another weird function that fails with functools.wraps().
return inner
================================================
FILE: clu/asynclib_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clu.asynclib."""
from unittest import mock
from absl.testing import absltest
from clu import asynclib
class AsyncWriterTest(absltest.TestCase):
def test_async_execution(self):
pool = asynclib.Pool()
counter = 0
@pool
def fn(counter_increment, return_value):
nonlocal counter
counter += counter_increment
return return_value
future = fn(1, return_value=2)
self.assertEqual(counter, 1)
self.assertEqual(future.result(), 2)
def test_reraise(self):
pool = asynclib.Pool()
@pool
def error():
raise ValueError("test")
error()
self.assertTrue(pool.has_errors)
with self.assertRaisesRegex(asynclib.AsyncError, "test"):
pool.join()
self.assertFalse(pool.has_errors)
@pool
def noop():
...
error()
self.assertTrue(pool.has_errors)
with self.assertRaisesRegex(asynclib.AsyncError, "test"):
noop()
self.assertFalse(pool.has_errors)
pool.join()
@mock.patch("concurrent.futures.ThreadPoolExecutor")
def test_queue_length(self, executor_mock):
pool_mock = mock.Mock()
in_flight = []
def execute_one():
in_flight.pop(0)()
def submit(fn, *args, **kwargs):
in_flight.append(lambda: fn(*args, **kwargs))
pool_mock.submit = submit
executor_mock.return_value = pool_mock
pool = asynclib.Pool()
@pool
def noop():
...
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
noop()
self.assertEqual(pool.queue_length, 2)
execute_one()
self.assertEqual(pool.queue_length, 1)
execute_one()
self.assertEqual(pool.queue_length, 0)
@mock.patch("concurrent.futures.ThreadPoolExecutor")
def test_flush(self, executor_mock):
pool_mock = mock.Mock()
pool_mock._in_flight = None
def execute_one():
pool_mock._in_flight.pop(0)()
def submit(fn, *args, **kwargs):
pool_mock._in_flight.append(lambda: fn(*args, **kwargs))
def create_pool(max_workers, thread_name_prefix):
del max_workers
del thread_name_prefix
pool_mock._in_flight = []
return pool_mock
def shutdown(wait=False):
if wait:
while pool_mock._in_flight:
execute_one()
pool_mock._in_flight = None
pool_mock.submit = submit
executor_mock.side_effect = create_pool
pool_mock.shutdown.side_effect = shutdown
pool = asynclib.Pool()
@pool
def noop():
...
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
noop()
pool.join()
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/checkpoint.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple checkpointing library for TF2/Flax.
The class `Checkpoint` is a simple wrapper around `tf.train.Checkpoint` that
also stores a `flax.struct.dataclass` instance in the same directory.
Synopsis:
from clu import checkpoint
import flax
@flax.struct.dataclass
class TrainState:
optimizer: flax.optim.Optimizer
step: int
ds = load_tf_dataset()
ds_iter = iter(ds)
ckpt = checkpoint.MultihostCheckpoint(base_directory, dict(ds_iter=ds_iter))
optimizer = create_flax_optimizer()
state = TrainState(optimizer=optimizer, step=0)
state = ckpt.restore_or_initialize(state) # Also restores `ds_iter`.
initial_step = int(state.step) + 1
# Need to replicate all data when training with multiple accelerators.
state = flax.jax_utils.replicate(state)
for step in range(initial_step, steps + 1):
state = update_step(state, next(ds_iter))
ckpt.save(flax.jax_utils.unreplicate(state))
Loading the model e.g. in a Colab:
from clu import checkpoint
import flax
from . import mnist_lib
state_dict = checkpoint.load_state_dict(base_directory)
params = state_dict['optimizer']['target']['params']
module = mnist_lib.MyArchitecture.partial(num_classes=10)
model = flax.deprecated.nn.Model(module, params)
"""
import collections
import os
import re
from typing import Any, Dict, Optional, TypeVar
from absl import logging
from clu.internal import utils
import flax
import jax
import tensorflow as tf
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
T = TypeVar("T")
SCHEME_RE = re.compile("^(?P[a-z][a-z0-9.+-]+://)?(?P.*)", re.I)
def safe_normpath(path: str) -> str:
"""Normalizes path safely to get around `gfile.glob()` limitations."""
d = SCHEME_RE.match(path).groupdict() # pytype: disable=attribute-error # re-none
return (d["scheme"] or "") + os.path.normpath(d["path"])
def load_state_dict(base_directory) -> Dict[str, Any]:
"""Restores `state` as dictionary from the latest checkpoint.
Synopsis:
data = checkpoint.load_state_dict(base_directory)
params = data['optimizer']['target']['params']
module = mnist_lib.MyArchitecture.partial(num_classes=10)
model = flax.deprecated.nn.Model(module, params)
Args:
base_directory: Directory from which the checkpoints should be restored. See
`Checkpoint.__init__()`.
Returns:
The deserialized Flax data, as a dictionary.
Raises:
FileNotFoundError: If there is no checkpoint to restore.
"""
return Checkpoint(base_directory).load_state(state=None)
class CheckpointInfo(
collections.namedtuple("CheckpointInfo", ("prefix", "number"))):
"""Helper class to parse a TensorFlow checkpoint path."""
CHECKPOINT_REGEX = r"^(?P.*)-(?P\d+)"
@classmethod
def initialize(cls, base_directory, checkpoint_name: str) -> "CheckpointInfo":
"""Creates a first CheckpointInfo (number=1)."""
return cls(f"{base_directory}/{checkpoint_name}", 1)
@classmethod
def from_path(cls, checkpoint: str) -> "CheckpointInfo":
"""Parses a checkpoint.
Args:
checkpoint: A checkpoint prefix, as can be found in the
`.latest_checkpoint` property of a `tf.train.CheckpointManager`.
Returns:
An instance of `CheckpointInfo` that represents `checkpoint`.
"""
m = re.match(cls.CHECKPOINT_REGEX, checkpoint)
if m is None:
RuntimeError(f"Invalid checkpoint format: {checkpoint}")
d = m.groupdict() # pytype: disable=attribute-error
return cls(d["prefix"], int(d["number"]))
def increment(self) -> "CheckpointInfo":
"""Returns a new CheckpointInfo with `number` increased by one."""
return CheckpointInfo(self.prefix, self.number + 1)
def __str__(self):
"""Does the opposite of `.from_path()`."""
return f"{self.prefix}-{self.number}"
class Checkpoint:
"""A utility class for storing and loading TF2/Flax checkpoints.
Both the state of a `tf.data.Dataset` iterator and a `flax.struct.dataclass`
are stored on disk in the following files:
- {directory}/checkpoint
- {directory}/ckpt-{number}.index
- {directory}/ckpt-{number}.data@*
- {directory}/ckpt-{number}.flax
Where {number} starts at 1 is then incremented by 1 for every new checkpoint.
The last file is the `flax.struct.dataclass`, serialized in Messagepack
format. The other files are explained in more detail in the Tensorflow
documentation:
https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
"""
def __init__(self,
base_directory: str,
tf_state: Optional[Dict[str, Any]] = None,
*,
max_to_keep: int = 5,
checkpoint_name: str = "ckpt"):
"""Initializes a Checkpoint with a dictionary of TensorFlow Trackables.
Args:
base_directory: Directory under which the checkpoints will be stored. Use
a different base_directory in every task.
tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
example a dataset iterator.
max_to_keep: Number of checkpoints to keep in the directory. If there are
more checkpoints than specified by this number, then the oldest
checkpoints are removed.
checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
"""
if tf_state is None:
tf_state = dict()
base_directory = safe_normpath(base_directory)
self.base_directory = base_directory
self.max_to_keep = max_to_keep
self.checkpoint_name = checkpoint_name
self.tf_checkpoint = tf.train.Checkpoint(**tf_state)
self.tf_checkpoint_manager = tf.train.CheckpointManager(
self.tf_checkpoint,
base_directory,
max_to_keep=max_to_keep,
checkpoint_name=checkpoint_name)
self.restored_from = None
def get_latest_checkpoint_to_restore_from(self):
"""Returns the latest checkpoint to restore from.
In the current implementation, this method simply returns the attribute
`latest_checkpoint`.
Subclasses can override this method to provide an alternative checkpoint to
restore from, for example for synchronization across multiple checkpoint
directories.
"""
return self.latest_checkpoint
@property
def latest_checkpoint(self) -> Optional[str]:
"""Latest checkpoint, see `tf.train.CheckpointManager.latest_checkpoint`.
Returns:
A string to the latest checkpoint. Note that this string is path-like but
it does not really describe a file, but rather a set of files that are
constructed from this string, by appending different file extensions. The
returned value is `None` if there is no previously stored checkpoint in
`base_directory` specified to `__init__()`.
"""
return self.tf_checkpoint_manager.latest_checkpoint
@property
def current_checkpoint(self) -> Optional[str]:
"""Returns current checkpoint.
Note that after instance creation this will point to "ckpt-0", which does
not actually exist. After the first save (either via `.save()` or via
`.restore_or_initialize()`) it will point to "ckpt-1". When the checkpoint
is loaded from a specific checkpoint (via `.restore(state, checkpoint)`)
then this property can be different from `.latest_checkpoint`.
Returns:
A string refering to the current checkpoint. See `.latest_checkpoint` for
a description of the format.
"""
latest_checkpoint = self.latest_checkpoint
if latest_checkpoint is None:
return None
checkpoint_info = CheckpointInfo.from_path(latest_checkpoint)
number = self.tf_checkpoint.save_counter.numpy()
return str(checkpoint_info._replace(number=number))
def _flax_path(self, checkpoint: str) -> str:
return "{}.flax".format(checkpoint)
def _next_checkpoint(self, checkpoint: Optional[str]) -> str:
if checkpoint is None:
return str(
CheckpointInfo.initialize(self.base_directory, self.checkpoint_name))
return str(CheckpointInfo.from_path(checkpoint).increment())
def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]:
if checkpoint is None:
return None
return CheckpointInfo.from_path(checkpoint).number
def _delete_future_checkpoints(self):
"""Deletes checkpoints that are newer than the currently loaded checkpoint.
This happens when the checkpoint was initialized from a checkpoint that was
not the latest checkpoint (e.g. when recovering from a pre-emption in a
`MultihostCheckpoint` where some workers finished writing their checkpoints
and others didn't).
"""
checkpoint = self.current_checkpoint
while True:
checkpoint = self._next_checkpoint(checkpoint)
paths = tf.io.gfile.glob(f"{checkpoint}.*")
if not paths:
break
for path in paths:
logging.info("Cleaning up future checkpoint file '%s'", path)
tf.io.gfile.remove(path)
@utils.logged_with("Checkpoint.save()")
def save(self, state) -> str:
"""Saves a new checkpoints in the directory.
Note that if the checkpoint was restored from an earlier checkpoint than the
latest available, then saving the checkpoint will and/or delete any
checkpoints later than the restored one.
For example, if there are checkpoints `(1, 2, 3)` and then checkpoint `1`
is restored, then calling `.save()` on that restored checkpoint will result
in `2` being overwritten and `3` being deleted.
This overwriting/deleting behavior allows for seamless integration with
`MultihostCheckpoint` after pre-emption (i.e. one of the workers might have
stored one more checkpoint, but that checkpoint is only available on that
one worker and must be overwritten when the training continues).
After such an overwrite, the attributes `.current_checkpoint` and
`.latest_checkpoint` will point to newly written checkpoint (in above case
`2`), but the list `.tf_checkpoint_manager.checkpoints` might be out of sync
and should not be used.
Args:
state: Flax checkpoint to be stored.
Returns:
The checkpoint identifier ({base_directory}/ckpt-{number}).
"""
self._delete_future_checkpoints()
next_checkpoint = self._next_checkpoint(self.current_checkpoint)
flax_path = self._flax_path(next_checkpoint)
logging.info("Storing next checkpoint '%s'", next_checkpoint)
if not tf.io.gfile.exists(self.base_directory):
tf.io.gfile.makedirs(self.base_directory)
with tf.io.gfile.GFile(flax_path, "wb") as f:
f.write(flax.serialization.to_bytes(state))
checkpoints_before_save = set(self.tf_checkpoint_manager.checkpoints)
# Write Tensorflow data last. This way Tensorflow checkpoint generation
# logic will make sure to only commit checkpoints if they complete
# successfully. A previously written `flax_path` would then simply be
# overwritten next time.
self.tf_checkpoint_manager.save()
# Clean up stale Flax. Tensorflow automatically does remove checkpoints
# older than `max_to_keep`, so we do the same for the Flax checkpoints.
stale_checkpoints = checkpoints_before_save - set(
self.tf_checkpoint_manager.checkpoints)
for checkpoint in stale_checkpoints:
if tf.io.gfile.exists(self._flax_path(checkpoint)):
tf.io.gfile.remove(self._flax_path(checkpoint))
assert self.current_checkpoint == next_checkpoint, (
"Expected next_checkpoint to match .current_checkpoint: "
f"{next_checkpoint} != {self.current_checkpoint}")
return self.current_checkpoint
@utils.logged_with("Checkpoint.restore_or_initialize()")
def restore_or_initialize(self, state: T) -> T:
"""Restores from the latest checkpoint, or creates a first checkpoint.
Args:
state : A data structure to be stored or to serve as a template. If the
checkpoint is restored (and not initialized), then the fields of `state`
must match the data previously stored. See
`flax.serialization.from_state_dict()` for details.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
"""
checkpoint = self.get_latest_checkpoint_to_restore_from()
if checkpoint is not None:
return self.restore(state, checkpoint)
logging.info("Storing initial version.")
self.save(state)
return state
def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, Any]:
"""Restores last checkpoint and returns `state` as dictionary.
The only difference between this method and `.restore()` is the return type
annotation.
Args:
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
return self.restore(state=None, checkpoint=checkpoint)
def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str:
if checkpoint is None:
checkpoint = self.get_latest_checkpoint_to_restore_from()
if checkpoint is None:
raise FileNotFoundError(f"No checkpoint found at {self.base_directory}")
return checkpoint
def load_state(self,
state: Optional[T],
checkpoint: Optional[str] = None) -> T:
"""Restores Flax state the latest checkpoint.
As opposed to `.restore()`, this function only reads the Flax checkpint and
does not read the (potentially very large) TensorFlow state.
Args:
state : Template data structure that will serve as a template for the
returned state. If the loaded data does not match that template, then an
exception is raised. It's also possible to specify `state=None`, in
which case a dictionary will be returned. See
`flax.serialization.from_state_dict()` for details.
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint))
if not tf.io.gfile.exists(flax_path):
raise FileNotFoundError(f"Checkpoint {checkpoint} does not exist")
with tf.io.gfile.GFile(flax_path, "rb") as f:
return flax.serialization.from_bytes(state, f.read())
def restore(self,
state: Optional[T],
checkpoint: Optional[str] = None) -> T:
"""Restores from the latest checkpoint.
Similar to `restore_or_initialize()`, but raises a `FileNotFoundError` if
there is no checkpoint.
Args:
state : Template data structure that will serve as a template for the
returned state. If the loaded data does not match that template, then an
exception is raised. It's also possible to specify `state=None`, in
which case a dictionary will be returned. See
`flax.serialization.from_state_dict()` for details.
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
checkpoint = self._checkpoint_or_latest(checkpoint)
logging.info("Restoring checkpoint: %s", checkpoint)
state = self.load_state(state, checkpoint)
self.tf_checkpoint.restore(checkpoint)
logging.info("Restored save_counter=%d restored_checkpoint=%s",
self.tf_checkpoint.save_counter.numpy(),
checkpoint)
self.restored_from = checkpoint
return state
class MultihostCheckpoint(Checkpoint):
"""An subclass of `Checkpoint` that synchronizes between multiple JAX hosts.
If the training split across multiple hosts, then the following race condition
can occur : If a host is pre-empted while writing a checkpoint, then the other
hosts will only be restarted with a small delay, and at that point they
probably already have finished writing their checkpoint. Upon restart, the
host that was interrupted while writing the checkpoint will load the latest
fully written checkpoint, which will be out of sync with the other hosts that
successfully wrote one more checkpoint.
This class also allows to specify a `multihost_base_directory` that is
identical for all hosts and will be used to drive a host-specific directory.
"""
def __init__(self,
multihost_base_directory: str,
tf_state: Optional[Dict[str, Any]] = None,
*,
host_id: Optional[int] = None,
max_to_keep: int = 5,
checkpoint_name: str = "ckpt"):
"""Initializes a MultihostCheckpoint with a dict of TensorFlow Trackables.
Args:
multihost_base_directory: Directory that will be used to construct a
host-specific `base_directory` under which the checkpoints will be
stored. Usually a directory *within* the work unit's workdirectory (e.g.
`f"{workdir}/checkpoints`). One directory per host will be created at
the same level as this base directory labeled
`f"{multihost_base_directory}-{host_id}"`.
tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
example a dataset iterator.
host_id: Host ID used to construct the `base_directory`. Taken from
`jax.process_index()` if not specified.
max_to_keep: Number of checkpoints to keep in the directory. If there are
more checkpoints than specified by this number, then the oldest
checkpoints are removed.
checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
"""
if max_to_keep < 2:
raise ValueError("Requires multiple checkpoints (max_to_keep>=2).")
multihost_base_directory = multihost_base_directory.rstrip("/")
self.multihost_base_directory = multihost_base_directory
if host_id is None:
host_id = jax.process_index()
base_directory = f"{multihost_base_directory}-{host_id}"
super().__init__(
base_directory,
tf_state,
max_to_keep=max_to_keep,
checkpoint_name=checkpoint_name)
@utils.logged_with(
"MultihostCheckpoint.get_latest_checkpoint_to_restore_from()")
def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:
"""Returns the latest checkpoint available on all hosts."""
base_directory_glob = f"{self.multihost_base_directory}-*"
base_directories = tf.io.gfile.glob(base_directory_glob)
if self.base_directory not in base_directories:
logging.info("%s not in %s", self.base_directory, base_directories)
return None
checkpoints = {}
common_numbers = None
all_numbers = set()
for base_directory in base_directories:
checkpoint_manager = tf.train.CheckpointManager(
tf.train.Checkpoint(),
base_directory,
max_to_keep=self.max_to_keep,
checkpoint_name=self.checkpoint_name)
numbers = [
CheckpointInfo.from_path(checkpoint).number
for checkpoint in checkpoint_manager.checkpoints
]
checkpoints[base_directory] = dict(
zip(numbers, checkpoint_manager.checkpoints))
numbers = set(numbers)
if common_numbers is None:
common_numbers = numbers
else:
common_numbers &= numbers
all_numbers |= numbers
logging.info(
"Checked checkpoint base_directories: %s - common_numbers=%s "
"- exclusive_numbers=%s", base_directories, common_numbers,
all_numbers.difference(common_numbers))
if not common_numbers:
return None
highest_number = sorted(common_numbers)[-1]
return checkpoints[self.base_directory][highest_number]
================================================
FILE: clu/checkpoint_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clu.checkpoint."""
import os
import tempfile
from unittest import mock
from clu import checkpoint
import flax
import tensorflow as tf
def _make_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
features = dict(x=inputs, y=labels)
return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2)
@flax.struct.dataclass
class TrainState:
step: int
@flax.struct.dataclass
class TrainStateExtended:
step: int
name: str
class NotTrainState:
pass
def _checkpoint_number(path):
*parts, number = path.split("-")
del parts
return int(number)
class CheckpointTest(tf.test.TestCase):
def test_safe_normpath(self):
self.assertEqual(checkpoint.safe_normpath("./test_dir"), "test_dir")
self.assertEqual(checkpoint.safe_normpath(".//test_dir"), "test_dir")
self.assertEqual(checkpoint.safe_normpath("gs://test_dir"), "gs://test_dir")
self.assertEqual(
checkpoint.safe_normpath("gs://test_dir/"), "gs://test_dir")
def test_initialize_mkdir(self):
base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
self.assertIsNone(ckpt.current_checkpoint)
self.assertIsNone(ckpt.latest_checkpoint)
self.assertFalse(os.path.isdir(base_dir))
state = ckpt.restore_or_initialize(state)
self.assertIsNotNone(ckpt.latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, ckpt.current_checkpoint)
self.assertTrue(os.path.isdir(base_dir))
def test_restores_flax_state(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=2)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
state = TrainState(step=2)
# Stores step=2.
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore(state)
self.assertEqual(state.step, 2)
state = TrainState(step=3)
# Stores step=3
path2 = ckpt.save(state)
self.assertEqual(_checkpoint_number(path2), 3)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore(state, path)
self.assertEqual(state.step, 2)
def test_load_state_dict(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
# Initializes.
state = ckpt.restore_or_initialize(state)
# Load via load_state_dict().
flax_dict = checkpoint.load_state_dict(base_dir)
self.assertEqual(flax_dict, dict(step=1))
with self.assertRaisesRegex(FileNotFoundError, r"^No checkpoint found"):
checkpoint.load_state_dict(tempfile.mkdtemp())
def test_fails_when_restoring_subset(self):
base_dir = tempfile.mkdtemp()
state = TrainStateExtended(step=1, name="test")
ckpt = checkpoint.Checkpoint(base_dir)
# Initialixes with TrainStateExtended.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores with TrainState.
with self.assertRaisesRegex(ValueError, r"^Unknown field"):
state = ckpt.restore_or_initialize(state)
def test_fails_when_restoring_superset(self):
base_dir = tempfile.mkdtemp()
ckpt = checkpoint.Checkpoint(base_dir)
state = TrainState(step=0)
# Initialixes with TrainState.
state = ckpt.restore_or_initialize(state)
state = TrainStateExtended(step=1, name="test")
# Restores with TrainStateExtended.
with self.assertRaisesRegex(ValueError, r"^Missing field"):
state = ckpt.restore_or_initialize(state)
def test_restores_tf_state(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
features0 = next(ds_iter) # Advance iterator by one.
del features0
state = TrainState(step=1)
# Initialize at features1.
state = ckpt.restore_or_initialize(state)
features1 = next(ds_iter)
features2 = next(ds_iter)
self.assertNotAllEqual(features1["x"], features2["x"])
self.assertNotAllEqual(features1["y"], features2["y"])
# Restore at features1.
state = ckpt.restore_or_initialize(state)
features1_restored = next(ds_iter)
self.assertAllEqual(features1["x"], features1_restored["x"])
self.assertAllEqual(features1["y"], features1_restored["y"])
# Save at features2.
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
features2 = next(ds_iter)
features3 = next(ds_iter)
self.assertNotAllEqual(features2["x"], features3["x"])
self.assertNotAllEqual(features2["y"], features3["y"])
# Restore at features2.
state = ckpt.restore_or_initialize(state)
features2_restored = next(ds_iter)
self.assertAllEqual(features2["x"], features2_restored["x"])
self.assertAllEqual(features2["y"], features2_restored["y"])
# Restore at features2 as dictionary.
state = ckpt.restore_dict()
features2_restored = next(ds_iter)
self.assertAllEqual(features2["x"], features2_restored["x"])
self.assertAllEqual(features2["y"], features2_restored["y"])
def test_restore_flax_alone(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
state = TrainState(step=1)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
ckpt = checkpoint.Checkpoint(base_dir)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
def test_restore_dict(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
with self.assertRaisesRegex(FileNotFoundError, r"No checkpoint found at"):
ckpt.restore_dict()
with self.assertRaisesRegex(FileNotFoundError,
r"Checkpoint invalid does not exist"):
ckpt.restore_dict(checkpoint="invalid")
state = TrainState(step=1)
ckpt.save(state)
state_dict = ckpt.restore_dict()
self.assertEqual(state_dict, dict(step=1))
first_checkpoint = ckpt.latest_checkpoint
new_state = TrainState(step=2)
ckpt.save(new_state)
self.assertEqual(
ckpt.restore_dict(checkpoint=first_checkpoint),
dict(step=1))
self.assertEqual(ckpt.restore_dict(), dict(step=2))
self.assertEqual(
ckpt.restore_dict(checkpoint=ckpt.latest_checkpoint),
dict(step=2))
def test_ignores_incomplete_checkpoint(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
state = TrainState(step=2)
# Failed save : step=2 is stored, but TensorFlow checkpoint fails.
ckpt.tf_checkpoint_manager.save = None
with self.assertRaisesRegex(TypeError,
r"'NoneType' object is not callable"):
ckpt.save(state)
files = os.listdir(base_dir)
self.assertIn("ckpt-2.flax", files)
self.assertNotIn("ckpt-2.index", files)
ckpt = checkpoint.Checkpoint(base_dir)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
# Stores step=2.
state = TrainState(step=2)
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
files = os.listdir(base_dir)
self.assertIn("ckpt-2.flax", files)
self.assertIn("ckpt-2.index", files)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 2)
def test_max_to_keep(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
state = ckpt.restore_or_initialize(state)
files1 = os.listdir(base_dir)
state = TrainState(step=2)
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
files2 = os.listdir(base_dir)
self.assertEqual(len(files1), len(files2))
self.assertNotEqual(files1, files2)
def test_checkpoint_name(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, checkpoint_name="test")
path = ckpt.save(state)
self.assertIn("test", path)
def test_fails_if_not_registered(self):
base_dir = tempfile.mkdtemp()
not_state = NotTrainState()
ckpt = checkpoint.Checkpoint(base_dir)
with self.assertRaisesRegex(TypeError, r"serialize"):
ckpt.restore_or_initialize(not_state)
def test_overwrite(self):
base_dir = tempfile.mkdtemp()
tf_step = tf.Variable(1)
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
# Initialize step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
self.assertEqual(tf_step.numpy(), 1)
checkpoint_info = checkpoint.CheckpointInfo.from_path(
ckpt.current_checkpoint)
# Stores steps 2, 3, 4, 5
for _ in range(4):
tf_step.assign_add(1)
state = state.replace(step=state.step + 1)
ckpt.save(state)
latest_checkpoint = str(checkpoint_info._replace(number=5))
self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
# Restores at step=1
ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
state = ckpt.restore(state, checkpoint=str(checkpoint_info))
self.assertEqual(state.step, 1)
self.assertEqual(tf_step.numpy(), 1)
self.assertNotEqual(ckpt.current_checkpoint, ckpt.latest_checkpoint)
self.assertEqual(ckpt.current_checkpoint, str(checkpoint_info))
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
# Overwrites step=2, deletes 3, 4, 5.
tf_step.assign_add(1)
state = state.replace(step=state.step + 1)
ckpt.save(state)
latest_checkpoint = str(checkpoint_info._replace(number=2))
self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
class MultihostCheckpoint(tf.test.TestCase):
@mock.patch("jax.process_index")
def test_initialize_mkdir(self, process_index_mock):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
process_index_mock.return_value = 0
base_dir = f"{multihost_base_dir}-0"
ckpt = checkpoint.MultihostCheckpoint(multihost_base_dir)
self.assertIsNone(ckpt.latest_checkpoint)
self.assertFalse(os.path.isdir(base_dir))
state = ckpt.restore_or_initialize(state)
self.assertIsNotNone(ckpt.latest_checkpoint)
self.assertTrue(os.path.isdir(base_dir))
@mock.patch("jax.process_index")
def test_synchronize_multiple_hosts(self, process_index_mock):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
process_index_mock.return_value = 0
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
process_index_mock.return_value = 1
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
# Initialize both at step=1.
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
# Update both at step=2.
state_0 = state_0.replace(step=2)
ckpt_0.save(state_0)
state_1 = state_1.replace(step=2)
ckpt_1.save(state_1)
# Update ckpt_1 at step=3.
state_1 = state_1.replace(step=3)
ckpt_1.save(state_1)
# Reload both at step=2.
process_index_mock.return_value = 0
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
process_index_mock.return_value = 1
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
self.assertEqual(ckpt_0.latest_checkpoint,
ckpt_0.get_latest_checkpoint_to_restore_from())
self.assertNotEqual(ckpt_1.latest_checkpoint,
ckpt_1.get_latest_checkpoint_to_restore_from())
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 2)
self.assertEqual(state_1.step, 2)
def test_preemption(self):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
state0 = state.replace(step=0)
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
# Initialize both at step=1.
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Restore both at step=1.
state_0 = ckpt_0.restore_or_initialize(state0)
state_1 = ckpt_1.restore_or_initialize(state0)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Update only ckpt_0 to step=2.
state_0 = state_0.replace(step=2)
ckpt_0.save(state_0)
# Load both checkpoints at last common step=1.
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Store both at step=2.
state_0 = state_0.replace(step=2)
state_1 = state_1.replace(step=2)
ckpt_0.save(state_0)
ckpt_1.save(state_1)
# Restore both at step=2.
state_0 = ckpt_0.restore_or_initialize(state0)
state_1 = ckpt_1.restore_or_initialize(state0)
self.assertEqual(state_0.step, 2)
self.assertEqual(state_1.step, 2)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/data/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DatasetIterator is an interface for input pipelines."""
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
from clu.data.dataset_iterator import (
Array,
ArraySpec,
DatasetIterator,
Element,
ElementSpec,
TfDatasetIterator,
PeekableDatasetIterator,
PyTree,
)
================================================
FILE: clu/data/dataset_iterator.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for dataset iterators.
This module provides the DatasetIterator interface. This intention is that
several frameworks providing datasets can implement this interface without
knowing anything about the framework used for the model and the training loop.
Likewise can training loops assume to get an DatasetIterator object and do not
need to care about the specifics of the input pipelines.
This modules does not depend on TensorFlow. The interface is generic and users
don't have to use `tf.data` to construct a DatasetIterator. However, if they
use `tf.data` they can simply wrap their `tf.data.Dataset` object with
`TfDatasetIterator` to satisfy the interface.
"""
from __future__ import annotations
import abc
import collections.abc
import concurrent.futures
import dataclasses
import os
import threading
import typing
from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from absl import logging
from clu import asynclib
from etils import epath
import jax.numpy as jnp # Just for type checking.
import numpy as np
import numpy.typing as npt
Array = Union[np.ndarray, jnp.ndarray]
# Sizes of dimensions, None means the dimension size is unknown.
Shape = Tuple[Optional[int], ...]
@dataclasses.dataclass(frozen=True)
class ArraySpec:
"""Describes an array via it's dtype and shape."""
dtype: npt.DTypeLike
shape: Shape
def __repr__(self):
return f"ArraySpec(dtype={np.dtype(self.dtype).name}, shape={self.shape})"
def __str__(self):
return f"{np.dtype(self.dtype).name}{list(self.shape)}"
# Elements are PyTrees with NumPy/JAX arrays.
# Anything can be a PyTree (it's either a container or leaf). We define
# PyTree[T] as a PyTree where all leaves are of type T.
# See https://jax.readthedocs.io/en/latest/pytrees.html.
L = TypeVar("L") # pylint: disable=invalid-name
PyTree = Union[L, Sequence["PyTree[L]"], Mapping[str, "PyTree[L]"]]
Element = PyTree[Array]
ElementSpec = PyTree[ArraySpec]
class DatasetIterator(collections.abc.Iterator): # pytype: disable=ignored-abstractmethod
"""Generic interface for iterating over a dataset.
This does not support __getitem__ since it cannot be implemented efficiently
for many datasets. However datasets should allow starting the iterator from
an arbitrary position.
The element_spec property helps consumers to validate the input without
reading data. This is similar to `tf.data.Dataset.element_spec`.
Subclasses may decided to not read/write checkpoints if their state is
sufficiently tracked externally (e.g. input pipelines that can be correctly
restarted from the step number).
"""
def get_next(self) -> Element:
"""Returns the next element."""
logging.error(
"DatasetIterator.get_next() is deprecated. Please use next().")
# Subclasses should implement __next__() and remove calls to get_next().
return next(self)
def reset(self):
"""Resets the iterator back to the beginning."""
raise NotImplementedError
@property
@abc.abstractmethod
def element_spec(self) -> ElementSpec:
"""Returns the spec elements."""
raise NotImplementedError()
def save(self, filename: epath.Path):
"""Saves the state of the iterator to a file.
This should only handle this iterator - not iterators in other processes.
Args:
filename: Name of the checkpoint.
"""
raise NotImplementedError()
def restore(self, filename: epath.Path):
"""Restores the iterator from a file (if available).
This should only handle this iterator - not iterators in other processes.
Args:
filename: Name of the checkpoint.
"""
raise NotImplementedError()
def load(self, filename: epath.Path):
logging.error("DatasetIterator.load() is deprecated. Please use restore().")
return self.restore(filename)
class TfDatasetIterator(DatasetIterator):
"""DatasetIterator for wrapping a `tf.data.Dataset`."""
def __init__(self, dataset, *, checkpoint: bool):
"""Wraps `tf.data.Dataset` object into the `DatasetIterator` interface.
Warning: Do not wrap this interator to do asynchronous prefetching if you
use `checkpoint=True` (default). tf.data iterators must be saved()
synchronously.
Args:
dataset: The dataset to wrap. Elements are converted to NumPy arrays but
no additional prefetching is done. tf.data should automatically prefetch
elements (to CPU memory).
checkpoint: Whether to checkpoint the dataset iterator object.
Checkpointing dataset iterators is required for handling job
pre-emptions but depending on your input pipeline can result in very
large checkpoints. If set to False save() and load() are no-ops.
"""
try:
# Since this is the only class in this module using TF we only import
# tensorflow if needed.
if typing.TYPE_CHECKING:
tf = Any
else:
import tensorflow as tf # pylint: disable=g-import-not-at-top
except ImportError as e:
raise RuntimeError("When using TfDatasetIterator your binary must "
"depend on //third_party/py/tensorflow.") from e
self._tf = tf
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("`dataset` must be an instance of `tf.data.Dataset` "
f"but got {type(dataset)}.")
self._dataset = dataset
self._checkpoint = checkpoint
assert self.element_spec # Verify element spec.
self.iterator = iter(dataset)
self._ckpt = tf.train.Checkpoint(ds=self.iterator)
def get_next(self) -> Element:
return next(self)
def __next__(self) -> Element:
return {k: np.asarray(v) for k, v in next(self.iterator).items()}
def reset(self):
self.iterator = iter(self._dataset)
self._ckpt = self._tf.train.Checkpoint(ds=self.iterator)
@property
def element_spec(self) -> ElementSpec:
element_spec = self._dataset.element_spec
if not isinstance(element_spec, dict):
raise ValueError("Dataset elements must be flat dictionaries but got "
f"{element_spec}.")
invalid_features = [
k for k, v in element_spec.items()
if not isinstance(v, self._tf.TensorSpec)
]
if invalid_features:
raise ValueError(f"Features {invalid_features} are not tensors. Dataset "
"elements must be flat dictionaries of tensors.")
return {
k: ArraySpec(dtype=v.dtype.as_numpy_dtype, shape=tuple(v.shape))
for k, v in element_spec.items()
}
def save(self, filename: epath.Path):
if self._checkpoint:
self._ckpt.write(os.fspath(filename))
def restore(self, filename: epath.Path):
if self._checkpoint:
self._ckpt.read(os.fspath(filename)).assert_consumed()
class PeekableDatasetIterator(DatasetIterator):
"""Wraps a DatasetIterator to provide a peek() method.
This allows to look at the next element which can be useful in 2 scenarios:
a) Get the structure of elements if the element_spec property is not
supported.
b) Request the next element without consuming it. This is especially handy to
trigger reading of the first element while the model is being initialized.
Example use case:
>>> pool = clu.asynclib.Pool()
>>> @pool
>>> def warmup_input_pipeline():
>>> train_iter.peek()
>>> first_batch_ready = warmup_input_pipeline()
>>> # Do other stuff...
>>> first_batch_ready.result() # wait for input pipeline to be ready.
"""
def __init__(self, it: DatasetIterator):
self._it = it
# Mutex for self._it.
self._mutex = threading.Lock()
self._peek: Optional[Element] = None
self._pool = None
self._peek_future = None
def __next__(self) -> Element:
with self._mutex:
if self._peek is None:
return next(self._it)
peek = self._peek
self._peek = None
return peek
def reset(self):
with self._mutex:
self._it.reset()
self._peek = None
self._pool = None
self._peek_future = None
@property
def element_spec(self) -> ElementSpec:
return self._it.element_spec
def peek(self) -> Element:
"""Returns the next element without consuming it.
This will get the next element from the underlying iterator. The element
is stored and return on the next call of __next__().
Returns:
The next element.
"""
if self._peek is None:
self._peek = next(self)
return self._peek
def peek_async(self) -> concurrent.futures.Future[Element]:
"""Same as peek() but returns the Future of the element.
Users can call this to warm up the iterator.
Returns:
Future with the next element. The element is also kept and returned on the
next call of __next__().
"""
with self._mutex:
if self._peek_future is None:
if self._pool is None:
self._pool = asynclib.Pool(max_workers=1)
self._peek_future = self._pool(self.peek)()
return self._peek_future
def save(self, filename: epath.Path):
with self._mutex:
self._it.save(filename)
def restore(self, filename: epath.Path):
with self._mutex:
self._it.restore(filename)
================================================
FILE: clu/data/dataset_iterator_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset_iterator."""
import itertools
import pathlib
import tempfile
from absl.testing import parameterized
from clu.data import dataset_iterator
import numpy as np
import tensorflow as tf
INDEX = "_index"
class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):
def _create_iterator(self, start_index: int, checkpoint: bool = True):
"""Create an iterator over some prime numbers with index."""
primes = tf.constant([2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
ds = tf.data.Dataset.range(start_index, 10)
ds = ds.map(lambda i: {INDEX: i, "prime": primes[i]})
# Remove index 1 and 3.
ds = ds.filter(lambda x: tf.logical_and(x["prime"] != 3, x["prime"] != 7))
ds = ds.batch(2, drop_remainder=True)
return dataset_iterator.TfDatasetIterator(ds, checkpoint=checkpoint)
def test_tf_iterator(self):
it = self._create_iterator(0)
self.assertEqual(
it.element_spec, {
INDEX: dataset_iterator.ArraySpec(np.int64, (2,)),
"prime": dataset_iterator.ArraySpec(np.int32, (2,))
})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
it.reset()
# Iterator starts from the beginning.
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
def test_tf_iterator_save_and_load(self):
it = self._create_iterator(0)
next(it)
next(it)
next(it)
work_dir = pathlib.Path(tempfile.mkdtemp())
filename = work_dir / "ckpt"
it.save(filename)
self.assertTrue((work_dir / "ckpt.index").exists())
it = self._create_iterator(0)
# Iterator is at the beginning (batch 1).
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
it.load(filename)
# Iterator is at the end (batch 4).
self.assertEqual(next(it), {INDEX: [8, 9], "prime": [23, 29]})
def test_tf_iterator_save_and_load_no_checkpoint(self):
it = self._create_iterator(0, checkpoint=False)
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
work_dir = pathlib.Path(tempfile.mkdtemp())
filename = work_dir / "ckpt"
it.save(filename) # Should be a no-op and not create a checkpoint.
self.assertFalse((work_dir / "ckpt.index").exists())
it = self._create_iterator(0, checkpoint=False)
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
it.restore(filename) # Should be a no-op, iterator just continues.
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
def test_peekable_dataset_iterator(self):
it = self._create_iterator(0)
it = dataset_iterator.PeekableDatasetIterator(it)
self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
@parameterized.parameters(itertools.product([True, False], [True, False]))
def test_peekable_dataset_iterator_async(self, wait: bool, peek_first: bool):
it = self._create_iterator(0)
it = dataset_iterator.PeekableDatasetIterator(it)
future = it.peek_async()
self.assertIsNone(it._peek)
if wait:
future.result()
self.assertIsNotNone(it._peek)
if peek_first:
self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/deterministic_data.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Helper functions for building deterministic tf.data input pipelines.
The function `create_dataset()` makes it easy to build a `tf.data` based input
pipeline that allows for completely reproducible results based on a single
initial random seed. The caller must take care to create a unique initial seed
on every host that is then passed to `create_dataset()`, where further unique
random keys are derived for every batch. Within a single batch, this key is
exposed as the special feature "rng" and can be used to implement stateless
preprocessing functions.
The function `get_read_instruction_for_host()` makes it easy to split a dataset
evenly between multiple hosts in a SPMD setup with multiple machines. Within a
single host, every batch is usually distributed to all the attached accelerators
(the first value of the `batch_dims` argument to `create_dataset()`).
The function `create_distributed_dataset()` finally is intended to be used in
conjunction with a `tf.distribute.Strategy`.
Synopsis for deterministic training with multiple hosts:
import jax
from clu import deterministic_data
rng = jax.random.PRNGKey(42) # Global RNG (e.g. from config)
rng = jax.random.fold_in(rng, jax.process_index()) # Derive RNG for this host.
dataset_builder = tfds.builder(...)
split = deterministic_data.get_read_instruction_for_host(
"train", dataset_builder.info.splits["train"].num_examples)
ds = deterministic_data.create_dataset(
dataset_builder,
split=split,
rng=rng
)
ds_iter = iter(ds)
for _ in range(num_train_steps):
batch = jax.tree_util.tree_map(lambda x: x._numpy(), next(ds_iter)
# (training step)
"""
import enum
import functools
import operator
from typing import Callable, Dict, Optional, Sequence, Union
from absl import logging
import jax
import jax.numpy as jnp
import numpy as np
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds
import typing_extensions
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
Features = Dict[str, Tensor]
AUTOTUNE = tf.data.experimental.AUTOTUNE
_use_split_info = version.parse("4.4.0") < version.parse(
tfds.version.__version__)
class DatasetBuilder(typing_extensions.Protocol):
"""Protocol for dataset builders (subset of tfds.core.DatasetBuilder)."""
def as_dataset(
self, split: Union[str, tfds.core.ReadInstruction], shuffle_files: bool,
read_config: tfds.ReadConfig,
decoders: Optional[Dict[str, tfds.decode.Decoder]]) -> tf.data.Dataset:
...
class RemainderOptions(enum.Enum):
"""How to handle examples not divisible by number of processes.
Possible values:
- DROP: Examples not divisible by process count will be dropped. Every host
receives the same number of examples.
- BALANCE_ON_PROCESSES: Examples not divisible by process count will be
distributed evenly on processes, by increasing process number. For example,
if there are 4 processes and 7 examples, then processes 0, 1, 2 will have
2 examples, and process 3 will have 1 example.
- ON_FIRST_PROCESS: Examples not divisible by process count will be assigned
to process 0.
"""
DROP = 0
BALANCE_ON_PROCESSES = 1
ON_FIRST_PROCESS = 2
def _shard_read_instruction(
absolute_instruction,
*,
split_infos: Dict[str, Union[int, tfds.core.SplitInfo]],
host_id: int,
host_count: int,
remainder_options: RemainderOptions,
) -> tfds.core.ReadInstruction:
"""Shards a single ReadInstruction. See get_read_instruction_for_host()."""
start = absolute_instruction.from_ or 0
if _use_split_info:
end = absolute_instruction.to or (
split_infos[absolute_instruction.splitname].num_examples) # pytype: disable=attribute-error
else:
end = absolute_instruction.to or split_infos[absolute_instruction.splitname]
assert end >= start, f"start={start}, end={end}"
num_examples = end - start
examples_per_host = num_examples // host_count
shard_start = start + examples_per_host * host_id
shard_end = start + examples_per_host * (host_id + 1)
# Handle remaining examples.
num_unused_examples = num_examples % host_count
assert num_unused_examples >= 0, num_unused_examples
assert num_unused_examples < host_count, num_unused_examples
if num_unused_examples > 0:
if remainder_options == RemainderOptions.DROP:
logging.warning("Dropping %d examples of %d examples (host count: %d).",
num_unused_examples, num_examples, host_count)
elif remainder_options == RemainderOptions.BALANCE_ON_PROCESSES:
shard_start += min(host_id, num_unused_examples)
shard_end += min(host_id + 1, num_unused_examples)
elif remainder_options == RemainderOptions.ON_FIRST_PROCESS:
shard_end += num_unused_examples
if host_id > 0:
shard_start += num_unused_examples
else:
raise ValueError(f"Invalid remainder_options: {remainder_options}")
return tfds.core.ReadInstruction(
absolute_instruction.splitname,
from_=shard_start,
to=shard_end,
unit="abs")
_DEPRECATE_MSG = """
`get_read_instruction_for_host` is DEPRECATED.
Migration instruction: Use `tfds.split_for_jax_process` which is simpler
and nativelly supported by TFDS.
```
split = tfds.split_for_jax_process('train[75%:]', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
```
See: https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training
"""
def get_read_instruction_for_host(
split: str,
num_examples: Optional[int] = None,
*,
dataset_info: Optional[tfds.core.DatasetInfo] = None,
host_id: Optional[int] = None,
host_count: Optional[int] = None,
drop_remainder: Optional[bool] = None,
remainder_options: RemainderOptions = RemainderOptions.DROP
) -> tfds.core.ReadInstruction:
"""Returns a `ReadInstruction` of the data ranges for this host.
`get_read_instruction_for_host` is DEPRECATED. Please use
`tfds.split_for_jax_process` or `tfds.even_split`. See:
https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training
In a distributed setting all hosts should get the same number of examples.
This can exclude a few (< host_count) examples.
The examples are distributed evenly across the hosts, and remaining examples
are distributed to the hosts with the lowest id.
Assuming a single epoch, the number of batches (e.g. for
`create_dataset(pad_up_to_batches)`) can be computed by:
batches = int(np.ceil(num_examples / global_batch_size))
Args:
split: Name of the dataset split to use or TFDS spec (e.g.
`train[:800]+validation[:100]`). If you use the spec you must pass
dataset_info. For specs with multiple splits each split is sharded
independently of the other splits.
num_examples: Deprecated - use dataset_info instead. Number of examples of
the split.
dataset_info: TFDS dataset info; used to get the number of examples per
split.
host_id: Optional, host index in [0, host_count). Defaults to
`jax.process_index()`.
host_count: Optional, number of hosts. Defaults to `jax.host_count`.
drop_remainder: Deprecated - use remainder_options instead.
remainder_options: The options to handle the remaining examples.
"""
logging.warning(_DEPRECATE_MSG)
if num_examples is not None:
logging.warning(
"`num_examples` is deprecated. Please pass `dataset_info` instead.")
if drop_remainder is not None:
remainder_options = (
RemainderOptions.DROP
if drop_remainder else RemainderOptions.BALANCE_ON_PROCESSES)
logging.warning(
"`drop_remainder` is deprecated. Please pass `remainder_options` "
"instead. `remainder_options` is reset with %s.", remainder_options)
if dataset_info is None:
if any(special in split for special in ["[", "]", "+"]):
raise ValueError(
f"Sharding split {split} requires passing `dataset_info`.")
if host_id is None:
host_id = jax.process_index()
if host_count is None:
host_count = jax.process_count()
if host_id < 0 or host_id >= host_count or host_count < 1:
raise ValueError(
f"Invalid combination of host_id ({host_id}) and host_count "
f"({host_count}).")
if _use_split_info:
if dataset_info is None:
split_infos = {
split: tfds.core.SplitInfo(
name=split,
shard_lengths=[num_examples],
num_bytes=0,
),
}
else:
split_infos = dataset_info.splits
else:
if dataset_info is None:
split_infos = {split: num_examples}
else:
split_infos = {k: v.num_examples for k, v in dataset_info.splits.items()}
read_instruction = tfds.core.ReadInstruction.from_spec(split)
sharded_read_instructions = []
for ri in read_instruction.to_absolute(split_infos):
sharded_read_instructions.append(
_shard_read_instruction(
ri,
split_infos=split_infos,
host_id=host_id,
host_count=host_count,
remainder_options=remainder_options))
return functools.reduce(operator.add, sharded_read_instructions)
def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
preprocess_fn: Callable[[Features],
Features], *,
rng: jnp.ndarray) -> tf.data.Dataset:
"""Maps `ds` using the preprocess_fn and a deterministic RNG per example.
Args:
ds: Dataset containing Python dictionary with the features. The 'rng'
feature should not exist.
preprocess_fn: Preprocessing function that takes a Python dictionary of
tensors and returns a Python dictionary of tensors. The function should be
convertible into a TF graph.
rng: Base RNG to use. Per example RNGs will be derived from this by folding
in the example index.
Returns:
The dataset mapped by the `preprocess_fn`.
"""
def _fn(example_index: int, features: Features) -> Features:
example_index = tf.cast(example_index, tf.int32)
features["rng"] = tf.random.experimental.stateless_fold_in(
tf.cast(rng, tf.int64), example_index)
processed = preprocess_fn(features)
if isinstance(processed, dict) and "rng" in processed:
del processed["rng"]
return processed
return ds.enumerate().map(_fn, num_parallel_calls=AUTOTUNE)
def pad_dataset(dataset: tf.data.Dataset,
*,
batch_dims: Sequence[int],
pad_up_to_batches: Optional[int] = None,
cardinality: Optional[int] = None):
"""Adds padding to a dataset.
Args:
dataset: The dataset to be padded.
batch_dims: List of size of batch dimensions. Multiple batch dimension can
be used to provide inputs for multiple devices. E.g.
[jax.local_device_count(), batch_size // jax.device_count()].
pad_up_to_batches: Set this option to process the entire dataset. When set,
then the dataset is first padded to the specified number of batches. A new
feature called "mask" is added to every batch. This feature is set to
`True` for every example that comes from `dataset_builder`, and to `False`
for every example that is padded to get to the specified number of
batches. Note that the specified `dataset_builder` and `split` must result
in at least `pad_up_to_batches` (possibly partial) batches. If `None`,
derives from `batch_dims` and `cardinality` such that `pad_up_to_batches *
batch_dims == cardinality`. Note that `cardinality` is what you pass in,
not necessarily the original full dataset size if you decide to shard it
per host.
cardinality: Number of examples in the dataset. Only needed when the
cardinality cannot be retrieved via `ds.cardinalty()` (e.g. because of
using `ds.filter()`).
Returns:
The padded dataset, with the added feature "mask" that is set to `True` for
examples from the original `dataset` and to `False` for padded examples.
"""
if not isinstance(dataset.element_spec, dict):
raise ValueError("The dataset must have dictionary elements.")
if cardinality is None:
cardinality = dataset.cardinality()
if cardinality == tf.data.UNKNOWN_CARDINALITY:
raise ValueError(
"Cannot determine dataset cardinality. This can happen when you use "
"a `.filter()` on the dataset. Please provide the cardinality as an "
"argument to `create_dataset()`.")
if "mask" in dataset.element_spec:
raise ValueError("Dataset already contains a feature named \"mask\".")
if pad_up_to_batches is None:
pad_up_to_batches = int(np.ceil(cardinality / np.prod(batch_dims)))
filler_element = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
filler_element["mask"] = [False]
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
dataset = dataset.map(
lambda features: dict(mask=True, **features), num_parallel_calls=AUTOTUNE)
padding = pad_up_to_batches * np.prod(batch_dims) - int(cardinality)
assert padding >= 0, (
f"Invalid padding={padding} (batch_dims={batch_dims}, cardinality="
f"{cardinality}, pad_up_to_batches={pad_up_to_batches})")
return dataset.concatenate(filler_dataset.repeat(padding))
def create_dataset(dataset_builder: DatasetBuilder,
*,
split: Union[str, tfds.core.ReadInstruction],
batch_dims: Sequence[int] = (),
rng: Union[None, jnp.ndarray, tf.Tensor] = None,
filter_fn: Optional[Callable[[Features], bool]] = None,
preprocess_fn: Optional[Callable[[Features],
Features]] = None,
decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
cache: bool = False,
num_epochs: Optional[int] = None,
shuffle: bool = True,
shuffle_buffer_size: int = 10_000,
prefetch_size: int = 4,
pad_up_to_batches: Optional[Union[int, str]] = None,
cardinality: Optional[int] = None,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Creates standard input pipeline (shuffle, preprocess, batch).
Args:
dataset_builder: Dataset builder object with a as_dataset() method. E.g.
instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
split: Specifies which split of the data to load. Passed on to
`tfds.DatasetBuilder.as_dataset()`. See also the
[split API guide](https://www.tensorflow.org/datasets/splits). In a multi
host setup, this parameter can conveniently be generated by the function
`get_read_instruction_for_host()`.
batch_dims: List of size of batch dimensions. Multiple batch dimension can
be used to provide inputs for multiple devices. E.g.
[jax.local_device_count(), batch_size // jax.device_count()].
rng: A jax.random.PRNG key or a tf.Tensor for TF stateless seeds to use of
seeding shuffle operations and preprocessing ops. Must be set if
shuffling.
filter_fn: Optional function to filter the decoded examples. This happens
before the preprocessing.
preprocess_fn: Function for preprocessing individual examples (which should
be Python dictionary of tensors).
decoders: Optional dictionary of decoder passed to as_dataset.
cache: Whether to cache the unprocessed dataset in memory.
num_epochs: Number of epochs for which to repeat the dataset. None to repeat
forever.
shuffle: Whether to shuffle the dataset (both on file and example level).
shuffle_buffer_size: Number of examples in the shuffle buffer.
prefetch_size: The number of elements in the final dataset to prefetch in
the background. This should be a small (say <10) positive integer or
tf.data.experimental.AUTOTUNE.
pad_up_to_batches: Set this option to process the entire dataset. - If set
with an integer, the dataset is first padded to the specified number of
batches. A new feature called "mask" is added to every batch. This feature
is set to `True` for every example that comes from `dataset_builder`, and
to `False` for every example that is padded. Note that the specified
`dataset_builder` and `split` must result in at least `pad_up_to_batches`
(possibly partial) batches. - If set with "auto", derives from
`batch_dims` and `cardinality` such that `pad_up_to_batches * batch_dims
== cardinality`. - If `None`, the dataset won't be padded.
cardinality: Number of examples in the dataset. Only needed when
`pad_up_to_batches` is specified and the cardinality cannot be retrieved
via `ds.cardinalty()` (e.g. because of `ds.filter()`).
drop_remainder: Whether to drop remainders when batching.
Returns:
The dataset with preprocessed and batched examples.
"""
rng_available = rng is not None
if not rng_available and shuffle:
raise ValueError("Please set 'rng' when shuffling.")
if rng_available:
if isinstance(rng, tf.Tensor):
rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)]
else:
rngs = list(jax.random.key_data(jax.random.split(rng, 3)))
else:
rngs = 3 * [[None, None]]
dataset_options = tf.data.Options()
dataset_options.experimental_optimization.map_parallelization = True
dataset_options.threading.private_threadpool_size = 48
dataset_options.threading.max_intra_op_parallelism = 1
read_config = tfds.ReadConfig(
shuffle_seed=rngs.pop()[0], options=dataset_options)
ds = dataset_builder.as_dataset(
split=split,
shuffle_files=shuffle,
read_config=read_config,
decoders=decoders)
if filter_fn is not None:
ds = ds.filter(filter_fn)
if cache:
ds = ds.cache()
if shuffle:
ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0])
ds = ds.repeat(num_epochs)
if preprocess_fn is not None:
if rng_available:
ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop())
else:
ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE)
if pad_up_to_batches is not None:
assert isinstance(pad_up_to_batches, int) or pad_up_to_batches == "auto"
ds = pad_dataset(
ds,
batch_dims=batch_dims,
pad_up_to_batches=(None if pad_up_to_batches == "auto" else
pad_up_to_batches),
cardinality=cardinality)
if batch_dims:
for batch_size in reversed(batch_dims):
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds.prefetch(prefetch_size)
StrOrReadInstruction = Union[str, tfds.core.ReadInstruction]
def create_distributed_dataset(
dataset_builder,
*,
split: Union[StrOrReadInstruction, Callable[[int, int],
StrOrReadInstruction]],
global_batch_size: int,
strategy: tf.distribute.Strategy,
rng: Optional[tf.Tensor] = None,
filter_fn: Optional[Callable[[Features], bool]] = None,
preprocess_fn: Optional[Callable[[Features], Features]] = None,
decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
cache: bool = False,
num_epochs: Optional[int] = None,
shuffle: bool = True,
shuffle_buffer_size: int = 10_000,
prefetch_size: int = 4,
pad_up_to_batches: Optional[int] = None,
cardinality: Optional[int] = None,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Creates standard input pipeline (shuffle, preprocess, batch).
Args:
dataset_builder: Dataset builder object with a as_dataset() method. E.g.
instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
split: Split name to use, will be passed to as_dataset(). To read different
data chunks on different replicas pass a callable that accepts the host_id
and host_count and returns a split name.
global_batch_size: Global batch size for all input pipelines together.
strategy: Distribution strategy for distributing the dataset.
rng: A tf.Tensor with a stateless random key to seed shuffle operations and
preprocessing ops.
filter_fn: Optional function to filter the decoded examples. This happens
before the preprocessing.
preprocess_fn: Function for preprocessing individual examples (which should
be Python dictionary of tensors)
decoders: Optional dictionary of decoder passed to as_dataset.
cache: Whether to cache the unprocessed dataset in memory.
num_epochs: Number of epochs for which to repeat the dataset. None to repeat
forever.
shuffle: Whether the shuffle the dataset (both on the file and example
level).
shuffle_buffer_size: Number of examples in the shuffle buffer.
prefetch_size: The number of elements in the final dataset to prefetch in
the background. This should be a small (say <10) positive integer or
tf.data.experimental.AUTOTUNE.
pad_up_to_batches: Set this option to process the entire dataset. When set,
then the dataset is first padded to the specified number of batches. A new
feature called "mask" is added to every batch. This feature is set to
`True` for every example that comes from `dataset_builder`, and to `False`
for every example that is padded to get to the specified number of
batches. Note that the specified `dataset_builder` and `split` must
provide at least `pad_up_to_batches` (possibly partial) batches.
cardinality: Number of examples in the dataset. Only needed when
`pad_up_to_batches` is specified and the cardinality cannot be retrieved
via `ds.cardinalty()` (e.g. because of `ds.filter()`).
drop_remainder: Whether to drop remainders when batching.
Returns:
The dataset with preprocessed and batched examples.
"""
def dataset_fn(input_context: tf.distribute.InputContext):
"""Returns the dataset for a single worker."""
logging.info("dataset_fn(input_context=%s)", input_context)
if rng is None:
local_rng = None
else:
local_rng = tf.random.experimental.stateless_fold_in(
rng, input_context.input_pipeline_id)
if callable(split):
local_split = split(input_context.input_pipeline_id,
input_context.num_input_pipelines)
else:
local_split = split
per_replica_batch_size = input_context.get_per_replica_batch_size(
global_batch_size)
return create_dataset(
dataset_builder=dataset_builder,
split=local_split,
batch_dims=[per_replica_batch_size],
rng=local_rng,
filter_fn=filter_fn,
preprocess_fn=preprocess_fn,
decoders=decoders,
cache=cache,
num_epochs=num_epochs,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
prefetch_size=prefetch_size,
pad_up_to_batches=pad_up_to_batches,
cardinality=cardinality,
drop_remainder=drop_remainder)
return strategy.distribute_datasets_from_function(dataset_fn)
================================================
FILE: clu/deterministic_data_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the deterministic_data module."""
import dataclasses
import itertools
import math
from typing import Dict
from unittest import mock
from absl.testing import parameterized
from clu import deterministic_data
import jax
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds
_use_split_info = version.parse("4.4.0") < version.parse(
tfds.version.__version__)
@dataclasses.dataclass
class MyDatasetBuilder:
name2len: Dict[str, int] # Number of examples per split.
def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool,
read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset:
del shuffle_files, read_config, decoders
if _use_split_info:
split_infos = {
k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
for k, v in self.name2len.items()
}
instructions = split.to_absolute(split_infos)
else:
instructions = split.to_absolute(self.name2len)
assert len(instructions) == 1
from_ = instructions[0].from_ or 0
to = instructions[0].to or self.name2len[instructions[0].splitname]
return tf.data.Dataset.range(from_, to).map(lambda i: {"index": i})
@dataclasses.dataclass
class FakeDatasetInfo:
train_size: int = 9
test_size: int = 8
@property
def splits(self):
return {
"train": tfds.core.SplitInfo("train", [self.train_size], 0),
"test": tfds.core.SplitInfo("test", [self.test_size], 0)
}
class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for deterministic_data module."""
@parameterized.parameters(
(9, 0, 1, True, "test[0:9]"),
(9, 0, 2, True, "test[0:4]"),
(9, 1, 2, True, "test[4:8]"), # Last example gets dropped.
(9, 0, 3, True, "test[0:3]"),
(9, 1, 3, True, "test[3:6]"),
(9, 2, 3, True, "test[6:9]"),
(9, 0, 1, False, "test[0:9]"),
(9, 0, 2, False, "test[0:5]"), # First host gets an extra example.
(9, 1, 2, False, "test[5:9]"),
(8, 0, 3, False, "test[0:3]"), # First 2 hosts get 1 example each.
(8, 1, 3, False, "test[3:6]"),
(8, 2, 3, False, "test[6:8]"),
)
def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
host_id: int,
host_count: int,
drop_remainder: bool,
expected_spec: str):
expected = tfds.core.ReadInstruction.from_spec(expected_spec)
actual = deterministic_data.get_read_instruction_for_host(
"test",
num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=drop_remainder)
if _use_split_info:
split_infos = {
"test": tfds.core.SplitInfo(
name="test",
shard_lengths=[9],
num_bytes=0,
)}
else:
split_infos = {"test": 9}
self.assertEqual(
expected.to_absolute(split_infos), actual.to_absolute(split_infos))
@parameterized.parameters(
# host_id, host_count, drop_remainder, spec, exected_spec_for_host
# train split has 9 examples.
(0, 1, True, "train", "train[0:9]"),
(0, 2, True, "train", "train[0:4]"),
(1, 2, True, "train", "train[4:8]"), # Last example gets dropped.
(0, 3, True, "train", "train[0:3]"),
(1, 3, True, "train", "train[3:6]"),
(2, 3, True, "train", "train[6:9]"),
(0, 1, False, "train", "train[0:9]"),
(0, 2, False, "train", "train[0:5]"), # First host gets an extra example.
(1, 2, False, "train", "train[5:9]"),
# test split has 8 examples.
(0, 3, False, "test", "test[0:3]"), # First 2 hosts get 1 example each.
(1, 3, False, "test", "test[3:6]"),
(2, 3, False, "test", "test[6:8]"),
# Subsplits.
(0, 2, True, "train[:50%]", "train[0:2]"),
(1, 2, True, "train[:50%]", "train[2:4]"),
(0, 2, True, "train[3:7]", "train[3:5]"),
(1, 2, True, "train[3:7]", "train[5:7]"),
(0, 2, True, "train[3:8]", "train[3:5]"), # Last example gets dropped.
(1, 2, True, "train[3:8]", "train[5:7]"),
# 2 splits.
(0, 2, True, "train[3:7]+test", "train[3:5]+test[0:4]"),
(1, 2, True, "train[3:7]+test", "train[5:7]+test[4:8]"),
# First host gets an extra example.
(0, 2, False, "train[3:8]+test[:5]", "train[3:6]+test[0:3]"),
(1, 2, False, "train[3:8]+test[:5]", "train[6:8]+test[3:5]"),
)
def test_get_read_instruction_for_host(self, host_id: int, host_count: int,
drop_remainder: bool, spec: str,
expected_spec_for_host: str):
actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
spec,
dataset_info=FakeDatasetInfo(),
host_id=host_id,
host_count=host_count,
drop_remainder=drop_remainder)
expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
expected_spec_for_host)
self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
@parameterized.parameters(
# host_id, host_count, balance_remainder, spec, exected_spec_for_host
# test split has 10 examples.
(0, 1, True, "test", "test[0:10]"),
(0, 1, False, "test", "test[0:10]"),
(0, 4, True, "test", "test[0:3]"),
(1, 4, True, "test", "test[3:6]"),
(2, 4, True, "test", "test[6:8]"),
(3, 4, True, "test", "test[8:10]"),
(0, 4, False, "test", "test[0:4]"),
(1, 4, False, "test", "test[4:6]"),
(2, 4, False, "test", "test[6:8]"),
(3, 4, False, "test", "test[8:10]"),
)
def test_get_read_instruction_balance_remainder(self, host_id: int,
host_count: int,
balance_remainder: bool,
spec: str,
expected_spec_for_host: str):
actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
spec,
dataset_info=FakeDatasetInfo(test_size=10),
host_id=host_id,
host_count=host_count,
remainder_options=deterministic_data.RemainderOptions
.BALANCE_ON_PROCESSES if balance_remainder else
deterministic_data.RemainderOptions.ON_FIRST_PROCESS)
expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
expected_spec_for_host)
self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
@parameterized.parameters(
(0, 0), # No hosts.
(1, 1), # Only one host (host_id is zero-based.
(-1, 1), # Negative host_id.
(5, 2), # host_id bigger than number of hosts.
)
def test_get_read_instruction_for_host_fails(self, host_id: int,
host_count: int):
with self.assertRaises(ValueError):
deterministic_data.get_read_instruction_for_host(
"test", 11, host_id=host_id, host_count=host_count)
def test_preprocess_with_per_example_rng(self):
def preprocess_fn(features):
features["b"] = tf.random.stateless_uniform([], features["rng"])
return features
rng = jax.random.PRNGKey(42)
ds_in = tf.data.Dataset.from_tensor_slices({"a": [37.2, 31.2, 39.0]})
ds_out = deterministic_data._preprocess_with_per_example_rng(
ds_in, preprocess_fn, rng=rng)
self.assertAllClose([
{
"a": 37.2,
"b": 0.79542184
},
{
"a": 31.2,
"b": 0.45482683
},
{
"a": 39.0,
"b": 0.85335636
},
], list(ds_out))
@parameterized.parameters(*itertools.product([2, "auto"], [True, False]))
def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
dataset_builder = mock.Mock()
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
dataset_builder.as_dataset.return_value = dataset
batch_dims = (2, 5)
ds = deterministic_data.create_dataset(
dataset_builder,
split="(ignored)",
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=pad_up_to_batches,
cardinality=12 if cardinality else None,
)
ds_iter = iter(ds)
self.assertAllClose(
dict(
x=tf.ones((2, 5, 10)),
y=tf.ones((2, 5)),
mask=tf.ones((2, 5), bool),
), next(ds_iter))
self.assertAllClose(
dict(
x=tf.reshape(
tf.concat([tf.ones(
(2, 10)), tf.zeros((8, 10))], axis=0), (2, 5, 10)),
y=tf.reshape(tf.concat([tf.ones(2), tf.zeros(8)], axis=0), (2, 5)),
mask=tf.reshape(
tf.concat(
[tf.ones(2, bool), tf.zeros(8, bool)], axis=0), (2, 5)),
), next(ds_iter))
with self.assertRaises(StopIteration):
next(ds_iter)
def test_create_dataset_padding_raises_error_cardinality(self):
dataset_builder = mock.Mock()
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
dataset = dataset.filter(lambda x: True)
dataset_builder.as_dataset.return_value = dataset
batch_dims = (2, 5)
with self.assertRaisesRegex(
ValueError,
r"^Cannot determine dataset cardinality."):
deterministic_data.create_dataset(
dataset_builder,
split="(ignored)",
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=2,
cardinality=None,
)
def test_pad_dataset(self):
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
padded_dataset = deterministic_data.pad_dataset(
dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
self.assertAllClose(
dict(
x=tf.concat([tf.ones(
(12, 10)), tf.zeros((8, 10))], axis=0),
y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0),
mask=tf.concat(
[tf.ones(12, bool), tf.zeros(8, bool)], axis=0)),
next(iter(padded_dataset.batch(20))))
def test_pad_nested_dataset(self):
dataset = tf.data.Dataset.from_tensor_slices(
{"x": {"z": (tf.ones((12, 10)), tf.ones(12))},
"y": tf.ones((12, 4))})
def expected(*dims):
return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0)
padded_dataset = deterministic_data.pad_dataset(
dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
self.assertAllClose(
{"x": {"z": (expected(10), expected())},
"y": expected(4),
"mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)},
next(iter(padded_dataset.batch(20))))
@parameterized.parameters(*itertools.product(range(20), range(1, 4)))
def test_same_cardinality_on_all_hosts(self, num_examples: int,
host_count: int):
builder = MyDatasetBuilder({"train": num_examples})
cardinalities = []
for host_id in range(host_count):
split = deterministic_data.get_read_instruction_for_host(
split="train",
num_examples=num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=True)
ds = deterministic_data.create_dataset(
builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1)
cardinalities.append(ds.cardinality().numpy().item())
self.assertLen(set(cardinalities), 1)
@parameterized.parameters(*itertools.product(range(20), range(1, 4)))
def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
host_count: int):
builder = MyDatasetBuilder({"train": num_examples})
# All hosts should have the same number of batches.
batch_size = 2
pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count)))
assert pad_up_to_batches * batch_size * host_count >= num_examples
cardinalities = []
for host_id in range(host_count):
split = deterministic_data.get_read_instruction_for_host(
split="train",
num_examples=num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=False)
ds = deterministic_data.create_dataset(
builder,
split=split,
batch_dims=[batch_size],
shuffle=False,
num_epochs=1,
pad_up_to_batches=pad_up_to_batches)
cardinalities.append(ds.cardinality().numpy().item())
self.assertLen(set(cardinalities), 1)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/internal/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: clu/internal/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Small utilities by CLU libraries."""
import contextlib
import sys
import time
from typing import Any, List, Mapping, Tuple, Union
from absl import logging
import jax.numpy as jnp
import numpy as np
import wrapt
@contextlib.contextmanager
def log_activity(activity_name: str):
"""Logs `activity_name` and timing information (or exception)."""
t0 = time.time()
logging.info("%s ...", activity_name)
try:
yield
finally:
dt = time.time() - t0
exc, *_ = sys.exc_info()
if exc is not None:
logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt,
exc.__name__)
else:
logging.info("%s finished after %.2fs.", activity_name, dt)
def logged_with(activity_name: str):
"""Returns a decorator wrapping a function with `log_activity()`."""
@wrapt.decorator
def decorator(wrapped, instance, args, kwargs):
del instance # not used
with log_activity(activity_name):
return wrapped(*args, **kwargs)
return decorator
def check_param(value, *, ndim=None, dtype=jnp.float32):
"""Raises a `ValueError` if `value` does not match ndim/dtype.
Args:
value: Value to be tested.
ndim: Expected dimensions.
dtype: Expected dtype.
Raises:
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
is not an instance of `jnp.ndarray`.
"""
if not isinstance(value, (np.ndarray, jnp.ndarray)):
raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
if ndim is not None and value.ndim != ndim:
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
if dtype is not None and value.dtype != dtype:
raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}")
def flatten_dict(
d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
) -> List[Tuple[str, Union[int, float, str]]]:
"""Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().
Args:
d: A dict-like object that has an `.item()` method.
prefix: Prefix to add to keys in `d`.
Returns:
Sequence of (k, v) pairs where k is the flattened key with individual
subkeys separated by dots. `None` values are replaced by the empty string.
"""
ret = []
for k, v in d.items():
# Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
if isinstance(v, Mapping) or hasattr(v, "items"):
ret += flatten_dict(v, prefix + (k,))
elif isinstance(v, (list, tuple)):
ret += flatten_dict({str(idx): value for idx, value in enumerate(v)},
prefix + (k,))
else:
ret.append((".".join(prefix + (k,)), v if v is not None else ""))
return ret
================================================
FILE: clu/internal/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from clu.internal import utils
import jax.numpy as jnp
import ml_collections
class TestError(BaseException):
__test__ = False
pass
class HelpersTest(absltest.TestCase):
def test_log_activity(
self,
):
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
pass
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
def test_log_activity_fails(
self,
):
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
raise TestError()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
def test_logged_with(self):
@utils.logged_with("test_activity")
def test():
pass
with self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
def test_logged_with_fails(self):
@utils.logged_with("test_activity")
def test():
raise TestError()
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
def test_check_param(self):
a = jnp.array(0.)
with self.assertRaisesRegex(ValueError, r"^Expected np.array or jnp.array"):
utils.check_param(None, ndim=1)
with self.assertRaisesRegex(ValueError, r"^Expected ndim"):
utils.check_param(a, ndim=1)
with self.assertRaisesRegex(ValueError, r"^Expected dtype"):
utils.check_param(a, ndim=0, dtype=jnp.int32)
utils.check_param(a, ndim=0) # should work
utils.check_param(a, ndim=0, dtype=jnp.float32) # should also work
def test_flatten_dict(self):
self.assertEqual(
utils.flatten_dict(
ml_collections.ConfigDict({
"x": 1,
"y": None,
"z": ml_collections.ConfigDict({
"a": "bc",
})
})), [("x", 1), ("y", ""), ("z.a", "bc")])
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/metric_writers/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metric writers write ML model outputs during model training and evaluation.
This module introduces the MetricWriter interface. MetricWriters allow users
to write out metrics about ML models during training and evaluation (e.g. loss,
accuracy).
There is a MetricWriter implementation for each back end (e.g. TensorFlow
summaries) and classes that work on top other MetricWriter to
write to multiple writes at once or write asynchronously.
Note: The current interface might not contain write() methods for all possible
data types. We are open for extending the interface to other data types
(e.g. audio).
Usage:
writer = MyMetricWriterImplementation()
# Before training.
writer.write_hparams({"learning_rate": 0.001, "batch_size": 64})
# Start training loop.
for step in range(num_train_steps):
loss = train_step()
if step % 50 == 0:
writer.write_scalars(step, {"loss": loss})
accuracy = evaluate()
writer.write_scalars(step, {"accuracy": accuracy})
# Make sure all values were written.
writer.flush() # or use metric_writers.ensure_flushes() context.
"""
# pylint: disable=unused-import
# pylint: disable=g-importing-member
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.async_writer import ensure_flushes
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.utils import create_default_writer
from clu.metric_writers.utils import write_values
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
================================================
FILE: clu/metric_writers/async_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes metrics in a separate thread.
- The order of the write calls is preserved.
- Users need to all `flush()` or use the `ensure_flushes()` context to make sure
that all metrics have been written.
- Errors while writing in the background thread will be re-raised in the main
thread on the next write_*() call.
"""
from collections.abc import Mapping, Sequence
import contextlib
from typing import Any, Optional
from clu import asynclib
from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import wrapt
Array = interface.Array
Scalar = interface.Scalar
@wrapt.decorator
def _wrap_exceptions(wrapped, instance, args, kwargs):
del instance
try:
return wrapped(*args, **kwargs)
except asynclib.AsyncError as e:
raise asynclib.AsyncError(
"Consider re-running the code without AsyncWriter (e.g. creating a "
"writer using "
"`clu.metric_writers.create_default_writer(asynchronous=False)`)"
) from e
class AsyncWriter(interface.MetricWriter):
"""MetricWriter that performs write operations in a separate thread.
All write operations will be executed in a background thread. If an exceptions
occurs in the background thread it will be raised on the main thread on the
call of one of the write_* methods.
Use num_workers > 1 at your own risk, if the underlying writer is not
thread-safe or does not expect out-of-order events, this can cause problems.
If num_workers is None then the ThreadPool will use `os.cpu_count()`
processes.
"""
def __init__(self,
writer: interface.MetricWriter,
*,
num_workers: Optional[int] = 1):
super().__init__()
self._writer = writer
# By default, we have a thread pool with a single worker to ensure that
# calls to the function are run in order (but in a background thread).
self._num_workers = num_workers
self._pool = asynclib.Pool(
thread_name_prefix="AsyncWriter", max_workers=num_workers)
@_wrap_exceptions
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
self._pool(self._writer.write_summaries)(
step=step, values=values, metadata=metadata)
@_wrap_exceptions
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
self._pool(self._writer.write_scalars)(step=step, scalars=scalars)
@_wrap_exceptions
def write_images(self, step: int, images: Mapping[str, Array]):
self._pool(self._writer.write_images)(step=step, images=images)
@_wrap_exceptions
def write_videos(self, step: int, videos: Mapping[str, Array]):
self._pool(self._writer.write_videos)(step=step, videos=videos)
@_wrap_exceptions
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
self._pool(self._writer.write_audios)(
step=step, audios=audios, sample_rate=sample_rate)
@_wrap_exceptions
def write_texts(self, step: int, texts: Mapping[str, str]):
self._pool(self._writer.write_texts)(step=step, texts=texts)
@_wrap_exceptions
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
self._pool(self._writer.write_histograms)(
step=step, arrays=arrays, num_buckets=num_buckets)
@_wrap_exceptions
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
self._pool(self._writer.write_pointcloud)(
step=step,
point_clouds=point_clouds,
point_colors=point_colors,
configs=configs,
)
@_wrap_exceptions
def write_hparams(self, hparams: Mapping[str, Any]):
self._pool(self._writer.write_hparams)(hparams=hparams)
def flush(self):
try:
self._pool.join()
finally:
self._writer.flush()
def close(self):
try:
self.flush()
finally:
self._writer.close()
class AsyncMultiWriter(multi_writer.MultiWriter):
"""AsyncMultiWriter writes to multiple writes in a separate thread."""
def __init__(self,
writers: Sequence[interface.MetricWriter],
*,
num_workers: Optional[int] = 1):
super().__init__([AsyncWriter(w, num_workers=num_workers) for w in writers])
@contextlib.contextmanager
def ensure_flushes(*writers: interface.MetricWriter):
"""Context manager which ensures that one or more writers are flushed."""
try:
# The caller should not need to use the yielded value, but we yield
# the first writer to stay backwards compatible for a single writer.
yield writers[0]
finally:
for writer in writers:
writer.flush()
================================================
FILE: clu/metric_writers/async_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for AsyncWriter."""
import time
from unittest import mock
from clu import asynclib
from clu.metric_writers import async_writer
from clu.metric_writers import interface
import numpy as np
import tensorflow as tf
class AsyncWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.sync_writer = mock.create_autospec(interface.MetricWriter)
self.writer = async_writer.AsyncWriter(self.sync_writer)
def test_write_summaries_async(self):
self.writer.write_summaries(
11,
{"a": np.eye(3, dtype=np.uint8),
"b": np.eye(2, dtype=np.float32)},
{"a": np.ones((2, 3)).tobytes()})
self.writer.flush()
self.sync_writer.write_summaries.assert_called_with(
step=11,
values={"a": mock.ANY, "b": mock.ANY},
metadata={"a": mock.ANY})
def test_write_scalars_async(self):
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
self.writer.flush()
self.sync_writer.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
def test_write_images(self):
images = np.zeros((2, 28, 28, 3))
self.writer.write_images(4, {"input_images": images})
self.writer.flush()
self.sync_writer.write_images.assert_called_with(4,
{"input_images": mock.ANY})
def test_write_videos(self):
videos = np.zeros((2, 4, 28, 28, 3))
self.writer.write_videos(4, {"input_videos": videos})
self.writer.flush()
self.sync_writer.write_videos.assert_called_with(4,
{"input_videos": mock.ANY})
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
self.sync_writer.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": mock.ANY},
point_colors={"pcd": mock.ANY},
configs={"config": mock.ANY},
)
def test_write_texts(self):
self.writer.write_texts(4, {"samples": "bla"})
self.writer.flush()
self.sync_writer.write_texts.assert_called_with(4, {"samples": "bla"})
def test_ensure_flushes(self):
with async_writer.ensure_flushes(self.writer) as writer:
writer.write_scalars(0, {"a": 3, "b": 0.15})
writer.write_scalars(2, {"a": 5, "b": 0.007})
self.sync_writer.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
self.sync_writer.flush.assert_called_once()
def test_ensure_flushes_with_multiple_writers(self):
sync_writer1 = mock.create_autospec(interface.MetricWriter)
writer1 = async_writer.AsyncWriter(sync_writer1)
sync_writer2 = mock.create_autospec(interface.MetricWriter)
writer2 = async_writer.AsyncWriter(sync_writer2)
with async_writer.ensure_flushes(writer1, writer2):
writer1.write_scalars(0, {"a": 3, "b": 0.15})
writer2.write_scalars(2, {"a": 5, "b": 0.007})
sync_writer1.write_scalars.assert_has_calls(
[mock.call(step=0, scalars={
"a": 3,
"b": 0.15
})])
sync_writer2.write_scalars.assert_has_calls(
[mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})])
sync_writer1.flush.assert_called_once()
sync_writer2.flush.assert_called_once()
def test_flush_before_close(self):
self.writer.close()
self.sync_writer.flush.assert_called()
self.sync_writer.close.assert_called()
def test_reraises_exception(self):
self.sync_writer.write_scalars.side_effect = ValueError("foo")
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
time.sleep(0.1)
with self.assertRaisesRegex(asynclib.AsyncError, "Consider re-running"):
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/interface.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for unify reporting model metrics across various logging formats.
This library provides a MetricWriter for each logging format (SummyWriter,
LoggingWriter, etc.) and composing MetricWriter to add support for asynchronous
logging or writing to multiple formats.
"""
import abc
from collections.abc import Mapping
from typing import Any, Optional, Union
import jax.numpy as jnp
import numpy as np
Array = Union[np.ndarray, jnp.ndarray]
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]
class MetricWriter(abc.ABC):
"""MetricWriter inferface."""
@abc.abstractmethod
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
"""Saves an arbitrary tensor summary.
Useful when working with custom plugins or constructing a summary directly.
Args:
step: Step at which the scalar values occurred.
values: Mapping from tensor keys to tensors.
metadata: Optional SummaryMetadata, as a proto or serialized bytes.
Note that markdown formatting is rendered by tensorboard.
"""
@abc.abstractmethod
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar values for the step.
Consecutive calls to this method can provide different sets of scalars.
Repeated writes for the same metric at the same step are not allowed.
Args:
step: Step at which the scalar values occurred.
scalars: Mapping from metric name to value.
"""
@abc.abstractmethod
def write_images(self, step: int, images: Mapping[str, Array]):
"""Write images for the step.
Consecutive calls to this method can provide different sets of images.
Repeated writes for the same image key at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing images!
Args:
step: Step at which the images occurred.
images: Mapping from image key to images. Images should have the shape [N,
H, W, C] or [H, W, C], where H is the height, W is the width and C the
number of channels (1 or 3). N is the number of images that will be
written. Image dimensions can differ between different image keys but
not between different steps for the same image key.
"""
@abc.abstractmethod
def write_videos(self, step: int, videos: Mapping[str, Array]):
"""Write videos for the step.
Warning: Logging only.
Not all MetricWriter implementation support writing videos!
Consecutive calls to this method can provide different sets of videos.
Repeated writes for the same video key at the same step are not allowed.
Args:
step: Step at which the videos occurred.
videos: Mapping from video key to videos. videos should have the shape
[N, T, H, W, C] or [T, H, W, C], where T is time, H is the height,
W is the width and C the number of channels (1 or 3). N is the number
of videos that will be written. Video dimensions can differ between
different video keys but not between different steps for the same
video key.
"""
@abc.abstractmethod
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""Write audios for the step.
Consecutive calls to this method can provide different sets of audios.
Repeated writes for the same audio key at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing audios!
Args:
step: Step at which the audios occurred.
audios: Mapping from audio key to audios. Audios should have the shape
[N, T, C], where T is the time length and C the number of channels
(1 - mono, 2 - stereo, >= 3 - surround; not all writers support any
number of channels). N is the number of audios that will be written.
Audio dimensions can differ between different audio keys but not between
different steps for the same audio key. Values should be floating-point
values in [-1, +1].
sample_rate: Sample rate for the audios.
"""
@abc.abstractmethod
def write_texts(self, step: int, texts: Mapping[str, str]):
"""Writes text snippets for the step.
Warning: Not all MetricWriter implementation support writing text!
Args:
step: Step at which the text snippets occurred.
texts: Mapping from name to text snippet.
"""
@abc.abstractmethod
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
"""Writes histograms for the step.
Consecutive calls to this method can provide different sets of scalars.
Repeated writes for the same metric at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing histograms!
Args:
step: Step at which the arrays were generated.
arrays: Mapping from name to arrays to summarize.
num_buckets: Number of buckets used to create the histogram of the arrays.
The default number of buckets depends on the particular implementation
of the MetricWriter.
"""
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
"""Writes point cloud summaries.
Args:
step: Step at which the point cloud was generated.
point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
array of point coordinates.
point_colors: Mapping from point colors key to [N, 3] array of point
colors.
configs: A dictionary of configuration options for the point cloud.
"""
raise NotImplementedError()
@abc.abstractmethod
def write_hparams(self, hparams: Mapping[str, Any]):
"""Write hyper parameters.
Do not call twice.
Args:
hparams: Flat mapping from hyper parameter name to value.
"""
@abc.abstractmethod
def flush(self):
"""Tells the MetricWriter to write out any cached values."""
@abc.abstractmethod
def close(self):
"""Flushes and closes the MetricWriter.
Calling any method on MetricWriter after MetricWriter.close()
is undefined behavior.
"""
================================================
FILE: clu/metric_writers/logging_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes all values to INFO log."""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.metric_writers import interface
import numpy as np
Array = interface.Array
Scalar = interface.Scalar
class LoggingWriter(interface.MetricWriter):
"""MetricWriter that writes all values to INFO log."""
def __init__(self, collection: Optional[str] = None):
if collection:
self._collection_str = f" collection={collection}"
else:
self._collection_str = ""
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
logging.info("[%d]%s Got raw tensors: %s.", step, self._collection_str,
{k: v.shape for k, v in values.items()})
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
values = [
f"{k}={v:.6g}" if isinstance(v, float) else f"{k}={v}"
for k, v in sorted(scalars.items())
]
logging.info("[%d]%s %s", step, self._collection_str, ", ".join(values))
def write_images(self, step: int, images: Mapping[str, Array]):
logging.info("[%d]%s Got images: %s.", step, self._collection_str,
{k: v.shape for k, v in images.items()})
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.info("[%d]%s Got videos: %s.", step, self._collection_str,
{k: v.shape for k, v in videos.items()})
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
logging.info("[%d]%s Got audios: %s.", step, self._collection_str,
{k: v.shape for k, v in audios.items()})
def write_texts(self, step: int, texts: Mapping[str, str]):
logging.info("[%d]%s Got texts: %s.", step, self._collection_str, texts)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
num_buckets = num_buckets or {}
for key, value in arrays.items():
histo, bins = _compute_histogram_as_tf(
np.asarray(value), num_buckets=num_buckets.get(key))
if histo is not None:
logging.info("[%d]%s Histogram for %r = {%s}", step,
self._collection_str, key,
_get_histogram_as_string(histo, bins))
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Any] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.info(
"[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
step,
self._collection_str,
{k: v.shape for k, v in point_clouds.items()},
(
{k: v.shape for k, v in point_colors.items()}
if point_colors is not None
else None
),
configs,
)
def write_hparams(self, hparams: Mapping[str, Any]):
logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)
def flush(self):
logging.flush()
def close(self):
self.flush()
def _compute_histogram_as_tf(
array: np.ndarray,
num_buckets: Optional[int] = None
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Compute the histogram of the input array as TF would do.
Args:
array: Input data. The histogram is computed over the flattened array.
num_buckets: The number of equal-width bins used to create the histogram.
Returns:
histo: A numpy array with the values of the histogram.
bins: A numpy array with the bin edges (its length is length(histo)+1).
If the histogram cannot be built because the array is empty, returns
(None, None).
"""
# See DEFAULT_BUCKET_COUNT in tensorboard/plugins/histogram/summary_v2.py
num_buckets = num_buckets or 30
if num_buckets < 2:
logging.log_first_n(logging.WARNING,
"num_buckets was automatically changed from %d to 2", 1,
num_buckets)
num_buckets = 2
if array.size == 0:
return None, None
range_max = np.max(array)
range_min = np.min(array)
if np.isclose(range_max, range_min, rtol=1e-5, atol=1e-8):
histo = np.asarray([array.size], dtype=np.int64)
bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64)
else:
histo, bins = np.histogram(
array, bins=num_buckets, range=(range_min, range_max))
bins = np.asarray(bins, dtype=np.float64)
return histo, bins
def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):
# First items are right-open (i.e. [a, b)).
items = [
f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}"
for i, count in enumerate(histo[:-1])
]
# Last item is right-closed (i.e. [a, b]).
items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}")
return ", ".join(items)
================================================
FILE: clu/metric_writers/logging_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the LoggingWriter."""
from clu.metric_writers import logging_writer
import numpy as np
import tensorflow as tf
class LoggingWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.writer = logging_writer.LoggingWriter()
def test_write_scalars(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 0.0000005, "b": 0.007})
self.assertEqual(
logs.output,
["INFO:absl:[0] a=3, b=0.15", "INFO:absl:[2] a=5e-07, b=0.007"])
def test_write_images(self):
images = np.zeros((2, 28, 28, 3))
with self.assertLogs(level="INFO") as logs:
self.writer.write_images(4, {"input_images": images})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got images: {'input_images': (2, 28, 28, 3)}."])
def test_write_videos(self):
videos = np.zeros((2, 4, 28, 28, 3))
with self.assertLogs(level="INFO") as logs:
self.writer.write_videos(4, {"input_videos": videos})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got videos: {'input_videos': (2, 4, 28, 28, 3)}."])
def test_write_texts(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_texts(4, {"samples": "bla"})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got texts: {'samples': 'bla'}."])
def test_write_histogram(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_histograms(
step=4,
arrays={
"a": np.asarray([-0.1, 0.1, 0.3]),
"b": np.arange(31),
"c": np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]),
},
num_buckets={
"a": 2,
"c": 1
})
# Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by
# default. Last bucket gets 2 values.
expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] +
["[29, 30]: 2"])
self.assertEqual(logs.output, [
"INFO:absl:[4] Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
f"INFO:absl:[4] Histogram for 'b' = {{{expected_histo_b}}}",
"WARNING:absl:num_buckets was automatically changed from 1 to 2",
"INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
])
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
with self.assertLogs(level="INFO") as logs:
self.writer.write_pointcloud(
step=4,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"configs": config},
)
self.assertEqual(
logs.output,
[
"INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
" point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
" {'material': 'PointCloudMaterial', 'size': 0.09}}."
],
)
def test_write_hparams(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
self.assertEqual(logs.output, [
"INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}"
])
def test_collection(self):
writer = logging_writer.LoggingWriter(collection="train")
with self.assertLogs(level="INFO") as logs:
writer.write_scalars(0, {"a": 3, "b": 0.15})
writer.write_images(4, {"input_images": np.zeros((2, 28, 28, 3))})
writer.write_texts(4, {"samples": "bla"})
writer.write_histograms(
step=4,
arrays={
"a": np.asarray([-0.1, 0.1, 0.3]),
},
num_buckets={
"a": 2,
})
writer.write_hparams({"learning_rate": 0.1})
self.assertEqual(logs.output, [
"INFO:absl:[0] collection=train a=3, b=0.15",
"INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.",
"INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.",
"INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
"INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}",
])
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/multi_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes to multiple MetricWriters."""
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from clu.metric_writers import interface
Array = interface.Array
Scalar = interface.Scalar
class MultiWriter(interface.MetricWriter):
"""MetricWriter that writes to multiple writers at once."""
def __init__(self, writers: Sequence[interface.MetricWriter]):
self._writers = tuple(writers)
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
for w in self._writers:
w.write_summaries(step, values, metadata)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
for w in self._writers:
w.write_scalars(step, scalars)
def write_images(self, step: int, images: Mapping[str, Array]):
for w in self._writers:
w.write_images(step, images)
def write_videos(self, step: int, videos: Mapping[str, Array]):
for w in self._writers:
w.write_videos(step, videos)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
for w in self._writers:
w.write_audios(step, audios, sample_rate=sample_rate)
def write_texts(self, step: int, texts: Mapping[str, str]):
for w in self._writers:
w.write_texts(step, texts)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
for w in self._writers:
w.write_histograms(step, arrays, num_buckets)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
for w in self._writers:
w.write_pointcloud(
step, point_clouds, point_colors=point_colors, configs=configs
)
def write_hparams(self, hparams: Mapping[str, Any]):
for w in self._writers:
w.write_hparams(hparams)
def flush(self):
for w in self._writers:
w.flush()
def close(self):
for w in self._writers:
w.close()
================================================
FILE: clu/metric_writers/multi_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for MultiWriter."""
from unittest import mock
from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import numpy as np
import tensorflow as tf
class MultiWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.writers = [
mock.create_autospec(interface.MetricWriter),
mock.create_autospec(interface.MetricWriter)
]
self.writer = multi_writer.MultiWriter(self.writers)
def test_write_scalars(self):
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
self.writer.flush()
for w in self.writers:
w.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
w.flush.assert_called()
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
for w in self.writers:
w.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
w.flush.assert_called()
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for writing to TF summary files."""
# pylint: disable=unused-import
from .tf.summary_writer import SummaryWriter
================================================
FILE: clu/metric_writers/tf/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Package __init__ file."""
================================================
FILE: clu/metric_writers/tf/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for writing to TF summary files.
Only works in eager mode. Does not work for Pytorch code, please use
TorchTensorboardWriter instead.
"""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.internal import utils
from clu.metric_writers import interface
from etils import epy
import tensorflow as tf
with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from tensorboard.plugins.hparams import api as hparams_api
from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
# pylint: enable=g-import-not-at-top
Array = interface.Array
Scalar = interface.Scalar
class SummaryWriter(interface.MetricWriter):
"""MetricWriter that writes TF summary files."""
def __init__(self, logdir: str):
super().__init__()
self._summary_writer = tf.summary.create_file_writer(logdir)
def write_summaries(
self,
step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None,
):
with self._summary_writer.as_default():
for key, value in values.items():
md = metadata.get(key) if metadata is not None else None
tf.summary.write(key, value, step=step, metadata=md)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
with self._summary_writer.as_default():
for key, value in scalars.items():
tf.summary.scalar(key, value, step=step)
def write_images(self, step: int, images: Mapping[str, Array]):
with self._summary_writer.as_default():
for key, value in images.items():
if len(value.shape) == 3:
value = value[None]
tf.summary.image(key, value, step=step, max_outputs=value.shape[0])
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.log_first_n(
logging.WARNING,
"SummaryWriter does not support writing videos.", 1)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
with self._summary_writer.as_default():
for key, value in audios.items():
tf.summary.audio(key, value, sample_rate=sample_rate, step=step,
max_outputs=value.shape[0])
def write_texts(self, step: int, texts: Mapping[str, str]):
with self._summary_writer.as_default():
for key, value in texts.items():
tf.summary.text(key, value, step=step)
def write_histograms(
self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None,
):
with self._summary_writer.as_default():
for key, value in arrays.items():
buckets = None if num_buckets is None else num_buckets.get(key)
tf.summary.histogram(key, value, step=step, buckets=buckets)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
with self._summary_writer.as_default():
for key, vertices in point_clouds.items():
colors = None if point_colors is None else point_colors.get(key)
config = None if configs is None else configs.get(key)
mesh_summary.mesh(
key,
vertices=vertices,
colors=colors,
step=step,
config_dict=config,
)
def write_hparams(self, hparams: Mapping[str, Any]):
with self._summary_writer.as_default():
hparams_api.hparams(dict(utils.flatten_dict(hparams)))
def flush(self):
self._summary_writer.flush()
def close(self):
self._summary_writer.close()
================================================
FILE: clu/metric_writers/tf/summary_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SummaryWriter."""
import collections
import os
from clu.metric_writers.tf import summary_writer
import numpy as np
import tensorflow as tf
from tensorboard.plugins.hparams import plugin_data_pb2
def _load_summaries_data(logdir):
"""Loads raw summaries data from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
metadata = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
if value.HasField("metadata"):
metadata[event.step][value.tag] = value.metadata.SerializeToString()
return data, metadata
def _load_histograms_data(logdir):
"""Loads tensor summaries from events in a logdir."""
# Note: new versions of histograms don't use the HistogramProto type, but
# they are written as tensors representing the bounds and counts of buckets,
# with plugin_name = "histogram".
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = {}
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
current_steps, current_tensors = data.get(value.tag, ([], []))
data[value.tag] = (current_steps + [event.step],
current_tensors + [tf.make_ndarray(value.tensor)])
return {
tag: (np.stack(steps), np.stack(tensors))
for tag, (steps, tensors) in data.items()
}
def _load_scalars_data(logdir: str):
"""Loads scalar summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = tf.make_ndarray(value.tensor).flat[0]
return data
def _load_pointcloud_data(logdir: str):
"""Loads pointcloud summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
if value.metadata.plugin_data.plugin_name == "mesh":
if "config" not in value.tag:
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
else:
data[event.step][value.tag] = value.metadata.plugin_data.content
return data
def _load_hparams(logdir: str):
"""Loads hparams summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
# data = collections.defaultdict(dict)
hparams = []
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
if value.metadata.plugin_data.plugin_name == "hparams":
hparams.append(plugin_data_pb2.HParamsPluginData.FromString(
value.metadata.plugin_data.content))
return hparams
class SummaryWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.logdir = self.get_temp_dir()
self.writer = summary_writer.SummaryWriter(self.logdir)
def test_write_summaries(self):
self.writer.write_summaries(
11,
{"a": np.eye(3, dtype=np.uint8),
"b": np.eye(2, dtype=np.float32)},
{"a": np.ones((2, 3)).tobytes()})
self.writer.flush()
data, metadata = _load_summaries_data(self.logdir)
self.assertAllClose(
data[11],
{"a": np.eye(3, dtype=np.uint8), "b": np.eye(2, dtype=np.float32)})
self.assertIn("a", metadata[11])
def test_write_scalar(self):
self.writer.write_scalars(11, {"a": 0.6, "b": 15})
self.writer.write_scalars(20, {"a": 0.8, "b": 12})
self.writer.flush()
data = _load_scalars_data(self.logdir)
self.assertAllClose(data[11], {"a": 0.6, "b": 15})
self.assertAllClose(data[20], {"a": 0.8, "b": 12})
def test_write_histograms(self):
self.writer.write_histograms(
0, {
"a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
"b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
}, num_buckets={"a": 2, "b": 2})
self.writer.write_histograms(
2, {
"a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
"b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
}, num_buckets={"a": 2, "b": 2})
self.writer.flush()
data = _load_histograms_data(self.logdir)
# In the histograms, each tuple represents
# (bucket_min, bucket_max, bucket_count), where bucket_min is inclusive and
# bucket_max is exclusive (except the last bucket_max which is inclusive).
expected_histograms_a = [
# Step 0.
[(0.1, 0.4, 3), (0.4, 0.7, 2)],
# Step 1.
[(-0.1, 0.2, 2), (0.2, 0.5, 3)],
]
self.assertAllClose(data["a"], ([0, 2], expected_histograms_a))
expected_histograms_b = [
# Step 0.
[(-0.1, 0.15, 1), (0.15, 0.4, 4)],
# Step 1.
[(0.0, 0.35, 4), (0.35, 0.7, 1)],
]
self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
data = _load_pointcloud_data(self.logdir)
self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
self.assertAllClose(data[0]["pcd_COLOR"], point_colors)
def test_hparams(self):
self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
hparams = _load_hparams(self.logdir)
self.assertLen(hparams, 1)
hparams_dict = hparams[0].session_start_info.hparams
self.assertLen(hparams_dict, 2)
self.assertEqual(512, hparams_dict["batch_size"].number_value)
self.assertEqual(90, hparams_dict["num_epochs"].number_value)
def test_hparams_nested(self):
config = {
"list": [1, 2],
"tuple": (3, 4),
"subconfig": {
"value": "a",
"list": [10, 20],
},
}
self.writer.write_hparams(config)
hparams = _load_hparams(self.logdir)
self.assertLen(hparams, 1)
hparams_dict = hparams[0].session_start_info.hparams
self.assertLen(hparams_dict, 7)
self.assertEqual(1, hparams_dict["list.0"].number_value)
self.assertEqual(2, hparams_dict["list.1"].number_value)
self.assertEqual(3, hparams_dict["tuple.0"].number_value)
self.assertEqual(4, hparams_dict["tuple.1"].number_value)
self.assertEqual("a", hparams_dict["subconfig.value"].string_value)
self.assertEqual(10, hparams_dict["subconfig.list.0"].number_value)
self.assertEqual(20, hparams_dict["subconfig.list.1"].number_value)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/torch_tensorboard_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for Pytorch summary files.
Use this writer for the Pytorch-based code.
"""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.metric_writers import interface
from torch.utils import tensorboard
Array = interface.Array
Scalar = interface.Scalar
class TorchTensorboardWriter(interface.MetricWriter):
"""MetricWriter that writes Pytorch summary files."""
def __init__(self, logdir: str):
super().__init__()
self._writer = tensorboard.SummaryWriter(log_dir=logdir)
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
logging.log_first_n(
logging.WARNING,
"TorchTensorboardWriter does not support writing raw summaries.", 1)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
for key, value in scalars.items():
self._writer.add_scalar(key, value, global_step=step)
def write_images(self, step: int, images: Mapping[str, Array]):
for key, value in images.items():
self._writer.add_image(key, value, global_step=step, dataformats="HWC")
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing videos.", 1)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
for key, value in audios.items():
self._writer.add_audio(
key, value, global_step=step, sample_rate=sample_rate)
def write_texts(self, step: int, texts: Mapping[str, str]):
raise NotImplementedError(
"TorchTensorBoardWriter does not support writing texts."
)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
for tag, values in arrays.items():
bins = None if num_buckets is None else num_buckets.get(tag)
self._writer.add_histogram(
tag, values, global_step=step, bins="auto", max_bins=bins)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing point clouds.",
1,
)
def write_hparams(self, hparams: Mapping[str, Any]):
self._writer.add_hparams(hparams, {})
def flush(self):
self._writer.flush()
def close(self):
self._writer.close()
================================================
FILE: clu/metric_writers/torch_tensorboard_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TorchTensorboardWriter."""
import collections
import os
from typing import Any, Dict
from clu.metric_writers import torch_tensorboard_writer
import numpy as np
import tensorflow as tf
def _load_scalars_data(logdir: str):
"""Loads scalar summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = value.simple_value
return data
def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
"""Loads histograms summaries from events in a logdir.
Args:
logdir: a directory to find logs
Returns:
A generated histograms in a shape step -> tag -> histo.
"""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = {}
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
if event.step not in data:
data[event.step] = {}
step_data = {}
for value in event.summary.value:
print(" value:", value)
step_data[value.tag] = value.histo
data[event.step].update(step_data)
return data
class TorchTensorboardWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.logdir = self.get_temp_dir()
self.writer = torch_tensorboard_writer.TorchTensorboardWriter(self.logdir)
def test_write_scalar(self):
self.writer.write_scalars(11, {"a": 0.6, "b": 15})
self.writer.write_scalars(20, {"a": 0.8, "b": 12})
self.writer.flush()
data = _load_scalars_data(self.logdir)
self.assertAllClose(data[11], {"a": 0.6, "b": 15})
self.assertAllClose(data[20], {"a": 0.8, "b": 12})
def test_write_histograms(self):
self.writer.write_histograms(
0, {
"a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
"b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
}, num_buckets={"a": 2, "b": 2})
self.writer.write_histograms(
2, {
"a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
"b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
}, num_buckets={"a": 2, "b": 2})
self.writer.flush()
data = _load_histograms_data(self.logdir)
self.assertNear(data[0]["a"].min, 0.1, 0.001)
self.assertNear(data[0]["a"].max, 0.7, 0.001)
self.assertNear(data[0]["b"].min, -0.1, 0.001)
self.assertNear(data[0]["b"].max, 0.4, 0.001)
self.assertNear(data[2]["a"].min, -0.1, 0.001)
self.assertNear(data[2]["a"].max, 0.5, 0.001)
self.assertNear(data[2]["b"].min, 0.0, 0.001)
self.assertNear(data[2]["b"].max, 0.7, 0.001)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a generic write interface.
The write helper accepts a MetricWriter object and a Mapping[str,
clu.metrics.Metric], and automatically writes to the appropriate typed write
method of the writer depending on the type of the metric.
"""
# pylint: disable=g-importing-member
import collections
import getpass
import os
import re
from typing import Any, List, Mapping, Optional, Tuple, Union
from absl import flags
from absl import logging
from clu import values
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from etils import epath
import jax.numpy as jnp
import numpy as np
FLAGS = flags.FLAGS
def _is_scalar(value: Any) -> bool:
if isinstance(value, values.Scalar) or isinstance(
value, (int, float, np.number)
):
return True
if isinstance(value, (np.ndarray, jnp.ndarray)):
return value.ndim == 0 or value.size <= 1
return False
def write_values(
writer: MetricWriter,
step: int,
metrics: Mapping[
str, Union[values.Value, values.ArrayType, values.ScalarType]
],
):
"""Writes all provided metrics.
Allows providing a mapping of name to Value object, where each Value
specifies a type. The appropriate write method can then be called depending
on the type.
Args:
writer: MetricWriter object
step: Step at which the arrays were generated.
metrics: Mapping from name to clu.values.Value object.
"""
writes = collections.defaultdict(dict)
histogram_num_buckets = collections.defaultdict(int)
for k, v in metrics.items():
if isinstance(v, values.Summary):
writes[
(writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
][k] = v.value
elif _is_scalar(v):
if isinstance(v, values.Scalar):
writes[(writer.write_scalars, frozenset())][k] = v.value
else:
writes[(writer.write_scalars, frozenset())][k] = v
elif isinstance(v, values.Image):
writes[(writer.write_images, frozenset())][k] = v.value
elif isinstance(v, values.Text):
writes[(writer.write_texts, frozenset())][k] = v.value
elif isinstance(v, values.HyperParam):
writes[(writer.write_hparams, frozenset())][k] = v.value
elif isinstance(v, values.Histogram):
writes[(writer.write_histograms, frozenset())][k] = v.value
histogram_num_buckets[k] = v.num_buckets
elif isinstance(v, values.Audio):
writes[(
writer.write_audios,
frozenset({"sample_rate": v.sample_rate}.items()),
)][k] = v.value
else:
raise ValueError("Metric: ", k, " has unsupported value: ", v)
for (fn, extra_args), vals in writes.items():
if fn == writer.write_histograms:
# for write_histograms, the num_buckets arg is a Dict indexed by name
writer.write_histograms(step, vals, num_buckets=histogram_num_buckets)
else:
fn(step, vals, **dict(extra_args))
def create_default_writer(
logdir: Optional[epath.PathLike] = None,
*,
just_logging: bool = False,
asynchronous: bool = True,
collection: Optional[str] = None,
) -> MultiWriter:
"""Create the default writer for the platform.
On most platforms this will create a MultiWriter that writes to multiple back
ends (logging, TF summaries etc.).
Args:
logdir: Logging dir to use for TF summary files. If empty/None will the
returned writer will not write TF summary files.
just_logging: If True only use a LoggingWriter. This is useful in multi-host
setups when only the first host should write metrics and all other hosts
should only write to their own logs.
default (None) will automatically determine if you # GOOGLE-INTERNAL have
asynchronous: If True return an AsyncMultiWriter to not block when writing
metrics.
collection: A string which, if provided, provides an indication that the
provided metrics should all be written to the same collection, or
grouping.
Returns:
A `MetricWriter` according to the platform and arguments.
"""
if just_logging:
if asynchronous:
return AsyncMultiWriter([LoggingWriter(collection=collection)])
else:
return MultiWriter([LoggingWriter(collection=collection)])
writers = [LoggingWriter(collection=collection)]
if logdir is not None:
logdir = epath.Path(logdir)
if collection is not None:
logdir /= collection
writers.append(SummaryWriter(os.fspath(logdir)))
if asynchronous:
return AsyncMultiWriter(writers)
return MultiWriter(writers)
================================================
FILE: clu/metric_writers/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for interface."""
# pylint: disable=g-importing-member
import itertools
from typing import Any
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from clu import values
from clu.metric_writers import utils
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
import clu.metrics
import flax.struct
import jax.numpy as jnp
import tensorflow as tf
@flax.struct.dataclass
class HistogramMetric(clu.metrics.Metric):
value: jnp.ndarray
num_buckets: int
def compute_value(self):
return values.Histogram(self.value, self.num_buckets)
@flax.struct.dataclass
class ImageMetric(clu.metrics.Metric):
value: jnp.ndarray
def compute_value(self):
return values.Image(self.value)
@flax.struct.dataclass
class AudioMetric(clu.metrics.Metric):
value: jnp.ndarray
sample_rate: int
def compute_value(self):
return values.Audio(self.value, self.sample_rate)
@flax.struct.dataclass
class TextMetric(clu.metrics.Metric):
value: str
def compute_value(self):
return values.Text(self.value)
@flax.struct.dataclass
class HyperParamMetric(clu.metrics.Metric):
value: float
def compute_value(self):
return values.HyperParam(self.value)
@flax.struct.dataclass
class SummaryMetric(clu.metrics.Metric):
value: jnp.ndarray
metadata: Any
def compute_value(self):
return values.Summary(self.value, self.metadata)
def _to_summary(metrics):
return {k: v.value for k, v in metrics.items()}
def _to_list_of_dicts(d):
return [{k: v} for k, v in d.items()]
class ONEOF(object):
"""ONEOF(options_list) check value in options_list."""
def __init__(self, container):
if not hasattr(container, "__contains__"):
raise TypeError(f"{container!r} is not a container")
if not container:
raise ValueError(f"{container!r} is empty")
self._c = container
def __eq__(self, o):
return o in self._c
def __ne__(self, o):
return o not in self._c
def __repr__(self):
return "".format(",".join(repr(i) for i in self._c))
class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):
def test_write(self):
writer = mock.Mock(spec_set=MetricWriter)
step = 3
num_buckets = 4
sample_rate = 10
scalar_metrics = {
"loss": clu.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])),
"accuracy": clu.metrics.LastValue.from_model_output(jnp.asarray([5])),
}
image_metrics = {
"image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
}
histogram_metrics = {
"hist": HistogramMetric(
value=jnp.asarray([7, 8]), num_buckets=num_buckets
),
"hist2": HistogramMetric(
value=jnp.asarray([9, 10]), num_buckets=num_buckets
),
}
audio_metrics = {
"audio": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate
),
"audio2": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
),
}
text_metrics = {
"text": TextMetric(value="hello"),
}
hparam_metrics = {
"lr": HyperParamMetric(value=0.01),
}
summary_metrics = {
"summary": SummaryMetric(
value=jnp.asarray([2, 3, 10]), metadata="some info"
),
"summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
}
metrics = {
**scalar_metrics,
**image_metrics,
**histogram_metrics,
**audio_metrics,
**text_metrics,
**hparam_metrics,
**summary_metrics,
}
metrics = {k: m.compute_value() for k, m in metrics.items()}
utils.write_values(writer, step, metrics)
writer.write_scalars.assert_called_once_with(
step, {k: m.compute() for k, m in scalar_metrics.items()}
)
writer.write_images.assert_called_once_with(
step, _to_summary(image_metrics)
)
writer.write_histograms.assert_called_once_with(
step,
_to_summary(histogram_metrics),
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
)
writer.write_audios.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
sample_rate=ONEOF([sample_rate, sample_rate + 2]),
)
writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
writer.write_hparams.assert_called_once_with(
step, _to_summary(hparam_metrics)
)
writer.write_summaries.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
metadata=ONEOF(["some info", 5]),
)
def test_create_default_writer_summary_writer_is_added(self):
writer = utils.create_default_writer(
logdir=self.get_temp_dir(), asynchronous=False
)
self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
writer = utils.create_default_writer(logdir=None, asynchronous=False)
self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/metrics.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functional metric computation library.
This library defines a functional metric computation interface `Metric` that
relies on metrics accumulating intermediate values (in a possibly distributed
manner), and then computes the final metric value from these intermediate
values. Note that most metrics can be created via `Average.from_fun()`. See also
`CollectingMetric` that collects all model outputs with a given name and lends
itself to metric computation in Python.
Some common metrics, such as accuracy and loss average/standard deviation, and a
`Collection` with the same interface, are also provided.
The "model output" is a dictionary of values with unique keys that all have a
specific meaning (such as `loss`, `logits`, and `labels`) and every metric
depends on at least one such model output by name. These outputs are usually
expected to be instances of `jnp.ndarray`.
Synopsis:
# Note: Metrics do *not* work with `from __future__ import annotations`
from clu import metrics
import flax
import jax
@flax.struct.dataclass # required for jax.tree_*
class MyCollection(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")
@jax.pmap
def eval_step(variables, metrics, inputs, labels):
loss, logits = get_loss_and_logits(variables, inputs, labels)
return metrics.merge(MyCollection.gather_from_model_output(
loss=loss, logits=logits, labels=labels))
def evaluate(variables_p, test_ds):
metrics = MyCollection.empty()
for inputs, labels in test_ds:
metrics = eval_step(variables_p, metrics, inputs, labels)
return metrics.unreplicate().compute()
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import inspect
from typing import Any, TypeVar, Protocol
from absl import logging
from clu.internal import utils
import clu.values
import flax
import jax
import jax.numpy as jnp
import numpy as np
Array = jax.Array
ArrayLike = jax.typing.ArrayLike
class FromFunCallable(Protocol):
"""The type of functions that can be passed to `Metrics.from_fun()`."""
def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
"""Returns the argument/arguments passed to the base from_model_output()."""
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
"""Raises a `ValueError` if shapes of `a` and `b` don't match."""
if a.shape != b.shape:
raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")
M = TypeVar("M", bound="Metric")
class Metric:
"""Interface for computing metrics from intermediate values.
Refer to `Collection` for computing multiple metrics at the same time.
Synopsis:
import jax.numpy as jnp
import flax
@flax.struct.dataclass
class Average(Metric):
total: jnp.ndarray
count: jnp.ndarray
@classmethod
def from_model_output(cls, value: jnp.ndarray, **_) -> Metric:
return cls(total=value.sum(), count=np.prod(value.shape))
def merge(self, other: Metric) -> Metric:
return type(self)(
total=self.total + other.total,
count=self.count + other.count,
)
def compute(self):
return self.total / self.count
average = None
for value in range(data):
update = Average.from_model_output(value)
average = update if average is None else average.merge(update)
print(average.compute())
"""
@classmethod
def from_model_output(cls: type[M], *args, **kwargs) -> M:
"""Creates a `Metric` from model outputs."""
raise NotImplementedError("Must override from_model_output()")
def merge(self: M, other: M) -> M:
"""Returns `Metric` that is the accumulation of `self` and `other`.
Args:
other: A `Metric` whose intermediate values should be accumulated onto the
values of `self`. Note that in a distributed setting, `other` will
typically be the output of a `jax.lax` parallel operator and thus have a
dimension added to the dataclass returned by `.from_model_output()`.
Returns:
A new `Metric` that accumulates the value from both `self` and `other`.
"""
raise NotImplementedError("Must override merge()")
# The variant of `merge()` called inside `reduce()`. While `merge()` and
# `_reduce_merge()` will be the same in many cases, there are exceptions:
# see `LastValue` for an example of a `Metric` which aggregates values
# differently over training steps compared how it aggregates them over
# accelerators.
#
# `_reduce_merge()` must be associative[1], otherwise we would get
# different results when using different devices.
# [1] https://en.wikipedia.org/wiki/Associative_property
def _reduce_merge(self: M, other: M) -> M:
return self.merge(other)
def compute(self) -> jnp.ndarray:
"""Computes final metrics from intermediate values."""
raise NotImplementedError("Must override compute()")
@classmethod
def empty(cls: type[M]) -> M:
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
raise NotImplementedError("Must override empty()")
def compute_value(self) -> clu.values.Value:
"""Wraps compute() and returns a values.Value."""
return clu.values.Scalar(self.compute())
def reduce(self: M) -> M:
"""Reduces the metric along it first axis by calling `_reduce_merge()`.
This function primary use case is to aggregate metrics collected across
multiple devices, rather than "merging" metrics across multiple steps.
In many cases these have the same semantics (such as `Average`), but
in some such as `LastValue`'s batch averaging, reduction across devices
is averaging, while reduction across steps is taking the last value.
See `Collection.reduce`, for usage patterns.
Returns:
reduced metric.
"""
def reduce_step(reduced: M, metric: M) -> tuple[M, None]:
# pylint: disable-next=protected-access
return reduced._reduce_merge(metric), None
# Avoid degraded performance under the new jax.pmap.
# Only use the sharding path for concrete sharded arrays, not tracers.
def _is_concrete_sharded(x):
if isinstance(x, jax.core.Tracer):
return False
if not hasattr(x, "addressable_shards"):
return False
shards = x.addressable_shards
if not shards:
return False
# Only use sharding path when shards have shape (1, ...) from pmap
return shards[0].data.ndim > 0 and shards[0].data.shape[0] == 1
leaves = jax.tree_util.tree_leaves(self)
use_sharding_path = leaves and _is_concrete_sharded(leaves[0])
if use_sharding_path:
def get_first(x):
return x.addressable_shards[0].data.squeeze(0)
def get_remainder(x):
shards = x.addressable_shards
if len(shards) <= 1:
shape = shards[0].data.squeeze(0).shape
return jnp.empty((0,) + shape, dtype=shards[0].data.dtype)
return jnp.stack([s.data.squeeze(0) for s in shards[1:]], axis=0)
first = jax.tree_util.tree_map(get_first, self)
remainder = jax.tree_util.tree_map(get_remainder, self)
else:
first = jax.tree_util.tree_map(lambda x: x[0], self)
remainder = jax.tree_util.tree_map(lambda x: x[1:], self)
# According to b/160868467#comment4, usage of `jax.lax.scan` does not add a
# significant computational cost for simple metrics where e.g. `jnp.sum`
# could be used instead.
return jax.lax.scan(reduce_step, first, remainder)[0] # pytype: disable=wrong-arg-types # lax-types
@classmethod
def from_fun(cls, fun: FromFunCallable): # No way to annotate return type
"""Calls `cls.from_model_output` with the return value from `fun`.
Returns a `Metric` derived from `cls` whose `.from_model_output` (1) calls
`fun` with keyword arguments from `model_output` and (2) supplies the output
of `fun` to `cls.from_model_output`.
If the return value of `fun` is a `Mapping`, then it will be expanded to
create keyword arguments for `cls.from_model_output`. Otherwise, the output
of `fun` is supplied as a single argument to `cls.from_model_output`.
Note that the model output "mask" will also be forwarded to the metric, but
only if it has the same first dimension as the value returned by `fun` (or
the first value in the `Mapping` returned by `fun`). This allows metrics
created by this function to be used both with values that exist per-example,
as well as with values that only exist per batch.
NOTE: If `fun` returns a `Mapping` with key "mask", then this mask will
override a "mask" key passed to `from_model_output`. This allows
`fun` to read custom mask fields from `model_output`.
Example:
```
def get_head1(head1_loss, head1_mask, **_):
return dict(loss=head1_loss, mask=head1_mask)
@flax.struct.dataclass
class MultiHeadMetrics(metrics.Collection):
head1_loss: metrics.Average.from_output("loss").from_fun(get_head1)
...
ms = MultiHeadMetrics.single_from_model_output(
head1_loss=..., head1_mask=..., ...)
```
Args:
fun: Function to be applied to model output.
Returns:
A `Metric` derived from `cls` that calls `.from_model_output()` with
the output returned by `fun` when called with keyword arguments from
`model_output`.
"""
@flax.struct.dataclass
class FromFun(cls):
"""Wrapper Metric class that collects output after applying `fun`."""
@classmethod
def from_model_output(cls: type[M], **model_output) -> M:
mask = model_output.get("mask")
output = fun(**model_output)
if isinstance(output, Mapping) and "mask" in output:
output = dict(output)
# pop mask to avoid multiple arg error later.
output_mask = output.pop("mask", None)
mask = output_mask
# Ignore the mask if its first dimension doesn't match that of the
# output of `fun`.
if mask is not None:
if isinstance(output, Mapping):
first_output = next(iter(output.values()))
else:
first_output = output
if (first_output.shape or [0])[0] != mask.shape[0]:
logging.warning(
"Ignoring mask for fun(**model output) because of shape "
"mismatch: output.shape=%s vs. mask.shape=%s",
first_output.shape, mask.shape)
mask = None
if isinstance(output, Mapping):
return super().from_model_output(**output, mask=mask)
else:
return super().from_model_output(output, mask=mask)
return FromFun
@classmethod
def from_output(cls, name: str): # No way to annotate return type
"""Calls `cls.from_model_output` with model output named `name`.
Synopsis:
@flax.struct.dataclass
class Metrics(Collection):
loss: Average.from_output('loss')
Note that the model output "mask" will also be forwarded to the metric, but
only if it has the same first dimension as the model output specified by
`name`. This allows to use metrics created by this function both with named
outputs that exist per-example, as well as with model outputs that only
exist per batch (as for example "loss" often does).
Args:
name: Name of the model output that should be passed as first argument to
`cls.from_model_output()`.
Returns:
A `Metric` derived from `cls` that calls `.from_model_output()` with as
a first argument the model output specified by `name`.
"""
@flax.struct.dataclass
class FromOutput(cls):
"""Wrapper Metric class that collects output named `name`."""
@classmethod
def from_model_output(cls: type[M], **model_output) -> M:
output = jnp.array(model_output[name])
mask = model_output.get("mask")
if mask is not None and (output.shape or [0])[0] != mask.shape[0]:
logging.warning(
"Ignoring mask for model output '%s' because of shape mismatch: "
"output.shape=%s vs. mask.shape=%s", name, output.shape,
mask.shape)
mask = None
return super().from_model_output(output, mask=mask)
return FromOutput
@flax.struct.dataclass
class CollectingMetric(Metric):
"""A special metric that collects model outputs.
This metric can NOT be used inside JIT-compiled eval steps (like the pattern
described in the pydoc of this module). Instead, you will need to call
`.merge()` in the Python evaluation loop that calls the compiled evaluation
step. Metric accumulation happens on the host memory. For an efficient use
of this metric that interleaves JAX computation with Python execution, see the
async snippet below.
This metric transfers arrays to host memory (converting to `np.ndarray`) for
later use in computations on CPU. The references to individual arrays are
stored in tuples, and a final call to `.compute()` concatenates these arrays.
If not needed, this final copy can be avoided by overriding `.compute()`.
Note though that these metrics use much more memory and compute somewhat more
slowly.
Also note that `mask` output is not applied automatically. Rather it should
be collected and used in the final computation from the collected data.
Example to use compute average precision using `sklearn`:
@flax.struct.dataclass
class AveragePrecision(
metrics.CollectingMetric.from_outputs(("labels", "logits"))):
def compute(self):
values = super().compute()
return sklearn.metrics.average_precision_score(
values["labels"], values["logits"][:, 1])
Note that this metric causes a sync barrier when the data is transferred to
the host. But this can be avoided by using `asynclib`:
from clu import asynclib
def evaluate(params):
pool = asynclib.Pool()
@pool
def copy_to_host(update):
return jax.tree_util.tree_map(np.asarray, update)
futures = []
for batch in eval_ds:
futures.append(copy_to_host(eval_step(params, batch)))
ms = MyCollection.empty()
for future in futures:
ms = ms.merge(future.result())
return ms.compute()
"""
values: dict[str, tuple[np.ndarray, ...]]
@classmethod
def empty(cls) -> CollectingMetric:
return cls(values={})
def merge(self, other: CollectingMetric) -> CollectingMetric:
values = {
name: (*value, *other.values[name])
for name, value in self.values.items()
}
if any(
isinstance(vv, jax.core.Tracer) for v in values.values() for vv in v): # pylint: disable=g-complex-comprehension
raise RuntimeError(
"Tracer detected! CollectingMetric cannot be JIT compiled.")
if other.values and not self.values:
return other
if self.values and not other.values:
return self
return type(self)(jax.tree_util.tree_map(np.asarray, values))
def reduce(self) -> CollectingMetric:
# Note that this is usually called from inside a `pmap()` via
# `Collection.gather_from_model_output()` so we concatenate using jnp.
return type(self)(
{name: jnp.concatenate(values) for name, values in self.values.items()}) # pytype: disable=wrong-arg-types # jnp-types
def compute(self): # No return type annotation, so subclasses can override
return {k: np.concatenate(v) for k, v in self.values.items()}
@classmethod
def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]:
"""Returns a metric class that collects all model outputs named `names`."""
@flax.struct.dataclass
class FromOutputs(cls): # pylint:disable=missing-class-docstring
@classmethod
def from_model_output(cls: type[M], **model_output) -> M:
def make_array(value):
if value is None:
value = jnp.nan
value = jnp.array(value)
# Can't jnp.concatenate() scalars, promote to shape=(1,) in that case.
return value[None] if value.ndim == 0 else value
return cls({name: (make_array(model_output[name]),) for name in names})
return FromOutputs
@flax.struct.dataclass
class _ReductionCounter(Metric):
"""Pseudo metric that keeps track of the total number of `.merge()`."""
value: jnp.ndarray
@classmethod
def empty(cls) -> _ReductionCounter:
return cls(value=jnp.array(1, jnp.int32))
def merge(self, other: _ReductionCounter) -> _ReductionCounter:
return _ReductionCounter(self.value + other.value)
def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
ndim = reduction_counter.value.ndim
if ndim != 0:
raise ValueError(
f"Collection is still replicated (ndim={ndim}). Maybe you forgot to "
f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?")
C = TypeVar("C", bound="Collection")
@flax.struct.dataclass
class Collection:
"""Updates a collection of `Metric` from model outputs.
Refer to the module documentation for a complete example.
Synopsis:
@flax.struct.dataclass
class Metrics(Collection):
accuracy: Accuracy
metrics = None
for inputs, labels in data:
logits = model(inputs)
update = Metrics.single_from_model_output(logits=logits, labels=labels)
metrics = update if metrics is None else metrics.merge(update)
print(metrics.compute())
"""
_reduction_counter: _ReductionCounter
@classmethod
def create(cls, **metrics: type[Metric]) -> type[Collection]:
"""Handy short-cut to define a `Collection` inline.
Instead declaring a `Collection` dataclass:
@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy
You can use this function to generate it dynamically:
MyMetrics = metrics.Collection.create(accuracy=metrics.Accuracy)
To simultaneously create the class and initialize an instance use
`Collection.create_collection` instead.
Args:
**metrics: Names and metric classes to use include in the collection.
Returns:
A subclass of Collection with fields defined by provided `metrics`.
"""
return flax.struct.dataclass(
type("_InlineCollection", (Collection,), {"__annotations__": metrics}))
@classmethod
def create_collection(cls, **metrics: Metric) -> Collection:
"""Creates a custom collection object with fields metrics.
This object will be an instance of custom subclass of `Collection` with
all fields in **metric declared as appropriate dataset fields. For example:
my_metrics = metrics.Collection.create_collection(
accuracy=metrics.Accuracy(0, 0))
is equivalent to:
@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy
my_metrics = MyMetrics(_ReductionCounter(jnp.array(1)),
accuracy=metric.Accuracy(0, 0))
Args:
**metrics: metrics to incroporate into this object.
Returns:
An instance of Collection initialized with provided `metrics`
"""
collection_class = cls.create(**{k: type(v) for k, v in metrics.items()})
counter = _ReductionCounter(jnp.array(1, dtype=jnp.int32))
return collection_class(_reduction_counter=counter, **metrics)
@classmethod
def empty(cls: type[C]) -> C:
return cls(
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
**{
metric_name: metric.empty()
for metric_name, metric
in inspect.get_annotations(cls, eval_str=True).items()
})
@classmethod
def _from_model_output(cls: type[C], **kwargs) -> C:
"""Creates a `Collection` from model outputs."""
return cls(
_reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)),
**{
metric_name: metric.from_model_output(**kwargs)
for metric_name, metric
in inspect.get_annotations(cls, eval_str=True).items()
})
@classmethod
def single_from_model_output(cls: type[C], **kwargs) -> C:
"""Creates a `Collection` from model outputs.
Note: This function should only be called when metrics are collected in a
non-distributed setting (i.e. outside a `pmap()`).
Args:
**kwargs: Model outputs used by individual metrics.
Returns:
A metric collection from provided `kwargs` model outputs.
"""
return cls._from_model_output(**kwargs)
@classmethod
def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs) -> C:
"""Creates a `Collection` from model outputs in a distributed setting.
Args:
axis_name: Name of the axis along which the values are to be gathered.
Should be the same as the `axis_name` argument to the `pmap()`.
**kwargs: Model outputs used by individual metrics.
Returns:
A metric collection from provided `kwargs` model outputs that contains
metrics for all devices across all hosts.
"""
return jax.lax.all_gather(
cls._from_model_output(**kwargs), axis_name=axis_name).reduce()
def merge(self: C, other: C) -> C:
"""Returns `Collection` that is the accumulation of `self` and `other`."""
return type(self)(**{
metric_name: metric.merge(getattr(other, metric_name))
for metric_name, metric in vars(self).items()
})
def reduce(self: C) -> C:
"""Reduces the collection by calling `Metric.reduce()` on each metric.
The primary use case is to reduce collection that was gathered
from multiple devices into one collection: For instance inside pmap
```
col = jax.lax.all_gather(col, axis_name='foo').reduce()
```
or, if computed directly from model_outputs:
```
col = col.merge(col.gather_from_model_output(**outputs)))
```
will sync collections across all devices to create a replicated collection
that include statistics from all devices.
Outside pmap, this metric can then be safely unreplicated using for
`collection.unreplicate()`.
If `collection.unreplicate()` is called without gathering it will only
contain the statistics from the first device, which is rarely a desired
behavior.
Returns:
Reduced collection.
"""
return type(self)(**{
metric_name: metric.reduce()
for metric_name, metric in vars(self).items()
})
def compute(self) -> dict[str, jnp.ndarray]:
"""Returns a dictionary mapping metric field name to `Metric.compute()`."""
_check_reduction_counter_ndim(self._reduction_counter)
return {
metric_name: metric.compute()
for metric_name, metric in vars(self).items()
if metric_name != "_reduction_counter"
}
def compute_values(self) -> dict[str, clu.values.Value]:
"""Computes metrics and returns them as clu.values.Value."""
_check_reduction_counter_ndim(self._reduction_counter)
return {
metric_name: metric.compute_value()
for metric_name, metric in vars(self).items()
if metric_name != "_reduction_counter"
}
def unreplicate(self: C) -> C:
"""Short-hand for `flax.jax_utils.unreplicate(self)`.
The collection should be gathered and `reduce`d inside pmap,
using `gather_from_model_output` or all_gather / reduce for this
function to return correct values. See `Collection.reduce` for details.
Returns:
Unreplicated collection
"""
return flax.jax_utils.unreplicate(self)
# Sentinel to make LastValue.__init__ support tree manipulations that use None.
_default = object()
@flax.struct.dataclass
class LastValue(Metric):
"""Keeps the last average global batch value.
This is useful to log values such as learning rate and losses during training.
This class mirrors Average, because it needs to maintain total/count
in cases when batch is distributed across multiple devices and need
to be averaged later. However, we don't inherit from Average to
maintain backward compatibility in case of isinstance(metric, Average)
check. For backward compatibility this class can also be initialized as
if the constructor was __init__(value).
"""
total: jnp.ndarray
count: jnp.ndarray
def __init__( # pytype: disable=missing-parameter # jnp-array
self,
total: jnp.ndarray | _default = _default,
count: jnp.ndarray | _default = _default,
value: jnp.ndarray | _default = _default,
):
"""Backward compatibility constructor.
It is intended to be constructed as __init__(total, count). When doing so
the arguments are assigned as instance attributes without extra operations.
For backward compatibility it also supports __init__(value) code paths.
Args:
total: Total value.
count: Count of examples, 1 if not provided.
value: Value, if provided, will be assumed to be "total" of values.
"""
# Note: This code should not use None to detect a default argument, also it
# should avoid doing any logic when its being called by tree_utils.
# That is a requirement for tree manipulations where leafs that use other
# values like shapes/sharding information or even None.
# Per https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html
# classes should provide a static create() method, but here we overload
# the constructor for backward compatibility when it was LastValue(value).
count = count if count is not _default else jnp.array(1, dtype=jnp.int32)
if (value is _default) == (total is _default):
raise ValueError(
"Exactly one of 'total' and 'value' should be passed. "
f"Got {total}, {value}"
)
if total is _default:
total = value * count
object.__setattr__(self, "total", total)
object.__setattr__(self, "count", count)
@classmethod
def empty(cls) -> LastValue:
return cls(jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
@classmethod
def from_model_output(
cls, value: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> LastValue:
if mask is None:
mask = jnp.ones((value.shape or [()])[0])
return cls(
total=jnp.where(mask, value, jnp.zeros_like(value)).sum(),
count=mask.sum().astype(jnp.int32),
)
def merge(self, other: LastValue) -> LastValue:
_assert_same_shape(self.value, other.value)
return other
def _reduce_merge(self, other: LastValue) -> LastValue:
# We need to average during reduction.
_assert_same_shape(self.total, other.total)
return type(self)(
total=self.total + other.total,
count=self.count + other.count,
)
@property
def value(self) -> jnp.ndarray:
# Explicitly allow NaN division as it is part of normal computation here.
with jax.debug_nans(False):
return self.total / self.count
def compute(self) -> Any:
return self.value
def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None):
"""Checks and broadcasts mask for aggregating values."""
if values.ndim == 0:
values = values[None]
if mask is None:
mask = jnp.ones_like(values)
# Leading dimensions of mask and values must match.
if mask.shape[0] != values.shape[0]:
raise ValueError(
"Argument `mask` must have the same leading dimension as `values`. "
f"Received mask of dimension {mask.shape} "
f"and values of dimension {values.shape}."
)
# Broadcast mask to the same number of dimensions as values.
if mask.ndim < values.ndim:
mask = jnp.expand_dims(mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
mask = mask.astype(bool)
utils.check_param(mask, dtype=bool, ndim=values.ndim)
return values, mask
@flax.struct.dataclass
class Average(Metric):
"""Computes the average of a scalar or a batch of tensors.
Supports the following types of masks:
- A one-dimensional mask with the same leading dimension as the scalars, or,
- A multi-dimensional mask with the exact same dimensions as the scalars.
This allows the use of per-example masks for examples in a batch, as well as
per-target masks for targets for examples in a batch.
The result is always a scalar.
See also documentation of `Metric`.
"""
total: jnp.ndarray
count: jnp.ndarray
@classmethod
def empty(cls) -> Average:
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
@classmethod
def from_model_output(
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> Average:
values, mask = _broadcast_masks(values, mask)
return cls(
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
count=jnp.where(
mask,
jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32),
).sum(),
)
def merge(self, other: Average) -> Average:
_assert_same_shape(self.total, other.total)
return type(self)(
total=self.total + other.total,
count=self.count + other.count,
)
def compute(self) -> Any:
return self.total / self.count
@flax.struct.dataclass
class Std(Metric):
"""Computes the standard deviation of a scalar or a batch of tensors.
The result is always a single scalar. See also the documentation of `Average`
for the mask handling.
"""
total: jnp.ndarray
sum_of_squares: jnp.ndarray
count: jnp.ndarray
@classmethod
def empty(cls) -> Std:
return cls(
total=jnp.array(0, jnp.float32),
sum_of_squares=jnp.array(0, jnp.float32),
count=jnp.array(0, jnp.int32))
@classmethod
def from_model_output(
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> Std:
values, mask = _broadcast_masks(values, mask)
return cls(
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(),
count=jnp.where(
mask,
jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32),
).sum(),
)
def merge(self, other: Std) -> Std:
_assert_same_shape(self.total, other.total)
return type(self)(
total=self.total + other.total,
sum_of_squares=self.sum_of_squares + other.sum_of_squares,
count=self.count + other.count,
)
def compute(self) -> Any:
# var(X) = 1/N \sum_i (x_i - mean)^2
# = 1/N \sum_i (x_i^2 - 2 x_i mean + mean^2)
# = 1/N ( \sum_i x_i^2 - 2 mean \sum_i x_i + N * mean^2 )
# = 1/N ( \sum_i x_i^2 - 2 mean N mean + N * mean^2 )
# = 1/N ( \sum_i x_i^2 - N * mean^2 )
# = \sum_i x_i^2 / N - mean^2
mean = self.total / self.count
variance = self.sum_of_squares / self.count - mean**2
# Mathematically variance can never be negative but in reality we may run
# into such issues due to numeric reasons.
variance = jnp.clip(variance, min=0.0)
return variance**.5
@flax.struct.dataclass
class Accuracy(Average):
"""Computes the accuracy from model outputs `logits` and `labels`.
`labels` is expected to be of dtype=int32 and to have 0 <= ndim <= 2, and
`logits` is expected to have ndim = labels.ndim + 1.
See also documentation of `Metric`.
"""
@classmethod
def from_model_output(
cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
) -> Accuracy:
if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
raise ValueError(
f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}=="
f"labels.ndim+1={labels.ndim + 1}")
metric = super().from_model_output(
values=(logits.argmax(axis=-1) == labels).astype(jnp.float32), **kwargs
)
return cls(**vars(metric)) # cls(metrics) doesn't work for a dataclass
================================================
FILE: clu/metrics_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clu.metrics."""
import functools
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import chex
from clu import asynclib
from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
@flax.struct.dataclass
class CollectingMetricAccuracy(
metrics.CollectingMetric.from_outputs(("logits", "labels"))):
def compute(self):
values = super().compute()
logits = values["logits"]
labels = values["labels"]
assert logits.ndim == 2, logits.shape
assert labels.ndim == 1, labels.shape
return (logits.argmax(axis=-1) == labels).mean()
@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")
@flax.struct.dataclass
class CollectionMixed(metrics.Collection):
collecting_metric_accuracy: CollectingMetricAccuracy
train_accuracy: metrics.Accuracy
class MetricsTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Clear the trace counter
chex.clear_trace_counter()
# Two batches of model output.
self.model_outputs = (
dict(
logits=jnp.array([[1., 0.], [0., 1.]]),
labels=jnp.array([0, 0]),
example_loss=jnp.array([0, 4.2]),
learning_rate=0.02,
loss=jnp.array(4.2),
),
dict(
logits=jnp.array([[1., 2.], [3., 4.]]),
labels=jnp.array([1, 1]),
example_loss=jnp.array([1.7, 0]),
learning_rate=0.01,
loss=jnp.array(1.7),
),
)
masks = (
jnp.array([False, True]),
jnp.array([True, False]),
)
self.model_outputs_masked = tuple(
dict(mask=mask, **model_output)
for mask, model_output in zip(masks, self.model_outputs))
self.count = 4
self.count_masked = 2
self.results = {
"train_accuracy": 0.75,
"learning_rate": 0.01,
}
self.results_masked = {
"train_accuracy": 0.5,
"learning_rate": 0.01,
}
self.results_gather = {
"train_accuracy": 0.75,
"learning_rate": 0.015, # Gathering averages distributed batches.
}
self.results_gather_masked = {
"train_accuracy": 0.5,
"learning_rate": 0.015, # Gathering averages distributed batches.
}
# Stack all values. Can for example be used in a pmap().
self.model_outputs_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *self.model_outputs
)
self.model_outputs_masked_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *self.model_outputs_masked
)
def make_compute_metric(self, metric_class, reduce, jit=True):
"""Returns a jitted function to compute metrics.
Args:
metric_class: Metric class to instantiate.
reduce: If set to `True`.
jit: Whether the returned function should be jitted.
Returns:
A function that takes `model_outputs` (list of dictionaries of values) as
an input and returns the value from `metric.compute()`.
"""
def compute_metric(model_outputs):
if reduce:
metric_list = [
metric_class.from_model_output(**model_output)
for model_output in model_outputs
]
metric_stacked = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *metric_list
)
metric = metric_stacked.reduce()
else:
metric = metric_class.empty()
for model_output in model_outputs:
update = metric_class.from_model_output(**model_output)
metric = metric.merge(update)
return metric.compute()
if jit:
compute_metric = jax.jit(compute_metric)
return compute_metric
def test_metric_last_value_reduce(self):
metric1 = metrics.LastValue.from_model_output(jnp.array([1, 2]))
metric2 = metrics.LastValue.from_model_output(jnp.array([3, 4]))
metric3 = metrics.LastValue.from_model_output(jnp.array([3, 4]),
jnp.array([0, 0]))
metric12 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric2
)
metric21 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric2, metric1
)
self.assertEqual(metric12.reduce().value, 2.5)
chex.assert_trees_all_equal(metric12.reduce().compute(),
metric21.reduce().compute())
metric13 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric3
)
metric31 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric3
)
self.assertEqual(metric13.reduce().value, 1.5)
chex.assert_trees_all_equal(metric13.reduce().compute(),
metric31.reduce().compute())
def test_metric_last_value(self):
metric0 = metrics.LastValue.from_model_output(jnp.array([]))
metric1 = metrics.LastValue.from_model_output(jnp.array([1, 2]))
metric2 = metrics.LastValue.from_model_output(jnp.array([3, 4]))
np.testing.assert_equal(metric0.value, jnp.array(np.nan))
with jax.debug_nans(True):
# Verify that metrics is computable even under strict NaN checking.
_ = metric0.value
metric01 = metric0.merge(metric1)
self.assertEqual(metric01.value, 1.5)
metric12 = metric1.merge(metric2)
self.assertEqual(metric1.value, 1.5)
self.assertEqual(metric12.value, 3.5)
chex.assert_trees_all_equal(metric12.compute(), metric2.compute())
def test_metric_last_value_legacy_kwarg_value(self):
metric = metrics.LastValue(value=2.0)
self.assertEqual(metric.total, 2.0)
metric = metrics.LastValue(value=2.0, count=3)
self.assertEqual(metric.total, 6.0)
def test_metric_last_value_tree_manipulation(self):
# Test mapping leaves to other non array values (e.g.: None).
metric = metrics.LastValue(value=2.0)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)
metric = metrics.LastValue(value=2.0, count=3)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)
metric = metrics.LastValue(2.0)
metric = jax.tree_util.tree_map(lambda x: None, metric)
self.assertIsNone(metric.total, None)
self.assertIsNone(metric.count, None)
def test_from_fun_with_single_output(self):
def accuracy(*, logits, labels, **_):
return (logits.argmax(axis=-1) == labels).astype(jnp.float32)
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Average.from_fun(accuracy),
reduce=False)(self.model_outputs), self.results["train_accuracy"])
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Average.from_fun(accuracy),
reduce=False)(self.model_outputs_masked),
self.results_masked["train_accuracy"])
def test_from_fun_with_mapping_output(self):
# This tests .from_fun() with a function that returns a mapping. Accuracy
# accepts logits and labels already, so this function just passes them
# along. (This isn't needed in real code that uses Accuracy, just to test
# `from_fun`.)
def make_accuracy_args_map(*, logits, labels, **_):
return dict(logits=logits, labels=labels)
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Accuracy.from_fun(make_accuracy_args_map),
reduce=False)(self.model_outputs), self.results["train_accuracy"])
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Accuracy.from_fun(make_accuracy_args_map),
reduce=False)(self.model_outputs_masked),
self.results_masked["train_accuracy"])
@parameterized.named_parameters(
("0d_values_no_mask", 1, None, 1.),
("1d_values_no_mask", [1, 2, 3], None, 2.),
("1d_values_1d_mask", [1, 2, 3], [True, True, False], 1.5),
("2d_values_no_mask", [[1, 2], [2, 3], [3, 4]], None, 2.5),
("2d_values_1d_mask", [[1, 2], [2, 3], [3, 4]], [False, True, True], 3.),
("2d_values_2d_mask", [[1, 2], [2, 3], [3, 4]],
[[False, True], [True, True], [True, True]], 2.8),
("3d_values_no_mask", [[[1, 2], [2, 3]], [[2, 1], [3, 4]],
[[3, 1], [4, 1]]], None, 2.25),
("3d_values_1d_mask", [[[1, 2], [2, 3]], [[2, 1], [3, 4]],
[[3, 1], [4, 1]]], [False, True, True], 2.375),
)
def test_average_masked(self, values, mask, expected_result):
values = jnp.asarray(values)
if mask is not None:
mask = jnp.asarray(mask)
chex.assert_trees_all_close(
metrics.Average.from_model_output(values, mask=mask).compute(),
expected_result)
def rename_mask(**kwargs):
return dict(my_loss=kwargs["values"], mask=kwargs["my_mask"])
chex.assert_trees_all_close(
(metrics.Average
.from_output("my_loss")
.from_fun(rename_mask)
.from_model_output(values=values, my_mask=mask)
.compute()),
expected_result)
@parameterized.named_parameters(
("Average", metrics.Average),
("Std", metrics.Std),
("LastValue", metrics.LastValue),
)
def test_merge_asserts_shape(self, metric_cls):
metric1 = metric_cls.from_model_output(jnp.arange(3.))
metric2 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric1
)
with self.assertRaisesRegex(ValueError, r"^Expected same shape"):
metric1.merge(metric2)
@parameterized.named_parameters(
("", False),
("_reduce", True),
)
def test_accuracy(self, reduce):
chex.assert_trees_all_close(
self.make_compute_metric(metrics.Accuracy, reduce)(self.model_outputs),
self.results["train_accuracy"])
def test_last_value_asserts_shape(self):
metric1 = metrics.LastValue.from_model_output(jnp.arange(3.))
metric2 = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), metric1, metric1
)
with self.assertRaisesRegex(ValueError, r"^Expected same shape"):
metric1.merge(metric2)
@parameterized.named_parameters(
("", False),
("_reduce", True),
)
def test_loss_average(self, reduce):
chex.assert_trees_all_close(
self.make_compute_metric(metrics.Average.from_output("loss"),
reduce)(self.model_outputs_masked),
self.model_outputs_stacked["loss"].mean())
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Average.from_output("example_loss"),
reduce)(self.model_outputs_masked),
self.model_outputs_stacked["loss"].mean())
@parameterized.named_parameters(
("", False),
("_reduce", True),
)
def test_loss_std(self, reduce):
chex.assert_trees_all_close(
self.make_compute_metric(metrics.Std.from_output("loss"),
reduce)(self.model_outputs_masked),
self.model_outputs_stacked["loss"].std(),
atol=1e-4)
chex.assert_trees_all_close(
self.make_compute_metric(
metrics.Std.from_output("example_loss"),
reduce)(self.model_outputs_masked),
self.model_outputs_stacked["loss"].std(),
atol=1e-4)
def test_collection_create(self):
collection = metrics.Collection.create(accuracy=metrics.Accuracy)
chex.assert_trees_all_close(
collection.single_from_model_output(
logits=jnp.array([[-1., 1.], [1., -1.]]),
labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct
).compute(),
{"accuracy": 0.5})
def test_collection_create_custom_mask(self):
def with_head1(logits, labels, mask, head1_mask, **_):
return dict(logits=logits, labels=labels, mask=head1_mask & mask)
def with_head2(logits, labels, mask, head2_mask, **_):
return dict(logits=logits, labels=labels, mask=head2_mask & mask)
collection = metrics.Collection.create(
head1_accuracy=metrics.Accuracy.from_fun(with_head1),
head2_accuracy=metrics.Accuracy.from_fun(with_head2)
)
chex.assert_trees_all_close(
collection.single_from_model_output(
logits=jnp.array([[-1.0, 1.0], [1.0, -1.0]]),
labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct
mask=jnp.array([True, True]),
head1_mask=jnp.array([True, False]), # ignore the 2nd.
head2_mask=jnp.array([False, True]), # ignore the 1st.
).compute(),
{"head1_accuracy": 0.0, "head2_accuracy": 1.0},
)
def test_collection_create_collection(self):
collection = metrics.Collection.create_collection(
accuracy=metrics.Accuracy.from_model_output(
logits=jnp.array([[-1., 1.], [1., -1.]]),
labels=jnp.array([0, 0])), # i.e. 1st incorrect, 2nd correct)
loss=metrics.Average.from_model_output(jnp.array([0, 1, 2])))
chex.assert_trees_all_close(collection.compute(), {
"accuracy": 0.5,
"loss": 1
})
chex.assert_trees_all_close(
{k: v.value for k, v in collection.compute_values().items()}, {
"accuracy": 0.5,
"loss": 1
})
@parameterized.named_parameters(
("", False),
("_masked", True),
)
def test_collection_single(self, masked):
@jax.jit
def compute_collection(model_outputs):
collection = Collection.empty()
for model_output in model_outputs:
update = Collection.single_from_model_output(**model_output)
collection = collection.merge(update)
return collection
model_outputs = self.model_outputs_masked if masked else self.model_outputs
collection = compute_collection(model_outputs)
chex.assert_trees_all_close(
collection.compute(), self.results_masked if masked else self.results
)
self.assertEqual(
collection.train_accuracy.count,
self.count_masked if masked else self.count,
)
@parameterized.named_parameters(
("", False),
("_masked", True),
)
@mock.patch("jax.lax.all_gather")
def test_collection_gather(self, masked, all_gather_mock):
model_outputs = self.model_outputs_masked if masked else self.model_outputs
collections = [
Collection.single_from_model_output(**model_output)
for model_output in (model_outputs)
]
all_gather_mock.return_value = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *collections
)
def compute_collection(model_outputs):
collection = Collection.gather_from_model_output(**model_outputs[0])
return collection.compute()
observed = jax.jit(compute_collection)(model_outputs)
expectation = self.results_gather_masked if masked else self.results_gather
chex.assert_trees_all_close(observed, expectation)
@parameterized.named_parameters(
("", False),
("_masked", True),
)
def test_collection_gather_pmap(self, masked):
@functools.partial(jax.pmap, axis_name="batch")
def compute_collection(model_outputs):
return Collection.gather_from_model_output(**model_outputs)
if jax.local_device_count() > 1:
chex.assert_trees_all_close(
compute_collection(
self.model_outputs_masked_stacked if masked else self
.model_outputs_stacked).unreplicate().compute(),
self.results_gather_masked if masked else self.results_gather)
def test_collection_asserts_replication(self):
collections = [
Collection.single_from_model_output(**model_output)
for model_output in self.model_outputs
]
collection = jax.tree_util.tree_map(
lambda *args: jnp.stack(args), *collections
)
with self.assertRaisesRegex(ValueError, r"^Collection is still replicated"):
collection.compute()
def test_collecting_metric(self):
metric_class = metrics.CollectingMetric.from_outputs(("logits", "loss"))
logits = np.concatenate(
[model_output["logits"] for model_output in self.model_outputs])
loss = np.array(
[model_output["loss"] for model_output in self.model_outputs])
result = self.make_compute_metric(
metric_class, reduce=False, jit=False)(
self.model_outputs)
chex.assert_trees_all_close(result, {
"logits": logits,
"loss": loss,
})
def test_collecting_metric_reduce(self):
metric_class = metrics.CollectingMetric.from_outputs(("value",))
metric = jax.jit(metric_class.from_model_output)(value=jnp.ones([8, 2, 4]))
reduced = metric.reduce()
chex.assert_trees_all_close(reduced.compute(), {"value": np.ones([16, 4])})
def test_collecting_metric_async(self):
pool = asynclib.Pool()
@pool
def copy_to_host(update):
return jax.tree_util.tree_map(np.asarray, update)
futures = []
from_model_output = jax.jit(CollectingMetricAccuracy.from_model_output)
for model_output in self.model_outputs:
futures.append(copy_to_host(from_model_output(**model_output)))
metric = CollectingMetricAccuracy.empty()
for future in futures:
metric = metric.merge(future.result())
result = metric.compute()
chex.assert_trees_all_close(result, 0.75)
def test_collecting_metric_tracer(self):
metric_class = metrics.CollectingMetric.from_outputs(("logits",))
with self.assertRaisesRegex(RuntimeError, r"^Tracer detected!"):
_ = self.make_compute_metric(
metric_class, reduce=False, jit=True)(
self.model_outputs)
def test_collection_mixed_async(self):
metric = CollectionMixed.empty()
pool = asynclib.Pool()
@pool
def merge(update):
nonlocal metric
metric = metric.merge(update)
for model_output in self.model_outputs:
merge(jax.jit(CollectionMixed.single_from_model_output)(**model_output))
pool.join()
result = metric.compute()
chex.assert_trees_all_close(result, {
"collecting_metric_accuracy": 0.75,
"train_accuracy": 0.75,
})
def test_metric_empty_types_doesnt_cause_retrace(self):
@jax.jit
@chex.assert_max_traces(n=1)
def merge_collection(model_output, collection):
update = Collection.single_from_model_output(**model_output)
return collection.merge(update)
# Metric will be initialized with a strong type
# Can only use non-collecting metrics as the shape of collecting
# metrics changes every iteration.
collection = Collection.empty()
for model_output in self.model_outputs:
# The merged metric _should not_ have weak types
# If it does have a weak type the second call will cause a re-trace
collection = merge_collection(model_output, collection)
@parameterized.product(
value_mask_pair=[
(1, None),
([1, 2, 3], None),
([1, 2, 3], [True, True, False]),
([[1, 2], [2, 3], [3, 4]], None),
([[1, 2], [2, 3], [3, 4]], [False, True, True]),
(
[[1, 2], [2, 3], [3, 4]],
[[False, True], [True, True], [True, True]],
),
([[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], None),
(
[[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]],
[False, True, True],
),
],
metric_np_equivalent_pair=[
(metrics.Average, jnp.mean),
(metrics.Std, jnp.std),
],
)
def test_tensor_aggregation_metrics_with_masks(
self, value_mask_pair, metric_np_equivalent_pair
):
values, mask = value_mask_pair
metric, np_equivalent = metric_np_equivalent_pair
values = jnp.asarray(values)
masked = values
if mask is not None:
mask = jnp.asarray(mask)
masked = values[mask]
expected = np_equivalent(masked)
result = metric.from_model_output(values, mask=mask).compute()
# The lower precision is needed for the lower precision jitted version.
chex.assert_trees_all_close(result, expected, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/parameter_overview.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper function for creating and logging JAX variable overviews."""
from collections.abc import Callable, Mapping, Sequence
import dataclasses
from typing import Any
from absl import logging
import flax
import jax
import jax.numpy as jnp
import numpy as np
_ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]]
@dataclasses.dataclass
class _ParamRow:
name: str
shape: tuple[int, ...]
dtype: str
size: int
@dataclasses.dataclass
class _ParamRowWithSharding(_ParamRow):
sharding: tuple[int | None, ...] | str
@dataclasses.dataclass
class _ParamRowWithStats(_ParamRow):
mean: float
std: float
@dataclasses.dataclass
class _ParamRowWithStatsAndSharding(_ParamRowWithStats):
sharding: tuple[int | None, ...] | str
@jax.jit
def _mean_std_jit(x):
return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x)
def _mean_std(x):
mean = jax.tree_util.tree_map(lambda x: x.mean(), x)
std = jax.tree_util.tree_map(lambda x: x.std(), x)
return mean, std
def flatten_dict(
input_dict: dict[str, Any], *, prefix: str = "", delimiter: str = "/"
) -> dict[str, Any]:
"""Flattens the keys of a nested dictionary."""
output_dict = {}
for key, value in input_dict.items():
nested_key = f"{prefix}{delimiter}{key}" if prefix else key
if isinstance(value, (dict, flax.core.FrozenDict)):
output_dict.update(
flatten_dict(value, prefix=nested_key, delimiter=delimiter)
)
else:
output_dict[nested_key] = value
return output_dict
def _count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(np.prod(v.shape) for v in params.values() if v is not None)
def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(
np.prod(v.shape) * v.dtype.itemsize
for v in params.values()
if v is not None
)
def count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""
return _count_parameters(params)
def _make_row(name, value) -> _ParamRow:
if value is None:
return _ParamRow(
name=name,
shape=(),
dtype="",
size=0,
)
return _ParamRow(
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
)
def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
row = _make_row(name, value)
if hasattr(value, "sharding"):
if hasattr(value.sharding, "spec"):
sharding = tuple(value.sharding.spec)
else:
sharding = str(value.sharding)
else:
sharding = ()
return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding)
def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
row = _make_row(name, value)
mean = mean or 0.0
std = std or 0.0
return _ParamRowWithStats(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
)
def _make_row_with_stats_and_sharding(
name, value, mean, std
) -> _ParamRowWithStatsAndSharding:
row = _make_row_with_sharding(name, value)
return _ParamRowWithStatsAndSharding(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
)
def _get_parameter_rows(
params: _ParamsContainer,
*,
include_stats: bool | str = False,
) -> list[_ParamRow]:
"""Returns information about parameters as a list of dictionaries.
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested. Alternatively a `tf.Module` can be provided, in which case the
`trainable_variables` of the module will be used.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.
Returns:
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
of `include_stats`.
"""
if not isinstance(params, (dict, flax.core.FrozenDict)):
raise ValueError(
f"Expected `params` to be a dictionary but got {type(params)}"
)
params = flatten_dict(params)
if params:
names, values = map(list, tuple(zip(*sorted(params.items()))))
else:
names, values = [], []
match include_stats:
case False:
return jax.tree_util.tree_map(_make_row, names, values)
case True:
mean_and_std = _mean_std(values)
return jax.tree_util.tree_map(
_make_row_with_stats, names, values, *mean_and_std
)
case "global":
mean_and_std = _mean_std_jit(values)
return jax.tree_util.tree_map(
_make_row_with_stats_and_sharding, names, values, *mean_and_std
)
case "sharding":
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
case _:
raise ValueError(f"Unknown `include_stats`: {include_stats}")
def _default_table_value_formatter(value):
"""Formats ints with "," between thousands and floats to 3 digits."""
if isinstance(value, bool):
return str(value)
elif isinstance(value, int):
return "{:,}".format(value)
elif isinstance(value, float):
return "{:.3}".format(value)
else:
return str(value)
def make_table(
rows: list[Any],
*,
column_names: Sequence[str] | None = None,
value_formatter: Callable[[Any], str] = _default_table_value_formatter,
max_lines: int | None = None,
) -> str:
"""Renders a list of rows to a table.
Args:
rows: List of dataclass instances of a single type (e.g. `ParamRow`).
column_names: List of columns that that should be included in the output. If
not provided, then the columns are taken from keys of the first row.
value_formatter: Callable used to format cell values.
max_lines: Don't render a table longer than this.
Returns:
A string representation of the table in the form:
+---------+---------+
| Col1 | Col2 |
+---------+---------+
| value11 | value12 |
| value21 | value22 |
+---------+---------+
"""
if any(not dataclasses.is_dataclass(row) for row in rows):
raise ValueError("Expected `rows` to be list of dataclasses")
if len(set(map(type, rows))) > 1:
raise ValueError("Expected elements of `rows` be of same type.")
class Column:
def __init__(self, name, values):
self.name = name.capitalize()
self.values = values
self.width = max(len(v) for v in values + [name])
if column_names is None:
if not rows:
return "(empty table)"
column_names = [field.name for field in dataclasses.fields(rows[0])]
columns = [
Column(name, [value_formatter(getattr(row, name)) for row in rows])
for name in column_names
]
var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns)
sep_line_format = var_line_format.replace(" ", "-").replace("|", "+")
header = var_line_format.replace(">", "<").format(*[c.name for c in columns])
separator = sep_line_format.format(*["" for c in columns])
lines = [separator, header, separator]
for i in range(len(rows)):
if max_lines and len(lines) >= max_lines - 3:
lines.append("[...]")
break
lines.append(var_line_format.format(*[c.values[i] for c in columns]))
lines.append(separator)
return "\n".join(lines)
def _get_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool | str = True,
max_lines: int | None = None,
) -> str:
"""See get_parameter_overview()."""
if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison
params = jax.device_get(params) # A no-op if already numpy array.
rows = _get_parameter_rows(params, include_stats=include_stats)
RowType = { # pylint: disable=invalid-name
False: _ParamRow,
True: _ParamRowWithStats,
"global": _ParamRowWithStatsAndSharding,
"sharding": _ParamRowWithSharding,
}[include_stats]
# Pass in `column_names` to enable rendering empty tables.
column_names = [field.name for field in dataclasses.fields(RowType)]
table = make_table(rows, max_lines=max_lines, column_names=column_names)
total_weights = _count_parameters(params)
total_size = _parameters_size(params)
return table + f"\nTotal: {total_weights:,} -- {total_size:,} bytes"
def get_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool | str = True,
max_lines: int | None = None,
) -> str:
"""Returns a string with variables names, their shapes, count.
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable. If
the string "sharding", add column a column with the sharding of the
variable. If the string "global", params are sharded global arrays and
this function assumes it is called on every host, i.e. can use
collectives. The sharding of the variables is also added as a column.
max_lines: If not `None`, the maximum number of variables to include.
Returns:
A string with a table like in the example.
+----------------+---------------+------------+
| Name | Shape | Size |
+----------------+---------------+------------+
| FC_1/weights:0 | (63612, 1024) | 65,138,688 |
| FC_1/biases:0 | (1024,) | 1,024 |
| FC_2/weights:0 | (1024, 32) | 32,768 |
| FC_2/biases:0 | (32,) | 32 |
+----------------+---------------+------------+
Total: 65,172,512
"""
return _get_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines
)
def _log_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool | str = True,
max_lines: int | None = None,
msg: str | None = None,
jax_logging_process: int | None = None,
):
"""See log_parameter_overview()."""
table = _get_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines
)
if jax_logging_process is None or jax_logging_process == jax.process_index():
lines = [msg] if msg else []
lines += table.split("\n")
# The table can be too large to fit into one log entry.
for i in range(0, len(lines), 80):
logging.info("\n%s", "\n".join(lines[i : i + 80]))
def log_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool | str = True,
max_lines: int | None = None,
msg: str | None = None,
jax_logging_process: int | None = None,
):
"""Writes a table with variables name and shapes to INFO log.
See get_parameter_overview for details.
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable. If
the string "global", params are sharded global arrays and this function
assumes it is called on every host, i.e. can use collectives.
max_lines: If not `None`, the maximum number of variables to include.
msg: Message to be logged before the overview.
jax_logging_process: Which JAX process ID should do the logging. None = all.
Use this to avoid logspam when include_stats="global".
"""
_log_parameter_overview(
params,
include_stats=include_stats,
max_lines=max_lines,
msg=msg,
jax_logging_process=jax_logging_process,
)
================================================
FILE: clu/parameter_overview_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for parameter overviews."""
from absl.testing import absltest
from clu import parameter_overview
from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np
EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+------+-------+-------+------+------+-----+
+------+-------+-------+------+------+-----+
Total: 0 -- 0 bytes"""
FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+---------+------+
| Name | Shape | Dtype | Size |
+-------------+--------------+---------+------+
| conv/bias | (2,) | float32 | 2 |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 |
+-------------+--------------+---------+------+
Total: 56 -- 224 bytes"""
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+
| Name | Shape | Dtype | Size | Sharding |
+-------------+--------------+---------+------+----------+
| conv/bias | (2,) | float32 | 2 | () |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | () |
+-------------+--------------+---------+------+----------+
Total: 56 -- 224 bytes"""
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+-------------+--------------+---------+------+------+-----+
| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
+-------------+--------------+---------+------+------+-----+
Total: 56 -- 224 bytes"""
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+
| Name | Shape | Dtype | Size | Mean | Std | Sharding |
+-------------+--------------+---------+------+------+-----+----------+
| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () |
+-------------+--------------+---------+------+------+-----+----------+
Total: 56 -- 224 bytes"""
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+--------------------+--------------+---------+------+------+-----+
| params/conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
| params/conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
+--------------------+--------------+---------+------+------+-----+
Total: 56 -- 224 bytes"""
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Conv(features=2, kernel_size=(3, 3), name="conv")(x)
class JaxParameterOverviewTest(absltest.TestCase):
def test_count_parameters_empty(self):
self.assertEqual(0, parameter_overview.count_parameters({}))
def test_count_parameters(self):
rng = jax.random.PRNGKey(42)
# Weights of a 2D convolution with 2 filters.
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
# 3 * 3*3 * 2 + 2 (bias) = 56 parameters
self.assertEqual(56,
parameter_overview.count_parameters(variables["params"]))
def test_get_parameter_overview_empty(self):
self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview({}))
self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview({"a": {}}))
def test_get_parameter_overview(self):
rng = jax.random.PRNGKey(42)
# Weights of a 2D convolution with 2 filters.
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
variables = jax.tree_util.tree_map(jnp.ones_like, variables)
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview(
variables["params"], include_stats=False))
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS,
parameter_overview.get_parameter_overview(variables["params"]))
self.assertEqual(
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS,
parameter_overview.get_parameter_overview(variables))
# Add sharding with PartitionSpecs.
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d")
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
variables = jax.jit(lambda x: x, out_shardings=sharding)(variables)
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING,
parameter_overview.get_parameter_overview(
variables["params"], include_stats="sharding"))
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING,
parameter_overview.get_parameter_overview(
variables["params"], include_stats="global"))
def test_get_parameter_overview_shape_dtype_struct(self):
variables_shape_dtype_struct = jax.eval_shape(
lambda: CNN().init(jax.random.PRNGKey(42), jnp.zeros((2, 5, 5, 3))))
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview(
variables_shape_dtype_struct["params"], include_stats=False))
def test_printing_bool(self):
self.assertEqual(
parameter_overview._default_table_value_formatter(True), "True")
self.assertEqual(
parameter_overview._default_table_value_formatter(False), "False")
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/periodic_actions.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PeriodicActions execute small actions periodically in the training loop."""
import abc
import collections
import concurrent.futures
import contextlib
import functools
import os
import time
from typing import Callable, Iterable, Optional, Sequence
from absl import logging
from clu import asynclib
from clu import metric_writers
from clu import platform
from clu import profiler
from etils import epath
import jax
import jax.numpy as jnp
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
MetricWriter = metric_writers.MetricWriter
@jax.jit
def _squareit(x):
"""Minimalistic function for use in _wait_jax_async_dispatch()."""
return x**2
def _format_secs(secs: float):
"""Formats seconds like 123456.7 to strings like "1d10h17m"."""
s = ""
days = int(secs / (3600 * 24))
secs -= days * 3600 * 24
if days:
s += f"{days}d"
hours = int(secs / 3600)
secs -= hours * 3600
if hours:
s += f"{hours}h"
mins = int(secs / 60)
s += f"{mins}m"
return s
class PeriodicAction(abc.ABC):
"""Abstract base class for perodic actions.
The idea is that the user creates periodic actions and calls them after
each training step. The base class will trigger in fixed step/time interval
but subclasses can overwrite `_should_trigger()` to change this behavior.
Subclasses must implement `_apply()` to perform the action.
"""
def __init__(self,
*,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None):
"""Creates an action that triggers periodically.
Args:
every_steps: If the current step is divisible by `every_steps`, then an
action is triggered.
every_secs: If no action has triggered for specified `every_secs`, then
an action is triggered. Note that the previous action might have been
triggered by `every_steps` or by `every_secs`.
on_steps: If the current step is included in this set, then an action is
triggered.
"""
self._every_steps = every_steps
self._every_secs = every_secs
self._on_steps = set(on_steps or [])
# Step and timestamp for the last time the action triggered.
self._previous_step: int = None
self._previous_time: float = None
# Just for checking that __call__() was called every step.
self._last_step: int = None
def _init_and_check(self, step: int, t: float):
"""Initializes and checks it was called at every step."""
if self._previous_step is None:
self._previous_step = step
self._previous_time = t
self._last_step = step
elif self._every_steps is not None and step - self._last_step != 1:
raise ValueError(f"PeriodicAction must be called after every step once "
f"(every_steps={self._every_steps}, "
f"previous_step={self._previous_step}, step={step}).")
else:
self._last_step = step
def _should_trigger(self, step: int, t: float) -> bool:
"""Return whether the action should trigger this step."""
if self._every_steps is not None and step % self._every_steps == 0:
return True
if (self._every_secs is not None and
t - self._previous_time > self._every_secs):
return True
if step in self._on_steps:
return True
return False
def _after_apply(self, step: int, t: float):
"""Called after each time the action triggered."""
self._previous_step = step
self._previous_time = t
def __call__(self, step: int, t: Optional[float] = None) -> bool:
"""Method to call the hook after every training step.
Args:
step: Current step.
t: Optional timestamp. Will use `time.monotonic()` if not specified.
Returns:
True if the action triggered, False otherwise. Note that the first
invocation never triggers.
"""
if t is None:
t = time.monotonic()
self._init_and_check(step, t)
if self._should_trigger(step, t):
self._apply(step, t)
self._after_apply(step, t)
return True
return False
@abc.abstractmethod
def _apply(self, step: int, t: float):
pass
class ReportProgress(PeriodicAction):
"""This hook will set the progress note on the work unit."""
def __init__(self,
*,
num_train_steps: Optional[int] = None,
writer: Optional[MetricWriter] = None,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 60.0,
on_steps: Optional[Iterable[int]] = None):
"""Creates a new ReportProgress hook.
Reports progress summary via `platform.work_unit().set_notes()`, and logs
some additional metrics:
- "uptime": secs since program start
- "steps_per_sec": point esitmate of steps/sec
Args:
num_train_steps: The total number of training steps for training.
writer: Optional MetricWriter to report steps_per_sec measurement. This is
an estimate for precise values use Xprof.
every_steps: How often to report the progress in number of training steps.
every_secs: How often to report progress as time interval.
on_steps: Report the progress on these training steps.
"""
on_steps = set(on_steps or [])
if num_train_steps is not None:
on_steps.add(num_train_steps)
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps)
# Check for negative values, e.g. tf.data.UNKNOWN/INFINITE_CARDINALTY.
if num_train_steps is not None and num_train_steps < 0:
num_train_steps = None
self._num_train_steps = num_train_steps
self._writer = writer
self._time_per_part = collections.defaultdict(float)
self._t0 = time.monotonic()
# Using max_worker=1 guarantees that the calls to _wait_jax_async_dispatch()
# happen sequentially.
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self._persistent_notes = ""
def set_persistent_notes(self, message: str):
"""Sets the persistent notes for this work unit (not overwritten by the periodic action)."""
self._persistent_notes = message
def _should_trigger(self, step: int, t: float) -> bool:
# Note: step == self._previous_step is only True on the first step.
return step != self._previous_step and super()._should_trigger(step, t)
def _apply(self, step: int, t: float):
steps_per_sec = (step - self._previous_step) / (t - self._previous_time)
message = f"{steps_per_sec:.1f} steps/s"
if self._num_train_steps:
eta_seconds = (self._num_train_steps - step) / steps_per_sec
message += (f", {100 * step / self._num_train_steps:.1f}% "
f"({step}/{self._num_train_steps}), "
f"ETA: {_format_secs(eta_seconds)}")
if self._time_per_part:
total = time.monotonic() - self._t0
message += " ({} : {})".format(_format_secs(total), ", ".join(
f"{100 * dt / total:.1f}% {name}"
for name, dt in sorted(self._time_per_part.items())))
# This should be relatively cheap so we can do it in the same main thread.
if self._persistent_notes:
message = f"{self._persistent_notes}\n{message}"
platform.work_unit().set_notes(message)
if self._writer is not None:
self._writer.write_scalars(step, {"steps_per_sec": steps_per_sec})
self._writer.write_scalars(step, {"uptime": time.monotonic() - self._t0})
@contextlib.contextmanager
def timed(self, name: str, wait_jax_async_dispatch: bool = True):
# pylint: disable=g-doc-return-or-yield
"""Measures time spent in a named part of the training loop.
The reported progress will break down the total time into the different
parts spent inside blocks.
Example:
report_progress = hooks.ReportProgress()
for step, batch in enumerate(train_iter):
params = train_step(params, batch)
report_progress(step + 1)
if (step + 1) % eval_every_steps == 0:
with report_progress.timed("eval"):
evaluate()
The above example would result in the progress being reported as something
like "20% @2000 ... (5 min : 10% eval)" - assuming that evaluation takes 10%
of the entire time in this case.
Args:
name: Name of the part to be measured.
wait_jax_async_dispatch: When set to `True`, JAX async dispatch queue will
be emptied by creating a new computation and waiting for its completion.
This makes sure that previous computations (e.g. the last train step)
have actually finished. The same is done before the time is measured.
Note that this wait happens in a different thread that is only used for
measuring start/stop time of timed parts. In other words, the measured
timings reflect the start/stop of the JAX computations within the
measured part: the timer is started when the last computation before the
block has finished, and the timer is stopped when the last computation
from within the block has finished. Note that due to JAX execution these
operations asynchronously, the measured time might overlap with non-JAX
computations outside the measured block.
When set to `False`, then the measured time is of the Python statements
within the block.
If there are no expensive JAX computations enqueued in JAX's async
dispatch queue, then both measurements are identical.
"""
# pylint: enable=g-doc-return-or-yield
if not wait_jax_async_dispatch:
# Easy case, just measure walltime.
start = time.monotonic()
yield
self._time_per_part[name] += time.monotonic() - start
return
def start_measurement(barrier: jax.Array) -> float:
barrier.block_until_ready()
return time.monotonic()
def stop_measurement(
start_future: concurrent.futures.Future[float], barrier: jax.Array
):
barrier.block_until_ready()
self._time_per_part[name] += time.monotonic() - start_future.result()
# Call _squareit on this thread so that it is guaranteed to be dispatched
# to the TPU before any computations inside `yield`.
start_future = self._executor.submit(
start_measurement, barrier=_squareit(jnp.array(0.0))
)
yield
# Same pattern: _squareit is dispatched after any programs dispatched from
# within `yield` and before any programs following this method. The time
# difference between the completion of the first _squareit and the this one
# is the time the TPU spent executing programs dispatched from within
# `yield`.
self._executor.submit(
stop_measurement,
start_future=start_future,
barrier=_squareit(jnp.array(0.0)),
)
class Profile(PeriodicAction):
"""This hook collects calls profiler.start()/stop() every time it triggers.
"""
def __init__(
self,
*,
logdir: epath.PathLike,
num_profile_steps: Optional[int] = 5,
profile_duration_ms: Optional[int] = 3_000,
first_profile: int = 10,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 3600.0,
on_steps: Optional[Iterable[int]] = None,
artifact_name: str = "[{step}] Profile",
):
"""Initializes a new periodic profiler action.
Args:
logdir: Where the profile should be stored (required for
`tf.profiler.experimental`).
num_profile_steps: Over how many steps the profile should be taken. Note
that when specifying both num_profile_steps and profile_duration_ms then
both conditions will be fulfilled.
profile_duration_ms: Minimum duration of profile.
first_profile: First step at which a profile is started.
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
artifact_name: Name of the artifact to record.
"""
if not num_profile_steps and not profile_duration_ms:
raise ValueError(
"Must specify num_profile_steps and/or profile_duration_ms.")
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
)
self._num_profile_steps = num_profile_steps
self._first_profile = first_profile
self._profile_duration_ms = profile_duration_ms
self._session_running = False
self._session_started = None
self._logdir = os.fspath(logdir)
self._artifact_name = artifact_name
def _should_trigger(self, step: int, t: float) -> bool:
if self._session_running:
# If a session is running we only check if we should stop it.
dt = t - self._session_started
cond = (not self._profile_duration_ms or
dt * 1e3 >= self._profile_duration_ms)
cond &= (not self._num_profile_steps or
step >= self._previous_step + self._num_profile_steps)
if cond:
self._end_session(profiler.stop())
return False
# Allow triggering at `self._first_profile` step.
return super()._should_trigger(step, t) or step == self._first_profile
def _apply(self, step: int, t: float):
del step, t # Unused.
self._start_session()
def _start_session(self):
try:
profiler.start(logdir=self._logdir)
self._session_running = True
self._session_started = time.monotonic()
except Exception as e: # pylint: disable=broad-except
logging.exception("Could not start profiling: %s", e)
def _end_session(self, url: Optional[str]):
platform.work_unit().create_artifact(
platform.ArtifactType.URL,
url,
description=self._artifact_name.format(step=self._previous_step))
self._session_running = False
self._session_started = None
class ProfileAllHosts(PeriodicAction):
"""This hook collects calls profiler.collect() every time it triggers.
"""
def __init__(self,
*,
logdir: str,
hosts: Optional[Sequence[str]] = None,
profile_duration_ms: int = 3_000,
first_profile: int = 10,
every_steps: Optional[int] = None,
every_secs: Optional[float] = 3600.0,
on_steps: Optional[Iterable[int]] = None):
"""Initializes a new periodic profiler action.
Args:
logdir: Where the profile should be stored (required for
`tf.profiler.experimental`).
hosts: Addresses of the hosts. If omitted will default to the current job.
profile_duration_ms: Duration of profile.
first_profile: First step at which a profile is started.
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
"""
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps
)
self._hosts = hosts
self._first_profile = first_profile
self._profile_duration_ms = profile_duration_ms
self._logdir = logdir
def _should_trigger(self, step: int, t: float) -> bool:
return super()._should_trigger(step, t) or step == self._first_profile
def _apply(self, step: int, t: float):
del step, t # Unused.
self._start_session()
def _start_session(self):
profiler.collect(
logdir=self._logdir,
# Callback is executed asynchronously, so bind `self._previous_step`
callback=functools.partial(self._end_session, step=self._previous_step),
hosts=self._hosts,
duration_ms=self._profile_duration_ms,
)
def _end_session(self, url: Optional[str], *, step: int):
platform.work_unit().create_artifact(
platform.ArtifactType.URL,
url,
description=f"[{step}] Profile",
)
class PeriodicCallback(PeriodicAction):
"""This hook calls a callback function each time it triggers."""
def __init__(self,
*,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None,
callback_fn: Callable,
execute_async: bool = False,
pass_step_and_time: bool = True):
"""Initializes a new periodic Callback action.
Args:
every_steps: See `PeriodicAction.__init__()`.
every_secs: See `PeriodicAction.__init__()`.
on_steps: See `PeriodicAction.__init__()`.
callback_fn: A callback function. It must accept `step` and `t` as
arguments; arguments are passed by keyword.
execute_async: if True wraps the callback into an async call.
pass_step_and_time: if True the step and t are passed to the callback.
"""
super().__init__(
every_steps=every_steps, every_secs=every_secs, on_steps=on_steps)
self._cb_results = collections.deque(maxlen=1)
self.pass_step_and_time = pass_step_and_time
if execute_async:
logging.info("Callback will be executed asynchronously. "
"Errors are raised when they become available.")
self._cb_fn = asynclib.Pool(callback_fn.__name__)(callback_fn)
else:
self._cb_fn = callback_fn
def __call__(self, step: int, t: Optional[float] = None, **kwargs) -> bool:
if t is None:
t = time.monotonic()
self._init_and_check(step, t)
if self._should_trigger(step, t):
# Additional arguments to the callback are passed here through **kwargs.
self._apply(step, t, **kwargs)
self._after_apply(step, t)
return True
return False
def get_last_callback_result(self):
"""Returns the last cb result."""
return self._cb_results[0]
def _apply(self, step, t, **kwargs):
if self.pass_step_and_time:
result = self._cb_fn(step=step, t=t, **kwargs)
else:
result = self._cb_fn(**kwargs)
self._cb_results.append(result)
================================================
FILE: clu/periodic_actions_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for perodic actions."""
import tempfile
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from clu import periodic_actions
class ReportProgressTest(parameterized.TestCase):
def test_every_steps(self):
hook = periodic_actions.ReportProgress(
every_steps=4, every_secs=None, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
t += 0.11
self.assertFalse(hook(2, t))
t += 0.13
self.assertFalse(hook(3, t))
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)
def test_every_secs(self):
hook = periodic_actions.ReportProgress(
every_steps=None, every_secs=0.3, num_train_steps=10
)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(hook(1, t))
t += 0.11
self.assertFalse(hook(2, t))
t += 0.13
self.assertFalse(hook(3, t))
t += 0.12
self.assertTrue(hook(4, t))
# We did 1 step every 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 8.3 steps/s, 40.0% (4/10),"
" ETA: 0m"
],
)
def test_without_num_train_steps(self):
report = periodic_actions.ReportProgress(every_steps=2)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)
def test_with_persistent_notes(self):
report = periodic_actions.ReportProgress(every_steps=2)
report.set_persistent_notes("Hello world")
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output,
["INFO:absl:Setting work unit notes: Hello world\n8.3 steps/s"],
)
def test_unknown_cardinality(self):
report = periodic_actions.ReportProgress(every_steps=2)
t = time.monotonic()
with self.assertLogs(level="INFO") as logs:
self.assertFalse(report(1, t))
self.assertTrue(report(2, t + 0.12))
# We did 1 step in 0.12s => 8.333 steps/s.
self.assertEqual(
logs.output, ["INFO:absl:Setting work unit notes: 8.3 steps/s"]
)
def test_called_every_step(self):
hook = periodic_actions.ReportProgress(every_steps=3, num_train_steps=10)
t = time.monotonic()
with self.assertRaisesRegex(
ValueError, "PeriodicAction must be called after every step"
):
hook(1, t)
hook(11, t) # Raises exception.
@parameterized.named_parameters(
("_nowait", False),
("_wait", True),
)
@mock.patch("time.monotonic")
def test_named(self, wait_jax_async_dispatch, mock_time):
mock_time.return_value = 0
hook = periodic_actions.ReportProgress(
every_steps=1, every_secs=None, num_train_steps=10
)
def _wait():
# Here we depend on hook._executor=ThreadPoolExecutor(max_workers=1)
hook._executor.submit(lambda: None).result()
self.assertFalse(hook(1)) # Never triggers on first execution.
with hook.timed("test1", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 1
_wait()
with hook.timed("test2", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 2
_wait()
with hook.timed("test1", wait_jax_async_dispatch):
_wait()
mock_time.return_value = 3
_wait()
mock_time.return_value = 4
with self.assertLogs(level="INFO") as logs:
self.assertTrue(hook(2))
self.assertEqual(
logs.output,
[
"INFO:absl:Setting work unit notes: 0.2 steps/s, 20.0% (2/10), ETA:"
" 0m (0m : 50.0% test1, 25.0% test2)"
],
)
@mock.patch("time.monotonic")
def test_write_metrics(self, time_mock):
time_mock.return_value = 0
writer_mock = mock.Mock()
hook = periodic_actions.ReportProgress(
every_steps=2, every_secs=None, writer=writer_mock
)
time_mock.return_value = 1
hook(1)
time_mock.return_value = 2
hook(2)
self.assertEqual(
writer_mock.write_scalars.mock_calls,
[
mock.call(2, {"steps_per_sec": 1}),
mock.call(2, {"uptime": 2}),
],
)
class DummyProfilerSession:
"""Dummy Profiler that records the steps at which sessions started/ended."""
def __init__(self):
self.step = None
self.start_session_call_steps = []
self.end_session_call_steps = []
def start_session(self):
self.start_session_call_steps.append(self.step)
def end_session_and_get_url(self, tag):
del tag
self.end_session_call_steps.append(self.step)
class ProfileTest(absltest.TestCase):
@mock.patch.object(periodic_actions, "profiler", autospec=True)
@mock.patch("time.monotonic")
def test_every_steps(self, mock_time, mock_profiler):
start_steps = []
stop_steps = []
step = 0
def add_start_step(logdir):
del logdir # unused
start_steps.append(step)
def add_stop_step():
stop_steps.append(step)
mock_profiler.start.side_effect = add_start_step
mock_profiler.stop.side_effect = add_stop_step
hook = periodic_actions.Profile(
logdir=tempfile.mkdtemp(),
num_profile_steps=2,
profile_duration_ms=2_000,
first_profile=3,
every_steps=7,
)
for step in range(1, 18):
mock_time.return_value = step - 0.5 if step == 9 else step
hook(step)
self.assertEqual([3, 7, 14], start_steps)
# Note: profiling 7..10 instead of 7..9 because 7..9 took only 1.5 seconds.
self.assertEqual([5, 10, 16], stop_steps)
class ProfileAllHostsTest(absltest.TestCase):
@mock.patch.object(periodic_actions, "profiler", autospec=True)
def test_every_steps(self, mock_profiler):
start_steps = []
step = 0
def profile_collect(logdir, callback, hosts, duration_ms):
del logdir, callback, hosts, duration_ms # unused
start_steps.append(step)
mock_profiler.collect.side_effect = profile_collect
hook = periodic_actions.ProfileAllHosts(
logdir=tempfile.mkdtemp(),
profile_duration_ms=2_000,
first_profile=3,
every_steps=7,
)
for step in range(1, 18):
hook(step)
self.assertEqual([3, 7, 14], start_steps)
class PeriodicCallbackTest(absltest.TestCase):
def test_every_steps(self):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(
every_steps=2, callback_fn=callback
)
for step in range(1, 10):
hook(step, 3, remainder=step % 3)
expected_calls = [
mock.call(remainder=2, step=2, t=3),
mock.call(remainder=1, step=4, t=3),
mock.call(remainder=0, step=6, t=3),
mock.call(remainder=2, step=8, t=3),
]
self.assertListEqual(expected_calls, callback.call_args_list)
@mock.patch("time.monotonic")
def test_every_secs(self, mock_time):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(every_secs=2, callback_fn=callback)
for step in range(1, 10):
mock_time.return_value = float(step)
hook(step, remainder=step % 5)
# Note: time will be initialized at 1 so hook runs at steps 4 & 7.
expected_calls = [
mock.call(remainder=4, step=4, t=4.0),
mock.call(remainder=2, step=7, t=7.0),
]
self.assertListEqual(expected_calls, callback.call_args_list)
def test_on_steps(self):
callback = mock.Mock()
hook = periodic_actions.PeriodicCallback(on_steps=[8], callback_fn=callback)
for step in range(1, 10):
hook(step, remainder=step % 3)
callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY)
def test_async_execution(self):
out = []
def cb(step, t):
del t
out.append(step)
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True
)
hook(0)
hook(1)
hook(2)
hook(3)
# Block till all the hooks have finished.
hook.get_last_callback_result().result()
# Check order of execution is preserved.
self.assertListEqual(out, [0, 1, 2, 3])
def test_error_async_is_forwarded(self):
def cb(step, t):
del step
del t
raise Exception
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, execute_async=True
)
hook(0)
with self.assertRaises(Exception):
hook(1)
def test_function_without_step_and_time(self):
# This must be used with pass_step_and_time=False.
def cb():
return 5
hook = periodic_actions.PeriodicCallback(
every_steps=1, callback_fn=cb, pass_step_and_time=False
)
hook(0)
hook(1)
self.assertEqual(hook.get_last_callback_result(), 5)
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/platform/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Methods for interacting with the experiment platform.
Use cases include informing the platform of the experiment status and providing
a platform independent interface for interactions.
"""
import threading
from clu.platform.interface import ArtifactType
from clu.platform.interface import WorkUnit
from clu.platform.local import LocalWorkUnit
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
_work_unit = None
_work_unit_lock = threading.Lock()
def work_unit() -> WorkUnit:
"""Gets the global work unit for this experiment trial."""
global _work_unit
if _work_unit is None:
with _work_unit_lock:
_work_unit = LocalWorkUnit()
return _work_unit
================================================
FILE: clu/platform/interface.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface work units."""
import abc
import enum
from typing import Any
class ArtifactType(enum.Enum):
# A URL for dashboards, etc.
URL = 1
# File path.
FILE = 2
# Directory path.
DIRECTORY = 3
class WorkUnit(abc.ABC):
"""A work unit represents a single trial in an experiment.
Experiments will usually have multiple work units with different
hyperparameters. Each work unit can have multiple jobs (training,
evaluation, etc.). And jobs can have multiple tasks when the training
is distributed across multiple machines.
"""
@property
@abc.abstractmethod
def experiment_id(self):
"""ID of the experiment of the work unit."""
@property
@abc.abstractmethod
def id(self):
"""Unique identifier for the work unit."""
@property
def name(self):
"""Returns the name of the work unit as /.
XID is a ID of the experiment and WID is the number of the work unit
within the experiment.
Returns:
The work unit name. e.g. 12345/1.
"""
return f"{self.experiment_id}/{self.id}"
@abc.abstractmethod
def set_notes(self, msg: str):
"""Sets the notes for this work unit. These are displayed in the UI."""
@abc.abstractmethod
def set_task_status(self, msg: str):
"""Sets the status string for this task."""
@abc.abstractmethod
def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
description: str):
"""Creates an artifact entry for the work unit."""
================================================
FILE: clu/platform/local.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation for platform functionality when running locally."""
from typing import Any
from absl import logging
from clu.platform import interface
WorkUnit = interface.WorkUnit
ArtifactType = interface.ArtifactType
class LocalWorkUnit(WorkUnit):
"""Dummy work unit for running locally."""
@property
def experiment_id(self):
"""ID of the experiment of the work unit."""
return -1
@property
def id(self):
"""Unique identifier for the work unit."""
return -1
def set_notes(self, msg: str):
"""Sets the notes for this work unit."""
logging.info("Setting work unit notes: %s", msg)
def set_task_status(self, msg: str):
"""Sets the status string for this task."""
logging.info("Setting task status: %s", msg)
def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
description: str):
"""Creates an artifact entry for the work unit."""
logging.info("Created artifact %s of type %s and value %s.", description,
artifact_type, artifact)
================================================
FILE: clu/preprocess_spec.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for parsing a preprocessing spec.
A preprocessing spec is a list of preprocessing ops separated by '|' that can be
applied sequentially as a preprocessing function. The preprocessing ops are
provided as input and must implement the PreprocessOp protocol. While not
strictly required we also recommend annotating preprocess ops as dataclasses.
By convention the preprocessing function operates on dictionaries of features.
Each op can change the dictionary by modifying, adding or removing dictionary
entries. Dictionary entries should be tensors, keys should be strings.
(For common data types we recommend using the feature keys used in TFDS.)
Example spec: 'fn1|fn2(3)|fn3(keyword=5)'
This will construct the following preprocessing function:
def preprocess_fn(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
features = fn1(features)
features = fn2(features, 3)
features = fn3(features, keyword=5)
return features
See preprocess_spec_test.py for some simple examples.
"""
import abc
import ast
import dataclasses
import inspect
import re
import sys
from typing import Dict, List, Sequence, Tuple, Type, TypeVar, Union
from absl import logging
from flax import traverse_util
import jax.numpy as jnp
import tensorflow as tf
import typing_extensions
from typing_extensions import Protocol
# Feature dictionary. Arbitrary nested dictionary with string keys and
# tf.Tensor as leaves.
Tensor = Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor]
# TFDS allows for nested `Features` ...
Features = Dict[str, Union[Tensor, "Features"]]
# ... but it's usually a better idea NOT to nest them. Also better for PyType.
FlatFeatures = Dict[str, Tensor]
D = TypeVar("D", FlatFeatures, tf.data.Dataset)
# Feature name for the random seed for tf.random.stateless_* ops. By
# convention ops should split of their random seed and keep the SEED_KEY
# feature:
# ```
# features[SEEQ_KEY], seed = tf.unstack(
# tf.random.experimental.stateless_split(features[SEED_KEY]))
# ````
SEED_KEY = "_seed"
# Regex that finds upper case characters.
_CAMEL_CASE_RGX = re.compile(r"(? Features:
"""Applies the preprocessing op to the features."""
# Deprecated. Please use `grain.tensorflow.MapTransform`.
class MapTransform(abc.ABC):
"""Base class for transformations of single elements.
This class implements the PreprocessOp interface and also:
- Limits the features to a flat dictionary (instead of an arbitrary nested
dictionary).
- Provides a convenient implementation of `__call__` that can automatically
apply the single transformation to a single example (`FlatFeatures`) or a
`tf.data.Dataset`. The latter is convenient for SeqIO users migrating to
preprocess_spec.py. For multiple transformations we still recommend users
to use the `PreprocessFn` class.
- Enforces subclasses to be a dataclasses.
"""
def __new__(cls, *args, **kwargs):
del args, kwargs
# Check that our subclass instance is a dataclass. We cannot do this with
# `__init_subclass__`` because the dataclasses.dataclass decorator wraps
# the intermediate class which is a subclass of MapTransform but not a
# dataclass.
if not dataclasses.is_dataclass(cls):
raise ValueError(
f"Class {cls} is not a dataclass. We strongly recommend annotating "
"transformations with `@dataclasses.dataclass(frozen=True)`.")
return super().__new__(cls)
def __call__(self, features: D) -> D:
"""Applies the transformation to the features or the dataset."""
logging.warning("clu.preprocess_spec.MapTransform is deprecated. Please "
"switch to grain.tensorflow.MapTransform.")
if isinstance(features, tf.data.Dataset):
return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE)
return self._transform(features)
@abc.abstractmethod
def _transform(self, features: FlatFeatures) -> FlatFeatures:
"""Transforms the features."""
# Deprecated. Please use `grain.tensorflow.RandomMapTransform`.
class RandomMapTransform(MapTransform, abc.ABC):
"""Base class for random transformations of single elements.
We require all random transformations to use stateless random operations (e.g.
`tf.random.stateless_uniform()`) and respect the provided random seed. The
user can expect the random seed to be unique for the element.
If multiple random seeds are required the user can split the seed into N
new seeds:
```
seeds = tf.unstack(tf.random.experimental.stateless_split(seed, N))
```
"""
def __call__(self, features: D) -> D:
logging.warning("clu.preprocess_spec.RandomMapTransform is deprecated. "
"Please switch to grain.tensorflow.RandomMapTransform.")
if isinstance(features, tf.data.Dataset):
return features.map(self, num_parallel_calls=tf.data.AUTOTUNE)
next_seed, seed = tf.unstack(
tf.random.experimental.stateless_split(features.pop(SEED_KEY)))
features = self._transform(features, seed)
features[SEED_KEY] = next_seed
return features
@abc.abstractmethod
def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Transforms the features only using stateless random ops."""
# Deprecated. Please use `grain.tensorflow.FilterMapTransform`.
class FilterTransform(abc.ABC):
def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
logging.warning("clu.preprocess_spec.FilterTransform is deprecated. Please "
"switch to grain.tensorflow.FilterTransform.")
return dataset.filter(self._predicate)
@abc.abstractmethod
def _predicate(self, features: FlatFeatures) -> tf.Tensor:
"""Returns a True if the element should be kept."""
def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]:
"""Helper to return all preprocess ops in a module.
Modules that define processing ops can simply define:
all_ops = lambda: process_spec.get_all_ops(__name__)
all_ops() will then return a list with all dataclasses implementing the
PreprocessOp protocol.
Args:
module_name: Name of the module. The module must already be imported.
Returns:
List of tuples of process ops. The first tuple element is the class name
converted to snake case (MyAwesomeTransform => my_awesome_transform) and
the second element is the class.
"""
def is_op(x):
return (inspect.isclass(x) and dataclasses.is_dataclass(x) and
issubclass(x, PreprocessOp))
op_name = lambda n: _CAMEL_CASE_RGX.sub("_", n).lower()
members = inspect.getmembers(sys.modules[module_name])
return [(op_name(name), op) for name, op in members if is_op(op)]
def _jax_supported_tf_types():
types = [
x for _, x in inspect.getmembers(tf.dtypes)
if isinstance(x, tf.dtypes.DType) and hasattr(jnp, x.name)
]
# bool is called bool_ in jax and won't be found by the expression above.
return types + [tf.bool]
@dataclasses.dataclass
class OnlyJaxTypes:
"""Removes all features which types are not supported by JAX.
This filters dense tensors by dtype and removes sparse and ragged tensors.
The latter don't have an equivalent in JAX.
Attr:
types: List of allowed types. Defaults to all TF types that can be have an
equivalant in jax.numpy.
"""
types: List[tf.dtypes.DType] = dataclasses.field(
default_factory=_jax_supported_tf_types)
def __call__(self, features: Features) -> Features:
features = traverse_util.flatten_dict(features)
for name in list(features):
dtype = features[name].dtype
if dtype not in self.types:
del features[name]
logging.warning(
"Removing feature %r because dtype %s is not supported in JAX.",
name, dtype)
elif isinstance(features[name], tf.SparseTensor):
del features[name]
logging.warning(
"Removing feature %r because sparse tensors are not "
"supported in JAX.", name)
elif isinstance(features[name], tf.RaggedTensor):
del features[name]
logging.warning(
"Removing feature %r because ragged tensors are not support in "
"JAX.", name)
features = traverse_util.unflatten_dict(features)
return features # pytype: disable=bad-return-type
@dataclasses.dataclass
class PreprocessFn:
"""Chain of preprocessing ops combined to a single preprocessing function.
Attributes:
ops: List of feature transformations. Transformations will be applied in the
given order.
only_jax_types: If True will add the `OnlyJaxTypes` transformation at the
end.
"""
ops: Sequence[PreprocessOp]
only_jax_types: bool
def __call__(self, features: Features) -> Features:
"""Sequentially applies all `self.ops` and returns the result."""
logging.info("Features before preprocessing: %s",
_describe_features(features))
features = features.copy()
for op in self.ops:
features = op(features)
logging.info("Features after op %s:\n%s", op,
_describe_features(features))
logging.info("Features after preprocessing: %s",
_describe_features(features))
if self.only_jax_types:
features = OnlyJaxTypes()(features)
return features
def __add__(self, other: "PreprocessFn") -> "PreprocessFn":
"""Concatenates two `PreprocessingFn`."""
if not isinstance(other, PreprocessFn):
raise ValueError("Can only add other instances of `PreprocessFn`.")
return PreprocessFn(
ops=tuple(self.ops) + tuple(other.ops),
only_jax_types=self.only_jax_types or other.only_jax_types,
)
def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn":
"""Returns a `PreprocessFn` of the sliced ops."""
return PreprocessFn(
ops=self.ops[op_index]
if isinstance(op_index, slice) else [self.ops[op_index]],
only_jax_types=self.only_jax_types,
)
def _get_op_class(
expr: List[ast.stmt],
available_ops: Dict[str, Type[PreprocessOp]]) -> Type[PreprocessOp]:
"""Gets the process op fn from the given expression."""
if isinstance(expr, ast.Call):
fn_name = expr.func.id
elif isinstance(expr, ast.Name):
fn_name = expr.id
else:
raise ValueError(
f"Could not parse function name from expression: {expr!r}.")
if fn_name in available_ops:
return available_ops[fn_name]
raise ValueError(
f"'{fn_name}' is not available (available ops: {list(available_ops)}).")
def _parse_single_preprocess_op(
spec: str, available_ops: Dict[str, Type[PreprocessOp]]) -> PreprocessOp:
"""Parsing the spec for a single preprocess op.
The op can just be the method name or the method name followed by any
arguments (both positional and keyword) to the method.
See the test cases for some valid examples.
Args:
spec: String specifying a single processing operations.
available_ops: Available preprocessing ops.
Returns:
The ProcessOp corresponding to the spec.
"""
try:
expr = ast.parse(spec, mode="eval").body # pytype: disable=attribute-error
except SyntaxError as e:
raise ValueError(f"{spec!r} is not a valid preprocess op spec.") from e
op_class = _get_op_class(expr, available_ops) # pytype: disable=wrong-arg-types
# Simple case without arguments.
if isinstance(expr, ast.Name):
return op_class()
assert isinstance(expr, ast.Call)
args = [ast.literal_eval(arg) for arg in expr.args]
kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords}
if not args:
return op_class(**kwargs)
# Translate positional arguments into keyword arguments.
available_arg_names = [f.name for f in dataclasses.fields(op_class)]
for i, arg in enumerate(args):
name = available_arg_names[i]
if name in kwargs:
raise ValueError(
f"Argument {name} to {op_class} given both as positional argument "
f"(value: {arg}) and keyword argument (value: {kwargs[name]}).")
kwargs[name] = arg
return op_class(**kwargs)
def parse(spec: str,
available_ops: List[Tuple[str, Type[PreprocessOp]]],
*,
only_jax_types: bool = True) -> PreprocessFn:
"""Parses a preprocess spec; a '|' separated list of preprocess ops."""
available_ops = dict(available_ops)
if not spec.strip():
ops = []
else:
ops = [
_parse_single_preprocess_op(s, available_ops) for s in spec.split("|")
]
return PreprocessFn(ops, only_jax_types=only_jax_types)
def _describe_features(features: Features) -> str:
description = {}
for k, v in features.items():
if isinstance(v, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
description[k] = f"{v.dtype.name}{v.shape}"
elif isinstance(v, dict):
description[k] = _describe_features(v)
else:
description[k] = f"Unsupported type {type(v)} at feature '{k}'."
return str(description)
================================================
FILE: clu/preprocess_spec_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from absl import logging
from absl.testing import parameterized
from clu import preprocess_spec
import tensorflow as tf
Features = preprocess_spec.Features
SEED_KEY = preprocess_spec.SEED_KEY
@dataclasses.dataclass(frozen=True)
class ToFloat:
def __call__(self, features: Features) -> Features:
return {k: tf.cast(v, tf.float32) / 255.0 for k, v in features.items()}
@dataclasses.dataclass(frozen=True)
class Rescale:
scale: int = 1
def __call__(self, features: Features) -> Features:
features["image"] *= self.scale
features["segmentation_mask"] *= self.scale
return features
@dataclasses.dataclass(frozen=True)
class AddRandomInteger(preprocess_spec.RandomMapTransform):
def _transform(self, features, seed):
features["x"] = tf.random.stateless_uniform([], seed)
return features
all_ops = lambda: preprocess_spec.get_all_ops(__name__)
class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase):
"""Tests for parsing preprocessing op spec."""
def test_no_arguments(self):
op = preprocess_spec._parse_single_preprocess_op("rescale", dict(all_ops()))
logging.info("op: %r", op)
self.assertEqual(str(op), "Rescale(scale=1)")
def test_positional_argument(self):
op = preprocess_spec._parse_single_preprocess_op("rescale(2)",
dict(all_ops()))
logging.info("op: %r", op)
self.assertEqual(str(op), "Rescale(scale=2)")
def test_keyword_argument(self):
op = preprocess_spec._parse_single_preprocess_op("rescale(scale=3)",
dict(all_ops()))
logging.info("op: %r", op)
self.assertEqual(str(op), "Rescale(scale=3)")
def test_invalid_op_name(self):
with self.assertRaisesRegex(
ValueError, r"'does_not_exist' is not available \(available ops: "
r"\['add_random_integer', 'rescale', 'to_float'\]\)."):
preprocess_spec._parse_single_preprocess_op("does_not_exist",
dict(all_ops()))
def test_invalid_spec(self):
with self.assertRaisesRegex(
ValueError, r"'rescale\)' is not a valid preprocess op spec."):
preprocess_spec._parse_single_preprocess_op("rescale)", dict(all_ops()))
def test_pos_and_kw_arg(self):
with self.assertRaisesRegex(
ValueError,
r"Rescale'> given both as positional argument \(value: 2\) and keyword "
r"argument \(value: 3\)."):
preprocess_spec._parse_single_preprocess_op("rescale(2, scale=3)",
dict(all_ops()))
def test_parsing_empty_string(self):
preprocess_fn = preprocess_spec.parse("", all_ops())
self.assertEqual(
str(preprocess_fn), "PreprocessFn(ops=[], only_jax_types=True)")
def test_multi_op_spec(self):
preprocess_fn = preprocess_spec.parse("to_float|rescale(3)", all_ops())
logging.info("preprocess_fn: %r", preprocess_fn)
self.assertEqual(str(preprocess_fn.ops), "[ToFloat(), Rescale(scale=3)]")
def test_two_tensors(self):
preprocess_fn = preprocess_spec.parse("rescale(scale=7)", all_ops())
x = {"image": tf.constant(3), "segmentation_mask": tf.constant(2)}
y = preprocess_fn(x)
self.assertEqual(y, {
"image": tf.constant(21),
"segmentation_mask": tf.constant(14),
})
def test_only_jax_types(self):
preprocess_fn = preprocess_spec.parse("", all_ops())
x = {
"image": tf.constant(2),
# Strings are not supported.
"label": tf.constant("bla"),
# Sparse tensors are not supported.
"foo": tf.sparse.eye(4),
# Ragged tensors are not supported.
"bar": tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]]),
}
y = preprocess_fn(x)
self.assertEqual(y, {"image": tf.constant(2)})
def test_only_jax_types_nested_inputs(self):
preprocess_fn = preprocess_spec.parse("", all_ops())
x = {
"nested": {
"not_allowed": tf.constant("bla"),
"allowed": tf.constant(2),
}
}
y = preprocess_fn(x)
self.assertEqual(y, {"nested": {"allowed": tf.constant(2)}})
def test_not_only_jax_types(self):
preprocess_fn = preprocess_spec.parse("", all_ops(), only_jax_types=False)
x = {"image": tf.constant(2), "label": tf.constant("bla")}
y = preprocess_fn(x)
self.assertEqual(y, x)
def test_add_preprocess_fn(self):
op1 = ToFloat()
op2 = ToFloat()
op3 = ToFloat()
fn1 = preprocess_spec.PreprocessFn(ops=(op1, op2), only_jax_types=False)
fn2 = preprocess_spec.PreprocessFn(ops=(op3,), only_jax_types=True)
fn12 = fn1 + fn2
# Note: `+` is not supported on Sequence[PreprocessOp]; need to use `list`.
self.assertSequenceEqual(fn12.ops, list(fn1.ops) + list(fn2.ops))
self.assertTrue(fn12.only_jax_types)
def test_slice_preprocess_fn(self):
op1 = ToFloat()
op2 = Rescale()
op3 = ToFloat()
fn = preprocess_spec.PreprocessFn(ops=(op1, op2, op3), only_jax_types=True)
self.assertEqual(fn[:-1].ops, (op1, op2))
self.assertTrue(fn[:-1].only_jax_types)
self.assertEqual(fn[1].ops, [op2])
self.assertTrue(fn[1].only_jax_types)
def test_random_map_transform(self):
ds = tf.data.Dataset.from_tensor_slices(
{SEED_KEY: [[1, 2], [3, 4], [1, 2]]})
ds = ds.map(AddRandomInteger())
actual = list(ds)
print("actual:", actual)
expect = [
# Random number was generated and random seed changed.
{
"x": 0.8838011,
SEED_KEY: [1105988140, 1738052849]
},
{
"x": 0.33396423,
SEED_KEY: [-1860230133, -671226999]
},
# Same random seed as first element creates same outcome.
{
"x": 0.8838011,
SEED_KEY: [1105988140, 1738052849]
},
]
self.assertAllClose(actual, expect)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/profiler.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Methods for running triggering a profiler for accelerators.
Where results are stored depends on the platform (e.g. TensorBoard).
"""
from collections.abc import Callable, Sequence
import threading
from typing import Optional, Protocol
from absl import logging
import jax
def start(logdir: str, options=None):
"""Starts profiling."""
if options is not None:
raise NotImplementedError(
"'options' not supported by clu.profiler.start(). Please file an issue "
"at https://github.com/google/jax/issues requesting profiler option "
"support if you need this feature.")
if logdir is None:
raise ValueError("Must specify logdir where profile should be written!")
jax.profiler.start_trace(logdir)
def stop() -> Optional[str]:
"""Stops profiling."""
jax.profiler.stop_trace()
CollectCallback = Callable[[Optional[str]], None]
def collect(logdir: str,
callback: CollectCallback,
hosts: Optional[Sequence[str]] = None,
duration_ms: int = 3_000):
"""Calls start() followed by stop() after specified duration."""
del hosts # not used.
start(logdir)
def timer_cb():
stop()
callback(None)
threading.Timer(duration_ms / 1e3, timer_cb).start()
================================================
FILE: clu/run_pytest.google.sh
================================================
#!/bin/bash
set -e -x
CLU_DST="${CLU_DST:-/tmp/clu}"
CLU_ENV="${CLU_ENV:-/tmp/clu_env}"
copybara third_party/py/clu/copy.bara.sky local .. \
--folder-dir="${CLU_DST}" --ignore-noop
# Note: we're reusing the environment if it already exists.
mkdir -p "${CLU_ENV}"
cd "${CLU_ENV}"
python3 -m virtualenv .
. bin/activate
cd "${CLU_DST}"
pip install . .[test]
pytest
================================================
FILE: clu/values.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines available types for use by Metrics when written.
A Metric should return one of the following types when compute() is called.
"""
import dataclasses
from typing import Any, Union, Protocol, runtime_checkable
import jax.numpy as jnp
import numpy as np
ArrayType = Union[np.ndarray, jnp.ndarray]
ScalarType = Union[int, float, np.number, np.ndarray, jnp.ndarray]
@runtime_checkable
class Value(Protocol):
"""Class defining available metric computation return values.
Types mirror those available in MetricWriter. See
clu/metric_writers/interface.py
"""
value: Any
@dataclasses.dataclass
class Summary(Value):
value: ArrayType
metadata: Any
@dataclasses.dataclass
class Scalar(Value):
value: ScalarType
@dataclasses.dataclass
class Image(Value):
"""Image type.
Mapping from image key to images. Images should have the shape [N, H, W, C] or
[H, W, C], where H is the height, W is the width and C the
number of channels (1 or 3). N is the number of images that will be
written. Image dimensions can differ between different image keys but
not between different steps for the same image key.
"""
value: ArrayType
@dataclasses.dataclass
class Audio(Value):
"""Audio type.
Mapping from audio key to audios. Audios should have the shape [N, T, C],
where T is the time length and C the number of channels (1 -
mono, 2 - stereo, >= 3 - surround; not all writers support any number of
channels). N is the number of audios that will be written. Audio
dimensions can differ between different audio keys but not between
different steps for the same audio key. Values should be floating-point
values in [-1, +1].
"""
value: ArrayType
sample_rate: int
@dataclasses.dataclass
class Text(Value):
value: str
@dataclasses.dataclass
class Histogram(Value):
# value must be an array of counts (integers)
value: ArrayType
num_buckets: int
@dataclasses.dataclass
class HyperParam(Value):
"""The name of the hyperparameter should be handled outside this class.
Value should correspond to a single hyperparameter, while a Mapping[str,
HyperParam] (name to HyperParam) is maintained independently.
"""
value: Any
================================================
FILE: clu_synopsis.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "4CldxEhqQac_"
},
"source": [
"# CLU - Common Loop Utils\n",
"\n",
"\n",
"\n",
"\n",
" pip install clu\n",
"\n",
"https://github.com/google/CommonLoopUtils\n",
"\n",
"This package is usually used with\n",
"[JAX](https://github.com/google/jax)\n",
"/\n",
"[Flax](https://github.com/google/flax)\n",
"ML projects, but it can also be used with other ML frameworks.\n",
"\n",
"JAX/Flax are designed for flexibility and the user remains in control of the\n",
"training loop. Writing and maintaining your own training loop gives lots of\n",
"flexibility but also quickly leads to non-trivial amount of code that is\n",
"repeated in every project (and usually forked from an\n",
"[example](https://flax.readthedocs.io/en/latest/examples.html)\n",
"in the first place).\n",
"\n",
"`clu` provides small independent helpers to make the training loop shorter and\n",
"easier to read, while keeping maximum flexibility.\n",
"\n",
"**This Colab** walks you through the different modules of `clu` with simple\n",
"example code for showcasing the important concepts and to be pasted into your\n",
"training loop to get started using `clu`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6CxpJUPTd1wD"
},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "sTS7uCbgizD8"
},
"outputs": [],
"source": [
"!pip install -q clu\n",
"# Alternatively, install latest version directly from Github:\n",
"# !pip install -q git+https://github.com/google/CommonLoopUtils"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6wmDlgxtdsQ7",
"outputId": "6cdd353c-165e-4b88-efc4-8dd46b01c331"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Name: clu\n",
"Version: 0.0.8\n",
"Summary: Set of libraries for ML training loops in JAX.\n",
"Home-page: http://github.com/google/CommonLoopUtils\n",
"Author: Common Loop Utils Authors\n",
"Author-email: no-reply@google.com\n",
"License: Apache 2.0\n",
"Location: /usr/local/lib/python3.9/dist-packages\n",
"Requires: absl-py, etils, flax, jax, jaxlib, ml-collections, numpy, packaging, typing-extensions, wrapt\n",
"Required-by: \n"
]
}
],
"source": [
"!pip show clu"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "WrGls5mON3Qr"
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "-kx9lHndyN8G"
},
"outputs": [],
"source": [
"import chex\n",
"chex.set_n_cpu_devices(2) # Simulate 2 local devices in a CPU Colab runtime."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QsJvFY1Xjfw-"
},
"source": [
"### `clu.metric_writers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C83tJrZnzTns"
},
"source": [
"The module [`metric_writers`] provides a simple [interface] to write time series\n",
"metrics in a unified way.\n",
"\n",
"Metric writers provided:\n",
"\n",
"- `SummaryWriter`: Uses `tf.summary` to write summary files. For display in\n",
" TensorBoard.\n",
"- `LoggingWriter`: Simply writes values to the INFO log. This obviously only\n",
" supports data types that can be converted to text but is still helpful for\n",
" seeing the training progress on the command line.\n",
"- `TorchTensorboardWriter`: Uses `torch.utils.tensorboard` to write summary\n",
" files. Use this writer for the Pytorch-based code.\n",
"\n",
"Additional we provide metric writers to combine multiple metric writers\n",
"(`MultiWriter`) and to move the write operation to a background thread\n",
"(`AsyncWriter`).\n",
"\n",
"[`metric_writers`]: https://github.com/google/CommonLoopUtils/blob/master/clu/metric_writers/__init__.py\n",
"[interface]: https://github.com/google/CommonLoopUtils/blob/master/clu/metric_writers/interface.py\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "aBhPGVLKz8c6"
},
"outputs": [],
"source": [
"from absl import logging\n",
"logging.set_verbosity(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "uO_ezNP216XM"
},
"outputs": [],
"source": [
"logdir = './metrics'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jt1eLAfNz_wr",
"outputId": "70668139-45cc-416a-ff3d-d8d99aece0a8"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:[0] loss=1.000000\n",
"INFO:absl:[1] loss=0.900000\n",
"INFO:absl:[2] loss=0.810000\n",
"INFO:absl:[3] loss=0.729000\n",
"INFO:absl:[4] loss=0.656100\n",
"INFO:absl:[5] loss=0.590490\n",
"INFO:absl:[6] loss=0.531441\n",
"INFO:absl:[7] loss=0.478297\n",
"INFO:absl:[8] loss=0.430467\n",
"INFO:absl:[9] loss=0.387420\n"
]
}
],
"source": [
"from clu import metric_writers\n",
"\n",
"# Handy shortcut to create create async logging/tensorboard writer.\n",
"writer = metric_writers.create_default_writer(logdir)\n",
"for step in range(10):\n",
" writer.write_scalars(step, dict(loss=0.9**step))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9hEe8xAO0M-1"
},
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=./metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VyTm-xlM0Zvz"
},
"source": [
"### `clu.periodic_actions`\n",
"\n",
"[`periodic_actions`] are simple helpers that allow you to do in the training\n",
"loop at regular intervals. Currently we support\n",
"\n",
"- `PeriodicAction`, `PeriodicCallback`: To implement your own actions.\n",
"- `Profile`: To create TensorBoard compatible profiles.\n",
"- `ReportProgress`: To continuously print progress status updates.\n",
"\n",
"[`periodic_actions`]: https://github.com/google/CommonLoopUtils/blob/master/clu/periodic_actions.py"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vsKd1Frm1cRN",
"outputId": "da99a807-b861-464a-f13c-8fe76049b749"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Setting work unit notes: 165469.8 steps/s, 10.0% (10/100), ETA: 0m\n",
"INFO:absl:[10] steps_per_sec=165469.768737\n",
"INFO:absl:[10] uptime=0.002897\n",
"WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
"INFO:absl:Setting work unit notes: 560.2 steps/s, 20.0% (20/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 4487.2 steps/s, 30.0% (30/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 1004.3 steps/s, 40.0% (40/100), ETA: 0m\n",
"INFO:absl:[20] steps_per_sec=560.175008\n",
"INFO:absl:[20] uptime=0.020087\n",
"INFO:absl:[30] steps_per_sec=4487.220620\n",
"INFO:absl:[30] uptime=0.030033\n",
"INFO:absl:[40] steps_per_sec=1004.251398\n",
"INFO:absl:[40] uptime=0.049151\n",
"INFO:absl:Setting work unit notes: 523.4 steps/s, 50.0% (50/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 146.7 steps/s, 60.0% (60/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 975.2 steps/s, 70.0% (70/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 758.4 steps/s, 80.0% (80/100), ETA: 0m\n",
"INFO:absl:Setting work unit notes: 1590.8 steps/s, 90.0% (90/100), ETA: 0m\n",
"INFO:absl:[50] steps_per_sec=523.352102\n",
"INFO:absl:[50] uptime=0.117317\n",
"INFO:absl:[60] steps_per_sec=146.689455\n",
"INFO:absl:[60] uptime=0.127566\n"
]
}
],
"source": [
"from clu import periodic_actions\n",
"\n",
"total_steps = 100\n",
"hooks = [\n",
" # Outputs progress via metric writer (in this case logs & TensorBoard).\n",
" periodic_actions.ReportProgress(\n",
" num_train_steps=total_steps,\n",
" every_steps=10, writer=writer),\n",
" periodic_actions.Profile(logdir=logdir)\n",
"]\n",
"\n",
"for step in range(total_steps):\n",
" for hook in hooks:\n",
" hook(step)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uB7-MTls1-No",
"outputId": "bf1bd0b3-48e0-475c-81df-6dc9ebb0ce6a"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:[70] steps_per_sec=975.188475\n",
"INFO:absl:[70] uptime=0.140739\n",
"INFO:absl:[80] steps_per_sec=758.410526\n",
"INFO:absl:[80] uptime=0.147027\n",
"INFO:absl:[90] steps_per_sec=1590.824886\n",
"INFO:absl:[90] uptime=0.159763\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"total 8.0K\n",
"-rw-r--r-- 1 root root 1.9K Apr 25 07:46 events.out.tfevents.1682408760.c6ce21f3d054.192.0.v2\n",
"-rw-r--r-- 1 root root 1.9K Apr 25 07:49 events.out.tfevents.1682408989.c6ce21f3d054.1936.0.v2\n"
]
}
],
"source": [
"# If you click on \"refresh\" in above TensorBoard you'll now see a new\n",
"# \"steps_per_sec\" metric...\n",
"!ls -lh metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ueom-uBWLbeQ"
},
"source": [
"### `clu.metrics`\n",
"\n",
"The [`metrics`] module provides a framework for functional metric computation.\n",
"Note that this module does **not** include the actual metric definitions (other\n",
"than `metrics.Accuracy` that is provided for demonstration purposes), but\n",
"rather provides abstractions that can be used to compute metrics in a\n",
"distributed distributed environment.\n",
"\n",
"This section is a bit longer than the previous sections and walks you through\n",
"the following parts:\n",
"\n",
"1. How `metrics.Metric` is computed, and defining \"averageable\" metrics.\n",
"2. Using `metrics.Collection` to compute several metrics at once.\n",
"3. Aggregating in an evaluation step that is transformed by `pmap()`.\n",
"4. Define a new metric with custom aggregation (i.e. non \"averageable\").\n",
"\n",
"\n",
"[`metrics`]: https://github.com/google/CommonLoopUtils/blob/master/clu/metrics.py"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Bti6-K5ZDZHP",
"outputId": "701eb451-8d32-4b11-e389-4599c41c581c"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Array(0.75, dtype=float32)"
]
},
"metadata": {},
"execution_count": 11
}
],
"source": [
"from clu import metrics\n",
"import flax\n",
"\n",
"# Metrics are computed in three steps:\n",
"\n",
"# 1. Compute intermediate values from model outputs\n",
"accuracy_batch1 = metrics.Accuracy.from_model_output(\n",
" logits=jnp.array([[-1., 1.], [1., -1.]]),\n",
" labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct\n",
")\n",
"accuracy_batch2 = metrics.Accuracy.from_model_output(\n",
" logits=jnp.array([[-1., 1.], [1., -1.]]),\n",
" labels=jnp.array([1, 0]), # i.e. both correct\n",
")\n",
"\n",
"# 2. Intermediate values are aggregated\n",
"accuracy = accuracy_batch1\n",
"accuracy = accuracy.merge(accuracy_batch2)\n",
"\n",
"# 3. Final metrics are computed from aggregated intermediate values:\n",
"accuracy.compute()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ga8J-z6KmMBg",
"outputId": "42dd6d88-d40e-46ea-b4f0-cba649784924"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Array(2.2, dtype=float32)"
]
},
"metadata": {},
"execution_count": 12
}
],
"source": [
"# It's easy to define your own metrics if they are \"averageable\":\n",
"\n",
"AverageLoss = metrics.Average.from_output('loss')\n",
"\n",
"AverageLoss.from_model_output(\n",
" loss=jnp.array([1.1, 3.3])\n",
").compute()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PsxIu5FfnJmW",
"outputId": "a5ecc19c-6772-4ca6-f442-52cde0bc1f57"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Array(2.1999998, dtype=float32)"
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"# You can provide a functional to derive the value-to-be-averaged:\n",
"\n",
"# Note that our metric only uses the model output named \"loss\". There can be an\n",
"# arbitrary number of additional model outputs that we don't need here (**_).\n",
"AverageSquaredLoss = metrics.Average.from_fun(lambda loss, **_: loss**2)\n",
"\n",
"AverageSquaredLoss.from_model_output(\n",
" loss=jnp.array([1.1**.5, 3.3**.5])\n",
").compute()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zuXskYw_ooKQ",
"outputId": "916be6e3-916a-4070-d00f-388ca29a16f7"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'loss': Array(2.2, dtype=float32), 'accuracy': Array(0.75, dtype=float32)}"
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"# Usually you would want to compute a collection of metrics from model outputs:\n",
"\n",
"@flax.struct.dataclass # <-- required for JAX transformations\n",
"class MyMetrics(metrics.Collection):\n",
" loss : metrics.Average.from_output('loss')\n",
" accuracy : metrics.Accuracy\n",
"\n",
"\n",
"# 1. Compute intermediate values from model outputs\n",
"my_metrics_batch1 = MyMetrics.single_from_model_output(\n",
" logits=jnp.array([[-1., 1.], [1., -1.]]),\n",
" labels=jnp.array([0, 0]), # i.e. 1st incorrect, 2nd correct\n",
" loss=jnp.array([3.3, 2.2]),\n",
")\n",
"my_metrics_batch2 = MyMetrics.single_from_model_output(\n",
" logits=jnp.array([[-1., 1.], [1., -1.]]),\n",
" labels=jnp.array([1, 0]), # i.e. both correct\n",
" loss=jnp.array([2.2, 1.1]),\n",
")\n",
"\n",
"# 2. Intermediate values are aggregated\n",
"my_metrics = my_metrics_batch1.merge(my_metrics_batch2)\n",
"\n",
"# 3. Final metrics are computed from aggregated intermediate values:\n",
"my_metrics.compute()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8v7vxEGGqArt",
"outputId": "b4d77b53-1032-4db8-c275-65aefb558c67"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'loss': Array(2.2, dtype=float32), 'accuracy': Array(0.75, dtype=float32)}"
]
},
"metadata": {},
"execution_count": 15
}
],
"source": [
"# Often you want to compute these metrics inside a pmap(). The framework\n",
"# provides the handy `Collection.gather_from_model_output` that will first\n",
"# compute the intermediate values, then call `jax.lax.all_gather()` to gather\n",
"# the intermediate values from all the devices (in a multi-host setup that's\n",
"# all the devices in the mesh, not only the local devices), and then reduce them\n",
"# by calling `Metric.merge()` in a `jax.lax.scan()` loop.\n",
"\n",
"# Sounds complicated? Using it is actually surprisingly simple:\n",
"\n",
"def fake_model(params, batch):\n",
" del params # Fake.\n",
" return batch\n",
"\n",
"def eval_step(my_metrics, params, batch):\n",
" model_outputs = fake_model(params, batch)\n",
" # IMPORTANT: If you called `.single_from_model_output()` here, then all values\n",
" # from devices after the first device would be ignored for the metric\n",
" # computation.\n",
" return my_metrics.merge(MyMetrics.gather_from_model_output(**model_outputs))\n",
"\n",
"eval_step_p = jax.pmap(eval_step, axis_name='batch')\n",
"\n",
"my_metrics = flax.jax_utils.replicate(MyMetrics.empty())\n",
"\n",
"for batch in [\n",
" # Single batch of data pmapped on two devices in parallel.\n",
" dict(\n",
" logits=jnp.array([\n",
" # Batch for device 1\n",
" [[-1., 1.], [1., -1.]],\n",
" # Batch for device 2\n",
" [[-1., 1.], [1., -1.]],\n",
" ]),\n",
" labels=jnp.array([\n",
" # Batch for device 1\n",
" [0, 0],\n",
" # Batch for device 2\n",
" [1, 0],\n",
" ]),\n",
" loss=jnp.array([\n",
" # Batch for device 1\n",
" [3.3, 2.2],\n",
" # Batch for device 2\n",
" [2.2, 1.1],\n",
" ]),\n",
" ),\n",
"]:\n",
" my_metrics = eval_step_p(my_metrics, None, batch)\n",
"\n",
"# Note that up to this point all inputs/outputs to `eval_step_p()` are\n",
"# replicated such that their leading dimension == number of local devices == 8.\n",
"my_metrics.unreplicate().compute()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fPKbjXgVOxnN",
"outputId": "edc422ce-d8d2-4450-e7fe-439e1b0dbe9c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Note that not calling `.unreplicate()` raises an erorr: Collection is still replicated (ndim=1). Maybe you forgot to call a flax.jax_utils.unreplicate() or a Collections.reduce()?\n"
]
}
],
"source": [
"try:\n",
" my_metrics.compute()\n",
" raise RuntimeError('Expected ValueError!')\n",
"except ValueError as e:\n",
" print('Note that not calling `.unreplicate()` raises an erorr:', e)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8v0gvSp3m9uW",
"outputId": "58e3e3a9-7e62-422e-b09a-f93023b2edec"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Array(0.6666667, dtype=float32)"
]
},
"metadata": {},
"execution_count": 17
}
],
"source": [
"# You can also provide your own aggregation logic:\n",
"\n",
"@flax.struct.dataclass\n",
"class Precision(metrics.Metric):\n",
" \"\"\"Computes the precision from model outputs `logits` and `labels`.\"\"\"\n",
"\n",
" true_positives: jnp.array\n",
" pred_positives: jnp.array\n",
"\n",
" @classmethod\n",
" def from_model_output(cls, *, logits: jnp.array, labels: jnp.array,\n",
" **_) -> metrics.Metric:\n",
" assert logits.shape[-1] == 2, \"Expected binary logits.\"\n",
" preds = logits.argmax(axis=-1)\n",
" return cls(\n",
" true_positives=((preds == 1) & (labels == 1)).sum(),\n",
" pred_positives=(preds == 1).sum(),\n",
" )\n",
"\n",
" def merge(self, other: metrics.Metric) -> metrics.Metric:\n",
" # Note that for precision we cannot average metric values because the\n",
" # denominator of the metric value is pred_positives and not every batch of\n",
" # examples has the same number of pred_positives (as opposed to e.g.\n",
" # accuracy where every batch has the same number of)\n",
" return type(self)(\n",
" true_positives=self.true_positives + other.true_positives,\n",
" pred_positives=self.pred_positives + other.pred_positives,\n",
" )\n",
"\n",
" def compute(self):\n",
" return self.true_positives / self.pred_positives\n",
"\n",
"\n",
"Precision.from_model_output(\n",
" # 1 TP, 1 FN -- 2 pred_positives -- precision = 1.0\n",
" logits=jnp.array([[-1., 1.], [1., -1.]]),\n",
" labels=jnp.array([1, 1]), # i.e. 1st incorrect, 2nd correct\n",
").merge(\n",
" Precision.from_model_output(\n",
" # 1 TP, 1 FP -- 2 pred_positives -- precision = 0.5\n",
" logits=jnp.array([[-1., 1.], [-1., 1.]]),\n",
" labels=jnp.array([1, 0]), # i.e. 1st incorrect, 2nd correct\n",
" )\n",
").compute()\n",
"\n",
"# If one incorrectly used metrics.Average to aggregate the metric, the final\n",
"# value would be 0.75 because both batches have the same weight in terms of\n",
"# examples. But the first batch constains 2 pred_positives and should thus be\n",
"# weighted 2x, resulting in the correct (1 + 1) / (1 + 2) == 0.66"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nx6H0936Q3N0"
},
"source": [
"### `clu.deterministic_data`\n",
"\n",
"The [`deterministic_data`] module sets up a [`tf.data.Dataset`] with useful\n",
"features:\n",
"\n",
"- Specify split by name or [ReadInstruction].\n",
"- Reproducibly generating unique random keys for every batch and preprocessing\n",
" operation. This makes it possible to train deterministically, achieving\n",
" exactly the same results when starting with the same seeds, even in a\n",
" multihost setup.\n",
"- Multiple levels of batch dimensions with support for completing partial\n",
" batches with filler values.\n",
"- Predefined preprocessing operations that can be configured and chained with\n",
" a configuration string. See [`clu.preprocess_specs`] below.\n",
"- Shard your dataset across multiple process. This necessary to efficiently\n",
" train with multiple VMs.\n",
"\n",
"\n",
"This works well with [`tensorflow_datasets`] which can download/prepare hundreds\n",
"of datasets.\n",
"\n",
"[`tf.data.Dataset`]: https://www.tensorflow.org/api_docs/python/tf/data/Dataset\n",
"[`tensorflow_datasets`]: https://www.tensorflow.org/datasets\n",
"[`deterministic_data`]: https://github.com/google/CommonLoopUtils/blob/master/clu/deterministic_data.py\n",
"[ReadInstruction]: https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/core/tfrecords_reader.py\n",
"[`clu.preprocess_specs`]: #scrollTo=049FEFBq-9i2\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "jDePkT8faFNU",
"outputId": "2b07ce84-35c4-4bd6-c796-4283d93884ac",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 18
}
],
"source": [
"import packaging\n",
"\n",
"packaging.version.parse('3.7.2') < packaging.version.parse('3.7.3-dev')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tYZjDvy6-VaG",
"outputId": "6e6f0fb4-4490-4a68-df4c-cddf0b66a274"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Load dataset info from /root/tensorflow_datasets/tf_flowers/3.0.1\n",
"INFO:absl:Fields info.[splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.\n",
"INFO:absl:Reusing dataset tf_flowers (/root/tensorflow_datasets/tf_flowers/3.0.1)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tfds.core.DatasetInfo(\n",
" name='tf_flowers',\n",
" full_name='tf_flowers/3.0.1',\n",
" description=\"\"\"\n",
" A large set of images of flowers\n",
" \"\"\",\n",
" homepage='https://www.tensorflow.org/tutorials/load_data/images',\n",
" data_path='/root/tensorflow_datasets/tf_flowers/3.0.1',\n",
" file_format=tfrecord,\n",
" download_size=218.21 MiB,\n",
" dataset_size=221.83 MiB,\n",
" features=FeaturesDict({\n",
" 'image': Image(shape=(None, None, 3), dtype=uint8),\n",
" 'label': ClassLabel(shape=(), dtype=int64, num_classes=5),\n",
" }),\n",
" supervised_keys=('image', 'label'),\n",
" disable_shuffling=False,\n",
" splits={\n",
" 'train': ,\n",
" },\n",
" citation=\"\"\"@ONLINE {tfflowers,\n",
" author = \"The TensorFlow Team\",\n",
" title = \"Flowers\",\n",
" month = \"jan\",\n",
" year = \"2019\",\n",
" url = \"http://download.tensorflow.org/example_images/flower_photos.tgz\" }\"\"\",\n",
")"
]
},
"metadata": {},
"execution_count": 19
}
],
"source": [
"# First fetch \"tf_flowers\" dataset.\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"\n",
"dataset_builder = tfds.builder('tf_flowers')\n",
"dataset_builder.download_and_prepare()\n",
"dataset_builder.info"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1qwhub-19vdT",
"outputId": "01f95423-8877-419b-ba46-88d616729edb"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset tf_flowers for split _EvenSplit(split='train', index=0, count=1, drop_remainder=False), from /root/tensorflow_datasets/tf_flowers/3.0.1\n"
]
}
],
"source": [
"from clu import deterministic_data\n",
"import jax\n",
"\n",
"# In a multi-host setup this would split the entire dataset evenly across hosts.\n",
"# Colab runtime is always a single host, so this would not really be needed.\n",
"train_split = tfds.split_for_jax_process('train')\n",
"\n",
"def preprocess_fn(features):\n",
" # Minimalistic preprocessing function to make images have the same dimensions\n",
" # so they can be batched as dense tensors.\n",
" features['image'] = tf.image.resize(features['image'], [224, 224])\n",
" return features\n",
"\n",
"batch_size = 128\n",
"train_ds = deterministic_data.create_dataset(\n",
" dataset_builder,\n",
" split=train_split,\n",
" # This RNG key will be used to derive all randomness in shuffling, data\n",
" # preprocessing etc.\n",
" rng=jax.random.PRNGKey(0),\n",
" shuffle_buffer_size=100,\n",
" # Depending on TPU/other runtime, local device count will be 8/1.\n",
" batch_dims=[jax.local_device_count(), batch_size // jax.device_count()],\n",
" num_epochs=42,\n",
" preprocess_fn=preprocess_fn,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TxrnWqp__TuM",
"outputId": "09c1db07-4ee9-4373-9020-bedc0620ab07"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"dict_keys(['image', 'label'])"
]
},
"metadata": {},
"execution_count": 21
}
],
"source": [
"batch = next(iter(train_ds))\n",
"batch.keys()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SWv59_4w_U0N",
"outputId": "26b628bd-1d7a-4eac-f926-a204adb8e1a1"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([2, 64, 224, 224, 3])"
]
},
"metadata": {},
"execution_count": 22
}
],
"source": [
"# local devices, per device batch size, height, width, channels\n",
"batch['image'].shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 437
},
"id": "k7_XF8OJIOCl",
"outputId": "9ea8496e-5343-4c29-b110-c9aaa5d809c0"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAawAAAGkCAYAAABtmxHBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9abCt2VkeCD5r+qa99xnuvTlKykQMQkwSRhiRtqkomzSSTFNgiDKo6LBM01DGkqONwFEmus0Q0RHyEIHddsjwo22EO8JgoAoTBlvdIJAIQBJGFsZGoJKEREqZeTPz3nuGPXzTGvrHet61MxGglC2UOqr9RmScm+fs4fvWt4Z3eN7nUSmlhIMd7GAHO9jBPsVNP9cXcLCDHexgBzvYs7HDgXWwgx3sYAe7EnY4sA52sIMd7GBXwg4H1sEOdrCDHexK2OHAOtjBDnawg10JOxxYBzvYwQ52sCthhwPrYAc72MEOdiXscGAd7GAHO9jBroQdDqyDHexgBzvYlbDDgXWwgx3sYAe7EnYlD6w3vvGN+IzP+Aw0TYOXv/zl+PVf//Xn+pI+5e37v//7oZR6xn8vfvGLy9+HYcBrX/taXL9+HcvlEt/wDd+AJ5544jm84k8d++Vf/mV8zdd8De6//34opfBv/s2/ecbfU0r43u/9Xtx3331o2xYPP/ww3ve+9z3jNXfu3ME3f/M34+joCCcnJ/jWb/1WbDabT+JdfGrYxxrLv/bX/tpHzdNXvvKVz3jNYSyzveENb8Cf/tN/GqvVCnfffTe+7uu+Du9973uf8Zpns64feeQRfPVXfzW6rsPdd9+Nv/23/za895/MW3nWduUOrH/9r/81Xv/61+P7vu/78B//43/ES1/6UrziFa/Ak08++Vxf2qe8fcEXfAEef/zx8t+v/MqvlL9953d+J/7tv/23+Mmf/Em87W1vw2OPPYav//qvfw6v9lPHttstXvrSl+KNb3zjH/r3f/AP/gH+yT/5J/jhH/5hvPOd78RiscArXvEKDMNQXvPN3/zN+O3f/m38/M//PH72Z38Wv/zLv4xv//Zv/2TdwqeMfayxBIBXvvKVz5inP/ZjP/aMvx/GMtvb3vY2vPa1r8U73vEO/PzP/zzmecZXfdVXYbvdltd8rHUdQsBXf/VXY5om/Nqv/Rp+9Ed/FG9605vwvd/7vc/FLX1sS1fMvuzLviy99rWvLf8fQkj3339/esMb3vAcXtWnvn3f931feulLX/qH/u38/Dw559JP/uRPlt/9zu/8TgKQ3v72t3+SrvBqGID00z/90+X/Y4zp3nvvTf/wH/7D8rvz8/NU13X6sR/7sZRSSu95z3sSgPQf/sN/KK/59//+3yelVHr00Uc/adf+qWZ/cCxTSuk1r3lN+tqv/do/8j2Hsfyj7cknn0wA0tve9raU0rNb1//u3/27pLVON2/eLK/5oR/6oXR0dJTGcfzk3sCzsCsVYU3ThHe96114+OGHy++01nj44Yfx9re//Tm8sqth73vf+3D//ffjMz/zM/HN3/zNeOSRRwAA73rXuzDP8zPG9cUvfjEeeOCBw7h+DPvgBz+ImzdvPmPsjo+P8fKXv7yM3dvf/nacnJzgS7/0S8trHn74YWit8c53vvOTfs2f6vbWt74Vd999Nz73cz8X3/Ed34Hbt2+Xvx3G8o+2i4sLAMC1a9cAPLt1/fa3vx1f9EVfhHvuuae85hWveAUuLy/x27/925/Eq392dqUOrFu3biGE8IzBBYB77rkHN2/efI6u6mrYy1/+crzpTW/Cm9/8ZvzQD/0QPvjBD+IrvuIrsF6vcfPmTVRVhZOTk2e85zCuH9tkfP64OXnz5k3cfffdz/i7tRbXrl07jO8fsFe+8pX4l//yX+Itb3kL/v7f//t429vehle96lUIIQA4jOUfZTFG/K2/9bfwZ//sn8UXfuEXAsCzWtc3b978Q+eu/O1TzexzfQEH++TYq171qvLvl7zkJXj5y1+OBx98ED/xEz+Btm2fwys72MH29k3f9E3l31/0RV+El7zkJfisz/osvPWtb8VXfuVXPodX9qltr33ta/Ff/st/eUZd+tPRrlSEdePGDRhjPgrl8sQTT+Dee+99jq7qatrJyQle9KIX4f3vfz/uvfdeTNOE8/PzZ7zmMK4f22R8/rg5ee+9934UKMh7jzt37hzG92PYZ37mZ+LGjRt4//vfD+Awln+Yve51r8PP/uzP4pd+6Zfw/Oc/v/z+2azre++99w+du/K3TzW7UgdWVVV42ctehre85S3ldzFGvOUtb8FDDz30HF7Z1bPNZoMPfOADuO+++/Cyl70MzrlnjOt73/tePPLII4dx/Rj2whe+EPfee+8zxu7y8hLvfOc7y9g99NBDOD8/x7ve9a7yml/8xV9EjBEvf/nLP+nXfJXsIx/5CG7fvo377rsPwGEsn24pJbzuda/DT//0T+MXf/EX8cIXvvAZf3826/qhhx7Cf/7P//kZTsDP//zP4+joCJ//+Z//ybmRj8eea9THx2s//uM/nuq6Tm9605vSe97znvTt3/7t6eTk5Bkol4N9tH3Xd31Xeutb35o++MEPpl/91V9NDz/8cLpx40Z68sknU0op/fW//tfTAw88kH7xF38x/cZv/EZ66KGH0kMPPfQcX/Wnhq3X6/Tud787vfvd704A0g/+4A+md7/73en3f//3U0op/b2/9/fSyclJ+pmf+Zn0W7/1W+lrv/Zr0wtf+MLU9335jFe+8pXpT/2pP5Xe+c53pl/5lV9Jn/M5n5Ne/epXP1e39JzZHzeW6/U6ffd3f3d6+9vfnj74wQ+mX/iFX0hf8iVfkj7ncz4nDcNQPuMwltm+4zu+Ix0fH6e3vvWt6fHHHy//7Xa78pqPta699+kLv/AL01d91Vel3/zN30xvfvOb01133ZW+53u+57m4pY9pV+7ASimlf/pP/2l64IEHUlVV6cu+7MvSO97xjuf6kj7l7Ru/8RvTfffdl6qqSs973vPSN37jN6b3v//95e9936e/8Tf+Rjo9PU1d16W//Jf/cnr88cefwyv+1LFf+qVfSgA+6r/XvOY1KaUMbf+7f/fvpnvuuSfVdZ2+8iu/Mr33ve99xmfcvn07vfrVr07L5TIdHR2lb/mWb0nr9fo5uJvn1v64sdztdumrvuqr0l133ZWcc+nBBx9M3/Zt3/ZRzuhhLLP9YeMIIP3Ij/xIec2zWdcf+tCH0qte9arUtm26ceNG+q7v+q40z/Mn+W6enamUUvpkR3UHO9jBDnawg328dqVqWAc72MEOdrD/49rhwDrYwQ52sINdCTscWAc72MEOdrArYYcD62AHO9jBDnYl7HBgHexgBzvYwa6EHQ6sgx3sYAc72JWwK3tgjeOI7//+78c4js/1pVx5O4zlJ84OY/mJs8NYfuLs02Usr2wf1uXlJY6Pj3FxcYGjo6Pn+nKutB3G8hNnh7H8xNlhLD9x9ukyls9phHWQuj/YwQ52sIM9W3vODqyD1P3BDnawgx3s47HnTA/rB3/wB/Ft3/Zt+JZv+RYAwA//8A/j537u5/Av/sW/wN/5O3/nj31vjBGPPvoogBzqHuy/zWQMD2P5326HsfzE2WEsP3H2qTyWKSWs12vcf//90PqPj6GekxrWNE3oug4/9VM/ha/7uq8rv3/Na16D8/Nz/MzP/MwzXj+O4zOKhY8++uinJvX9wQ52sIMd7L/KPvzhDz9Dz+sPs+ckwvrjpO5/93d/96Ne/4Y3vAE/8AM/8FG/f9k33Y/muMIQtgCAKXkAQISGNQsAgFENAMCZCvPcAwBGn38mNSGG/B7MWYLbwcHGPCw6VnwdMKcpvwz5dcoCxmZvwFYmf1fTwan8bzXnA9anCdHE/F4doGL2DyyzsVMf0A87AEDlHACgrSsETz8ivxWuqhChAABB57/NKiDSI6lMfu84TFhvsheVeK2dqxBTvi5n8z0ZUyP5/B4d8lg19gSuHjkcawDAMOwwhTxe0Pn7YRxS4ntj/tyVc4DN1xXjDADY7XYIIf8u+XydOjmM9jY/Lt9cZRsYXl/XrvL42Arj4Hm/HAqroFK+J5Mix28HP+bvW1E5+dq161BT/nuja7SmAwD0m5HXpwCONVx+1sO8xeX6FgDg7Pwsj0GcwalQnnHnKlg+B+vyhdVthdrmv4v/F41GMvnNE5+5R0Lg2EzzkH8XJiSd76ly+XPrmNCY/Nmzy89LBSDxPivkazdwGOd8n7bOz7BdLBBi/rx5zq9PEQicR5bX6aoKnmO4nfK4hE2A6fM1uCjzxCGYwOcgz8uDt45FlQv4XbXEYrnMr0sBH3rk9/JYbvNYxpVC4v1NnNRDGKH479rk67JKl+cdAn/OEZE3YFUeU1c7NG1e24Zrzo8ewy6vpVW34Jg6zD3n9Mi1nhSMrfNncw4mkxD4rKfAeZKAuslzypk8HtMEDAPXF/ebSinU+RJgOFYxDRjyI4ZUXhbLDhPHHIrrFhZLne8jTnmPiVpBK8Xv4LMMAZM8V37EnACttnxPALTlOOQXOCxQ2fxMnFvyAjUuLp7Kn821aVUFqx2HhusqbrGd831GmTOpQaVq/juPi0oGivuCttwfdAK4r6bthuMSMee34s64430kvPB5DyBMEe/50d/DarXCx7LnLCX48dj3fM/34PWvf335/8vLS7zgBS/AGEeYqKCq/PCN2oeTEXnSpZgHzkeVBxKAbfNPBQ1wYbo2j2aDCpyvsDyQAgI4D1Hz8FFOIfKzZYN2dkDFhxsVJ1+YAH6O0wYmcVNSeYIsdIJVfLicLNoFwMlEzZ/dpy0UF2bi501xRgIPS5UfdlABnvcpFcrZhnLwaW46XW2wOs4buY353uMwYOZ7PSe/goZWshrlM1RezQA0OKbOQPHAKtFwjLCabwr55zRMMMjfV1X5O+qqQr/Nk3iiYxAQoVf5702VF/Q0egxc1IELWdUGDXdPxU1v22+LI3J7vAPHDaGp8v3WbQtd8forbsZjwDDlsY5c2woabZ2fU93kedJCw3Chn23yob4dt6h4oMnrnKn348W/NcZAiR8S8zzw0WJO+XuD4rVED897mXgfWiV0qzxuXZU3i76fMHETro7y50XrEeRg42GHZJD4vMo1OcBwUkxTHvtkE6KT65vLGCRunoHPeoeU1w6AbcrvXWrgcsw79G5c447JG1U45tc6DR9kveSfFopHPxB9/t2MBGXyb02ik4MEH3mwRNnwgcRDPXGsUhixWOQbCHSynnzyCRgecqdH1wAAi26Jkd83+okDAoAbrp65hgEYI04aD3+zAejUcTjgk0XktSo+1zkZqDZfa8vnUDmHEx6A4Jp66qnbuDXnQ6dt88TzRiPyMBynfB8+zOVAcE2eB9ZomJjvaQwT+jG/dgr5nho3QHONidORkkGs6DjwRPVqRmUb/p1zIoxwXFeGzr8KCihnPp9hrQDuW4PfceznMs+9ztfi6gaKa8mB61EbHB8dwY9cy0pmwx9tz8mB9fFK3dd1jbquP+r3Vd3kaIOjM/r8APImz92annsMgKPHW0uUoQ0sJ/vS5YdSo8XEQY5qLt/lIickRyzphGHMrxumLX+5RXL09mQBpgTZBZxroOkdW0ZGi6aCYSQUwA3TBHg+PMMdxiPuvSB6NHH0GCY5kLnp+YTERSOLIkWDxuXNunF5wTTWopLZF7InNY8DvMl/TxzTqGYk65/+cUiIUDIenEIeFpEevecGDBuLJymRRbQRZuLGa/ICbZsGyfM5MdrwYULF/dbRfa10DWfz98mznsYRMDx8+HBSTFA81CvXwPNaNxzfzTyj4hxIfb633bhFTyfDdhznoKG5YSVutj4Bhs9O9s5xHDHNMuZ8vQYCN0PFaM5YA8tDouHvuqrFyPnRc9zGMGAK8vybMn6JTo7nmHqzj1q2MW9WehzhGPXK6ZSiKhuR52Y2D0PZdGYe/nCpvC7yUFbRQMsa4ibkzBJV2XjzL4e4w45zfpp7NEeyteRrnkNA5GcbiRC1RmTUMNERmf0Ma+RQ4k+fYDmnDdewTQZhx02dh1MIO3Rdx3+ncnktIzFHZyIgYZzy/JEIptEOmodOYnZjRITifNQc+9oeA+g5HjwEYsgbDABorhXl89rPgw0AWOgaDZ2ny8u8ZwQfufqBmZ+hoTFyb9nyOqOKaBnlJfku7zFLpmb2mD3vpc5jUGmHRAdZ5uowjCXiHOZ8wFijoTmPIsM334+ouZdVc/6OxrWwNdcYr9nHGbPKn6d1nr86BVSakVid1/jSLdGafO+o6fBN65xtYzbk2dhzghI8SN0f7GAHO9jBPl57zlKCr3/96/Ga17wGX/qlX4ov+7Ivwz/+x/8Y2+22oAafjdUrg6BHeKa5ZqZ05jBD0buUXLBPsXh4kruttSsedLSMAJRGYjQwe6ZFjIKS2oTklKcRw5C9pHlmWiEozPRGJIUxqySlHShoKMkRS4juB0R63Yl+yzyFUocwjCgqY2BVxWvIn2dChJY0FtMsaU6SvYChP9LqBTrNXDY9xanvMUwXvIbsIamQoCzrGcVJn6DF46W3poKFEY+XSR0/BniOl6RRalVjnlmHYiRorcPCZO/rqMpRrdUOYPqkpzcalIcEahLNVcZhphdfamNIsKxDSQppnkbEIb+5bpdwjilPzoXNZgMfec/0b61ROGVKZtoxtTVM4FdD0TvXZp8ilZx7XdUAP8cx/RNjgue8CKwRQSlYRke6lojXQXNcK4mgPQpayrFmoF2HihHksOl5fSMSV7DUZRU0OqY+Za7GkEptrR/znB38GlbSoo41FBMRtEQN+XMNNDTnkeFgNEbjeCGpT873qcf6kimhYUC14Fgy6t75ody7Y5rYKYMkaXyuixhDmW+Sbo56X3vtWtamjcNWIgVJQSuDmfOiYhq5WRyVf48SsI0jAidXzUi8bupSUoicsxoJhu81zPBEGCSWdKdB7ncGvKS19tceGMFWTBX0ccSju9y2c3GRsxrWOZiSF8330Zq0T7lzrQStAEbdUerbIUGbfDGVSmj4HI8WJwCANAF9zwiW6zXsJhjWAxese1uj0bIEIDU9l2osSupOskIKimMkqdR+3GJC/reuOY+TxSLk917vrgMA7qvvQsfIf0PMwXa6wOP9GeL87HF/z9mB9Y3f+I146qmn8L3f+724efMmvviLvxhvfvObPwqIcbCDHexgBzsY8ByDLl73utfhda973X/1+wO2iAjwUYrVUljU0Eby5IL0279unMV1f1rOnh6XVaogkhLdV5dsiWCGnt73ble8R0GLRV2V9wia0McALfWZFDEx3+sLKCPCGAFO7GsOKmXPyvAnJl1qF5E1G+NrNEneS483eWjWFeTzqrlCIOTKS40oDpiImhTkXWUrGNaGIlGTycYSASgj6EkNR+CEYU4+waJi9OZ4nSnMGFgz0SwiG1Xh2jJ7yR2LsHMMaPicRi/tCxa7Pl/fRX/J71JQzLXLNafgMUs4wEhgngdoem1ezVgQKVU3LB5PCTMjEmn7WLgalhHaxcAIZjtDVbynNv9UWmPkPKoJ1KmbGmAEIwGD93MBfgSJ/FNCpLc6Sd1oUvLWAiJwyaGyAnohYKNpULGGIei/+cJDEdBhOZbRKMwSFjJ60NbA0yMewx7UInVhAVCEEASbUwAs2rgyt0yU+TljGPMzCazFpRgwMcIehgmK9V/vWRPRHq5itGj2lQiJxkudqbKYk0S/HFOlS/3WNlIPrOAIEKoZFSjbQDN6k8gIxmLg+AqABQYls8KgBUOcoSUKkQjQOdiaNV2+fpx2iCGvEc2oyoUKdSSASQtQa4ZDHqMFEZzTHHDZZzCKlyyEMah4WQ2fQ4pjibAFcBSsxRQkA8NsChI61pzauoJjlLRKgow2iG3+nB1rYpfw0HxmXHIIaYaapFbLeW4dlozUVyvOWe8xCyqV119Zt98vueYq3ZT6NDyjyzFiw3V1a52xC0/NTwFHsaydZ2NXAiX4R1kIEdCpIKX8LGk7BSNQS07IytiPKtjpZEpqaTtL4XAP4pDisA8egWmCfpADa4ace3XLTaPqYJh+Ujwg9DiW9JlOuiAVBUKblC/12sjdrjYOtc0LRYqX8xTQEzEoKUQkW2D0NVNHs5rL4VrwwQmYeH+8JdjaoCXUNXGjtroFZnkvxyN6aAIKeOZARVdSFZaINWtqOC3oLh4WfoDjRuMdJ7XWqJhe8/zAkGIpCiuBMo8zxvOcctmlvMirhcVixY2hEtCHLkgzQVkGnZAgjogvGx+40Gs4aIF9Kpk7vqTBHJFVJ8fHiLK56j0yU8YyMAVilHma08EDVVmYxI1enIXZl2q1rNGAfbFdFr5KCeC87BYEtcxTSUVazonKdIhMHdWQ56/h+TwFsAOVIfpAhkwDgDFNceqk0B4nhRQllc6/aYNZADhacmoTzi/OeK3ceOt2n6qyVUE8aj6nztTQZg8gAnKaXtapbPTWKEBaJoyg1DSipNiJIk3QcDy8lACdkoc4LeJchZSQyiMUAEIoyMcx7lFxknrWnLNd05T2h+LweVXGpqIz0xiDllspQcLwUUNNGfIv4BCrgSMezFxKqHWNipuBlgNJzXlvA5CirM0Ghh+uVN6D+n6Al9YU5cr83e1yyu1au8SSjoyWsoMzxQlLdNK3w4hp8uVzgJyqFLTxlvNTGQXQkTJREL4aacrXODD1HVVC4GF4zvTfE5sNBkmRM405K42ltog64RaeHSnvlWVrP9jBDnawg/0fy650hBW9gqubAqyoq33qTRxj8SJqY1GzaG0kXWhreHoy0nSslCrHeMCufF6IUuSXRhZbQBzTmL+/W1Qw7GewjHJ01Gj5nkVVQxGiHep9+D8T0is9JlqZ4o2WimxEgZ5KSBGmUCDskgqxGtBgGlTAGZ0tfThK0ok2ASWFR0i/O4ViVBAYzSW/K8XgQA/JRwBMT2ktaZ4aSUAq9PQqOMAI7J0RjVXoCZxYb3MfU+VciZIl1Eo+lGbSiakXaw00b0oKxwZ7T1zSNj5qTHzuPmyhdtIgveFXJAT6uI7RcUgzPNORlaUHXdeYeT09I845jpijNJbKJFOwjArkvUarj+oTsiGg9FbSe53ihA3TdZ735pQGOI8UK/zzFFBxThnx5psG3ks6XMALdYmYB/ZjhTgi8JqrmlFfZQuARU2M8GYFE6UQT9h6NPuMg/T0+BkMdNARqNJVC0z8joQRiuMvKdekbUn3St+ZMgouydjwO1KCqyQ9yHQAgMQUdel9VB5K78caHJnExSE9fiHEp0Xb4LhMmLnek/QJagVB2GgBhyhV/h0ZMZoKsJImluhn3r9XoscxJuiUswF+l+edhUdLoIsiEGMKKIQAkSm9ql6gEoi9dE/P+zRmLRF0jOgJox/GobQzKK5/bwsfAuYd78nWiI7vGRjxhFAay0sfnk2YFNElAyH9dQ0wypYUc4IHnKQN8t/WflOa0q20YmgFLCR6y3PLjAlLY5AT0Xkv+Fh2iLAOdrCDHexgV8KudISllYWCQcVirtPitcylQ1waAx2AmgCBzmU4cuOW6AmJtdLwV6uS7++lrqU0EmsedKCxqGtYRg2JnpHRZg9N5t+iVgVefrw8gTHi4Yr36NGbDC/fEmY+w2Om1zVLPcomOHHPpSCuHSbCVgUSb1sHi+zBCBtE5RQqLWAQKb7PiP6ZkOJVZbFqc+O2UOAM/SXW21u8BoI0bCierHivkx4RCGQAC7xOqQI5L+X8ZLAhI8KGTZHG7wrTQUdorq40Gk3YO5k44HypOUlNIUZf8umW713aJXoCXcYwYk0qnVLXiKHAvJeMLiNiqRGOQv/kpxJZz7z+OUYQQ4Gm0GeFQlUkHf5zipinPf0SkJ14AYvsGG0MPmDHCLukBaxDZO3nMuQo3/sAx1qCYwOm1wqlX1laJ1LcY9I5psnPpbbWcq3YoDENrE1MhFAHA8HJK6lrWoeaWQHL6HzWDovlCQDg5Pg0f27bYkcann6n5KuRpGZboB2AkfVi9pG/1PFUbdEs6IFLU+w0QXE+Fth42o9lS0quFIFegFCFRDUWsJVhdKBMRCytBqwfKbUHdkhWwM8lEtPsTdE2QEuyQLIgSWOM0mgvbTQ1lIA3NPWnpr4wcFSMLFXYl42lJcKggWWtCIT4+3lEmKUxN/+ps1WJdOZp2n8fwTRz3I+R0FglF7BlNCO0bqbrUDPrJHVln0bZZkpGZ4wBnnNF2lWcc3DcFJ0wDlUeSQlFlQC2DIwlDVclz7KGjamwgzwbu9IHVjIqI8QKtkCK/XPpFDfCyaYTFFd1I3QVqioUKIoHkqscJqZIBLCmnYKMqcyjtjWo+YCSoG50AKKg4gRBqArDRtVUUDwkJHyflS9hdtDC4zcVSimhY3JWQ8kEk89rK0T2LSVOXFe16Hh/Xi469CjVfi1AhxlzkM7+vNHskkM758M8MTW422xLsdQc7dMjwuIgm2OYQ0EdegJTKutQM7UlSD4/77EgppUU065shqYR4IbFPHBz4pjFFDExzSU8binGUlB2LEpX1gKjFME1Non9ckLXZWJZZEGcgGBheXgJ0i+mWHppfJCN18IwrRZGQSxqKOlj4ebjZ4+RKdKCilMobBB+FoYSlLyZ5WdE7Ddw4Y1TxpbU0VxYFQxAmp0gh1Sc4CRFzg1Mu66gzqRXLWwHxE1+hi03qxqu0Fhp7A8VcZSGmQg3jGVTTGSFGadQUqVBB1ihD6LDYmePrt5zc+axShilj5BzdtF2qPjZgQhD62w5bCRtnlIsh7QgKn1KpQSwk3URPSqOfy3rFKE8KAFdAbFQfDm7BxHM5RoI8FAGgXm2SdJ1ycNLz5hcFBKcUEwRHJKsxrChU8R5bG0FLZuKAHuSLyl0ToncI0cHSb4WpsKR0I7Fpjil0nc4RV/onEwt7CwJmkhAAT+5pkJkuWFHRK7feVQEs8g1eB8Kc4Y4Nk3VFsBR4Jg6m5CC9LLu0dKRFEwNBCFpoU3KTtaztENK8GAHO9jBDnYl7EpHWJd+RFW54j0K/1mKSYIWWOHSUsDIX+4YBc1xgwHsqShEw25ffE3CTYeSQkqFPFShEq+VnsgcIhJ7NMRLq9GgYcd4vx4QBTTgyBVmNhhi9mpGeqhjmIoHaAkfT8pCC1RYSU+ThZYwW1I+JrHkCgSGCrvJYxRGXyW9Lb5cy8DU5515QH9GL5ms3WM/wbCXQ9JJkx5K6kD4zSqvYAWbLqkEHTDGPB4D72dKoXjn1gnEuipRksDVNSKCPCfxhgfAe+l34tjrPfHoKD0/xhcv1FpXACKeSAFlLBIjLE/v26IqrA2STdJuz4sY+x3HPqLucgpq3s5lSIUtv9JCnFqjYurOWImCZmwZzQoixjkDJ8zxnIQhhvL3lv0sxuqSPvPSI6dS6QEMMj+T2oMyhKFfq336it+xSxNmjmVDcEADiwXbI6qSpYkYSxo5X/uQpsJiP++k6D+XqFGbfF95PPY9XF3H3j1eyjxMMEISo/ctHTMzA0IxVxtXwDZdK6zpA2ZG+VsytKPR0LxwPwhQyJfWECGedjCw7JsqwBOnyz0FRuJK+cKaL559UKr0t03CARrjntiXmR0VApQAVwg8MK4u6gDChKOdKpmfINkPQ1AUrwHIQBfZrTUzLCpaLAlgUG27zwzwWlNKUBJtMVMzD8hAGgCaYIloUTIvwvJtGg3bPDPVbgJQC4n1JL2btqhMTEwXRj9l0s2nfWzl1T71Jy0KxsNZi+iffUrwEGEd7GAHO9jBroRd6Qirn3qg2bOwq1JnqsASDLRgWaHgCWXezrnoaPQIr0RjiHx1SRd+vLZm/QUBo8iUJGGMUHBOcsYs3DoFeJGSyJ6P7h0Uax3nuw3ACMI1vBY7wwuAgfno4CMm1o2sFlCDQbJPa+AD0NUNOkZgAgHHNBVdHYn8lA2Ifq+NBOR8s3h5kfn3FGZ4gaQL47M1GbqKfYe9t1Nhlp8JWqlNi4Z5cmE/GOcJAwEWk3AdGl3kLxojmkUNZnqmfisNjECi92tKJ7yBJqecRCWuqkodQqKvDG8WKLxFxZhTakDKqn0DL6kOrLEF8i8NyMaZEuEKi/zkR0jbp/C4IezbBWrqeVW2LdyKE4uh0xTRkm3DWmmONQVAUhjEY2YpyGPEelnaPy8BuoSUShZAgEdIFhg5HqwPpZCQCPIRsI+xDdqOtQTqqZgpy6wAwMBrj86Xf4/yvcbAS72Fc6LvR7QN5VtshZ4AjPOz3GBsugrHXE9LqSXVsejDCTx/43skYXwh1ro2dakHLen1x6rBthfuUIIvlINinUxaWLapzyEf9gwW0VUwzFZIBiYBhZlEPq+CK4wuQZrnnSoRhbDGxHksbCbgeuicg+N3SHtBVWlUbR4jxUglIObaN1AanGcTEY00m1PrbxwKkcGiPsn327bAQDh9nIuummJEr7TBXJQMGMFWq8JJeTGc5zGaLhCs7G8io6IKs440XKeUCq+k4vr3c1903Bo+m3lMeyCJBPYhFf1AaVvwSAhIZf95NnalDywVPFSIZTMRpoAQQ0ndCWWNMwZKEDWUYvBTj8jNR8JSEwMs6ZBEDybNESFJ0RL8WSHyddELlYspCKOK5I/zlOD7PZOAFJQFZVfPC7SUNhEUzSZe4HK8k69RenQQMEt/vEwCVFguhV4n/3K3XmPa5fsTja9t7AuNlBawgVd7RQSzF/WrGxJotsf5+pLDbpMXzcjFU+lUNhrIppz29DkltYVQdJVqEZqsHTwP1I7XXCmHLVlKZvaGGKWK+GYlOkVLXSa3KxuvK4ilSVaJUnAQuZJYEG8N0yM+Bqwp7zBzBSyPLLQVmRc5uPQeEFFobAK86A4VCiFTEKoFLlhZiHCESMCEGFHxGtpKBPCAKYjSnyA5U6FC2ImIXkpFKG+v86KQkjBT8GDzETMRcEKMbBMQCS0rDDBJF0LZOslzS5ikB1HmubGYBVAiPVNQMCzcazpWDRrctbobALA66nCGLBIYpvxzpzY4F/0lPpuuqiU7VVJzcQrl2YkES+UMaqHdIpK2sgYt2VOOmCasdYcdL3zJdKyq9J4NhBI7UQGqytdiRfzRhJLShDDYuA4uMP3Lz62cLiwkSphn4lxQf7KW7z85LbRktyZqg82hOJGSv0vRl1KGSK2M0wQo6ZEiUMj7giIUJHJlalxSaDXA74Fa/DydNMDnozjQy3aFplvyd0Ir40uaVhhsUtqTHgem/6dpRi1IS+zBQ/RJC9WXda70AoKpQxVDSePLqRMTYJoacRIUyce2Q0rwYAc72MEOdiXsSkdY3gdoRGiRhg6SzhgLSELUiJvaoRF85rwv5kfxkuiV9HNCpSUVJa8HOvFqCdPWtiresjA7WNjCbCf9Oz55JEYItXNwy2cyPyyxKGkaoed31hb5+J5Q4jkEVEwZ9pPwGY7YGELIuxwRhbrGdif6BxSY9B6GhV/FFIdWEZYRnXg+2jigIYDBCGtBhFqS1HKWQrpFh5z6EgKL87Ev41/aKoxG04ja7x6UMCTh4BN3PpXeElHJrW1Temk00xrKToWZZO9qBSh6fU+/n46e7DCM5bXS/zNMCZsN0z9USR29gmkJahEi2aDKc5RStlWxpGRBQIQzDRqmAhU97MFH9AQUTHEv4Ch5a4G8m6BQ0attlnmsNuOuRBx4WhrYF4kbFux1VaDw8hxiCIiiYG33sjVRINZ8OC6pIljYM4pLxsJzDsZCnFshFk5E6XfURSJEPO0UUGTnFSxaevHXeC2tSwUWP27z3DJThGWaq5oFtj4XcJH0XA3zWKRBJN3d1RXAiGKxzN81aItLEiZbph8tTOkpbKOIf86YIH2X5IDUCok9jaaINTZwNj/PIrGjAgbKinhGPyrFEhVWMm4whYxb6f0cktaKEmwYjUCWF4F32zk+TcxRWHw6NJxbksberbe4GPOY1p0tgbfIHemoSlbJU25nuz3Dts+RpkDeqyqL4QIleESKCpHfM5PLM/qxqDJLL5VNad9vyNS3axaFQUSyX6ZKJW2+4Tzo/YzjxWlJfz4bO0RYBzvYwQ52sCthVzrCglEIRsEW6Cl/egU/S+REuYxoS0NdQ+iuGnpEwT8XiQJXGK0DgQBNNOiY7zcire5syUcXJuc4FxdFCpV6meDFe1QRk5Ooh96ZnmEIu4U0yGJGJc2GBG+4qGCUgCPy9637HrfTebluIMvJbwl+EBmKha0L+7sWuRUdi0RA4flDLMXjgdDp2ZjSLG1bYUtoccNdA7Dv2K8uz3A5Zc9zEKYI7aCl0Mqm3jBNmOl9CTBCpRkzBTQd4eFwujCISKd9imEvZRCexrsoYAph40YsQJEQ58JKrgjKqGKFI50j0h3hz+OdCHvCaJxjZZxGIhQ+JmFEUWgaQntjHsumXqBqRJo9X+swjiWKKjVH5wq/YywyDRY13yuRrsW+8dJ2jArnAX4en/E6bfUe9iwdximWiFNECj1iiYQsCzQqONSsa4xSMEumMHpIY6jRzX52SGSp5gzvBqCFQTwprEMGWKy3MzSh8IrCnJWuAXrYXthbVFUg1sYIf7kvUGiJFHrviySQ9OV244y6zA/y2lVrDNJAL9Ijri5gqzWlgYZhQM/7lPWslIblXJUG3GgjYiW8h4zYpwk9G+RHzveQptIWI7Wix89uFen4USJilwpTunAXYPZQIm1fmHhqDAPXCG/Yuqo8k573MQ0egcztKi4K04XA3oGEmWMz671Iq2a4uqRag0oKnnvdwDpjSvs9zDMqNCkWsIpICQVnMLFmKk3zrn4a4KvMpxnKCo8pf7Ub0Q+XCNK/8CzsSh9Y3fERlFX7HqPS+a6KtlEhQPK+aNtIcTAGVfR6LIlzq64p7BL9nZwqQYwwghwzQgmUyhfK5tiPY2HYbPldrjUI7OEZ5n2qUiiSNtMd6IkotgLeUGBmrPRKIblCzSLkDJUzGLmo71xmkEbVdBikCU2oUqIu3ejycVobmFpoeEiSGuaiI2UF6GCAgXQycoDrkLBkStBw5R01y9ILJFXrSadCzSOUNRMKsKkgIIEJganNWTFFlwysSBjwkF2oZt9zF/dIKMXPDkIE62f0vJbJe1RSWJe+JFQ47nJ6BUSzeR9Rk+miZRrINhZz4rObKKcRfUndLGoCQJoAVwmKUDanEZaHnJCWZtYHIkslVW1qGBbRpR9uVoDuSJzK79pu58KIcnJMdKV1GHdCjivSM09nySDiS8enaa3xIPIBlrIWNT9PQcFxPiUjz20oTCeDyMakqchoLLtMO1RXDUZu5PNwCe2E9UCcpxp2kb+nUBAlB4J0EXuCYKYNKj5bkdZwrsZWaL84d3ZeoaOqcV3nuTgN53Acr56MDeudQiBSStCOUwxYLvJ7Oh5OVumyQUv6d5w3mNkoZmpJx0/oSUjcTwJG6st6Cdxbep9Q0cGTHi2lAcfeJ5FfGee+qEovV+x38qZQaQlDTFQKE5/xKGTZ2PfLpXEuoA3lRBomYYpC+0UNt6YuKWghzB6HvTM0Mb1X17YgcbcEXVW2Li2W4NwKPpb+QHmutXPFkRmor3e5uUBNFOGS86CaNfr+/ONSHD6kBA92sIMd7GBXwq52hGUj5slDzcKTxrRTdIiFR008cY2Jp70ixX/V7lNfEiUZa+BJONpRZRaTx8Wc4bS1k9Rasxeso/e6iVMh1dT0hhtboyKDRQ0FK5EJP+f25jY22+zBCMeeWrZIwhCghCQzYhYPVxxBp0pvQ6TnNky70pvld6KsPJV0mRDAJpOwkTQLvengx8IAsKJH6Vwq0g47pt6e6i+xJSfaSS0ikBHaCWEnPWOtCsddARFoB7B/aZg5znULzYK9H6UPq4IjwMWJEqsfi4K0ojcXlYUXSLRIrDRtASNotUEi20bpA7E1Jj7alqmNRtc4PclErp5pW6i5sAdcsj+papaoSSTaM1nW2BquEi+Rzx9bOEbgNdsfjoPBStKNjAqCttiwcD4PZMEIIwKJaROjxhuuKWkutsxg9GMR9ZtqRhFTwMC8pMjWNJWGE6wKL9OYuvDfVVOGRm/6EUYyA1V+rv00YWTB3hRBxVg+RyRloq3gmvxeryO2EwmHmbrtFoDjPAPn+7yLsEzJSYZjVgYj13HDZ7MyGobFe5HTmOeImaCSkffbuw7TkD9bejbmMBbZDoH+d9phSYLomnuH8h6a/ZKWkd12ngt3ZSVyNt6UNDi4RhRQeu58ibBjybaUDB2GvSwOx3JEhKbkcCTUfpxmbBitSnrP2KrwAdZ8RhGAYxnh7PwJLBaiAs7yRqpKy0IgGXjrFGor/KRcBFUNx3nZSIcDNNIkxOH5+paLIzTkiJRyg48jNFs5jMynObOmAEDk/F1YXeSJhOdzAYc6ae4R0trxx9shwjrYwQ52sINdCbvSEdZumGCdKjBvga/WtsY80FNnJ2TTAdqJ18WaTApIjACGQfLDCRXF1TqRf9eAEol3odPQLZQSqen8q8ouoFk4PepO8i+DLxDhiIAl4c+pz286aU5RRcmPs+YwalRscpRrmTCURtDOCow3oGcueb0hh2FdI4ksNj2pebKlqVDH7IEmq/acXqM0SDuA9cBdzwL0rAr8/Njma5pCRDgnMwEhr6lWRQ5dROCmNBfobEfPTCmgi5LHZ94/jKU5sV7k+11ai0pE/cJecE66+QNh99EZJCbyoxPmjhmLMT+7qjPQotct/IJ6LhBy7/bwZz7uUvPY7fYs8onXUtsKR23OwR/R0R7HHvMFo1V+SWsq+PzRRareq4CBRZuaoBvtAwwjYbVhHWzymFgzXXLcjDXYrsmkLeKIjYWt9lEDAFRRww7iqdNDDrmtA8jRIJA96ImR6Z3LHGFpV5Vax9wz6ksBrhKAhdCsGygZS0KfG1jUIs+hPDALA4NEeyOunbIVgmP05K2nsNvk2uBilce0O2kw1fnvW8UaEQDNeuuy4mdEDSMSMtzGBhVgGeVJHVKNOwy8T2FYaOslItecRPR251EPhL+zlqlTxJoAhVGIZNK6NPBHiWSsLi0CIpOjYsDUCBZf2FQUlDDmSJnZWDhGgP02P7cwJVRGamuseUZbBCtjkfbRRSqkqTQqqY9xDft5Luz2lTDEqKrsjSNldwCLjnO6kaxSDBjZIL8gOMOafZO+1Iut1oVJJPJvZ9OdAnSq2SaxrDrEXsAlzDJ1HWan4KcA4AzPxq70gdXULaL2sEyHyMaQgkLihisSIXMfwDYGAdvAq333tSI8I3kUds5SpHcORhB1BZ3WFI0eAXt01mHJIW1mynMMOzx5mR9Gs2rhOOkW7M7vlULFA7QR6YmooEjBoAotSoRt8sSRuHieBnRWZB5Es8gUElch6Ry0hi7d+QzLZw9pbklkLWjqYzgrPSFMMfZTUTJqiY6rVQUvgIkgKdfMbAGgaAht512RPA2J/WsJcAL742KyPqLmNS+5MSUfcMENWtKJs58KO4MQ6CKaPXmmoDtVgtdMAybsFamFSssoPmhAE8kVfcREpJ3QSHk/l/48R4qkuZ8Q2vz3hrtO7AMs70lQjgaqkKOKqG1Cwo7FbU/6JN2bpzlXnG+q3Y85++YwBwwkmvVEBi6qSs4XWCKt3BTgmFv2an84iSKyzL8YEzwdh6YiyKSuCuWSOIFRJ9TibAiwwFtozn1Jw+7Gvuh+peRRy3N6GgNEyz4uYToxg8F8RsTrZT4g1dZjukaWB65XWI2Oi9dxXtbaIbIXyNMx89OA2glTg0h/qJLWFW0p0yYoAohmokR99HBE6wr4xkaHjmtbgCfruEXUQuS8pzGy/BKC9qDHAM3U4iQUaA2QqHAtaMyUFAL3LS5XuK6FYbkh9XzmYyjgHckxqkoXtKyBLjR0WgiutSrlClO0vlpEzlu5d6UBVQljCtOnVsHTIRBQ5zyHp6H+ON/1ntJMZN3GfsA4UPeLEOPGVqi4N94gUOeu1V2IU8Q0evwanp0dUoIHO9jBDnawK2FXOsJSPvczKHq3heDR1KgqUS1lekzFfR8LPe6qSVlAESj9BcZYGHoFthbwwD4FItFUmOcimih/a3wqRLeeKZEAX+L/uIiYRdiMnvsUGiiSsy4YQSVlMGmJVtgDoR1mRpKzkJ3pCq3wd5GFYpq3qOkSGXbxK6dhPfuvtrwW71EzrBfuMes8HMlDlQjSaVVSngHCq+dQnQrhsIgOBsyRDADSX1U3hQtxy2iutU2RWEgU2Vsag2OB2jI9uZknDF4AIiw8uwaGHn1HWLrRGj29uYnpncpZBHrTc0SBDS9W2bNzpipepnDsbdZrrDe3ed1kdKhcIVbWQlDbb9Cvc6rKCJExTEkfF5aP3RqK4X3FqNEri57RT5QovtLwMkcZwR45DcWIU5Hg9WK7LqwBLdOZy5MlKkmrboVzbh99pijq1xmSDABWIkVEzCKdIs8wmbKWLMEtQScoSPsDU686Po3VVEhwJwwEilROFdkTkV1x0WGkRy9sFXVl0UiUJKKp/YyO0HuBZ2/jXAr2MyPJulGIbCXwRjgpE7Tw40Xhv+sxsJeOwsTowxqOEVgkwCYuEwbhDSSfZdgqaK4r4SNMdsrkjAC8tACk/Rxzap/luW5y6mtO+5aHksnhvrMedwU0csTnWtctHDMXAt5SURXByK2AjBKgCR6xCVAcw5bCqMmpkhES4UUNA8tIsil8nLbwdgaWUJSrCihDEZSTUip9kCIHBG1Kj6VhJO56C7vNYzNxfQ1VgGa6Ufrdlg4IuwFOAFnPwg4R1sEOdrCDHexK2JWOsCw0hjlgwYJhTS8do0HHAq+d8++Gocd2kga3fNsnqxXGxMa28ZwfmmAIzhD27hTHwmYhLOvDboSn92v3zdzYkpm9JWhCdzVOTvK1DHrGhc/e+Tnl31O/wJIQ4hU9LMAgidy40PiHp3F20dN2VYXAIn5gHnkcp8IcLlHB0rkiyyHRg1IJiRFbzxrapGYEduorqcV0VWnWLXULpUptTaJRF2sYka8nfXNd1agi5RmY316trqFm1DKZXLforCu1xC1Z1FPs8LzrLwAAXDu6K795ArZ9bi/QgtM2Hn7MjODb/g7fG7Fa5Ghq0TQluhAeyCHOkHKAsJ8YPeYHCBSouKktasLnkzB5xwaJYdQFI4ZWO7QEbzT09pdGQ7G2puW9MKUJWne8lkZhIuBHJAaWrcWC+OL1bfL89ROaFUFFR/yMNhVuQnF4lXPl+UvkHMZQWGCKhgVCka+P9Li1NUXOXTNC2Y495l6kWvbKBwIbFwaNak9sjzRPe9mWRpg6Zji+QGrIqVGoT/Jntow8urZBwwhBOELjeLmXnyFzw4iAyLqMOOgnsSvipoO0CKQZdSU1IkYPkSKpyCwVQFYRMFHqthSq9CEz9gNFxqVxGqYRphZmD7YTpPdVmt2tqYDEfYmRYtB9aZLtyLFpjIIm8Ou6zq8/cUcIBFOJyGJKwAUh+zFK47DF0epG/p2dAIIkVpXIn4zYiGBoERadADbpLyhE2lR1UaEYuXYjXMlE2Uog6jMmwuOleV43DSpG7ZL9qJIqtT8hazB1C83nZVifv1yP2FyuMX8cbO1X+sDqxwmuq9H+wQOrsmgU43+hO7oTMPOAqTipGr0ogz2QZqUP23IgSMov+YjIw05C+tn7smiF7l91NVJHVB83rm5hUWsBPGj03JwUNzjECL3gxtaQBcOPiMLeUch5R0CYLgjkSWHGTiRThHLHGRj5bKZe7OBLUbVQSyGVTvfC3gtAZH0Cd0Bd7TveKx5mMQGehV0hrqzaDq2T/hD2kOgIJZRMRV7EoBZeVy5aZTR6gjNAHa4Te4LPue+zAQCf94LPBwA0aokPP/UhAMDjFx8GAFz4O9jx2V1uLzlkEYbUNta2RZH4gq/bjpvS47NUS465KRRZUXSn0lwoa0Q3CVohsUj++HlOlSyNLQfgMZ+/dhFGJDG8UHd5JHEwBLMSUWCmwkDQOoOFKAmv8nUexxaJac6JNEZD2O1lKiRd2LWou+wgVdxkx/WIcZc3op59gh6+pIyFQSPpAqQs2lxTPxQyWAGe1EFhZhp2ZtrT1AbClNQPo9TkC7ltsl0hgR45ZwfjoY6IJuMmm7QcA0BDzaguNZjloOW1ep+wFCVf0Q5LDtvdeb4ubqyVU3BVnlNH9TVeU4WLlJ/djkSs1uuiYSeoSLMK8EUuRsoNfr9rEnUYh4gwCOuNOHINJie9oATBJFMAER2dp6U38FsqJj+Vr/nc3iqUcS3TbLa2hV3CFuVxjcWKCFjrMBGoYQu/T4SWPs6idDwiJB58PNyjChg5R7ejsLNUewLvILpqfs+2I3okKolAe0mbKhfR8t6vLbPj2LnjwrIy0pE6nz0ukUo/4LOxQ0rwYAc72MEOdiXsSkdY67HHsllgFI4toec3DqoRhVWmrJLBtZr9FeRQGy63BSDgeeoPCBkoAZT0gp9DSXEk6YHRFhU91IYd74vuCJ6u5Yb8W3F7gbtZkD9RDlt6hVH6SY5MkWoIJKj0es5U/kDp8G+SgmKPV6THNum5QPlHEmIemQZ3uZN8PVShDf0Wa6b9RoI4xtgX0ljP8KBPsfQiJeGhU6l4/kuJzmZgQw98Q57EuTFg9kcUShCmXRGHFH7GzXrEllGvpErizsMTsn3aZi/YzBGXT53na72Wx/J5D9yFjc/e8mOXLGSPExwh8yudmSqSj2gIoU5j2qd/OgIn9FRSvIo9WvMYSs+KkMamlBCCkMUSTm26ktYZeU+N0aWwHxh5DL7HrEThmJGbTYggEwNh7TZWWLl83SuKZmodMa/XfC/79Y4XBSRxSUFAWI0o3Ir8LlNVJYWjhDewToVPMdLzHcdNIb0VZhcVDYL0HRFOreGwpHRNx/Xj+xGKrRC2eNehKOtuxxmBkcRs87hcO7mWyaGBIpOCxhX1XpHGGIZN6WUUkIeKBi3Vqd2SGRE/4pTzvBJo/TwWcIH47EYZWAKgVj5H0887eT42Oo/hY9NjAJDZRrhnNEzXNgsLJdFRIVPeFYaWmhe9qtvCLqIh5LYVFsf7awWy8OXyKK/7RS3CkFvsyKN46yKntNeIBZhwei2/vkUDL0TMAs83AZ7k14gaux3BQARimE6hYrYlEcQzhRmzl/JBHoPL1CMICEnkTJp9n9485s+11sJYiSAl4lUFuFTETl1XBGW3jInGeSoikV7UytMEb3bwB3mRgx3sYAc72KebXekIS6kKymtMa3oe9Kt05QqDcMsoo6uqAuec6D1uNrsiYSBeS+VqBM8csCXVPlTxWqXgHqe0Z8KgHPc0BHjJI9MDnXcB52wMXdYWmo2DDT2V5BKiQL/77IkN0654+cLabIzNkFoAo2T5nYFhHe1IisJDwt02e8T3re7J11DNON/cAgDcnnMT8+0US75dZDesVzC8F11cmYSJhdg1gSBOOXj+bkeI9+gucW2Rox/HMQq9xsx8+0wgy/k8oOHfu0Ao7ZAQdhx/sqPPKeCJJ54AAHz4id8HAHzOZ39WYRl49InsGe/6vnA6+k3+uagXWDDq9ToUxo+GMh7H2mJLVoFJWDSMRyD82AoP2hxgpE6hWXeJY2lkXrL5+3SxKGJ9t/i5s2uwWp4AAFasjRg/YSJEWLIC2gBtQ7aKZb6+2rW49VS+l61/PF9TpTCJXDrH0jm7hytzftZVjYmRzLzL3+Xg0DSMUESA1NW4Qxj/yPps27aFu87yM1wysITiV4wKKw8MXF8jazt2WUN5YfBOmBlxLFyOEI5ch4ns5hXXpOnqvaDlJNoUyyKMKhDq0QENx0i0WrpqCcNMSeJ7x37eRxTCiBIDxnMqAJDZoesMjuvreRwW+TseVU9iJAu7isIoAygWaGZG6bthiyrJ+uO4VBZHOj/jJT+31h228Ty/jLXaqIAVmcot72MNjUWXI2xNHk/021JDtqRLsYMrMHphkhlVxAWZZsxsEFgcDpwf1jg0WiI51uV7g3GbP3O3oWLE2YTEvfPkOhlirqWi4iAs9l1To2FEPIl0DtIzRFcBwCoDxTU+cW5tpwEThUIV67jGzUhpKIKkz8au9IHVtQ2sSiKzAyvCTaba5xiCdOdrzDOVOklkOgYPgVdJ4dbUGp4PfOIkVSmjiIA9mwYc0JBMU8LpcduXImfHnqo5AmsWukc1wQlhKhdF3Eyll2kkmmyax6JzJKF3ZZpCUWSFcDSFwqIQOQiznvAk0XKKB/Rd4RT32oy0u16fAACemM+x4wE08D4N+qxRBEBQBPM8YEsww5bIsVW3KN3tJW3a9+gFedfkTUohIkzP1HCqdELHvqRW8o+1w2ilL4WLRAUoglD6OR+2uw9uC/Gs9Jf4OJd0qBIxKl0hcMMaMRV6G3B8nVdwwkLAAzV1Co4MB0s+u5gCPJ2DkejOOPsivXGDSMql1xjoBFmmIhcn13D9KCO4OrIfhGmNNVOyAw+4VOtCQZQIMqgxFAmLjZbUXMLE3wUh8x00Ku5eDQ+GRttCwGrokHR1hYrPtRS4bYMxcjMjUrZqOhghTqY+mfYBmiiuSImPk9USIjRxHqnw21Q4YZp7qhpMo/QEEQHr6qLyOxUNOl82Spkfdd0VItwNQTJRRwTSNSX22q10C+tFq4yyISaU5z6UMdjDFy/WeXN/5NGP4AXX7s7XLWnHpDDw4EvSWBlmJJFH4Xqd5qGsOZlDldJoKXWykLkfG9whAEQE3cLs4XZEJfLAr6LC9eN8yIHO7G4342LDtK84CdHs9cu4DnvvMRNfZSuLqhX2bP6IQGL6Ogyi8pz14ABgIPpT9RpGUL/sO0tzKEhccfgSiBDF/hAzzqEW9OUkxMQDJv470NkJXhWpHNGWcylgnjcI0wF0cbCDHexgB/s0sysdYRmb5RWEnLX0GiiFwJB0JBxWGVcqxBIyx+SQGLV4ATyEqZBaOknbhVhSAkHt+QWFhFQi2ln30PTYBnqOoUrFW/V2QhBBtS3TcCrCOVGcJbS3ioVENSYBfvhSRE+J4Af/NIE8esneeNzc5FTa5VaEI5+Pz76W04PPO8mR1tIf49Y2pwfvkECzT2qfipAxTQmJqQihKHRWF69rxTRGWG8xM5oCucIWqxVagiSE7WGKM6pKxpCpIViMDVOG5zma2w09DNV2pSN/8hM8x2N5nD3aJjWoWnqhfDbHywUUI+J+8ugZJc1MY7oIVOy6F2c/TAnHhINberLJA6NA8AkE8PCYISS17GPzvrRHnJJ8NVm7T6kICCI5+CnPj0uCVXzokcbcW9b2Z7wPXTz/JWH+Ic2IIp/DFGic92rVLYEd3exK9KlIjFtVVVHqlehgnmc4roPEdKePm6Km3BFBo6NCYKpamnXUtRo15+wRo7SQ5tKDZoLDopM5zWtOUyHPTbz+eZoLI0YnrAvGoKcHHhTbI1yEZEJciZYHqCQsGYTshx0Mn7tlajCGfQptJuDp0sy4PZ0DADZUJB9ihFdyrfn1HqmkxrXwKTanSCKgKOORdEkJ32aa1SeHke0inkCWOM+wAjNnhHLcdTgmIfZwwWh+Mxf5lvFpMHLH9N5EeZbNxQb+hMCak6ookV9IKlgDhgAXx+xTXTso7nnNUf68B67fKO0Ra6btxmlbWgikJcVojZpzKvF+vffwY57LWugbZ4XAMoiMpdIai1okgeQZOUocJYgsz8eyQ4R1sIMd7GAHuxJ2pSOscdzBuBqW3pl0h++GAZH55YZQ3CwbzibLio2SNiGwYVXkvX2cioy98O+1TVu6viXbOs0TRtaSUtEmTIhJir3SYa6LiGTSCTPrAFJmrE1Voiglnr0CIq9BvLgY5pKbFsaJunaIzGc7KV7Ps5CWo+d37dQWiXLuSupH9QorelA98+p1v8ElIaxyLQgTLN09V0mjaSjM7MRoYJFqDNKYzVl1uliW5lmRZen9iK6SSIFjNU2oeZ9LaYpOMSfcsYd2j/DlPZG8dcfdAi3hwxKVmCbBLtjFPwEDOQsnwn6j92hZC9NkVTDjjIoR0Urqla7GjtIfI2sizfIYG9aQPnT7UQDA3asT3H9vjmAVQQi3hw3OyZhQn2Sofu0WMISr143wzA3o2awpXIY2JiwYIXymwJ+dKkKWl5Poojt0HSNNw7qbj5hE/I9etfczDL3kjuwGehvKXJTxRfLo2JOwILR7Gnus6ZFLXfPS9Ii8lpiE46+XcgtU2nNzztIQ7ic4eucjWwVCmIskiXPCB6owczzAtbRo6zKPVow42+CghXyEEeA11SIQwj5Lq0ZMRUpGuPF2asJjhXGdNa8qQbE9wgoDPmKpAxYexbqBdqzV9GxhmGbAC4gHHCtXxFcXvG9Tt4U7sZbm5LbG1lAiiKCQYbFBqIQwgFFaZTHzuW+4XsfJYzgT+ZsRFaPjLetHMU5oiiyLyMtETByvhmN5erIqdfjt7Rzt+zEUGIClOOlxu8QRgVXnd3KdfLfZwHETWPDznGngGNL3ZN9IKhUyBMNrUcqi7U7gbQQgUid/vF3pAyt6IKQEz+3fc/LNYUCQtIgl4MFbcK+DeRpVSzJCYCmF4KcBBPggmropSBZhLQh+LA9ZMe1oaoMoEhuzgAdQqPjTNJcOdyGMrHWDgSs9MsUR1H6Dt7KQTYTmpq55rSag9AmJ5MQEICphlWBa6cRgXecJcUbqo9BUYFsVttzIJ7+GD9IFL3pGM7SgWoikHOZdYTJQQl1TJRxJvxbTPEud+3MA4JwTd4tJGGQKVY6rDCI3RRUEiehLCklkLcYYioTF5Tmf68UOK6L/Gm4Mu80WCwJsVFOha3OK0gs69HINQ4dmcZwdmhSnkoqCIKB8KNRXW27uUxqwJTnqkgXqha9wjcwqkorCEDAxldrMoknksGC61Ev/0Tyi4dwaqWO1O7/E2Xn+vttV3hjQKAxM3c2cB3XbojKimUOiWH8H51M++DQpmipVo6tF94tgmjhA3K9GVLRhClvBxENjmEd4prZkHsypL/InigdTMAE7poTrus3M1AAmFvvn6Q5quVYjNEEKiRtb4CbrUsQ8yuSiI2pQ0ITN6QkAYFUdY9jmMVpf7vXtjk4z4q5i/npz+ykMTJF1S8qHtAY7AlwmWdcBhdFDSLRDiBiJ+hT2k6PjU1TU7tryMO4RC8JUeo1SGJD43E/JsOGgiwq0KA5vJ4/NkFNqF0zRR/RAYabhflG10GO+/rvrBwAAL/qMl6Ln4fPk2U1MEMLZfM0XU1+0+EaWB5wBaumh4jAPOhai79VJnp91qjGxX8oyXVfHCDXK2hAC3gjPOSApUGV0OaydEdfcF8ds5rgYpci88+xRgoeU4MEOdrCDHexK2JWOsKxpME4JA1kcxHEwKpYKYKDkhbGuCKBJp70PqaRZxONJMWGkF6EWIp3gi8SGF2/ZtMWLE1XgGVMBYgj8Pc4zenpVOiQsj4QvLnu1Q++BIGmCnNbR1gD07JWR8H7GQJaEkdHIdjujbYQzMd/vyXKFngVPieLGZsRTMXvdTxHaO022eN1g9IAapag68zviuFdlnpmGU0qV9I7moPdzX/j0pEcrqqp4RJpepE0jzjVTh0xPdW1TmCd6AZuYVFJCUjhujCk9HzK+/cVQ0oRH9+T+s3Ec0V/So4ypyG00HKvJVhgZDeiBkWQVcM4+oXmXx6pyFgPHplvl93qlCyz7eeok/62pEBn5iZhkZRQi58ru4hwAsFUavYh0sg+rShqWnv3RMnvut3uFNenlHrX5eTW2LQAGJKZe3AJKgBizRNozvBHPmGSkdQVFT/f8IhMFD6OH4TM8EUVZ6zAwyttshWQ4wbJYDqai+3ko9JNaZGjgcUnCWYtQiHxDTxh3ikXKRwA7qCyCbEEiuxI0LLn/PKPk9cVlATXUFdlA7BHOiYi5ybEMqwqBXI6OefHYWkTOxyBR8i4iEPQkgpZd3RVA1I6R8Rg8vLSScP13/QBeFmZGYnZ1hN06R0eeYBo1z5iMgC0YbewsNNPmpys+wwoIM5kkpJ9QVYJ3gKXcR5ojDOV2Hrj+IADgix58KZ56PPcovuv2JTayRxGcMcQtKjJvCN/pHCcMXE8t1b3j0kKRCLJlv56eDAIjUwFTbM4voeb82VIigdI4JwhEM/NwsjRYcC+r+CyNBqou35RA9p1rAGWg5QuehV3pA0upGrthW5o/G9LOdNaU3LOk8GKMJa0gst0xRRjmtduWWi5B5XoRgDuXOR2jtS75XseUj3N7hmNJ0aQQ0TjpY2FDqo971c+mgqlz7WLke87jCMPGUkGuVcYgUbvLSJNiGAqCj+A09JNHVcnhlH/XVaqku4RG6oM3H8XC5u/YcazGWeOMaeOGKc2qqVGzUdZAvj8WnanQS8OqQ00lXEFq2TCVzTgpooY8cI3otfvYB2R1hSePma7hBnax2WLNnhvph1u0LQwhRokOhIkBDXu3pDa5m5tCWnqtPeUgKNwa8sY8JY+RRKjS/L06XWEaiIxkCnS3GdAPctDmj5kwguscO/ZhdW2Nu+/KaRPdCyI04jzczON6IelThSQpElF4HfuCvpQmZgRgx7TTbiNO1gjXsUBzzGbori7y8A3H3kaFaRKUmzwbhWsu9/VI0/NCNRjZEySHJ+oaigfR3USOVnWDW+f5sA6s3YU0lYZ8kX83FeCJrptIDpz0Xn/NOQvLtF7Pz3HNUVHglTWZTEDFWo4oKJhoUWnRS8uXullvSsr4kmz9MWmcr1l31sKAXiPwGSbWdiu1b/bvN0QC+wmeKUipybjKIjAVLETYXmko7g8iaf/E5g4wEOVGTbZFVcGSbkxqdrVzaIi0nNds+N5MOCEjuzTNV0nB0hFxlThtCVpSb5NQiE045v7w4I38vD7rnntwdpbnnVm0OGLjedtwHdw2GEj7BEnDQSFwbEaXx+/WxVNIQlEnzcl+KqhDK+rcISKJPpfdO/MzUZBSYggpYRB0Mx24xuxVBBZUuK6qBlGpohf2bOyQEjzYwQ52sINdCbvSEVaKgJ8TmAGDKt3jGkaLIi4BGSFAmlGECFKlVMAWiRGRUqn0UonnlhDRz6ITI2g8v9frYSRgooehN+rpjfbjnpCzrSswO4DLyxwW78xcqICMJy3RHGGYnloQstTohJaklq2kR3REoOc0MxK4Na3L/UlkF4OFlfcwpdbUdfFgRSV1iqmkyiLDuOAVEtFJMzW8XNSlKYuZS1S2KWwAI9Mx57sdan7JfctceL5/sULFFMBNScGNO0SCBo4WOeWzahbwLNhPUw4FTTWjY7+T9aRbWi1hGHEkerdd2+2L3yqUZ6EFnagDdiR8TWRfgE/QG0YQjNSrqikEx0aADlNAYt9JWgjQYcYTjFw29NyXRyssxXunplFVm4LmHLb59RcS5mIP9jm6toCXMI+IsOsq4ZQp3LbKEd5uGPDEJqeiRqbFF3qFI6bNHKNa3XuM/BwnCNhm3yd0eZGjlm4VChuBJShkHgfMlCbRTDEO04yeaaymaN2okiacxi1WjoS5rWQUXJmrA8fe67msP1cLSjBhFBVlIV1ZGBiqck9MuZ1deAxc+KJqfLoeCqjBM1qytSnMHyNZH6aoMRMYtNWS+r4sfYTyHBrbFGXqYRQKL4OlUEsJ/dDZBk6kNQiSME7jhKm5W+c52kdv0LIkoJgC7YeIjWiPJdG2Azrm1x3BEDokgACyO5ssrfNbvzvgfWc5Iv7w7nah1Vqscjquu+s6zI5pf66NCgED+/2CgH38DvvMLMcPo7RmFbJc5QHsmD0RViFlcLzkvkSGmKZblHm02xAlaGMB7zSM/JfNCjAaszkoDh/sYAc72ME+zexKR1jr7QY+pYJCviSB6lR7HHfiBbGmMIXCrGBZbA5hQmStaMI+99wxkhGOsnkO8PSwRS7B1TUSPQOpiSVMhX9wDuKlL7HZkRduGjGT9WJHdolVPILAOnfr7C1tVELNnP4s8Ouk0TESKsCDZIt44bTdCzlaelode2lMa3A5s/udbBkn43Vcd5lPbcVen83mKdzh69YCR04GO5EmYI8RLBCZY1+wFtMpC8tamPS2AUBPT/ZOJdyEI6Y7+b3XSAS6Wp3g4iJHnPe7zL+ndcItlT3BoWXePOwwUHDvfMz1RacXcISUR0bBl7tQ+nAWi64QEktfl089BtZ+tAgDBsCR964hL5yPI9JOYPaMPJLBdZKVCqPEZt0XQUhPFeeECYb1vSPJ3eu2AHru0PO88B6J11pLK4NPRZk4yJx1DpbMHxc2v/eJuMZ6RQaWjUC7bQGkrNhL1yHgiP042y7Xt85TQGA0cmkyV+Pl5hKGEazAoM2cynzasW61Q0CUwj7vu6kaOH7e0F9iNjlytMJvN8Wc5cAevKOrCn4ifx7XWtcaLNnKIWw13lalx1L6l6pJlYa/HYv9F9YBBNhYPstxCrAmz9El8vU3ymNy/Dz23PlxwCxqwLVE7AbgehHV6+P2HgRyGG62eb1OaVfaTyLrOfev7sexypHw+bjmWO1w2+Y5fSG8gFsPUUCUdptYT7gQFWjGFJ21cFz3/+XsPfmaB4s1a0Rn4Um4LRlhbp7n9xw1cATgSA02Vg6RY2OER9VPCOwj08wqVcumZBo2twhCUQGzF/aZfH2rZYsj9jR2BHj44NEzWzSA+5NpYKjA/OBJBo285MVfituP3cY4TADejmdjV/rASkNCCkDLBSrYf2wiRrKIVxJaV6bo5lSzUBwl+Cjdh9xkG1ukwKVPwTiHmZuTADwSAozA4qRzGKmkygyH1qkKk1AzDVPhcRI6qRgnWE5EAYFNk0cIgqhiY6WtEaUjuHT06SJlrpmKalHBkSneiXpwGpFER9xJ83Qq32F4EN04uYHFmBf3nZDTRLfOLotSc2PJdt6tCtWLYgrUVICcCwKptJUpcvM9D/cwBxiSb566vKBdUxfgxywTfdgCTE+sCIjpZ1OanHfCMO8iEtFfZ4IIVYCj0il0Qs2DY0VH5KTSaFd5LC+IbNv0Qyn8d3zvOI6Y2EgthWHdrNBQlnwUReftjBWpijqmcEaMGAj2OKdHNTuNiRvMjmniGaoQwEZuymG+hJJGaqK2hhBQiR6ZUOXETZkfnpvG7Gd4IcnlNWlXA4lq3Dw04jQjCts5p1WYgVEcOAIjautgRK2WqWE/bIuKwJIN2nrW2JwLQMhg5mTuBT1ZBVjSG1U8AGPKhLAA4OlYzlHhwgvlFXvthm1pXk9MVfquLsSqYGOrrtuylgr9UwoogDZJkUZVEK2itDvMAbMgG5kO1whYtvn+ri9zivOFi7twc3MOAHjqMh8+26hxRHDRDSIuP3d5P1qmwUeyot88v4meB8MtQcq1rtCWiXr3HGecqPzs7uf3TzuHM58di4kp8Gsu4ohgqqlti5r0xFSpchHK7dnS+UFoyB7fMTVnXQvP/arsO8rA0nnRR5z7U0RNR9xKf1X08ExHX7D/a/YKxG7A84ALc8SNGycAgCOCRj5y/hQ++OEPYB4PKcGDHexgBzvYp5ld6QjrtDpG8BGKp7nQ9QzbobiNAi2tUwPLwr70Xmnlim5OJJU+TAXDdEzkZ4QQUNGT1UZIOi0SQ2qJMrTRkAykMAZog/Le0I9IlF1oCoRdoXLP7AWbZ4+JqaMpClBAl8hJQBxGKXQkZW0JDqjrGo4RjKQn9Dxh5L0nevi5zYm9QEtGWItjeJuL5F0gYKB3OGVaoaY3vTxeFvaJyx3JasO6qJEKK0DQCTuJwDimq+US8zpfqxRz4/oSidHR7SGnWS7GC9R8dicqe6qdWeHu4/zeW8gpwYttj1nGmt8PBHiRofAT7rEnecwDiWmTxZIAhpn3Nqhpz6KxztElkkVF2Y5awl9tMBBkE9mX1LZHWJFuxhP2e6c/x47gjJHprtENiFFel8dDVQGe/T8j7yOahIb9RJEw/k2IpSVBmALGYSiyMsK+oMcBEyOTCzquEzwSo/KB6SdvHCp62HuF5YDJC5sGPW1XQfEaFFM+ZkqwzDgsSAmlZqAnhLqtWlQESYxMMW/SJVaMJKR9A0kVUIOwpI7jiIGqzYHMCEulsFzkqEJUmC7nHUamDC0jz2AV5iRtG0KmOpfWFfHOfQwFIDQwSg5QMIysowBefISuJNvCtP1uA1EfOaFUz8peQ0uUj2Nq1u8mhJb9gcfsD1QR05pSOYTnn6xWqI8kBc303UXAvUf52XzTw58BALh1NuJ/+43zfEecO1ZN8GwbaNsKWniBeYE++gJNL8rJsy/AlMRxM1WNFIXdg2nJpAvJ9nKVx6CeFVacv5Ggm6c2T+CCQAzDa762ugvXdN4/nryT73fTTwCZOH7/5vsAAHdu3cFmc1b6FZ+NHSKsgx3sYAc72JWwKx1hHekVtAUiayw7QsWxTcWrcvQS4HWhwhDGAGs1NNkPFIu5QRlMeyo5AMDkI1pCNxdUPk0W2M3sak/ifasSYYkuQYypEI8apYS8QWj50Lh9F71jbt/bCnEU3QsJ2VCaCactc/sIRYJjscjezVG3KpxtcWA3vw+YxMuXhuRocMSC/oL1KOsTdoy6RN9g2Va4cZwjnNPTnHtuVytEm2/g5llWxH3szlRAKiM9phASougksAFa6QDWa3HJOpk/uw3bsbmTisOb1KMfRdYgj/l9qyMcrbK3KmKN/WZXmmeTsJZUgGKEMIwztizOX0gxQ08YDZs52TZQaZUjc+yL+I07QcdGb8c546cRa8LsBcYNHYsy7djv6whGEcJMT1ZrWyRaOkYgSkese2ksJhS7MlAuP0OWAHDue3zwMjeJBtYjz7abQkh77TSPi21cUYi+w3pgmnclworshLamhuN1OTa9OqOKrE3wey5JGWthQ6jqrnTSVmQ0aG0Fzei3a9tSZhVJnN4PRahP6ipKxwKmkNaQFIFIdoeWgIy7qhqOjCS3NqwpqoAkSrec+8ZUpVF25JzonIMTcmxGFkgoBNeRrS6mrvdgm1mg7iMGipw+vs7Q9LN5i6MmR2J3UY5mtbxWmGGe2DwCAPjA+RNoqDgsvJxu5VBRQbhb5M/o6hU6NoRfXJDzMFj8mRdnQNRXvuzLAQB3+vfgiYt8LW/59fwZd+oJcxTYeNqTymphJOkhfJGVFcLpPVPOjuwoSZlCsj2QyzP4WBq9PTM72nbQHOs0MeLdagQ+ryUoOVMtoDlnWmY9ggOmMY/H5TaP5VMXd3Df808RhDvyWdiVPrDCJkHFBMV6YsNw1aItBwPrz5j7CC1pM1Lg1FZBkXmZewtmpwp7+iAHknZ7oAPTQCGGkjooOkVTKGzutdDPKBS5aPW0ArtsOrCqENhaonIaYzDzEBGWaqSIkeSsW/axzCGgIeGsUF1UY8AJD7GKk8srhYEM3lhLz1hC28hEYX/HxRpPhrxoJD3mLFDVpIzipuJ3MxTbOxqmSBtdQzMdZq1sjhqTEFuSRXt3OWDg94kqbNptcMxF62phXm/Q83C9w34tu3sSjgz6QnflnEVDlKCXhJFKOBGF3QjMrACfqXxvyc3wvIaJwI6oPLQRtgih7RkROQe8oAnjhCBgBaIhh92EIGrWBL9MSqGVFGqbF/Ki1gA3QDJHwSGVZ6wV+4CixkzqIO+I7oqxjAfKHIrQotZKVGxqKvTcjCdRup48ghKF2PxsWteg5hxbeo59ZWHpxMg8Pe8HTAQICIIwKY1ohDyW4IumxeldTE87l0EzACo6G607KohdxXRijB5B6M2FLLVqsOIhtySY6sg5bLf52Qnzw/XTEyii3SqynhhTY2ZjYGCdoLFV0fGSvr4UEgzTk7qQXkcSAqPQQCUVMCYhfM6ftw4ens9d1murl0UtfKcEeHAJvyYtFT20rjtCJJvFyemqPAct5+hl/ttn32Xx5V+cN/WufWceK/sh/JWHmWI0GfTzU79mcUb0X9VatEQ3luc+7lBJH5ew8USPQGRnL8gIoCg8eB6AU9zrU/U9D/flUSkFSG/jrA1WzQkA4JSpe3WpsWFaN3GtNMbAB5mXBCuZhBQCkjjlz8IOKcGDHexgBzvYlbArHWGN6xH9MMEG6WNhasu5fVjPn8EDmt5+WtKdcAaBDBce0qMT4IURg4qsWjvMkMI04cjBI9BjFJjxOM+YJJ/YChxdFxmNqqlKKkIrgcmnkoZRVoAfCpYeoHDPRQR46nipE8okRAsrCqx8lH6YkQgvdS2jjMqUFKPwJCJ6bHmt20jgxDTjksrLloXbu5sjVASkrC8zIGK9G5DYG2UXeTxWdQNDiK1hxOkTsCbwQPrF+t0WW8KChYRzVbdITP9Y+lButtgyjbhj2uip8RxqoFQLIxprOlynJtSWBLuX2wuYiZ69alATPi8F/qiBIHpNhBnPUVHHYr8oXArQ9FaL9IQGtCgJi6THnAoQY5AMaEpQBJUsCJy5BlvaKARrkIyFZmpL+hqmOcATQNQyQnG1K71jSsib04SZnu4FSY0x2kLOO5N4dhii8HmgbkUWRBe4t/QOYs58ggCKJMt6jhjYR4ggjKwtNK/1kgV3RI0FI4kpzNhN5Gpk75PSurRtiC6V0TUappFq8su1zRGSyveSOPZzbRGoVbUkfP/o+j2lB8xzzrSmgWFqWSgdLQDMIl3DuQVb8rla1KPHqXDsib9vtAMIV5cEix9mnFG+JbBdZdtPRX16rfJaSuoCg7SNcO5v/BqGJYqGpLApDRiYRu4YYT/04ojP/9yc/h1C1lyb1zM+61qOYL7pL+VrPhsUfuqdBI0MEZZk0tJ6E6YeEMAJJ3XlVFEX1pxHSqUC85efVu3VlgeOVT9vkS6lb5HZqqrF9fYEAHBsqck2aQxBJJqY+l5ojJyFPdeK6gesb+8Qp0OEdbCDHexgB/s0sysdYTWuwThFRHqURezT2iL9kZgPV3MqDZJjENFGhUivZiJrd9ChjIrwCwa1r1dZ5u59mDHR85cOdKdq9BQJ7Fm4161FS8+naQwSm1hn1iP6YYKzUsPg9cUAEfwVOOo0jsW7kIKtsytERhpSNN14j7DJYIZLMlerxpW/OyOiiAkXY258vKSH56u6UJVbAWd0CxytGKmxXrWd1jjrM1xVuOSOzDXoWSJddumbCjULwTu61+e7Hq0SiQiyCJxcgxJhTKq4tr7HTrj/hOcseRiJplMe09PmFNePs9rvWmdP9lZfYdhlr61uj3FKGRCJsG2TsKZUy8VtspCMPXh7qPm6RaP3cHYv1xAKs8lpm0Eo5uQUM2uItymGudlcILGorfi7Si3QkmsuUdnXmojE+bhl3t+qWFodHIEnK9ugngWAQ8DJHArf24aR1jRN0ALzL/BslLBBak/R6sJX6Et5zkOJZA45AKvaY8cWgcI0XzdFWPCSIIgnxttYki1cISAmcm82BODMoVyEROA1GhyRf+5kmZ9h445xZ/ogAODJ6RwAsDM1qhtkRmfzujYN+jVrToy0VnWNBVsheoHGz0MhFBCGf6Q9q0Qo9cMtFGuwso/MEZiCCKhyEroJA7MGAu1XXpUWDW8JdDEjEqMMEZbdrHss+L1JU+05VAW08MK78+te/qUex0s+k93T5glrifdcz1HcX/3LCe/+0H0AgEce34AkK3Bkn1Fx2MsnawEX6dJyY1iXdVrDSAuPqJ4ni0qYdY65ZwyA5RxsWetsbF24V7c7EUA1ENYFL4K2YSgCqSP3Wmsjog8FcPNs7EofWOpY4ejGCmmUlAY35WRLSgXsO7K6LT0cptp3ggtErwxsHFjqBTZcyePgserIZGB5WCBgnYQIkj1VKSB1gpDLnzFcDlgwPXV8fARNafRb7Fr38w6hpGbIiIBUNiLZzFKY92hCFkZ1NIjcKEduYioFjGQk2HECnUzNHklFuqZoG0ySciOhaEhbKC5+Q+DJRb/DfJxfd3TCdEyosSGpJjN5OJ/PCwihaAyZZVafBYpkRN0s0BGxsaDGVEi7QkY7MEWXmhotC8aDyvc7Rw/FwWqZr7hROdygw3C0zN+7TA63LvKirl1X5AyETmbGFpcEBfRUq/XwZdFqbsZQCySCQYTVIsRUerwkZ/KC+z8D12/kQvjjt58EAPzvH/odJCrIFkarypb0cUPy3qgSRjoRFWmCZh8KeYqoJN/enBW1YAG62FTh6JR6aEwnD9vL4gxVdITc0hSiVtFhc6nCcEG00pKsIRboRVuOKthDGIuUhCVDgtEazE7jbgJyzi4vsePhpTqgqkSTK79wGwMSD6pGpOhTg4q/q8hWUqUJlhRQQow7YES0soays7O9s4MnkKiqJQUGbAjQ2cXstIWY0ILPk4CNo7rFkkg/cYDWk0fk5h64uY/TADKxlbmTXEI/539f+n3KT/E7klyLV0ii0E2hqmEEROgqNkzbzyPMNo/vF35xfq4v/bxp72jzmuplgqfCMtvs8Fn3t/i//MV8v//PN1pcNHSMFtzftgYT9zrweQzJAAQV1QQZKVMXQmU1Mc1qzrBj+thzztY1sDziXKBTOswT1qSe8uxBi0ohNaSeoiM6rj2iEkJtOvMqwmiFpA8pwYMd7GAHO9inmV3pCKu5YWENYCM9Py9SoMC024MtAMCoGkt6/oYke7MesOZpH1ncNtbBkj1gZqpsu/OFFHJimN/VDRQ9XokKRj+hI/dbYJjbr7c4u5mjqdP766JiTNFdpCntyXMZ8UwxFiiuIwDEKRRXfctC9uADEiMX6eFyVsMSQn7E6OauuisAEekt28UITZYMbeiBGl8iToH2PjVd4Pcucq/VcU3evWlCZKqvpoidm2ckvqcywu2oCjeZhJGmajD2IleQ7/f29jZ2TBMM9KqPlielyK9mkZ7QBTq/DZe8Po2JAnlO+kBOTqHNCQDg4ryHYhTaHq/4TAY8djPDhrfsvTIN0IjqKhkgqqrFyMh7oLesnMVMoMvFnJ+rO7eYGJmMvObGteiZdtTkMGyqFS777I0+eisX7i/VhMBI05JRouoTlETbBIfsxkGoLbFhrmB53KFhdLlkPjk9je0hcM5GnaCZ6pO+QyhbhPdqAQJVVnhYse7z+A7zWFoNZuZFp3EHkqNgRWLZpqoxE9IddYQmGWwjyqJqBRJ6oCYLQuXaouh7tsvMJdZEjAQ/iHDkxfoCRpg16gw8MJVDx7xZ5Lr5wFOPYOuldYFEtt0RLCPYnswZo4qwwulIWPg0jwWA4SWVgQgrEY4TaZIEzfGtmQZseg/FdLQk7k3S2DmS8n4w3++9y+vYqSwNctGS4zTU+ILT/B0Pfzmfa4eSJRGgSG0sNMFPE/v2xjThFV+R19pHnoz40Z96FQDgmx9+AQBgvfkIfuM//O8AgN88+y0AgD6O0NQi0sw8BJsQSFa8oGDoE2cGmmK1dyOnvi82j2G7YA/iseyHvoB8FKO4oCN2IjI7CDG4QkeF45bz3Pcj/OQhdK7Pxg4R1sEOdrCDHexK2JWOsLpjh2m3RSVcZ/R4tr6Hrwnj7qRo6dExBy/AicvLGfM5IwrCNFVj4ekeVKxVmJ3Cmszbj4YcbZysutKUFyFNdrpEPMJVBxcL28Odj9zC9XtOAAALRj9TsyhNiWGUJltfIrRCHO9MgcSPjEx8mAreVnL7rTMYCe/2bOCddYuZ3uOOEOXJWhjWBVaLfE3r4QKBkY4AodeIeKTP3m9NFoE4AzWbNpdJmnY12ka4x7LnZgMwsEYkDO4r54pEywVz3xsdMLAJ07Ew7hYdAmtrcSfN2gp6mV+3Nflze78pEOYjJdLrNWqyk+sYoelhR0htYsTE6PmE8O1UpVJ0F345j3OoWqJyea5difiWjLq2Fzdx+9YTAADV5utXzgKsF06Us9nOCWvWpNas3ZwHjbQhpyOjEgsD17H4Tab3k+UCWz67gfexrTwC6z0nZGZvj65hQ4j1GEQiZkbge6JEU1bh2l0n+d4GaXWIsGQF8Unw+Xuh0oEMGn7WsGyZgDAtdApWGBbSGgPryiJOunBLBEaBjq9LzmBHdo8dnweCh1B994oApnmGIYjicY7zcXeKpXDwsXZ6Mewg6BgBy7RaQwtfKFEJQ+VxPuYIe2TNqwoBLd8jXrxFgmFkFa1kCkwBdtgNiQUGjcWKNZ0NW2FShVu/nZ//933XnwEAvKD7s/ifv/ef5vu9L9c6rx9NeOmX5Nf9mT8t87cvvJwsH0LPASkw8md9K01bGORx+av/Y4Ux5HX6nv9yPwDgvrtfiM95YR6jm4wuH9+9H+aU87LlPfUbxG3+vruaDH754tMHsdnlsf5f/nqO2H7zg/fgZ3898wB+ZPNYvj6l0LK+HwnZv325xeVG2EzyVzi3n8vCrVrZDnrpEKaIx3CGZ2NX+sDSmHNKkP8vHd677SVmbj7Gia5QgmdILR3+o52hBflEwli126N6TillcLJsMDM1t13nSXq+3qI6JlLmtOMFGZG2KjIi1+86wcQJhl0Qwgc0lA1YWA/vc5gt+csOBjU3u4aSDDFGGKHNYR/Q2Thg4gK9NGS/qDNrAwBsU/7cjTVIQahXSJZZreAKwIIF5R5FRqXi97vKFeqpkroMqTAPCElvaxroozwOl0ToxXHEhjpXotdVmQYdD9cU8/c+cLzCbR5eQjaavIfmYeHCiuMXEUkTo1xODV3sNtgxj6VIkBrOe1iTv9e3BnPMi+EW+79unT8Gt+ABxOK7d8COiFGenYhxhhcQCKWiVbRYkW3jiIeEbWtsiIa7EIXrKpV72fBzQwiYmG5u7sq6VPdojZm9TLtNfl49dggEuCxJs3WjO8LFeQYSnHHzXuuI0QhJch7f9qiBIfDHMB1jIzD3RBae5bGYoHF0LV+DgLTOz9do+Awrku+qMWCzveR4EGVZLTCzX2+7zddUtw0UHbMw7OBJC2U9n6fxhZFi4sE1BI8xyVyVMZrLPBdyEW1s2eRENXxKc0HzbQguqhtbDk2jhXzalM9papHTqBDZJ7YjsCqGiGj2PY9A1nVSgn4hg8YYFSJTai33ljM/4lxSX3z+T/5uxHf/lbzG/69fS+mf5u1o78+p4P/P//sk38edFb7tNY/y70Q2JwPNA1JkUjwALTBAljaGXQPFudC5Ad/yP74DAPCnXpk/74VfpHD/Z+bxuPv5+XNufaDFiqjEif1hF09E3K/yQXX9+p8DAFTd8/B/+6u/BgB46PN/Od+7fhA//kv5EkYCT1aLDi1pxi5kT9gkULYObMNCbQFHUFMlP6sOKlTwJgLP8sA6pAQPdrCDHexgV8KudIQ17jZwymKm0uy4EeLACNEAiOQU08bCJyHtIkNFBVTk3Uv0fIOKUAQwLI6yF+G0LSwa4y57oJt+h54w2fFcxB8VHCHClcg+VBrxOr3D0wozC8kjC9S7cYBixFdL71iIWLKX4miRowsFg2kQVgamTEIsYXZPkIk7UaUPYzMxytieYSKPmqjg9pMv0YWGEN0u0BG4UJGXzKWIyO/VBRwCbBh5TPRA7z09Qsto9cOPZALQs+05lsdMDy6zd3s2XOABfvYpASqrk2PoMyrxUhQvaGASrtpWZCRsxkwDOO4yOWgdB+w22bOLjBgHNeO2zZ5/7ydgyOCIVghA7Q4NU24j58RRdx3HVVYS7hpJgW0wECK+23HsZ0DTo4wiJlgZNEyRTkzRrfsNKuGGZPps3o24uc1pGxE7XCyWWDB6iETi9PMWltFMx6ikaxsck3VheuwjAIBtmDHwmZxzTm8udqVFI5FOo7UdTpgOF54+7VWBe9/qz3iPOwzk06s8U9Z+xIb3LgKXq+UKhoKcA7MDYegLr+G4W6MS0US6xLOzmAiImBlNJ9gi4TPMkm6cSg+lZbr7ZHmCthWuyXxd4zjveTYZBA3JYyIRtrQPHC8Uarr5HX+euCXuO8tjfkbmlNtxwMg03ERQyzj1wtkMxTUX4TAyQumZTqlahcisjCMg4sF7K3zjN3Cer97KsQS+5svzPb3q/v85/+7oH0M3OdL97f90zsGKuC9n9bBaUQYnJKDK12A6lhCsxcAxvzsA19j+8ff/Zk53/t9/ZAMt3JyLPLeRtvjIY3w4F7kV43/67z4DX/uq/Lt77/09AMDzT38VbZ2jQcfxuHN2gY88RuHJe/NHYKpRxzz3rwvBch1hpAeRW1/nHBzBLInMI1oDUz8jHORFDnawgx3sYJ9udqUjrCl4JGeQWFwWSO5ydQwvfGo8ze2gYZn/BhmTQwgIrPd48oKFViHR+76osydz2lRFHr46Y2HWBxwxl7w4zfWUoU7YIeeFO3rzSc0Ix/nCbl9c7gXhmGNvmiUuL3M0MBJk0NgajcieMMe+PduU3HVFb/6uRuN0lT2nu46yy3N64248ciszBfzeo+/N1+Jn1ItcN7pkoXjqB3i2SHcrijXWBpWiECDrVnHsYeiVt4S/K5Vg+N6BY7qr9xDsRBb4atbwbCZtKGP+Bfc/gEsRueNz25ydQzNaNWzGPL5+jDXrYxs2+Rqb4ES8jkCX487h9CiPv27z6y93a/TbcwDAerfBohGOw/ze+0+PUc/5nqZ19u6uuQ733sh5/EhQy1gtcXtHYEi1l/RI/LybOkcm68sNThIlWPg87FTBEJxxQhHI0dVYLwh/JrfjrfkMT/k9swaQoylpDfAk/PvIracKy4BEJcpHTAQojKwpLbXFiiCkps7P3Ji6NKLahjVAaGz5vcLCfnx0hAsCa27fyXPStRZtSwALo5ZhtyvrSrgRK1NBSY0rajTMGrRcN7N1uHOen8nE3oq66WCE5I4N8iqGEiEq1rqaagHHaxSW8N1uh4XO1wo+Q3szYFnnrMHyen4ex+01BH623ub3LkKCeZLRJ+tRq6MK7ohRN+eW0hGJYCth5+hMhR3DRuvyszbNDNewReP9+edf/YYHcff9eX5sWd/yW+D4lJmQB34o37YP0FSMeNFL2WwdG6SeLRhjvkePiM5l8IarGN6Y/4zpMs+J9/7WGR58IH/3X2LN7Hdu/U186E6O6D/wO/8xX8t/egr/52/I1/03XpfH9MHnfwBuyM+7CGCqFlNPppz8bXj00cexvpPnwvMfZOuK36ES+STWwmN3HS1Bb8IhOocR8yjCksw+dQm2tlBCavks7EofWGdhQJwvC3miZc9CqxZQ1M8JLDabKcIwbFdRFDsNluy5sUyj+TpibNgDUzNd0J/jyZl9ItyE2q7CsejP8AxaugY1+40mFqPdwuBix657jNj1OaReEeE0+oTldcqdcPGMmx5PUamz7vMjOrIr3Hctp8EWTDE1RwkvuPZ8AMAXvehl+bquHeNf/f8ycWZNTqPrN1boSJszE3RhHGDZ15WY3umnATNJgEdJd3qPBdNJjkCMOmicxLygtLBHdAaRaRNFxGVnW9zLjaOlWu3m8TMMzBdtBhaWky1jKCAZGxQeWOWUxY7sIk+d3cLZOqcpTMv0b+1KMX1BYlQfAgzRWq3KirUA0HFzXaSEBd/jeYgpBQxMVUnqyMLBs6P/qSczqqtta2hScm25mWx2W9RMI9+9JHpy4TCQmeDWRS6CT9qgXuZxu/eY9DmXT+LR8zy31pQmWTqDawRbtAR23Dx7CrdvbfjZ7LnTBo6SOiORfqPx0HQEHPtsrDGl30kcpcViVQBHccopqe20w4av89Rcq61Dt8ifIyjV7cUOl6RF6o7YQ+ZaOGHiQItWxlAQhpseapSeQYJpdF0OJUeqH4M9NVpkr9RmNwDc9IVSKSaFpx7Pc2F4Im+K90/34jNPXwgAuNbmTXkaRlz6cwDAjj1w4507GLlnxFWeW9V1C1+Tpo3US8YoJIKe9vysBqfXCIhialhbh3nI43B+lr/jS150gUWXvzfO+W/dSShq0KbOqdTGGQxEKAR+XjARdZM/R4AdEUvAZoSyx4fy9+o1blxn6vbz72DHvcLwsPgr/8Ob8b+87r8HANzL3s3/1080+OIvzClDLT2q64RRiMHJwLL0AyIRkobo1bs6h3tddoIq7mMRHk9Nea86afP+tDo+heFBf/lEvo/zXY+a5ZcV536sAa0Mkn32B9YhJXiwgx3sYAe7EnalIyy7ajFMEaoSUUJCaEdAC0mq8Ig1BpHoc8ViuNUOx+y6X9KTGcMO5wJQoGyFvfcGBoIoLpji2lxuEei1GCG3DQaRkh3iM6zP18W79X6EpjcxCCFqt8DAMLyIphlV+n4co66Fa3Ef1V1X9Fpt9GgpEufI2OCNwmMEPVzQo5xXAfMme6MbpgS7eoFjeqvTSP44m7AbzgEAdzb583oknHTZW62Y2lLK4jojgFOqrvp+izO5p4bktifXCpmqUDvOR6owYqwv8ve6pkJkL40U66cnz/DgZ+fo8T6CM9ZPzJjW+Z5OTpgiahXukCXB23wtKRmAaSzoiB3TvgKm+f3dHXSiNFyd8HmsEZ7IkfCC0PkXXLsfA9MiFYvp07CBovTL6WmOALvVEoqe8+Xt7L0u2hojI84nRgItkkEbmapib9Y9zRFwTL7AkVhgraGZ9r24ID9f28Gc5Gu4zXFbqFA4/U4FJLOq4TlXRyfkzBOc8BkybdvZHYw8G6aa4xRhmIpsmD3QMIhM4Unm5nh1hJbRW70gGa028JPIrgAzMxECgvA+YMH5Uy0kkkgYGa1quQYkjFKw5/bkoJD45bOX1CGwW5MomdyEn33v5+GEUjLnj+T5/sTuUehrJGxlqwC6CoHZAlCqJy0mDOw33JE8ehqGouIbmeLsY4Ah76VjGWHeBdx8f/64P/9ncrT6ws9+Eo5ZChGJ7acJQnt6xIxCmjUagm2iiCMqDS/3SRJh5S6RmIZILHdYEzGzdeWeuxqsn8iTYfX7+Z6+4OQ9+Lf/9j0c12xmrTH8DvsamaJz90YYAsMS24G8rgAh/85TGg99wQvx3/+pPL6/8fjv5Nc3RzBcS1FAQWHGzPUg/JiVr+C5zwnGQoWI5cIg4RBhHexgBzvYwT7N7EpHWHUysLZBRYBDK/Bg4wuEvV0RMn5UYSDEfdNTfiN5eHrvcyfszRpuICu5yZHF3c0JtixK3mQT6ONVxLrPXlwkaKI1rnjxkbUP567B0a1y/g7WQ44QiDaFn88xUfVPOLVUUMUjOj7K0co91TE07+k2eejWt85wzCgQDRmYb3SoKQW/YL3q5mPn8BDPid9rPBKZCxrhP5uMKFcUgbu6MlC8lzUjMRVndCJRwFrG5eYOtoXDLn+EmyscM1IQ0cGhHzGzXrGos8fdwaGjxLtnA2msPM7oQS9qylE0Na53JwCAamJtxE9FZNGz6L9atWjJCL+2ESPvZUeQgYJCT0YPQ6YDC43EaACMpm77bVG+Y8AJXTl0Nj8TEaCsTMKSUXbL9od+M+D2JUEljDZG28OpXIg3xAUv7DFWvOeWjc8KCZFsK4HXbrcJxzsKd/IhtssGdc3nzqb40+UKgbXaM9bkzqanyvUtJCL2O5xd5L9viQt3usGKrP5Sspm8L3XXlpFlvbB48N5cr7hGkMPjT97Go9Je0FRgoIY0CS7c4oQCnws+//OwxY7PbMusxrgeMAhQSpqAO4eOoquuJVhmM6Dm3z/j7iyxcbc1ePcHckTx7t/LHHr6KOAe5Dl4n8uceMera6WV42wgcGY4R2LG5DrbG+7cuYP5gtFMw/rnwoBBDy6435w9OuP5BAN9x7dkdoj77h0QWA/CIKwfRQ/yaQKes/RCw7JJ2ISw50Bl1KWhkBiNSqO3jwqKciBKTVh0OeJf8/lXdkQ1cP4QsBPmCEdwBkvXSAFIbIsxLWuEzYw5SaNy/nH39cfwOS/K4davfSBnCupFQEMFAIkKL8dbpeH6hJHsslrhgtkdO7BdwS3gLgG+7VnZlT6wIhKq2hWUii5KmrYQXrqiY5PKpilpuTQH7LZEGlFOobUtaobeDXfe+c4mN+AAqDip6l0CW0fQcENNYypMEpah8MLaQgmlpgjFiZOIGWhdQrOVJ0ZNGlUqvBBB4a5usGAvyq2ncpHz5vk5HrmZwQCPXOZ0RmoNzphaWi7zxnpjHnGH/TVUOsGdNOKMu/CpyavoRr1EZMpIZKCgUfSXRqbHzqeIyD4ykbqYAQQCKyQNW6FBzX6eJGSuwWMmQGBBZ8HVDS5ElZUkskYbgCq65xdCWmqxujsj+UbS7UxhAzhS9LDPYzenIhUyzqGk5hQ3uKaq8ucDmPgQhziXsdZMIe6GHRwRd3ffRb2mVVdos3ZzHnuogJFsCyJNoisHzc1V5BNsZYsKtCbgJWhf2BEC55gGIMLQmqCQVbfEhnI2W7JIVMdLtKQeUxy3PqyhuA6E6kknC0MHo5devpBAQhIEalLVbolIyRHNNVA7W9IwKxIdnyyO4cCdl2hR17QwdGi8UXALPhNeSxoGXJJQdycHuIlF3TtKnsgnnEAONs6P5BCZXjuns6lCwv0n+QCK7H1799nv4kM+H0ApPy7YE4t+mZ/NYz6DkS7mXVH0vkxCJJ3guZGK1tdRdTc2nJfjhshBq9B01Ldi7+DdQ43/x9/I7/miL+RhMCVMBBWVA8kVdRGAqMMEFJJnYbVI0/75CxhJR11otYJ/2v7A5+BsB+/zHnByJCeRQ9gJ+EzS4gpSJdGicG73Fynoz+QBTV4otmvifb8/4jd+Mz+7R2/nPej+44gm5uc0kBJsjiMqUXRmqnqx7KDZt3hZ2ILOYZU59GEd7GAHO9jBPv3sSkdYrnFwtUIUtVUvXlCCEb4yRhYpFjFdJOm4jsDIVIRwrcUa0BQfm6jw6UOE1dL7kl2fZlJQSuCtZFjwHoG9QxWJUfU8lfeaPqBiiBMYaWkAlnIb5B2F74DRCIcYC7Z+wFNr9iXxd+2913H+ePbyn3gq/4wGOLmLkOiKrBCwRV6kIh9Zb31JvQRGVdvYF7JNEQTUCtBePH96oMYAIhtBAl2HGo3O3tQx00SnusWSaZMdU6HrfkQ/ZS9ZBN2c9tgy3bkl8ER7XSrFmsX8mCyOTik0x+vbxREnzUm+XxF8vLzArTXlSsIEx2imYrRdqT2RsABiYkhlvK5dzymhk7jEwOcp0WW7WsJJIVnSLDHCB0nh5Wt2rkJFyL+L0lMVEZge3ilyLBpXWBKEqy+kVL73lICXpq6Kt+3ZbzalHfxOendYQA8ovU0iKVIpBSUcfHyWkzVIjfAV5kjceV3m6umKABYVi/fdcIIGTDi7zNHKht816gTVcC2FqaSewf61cTtiYjbDMiNi2xaaPrPdSe9eg+u8nqUm96DyuGDqdmAUZwDMLv/ukn/bOgV9Vx6H+5Yneaywhas5f9gj6RuNmZHzJHNMVyj8tik/m7pN6LgOHUlAG6sRdvk7TqY8ft/2DUs8/Odzmr5m9sDAwTNMEiVxFTMgLA8sf6dVaSBNjDT0pEomJgj8fvAlTSvgF1tZ9GScUeluDFTrVidMqbYWA+VTwLVhgKIxJFkUZjPzeyUVmQAr85YLUW2v44sfzP8+IwPMza1Bs2DKPnEfnBIc99DIdTHNUwH0rJb5uW77Adq5kmF4NnaIsA52sIMd7GBXwj7hEdb3f//34wd+4Aee8bvP/dzPxe/+7u8CAIZhwHd913fhx3/8xzGOI17xilfgn/2zf4Z77rnn4/8yowC977YPlMy2RiNZqVHQo5w9VHEpWCvS+6J2oDc0+rEwTG/FNUoRDaulItCnnEEtjcOsiRkdoemxi0zHOAKOcPTGROhGpKXp2SeFBWU5BIixtQN2rBGJWOOt6QyajZSR6nltu0JLaO3unPcUJuwIMzVKivkOjScYhHIQKgQ0xJCbcv2AZuRkpNY2x1IVdSpf4KJucMQivuYU2mxH1I1EW/QUYyhNogtyCR7pJfqaMhmUy4hTxJKNzYaRz+VmB8+i0qrl30ZAjxKBkakgBHRknljRVXxqewc9AQ8T5sLakYS2W3kY4W/kTySLU5ujmesuF69XaHE255qI51glExHYXG0EYBMSOl5jS8l4BY0lOzPHmOfCejfCk+9t5LP0VQVlnnl9PkVMhaeOMve1KsX2ji0Yd+6s4XsW9IXJwjRITCUIY4StTHk2RXVgGLFlZHfCdoDGaHg2L6/4vJQzhU3DS3NyErY9YCIQYDcNmCRSnGesCeWfWZ9VxkCzjUILn+GoYQg4cmtGWMHAMcyX2o9qVVnjjhF901gEjm9aCAilguIgTRxz+B2skrGU7vQJPaOylvXbSlXQZLn3ZEC3sGgbMk4ESsScBVQqz7ev/+/yGP0Pr3oKtcsZBNVzTwgGluGKZHtUiKKcgjDJs3la1MDwPCVVkBXcbnLEI4T1DLH8kKCFXSY9AcPo2a/zPVlnoYLwcDLC0jOSAKuetvtHCd/4fXEGZmauAtuAnn+3x3e8JkdHL/pP+XU/+b+2RWx2qnMEexZT4Tgs0ZkHrl0jHyT30m7pMcUEPwYAT+LZ2J9ISvALvuAL8Au/8Av7L7H7r/nO7/xO/NzP/Rx+8id/EsfHx3jd616Hr//6r8ev/uqvftzfc7m+RLdwpaAvKBUfAUuS15kHSB88SoGYG2FEkl8V1gflEjypWc5J6RJCQkcQh1qxaA67Ty1y0lhrUEuRmQ8xphGe4A3jAMfwf5SZ2zRwBB8IWeow7ymKAlMSZ2FTtIWEtmnGgImyFvqEm0AI2DCtZlnkRJUQO6bm+n2R2RWwSn6Z1nHfQFbO6gDPxVMK5KZC4PXP3DS2ZsCOKYn1ZU6PXEs19HHuS6lY5I67DRx1vzSBLqddg/uuZYdl5/JY/v6tp7AhcmnWomTsimZVIljCRg8QSLDdkgZmO8KJvpmyZczBXq+ACap+JiindQtcJyr0GlOMek4wpEsaePjPeoQWCByf1xQmGI5lkM1gDDgmK8DqKG9s5+YOLog2HQZhFNmiIsuKAEFgLCo2WE0hb4SNXpUU9ZY/3WShxvyeG3W+5rtP7sXIizgnPVVMM5IckCLj4QMSD7T1Lr/OHK0QREZDCxDIYENGj92lqOpG1GQrMLXM9xmWu54zDSIPN/neduFwbZWRZTNZSLa3eiT+uxu4riYFxTFHt5+fQgUmFGlWJazZ9weRQkHEKChTJWkxjYY9bzKnx3kuh3ot694PkJ28UEK1J4WqLOwEIZvwVX8mP89Xf112Zu45GaB4nz3nWFwnKPlaAfOkfBAAOXUP5INClbXG9RVDOajkHNFalUMn8I9j79HclVlyK6MA5DStFlnx9QaO7CJg+lIvIxJBIwICDGG/7IVQOAHwQt3FsZw3t6AXGWzxfPa03VAWT95m6vsGn42bEbm2dwSALRYOR2QVWXC+P/LY47h1505Jez4b+xM5sKy1uPfeez/q9xcXF/jn//yf41/9q3+Fv/AX/gIA4Ed+5EfweZ/3eXjHO96BL//yL/+TuJyDHexgBzvYp4H9iRxY73vf+3D//fejaRo89NBDeMMb3oAHHngA73rXuzDPMx5++OHy2he/+MV44IEH8Pa3v/2PPLDGccQ4juX/Ly8pdDf00DaVkF9EUmNMhfxWwAYpqeKlCXhAqQjLlJWo5TbKIjH90PdS3AbGKb9uuSKXVuWgGFnNXpp0FDp2/hu6VzoazBuSb8a9DEFIQloa0bMvRWJ+gw4n/LcUy/txLsrEIBMDEEphN0kU0hhEppbSIu1fzvv09LgutxMs+41WOWCDaRS04GmjgC50Aav0jDgx7TCJmCTdw61JUPQHd1JoHYDks2e6JKBhihGO6ckqkXQ3OoyEABfWjxSxnijVwQKvrxc4ZrHcSnpvNghMO10ywhq3MxSfXdUmNIVFgVHrFGD47ARA0TQGoDDmJQUftdaYFoT0Q1LLF1jY7HX3jGRnTEWVV9QCTXQ4Yip40WX49WJxAkewwmMXWbF17C+hWayueW8VDByJhuGlPwkIhFbHSwJ/zKKkia+f5u+43l3HJSMP35HxQk8YCT7aEYzifUQjasrsgRuUhzcUeqQ0jYZBw/6pFddI3N2B70XSRfMeNfxcGoQQ6J13pHQ4XlS46zSnhC4kCrrwRZW5ptedRl3GUpS6fZiLgKOkpYfdFiOfseXrukZB89+zKHbDYkNyXGFT0VZDL/PnBc43YxJCkJaDk/x5y3ugGcl/Zhbdxcv/vMfXP5zTVw8+j+t+DYy8990d7jFaF8FTx6xApVVprZF0Z0x4WjajDN8fTHQgzAlDkL2M/VpKAfWfz38PCqp7a37PlNOxwe/fL3B1k57GK8FALD/yfb+XfAfVQsBEBxoNJEaf7/6tPH6//VTA1uesQfUUGX2MwcCSTHQC3AgYd3lOpV4IuC/h5+0zYPofyz7hoIuXv/zleNOb3oQ3v/nN+KEf+iF88IMfxFd8xVdgvV7j5s2bqKoKJycnz3jPPffcg5s3b/6Rn/mGN7wBx8fH5b8XvOAFn+jLPtjBDnawg32K2yc8wnrVq15V/v2Sl7wEL3/5y/Hggw/iJ37iJ9BSFv7jte/5nu/B61//+vL/l5eXeMELXoAQEoZhKIV6SxfFOQslUUaSPKxBxQJrw9pOgi9S2iuydzsFzGSAWB3nnPs6RJyz5hCYxz9aLrCgVIMEJfOU4FgQb8ni4NBgZKTQVRUii5CmylHiE/EMMyHwljBkN9fo2DBq6am4bsZW06OTCDHEInLXEUIdqrkUxEXocVF1sORZvFfnPLILCTML7E2zx7UqesY1oxLUGjOjLYlyvZrRiwgmvTQVHIxIqrA+tBknBDZ6rug9tqtTdIRMS7v/OSbcJLt9T9j3Zhox8vo8I9Tx7BIj2fWPrDzLGkdsqD5a5YjmbDrDTCi8T1MRjJuScNSpAqeWSHIdZ+wMGR0oEqmSRWSNYxb593nCRP7GguGoFUY2oPas4+mhzpyGACZ6kMpqGBFkZKSenIZhjVMPAmtOWJDZemR0efvmLSjW4gwx4wEeHaO4RO972GzQs0G3Z90tLAwS57yd8701k0UtzPjCG2iAwIhCGnWP6hanhnNZ6rhpxBmfl7D/O9cVbjoDg1miLf69a4DI65IG6aZzqPnsas5LhRqBxJMX5IjcrNdwx2wlgNSrPJYnuT5qiEdfRAUn9VghCZgM7Mz1woWxrBYwZG/ZXlAVwShEyr5XJo/R85zFPdfyM/lzL8t7159+yQdw1yJHqfK8krFI55SGYTSCo4hxy5qURC/WlP1G9qUYUwHTSOST4tO4/wplSIIfnwn/rpyFCv8BALB74gF0lPWR/ctojcD0SBTuQhf3bDbSnB6VENsg8LqSSnvYuwAxmhV+6defBwD4/76LIJQbF7j9RJ7z28fy8zW1RneDmRAqC2zHLT78kXxdx4zsrdO4dtrCT2l/MR/D/sT7sE5OTvCiF70I73//+/EX/+JfxDRNOD8/f0aU9cQTT/yhNS+xuq5R1/VH/b6yBkYrjFRElaSh9nOhxFcsVLZNXcLdSdgokFBJ3MvFtvUJYDqmO86LqNl5NJJm4WJK1V5aQ5o51HaGFSQUK626rnBE+ZB7jq+h5qbz4Q9/OH/t4NB6svJKNm6cMEVZ/HliHNUGI3W8zihHkKyGpdZWS3YANxoExvASvicYNEQnLkm905wmTOz1aQj60M5BsS9Nc3GoORXy2CeYKj2bRnQLUgGR6qeZgBikKLxPB2yJWPQVWRBaVRR9Jb0Tgy6AiduimTTNqHhoSx/VOG1xe8j3PtP5OamPUTGVZiApLosjpiCr9qioI9+hjMdgFBa1UNbkMVfDAEtl40kQX9rgWiJwIuUxijHh9pw36+s8LLRSkl3BWhCB8Uk8Pmc5iHZHcmTdYEFWCdPx4HAO4OEwXOZ72+3WpadN0GQXt2cck1R4QWLfzbxGb9Z8Nvn7H98+hoGOSuCaUaEqdE6nlAqxi2P0RBhuPOeTSvBWHDw+y+ix+wOyMTBLOEHjsbeqR0BV5blgtMHEz9yeU9Ij1ZhI3TRzHUab4LnDO4KHjA6I2/w6bfJn1xczEtfnRCJZvbQ4Wub1WYtjcD7DEdWwItiqbRs0pLyqmGp01mErhNSknFGTwylT7v/TX8ob7+d/Xocbd+c5eKO9yevbFoXx5mkgCEVqrOZuHkh1wpLyPuNW+p4UpkGIhHlKmaeVMvIPIcHIvxNUstYFlSpq6kl5xC0pqObHoEbpGSPyWSsk7oPNEZ3SNBXwkUzaEFNJR8q1QO3Tb0sewhs7YjJkF0l5Hnd2hKKOmNsSWVot4NjjVw5etyhzeSupyFoDjSIN3CWejf2J92FtNht84AMfwH333YeXvexlcM7hLW95S/n7e9/7XjzyyCN46KGH/qQv5WAHO9jBDnaF7RMeYX33d383vuZrvgYPPvggHnvsMXzf930fjDF49atfjePjY3zrt34rXv/61+PatWs4OjrC3/ybfxMPPfTQfxVCMEaSQqo/0FU9JRBxjpaQ+pTsPsVEj1EA6gDAYAPDPMOSM21kukDHCo7pOunbqWwC+DmRaRs1qtJjIuAKoyNaQsq30xaXJAOd6JFVc1e4tATsEdNUGDq2Y/bm0xywYwg2FEXWCp6P0NBjtGkPKomClpgSJqasAlMSugIaYWJgYTTFyPAc6M+yd5t8gD3J0UDHqKX3M7AWhgghVHRQzCuMTCtNs4emV+iZkxguPbaiKkyuw66qYfgxlaRyoi6d9pq5CX20RGKaS/qiztQWO0aDEpUMesBI3j1lFuhY2L9R5e+7PW/h2UsluOFK74v9Stw4rRAEqMHX+zTB04u/vaMIn61REzjTiaK0qzAxl7ILG47bBuPMdFLYM57UjAzRSMpNIwgamRdoTxuYVtR789/GMSEQWBEYqfh5hmZkXTNlbWIEWOiW+1bNEnMt/q9A3j12vD7HMVcpIQYhZWUa9siiZbQ1kJAwwBfQkNYGx4Qwn5D8uNopBGYiNFPg0SgELtTYCeXEjEgwSy3ErrrFlmtsYvqvaRtUTONbDsigFbpWImHeJ1Lp4VJmn46Tns0gDLAD8Jf+T3kS/vm/mH9ePH6B6xTkrKtM+tqfhZIiS0/jQaxqSd3nnyHNBTRijKQGQwEmFbx62P9bnobGfg5qtX+9vFc97Q0iLxP1APS6vB/IkZOw2MheBRVhJFsr9xFymjy/TvYOIJDdQ7O/sbMbHJvfBAAsbZb+2WiL5piyMRyDe47vQWTa99ZZXiPjLkFTOFfubXt5CbVMzy2s/SMf+Qhe/epX4/bt27jrrrvw5/7cn8M73vEO3HVXDiX/0T/6R9Ba4xu+4Rue0Th8sIMd7GAHO9gfZ5/wA+vHf/zH/9i/N02DN77xjXjjG9/43/xdPuWmXsFfKml6i3PJt6sg/GdKCNdLkRMqlqggsmY0+ISabBCGL+yMRs3mRKmBRjVgJORYhBybZoFVnb34yIjHjBGKtYInt3dwZ8weR0WmiLq7hi3lSTyZoWECRpARgfD3PsyIwx6WC+Qa3MD3CDT0qFlCSTMkPahKaQRpoBaIePIZFgsAvURTgCfjwPZ8L7txbXUCADhlXnrRLBDpiQeOwXaeCszfs+k0mlSYytmEj3E94g6veWakeP/qFEtGA63KzbthAPoNWRm2AqBIsGwmjYxQ5zQURgdhmk9VwKVAZyePSmplbHI+6qoCcFEEZMADVvOZSPQwzRgY+qWGvHDWlXqhNGhXbg9XbjgXnWuLrEnvpQF6godcF7nxdh41r68yAnRRCGxTUIyClm0Hw+brncipaA1U0ujJWoc1pVFfsgEuTmU+9WR6H5PHINyA272AaE8JiMhIsW5c0aEQRpHlQiMII/9S5HRsEVz02wkNFSNvXMvrQV2EwoTRsyZaVQGJ8PLE59lPmxK1N2SmaOoWw/aZ7Ruq0phZE5m20mKxZ7iRrc2HobDYDE+DT3fMVniynX/m8+/GS16cG94vL3LEcHztKbSrHFltWLfSfg+Oghd+0VDWkiIvqJ+AodSr8suHMUIuTyIalfb1I12g5QZKMjQSVelYgDUSQhkLKAmcZoUotXQBntQoMH/Zj8Ztwiz91owGq8ruo0XOo4SASMj/ZmSd3Oy5SINh/XMc0XKunJ5kkNqqqzCfM+tBkM+wHqEJ2qlX+Xt3HoizQfw4YO1XmvwWJtOHlActkiLWwhL0kPgQg50hsXBBiCklcw4C7hkmD8+FcDfpghbNUuqcGFjAH8ZdORHkwaquhTvKi1r1eVPpzy7QX5wDANbTGjMX/77QeoFxe8lrJc2OVZl1AoARqdikCnFmIyG6brELcmBJ6G/2JzIRcK1rIWQ6OyL9+jCXi0hMlWKK0IHonpb9MVUDkCkgMdV4tDwqshcX29yzdCdsyqJtCcQwxhRmjQ3lNyYVcRc3iSNesvMBRjISHNPJR/TcSMdzrsroi9SJp0prqgIcQQiyQceUJOGH3cUAzxSloJ5OrzU4OsqpI2UktRnRUjcpMl131m9woTNQo+ehWJsWbcz3t1jyPoNBPwglV77P5CMCr9Vy3Grb7FVX6T3N44SJ0iBtJXQ9umxix+TrapYdEpGbhVTV1FBCgUPF6SkB3nBs6jwKjQsImpOnwOh6hCh9f6QGgwOssG6IVA/g+J6ywUUNRB7umgwKtoYiu8x6uo3InpuSqj6tCw1IIKog2ljSm57jtttcFhLaRCdG17owxyjpbbIKQQirZ2Gt6Io6smzAVdvCUK9JVHDHaUCg/Hjd5fWz60dsb2aE85d8Rs/v/1GAMjUyt/7/7P15zHXbXR4IPmutPZ7hnb7xfnf2vdf2tTEeMB7CjG2MIRWmpOICKmlIgZQuUFBaoiutRJHotKJKIoWQUiXdKbUydKgeKgoKSppAQxIIgwMmBjz7zsM3vsN5z7DHNfQf6/mt816bItdSuqs+cpZkv989wz57r7323r/f83t+zxOUTT1Uco7gNEYyXrulBMBbuSMR2th0FwRn5drz23uB2MJ4XID/kgpNSA8qdeGBNZKrYIaQ+syEvZpNFEoGORLQeAsEIc+w/6/tbVIEEr+uotpCizXn3lbAv38h/sbP/2bsI5xfLvDQJLI1WyqUnB8aHO0/BQB4UHoR6xP07tW4/2RZzocpzpoAnzxX/sNjJ367G7uxG7uxG/fFuK8zLB9GeO+T9pciHKMLlRxFEYRosdUN1EqEYh2c6AoyOhxGi25DynbFtLzIUqjT0cX1pF3CaqGrMjIqVijpKFsxomnWZ2iOJQzSmBEiCXEzaMM5KmaIdUnlB5VhCEJgEPp1D8c+oj0qLewXc2xs3NCa9FzjM/Rix8H00U00vBF1D0bsRYZCMhJCeN3QQmIYz14qVBprCraO1AWrshwF6c8tMzelC2QkPUzKCOtpAK2nqCV7yMop4BlNn2xi9rJ0LTzP4dAzU1lbFDZG7zWzvSLfh5/G767Avq3BITAkD0kEeYQljXvT+kTf3Sesm3clBmaVUtSu8hJTRuKguOmot/Pa8Lz3tkdHiOQSM61112JFmA5SfA8AmGVXzACyvIZm/4+I1dYwiQwg9h3BWKxI39fs5wN0Or6J9FSFAgNVPhoKto52TDCyZWTfzAIC2wukjaMGoLm2FLOIqS5xZY+2Irx+unFAxd63jPvnR52cbt0g6iEBNT83qQ1avt64RZw3b9K67BNEqmEoNC0w2xQFskKcl+N+9eMAx+xMoMZZnsOP4o5LyErnyNhbFphl6kKhoJFmRxLB4AZsmKntUYnj/PYav/bbzwAAPvxNkWxlpluR4kDURQOJHZFKDP3WHFbWmgZQk0Qj5JCRPXrAlvCAsFW9EG1H76MQNbAtQah8m20JP8JlolIB2NYlyDATg0YT4LlmpIdrOgfMIYkyhF7dkMAnSPulh06ahSP7DrWa4JEbcS7f/lQ8R8+/WuGsIxlrP2arnekhXUgqLOJx1OfwtBjqubbLosTlqoDdZVi7sRu7sRu78Ydt3NcZVuYVoFRSGw+pS1vBifIyiyNe+a0FNeTzMZIAAC+4rgMyUVSnRUWfbXUMBype2N7D5cTJiRO7bMBpH4u0ZcfsJvRoiA/nzmA6ig0EI8tJwJSRaSE0Xq8wsDZRcJ/rUidl9IL1hWLUGEdmDcJqUAYl3wep38umQc56wKRgk7DRyBhZSyG7LuqkMN2y7tPoEZZWJ2t20J91DsaINQnrM3WF3JC2yki26zq4Xgr24PECK07nhtnBuBkgwvJieTKpcpT5a4vItjAAm0AtVQk27YA1lbRLFvOVDcn2ZDrNU21lQur5tJhiym77kRR160ecjzFCDHxtgwEg1b1gfcsolZRSwBrK5nyNNWnDZR4/l2c5ChajDesqyiko0vxLsV2p5ymKzkjxbts1lmy4XTH71W5M9cwpM9lc1+ChISctfIoCpRhBMmLvXA/KRqJntVYHmyxaPNddkU1whcoTgWvjxK8QWHhpROfRbwvoUg8O1qERs8EwYJzFuWxZ/1p1W9JAsvlRBsGK3l5clwf1EXi5YEVV+n4cIsMKQMk6WmYtViQLYYiTcOPyHPs0+LQsOq/6FbwRAgPnuTDIqZ/o2QTedw4b97txn5t4DRsL6X1BwZqdY7sEsFUz74aAkhY9YoAafIDhfUSUO4zaKsNcHJI4Sc0rqMQjS/WqEPQFA0efPi/tO30XUlZW05ImGJ/2Ub5rlJKpFKEZlLWGJtpiJfu1PjlY9DUzrbMC3/xEXLeXvjfWrf7m/+MUv/Z8rGc1J3QWWFS4e34HAEDxGcznOTqaeYoLQzEETCv3ZdWw7usH1ryaYQgjWmFNEeIoS70lNfBp5oOG5cVBNAsm08n9VPw8p3mOgylhGOl3adrUV1DxIXWpLjHyBikkiAwaitBFIzeDQgHsY7IOGGjHkNUsdGc6rc4N7VGCT9cnPFdhPZ+h0sJsoxxOcKh5E5aTPhhAc3vCFjvvu9Q7VLEXIvNAEPFQPqznxT4yYaKJAoA9x7klpCk2Lj3g+RuGF0c21UmEWHlhHTbIINI9cT+1ATx7h0piIbrTybNKi5Bx5tDlJJRQU6cuPRRvgJpXcjbk2LRiKxPXwayqcekKVShMjrETxowI3haYH8QbW8NC/J31XSzofSVSOk4pZKJ6whtqpTRq7kPPICE3BjM+RER6p1AeeS19LIRgvEtQVV4KfFom77GMN4saJTwf/oHwozYBhg+big+9HA6O512Lb1alkHHtjZQGst2WMRr49O8zl/oHa8K/xhQotfgnxf1buQ4LChgPXNvj0KRrqCaknnkkwkPvBvTC2BRIDTkqPsArRi/W+6S2ofxWXFp6wALZpHbjE3lKyDlZaREYYEjQWZQDXFhy29xXt95KixEeQ+YAgTJ5vNNywFe/jXDjnH2TJyOMrAWyKxHifgNIPnh5rZCJ+DGJD3YMST1n4GsmB0ZhBG55X2kINJibbVAaRoFFfWTAIBIi5HPyAPQeKATG5yRZvy1/eLEusQpOlEb43WH0EJqSkGa1BrJKWIRxp+8tLTT7qt71cNze939I46X/Ma6P22ei6FOjppLLE4+9BQDQ2AGfffWTcb+Exaw8qumM6ir38HrGDhLcjd3Yjd3Yjfti3NcZ1n65j14NcCQeOKpD5KqAkn4CsckYAdUlMBAAYOoMRoga4vBa5difxuhW9PzGdQPN6PKIne9ZVeNszS7uVgr4OkGQKb/PDXKanQWrMDLD8hSFVINJVigtI3anPFwmfFZGpbVGxqgbA6PgEFDlUsSP761tl7aXE7qqCuB8tQAAWBH2VQUKKwV4QmXZFIOVAisjewd4wnoVBMbKUNByIpPqqu6RiRMvySGlGRIt2A9b6HXOzPSAlNf9ao6RoefJEPfzLKzQExIUJY6q8CnzzIVurLMkapsx/josZ7hy5SoAYAwe6xUj8YwqI2bEGbXueiUGlBbnon7AqS+yHDkzJ9GANCEkWrlk2NkkQ8FMOGSy7sZkvinkHGQehpCnJjnAZxZ6FJUSUUbQuDKN5BwNEnEKDSs9YRIa+z5F5aKm4HRI2bSYbOrRoZIMlscRnEYjCisk1TS5SUoXU0K9WWFSv9Eo1jhIcnYYL8CdRtaqrhD42YbHFpRCoJu1Ijhn3YA1C/CtZI/1DIdF7OeZUiXDtyM8synR2zTGoJ4SpuO1dra6g+6EbSdBjCg9nKityPwWwBHn95UX4mvvvAG8/90xyrcNs67CoDvn/gtE57di1yaTzC5L6jKjCB1rn7KfIIldYZICjuCAWqmU/Qj8V2Q6QdqjtAK4bUaU2vWCiia0iCIjWS09cYT1nBg7pssQo/WJWCFkNWeR1EwkI+4GYHEvHujVaMCNzrQIzOj3iUY8dKjwtW94AADw5Fs+BAB4z3veCoVHAACzydMAgH/yC/8Cn/xMzLD8JN6vTTXBVJfJoPX1jF2GtRu7sRu7sRv3xbivM6xsZbH/wGUYKg+criOFerbOkz1Gx3CoXw3IJcOa8vsmg3SY5lL3KRTOaS0+4fRkzuKQ6uDXSVow3mBNde1FkAZdHVvcgdREe7g/g+F2zpcNBhbnLS0sDqoJPJt5RSdv7fut5h8jpGHsoJREq6TMtxbsn0XLBmKf+2S4KMXZAI8NZclHRqqFnuEK6ecHtHYI2RQN978bha4MHGQHcdqo9O592GLxJAI420P122ZYIKp2T/J4nGuaDzrncHAUt3OV2oSzrMCKtZ2F1G6GEorVdynSW+vhB1KeeZCZCihqakOyDjIvDzAvo/r/gBFOLeIx1TGyO+82OPGieM92BuNQazH147E5gyAadsJlzguM/LeQUcbcA7m0CDBTz/NU6PZGiDMT1JN9nof42rpZo+SaEA3AZt1gvYjz8eBBrKdO8j2QXYw11VJ84QFSzSW7zDqVbDQCz02ZT6GltkLS0HozpMhfvCL76YjTimrzbGIfg0pr0PO7m6HHvIprRlH9ZHPeISdVf1ZNMZPGXV5/x9igl0Zq0tsHOyatxsZLn8etZElxcP0g7r8pkXEfujKu37NgUdCqxXNibi3PMIgyDBulOxXg2SBbMNOaqQF32fZQknzzzscNHiIN3ZwzMy5tag0wrCUaAJmQWtacU+vivQRImZNSCp5pjdSArPXSy5/o7yFDugtLhmWBREYxVdyX+WUdWzSA1LLhlUkNxk6PGHiNi5xlkSMVww2p/6PzsKLqzmtNqwDkbDZnjV5VKmlbbqjeUYYSoY/X7DGNYF9eNvi6r30SAPD+978JAHB0dIRmcQ0A8Duf/gIA4PMv/BrClPYyvDcufcDLq8/8L6sl+P/PsVotMLm0h4JuwHUfL+6sNyh5M/SE2YIb4Ccs8E5JHsg6WL6v2bHvQo6OqgUTEx9S+/UUZR7fb7iCB2cxagp78oaTIzKyAMAmHxq79bTpLDxT/IGLrlHbrvaBF6+FgyHEVBQieNliJFSRE5pz/YgTwpJrCqzWV/agePMUlYnKhWQXIp5FjRnQi3Aq7QFa2+HYULqHzMiyzHF1KtBdnI9V02FBaNFK8bg3sLygiiDSQCXEFkmq/qHcqjx0UnA3Cis+KBu6Gg/BIxgREmURfKMASvSIMG6daxSE7TLefJRxUE08jkxHF2kAyLgmjB+TZ5Tv2RsTfJLLyji/Q28x8ka/5o2wcCMmvEmLrYzzFstFnA/Fh0R1kKOUGwL/zosJCt6VBH7sNhusB+4DlTPK/UOAUNlyyu3qCo5rr3SRoVWaEpo3oqWLN/xl3yR5Iun/qVQGQ1h3cUxZrGUL8n4wER+rMKKlpcea/WQ2M4nY0dv4ngoqMogABGJ0Zaa2sKgasKJj8YbrxAWNTlRK2J9mjYfQQzMGAd24wfkQmXs5yRlj3iPMRXaN66jvsU+R6jUX2fFGYRBbIbJJTRiRCzzPNea6HO15PP9v24vw4wfePSSrGVkSCoDwGESyyGmd7ppCbnM29m8CW6Hm0SmURVKwja/1XsQ+0gM4bva1QNfoAkxqzuI1OiDBocqKz1oOdRp/tzzMoKT/ig8ABw9N/phLSiEqMXKTXFO7hUvLKXtCpwbljA/IU/bk9Tle6aJb/Ceei/v8yp1P4Tu/PcJ/l6/Ez6+PV7j9fHxQ/fN/+2sAgN+5+dvISWZBYjmPGNQETgUAW/blHzR2kOBu7MZu7MZu3Bfjvs6wNr7HcrWGXbMwPsZworYVplnMCmqxaZht0JfstGb2EEIGy8hZ4AqjhqTZTxQAVZnBMjI+EQgDPbpSaPKk3BqVek0acSg+P0UlNghFBmpGYrmMKg+d2wBSCBdNsVxhwmg1oxUDbIPaM2Jmc0gOg056xmQbnUvFb8N4xDUNPLNGoWyvtMerjGq6IkbDbpphs2ZEzIK9gcZA3HF0QjzwKEVIjZDD4KK2HQAcsqC9N5tgQ528hpnRiB63GXU3jCJnWYXBvpZ6rFyAEYo150KPCqPAukr6iXwyzZyoeNz5aHHuYlHYOZcgyimdladljQ3VMRYuRtpKK+yRBCICpstmTBm4I5lCFRpTIYNQ6LgZevTs4tdanHHL5Fyci21JH7BaMoNpRUFjSASXoo4w28NXHsJl6q290EWjz5Vbo9TUtuT8FlWJkQtK0z5GKYMsYyZGYofyHp6Zn2bmkcOlnjvJIkfjYYgnGhJdfPAYmKlL4840n6KS3/DSQqGTXt2mb3DWLeLxcf5mukZGgpDQzNfDCMNsvKYepFEemrqIgVmjbUc4tq4EpiiZMxDbzPZezMiyQkPtxyy555pVvUPBcy2w3ThozMaYnX3gHXG7Tx2uwG6M1HsVxhxK1GzEKVippKkp5AcEj4FZnrRdGKXQNGIiyTaEysDxOhS2jFYmkS4C/5qgEzFlWFG1pg3IeQ4z7kt33mAkvHdkyqRT6HnsPgRoKwQXUeUJiVIvzBlTaYSeBIyOyM/goHjdidSoNyV+/fPxoH/+Y3HdfeN7r+KRR2LG37H39OTOCW6+Gvuwbp08F3+3HjGhisq9ZTyvufdwumIGv8uwdmM3dmM3duMP0bivM6why7CyFiVrSYpRn7cZeupg1UK7rg16dryLjpdHCc9IpmVEqPSIahajlpFEhsXaoicuvBHouQRygsp7jMKGRiWbj0Bu9Pl5jyVl1g9nM9RsMrZN/G479ilzEXJBVRcoLHXlmLGNfRuN+AAUlH+eTCocJO8CNhgGJ24QqW506gc0/K7nb3itATYGVntUZi8L9KwfLJoY5bYe6EhMWIrF9dkawyBFWio+5wE1a0D54UH8W1cwHdsKGMENg8Wa0eXIzKJTOhFJXGpDCOl485xK6bnBwELw6ERrsU+EArGD8doi+EX8bjFJpnR9L5RuhcAao1iPqFJjwuxIkSgQkMGKMgSj1yrPoJg1GGoO6lFjwnC7qGgdrmeoTcyIpDH79N4pTs/ifiXPxrpCRYV/nYmteEA95fYaaklmFiHnOWGEPYQimXRaZrzOhaQ5KUsjwGPkWnX78XqoD4vUALuiRqR3HSpueyqK6p3HuGKRXstc5bCSlQlBRZlU5xmC39p8JPNCC0WkAYz6/ajhe2m05vnIChQsjEoTfoEMA+uYIzMAlxcYiSpk1GecVRYtL3dBOIzzSdFhYGP25sziXVfjufmGd8cMu/Y+Nc1S1ASh3yqyd1Rn6dtt83KRS5M1IIem51KkCnA3Re1GWh22ChaizuO9T/eMlPnAI4xSHOYtOlgIryOXL2cG6prQ7gNGabcQBR6t4BtpChckxmNDo8epZH5ZSAQnIYP4ISQCyYyFvJfujPin//rn48dMNHB873vfhMPLcdJfeSbO+c2XFPos1gYvz/l3PMSSLStCZFqOA6q2/0+HdKGnE6hqkorpcuBWBUAgEsoZde0GDZlolpCfyrd9CuCid04ntk7fS39Ui+Wa7Dmexb3LE8z34sqeCaeiaWHddnECgMEWJvS9hiKra1pG+GflNrDi+Cp/PWB4c5BelB4GZysW4tsIJ9bZHgbedBqutE3fYiId+NLrM1rMuPBLURsoc8x4g65585kEh5w3QGGa6dEnCCLwih6b5NcLz4t2hEqstHtUyVj3Paz4L0nRGsCUIrTy4PVhC1mk4nbwKbBQJLLo3CLjA1LLnHYjnAgNJxFijzk1ayYTDZC9JhfjOGwhHBCmQgY0vJAq9tdU8wp1L3RIwkpFASd9fJSHGtcd3CDqGEKgcai4kBRv/sWQ45D+QAXnSpksqVqsSRR5rt+g4AN8LxAqQ8CoyQRlP5YqKigqUmSQ8+DQ9PFzM5KG8iJDS8KB9PXkRZ56pXxitgIlvSkK6ZuzOZQWaw2qwvQW/Si+Wgx68gqZPKSyaivaKioZqsc4CmmDa3s6TQ9cT9pp13ncondXxut0MttDwWtN1Fv8AHj2Au5f4Tryx1iRzDKQwec2NjHk5JwXfY73vTlu+4F99pGNgOeasXLsakztlK1AZtYhK16rFOGzAHKyYNhn2VqfpLR8vxW1zb74jmtDkmZKohc+wKegQ/oX49wAQENYsT8tMLtKiPTIIHC/mpOecwWo5Ios2zYQkaolrVCqXKPgmheZNo+txUxxEK+5Fz+5xqUqPoD+xHcfAADe9cYS7Wlc0yfH8Tws+hq/98qLAIBuw3PZZ1isYnBQkbS0zGoYtyEtV3bwDx47SHA3dmM3dmM37otxX2dY3gPeO2SMaiZXaQ8xBIxrFrdXEeLa2BX8RArn/L7yKEiFrgjlDEOPkTR1x8jM4YKYrnT2wyc33ZbRSWdCCuZzZg97xSRhkKH3ieIuIp1d06PzktaLGG1IQr6GmVFV7SGwX+tsJVDlkMQtB6ZEzdjBM2nY534drgIKUYugVuAIYLGImdopDSarLE/uuAzwUGUVdE+VBL62Xx+g4/6tqcWnvUPDvqQzR7di7ZNeoChojCFgj5CGeFN67xKNV4gWRuskxCkScMobFJw/yQ4UMlh+p2WIulZDpB8jtjNMmY0dJGFajbvUqTtjVgMPzCmIKi0MeQZU4ipL2LYIJXJGiJWLhftZBQQp6CcHa4uuj/NgRatxVuFhqoEWuWg/WiyUEECIBnQbdISjpzpGtJO83Bb5mUG345iMMQUOnxQZNDOFOde2znTKLrpB1EgCSuKcNeFrZXQigEgfoKmKNC/SWDb2GyTESiQgsgyeGX1e5JgQLrVNlt53nTSIMfut85Qh9swa2nWHDect0Mm7HhyuTK/EOWdfX79ssUA8dkMMvM8Aw4x5RvPKoXcAERhLItFDswzf+O64K6KYMpqQlCuUpb1MvkHgXE/343GYqYPh+RwWcjyA74W6zjaDdYBnRiSiu1VtUoYlJrFKRcIKgKT24twWHkx09Ayo52y9YRtHe1rBL6I6Rz8bxBUHGWn5TgU4moLmRpCdDJZEnQ0dvZs+AGyfmR6yNQj72CzZqzbciofZFfi6d8T+qm//QJS/KPQEx6/GYz976QQAUE++BofX4zp/9Tzu3+1XP4WGF7KxYhaZwQebVDZez9hlWLuxG7uxG7txX4z7OsNCN6L3a3SO9hzUAHSjTxpbhlHkdJLDTUVfLH7d2gygaZ9osSm9gVAsdZCWcYPpTLq9SWm1LUAL92FCtQRlMRrR0YtfnewVmJcx2uhWDYaGnfisAdRZjuUgNQlpcM3Qs5n0fBlx36KcJHuMrJJal8fpOlJEhcqaaYeRnY9LNmg+NNlPth1WUenA9diQDDKwjnDeDTgItDxn5Dl4hQW/U1A5ej+rccj6SMUlpEObVDw2rFuFzGOa5WlfAWD0ChtG+ZXaGmCWPPZaLEqcRU96udSeapMjo01KzyK9qU3KkhW1AkeMW4NBFzAG0f4TVX+HnmoRU56osqhRkmI9Z3ZsRoeSVPmBKYXbqKSjOOF5rWZTFNzeyTpGmdAWI+s3GTO2/etHOGK3rqKF/PnpBv48/ntvFjO2o2IPA61Lzs6pADEABbNpsWuvyyzpFDZsH6gBHFHOO2c9arFcJ9NMUa1AbmD4uTn36WR9kpTZA+dsYhyukJxzwGul8SGZ+4mquzIZAmtd1WSGKc9nSTUKG1romvvKZnc3jjBscs9mdB0IForanNKmEMYN3MAiURGRCZdZbNi52zSLeGz5AKkIlWW8pvavGLTMymtmS+99cB+PXYpzLkyLjQZqqqjASv0b6Q5ZlGzGrUt0EbRBQ7eDWQ5kEC0+Utm9wkB+vKi8wHnMmamJHqSzSLqnolGoFKAh1w1RF++hSBrTVPuYHmrQLxbulYBABKk8oF6lzlJtVfZ1qDyKTJRXmBFPHLI9XmSHtEIye6imDwIA2nvRPuSd73J4P6KqxdmLV7hdYHP7dwAAkyLS28e5wjOfeAEA8PFXYgNxcxhwqXyE+0XyWL/CqiuoQs8D+Q+M+/qBtW47HOgSfkWJnDFeCJu2QSBEpmqKas40TC2FaWJM45iYQyNVGlo7bskPvFFPkeOQrreOPT0e2PYxkLQQvE+kC9NLL9cIZPGmM9o+wSFdR4+pymLFQngp/R8wqSP+dBNPZNue4GgW+2/m+1FS6V6zRmB6n9WE/PZrrDsWtwmfLFRI1g4b/u5qHDAQBnB0o1Uzj4WQhQhn1KGAc4S7+F6ZFyj5e+Kca5THyGK/EkZi55OvVpUskAIoWoBafqu1ieE3nxCSKGpsxCZBbDBCSHBGw4tW7xmI8IQn5JRZnZx1h2aDjA/NhcBYXuHKURTHfZD9euM44OVNvBMNPIelqhAWPBEk3aDSqbg9UsQz5HUi7cw1lUL7Ee0m3hQ7Pig736bz6rkGfV1i6qWvLu7gtetzvLKMx3eXDyKbKQxyR2vjmniw3sfMUPWCbLvWrgFKZA1GiCBIN3Itth9qwECILKN1yvXiMg4pc3bCIOqe3eBOiNCxIYuxDFvBVvFMsmHAyABI2RH7ZIeJX9Mwmuj6DcBxOyEoKBJXHFmnxlvU7Kvz7NczyGHpqNwQms3qKj18DfsTh9U5Oq7zCZ8RplRJ8Llo4+c++N4NrtXs9UqonoY17PXiDdUEIPRyAOwPXPkkbF2JQXXYMo8FSp1eNxhO4sanhfQ4BdheoOw4+nHbzyXWL7VWsFIm4Of0CIR79MsjlJedn8MTmldGAYQlGypTZLVFyWBaHtbeecDFHzyTvroHMxQHkfU3Es63YwGl4v10sWZwsqcx54RtTuK95e5miZt34+ceeexhAMAv/eZP499+8t/Ez2m6C8819nkTMDzgs95FKRG1gwR3Yzd2Yzd24w/ZuK8zrJnPsYcKNYuQd16J3dWjclCkxFYU0gxjgMmF1MD3MoOCkIQjpNYHC+8l4xFhUQNjReWBcJGpMKWWmUziFQe0jMSWhOU24xoLIYFrj6NJhDRqwicueEy5X7nYFRigIOwjwvt26PHcWRT3zQzl+fMC+48eAAA6miwOaJJ5nVExSl9OdCJxSL/LZV2jp6bfivvq8gBFfUHp9h/WLfQYI6KSxBQ1zdDmhEMZpbebFiFFySz0Oi2amiilZ6U2OCBMO1eESm2HDbPklpF9XpQIFbX/StGUAzSj71rsGXKVzO6qgjYp0wNcI3R0dnKG9UbgMEKHeYGSkEvJ83lQTNEz8lwwum0yBXMYz3FxFDO/xvW4OxzH7zDELr1HzXM3EbKM13CEw6T5ZrLJMNBluWljNrdGC13R3JJsoLvNBms2A80OSAGfz5LGZX8a98WPNTRXiN8ccxslxiBwkigZFCgSUSNG0E2zTqaOZC3j/flDeNO1xwEgZYLPt2f4TBevq7aJ3+3siI4uz8mOZpphYBa9Dhb3Tkl7Z7aHpkvnZ1LF/Z/V++mYN9y2s1aSVXgtWpNIF4IjBJoFlzIdgS9PlxYoxVonHu+69bC81t4UXTDwyIMh2QmJHQi06HRsRWgVcCHFIYKR5SiZpQoRy3YhtdQIf8D5gEtvFIWTeIwv/O4aYcV7D9tW9nQO24guI/sTTZ96x5KYrgYsDUg9M6zQAhA+TAhQ0rPF9TsMAYH3tZyZsNEKlv1r1GFGPu0Q5J4SDuIxeYXz9rPxA3t/FgBQ4iGc3f2X8fgW/y7u6/kKTz76RgDAPR+v19/61KeAdVwfsylJUl3AsIrneOQ1cro8gzUeftxlWLuxG7uxG7vxh2zc1xlWsXFQasCKXe3H92LBezDA9BJx8mLbyKmIR4slxuHBBIf7MZK1NH/0qz6pMOuRkbGpMNeiWsCoKRgcTWMtqZBsad3hXMVIpTNSw2rF/w6lNpixcD0J0nR6jD1G+6Il6PoOa+qpWSlbDB4idZZNYgT1wLVD7F0mpfRmfNcrj3oWt+fFyC9v4RgyVyQt5LYA6/BQbONXo4G0BIu1dvABlnRrsdg46dfJ/iAk1YUAxUwnY02vDAYVC96imm8bh9l+DO0u096kcR2GVayTSDNmDw1XCCVamjF7ZJ2oX8T5K3INw4iSAgW4bnI8zW0v5yVu+xjtb5i1dkpjZIawgWjwBbAvF4aq0r7KoESRQtoCfAZDjbiB21BuxJy1E6EPqzyHYq3GMAzWWY6RGUnLoseZXSd7lIM6ZmTjGHDjxhviayzGnC832BRxX09JFQ9twIqRbB5iRnl0+AjOaZ2yWcWMvLEt1J6Qq8VmAsm8dMMseT56PF7G2tP1yzEdefz8DOsX4/u/07BmFwz2qKPomEJ749GOJDKEEQtmzD5IF75DcSBEKCITOppkAoCuaNsTAjIqhHgvVjcdelFWcdS6dEPqsxByy95RDcMMrDJClw8IREe++V2sEU5aEZt/TcguMyTXob9QmxKXBdf51BwumVimNHJBaljo7ZYOi5fjd/qN6HIa7B/FmuMZ0ZIxG5LIgNS1QkCyCMqlTcZogDqaIkphNFBSrWIIFiMV+cUuJtc6HYA4KZhiiyYVvBdMp0ewVB/RrP2v2pdhOCOTSwcAgLp/AC+9/EI8pkXc3sPzq5gcxZ395V9/BgCw8BkG1qI3NiIJY9OnFpei4lxlAe0Yks3N6xn39wMr0xhsg9MV2ViUs0EGTGZkDhYiTZKjFUHaNU/Y2KAkG85qQhxDl7rMK8JPmQrJ9Tb3IiiaJ4jJEQc4CwMWOW9IPAtOB2RamIUKLSWPHNP72kyhRayST0qtPYJI//PmPx0cSjLHpMB+NG2h2c/TnHG7JVBPWQgXdt16QF/IU5OsrtpgwwfMiliOcgGQYjVk/2rklTC44nubrk0uv4pVdVMXCEFYjvFzOQrkvGl3FC89XW5wKKQGcX3QCkoYS3SRtTpPN9KeMKsdLKa8rOVzowE27AUT5tXcBSzoiFpP9/AACQB3ydbaLM4TqWRgD09nTOrPk0p8qTQmVNaY8gHurYPnueupiFB1AZ77uqaM1DgLGAmb5lwTk7pIUG9ZMNhpFUZD92MSLIpxwIyizOo4nptZmGEibMh53O6JfxVrWpzsTw/i3E8qHDAAWm9iENAMIzSjJum9yoJJLrk5IZl7RwZLrsEJITXUFWbsX6tIyKjrKS5difvfI0LRp+tbGFZcn6VHQXJj4EOxneoUrAXO1ao/waSgNQ+DvtKPCQoUdZEQBggmKKjXgIC+paQUr7Wj/T3MeIdfUu3h+KTFA1RWeepGnN+jOsO44XliQJipbe+TeNlB6dTvZ0TA2Gl4Ib+4LZwYEkswvtSugSmD0pY2OuXcoXGxLwmiUuVVsp8BCSVat0lWKzl/K53YdSr1LBr0TnqagGLCQIu3wc6OcMIVkuvfKWw2hCX54MjPJnjl5QUAYDKPc2SqgIZQZXVI4lfncPjU/x4AsL4iFkw9Xr37KgDgl5/7VwCAW2pAU8X5H8VNOagktTZh39/lbA6fWTgTsNiJ3+7GbuzGbuzGH6ZxX2dY5jBDbx0KRpQSMU6nFXKqVORGtP00FOG8dkFbjUWLZrWI352yFyIDDCPZ6YxmjZsOC4rVSvhSV3MMrcBcMYxZNGcYGO1bQhhlvs32TFVCuxhlit6jb8cUJQm9dTafY7LP3yYluvEdBlLSxaF4PG6TaOGEc7IagJFZjxRahyHAES7ID+JrB3t7KAtmmlwGxipYgQTY3zX6gAmtHyrCN3OTwRLaWtNocPAjjBR4J1ROUB6afU57JkJIY25xLOaFpGeHcYBjaiecD+969IO0K8TXtDWoSHQQTbnluMEZ90Fo8KrQcGxx2PMGe7MIlw08zu5sgJEiNDnxPXyCBwvCLJNpnjLdY/ZD9cMIw76qKfv+cpfj3q0FAOD2hpYY+wX8jOKtk/j3Sj7F9SrS3nP20mnrsaayw8BIfGpLDM9HeOXBK5EqfP3oUXSMeM0Qo/QyXMEbnng6/h5/d3F+L5GL5syMg6rRiMsvVTWm2iAnmeVBQpG3C4tbY9zOVcKsIzymzNSfoI6cn05QcA4cjRr7Zki6d7YLUEoMFOOa2StrKEbWA1tDguvSWpaETvkRNuf6ZobloGDYK2YKZuDWpv5BLbYcqsfITDJw0UzNBN/97tjC8MQklgwWLy9Aw+QUsmsFaP6H2Hw470CAA1oUpUPS800QnnUBXlQcmInN6wxrUtgLah32w4DlMn75GrsfCmdwerNP+wAAegpAkB9m0NoEFHJs/P0QQjJjDBromfkJjF1OFDReq9U4dgolnR4JzuD47m0UeTy3oYv7vz5f4eAy18DpPwYA1P4Qx4v3AADO+0jyORtv4mOffSX+OyziHJgNcpKKckK0hckw5dzU7GnsWx/te3aki93Yjd3Yjd34wzbu6wzLToE6m6BmDWZD2q8pCxixlhfu7uhQU0dPU1FgGFp4RpzSu2as3hIrRMW/1xhYR+mYWXQTh4yKCGLbPck0rhSx2GiZMbjgk51CpxyWNkakkjUclFPs5TH7qBguHe7vp3+PbHDsZy0WNNLrJEPpltgQq5f6y/60TtboGUkLp2OLbhF/N1BhY459aDb9zqmJV5Y17tJeZMUCusWIXotneDzOeZ5hTmpywZrIaXMPjvUbLY2SysN7FsS5T5eOruL2Ju6DNMIOfoAVPbghZkY6qFTX2JvQCNEXia6cVCvCkOoaI+sXZzrDmmr95vwurpG0Mc2olD7Lkq1FR2PGwfdYuXjMr8kEeHwbhqMbv6XWT2axjjOGDOesTR5TdcGMPaas6QVSp1dnLabMBkup/aFDLo3IpNh3qHG6FGX2+PegvozLVcwUGvNpAMCTT74T73rnOwEAv/xLvwgAuNu9gprowoS1kW7IcLqI57WlKnaxV2OfhXGUMWs59Raf6WP2drSI74UyR8nasGaUvnZLjNRglKy0zRQ61ivP1y0MWxLmzN6mWYGGNia93dYrOq5Hw+LPtNRJW0+yW+tDalAfWQsd7ZgU4YukVqJgR0bxB/GYjrIe77geafkHVL8oM5XsXYSO7rWCEg47a1POXTBV5TlSOkBR8cWQUBJCSHY3onThWgvySFBzDs5OgZEEjHMKoly7OoHZE+WP+Jp3KqVvXkgfIaRauZAucgCcDmQ5UIj2Kde0RdhqdIo7vTLo2Ajc0ax19mCNLIvnu6MCS1AW/Zq1bRouvnD3n+Ezt/7vAIDrl94CAJhPH8TvfjpmWC+8GLMuU/aoea/orNjQeIxcjyMz40XfwPrwZdHa7+sHls4NptMJ9CACtvH10SThAfQUEc0yixkL+2oai8hN16AN8WQ4Crfa1qGnJH57O66M/f05KkJQFW80o9Kp21+ebCEEOGlg4uftOCRH1y70STJKLso8Bw4piJr5pHiJnoVnSxaWm9okyVR19E1yHj1FfjM+oIssR06Fi4Dtg3cU11XCCmf3FhipxuEp31KZEQ22TC8AQBmwEdYZeYrNpsUhWWLiaquzHO3FhzSimCqodOFoOfLQ7DJKF0/OMeHYlbJwhOF6UhedA2pCWlP2M5Uo0NEWVsRhtR6xVwgLLx63zidopE+oa+BclJa5cUBViNqgZTH9mCoavesx8uGqeSdcrntoVuJryjDtVTUKzq9IaZ01K/SH8XzVUrnvRxTi/EwlhlXbIqfD8ZWZnEMFb+P2mhWhoUxhxvfH43jeHnv7Go88EG+8b3kT+/XKd+KXPxH9iV5ZxofY/o06eYsJVI2QIzdx/8UtWZdlYqfdG+KcFtMJnh8iTNu9+ntxzvYvoeC6OybBYm1b7OcRU+tJDhnzAHFicZlGTtakIrzugkNIwQHhKTcmJg8RWhRZBkNZLVGXyFSGkn1LLa+loWsSuUexx2xwFg2JGJ2Lb/6xpwze/ca4nYNLZA6qAc2SDxZOkTEBmYhPQ15DglfLifx+D8uHnOKN12QGRnzTRFEiePT0ETs7YdDZa0y4toRNaF0Hiq1gXDHQ67euwMJcNHp7v5EHl8pNIpVZ7zCQSCT1gSzf2pNYPhS8HZMnV8EeqXr+ZrgQRW2zcIf7tUJP4owjhPi7n9tDth8fVHUZoejPvvx5/O55hK87Qt9H5QTD8FqXahuAjuc9iN9cpnF4NIMbAu5ggdczdpDgbuzGbuzGbtwX477OsEwwUFYhZxyShD11QO9F2JXqC6HHyGppzt4hHSwMIwCBgRDM1r31jBBiVqA6iGHLjIV2O1Fo2RPSsD/FZBVUGT83UKhyFWz6XIBDXUpvV4w8M+0wEL5a9hL1+WQSJ5CmtRYlI7qM+7o/mWBKjUO2tsDmCk5LdhdfO8hrGMr9S4a3XLdYURx3ZPQV1j0MiRjSr6WNwcCCcs+/XWixYlFYsXqtsyIRUqy4xWmVKLhCa7e9RU4XXUOR0zLzKRI33G7bj7HRBNu+uQI5NLvppd9mdB6egsRB2gdyA8VzbTKHgRqHmzVxkbpEy4L+CTM/b1wisPRse+jGEQfzeL4fuhTFPg8xh2FGuiA0uzi/iTGLx1fs87tnPVaMsAv2FT14rcJbniDMwgj79NMdzo5Z2ec6uPFwwBvfFPf1zZfj9r75a17ArI4iozfvXgcA/MpvPICf/5VfAgAc3Ij7WaspLKHjYS0QeYmjvRhBT/e4UIzC3YbqKMyqhqbFhln+K8yqx5MWJbOCDXXmTK1hRICZZCTvPQzX5+F0HzkhDjWK4PQIU7+2V23ohgTFKyIOgwYUM5iB2nhFXWM2jSwFQ03BsW0Awm/GSx9jiYHU754Z1BtveDx4dZtdAIDac8ipgK2d2HyopHSR+rDctlQg2WGZA3QXgSWsfH5u0ZIcsUc311wjKX+4JFrrErmEpsA4vzlc0HmkWWPwCZ7MRCEGJumUIkgT2bbVROkt7T3BiGPsy4w7IZ9T8Pw9NYje5phg2vM7ghZ5KKI8t16OaEo3PIpHrkVyRn3l8wCAzZ0XoEmyukREobIqEZxGnn+lgQkRrozvTXONeVnAyj6+jrHLsHZjN3ZjN3bjvhj3dYZVmAqZzuFY2Nu0rD35Hr3UfqQJEQ4D1aRzKggXhUJFsNiwjjDUQE/18oMiRhPI9NYvhLi5zgBPhYjAgvulvX1MWEdZaKqiu4CBmYQ2GooFbk3ChrUDXu1jg+cZGypNBsy9RCvMxMYy0T8zhn3l1KQaSzmL250XFQbOx6lYXaiAvfkBj7ng5wrs1/E7G6qcrzYdgmWtjtzYQ7OXrFDOqVeISiWyRc5iuVEKOYuqDUPUthuhk0NI3Kd7y3O4jpGkZLpZQGD9K3X7e4uOJos96fQzVJjSSmbGaK4r1tiQONMwAh3HBSA6dMWIjOek4zE17QoNyTMrMYLMFBTrEC0L90NmUdKuvWviXJpJBsXtnSxvx22sTsVpJil/ZCNQMNP5isdidvu9H76Exx+Jx/S7X2Azaafxxkfjdx97LM7LOx+2KMtYS3j3O0W5fIPlrbj/z396AQD4+Gd+CzkTppaGgKNugLB5zb4U0wKWc3TA7PtKvp8yWEUduaUdk8GoYeax0A4dMxMxpDzI9iUZBLkvmOgaGb9b1vPUML5iO0AYLUqes1xqkgoYnbgfCFHAJwNKz7qmLSqogsaXJDDNqh49iUEZa17tImDdxH9//dV4vO952GHgyVH9tjCUsw4kVHIfAsilSHUj7QEwi7JsYdGVhmcdeE2L+cVCwUrTL2+peR1gki89EQy9rY95IgkeIaV0OmUaOmVRAhr5MUAkczLu9DCOYhgB6QsGts3LWiV2fEJbXFCp5l4xy3eLz6CTuhyJFkVlcdbE6/n5V2K2fP0NX4knn4oKKHfv/QwA4IW7x9icxzUtmbP2KhndemaepsoxssXIKSkcOizaRSK+vJ5xXz+wdKUxmhEdHxxLCnsOKgCEiRSfMCH0qdgn72UmR1aKxE/8m2cGJRdzdokXTDdgyYK84g3dDEVSragKXoCTaZIRCnJBaJXYhuM4oGv4uihZ2AFLwpJSlCxNnpopMj4EfO/RkVwlMEU7DukBWpCSlOUZPG9ehRRaYeF48xSx3/nBNMGbdxaLuC9+hO/j7xZsxTeqRM4HQc4eM299utl5kanJXOpPqcXHyAZozvWE/TPTvIIjUWNCDZk11jgnO7Hlg8SGbUf/ygoRJKCmHcWUUll71RR7JJzQiBlts8Saoqsu09jwxq2oAKK9QXMa50PzOEpTQ5EocE54pAkBtHBCwwr0YlrhLh/gn3khMuqCC7g8j8d3jXYOT13P8dZH43y9+21x/p56eolPfiHeeJ/5dIT13vW2l/G2r4rsqofouVZZhbu8uQp0CHWGOnJGsOLD+OO/vcGjb3kMAJLHUYslRh4vBSrg3IjTsxgUqbtk5R0OAL3bFG+sB1rDEWM648PJZAYF2XAl6ZjWjli34qYsShB5gnOVB2r2afmKEB1CgrInQpIpZ2jo+7bg9ppxgCb8Jr2Uk2mGgeek2xDiV/tJLumUDNhx7fAk1/R//eH4G2++1KGjsoMQAKpjA8U1I/CZUkhPE7l9KpVIogh8WPRLk1huPfuspnsqPfkC12/o8q3U2hbB2/ZaaVG1cIn8Ir+stU/vW95HQjTOituTjdgMGe8jwSmoXBQwtr+VUEQ5TuuT7lPPnqtQqiRzpVkS6Dxwi3483eoxAMCjN6ZQvIc+90K83/zupzp4BnjClXHwyRl8SqjfawXP+ZKyhAoebrRwuz6s3diN3diN3fjDNu7rDGuxOofSIQmmBhaCM5NBSQE+hSUejtGBZuiTaZ3osgJXQBs4pkSLjJGb7mBIU5ccvPAKFfuDDkg88LnGhjDFkj0n1luU7P/KR0Az1BEljJUd0TDyCyQX6CyDMq+1FxnCkEIYiUjsaJFPCMNQJcE2FiDEV1HLTM/mGDlJLYugujboKAq76BcAotVFQcjSUqGgC32yqRBSSzAegyaJgpFbDg9FvmxBvrEKBhl7kfbqGO7XWYV1KzAcIYTRpGNSjNKL0sCxiDySOj2ELTKrtbjqFsgd6ezMCjMzYsN+o0nZoOlJTKEG38Z6rAn79uRi39gzuL53yP2J89G4dbI9eYFZze1wioo78fgDcf9nOuBD74tz/VVPxsjz6r7H9atsqWAvlR2XeOR6LGC/953xczeuOxwQIlGNpKgF9q9xHmyEItfDceq1u3QUCRQPHV3BW5jFXb8es7Rf+9QCz9LAbzaX3qUOJ23MEBtmnjrLsK7ifJyTfHMpqxMasBaKdAHsma0INAB0YcTYxe1JVuqUQUiEmRaZiBNz7StdQDNbFP3OShWYVFSwYGPUsDhNmYtj/9rBvkdFN+mG2W3TKKypDdqJXt2Q46vfFrf3DW/jmrmToa0lS+H89h5Iup3xpQvIXPoLRBgvzhezF28xcpKmR3FebAhbSyJmP2MzJsWM1w71Rf+l4PmD8rtaA4rA5Nag0YFIdMoK+1EhCyJSrVJqqAX1UBoqCAU/pF/vea0tOrnWPCqRyuFtbrEE7tCksdyLpByYX8Dnn4/f+cSzEV04bjaoed/tKCAOkyGnweuUJQgPne4PTuRDbIAaDdQYgER5+YPHLsPajd3Yjd3Yjfti3NcZ1vn5BpNJCZ8aEQX7zjEloUCor966RB8NQfDrACt22IyQgnMpghEtszEEZKLzn6IghWpCnJ6P/bVv0dDqoCOV3eQK+4Vwzj0Co0wplq7HIRV5JRqd1EUqlnbMCsd8iFVqACObegdnMVHxOKUW0LUBBe0URMG7mBxA0zZk08Va0Xl7ipZ1uW4Qam+2VfwgnVY7hQk7sqeMpjuzwpkTw0VmdgYpO6tItTcmIDCD6YTe3PU4F3V9sUL3Fr5lBGuSnwIMsyAhmcwUMGezrmGtbbP2AJu+PbPqdsyhxUbFa1xhMWekUeUwtmhO4zm5VkW6+o1yDzdfilnxHn/j295ziMUm/vb/+3cMt1Hig0/HRuTv/D5GqMbhxrVYYDqcxwwKeAlKjDE3tF2xGkdV/O7XvD9mCiob8K//PzGS/Yqvjp+/Vu/DKjaEM1PMMgs1xHl98ijO6Y/98BJHD8TMqp5FksbdpccXTqJqwYr1lHHsUkuHYl3oRPU4HaQRnE2smwFVxtqvGANqYEP6fqCm5IgBLdsU5nU83jE49Lz+lK6g2QksjavGBHg2abesQ3loFGwOJ1CAmVY466XZlaajNschm/0ds0L0LSzNGq00pw4GexPWkJU0KbtUc9Kcg5CbC1mIZPbbOk+qa4VtPiRIh1cB1R5rhPvczwA41oM8s5Z+Y7eUfXUxqyLZwm3vQSmjkwzPKXgvyhrxtbwAePnBis1IGKGwrXWJoopOKMT2WIQJ79wWiZrTWqk3AzoWyAXpOF4A/Sauozc9HRGk1fkLOFvE6+WVZVw7K7VKjfuK321Vn5ASsWIp8xpT3h8U9zP0DqMaYHUAIASVP3jc1w+sIi9RmAKD3DQbKgXYkCDBDEJosOi4Ipw4e/oBIyGrQqwMTEhFQZ9S5ipt72IVs+dDQlQkkAUMwkqUnq8sS35ZCO41FwMATJEhKLFYiL+XQ2EkdGN5Y1CZh6IPVuCNXNkAlLzZEDYrizIpXPQCSbQOvA+hWVO0tlkmOtScKg55VWOq4s1zWhzEOWgNvLjKkmqUjQFOYEIV58DmY3pg9BAHVZ/sGRwfXH3fp5vdhHfFQmXJtgVD0qJBRkirJqQzQYZSX+IckJHYnSBd6Ya+aOfHaOYMGEKGol4AiOcbAM4bh2kbb7R/9o9HJuhe5fA//HTc18Mr7H161z6uXiJU1ZBUcfUAH/0j8aHz1Bt5/gOQZXF71sbPdV0Dml4nHyDvB3hhSOkIqemsxmkTH56/8mtx/r7jI3PURcPjJFw8emhCYA/txdceeuhzUGSj8t6Odz85x+dejDeYu1QbsAgwvJHnOu5nCBabJVm1tEtRjccoAtKXCbkWk+Rq61NRX6EQGw32RRmd4fyE9h2HBoprZhjlYbLCSEmujA8YZwwsb1SBcHOdAZq9b5Ykjbm5hKMyklQs4diNehETklRG9vVdGzO880kqhAycX+OQNaJIwTu+ColBKSN4pKeS3NwVLjyo0nsKlbCLeDErn8MPcZ9XC7IebUi+erI+lVLbHi9RfbiwG/JY0/hSKC8PGiXPjeO9CmWAMcI29MlOSA4tQ/gSuHF0AOZxLexditf6stFJOeY2A7lXX53j6tX4wDo8jMzA45dbnJ/Hf790J8oxnQ8WuaUEGEVtz8cOXkt5hsFCDSiRxsu2AUSWFRT4XeL1jB0kuBu7sRu7sRv3xbivMyxTK1hvE7xWCce79+gYPVblxapq/JcTyT7n4Ayp6UKqyH2CCaeE28wkSxlFErfs+0QPNSI2m5sUhSYoRJtE94XXyASeYDV3lk+hSKaQCMltRnTUgbOkkk/medISFH5INgKlhHECQWYFHKmuaxJANusRlooOg42pvAsjck5Eyb9zPcWT194EALh08CAAYLXqcHwzwk3rRXRJ9a1KVH7DNH+jVmilX2fDaMl5lCymlyzYqyKHodCt6M1NQoWyp52K9OUMHiVFiKepv8PgbL2I54EwYNd61D7uw40H41w9eCPHr386TuZzty0uvSUe8+IO92VxBd/7kXjMf/LDz8ftrXMcMnN10/gbTz+hMCc0+kPfw/N16Q4emgsMKmFrhTC+EM8J9z8f1lASRYuSQRWgiA4r8XZwA97/1V8LAPiVX4iabGG8iZa6gvlEaONT+MC+IxEKXgZMDuJmNA1GHzwaMaMg7V2qUExdjTkhbWl/2GwaFOxZ8iQA5VWWXKgLEllqe4BqP0baIy80ZzYYVNyXs+NICgk+4J2PvxEA8MwXnoELdE8Wp+ALmpqeULCpSoxEQMQ6x2iFnKmpokGm6wIcHXPHnrp8XQfDg8+oyvG+Gz2+ObpfwIgRaZla8lKPlHNjomBLrd87wNsveg1I9wxBRhAC+hXhzg2zvXVAz/0XIew826pQJJ48wlZ/dLu53wcSdFsCCD9vbYZ2zflg+0atM3hm4MhibyiwNbkMSqU2BSHEQCsMhO5WpzEL9WpEz2z2Cy89AgBo+kfw1ofi+wMJZC67gpsn8Vq6dRrvCV0ecLqh/iczrMHrZHFUi6pJ0FhzbkTRIy9zIAOs2Sld7MZu7MZu7MYfsnFfZ1i6dChQQDtacYxSJxmThqDoh1WZwR6bcFsi0711MIwalaguqIBAnTpF8sBqXMExs8oZiVQokjGgFJuddwhJz0vqOAYZ6y15XqRipPzVxRQZNdqWVHZomxF2LfbgpODXCqYUkgIjFbhU97JUJ+9HDUVKemAqttE9Otqve5rj6Tyk7MdIvW0c0LAwfp11hKODeaphhFfiNtbHZzCcSwo5o/ctHLPBtdsa0qVuAcZGLhh4zlGa5xFJD7KkzYHKASYNoKckTroNFm3M3iwV+utuD299a5yrd72DGd7JI3jnN8btNNMb+Be/ES28Dx6PBIVvemOO7/lgxOCnVBkx8xrf+E1SsGcDb7GAZzH9rY/aNOciS24litceVkWar5Z1ZDRcFzM/z5pNXowpihc9umAHPERK+ge+KW6jP/fIHqIFuUTfxqMnCtAxU9GTkOo8XtTCyw3WbGg/PadZpzpAkTGD4Tn3TQ/NGu18E/+eqwCTUYnlXtzB1ZlDuBnP+/4+CSX1GuWMse4y/r20v4cf/I4/CQD4Wz/538GxMftgOuUclAjM7htKMdhCw3Pt9ST+5AooeT1pqr2s12e4OSziMS1uxflTGsrFekrl47z9uT+7Qd1EUtHQkwSRG6hJTLFcqgupFKon3T1/gWwhf6G27wsl3gNBRPD5pmsd6C8Jw9aUfrhA7rpAl5dt6wupQsqw0nshNf9KPW1YA+2KSvUkAh3sBWAiX9peayopxgPSd560BBVgrOhZkvijLF55MV7IL96KpIonv+IhWB9rtYH3pbtnK3zmxRcBABYx+8rLkM6h4s4WLks7Y0mw2PQemoSfmmtiUAar9Rruy9ASvK8fWMo61HUpwhXI2KWdG4+BF3AQR04oFNL3w5URmgBIUZs3wN7rpN5g5lsvl/N1vDtV3OD+fomqIrFDfGpCDs8ptTWhEAQMIoEyOhQQBY44jA0oqbZRSx9TsLgsabIoPzgLyxtlegAWQGDBufBk6miTrElGzocLJWr6a4mih/M2XVBrLshVs8bxabypPHvnZQDAYw8+hDlhP08oKisrOCfKA3HhrnyT5lx4KS5X6Mj2kAt+bBtkuSgsxIVbDYeY6bhfB9f5oAkbdPy9joFIjRGWcllnvCnX5Rpf+7445295OH73rN7giYcW8f39M3zkK+M+CLvrgSuHwPCFuI+bSDgx5SpZvmRyt/AusSUVr/hm6VFmnEM+3LUZUBD6SN38o0dZsL9KFE+DB5WeEikqKw2cjqK21x4nIzQgXfCKjFCLFgWbZdRpFIIt8hFDiDdrx76phyYab38wro/nnuG6NJvkVXaXrNN7a49Xnz0AAEwZ9KiZgTqN2/m6D7wBAPDerw144HKEfw734sPgv/9nJ3j2HuWBeP0Ug0X3Ypzzv/SdP4z/7n/6aQDAcsGHdbm17yhKUXmwsAKx8xz70QJ1PI/lvihnHMMfU9x3Fuf5hbHD0Qtxrf4f/0z8+6Zpg+44HnMhrCYfUi/VaBjUwaAceR0Km2JEemAIM88olYINXoZQUNCiagF+Ltt+aZQHmwrbzyW1iu2XEonDXyBJiIJFMOhIPTZ78YdnlcLIOa/4IEfukgKHugA1irwa9Jg88QYyTK1bp95NRfbvq6sez6yjbchDD0WH60fnd6G53o4//wIAYO3fgFMXIeAZrdRcWUalDAAFiT2blccgzGP28Om8iH4niOo+cU5HBFi83h4sYAcJ7sZu7MZu7MZ9Mu7rDMsPFpN5hnwSI94gdE9dY2SWcUaH19ABWsQjmd+UABqqC5yLm68FcrZwOUbcPcKWzc70tR97rDbcD+7P5YPDBAlsqATQtF0Sv50WE+Ts/O7YrzWEAQuqVIyECQ/qKQ7mNGHs2T81NMhpaDdjlleWBVaS6VCwrKgrrDLp09ramuhMNL0YdXuFzDArYza3bld48W6Mpp+7Ff++cOdVHO5H6FN091zmMVD1tKPGm3NAx4yN/e6Aj3ph8XfjS1Z7tFJ0L0Q49xQ6xMzpxuET8RhVgKeIq6dp4/kq4KynbmMe53cfDtk6zuWlaezIf/K9N5G31AVsBsxof3B0OR7HZBLgi6hq4QoWj3vAMINVhF5dqKCo6KEIO5fVAEftypqZgu097CgQLg/dAgPhZoF/tLoQWYscnHVQmUvvA1HI1F+0lQUwtoCWXrvuFt8LMIdCmIgvHfcKllmUP49/F77AwdV4vMvn4jF+8Cs1/ssfjcc+6mjG90u/ehdPXI9z9JE/Fu0j5vv3kPdxnZgQ3/u1Tx3iVW67mMb1eXl2hF/+V/8KAPC+t74D3//93wcA+B/+8f8IAHjm/A701XjuDq/Gc10Gj+OTKCAswrPTosJhxTYLak6u2mMcG2pNbiK9vT0p8AN/PJ6nr7hEod5XRkwP6e7NbD8zCn6It7mM5KJssBi5LuU8ZHmWEAfL7KDt/dbJQ/g1NkDxvIrjMRTgCA9u7T6Q+qISYUPhNdYlQBSqTbkRr80Ai4x9bHuE78ZbFh37r/aucVFMM4C9aH3YanlWbANpuxzNikgP3bRDDqipCIPH8/rsCxond+N8/WcfiGQrvSnxyp1P8Dvx924dL9DSKqWi2HbTWaDl/ZTnTU8VQKd0e8HrpD2P8PvAlp0iy3F06SjN3esZ9/UDay/XaM7PUI3xQtijbbsLSP400oDnvU3K0BkfKlVWYehZiCBzxmxRQmyWWwZOJVk4V/BquYLjg2FGJXTfjfANG5X5oAmNgyHbKd/P4Ph427Bp81w3aOTmJErPVqFi8juheGSdTVGw4VJdkK4xvAg9j6PXDQIfuCWxbJNnyAjNjSLbDIN5yT6LWaxNdGEJdBFi6ri90I+w3FchYY7KQ3N7dT2RzcGJBTpvrCGEVBvMyVKrTYYFH7xntFmfHGjYKu7XzTFCkXvmUmL/nZ9QcPOlAyzvxO98w/vjBfFDP3CGRw55I6pYQ1sHhEW8GNt1D8V+nexazXmbAT5CXoP+LQCADRaTLN4UnTA58y6dr0xJvSKkBs6hEdYTEruO/dHIzdbLaMsWC1DiRQRhhjoQEU7yPwiASeyw+FJmJnBSPBGhUqj0u1ThQuYtPCV19utYjzjLKtwhq+u7vyP2sU1xjrvLPw4A+OMfjlJPf+Tr/wXC+CwAoJIa25DBMuB67lacn899fAHbXI3zQZnw4toc12YHAIB/8iv/Ev/bP/vnAADf+oFvAQD8nX/2j+AH8a+K2z4+voOWIruTOZvTqwIlBXMVr6/FWY1XuFausEfu4dU+Mu7Pmz9MNptRaM7id9tF/I08KBzsS3GHNU54BC1QGetCg03QnMD1waoE3QqzLTPb3qb0Se23YgM8X0ZfYB1e+Lg8IGVcZAlKf3FmtixHfyfuZ/uyQzgkO3TOJvx19OcCADNDgv9azrPtMxSMZCxrmJ1SsPTiu30vHtMzX6hw+Vqch3rySQDAWfcS2MqGJ5/6owCAf/J7v4rP3YxzPpvFtVi4be1sxcDRhvxCl7PYGCgY9mmKClRlMsxVBis079cxdpDgbuzGbuzGbtwX477OsGA9ytykYuX5krCB7TEyunUs9LoM8NKnQPpZ6D00w5p9FiKLsoDOtxIjcQSUfLanloExoKCNxlyaa8aQIARDllLdadTsWapUjWUb91EgJJs5GLESYOF5uVnClIzsDyJ8ovMc9+iivCJTLtdIEgcjBW9z7aBzadTiH+OiPQGAkZkfnILh8R0RUp1efgA1+6U265gtaYxQjPyXLoZc67aFyGTUdFg2RqEWh2DCnlpp1IQvp5aqFsgRVhG6G0eBBnM0Kh7TuY3svcJeQb6KGUI+xH350DcM+Pbvioy6B/djxLhnAkIbf/fOHWa3FVBuYrRXGo3Ak2a0nM8NFMQlWuC4SxjYT6IJZ2i9tWoYVsx+SwfbCsuR85tlMIR/lKhGDMBIx2Gxl8hKgDZRkLDbhws9Mlyr1nlxqYG0B242DSZM87NZnMtmHWCYheReXHwtFhQ1vjsI89Xjm94Xsyh7Fnfm8uMfwNu/mdDQ5BcBAH14FnuEOUfOqRstSAhFh4cAADeuXsNv/ObnAACzq/HNvYMHsMfeq997+R5+6Vd/GQDwrV/77QCAjy6/Bb/2278Zj+XVCA2FtkVJm4+C5JJeOby0JnuNsN3aeVBfGstb7Jt8JscP/M343bN7cVIPHrHJqbfkBNa5SZF/OeU8lwpeSB78vL3gLiwcGRXMlixE8oVXAeRGgYkKoFU6T8ISVTDo2VgZtolzGpJNaR3lzwDA03tLW0ARnl7c4ZotMxw9QvFhCgFnmUcm6h3BJgUf8dBSSmGwgmZwM9c0mhChu888E99bDUu891HN+Y3n4WQ14OpeZGE++2pkCz7z8otYUy2mkiUbNHpCgo14smFEVsp1JYSTgEymn/enwa1wp2++LD+sXYa1G7uxG7uxG/fFuK8zLOMiBb1h4aCzQlDooIlR5xTX9IVHy/4PiURyk2PCngCjpKheIGeItW5jWDJuOqhGahncrqlQsMbiztlDslem31uvYwRv2wHTesrvFDihUoOIVmqrYdlXExhhKatgWdxsWEty1uO8YY1IjCizLOkBehYd9LSCY4a4Zp/FyvfborCYU6oMro2RbnUev/vA9BCP7NGKg8fRdi3u9jErPGO01BovDiY4b8W9eYTJtlkqANS6RBXiv7OB+zcqHMpc70d6dhgs2obF3GnFubBQLv7uN319nLP/6k9s4IQ+TgPHpuySW3Gg/0bZ5/CaUbrR4O5AqXi8dmihs1grK0idV/MKjjqA50IfzgfUh4xmXaz3DeerhNmbTNQXLEQisOA+2M6jO2fvHtM0oxQgvWeMvrMa2wiT9OAhbNUKBN4vKwW7JDmnYwYbJsjyLQoARGWVS4dxX+diidFmMMv477d9JJ7Xb/lmhwfnMUsa1WcAAFMA7px6lo5Ei0bBcC2++mKkNDv9KDLqTx5QQPVo/wHMeK77oPAbv/MJAMCbn3gXAOArnnoPnnkpZta/9rFfi8eUaah5nMxWom8LlNP42h7JPnuHwLOvLOK2X4lz+m1/xOGwjsor55yP5acdQMLBjPYs5Rzwx8waWYrFqJJlTkkC0LTKMTJD6NbM9nq77bmSepRGksoRlwxlPEpef0Lc6DoPLUwjIR5dSCSEIGHCtk4mLs6m10l9ZmCflbruUdBQdrgn7TslTo/ZZ3o5oJpJnxbXGyzKmlk59yUUFT77TPyhzz8b1/sjj70Jj1yPdPZuE0kwzitM5/H6/Je//hwA4Hi1Qc46meeEjMqA7XxopV0oaAQr2pWsB2NMKewg9y8f+1FTj9vrGPf1A+tyfYSzdYMT2nCLQnS5V6QTVbGi3diADYuRQsTIywlKQlriWtx0HQoWU63YO7schpBLZchgyqqtL00j/k42WV8LaWHwA6zc4KoMw8CHIbdtvYIhtlRRHLKaFsh4Y7bs9XHeYZ8XVyC7r6xKaOIS4h00qICGjcgDL56hDamQLL1GqtBoeDW+PMQLvy97vNNE6OiQluSNzbASAV7CaN4gXfBC/xtUQEYbditir1ohEPsQ8oIKIalmJ4dUAEVLJiJvyrNsiutH8Ub08OPxRrdwFZY3I9vtypX43cNqgZ5ElwkJJXNjcSJq2CEk91MhzIy2Q8BrLxLnT3ByQv8owijVXOMLn4nH98CNuA+zacDYsZBtCR2OSI2ZjlJE1ofENs3Ev7xS26K7QMdWRftzbJljcZ54g+FuDqMXwiBMuvA92LuOwN+tFZJ2z6Ina8sOuPFo3PiHviGyMC/POoBN06LkXl6wZh97YkguQAT0y4MIEX3m7i3c5Y3t/U98GADw6I3Hce/5lwAAR3sGJ2Q9/Oyv/DwA4Ds+8H248kjs9fG/8nsAgPXxCvtkCOVV/JEyODiqw1dUh59P9/DsK3FdfMvbYvD33/63t7H5PKfViHC1SWLL9oQ3z02/haUEKitKgOzgwOswZIDjv6n+BDdsH1QSOGQwmNJdWk7IpmvREQoWOK4ss6SqnohyIdlwbVffGFKwo3oScnoNK7BjrAhAX/IYA+8FjFGc63H5KO5Ln3VbkQFiuLbtEgPZsiwx9gU+99wi/ttHSbAHjq4jIE7mmnDs1SuXsKGo+CdvxvvDonUIbP4VrlqoAJ/JA5LDBlgrivxxTEuFgg+7nvfhpo8CBuH1CbXHeXj9H92N3diN3diN3fhfbtzXGdZ+tg81rbHuYyF+ZHG9rAtMRYWC3NJh9MknitZQaGyPllHJqhUfqwblnPRLytRUWYkM0htyEL9clGgb0mnF3mTdoZIoWbKIicICFC1VHguqH9yhUkNlFPboqzVj30mls+SKLJJQCgqKEYrQrk2mkuSRFHGHfoBmtjVjJpZ5k6SiLBkAQSmoOn63Y6Z14pf4PPuNiiUdkzeAnTJDmcdwvmhbaEJbFaHXtXUpHXCEu6zuERg5K+ltKyy0FfkaigxrlSR8RIWkQAnnRJqH52g5YDKhzBJJKcMtn6SShBizXnmokd5ihxaS4FiBXPIcXW9fM28m28P+ITvxmf7cPr2Kf/EvIxnk274rRpmPVz0MZb8k6lYXNmSl6D6NMB4ABArd+kwjkGiimZVrGyDtg44RuTGAYuOfiLRqnbyzU2OX6wuM51RvYeie5QHXL0fU4G1viNTzw8MW3/Kh2Lt1nQSLceVgp4SlmSX7LkfoRG9KshKFgb/8uy/FuX+uKyGXwcB+wkrneObVmGGVwSPsxzX9hePnAQAf+41/i49+6I8BAOrjGMX/k3/+zzDlejxgxtNXwEDMdSBB5M6nzvHtT8ft/eRfiufDP7uV7LIiHeYCDO2CUktm5+EIyQltfXQ92KYFBvtoXAfpLknwn7+Q9QqgYB3603htSIalcqAsJCPmcfQ2ZUIXfEPSkExL++06AskXNgCWN5KcRJFJphCE+CEEDwCGljp5ERJDaKCtDLxJpLKsjl966ZbG7dsRpcj2BKp8FXZk36dhm0F1HZ96MWbgnzyO8PlpP2KvJsTHddl7IHCt7Mk9Nyi5xcJw7qvMJBd1K2QUBDgYoh2vrxdrl2Htxm7sxm7sxn0x7usMa9UNcMri0l4sMnuGKnbssGlYYRX82isIq9mRQNGuLoiRCtV9kmPgc1yyM6cVOhbxey1FTIeGTY/dOkZ9OgCa9RKRzB9g4UhDr/MQpTQAjAvWvaYKlhi8WC1oADmj6ETOUBk2bBj13HatNMBoyifxUIM9UtNF824fs1RTO+O+rF0HVVyQYAAw9D0+2cRIfHPMDKRTODKx2bSckb2gdaoL1Qy1QusQJFRk+FgajakYaYqmIDRqdjtO8wjQV8OITGqO85jJni0PESZxzt9wJe77G2cdPKPQcB6PzTYjHIsTYnbYQKFgo2eZawQqzXYkx1RzwIkAqOgHVhtkRuo274xzdO+dePfXxUbaowdiDQuqhxZ/GkaKbtxG76JDp3OV6jI+kTOAkc2YYjODLGxdApkxGLPttxRave4UlNDeOffr4SqKltT/PFqTvHqnwkNXY7H8//S/ixs5vLbA0RXWeWgBoUuNVtyMUw3Fg/wK2At2FJokpJZZ6dlqxEOPRor7e973/vix0eJzz0RNxHp/Ds/sUoz8Fneew/5LsWXhu975VgDAre55vPTJWM+a0J6jDwaK5qEnzGT8aYmf+NvUjrwTJ7A9rTFnI+2SYqJFH9AzPZLMZAw+OQkbER0dfGrw9Rd0/qSp2/Ak+hC2Jo2yxny44CQc/yLEVgRgK24Lp+CZuSoh2gQSb4AtgtIoDIv4b2nWzmaAPmQ9i3x53/iUNooNUakMbBDEZNuMnnO9WW1gmaKP7Pp/9oUBWsda5Ae/4d0AgMcefBF3b0ZixbSOKIrtzvGbn73N7cQ15vxd9KIGk8lUjukhwtsNlIvXBIC0ZscADKynteO2QzszGS1XJCf7g8d9/cD6/OI2qkLjoYN4Q6WmKro+JK2XjuT/ssxR8ET3fBAp71IabioWKvMcGycuqKIqPSTZ7IELRHUZSl7dhrbdOlOJsbjZkF2mFPZNTLevuBozFlY1HyCd8VsVhZwsMO2SMkRZEirLCnQitUSli1pVCVrY0O8qGI1JHh/gM0Mli2IGX8SFmLPzXeEUDb3AZCOq96goXzWhWnu+P0kyMeMqfr7wGQoK8JZkxYUQMNASfjLhw0flUFRJGCB9LwYTOjTXvMCqscWm5HxVUXrngcriQ++PD8/3P80u/Y2C7YXdEIOFHDpBOFLIriuPnHdhr3RijAaed++GxBwMIkWvDhB8PE9ZFllzb3v7Ag4nnKJ4vnwTYP2W5QTE573cxAQG0sOWtBP4QbcMAJ+Jeo8P+synPZcboMq2UJVc+KUPyWlW8c5Q1k8i483krPssAOB/+lc5bp7GDf3oO6K6xd5Rl7AtR/hGuwKBz+BBbpRQCYY9ZcAyDApXH+OOkcU63nL4+u//CADgK5+OD/d/+w/+b0kcuZ+WoBhDski3+TnG1QsAgCtHUWXkXe95F26fx/m994V4wzR2BuviGl3SV+1rvspjyr46gZ03pUc9yNOcosFeJXURCQK0URiTJxT3PwCDfW1zlN9uOgVFbgwIDCwyLZCfTxBfErXFBQULWQdj2DoyCCTdbx14RVlnbDTGhmQriv3a3KFbxteWy3ggVy4Dh1GQBpkwMpTeyn8pAxBqXy/o/N2OGIs4l4OOQcxLJy9AUdGl9L8ev+tLbMi03JvGtbXpenzhmPeKDQWMR4t+wzUo10+OxJpseG9zLmAtpBZOSJFvvcAuqn0oN6Z5fD1jBwnuxm7sxm7sxn0x7usMa7MZMVqN1SxGABM+f23IUDLEmihJk0s0JDxINN8HBcOUWbHjPs8cpvyupfy+ciOu0fLgqli8Ng6bSoRnSSLwIwY2KIlSwSwrUBOKMK3HvokZ01NXHwEALIYGvY/R48gMSqFHyaxgL6UCDiIDaLEtbvZdPBYhEejKpL6kLFCYVpfImRHNaNzT+Spp03Us5mqtk7VGQWHMSVZDEwZdEwK1Nqo7AFulCz3xMMxmCynMFgodCSnrVjAuhS33Je7L4VBixeN94uHYD/J97/5neMOlqGtoxf5iHlLvixetNYSkbiAis4VWMOL8nHkEiBssI7wLEJ70RaF38JJqkJ7vzPMpdUpwiwIGy3URZLuAxH6Z6AG6AC8EF/5UvqdRiGUGMx2PkIruQaxH/FZtQRgjo3FbkgGNuMazM4yEf0E9yAcufw32b/ClS3SIdiGVtNmaB3vqUVNrsAFTrWLEOMZzcve2KJkEGPaJzUZRLSkQCG0uX46/0XiHrIrMmsIW6KbxYPIVoeB5h1smQoKujbDvtctP4rG3R9jvE82Cc9XDZ3HjPX/36QdHZPx3xj4rE1r0THVE+1PnOrVPSCT+muxXSAs2OnMDW2+5YF1yhhbB1k3vU6ZW0wWZpi/c1wsZlmRn29OfPMpKwuI6uAQFj42QwANmkQOR2h/61uCMOn+a0OpsHhIc1xM9mh7Y5EfnlU6eYufSb5YBSuxYbsY1c3LS4V3viOfp2tW4oM6Pl5gRhnW8sHr0CFQ2OX/p3wIACqOTsHXH38jmGh0z0s5Jr6pOa16uOeddImqUSRVEoSg0/BBwnnDxP3jsMqzd2I3d2I3duC/GfZ1hHYUCnXNYjPG5v2Jjra4VSnbqF1I0tWNqisxrFiLzrUuxhehzaehMIgWqNKjY0AgAR0laO6DbLOL2WBzLywpMFBKeP1c1plR20FalTEc6xvNQYmCtqaGNxug2yKWZlFH36PpkiaC4X711yV1WSz3I5wCLm5rzMRYuEUM2DI16NPC5WFgw6i9M0l5suI2+OYdhhpXotGWBwMzPirpFXSdjuI7Osr0dYMXdmYUmHbSw0GFZ8Gm8Qd7EqPtdj8as6qQ7w6/8bJzrH/zf8LjHESPdcYUx7MMFZ2IWy50z6VxnyqdGZVEX0NZs6w/ifot7cKzfiW2E0WWqV4k7sg8mOddKcqYcLjhWc8fUlkuREPo8pGK1Zw0l2G2RPznYum2NIOM66Kc60ZpDLYotL2NzGvevPIp/v+uPrjDQrHFOE9HCAJZyBO1SCDkalsowBbPCfq3QkxS0fyXuwLQc4Nfxu1/3VXFi/uSpxu/8xr+Jc7SMOnMPP/4Y8heiKWYFhXs21kk6toPUSmNVScYZM6jlao3r12MUP712EF9bvoSeM3b1Ujz/3/j2DPs8ljXtMo5mGTZkptTM/EblXpP1xDlNQhOJYBXcBW1AyQACUqYuLbCZDml7mum7t25LrCCioHRImZWQZcYR6V4gF06hQ6pdkfsE7z3URMQBhDKuce0htmVIRg4H28W5tCRaqTzAUpUgLxw6ksn2rvF3qwxnRH8+/omYCRfFBJdjyR8Tno/T24vkOl7WBwCAuycTBIokrE7i+Zo+YFCycbhVopyzdWUWxZYy22phCgkpU0DB+cqTxilQ5ZrX0+vLsO7vB1Y9xTJs0NPLamBqWhZZ6jto+EBq+x6aM5vzrjHxW2dMY6TfoYKZkJkjWXtm0PBGf+s03lDNqLGgwob8Vp3nmImzr4lp91TVqJi3a5QYhAzCdvqimqDkGdQsvnadgeND+N5ZvPAXwwal2NZfptiuAhRtA/boRVNohaY757ZltRi4RELgg3AMSRJqJHnETzWyVIX26XNyE9aS00Oli2ZlRXR3e1MfSOLoRgsvasClQGEKhjeLgZ5hG9fi+iY+tO1ZZLv9zMe+Ab3j03rxMQARystZ2JUbjVIhXRStCIaeARNaikywZV8Js1FZDScPoiQB4pLfV077GZ3twTGIAPtTgtsDqji/WvSEXEgPm+RMq7Z26aK0AIskuppcXjwSrBNyYVQFZGLvQXfZbrCYk+0oPVcb1+J8FXsQr03j/h1MfxWupf8T50O7AJnKgj1hWTUiI2Q18vPOFkAWIbrLhIvMCGRcs8fH9Ji7NcHj1+Pvnq8i0/OhR96arHDOm3NkNq7RntDsRE8woYLteRe/+++fexmHD74ZAPD0E18JAPitzyxw8mokYnzkq+LvPnawQaDckEglKe0xsKdtKuoW3idBWgk6nAMC7wuCiwYf4ITVwnHRq0wxmCgMEsFCbsAhUwledRdIF/KaLHelVDrvI8kqucrQnBOG57KqpoCphRFIWLlwKGoGJ9zw0AUoBnjCAtQWUBSXDuOYjOgU2ambocDZOuKNx/ciqeWBh6/g2mWq/zDYVQaY1HJ8lOG6ewWeHif5JcKx+6CUCpAJwu89cmG8yvWjAwoJbnlfVeqCyke3JV+40SWll9czdpDgbuzGbuzGbtwX477OsMaJQhVKZBIpbChoWXqgYurNx/oIjYI5eu1J4QxbfbkpSQa1qhMd+aZaxO8Gj55R111mBSVq5KSKBzZ19MsOM4q3Hs5ixpOFLPWvhKCSsZlADV3uEtHBBinw59CMWgb2k202QEnIqiZtHCXQi5UHs7R122E5iDqG0KoHoJBtE4rsK5iOUBsjtzCGJLqbso3KQKVqKeObPCTChmW21HgHaXNxzLS63sIwPasJIeYmh22ot0f69dOX9vHhb4oU7Pd+VSQP3H35Q3jfh6I4KxZRLNV1MVsA0uHEDElUJph19d4L0xmqu+ACK5YtF6jJSf8sZKnnTY5tHBfwmWR0Ex76owjFp+NrjJaVj5ASgKSWoE3UTQQAyzd1ZqAyEUnlOQpb2Ek+nzlANG0t53KSVRHDBrBZsP8vbOAq2qTw/K+WI6apv06idIuqJCwK0YUEBtqkSP9RVjvkhJCUj9dDt+4TTHg18mFQXH4I2sZMbG8/YlvTPINSMZpfFSvkYpbaUm2lHFDztVUTUYq7569iQfWDr3rz1wEAFmuN5cu/AAB472Mx07pcOQwb0YPktTJ4GBfX1DDyPOTbTEjo43YICep1YuPh3bYfiuvce5fOnZBlMmUSUUDY6M4oOH9x8eGLWBf8fFBQggJIKcJUaDgHBV2t8wKpl0oL1qjLdK6LPKbGWQa0gmZUci0XyEkgg98iCYvjOC+b2R4+/wWx8Im/sb+nMKFG41rOf3EFSsX17amOc7wsUHAR1g/EbbjCwnEfCZigclFJBwAGaRAzATmvG51yIpXUduQwi6JArgxbH5JP+R84dhnWbuzGbuzGbtwX477OsNwkR62KhKdmLXH3TqMl3duTKqrqHE6MEhkR1KbApBLbZmYtrU8R24qq0ao0UFRAbrttk+Lh3gEAoCR11KFFTxr3cRupws556WFGUU4xzbeW8kA0PZMivhAUdA4UTFeO5rQ0L6aAUNhZaPdzi8DvtKwp9cMIkde2LO6sx3NY0lUlmlNZhpJNgqIAEHILV0o9I+6fHVUiWwRmXwNGSFlblD1857FeMXoUjTptki12RuuJ9UmHcDdu/Pv/aHzth7/1Hg4fiDWPMxbx3//0/xVvY52kv8kMwOlUGJcisQ2AOCmWzCKuXtfQJB4EH5s4ga2ZZ5GHpN4gdQPnkbLa0Md6mlMGZkaLGaksmxcg8tJCo0eI0TgAaMlk1QV18FTDctDMSMVBXE+LLcee2/P9AJH0q0jY6dc9Vvzd516Ir73xXQ4Zaw6KmUw5MVDT1zY2j+2WOi+/6/0e6gPW8To2fi42bDQAzs8Z2XuNaS207EiDr46ewO/8alS1+FPfx9reqGF5rQ2jgqH9y5StJH2/xquvRlLG3SZmZffOFjg/i1n0ow/FBuRq8jg++NVPAgDecRSzx9A3GJiRVjRLHdUGJdsPtI57rfy4TaeFTOEAayXbim/FhIh1FCf1r+1XJVkarUv1W6myuLD9t5JzrhUUtyOuCcGFZPaa8R9d04NgEOpS2i5MWkdSt9TBJIKN4vU12jFpFwayjLrcIifdPht1+reVunz5ID777Atxm9QpffjhA8z3YiZ89zieY1PcwGUiQrfPuCbKHHdefZH7yN9VW43GGe+hR3oOzxvcKe2PGuXTPc3wOswKjcDrIGXBdVT//GLnhD9o3NcPrP1sClUorLnSRsqx2DGgXcXJk051NWg0vKmvCXtslEuOuDlhIO8cej6wwAtBjzkUu+onvHvnukgLIyNEV+/vwQ6Ea1bClOvRUq7ADB08U/gpt22HDCtpKmKPlDaA50V2fT9CZHN1hIaQxr0xPgzHYpkYfAKzlUElh2NJ1RvTwbKAannnMlCJfJKTkqaLDKP0+HBOTZYhJ5y3tvHmHZoOFQkFIrO0zhymFNMdSObQ2mBDT6jTZdy/Rx84wo//aNzeH/26yFyanDt87m48tv/nP4oUp7/8Y89AL2j3QfWFLgwJEjBJecCkYrn4U833PDSJDOtVwIY3CWGG5Xr7sLGERddNQMYLqqh5XmcamvOVh8hmC/omKvYbqYLBSzsmGClZ/agt9GEIpZrSwPEmO8oNM/TgtCYX19zrpNoyEHvRmwz1PK6dt76Jd41Kw8y4vqmw4rMAk0shngHLZkAtPXIUPA26TxJf+UDPuE7DbuLETbhOrbPw7M165dnHAQDPPncXehIfJg1VELryBDduxO+e/M4UyylhxDxeh4fFAe6SGtfW8XN5PsP5nRigfPoLvwUAuPzQ1+BNZAc+uceuJ68wOgk22Ps4AlNQxQFklHidPK3Sw98Dlg+R0Uqws4VwQ3r6ANLSBksbjGC3sk6y8NzWLgTyULno6SQ/HxQCA6WcxJNVY8HThGkSnlXQQYLEeO2Z0CMnrGvINtaFQSZiymSJ9igBWo7YXmGQBzJLH8vzJ3Cy/BQA4OhSvI88/tQjSQ3kypWn4j5kVzHaeB5m0yibdXDlNlbn8ZyIMlBeZ/C8XjRtkiZFBcsHVsaDU3lIgsPpRGiNoAQvJxvWOAxD92X5Ye0gwd3Yjd3Yjd24L8Z9nWGt9Dn2TIWa2MfI3iFnfaJVBnH7HD0qZkcdC59tsACzkd6RCuwLzKoDAMBeHUNzN5ljw8K0WGeosUfFQvesiBG3dwoz0a2rYjR54gtcZtRyVGa4ufokAKAe4+8ewSEEKfzGfZ3qFrW0JWzi5zb7FeaE2qaI+3faF9gQuxu5r/W+QUs4J6xZIG1t6pHSh7QRKHXCNrIQqclBdfA+QgM+owmkP8c4MAslZ3haKFj2XAid+6AIuEvI0LKLfzgGChZ238bej7/251t85WOLuC+3mclcrfHP/2X893Ov0AF64rFZEJ5gGFyoC2Z4IjaqQyqIS29VexpQEML13YA9ts5Z6f9CiTxE2O/mMxF2zEvgygPbYjsAjGuLjJlJVr0QX4OHIvS1eDj+3vw2YKLzBkZGm2NpoJjxhUKK24DasCeHxB7nHVjnBlEbdKdT1EXMYOxxPK+TS1t9QX0U10F1qOCl0E2VDqWQOPPhlAX3BdCTcFIS3suUx0hoeb1kltYD5+zXmpTUXQRgCOvlVbQPefT6ZUweehMA4BPPxHWyd/0VXN6PoqpudhOGcgtXHoqvTcxtbAZCfCR0WOtSL95zN6MI7tFj78NDB4/FOXQfj9/V22vDsg9vqq+gLyN5QzIUC6AWGxuu7b7b0sszwsiuN1Bmu34ARJULoaSTRFV7JPULyZydR2ro89Lz5RTghUzFc+7cFjJ28ZrbnDY4oIC09PBZM2AkRFCQCDLahDonssxQW/AyxaWnCBOHGe59gWLcpsFttnxMmDl9/t+do74WFXVyZvlPXJ/i7iKeh9lcNEcfgaY2p6jZfPbz59jYRdzXMd4T1nad1qrwr+7dPUfZxv+Ycb1XBwYD7781v3Apn2LTxez4nLhiuwaML0lr35EudmM3dmM3duMP0fiyM6xf/uVfxl//638dH//4x3Hr1i3803/6T/Gd3/md6f0QAv7yX/7L+Ht/7+9hsVjga77ma/B3/s7fwVNPPZU+c3p6ih/90R/Fz/7sz0Jrje/5nu/B3/pbfwszFv5e72gXG5S1QskMQLMmM2oHiKK2KJHXBQZGaeOakbQHahIU6iwC2JeOjnDj+oMAgEERG89z3FvGutHdk0i/HtUALYoTVFTWaGAYyTZD3N60vY4f/+E/DwC4fbLCf/O3fihu8o3xO3u+Qq1i+tHTic7nE4xl/G1nYzSU+x62F9v6GKnYiYdjUdMRy96bPoSpjxHx3fULAIDj4LCfMXPJ4zZa1aGu43Yy0vebZYagYmZomPWZsYZhYV+zURPdEUrOa87fCm2JoYn/3q9jUfctb3kCb3k4ztef+q6YgrxRj1i8EKPGvVgWgpponJ9Glfb8KokxrcGM59Uq1tUQkqL5RdQ7vSb0dcTaABCj1YJR4Yy1xrbvsWniNqv9eEzzA4OGx1Iy/C0rJIuIwAwmKzMMfVxbExF6v5TDUUVe32ER3LlUS1BUrh6WHkoadxO0b9JrG2at0ysFbkV+Aq5epmxB3iMrFvG7UmtxJWwn4ookglgF21OZhMra2oVkXrm8JRmZhydVXOxZCuUx3d9wTmX+NAYe2/Ubkdc+2X8z7t6Ov3fzPO7oE4ceb3xD3NfTk5fwyOW4zi5NSQpabXB4GNOfm/fitbRqNqkuc7aOme4nP/Wr+OA7echS5xlCUqkQ1ZIQmm2TNt/MKpUahkexcA/bhuBkxnghU5c1ExCSqaNoNsJus3vpjYVSW1sRobo7v1UhkSwtROcYYGuroQIStVsyN6i0GciqdnqrBmGIqswONXDAGhaRjnEIOHo8zvOwWGFWxYypnEfSyhde+ncoj+Jnr1yO1+FmfQ++j/eCO20kvLTzLWA5uwABAABJREFUGyioi9mbg/gb+rHUWOzZChHCVtGjY6uDsQH5JG5vchDbfMx+QAuuIyIhhcoQ8njBnNFKph0c5lW9VQ55HePLfmBtNhu8/e1vxw/+4A/iu7/7u7/k/b/21/4afuqnfgr/4B/8Azz++OP4S3/pL+HDH/4wPv3pT6OimOz3fd/34datW/iFX/gFjOOIH/iBH8AP//AP46d/+qe/rH3x5woGJWrDiSIBQOUaI82vehoQOT+i4E14MoufVxbIePIzUZvYB4ZZ/O6SN5qpAqaFfC6+thg3WLJdvWTxt4bCitIxjsSNH/yeW3jv2yK54LMvZygJD7y0ib/xQFYjVyRodIv4G2GSHnw6iw+VCQoMPJaWxfQmGJyr+H4xjfv34uYVHDLV90N8cMx9C0sqZT/G3hZvHBoqJzgXC+PTYgrfx0UlNhlKeWTCBBQTISjUJR2YCZmtlz3Mcdznr/7KaGXw4//VHE89HAVPNT1/7C2Leh7nQM0JXfoBD0zfCACYPcynmP+9bZu/sLGC3TK0IP8IiY538cElIqjTaQ7HYvD5hn0sE+DSFcKWJFV4rZHRtiMv48FrA/St6M4QMoZGTu+uwAccSg1PyDNQXcK0AZr9d8IW80bBJwiKn/dArhioacofTDu8/Gr8vQfeyQBp0Fs2Ga9aZ/2270t835yGSsoK/FtuLTNEHcIPAUbu1rxWyixDz7u2sN3yzKGisMrvfSyey3/3yQ1euhsDkAffFAMNUxuEnsyx6QEuU1KoRnwQdX2PdRO33TAIsNonZQXDg1qefhaWqhZaiFAWW5UKYcrZzda2A6JwE5I81zhsHyDJSVgeJtFMhC8i/ZX35S/8hfaqhEWFtC/yEAvugv+TtFKpaH0DAO2KpAWFBOFqcShWSIxhJcLTRidF3zY59ir0JPY0d+OP9V0DXsK4dAV44pFvAQD88i8dAACON4uk+PI03adVOyamZccvL856PHwtJhTrFTFVdQeHJdmyDBI6IO3XjNjr/GCOigvEU1mncysMK8pNkVXdBIOCsnazjD1ftkEddGIUvp7xZT+wPvKRj+AjH/nI7/teCAE/+ZM/ib/4F/8ivuM7vgMA8A//4T/EtWvX8DM/8zP46Ec/is985jP4uZ/7Ofzmb/4m3v3uaCD2t//238a3fdu34W/8jb+BGzdufLm7tBu7sRu7sRv/CYz/qKSL559/Hrdv38YHP/jB9Nr+/j7e+9734td//dfx0Y9+FL/+67+Og4OD9LACgA9+8IPQWuNjH/sYvuu7vutLttv3Pfq+T/+9XEbYqgyHKMIhMmy19YBIH/ZCAJAoeOywl9H+gBH+MFpsSHXf6Ph3aRu8dBIzIj3E7R6EGooacOdrUuPhQHEMGBN/o7EV+kUMAb/3u+Nr3/udx6hm/xAA8NQb3o3v//aY9fyffzHu3+XrORBo29HHbMn3Hi2p7q6M2U9WFRhprdCyIWMcOxTMxNSKVh17B3ipj8aH03mEaExXos9lPghPtQ6rVvw4WLD3HmPHbdPqpMoMGn6ua2JGVJZrOO5zPka6bLAKX/eV8XPf/y2xOP+kfhnDJ+N3BsJ7k4OA6g3xRPUSIrcjbj7/CQDA294cjzcHMDCzGxhZ5tBJveM1xVehMPM/lQK8VK1hoJmBVSI8GwJ6wlzVkegGBiixlWHbw9D4pDEpMNyAHrl+LUyU+QGeLRWeCaI+BnAiGKWoDeTbcF+MDX1ISimF6Af2G8zmMUultilGe45MSBxOFCwGZEqyS1E/sNKGl9anHQHPYzK8SPLcbmU++COjdbDcnrnACxeD5TvncU6XqzNUM/GXiFl/UQElWzUeuvIoblyKc3n3c78Rtxc8ep6HDvH6c8rAiygy4fyr8x6XpoSlKUittUrZYk9iQa4CDLe37BjtX2phJZMUSNCplKiLQoVW286frQag2jI1LmRk6XOSkG25Sts0328tTCTDUhoAjRaHjcBi298TaBPmQj/fVuwQ0llD0AW2Vhi54CzX36TSqA/iwiwOB6ybeC5++xORor5oWlwq4/3myUcj+UWNGxRECIxoq4YJNOJ1nHT9GoenrkWWx8mSFkF2TAdaEd4rjEFPyHDB+0MzrKB5zcq8rNctjrgwHziKcMSB24PpLWzwAMQT5Q8e/1FJF7dvR0vla9euveb1a9eupfdu376Nq1evvub9LMtwdHSUPvPF46/+1b+K/f399L+H6Zm0G7uxG7uxG//pjPuC1v4X/sJfwJ//838+/fdyucTDDz+MS/sPIdcGgdGZB7HdscF6iIXdlnTNqtLJdkEpasX5ET1NDlv6ArjRiAMAZuwoX24GNKQA32vi582lHFM2OAbaadw9XeF9b4mRyn/+rfFzR7nH8iwWpvev3MN3fmOMQv7Rz7Jjf5lDiRliOIj7qnLUbLI7p+7hvfWYGmh7YvsmD1gv4z7O9gQrvgOMVKq3MVLtzAItazEDi75jq/DEYVTKvhxigKH8As/2MTq7fRyPyWQGzr+2CDAZgNUibu8qi+p/5MkNvvePxYl73xOkAt9UAJtSkdFU8rqBY/OnZl3I1R6f+0zc3rufjjp9gIbPWDeUZtDfD+pW2zcu1MVTZOeDS7UtUUTwfgSTWWS1GC8qKOLtLYO9flRJxSRnA2+WbZXWRU8tmAAvFiZM9nXY2lm4lpH9OG5rGKydKuMRqMyvufCGAXj0DXFDdv1OHsdvQWc0QJTmZO+Tkoi4YgYfkoq4WF0EBQRp1hS1+IBk4T6y9qTNCEPKvuN6KXQFJkFoSQrxTY0rezFin6pY57h7fIy7+mUAwBsevozbz/42AGA2OwAAnNw9xoTX04LZbdsFZLRw35C49OADHo9HQQ1kSqjiDoaIyYYIQJ4DnidiTW7Uw5VCs2QNi2QUbwNCKkBJ3chviRiyTtxryRYAMyyeY3FuUNiWVtPwCkrSbbEzCkjz5kSHYLpFfoQaH/T2NyRl9zakDEvUalBoBNbl8niLways4Jv4/rnbw6c/E9sAFqu4nb39Q+zP42RmbAEIeYczZqTNEK+H6w88jj3WvdebOJkH2SF0vYj70MbPBzug55xLZrzBKrYHAVgTFRpgUbCvSBLJrhnREJ0qZiSHqAyFV8kw8/WM/6gPrOvXYwH2zp07eOCBB9Lrd+7cwTve8Y70mbt3777me9ZanJ6epu9/8SjLEqX4UF0YRZZjaDdJoBSFLOoVTiWXZqc49it0hBgKwoSFQSoiZuKC4bJ0Y8spWdS6EZ5Mrpo3MF1qZEzNubbQNhbveFP83LUDQldGoeDN2rkXcX0/3oje8nA81mdeqDBnYbTiSa6mAbOKi9OyIN8v8CiVMK49EhfN7768xmnzBgBAQymoZXMPc96hF6tIsFhmGvtHVCQgVFL0Fb7+WrxZXDmPD7bG38YbH4wL9t9/Nm7jhVdHWLLdxMNL9xZzFli/6rG43T/1TR5PH5IhyUS5qANycXImWUVPAnSQGzNv3hODD/2xqKJw9SiujeDvQahtJa/o3l94YiXZI/m/C2+FsBW1vSCR1FIKIC81chFHXZE56nwiRwTe/LOQoWPB3PNuNr2sBeGDJ7nBVukeB1mleqLhL/MGeYfbWCEJBA+ELI0H5B6WbnAAikk8n95EuxXjVgnaFNhGKSTq27gW2NYkVp/AhJMSKKluIHBi29I3DEBO0oVWGiNJGYGsQ6UHWLJXb9+LT/mTeyvcmMa7Zk7Y/LTZ4KVFhIJzt4ZvYuAz8nrxWiUItZUeJKuTtU5G+Ox6bnG5sK/ZV2tdemCIvJYyGp5P7mzK+dAaPZUYtkK2ITEnhJAROYHCNuRLXm8fROlpBlH9eg0OmJyr5RWntoQOIV1oJNhZRtBRoBfYrgOltxtK5BALuG4rPg0A6BLCCFF3azcOZyfxN377eY2zkwi1tfS+m80rTGvpeaPgdK6Q5ZFFaMSOBBsonPOYKLa9PkVH4sSMJLV136HhNaso4VZUOrlxi/xX50ySwdNc1CZXGAkTniwW8YPG4CCrYN3rf2D9R4UEH3/8cVy/fh2/+Iu/mF5bLpf42Mc+hve///0AgPe///1YLBb4+Mc/nj7zS7/0S/De473vfe9/zN3Zjd3Yjd3YjT9E48vOsNbrNZ555pn0388//zw+8YlP4OjoCI888gh+7Md+DH/lr/wVPPXUU4nWfuPGjdSr9fTTT+Nbv/Vb8UM/9EP4u3/372IcR/zIj/wIPvrRj37ZDMGuO8PQrlAkiigj0Mwm07aOZmHeuCSImnoqtErRtxFXYFfCUjfQkKKeaWAy42v8wuDb1MdixR342OH6LJIGxDE2ZBrznD0OCghVzGDoQoLVeZ4Kym4V9+H42MPkzH74u9/yrgYf/kCMgh5+PNbwnnn5nThvPgQA+Plf/hUAwJj9azx0LZ7W01fjNn7j02u8+Kr0eMT33vugxrf/kWcBAPvU+/NDg/mj8Zhuvyd+7rMvaJzbGNIJfKbVCoc0cntgL0bdb74aUJ9T849pXNjLEsSUTRjhFQFa7B6YqtRW4T//zufiHN0VqxgNwwzAEh4NPmzhE3XxzwW1Ag6BIkYLaOoeSr8TtE9KAmIq5wF4NkdN5rKeHASNTB6APZAzeuwYVhejQiHnkCGgM4Dm/ghcpAygmTBbZpxhMClTMxKeWwOHGCVnWVRGURZQTuRbtr08IyPxdiFahvYCzMX2h8EDJOeIKkiWbZU/tJBIbA4n/XWMZZUCzkm2WBOOq+YKiqwG6QPzV0u0hHqXt55Fxkao83OhcWfomUIkx2Zl4Ej9L31cqweqQy6wZXJU3KbJpqJDtAspOzqKbZNoG4dNvPwShhcuMCeCZKjbpCtpjQYftqa3F7Kl5CCSPh/wxQiWcz5lWHJDzbTGmhCeZORBBWiiQSGxNC7CkszYe41AxZQ5kRpkKokkS6K4XAbcuhPhteN1hgW1NztCpEePFHj88Zgd3XoxZuyXLteop/F6NoTwFmfPQyPCzSvqRm66Q6zYP1bNt2LQQpwRpV5dqGSfUimB18vkCC7qM9WkSgjGig10TWvRF0NqoXg948t+YP3Wb/0Wvumbvin9t9SW/vSf/tP4+3//7+PHf/zHsdls8MM//MNYLBb42q/9Wvzcz/1c6sECgH/8j/8xfuRHfgQf+MAHUuPwT/3UT325u7Ibu7Ebu7Eb/wmNL/uB9Y3f+I2p+/v3G0op/MRP/AR+4id+4n/2M0dHR192k/DvN7qxiTUlAU/ZWj6dTBLYeb6IkcW4HqCYFXgKYbVhTFp4hjisGQNG6eJmIbic1yhJ6AjM0rrg0TGCXd+J2/3KKwpvf8tWDR0AAnQCtr0KGKi2sKKE+DL0yBkS6XXchxsP5rj+YNyHJ27EWOWPfmuPt78nHl9dfj0A4IHD74I3PwMAeN+TsahqC4N6KnTapwEAP/dvPof/w988jdt+IAYOf+KbPd72ZHytWMY6hFYGgTpfDz8Yj+0db/Zo2ZxsSKENakRF87wNFQ9UC2j/2nqgO84QZgTmD7c1A4lQsyRJDlxCpOJb6Vtcaygx3+QXtN5GozIuVq+2pawtldmoC2oFSYdwqzQh0bdWSOtHkziBXKGSjJqZuht9UiGQeoQeFAIp+pY1URM8REQckuVrDcXjK1nb67sCjrR2UwhlXyH3MaoNVDoBAMcDSIaP3sCydkWhAmRaI2cmOYrG5qhA3gdqmv9lGZCzXtUyS6qMgdLU4+yo7FEZ3I6lUCypnLE/qzDSYDAnYePqpcvoFy8CANandyGnPWcz7JB5DIHW7MwUNTQsa2ViezN0AT31AjOxBtLb7MjwWrcBANGCfWbvZ68EjMMXpUScsfgSs8GwJU5IVqPdBSWM1BG8VXq/aCWTmokvNBhvVTR4j3EGPjkeyCILadOyVE08uLjpUTIsQAlEIJlxrlPtchWTbyytRbn/FQCARycW90juOmPX/z6O8fADsc+iuUeVl6HDqotZu1axWb8aHkEY473AcU3M8ss4CfHeqSQzLgwMbYIcG+Td4JFTrcQKkgGbsmQvdbyQw/N8DaLqHjT6foD//2WG9b+qMTjUB3M42st6MeKdVKl/puBKaoYGG57IfiDUVxrklfhg6bTNCdvpg4p3geEkQ9XzRM0W8eNVwOI4bvsBPuz+zEdrvPmN8TdykcwpMjhClhk81lTP2CxF1aDCah1fe+vDEc/4oT8DfOVXxhuV2Dxcv+aTyobX8eG0aV5ANf0XAICHD+L2XNkCc7IXj+K+fPhDUzz7+Xgs4tL6oXdtQLNaZJ4SLHOPkYtOrDqmo0POYqm4tPoS8LwZr84pFLsocGOf9iNMpkM7ImfvWOBDIMuy5B0krsBu2BaAjWAJmU+YlfyuMluVgd93iYftXy9+PEZvH3gJ1tkW06X1RqstS8/xfCBTsHLhObn5bNUWxPXYAtGzRDYOwCBH6Eg4INlDFT4xBsM5P99ZZCLJk8gINjkOS+AVDBLOKTJLcA5ZTudiSvj0rYbvZF+pnDBxCebqyCoVGBKIsE58MSTih/iOeVjc2cSnrEiHXdrPodv4ncP9+N4VGLx8MxJmaigMnJyRMlZGK6wI063cFm6UB33PB9FqzNCT9DIxAnMCnoFiwVv9EAJmZMu5JRlrKyRCRLKsCOn/0poI2AYvUu83IWxJFwk6VMkGpkjPQZUeRAmx9EjrUogdbeOTsC51mGGykKDlLaS9bc7ywsLsPIoDvs/vqlwl4d/1Mn5u6QOyedyJxSsKJ6fxorxDSP6dh1NcmXOSpnGDp90nMa5iMKSLqG5Rz1foezJ2dXxIHZ89g42J1/OKgrhd3iPjmpLrtbYGE/qyLAkJD92IilC1b0nOUCVGzusEcc1MawPbNWS1CkXzDx478dvd2I3d2I3duC/GfZ1hPbJ3iMG16LP4dO7EfbcZYNh7cbAfu7Ufnj2K000kLdxdRH02Gxxyfk6K75lToHAFVtTfC8MSa9p8NCQwr1rgGtPiP/Md8fPf8iGFqhDYJoZGox3hGvZZHWTYZ+j8CKG+3/7UgDc9RnWMP7nidjwODoSKS/hEqQTXqCGqB4TNLYy0AAhzIU4AhtGzy6N524PXDP7rH4r7tbwZf/dq2QNrRpKcNwVAM6rVIn7aaZhcaKsxrMq8QsPvdozWYK9gCBESMkyTXOYRCIHJShusRsnQVJPb691WecBdULIQGEb6hYIKX8xgvxhAb+nGHimCDWr7bylWS+sScAFidFtKspMufRv/BwDVVNRIPfz4RUKnJiQcVEg5YcjhGF0q0q/1GACSAgLPkTYWjvNgqlggL2cTDEPsDdDsY3O2R8ksVbIDO241JL0IxTYDxA80oyJKoQ06wnmS4Q9tQFELdKT5G2PKPGWuLDx6Cv9qS0FnGzAw5bj+ZOzhO7n7AvxwzskMW3sPZoVTZdAS2fA8AUVWJEWEJoi+oIYRzBXS86MT7C8OvKa0mOzFY7p7k/s6vBamA9g3JS+F7Vsyh3LOvd++lrK0ENI8yLA+fEkflnfb7F0TAu1ai1JISoQEgwmCYqZ5ts5DGj8luw0OKJixC3Ts1DZjk+M4XwXcu/0JAMAnfkPjlFqjB2wleOzB6yio26cpdmnMy3jg+tcBAPb3vh0AsGn/eerTVEVksNxc/QaaLJ5PUepwZUA9jRd0OWGf1QA0FKdes+HMuIA5W4OwJ/IyBQId4Qf65JhyAp8rhNEDaPB6xi7D2o3d2I3d2I37YtzXGdaDj1zHF+59AWc+RhYLgvt953GQxcj/wcuxgflNTzyNTRsx3s+++FkAwK2Tm+jb+J0pWz7nZYGxF6/pGFGqskOYxAjAURNvbz3Bn/72uL3v+874+cOJZTUY0DTAU9YjtAydGoXLRzFCfOotcXtfv7T4cz8Wo6CveCpu7/J+BrtmoURU5Osx1UfsWcSZ90KDjuoM/iDuf1Z4OEYw2vC9Ebh2mXN2lWoPtkSo2DBKS3BYD9OzSE6b+1AElEeM8libsk2A6uIx3ziMx9P3ZzDkgEuN2WgPI7WdUbr4x6Tplsgo49Y0MRipGWFLiBATTrXNui6Uq75EwV1BJXM95/2Whnyh2Th9P/2GSgVzJ+Z+9VYxQ0wYQwgwUoMhuSE2f27rI3H/LSznwzjRywM8s56kC6g9LNeZD7EIXk5yrBHrlDUz3gyAcIXzIFl3SCmFNM9Op1sCiOY62NwGWn53wubu8hBoNqIKwaw20whybExRMqOxJmW+P2GGjxGPfPVb4lzRabS5eYxWmB86gC7uKJQQOzQ6xsfSaK+dh7dCKyclOyAZGkq2kpUBA1MSsbnfvwJQug4bloOd3ZIpVFoUKtWuXpNhubSr8TW3zbDkJPqAtGaELGEv6AbqC4vQcGGqiwtTMiIhUOgLFHe9bQKXLF4ckeCVsMah+bnOeniuT8u5WKwcnr8ZU/alnUATCXnTkzEVf/iqwfHxr8W56SPS1ISXoPd+M742xHtQ3/02rlyNP37r5SibN2wseuoLIiziH2dh2aJj2TjcYcQ5a2ZLZsEzHTBl9isWO/2mSYofDXVhx3ODG0+8gdfJMV7PuK8fWHexws2wxHlGfxVCCFpr6JzECS6ktuuwWsUHW8+brdFZss6QHF1XBWpWOvNTSgjNjhDm8bv6ON5cvuu9Gb7vP+ODas6UuVHIaZKk5GHhgIrnHcHB2i0cAgBvfcMU3/g1h/FtqgMsb1tosvUKWp0Y77E65wVwOy7cg+kIzwXttAisjrBM4RWJInYcUc/4QCD2aXONQPdZuamYHjBrQlB8uJdXPXoj3eq8UXY5QBhgOhFPrQbGCtElzr3LXIIWtTCBqgAhDmmB8CxinwmQbhAOF9L/CxDMF3sRXRyipRvUhb6ZcOGBxve11snBVl14csmNzaktBCk3wKGVGzhQ84nc84bpXZaYe3ktcKJH4DF7kgjiNJNIIOw0HQAKLzvE8990LTJxCD6KN59wCtx6lkLIhNHqCsgnvPGxd9Arnaw1hiUf9GsDLSq6Ik8EoBDoi3RGc+GBFagN1K812nPCO+u4Do7e9iD0jQMAwKuvPA8AeGS2h8/cjUzPQTlokRliL93pyqGl+rCj5JbKahgGB4FklLbboGH/jyaj0SokBY5pRE2hNHASn+lboVu/VaG4CB55EY0NEnRs4Td5IAS3PdcXH2wCCfoL60lBjo2kCmxZwaIykudha0kiD58MX8IS1EEnlpxY+jirknOxGEp3PTDyWy3X0/HS4y4JGGrPoKFc2uOPxO/sV2foV1Eua+ziualzhcWd34nHomPJ4HCSo1/x/ufouXfqEaS/lcSegBEgMcTyPrLpOjC2xeUrMSouKo/jJlJLO8LAagT2aWtyWMSTuDn1mF+bYky2Rf/hsYMEd2M3dmM3duO+GPd1hvXp4xdxz6/Qkf48UqTswBUoqJ22Oon9Bb977xSn5/HfA2nc1V6FKTMiod+ejytoJU6szNHzLIlpvuF6fO2//I4OV9hsIpYMxSSgXcb3yzJGItUcAAVskfW494W4nTv8+83v2YNpGVmfkW6fOeSTmJbpIn53sVhgdSducy45Q+YSdDAONEM0AQW1xqxorZUKWtrMCf8UQQEUYPWEp2xlMZBIsuHfy0UOxbkp6MQbco2+5pyLuWM3QV7TcI/ZSG8ATSq+SEGqQqXisSbHV61DMiDcitYC5ovCqYsEC/l7sSieHId9eM12UgFe4EQbvgT+0Xr729nAfpGVSuddEb5U8wB2BkBafuAAO0ivmNCgbdIzFGUM4wwMM6wg4qZLDUcoZfIoRRi9h6b5ZnMSm258A2TsI5wTjrH9iJEkCr2J57Lpo01I3EHugHUXJlbWdsDIlGKyR7jN+y/RhTs9MVjR6frwOvueHrmM8z4iDpvTmObc0wpeeqQcoEh19szKrj38eKL3H5/EYv5YIaXFGbOvo3mOioK4gRqcQ+cTqeXgcozS797agP6pEAeh1xBwhMrut9R1mZZwAUcWGxXv3JZ0wWkL6kKGxdd0UIkEJPR2Fba9fj3JNHVpAFLxxe7F5FvFFBnKq0S6EEgw+C0akDhBHhj54opZztlGoecneuWhec+5cYVZY3OCkr1vmYn3B+sVOk5EVUeYELnFmsokL70SD3izcrj8CNVMJswADVDWsqqJGowZDvMDAMAbjt4U35p4fO7VKGKdsTUo5D2U3Av2qCp0GHBmX4UVyZXXMXYZ1m7sxm7sxm7cF+O+zrCafgltAjStt0s2wk4dUNHbwfYxQl32KzSWdvKVSN8r5JJlkD7auB4NM4rJZWryrRscMXv4Lz4So4G9iUtF0NRkWQIqE7096uANAzqC03tK4bCOEeKf/p6osv6m976afk+z/lEUAJw0LbOGtcqwzzpbfcBoKcO2QCwWFjmg6QQotQ6tQqI6CxafZQDYca7Z2Tgag5I01JCkIEpU+SIeJzH7MA4xcwTQMZu69/k1KjF3k4L7UYbiMs3iqjh/nQvQnGvVS/jqtk3CEuWGL61XKXyRgAFixJX03i6IHLzGUO/CNoEYZUrdIP1GtqUaC/HDTBWEnaskttMOlELD/CF5r4BvRfFcajLbLEVC9xB8yqwUm3CLwiWCjihiuCKDVWwcJzvgeAV0rF1U87hTZWXgNvEcjueisGLSufOsMwRv4JLyOWurAHJ+LmeWvOlHeMu1ytrYYlniU3fjdVO94bH4d1bg5Wdi/cM38fq6qSw6ZuV7WYn1eVy/e9TEvHrlCnJG9q8ex1qctw49a71S071eecyYjUtpQ2dAfRD/LQr/Z8dRnSIeH/++pp08pPeEYJHU0KGS1qRkRs5+6doCLoheXMzInTSHsz57YTuCHsAo5NJtLLR2+fCF7fkxpIZh0Q9Evr2PXCR7dFzg5xs2WXc6aWL2fYOveJpzPY81Ij0s4MBslTVx7abYm8YJ1gUFBkyXyF1L3kND1mOfjgyzq7F14axZoGMm1vHkhKAQiBrcvhn1CH3mkMu541xlE4POsi5+EM+1DhkWq2M48/tM/P/MuK8fWIX30EqjZs/T4Sw+DI7GCnViYzGFNXliqjk+DPplC0uli5F5+4AAx2LvcYgT3KPFU2X83HvfEE/i3rUBgQ+70sUT32JA4N1a8YTmClAC0TUapYsX69ue/j0AgN+E5ISbSo8KUOmMk9WnRkxk26yfezOFJtPHtPFz1UGB9ZoEEUImZQY4yjV1wjEZPUqyzQp5sFkHzQoqUSeshxY191+RZDA2OQoW0zW9nG48kGE4IdGlY79WbtOFIt5QRuut+LDc0T0SY1Aek5rzcHFotb2A01TpLdnioijpRXuR7b1mq+ggNzF5Lutg0sO8n5DU8CBQso8sUA6neQUYGwYjvHFMih7WCSs0fs6uM4wbmSNhmWj08jCf8AGXG6j9eO7OCOEePeoSk3J8Lu7g1Ws1HKl3vY8nVgWFgQ/9hg/8PW1QUPpIZI8G55KKhgQEuQKy5JtGu51KwcoNnOyuF+843OojPP3YowfxtbsvoSSsrjm7p63HI0WcrPVxC0e5nts8/w8a4DoZE1em8Tq9t9kgkNAz4fX14DxHpuJkNyQPHVwH6v24vZeeJ0ttBDTloZJBmAHwRfe+gC9l/wEX1CqEGXpRreL3xZ0YnHqdCDOiwBK/R9KNyJ+oAE0o2IkklL/QA8h5Hm3YijwIm3H7z+Rp1vcKDU/iisFpZ4HS896jenzF2+LGJ1RYLrMBZzZCtqMc29ihXbMZkA/U8nJASxudV5+N5IxDXIenJ56QS6bZBEUdz/HAC2Id1mgI+x13kenXboCK0PMRg+eqMrjEB2nO3sK5n6E5XcFajy9Q7Pk/NHaQ4G7sxm7sxm7cF+O+zrAm9R5QKfSMYDrCHaehwczESECTDnu2djjdbGmtQNQ30xshFMTXlCkRRPuNNhMu30f9aEwvZg/F7e15hcAwXdJjfbYlCojNhIFBwbig9WNSftA10/Z8CU8B0Y42H0WrU+SvmEZP9xQ6ISkEiXiaLdxAdYnurE+6Z0YIEZVDRlLA3gX4JFDMT0sytwJczd4z9mhNVUBYx4jNNnQhHkag4VwS4lpsLMJZ3JBAOoUxyV4gsGdJB5/gguYml5+2+L3fiPPxrg+z9eB2QPna+m4kaUiIxQKucxlAJ2ZHuCJXCiO2QqcSCCdBUWg4FroLmSy9taso7rFofVdhxXmTunA+3VLXaeKKMNpkMEhRCBStw5447nG7YxhQyu8R3tN1jo6sgemcBfSXgCqtRzryNhuUkYeBipF7czPD2cvxu1ILH7MxXdXs3oAbNYpEM5f+OmDImFkxW94MDuxSwEBo68wF3Myj7c9/8Z4/AwD45L/7WXzslf8XgG2P2Y3rD0Jxv+4cvwLHgPna9fi7r6zv4Fr9bgDA44++FQBw8xP/Bg01CZ+6Hr/w+IMB+49T5YHFfj9UuPlivA5aimkoC7RNXL85CVbOBigRVpCerzEk2vVWtNsn6cdOzC69vkCm4Gs5kgajFYhROWSybJl1GWToeW+R98rcbfvhRLnDhETiIWcFfpk6RBIJqfABihBpT5WcZmNxj5nky7dihtqdtVjwmB64DNwgHJBXcZJu3b2d+r4c7zEZiqRtWhSEh9uQjEfvngtSo6CZlhWI1+bBNEPrFwCAzTKer/2QoZP5om7owQSwFBC/w3JNZ6fIxriWD/fjetq7fAPDpAeGEaD49X9o7DKs3diN3diN3bgvxn2dYWFewUxK1NJNf0667DBiwwh6j5jrjXqCkhHDuonhzf50jskkFmjY34bbp+fYmx0AACpGw+2sw/IFURdgkVt5WFHFZmHeK8CItlsVo45N5zFhiFcr4B7rQM+9GiOUBx8NmFCzsKD2WK6yRJOWWkue55gdstZ0IUrvVzFKFgXvQl8QDk+q6CrZoYvlBMoALarqLAUo6MQRlkgRIzD0EfO2hL5LAA6i4hB/7PqYJaq7I/16tB5qHV/LhRGhC7TPxt8IxLIxtciZ+rUr8Z/PkVmRcxfjP5XID3Js3npoJY3NLHy7LfUYYZs5bdUPkLIeqUe4MaRal7TYAkDOc1tI82ehUm1tYMamPKDK+KLYaYTWoBtFM5EZ+AQwpB5L5un8ACUWNxdLeqzCtSvWbFaAO6M2oWVds23w0D4p6U4URzQGzpeYBGoNjKPQ7ll/c15EWZCRfl1ag5ZsgOw0vjbdfwBPPBXpz/+X//6/ie/Zk9RwCx1n65WbN7E+jRvMNJBzDeT7sWC/sRqLTaTt712KGcK169fxuc9GhffLj8Xr8PqVDoGKGe2K1+t5g4ZrT2qhblDIxAxRVEaUhmcqPPB4235bBxL0I88utDgI+cH5rXKFZLe4MC6QeOQ7NtHaHTTXglDYkWGrDi/nw8bam/w7/vBWtzPjGspmAWvWlBru/KA8GmZiZ+uIdLQuR6vitXTlskbGmvBicZrmav7/Ze9PYm5d0rNQ8ImIr1vd3+9+79Nm50ynTdrm+lL4cgFzy82tQoClwiVqQCNQDWDiARITJCQkhGCCmCAmFEgwq0JVroHBsq+BwsbYaafTTmd3mn12v/929etroqlBPG98a++TkCdLINUPK6Rz1r9X8zUR8UW8zfM+TxXHe0lPcr1ZQAvojJPbWY81r3HGiBRuzTHM4zhNDpmH7jKcvR9d8ClZT6ojncRtG5Kw6qxCPoxRmUakczqFgh74YjmN9zkaoes62PaTFw5f6w1r0a7w4N6tvsp/SXe7vkTG+qUij4vizckIh0zsb1bx4dib7GNAjYI1J99g7wCBq/WYGhovlgo38/h3t+KgHJXIiIBRgo4be1iinhx3gcIYaBKP+tpizE3iM0T0FGULSyp+ISh11sIysdolthuDUqQ6pAq+9mgXsnnxgSlzKIFvCDmrzeCE3og1VypHeiITOk4HaAl9kIKqvepA7AkGEpJUCh3DcJ5qxL5ZYpFILeMBJxYwDH041gtt2haKfT0mmMN1wBfflaQ20Z25xmrKhLxsrC706r0CqwgOSr0qTWC9TxuR99v1WfKN0DNqCHJMAUJ6knLpLiTKHZP37BdSN+K8kMsGVLwXXQnlDtBd8jcSGlQBEDqnFLrySQVYiHMVAM+/P3oW4ZgndxcYcWOpSI/jQo5NE7+3kk0Rqg83clNpuy7NUcV6pq4DSrnoEBehug0Y5gROcJ632mN6+U0AQFnGzeX+yQDPqYN2QYMp1wpl0oALKA9i30wJ7GltB9SR/WB8HO/pzr038fzDOEFuE1A0yR0Un6uGkj+zC492JfOc7C1th5wblufz4IJJNVcdx8hb9ErCyWBRPROKhKwdEjNF2os1UlhPnq+wRf+UNjGFRGqrZJ6Y8LHV1VnVs1k0/fES+paAjWxi4JWENGM/NzlwNYvHXnJt8IWGkfXk7bs4ZKi6sXFDK3OFERFawr5Wa5fglwJK9JnGxWbCz7kWjad4zrhly/XpePAmbpDq7mSflCODFZY6OgpXnOd118IM4jXsDYg6bjzAcGLHE1/WS1xtGrguPZjfte1Cgru2a7u2a7t2Ldq19rBU6HBx9hKaVpdiCOZwsgffRCt0sbjidzfJPdZi9akW1rJOi9xpZpKh6aLFOd/EOISbe/ypP0uYJjne1tMaA0LOFdEBrXOplD1IvYvtUoW9MRqa9AjtGcNFJwEZU54VrSHfhVQj1TFpObvYII8GKjJhjdiq2C8krOd7TjwJJxp0vQUo+V/do8qllsObkCx7vxaoOxJjgqIc8KbeIMsO4nHGbwIApvo91Jexr8e0eJUFaOwlKP5EZwDDpb4lPNsr1HW05gQE0XUtTCkWtMRoHDy9R02yVAPdK8nSUHNWw/O9/5Q4dvZaSEhrwNBil9CQtSFZdEm12PkETZ6uBPzioSves1j4CwWsGRZDf60634oFIXIOGi1gGp5fh3Sgz/wgrdeXAe1Catr4/dymazGdeN0t8sR+QjBCB2T0eq1coFEI9Mo6etAuAzQBR54d9MGTS7TLOEGOxyzjUH19TVvzmrMcg0RX6KCIEJnTiu/CGqAFPl9FN2NycIIvfCpa7J85+QgAUNgGlrIsq1k8R70EbLMNIY/jYSXCwUGyrmf00FvhP4Gub1P7Jf6+LUkP+VyAClD9/EnClg7JIxbQhVLb9VfsS91fQ6oTawHH+/CtRBR670KAPbVVPRiIfbpsDU7P44NfkzyxLS1yPiM39kfIKDW0ErJOC8yXwrPK61LRGwZ68FbQDqfnMYSXcR6bLsN8Fp/J+RWPd2MfdyjXNM7i6+XSYiGhfc7zTbdIYz0ohbXEYsU5WpJd6GqzwNJ235Pi8M7D2rVd27Vd27Vr0a61h7Ws5zhdTKFZzX3nJCZ4h1UBKxIRNFQ2AwUnchas3vTNJrErrwkZLaoSBZnez1mt/4c/X+MP/8FoFeS0kJQJCEoSmQQgDIABzT2z6asQA60lqzQCYa1g7N/lOskoNFfRclpfdanAs4CIAHp4wqgd3aqsColxgjXMsJ1NjA9ijvgKAMX/tBQBZ+grardi8prFpB3zacb3ll/GeLlCQED0jtY1ccZbzPdK0AMTDTMSTDdh8LMOrDmFoWEWtEFLq89SVt6YFoV4A/TwtDapqj5Y4fjrPaskWR409BasXdIQ0jEhqJTfcVus7WLpJXtPI/1YWBSU6hki7twhU8C6g6UkjfDaFU4nmHdH9nG30EmYLxB8ER1FegXCXh+QPAVDEMpB5eFZYdoxj1dvAEXQSyEeZwh9TkdAAcgxu+IcYxL/5u0Ktus9XADQxqBlsXFDq/m8MdC8qUrFfO9quYZjrnFMt6pda6zbOLCT470kTVHTI/YosSHTzHQVQQHjw318+u0Icf7crTi5c7/AmnnZxSXvo0HKc3jRqNApCJHmftii5k+PgEZfEJzcpR68Ebby/eH119B7WDLHMqVSwbCSvLAPyStLr1n/WAnQwnYBXtYPiX4ElfLG4mg0KwdD1pAN3eWLTY4zPmprKRx2Hu/cjzd/NCmTGGaqqQkKHcdbVB2CVqlAPjcyTzweP2b0yceoUoURrOGDyjqVTbvEJde6Kxasz9c1qtvR2zo85HUtOnSQ9aNHPHl6oTWxAZebJapR9TGCgP9cu9YblvEaIXisOIlPN/FBqIOG4XTJGYtSukgURIF1EdPVAjVnrhnGcJf3AMuTUHLR+D//r8CQ1fc1ByobBOQkDWVkC2hdUsw1nOnZ0KBlrVRdtyi5iI24ibWLgA1RePVC0GJb1DEiQxEUcin2kXN4m74niskWPYhCwhPeAEby6wOpNdkin5UFX3m4K4ZKhOppqGGI4Asla36GBVYk6r26/FZ8r1Op/kokL1AaeAnr8ak1zotCgZRSwTiPPdLE1J2EuzJ41tLZLj4QGc6QDWKYwqXFR6UVJmkhqX5BQuj7wydkXkDOuE4K12whB8XKUap/uLdDh5YrSztd8z5UCs0E0UCyDoo3mtGIcdajJmqSChsoJjlSeDAp3QKBJ+y0bLwKNUEIDdkNVm0Nw5WvIgAkCwUsmStU6OlzLs7j39w7cevmEJmKE1fkVNzaoeGK4Aiqueoq5ER/nT5lTNoFDFivJ5pbOTxqkdhQFQL16Do747WM0dio9L1uZjzfFG8NSclEpJltA9YMpbYiQuvwMVSnDkjsHUlRRPV/b49beM0QUXrLuEnvKSg+pxJmVa9Qq9CA0L0RLGCO4AFtZNOXeeDhiOJ4BYkqennJWAz9tUoNZxsSAnlNxNHlusSCG75NFlqGtw/JaqFb1BsyU7QiB2JS/FLAKK6LRjIAGM6P1bLAdBYXiEzmYrAwOevcWLflNhvMidasL7jZKY3JXa6xrAXNdcCaOn2ytpR5jop1Yles4WpboCz6lMQnabuQ4K7t2q7t2q5di3atPaxi4THUKiX0uzpaIJsAFGLVcvvOOo8xIeyC+283DpstAksAKKFRU8L0j34+WgTftzdHSWtpIUVXwaIV61ZAELXqD8RQnRt2KWFfaSCnGeUaYYoIUPR0RHUXOiSrUDw2pQEwjCQhLpUD7rWksC76JHSyDVsF1RH6TyvStT5V2EuII+jeOhMxxhAylIexD1cMBxTuAZzUoOFhvLccycuQWhg/t9Dk3RO1XOU1ciFlZcTBKwBk9KiIlq3bAuuLLwIALj/8NADg4O7/G0eDOMYS9oLXW1IS8tol61sBySQWcT/vkGb+duhoG/UORE9MrFmBUOe5ShISInDnfECWIOk6XYsX4k961XmukdElDiLLXBhsxJulpZkhwDDsqIaEeF8GZCINIrV5KvSEv/TEIu7oVdCNsxYTqkYPJPzklsnDFfJVn2XIeb9Lhq7OFr1MjSUmW6kKfvPqPK4Kj5YyOvn4BhYMIzrW9SgdEvabArWo1zOEEUOaJL6s1wYs00nhs+RBYRv8AmQsemoY4s8yjSBEw+n7vZes+oMgvEaca0wve5NKHkJI4o9JesR8PIKldTw30HtYnQ/o+Fx3nC+2Qx/324bay7GFctIYUGEFVxzX83lIdVji4Y2HGm8fRw9LhTVaFmNKdKd2LjGvpOY1gu/BOAAwvRigbmPOwLNDWrxEqPiAcl3SWYXhOC5sE6Yx8sKjZB3pnBEHVQMF+V1Bvs1inGHE2izREg3ZAn7m0lr3SdrOw9q1Xdu1Xdu1a9GutYflFwEnd/exYuy0I7FdZhUci+NamqPDQYkBuQFnJYt7B4OeU4xxf9M5PD+L1sEP3++9GzGIRBhSZ0DLnFNJGoTcBDgJTgsRnvPJu4BiQSEAxQJet+mZxUUlvtuS6zbovSqThkuupo8RayselO/ZpgWmrVWPa6W1pywSflsAACoDskNhXGfupHSwvGZH1IUqhpCb0rT6vXHwwjxAK9M2HiD4JEkxZBmMyGzz3jr0VnQl3VZnePJrrNJ/TEDMjxvs07LzwrThkVjYhXlAd0gAi8aGJGU/qmSc0EtEbEukp06NLyH9b9t763NmJkGSdWLaz614HgaBLNYZC77VWIEK5OnJ81OLeho/F07MvT1gKP1FLy43Gq6WZH+8oaNDoGXOT5jhu+Bi7gWAJwuFh8XBPq9BmP67TWKOb2lxFxro1vF8V/Sw120DtxZhRh7XdbAiBCrAjryEZplHORjhjKR/nhEFo7uISgKwJrBn4BwOTMxrGbrbtjaJb0+iBj70Y9PnHFPpePI4nAsJICQy9m7LgxGmi5jXgvw6nsP335MSEPiYd2Rn8lo8Ag+kE8LCJU5PCX94p9AyIiE4kdD1uSuVrnPLY9ty8fOSbCZk3j+bd2gJhe+4Bh0fZTgeMV/V2eS8Ja3OLiTYvuTigkIqByh5/R+d5zhvY3Kq8Wseb46CBJWHKoY9jKlAtZgkpLk/GkATGKKvJO/ZQI/5LLbyTAaojfBeEnh0UGCxWiYBy0/SrvWGpco9lIN9LNekIiFsqBwO0BABIyQDtijB8g40MrlGDhOOQEm2jNI3+Ma34xdPcsJyhqqXZ+Bge2UTkq9HYwUIi6QRTSUf0gPVQKUNS5LuutCwXBBEcVR7pAVXQBUFkOISoiwasjyhtYyo+OoMXja5Up5A1bNeSPgxV4mVAfyezwKUcBCVpOspPAKLt/o6sQsEZsQzSWiHngIn40TXhUorjE7MAzYpjwqaKdQZ1FDUiuN4zF/cwuOPuIitvwoA2Ftb3JQNH2ydQytSDbzHyvWRl+kCmM3iP27fjufbHyORwMrgWLuFumBfKtXX3EhzLqT6GgnhhVwnBmGVQpEGFHuFp1xNqAC1xwMy2tJceGRCsTCW8QC8Fz01zuPgE/hL6vpyZFB8+Ft+b7kEskw2Y16filRGAFBwLB0cOq6ajveb1R0aUjyReQfe1XAiK0NjxwWfFlyhNFt6j3ISb7i1Gzgv84y1WxpouSI3DOxMnMcNkAqFmnCbDeAFESibE3qkamo6JMqjFMLzIYEL0oblfD9XBNRg+/orGXOrerBFLgZLQAoJKm68yvkEHdRbLChedNAE6Wf7zWKbhkm9tqmo/hJS88FDwMTTdezTZb1OqGTBXh0d5CiJ5LOtheXCn0AMDmkd6fhA+MIhkw2cIdyHlw4XzRXvPb5mCAmwNiG9XVUMMJ/FjW1Kg9WZQwwGMUwYeLwSQ+wTDCJyOxcvLzBdx2v1YiAPTWTmaQO2jfD/XNuFBHdt13Zt13btWrRr7WENyn3kfgS9iru+0OFn+SBxv4ksaZsbvKTpYbxAy2sclhHOWQ6iRbDuGrxzKyaADw6iFVHovkJd0SI3BskVluRl6HxvxdEVNmUPKVchIBD8oMdU3VwFdJIoTiQEqq/KT/UfgOO9iC3ivII1Ev7hvRsPQ/4uTS8pbHyyyiQkFEoFS0/Hi9eiAWVFqbl36QXqiln0qtrLOpHaTjiDOm0SyEDieybLU71UT1bbC1VWNEcL7aAP4wWuaVH+7m+N8YxyIRtyle0vGrzJujkaabBdHwaS0GBQOnlby7XDOdHYBav4ByOdvKQEiVeqT+iL4a56fkG9ZdqJ5W+25GMTT6EMYtZCUUpGSa3cXs840a4E6OBQ0usasy4gQMEyXiceVghASdJSimjj7KlP12xY+zYamxQNkJAU0NcH2cQKsu2N0EO1KoEWnpzHvreN78Nitg+9ZlI3Jx5F0+HdTx0DAKa+S6AYEVDNfEBDNhnh5xvYGodWIPXxOOuVT2ARaQEfc0IA37+3XYfX102F9FliNWF0wbX9vadhswH5NusxIghCgDoZXaJcAUGiLJwUZZHB0pUQnklr+/q2JGuydX3fMQiW5lWGFR/UpyvxkDJ0JcPwnPw3BzaxqNi2ThGaBNnXKq1Dcr/BBnAK4iXJip9fdXApnBvPW6FAyaCrkOSOqhI1IemXG77qFTJO8IzXXA4HmJgYRlwQTLXUDisZZE7QolOwtk2gn0/Sdh7Wru3aru3arl2Ldq09rLzR2A8jKBwCACzjqgfFHgphi+5ijPwKG8w2MYY6pgk9MYPE33ZAJoj5oz38H38selj37ggpm02x2CwViPbwcUNPK/iQZBy6GX87DChGTHIC6GqRV48WiF1fJFBG8qZMH4NXKeu/lVxmdacZOGipVmeOIrtt0JKv7smT+P13b4XEBuAkr5KHVKgoxzWIBcpA72W4GgChyyS3gK5d4tvLcikfUMnzE0/Lt92WpxY/K6BQMNbtLZk9vMOTrz8AAMyXBwCA3/76EuWQeQ0mA+ZtQMvjZZJjsX0+QNygNigEgeIahUCfriZQoG104tuTH78OVQZeZTrYzj30kGmygrRt4miUFGBW+QQ5NweEdh+WiWlCbmBwG7BkYbe0NPOgUr5HKOVa1zOcCIN8rn3yJCrK0w/HOZYrih1SVFArkyINwtofAGQ6TgaBq0dQRTzeR+e8po1LnonArn0AvABrhLuvBAb78SE6ny7hmLiRQId3BmB/BXb+gdpgn3mXVvJknd7CsUtOEVsVvvxkq9JbPCzrkSDbyUvzSGOTZEHarQNJvlL3XigCgUTOJa9H5HaqQifpEhEEjc8R+yjxC/Zet6wd3ve3lIqPt5yLxNqvM5yTVeQFhRAbt4AdxPEqOHcO8yZhPby16Vrl3rVS6Ni/0h25Uhgyd/3vv0Zv+qrF8ChGmIRUwYYWgzqeTzH579Ytxq2sPXH9mlYG4sgn9g6vsCIP6+Ui4gvqbpWYWhyxBs7GIvzvpXD4Wm9Yg3GGcVUgn0SX1HZxIZwMRzCUbG3BMNbaYUkiSOGGVFmBWmiECCUrrML//gdiDx4dxoFtVzEJCUQGAwDwbf93v/gBmuFGQ4VPF1xfX9EA7pSJZyIMtQ8oZBInfzf0DKxsDj0AI2NYMQsuPRRSo+E6hRVd+DlrIdRekyaLkJ9qrZFzciqZ1AubyHa9oAk7laRBpA7IFCptImuhXHI+LQw6yEYftuhi+nvRjKFKidY3vnUDjx7HHx+R5uXo3VN89E3KrRC51KocWpLfkkT2qq+v4nsb61NnjUfAjRvxPIWoMzdIsT6R21CmR6LprY1cpQWSr1v30YnOUQmoxHDAzWToe4YNI4CTLrGZFFzEiiHgpY5PNnelUXH8u8fxM6MVrKAteE17Y4OOuhHLZVzElusWXsA5UmsWHBopqRHqsDzACnwtRXIVRBzt6bTf3LNUi9SDHByNEgF4TfYqzBvWUtkandSg8TfKFEhMruzEW7nFgLuIzN+6C8heK5zSGglKm8ZDhY/RJsWFWq6VfRl640+IqWEDtOyk0kcmbCFGeQ7b943QcWWVhs8ktMln2QHK9EarnDdRfW3XXCm5ta1woezLErrUGqerOCkuNnweiw4d4qZyOIzvHYw8HIlwneuPI8NkEdAJk1UuqQyDNa/r2y/iNb+48rhVcG3kOhgGQE7BvHFObavTc2R8nkeUpmmVRkZ1ZFn7lt0KS4KyPI2hozyD4rogwB6rA6wBvA64SlDq/3zbhQR3bdd2bdd27Vq0a+1h+WOD9aiGEmFGelDnsxnqOavyGVPZz3MchRg2WRDCvmrmsAxPXZ1Hq+P/8AdneHtCgTxa0LYDqAcJ63r4ci/0JyECjYxY4oyhn9p6tETHm0tARw8ZBZVCddGHGySE1G0llF9XRo3vxd82q16lFoROX106mOoAAPDZzzPkY58nktcUbWkDLH15TwBFPQu9YKQYoEolVotKEs+Z6ssFJFHskKxR4U4zBigyITQUHHSAlfovwrj/t19a4Ud/NPbXW194BgB4svL46JReEr/XuoAsEa7R8gwmWfuJ2cEHeFF5HgH5TfahWJsBqa4m8EdB9Qn4FK4B0kBIJG/b69L8xWCcITDmOl/FDhxkQEGctBAsq7bDUPqXHdhdemR32F83GV5tFYyP4bX5jCCTPYe8YuxEyiRWLhEvr7ZYENLY0WRXSqd4jTB3+KZPxJut8POqIfEzXZ7GaxhOQgEWaK1TzZswdhTjAU6nkcB24wxU0m2m96vQe12M/93d75AJKErkUba8Gmla9yFZgZ57bNVkhf41/S2lJmoL/LDF4ye8jYkIwmxZ70If47ejJ+znzqYfadYdto1Hjq2LQIzAphKHrfsxycPqm/wt7CZTU+Ah5YeWhIKrMgB1/Pv+Ufze3kCjrbmOoF8/UojU9V5oGmMHPD6Nc+uKckGlmSOQ3SdyWwLGFTg+jKmWnOO+1pskBBpIYOutRsn6Ri9k4Ospaj5sd/ZixORAjdAtWS/LcEtXaCx9Ddd6PMIpPknbeVi7tmu7tmu7di3atfaw5mEKX29Q0eyaswjxfDXHmhb27RsnAIDB6BAZ47itoTDj8iqZaeeruHe/ca/BfslEJxPBRiEVBqpkIhlkwtCcCvpUSm6LWaUNACYqg/XIDHnPxszPQPWJ2C1WixTGF45Ax9wLEmEA2hyoWXk+jsoqKPIhGjIhqDpavDaE5KlVwrHYAmsKs9VzOZdBxT4S63vTqcQNaHhPtvGwZFgwzINp41LuxBGK7VQUrQS22KkDEBjLHpUx91iVe3jj+/6H+PfBbwMAHn6zhjXx+h0BBZlZIaONJbH5xvaFtGJ9FVkv6QAPlFIgLRlq12d5X2Fjfx15sQWwSG+FreQ4+17nAdN1vMHFPP7ieKQAgRnTKh3kSN6KY4bfXoXE8O54I7OZgyHsPadk/flZjZz3MWZhdmYsJmQ5Lyf02GzAZkFmdkKiyzxDIXkyuoptB5QmWtBNGyeAUQFXvP4Fn5/WA2XO8RSxSK9SjmhAizwvMyzIFt6FHHkRcxxJDNN1aX6c0BW4XTgEasdQnQUmKRIgScn4rYLbPt/WexLSvH8VpAREIIVA9RMYIWw9a/ytRu+VycmyLAKg4hcYxUHkkwT6qIt3fdG62TpHUlyQ+9FbHoJ4gK6Ponjmgp7aMR5SfqglMKkNGidVfO+zR3HiFV2LThQB8q3oiUQPXA96Ea/ReIfLdQRMeBOBHXdHAUNWhx+W0SPqQodCctvrOD+GeYGDw/i5K7mmNSts+Lw37GDnAxxFPxfLmMvq2hZuTWIHlvSEYcBmfQn33wvTBUKHzXoDywdgTZdZ6ZDQVZLLa2wOXZFaKEzj61IhI32OZ01Ss+mSxtSaD7yGSkwCJb/nupDqKwq+1ynbo+ckWTs0MCMyRMw3aRP0PK/KslSxLw+l0QpaBlHYL1rAtgRRMOk6uO9R3iCCR7S+6gWGWyAPALCZjjRJ2Np4rQK5MhNCKzcFLJ+8+To+HOump0uiEkTsUyFqZdjJeptABmkRCEjnFU0rOCDjBh+mcYH7zA87/PyvR0jjprkHAPi9r/8HHN6IJzSj+L3haABNg6BhKKTdUliW2iAFoOKmboNGIDRSMQyLzKWok9ANIegUntqulUngtC0kZVoDpNZr4zBfym/FoNFoiEpMUSWjsZJ+oGRHleXwUypgk+2h3ASMSACxmFD1dQ5kZAFRDCdmeolRXHtAKSrUa6QFKxf0oXUI3G2EmaJtgc5cAQA0aVc0HKYMD28EVIHeeJG9xDufJFYmVJT1CNAEHvg2oEj8QAwjBoOcn98is+txvkBHy6IW+iyjoDOdjsmD4LWIG7z/eE1TDAmGV773yhdkE1M6sYBkiStpKyQMoZMCMmGQ4co/GBoU5atEtwEOzUxgtVtgDwkdy5RQfR3nNiOv5nsNUUsfrXqqJNF1W60V3nkr/uTNfa5LbT8HldFpjZLzea/S+TrW9ZVGYU6tuxkVncOyxh7ZLPZCnFB6WCcat4M2Gh9BGxQuLo7LGRG85x2uaEnnJ/G379z9LDaz+My+/DCG+s6v1ig4/uOjeLxirNCpNdGx1C35Lm0XEty1Xdu1Xdu1a9GutYe1yjrANQhMam8IH16GgNGIdTg54Zr1DItVtOIXhGljNESTkVjVUlG48FASCpTEMkICb4jlptWW1Qd5D8h8tFRMLZBhBTng+gwYMcYgFmrmMljS8zvJU1sAG8JkaaVDB4QHJOiNTgjaKoNhiJERHyxrIBzEv9dUjS1XXWKGEF64dhNgWBNWpHqWGgt6ceJ5TIqAMUNRGZO0tbcplGMZsrCqD5WJ1Wddz9nm6e0NVV/HYm7Gc33lNzo8uvz1eOx1DOEWE4UV0ShChqrCEkS4w1OjQDkPj3hdM5IgGyhoWrqjsUbTxZvOmCR3dpsVgx0HB5vE+mL7Tgq23vR1UG65z3u/wtCIUGicT48fBpSUzrg7EUsWqDgQKaSpux6yTe/HKaC9Q3JfnvfNz/bzaMUQ7tVlhpdnsY+ynHMj8zCcDUnEEi18LmEbWtzZCK6lt0q3/0Gl8XjNPhf+SG8TM4lnrZ/XALsclwwDlsoi43iVW/0v0iS6KmBJ3vv2zfibm/kMyyXLP4SxQXUJ2LRdZmASpx/vaAtQIK/WauTCo8ieU00Ow+creabKo+RaIXyPSw2oQEYHRA9A+SkYpUveVJEbEO2dmGeGQ5VIfi3D7N1Gw5MIOScRd6WBNefJkmKoI69RddFreZzFEO1XZgH1PCK1MpZd7A0dbk0Y4uV8DkonQIe1Q9R8FhULm4pWpXq5huvXdDNG08bz/YGDeHOn7gFqFTtiQcYLVQMtw8OW5zs4HKPj35eUEnnybI4Z6/1urQ4AADf376Iu47pbMQw/qFY4ohTK5E48fzP2kfiz8dh5WLu2a7u2a7v231S71h7WUXWIcthbX2cXETPuluvEV1ZLfqk+w3pD2WZEWOfBvsFaxZ19UkYL42RgkEl2WSQscg1DyygIO3rmk8Uu8WhXB2gdLQspQ3UOsDFVEAuEm2jBbGipaNTIy63PAYQuwDJ7G44ZTz8q4W/G6/f0Hv26gyZEX2RNxkclXvwuLS3KVK/3PfZo7SuCTJwFFK3MiuwdjQsIjEePWHg9GSrkIZqUC7IfrOs23bvkArwNfWW/FGBqlQquy0MmmesC2V7sw1/+19HSPp0BNfu8KwgAGAxwsYif37kTx6tSFi3zlJZWa2OBDd2VaYK3Bxwwt1NV2xa4QHIBI1IoSrzL+Lt4kfFFqyjsB/T5CIUtLjkyWztrcPcu3SN6PPVlBkhFPzkCN1ohp/spDCXe+5SbDHwvyx2aOfNVA3pOOsOasGDlWKipAkSF0yWwgUKQgmDmX4oScPQaaoI5Zssljk7oAXiBrXucLeK8tOyEPDfpupwU5bYBJV0PKYrOS4uMbDG+0Vg7PgcUWSyDxfE4Wthvj17Ee7ItmloEUfnb4JM3IwNnHVKeUvKkIfS/kZZlfWpIHs08739kIaUkGmoseah4/TeygLolAIRSQ3kJFHw2HOH+DRpYArQELJNpn3K0kqdzxkNo2teEoXczjz1GBupVnBuVyXA5jn301SbOk8uXj+AJYGAaCTeGwNFIcntCOqAQWDjuNkssuI7sdzJOGab87pjSQIuLQyxnbwIAHtw/iJ+NznBVT+PnNnp2l5erRCywIcBitVlhMohe4OT2HQDAW4M9fPtZrG5/HmIE6/z901TgPyK/YDXSWFLjaD0VTtcMvkPiQPwk7VpvWLnN0PkuAQXuDCP55u18H4/PIsLs7IrZa6NQcnUd5rGHuukxPOFJI+oA7euQGAUaWbhaD8u6lIEs5FUGH6SGhwuDUVB0gTu64EYprC7ipDluFVoeW/PhNrlKcQ6KtEIZhewwDo3na3cY2QnihTHcqSw8NSe4PsD4BlOGjG5ws8tUwGYh5+AD2AUMuPFV8qo7lPy8IgQqtMDLJ9H9FzThqDKJFHZDaRSt+41Kk4LH1QGeD63IqQR0aFgP9+Xfije8si1WVKYNDJX4tUbDJPLekKhCpbCZxhvd8EluAtAy9CLft53CeMBFdit0lEBloQ+1he1V77Xw33aSXCup2/KpzkwQnGU1RFMwTEOr4/jtCi4TlgKGTZEhFzYQfk91BRSNoI6L1KQawhKUsSDzSAj9WLcEsgzHBpPj2LGdoPo2BvUi/p3R6PEALtlvc4Zvygq4dSsuYs/OCDxywAvWfVnqMLXQCa2X+rEDTm7Ehas85P0MW1ipD7QWDRkYsiyGTTttcJ+UYW8yTF83QEu9tJzPZqdNMnLkhD70Y7cdEkooQWE/CRUsGRZKCeHnHisaqhV1yR58ZoTqhIjMF/F79sJBc+6PJhzYCghOgFUM2zugUgQDpTlRJ4CIMOx00LA0TtaUt6lqYJQJ2IqyIeUBPiqiovav/j4N7qZBRgLrjqHcm3saNxiKTDIi1iFjOHGBJfYIS9zw+dMuYEDEVE2E3n982OCZncb+G/HZDQ0M162KEOTx3j4G+3FunexHCPLDl49xxrVqRAN5dHOEO6z7fMoaxGU9E0wRWhrUxpQY5PFaO87ffJEByiVU8Sdpu5Dgru3aru3arl2Ldq09rE0TsFgv0c6i6f/5W5FA9d3bbwEkiH3x5CEAwA8BTyCB1CJt2ktsTqPl8YOfi1bT7WObYPJC7DkcZDDhVfLQxgVUldQiMbygFCyTviLjoZsCgTIZwecAvbsBJSV0liX13BT2KBQME6yK1p7PLAwZJzoiJ2wG+A0twE28/sW0xTFDPeUhrcILwG5om5CH0HUBK1o6jiCIci8gs9GMm57TWpr6BGHP6W20mwArApWcQa5TMGLV0vRVSqWSp5qObrXv8JWvxjfP5hyHDmhY6zUYxWMsFi3ygp5EFu93v8iThyhe1aJ2qFvhSWT4scxRpPiq6iHpyTxT0KzdE5aM79SCQ88QIhX+ISTLXrjf5vMl5kX0JEYHsU7lxn2NbspxuiBoxNokwphKGMqQCJHlWpbLGi0BCkbqp1SvqCwDYnKLjGy7m0Zg63kicQ2M0TZWpZsfjU06xsun9Kzov3itcc7ct/DgOb1VJ8jnIssD9vZJxErvd9NZ1AxZRg8vTowh5SpqZVASwu6X8YDzNpZDAEDGidR6hU6iEBLh2GK6EHaLGMLlewIlh00erNR81U1AMPHZOLzP5/nBGoPb8ccnn4738dFXclx9k55YE/vq6KCAJ2NDzeermGxQlRxXehHWajT0erIhz7FnMKcXclLFPricKXyDEjeHJ58DALwM9/EffzdGg5rVBwCA3KjE+TngGnNrrDDgGsSoOEIHqBDD5nmZwzI07lop3ixg6V1O+bz8R7/C/o2o8vxhHQf7VnaIvTLSwRh61k1nUXMOPiUM/mytYUfx5CWZsMedQ07x0gnrK40poVgreFFHj9fComUIabaM5y1q4GR/nJ6xT9J2Htau7dqu7dquXYt2rT2sfDyAtwtoWjDLddy514sFNK25gRTlBZPYmF0n1bseJaJH8T98Ke7+JwcNVkvu48yDDEcDtJKQZc5LG5VYL5Lceg50Ys3TkF2degxEhgAKBWP+nllhn3XQLERU9NS6wsMVkmyP34//JCyXXlzoAogVSTBdvzIY7MVhvbwUDjuk6k/hUNO+J3yY0UsL5wr1RkTWeB9Kp7+FkVpnvRclBataI3mmYgaXWV9I65ijchnwq/+Bwow0rDrbgyDEqmtahwn7Za8iK75zWC3j53N6WrONSqwFJJXG/l5IoBBr3ZaMSg9bFxkYMUYz08PyE3+g28qTiAyJ6nM5xH1gbwToPPbb4uyHAQCPn5Uo6lg0OSbjdlb0pPWSJ62tTXQKcrwsy9EJeSE9QWUcWrIalAVzhM5jcUnplRmfgVmD/ejsoaDHu954ZCU9BPaBbV1iIRGRSOczXLHkQ1j2NbLELC7S9dVIIRgy6NfxWnJdYmwO4nGwRmACuBAeRQ08Jr/gB5dx0t63GrakF0iIuPYZxA0NqahXp/krXtV24XBirTBtLytixNo3sHz+ag7w8yvAEORz5+04F+//TxUa5ohf/m48xAlGaChBn92Ma8awNDD0Bg9IROBshdk0jsOSue7JUYnDz7DEZRxfv/G/fQa/8d5zAMBbx58CADw7n+HDRxGEMuFYTz3gOP63h/FcJ1WfM3WyyLiAkp7YCgGs6kBGj7MLLuUG/+0Fi8P3DVAIhJxMM6MSN5n/F3HbzfIRPvhGBBVdTeO360GN0RuMJnHdWbdrTHw89n4T73PQ5hLISc9ZrRQsGVgcAScYeVRmnBhUPkm71htW42sEbTEYxYdxtY4P0fsvP8KcRLhjqrRmOiAn0a1Q5Jusw/Q8Hqt+GTvxYJBhTloUw35t5k0ipizlga99ItXM0ubTJ/DVKp53/VGLIy7aoexAwF3SddIVYIyEETkRy5CkOgR5Z1okf9hzk/AzIHAhWguJaOZQcREQxg6rFbKSISOCL1znemJdzq56naXNRlRaFTQCZEEThKFP15LqmTwS4ays8bkDhCU3G8bvPXzP4IqLq1BIrdcKfsAEtbBuVAUc60lyYTxtgdllfO+SC5jTwD7DiCUJgDNlk4zGNpDMJx4mlZAVadHbVhz+jhtW/yqLIvdWHDjgQDbA2a8CAL71H4HDLK4g++8IxDDHhuwCnZCkjjUCV6LBkIuK9hjtc0OjFtFm0QMTkDYQjYb2hVBRVUMDw7Cz4TwovMH5RZzTsoDs75meYYVzetMozBhaFCkRD59khYUZpTgqsGI4qePieLC3h4rXb8MQ3UJCqDRegsIL1vr8Th0ti3E3x5B91LLTixBS6DHQ0nBe9RuVbEjYGpLEsKJTYR0BmlisOuy/EX/EPQKDG0DN55PYLLTO4canaVyJMaZb5MOIruyK2NFLFVKdoWHNUnVgkBHgNLaU9lkVqL2gQ2O//OrXJ5idkyFiegEAeHb+FJYxvprmTJOFxCrz5l7sl8MsJANTwrUKCk7SCLbDhpvNhM9Gnms8XcZzfziLVszNgwYdGXwkfHk5f4qc9Fw3JjGkPd5zKEsaYXUM/61nC2jGqCdFHENbOaxo8Cjq5VWqxJDyzRLedZsagZR3B8N4jKz08N71z+UnaLuQ4K7t2q7t2q5di3atPazp1RlGgwyrRbQENgQRrFHDMYTT0erQNsDQRBHZh6bReOetaMrcO2E8aeWR57GeIHlY0zZZcTQcYArAiaVDD0tlQM4kuL9g/cnc9TVc4xat0NnR2kTV9JYirSXje/42YYpAq9N5nIQG9yzWlwyVUBgwIGBj4/eGN8mqMQwALfXFTNgNYnIXACzEw/OpfqWXCrF9aEY8vD4Pj8A3Mw94Jp4lHGptgPHsh/14I1/+BWBNr2tGIIZDBstijNoLw4ZK8g2uo0e2dNiw/KDmNWXDPvTFyDCyEHqJkywkSQqJsnkH5FJ2kOqsehiyeF1uy/ATj837gE7GvYwnbDuLeUQkY0wo+QkCsivOswU9BuMSaMfxogwAQ1DA7DJ+78XzGm+9M+A5yNLRqSSP0RJk4pxGQf2L/IhWdeWTB+AYrtusC6ypPzLgNDe5T/UvtZCc1g5rRiFkUnrlU39kDF1qXWK+ljqKOKHLME0chrYqABPrNjb03soiR0X36JsMHb7tp/g8QQEbEhWG3MMIaCdNux7okuQytuegjJcOMPTKZIzWtcenHsRjv/EFhgZh4J+yHvI09u9iCbSiRP5O/Oz4dk/y/PRb8RjLl4Aj6EWXUwDAy87ByzNHluHVwqLmenPFa/8Pv3GBo9vx84NFLBSc2SXqgmFTelqtAm5NyJIiauU+oJb6qjS3DdbspGLj03xthTlFB/zr346f3yarRRUU1vToOxW9Km80PljEe/ngLNZU7U9u4d4fiCEacyu+nr8fkEuU9oIkyodAo4Txo2QfVDBEnRVJmbNGzXSKF6COd2i6Oq0bn6TtPKxd27Vd27VduxbtWntYqnFQZZ64x+Rmgte93Di5+EqTYVxFq88ySXjnhsX/6Y9FC+DdW/Gzbz0d4+bNGMctEc1mo00S+hPr1XdAmXJAtHhqBUNY+Pl7zEt0Kome2RVAjAfyG3wvqxCCFBPyvgKBEgBAqGgHoIs5UGyYGB+PgQHzX9bHA28WG3SEv18yP+e9hyf7gYAkFLSkJtBJwa3zCXYtkGIE3RNLMz+gsMWK7STvk4Qd4EQwL6gkGHh2Fj99+MxhSlizSEoUhcaS8W8IM4ntMBCCRAFndApW0nzM5xRGo2Fi35F5YDA2yCFSB30hs7BjKKUSwKFjqcMm+ORxinvhrE+wfRHZjAAR9hHvzQaVvFpDxvWs1Fi8ZMHzjB7gsEMj0Ps9ne6pIfvE5VX87XA4wKOH8dhvvhM7QYdGHBh0cq5Mg04e9o6ZsL9qMaO3J3mtzgeUo3jRBfNaXReSty/9e7726ILIvwuzS5s4Fg3zwZ1p4ckwmPF7rtOwXnglLYKJuR+Tx3mZuQFMxiT+KFrs769v4z49tZxEihulkkJPxr5y2IK1c9xMgeRiiXpMpwBP7yF4yUf1JR2bl/F1ug5YsdTAMpfVOotLRitKFvqjzdEwUTk/5XxaWGgRXJXnZtV79AuyOXgEdGTJOH2PEizNGg3n5TnXIOs9nKUcECHlY+3waTLDTFieseny5DWKoKL1HrWIoXqNoWD+6XX/37/lsWbk6P4B1xGdw7ETA4vIfVBQJEjcsPA6Wy5w+0ZcB+999j4AYHnjHk5P46KypCqtNyqtIzJwQXl4H3OcwwNyIWYKC+afN4u0uKE6yLZUNL97u9Yb1nLqEJSDl/BaJptTHslpAVQM0RQ+oGIV+pqItc/cBf7YF5gcbOLAfnDWoriiNsxADut78INIH9iAWlxzXk97EaCY7TWyiRUBqhDmgQwYcLJQIsQ1OoVfEs2RAwLdfwlJTC89Lp4zdMCVa10BQ5KpDki6qlRIulmBOlztWieiTtkQkIVEVZR0gqB7PSFhNfA98CCFaLTqtYMSKkEn1oiWxy1VhoLX9/RRvOZV57ARXmAJ5XQdAsOEuuhr0oJIf3iRWOjQSLiIm6f2AXkpyX7qQHX9RuNUr88loUroAMPYRkEwSL3xCX0lBobSKoUTRbrB2j4ElfNau64HcaSw6DBgTdDInPVm2YEGBkW6PwAIxiaGixu3CbooLFareOIVkVXGmBSWFh02Z32ShlleEWzg0NcHTrYkRaYiy8J789jiOYr39vgiJHRik5SiFZyE1bkAa+MTQwXXNwQbkuZWGxxqH8NNOfVPdGiQS3i4iOd4r9zHXhfrf36wjug5FVwyltqCm7oKGHExXzHMVgZN0lSgE3anLqRNXTa23ChYGninHwrIoEPLfhB5nLbxCLLBX8XPvv0bGyzXNL4YBpyMXPqekAvHoWT4l4+D9Q48BaYkGdbFHHkXa0WXK8qHZGsoSg1lVbyRz9wIeFuoo2hZrRqPTEBBMteCgeZ1DbyCInjqV2msP2kLHJ2QkqlgOgQDHJSRVmlxHjed6fl5IvmtKpKANwHZnPOoi5vUOBvDHkRDZEDDLM8rKIb/nl+QXHx2mTpiTEvvs4d3cHMY33vOosx20CC0NVwXEEf/u7ddSHDXdm3Xdm3XrkW71h5W01hUKgOKGM5rOiYtsxxDWihDmqVH5Rhrhg5nC8KNDwLGFQlYV7H6e+PWaAmJZ1QDhdZoRdjwSmqqAnIheRWY80IlD+vWEa1RHRi/iDDS5jWLPdMBWSFYd8FTt0mhtyUv3PpSYUCs65BewfrS4oqCe36PiXgL1LToE7BgEKAFCi1hlhCSSqoAPLTqLRhJ7AYfeqaDZL32SfCkAOsDAq1VJQwLzqGIBjYePSbMv/HYWOHOi58Z5aHJXeiUkMOqFAcSLjsXtuD0QeqiPCp6WIb1SVZbjAhx77oeplwwfNY2ofcQk1elkoe4Lf7Xircq/26BjOO5ohqjMgETJseTE5cHMBrWh0qdh2Loq2nFA3c4Iknc+IAhN6xwQm/lg6/FviqgEFh7NqHn5J2HxEg3DHspGHSc5+KdrWuHhhYvS4cwHJQAhS0Xi+gNPTkDOoagLa8zUzqVLgxEKiLkyfvdSMlDAAy9x8x2cAR5FEkmu00P1IQh8Fk+xm+P3o7X4+L33m5OseHYVXRR9AboCBYp6BWGdUigF1G6VconJgxHLyPPdBrD85dxvDZtL2QpLC/GKGjWa4kyuWsbDAnVD1LOgr78QahObNuHBB2vb9OoVPayooc1a4Cc4V9NddV1ViPwwbq9H6/vc3c0xhBSW3rYcKn+y7ZSJqEwYKTBzDx+n7Pv37jYWTqsML+M3kwzJviltFCzeL7lBcN6dYdSYu2URdJZiWw/Tpb5LLrRL05fYjyJYacJQ9pFCWQliVGLqHv06EKh2RB8NIrv7esRQjcFAFR34vV9tHyClXUJpv9J2s7D2rVd27Vd27Vr0a61h7V/bDDY1ymXtKEY3CbolOcpymgRhOIgiTkqEo493ii8YBLg1qenAIDpwKOaU8CNgnmb2iUPpWKcuesC1mJhC6dZ18NqmzVhxoVGbXtWZ+I+ECTJOXeoa5qAtA6Hg4CS+YKCOadhFRAYPO+YN6jXAQMW8G1YQeqDhyNsNDB/h01fJCw8cy5ssQYIdD54hDz9GT9zPWeeBM+9DT2PW0plaSgePBOYrg6w9EKfPJP8QRAjDkT7Q5ksSXhL2iovDXJh1ZciYZgEwZcckAsBK3rMzOUjH2jYlqKOs01irO6YkArKQREs0BCFkhUaA3oIAkKp666HU7MLiiLHktxqY8qGqKxFwd+21HrvXF/83dLCtm0FI8dj3iKoDJdTIWskFLsJmM+E5Tp+NJiUWM7ib1b0PAblICXJHXNyWajQ0I0PxK37NiRrXxKHm02L1ZzCgsx5XSwGaMXdNpLf1JBpVEpReeMxEjCFMITrdYLq67yAJ2+cpyuTjyJYI15X/F7ZBlwxefIbe7djX14tcWsT8zti9A+NhqK3SJ1VrBCwJ8wlzKd1GRKViCgfQCmsCOhZJhBKX6QfkqSMQUPak5avw7HGaBTn0UqYyFd2qxqdpSRbZSjW9znYmn1zeUZxz1UHd0ARTgocrlyHNwZxrL//VjzIYeYQGkEFiaeoktco64TOOoAgrw/CBu/ffQsAMH4R+/Ss+Do03ctbw5i3qkYVrubxvY5rRqZVWhdWZAuyWYcV82MbAnXmeZtkj+YsJkatEMjCvuFCuC432DuK5wvVjfi1Ok+cg52N1dreN9gUqldR+ATtWm9Yw3EGU4SU4MsZEqqyCoYhlxUHLNjn0CUfYGp7/NYH+/iFX4+d/H/9s3EAfujdAucP46Ti2oNgAcPwVK5THKtf/KW+xwIDJqtlcm06j3xC9Nc4JDqns8fxvC+fW4z3Gc4RDSxloLSgyJjs3Qc2XKgcN6yjm8BkGD9/+YQULQukhV4FUVjtw5aSjHZ+C2whOAzTgykEDRk8elYIeShtX+0vuLpchd61J3rEjICaMsoL1gTNlg5OABNKascUaqH9YehC6YgejP0bP7uaKlguUoORfObgudBTpBUTBbTcGVrX01E11AfLS530hrKiR2utufHlorqKLOmfVbIhWQulZaFiX9mQFrRO2CAKjcOjuNj5WhbCBhSQTfVT52ceQeQqRJ9sPE4UWXsMS19dNCjIVjLjyjtbepRV/IKQvW4WG1S51AQxbJobrNavhju962V0ZAyfzVyS6hE9I+0NNMOmFUETo8ygYzLdgXpc1qdFGyFDy/obR4OwGAXkJGe1Kl5zgQ0mpG2ZMnb81ePb+OJZTN4PF/EzvV/gwguBdLyB40GObk5GGho+1qPf5QSg4AMoqJyAOq3v2SJkjtUb2xt1fF3MPV68jCE1o0P/meCS+r0uEROLMaY0IMLmZ89jx4wGE3RUn15t4kW9OQZ++EZ87y0as65RsCJhI8/cOqRwuWQQnAMecs7/3p1j6Cpu+kP1hNeXJ+omy5rHogOOKCeEQxJmr5doOafzI9ZSjfZxQYLb6jCO1x/8/p/CjcMYwn36+BEA4OHD9/Dw5TcBAOssplX2xgZNHQfl25fxPu8ePMDtm/EG9TqGoO9Xx/DrSzgf8FgExb5L24UEd23Xdm3Xdu1atGvtYa3mBpOyQiYMFrRQR5nGgESywu/XzmcwNd31PFpr81WB6TR2wcXXGCJYtImgNjPRlKlDneDDLekvQgAsI3nidOUGCEL4xQTo2jvcesCalRFwHnkucfqM1v5QYUyy2opSDG3d4Ip1IiWPUxU6sUWMmaS3mwbrM9a+zKO1VK6BQRW/p2l6Npnr2QBc/5o8LDFKVUhQ87Dlp/eEo733JTU+qZmQ6maEFcIGj4sZk8v8vvUZEmKa+gyd6q1bAU50zmKLqzbeRxOg6WX4IKwLQJ5JIpuhQwtoiibeuGuwqQW0EY/TBocVa8EEwLBYuZSILyhEmCuD3AhbhHhnCgMBUwgfoO77K6kbdx6eXkFiJvAhyV/UvKYH70xQ0QN/8ZzyC9MWh7EEBpqlCVlZoZV6wzL+dnTg4Cgv0cxFz6ZAw8ERZo/V2mHDuUrUMjR6lo8Q4psvlk0SqjR0IxQ8vBdgSpxjo8kQG9ZhdezzAz1I3upm5VAw5F3mDMkHB0+l70Jq6ELT8xTyvcf7J1D3o7f1Qx/EgrL57z9Fe5OgAWE/WXYJo0ROXZQbhcCSFUNvRDmf6v0ydn7nffKwE6muB4YENTlCyb3SsOxEAV2UWZbEGkXA1Qeg5RjXAnlXwAWf9SUfutmwxobMHu8OotfyR+9ZnFQk1hbZEN/X2kkpgAoehgCmDT3F81OLL68JJDm5jQHrm+Y+8hQeLSu89QZJbc/J2PHeh8gO6U0dxFBebgxq3vNgL47b4c0b8Fx7DqrPAwC+9P1/Bl/4zP8EACgY716vZ/jlf/fzAID/1y/9EwDA2fxrcMMYIVhtoms3X19geusk/pYpg9bX0HWb6lg/SbvWG9Zi3kKXTWIvn7AQRK86gDmK8eAgvndwgAVrH1abGEPdG61x74ATw8YHqwljNOvIsr3cxBl0OM7hWDsipLVG5xgQcrWZMia86RK9UiFkuZWHYshquQCuTuPfbzyID0c+NmmBn8/i4uM6RPQjgIYL62rVYUR1WWGnn73wOH8az224mJWZh5SlOaKx2q2wu7Cn24Be6ZifOdsXaG5vFqnkKiEH+88lx2NdjLMDPcVU3QQwfYN1wxCRq5L+l+TGHAIyhphkU/Q+wAriSzZMpdOxG1a95nlfADscs4jVObQMgZT5EPDR2jBUAC5yoCgENceHdwjUrGPxG9bULB06FiULJddq0TPka8bkI2iQFElEENbNGi3VggdSZ2dU2kTKUZw7Z88tuuexb4ZSVK4MDPtoLbk9u0qb+v5xPODhscYlDRthe22aHF3L/BLrDlWmoRgabxiWy3SfBzy7iodYqPhdANAyN3SX3mt5TYumwVKYcBmyzHONQANEO40x52hBcmbVAGshd6Y+nPYKDfNeovGGJoO9/wYA4NGPvAkgogU/zYusWNf3sluAnNeQn/pOQTEJqhOHE1BLsTFJBCw8Chq0LUPRGhorhlqlQLcaAgV3+JrFg+2mE37dRJVWtwEtc+aeG/Wqtvj99+JvlgPpPw/awvgDUcQXB8aiTRpvNMKCRyd4U26KRgecz+Ia9QHHfDpz0HtxF3vTn2BzEjegTMV+e3dyhMPDaCTkVRwvt5xjhbgOKhpjWaF77bEQO9WufapVu1jF9VL5BlkuTO/xB0eHA3zp+34EAPDV3/oNAMCLF8/gWcOq8hizbNslgtmPn59fchxaLJtNogj7JG0XEty1Xdu1Xdu1a9GutYcFNFgvL1ERSlUcR3PPNR5r1gGYPYYpXI0LoiguGTa4e1TinXu0pk387NRvIDlJSXh3dQcnWjN0LVbzDnOGiTTd90FuUDJc0/GzwXgEFaJ7HFwOT3dgPImv6w6gKGeyNAyAjkglMWRDnWPzUbQGl27O6wIKegWaVeQNgM5Fq6r10WMrFVLNVfJgtiQ2hGFDqV5FVTworZE8neSmqf4L4iUZvZUoljAKDCYMd3oiHNebFoqINiFiDT6qNQM9yKTI+5CWSEW0wSfPVby5gJA8j0YQaaVOyerZ1QLEUCQ6psFIocwpA0ESz/FRjpLeljB1WNeTAcv3i1Kla9SZeIOAoytZ0jM2WscaPPTMFAghyXGEIEiADFUu9WuCJnTwpG8oxyQFXocUgp6+pKfdjLGZc/IwNJRrizCQkBZDuEgMX0nzS4cIKgGAbz+mxIfutbFUkLCuRUkOsoZhwPN5gwU95pxxuWKokPM5rMpxQoy2RKEYfbMHXXDS+HIEMMRbsu6rcB1mX4uggdWdAwDArR/7QTz6+X8NAPgfGUo/OShRk0aq7GQcFAwERMO6xKDgOOfFW4ZR8JzUEgL1SsPSmx2xxqhuPBzvUwBWMBERm/4B4HJpsaglLxBfT88cnkzj92aMFxemwBdvxGu4dyCoQ9+Ht1vObd2zisgzdbUAfuu9+Jtvncfj7d9Q+DzDnPfuPsAHiJ7Qm/tRzXhYbtA+jh7Ro+lHAIBBprHkmnFFpgs7LrG3F/v6PtF9zbTDo9//EACwzj+I53j7Myj3Y4jxcC+6iGE1xeMnEWyxJEimbAbISTNmGDbISuBuiG7X24dvAQAuwgLv50/g2gDsQBe7tmu7tmu79t9Su9Ye1t5hhmxQwNHCOp/FOPeeGaAY0kpykQer8S0yFxMbQs76clZgStACSDZ5uDdEBta2iDih6jnKWp7LAiCFVkr6Gx8SrDljLFuXG3T8TWdbjOlVaBE2DD4liQrWw2iEJGQo0PoCSMqcAh5wJZK3IlBxFQJCFz2rQpg4goHdCBBD8fq2+ABT/lmj24Kaxxvu+1uEFJ21yVPree00QGXgVq6ldRhI/Zpcu+q9FlH7DSVg8Wog23e9lenJI9e4rheY5PeU6uvIxNPaeI9CeBkDYNSr1qr3ISlIy9icPXWJEaEsJR+lE654I3VFAQhSX5UUhxUKelPuKnq/ustTPnBNa380UPBCnEpF3nxo0dJDF5HNTAErekxZzRo+qxEk38J7e/ToCoZsJkHcZOtRkAi5FYLa0AGUDekICgo6oMzjeP3mixhKsPUKXsaQ/YvWwOwJ6S0latou1VyhZY5n1SWGiMmRSWzBKheQTAuTtHlkcDyMKBIzD11rj4ye7g0eY7/ex29exOt/zhq4H1IZPstykILksV47KJLPOnriTTPAWgAFOZ+BNkebtfwe57vuWTIopgyFEiuGHILwKCqfeD5B2ZuzJfCC+co17+fiDBDhIKae8dYNi8++wbksz6MuUhRFyJu7EGCX8TiPp/G+f/OpATEoUGDtYO0xkzlz9ggNy2K+cRU9nvvZMZppXPOmc2EUztFRZNJRtDGoGpPDeJF7Nt78xWKNmi79xXns01/4xf8Hzhbx2J99+/sBAPt6gve+9rX4PUo17+23KMbMU/PZHekTuFm81g9fRjSKKzd4+/YInQ/4xif0sK71hrUcGAyqLBUvWi4u69DBMTSgGeLooAAWAQ4ZRvn8wRK3juMD13QCvtigk/AJE7jBhqRkKizQ7cbDSp2WsKxDQfN7XihrdGL/QeOBjMnInNRRrXEJ1JBL3cxWmLBditS361FwKby3TSPUI21S5C7VovSQv1QQjLCltov0PRF0FTn24IBckuW6f09+I/LpRvWgDMFrZH3UMVECadNL2qcXv12Lyc3FBqw3BJyQuNPlH6vZfIVGSWrXtAKskMuGgKbbqqFBLFy1PJIg5UwWkJU9Mi5+vz+4hCddQNokhsMegONsv6EBUSOrY+ilIOrCFMCAwBDPa93UXUJG1kSQBRWQsb/W1DkrMkAzdCjKw2WVbI0U9mragFxYszlfOo+k9SVhwLrrII//k6to4LQeUDy2SnMmwNpXwQjwHvv7ET5gGJJ6+eI8oSKzJqBk8n48jIn2ae0SuEgCO1r1xM/SItiGiznHMOscHMfpA352trL4Bp87RvXxGSjcINiqoJGQZ+vENj/n+tBmDoXUEQpIQ8d7BYCW4cLGN5jz4TV8yCfeJhXdCyqcn54BVxuqC/O3wTmMieq7dyve79uHJQwptPySxkTWpkktNVwvNxrvr+N9vE+Mwzk20Pvc0HhND9sKhzTCly++jUsijYTE2WYzjPRBvM9xHIdMe4wIONojocFGb9Au4mJ25qLRb4sc+3fjDTREX7/81hP8fx7+MgDgW7d+J/bHEVBUsR/2jwlC0kO0rL+sJvF1f1hhcRnn2bPzK/ZVi+O2TLqCn6TtQoK7tmu7tmu7di3atfawXNPC2g4DJsT1KO7mK2cxY0GEaBaVpkqsAHdYL/LD725wckDrsaH0getSyCXV1vi+4jwwfGNboBM9J9F/MjqxIAgZps4VjMSnOv9qVT4APdYYMhyZ035YnQINSW3FAlQqpHoYsaZD2PI4kssT0p/ymbMBSqn0G6D3hoDe8/Chr7lK4AsAhirKIdEh9ewY6T512JIp4fl1f7zhWI6neqVm1X8vWdpM9rsQ0LBGRuRIGo1kGct9hNCHQ6UeDj4kD0uZnolEEuetDQg8toT/oBRM+SpJjG37y5JQqgKgEwUV9bWg031KP7dtX0tT0jMKyvVUT0nCJCSVD7lA62yiUjKZQJ49NCekUHRB69SX2zaqaFXJ2HRdHw7NWUvVug4vqdN1xVBoZxQ0xZ5MEG8PiRHD8LmphgOMKFZ2XEUo9Y2bQEfZCBcsWitsK3KtKsliIN0HYDhbxDtWUKjoXspnF9NLNNRzE/GrRRbw+7zuhzzuRyrgAaMtdxiKupF32JdaNpknxqVwv2ileb9VQymh5aBwIXQVpH/yJXC5ZE0TibBPlworK2GP+DoZKbwTiSdw54T1obqGW9HDYmfM2oAZVbRn6zj+L9YeTzeEzEPAPiHpTY1IxzY8eBP37kWtqnZziavzGJrTOno8Tb5EYMgzUcVB4yYJfUcmhnvyTmFFuZAL1vXpwRiHd6MXfbni3M+usFhGl29J7xKnHe68EY9z92Yct8rVyFX8W4A4Vgco1oPcpOe22GxgO9urqn+CtvOwdm3Xdm3Xdu1atGvtYd1YBZyMC+yxmn5Oy/KZrREIOMhZwFiqEJNDALJxtBzK/RxZxqK4Lh5DwUPTWvJbnowk9l2CwwLF61BhZ6FoPYpsgVYqJbl854XLEouzeC2VyTFhYXHNpOT6PCC0IvpIa9nrCEXHlvgf0BcxSqdsF/ymwtsAsU3ES/L9W+knDoBhDiMnmEMpIPCeRPBR697j6I/nexFD8b4s0BFPP57Ih4AVxVx6mwgh8dAp4WVUCmK01vQyaxswkLyiyDn43uPUSTIECSRTlhm84Nl5Cdb75KFpAWLYEAU20Y9n23Y9V6Ik3bVK+UmbwB4uWcyayJi2tUmJVYhWvQHa1xAirQWWy+g9DMZi7avUN94LQ3AHQ09tvZJiYiRxxd7T7oU0Uy7Dhd4rJ1LEFwbffMzCViXXFyCdnqt+nsi45wRNHOwPktsqxdPHx/tY0fOYrqbJK/OEe5vCpLxdYlMJIV2r5LeU0km6RJMYd97N0thpiTiE/jhLli18UwU8Ze71Ph2AN5XCl6jufcRz2SaDI6mjSfo4qi9UF+/HKpzN+GxyDs1GwJQgiSs6fRsoSDbxcMzz39J495AeNvi9VcCKoIw114Tns4CzdfyG1+SFNB0Omc8e8cGYKg03ieUqByd34zluP8Cbhz8AADj96H3MmiiD2BFcUhgFy6L5DbkuF11AVh6y/6MXpFyGgqQLK3pkVivc2Iue3A+8E73oxa09rMmBOGUnzC8BiGrzBW/UdCgGjMowz29DjrI8AADcuB3PO1xP4VDDdh5fB2lBvku71htWpYFhVSb6HE9InW4shuT/4RqAHDbRvyw5CZ7P6rSxZKyB8aaAIltmesgtEhJRaqUyBRguRJmE95wXlpjEHqFVkMJ+2BbQsvHV3JzWHUom8VdTIqWmIaHcJGxmrXplA5X2MaiF6utOEqVS2GIu2NrsUkhuK7wm4VC5poC+NkfqoYqiRxGmxcf1NS0pdOWRFHj39vrPBDQii7bWIYljCflqCBqGu3FD8EKdu7SYpNCa7zc7YRVXHol13qj+ntJ99qdOWlnWAb6WMBEXti5uUACgZfwzn0KLEhJWmU2huZpJ8s0mAmWAHpSDHAj5qwv0ZhXgHDcdQght11NVrTc9MlDxWsTocV7DM3Hu2Kle60TsKn3fKpXYQAJ1qrwe4r2zuKovhOPI9B2bwskqjoX0KwBkrk11Wg2ZmK+WDWo+N61zWBNAkrGTiiMPLTLQqreUBLSTyGqdRSPADwIo7KpOG2ni9s5UD4oRYwcajs/9hobByzVQ05AVo2PZbVGUCZOEAjyRhTURvOfzDnOSkmuGthZ1gGUf0RbGQAWMiEq+E8uUcHjosObmejolq8mlR0dEcSjj5nRWa0wbuT7W8oWQFpqa0NUuV9ibRGqj20cxpHajBC6efQUA8Ojb34Cv444x2KOx6wKGPM/oIG52s1UNNYp/r1v21XyDSgxQ3pTWGUZFNOLfeCPyhD1flXixjlDFipvxraMMhqrMjnP1smnhiXLdN/Fc+0WGo1HcFCeUrJhOgbpZoBM19E/QdiHBXdu1Xdu1XbsW7Vp7WFd7OWaoYchuKRaoqkoYWkFGaqS0hq+ixWAoaVC5BoHW5Zq8e8OQIxPrW9RmO0ByviFZ8SpBsD1N/GB7RoFkofqo+ApEpoWEjWDoyDTARkIVW4lRgTB39DyC8in8ljysLXMjwQVc6Ou0EhAAKfQiHhG2IOc9TF6lkKcQ1Cqt0jWLNpN3SIAIAaZopaAY6iNKG6VSsSQAwGhAiHJQib+t8AI2CX0kUyI0vgdLCClt53vdIRkH5/q6NLm3DH1o0dZ92UBK+mtAJzCDQNl7tgrxUJVSMEpCqT5dl8D8pcQhoPcal+QwtB1QDrcuCIhh6tSZ7OcqYI+lFTJ53DzHmnUzljeSqYAlLVhRTnZWJdYFCcGtg8eGQAEhZ9joDOs1Q33sl8tLj8csmRCYNsre4xEvLSidgDACAKqcxl5JVVuCAur5HDn7KK8qrDhAa7p2mXXI5FmUOqwt8I4MrHMuladIV7XLDTyvUWoReVpeA/j9gIL6OXutxPcUNi5+2Utdom1TzZN438oAIuQ7Y43cYtOP9cFEIgQOa0ZHHK95NAL2xvS6KHtztdKYd3FcX57F7z965jEg99/N2zEsN54YKJLVwsU1yLuAhYQlpQzBWVScXDVZdRftCnYVvVrrTlHQQxPlZFc3GNFj2r/B842GWDHe7OkVqkmV1gfLMIp3HVasKXxczdgvDTznrSriedXeGnoikkrk8pwFLCgvY+T5R4aOsPucdWBlUMjzEdqw87B2bdd2bdd27b+x9j17WP/23/5b/L2/9/fw5S9/Gc+fP8e//Jf/En/qT/2p9Pmf//N/Hv/0n/7TV37zEz/xE/iFX/iF9O/Ly0v8tb/21/DzP//z0FrjZ37mZ/AP/sE/wFhoID5hWzugdh0sY/VjksbtaQNDN0REHbNqiCvGpg+rmGj44pFBTqu2FiN32aFJnkT8bbA2ib+J8nDnImchkOjD4IOCo10rPG5tcH3yvcjR0V0QUMWgGAE6Wh4FY+ydzVLeyzERFXwqt00pAKgtbyvlodQrrOoA5S/kc4n3m/69hMgNKokdpjSTDylfJJZx1/o+fyNere4FAcGi16rUCRwh0i/a65TclnyaUirx7ule6wSWXxSPrHM9AMAlEMyW2OSW95hxnGwdkqafSrNdJ/42gV8XBlDsa4Gmq9AXXW+XOkg+aJCJmV+n/m9pfWujUI1o8RbSgUAQ65cDMRhmCLQwF3NCwb1CR7BCx9zC/TcPsVzNeRHMxS5cumkBgCzagFUt3kD8bKUNNizCPiSTxQePGrzYRItYvIgu2MTanbxS72P1OyIXHgDcHh8jZ2HonFx7Hhp0TJBlCllJdowmue+pqD4hhULAVm12vGZs5anoaS26DRwEWBOvxXQKHftN8pU+MyiZlzlsCInfbPB7jyQaQNBIpeC7eH3CUOKVw5x9PmuEFcZgbxLPsT8mP2aLXtWY82SyB4BF588ZDTi3A+iaauC8zibTyMuYD6rp9Q2LgONhTHyVhIK7zOMGIwATKkycXl0BVIhekr1H12c4YGH2/bduY5UKz8mGb5oeVMSH5Gi0B801qBnGRW9/MoImmOJyHnNU1gcRncbz6TS+t17hrTuRa9DsReDH2XqGJQvG9R4Lkg8DwoyF20zGW+fw8mn0JMcqRroO9wZQWUgAn0/SvucNa7Va4Qd/8AfxF//iX8Sf+TN/5jt+5yd/8ifxT/7JP0n/LkWPge3P/bk/h+fPn+MXf/EX0XUd/sJf+Av4K3/lr+Bf/It/8T1dSwYDt+xSsj1oDoRyqbZF5+Imb7Ch9PY7lJk43kOiiRkLJZELCVygtxLQgtBL6sJNgBayWj5/OXTaLBwXqTZoQBA/gxLLhaCrWAc0blEve8QbEOtwknprAgyErWS1hGtC2sS2QQjpeZfQYOg3qr7eqUc+9t/3fdhPNji7FSbaViNOYcI+pJZCXzyGVQG51P9wbW9UARUkvtojtGRDSOcIHrnQYTE8svYmLe4plGdVUvGVMJbLIhkvEKXqBaSQwrFaAbLhcTFuvE+MQaIuHVwPNCiEikhFcloAKDICC/Ik6wQvDAtDjZLEy/mI9VqFT/NSkHDBA0vKWjTcVDKtEs3Rmii1zZVDzvqkJcNAvvOJxYNSSFhZhVrq9RjunnqLSkK4nLPfmgcseNHZoEe+ugTXjC86AIGT3piYwK/KAVazuHmechO9CC3A+p48U6hIeqogelI9ylXmpQ99qDUIE4fSKAvROmMfrDYpniuSHs65FPOUezImYMDdS/SrrFd4wl1kQ0DPF984wp2DiKYIpDlqHCA8wi0fuqoqMGRNUyXMV1lATrqpFcNtDzOLphAQSjzvaLKHUX0Q32tjSO3oJEuyG467np1bHA4jCGG0H79vjUI5iF/cW8XzGwssV3HBrzjHstpiQjSe2T/u5evPIvL5+GAIp+N9zut4DYXaoKAWVz4mhZd2aBlKragMvl/u4yiLm+F8Gufb+nKGyRsR+PHgzoN4T1eXeDyLZMXnNtI2+azFgCkAUa5u2gzPPojXRU5x/JH/8Qs4PihQN5+8EOt73rB+6qd+Cj/1Uz/1n/1OWZa4ffv2d/zs61//On7hF34Bv/Ebv4Ef+ZGoo/IP/+E/xE//9E/j7//9v4+7d+9+r5e0a7u2a7u2a/8dtP8qoItf+ZVfwc2bN3F4eIg//sf/OP723/7bOD6Obu+v/dqv4eDgIG1WAPAn/sSfgNYav/7rv44//af/9MeO1zQNGoYeAGA+j1addwql0agGFEMUq6prsQ4CdIjWQVi0oLECa6KrHlrA+RiO6yhHkvsMWQ+ZiP/3UUEW6MEX6A28JDniWt8zLIgl5fqQWtfUcBSALMkGG3yDTNwanrbzLkHSt/ORbosTEKDHoF55C2EbmMDj6f7Qqfltr0u+t328rfCaQCKc7c8vYSSxNrUCOsalxDuz8GCePSntdsqlsJhAu73bur7kdIUUliRmBavOJ2VXuUcXVBKlDL4fL7kIH7aAJimU45NlL2wQSvVght5b7b2BoATE04M3TC61cnEuAoBi6DMfBJBUInmcLoTkofeesetxGOwj24UUchXv/eWzZQoZtwQUeNczWNSMe9YN0NFb6dK8C5gw9PUyRnzwfIqUQPeJ8Hi7HIMhcKVjjR36Pl/aBpfrKwDAqqV8SI4kIQPrMSiiHa29EOca9Ah2iRD0fSnqvSFYKPFm5Xud3QK9iLuMfvVKcHtgQJg/KyGQlQWMyKJckhy2W8J8Lv59sCelBB2GPG3JiZL5BiMpK5nH8z5dZ1gQcHLFAVt2DqOKa5CElRuNAWuybt6NHpQ/32BFkueTvRgazC+WaBj/PyCLx8AYKNa/3DihZFJ2jA8/mgIAaobvKjNE1/UhkQPC1bGMx3v8+AnGZN7NCOhwbpU8/lFJpovxEAt6lRnltINSWK5Zo0EPyFqHF1NC5ycRPHI1n+GMHt2cqQ2fWSwZvjSkzLmjgXcPYz3X07M4CZ+cTnH31ud65pZP0P6Lgy5+8id/Ev/sn/0z/NIv/RL+7t/9u/g3/+bf4Kd+6qdSnciLFy9w8+bNV36TZRmOjo7w4sV3Lh77O3/n72B/fz/99+DBg//Sl71ru7Zru7Zr/3/e/ot7WD/7sz+b/v7iF7+IH/iBH8C7776LX/mVX8GP//iP//90zL/xN/4Gfu7nfi79ez6f48GDB1jMaqgsIFD6wTKOHDIHLdYqTfyusVg30QJZkha5ME0qchVLpYTDhrF/Qwxt17qU1xKPx+Bjzg10ZhKYIhXqeiDXAn8HvPC2Cex3ARheQ8tclnIag5JFk7aH+7oETe/BCskzEevVAMl+T3kr1cPa+dYrPIRbII7XnDgopZIX4l/xyF7lRHPevvKb+FmAoIs105iTwuOSsOuGXkEIPlnvPrmtBpaeE5UYsOlsEqIMWzfug3hqUgLQ31vYKhxOcP8ubAFNeLrsOwBTtvpIZEiCAgRr0cOzNVqCC4oqTrzBxEGSawnAspWGFA+67ULi0ZOzbbPXjyfxZLN5Cwky9FyRCo6enZQ/dE2evPzGyzzuz/stelhXCyAb9n2UTv+aCatU35ktreXzWkEzR7VPcIVez1HXZGvPczgRFiUCJ47xa6CWTKc+lDyp80DDAuT5NHpxHb04fjE2E0tVgB4YUwSDCZ/jhnIVXisUdHU9IwAvzxo0q+hd3L4dP9s/WMPn8TetFHeXCothTFEEyqjM6w7lQTS4HxxEj2G9mQFdjPqUzCOZXGO5juCI4/34/Xf2j3B+xVwRAV2HJzfh+ex4Lse5yeAMC4dDPG7IV9i/Ea851PSaWoVAt7CpL3D3OJ7n1pvxNR8YnDfT2DfMow8DMCro3ZFZJw85qlGMgDUEl1xdXeDlNOakxoPoTd05uY9LyrdcPfsGAGC52GDZkLsQspZ2GI/iuE+GccDu6gxvvfEGAODh5h4A4Le+/QjNbz9MbDifpP1Xr8N65513cHJygvfeew8//uM/jtu3b+P09PSV71hrcXl5+Z/Me5Vl+THgBgBA5fDBYbGSyu7YYcOhQU4tHZFGgK+gpb5Dx04/GA3hGkm6ssrcAXkWJ8SayXDb9QuMPFjbtU0prGQtNgw7HtySyedgGQasG9vvDalOSCXghwAxQvBpcegSUaxORKi9hEmA31rkgLg4JjJVWRxtLyWy/brNmCF3Z2WRE4ACetShLPIqAJ4bc8OO8R4YUEYjIxVEgE/hK0GNHYw93mPZScvxytQW/kLCovAI6MONALD2Go2Ep+R7nUMt6tKcIgYGRiwL3R9ACeFsQFqEZf1ToQdCpPXb9x0k4Tqt+w2w4701XmFdM9RLdd7hnkPgIuH6syQpnBQ2dX2oVcKA8EBBlEq2RQ6hhCiW/Vu3LtWopQ3Ya3Q83ppzeoCAJYuMHq5JfosGQeK1W+tFv3dyjqmQNrSa8MiVy3FIRG+2iZuUWy4x4U5eVhNcEVzUSNi38tC9GcH76P/uQTwajDKjqUWp2/WgnK3LE5YagYFWOsOYdFQvCSJo4VFQ6mSQR4PVqYAl6YTOpvE9c/NNvOziunRmYyis2JugmMT7vHcc0XH3bk9w7/AGAGDC8Nnz00d4xnCjllqo3KMl5dbzqwh4uH18jLduxuN8+LsfxntsPA72I7hBqLccVFJiFpSoGRa4NYwbZMm4ct5YQHGdyICrzVMAwHgcN6zBrQITG3+TE5W4fHaFkY73dCLqwm2LAefUy3XsgyLXmGZxbC1DlrmpIIH6wGst9obY59rcrONaerm8woJOQXkQP1vuBXyE+OCvuJkdvnGI58/P+3H8BO2/eh3WkydPcHFxgTuEQ/6hP/SHMJ1O8eUvfzl955d/+ZfhvceP/uiP/te+nF3btV3btV27pu179rCWyyXee++99O8PP/wQX/nKV3B0dISjoyP8rb/1t/AzP/MzuH37Nt5//3389b/+1/GpT30KP/ETPwEA+L7v+z785E/+JP7yX/7L+Ef/6B+h6zr81b/6V/GzP/uz3zNC0HoHkyuw/Col9g0MHKn6V3PBvK6TVbtf8geuRkZJiRHlg/1FjjVdXP9aGBDowzEAPua1VIMKs6WQmwqSIcBT7bWqgALRra/XM94DML2SMGL8yaBCYgWQ/LO1vrcutiNIfUF/vOawVebyGivE6+1jNVzoYfL9T/zHwoRB9R6n1EBlRQ8l7x2T/ofiUY7zLSi+EqJdDc2wSJs8hQBHS7JlJ6w7hVo8U4FIdyF1R9KZhE8weoVEKtGT/W6FCRNbhQO65F32kPMEiCiSy5mg5KJM643r+4EhkHwcknCnKNRaqxJywSdZlh4SLwNXNwEZk+7ixQX0oTQ5V9v6XmlawrChS15X4pTUwONLgi5YfxSqJoXSlIRZVX+/fovZQ1ojLqA+wJRQ55rRkrECSlLEGK3gVqxpIuy9OriFTBh9xQMPIcHoRWwyyxTGA4IUUrK/52rskU5IXpmE3Pe8gWFExRt5phT6OLIsEBnKI6YASMSq7+zDUzU4LOixtwHFgt7WgN7ewR28UNGT+L2PHgIAFosZHM+hVkw3wEOPGM7teIzFFe7dZXjwfgzBPf/wCZbTSFpbljfYPVVUKwWgybtnjMaGIVeL+Hp8oJEzujSHwpKgizlZe5bzGUqWQgwC+QNth5cEZbgqfm92fo7wKAImjm7FOPFmM4Wid9fx+8tHz5DRO9JDqhZbl8Q690L03EaDEqcXMRx6ccl17q0clZA3n38EALiV7eOzn7uPtnH4yi8+wydp37OH9Zu/+Zv40pe+hC996UsAgJ/7uZ/Dl770JfzNv/k3YYzBV7/6VfzJP/kn8ZnPfAZ/6S/9JfzwD/8w/t2/+3evhPT++T//5/jc5z6HH//xH8dP//RP48d+7Mfwj//xP/5eL2XXdm3Xdm3X/jtq37OH9Uf/6B9NsNTv1P7Vv/pX3/UYR0dH33OR8HdqtnUwpUqJacmB287BNgKTZdy/7ROyD0hz32GKMU3tltxtWreJR0/AlikpjS1IbtiyvoVZu6mxJp3+TRYVWu/x/DxaXUEDxzGknKzC8WiIASXgM+YA2rbDYhGtGjlHWfZejf8OOYdsK5zvX/P8nNvKTaD/7OMe1tb3eF5ne29LPtNaJah2YhTJFXJ6IVYS6FseoGac+ngSkkVvyRRi8y3TXvcXIzkbGcNlE1C34kkY9k+Akrg65H4DtE2uAhy93cTisNU320CShKGQe98CKyRwiQugQgcqJTkHj2rMPMqYXmPVd2LnehCKwNpJfwlne1BLoOfhnUVNz0qiBt4DTWKy7+H7ksuTioPGeTQECwnbx6oL+OY0fj4TkcICMOJdbj3P4tH3ZQ0qRQtcAgot0w3URMR01qEm7D5rA2qhvdjKu4VUNc170tgKUzA/q3WC1C+msei1XttXyizkVeZRSU/gUBXYUM8iAWe0gdEijCmgHI+gY6TDkqHio+nXUZexk0oylJTdELcREcmGMO1vL7+M0yvm4FexL8fzgDEZTvYPD+L3Dwe42kTvIhtRfkPP8fDpbwMAblQRyPD2F08wv6JoIpnvB5VBJ31Aj20WWtDxQ0ME2JN5jcP96Dld1cCtSfTaTtj3VT6KxAUArIoLz+D2HTy7iDD03//2twEAwzLHMYu+A+fY4WgfJcEn7TD22/zsCpq1H/J4zbplYu4/rOK9H48OcMz+ffxhPMaLx4skZfBmFZ2XTw9H+MG372Gz6fB/w1fxSdq1Jr+9dXiITtUwhYQL+PAoB5/iBfEWR4XH/+WPxI7/sz84BQDYlx0WLPIRkMFsqXDEwXMMy/itWqrXqWQApDogowEtshyMQ3UIKFmNno96EIKgzsxmlWh1OhZ5hdAvrrJ5etdvVNsbTapBktCQ9/01ysKg8DFIo8cWiDCxaXwcKQdsbVSSH1chLWx9mDJAMItqS8JEFguhaDoYGRQEv7RECy6DS4AJn4hlQ8+wwYV8A48Vd4tGCHTznpw3SbEEQIl2V9bfaBB2C7VF5yRaWluxXr9NFZPAHVKz1J/HdSLF0GBMQ6SseO/61esBqMOVNiq+1/XyM53shKqXkhF6MO9DQlPprbEUsIVE67o2Ig8BoOKGdLoAnrQ0zHgBGfroWi8VExL1mPSBQt+/jhN1s1piRGRENY4Dt1luIGiJaljggBIYB6x5vIJJ4b/0LBUGGSVEZHKFEKAp0tNQBdd7D8NaRUHKwvrEcKH5kEyyEjOG65SXpS1LKsrSbaUxGO8T1EJWjdPpCg3lrDMtIKMGw1HcTBw3H++nuDlgbdkinuuNyTFOuFFdEnU13Jvg5OhW7Bsi/qxdY/Yohv88DYcbd/cxrCJqbnUVf3tpPZqWUi2iHp0FOA52RxkclY9gSbB7e1yhYnj20MVNrF7UOLgdjz0+iaHITRPw3te/CQBYvIgowMN37sNI+JLMH5vFGgdHh/EaSm6oeV+7dXI/Hu+0OcfDl5Hp4rJcpns/oPH9ueptAMDo9BnWbby/z7/9LgDghz77/QhtBg1BG333tiO/3bVd27Vd27Vr0a61h7VXHeHl9GGCVhPPAJX1nknnRbE3w/k3osVw6iN8/u79U6xttEpKERoc76FuWLH9uiTHK031Yoj8Qt0hyT2IxaiMwnASE5njPY/FlF6UJMatSl6IwDsDNAxvQGo0dG9b9uEst10j9fGLfB0swYO/cozX/36dSFZteRTbodGERhZPy5h0LVauOQfyUqxlwa23OGBY4UzqsZxFzhiD1GF529eOibfX5DqxlUjotdA95+Dr4TGANUiv8Si6sNU3cm8h9MKS6D9LYUSpjgASz18qOTBIiA8vFrEDLOdCy7KG1mIL+NGHGEUQmTl15NgaV9sjbDw9LCth7NCzWQgoIXRAx752DA09uVJYBAFEMJzY9mUZagtkIp5wT3kRUsiwFqjy8RiQ4/AzM9GpNMSYgMkeGRoIdNDOJALWFIIM7OSt95ztYMhSnJ6Hrgfq9G4+UNDTOWZ4zV051BLtYO2Vg4ZSAvYgyMToFIERnsSbwyE86+bkOucbh8ePH6ffAMCnjo9xQpaH5pjndRUaQsk9CQmfPr7EMcEI4/sH8dLHI9Rn0VNbcFzDxsFSVPPqZfRuukwl0M2mu4p9381gGAM92Yvu/P2TO7jD8EjTzJENGH618TgfvXgO8yxez//6qTfjNdgS529GrwdkDOoup3AHVAG+FbkCp8+f4tGH0Ru8E9/C2+UYmmHQ8CSGFYtijQGBLqlOT2mUe3GNvUV+xBsHQ+RkJMpG0ev+l7/wazh/tki1pp+k7TysXdu1Xdu1XbsW7Vp7WLcf/Ajq1VkqnlsSSuutxYAM1LUwXqBCoX8MAPBW+Yfjm9N/iUUWLRhJgw0xw0tPHjQVrRNtgEqOw1xA12k45qMGQuTtFAItzyDCf12BEKLHNi4mCGSKX21YGKg8WslNSDFemaei2uVG4tUZcrzKohF074Vsmx7iKSTDJfQ5k1eg+K8lsvUWS7gkqJ3dSmDL90yfnO9BGgaeb5qMie+gsWGCpumi++s2wA+cRIvsFx8xgTtUKK0UCTMP6QDLE9YCXqgdzllJW1MPplMtVopWH2VNvFeotpAkAndPXH1bcHXJ3zkfmb3jT/qSgkzUPJOYpE9SKPlQEkcG52QNuPuAYAljsJ7HC6rn/F7lkzclBdpdC3RC5CDADiAx6Yvku2o3iT3ifBn7beodRocsIqb34xaAa+ONfotz9Zu6S8wJQ/HINLBJeIf+fsXzTJCJ0BedW5G80RM0Am7gs1cWGhnlOQbKYM2C4lpJAa+CYi7JUPLHZOh55AjoyEuN4YgJwdBPaqVfSyIHYA/RUn/z4D4A4PTF+2iYG8roEeUoUo5OMWun8hyTw2jtHwujvle4uoy/XXFSq9wiy1iYSzXXet3g+PYBAGB0M87Bj148xwcvYqmP5N9yneHyZaSaOzmJcz8rHOo89tvT8ykA4I4LKOp4v8sX0ftaXW5Q3dmPx7kVz+HKEutlpCnZp/7RoTa4fxSh8GX+aagiHuer7/1OPI5/hM8TPu/Yl4+fPsVbb0TP8DNv/jEAwOXlBleMWJydUsjWHmHZxtzV4SB6j2/cuo35eSz+Pb+Kr229wJD5xfGErCYroK6j97agSORxOcacz8HyafTOzh6u0C5DiiZ8knatN6yTO28hTG/ggJNUFoPHpxfouFCKVMTJYIw/9Ok4eI8+jEXLXj/Egx8igoe1V84AoKRApgWl5hMwQsIU3vkUapKketMY9JRFRJC1AD1+fPj+IpHBGk5soEHFHa9lEdJyvkkUPydHcbJ7AC0TuqIxFMOAgsLqQ4Npg5EEeoTc8XtbUDnVh1cAqfV55a0YEnwtjKVU6DcvLeguD8FVCpevtQaWNUgNGUU2qwKDQXwIcyZ417WB40YuGjxbjEDIGObJg8aSi+JLKpoOx65H3kXbA6PxAB2PrU0vmZJAFx69jpdEKv2W6nFCBvYb9/ZnifWCQzg9yxN90N03ySiw6LBYUsJCwBdN4hHtgTYtUKcNK95xYUI643wlVBdIOlGHwzj++17jdBb/fh7XOiydwlPurt8WVFkLlIJk5eGcBgbUjmpEn8fhdeKJ+PdrgB2lAzYz0hgt42+P9keAjmNyulmh4KRZsrZptDdMemT9PCJSEH34Hd7h5fNYk3N1Hmt5vHfIyZQiQCh0wGoek/zrMt588D7V0CWEKXQv1cNrKqoKw8O4aGcTbv6nNZ5exXu6auN9HL4xQUkjSJ6pta6x6KIBejiIm8rJ4BYyalCJTIfRCjcJZLCsRTtf1QisD71XHAAADnyG1sf7MKN4jNW6gWYt1a2TSLhwdHwD02VMX2TL+NnVYobHm/gsDaoCmgAz6+K9jSb3UFRxzbuaxd88e/4So73YD3dvxWOrymG5nMZ+aBgmzDqU+8K2ETfFbz66QkkDQ1NLy3dXCHzuRf7GFANoWo9iAPnOoV7EftvQsbh97xZenk75jJEz7Lu0XUhw13Zt13Zt165Fu9Ye1s39I9T3P4+vf/3fAwD2CzJK5Bo3TYR2FkXc6e+VE8zf/z0AwPnjaEW889YEIF+ZIm3BpjXQr4XeLHppBTEEM6VSjYxwpDnvk4ugWBilTcBBxHrAdYBlvUbN+orMAI0onNJ8qCqVwBbrGWG1IUBpOY8cL6TfyGfe91aNeCge21ayuEbbYIverBYAySuldil8Jkn6LSCGGMYhIM91+lz6QyDpZ2fxs+msgaliyPWQtUu2LeFy0X6hx6ZCL4ZJd0QrjUuGSJ8wfHM8Bk5G8dgrgg1ao3BExT0TfJKw6OHvKvEdYktmRAAW0lUefV9vS4DI310jnkKOg5M4npsN+eOearS8F0Zq4KGxXhOa3PX9LGwb4vDarXIFxfv1mUagtPWcHtSqVTil7MVqHifjCxXwmPDsObknuywkVWw5h0HvZSZZHh9SLZpYsj4gMR5I7d2q26BLhM0MvQeH6pjjEDrUOlr0GZPuCnniQMSW5588df6RZRk8+1BAJkFFyZ34j/iSAdhnaDEXLzj0ZRSK7pwKWeKkFFe7GAyw4jxfk8nCGYviJrn6BNjT1QmMVebC9+kwq6OncJXHZ7PYm2CPgovWSdy0hWe4dEYX+vFiir1x9H6GjFMPdY7WT+NxbsRxffetPdwbkmCXiIeu0mhvxIXEMrqweHyG51NC4d//Jq7I6nPrzQin37/9JmwWPaE1CcKP7xzj8jzC0H/3G78buzQzuOA6uBTewCpP4pCnFzG0WU8XqFJEluHT4KApo+LZL84AlqmRhpEHrVscUPBUUfVa3T1A+eYIXWvxG1+O1/Td2s7D2rVd27Vd27Vr0a61h9XaDfYf/AhmH0ZLYXkWIahfHE5wl5XdE1ZV36wKvHwad/H1ikWszV3MKGEwmkjCvUTuaWJtex5IfwKIO70IN4o13HlAUlOGFlRjXc/L55BkSkT/JIT2Y6zpAX1uqKDJa11AY3tmdCBav4mTbgspLIdLyUzdv5dE8UIPJU7e11aCfRslLzmsV24+5bAIazYqiQkmj8EAgcn+xTx6v8XQY7mJVtfN/fj9q+cunU/53pIN/Y0AABrYxDX4eBpfJzkwfCCw5Xiu+WyDgn2U5Uh/J+/BBAhuO3mDDvBKINHSB+EVeD8Q+1FypSLGeXXaYDhnzobx+dYFZPTyhMrOWoOmfq1wXKs01nK7nQ0pd5Yb5jVCSIwZU3pVs0XAKXMiT9jnD+GwkZwfB67NXBLBHGwx7tf0xErxahGQJWZ7XpP1KRcqY73crDCoWEPCwmEz9PCDeK1lqQBa9oocdtZqaHo44rEZpZJnlZNncFQVWM/iM9mJrozfmr9sxiExxrdLIgagkihoqrdQWfLsBOyhC43nFxGybRGjLZNBQEePyIgbAY0iMFdDJnKjHZY+elYfNXG9KaphYqkJnByjooIm/PzxR9FDWXZrZFK3QUWIq+klauak9sYxJ3agDlAzgvDVh78fjwuPdz77RvycQoiboIC9eF5lLTrC2TdFHIfxZIzzNoIjBB7/xs37aLoIOa9dzBtVkyFGRfQ0H334PgBg6CwmLMfpFg/je6HAXhE9vvPL2FdNG7DHiInh4lbPlnAs4L7iYmXyAmPOCZAFKMtXePfeAzT1JxdwvNYb1st2joPBffzwW38EAPB48f+MH5QZrhgeUiSeHK4bXBK5skEc5KdLh9skA60yWTQ6sPg9LY4hAF7qO2SzsCGVqghBgQ2611xKm4pKyf5ggZrXVfIkauthFFSf0iotEqJt5HyvGSVNaQVLxIcs+Fr3G5p835htdOD2hvXa8dTWe1souh5g0X/2+ibmQ/95I8rDHomI9eh2PPCLFxqOhJ3HTGhnroFwBo9I7+N9f1NGNhcfIo0TgFPWsbQvfdJ1eoOhV2M9qPaAagBM9rj5c7YHH5J6s0hZBGyFQ9Hf0+v94cJWPRrRWst1i/OX8c0bbXwYx8c+GRiO1+qsQ2e37o/nTPVm2DoXO3hGFoDGaqxm3Kiu4utlrfEh+/dDxvyWGVAJ8EAWC7y64aZXCSOnXXkLELGFxEmGjVzessWESLmr/bg4rgcNdBEHcX9cIWu5mRA0sjE+jaMRtKBW0AzhFXwdmAJz7q4SEown5tmTVp3CmMSwzXyROi4xv8gmq1TaIDPq5Xm0cOFKbiYeT1UYcOGdsVZyUzdoeL5BKb81WBOU0VKmqChzKOYH1lwMbt69j6GKE3Iwjq/t5RrKxd+sumgU50WJUsXzKoZ1lzOPKQsNT0/jhuPqFaaXEYRyeDP2rdkbpI3PDxVuvEsldW5Yp6tL1Mv4+xHH9Wivwu17kf0i0IKbbtY4OonAik8zlHd1/hJ63fG3BPscjDAiCqXgc3iIPWwYop5P4/qqNXoZHa5z6yJLYd2caCW3btFebdAKEukTtF1IcNd2bdd2bdeuRbvWHpY1Gg4t7t/9PgDA/PLrAICvfvR7qJrotn/RRPd5iAxThhg2rJ96MW9wQWjv3REBFqvuYySz7RaPX85XHeN2AABGd9AFn2paxDtTuUrsAT4gyV6UQ0onrHSKXojwXgghQbHFy3BbfIbbXIIJSyGhPP8dPKfv0HevfOc/9TeiZS41V9vw5vS3JOxd+A6MEhotLVSpWWqsSkJ1Uld2tAdcsC5tVEhYEYlmYltIMQEe+MezWuHLz8keQuvv7kTDrijMaV1UiAQwrOSiHaS8isov0FuAkwRqcf3NiJWunE/8fRk7cXIAbMjasVjw+nOFwDCXyKh0rkt8i1tOTT8WUiJgVPLoz1dC86FxPo0nfsayrkch4CmPJ0J/hdapBtCTW9N0vWUqxBkuAJmwd2yFegVerrYuUPoleWkbhwHvbcqwZw2PKj5yMG2OoYkhw0kWQQbP0aGnaOzjrEHmj4A4Oo81x06ELaE1tJSLcHBG1SBxK7Yik6J7BgsZt4A+FFyUIizaImPYY0AP4KgY4uY+Of3IBzpdzLBcRQ9FMWqQFQOMyd7iNtEL0pvI5AEAFesEZ7MlFl2sN2pYMjOe7EGF6P0IW8bJ8DZmIXbcch5fN5hjSPmOd46P4rnWFRaLCN9//yx6MiNzgoGwnjQNJoTqD4fRY1tvppjPYthPMzVyNXuGbBAHfsCQn97fQ2AZzSFFLg8O72B1OQUAnL2MYzltHQpC4k9OyJPogLMpywvYB3bTYMOSDk9y4Wq7bIS7Tn4yxNnG7jysXdu1Xdu1Xftvr11rD+t4fw/FoAJMNDlv3vlhAMCzs0uEjonOGPbFe/OrJNmxNzqIny0znC9YbCfJ39KgZg5X5KCtB7pWMMfxJfd9Ib7ksDqrUbKqVNiYVVahZcw7uICC1luXABQ6WbUJnKF6mG/Pz9cXLyaErw9bnk58cXaLDXsLpp2S6InqO3wsZxNzdfjeWio+DgiSUKarZTukHIXbUL4gNPDMZXgWzN69m+Pr36B1xsRxpjVKuqMrUkG0RRTVA4CB5GeMwfPz2CEf3DyIv71fYm8dLctmNU+sIfvRAEVVAqQzxLZzI2ALyTlGRgzmLkWmRvmUaxQYPxRQUAh0eiVevEJO2L5PCpL+Y9B5bXpWBk1Xu24dFizIZaoDFyuPl3SFntErfB4AK8WavJNm3qTSiryQ6/NJzFHAPiZ4QMQht1aBxA2otry91yZFGwJEPeTkKCbhx/oOVt+O7N/NRY3xW7GD24gjwCQrkiRFJ3nZ4HsOQRGMbBvUIsKYpqqCZ8hCCC+ODg6xvJKK6967RHpHCod7+H5JD8vqLtF8FCZ25r4ZgOhs3LwRvZrZaITZkrk45sKdbmEyQth1jN6s5ht04iUwB6Q2NV7OIyhjyTzkYFyk0pujjiCkUAIsAej4LBmvMaJ+oIjTrocOhlx8JUFjWK2RTWJ+bLQ/gOPi5MnpV2Y5csL7C0Wp+qsl2tO4Hk2O3wIA3Pn0Z+F57AVz4mbQQOl476ccuMtljRULkCcsDA5aY82Vq2XhuG1XUCTGFH7JxljkVfzNchE9RDU5QOMCujatfN+1XesNK9ND5IVBy8V/7ySiX37oM/8zlh9+AwDQPIyaL+fTGQIJMTMqDocQ8PJCXPj4gB3kFhuSTNJ7h1KAEUYJLj5WKdhEiySfBSii4rzjOVyX1GrrlU/xw+3oiGxKsnCVZZk2p03LcFeEsfFvCRNuJcJF7XeL5iSxOeg+zBWChB17ftP0/S3to4RYC/0kKQsBD/TMH73GVF+DJgg4v4Xs2PAhM3lAR2ohW/JBVhb3q3g1Vw0X78rDkMIiZ2IZnU2sIk5IcpVLwJNvPo2b1N47P4Tlp2N4B48f4WgZH5D1sxhSuXmi4IkEawhcCnmKHCZtsdwDG26MArTI8gyBNXK5CDJppHqXR2dx0lRLhbuyK/JB9qE/hxg+ZlAiCGkpF7350oOHwcsmLp4fvbhCqEjoOiIVjmkTQ8jaiSEU0jlEp8gDaaIIv4r3QfauLWLZ/h9qW2JFjidGEYCreQQtfP/JZwAAn967i+U0LuC/8+IbWCzJFkOC2JBn8ESRFXJi7dFyRx4wnKs7j4bUMIl2LPQ1chLG3FMDzBuREhH4p+pBIyJrE3p0Yj7gJpx3GPD+xOgIWmPl2elEAZbjAgc5kXKsn1wt19A0pBrCT7XqsJozjMl7HLgSS24cK74GD2TUS/vmgpIcizkOingOCi2jtSWmZLqwnBO+bZCTTFfo07TJMCJrTFUNExPOsOVmWA5Q7Mc5uqLW1hO9wZjySaaJlvlJNsCa9zTlWppVBarjeOywmQIAFs0CGzJdCHCizEp4ovwy9uXByQEUmVDmvKbG6mThW873Zn6O+dol9qBP0nYhwV3btV3btV27Fu1ae1jOOhiTIdBiy0Zx13/w1udwRp6xb3zjawCAzpgUDluDxJ3Y4GwaLYvn55Fzyw+ewdFD0PRoCjWApYUicGOjDCw1J6S2Ki8MVCYCg/HNctiH/0LY4hh0ktzuvRXPEE1T18lz2oYUf0xcMWyF8xIn3nYPCRIDHxN/fEWsMTmKYYs3kIdQKYrUmzeqDx2m8yvVM1OIO6J6j8+yD7SOXhaQEMpQdcD9O6xLeUHL0gJZLowjyY1L4dOUjzcZBiTYXC+iVfe133kf7/4v0fI3n7+P6TJ6A+XLONb1dIXjWZwDh0Na5AXQDclXxxBHDYeO6AwRRcwby6II4JJWejYAPqQH+fusx/rUXomb5DM0DP/YLIdl+HpE/7bwXR/228T3nnQVHtL1e7KJ8/hyHTBmjcCEnniXA2vWnjUVPYqipzBRieex96Z7tpKtySOfbaNztryqFPXlF8pMwdG78fSCF5spfC5SzB46ezXkNrMNDAEJhtfvYJPHXxIW3tU1WhLnJrCNATz7/4C1V7buoF6bhM4DeYqe9B5WRsBBTQ/KYoU91i+ZFK3oEEIM9a2YE7DdJmnWZPTyrdaYk12iZa+W4wzlhKTMG0pteIuBl3uK53W2hcvjOWpe8/z8FLf2DgAAh5PooTa6Qy3sLkwxFMMhlIiXioJ1AIbkMMyCxkBqwcjsMa07zKiyviLyv+7WaFhPchTJNPCiOUWuY2j3oIghxovVORrWUuUsP1nXU5zwGvfG8bqaZZuYSQqe1+sMnmAcWUPbtoGjfItjGcoKG9RDndbbT9J2Htau7dqu7dquXYt2rT2sCgomKORMggZa8RYNFozPri0lQgqFhu5HRjei7jZ4dBH/fraI+a/9wTmGBElIYWXbdMnzMPQerPUgPVcqBjalTlatFMlloffAlNbwjnkFYRfQDmX5qtfVdWHL6xHGeJWKXV+Bt/cYivReYquQgl/3caaAsOUlvWJNv/Y9kwEmEwAIefe2ACfByY9NEm6U4k3vFCyTrgIF90ohK1kRH50HFNpgsBc784j9d/4hUBrxZZgD0qq/EcnTuoCcFruwNFw+PMPBh9ESHH3+BpYH0aQ7Gkcvuj2vcXoZ810PyMx/sJkjfx7njFiUqgJGe2RHIBKjbkNiGD9aRIvyW2cav/PteE8k1Mb+uxYZs/iK7BbDukuAjhVFBT+atrikl/GcqI+XRYnHAoig9To8qWA4oNmQEOq8gSN0XWDhGlvlD1vSMy7Nj/49KxLzrwF3tt/T25yT8toGLNfxui4X03jf0+fY46TIT3JgFL2algwRNRqMRqNXriv4kODzgXNrvVyja4T5gIAXreAIYT+6FT2AzXy9Bbdn1KPIkviiJthAZ3nysBYu5jKt32Aon/MYGkBgxy2Z81wsWrT0mIQlozgsUbMUZr4gq4X30AX5DzUXBV1gwN9OKJfSZRmcipOrIZribL3CVR378C7BGZUO8PxtaeK1j45vomJxrybs/+zZS6iXcR4f3X8D1UHs35qSNLarMZvGvO38NM5zXzssZhFuf86C65e+xju3Pg8A+PynvgAA+NZH38DXPvrt+BvE7x3slYlgoV7Fh9dvDDKJEHFNC0YltpuDks/h6Ag1Ym5wkdHjnCgsfQ3ffnIP61pvWJ++fxer4HDlKFPBTWK5XmG6jAPVJBCCRRCXmbRNTWOxaeNgPGZo8AtvDLBHYMVcBsJ3KERLworcxzyp3soC0XQWGTWIJlxs101fe1MYg1oSlJyIXeeQk26hEFQXuqQga5mQbFu8sinJedXWZgOQnYFHkaO9Mh1eQQ6++lulVR+G6UtgUhM0JAAY/eom612/KUnrXEihu7RgepUWzZyLQPABG55wdCtuLhcv1mg6UuUQtFKavqZNNkzvXKIxEpkWbxw++vcfAAB+cDLB4d24SEwbLgg3b6F6M2oofX0T0aQHiyvcmMW5IDV8WbuGPRW9JhovBtBUuiXjDn7vSYfNKp7jnYN4LUeuwx7v3eai8VXAj2P9iiNC6+yswb/9SqTfmY7i9x583y3kzVMAgGKIcbQ/xuV8Gr9HQtEqR9qsqyDgAZfqkvwWY3OaH9tjLZI0sjnh44wYDEZj+13rdKLh8o71dd0cGMYF8947t3C2jvP7QlgoRgVsCnMRZJKFpKzdcmwW0xms6MUIyKN1kEdDZEvWbQe9ReQcj6vSZqf5rKsiQ7EXw4iNnvMVUFz6CqGJgkJdCyUbn+G9AU5Jm/TscazHuqX2cDCOgJ6G+k7nTy9gRkQHUnAsG3kQI4MjUim1psLjGTdDkuQGnSWUj4TSC2WgSfHUUcH4MlzAEH1bEJ0xON7D4CyeZGMGmBP0VDPloZXFPsPcw5sR/FKgwPk8nvv5Rdzs1tl7eDmN47QidLdxHa5YP7ZmLdfhZAjF+rt1Q/qnfB/5mOg/XqtzNqUAHNeJfDSG5TPOoUE+cPDrBl598g1rFxLctV3btV3btWvRrrWH9W/+1c/jCz/0B/HWg7cBAN1e3NUfzi6wIIHmYMwMn+vQSlIwhdGaVGv1wZPoJl+8G3CfasWezAJVhcR/J8KQeaFSIlszwdiGgJrhn0MWtzhrUWR9YldE7IKmJeiATmKGNDSs68M17jWvBXg1JJgAGFtw9Ne5/7YFPb9TYr0HTmx9LiVVRqXf9/VdqldU5ted9T2vIDt4O7SZzqH793KSry43AZb9ckmvanBvhMVHhBfzeJXq4fkSsnIa6NBD5oEIzAgsU3j4tcd488an45cJQw7NHM0yXv+CzBqz+w/w7E486GRFjrjZCuFXIxmouozXYkwORSDGt2lllguFNz4dPbZbP/ID8Z7WH2KvfhT7gUqsj24+wOzkjXjvTLQfPMjwB1gX+O1HUbV28dEUBVkjAmHhTaXQkmxXZD5aHxKzrmZUoCpzlK9RtSj0bCVI9XghCVCm8KrCx8bLbyF75LDITeJgBMsu9osqTdqyqpLIaMcxKbIskfzK3NJGoaT8T6DQ42a56cOWqbYCOKSXZAnIMFDJMhfPzWuNgvBsqcOqRgMcHMcwYrOKa0LdZjBk4pDvl1WFhvV+HedgNRjghJ6JxHK7aQ2VxejN0YRzaA6sohOCjINTlQ7FfjyHeFMeOjGmSLiigAMrK3opE2egCaOvuGDUMLiYx5PklkrAwxGKW/H65qZBxdBjs46ezuLyHLeq+N5tEuu6VZ14/torRh6ma3zYxIjEszp69qNyH7mOxz6fxmu9uDzHG7djWPLeUQyv56GA5droCABa2g1anuPMTgEA3758gWIc+0OEJpvZBvC2R199grbzsHZt13Zt13btWrRr7WH9+q/9Cl6cPsM9eljvvhVfK6yRi4BYohWvYPJoUqqS4ADXwNJCnM4p0/BihLduRQtK5LHbFaAZP3ZMjK7XGsU+Y/FS19oYrBl73tDyLcuATuoRPVAwb9O2ItqoUw5GWgD6XBOtUmOylHhOxb++/176re9h6CrZI/5jEPbt1ETyflxIvzApRxSF/eTv+P2QGL7ddpGwcNilpD+gBMIsJzYqFWsuaYWbAliy0PqUVvreJCA7lARv/O0YIVn2KdXmkcwuT5BMUSo0hN1ePLvA/bOYcxjcJXOGu8KQJIJVTCmgbRfoKHuQk2WgOx7jZRfl2i1j/BmaxOYvmJCqLNAcRw9g86V34znC25ieRYvYspj8cjjEgmwgfhqt5LvjET71hciFeXI7Wq2//7Wv4sPpJc/BHITyePsgFhErzjEPj4be+4zgIuOAnHkhx5Ftne/HmsXOrh+uJE0DBYTXyh+wDexhc8Gn8b+8isn3wzvHqZi4LuZY0GUu96KV7q1DUQ5evQbnejg7x91u2q3B5WUpYG9IJnJ6YirE3wOAoweukUFJET+fm/HhBIILOn8xjd9TDi3lhGYsTbEuQBFQsOa8nC+ukLGIYULAznK6wHIZ7/mIgorHx2OsXsR7RxvvZ6SGCETnnJ3Fz8ygQSaFvlQlUCqHFmYPenY+r5DR88y5ZtlCI4Ow2LNI2SjYo3iO4+E+9jhvV8/jgnPxYoaGAKfi6E68j2qEnAtSRqUK1Taporwh9F+3SxQEiEyKg9j3YZPCU/I868LArQUUJWueBjjGDaMote9QcyDG6xj1GrUFBus+X/9J2rXesPaPxnj89H08ehhDKR99K7qrx5MROi90J/G7Cj0JqUgBGDVIVe2OT+U3Phrg04dx9O4z9LOqm5T0FbLO5TqCAACgGkhtVq9CnGhSfIqaYJhpOKmhkKiN9x8DPwTfhwJlMJWyKfy3HTFJ+j/YYrCQB930AIXXiXNV+t+r9V1JM4oPvA0h6T+l5jU6otjSZNM9+su6foVLZLXpewZKv8oe4fMMUy4SK/ZlZ1sc7DMM1MUQXb2x8FyMe4Sj7+txuHnu7Y+wYfJ79VTh6W/GsN7n/5cfij/Z17ggymnIjXLceDgb38uJFFFVhpwTSGiTssphzHq/fQHE+Ao1Q1WNjRubvX0L5yXD0VwgdLfAmEgdy53hfDbDFQ2We5N4v3/of/ejGJ3E6/rgg8jUMoTCH3jrHQDAcRHDO1UxQM0F95uP4vcenb9IEiyBccDZuk4dltSDQ+hlUvoh/Bj7idn6PLXg0jyZ1wRdFGNMFVkNPOAkZExQRggZMgFYEF4blEfLh2M5j5ZD13avYzwwLLKE5pP76JoGnkATwyx+Zko49kdGqyKYgPOzyM+2mbL2rjBoKMfSsp7satqlDcsQEAWXo1sKq0kc8+HeLbQ+bkBnVCsuxgbHxzR2GGJ84+Au1gRsnBIsY10DkwmikUCi0QjC8FWx5sp5jY51dTWfm6v1EmtuJhWJastBnpgwHCwqH6/h+299FgBwf/8NPH4ZGTWWXA8n+/vpfIcnXB9mCmsWAzaUrhkPNYZZ7K8jEvDm+QSeILV6GkEo1dCmsHQh4eayQGDN1QGRJ/tljpo6bs0zUmpdOey7igAe0TT7z7ddSHDXdm3Xdm3XrkW71h5WMSxQwSZUweVlxBlvZhqH4xiKULR4gu+2CDaj9TI0IwQdraWWiczHL0o8XcaarJtVBGJobcC8aSJSNZnCfBWtQ8sqeJ0pZJR01bSkuq4Ps6gs3yLJ7D2jHgghUHEPQfb2oIs++f2KhyWfJuG60LNQSI2J1imMKC2KNb7uYm0xWITeS5IkuN66Tu97D1J+24tXCvgBSUivc70bKaKUBcEq8zXQDISYlL+1Co1AjkdkqGgdvBX+NsJlTZbIWcXbMzbDeMA6kXsFLk6j9TZ9GK3C4WfHWIjkBy1dC6BluEMAL6UKGKRQmknncIROryRc12wwoPXrZyRJLYGsJtNFYFjHO2QEfihOiqbZoGD4Zz6dxt8WCp//wpcAAHsEbJw9f4bTlrBsuue38xvYJ8T5M7ffYN9bnNFTtAzLmK6FFgAG550x7uNsJejbdgg5cVLKD1w/v2XMa6+g8zHfy1PN3oySGIcHN9O5JayUaQNHKPmScG/3em0EgGE5gKWbLwAF61wqbZCykKAAQ68yJ49j3WywIHHt/iSGJEfeYLEgiIaeWDUYYLqI3oOUakzKITIdfyMweD3aQyD45clp9Nzv3R3jwf3o9RqCLu4dnWBJMtLLMq4xrtDIyRRxRUmRVVOjZXgwUP5msVphM2X9F/kDq+EYQ0OyQdaY2kVAINny7118gIdlLAL83DufAgAcH+xhybl12cZxuKpX2LC0RlEm5WT/AKGL0YDayrEbTPbI70g2jc16BS8EqwSK2WUDxTB3y6hLnSuEw3jdeSX1cBkG7P8Zz3H6coE39m+kufJJ2s7D2rVd27Vd27Vr0a61hzUohtjUHYIk27n9jvNiq/A17vSd0kkyvCJwwps2iS9qxpvD8RL/8Vkk2frUAzITuwyePG+diRbgeOhhmWCdU8LEKY/x3fi3WKiuVSBzP5pNA81YuCSejem9qJ53T/UABynQ9Eij9ao94l95L+hYaR7/wW9Y1UPSRRI+oJeN2PbcJB3EjlG69+ISW4W3PVu6MMhblQQopWzAeiDLJJchuROXwBEb8Ty1wYZeVEmIaxd8Ys2XguDQlQhk8s4S/byD501pFjUuFy3GdFuLSY4NBeve+9a3AABfuP0p3DiKcfkrG49Xo4FmB4unsW7WUIoAHc6ZXCsoVvQL6KLwGpo8eRty/6lVSCzzju6y8oAV2QvSaeRDgzW562oyYS8bIHsa5+PhfuR4G+0d4Pe+8XsAgK98GHO29w6OcI+fjwtChscZTEe49YbJ+VxB0WMWKQetVMo5blE1JomVnoZSIZe8F7vc+JDqC0RhoL2Y4c2DWBQ9Xa7RMI8mhaNFXmx57/TyPGDcq4weYUs0UyRT9sohDK+7pmUfymxL+JLHMA0Cc5IDMizsnxwjUOq9o0io6TK0l0K+Gc91dOsEHaU/phfRG1GZwvE4HidwXip4qCp6Oof7N3h9Bodkghcpoc5PUVYxynPzZgRn1MbD00e8uopjvlrX0IexP0ouFFU+wpKe+oIe1JufOcHefvSCzi8/AgCsrzoonjcUOT5sorzLkw/ignRUjZFtmCvzZF6fzQBK4ZgQ38s3Fvucv5c6nu8qc2gpkdS4CABaqi4Jp0oOs+wsNMdpwfk7bzsM6ji/hx1FTF2N48N4/cdvxgjWorX4yNavKEx8t3atN6xMAaMqT2SZY6Jk7hwe4eLFBb8TJ2lnfArDBYYQXKfgkuJpfF01Dc4uogv/3lFcBD6rVqh5nJJFJBYKeUfEDwdstu6ADUeylZCgS+CMDiHVXHFtjXNH2CC8gBZeRdoBiAiu9L34anQfcpOwnQd6tdjUwsdQgtvgjO0NKykdKwmzIZHzpktRqgeOpaIvk6htZGFS2EI08uvWh7ShMTqG6UBjICE3XuDKBVheg9Ty6HFAxsS0pcJvs7ZwDK8pnsvagJpGxP2BhqcuFeXQ8OKDMzzgQp9TT6ppGlhuSo1IVDiNgrt0lei1AlDKCs+kdQgJzeUZPlstFjDszKKMC5zReY+47FfvfuyCbNYONYluJXxyZ+8Yn7v3RQDA5ixe08vFEhsVw0ADGlJuYzBfMRFPkENmSrQkbzYyNiGkRUdUiF9RsE4kEiGhY1KJnkIa95ZGymy1xJ2b0VrraoVqEJ+dAdGCutSoJa5ONFnQSFBFAT0FhSS9MmCfGqWSbIfUSCEoGFF0Zl8apyE414yAncO7N9A9iRtQTRJkozWOmDKYtrFfauWwdxI3loyh5WJlYQj4STO+mWE44TrzbjRsBxlw8TyiSa+m8RyLhUVRxA1NcaKPK4sNj3fM+j/XGazIXDIKcZ4cVfvYO4rnmweCg/Q8PYieDBTd3GNESPD+kYLeo4aWSA0tWiw56TelgDg0buQxfHlvEg2MEharOq6XyerPgQ1BF4raV8pkqeatE+kfk6Ngn48EEdwCZsH5KMhmVHDc0Co+pzeOFNZDB98FXP42PlHbhQR3bdd2bdd27Vq0a+1h1e0agzKHpjV6wAT0qKjwgu+JparzPMFWRXgxKAVFK7gT7q7OJcj275zFcMCb+y0c612E1LGrHRyhoiUz88MAFLQyJElsuwBH60YZiBgvtKjsuZ6sdPv1Y4SjW5/3XpJK3sx23lIZvNLClse1jb14DYcRf5vKAPrY0Mch8SER+tZUbnUuIK/kNwwHOZ88NVEeDuhvzvDiX4584rWrGJpArjBjuG5FyzzPPQYMJ2mRZ+mQCGUD3QeV9USc85cNjt6kbMQyHuf0yRJ7d6ILNrgXrWosNAaEOFf7cdwXXYuVoF/oVZlKwaa6O3K7lUO0tDxXZFhxQ9PLaFipTxklvrhexDckgmNxeYL1aIj8EFmLFxdz3JpEi/1L3x/h+S8ff4jZlJyDhn11GXB5Ea8l2xdvox9N6d7t8K+cX4WQ+PlEwFGhnwtaON9Cb8U79vnatqgZ6u22yHaHFaH9JsprxOP0E69jPaIXEb/g07WOBpL0V+iEeFlATzpPfSPXnEGjYJTFEUzz4vIFBvTK3qbAq8stpkuG+1fR+2psk0pEbt2Jc0JdNlAMqxqea5MFrEIEZ2iRRlH7ODl8O/6mpTdtANCzX/EZWVy0qAh0ePN29PCqFfB0TSAOibFrt4KhUuWde/F7YS/HnOuNGTF6s6qxnHINGha4dTuGuSdZ/M161WDWxfsLhPEro7GuY59PbXwGbgwqYJ/rGwcubzPka4a5OTS1a9EkhU+W/HQu1cNxicQtM8Keih7uil5VU2aoGAkpGQI/gcfCxLUrBjm/e9t5WLu2a7u2a7t2Ldq19rCUUcjLHN2a8XSa/VezGTrJpxiJ42dJbM6zYjzTBhmLDjsRpOs8PGPYp+Sj+ygf4l4+jd+jFWxgACbkO4GAFwqOuQTht9P5FpOFi1YgAGQCVuhcXxAsuSIbKbbim/39hlcdGITQs6sLK7pS21aI5Le2vKPXkOzb7wE9sMKkE6tkbadi56AS9Fe+H6XZhYCwh+wnoMZWjkKU5Wue91neoSUHY04RvbIoUPIaNoQRtyFPP9ZMHNsAWGHNJ9ggUyVCET2OaV0i47yYHDLX+CLg9IPnAIDP3I4WeeZLdF+PUON5iHmh4/u3kZGLrWbBr0HAUIApzGdmRQHPivGa86j1DmN6CJLXskZDHjnJuyilt1xmei3oPZTKkXXDODxfxjzDIb3Vz95/E0se7+WzWCBaLxaYSG6Kk8h1IeVMk+PuVRKllLGOxClbCUgA8AG+h/Twt7HoFwBMLjmogMuaBIhljuUq/j0+pjBqZ1FkUn3v5AKxIcO3FUmRAJTMMQ8Kekut3YI+q3StAmFPl2c7VPvxvaNbMU+zcTMMOGc+d+8BAKAtAr76/sN4DmEQb20SILRl7NPqVo5jTfl6vl62LU4XMV91yuJZPanw7s1Y1J0RVGXyFpoez8NHca49//ACe8PovVX75Cg0FrqkFBKLd3VeoJDx4h/LzmLBiMP4Btnf0cLSm960LcaMKkwv4hycPaxRsxDYk00+bywMSzD2xvHeNwOPSx+9xpYu+ESdYEEx1ZbMHllewIjkEtfQWVsnjtY37kYP9t6NW1A1QWUsa9i4GmN6YkdDRjB8jlkzS+Kcn6Rd6w0rzwpY75GRQ0mUfTddB82BVikxDmhJHjNxnxkNRaiXY/JSK51qpdZMkP7ueB83wUSnF3Sfwj7LIirWyqxciU7Hwe2CyIiotGEVAchIZukYqlRuayMQOZMOH1MIjidNmfD0WQoJykZjXq2hAmJt1vYmBzDBn7x7CQP1if9t0l3/WhgrhEhcCgAlw4DGazgvCVZu1kYlSZKGRoBTSGSrc6L/mgxYcbfOwFCNLjFgGCZwN7YuYMTvHbB+ZjQaoiEbxelVfMCWCw8vVsBY4fIFf/MGr3nPYT2N4/TiG7F27/Zn38H8ftwswTDL/niAcCMuLLNZDJ8o71AQMVZR3RbaoOPC1/H6rG1hhaxWasuChxYCZgn/Bp3qk3QiFFbIiFhbLrkbawuvSAvG0Job7+PoNhP/s/jeQe1wyFDgBVlcpmhhubZLKDdelyA4e2DMx+E6/f8TXVjRIwwTmS6AxSo+I7fvPsB0ExdzURxW2qLiRiDEuK7usCapsN0QSemAUUUEJfvDNk0fvhSjSJtkBAl4JM8UNAE6dRvHt+2W2CMAQ2jCZleXaAl0EFokbFboSNk24+Q3eQZdMtzP8biT5bBNvI8lx3UwVJjVEaE37+I8yXXAvUncxEYk5T64pXFwGIEpNTf8s/kpVkTjBRo7wz1AU5lY6j61yzEhwnBM/bx2P0Mt0iQVYGlAvyBryNIDIyHv5b25Yo0hn/GjGxF4VA0q1Jdx7LIqnvfo4D7WwziG7hEV2jUwmcSNe8P1bVa3KE4EkRlBHNgfYFkybFrEvl8/WeLqPL4X7sc5mx2MYF+eJeP3k7RdSHDXdm3Xdm3XrkW71h5WpQtsug6K3pGAKsIWs6dYX8oj8QAmq1ApOAmf8JitD/A1JQyquPN/2BR4xtDhTVpunfYwtKDn5+TBUzXKQ7re4vFrk3jjcrVVLyVhti3ghDBAuPCqJIhcsiS9k4BjQC9+tlV65aXmCZJoRQ9d5/ESpL3/Kc+z5YEhMjv0vIf9q3SlSV6fQt2+CtnOMx3lKQBYhkBb15OVPieEtmj7eh2ppVJQ0AwVVLyWYwTsjaLFVogIZJZjcBjfe/cmrXmb4f/b3p/FWpalZ6HoN8aY3ep3E21GRDbVu8q+tg/HXBfoIIS49pE4T/AMRuK+lKqQwDwgEBIIJErwwpPhCZknCwkJhGQQoocD2CD7YkxVVmVWVlZmRGRE7Njd2qudzWjuw/j+MVdEFM4s37rneOP1S1U7cu+15hxzzDHn+Jvv/75LhhJX2QbYsus+j79z0xrbIbnhEFM0+uQKP/l6vGmrYxLxXjyGvh292u0z9qxsWihGiy17VorpAAvOQ0PuPO9brCgEqW2MGEbjKQpGjYqMCMEGCHuypJbb1qZ5E9mQrm4QuEoV2QMeri6xogd783PRcx9e5VgR03+5oiCh7RVEesHhkLjferkdvKL6uYvLSUvNqJQNCIxGfPDYMHVUlBVypvMkzZlnBbTugRUA0G4b1ATCiOqsUkCZv5gd8dZCp2c2HsJk2U4mlRmCMkfHAOyDDz8AADy4fwvjW9Gjf3wV5+PJR+dwLBVUXE9nT89RM5XWkcdxtV7g+fMYrVyM4px+dnIDkq+T817VG2xcjB7GzLZYa/H+u+/E8Rfx/F/88f8nFozsPqSa73Z7CkWZGgGWLGyL/EY8TsYMxqEqURBw1HRxTFZ7uCH7AwdlgpzbkhHigwr5WNYZ2VlWW4wZwQ5vRxBPrgeoGMnnXJd5qzCiNNO2inO0PH2OMQl9R3x+QlZhREXni/U8zuXmBDm5EI+mMSV847WbeIqYGj3J4hzcGpW4NxzDdh7fAvkFP8b2Edbe9ra3ve3tWti1jrBUB+igE0fgiBLd3WqdJLdD1zM6d4x0xDMqhkNk9Iy3q+ghbbo2wa01Yan5fINvm+iJHxEiO1IWGx+n7+qUefBMYzgSyn7C6q1Lch9FBXReOlDjjxDQRyEp0noVJKHVS/Us+XKajP4fKRKSupbGK8WJFyIsgYjr/mNS67Ddi6CMNOZUA2cOPYQEmZZGY+d9X/dixNs4j4bjfkQQjJsDGZkRqnH03AptUu2nJWdf8ANsKGz39OlVusjX3ohe3OQ43v8t1jgg3HqsNe79eJTvqAi1nV88xtXteL67938UAHBrdIT/+s3/DAD49nuRIy73CneOYn7+LcoznK8eYk3xzUkRz6czk0QdnbDme5tg7SIM6JcWU4rsFXn0XrVS/eeEzWG7TfM1zOIctcHBZIxaKJhXeI3NJrIQLI/jujy6PUD9KF77SlordASnA0ApxfzGQ/MmshX3hRpWYtJX/e8Slkapnl+S0bwKCo4FS+c8KnrxnudwHmgJEbdkg9g0TWo8FgqWojDIGG1Ls7B3DoavKoFQ58qkOqUS8Mt4BDMirJz157wYYs5zbBlxblqFVnIqoq5QDaBYr5KaqVF5EhY9YX0ozx3WjjyFLv7xqBjikPewYpN4s7Z48jDWR7uC7wxYvPMospRcnkRgz2tThQmjGoGPrz1Q81ZbRRb71RXuj2LNyfG5WNs21dsmncPmkoKnBeU7JgMM2aCcb1hH2+b47K3INfjW/TcBAN89e4KnTaxXTRVBPk5jyeu7+ekYvecDg8uncdwFYfL37x9jw+ubr2IUWlQVGuFMXT4CAPzI7ACTm/HYH+YEc2CIg3yELngAbFz+GLvWG5ZttxgOq8RSoWTxW42KRJwbpmi0McglBcLAUnUZNPVuhqOYDuo2V/Ds+1HMe0yre/gem/TfGkSE0P3OY8Dzngrd0dhhwQVU1CzIjpDyKnUDjBiiyw31dqfPSZCDoSewFcLZgB0ghKRHfC8RIS+TEEKvVcPfuU6nf/ud5qvEfSt9UzbsoPnYe5OpHd2hfiwvb2Ku61988iet8tTXoZlKqDbA29M4B2v2g9hWp5xsXsVztWgQRNqD6s1XfotA7Z0Ru/onwyk6EnI2LIYrZbAVxJ3v8Pjp+wCALr2ZKxyRLPY778bN6V8+/LeJ5sNTpmGQZ+hEgOt23GiuHim8TuReoHQCMiDwhbag2quCRcUXh1AbdW2HqzqmzcqKL3fvUbFXSaihhnmOBfW3rvI46HKQQ7HQvSazSxc0co75+UXcwP2sRD6NL80JUXuqbtDxHrZWAAUaVnqfdgA2L/f/GShkPIdte6qVXDYVfrAajKAMwUdXKwyIBGt5wCas0ZCANacn1W03cPJik3a3cQbN8zUkl1bQaLn5GgKsVKahOX7D45WDAZZLskbcYE5+DDy2lNggxdBwPITGAYAe1QkzQ1jOAQA3mQIb3q5wRsXcdYhjOSlq2BUXEjeBtmmQz+LxnGgJaYWSlEuPTuL6y04+SqCiNw7iGvvSgwe4OInvlMZzQx0EnK54Xm5OB6rAjYIbuBGQhkvOxNW2RS7sOUNJ22XIxZdmSlApA3/FtfVRTNG9/+S7WDGdFyhDPTgeoaKDEYhe/Oxn7mN1ENOIJ8/j5rS8bOFZBtkwRX7VrjBi6rslK8fT4HCfyM17TIFu1+fYlBrdTl/ex9k+Jbi3ve1tb3u7FnatI6yqqtB6mzyyklIiVrVJGC4154eQvEwRkgtQCbBh+LMJ2OmLYv+MswiUdvjOKEI376lnYNYRBdOAYRYQjKQ4ejDEbu9TCnCkaO17OYvwQj/Oy6iLHoghvTK7H9utlafoR8g0XEj/kRgWdtKEEuE516cnxULYIef1/XfVS8V5bVQisRQphu0qJBkSUX72DjhlKlAK5KO7BQZV9L4CgRbziyuU9OIPSUDazC227JKfjiP89s7dG+mCDXuhZrMZ3DZO6rJb4HuPoof7qU/H1KDWE7z3fvS6v/VuFD6spmPcunmLlxTH0LkOS8pfDA9iFDS9NUX3UYwUMkmp2RY103kpZZVpOEb88rO1HkpUYzkHQffny0nUnOsKA0ZvZxSaDPAYD8nPN4wRXrtuE8OGYibh1G5wzH6L6SzO6XJ7kuZIoN2N9cjJwAIydvgd5VdhsoAP8FauqWcrCQkgxJ/wMFzcTb3F7CCmaWvOX9ts4AnfllC92TYJUi/gqMzolEaUbIAOIaULRNYkXkw8dykRalXhYEqxwSlBAbqBkvVGb3/VOIwY8ZeMxLNiAssQ5bXbMf3b5g7ffhRTVXMbo42DW2comJrtGC15q4BNvE+TLN6jTJWoy/jOmMziupxm45R9IEoen3njM7go431657sRpFGvFmiEYYWtM4OxoZomsKH0SH0OTBhVtWOD7ZjkuTdiFFQvPC6eXwIAjnQ8UKEcvvXkAwDAr5xHfon8xhCfvve5eC1n8bwnJ9vUznJFMt333DnuHMfSyPQmhU+Nhh+yDYHgotPFHB2jrXIaP//ddovvPI3PXCutHyZyqv4gfVj7CGtve9vb3vZ2LexaR1gOHjrTqePd7hRpxUF0EmJBIZf8tzBEBw+h9BPEgDJZ6s9tCYduihNoRmDv1dFj+EI5w+sZPW0WHdcAWF+HF1dAI8mLZDlhzNiJ/JRCJ6GTRCsaySNOEU/oo6NdLyMVx4UDLoQ+QiQMdlcg7ftxFArEvnM70OUEBAkpwgo755cjiiPYdR7FgFFqQ1BLYRMQY/M03psLHxLDhejRYdigdvN4PDLgq8LAI+a6LchYnRXIR/HvKzJvnzwvcOde9IirAYvh7Rqf/Wz0GP/1v/+3aKmF8va70aNczy3mp/Heje5Fb3Q4Gad6TEGGiszn6AS/zWbzya1bOPsg1kJuSmisAjJe6CCXZlyLTcumX0YmSpkkNdKweVYblbjwNMnYlAEy1s5yHa93vVyhkIb3xMto4fvCJgCgreskEz8Ysvg+neD5eZzfRprYBxlqgiQE3q5UH0ULUEhnOtUhpcFcZzq1PUjY74JD4FzVzQZ3GL1dSZ3SdSikMdoKuKROmYiCTdhFnsEJKzmvzAeHIicjidQZnUuN1iOyRmygUJQCzohzMP/gEYYq1vIGvLhVt0bDyM/zXOWgwnBEGDd/d3a+xPwxm7QZ3YyHBoHnWG15DAUY1lktQ6fDcYUJ62iOkPfL7y3Anl+sR3FtvD9/H5rHs4yQQjvCmMWnmi+U5eUGH1yQjKBhdFMcQ7MheDwbY02Jnm1LxouLC6zmcYwNYehllmHFF9Kc83IXE/iGUjhs9D5ZPE1jmE3uAQAefvABZrN4vNsUdwx5wPkqHsdTSuQom+H0WRS/PeW68xVQHDGDwOts2g1mwzHXUnzGP86u9YbVWYuiKBLKqmanffAelZBg1uyLCR6GBUVhnrDOQhXyMmGRPnhYPoQlP1+HJQYhpgK3JM38jXyGmwN2po+I5NJdYp6Vlw+U7xGBHqhJWdL3seiUJko9UjuEuD2YAi8z+AA7L5gX5EPkpcNipt89jkye3/l3Sh3ubljpY6+Q5IaAxK7Tsx8ArpadjanUEBI7x4BP/EkeYNnflpOgtKs9mivqXHGcx5MpAiFaoUu7YkrvSI/OarXCighPkohgMBjgG9/6TQBA41vMjuOmtFzE8Z2dv5+0qopByesMUNwIDFPLyll0RLTVfIlNZxOYw5j2MZx0VeTw3GBE4kQFBSsgIM5pUWhkQtOUnKsecpmRRLmzNvWyVUVMh65XCyznC/6uf2wD146QyKqgsVzEl6FW8ToOpkfYbNhjtKTMCBQkQyfrPctU0kGTdVWUedpUulrYWwx0LgAXnt87ZKJMW69gZP17IbztU56NaEbVXXKMcjkeNLqOoKfkjIVEFi3z61xIqsKVpD53Fqbfxpfo1ZMTTKZEu9044nW2eKwigu9qE8dXADgMca5Xz6JDkvsKnz+OyLz1Jh7vDiYoywjKyC7juVbLBfKpjI+ksKHBm597CwBweCue/x39blqr0iP53z56DOtJ+8VnZXo8g1szjRxvOVbeo6ZC9y2Caqqywgk3J5y1wIfx2HUTX/75SOPmG3Fuas7ptmtSPjK/JLrv/Ut8V5Mg+LMRQfhjP/YZ2KuYDl0wHfpjP/kZHB9Eh/2czC+DWYGC6dfmIm5SUBVaApeunsWxaA/cJDJ3dofabaGBsQ7YpwT3tre97W1v/7PZtY6w8iwKiml6NTlZEjoVYFnYE+mEoPROqNADI3JGWJmK6YD1qsB2G9NFgxG9ZTUF6F16dus/XFf4Bnm3vmSiGxRchnbHWwVixOB3QhlLb0KG4r1PXHK92q/v02+hh7D3AAc5nk59P/JDhZ5OQxgMlA6JQzDsjCXsgjewIyqIPsLa9X0E6h5CLwiZOHq1SsJ38mWrFVr26XSL+MmnIwNFb3p7SXjuoj/RkJB3v26xIXxYtRJFOOSkJhlPRhxfwHId558BMQZVgcU6eoDj6RhXBGCU5EFTAw01iikNx8J5WQ5TdCyBUQaVGBsSM0mhMbgRvdaMzA618rhiQb9msXmY5YnPUlRyu7ZFIZEYL9h1Fg1hvpKeNplJ6VzNvphBniU12y1DozLPdlKMXTpGwQixpopvZjReuxWJSTMd5+Xp2QLMkKe0bsCLET0Q1ZIlJdjL1Og0Plls3tsk26OaOqXkAudXIaT0YbOI1xHakDgpDUNr5V26AWnN6h655HbS5wX7LlW6kDYBU5YX8/i7xkL0YFbsw2qCw4D3piQjQ1O3YMsbKl7b7ckUR3djVPaQgIErewHNtoJblJy5oSYpKl8TAr41W7z3wX+Ls8X+o82bHdQqfq59j7yn31tjNIpjGN6OY18sltiSbFkLKCUDMqbc7ZC9iKsFOioJF20HwyjwQAtoRWNMtoorAsjqboXAedArRre2wMHtOJe3pnF8P/LZH0VD7sJf/dV/F49nNRYMvc+YuXrtcIzjg/hdxxT51uXw+iCeI4/R6vZqhZbr45yZAjNUKPMR32FMn3+M7SOsve1tb3vb27Wwax1hKegXQAbipbedTXn5Lum1+1Sg2eU/Y2oXGb2SojBoWSRv2Sx47G/gik2PAok33QjfMdEjCs+jFzEKOxB2lnOyWQQzxHO8CpxwNux4tQRa+J6lQmpAPiC5vVowwFr1n5M5USodO0VJIaToyb1QVumL7UCM4hKBhTC4B5Xgxbv8huJsp95T30OT05BUjpbidHPOfXNk0JKZO1PRMy5nM2wuI4hicxXz5YvtJh0oE1AAAgrRXWCYud3WCDp6ZyUL9+vFMkUem3YDT08xJyBCVwrOER2jZpyDrGe8l1OoHqovTCi6ylEe8L7TC26NQsso2fJmOwsY1po014zvbAJ2SM1UhV7sU+plxlRJ6M+xiX1YaSgyD2zZEOy69hX2k8a61L6h65pztMbduxGyPyETxPPTRWJDD1rYKHbEJGlt59J9KGTtm57Hz8jiVQZ9QdOja1iXkXkJAZ4AjC2bexEAwwxHLpGnDTuEmzycMpAnJsHgsxLl9CDOG4EYhS6h+HfPqHVdd3h0Ef+tqX6y9R63hrEOdTCc8XM1Ctb8CtZflus1SkfevUmsb32wPceSIIMpoeKTcorAAqp1MXo5X87x4WV8Z5irOL7ZW4eY8J1Rn8UoozhpcIt8lcOWbPZqgJI1paJglidTqFjD8mzgXi8tbgp8fFLhlDfFi/JE44Bzrh+ut/Wyw5CpiANItJ/h4BZrpXkcV7M5xcEoMlwc3Iugi0cPH0PxWTPMPj05eYiL53xgiKbKKoMZWy/0QTzupbIoKDyZV5IO6rDdbODsS4v4t7BrvWF561FUZXpoaoakTdNBCcAigRIs8pRNkMKtxYZpEz43MEojIxqrI8Bi6a8gXx7IYtjW2PAl/D6bJd5ULQZBXj5cPEEjMIy2vt9c0zOp44sCQNK28jt/l5vpXOjVRSQluIPgk6db76T/dl89UkS3PF7YSfZJsdf7XdID2dJC2uTcbsow7ZlM5agMNTd4S1oku3Xo5vHfz48Ibhh5+BVTKTdiAffezXt47+34EttcCt2OhjkiOa6cY2nR8IZmHGk1GKT86oaAgqHJUTeSBjyEtXxQStlAFFJTC2lxymKUNie5dmdbbNdxo+oyWScjiKpIyZRwVxQAaX0UNYustfByioRg8Yl+KxNdNaUReEOFishmOlFVgUCG4DVypp2E0aPr6gQ4kl7EzivUXBTC1JJ1FucX0SGQDb8skUAhotjhHfo9R5hOtErOizCiRByRSKHINfZpYqUUtkTQyYaVQScWiK1IpqiY/oz/5Nq3FsoLSwnT63nWDywBWCoMJxGFV/P+52WGjpv0chHv28p7CC7B8Jo6aJRMVTZd3HyyrIIl7df8PG40rt5gxVTv8HY8lylLGFKyWR/Tbcs2R7cmdRN7ktRVwIqO2RGdhOJUY0AAy+GduPaL0TTpf225pqdVhVz6OUM8/6QoMS3idxZb9pNWI8zYg3ZaL7BKQKj43cPJCDkX602ykGQri/laNNvi54ZZQMaU7JZr5uT0Izz8RuxRfHIusj0LWEX5HxJXV5MhhoYoTaI/n12c4S4Zfb702cgo8/RRwHIT1+CAaF6bZVg2DcIPkOfbpwT3tre97W1v18KudYSlYWBUlvJsoh4MpRHoNiphnPAqpb4KRhR5liMEIbqkh5y51AMTyBV25q9wmxIWBYXVmvUKgTDddhYLs+v2DIY5spZ9Nu3Wp+hNAYnPMCRvvlfv3U1tJgkR3/+NwUCKeBDC94W1q5dkSAL63+2KMCId5tXIDjvBRlIN3mW68C9+F94BTLm19JALZaDoGT9j5Lk+c1CGLQc827bZQNHrKgnjbm2XelYUQRNZlveMHkzBmLxCJ9IJhChre4VyFo8TOoWcPVJyX23XJoJVuU7rXJKwSBFWCMgJAS6EFSIEbEhxcjSJ6aStcQB5/iQK9ggp9eyCRDwBXhg/mMvNdJTKiGOgQON6BUuRxrKUVKNJvV4VWTBCbVOa2yem4wDFaEtuXNcBC0YelYAkdL/uRPUj3iu182+gyPLUJybktj44mOzFKN5DpUWlAlBToicjI4n2KvV92UYAUYARgIBg7Fu308RIQIYuUvQrEVs5qGCFo1HGGkJSDr+axygpaJ3ANpYQodw7bAbx30v+bjTIU4RbNzE6m2YGW6aOOwKA8qlJ/WaFrNlBgTWjByYX8PqdIYqa7weCG87/63OUN+O7YnIUo5KucFhQpmgtkefmEhXBIyrE6Mxaja2T1Hz8eTAeY2vj554+fEJFa8BP4nycmxYtx1WMYnT24MYBMqbf5xXX8fFNHIzjeFqmtI0a4/QyAnTs00uebwAzFqALsx66FyCdvBb/lt8uociV+XQdoe6r7QLHgzjW4xuRBeXp6gKny/kLDCsfZ/sIa29729ve9nYt7FpHWCpoGBg09EY1c/ymquDppeWh57BLoAt6vFplMNLoS2/NAdAFYdJsfjQ2oKV32JSsKRQlWilub8kmbjUORF6EnnazdSipvB7Qc7kFkQnfbf4VeHFQqZ4lPr/ZYb+wL9Va4nfiT+f6+tIuG4U43btqjfLtxGQBJBdG7URu4eWozKsdDjnOkfIIrK00TrzvFs8ci8ysW5nLAH1EpnIyQZydXKSOYc9mbKe2aLf0nFmDUiEgY5gp963xgCdKQtO99V1sMQAAv7YoDsh1Jtoa8NCigUKOeatqFGTul+bpTbvFiMXjkbB41BuAfHFbSoRs1RI556HkPWp3uBXTjTAGTiD/AvxRITGfS0N71zZwjI40Idk6K+Aoed/ffZ8ahsU8bGom7iNnjS3bKRRrsVmVoWPkGnwqiiJFNfJc7JBJSqM0tEcwEvkLF6BKdS0EDctxVazP1W2H1WLzwnwYA5QSqbGepm1Pcilinpkq0JLxwbB2MjocYyW/mxzweCFxP7Yt2S1yg5GS1gSpYbpUz1qFuAav2nkClRzcjNmUQzOCIehhS0WA9kxH7kAAekTZmMMAHXvTMXdkbDic4LaNv3z316LExvJJh1uco6uM75hprFkCgGXkbucdDLNFNd9VK2ehGVlVZJMf2xKB9ahyYzAmiMnxXbbSG8wZ6XxjHqOl+6MBDmbx2K/lMQtRljng43GmzH6YPMfR/Qdx3Ave61zh6IaIPrJWu13CBSoQsCVlMj7G+SLO/7cfRnYZBI8pW1LaRbzOdtlAbRWi0MEni7Ku9YblQ4D1vi842/7hFVCBoPpaZ1P678W+I/YiSKe9LmFYoOxMKwdDw3SBIhmpRp5ebKKQ2oQcpwRq8D0DpYBRJ2i8gCU3qko2Luf7VJu830JIG8sLKT/ZxHY2nXS9u597qYlKAS9sVPLLtP/sMFkkxCD/27k+JZiK6rpHEyb5E6VRSxFfUlFe44PT+KVlzLJgoEoccX4DX2rbi+cYUEXX8OVysnGopR9KkIGbsENFJekplQYh6d+gFSyZHULQGB4dxOPwJaCMSS9koQlSoV8rljokkXGEABKmeq116XM1C9CqNPDcJAxnpjAaljelE5SltWld9kwiCp0X1KpsCCbJgBSZbCC9zllyWJSHomcha9/B9yrQ9FysDv2SSNxLqpefSaCLAI3dhRT9F0kny/xqtbtpCjhDpzSi9z5RRiU9rMZiTcmUXuNNpYdR6NJ88OncuYA9XJcomeRZr4YjOJLQrgj1NZmHpwRHNREwjU5zVJJ2KFMawbL/Uuip0OAu0Z+feSO+qHVX4gn7ubZkSfEbl5yJVvH9kDss2UdUUw6myBxuU4n37lE87k0o/Oib3NliNhnPu0tsCLBgKyj8FlhQgVeN4nfzvEpk3MqLHppDS72rGwcTjDbxOTn5MGpWTY8DyvsxTbckAvXR8gpbMrXcvRX/VtsaHUEywxuR5mzlWzxeRvRzPYjjGw0HGAwpgcP71cHBK2FqifM7PTzA8TTOYXsax3qxfZ5km9bLmCa03QZdofoM8CewfUpwb3vb2972di3sWkdYkffP7hStpXqsX5HYsK5NoUTB1JD3oU+fMB2UmQJZRiXZUrqv+96WQFGzTAdk9NI7FpG7ssKpdNML7DpT6JjaguvB5H4nwkq4hZ2fL/fXuNCnqnZbZXpwhvyi7+dK0HS145nszEtKI/ZOd4qYUjpxN8KSr6qdgI2/7LxHR6j4iGmgs2WOD57FI63aOILZCCgIB2599LgHJTAiryDYt7XtNJBLarG/DgEUKEY5WquegYMXaYOHrkVeRqOQVFBidDXJs/fSJGdd4iIUSHeW5dC8x7vQfknh1fSMi1GR+NCMYB+UShGT2lEhlnlLDCwhpHUr/XWZMWgJHuiSDEl/c0I6RuLkRSeRog+JDDoJfSqV+tKEy1KbsKMGzfHpsCMbIt/VaU0IvN0ondoZkuqvMmn83vdRkmCe3MYmyRdJv+d5lgAdPrV0+NSmIiAYZ22S/xkR6BKUIdwdCJSc6doWhvdkzLRXVRg4Rsf1ls+zRSoZBB3nY1RpHLCkMCNR7GLhsGCabkVggMkV8pz9g8z1596BgR0k52vWHjfvRyDDWz8VU3jbkzPcYHSjCMm/eLrCpp0DADaMzupRSK0rhxPyViLDah6zO4ndppigaWJqsTbrBOhpDTlOywJD3jsR+my6IVZsNZmzl6ocGmxIWjjJ4nw0yxUefvhBPA3fl8OyRKeldYUp9zJP3JBXVxGcgWyAL731KQDAl74QFb1/7Z3/iqWPgpFqwOerLXGgCrgu4Cnm+CS2j7D2tre97W1v18KudYTlQ4DzPnlx4lkao5OHJywCtqtT/UCK61qrPtcu8HbkyLLoEeksRgC2qdM5JJTJjcKQ8g1SbLbQ8Cp6RFdsGrxaBRyU/K4DmEZHK/Lero9mduHju9EMIHD1/t/xH99HXHEHmixRhNa+96bTcVVflJe/6X4sqV7id2D04mmHvg4oUZ8NAS3ncMQ/Prl0OGUhyibGiA5PCZN1rBGOJiY1G9dkGndGY0ApEeelmdjDUGZCszjsQx8qCt+fVQqadSjjNBTnWiVotEEQT1gYJ7xLtSupdWplyLLQc91lvoSmd1mzhpHnGZQXLjyBhfetEuYFtpIXphLehyT6KCADY3IoMhg0LQv7WiFDH5UBrPsIOkaiJO/TtUlDsAoBPolJxt/lZV+3lfqXQs9dKQKOenc9pehcpyi+T2qoF64pvBS+16tt+r4cu8izyH2JHjIPpaBFPiUx3/dAoqPbUTVh1bTIy55nEQCMdbi6YhTCebNKwZCZPyf3n4dHY2MmRDphMp1hTY7Dh/UJAGC98aj5HEuDs8tsYghRvK+H5QRvkQm+YV3zfH2F5WmMWl5/6z4AYNmssFrNAQDDivWjjcWC0PmGDBDtQKMQYVmOPasDPGHvwsm3qHKYRYTvm67DllF7fp/Ny0cGdbbl+eJc3Tm6h4cfxO8sGRChKnDJjMST8wjOKG2BAQEux4cRnDEe51hy3tZkjwlQWJDxZbuM17ZcNch1fDdmZYwk88McuY/jsiOu90bhyI5gWw98wgjrWm9YKvAB5UOTc/U1qktFebfzElKJlaHf2GQTs8IKgAxFESe7HLCTvqnhiURM6rbGYMC+gsAHfrneYFgQsMHeoKcXGxywh6c0Gk6YMESXSArkiEwYwIuMA5IW8ehTY0L2G3yvddUzXuwYXxomQ3pR7m6EiQoqfa6vpssGY/3OdyXFhD6dtJt8VOypWdRxWX14HrBKEhzxb4112F6QI4cbeQMNM+LLesa+nUynvi4hv0VQCekVmH70O3RCgsBURsErFq19gGcaV7XxJTEsKrR8oWl6Iip4BC99S7JmQkLQySals6zvVUvyLSGBAZQTVKrr2UQEtKBC6n3aVdWVDVk7SXPalJYU6qsIZHkREQj4HaAGf6NC2oDSOgnhFbqj4JHkdtwOYkfWm+H86rTqdjdKhX57kjTgi8cWR1DoplaLDSBAJPmmVpGKCUAQBhOdpV4r6WMzWqPkC3dCeYvLswsoztuYf8s7hUUr5MPsgWo3SV9rMoj33weg5vMuz1xr85T2s5QmUV5hkGifZMb73kJRUNFdQBnicy9yMF1R4MOHz3i9dE7tJjkdWyoUt96nnlFJhU4yA8f5F9aSbMdz9OxxW7SneDCK76ibxxMshfqKjlTTrNCRqNeQ7PnG8U1s+Yxd1nHjas9aZKSFev4kAiIGxQAzEv8ek9FFG+D5csVzExk4PIbKYprWcbNe1S2+9+S7cdwi87SY45jOwZZrbNF4HMzGsOb7vby+v+1Tgnvb2972trdrYdc7wlIK3vlULBd+Ng2VvDPxmos8S165pPC0zqANI7EgHpeH0SyqFtEjK/MRNi17VugedtZiSYh7YJ5lvV0BBwcAgFHFouqixUfncQzDsoRhqoc8m0AeUorES8SDkFIl8jvnQurhkUI7gkppRImwtOo94fRzB2DR873tQNf5jwx9GkmwCM7usF+oPiJ7GYqvoJBRmPHpVfzcw8s+nSfSGa7UkCPSIUN37lPqdnI/enPFJIcV8U3yAoag4NmXYoVdOOyAKST1ZgAtnfguwLIon8s5hkOcLaMXbcmw0HY1BpxYYZnQ2qQIxnEdOWfTPEhqyLYdlLCjGJEo8UkFuod4h8SP532fipYIUeQZQuhZN4yR9JiCldRdUgjWPaxcuAmzHQLenbD7JRwOvAWcRK7YAXMIckVycLsRVmp50OnvElnCq511B3R8Xmqmi7ptmw6ZSTtc8Im80qDnVpT58CnCAg74XDVc8PlgmLILiucdDjIckUFCZGGWi4uobAxgwB6+TGt4PoCnTAO26MCWQbRKOP0KjIVXdCvPeEAj6quZ9GH2nJOBqcMNPM74frj4Tvw5nuS4eyuyPLhVhJ6v/Lpfy1TqPhxO0EpalBG7th0qjjnT8q6yWPPfWWUwIGhEgEtVl6c0+DAGU5gcBXzu9ZhW/c3353F8Z3NUVBCWiLNxCjdvRgh+wX7J0FkMCZ56QtXtuQbusl9rQq7O4BewomJOUIixW8x0jMSyTRzz6eUFtvk4tYx8EttHWHvb2972trdrYdc6wjJKwQa/02Hfd9tKhCWeap7n8Cmn33uo4s1Jzck61xecyRVYlgPUG8pLiEMZPBbL6D2IbHfIgA29+MEk5tp9tcVjSgkcjixm4xddXe+BLO+bNPs/0VsVmZS2vzyB7irVgyNSg7EOvRfdo/xf4RyE7tm1Uz1K9TUMoabr3E50tjPmPmKLoy50DtFqeXIZ/3hZG5jBDtQZQMh01GEBEMixVjQlWsogbHycq9GbOY6LCH5ZkWNv6zs4L0wifQQivHaekVvru0TqpmDRMuc/5BimoxEuKGQoPH9Nu0XGmyuefZ7nCegidS3rLCqCbTybybfrGnrK6xToudHQfLwUeQbhXYpWdjkdE60h15HtLCrWZQKBGMHaFFHI55UKUZgUfeOtNipJcATWZJRWMGwRkChIhx44ERLIIQGmU50mQPXRmYzdhzQvKkHn+2M755IoZdD9hQqEvaDsiu88tBfwS8+2kcYoDeEqYDaNkdPVIkYrg8MjXK0iamAjdRptEQwVFpZXaaJnlLSfmZgxmeYFqnAAAGjmEWBR6w7lQO4/a5imQaHjfbCMRtdrjy0v6YrsEAtncTSgkCKbZ+f1BmES//2cciRPmi02bPR1DSVi4NHJg+/yfixkWQ9r1q0utvBzRjoF1+dhgdKQqWP5GErHENFyvQ0HBiUbpMEI0a/neP12rE19j/Dyc68QKPQIfn59tcL8OWtcOgItBl3Apw4jgMQ+iZ/75nvnWLE95f5bUSTUlwOsuvh8IWO0dzjAFUvXfk5Ox6XFanUJ18lT8fF2rTesJgCZydHyZTIStgSdQStS2uzsZfI4OsU0oDZJ0VVya23boMhkY+OLRmfIB7GY2rCXQxVFQnDJy+Lz9z8H5eN35hex4DqalLi8jAvjgzOHzwpiif0dYQSIPAyje5gAaFZ5BRG0CiF9Lkv9RoCQeyTpEWisicyTVKlv+hekyFGE0CMV00tUZUnOxKb+JEAnpouQPi/d6ZrAgyIDni3ii+HbH8UF3g09nGyuRGaaxsIJtJDja3WT8o52Ff8xf9/i4I34AN45jpv/95YfpMJz6gkyWUKq+aTSW8AxXWPdGlOSeJoQH7abR/fw8LvPeG7KJfg2Ed2GrCeSbR0L8NzEZtMpWpGO4U9dKoyIXqz5ctKhd3wSkCWoRNgqiEYFlUhZM+lP0ll6aYqMgzMh9UNJWtE532ujiTPhbAKSJAUVF+Cll03Obzw6uXbem05SXQBCEOSgRilQOplfY/rNbkcBOMgAM53AJUqyjhqQzK3i+HPVowRFx8ooBZUJICV+rhrP0NGxcJ56TGoDzfEbvtxPLp5h/lFMtYFp4Lt3x8io6Ltm+tepAboJN82a86s89I0hrz2OabuxMJRE4RShGTS4XBLsURPB11VYc7IvBhzffaAq4sbgLyi1UjmcUXn7cHiL5zqFY/p0exavsV41yO5yPibiTTapF3RAqrJsZdHQCei2FreJ5ps/n8efQeM1XtOYDNyr1RVq9nMFrrsONTYXcdwTbtBTp7F8/yEA4O3z+LmfeP0Gpg/iRHzqJ0ecj0u0NgI1Wt7gg5t3UG3js9uensf5XS5wMeX78k50MG4MBshRoOuLDh9r+5Tg3va2t73t7VrYtY6wlKK43Es5C61UL0Cnei7Blum6xBtYZshL9iwk1ooanfSqsIhZhb44b5k6VEahECU/eoLbzQqGLqUoqbZBwTJltWo0Ti7iwW9OxYsLWEuflqijGCBjeqjkWEZKJdj2SoqUxqARklGmjqxFShOJnIOrAKHjc1Qr7TaAaMRVTBf5xsNL1JWiAkCa+FsWgDvXn0PkUpaNxdsn8bunDVNbmdphW4ieqlN41U3avX283nrp8NHjmK6ZzmLklpdVlJNBn560ziNLaTFhTfAoWHTHsIIRUl6KdWJQYVAS4kwF1QjoYIqyZH+dVqg3MUW5kRRTAAZUqy2YBlq3dSIrLrmetnUNy1Sk5gXPylHqURMGCxssOoo+Okbnne8gXRQjSUF7/0r69wWmFkn5wr/SrxdC3x8mfT2RIQQvzBuAnZSgHE+ltG9/jtADP3ajyB0U/04mMB0wMWbIsX2fvlb9x6CT9Eq8+OFoAiupEoIaVrbDxsdJGnB8hydnIG4Ccy7urbHY8KuGoIX28gzrWsAKFDg8mKS2jDn7ndatxYJw+4ophU9XU4yaGKFcMVq9WNYJXNQx5TcdFNgiRtst5UMOsgzjQ6ZFhzGNtr2y8IGRvfSsXc0Ty8ut44N4/tkEV4R/XzI1ePtoBncZ52A4GKDK4to7HMT167ZblARyPDiKQIvQ1nj6JAozbucROFEGBwauuGI0+Nprn8aMrByT+nsAgNMPT/EOIe73yeJxdCfHGed3hfjdsLjAg+MIxJjHxwfnG4OCz8HRXUajxUN8b3sKn5pcP96u9YZl1IsblSCmgtJ9T04i0izhhLldei8aCy35Lvak5IXZqSXIy0wDXh5MaXRsE5Fo3zDZpY1qfhVv3v17b+LOzTcAAO/+yq+hXjAtSbTNZBzAlD4INIsPtpTjZKMpFBpJlfALTWORc6NqSLmzXXockMSzWxKduBgkGXTNvGORe4APVy0Nty7ANUIAy5QakHS4tBwDQMZ5aCnn+psnAd8+o4prfG5gdCmtVOjIdu0K3RfIUi5y56UptZbGY3HJF7iND/7seNzXHwXxqXpF3PSm9h4d75e1DigGnHP26EwHadNZ10TmaZ2Y54VlvdAlRoMD8EDxb51F18Z7OxpF1NO6QarPDLlh2c4mstiSSLNhUaW1tyar+LrdJKSn7OTKAKbq0368pH4D8v1PWasCCNyty0rdUIWQkJQ5U+VGZ2neWjauvsDkL5sUdHJOUvYvhMT+LzVPhT5lHD9EqihJ4ar+s4lf17r+uVL9bmZET40O4XAyRcu0ZODv1q7t64WiczUJqQ9SajYb10LzWRvwAtvLNeorpsCO4zrIBjlapuY0yaoz5FgToeq28Ri3pxPcvBdT1L9RxxraWbXEiP1GVyWP4TWOqzjXE8QxTeoK802876cktw1eIeNDYg7itY2mASjj33P2LwZjMbt1EOeDjtLRcIQ3778Vx/DwcULB3r4Tpe2LLEO9Yu2d77fWAGfsM9tQ789ai4pjHJHI+/TRY9z/0hcAAP/b//5/AAC++9EHuLiKqfT5ZXTg6tBAj+N4Didc+wvg/HHcDBXLNV17Dv2MdWpSQ3VVgO1K+C6gd4t/a9unBPe2t73tbW/Xwq51hAVET7tPaUjaI4Ni6iOxWoQAI+wC9ETatk66RDk9i6IokNEL7SxpcUJIqZSCaUJvVSpMa0Yoz8/OYZhudBxTMR5jRLXPblBgS/TUgp3xmG9EXgk5EXU6D+BwkPHYbu1TekiAHQZAl1KV9ApHGifsfSFoCNZvXu3D6gDNxi7X9dGjqLcmQtzQLxL5buuAJVMlG4JW1oe30G0iuEE8XufbVHxP4ZTu/7lrSf9LmnVylVz6pL0V+n4dJTRcWZYUT5POmVZoCVpY+wY1vcuSvzPBo6BXa9fxd5kv0nlKylYYM4DhMbNMUsc1Knq4jnNgjEF4Ia8ZyUoHVI3NE/rFo6PEgmiptc4nVpEeBWoSrVPPLhF6It5daWoB2wjIzoeEWu1Rdn0AY7QAOwy0yKR0UoT3KXLdvWshzbmkXG1Kn6aM3849NUrDcwFZApOgd+ieREXbBRjf91/F8fVku0K9pIocXnqMpImraXBjEKOjdhO9+U0RMBHyWyI5u/kV3JrwNGZTbg2HuDmLf18xVf10foIVr2ZISrA8ZACJrTfL+Jy9XT/Hmw9iL9VgED8/agOOiISaMCUyDBVub0jhJX2EVYn3L+LFkVEJx+MARUJZ6f/qgsdEHgPOfb21WFP7Ckxn+mWDcEo05rbGTckarOPB11olurQnZJ5wWiEQXNRK5J8PMSK6ccD7cHN8mObrvZOYmn/r3m2MRqSvokTISN8FFnFcy/NIbtsOcmTj2MN1s4hz5csbWH4QQRwPbt4AADTDFcqTE1gXEBXDPt72Edbe9ra3ve3tWtgPFGF9/etfxz/8h/8Q3/72tzEYDPD7ft/vw9/4G38Dn//859Nn6rrGn/tzfw5//+//fTRNg5/92Z/F3/7bfxu3SVoJAA8fPsRXvvIV/Jt/828wHo/xcz/3c/j617+OLPvBAj6tDZTSfQ8H3UxlMqiExZYcv+37OwQXrnpWgyCADK3gkn4HazvWphxwxmK+11rajhJ2viqr5P0Gektvv/MdFHRvRzenUCsWbEVhdziAZ9RVLunhKyBn/SMXQEamsE1ik8zdW5XYIgYURbx82kCRu9AZgY8HCPo5y6QfR6HkfHvWxja1xUZ46JC+ClfH78x5rufW4JIF7DCKUNrR0QRD9qotTokYyJD6oVIxrrP9weWnSlOdoqT0fSBFrVr3cOpeY6X35iWY08YkbrpiXGGzjB7gjH1T68tzTKTvhEVhY3qgjqyTrutekM/gJ1NEZ4UHLyh0XD8CM9faJBkKw3rQql1hXkdvv+HicQi9CrWIHXZdqnXlZX+dvXzHDrpB/pp6uUL6vdRGooK1qBATNKRyZImEVor+IfU09rI1vq+FvcAUw7lK/X097F5rlch2RanbaNUPV6LGgB1wFH8alSRk8kr6igzArIfUvIYqwxF71R5+SEHA1RoFmSTk2utW44rgl8kg/pwWfaCmc+Ew1FBrRr1XfLB9hzHvob4V5+hivkS1mAMA7hxH4MGw8thSnHLAZ2owy3B5EY/zjCKGw89sYQm2MlTgVguNlj1eqcS39piyjWZMGMnF5SUuV80Lc36FAEsF9LHJcNHEMRhhUUHA4JAAoWEEGTnVYcsa1nQYa7DbtUNLUt5yGOdtPLkNzWL0d9/7Zvzu6kZi2Zmfch0rk96JDcVrT+ZzOOpU3n5wAAC49cYNlOwfW7NVwLTAnWKITgWAgI2Psx9oh/h3/+7f4atf/Sp+6qd+CtZa/MW/+BfxMz/zM3j77bcxGsUJ+bN/9s/in/yTf4J/8A/+AWazGb72ta/hj/7RP4r/+B//Y7xo5/BH/sgfwZ07d/Cf/tN/wtOnT/En/sSfQJ7n+Ot//a//IMMBgkIIPqUsXGJeD+nllZ5tncPIIlV9ysQHpv1EsdXb9GC+YFIT1iJZnqEghVPfZxNgif7KmNMrM4e16MQUGYYT9i8wHYNpjpZ9ESMrhfYAxZuaEWFkfEj0UYqNdraNmjcAcE6F3XwyQsNdrpyyvwYGF1ROdVwXy1WLDYEadZu6jhEIxBAASOYVbCMLLf5uZRz8lClSaY5ttxiw6LoRxnVnd6igUuNWX5yXLJoBXmnF8Lu9VmxOzgyCsKHzEJHYVMnw41eDAmu9GB7M4NZx/oWstL5a4vi12OQ4vGCqDwGteCCq4RyoRPclirdd0AnQUWph/c8TjqQgOCDLBb+HtAhdCGjZ6yQqw0qFtAGKxpQKSMq/aXOKVxj/f+dXPU0TT6VC2rx0ck52cC5p1/BpgxlWRJV1Lm00SgAoIcAzBd03LO+Q30rq0oc0CKUj+hFA38BvVEr7ylhM0Dvj769D00HJSS7tg+odBjo0eabR8QXdMAVW1AGLBaWtSTVU3Z1hdRl/J1RUxjm0m3hNA37u9dEYjY8PR25lIx+i5c5W851w5T227PG6WcfruTs6wBM+Q0+5GdTeoub8P+FbdpZbHLG/qjvjmjjP0PDfFImArrIEepFnb6ZKTMd8d6QJ9DCEERcqx/YsrlvuC9i2LSzBG4Oj+F3lMxxmB5wPrlU4eDrYGe/71eUJbjKt9yNUSX78bI4q40baEklb1Ji9FkseUxedhcPLBXK+T8++9S4AoJwOoQj8+eZvfggAeO3mAA/evM8N+z18EvuBNqx/9s/+2Qv//ff+3t/DrVu38Ou//uv4A3/gD+Dq6gp/9+/+XfzSL/0S/tAf+kMAgF/8xV/Ej/zIj+BXf/VX8dM//dP45//8n+Ptt9/Gv/yX/xK3b9/GT/zET+Cv/bW/hj//5/88/spf+SsoiuL7nXpve9vb3vb2u9z+fwJdXF1Fz+boKHZ0//qv/zq6rsMf/sN/OH3mC1/4Al5//XX8yq/8Cn76p38av/Irv4If+7EfeyFF+LM/+7P4yle+gm9+85v4yZ/8yVfO0zRNonoBgMVi0f8xhJ6gVKQbQl9llvRC8Dr1SoiHasoMmXkJFux6T7BXRs3RMp3kd7r9K3qAHb3SPFdoGmrWkBFjVAzhCkpZFEXqBl/Rc5o0Gi3h0c2IWlpZhpxY52LDIv3ZBbYn0XsLBElYByyZ7mjoGa2bNf7I//v/BQDIDmPKYrtu8O1f/c/x71exQG3bnTQLoe7BWuht75XH6/U7fW592qYkDLXq4rWZZoSS17GdxTEtny9SUd1nTL1m4RWpCyjsIDvQW2rxES/epT4sSQP6ABimtCRl1tkOJWmdTFEiJ9S424p4WJaoOgYSXTR1Sn2JJIYpdUozC0uKh4HOyIggace8AoykBPv0dC8l0kdGeVLylfWp0iUnlV/Xg3a86yerB2DwFFrtACHkXEgks25nLoWFQoAb2mRJXqSg2vN6uUokv5IajCTJkhaVsZgEkuhTh/3JYlqSvxdpjZ00oQQIChoGAovvI60s6VdVHLtJRMIStZoc+OjsCQBgY2JkVB5rNLKWpUerGGDEe+zZjuCshrcEVjHKCABaXtNoHJ/Dg4MpFgSNnF/OAQDbTqVrPWF6f9oOEux+zrTjdrPFkECGyc049sHG4Q7BFopR5NVBiTBnqpX9i/feOobPGOXzXv7oW2/AbdnnuIzP8EI1aCjVc/fGEW68ESHujz6I8zJvz7Hh+8Gt2K/pC/z48afiNQ8ogfTsDGfsD7XzeI7Pf2oACoMjn8YI60hXcPN4PGVjFqfrWriLeFG370b2jnuTQ6xJmzXI4t5w++AOPng/jutzxzFK+9Ef/xz+7Qe/BvsDUDP9tkEX3nv8mT/zZ/D7f//vx4/+aJRBfvbsGYqiSMzKYrdv38azZ8/SZ3Y3K/m7/O372de//nXMZrP0vwcPHvx2h723ve1tb3u7pvbbjrC++tWv4hvf+Ab+w3/4Dz/M8Xxf+wt/4S/g53/+59N/LxYLPHjwgB33gLjl4s1pbWCkOih0+Z1LTBeSL8+8hmIdQlLkdkeLw6BXYrVJrkSabD3yQrzpXrk1Cf7JcVUGQ9x6UVVopQGVHmi3sQikoXAEU3RBARSRxCB6mRdqg+ek6pe8NlyvZitKsqZUaOhVtfPo5RSmwi2yRTzfRq+wQQ/L9iKAWGbQotC8Uw8UxIbKBPygEvdit96kucyqONbBLHpQ9bJO4osioeAC8BICPEZcu9EWLVVbRNqjc8jKFzkirfOp3pIaUwOQU6wvc4AbM4KlF1kWBTpC/4d5nOd1F2AK8RrZ9tDVUFQdzGQ9WZvqKJYACjXSMJybLgFA+ohIBASDtSil8TlFbg5dAgb1datOFKlf4Ap8sV6ld+ZKvP7YwMvfST1KKRgt964HhYiKblEKjD9HQ+i/ZBeUMjtgi17U0bx0E7sA5JnIswSILyzqwQX8Tu1K6pAeSjgVRUhTa+Rs/tWMWqzRiRfRESSltMcFJXODgEuMwt0JxQY5pydnp3BsIVGUC3r8ZIEtGSKyQ2YyXsvhEJ8bATfcyi0UJTPaQgBYSGCaFSOy1XIOmwsWn/XejceQ9aUbU173PECRFqK8GdfV7XslDghI8pfxu28eHGFDwtltx5qc32BJmY+rbfzdfNXgDfauFFugYqNywefQbHM0jIpbZgCCtlggztv912NEdLoIWD1ku40QB2QXOCPLy+tvxXN84cE9ZDcYlT2JEPbLjxZYP4/Hm5BMd3Q0wNU2Rl1uGMcy7p7jHnkIm8mM59hguyO580nst7Vhfe1rX8Mv//Iv49//+3+P+/fvp9/fuXMHbdtiPp+/EGWdnJzgzp076TP/5b/8lxeOd0Kcv3zmZSvLEmVZvvL7zrfIoJNm0C7xhby8XEJ5AYXwE4lOkVdpw9JF3+MiCqFaUjDwyJiK0lyYbevgyQZh2KMDmAS2WLfxZqu8RDEQxFSdgB850UnaZD1SSigPAqB4awKLucrq1DeVlE6znvQ2PbQFsGKPRHDxGIc3XsfsVlwkT08/isco8zRHLVkoQmjQJuACL8mhJ5yVDTzL4GXOmSJtbYucSIeCKZh8XGJzGedBkE1Kq8Qi/gKHkJwvvYV3GMG5WXiTwco9kfe4d4nSSLwAow06poSKrsBqGNfOAdGLBSwCWUwOSqKw5guUFdOE7LPyPqAT/SV5ebYN2vU8DpXHULZAVcXUh/TSaK3hJI0sQBxtkBfC9M7ivN2i5bz2LOUZNOGfaZPQPeVSeskrlSY24Vd6AF+aZ62DMBql6fXBpk0g5xwMhmNsN5J6l00qQ0hMx3xWQoB5yemwPqCQxe0DuiD9XHEMhXepVzBXcmwHbeQ54AaXGVTD+Ib3PF5nkKi9fEf6nzag2fLfPO7JiUNWxXl9cDteU+kVngsqjQAlPSlgmIarLwhKGOvEEGPJkmE3HhZCik1QQgcEImRzMqjowqPM4v3aMA+7KAtgxrUzpkNqPE6ZGh/w+T/UW9ynFO+IgKzWtMilZ5Tp0fPVJS5W8Xo35/EYh3WFnP1k55dLPL+Kz1rHzd9XFhnZaQRIooPFs4tIEGwJOMuNxd3j+A4bGiL+KuCKiNKrs0hg+5E3OKLO3xUd1S1ajI6oYMyU6+LpFQ6JHh7k8fPvfech2N4GTVRyHho8GB6g6zz+PyC7+8fYD5QSDCHga1/7Gv7RP/pH+Nf/+l/jrbfeeuHvv+f3/B7keY5/9a/+VfrdO++8g4cPH+LLX/4yAODLX/4y/vt//+94/vx5+sy/+Bf/AtPpFF/84hd/kOHsbW9729vefhfZDxRhffWrX8Uv/dIv4R//43+MyWSSak6z2QyDwQCz2Qx/6k/9Kfz8z/88jo6OMJ1O8af/9J/Gl7/8Zbn5zC4AADR9SURBVPz0T/80AOBnfuZn8MUvfhF//I//cfzNv/k38ezZM/ylv/SX8NWvfvX7RlG/lXln0Vif+pKKXCIt00thpFRJSBBlIQ/VWqeUhbjsQYUUYWWk2ndt3auk8stFkSUQh2VkkZssdfPLtQSPBMWuqlHiYKvJNpnlRYpmpL4egJ5vbYfxNMlLcCghIMlFSBEcpYYNwt/FAmlwPdxaSH6d7yM7pona1sFLlJcGhWSB6YXgPYKwS6TiOrAl0GRId344HqNZRq/LvhRFvGBK9W0DiSy3h0kb9qyZwqClLItw9xmTpfSatDfoTGNA3axNDmgrGmUlx5djRQ9xNouRkcnzxE1ZMLXVNh3Ep8t5jlwZtGvhEuT6UEhku4FppWCRergEKjwoSygyg1xdxtSs9+hlTV7gMpR7I4CNkNaCTGGe6QS6SJ9DP5c+sWX0oAbJRkRiWbXzb2A8GmMxZyrHSki8Q2JISH7M4ErkJ+AMnT6modM6c1bUlpGiaAnOcm1SL5h8OTcZNJ+dRjg94VO/pDCUXJw/QcfUdybr1wc8p+QLqCO19QYr8noOj+LfpndyFEwtNyeUuggZRipGBdssRh4Tk6NhZLcgTN8PMxySweTIxKzFUI2Biuuyi+mxxjfpnowPYjo+DA1O5+ccaxzm6WqLUpSEb8TPnXVLuLUwcDONhhxVxx4vRj6v3Z7hYBoh5SfnF7hi9Lnm+NWBR8HMkbYxgiqaIQ7Ij3nxnBJIFTB8LZ7bMorOsxxH7N3acK2+/+ghtrPIUuG4Fo6PjjCaMO1bxXs5Pz3HuIzjGo3jcR91wIa8o6/dfR0AcDlfYDlvUz/jJ7EfaMP6O3/n7wAA/uAf/IMv/P4Xf/EX8Sf/5J8EAPytv/W3oLXGH/tjf+yFxmExYwx++Zd/GV/5ylfw5S9/GaPRCD/3cz+Hv/pX/+oPMpS97W1ve9vb7zL7gTas7+sdv2RVVeEXfuEX8Au/8Av/w8+88cYb+Kf/9J/+IKf+vhYVdz0aMl9bFqrHk0OIOyeeQAgqicWZxLGmk2fnU/0ACVablIx1LxfihdE5+CQ1Ip5UazuUhagUU2Ziu0HGXLxWPSCiLEQIQScWChmrVunMqaG5a5sXO0Yh/ymeNfPlRQZQoFJkK7abVWoLkGhEqb7J2dq+ZUCK6SlC3SkvSbQVdN9KoHbmT85RkIFgOBphQ2/UrclX53YPGF74EY+NV8wnwItNDAY6CExbpabkHoxgYISFYmhQkA+uEebrepsg6eODOB+DvITtpZfjD9Mzn6daTFEkAc2cNRFnOmSM7ttOInubIpJMAAVqF9Ev4IueC9FZqXmFFLpK4Bx8ePn2IwSfInqJ3OMl7ESpYJTPL0sDdJaZnbqXRFijXoFbeBczlY7nXV/fClLD3AFxqMQHaHr15sTM3q8tuXeZVsi40gW2XuRFqlkm6L9WsGyfmLD1Y3x4G28cR2+/YoNud1lju4oRYhKJrC22sqZYhwp1gyyXd0X83GyYYU0wTcu7NESGCa9zM4rnn5uQGrINPzcbjlKtqeL589piTVbyJetWmRmg4boL8XFA1hqoGzHKX/O5OZ8/Q9FKO04819Y2mHFeDo9jZPf63Zs4PIjM7BerDXTGqEzLsx5wQLCVqymueHgfxwcRZd0+i+PrsEab2DGijfIKFbMB9YQCk3aJb343SpMcU8LkjdfuYkYQyijJTjgszuZxjoyAh4Ar1gsfPIjR1/HNQzxf1klx4JPYtSa/9T6gKEoMiAhbr0jZH3wf8rPgqoJPtD69dLxKaEKTCFQ9GkETSlXaq9S30xOB9ggoeQl45xKBakIpwuzIPYS0YVVlTD/sImQE1QXv0wshWAEU2Bc2EQDI8l6aQNRIVdbLRSyJKjR4jnv3Yhg+quJ5/9P/+R/7tBofQOVM0uR5QaBIfiX7i/dpEr3IrYSQNq+Om/pgMMBgEh+YmulC19kXe63ipO78u9/EQupf4gvf+5SyFERdkVcppda/WB0c/51DpVTfaWC/zryGIsFtxus4Hk7wjEVjLz1QSiel3tqKtpVPpLGaqUOlbHq5+p1dJed9z9L992hE5p7rI1NF6nPqd2ufpl8mR+seuCLmPRLxcgIT+pCyuTu/6qVw0jHUDq1S/DmoBqiEZocoSuX7+9p3jPU0S8kBMn2/nlIm0aWJDpfyJt0H4gBgVA+OEhaEohz0qVRR/s40HKXbc2qyHY9up9ScJ7nsWm8xIkWZYR9TW1kMKX+xVVT+HVgYOhZCF+RCwMoTVSv6cNuADYFOLdFuaqKw5GZubUzveb2F5XvmikAQ04W0UT6/jL+rih4IJSjFST7E7OhenPMte6a2GiuhjCNzRuc75ERAlgTJnK3XeKIiWOH9xQlGg3jsimnETd2iJYAkJ8DCho/w8Ix9XJRHySoDTV0wLee15zhfxr6pO9PYdnTz9Qf47jxuWC2d0ycn57hYxPm9czuiDo9v3cNmSUVvspEc35hiQTDYr33jNwEAn/+JL0EdVlDdJ9+w9uS3e9vb3va2t2th1zrC0iqmIe7cjmHxehy9wsvLRcqlSATSOptAEgl0AQ1teo8eiJ6WeMlZz+aZurEllWOUgeJ3rSik5jk6pvDEac6LEh275TMTkte4EU/M5CkC26GXgJfISig4nE0HTVIRmU7evkDs8yJLZLGGwIOsMChJJNoRt14MBrBk5RAv125dH9XsTrQMS+bN9z1GKaIIHgUjUhFKbNoWBb3gjFGQa+yL0dsrJ/s+J5bI0vnk2Ul0mJU960JHGHmAhqZqsN808ASmeMKGu0VAWDa85ngf7h4e4ZxQ+JRS21GzNrnI1Hh0BMxoppCCVql9QiIybQwyEtcKQXHnbQJCCOEtdjANiUvQmFfALzH1JhGngF/wylwqpV4RNnUhpOgogXhCSJ+Tez4oh5hNY7pps2DOyvuUGs8YBSn0kbhEbLspwRB8YiQRDkAdXOJBFukUsyO+abg+snKATrgXeT+1URhzHZ0+jp77w9MLFHwmhSNwsw4YMB12926MiPJphasm3uNL9s0d3TpEQSmfbMS106xQUtDwLplMioXDo9N4r5+P+ExNcmQE7yxJ2LrcLlMkKelE7zWGnML2TEA/NQ5GEeZ9vonR2WBUYJQzUtyQF9DnuGLLxHQU52VWlljY+LvHi3iNHTqcPIpRkFmtcXsUufwc20ouVcDjFZkwDggGC+cwTcw+3TTx2Kt1h7qRKIfjzz02IJn1ZZyjqclw80ZM54GR5/n5FZ6fx4hNJXHQCRbLeH11G48xnU0xPWSfJuL8fXj6Ic6bJdz/FUwXe9vb3va2t739X2nXOsIy2QDBeywJnZaaknM+bcVGq+/zvaT5nhpRk/yCUq/UprRWiXW6a/vipHjVUiwvs7xnK6Cw4qAqsaUXPxwUERQBoCbjBHbOl3NcIbgEtgj0Ml3n0DMd9B5JSJFh/FkNBhhNYp3K0ousBoMk5NbSwx5OJ7g6Z50igSn6oqtKePXwKjNF2Kln7PxIUZeIEzYNqtEwjQEA2lXd12w+zhLcn/MMoORxBChgyhxOiuVsIM7yEiGT1oQ8RQPLZbze01LjaMGxrsn5OA1JLM+2IsGBFKknjkDXS3UIA4epsh2Wir4mk7gN6xhhd9aiIyS5I+DFObfDoC7gjJBClxTAulfvuVY7QXl/43rhRhEdCH1EKnyJOvSipRoCltAYspgu8+tbl0L6kIAu/T32GRullUmtDlFpU5g1GGn5NnnHJrHT9xFWVkoTbgWTxUgYfFZa30Kxhng+jyQD2Czh6OUv5+QI1EB2FL+zkChu4XF5Fv9+cs6fTyxu3mWUMYvnMlCoyI9Xshm3nXisuRaYNMBkWaNgxC6MEk3mUTPDEao4b6NJmepU9jKusVx3SWAyJ/vNqByiexyvaTGP9ait3WDEDEEuKgFdA8PrHQrqp97is3l81quboyRFsyKg5GgwwIZr/mwRx7eqHb50ECOx+4exdvbe9x7jg7N4bgZdUJMMju+jZ2SteH6+whFbDo4pSht8h8DamyYDx8X5Bo8XMeoKNeevAe6Ta3DLNTi/nKNtHdz/v2Dtv9PM2YDJZIz3vvM+ACQ9reFoiCHx/23XU80UiUJJXgwqbV6BKQwbfHpBdvzuoBokFvmOxdWQcGNIBLs+hESB41P60cMJbxIMiG9AJRRDeZ42VXkBdp2F57kNASAI0p/TZwkR0CPkuIazvEjqyGx9wrZrcMgxHpCouCzLVMW3CdgRkrTDLtVFQvHs7jOvAvySvIu8HDvrEqihZErHqCUsX9r9Rvh9FuwuclDSf1WJUuiTeB9q1/WpKNEY04AmpNHlGQrSZZUsyNtbOfzz+CB3y+g4TG46HLA7f8nUkTem77VjKtLbGjlfhrnodFVl3/vEgRuYlPaT29V5n5SQnZP0ZX+hMn9KhdQv1a8dlSYlZQFD70lIWs+FyIoB9OmT4AEnDBziCIX+HLKxWuvSMfsNUiUGEVG3DipLB5eNS+d5fx/UDkuJgJm8Q5DNP61ZlciFJSWosgKWpLxe0ol5QDUkiINIuE5vcchnXI7XbHzq3aI0E94Y38CwjPfa6vjLRWjQMAX2wUfxb9gGHI3Z50TGE6uBiSERroylWGPE9QSm1mzbouaOtiYptKoaWKLnZnfjOI8OJom4+wzcxELA957Hl7swtgyqDGUXj33BzXbVbHDvMG5inz6KaUUz7jCzJMwe53gEAjRUXNO38wJ3hvF5//bDCLBYtRb5rfid129FZiF7tYbbxE1J657Zo+B96MgUM56UmBHdPKSzo3yN0She3+wGU4zGodkyPcjrCDbH1Wm8zsulXJODqoYCav5Etk8J7m1ve9vb3q6FXesIS2UBTVfj9p0Y4rZUvDWmhSa0WzGaMkGnorx4jEVepFSOpEAUVEqvCIiga9o+dSHpjBASpLehd9sFs5OviT+W6zVKpg580NCEU8MLFF4nIlxJqagiRyY8fzV7pUJInrW0OwSlkkKvZJWqQmOzESgs1V7rK2yG0cOaVlEq4OboABc6Fmxb4fhSBoEpT7fDayhhYSrcYwcm3Un6TCVGWhFCzJyHoudcUWrBjDScSJjQy9V6N6UlP1USqlSkBXCtgmZk6tP99SntVJAHzcFgG1isbjIEescTRuBbb2HvxTXTncR5GaLGASOOExKFHjYGGwFbcILn3RVuVpISFM5BjVyRrJgRR6EMTMLLcE5VHwFoeqiVzhMjiUSmMRrh9Ccczk5KUEnkFhJgJuc9yj1SikWiPWgkcI6hGrVVDot1fF4s5UUO2xx4EovlB3xUlqGDlQjXMXI3DjnX8YCpLagSTS09NwYGki2If851hkL4AqU3R/UqxSLzgqzsZUAce+Smh8jZkvCFgxgV/If3z3HFVOtgLYSzQMXU1+FxPPDN4xLb43g/L5yQ9xbpuZrKK9C18FQoPXkW18TsoMTrY8K9mfbqNiXmW7Y6MB263Vq0V8zkSHJjoOArtnJUMe142SlccD4kbVeogEKkbsgyEUKDE0Zbjqntt27cwANyLLbkFAy6gibU/fnZQ9wcx+N89vW3AAAH5gjnp+x/rOL1bv0KtyjiungeYeZVOULryNTCtX/31hBH7Ftcr+P4daUwpRCk3COrMmw3kmZmulNrDHicwREh9tkAjykyOzpimcAFnD1bp8zRJ7F9hLW3ve1tb3u7FnatI6yusxiNp9DCZsGCYJHrxIUWUsEne1GUELFmIOKLkvff1Wywiem9/52UdoLzvZxF4nhreqlvfr4ochjCPV3XwDvxouPUO3TIBJqeJMRDYrTOyDBvigyGHfNS1tC7jZcyPtfBUoZA+mld57Chd/P6rVgs3R7dxPcYkQgT9QskhrtM6okErp+Hfk76z/WNzVJrs1guYm68tWQvKICwoIfNlgNvfV8f2+l2TbgPqSPZDp3ta4hAlN9IzcTkUESWQ7PBdKs8Ms7/lhHFoBiCqX/YCZlJtMdrOubi35Um0LKDyqIn2Vk2rJY3sdhGD/zGUV+HlDVlEuUEUouDS3B0lcYqwI08NyhYC5PIqOnanrUDvfVBVl9TfFl23vu+/CdDUUYh51yXSSLGpNqklhpV16ER7r8kfGrg2NipJOqGSQAVkXIPPqT7lGdZz2ohQAwV4KSuyKNYnWFIvjpNsEenNTRlKkSWRa/nCJcxG3CD0dmPv34H334U19bijFF+p2AZaTre14vzJj0wx1mMUPJ6jQWfvxFlQ0ZThTvH5MCcxYzIo4+ucFrHe3x4FFk12uEcCzb/b8hGURRDjAmSeP4sPnuD2Ri/9/d+CgCwZt3yv73/LrYMu0XgYZgFVCOuo8QzakEMB8Yc82fufQlHjIS/cfXr8brbBcJSFAYyPLgTGSxmx1FB46N3n+Cdt2Ojb0MuxMm4woLiih1VGlReoiooaX8VI7Ls1m2Ac3RxMQcAbLs1HhBMwwQBNs0K+YBsICFG5yeXHbSJ97UkgOb8fJskf24cxLlaLC4xPcrhbMBTSp58nF3rDSurhoDSKeUj2kG2tdCkQxoIkWYHOCO9Qz16IOkIMeWj0ROKysvF6OwVWIC3PvVUyS4WtEu6SPKiyVSR3iBd3SWtpbR5OgfFtJPeKVpLIdwQFVdNR2i38cFMlFBQCQGXXmbKoyVaJwRBtg1SX4z0rCwuFgkdWBAaVNf1J6LfirZTlIfs2TKX8gmPuhH2iDiW2WyMy7mId72YGnzxyz060SdwQUggBAGZaKXTixJJWwkAH1BkGkG+w4cs8x41X5uXTHvU3QL3KVx0r4sb19v1KSa5yIUI00KOfBI/JylcU/To0F6HzaETqqXEfmISk0RNBgPXdNBC2pzAggG50BNB0Ku7GxYvFxHth52f8Zf8newV2GWrIHpyh1w49c9Zoe4FvIAggkZwW16SVMdD2iDT2g++vw/oN24jG5v3aQ7lb0HnMJSrAHXJLExS45Y04bAF2sdRf8mRJujzd19LmmZPdNwQSjhpD8LWzAEAjxYNRiSxBlPM3dr3cjwj6cfqoMs4/juH8ZmbjSb45gcRPffoKqpLdLmDI8mr52bXdXUqKRTcwPMa+JGbnwUAPL6MKMD/fPkOJrdILsuUX+4t8kn8Ts3v1ssJcsqV3CDg4dZBDkfgz5wOxNMWOOT6tc7imx/EOZpTOqU5r6EGcX4PhgKw6XC5jpuD4znsFvAEQgkR75PvnWNN3apHi/i3ximsnkbHtygIdGkMipFoaMX5aJ3Cg3FM3U50vIfvPXsf6zamIN05gRhFh64s4X8AlOA+Jbi3ve1tb3u7FnatIyyngKv1MslZKKYQ1osrjEfRO9BU5HSuTd97QXU1EZMKDnfnBLpnPEg8f1IYR0iSDYn7TOcxotox711yjZUO2NKzLoiccF0HJ1FZ1hf4BWIRGGnlwyGKMdkZrMCfe9izeNNFnmG9iWG9ZeQyLAoUTJVcXcSw/dnjxylVYoSCAO7VCEupFM6kOCf0RKy9qGBPQrsLv04Kzbw3eZ7DTMjzyB4R9Kfo4dSqP6MQmQbvEsw8pQRVSEKEeud+SOq2UoBhWthNY0pC2QDH88wZ1jyuL3GD3uBnKT760YmC7iLUXQA7rmjhdZxf5+LftNZorAguMn3mAizXTOKLNBoDqumC89G2fVSbuPagU+S/ezdeIrBgSnhnugBkpv+O/NRGIZeInb8rc4OK7CeSvlPGwErfT4KlG2QSxRPkEHZTkakPzEPx3pgd/kYBD1kXEkdfKXIwWQlTxmjWMwryTqEsYlRQUgbj9N2HyAkUUCXJo7HG7SlTkCR4nS9rLMhgsqUMjfUaiiwVmqmrMBrBu3gPBV6+CMDT83gxK8LCP/PpN3F/SOj6sxgdPH/i4AgCOrhFWD2AtY/P9ThieWCUw0fk7DM63vN72RD6Kp7vkHIebgA0EpmyZ2lUGxywFPDZBzFSmc4GePe9GKkt+TlkFQLT4N4pzJfx/lw++4DzPMTdW3EOO8MoqVlDV2zh4XuntgoN16BIJW0WC4BR1HgaP3djMMV4FkFbLaOiq5XF5Vm89moW3303JhWOqZSu1vGev3FwgOWKfXVMO3algjkyfRPpJ7B9hLW3ve1tb3u7FnatI6ygAozZhXbzpw59QVk425RKjbnYKQinv+/Iakgu3ph+P5ffqd3jyTmSCJ1OpROxrrOpnpJlCttNdBVFXtt7n7x38UaV0bDCUycRhTEYH7J2YkXCu04eB9VMoDTgBUssbAS6hmvI1qwIed8sodgImlhBTN+cnFxnpXrP/vvWt/oIVaLU1DQd8Mrv2q7D6IDw3CuKHepeJiPximWAci9+t+ss8r5Yx7/twMH5l7reIDCSqRqPLMRlfsX5KLMCJVEXW3qeHzYNZstYxP9ff+wNAMDD1QW+9TA2eor447rcIiNIQZrOi6JATa/X7rROpGhVmoC9TQ3NImET1yTrp9LWoHWSmpEQSu3MvdqpUb0cde3+p8jYD4oqCUyqHVaLMoE9BHAEWJ6vdf2YcrYSuIpRZuvTMxdknEEn5ngNlaI2YW03SqPkeHIR5iyGKKgeYPkq8l2LomLEyazIh48fo+KzO2jJSN6VaS7Bte3WY9TncYwVa4DDQQUvkvcUWQwDjVILICaaHVZoiXR4lwwV5994iPEgXvsBefqO3xpi1cZnCKk05qFm8XMZL3I+3+I3v/UeAOAnPheV1L/8xd+Dx88iCELqPl3WwXF8E3Y7F22LipkhTdHD55cdHpPX0Apj/bTEIdEPpRlBEcDy+FGst53Or9CRW9GwvpSbgDuHB/HapahblDDCo3kVa3bHNyuYGaNkMonMqgp32KIjrQ7YltiyHjc5YAtL7rAmSEav4+ffuH0MczdGZycfxfH95uOHGB4dAJ+U+QbXfMPS2qMsC7hVfKm0ZCjIMp0IaYVFoihKdCSP7BU9Anq6ox3WgJdIRuF1n7YRyhqlU1Fd2CFc273yArFdlwAdeWH6gjP/nuVFr5gb+o1Pemik6K60Qk6F22obF/OmbnvqKSl4di4dPMlaNG1CeHkvIf8alr0SGaUTspDBCsUM/scWQujnK6WEQrpOtZNCzNnt79nObjuPasjeFmqMua1/dS8Mauc+CBLOwxBMI46D8gFBUqT8qu0soIURocVgQHkJPozjwxsw1LJSTbx3p63Ft87i+niLtFm/9ye/iCfz34jfJYmnCwpViC9ZYd0os7yfa77oDXR6WWcJrtcTuvZziaRLBgIBTJZDEzAjmmZxj9wBlSCuWaF1SpIivpcayQTRZbI+fc3nomlqaI6vZHraa8DyHPL8aEuABgBFBgheKa+X8j06T8+GVhGsET/L5yvLkJkXn7VyOIQhuq6WfS+41PO2Pv0ojrU7R3knzvkV+3/e/s4aFZ2EYcyu4ebEYchdhG16GBVZUiu+YirKeIPBlL154/iS32qPJa+94wQ/OW+g2H91NI2/+8wXBnj9OKbpWqZKt41Hyet8yrWznm9wefEofu4inv/Tn34N+fQg/s5ygLXDqCHoietvWJUouBGds1fqfNVByK2OZpHiKFMWuY/jOx4d4vZBpFqaktHl2x++hw0385KEvqu1wylZOQ4mdOTUMsnd3L4bN2bV1DghC0xuSQbcGjx9SsVkSJ+Yg2avWNNxfhFQEuQ1nEVUcl6pRDOnCFZbbhSOxhM6TPE6P872KcG97W1ve9vbtbBrHWH50MFZD8+0hHBxaZP1vVaiJGx6IUWJIrrOJgCDTkJ0/R6e+kAy3RNYiCcdgJZ/LykBUG+WyBhRiBifMSpJLASvUZUzHpsw0irvC9cJ6u6hxBsVNVpnEzy7INRVTcawdfTeGgq0eaVQcgwCNFFOo2TPVX0VP7e+2qA6YGFUSvEqAIqpgd2IMzGwvhp3pRaA4CElffG0oXfSqiL0aFW6JzmF9dym69sL5OM+vIycB6ASq4XZiWolPSXnKooCm1X09laVRTGNEemAQn9dt0WguquSdZIpnJDp4N9+8yEA4Ke+/CaGJEddkzGg0hWUiHnSw+6aNgFDBGxgoNMNlZq6yTJ4K/PFKdI6FZ0FUBLnlPedKVKPkIA1KU2okCInuTdaKZQpbU3mkaZDEI47gefnJmUIRqM+zZ0AGwKWCRaB0XFG0IT2KkXsIYhKdpYiJ6367+/GY9JTKIKbxXCIjt/ZCnjIAGVxAAA4e/I4jq8ImN6K9+vqGYUDP5oja+M5vvjZ+EyNpxqe6/zJk5jebVqFgvNQ8HVXtgqm4zNyxblv69Sr1pJE1o4zSGfIo+UcALB6d4u37sZIJidzxnrZoGWabnkaz2sv1gi8dx8xSrroFrh7J471kAwaWaew5XVsakbVbUBgL+VGx6zRGksEsosccj2vFwsYRsf5pETO99qIWZmjgcGkiNHR7PguAOA733qK509JFn6PpNDVGuWQasaHsYerPlHYXBIgNiGvqV0hH0j/KNtBgsW2FSCMZKEyBCEiZ0bnYrNBxrV6yUxXfpDDDgzci0mH39L2Edbe9ra3ve3tWti1jrCs7ZBpwFAS2oSYzM5MBrDQLt2YAS7JWohH7mybeN68MEkHkzxTsyMc+LJMuA8tmiZ6IENGWMFa6EIK2PQwgkpNzJ2zyNikK1FeZsoEuhAGcp2pPkpJ/HK9TPuQBfvhdIxaQBTsvs+0gSYCo2bEkBkNkatohbcuM6mGEWzvBauXcdI7db6wA3hQaVwSCaoUEe3GYQIfT6wVNqCro4dVsphfZwBSYNfXxpKXLuQRzsNLFC01GeV6xnPOWTUYod4Q0GEUOkYBRUm2/noN0WlnwIZMKwTWP56QK+7f/NpvYrMkZHcUI60sKChyyLUS2TVbZJzfmnBqaI2MkO0hH7Ntu4SXcYuvmJm0VqUeFbxP0Y9LXmvfOSzQ/2BUioRlDlzoUcIqNfL2wpw5wShlblAKTyLvg/Xhlbqhc75vyJYWC6WgmNXQCVC0w+hhVGLP0E4iRAunWIMl40FWjlBbab4lU/kwBzZkdCGr98QUaMj5qDfxc5/61E20bBEJefzdla1wSj7Dk1pqiTVmjKanzDxYX6NeSJYlftcYhcmtGP1cUebgqq2RsR1ElwL77rBYxpp54Fguns1Rki9wQlBCuDXA5RWbw7U0Bm/w5Hlcl80qRj66c0luRUQqO++x4vMssKqLxQYmj9mRysj9UFht4viefHQGP2GWpYvjOxjmKZMgYYxqtqnNw5BE4MZ0ACvvh/hpDI9u4G7HDEh1xkOscHAzMn5YxPfY4ukWVwSpVGQoGeQK1sZ7d7WOc2CqMgnKgq0CX/pfv4j1QMFmAn35eLvWG1bwHibP+5chQ3SYYuetKWkWDcM0mxR/Q/DI2IclhW/vkR7QJCfkXAKySHrE+zYVnKWnRjmVUjgl03ad76DZ79DaLiH8hP2ga11iSUgpSJOlVJVcmzYZHBFDLReXKQrkI6KJnDSqqLT51kKVo9vUr6G5OKphkVB4mZP+H5fmq8dPoU9p7vaqcWACMtAIPQozMSz0AXwmaDGvUse+aExVsyHq5xy/7D3oU4cJ2GEtOhLrukxSUi6lTT032zwfYHp0i+dQyDiOhgS8frtFYCpViTZaYxMhrYAf5qtLFILcFPyNDVAEWxj2mgyrEldnUQk3CFlxnkPxOBmP663vdbOk7yWEXqVY5lv3IJ+cBWqtelonYSiJoAuOi+OzCAidAI7i74pMI+fmVPJ4uekpoaTXr2m71DPmU366Z/QQ5yT+B9e8bFzepRez0jo6SQAMj5MpBRHoMiRxRTZA6HZSqACKsoClDEUppMwYol7El3Gzjfft7heOoIn0XK/i5882ayy2XN+HOY+XoeL8H5I5Y5bnuDiJx7vaCnBKJ6YRQ8olXDnkBOfcuxXTgJPSp+c+jOIXZjdK1FtuFpTnOLhdQo2ZAltGwM7oIMeGqsJPSVBdGWBKR8qw7ymYCkDcPO1a+slqFLx3GR1gYzQa9sY9XZ6iGsTxDIZ0NpqA1SKeZ34eU5Vl8LjHHrAJQVwAUHOdX9BpzpxCJtRuK/YiIuBDpjeFlePsssaCVGueqXKMRwicV0HUhu0Q738Yv/vavbcAAF/81I/h8cVTWNFB+gS2TwnubW9729veroVd6wgrenEhcfCJOe+TRy+pK610gpRbtQODTuKLu71D8ThdYpEwCW9gJZzOChRMvVl6zXlZYsNQPqcXWRaD9J3cZInTb8PUUesa5IzAsiCpOZdgxSlFgwCTi8vJ22bylG5ULGSGLE/sHeK1GgBCZiHRTQ6VxPdUYqHoO977Pqx49pctqeSmmr/vAQA7XAtCzptUcJWClZQW53wym6E+pWTCbsuBHE6iG7/DGpH4A1WfMhRlYu/6Pqfg0ncSV6MPKRUkhfY8M6mnSEQ2dZlDS0TE6KazFhWJPUfkaatyjcs6/v31m7cBAFebFo9Pn/I4ZHHQLqWUUwuA932vXcoGRImGnemN4o7SNid9WH7nO/ycCn3fnERJXeehqHsh/TNdiBIocb7i55q6SYwku7f8JUHkyH7BKNM2jLTCFsVQJFMUCrYsNARq5CpHTlVh4Q3cepUUsCX7YbxL/nYnF2od/h8//r8AiAASAHj20WOsScpakRTWtSGp8o4pUjgYAzm/UyECZ+4dT3CTZYQnZLdoqgw1z+zZDzXyBjmfyYwpsPnzNTomzmZHTHHOAjT7r4TxBCpgckAJllmMUJq2wdXli2nuajZAqyhdouMzUI41SgIx3FX8nSodNElys4JRpG2AeGhMJ4eojpnuJZNFPTfYzmME9tGjGGkdDce48VaErgs59mW9xKZhNMg0ZpVbGKot6xUZPYZHuGQ6dAOCS4zFbbJpjMgUc35i0WzjfR0TzLFer3uS4iJ+9zfe+QayqaOa+iezfYS1t73tbW97uxZ2rSMswKNra+RslBPWCu/6uoB4/UqpVNzUqVCN9DsnNRnfRwUv9l+/yLpgsgwZvTSpS2RFgS29ks12y/PqFxp5hcEgy3X6bkFRPU1Rx9Z3O4jufqxFXr3wO1Xk2DKn7xkWdM6iohCcSYwMOWoWsrer6HFNRjM4MsfXAmTwGlrg1LtetRTzxMKusPsuqEIiGKkv+VT/krpagIPlHJasKWmdoSRUtzmP86YL9Upg513ood/SlJuZVF8SAI3yNpXgPHxq3FXCpK5NL8jJOo7SCs1WWMk5Hyr0TdCMWqx3SXamHB3EzzcbPLj7JgDg/p0YYb3/6DHgn/G8MlYN3/URPwfTCy66PqpO5UJxPh0QRAUm0Vr2K7RnIwEUI3CJ5pzzSa5CODFDZ1FKTYlh7bbeJm7KfvWHdMLA6GVrXao/Sn1QO49CmvVth9GEwIobcY7a+RbFMIYDnvDsxgEN1+CYcGndtWhEHJC1osXlPMnKvPHGmwCAjz58jLPz6O1PpF4MiyHVDW5P2T7il1iRUeX5nOKOdobjESM/gqS6UQZQoNJfsNk29O0nHz2L0fLl0wZjysoMbhJgAYcxhQ0PlUQUGyyWrHUR3GA6jeGGPJrSEjOuIMWiLZt8F3aNYz4bh3coYqmHaBjqPq0jAOUgH+HofoyWsgA8Id/hzXGc56PZbTRTNiiTjHHVASdn/J2O1zavtzilCONiFcd8dHOGMTt0qgkzMFmG3MT5vTHI0jGqPDYHX53FsT6cn2A4EcIGvk/aBmbCaFbNAQCnz9bQz7ofSMDxWm9YudHorAXfOSgY43aNS/1VkjLJTV8ITpkG26ex+qK6Sy8E+ZuzLiEHw06/i/RaSXZKZ1rqytjWS46pxEbomLRJhfAELsiq9DJx6XgKyktvFDcsFdKJ5EH2vk2yANKD1tYbtIrIJ0mFmQotwSBscsdkcpg2LN/EsXaqThu4TSjA0GtXfIz0iKS2dolbw0vACaV16qUSxFrTthiTrqm5oEqrD338L4fYSeWF1LvhoelsiO6Y0n1hurVtAjgIG8VwNEy0SoIjaLsNnGht0alQqpcNkZybQ0jf6biTj4aHuDuN6ZAbh0cAgIutRXUaKWhqwV6FnlxY9pdMG/jUs9KnqpOumoAXtIFiT0siAA5IPVkuzXkvJZK0qDIDw/S1yN8gzyEPjoBbttuml95JqWHfIwZl/bVtUtEWVgoFDyXML12HSkAeh3HjuKgX6JiWFAez6Vx68Cpe5/b8CpsL0gNR0yq3Gd75xn8FAJw9p8xI5jG8EY89p+aazjxGkn5fxU1qbAzGZWSmeLyNG9x3HrXQb8U0V5fHjeZ0VeOKQAhQvmM2GMJRGXhOZ0ZPK2SkIGogMigDFExzgntBc7WCvaLKL4ETZTbEhI1ddsOeqTBAxxf2xTPSgGUFJm/EjWgwjnO2bFZoiVidX9FpOxphyL6+dtWiPmNfFSVR3nqzQsfSw9ENIviqCVZ8sM7P4vm6oNFRPVskjE4ezuFuMI38GgFRbYGOhNpTvp+qYcDzZXTMFuu4ng6OChyScNjzeLfu3MDwVpzrtovXkW8t7Om2T/N/AtunBPe2t73tbW/Xwq51hIUQYcYpIsolggpw0mtFr1AHi5cJG6zteyD6YyqI/ytAAWt3pUToUYaeM0+83KZrUdKTbQS6rRWqit5+3WBLbrIpUxZ5XiQZil7GI8OOxiRNw/novrUdoxDX4fhGDMe7LaMza7G4ZETH/hroHIqpCkvvMbOAordXMLVp4VEnqgn0P8Puf9CSJG6Pq04CfuljCl6Jgq3cG42cKSH5eNt1GHHeMhaW7apLgovJ6fe+h3Yjl1Mk8Iba6bmT+x9UP367wwcpQosd039t27fbC+Clyk0iIc6ZtlGuTSG6F7h3niFjuuvW628CAB5eXCaePCUsH77pmVd2+qtyAkREHTsEBwthYJFp1lAEMEh06fuWwZTA8wpoRf2Y96HICxSEdBfC8KF136/F9de0Dfqbt5sSZMqSc19kOXLhHySfX9t1yFVcn7oaYcD7ORzFSGa5XuKSPVJDns9aixnbMrzAs7/3GGEzj+MnG8SsylNE9+x55BdcO4UN+6BWTGdVFVAw+1ATGv/6W3cxey2yNzw7icSz759ewJXxu7ffjISsul7g8iLKd1DHEYOsSc+7pMpv3p7CDJiZYPQ9GE5TX5exccyD/CbObeQSLCdxbVyer8DXAt74dIzER4cFPiQZrGRYyukIzxfxg+vnMXoMnUt9dQWfn3azxfy78XODkOPmLCoOv/f0HQDAfL3C4aG0vUhkHXDCtJ8tJOoaomTmRcRcN6sVBgR5+WWcq8WqwZIR1gVBGse3DEjViTt8doc2R30RD/iE7SpqbfHml2Kke2sU5/zd//wu5s/DxyVuXrB9hLW3ve1tb3u7FnatIyzfxd3ZCTxbKtS6B2CkSrUKyJi/Ly3lElqbmk4ltx+07xtRRTpBhcQ1KHUfZLov4vMUBnmqTUiDZlEW6Pg5o4se2isAEe9Q14SzshE1y0wvjyLcdEaDva69fIcP8IyYhBm8bbvEcZdVcp01HNkAtHjfcIlNw9Kz21qbPMkdvHrPOKH6PyVGDFrQSJ54ilp39S+SiKVKDBubtfAfOhhCeyt63KtV13teydkPKQCQqNBkCk7gt8JCYtudJucMmp+VSCJYm0KTQmRIBiO0jdTP5LQ6gWPkOvKyTFFKR5mJx2dPUS9jIXy9jffyv7/zDs6v2KxJz1MpnZpnkyQONIzwI0rU530COKR6pDLwEjLJDwN41heF804rpLpQYtXQqmdf59rJMgPL5l9Z5m3bi5zugmoSSInFO+csupbRQNEz0dds6diuN/C8J6MbsW6hhzNo1uCkZlPmBiN+f8FIYnu5AlinfEawRK4CPvd5thBkcX1cLhYJeDOZxHvYLmtoMspPKe9+dPMWNtJITaDFwGb44IT1G9a97z04xvhTZG+YzwEAzWYLw7HeZqRYZDrN1yGZ18d6hm8//AAA8PRxvI5ikMMT0JFz7rfKYXwznmN4I66hs4vniX3m4JDX4VZYEgLereP5p8MKM/JaloyM2s0Kzz6M371zdIDDuzGS/NxnI29gNQw4OozjnjEMuljWaC/jWh2S5n48PoAKzPw8iPOmcA8L8o5ePo/z0S3XGAwF+EW2l9phRlkTee6bsMWWUjSzz8X7dXDvGPkwTtxqE+cou6kwHQ7hXcDiGxRz/Ri71huWVgZlmaUivmxcRmuYTFB4cZJCcNBUA5aCtnN917rmy6Bp+5dZzodJ7zAPSE+VNjodR14GmS7S+1nSd/W6SS84XRSwtWjaMLwvQiryG2E8UEBNRoekrGtM2nCzPD68mVLY8gWZ80HuLJLelKQEK6Ph5WXcbNL8Gaa5Sl77NC/QMN+0lZ41eXvHwaSfQfUMInFOkdgv9I4T0AOACCKA6vWtiIpCpuGNaEvx5T3QcNyMBciiPHr44s79ECSX9LPBAxmRl7Wz8EThSY+ODkBAz8oQz5EjgGlcvsSU7slgV1cRmFKNB5jO4kMdqGFhsoAFN6y33/5mnNPhOH3OSyrXBbgkWUMgSKaSs2Qg1AIK0ASA7G5SuaA+RaImJMSlMF5kpld+1qp/vDdcd5KSnJQFcjJAbFvZSGwiQhZ3RIdePVicgCzP4Xn/O7kOo1GIkq1r8d3vRC2oRx9FkMS6GGE0O+Z5iLIsM2yu5gCA+VmUrRiVExj2t4GEzvXiEvU6PmN3b8a01/zqA3TSZ8gerptHJSpB3HFTsdkIz6mIu94QeTd7C7qI5zvleatQIyd7CnEWuH//dbRUxZ5TPbhZNGjkfpLoNnvtTXzps1+I5918CwDw/PICpubzd0gE4XgAkEqtIVoX2w4T3veOCMKrzuGoiGlEZZiCq9doazpZlMTJ2ww11+CzswtUv/HrAIDf+7+8Ea/t4gm++24EYAQSCjetgm5533k8W2/xqU/Fe3P3Tpz70+dLPHwcwRTrRVz7WjvcyuKGNeEGvm02OPmQ9HACUgtAy/fka5+O1z4eFDj73sM4b5exH+7Tr7+J58Ucznp8hP+JN6yEWmodglcIRjYlqatoCEuQJywuQCfeM8vfuc71ulP0Hv0OcslJTcSpV/L4Xvn08ky/Qx9Q+MRm3dddwku/B6IGUXipORVB9fBjgeeH/iUinZxe9Z8TWiJvQ/qd09JIi8SunWig4FMDtbxEnfcJYhp2qHleUAVMY8AL44sFRRm+zFt4pfwVXOg5HVMRJfT8eKletXO8nUgrpDkXpJxHEFh46jBW8HzQvfM94s6pdJx+rvvjBZlL+bnT0Lj7O1k/Saertf13+R3fuVRTE8oq73w6jhTwvPJQ6qXzuvDqfQivIi4D+vptklLzoZ//tE5Cv2ZkTbQezvD+E3YdbEhzmX660B9bxoT+vidqKIT0zLkQUiQvrQZOuZSdEH5JawAtv7M796u/GE5VSI6TTbpuO60r4iSG0K9vjqVtLbr0HTpAzvXN2pxf50ISoBQfzVqfzptQmKqvvabrsC4KtWIHIRv6OUrn9R5y8xJFlwvpO/27KMR7BiQVgLCT6Uj31Ye0JsLOGFs6ep0NfZO+vBOc6s+9My6Z145r21q3cy39NXn/fca/sy7kp4xV1rRtXWrfSPPWeTjr0+/DJyhmqfBJPvU7zB4/fowHDx783z2Mve1tb3vb2w/JHj16hPv37/+Wn7mWG5b3Hu+88w6++MUv4tGjR5hOp/93D+la22KxwIMHD/Zz+UOw/Vz+8Gw/lz88+508lyEELJdLvPbaa33v4P/ArmVKUGuNe/coBz2d/o67AdfV9nP5w7P9XP7wbD+XPzz7nTqXM9Z7P872sPa97W1ve9vbtbD9hrW3ve1tb3u7FnZtN6yyLPGX//JfRikqb3v7bdt+Ln94tp/LH57t5/KHZ/+zzOW1BF3sbW9729vefvfZtY2w9ra3ve1tb7+7bL9h7W1ve9vb3q6F7Tesve1tb3vb27Ww/Ya1t73tbW97uxa237D2tre97W1v18L2G9be9ra3ve3tWth+w9rb3va2t71dC9tvWHvb2972trdrYf9fRu1zumY0sKAAAAAASUVORK5CYII=\n"
},
"metadata": {}
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.matshow(batch['image'][0, 0] / 255.);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "049FEFBq-9i2"
},
"source": [
"### `clu.preprocess_specs`\n",
"\n",
"The module [`preprocess_specs`] allows to define preprocessing functions from\n",
"simpler building blocks.\n",
"\n",
"The module provides a dataclass abstraction that transforms `Features` (a\n",
"dictionary of TensorFlow tensors), and a parser that lets you chain multiple\n",
"of these preprocessing steps with a string specification. This is very useful\n",
"to try different preprocessing functions by sweeping over a hyper parameter\n",
"that is this string representation.\n",
"\n",
"The module does **not** include any preprocessing functions. They need to be\n",
"defined by the user.\n",
"\n",
"[`preprocess_specs`]: https://github.com/google/CommonLoopUtils/blob/master/clu/preprocess_spec.py\n",
"[`clu.deterministic_data`]: #scrollTo=Nx6H0936Q3N0"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 474
},
"id": "n_jlWJcSIZf_",
"outputId": "890f394f-a00d-4612-b40a-343899632765"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset tf_flowers for split train, from /root/tensorflow_datasets/tf_flowers/3.0.1\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmYAAAGkCAYAAACb5OmoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9W6xt2VXfjf5a732MMedcl32tvXddXHYZGxtjXICN/dXh5PB9wQcn4uRAkiNFiAeLREFJbB2Io0jwEEiejBIJJUQIIkUJeQockAgKCSiWIUacgAGDPxwcHC5lquy67utaa845xui9t3Yeeh9zzrWrbEA6UK7UbNKuXXuuscYco19b/7d/+zcxM2Nve9vb3va2t73tbW+vuLlX+gH2tre97W1ve9vb3vZWbO+Y7W1ve9vb3va2t719idjeMdvb3va2t73tbW97+xKxvWO2t73tbW9729ve9vYlYnvHbG9729ve9ra3ve3tS8T2jtne9ra3ve1tb3vb25eI7R2zve1tb3vb2972trcvEds7Znvb2972tre97W1vXyK2d8z2tre97W1ve9vb3r5EbO+Y7W1ve9vb3va2t719idir0jH74R/+Yd7whjcwm814z3vew6/92q+90o+0N+CXfumX+Ct/5a/w0EMPISL8h//wH8793Mz4vu/7Ph588EHm8znvfe97+b3f+71z19y+fZtv//Zv5/j4mIsXL/K3/tbf4uzs7M/xLV6b9uEPf5iv+7qv4+joiGvXrvGt3/qtfOYznzl3Td/3fOADH+DKlSscHh7y1//6X+f5558/d81TTz3FN3/zN7NYLLh27Rr/8B/+Q1JKf56v8pq0H/mRH+Ed73gHx8fHHB8f88QTT/BzP/dzm5/v++7VYz/wAz+AiPDd3/3dm8/2/ffasledY/YTP/ETfOhDH+L7v//7+c3f/E0ef/xx3ve+9/HCCy+80o/2mrflcsnjjz/OD//wD7/sz//pP/2n/NAP/RA/+qM/ysc//nEODg543/veR9/3m2u+/du/nd/5nd/hIx/5CD/7sz/LL/3SL/Gd3/mdf16v8Jq1j33sY3zgAx/gV3/1V/nIRz5CjJFv+qZvYrlcbq75+3//7/Mf/+N/5Cd/8if52Mc+xjPPPMNf+2t/bfPznDPf/M3fzDiO/Lf/9t/4d//u3/FjP/ZjfN/3fd8r8UqvKXvkkUf4gR/4AT7xiU/wG7/xG/zFv/gX+ZZv+RZ+53d+B9j33avFfv3Xf51/9a/+Fe94xzvOfb7vv9eY2avM3v3ud9sHPvCBzb9zzvbQQw/Zhz/84VfwqfZ2vwH20z/905t/q6rduHHD/tk/+2ebz+7evWtd19m///f/3szMPv3pTxtgv/7rv7655ud+7udMROzzn//8n9uz783shRdeMMA+9rGPmVnpq6Zp7Cd/8ic31/yP//E/DLBf+ZVfMTOz//yf/7M55+y5557bXPMjP/Ijdnx8bMMw/Pm+wN7s0qVL9q//9b/e992rxE5PT+3Nb36zfeQjH7Fv+IZvsO/6ru8ys/3cey3aqwoxG8eRT3ziE7z3ve/dfOac473vfS+/8iu/8go+2d7+OHvyySd57rnnzvXdhQsXeM973rPpu1/5lV/h4sWLvOtd79pc8973vhfnHB//+Mf/3J/5tWz37t0D4PLlywB84hOfIMZ4rv/e+ta38uijj57rv6/6qq/i+vXrm2ve9773cXJyskFu9vZnbzlnfvzHf5zlcskTTzyx77tXiX3gAx/gm7/5m8/1E+zn3mvRwiv9AH8au3nzJjnnc4MP4Pr16/zu7/7uK/RUe/uT2HPPPQfwsn03/ey5557j2rVr534eQuDy5cuba/b2Z2+qynd/93fz9V//9bz97W8HSt+0bcvFixfPXXt//71c/04/29ufrX3qU5/iiSeeoO97Dg8P+emf/mne9ra38clPfnLfd1/i9uM//uP85m/+Jr/+67/+kp/t595rz15Vjtne9ra3P3v7wAc+wH//7/+dX/7lX36lH2Vvfwp7y1vewic/+Unu3bvHT/3UT/H+97+fj33sY6/0Y+3tj7Gnn36a7/qu7+IjH/kIs9nslX6cvX0J2KsqlHn16lW89y/JRnn++ee5cePGK/RUe/uT2NQ/X6zvbty48ZIkjpQSt2/f3vfvn5N98IMf5Gd/9mf5xV/8RR555JHN5zdu3GAcR+7evXvu+vv77+X6d/rZ3v5srW1b3vSmN/HOd76TD3/4wzz++OP8i3/xL/Z99yVun/jEJ3jhhRf42q/9WkIIhBD42Mc+xg/90A8RQuD69ev7/nuN2avKMWvblne+85189KMf3Xymqnz0ox/liSeeeAWfbG9/nD322GPcuHHjXN+dnJzw8Y9/fNN3TzzxBHfv3uUTn/jE5ppf+IVfQFV5z3ve8+f+zK8lMzM++MEP8tM//dP8wi/8Ao899ti5n7/zne+kaZpz/feZz3yGp5566lz/fepTnzrnXH/kIx/h+PiYt73tbX8+L7K3jakqwzDs++5L3L7xG7+RT33qU3zyk5/c/HnXu97Ft3/7t2/+f99/rzF7pbMP/rT24z/+49Z1nf3Yj/2YffrTn7bv/M7vtIsXL57LRtnbK2Onp6f2W7/1W/Zbv/VbBtgP/uAP2m/91m/ZH/3RH5mZ2Q/8wA/YxYsX7Wd+5mfst3/7t+1bvuVb7LHHHrP1er25x1/6S3/JvuZrvsY+/vGP2y//8i/bm9/8Zvu2b/u2V+qVXjP2d//u37ULFy7Yf/2v/9WeffbZzZ/VarW55u/8nb9jjz76qP3CL/yC/cZv/IY98cQT9sQTT2x+nlKyt7/97fZN3/RN9slPftJ+/ud/3h544AH73u/93lfilV5T9j3f8z32sY99zJ588kn77d/+bfue7/keExH7L//lv5jZvu9ebbablWm277/Xmr3qHDMzs3/5L/+lPfroo9a2rb373e+2X/3VX32lH2lvZvaLv/iLBrzkz/vf/34zK5IZ/+gf/SO7fv26dV1n3/iN32if+cxnzt3j1q1b9m3f9m12eHhox8fH9h3f8R12enr6CrzNa8tert8A+7f/9t9urlmv1/b3/t7fs0uXLtlisbC/+lf/qj377LPn7vPZz37W/vJf/ss2n8/t6tWr9g/+wT+wGOOf89u89uxv/s2/aa9//eutbVt74IEH7Bu/8Rs3TpnZvu9ebXa/Y7bvv9eWiZnZK4PV7W1ve9vb3va2t73tbddeVRyzve1tb3vb2972trf/lW3vmO1tb3vb2972tre9fYnY3jHb2972tre97W1ve/sSsb1jtre97W1ve9vb3vb2JWJ7x2xve9vb3va2t73t7UvE9o7Z3va2t73tbW9729uXiL1qHbNhGPjH//gfMwzDK/0oe/tT2r7vXt22779Xt+3779Vt+/77X99etTpmJycnXLhwgXv37nF8fPxKP87e/hS277tXt+3779Vt+/57ddu+//7Xt1cUMfvhH/5h3vCGNzCbzXjPe97Dr/3ar72Sj7O3ve1tb3vb29729oraK+aY/cRP/AQf+tCH+P7v/35+8zd/k8cff5z3ve9954qw7m1ve9vb3va2t729liy8Ul/8gz/4g/ztv/23+Y7v+A4AfvRHf5T/9J/+E//m3/wbvud7vueL/q6q8vnPfx4osO7eXl029dm+716dtu+/V7ft++/Vbfv+e/WamXF6espDDz2Ec18YF3tFOGbjOLJYLPipn/opvvVbv3Xz+fvf/37u3r3Lz/zMz5y7fhiGc0THz3/+87ztbW/783rcve1tb3vb2972trf/v9jTTz/NI4888gV//oogZjdv3iTnzPXr1899fv36dX73d3/3Jdd/+MMf5p/8k3/yks//X//vR2hnjoMu0PkWgFZafNugNoAFjIEhK+uorPMaAFPB4xHpCdaQEqRomAmmkJMCkM2hCjHDMA4MOZFyxkzxvvizrkmEmXB0AIezQ46bSzT+kHlo6RpPoCnXWUcjDrNMjCNqiSYYbWu08/J9XeeYNSNda8w6OJrDwhtOMykOpPpcThI5D6xz5sXTkWdOek7HAVXBuxkAY5qj2Zc/6siqmAlKAxZQKW3YupakGZcVSx6sZS2Rh0z4C/fOAHjrr92hfeaUTg3zYAiNbxExknMYuTwX5aYmipkhIrggqPObn4l4pOtQ8diLa5YvnrKEeodisf5p6h8DRkCBnm38XXauuwgsBAaDAVjXa9J999b6p0VYYXT1uc7E8AYm0Bm86Thw8S++m/S1b0cffH3twznh2TvEz/4+nLyInN1BV6f41RlutSaty+EhrwdczKQBRCBliKn8nabn0vLsVt8j13cZ6rumnef19Trq3wbMg+PAlbF1KMJBULJF3No4MUgOVMu9F/UdLzx6A8Rx++Qu2TIqEMfEkJWzWMZWb+DF8Z43PcbBjWvkJmFDQs9G/NldUhxLW/Q9mkZySqCKJmOMYJnN2BpyedYH/4//B3zVE4RB0bQkn96lUdDlvfpWuY6NAGpkHXHiMQ2kOmfdyYt435JNcGmJBEduFnhT3OKYoesACJevk970EOHWLfzv/A7jIiPtnOa45fS3/08APvWJz3A3wP/xf/0rzL/9b6H/8zPoz/9/IN2FlEpnATaOWEpYjORhTe4zaafvprGa65h0tS9bKX2WaqcN912j9bMAXOpmhOMDAJ578RaJMsYb4Kz+7XbGsANmOFo82ZXowZLMsj7LOI2vzbVsvvWNobzes2yfS3d+R+rYyoDiQPz2c8vM0M210cPs9dfgO/4CAC/caHCm+JzJpsSxZ0gjfezphyWptpiSUVXGcWQYBlJKoIaZMeQyBscUcRLxXnBO8F7AlUHlQ4uTsgIkEZrgyT3ktaM/HVmfjQwrWK/qyLozMKTIfCEcXQlcvrbgwrVjZsee7siRtXSSL6+KmZCykDOoNuSspKF8d7+KrE4HVic9pyc9/SoxDJEYM1+R4X+v3/k19+BwgPkAB4m6A8BaoKnfM5vPCN2csDhA2gXaHhC6A7RbILMOk/JbkiI6DFgeYRzRfkTUGC2RrIxTVSVpJmVlHSNjGliOa8YhErMSQt0b244QOlwTMBFyzqhFco5oKv2ThoGcIyN5sw5PfT5ZqH92161pXYIy9qdRJ/XPJaCjtEVYwNkbSpve/soF8qYLpAcP0FkZjUF6goygy7K+AJohl6WGIcMqwlkPZ2tY1gccIwwKpxk6gfd9Gt72/4U5dX+oG0ez6IgXPfgDmB0Rjq8QH3kU9+VfRuPKzdK9p3CfeQp+/0leuH3C+jRyD4g4IobVt53aJ1LmWhKhMeOtNNx4+9eUL7xxiTTr8I1wdvt53vCxX+Xo6IgvZq9YKPNPY9/7vd/Lhz70oc2/T05OeN3rXsfBQjg49My6QFsXEW8B7x1GIMgMc56YhXZMhFgGQ8wZzOGcESTgkxJGAxyGQ+vOkpOQVDgQpU+eMSf6cSDHEaQMGAmKbwzfOcIsI03ES0Jcgw9C68toCGJ4tbJZiiNnD5KIZEIdMK4TujZgPjNIRFMiirFw4LxiWhe3KJzEyK31wHOnA7dWY9kEzOGkDBhxDvOCdw2Nc2QMVQUTxHnqZbQ+kMxBzGj0YJ7g4JI5rq7LhL5ojoDQiSEC2aBRxWEkMaiLm5kiIhgZNUOCx+MhNEh1NZxGcjJyewBHC2ZjojtZF8erPtMI9ALOymQWgcbK4Jf6OZQFIDtoDQ4MZlZ+17NdDIOweVcojlcwGDA6INYJds0cg1O8Cm+66Lny/3wf7hueoL9yg+akbIvjH/wh3Wf/APf854hupDk7xdZL9OwMN2ZyXxe30ZDsSTmTrGzkkspzTOi1Z8fhkrpxW9l/vG03/unRpf7H6jWdGYf1h8cCC1NUDOchaNncU333TVusIy54upQwAXMw5IxF46De66ANpOQIQ+R4tWJMJ+h6jfWZENdQN884jKSU0FgWy6QgWt5DrMzFQKYHZvdustCRPD/Ej5BjxuUBG4rboDbiHIgFVBRCgfyxQO6K06LDEj+sMefw7YzUL3HxLt57/Kwh5TK+wnOfI108Qm+8jnC2xK2eI9ORY2LxFW8AYHjmGX7v2VOufsXbiWY0v/8/iacv0szLhjW1uk4dIpDUSAqDK2NsWjgb2Y5bV9t7Vv9f2fZ1ovSrq2Owk9JmPo5cdiWzbjULnI0JUZiL4My2jrtsB0SD0iKIGZHt4WNy/GDrmLVMBxEYU/n/ab0xBRGhxTbvQL1X3hyl6rshCGUsTWN1vhppcvnF5eGMRoB1z5jL5pb9tO5lZLNGCOoFpDh+IRqigmXDSxn13gIaijMoXjEB7wQXPM4p1S+jdQ2NBwmBwRmmDlWPTZMKWC8TXjNgaFRijDjn6LqWrvOT74lzjmyKCjR4LAU0Cxoht6Unm+DxrrR7ThlTxarTey8YJ7ncbGyUkIzGQ0jFMfC173z1sLtxZBY8beqwJgNC4wPatEjosMkpFsFM0TFhTlAnOPEMKZOsvGOy8hwDRhBhbaW/ejFGMXwNioWsNKI4yYh3JCugQ0qG5vJgwQyta+R06Ojr3/XrCHVaOCvv5SiHMaOsT029Tqz0gnOw0HLdHPAB+rZcNJs36KxDZ205EZsgKGBYbjdjUC2RXbmXSB3rGRi3ayTOEVxD5wwflNUjiaNWuDQaAbiq01jOrDXTyQEcXCW/8csJb389cilgTxUwgtUR8qa3g45w9w7PO8GrcVcUby1W3dXAdn0VKRPjkex4/fWHObj6QB2oM6xxoGvc2dmmX7+YvSKO2dWrV/He8/zzz5/7/Pnnn+fGjRsvub7rOrp6It415xLzpqNxrrjTlE3CSYv3DucNYYb4TG6UMS0AiLaujtWsLk6R0AlCQGyB1gkWfabBUBXmTcKnhHOB6AsSBxCalqaFWZOZtYGuAclK1IyM4LvqmLUe59Z4a3E09NmzHDIaM1rRN2kN1KFiJFVUB1qnzBrBWyanMhiWp4m7w5I7w8jZKpOSIgIiWlZ/oOtmOGlw4jAJiFkZ8KIIIyKlPXNOZKBtAt41pChgmc4HgtUhp9VpcJSFIYOmjIhWR7OiYXWGaJ3c4hQPiOnmhGHeQYzocAoHR7gHDpnHSBoSugNtuYr4RCsLgcFm05wWiLonYDuLyEBxSDaT1bZIwPQnUgb+iGw2srI5BR51iQe+7u0MT3w9YX6F9tk17qmnyzM99WnivefoGNGTJXq2JK97dNUjY8Yq6pSjYerJWchmRC3oSdx55izn/39DKLiPWDChT37rL5Rly6qjTV0wtWxYIuVab6B1M52aNY0jYi05KuaNxgWaNjBaxE3HYs0gcLZeceU0IOMJPvYQDR0HklVUYzByRQKzFX9Np++SvHmVETh59mkWd28TL83wzuP9jJxXG+TDkhakJifMDOeNHBNBIHTz0kahg/VpWfzCQXFy+lOkm7HuT5gN5RBhfY9+aoVffwXp0lXcIhDu3IH+Lnq1OIIPvPWNxNu/h1y7Sji5i50+g28MqS+S6iZFVlzOaNZyYqeMx8jWUUqcH2uuora7C6tRnDejbGQFgYLkYchKqg59CIGUUkHtTRCx0rZsN0Wph4+KfW/GUb5v6Fh1DJNBQpk5YalGdDDujKNAcdSiCudZLYpZ6R9XcWZli2CrgZ6t8Sd9uTwekL1ieSxIjGqBTy1jWcmpjglTVHJxBGrkwqw6hlZaTQS0RidCAnEZazzBebImfIW3nEFWR9N4nM/41ghNWTtcPb01jSclh6iiSYl9RpMCDmdCqJEWKScDvBNwLWTQCIlEqMi0mKBJiF2i6Rp8P+Kc4DzccY5n6hx6MRhHta2yg1iR68R2zvushDHiwoBvZhAUUUHVqkO+XdxM2fTNdChVzZv5n3M5ME/Io6mCKpZL5CLn6ZhXv5sWVMgGKadywKr7J6bnxoHWtSmzPWgoZU2W+rPEFj1DtnNjckZtZ3wqIAH6rvR1bmZk6Vhriew4dXgxfI1c5ViRqZQYc3HGhlTmWJ9g1C0ynVTICOKM2C5YX1wyHmfSzY5AJNY36DQh7hr22NuwR96GvuFR9JELaLhHSmVf7z7zDHIxYPMDbFRSPaiIK+05tdDIdAASzjC6DNdkzvzB69CUcTNqRKxBxkha9/xJ7BVxzNq25Z3vfCcf/ehHNxwzVeWjH/0oH/zgB//k93GZWWOI5Q3KZRbBBEeHiOGdK3+b36AVzgXUKYZg6hAnZRPwC4TZxjFrVTESYzIg0DSJWevJGnC+LPJdZ3SN4Z3RtY7WN5gLkF3x7H0ZDG2Aw7YFa4iDI2nCBohJWdeV0pbGymdwhmoijgnThLNM1shQ4ebVKnE6rhiTkZPQKIQALrA5YYQccU3AeQNTRByN8yBlAbbaGBYhi2HikcZDEsQU5wJe60KpUiaoVIdgWmFcmWjTIjidmHDgg+C7Ft8EUmNIX6e1lQd1OeLTGjc7wN24Qrp5l+VpmRRST2JSoYaNz1IRhA2KVJ0VkzI5ywZWVovJ4bKdk9wUxpzCSR6DujBrVh5pMg9+7WPwDf93wmIGd2+TPvss8gd/UPr69DYxrSCONM+fkm1F7kcYEzYqVvdzUxhjRmsoz3TrFG6cJNu+h+680/3mrC6Oev5zs4J4lXYQzEFAUAdBrW7+hnMbkIscRw4WhxzOF2Qy2QqSEWSL5KkawQvDOLI+PYFhCVrCTWM0Ul20UyxDIGn9GyHVoPa0rgvFUb714vNcu3MTu/AA4Cvytz0xqmpBes0QczgEQzFN+Iq0ZAAtG5Ca4B3kcVUQwqGjj3fq2DFmd3r0d05xb3wbevkS42xGywKr82z20DUe/aoGmhlhnRhtwJwnqJUNSusmlRI5juQxk3NxTKaNakKYNpsN238nzqO0uW5W56m+ZROOwHpdwrXtvCVpTxAhmyLiUcvsbquujnmhDKCIMMrkgE2Hl/Ikm2cTwbSE7kfb9o+rEKwVmHOz0RZUWqsrtu3L8jt1gzZgSIR7y/Jt6YAxKy5HUk5kLU721A6bQ5tq+f6KsGp16lM2xgmNTUrCIAsWwLIjV8RdvNuSpk3Janiv4HJZ6wpATxorvaIFW2VyqutYDaNqnOGlpaU4/q7xdN6RXR2bIoiDLBHNZR10FshpYJgn2mWiaUZGrziXWTnl+VBe8sVWuO6NRR0Lu47L1I7RwKdIiImQtIQjU0LMFUrNRP3QgtChUpBFQDWhqqhOYEQqY0V147CVdjbMMjodihU0GV4V8QEzI2lGLW2ukeoUTmvldKBQOT/mpzEx/WzXpvGV2YY1p38jkGeePC/trmHGKB2jdbjckq0nG0h2qDqqP8+QlXEc0CwkNcYMQxYyhtSolCeAeVIreN9xeqjceuSMw1sNMxuQ+WG52fVrzN78OOOb30Z77WH8McgStO8I1x4GYHzrLdpf+W3ShTmhm8G6ZxApUQ9sS0mhRHUmxHoR4HB+Abc4xqpjJrnHaSaPPTrF2P8Ye8VCmR/60Id4//vfz7ve9S7e/e5388//+T9nuVxusjT3tre97W1ve9vb3l5r9oo5Zn/jb/wNXnzxRb7v+76P5557jq/+6q/m53/+51+SEPDFzDnDiU7HwvqZYJZQ83gJiIs42sLdCZWHYgHTABLxDQgdjWtp/AzBk11x01UdZi1dmxApJxkJAec6fCULzFuj84ajBYlYMrI6nAQa71jMCiJz0DUspCPnzCCZYexp2kRWR6pHjtOzDGHAuQL551TDAFkZstFXzFYzrNZGiuVEBbLhd7lK9GzMl7CDi4hFMI9zLVhbkMJ66lRfbpi08O5MAriCqrUVTTKTGjor4UXLJVqlUIi59Yg0na/FO5pZi+saEMFpxEINiyYQy7jG4XNE4og7XnBsx0BJ/04nAzGDuoIQpJ0T2C5wpPV031JDlA5E7Rz6ZLZFzSYexHQaxAVy5Qrc6ISH3/M67H3fgrzpLcj6BL17C3vxKWRVtPXi6g6yXkEfcZawwdARJBYuyhYOA0ub/yXV06Zum2r7DPUZ2UH0dhE0lS0yWFsXh21Dn1DD2FLDExnRwk/aff/yD8WhNAKNb1A8IsKoWnkdBTnzzhg0sVxGbD2QXUEgU9p+Z1Qhq22QylRPkbv9M6FHJykx3L1Fp5nkHM5SQQA2yJriTXBWoTvREnrXTK5oEmqIC+X+MYGVMLpFpXUlGQbAjwadI5+cEJ4EiW8mHhyS3EHhhgK4lvmj18lO4PYSdzaQKxZkZGwKZWqCrFiOlXcFGUFriHF6R9t5V6lIpuw0RK5jwFkJMU8cmSksuYwFE1vMZ5WnZxvu0iY8rNvxkKVyd4wNSjmNsV24UjfXGT1GdiUJaELuXIW/dHrunb7T3dC57IRftx9jmuG0IAA2DOTWsJxQzaW/THDm8BJQ2eEpqNtAMaaGqpGzkWK5JppCKqF4M0UqIp+dItiWn9M4ME+MsY5fJThHDkZoSoMNzhAPLgs5KnkEjSXM57LR1DCl9y3qDHFgOHxweIPsAzlNuI+QszHMZrTzkbZvGIdESg7TzK162ecWjocGYzFCGLdI+RT2goIkxWyklAiqJURZEWzTaZWicgRKiLKs9JCn8bntyU34uCRdlevLaNaKf5Xwp5iSVZFcmJLZEomM7Sxe01ie/uQp6lCbfXftmcb+xCebxt7udZ4d9CxAmgXSvFBpUpgxNnOyaxE3J8WC7Il60IZUXZQxG32OrAcl5dIspf0ooVvKGuhEaJs5jbYMC8/yoRH9P9e0Adrrj5UHevz/hn7lI3BwjDqHCxHWiXBzQB8opPzxHV+J/v7v4P7wGXBGA9wNhYu9m0zmar82dVhfSI7Dg4tw4Ri1st40KpAVHUesX/MnsVeU/P/BD37wTxW6vN98MJwkVAVXB6l4B6Z4KQT3MokF58ImO6XDqHxQJCiNa2n9rDguKgS2I8q5QDfrCulTpHLFlFBX3oMWusrkTnkgDX1JNmgDB/OWRVdDntJBdqTRCJIYk2Otgia/hfBjJsYy1MVlnPcE3xBc8VCkfmdmTa5crZgzoyRyBqeFbFq/kGxGsAziEbTAvqKY+c2s8QiiRraxhDEUcNXJm4BoNciV22Xga/bqNC0nKFup4VTvUAxLqSz6WdFu6h9w6hDnChctDqhFugsXOMgHm3YY14mo5zf67Qa0DddMG5zsXLMb7pyWuJ3I2eaeqUlc1dI/r3vHA9hf/mbk0bfC6gz3R89jn3ua7vQ22hZewOAHfFwxrnqSRrrUIFnJsUb7atvnPG1gQqzE6lybcZenMTlgzrb8rEQNG+y8t23eT+p4nnZMmz5FNGPmyFYIvFMSgdgO58NAU2TsB7xzdPOOZnaAec+QT4HCP5KUiM6xTAmfjNHVhIS03QqybN9rIp/f75hN7T8AJy88y7Vxhc0voJoIFlCbsnlLR5c8EofkyTEzzGp42wtuNi/tncZC4m4aHJ7l6h7z2viWYsnKajx2cpP4rGd2+RG4OEOnh+uFOC7xtz8PNwfCsi8HlRTLil+fS3KqjsYUri1/T+Ed7nvniXxvtuWS7Y453RmUQcrYXQtljgIzhUaEtVnZDMxe0p5q5caJwiWMO31gnA8dTdd7Jl6c0cqWS5bEaAFxUp2B8x23OWjZ9mV0++MS0j0tG432PdELTnM5wGpZj8QcHk+SiWpihbtmrj5HddLUtkk9CqYCzpNzRkTwOHKy6tAURzY7wVtZs7s24Cs5PATHWMOKvnW0bQsDpNSXvswZTSWEa9OXUugeTjziHcEFGufJvmQcT5ay0Q2Z+aJjXI+MfSSmgVkM3G1KPz6N49G5cbi24phFNtm2GxfPqmNtpZMsa0mQsIxpYiJjmCbQjOWI5ojtOGBmWxdBVcmm5JoVa5UzuGX31l6z6sSq4DESiSQZq6ek3UPGxKWcDsQbfly9ZuLNYeczMO93G6dx4wEaGGZCnpWfxiYQuwbXzskaSNaiuYAblj1av3TIC8YUiawZNRfengMfPJYnLmSLuIacHakNCA3pYsO8Gblw8VF42zvLw7z1HehjB7QxM56ekM8ybdthD4TNQ88Xb0Df9dXIUz+PjwOtF1pKtv8unzMwZV0bTYZjZnB8kdS1hGXZNzKF/pNNuW+WfUF7VWRlfiELjeCDQVRyXWzEGYInBAfOcHQVdUg0buITKOYjmmYkF2kaYdZ5HB4dAhZCvb/SzRyHi4bgBeccTdMAiqsZkkezwEEbGCRh2jGOHSlluq7lwmHLPJSN31tLvx6JTvAW6FNLFzO9DRsSkFkkj+V06YLQek/bNATXQnKkuoiccQazQONGlqsV63iGGcyy0YSyAGaLiBoue0LwOCvupkjNdNKyRBiFwOSMwntTXwa83y7gqorTrRMh6urkrhvFlGkYykQR7ysXQlFROr9A45SpZeQgEGu2lRc6Sww60lRC6GKxYDmeMuTikUwOgFX+xzS4M9vEAEdB83ZRKdhyI7QiFlJP/0MDFwZ467U6Bf7C++hvvJlmfZfw9G3yZ/8ndu825CXptPBo5mcKYyjO81kmhbx5Dvwmo59BqQkJtkFa7t9gN/wM2GS6Tg5kvu8dtoud4XYSFnzdLX1JW8F8XfCd4IVCZmcHRMmU07MpsR8ZxxHpGiwITTNlgRkSC79pGMZ6Qi8LxS5nZNcpm5wC3fkbzi/yd2++yJX1EuYXMFeW6bxDslfL5Xdz2bHErPCfbIvkaeggD0jOjEPP3JVEmjb1jBOxd+HoTldkLpGdJzz/LEiC2UOEWIneoUP6m4y//9u4XsjxDqS4Ja1PpLy6YVreOma7fUrtr11n29m2v6a22l3IJ47jaCVjOAGTQqPGVAjuOWEYoe5+E1o1teku163u6+e9wJexCEQ1up35sz0MGEFchSq2Dztdt/s+u5/hiiMM4FTB+0pI327JHnduzJakgIlMbxVZKy8xZW76bGRzNA5EWvAV+8kjzjU1GQFSypi0OCc0zmNOcc7hPThXURTfYJYxzcWRTEa/Guj7EUsLYs0gdeLxoUHEYyKIC0gINE4wHzfP3sZA1zV085Zm1hBaRxgcc3XcrY7SCz7xfAPXWzhqSuZpRs9xzLKWxAA1IWnGW3HA3KQL4er6XB011YRoRvKWQ5YmKQmkZGaakXMuZH7dumQbB3saSFYYUZlcnAa2+yeyRfh2D5O7PLLdQ69R+ZP1AODkPPIqVI7jhPi1oJ1HujIGtHHYrCV0HeMaxLUQDI1CNE+qa1i2ESWRGOlTRhWaxuFc2O5lKoh6tIXUNqydMFsseOCC4r78/wJv+uryUDdaVFp48RZ+uQRxpMWM9vAitrpXvm+RkC97B+51v87sxTN82zJbDyxrW+Rtc22Q50fCjIcfeIh8/TLhLG1g8ykTVzIE23VXv7C9qh0zcYmRBnUT9Q5am9M0grcSnkm6ZsShriHUPN5gAdXAwEjWUCDz1OEIKCvaOmgOuo7Dw0DoWhonHHQB5wsk3Fb0bR4cc4ED54nREYPDh8xs7ll0M0QLZBt7uDDv0NZx6s440USbPH4IxLHqX0lA/F1a3zCbLTiYH+Bdg/cd4hqGviIMYUFs14ypx/xtdKWY9DgfCaGcYLuwKOn7WtLRs0aCbyvhNxToCUgaCwKVPYqQpYQfAoZVR4l2Riv3StgzgaSM8wVgkJbtOiyQyXiTgutK0cSK0m+S791YNlsTkNBg4km+o8mJVFOoF1fnHPUr4jJxZjubgrqyiMj2+0LdS3oge6akwu0GUjcvPzk+Vl69i/CwgPzFJwDo3/wW5hHSs09hf3QKt29iecUYe6QpznW+8QC0c1iP+JMT3LOfYzw7xftcJBFqpCzsTFoQjJJcnWSLXmXdLoCTmWxP0uecG1f+pwGw8v5BwCot3CyQQ4OLGVPoxfAqzBCWbBMh+hzRvidlJWbFksKYSA6kq05e53GmtDZuSOciJQtqgu1h64xNp2QmeQc53/ap9t2ds1vo2RntoZDbruTp16xMcS2WejSv8aFF8CXJwCVkrK3oXEV2E0LRZhpiYjbzKIe4KTM4BLS7i5cEKZJyprn1IlnmxJrhOes8eXGErk5p7p2VN0gB8z1h7HHV2eg1kYt/WDJ+fZF12HU+8867bpZc2zqs078b27ZXOR0YWQvqVXONWA09XeMK+l2TKeYmnIpN05WZlbG+NJiZ0Lsi4+EMTGSTnKGUMT/1xVi/e8BKhi+FAuCpCJXTGgpli8ruOGhiVa6GIgeyFsfg4cp4u4wbexjoMF+dau824z7PWqxKFTkcMa1xZ5GzRhg1YpKRIPg6uJat0I1rkjmc94gWKYvkPKaJWVPWJR8DuIw1sB6NLrTITDnph4KSAePcaA6MrFoyKscimTEMa07XMy7OJu2N4hBYAKvRkcZ5vHdlEwcsOkaXaV0kSIeTBu9DPQgl5hO6I/BcJ7yhgyvrjPfQ5a0WI1SZFQWLqYRVVfF5JBv4zQLHJuEpVOJ+pjjBRXqlDJykCUwwyziDkB1RYjlkGZskoSqDiUNKeFHSFk/b8aZ2D5MbJ8xkkyCgdRxMSVpT2FuQelgo102ofWsFmW0BnUE+SPTzMnP8vGFuDesQ8O2Az2Cjx0nGed1ER5SOnBIaZ6hXomSaHApdqSoRmBiaI/Pk8GcRHwLXLz/M7Kvn6PUvw107Li8zV5o/fJrx7AQ/ZkQbGAyNq02YPHze0b/pQcLjj8OnXySMy818OSe/RPG/HgYeu3iD9u1fSXKBbGvamumapMVZGVPZ339Ef3l7VTtmI0qbe0KAQBWYFGi8x2mDipVWM+opoS7eEmgaB64jj+sivCcBTRnnhYN5aZbjgzkXDhvmbYd3cNC1iFOSJdqKTM0ctDV7J6US1mtaYzb3dM2s5MQDo1MkOeKYaLtAOwRCSOVgVPkXWUdCaOm6OfPZIV27oGk6mjDDrMXVpX4eRlKasRrPGPNI0pExprKw+J2QAYaaFgHBGsdxFgBF1W2uy4W1hGo5PZk4RrVNuryIlEUk1NBIPTo7EbSGzcp1pblVC9fDpIRIbIfBIEYJy7pQtNam01RULNQdqg0cHx/Tr2/T6zblPNalYaKYZJ3gdsGbEfI5mtfm72khATaSB28GDt7zteSv+wulHeYt+fZdwu0ey5+Dy5eR5joiAedLiNVdvIxdOIBhjbt5G3vqczQvvAgvPgenJ7hQHOxmjNAbg8K6boRt3fS2EaOCOga2obGJD7Y78Xf3yGljd/V/NijKxAGcUBRjwyyZfm/qa031xM32HqpFew1ASXRBitBGlSzBtojjJtuV8+jY/eHiqa2n6wdNnN16nstXX49vHLhEqMj0JlPfrIQovRVHS3bQ0YoUFETQoc4Xj8J5xAtSSX0xDTjfgCQ0DgTvGfoR9+LvES6+rjzreIiGiGeFNp4hCAecMC57NOdN+MSSMdQwVHLCmA2rWa7TOJv6zk39IxvFmo2ZbJ27qfXOhd9rI0YrSIBBnYtlk9tt1nzu/+0ckrFr0xiZfmfKDNzdWNTuy7JkGoc7D7f90SZcVeaU0uQinVNuVugN07siVsNjhSsoVQvHmzJ6IbUecsYlxWImmaKu9GETDaNl1Agp0kptV++JzrOuC0CbcmFcmiuosTjwQhOKPBBAE4QUHHhfw4WQx0zsB4Z+xSqWQ3iIka5pCyVGgLYIMOMFqamBzhveCyE4fBBC42maQNM0mAa6KpczpMzzwXiqg+M5MJTtZ5eGOumA9ZoIsUfigGtmhbdXeaJlbBQENyvklMgpklHUdJttaYZamR/JCmXFLNe1fbtObF2Csibfv05s+p8tyruhXMjLHEJqNGJ3uAjb/WDSOJt0/OYNxDmkgw7fVo5Z18LBnNa3OJdJCskb1oTS9pWiINIQzKO+QWLAq5G8YxShysyV9xVw5rDseEQcj6WLuMs3GK/OkbaOr9sruHeG61eQCy3Jesg6bFAuJ4H5LYe+6S3YtV/g6LMedUZUPSe42wML4OFwwOFb3kB2ifa52/DgRaxKUpll4hCRcYVLu7/9he1V7ZgNKeLzyKJp6CqCRYV5hYCZ4p3QeEcjDXlyNAj1tJ5ogqcLDU48SZV557h4seidXT045PiwJVi5x6zzmCiZwMRlnzmhEUWsIaWAukhoMl3b0IUOq5B0VMXcSK8wtK6GjkZU04YrYGTadsZ8dshidlSdswO8a1HzhOpte8tkHWhWDSZaUC8ZcB6kooLO+cIbkVTg6EogLdFQh9qWjWLm6p+yDZgqg5VUZAATV/hFlcUpuYYNs+HOjSDZ/NdUtxIPWk6hQNXa2q76E5dInMNVryXlHmk9Tdfi1yOucs0mwdTJVLbKyx319MYWpYFKyqRc5ygq+FdFePBtr2N456NYLNyq+TMDMSraXqV5y0N4WnTekRHsrHoOroHDY+yCxw4j/uBR3IWnkflnyfc+D+uSvOCHNc3pmnC2wg+w1on462hqm+ZK7HbsbGZTf+yGk2zriML2/3c3+hLGKLxBs63jcP/ZLEMVOdYNFM/0VRPffSxjSHB47zHNG6RRZZuIMYU2bXre6V47DsH0fgHHOit3Pvckl77scaRZYDJsxkEJcxSEYkxaKkdUhVZfk1RSSmCF+C0ulMEoBq7QFuKyhJubxpN8W2RwrCf1a6SZ0Qw9+d7N0g4zgy5j4y2StoR2wZjvEFImpsww1tBVRTnWwCgFn5zaYXezk+lv2To/BuecmymkPrVX0O1mOLVp3EGpcv0db+cR1I0TJWVTmMKZG4dqx6Z/xhp6mZ5zcuYyNfxedTJ2nfhdm15j45hJeYdWBa/TmlqTiSr/ibpJFq5q4TOVRze8lOSjdkjFAal6MEOlarhhJPqWNrtS9cRnolcsC0EDjU2LTi4OrrkyLkTwUkS9J+e6CZBDkd0g1pNGAotKHnqGKnIcx5G2naHiESfbg6xIOUhA1UwLNK2nC56m8YQQCCEQBZqhTKLRIi9I4vcXcJQdMihHCk3ahsCcOZwoXhU3DjAMuC4hYyLjcPXgb7mssyklYi56Y9mqXtlugk89EE/tr8Jm/Z4Og9MYdGwTCSbHS++bwyLbw2IJeW/5jrtnAcfWwbeK0E29E6TqnU3vPIOzRWC9aJFZ2WPzbIbNOxrzmPdILn3amFBWjno3SWTfEDgguVgkQwxGK1UagFK9hoy6OV1wvCU2XHm+ZfRXkYMZwaqG2B89T7zX07ZgKZI14i3RqN8AGzkN5M9F8vWG/IYHWTz/WUQdy7FnAdzcaYMHgesPPoQ+eAWGFUlWyHOKXr5UnmvRlH7xDpvUnf8Ye1U7ZkUgD1JwxOmEkY1sSiOZpgv4pmUeOnANqcYMUpayqWRoxG3CIN3Mc/lCywPVMbs0m3E07zAzQnA0QdCqTh0qJydIzXLThhiKAkxoHE1oaEOz2Ul9qKeaYPjgMCcMKTLEobIPisBk13XMZjO62QGL+UVm3QGaHWqOri3eoGkPWp5xtMx8XDOkU1wA5yYnwlGon6meblyZuJYwc5tJ5ilomCh1Ugm9GuuUNhkx0viNxhFImYDsTOCdmI33hVNSSldZEYNUtqEmKbRos6I2LSqlbBFGMyEHychNx+LyARduGnmIxSmrCwubxa0sErESqYN4spX8ulQd4t4yMyslOa4uHEcXOtpHrpMffoTGZuhzRYl5feEGzSOPEC4cMx4dE1Zr9HSJe+YWdrdMwxR66Brc8UXc/ID80AINAbGIhIQ/Ke1l/QrfHNLOM74/YXZvzdkQOUM3YYVpYx7F4WtvGJs96pzdt99udYLqCpmgCiuXf+9u4v6+f6sqSWoke8e5mJ7LWUGK8GwcM8vb++wij7A9SW/4LPchOLub+u3nn+aN6zOsOyKbblZsxeFFCjdUQM2wbOeEiUHx9V8FWWwKKqOCYKRYSzdJS/bzQuJOig5rvA8MUdFVEbRuLhm+e4Cc5qTbLzI7uMCJP2CeT0r71BjuYBCdo8/KkDlP4J/e674NbUIbXs52mvwcMjr1zwhFQ3DHAS4u6xaF2OGh08tLx8bL2aQxNzl4u79Sfv88KrcJR08f1l+YxkCu874xm8D+glZYph22DoMilUPlGWsokyioN0Iumaeua2icEIbKrwJWXpFcjhbOObwqWYXkyxraVIHZ1G6/Gy26Y0FccZbShLQorgUJUlFfQXNZSFJK6LrytA6KKLirLxSyYCnjpClZDJT54IMQQtg4aE3T4JsGlYhrt3paa0vcNPhcUrrD0s5HS9gsz7ga6lN8GpFxgCESun4T5oQyX3PljGXL5Fp+afoZQM5pKxSrBUnb1Xy0nXUC2VI6gmx5ZLuDYpeOsMlot53P6t8bjcX6c7/zOZT7T25IC6QO+rmn7wI6K3tZnLVkFKLiNWE5V260EEwJ9eGjOJi0NRuPqsflOr6m78URkrAk8jZxvGV5hNxu0WtHzJoOVuXgzM07peQbypgT3rWIxpIBXjXRXGjIsaFdOdKXvwf57E38507xsUQ/pqmxAB71C+TRh1HfwEEgPnyM/P4pzUk59Ks/KtUCpGVbJ+CL26vaMQuuQcyTY2CsIoANDS548J5ZO6frZoSmI5tnnBZcAacCzpFcKKTOoBwvZty4suCBCyVl9riZcdCU05APhVCa1SNWkw6AgOEwmiYhVs4VTYDgi9O22URDRKyUzFGEmJQxZlLKG66AeKFtW2azBYv5IfPZgrY5QFUw9YQK02VryamUPuniQNfOymlvR4rSpJxgnBQJBauelGZBNW3JnoSaDm/ll9RQVw6XY51h6kB8mYCNc2TNhb9Wj1ybNGotC4rk6X7UU+x2AylyFoUYJjWbxsiMKFI9DYfDGqM7mHFRheHFu8RUeBW7SJCjDHOPI4liXmlTUfSfnLwbzvHwjSPaR47R1z2MXrmKNQfI8WXc0WPopUcA8DcuYEfH5BBo6RlO7+FeuId78jns7p3ynBc65KCDg0PyYo7rrpB9RFf3aG7fxqUyEaOOJCe4rtTE8/MTmpMTZqc9p3UVWVKQDC9aJvkOQnZ/AsP047xz3RSSLd9XUrhdZaFO4SlfAj3nwrop13AB5wnqkzOYHaRktChNCATzjJqnxLmtkyc7ToqVMTKFUXefWYAspe/W/Zq7zz7FpcuPAC3mq+MfGnQYy6Lsqop5LptyrnUrpYrmqhlay4rhQ3Hw8lYNfhgGtAnMm3azccWhx+uIjzXDs/Ho4iLWXEb1RfoXnmJ2+UFS8AyrSI1IsRTodes0Zyl8wSmJBEDq4J+y2nbbZSsHsv3bwaYRJ+GKSZFcKAriJchdVv80xYk2jlr9XdsiktO+OqF20/iA6mibbBzcCSmZfjaNvfsRk3OO2c7zT9/VKgX9rXHobkh0/UCwhphzIdBLEUpNGOt6r8E5Wi2LiQXjgngumkOzkuqAPskNJ5q42WaWztMmCOrIXuh9YiqHs7CKtm6qFhSUt22KQj/ASnqaxhGbDM5ItSbwJjuzjq80RjSOSPBIAFUpYX8X8bKD0OHx3hfUrOvoukjbDgRLVA46FkDUcdYoz87gIMJcIWTwk1pCzpt29gqSMhbXLIa+oIyTsHLOBb1JQyH2a0RTRb0rVJQ0o5qrgzaSNBNFp+V8W3fXtmu1UNaPzdg555mXQ9+UVLJBh3cc9d1DclVs2oQuNyIetg15NgFWC2E8aBiD0Ncs/Rw8Fku1zUUfi5wHxiipZPfWAT2XUp3z1BSPlEOjOUQ9fipfNR2McubRs7s88uJDJHdEe/FC4RCeVsSs8choWMqQIwSHjUaO4yZLN4SWRlv0gQPCGx+nf/P/gPWnuNTD7Z3s9OvA4dUbcOUqri2RqHB8keaROZzcKxedjqWyRJb7p9QXtD8Zrra3ve1tb3vb2972trc/c3tVI2YLvyCEUOL1VoVVmxnzWWDWLFi0c2ZtRwgtWc+HjAoRUxEahJFZ23Dh0HP5aMblRSHtLULDvHElZTfUAt5aqPJ+c3pIeGEj1OpQEI+I3xT0hkIc1VTImatx4KwfGGLacjIoSJH4QDtbMJ8d0bQLmtCWZxSPq+E5JZQM0BwJcUnbzmjblpi36kkFJctbYVx8yaxRJaVMnAjOAnhFncNjeFPMhByEsfIctGtwlaSfsQ0xwSjomJ9GkVHqMBp4X0pdbXhEE+xuqZwevAOK5hFZ8V6I9ZrWCy5FVKA7nHExH5NevFtC17JDvGaLFBQJICMjPNgYD1wo79e+/iHC274aHn4j7uACrvHo7Arx4iHp+BG67kK52fqE+OwSpzAON3FPP437/HPYM88X4hXgw2Xk5AK5TaSUaY+UMD/CrlyDyy9i98oJyQ0r4rhGx0SYz3HHC8Jxy8WTJfPbJXR6+yxiVmQTtuHAL8DxuQ9Bm5CL6SQcrfDKSlPX0BfbxIFN2JFaQmnn3+e4S1CRhIKA+cYTFJzLpS4m26BXqohOrmOoKh9s7jvdc+L3TXUdn/ns/+TSmx8nhI40ldYJARs9Sim4LKa1hu1uI5TKn6X8VEGQRYqYqWUjVFkacYk+ZZIknGtQcxVZSPipbNvZPcyeJB8/wOx4wb27z+OWd1gcX+De3Z6paMpICWf62uJqpcSVybbBdrPatLZ+EjuXKDH12a7ZFLLdCREVHqLhayavt03y9FY/jh0y9oTI7Yald1CNDfpK5QHtxlL/BLajoLAjrbFFqoMTwlBQ+gdO1rTaY2lG1IyJLwiXgJA2GYEZwbuWIYzMTXhkBTfWHhsdzaosJKdZuekSL0R4vjFOxEjiaKPR1vUKCtrid0RWLSviPK0P5KpR0HiPNGUs4xOoI6WM5gbRIqQMMKaBZmwL8mot0niyJXp1zHfnnhS0pmk8XVekM9q2ZZES62aS1SipFtHDzRkcKBzV9lzUx839VFaoIpQ5ktMaGddkB77GPAti1qMpksaBmBMy8URrXDulRCIRNZE0MxLJE/F/Z+BN/ESRLbr6cqHMDR/yvt/VnfF2P0d14ryG87dCKdGjOIPVkac/aIjzGeuulsJDcUNEG8HFvE1gyoVeMyXSZKkRAck04sjiyaGEM319mJQda4NOjDecGfOzOcvjYw4utFhQZFkLlIvDxRXiEiHnkhbtFKfDRv7F9Kw8w71DePiY5uvfh53cZvH8k9xMtlkjrgYh3LhOPDogpB4dM0E6uHJErrp2elJZepI3HLY/zl7Vjtm8mTFrA2KOttauPGgXLGaBrjkg+IZZKAKtyRXHBMqgMs2VzGs0Ho4OhItz46iBg9p48+BoWsXlIu6qUhrXyyawVET/xMA8ahGRjDNXsqp2+sCR6POaPmXO+hUnqxXrcaiZZuVeIgEfZjRhRtfNSlammxEk4JxDpqLPeNYoXdcRhgZpHaFriYPbrKZjlRUoC5YrGm3mN1o3cZLLsIj5SAhC592G9JkFtKozy2xWakrqWDJeZCuUZ8ZmBpfftVrYXKtTaNX53M0J0k3WFJXEGmjQGt+3UIixWCn+ejjr6Gcdy/WwEcuk9sAgoKLMDC4IXDkyLr7xDbRv/opy0WNvwl73COnyAzh3iGdGJNK4Q9Lzp8TTzwPgeyWsexjvomlNePZ55NaLcPc2bj4lllxAzBHOMu7sDO0TYeagOyBdu45UDRzpT2iWp4XPpyt8DLjFjHS9pTso4ehrz54S7vXc1FK/cLOf2hSG3hnosoXpVeriZyXsCNVZVcM5t+F3mKu1QHccM633mVLcHVs+yC6tLWnJBnPOERpD0iSEybmU+YkYPIU77k9igCqvYaWfGjwv3PwjHnnuSS48+OWbUKY6wfuSkJPzWHKEzbA0hWdBU0ItF9V85xBPGcNWqw/UZJb5wRx3uiTGiJeWSABnDGtIY3VlB8P1LxLWS+TiRWbdEbfu3MVdfQBmM+yshDxibeOk23eewpM7GsebuVDYAtsG2A25T9Nkk9kmJfFFke1njlIjs4rLWA0pich2vjA5XyVMvfn3fe0+Od1lE9UtN+/8ZbzMr77EppDpdK2ncIYAwqqEh68/f8qVe2ecpYakmeQ9yXtWXmg8RS+Eols4dg0aHNcG4Q3rxPVlZL42pMb5z8w46hyXBQ474w865VbIxDpPXA1dhZxL3WNxUGtMai4JI1Nou5WqyxgiTeuJMZNjJkcF3WrpxWEk9mtcKHWTAZILBNx2XbLaF6L4IDStJ8wCbdsSVPE1e7M1cGPheqwb5QVROg/rFq50tX9Owa/AD4AaEiOpXyF0Rdes8onJiuaExpEcSwIAWgqVT8+eVYkUtYBIJmKbMObkUJ8fO5W0Px2ozoUxpY7VbXWLifawe586BMsh3LbjbdfHK/QKoIH1Rc/6eMbqoKVfzOgnvnSmcMsMUioghTeQWspCq0efXGbUjKck4Xjf4cTTuIxUweGybrU8loQ3Li8S5QB/cFwGbL+EW5VjdteBT8TlsiR3TM8Sx3IdEEzQJuCefIZ06OGRN5Le/NXY736eK8PAsr7kxaNryEMPIY2R75zhBtAu4xqQiw+Utm5WjH3EltuKO3+cvaods0Xb0rWexvmNkOusmbHoAo3MoHG0LpRSL7ZFuZwD56WIVqoSAizmjvlcCGKEuvo2HoKrWWmSyYxEGwonJE/p+T14kFi4D40XnJuhJmRVci6+taY1MY6MKbMeB1b9mhjjRtkaoA2epulwocGFGW0zI7huE1PfOGY5kwSC81VQ0SNNQBJMEhcxxepAOkx9KdFTc/FzTuRKiMgxMcrAbObxXmhEyE6LunJ9Lq0K3AEhp4JcWGWDBt1uRk4cztnmNC+5VBDIWbecHKEKSW7LCjkBVGl8mSBjGkuRXzyyTNDMCfMG1w/ndxGBWMtHXfBwrTvg0tc+jnvjo8RHSzHaeP0h2gvXaA6vAIHsPd3Sw4tL2meegdNnAehPzmhaj7o13Yun2Au3QQe0MZg4JoPghgTdCvEO6RP0DiOiTVsWAcC1c6Lz5Do2lQR9pPUeusJfTI/MabrbXHrhlBdsu+G/nJltN/fp37tcpmwlMzMUBd6NyON02tx1+vzUB7bdaIUtIjNVHUilbAbeS0E4XUGxtkTbrUOwyQw0zvHjHEU81VMcRc0wxIFbzzzJ8Y03Ie3Urh422Zcj3lHka1SZtM5yzqAZ7wtynM22fCorxGCAIWkROFZhSAqhI+c1zran/Du5PNjxakV/a8Xl48tID888+yIPXL1Kqhv/aRUWzlUiY3KIjR00Tzjn8MiOw7Tpv+k5YSMeXOttl5/v9quxcQyEilCZvaQYdKX1lN812Wyk5xz8nZtPR8ndjXMaQ+JexqN+GdtIglCdezOojtnRC/eQcMKFGIhmZHEMwXHqS7UUmZc3mHWB5WHh5X55H3jsnnLpLHHxLJHulXvdtkjXKs/NyxfORuPZAE/P4MRN6wdYKAXMpwxFVQXHZk0EaHxAUxFmblxJDsqj1cL0jqaWgUpjZD0O5NHTeKETh09FpDvWxpfK1RMRnDO8N5rG07YNOWd8RYFmWrTlYlSSF07mynNNETOrSfMcWskpkAzNCC0JGZVAjzqhmQ7rJS0T00zWhKqRxrh9X6gpXolEcdAmiYz7nanJJgR+82dnTEwY8CR4LZVXOZVeg/NO2XQ/mQ5p7PKoysLlvCtI2aJhFTxDF8hViy7mkjwXhxEZivySVd6tN8jV6dI4oikSzEgI2TkOXUNrsJZykFIp3/24Lnh0vMbgZnjfYOuEcYrUAxe2IOmAkXDmyXEgq+BkW00hrwe8a8is8P/zADm6iH35u4hv/k0u9J/lSjmDc3RwFbt6ET8uyRohBwwlBUfjSoIeF2c0p0vysCBUkfs/zl7VjtkszDgIjhBmzGbVMWsb5m7GojlApd+UUIlRNyVgGvEMpgxpQEJk0QWOgnAggrNYoWjAhkIa9kpODhvXNG5gGBOjTg7XSBoDOvc4l/HNnJwc3pRsRWcMYBgzsV9zNsLJMpLWGWeKuBFfHRLfhVK3TGrab11oppTsDZk5l5SDrBGvRgB8KiepWKHYrrEiIKquCH2msejg5I5s881pa6h15kYU545xMzgkkGhYzcrw6GfKAUqoav2at+GMBJvFJm/jO6UkixPyOheJjLq4uR2HAraLRM55i9r4gMulfpt4yNoz80JrRRajAtLFYcnl/W9cbLnyVV9O+qrHyceHm9NKu7iCay+WFK6o+LEnR4HxDmI9biibQTfcI968S2uJdHaCGyKZDjPBudLXLp3h1pHcgnaGy5lheZOWntaDNmXS9UeXkf4Mv7xFjJHiNwjRMk7L4jDzLd3Fi5w5yM+ecpuCKmFTyLiYUiapyJZMO23uU7uPVqqMhqxF8V/ZKHEL20ne1SSJGoUsDtNOiR7Y6lytR2M4yFxqO7RfE+tNJt0qq9dN+lhTmGwKX07PPjkPziBJRmh58skneexNzyEHD5Y+8nehaRlnc0Jc42ORf2nCAWnKWBawbDgpenviPM41iBdCHkn1ZB0TBN+gsiZbxDlPK0csfcRC3TpSxDy8kKEVOLl3mxZY4Xj25l2uXLpcnv9sibMC9gjFkWmnDWgHObs/WcPVdt1ovu1kyE0O8Shl/jTiNptwL6UvvSnBwares93eGpMJ5TRGqZUv6pcXZ2l77eQ8Biso83SfyWEPNTXU1915F+HbDXtOodPRymHQ1IgonUBbdQMuf/J53AycJpZt8bXxwoXQsFh4motlAN09LAesh7ojbowDs9PbHNzNNPeE8WaZi4dZaRbg5hAGoZlD6ErbPdnBva48aQBcLJVMxkYJPuBdKIfPKkuwaDOaWrqZ0K9PoFVSL6zOYHEkHCymNlG8KaSIRUfGM4gnt4INtToAvpRxUsXhmIcZsTHiPHKpd6yqdMhAwPniwLVk1Bz3vHFkmdNLZa1/cWYcuMTxAP0I9yj3tbTGtIFUq894X5JBcsZ2auCqJXKNHSSUEdvImGTd6jfuOl0ik1D3ZoRs6r6eozdYGUzRirxOpCQqsTNuHEVOKVL0Lmc1lWWq3Qr1sLmA1TXl7sXMcuFYHx1w5jvWp6WvnUtkC5hv6Mk4MbCMt0yfxqItBsX11MipFqTTdQHvtGTIuqlqyZw32gn/29llzD2CC8dw1KKNR87ukq0ctEiKBEOiYLMZ2rU0yZCY0XoIT7pET29j84A8/1nGTwe6tzyEvPk96PO3uHJadqHVvbss+lM0CF6F3HqCtGhyxCqw1jiPHF4o9ZXv3AWe4o+zV7Vj1oWGtvW0oaNryrLTNoHOBYJzqDTlFJqLpzAtYLmK81kNEzgHUgr1Fc+8Lakzlg2VzJBCKT6eMlkS63hSkDKKR68RMi3BK06MmBoiI4JunLxhiJytVtw5U07OlvR9z5AoGVM6yXhUSYtc4OkibZFRk5quXJ9fB5KVs5H5XGp3No6gnlw9oCKNAaKK2BRSrHyEKKRY0TctiEjORkpFEd4HjwZPmhS2D+dY05DjUJAYOc8z2GSAKRtU0uqGMZ2uJjuXIXb/ya32j1jVOppg/BJDYA7cgiolDLnWani8gaOv/AqWX/kGmsuXaC5chAtXyjVhVoSDZcTGxDgMNEtFl2eksxNkWY8+qxNYnRHjmrRaEXyLaUSyblYyGXvIEVehd7eKNOsV0mRs1mEXDgFohqtIPCPlAejBima6xzY7Z7YETpkfHXLYJ4aTNTEBoRRKnhC0Xc7GFFLaNubO/0ppP6UiLrYNgU2T3JtuEK8Nqmnn6zGWNPpS4zONkRQcLnhcLbQ80QGcFR6b30F4TGQjabH7XJMwW3HAR9LqHk8/9Xs89BXFeVY3BzeUTCsfChqriVR5hgBeSuw1pYSJ4SloseXCZ4w1o09zIjQNXddhORHHvmSsiSC1YLWrDpqvGapTcek2KQnl2VsvAsURcb7wgJrgcGYbfaxdKzIUFV2mtqfttMMOqjZxvrZSGduNLFOctQlh0wkBk53PXvLtvGSO3W+7uniwDV1Z3bWnzfsLYWYbB89DIm+c8V5hWR/o5u3EYCDqaExpgeCNWTNyfAB9mY6ky7A4POZSWHEpRm4sPeGWcOupU8aqvdnMyoser2G+hNmRQx6oIa2xZIwDDC31qZVF1iJ4XVNNN2uEb/BNKafnQuUnipFiWeuGvtzXzQTNDnJGx7GslTmXbEtXxctdg5ew4+gIwXsaX5yDSTC5VKUqWXgF3ynI+d2ZIFoOb9YpuB6zkfAspFqXy+uaJrZb4d6moalC6SJSRGMkn+MYpjrmktU/cE6b7Fyk0nb73zbor9v5+fQO00FwmucbTmstkp6kfHZgW9HV3fVqaKE/BrvQsDoIDJ1j1Th6txOiFEdPJgxaDiuhACfJlByVWMPDmUTWRLShhHIbRZpAp8YwpeBr4vVDx0PpAlggLlra2QzvPHrrDL1TnKlmflAHkmIygpSMaw2GW5T+8YcO+8Onafwhyd2j+fxTjIct/trDcPU63C1f2d28hf7+M8Q3Poy2M+a+IbeObKFkHwPqSwa5n3eExWsAMesax6zr6MKMrimLbltrTAZTzIVCmrSSTpyqk5RyKRweK/bvfKZQXAxM0VwcsxgzOWVWGQZdsRrXJAaSrTaIk2WPpUwaR0JTFglJPUELapYqgjX2idvLU26eKSerkT5mxlHR7OjqKuLVo1EZx5EUI8lFvIxYrWo2WSaRLWGWSsTdGeYUkYmeWxfTabM2JVV4PedS8kU3R2sP5hANaK6kUnGoD/RViT8vFqj3lWeyXeWnjWbaLe53GiaOyw4lr1RjuO/iDd2skllFc002KOESL4Lzgdki88Aq8WL9vQHlbQ0cf/Wbkbe/m4OHrrFaHCOHl5CuhBWzOdxqQOMSGyI2RNw6w8ldOLmLLYvEhYxLvA2k1ZJ2SNisLSdksaq9Vh/arBxJB2B5Assz8oFHL86QRakQEIZjuHuMnt7FxeJoiyWcuXIAAIwioirecfHqZZB7xFtnxKzn0AphG4IS2Qkfbpt9QwhP2Dl0BUqORfVD6Kh8tGlDrg6D7KA7zqxC+oCVZJWws7KHikSkIW/IvtNGPen0b6UkavhHpKbeF1J8Smue+uynufqGt5dnDAeIUxqEKAGzBOrBRayKiYp4xGViHJBSHbvwrkQ2PDQo0gE+B5pQXl7NSu3a2SFa56Jv2nIAs4FlzKyB5OBYPSuvrOuON6fK6lD6fJJrOWfT2N1xSF/uGnNbzg9QnIgaB970Y71/Nja76VSLc4PI7fTFpirEF3DKavOU0I/YJrRpO888hUTPhUbvsw2vriaFlFB3mQKxUif67IkaSa60kxfB5cJPpIfjZbmue0E4u3TK1WPPBUvoiefklrAeZSMKl9To+kBwRtcn3Lo0XHNRWM2NZd21nqegRTErLoKIknxBG6c6shJaQjBCk4vmmE8kEYaU6Ve2QcxSFGQsJYokjvgwEtqOoJlUtc3adkbnKaK4YjiUxpVEgPZgzlCdCFNP8EZjDZmRnEvCU249J1r2qUUQDtsGc6c044h7BgwhqGGxJ1fy54EYhAZzLeYyufISJ206YFO7MVOQ0YlXdv+wmHiqm/Ex/blvvO2OkUGEaEVyZ9JN1FrxhZ1DmWOSg7VNTbp8xRMfcKQLjuHijNMDz91G6eOIq55cckp2DSH7UiFFhdYJ3hw5CilPB54i0p5zoHfG8djgxRdZEF820EvZ8UY7oLUDtDkoc96kOF5xgL4S+6UnhYjkiIvgQkm6Kx71FBLIRbqjX5H8yCxn3LMX4c0PIA+/ieZmcfL05DmG3/td5pfm2GFDXK/xly7RHh4yVqE97zwaGoQ5dnTAn8Re1Y6Z90IbPLO2JVQ4s3HQeIfLxSEhp3IK0oTWwuMpZdJQaqapDKQspNSTfYNVIViAIY7knDmJI+t8j7urMyKxFDd3JXRK9uSYWA8DoSsnYMtStGlyJMVad285crNfcftMWQ7CmB0xKaoQphClDcTYMsae9bCsEK3SWMCs3RKhK4icGMkWUYsgqX5es0CdoGplQVYtujx5SoCoJW0oDpoPDd6FjXCjWUa9MdQdNnUN0TmaKoqYtdT5mzRqJjPdTnwTNvUpyz3r/7idE5oIpqXSgLBFzNQyVB0k5z00HudbutUaZVv0+Y0BHnrbY7h3/m8MNx7CHV6jPeqgO0Qo/RNigrMldvcOsh5pc6ktJ3fv4k/uouvT2o8jorlw6dqOPDtEnEeyohUG1K4IyjpxWDLMRiyPuBgIquiUgSoecQ3eH4BPtN5BGlAbEL+N6ZoqGiOtn3F8dMByHVmuBka2/KPdBXbaaCebTr6TVlFky4OaLgu7aIsVtfitM1XQg6k8S+mSoqHUOSAnxlRq8E1ZUpvi0OyESjfoz0tRvQlB0vodzpXfO73zIi889wcAXH/0axBXCmCL+JJ4kAv6vEk20OqI5QTiyckqgleeeSLwlnI0hqrQhBabG0M6Lc7xpJvmHTnHgijWZx6BexUNml5AK4o7FyFPm1N9r5fzhaaNbpevd+6H1TIVqRRAbMdR2sm6POchvfQwdC5pY/r/l3moiRt0jui9hXzu/6JzGZ6bcbfz3FPCR5ZSJm1RZ/jcKldNq8qYbAVtExR4DWh64eLK8C8mrBFW68wwlMzuCXFqwoLsSk3EkXu4OHL1BUEHOLmy0wcZXjyiZNyJIrmIsHqTrdiyOBrvaUNi1jUMTSK6gRSNYa3EyjHz655Rary2EUIbaHKiUWWcNOoUpCsC2pYnjUql8R4/g7CqhP0oiJXEg4inTwN4mHVzfE26ygKr1iNihOGMcRzJ90rCy5BH0mYQeVRD0boUxSQXp0yLUzr1S3KF/L5x0r4Igjr1qbnyTlm2Y8d0exBUYFUPHJGdChU1EhJKN9MDBzUbPGMMR3W8XIZ40ZMWgeVhw50jz5lXLMZtTVeZIUyEfqlOv9KYYdmjVRQ+Y8RknElilh0imVWjzC1Rhw2HrfBlyzlrOaS9cJlwcJHsGqKLZFciAVAiTjmmUmVABMaAa4BxwKrWYY4rxAVi3xOCQ5sGuXsbbl/BPfYVhBdL0ti4uou+eJv05GcIDz+Iv/gAsrjA2B2gFyrSujZc0+CaI9zh4RfvmGqvases8Y6mDTXFuXzmkOJgOA9pLCR/26IGUOPzOZPjwKA9ZyvjdOGZWVt5N+VmwXmSZk5WA2fplDvLe4w6cHQYqHVtcepIcY2qY+gTib5wwbKS0oqxLx09rjK3cyqI2TKw7juGBOISsZaKyMlomxnj2DO2K/rQIKIYpVivq+hB1MyYIzH3JB0q10DxbpvgoFLI0ZOqvuqWLDpJDgCoV8RD25VwcOFplMVnIv/nroE2FMepIllQNyER3FSMmryRORC3dSLcDtIgdQEXcTWBoJy85D4EzcyQ1uG7ppwYccy6C9jZHd5Ur3vkjQ/hvvLr4MHHcFfmxMVlulmZzFZPR+7sDLl1G713C8ZYqNk2wnpZ+BwVRZE4wqiob7HLl+H6JWSEfLLCJiHcpilFlcUQLQV2LRssB+zOCVbLu5AyNDNsfoDPA84C5oqY5SaMayAJJDmSjhhwfHxIn404bBfmbJTklWljPod41GsofJCJR9Sw3VRFSi3X0tm1kDA1PCHFjc+6vX6SRfEGOcLoi/ypr15X3gnlCVs+0ibscR/sUg7jO86HlXD3kJSn/vB3ALh8/S3Mm0A2qdnHxWFQS2xo7zZl9xrOqsbAJCwpk4YF+NAgCFlzCTGJA1mShxFpJkHbUgsWJ7SFk72pUelzISsDjLWm6MwHYo4M7Oh23wct7Qp3bsLIOxkdu0RrZOvE6n3O9oR5W21LBbJtC9HvonHOXvIY58Le07WJKcNONuGr7U04Fx99OcdwMzZwxLr5Tk7olCE91FPZxjmd6KZSULbJkhkMsBzA6gEaUWYtdDUL/KC7TGoOEAbO0pyT4ZRmuMODd6tzVGNqOUHfwNnciKYMKG2Fg6dx48XR1PqZbehofSxjelTyoPTLVW2jEW8eN3cEc6RcynM12Wh2IGPnHIFQRG2xTcmpYEZb11TXtkwJGWhEfBkPdA1dzU7NGKfJ41wg5IDpKYQ1w23o1dDNgplIDKWAtroyL62E3Df9K9sQZrYtun5uzE232wlj54raTslFU/9P4yfdN5enUOYmUxgwNQ4IGKnMgUvQXy3XnR1n+sMWuzxntXD0nUNDUS2ZpGtUFM0T5ScwavksY0guFUCgzIEhZxKOwYoAtnhH5wypkZ0HCDysF2i4QD46QhZz/OEC9IR8tsJqVrbOjKbPiBUednQDWov0TNUnJCs+eXK/pvWe2C7wyztwcwmvewP21RXtf+HzhJMBf7bGmiPS9UdwDz2CLGa067oH5R5t2lL3czYF2b+4vaods67xRVfI8mYgmy9qZX4qzYGVlPkdvbCYMzFGYowMKbNcGfeWkWAjKStn/YY1hZI5XZ6w1DNO16flbuECs+oktTiwGU4zgyXGVWQ1RlIeGYYlQ9X5GXrj3mjcOc2slkYaXSHcewg1rOi1FGEe45p+bAgBXJhjNgNVvExZi5kxr4k6MOpQEbOimh7q4hDPrb62OTVvlNPrplG4F4G2DSy6WUlF947WecaK0I2d52Debmq0uVBOqy8lvJSJ7oyqiC5YTdXfhENqVpUyZVNtF5BNDU0BgsO1Da5tQBo0GZjj4TZw6fpFAPLjXwePvIV06RJydJWuu8hgK3zKyEkhetrNO3DvNtafghYqeZMGNI2Y6OakGEaDmAmzBe7KZXj9MdxZo+MaWZWLgnOUku9jUeAetSh2j0tyWmLHhT/gFgfYxaMij+OE2J9geUBchZ2gLDoxomPCckm+mM3nXEwLToZxAhhqhuXWsXm5cNpEzo5WQlbsIGawXYwbV8nrdRG3Sv6eUutLg9k5xCsapARtPh/qmpCbLfqyJZ/r9qON+XqSN6MS4427LzwDwL2bn6O78TBjKFlyTWhBwdDNQmkiuEo6JikuFG5kSnnrzNW2KCOvlhrDE9oOUtyGpJ0ni8N7R0NDoxG1ErYRq1p9FCRABFKKzCoGmyvClXc2uHN6a1QHtYacSjvYSzbITcbcfb9b+rHOl+k7OP+7m2vtPOj1MgDYuZC3sK1ZCWWz20Vm7/+O+99rKpttFEdAHEzVCxRBTFk7o0aat3Mrn0eAEyDOsc6ZGVNSz5yLR2VHb+fHuNkcsys0fYbuHmepoTu7w7VV5LTedw3c7SAvFG3KHApaS3dNySrB481ovWcQVyQugiMlsBGGobSOhQQNmM+oCtmBDIE2QlpUJNyKxlzjC83Bqwc84h0HUlAVKOhrqqEBLwGLHnLizDnc5NkHYRiFPmfCsdCkhM8jPZnVPRjrwTIm5VI7ctwITW4IsSVbobdsQuC6dajvtykBYPffm5qtWkvZcR4dNSn6fRMaloC0kyRU6rcW2oQAx6RyEL8E+QqMl8rd1geOk6NAutBgoSgLBO/KONwsYlL2wWTEsSQjRJSmeKAbxwwETZ6gjjOXCBJY9IHcebSW+7oUO+bxEOQisT1gDEVrjrMR6fuSUQ+4vi+RKnGIF8SDBofbgRlFhEgPNqJRaNY9Od+DW8+Tn76IvO7Lyr1e/3p47gw3dnD4OvThN6GXDmgOF3CzyHNon8khYNmT3Z/M5fqTiWrsbW9729ve9ra3ve3tz9xe1YhZqMryMcdNMd2UjDFlhEzwRX9lNGOIRl8zQcZojFlJSUmDcnpmND5iccmq2WaxRc0kHThZ36PPPYP2LLrAhYNDUhW0da6B3KDplJSUXhP31muWY88Qe/qxfOewhrOVsBwdeTQsjyRS4epUjNinDm9LQnDljzecN0wySkZSjbcrxLxmzD1qPcaAuETjDZuIl7mgUy8hIlfp8rYGZSwU7bcQWoJvaEJAXKBpOmKNgY1tgIMOdSVUjNoWHbFSNB4Kj0iknMxL6KmcyO4P15hRa2bunPqFknwBiAtI8KXmKUDKWFQkjhy+6UH6r3oHAN3rv4b8yDXC7ACskE/CmPGrAWo6sy7vosO6pMEPPRqHkhE5plLketICciBaUuatH2G1RIZyjfdTdkZX+GaWISXCdJrTiF+OaNUhkKMj5MIFmuMLSDsj3XoByRVJmDgM40AaBtKypwsdTgo6dLiYc/UoYcsS3j7VRIRNeA92kZEJJSqI6FjHxg7YsrkeCmrVaWnrwhcxopSo/4SwpnqcnjSrvASyJbIWeY5NXVReGiZ5OW6Sh3MIk0jhKCIQK//yc0/+d44vX8PaObY+LdpvziOuxdVUPc25JoIEUu5x6go/TgQRCE2dG7mMLRFQTZhTmllD0hkbuMIAa0q4xMUNWpRtOyY3z0pBDVqEw9BwNw1btXS2Yb8pnCtskcRNMsROiG+yDUF7FxmrKNOE+r2kfXfa2VVOzy6hexct25XwqPxnpM7T+++5i7SovRQ92/Sn1FDihLDats5nqmhck13lHuqGk7QbLc0eTOHMtMrAGI2DbJ486QXOGqSZIc0FZk3mgdTSjTPWrmPR3+LBGoVIZ5G+g3Zm3GuNYVZ4iM5KyBUK3SVgtK0nBKVtMk0TC6cr5002pAuU6i6AdEVLS52R1THWhkytomp0IeEl0EqLaFvC80FoK9nJFLzVbGZvJBFShNa32LTYB4dzgSEnvAOxY7IExu6E1PSku3VdGkomddKGA4WORFPSOc6JHG+4vC+DmurOnN21ifLQsOWJFu5vCReuMe5SwtXjzk0bSmLMDOga6DykSxCvQDoWVge1gsOFwNlhQ+6EedcSTArNwGnhIwKtEyQ7xIRVFKKmKoBex12exmCJiScb8EkZfGRtwjwquY7QBQKpgcML+PkC3wZQI2tEglDzVPD9QGzbwoNUxeVJvN2wWjvVUgQxlIbeRebjEq+ZfOcPcL+bEf+68lyPvR3/3B3Wv3eH9rm7NMzAefo7d+iqDBYHs0KijCPSvQaKmE88DttZwbJVUr8Kzkf6pKyTshrTxkmKuWjLKA7NgfVy4K71WMocdkKoHT2mTLKee0tjHY0sirsoqM4mfVmijJg6Qi3hmrOwXEVOVj2rOND35cFOV4blA4ZkNVYR62yQjQZT6jOu62lGTxsK5B6akslXMri2MpPRikAeknHeaNRKBtxU3YDq9NREQgVwgtTCsE0FS51vSrZSLSMVfIv4ItBoXeW0tRntAr7xWM7kyMYRLk7CdtJOmYPuZXeU+nmdIFI3NBEhubLJQtHumRSSNSV0KGGH7o1X0Le+k9kb31VucO0CfnHA0B7S2RoGj9zqYVhhZ4XUn2OB/V1W/JCQ1ZrYGt2giLLpa2sDeRwhRvT2HdTOStmfXkrlA4DZATmUsumWZLPxiyguR6Su8hnDz2aEWQeuA/G4lGAYGdZnta8H0jAShwFU6boWGyPhoOX6A1cwu13a/uy0hH5qKGHSM7s/bDFtqPd/fm5zd2XCBylDb7QqfeILLQ6KY4c4RlNCLpyLiWBuO9/bwEb6YfIEp31iyy0soYnpGYxtGHXXnnnqD3n0rV/PwcUZ6gQ1TzTD4WkmIv5YdPqCK6eNaLESoh1x6HFVrNaZbMp/pVTFHRy08/lWlmbMpTFcybo1F0GhEWPULR/VKIkmycNJ1hoO5dxGqJR3t/rek6OU2Ibvd8VhNwkbbOfJrrjwJBI7Ocb19ve1rWzC1V/IyZrafAplnXO0djf06SbT4Wjn4pds8DZxkoqr5WU35GJV3kNLck69l2ObrAMlrHlYP2uqMzfzcOhkU3HBwikcXIY24meBkC4xo+VuKo6o658DIA6QlxAWDo6F23VtE5FNclNZ/xy+Dbgm4dumhMGliGUPk+aoN3JIOGe0PmBeihitDogVR7CIHBvaKME1mCu0DbKQ0U0B7LlrIAmiroRU24YcjMZ3G+6oc0LwDf0hmB/xzHDdAYPLZIl0bZmQy7tgZ6CjYh6a1uGGQkOYhIxlJ145jRWxLxAq37muqf3X2U4lh/rTzjvanLhd3WplkzRLB8wEDmewOOjojwf6izBehHQQOF2U9fL2UcN41DBf+HLYNmGWBA2BsU5sj6cdPYNBKy3rdWaMI1EEJ5mplpdHS/k1g+wDZI8SWLpMVzXfruQF5uaMXYtrGpquIWtG1z2kxGhl5fFmNDLDYi6OWJNx3QyNeVOFx5zgR0XsgLE7ZRh6mhRBloTTAH9YDuvrx64RvuJ1NCd38b/3GcZHH8I98Rbc6RpLVev0eAEhIK7Bta8Bjtk6ryvq4cjVIRmiw2ssjsUopX7WKJytEv3EJzCPzyMHXpDOiGlGv1KcM2IukwYoXANr6PsVZ6OSTLhyeU7WNU5KdsVB6BAyJFDXcDoMDGlkOfasR7h3UjeJ7ErJI5dJIdfSOKX0SqpLl2rEjR1tB1FbhhxoU0MIAUm64dEkzWiOlBJJhcRdFhI24jZBKJloSr2mEv2dYU5IlYQxazq8BELw+AAqica1BHEkKxNsPIL2oJw+vHok501G3uQ0AEUckMI/2CXzlzbf/Uf9SwqqJ2IsFJZNRSE1oUGw6PEh4WcO/4a3kh9/B83r3wiVi6LNgtQ7mlFJznDDKRrBzpYFogQsjzTqEOfJRHQ4ZTZ6shjStEid1LlPNM1xcYCSEW4mCIItjokH5bl8Az5mAkbOgs/KOJSyMCn4Td1Ad+su5jrS9au4Bx7EhRZb9/Di3U2a+HC2oh9WeIW1K0Tmw/kCyz1Ns+Bq5dGtdIDVyHqqNcp2o7Wd1bjwKo1BC3ImO/0xCdG62t6twiClfJMlJbeQ60rgY2VWSkG5sgoeIVEQjsmxyDsb+G5m4C75fxLCnTYKX/scmxyQisjawHNPf4o3XXo3XHiIKC/SyQGSzljXLJvWGeQ1ySK59fgkuLHMgewPSBXyC97jndCPK1zw5Kg05gtyVhfFe+MdRIyu7ehXEadlzAab+DW2ea8gZQ71wOco+lwztjpZ05RT2eqhFfL/SzNrbbfNpFYpki2qHeq1G45gdXbPyWxYcfx3s1/HHcRtF5l0VtCOtVV0pN4v7Mw/X52nDc+z3jNJQUXKcwm9GdkZM4VlrVuazTja4dGN9X6y0wbG+U3GGSyntrFSe7SZBQ6OGlwoLRHXZ2R9mm5+HT24iM48rptz4C8R4wBVs/LG8DT0IzbmIoXTz7jTBMZWOQpVpqhm5xKVxme8G2k7T9N6xjFtHd7ByqHACRZyyW6W0tZUBXeRyGD3yHGgbY4gtARJqBNmeeSB6t5cHAzWxqoTRIUWxzI0EB336sK41Ewfeugc+eCY1droVsqC69w9aOgvVHf2+TV2s6c7zbRmXOoDvnXkUZjXNWHAEIrsyzS+RoPBlbE7Lb1DHZDzXJxiqxzDFjiu1ywARy48OVfQ7T+q47IWb6DtoD0AO4Z7Bwk5BH8MdtByNj/gZFbdvAPKgTMJ4pWudThfuGKNTQoAkcxYNOJCwrUZTDlZRrJB8BOy1pJSJIVMFMNlRVDuqvJltR2+Mj+ArBvC3JN9IDYeP4w4C4TFIY0vh90cA9oq+BGvWubvOOCkyGoBpe6yPwa9Sxc9GSXmM5qlodoxZf27I4+/+uXYl5+x+u3fRT55TEvCL0b04bfUBpvRPP0s5DP6TVbHF7dXt2M2xIIC0SATTokVEcqsJHOMyVivjWTdZoCOY0JzEXb1+Fr81VivipMX3DY0l1Km10C2GRmjTw0na2inDpwV3ZyZD4xZWY+ZfkgMAwyjMNa0paxCE6Q8Z40DFgShFBUHyFbSsNuhpW/WhBCIsSU2Dc5t9Zo2JZCcw3uhpDh4vMmmFEmKsQroAqF8pbOAWMAkbNCDUuLGFSdR5Nwfq9lpwyKQjxZ478j1WYum2XnEbEINJpvClM6dd8ymz3ezN/sMYTrRaIGYk2lBCbsW92WvQ1/3Rrh0baNbY87hKAr8ZEGjIWNEYt7AMpIdmhKMho651H2MZ5j3RcR2ItlPGWu5UpzbBgkd0i2Qbla/b6qZqYgOmEUsp1IgWq04ywBn99DZHH94AJeONrpbMSX6WCtB6FaU8VACIq7U+fOeOI7kmsp2MD9kPdyjz3kzfndDX5uPZLuBvwQ1m1Cn6ii5DA1aSmxRNtNJo3EMrlSMmByoXEJSwjbENvX1ue+gbOYTcgdl/kzp+wYbUczNM1bHUQz+6LO/y4NvfBsPXL5MXB+h401clI1zowImDqxQGHBKTiW05LxudPnaEEhpRMRjFO2qKZae0har896Xw5tZISWLZ5kGzLbaTFsnuBaGr20y6bbtdMXGdttI75sL9xN6N4XFp2t2nFqr3z+FR3dtcrCm+06fvcRk6xQbWxRt1xHcSNfc16HnkNYJEdy5bDdTdHrv3dtsHFG2vwtbJL/EF0pbt77IHbkpNK9KOrmD4JBFwh9dKGW2gpTM2jpxLugcHRS3TPi1ZxiMYV6OEK4+XRMcwTzijbZt6LqGfj3ivMdC3oybnAXJYMnQsaLDvmYmDgX5aBCy+Bqh8Zg65r7jgMDrR+P6qoyv2WogrRO5b8gHLTPnaFSJLrOsBdjP5h23VbklxtILeeZJrbDshPbEOPRVGV8blgu4ejJw+Q7ge8LQEoIgVUqiTQknDnXKmRXi/lAbfjc7d8rKLk6n0RnMBQ4ofwCOu1L7c7BMnzIPrRN9AJ2DHlZgYGHoEbhjyItMc9Aydg0n85bVpSNyW66buYzhSRYQ73De0boGdYFpZmdRRKwcKDTjg9JkoQmlTNM41QN1BjkTRWuCTuY0O7IoRzVkONd5caK1HJ6dCMQBd7KEs/Wm3qzMfElO0UqnEI+MI84bOmnRpZKeI2IYDict4jKaV+ThjBDuAdA9+SRJHiJcezOHN07If/Q/sE8rvOfrcBeKbsgw3ERvPUV75y6zTR7zF7dXtWN2Nias9eXkXWOLlkvHiUGkQXOpmTeO603h7jgqWItmj5OWECYUSkijQ9rtKheTI+biuJlrSKllPcC9ugmP44j3mSaMLGPidFyxzpExK2OWTXjOh5asrkgsmJEtF8kFVcapPpfArPMMaWCeR2IeiRpLdQHvNlllk56mRwgWUEr4x2nYOG/iynXOQQiUoIG5stpIi7PiZIRQ9Mtc5Z6V8GLJfpscs7NZi145wgdP7lNxtOpJbNJLm6ykUVdtMj85bts+K/wyNnBCkc9wEBQvkzhugziKDlVy5JjRxQFxdgQyL0QVKEW8LRNiLtk7fSxFaPsVVAfIp4SNEVsv0dWSMPbYsIK2hZ36o+J8yRa1XD5r58jiCFkc4bqCjppvyjhLI5IKVO00l+oAUWFTpmtAXUNo5sW5uXOXeHKPlIcSKqSERVQhuCks4UljhCoEPNXb9uppXUNTS1alHVRkg6LUv6dMvylMtgkhThdWGEDqojxQ9IMYwVVYLXblJhuJEy2n70wpXzSFPCZEZnIIJucryw6fiG0G2CYbTLaO2Y6ryXJ1jxeee5qLl65g3QIdSnkdNxbkM4nDxOO0voaj1sZrcGxlYsqhwmGWUOdwDlLu8fjNASKEgFiLcwbmCBKKIjzn/ZMgUuZ9/XQjMWC2CccGCiI51QqckKwJ6ZrG+G5/TXNCarttqjFMf8v2d1/OZNf7m37nCzhWu0WrJ8dqc+mOwz197xer2Spa5TGs+EazncdQ2YZfJ6QMtvfbddgCWw08p+UwmlMiT6WphpFojuFWxp8OdOuhHHKcERRmdf6Lzbg4jsx6oVk5lkNmlRMr8VjNfm68g2x4L3RtIM06VrMR33rcIBtR66RWKkkEwEuZY2I1i3zSrgloFrCAmjBYpkuZ4yg81AtHJ4Wm4NcDi15oc8dwt8cFT+c967mxrrc6W0XaIDSHLScO+tYQD8sLc+be0bcVMZs1HNwWHvE9V8eBbD2tc7R4dF2rytR4dFIhWeGbriuxb2QrRDuzgvROlVs6g4VVx6x25JEXFl1DlMCQFQlnXAhw6yIsL9U+PIS0KIiZzIVlaOgPOk7nHXYww9URJVpoB605kID4hiZ4ZGfdVStZmQmKhI1lsje6kOmHyDgVKK+IeDSls0RywgkOp45rs4LtzvIBhEOka0rIPRdeMVprUk8CxpJpaMGKxqeTDFkxysG89PVQSl14IDk0Gz4Lmg2xFS7dKdedjPCHSnrTVfSNr8POVjTLEQ6ukyvXTn7jt7HPP4ctVxse9R9nr2rHbExGXPaIpI0H7qU4F5oywzig6jBrGKJtxOpEPH0/kFJZpsxnnAelKSWK6lZX9HoMMYdHEAKSPSk6+orupLEHHcg+MmTlpF8zJiVliKNh06R2pY5fqU1mpdTQ5JhNnBaB0JSaYFETaqlqN5X6mdMuWIRXAXEYDrOAatEe2yKHxcSVk54QsNwg0mIubHRkvGylBs79V6SUwQGWrWO8cEjXNBhDcbg2EPn5gWZaRUttGx7RHQhnQsqmU7gZdXKUkhgA0QtN56tgpUOWK9LdM0Lf4+dWeFuUUFgmkiVjcY1frUn9PYhn+FgL245rWC2x1T3o7yDDaZHyMFfkKmr6sncB8WWBC74W2e06pJsX4ub0prmcpmzsSf36/8fev8Ta0mT3ndhvRURm7r3P4z6+Z1WxqlikpCJFkXQ3bZFlS3LrwSbUbnkgGg0QsAAbggEDlgaiNNFI4ohDjaiZII0EQoJnlt1GS23DsCC6ZdotkVSTLFL1/t73u/fcc87eOzMi1vJgReTe51bxIZtqu9xfArfqfufukzszMjJirf/6r/+fwcBqJi8HQltEQi4Eu3EJl48+xI4H7PYlqrpqwxnCaIGNRIxATAnT7GivCnlpZR0VphCYohOFaWP2bTAKTYqgo1J9Y7YWwNJK2hYwcTPlHU53XOopSArBqA3hFHOUqkibo3LSp+rPcEW0womDtCJm9tAepqM1PYvvwUe/na985Vf4zPf+MJe7Ad1vWTi4vhxASlhMUAqmeSUxJomUmrFGHK/Vjc7n2ZAYMDs1eMRG/h/HkWzKMh/9XQqBihAlUE3PUCiPVKwHlOZX23XGwIPV9nb5vfTPnSFkry7Fwonb1YPW/sEqJ30yocnanJ3gHP06R7DgIcq1Br8tYOzPqT229Rlaf0bfCXF75Zr7dZv5PEvGynPKwflj67V1VPcsMOzH2jABTTJHqHpmIF0XjmaQj6TDwpKPhP0FcRo98W6b55Irg0S2C7A3bu6U/aXxgQp3XT5FCiE40pZUmErgYjtynBL56GLG/aK0KJYjoaFnROeB9GQ3hICKEcUIUqksDGXh0b2wu8tcvPC5enVfuToEOApzrqSUsMvIRYjsN36udDliY2SYB67yjruriXy5ARHiZsI6nzgOvLkYn74bGEfnEsdpJCy2liY9WBAu1d89C3CLB8932Lq+b81Lzkk8QLtE2GJcCTxptfnLTWIcIpshcSGB23rkMhbGJ3D/RkMXLwJ1A8tFIm4SL4YNdjURhsQQoS5dlsbdCqK5qb1IYEgTMfg+sc5ZE6Ay4fdR5sJmCBySknN71rUSJFKyNo/MSDVjyPBGJ9nrBHFD3Y1epi9KiEbYDpTNdNIa3C9wsUWCIwdmFSkZ4rhC2MGUWiBG16CsdW4ddY2XdnApjLwdSc9fou9t4LOfZvgilH/9VeT/+S+Qo1vOTV//t77X1ONJpff3OL6rA7NCwFRZcm6+hC1jbt6QUpVapHUGxpMit2rTMXNCYRiVOCQgUC10viESIyFC0JEhCGaBWhLH3v4GCIppZa5wmBcOx8phFmrZINXogjpaUyvpuVBmqUatrvx/JtXCUqAUdQHcWn2xEq+nh9Ws1e8vRNfRUR1IOhLDRBBfHMRc311wRGpkAJlQS4gNhEYqktbZ2lm/ZtpQM3OTYzxQWi536BgJlbUHwQA7K5kL4m4LxqnpoO3CPV48L7n0TUbUN/zOc7Jx8KA1HxF2TDFht7dwPCKU1ZAzmu/2EgOWYitFFFJRVuO9wx67fwmHW8LxnjAf0Wnj6JgW7Mxd3Sy08ltw1fkQmy6bP+tQlVoVywt63DuvxRRy0ztrDzJKpM4vKPuXYCPjOBKqd1fWlimGtCFJ9puPwZ81htbKUpTjsQX+BhIHtkMC/Pu/k+aUvw/++QUPGF61dXSUy1iCI3UX6tP45Wk6E0ovG/qz6/+fxcu9fcHoZYHzoEJ5qI4Pp0BlDcI5EdI7KXwbhKzGR8/e5733v85nP/sZ0pgIaSR23Tdx1EO1QpmRFFpQ713LvYEmCIySnABu1jo3o3tctmDOrbG8tCkiEAODRDZWYJkfBJb+5f5/vYuya3H1f+pBT0cMuyL6q4hmP2oL9DoH7xwpi3TUSb7jMz4/7OzPGnB1NLWdtzbkJLTrtODvKe0a5dx5oJ2o//erwVpH4Lr9z6v2X+dls1ed1x7ci5zmheJz71gqNGeWpWTyYpQKQaqXKpcDabPFYuLQStIXcyGEyng0nhwi37MXjgdYFmPetLUy+p8gHgAMJbLdjOy2E8djJrfv1OzXIbWtadobqQSafiQSG32jgQBkNqVyva9sPpx54wO/y8fPId8bL+eMqrGLwOSal5t2XePVQLweGB5dMGQlVqHWgc31wMejUNqb9roF3h4nduM9ItX3m7uF22Uh1/XtI6isXfHBYGvCHndh6HzI8+D80pxDuAUeb4THzZsqjgOymYjTBkV4mh5zlw5cPNnz7Knf3+E6sGxHjuNAiBN1OzLiKF4xY2geg1ICkkayuXhvDBNhGBij0IVHnbJRmsuHIGUhJ6EYHpw11MKye0ajQg3Goi4ySxauOgplydW0x4EhDEhd0P0BvT3AsRK0NQhp9QQIt8QrSyHU6uXv5iBi2cEdolehQmz7VQXNhVwdHaVAvXxEeHchzjPhc4/Qz+/YfPm/Jn/N9yD79FMsBUo+YsfTM/vdju/qwCzFibRJKEf2i5eRDvORpXjL7fU0MaYLalWChZVjkhfn0JTiZcIoEQ1GjLVlb7GdP5ASZLsiMYNUqibu9pm5kUvNFBQOByf9LyWjNWAqKGHlfKl5WVGL17FrdQL5Uk5eeTEKtZi7SBXPKlRP2cXa5ZUEz1WdW+btvg4PS5PBiGEk+sqLMEIYMEskhsa/6auurSc+mboHh/XbMqtxQ7naYhcT6rdLwQnFFU7CsGcbWe9eW8nf50hL+28VVo5ZQFYfzWjufadFYaPYbsQON+h8QPWINaeEsTgZnzSgMUHakiwhVVb9B5ldXFByaeOYfMOt1csvPRsmNIV0RUwJrYBlqoSyEgUJRZx/UCqSC1oWNB8IautOrjWj5UBhgXEDeUJig/K3F+s91rJnQRlNuJtn9z2ulftDIZ+Nl4gyDlu0KdC77dDDDf+cP9TtmTpq8oD3h67dlUOAjcLRWFvOhyaFcpSO6jRBVYMcjJ7eSLPc6RypFXnh4Sbcy5c9SOyb+3mH56HJqgjwm7/xf+eNN9/kerOhpBFJbaE09yZUXMV+RBBV75pLw1qCqKGSTQlBWGrxsqUoqotzDfFSZkgDMUZqbCX8EAlEt3br/NEmjkx7L9Kr9cI2zrFzV9o41FeezXdaik1wOsBZINTHsd9vRxlfXaRPz7PxIs+Cux4YGScHARr1oJ+/l2Y7p6vlN2uX6O8UEPZSemfJjGeYnfWBOltZ+u8EzpGzk9p8wt+/YsqhLGiTKphn41B7Z2slLHumUkg1Y2Hg2HhH25rRWilHGI+B15bAi4Px4dG4vW73FgxBia0L0IoylYHdbsdhX9Y5Qc2+tpVKrBE0+HwQbwDzcUpEEYYwuB9oLVzWwHZZOOwL9+/5uT5+DncWUfOu02eAHGB7xgCt72XkAq5ePzLczsT9gfmNayw95UJhbtSC163wGStswswoC0Ez799nnnNCPh8TGBFUXGA6KIz4e5wEtm1SjbTEBS9rjhi7ANtxXC3NmDawu0K2F4jC5TTwaLPn4nGAR77u3j2N5M2IhpEgEwklF6glIYMgzd1Ag1sXTubOLS7J5N6lXT1AaiA0TwmtBYbIJhm1KEOC1MbLdPH30ZSqrcKoC0EG0qHN+s0GpgGdku8JhwPz7S3x2QviMa8vTpmiyyeJtr2uMoSIxYg1kWAJQiygoo0SFAhDojIAGWnVizIUJ+jmhfKVryE3TwmfeQu+sGX/zlcBuPjgI9htccbV7y8wezWZ++T45Pjk+OT45Pjk+OT45Pjk+P/S8d2NmKUdwxDYmJFblqtU9nnPzd1Ljge42j1hE65d+K3drpk50b8Uz3Cz8wZsKAj2IMONIyTZEAchSkWLkZdM7iBKqeTi9k6lC4fRuV/qLW80ZE2cL+Y8pUDWQi5nWbIaIUItDqObxbUry0nyLTPvBGcgiBvcigWCDattUwqTlzGkegHKRhTvahPRlXyJhNa0dp7nOq+im/fWKBwvJvTRlphOyF1TIzx1Zarf/jmK0k73bd2aZq282f9ttFP5oxY0JEaJ1AB1GKh3t6TbG+JhQbufhw0uMTK3LlzNiBZsmQmLWzLJvCfk2bsnCVgaQcQ7tyyeOHmtfNERQ0lbLG281NmyapkzWMRyRbQidcHyghRtpPY2XmXG8gL1SBH3vmTYEjYbxiazUmsFyy0TdFNiM5gX5aCnbDi08qOYOUbqfRInDbE2nu3y0YZoFJyo3hGVftTovy8A0aUztpVVkFOrn2dpbXPSWgV7+bQTe72wesrsFEdFCi6j0B77iWx+hpKp+Xe0V4N7GikZePbh1/jgg29y8fZn3b6kSSPY4cCIYilQdaBKIFhuHXJ2moN5IRuExh2seMt/N42H9v6EwLTduMhoQ9gVI4S0CrH2+wg4KjZyksRYTbJtbRJsVkyc4SJ+nPO61p/r6ffXwx7+teI8rvNxXkVFO9r1ynedl0XVGnAsZ1xAO5Uora1Xeq6p8Lsc5YyTJubk/349sSHp34mq9rBj1CG5/qOYvJxksEoemTSx5PVODdGZelwIMq0c4GwVE0MOFe6U6VJ4/Rh4fc7cH/1ct1PEUEJsa2YFLLHNgYtZqQ35WFqJisbTjURCTFgKhEasDLEyxEiKPt9GEy7EMBa2uXLwPhVeKCxUdu15HNr/353dkwLDPVzdz1x/NBM+NTLPM3uUx9uBcWySILeZR4cDm7kwHgy9hecBbs6eYxHjURWwQLLqvprAJvg71jsuN3ijyraN/1bgaoxMaUC6VdDFNeH6KUwXBGCUS7bbPY9SYpQbAGpQakptTmXCrCxE8uAUiqF314ZEHCZGQnu3XCszdPFEIK68PajVsAp5MFItjDUwNvQtSaGUwiyVsUasQMG5pqHzCraXsNtgYmjJWF6IVol45/XS2q1rqSTU0VIpCEqKgRITtcUIUQJBj2h1D19aBcziiKS8cl/HlxkpFfncF0ifTXBzC/uAvfEWm8YnlF//NeLLiuwGiv7+ELPv6sAsDiPD4OS9oo0YTWGuA0t2mYTb/Q0ahW0MjNEne4yRQzm6RIEoqgLVuwhDCCtBryRDtDIGIbVOxaLm+3QLkmpWyqIc6+gEQWmeYDJ6GaqpQedaSLqA1bZZ2srBOm95L9knaC29NBbWYCG1YKqqw6u9HR4LCK4k3QnOfrX+J3rRACEiITTSfF/YnRp8LhJL+5fcIqUSAnmK6MXEMIEWsNI2BAvUbvxq0Bo5G/+n3RenTtJzrSRbd23Iav1dBV0oaSSSEBQR16Phw4+ITz9FHJ1UWYfBOUe5QDmi5UDIB3S5pc5OzozlHuoetUxpzyMkF4gVC4Qe47XdM4XohM9h5+rjErDSTObnI9iI1GaAq5lgbpgsBNRaVybZA5hsqN1D3BC3O2Sa6FtsjBGrwTvSLBDiyGE+cmg1zLUk1XZlLRknlLuwZLDTpriKhJ4FBf7HVqI1OD9CRhiOeCdvNKQ04/POJ8QYzTWPYhQv6TTRssIpCOml1E7v6EFhwai9maSV49p0WoO0LpnQOWZdv2to9/GbX/4V3rr+FLvdhrLvHM3KGIUQI0tSam38RRSyrhzGJWfGNDhFYRhacOWl+dLnqQha1YWVo/MI81LQc30XwGpdg4oeRDrX7izQ4VQ+HuQ09q82X72amPSS3jmP6/x31mE7VQvXo7aTBLwrWl4JvuHEA/MEqLkE8GqQ1JJU4ztHVK8c/fn18vhoAZHTu99FkM+PwCulWpwiuiac6mthKbZypmo949qd3XPGSJaRtsYtVhgtQob6sjBdRJ48GXlzD/tbz5zvrwZqEyKPKbhMvRnjJrDZKcvSuMkKykwO6uukuJSShZEhdv6iC3RHBlS8FDhKougdBEMb6W7ae8d8f2cm/H2LnBpjDvSATRjuDfmthasXH/PWXrl5a8drTTvt+v5I1SPbZwv2wvjm4nMtIdy2gX0H4yCVJ+0dywmkeGC2qyeOWeDUmVnFg/4xetNYaPtU2F3C7gqLEyKBpFdcpR1PFK5ayTDNM3kwKoUclFAnX7t1oaaI5WZGL1uGEClxZopTo0EEiKfGMpEBQakoZShYMZaYGOLCNA0MraQ7JqEuC0f1wNO3xozERNJujLqDzYahGDEv6N0ecqbqgump/LhT0BRYqnqDXcme3OkpYJQYoHnoio2N3ZAISUGEpYmJ25TY3H6L8MEG+UNfoP7QF0nyGH2jMM2PfA6++Brx688Y5/jfja7MTZrYDpGIm2UDGJl5GSibDbFm5sXQsnDUuzVQGMIFYRooVZFFUSqmE6EEYhBSi9I9sJ/YjbAdt5RjZr8spDgy176tbBiHidvDLVWMlBaiBmJIxM1utSuKeNRSQuPnREcjqtralSlEtlKpjVumqoSWTUgwtC0QSUEsNPucQBQhEVEGam8TF0NlQESxxncws0aWDm3hAAhNfd+vqxOHzGBsC+VcCzcXGz735DFFfGJnmul6qGvQ5eO/ijWvi3LBMym/ydNO03sOOsk8NBQyRCEtR2oaEbYtOFPizddZ5u+HY1tMKTDuUHy1ctkC1wuSjo7kTFbnDsaQMBGiTi6hOMipK1Pcnid+6op8ORKGC0otjLphXXE5IuWGUGZX6beEBkWConlhaK9T1UjVSpWAzVAnIwRhjCej47y7QEyRw0CshcP+ljnXFXnqm7GKE5I1Fi6mHakqejw+UMPRFiQpHjgczIOtFVXrz0PaZj160CbqwddYjakt8gc8aLjOoGocN8Kh8/XOkr1uHdQtlzqR284256771X9N2u905KwLmO7bvyUCBOWb732Vf/v+b/ADn/4iSyMl7/b3SJ4pKRIGoZSXbIJie2U/ZGLjVg4ykKsSGu9MBLQuqMiKJtcZYpooJZOmLaEulLIwDIN3Yx09EI+Blc9Cey4X4soo5+Nf5GQ8fq7rfd4g3RX1O3pmYg8aB+AU9PSAe8LPW88+J3aSJAm4zENu33WuAK94EJ3Bm1KCz0kziF02wIwxGFFZEdHcryOe5o0ncpWlbfx74MpAYyU37ZaIuUYiy3ofgRPXsb9BGRfWTe35B/MmnHkuawNKttPcrQ3t04byKpVuO7Jn5MDChcBOI/HeuL6rvL3Ax21RfXQz8e7jiOjCLkbitKVm2O02ZBXmxk2meANZFSUkwSQwDIO/3y0hvigjuh2wKMSYmCXwYt7zeLfjw7eNq7lL9EC5M4b23CoeGGdOic3c7m/h9P6WZ5Wr4zOu3r0jPfHv3Abl8ka5/Sjz4qWjy/cEDkHJPTFTeGHexLMFtiUgQXmt0+ekP8fTenxNS4SCUIdImhxXs9AaxNIFJhOEymY/MxwH1wUDrm4qixh3FzODDRzLnVsbyQWhxnWyWlDfj8Loe1XwJEDrQOxd1W0flJyYhgmthZgPjJpY6sKmZfr3g2KxYntYUmbQSJANt4vysd33UQUeIWVmkQWze4YXd6SXs1vgtQWsSGKkIpsAB2VUON6+JMaRqaFcOk1ojHCYYQdpjMi+UOKI7R7Bnc/BoUCYhHL4FnxdiPdCeH3BNhvs6VP/zBd/GHn5f6O+f4ekS/h9aJl9Vwdmnum5BtfQNryNbdluZy9NLEe3ZNCEWJc0ZEWsoGWTJmAV1UARWcUJEwFR32hSShArKQhFfDHp51qWxUscFde3idGzmqWuabCmQJ2XhpA5IqaLoYXejIRQKZdQTKnmpbklZ5aSGS0xdFG+6qJ8pmDV0GqoicPE7TNBkgcua/kytu7LplG21ht7+UDp0rEm/uNuV6QI8xQ4PprYAtpWl9p+/Vz76BwF611s6Txws1PmHuVh8LD2EDShXUQRq16Cqpl68zHcvyS93jLYNNANQWuuaHbNsl6SBL/vEJJ/b0gQA/OxeHODnvzTsii6Gdi99ojw+jXxJhFuB2yua7MBYfauxWNhyXuGmlHV1gncERwopVKLNu07cV24XKnHwyocqyGi44hZYf74nmPJzGdehiuA0TbDkyVTK7eeySice+ZZ39Q6QmWvBAjnTRitS/icHI55aSSIL9wFV8Q/csqyz4+TCGtDj2z1I3hgEyX9xtr1WgtwwAMgVahi5BY4fO1rX+Wzr3+GKXn4dhgSusykEDkiZHP9mRBDk5Q5KyiakMuCIYzjiIRIsd5u58ReU2nvmY+2o80R0wqt+1mbHVVo91ntVFrsw5VaOX6WZnHFqcx5qis++Kufu1/t2XieFzn6u9MJ+ef/1suRvQw64MmQhtOY9p/3QKuoF2liONEWBKMgHJqiP2fBkJ0h+SZ6mpMibA12GBs7aWQVYAmLK0yc3U/khH61L129GQVIEhgkPJBPkTYufQx6c4K/B6f1ZqEQxBtVoirp3hheHtjdb3j9yp/h+8c9YzZmE2zYEINysQkcxZiKcLVt3YhFvUu7JuI4EIdEkOTK81vfJocc0cbVGAkgicPVlqzGxsBea2jLiyPD0bDi977FAzDj2xt2KtaaMPym8h1s6sz4fF0MuVdYZrhTV/OPqkR7WBYtbS11PTBlbOMb5LSWnFuEBXwN8bXkpH/psHf1qkAscLtH93uGULhck/5CGIwwQonKqIai3rWedJXeqZgDHRJcIF28quPVnRbQh0BIFRndImscYZwTOWafG21fn4aRYZhJA8wKgZlNHSimvJT84AYVGNVBiBoNSwIzhKUlmHOByy0gCBNBvDHMinjjGBBkoqRAyMHBjAAWnJ5UTZBWai6yQHEJj3B/S5BvwP4D7PkOufJgN11dYW99lnr8KnrTg8jf/fiuDsyk2bCGEDxwAiZGLuoOnZTFondAWsLKuBblKw5bxhiZS/GOMDVU/E3Srolk6m2zIxC81JaGQFRW43E1mEt2nlB1yQXPTATCyUYJ8QCnFSCca6aBqErTVcUsuLF6VnIpLLWQNVPKQlkSuWVuUd0jUFVdEsAiYhFhcLE8IAUliroNkUTfeB9onLVJGgwkICE0Q2g5SXq0IKKWzHGaOL52yWaaKLfzabG1hzwZaf+zimoKDbZ+uANZ24w7z8z07DyqLjobmnhrcPBeb17C82fYW07mqFMglAO2P6L7PXI8rubq1hWczYihyW+YYFUIoy/QWnHvxXZNUgzuFRsy5T7C8UgoCs0rT+YC+UgsmVAFVFu3ICCCtRso1Vhq9uy0LEStYPfErNShG6IHNCXKElnUVmSkl1TPERLw0qdIE+4VIbxau7LT//WseD3n+fNp6Cjx7N/iWXDdArlgMIl7aA7RS64n8/NvL9UpTa6DhwjZuXTG2tUuJ/QMeinTEx7D+R0ffPgtvvHBN/n+t77fx3R3ScyLPyMJWEhUUUqoWM1oR/WS68KVYmgIK1JNqSytlLF2EovTIWqGcZgoRNCFkyiIrkGWtOgq4IFLv69k/ix6OTMC4ytjIzwMrOA011/tyly9NFswZpwC7X7051YD3CuMTRqk69jBCYmJAIOQ1JExE0/soHXpmZBbyc1jMyEZbLAV0exl025XtSHyGtVtrPqzbXmv60A+7MwU8c7Ldqo1yItt3e56ame5wRrE9Htd55SdBRrt3xcgVmNYjM2Nsn0uvH7p79njsPDhaNxOhoQd29GpKUOEYTNy3bqkRSu1ZsKcIY0wJlJKjnLHnrwG795LwigDFkf224HncWA7jFR1+kS6XZheVuYbv8aJs2SlXXwHJL1j2db7rgJphunQgxs4IszifsKIJzD9PT8/uu+q4aX31Mawf2f/Ny/FNycLdfEU623TuUCcETmAzJQPPuawHIip8rSpFXxocDsW8mTcp9woOS7xFDSduLYBlMAUnJgcxOk0kRP6FkMiihJH93wGGKfAkoVxCEyjf+dmSmxzYlsL+wWO1ZC8UEV4X842YxHPlvYZDgtaK2aKVCPklpgV33/VCmoLVRfUEtR7qFM7VaWm4HtZc6EJ0RdOESFdXPtwHT6m5spGCshL7O6OsEzYfIkcXmvj/jb25vciZSEt34Tbc/fY73x8dwdm0Ut1wFq7jVFIITLFCWKkpkK0RBdhhb54eLZ2bDB9UMNQNMmqnRTUCKVwGArjMboummXMMtYCoEW9VCYFFjPUAik7wbGMgrVNPRxnSoaslazVJ7IKeeXKQGNPo1kpdaHWkdpER/NxWdGOGC7Wa/Q0N4K5LEBq5RoVI4lSJDY3BEDFlwipp3ZlcbNY6xFVb082Y2lQl5hn3PXpI7jYoB/1UuJp8wBWJfhz1EzES2LxLCZsAs/+O+GUBa8t/QFCcmkKjQuhtbwHFerhBVRHR1Q36Dwj+wO23CN19rJn8IAVPLh2dChixY2GQ9PosAD0UqZWUgY+PBBuhTqNjthRkUMzJz8cKOWA1MKQixNa8aYFrbrKsdRaydWDtmpKXVweJRrElsnLMFG0Ms8zORjaymavoktt7cRwB4pc9YEI6vlnzpHKcyHXc5RGeCU+x5HL7qc54teR2olj1w+rjYx+dq71eTb0ptAI4u3nK3rXvnf92jN01D/o72OhITCNC/blb/wmbz35LADbzZZyGKg3dwxpIsUNZajrYqvSy/yBKk7irtUTnBBc16w2mYVSiqMew8QwjKBGSEbOC9XKKtviLhanm0ltLCOnYkQPiPo4HNq/d9QHThtuP4KdNkvaOK9Dc/47bQw9mO7otf9QYS1FZ7XVgqcfA8IRQ4cBy5kNHSl58DHM9FSqNchN18yfhx+zwJ24SvyNAFK5VqFiK1eVCtvqCJxixNjKj0ZLR/tGHIlaSMENuSX5eRY9lfly8ICx35I1FPY8GQTneC3WpFxEmNWY9xBeZjY3/qEnKXC5yXyoA3ejp8YXwUuR22GgTj5vct2wLN6IE8YJ3bj0w5gGtEtcpIjEwDZFQpogjogYLy9HXhsmtCEy+sZL0kcVvTmVdCdOf4dXqgwtGQ4tKZs5oeCjeSsZnZOMI6PVToHZuT7ZAGwaD7WjlT3mOg9oO0pn4utUmZsFFHeECvWwcMyF8vGelxg1ChctMHsSEvvNgk2ZvIVlcpbYYI6exb44rOh8dJI+rVGNSGoi4SE0mo9kQpjIVRmisBnc63JomdEQhWkMjC4txt0eanaw450mJk7NMIzYqNR98SaseSEslZqLa2wCKUXQSAiCMfgsNrzk36WpaM4wBGou1G6JWCFIRLZ+/XG5oR4VDTOaXRA+aMLqfOZJOxE+/ZTw2mfhvsAHX+P3Or6rA7MhRqTBsXGt50SGMDCEiqaBHGaKuD9katGBEih4GcqJ4C4Sq0jj4Zw2dVPjMFeMlwxBqEtmznuO2bkJ94cD+1zQ7EReMXFemSkx+XkB7xIhtbKLv+rFXESxL4AhqDsXAFoqWhbqspCHhYXxtFnH6qRNaZNePTiLNqxvfBBcJLb5fS0OmXkZ0U5IUUcCai3uRdhyuRACxxY9TQbHklkeXaKXl8Rw41wX44F9TF80H+BjPUNuu1ffYHopuSc5MTnpF1q5TVsJpZFz05RQEcLhHmm1VAlGxA3dRWdEj2g5gp54bybi0LOTa0hxaIFwIKWI9rKVGJJaWTtuSNGgDpALUm59TMs9LILNipYD1kjlXbD4cHQkb8muZ1fN3I+yuEYSm3IKqLV4d2dVH1utp42Y0yLaOVlzyagJRdXHjDMC+hmi0oOkjp6ce1e216OZM/vG6eLmJ42uhGfbMXhXZzXn2/Tw4zzI60dHW7K8gpjBA+X5cHZfnCMkrZyYW7OAGYzDwAcfvcc7z94D4PNvvNb8TY24zAybDccsTIBF9+UAKC0YIw2oOs1gGIQhDe6LhwdmtVRiSKimNUkJEcpcETuhrSHGFmC0zRkv8XfO1NxGJjYEcObMj/BsjHqA1c/xbb02/XOvoGMdGelrUkf6+vhZtWYLFV1QtIshh4iZss8ZL9pEIkqoQmg7Z8ZjYJNTibq0d3qCU8eyBaIZt1JB4LEKKW6YdT6rWxrHCEkFNR+DR1PkchjYjcPKAa5WyHXgMkQuhhERYS4Z0kxjh3hXrblY6nn3MfjatXaoNiSqcgqEhxmml0Z87id7vIHryRhj5MX+nkkm5lCIKRFDIG58g51QxuyBmQ0JmbZYUoYhYqnRVuJIGAMSvJs0DiMpDBAyx7gQLr0sGl7bIk8z0ze1uWp4wLRw2nATD5lG/fo3je/XDWNmaY0n5sF9bgl2eAWVtHYSwTuIpwCjfrtzBLi2WaSJ0Qresd6sz2pV7JjJVTnMC9F2lDhQsxJbFeIqCU+2Rtkac1I+DobGRAzF3UVaJCjR9co86IzEMDSP2pHYfIdjEIKVNgDKECJD66AcgqwJfQiu4j+MMGWoY0vwZ+UbeW7zZmEgEhVKCpAiCW/gqyjaM6Wq5GVPHAdSStRiWKnIcmBoBuw2RBe4jhnLmVILNSQXWG9lbIAwTpSDuId09cJ0mDMWbt20FEgywsdbwtVTwhsV+P/zwCwEV9OnhpP1ER6Zp8Yn0hQpFqkmZ6Uy72rr0I01IlRM4nIPPdAwo+JSBcelUqOwzEeWZc+xkUZv93v2s29yg/iCeEQItXAR/CUGyLQAsHkkurdgxOQURIQUsFod4UJB3UNuPmai1XXFtliRwRhSaFF/8oYCETqsFsSVjMV8pRXNXq5FkKCrUGgMqfFnDLPSXndt8G+DfofAHYX9bqK+9gSxb7XrYO3QakN/+j87bUbGmR9g9gVhVeto64aGFSgBmpF0r1uoYlqoIRBefkzcN8JuvvaAzyrU4sbnZSHUvAZ+4poPuBF145WNfuES0mrvAiAhUCQQ0wghIseMHvaURuq1UojHjC1ul4W5/GEphcPhwKH1y5eyULVQzJFJF89V1CqlBW+53LK/vyfvjxyXZeXq+G2fOHKO4xpzsWYRxsrrO8++HwTCcuJEnSNmPjFOiNxa1uQUMPfFXUVYDBahIYMPS0q9xNe5QUVO4qPngVlH7Tx4kdUx4Gy6NCTP1qDSSe2VbPDlr/8mAE+ufpSrzSV1OrB//pKLaYDgzRul1nWb0lrcv1ZALFKppGRgJ7qDWcWKrFIZYt6cERuS2q+/trZps5OPZhEfiN48s1I1OQXVud38cHaf/bn1bsSTFMTDQ89+uDZWnH++Ba+dezYCe4GjVExZS0R3VjgAU/JydA7KaJFBy8mtowXpIsLBHjYjJFiRw6OV9lwHqmauMPbh4O9eu9YrAi+rkak8joFPX73G66+/zvbpI4bdFtrmafs77o8zowrRoNTKsCwss1Eal3OpuXWensanB2DniN85f3FpZLShGOEWt24Dtju4HheuEjwPI/O05ZAqo7V+9cYVmky53O6cXhEDwzCisRKjYL3EJol58MrEKC6nMaQIIZAlEa98YMfDI+zxgXA5E25OJe5z/mFo87yaz/2IB8e1tfGcN4VUaQmwp9FNCPvh5OlrQqKhZgaTOPe4Vw/6r6T2/aij+1UL0polilWO5Y5DMQiJbdw6J7ku9C617Sg83QTKBpaNcm9wHKvzy+xM/0mtdd0KgnN9YxyaqGtznunvUfDPp+Si0FHcdnBodnhDit4IlyAssEugEliy8i7O2/pYb3krL9gxE7V4xUOrcweDr6QAuswQFyxcYKJkKkEyogGrGx+rNHgq07q2tVQXz7XQnIXadY8bLCby0d+/TuyTkDHxfarcvIfMgfrmp8ibHb+f47s7MGsSEFhcy5SulN/MyaOwGUc0RgrG0nb+rLVF7opkR840CEPwDscTC1WR6IrFRqZapVphKTPz4otIzoV5PtsIo1HEvfMGOWng1CCUJVPUqNUXSy3m99AV7/ESlZlhJVNqJOeZmQnVhbm2zrMBxJxEGaPrkJnKmpUAZCpRtJUJrWXQEIkQw5qJSOsOc8RQKRbcON3K6v8XMI4C+81Ifvs1csR5cRVIAV3tqU4bRt+MgHZP7e89qzuTqQgi7ivahx0P1JqVGaBrh124u0FuHcGyJ9eQB0T7tiyrbcYKSUvEWkOHWStnmaHqBrad6B1j8sUiunacVggswLxuZLZUynJE56Nz12LEMHJr0Dg0dI+uUi1KCNOqwGDzzLEFb/fLwt3xQHacdt0U/X5PK25XSV8357PFeA3MzsaaPg/P0Kvzse9JQNKTR6IZHNpCc0BpFrIcMAqRZJGJyh5dS3B9b+8BR+dEnRPj1/9uF9j9Jl4N5OEUwPeflaqMAd77+B0Avv7sbb7w+mfYbLbUjXdoprghs7DkeX1nI+53KaV4hV8hZ+eA9mTAx6KVcKqX2Y3g6JkkQpPVqfXYTisrh6xzpPoR2263tHuYWnDqG+ApuO5E9t5heY44viph0Y917HhYFu1di1G8dKrBfSoVuG1ryZ3AD731iP/g0Zv8y1//Mi+rUaKj4rGjai2IN3V/1EFBJLB4n/PaOUeFKcC1Fl4b4AfGwI4tdYJvHXxT3KNESUyh8NnPfJbPvP29pKdvwdWFq9F3WYrDgYt8iywVFqUuR4Yls81K4caf/bG4fI5xspY6m2Pn5fTYynaLwlGMRGA4KKOfinEL1yNcDgsS98ybmX0MqGV2miH6RjzFhG02zMWdI7ZDQmNwznH7QgmJTUjEAYbQKjBJPDke0mqInvZX2KPnhMcz4UZahaY+eI7nkiERD6ZGIKME/P2EhyXl1T6rPfv+857QCN7xukXYmjE2asA5d69/ZqTl8NWobV8CyBp5uVT2CuPlBYNmKIHAsiax463xZBRsN5DHwv0VVCnMm8JslbGBA1OTeQq0xrQQ1ual0xEIMhCbTEVKI6m5ckQJLtkDbb8WH78IY4JajN0EHzQA4d/u3+WtMlP0SDoe0ZvncHuLLQtSILTAUmeFcURnqDFiNiFhQGrASkvehgEIWPBGObNCre4DHWkEaSCEHRI3HOtMqRCDslQlREgtRtiGSCwfInFgfHzF7+cIv/dHPjk+OT45Pjk+OT45Pjk+OT45/ts4vqsRMxEhyojIke6srdXoqfI0REShWoJgaFP+FQvuUi+FKEbG2+a9LBa8TZgWpY8DQwDEWFpkrnoieouClUZABYJWzzyjcIz1pGafq6M0tJJHDYRmrN45LaLKsIkrUTnnmWUZQb3GvVRnJWzMBW+HVpPSxrYwszUjSykRY3WuWajOwQsnsdrztjrVQq6FRau3KGulWHH4GghZqQlyjOibj1k2ic2+MCrcq66k/XOfvQc8GTlDd6R/p382iINWXaeof0hbetwh/VoLEgQOt9i9Z+l2OEIj3VsxJBuiDt3HhoQRElSXDQkh+NzIlRAiGcNSI+OPnp3F5FpnWueeu5K0d9EoGkdUKrbcYmGDtW6kXAul6Z5Yrc4BafVGEYFayfPRhYPbmEsTcZXqDKyAghpqtnY2dakA//Z2Ln3oThGCZ8EdSYEzpIFXiObrEHv3nOLivoeWMR9wnaqEMRsYXaI4Yba4mwQQeqNG+9ORNzm73v79HelbPSTNS9md19TRpF6mWhsg1JFfgK+8/x5Pdk95PSTGx5fcfficx9trMqkZ0TfUNkViiWtnsSLU44xOQrPdbM0ThopRtLbnXajZnBYQp3b9i68L8rCE7BW8ho60ay3i97MlcqAy2IkXdkbDOuvy/N2P3lm76njJ6eerVIb4c68VjkF4KcZFG9M//fZn+Ikf/WHm/Qvefjnzz7/1dd7pSiDtM8czVKY/T2/AcA7dWw1FfYzymiTSrnKdRmyaWJ5+gSf//R/mM7/yywD8yq/9Blorn/+e7+VTP/hF5K23yJtHDJtL4rLAwTsW7SJS94m4VHTO2LwhLjPbWtYGjpoqh7t7d6HghPye5CE6Pm6nxoU2VosJR4VGAUZeOmo2jCBD4X48MA4TWGY0JTWyYJTAJg4MQ0SrMIaEDFCiIU2EcYmJC0nuIjAkEu5awBQY4obS3u047bDdBru6w91eQsNcT3MhcOqYTLRuXoQnrUx5eucDgjGEQErO9Txm44ix78/RTmLN/j4GBipD45Gt6HZ7P4Od5m0gUqtynJtPqWUO6nppVZVNKM6TCwu1ccfifeViFHgUOU5ws3FZr70pC2WlZQxRTvI+oZmBB1YxYvAKTgja1mullNl5ajGR0rh+LrT7DxWY2vZVjN028kGrhH3j7n2+VAqyC9iLPTx7gd6+xBawWqhL53sbkUOjJjj5v1rFMlgruaddQiRgsf2plaJGYiCIrY1EWQyVuGrwSfB3SVajWtjrke0E8jyies4s/J2P7+rArOqRYh6gWHuAx5IxNaaYoF6RvCJPGnSt3Bet1KAQjSFssRiI6u30i7KqIE8psh0SkwqLFZJ5mccI7Bsx9q610sXanoPvreShUpMx9tJRLSwSQANWlKjBS6sBvxZgGAamIbMJsIkJsrHEQJBAzZlNm4BLKCxJiFVAJmJQhqiMMTCIQ/Nlf09E2ITkJaioSFRCSG490iBbDQtFFy+3qBLKgGlBLLOVxq2yhGS4jXcsb0Ref3qF3D1fyathDXjBqnkpTE4cGc6qw+ksYOjk3qpGbZseOHlVxHlmYiCNeLTYyFhm7MVv+ecOb1E2IzIG0jJhcSYc91jx8jS0cre40rsEacrfiWKFoMUJ/YDIBgs7qozO5JLBjesxVFoLdShYXLz8GybmekQkUk3JxwPW5BhkiIQ4EEsk6EwpGQu+4Ja2M2tIRFOmFLHNjmkakGVhORy4LYf1RY80OQCa4rfaqtH7euPHxGngo/09S4WPo5u7G1ACq3MAABmSueRlEbcyqUvghRrP2ioy03hUMTAMiUMu2OCNJHFeEfy1G/d84e/dtW06UOxh2amvVV1A91yMNbTP9Evt5xxb/fGjZ9/g/cfX7N74ApK3bDcvWeSOTbhgv9lS7nzsBwJjCBwOBzcqFyc0L6FijfVVi0ulhFAhQYzdZBnu5z1zI0K7DIWymLqUDD04PgUEty0gvrBIlYp3PDvvq1vhVPNyWCwnaQvkVBbtQdp5k0YUP29M8NJg1+pWIcHe4El1/a4FYTHXbPrxy5Ef+9xnAHj66BEf/cqv8dvf+AbXr13xn37hj/J/+Mqv803RtRFH8Xc3F7+fJcCtVN6s8MMEXm+B0gEoWnhxgPdsYR8zhw//NT+0ibz9Qz8KwI9Ol+zfeZcnf+zHiF94k2V8RJyeYHVGlyOxRcW1zqR9QXMmqJFwmsXFxQXWfe6yB0MlH6haSerBRw9WesA7CmsGGBCKGfdt/Lezj+rlM+Nik3i8KVxtZj64UDYHiJe+RufWNX8ZjboRLo4DixaOx4Xr3TUhQWwty0kyIbldnQXI0XW2UkhURl5sfd5cjVdcX15hVx96KbTkb6MV9PKlCDyaAm9uNjyOG9gNLPvKcvCxNxLTNDFsBg9MopBr4f5wx8194znPB16actQebLU1pgV9m/advflgautCHCBKcXWANglfqvAS764dFe6jwADVIktqnlOqDPPE05sDuoX9LbzcVm51aXITst6jiEAcvSHGIkEmpwq0/VOCNRkgp/CgMMhE0HugIi17i8lIIfBk2LJnZp6NrRkvqLQ8iv9qfof/2e3HDI/e5ni8YZBInEa07rF80qxcglBTQqw6FSoNGBHVI7QOz7iM1Fi9hB5gYISlYkEpUhlaNzjTiE5X2MUeyUcXhscToK4puKmV4/HIyHP09sDv5/iuDsyOOSNjdDQiN7f30lpfg3sihtDq1GbE2AnvTjL0jBoSFVQ82w4D0+TBTQrR0TFSQy6cp9RV+eG00XRib3XOPkGUlFkRBgEXTcUz3CqKBbemGFpQmZpocg9wCAGrgZLFkb/2tAa1lv40Z4DkCFoKBmtXliNEEpU0OMvXzqCT3CN3U0qTYSjVSGuvMxwbEWyIMBZlibB/vOV2N7CTpk8WbNXvWql558EXjU/RztlbtzvSIm38VM9QtYBbzTROhdbeOtYaJ2bnmMlSCGN1L0Q1yAVdZszK2kmg7U3xlntrzRCOrFmtZ8iTEYPruWnTJHOnBD09695Q0ARlwflL+/2ew/FIbW1l49D4FEkQDQRapiiKld6d1m25hMvra4Yxkl/ecjjsG9m8IWZ2Iov38RKDR7uJt7/XNb42T58yvPM+3/rKl5HW4CRNWNg4SWu4Fm+lhNZ5WWFflb2dyT9wOvp9rzoePETD6Nd1Hl2cHa8+4z5HjBMi1I+ewa/IUPtMb3GPwJff/QbXuyfEkEgyke4cCQshIKO/HPMyr99Xa8WiuHBvPmkKmhlVC2Yuf1Brcb6ZuhRJJ0v3jkADcpNs6Ije+T16N7atzQ9d762BNt7BqZ1zZw/EVF/t0HyANgY4FOcL1R7QF+MiuGxF1kQIhe8x+NJ14o3v/TSl7bDv/vJv8M3lgCX4xvNbPhNf8hNvvM3/49k7vN+fL1CKd3UuRPZa+UKAHwgJscJH7WLGIBxMOJiSA6RqbGLi1//1v+b4/AUAn/uBL5BuD2AHePQaY3oES6Ec70kfP8f2/s5yf6Dc3TbEocvx1OZH6/89TBObMlGkVQ1QQlsfzoNXFdZuxvOjOzGAm4TYbSFdwGZnyOHIYVyIJZBCZWozX2XwBHuMRKsMmy3bMLGJrFzbgYFDCryY2roWIyGJ+9dKhKGp51+DXT2CqwuG6zvyx6dn2i81AU9D5I1H1zy9vma32fg8ZqBMlWXX1gmMYRyZthvntJqh0djMj7i486Ty5d0d0/6Gjw97AjCsmFVfC07fmfBO7LUS0VQEuvtMoWvaeQd4NnW7qlpXjpmqYqUyVONicdunTRWGlgRb91imuoxVBCEhYSDGgRQ8mAVvlnNlAGsNAeGB6O05YjYMA8X2TKoQIxa9OrVp1/41eckH+494o7zGNo6eEBVDLDQHnvYcB3+f/aE0eZgg3izY1rx5nhm2AxIyQaKvX6G0pNg178CLMRaEsNn6z0olGBw70dR/DYkVyUvzxfi9j+/qwOyw7AnjhOTK3CLYWo1A9IBLBCQhURBk3TglGTEuhBSJBqqRkHwznaYtQ8e3TV1ZP5pPYG37vT8l/4jpWi6S6ptmbkFayNorrBRxkrupL85FPQsPXQQTGKmk5HCCmlBrotTQOndwhAw3o3X3JCUFZYwDMZZWJmoZZPA/MbrjJKaYeolUHgi+eoAgDU8MURmitzH3DjqxCha418xhN5IvN76BBu9mWwUS1YMq45VSpZ02nxUx4dQogJ2COfCJLqEHZe3fKQSNUIV6+xyAeH+EaSIWQ2qhzntCXlz3p+181vTEkEY8bS8/lj1b6ozwNKApIEEI6gHc2jG6dni2DFDcn5Uq5FpZlsXHtZVPJUXiMDCQCFWcFFxLG8cOORljGhg3GzabHaCuaF9qK5+tj+f03e3PVuCNJ0+5+tSbPm8+9Rk+Ex+x/+gDbl7e+Omxlfi7ZulymrphAc3CLcb9ebmjTQ2VUzDW9ajgrBR9dnm9tPY7LTn9GnpQ00s1ayNo+8V0dg5tN9sJ9IsZH+1f8tsffp3tpz7PRRowyWg5oEmQhsjkRrjVAGXJbrUlgdI6MMHfiarajNSTu4S0ALyarqWybNrG0Debft3nna6l3494YLac/fyu3d+1NXmEFrD1hKWfp7zyrHuJqyhsJfBMKk/OosGorgC/scJ/lOB7xwlUef/XvsrL9iDn4Oe5C3BRAr/27Fu8tdnyOYlcNEjzG7VwK1DN15i3Bf6QQoyFI6epeos3AFXxBoGbAKMWhpD4zW99xT/z/vv84Pd9H/Kt9+FHE/lyC88+Ij3/GD58n+OzD/0WX2YO9Zai/r5o34hhRZwxY4yJkCaCKSLF3TT027uM10C/le3UfIznXkZWCLdwcQnXl7DZLxyuMkMdiaWyaVGdtQ53GVzPsACXNfHpFHlSekYcmTO8EwvvTZUwePkzT5FQI7V6YDZfwXJ9xfj4kvTaHfrxqaM0tX3jerPjtevHPHn6iGG7owQhE5nCSMiFmFqqpMWDxe0GGydKHEhWiJvKZvTvm4JLAWstlGVhgFXxP9gJwe6doP0PkTUo6+jOkT6HvYlt7NlSWw/7oItWpsV4tMB1hosKSRUNSm2rif/N5ZqIDpAMIbofcQ+4kOY3jEs/mdNNvPQ5+BqNa45OQwIZCCmSTLlDGUy4aMHUV8aZ3757jyflD2MWCNvRbffuCnGc3EQUEMsUs5XCYBLOKjvtJlURRmIs1JhcUqR7VPeJhe/BRoBhR45HkEOz4WINzLK12CBnsp6no7/z8V0dmM3HmRDVbXK60GQWRycUCEqwCJIIIsTW9hyjZ8pRhapCHBIxwJAmNpsdsbXo1lw8mw6+uWpr/Q345AIQcY83wwNBa3IEhrmKQ7tWixA7stSCd2/Rd6TMrwtvzRZQi6gFlhrQYqQoTJ1bIX6PMYhrgAUlBpomTVuQcCsgCS6jETW2DdpQKtrV7DnxAGLTjgFFdWHoIoBW0BiZLXPcbdFPv0b8tW+59ov4hg++GZ5n/z4OzUmg/6zJETgnrgVv9hBV6Qbv1uKrzougLlAm9O4jANLL57AL7r2WZzjeE+pDdKt9qfPr4kDAA8lgLmVicVgH34ahdcQ2Pkh7Xid7J+8KWmom58xRYTFvx5ZgzjcBJLhRegoDVkC0oHUGraseVhQhDgPTMGJzJlvmeDyekMx26GkgfayA3WbDozffJDQfSQQ21ztee3zJ+3c3HNQNn4ud8ZFwDk4CYhE0Cy/VuKXZLfUvbAuUf7cHNVS3s0pSyW3DO19ezp0f1nPgMa/YKfB69SP9HElOQduKfhjrPAEvsQ7A1559g8urLd+3e0oZjc3hQAmbNbmx7jMrgRIqqm5+PhddbWdCCI2DVlddwVorxZRcK0U7WuFJVw2eKPWGe+NhAGYia5l2OQtce2C2NdeU0mBkPUMR2+f6+K0dzdLGLQzcaGYb5EQbMLgJgUGV/0TgzQg39zP3FgnIiu4cDGYCuig5KBsNfP1+jw7wtHWwv21wF+G+KG8I/IANZMl8oM53im1TVCve5dsufKf+PKIZF+1c7+mezVe/yhuPnjD8q19j96UfY7m/Qz6+Id/ccGjI2uF2z8v5pnH7YuvSdo/isa2pKbj+FGaEODT5B0GtPEBaT2uDj2YwD5J7wzg0+YgZpgNc7mFaMrd5ppQdKoqlRg+pwmUaWOqRMQh3c+GRCN9nabUFemQTdRIea+axFW4wjgykHJAprYLc93Vke7ljenxBeBQZtpVwgOsQuHryFICrJ68zbbZIGlhShGkibCa4OzLPC7mX1CQwWkusgyd9WquDCpOvz5tL5VJnjnXmoM8JxRikaRSeVSu0vVcBVlkoMw8cOkdtaXPapYVovFwhkUht1jvKawwKQxGui7GtxmSFmYK12VooKO4YIs1TGClewVivyrvmjV6doElRnPGhcc60kdBwSaiVRfdYUhY7uU98M8Bv7D/gj2uivvEUO95gzwJFCqOl1RfZ6ozE5K450ioD1pDAVuef0JVHhsj6OVXXGj3pgEYKgoVAmK6ppVLq4mPdPWlr60avkOt3WAy/w/HdHZgtMzFCqrJasnRfK6KgTYrAsFYeOy3MMUayhkacjqQYGYbEEBOpZd9GoIbQsvcBy4u3Z589mBj8T1GgZefW6hZd3wlAkriIaYPk3S7F/cpCi2w01jXbCRq8HJtjI/rqiTjaNF5S8KBMMKJUopzMUAIe+FkLTnXV67JVvmH9b+v/rcTkpV5B2bdFa1crEaVa5cUGbn7wLa5/eYe8f+vZU79Hj+kebtR2siMBTuiafHsJZyVGh/VXabfug5wVaibcNZXqu+ewH8G22DwT8oytGclZw3HzaDM8KNZaQZsaeWvT1uD6NDRBxNjDyvPWbjPni+Tc/FEDlULFpTPi0DR3holx3DDFwW2RggfZ4Avt6ZqCZ7HHmVwOzlPDk8YezFgbqx6wCLC9uEQuLk5cktvnhLwwRZjUOUGdiGZn2mHSYJowG7eL8aE4GtKlIMC/59ylISKN++WlQG0P90zCan2eJ/Sin8uvofPNe2Dp93HSV6ripdX+mR78uAXQaRwiwjEbv/3+u1x+z44nMTLGkaC++AFYSuQihJqIYmhZVvStWzKpQK5t+2mad7XW9U9vvHDfB2uBnq3yFefX5QigrWU1g1V+r6v9L7gQbh+zDlgrZ80RcNKJa+c+WPagW+0UCIrwKCp/DuEzBN49VA5JkFKdEdntg6qwx9iS2IeCVmVL4uNS+EpLDt6QwKeK8Q7G9yLck9njz3FPS/NxCZCCz8tRPZAf8PLtoVEfdgb7PPNBPmD/1f+ZL14mxjJw/OgjDi9vuG18qGfHG+7ygVJ0XTeiOV+su09McWAzjkQJJImkJExSvDRV69p4sXAW9J9VAc7RitJuYzgIw9G4OGY+yAUtBlFbiQtis1dLITJIZCrwaF/5dBaejJ4ApSGiKFdBuYzKt5LwQRg4xokynrLSuW7Ryy3y6BE8uuDytZeEFyNPNq9x/dhteoarRywxUEYhjM5B1ps9dx99yM39kbujB2ZhnLi+uOS6FqalEOOeIo5myWoxqMQhMY4bNE3EsjC01qXv5MUqwqrZUowViYRTo4Xzg82TybZmrVQAUXd8UGMogSur7IAhKFlaANa/B7y6cAa/25nzjKq3Nyi2VjdWL2sLDC1xVhswMiMDpVYQZawjt5aJ0d+Oi2J8+fA+3OyZPvcGh29+w3lDKDbPWJPBGYYLCk7m789cTbEolEaBGUrFSvaA0TvoCGlASyUXXd+NIXklTYaB7eUT9nVhXhYHV86CMGkL23z2PH6347s6MCtLpSTfuGtDOYpWQmr8HnFESVppq0fyJgULddXwQqxxgBxhSn3zjMIQkwdapVBSZg5H52q1VTcKpCbjbMFLQNLRHj0FGwiNHC1to20biRmd6dSzmFEGBgakDhjJfc8GWfWHqrZOHzEv8fV+tjN4oqNg0ryO7AyKlmDERqpUBUQ9EAnWTI69RNmRo6oFSRlM2NfK4e1HHF+/Irx7+4DAfd6Rd/bTlawLrDBv506dNYk9OKqdeBBmjgRqcXhcGn9SDs+w/BiJ3rkaz9COHrSIBCR4VqaKlxSbhpWGdBZdhAbXu7acBVkN7NePVEXb5l00YzL6C21eHk7txZ/GLeOwccP7UJE4oBIQGRiGlgG2riNEsOXIshycE9Xu95xLg/mLWnC+0vbyERXB9od2j5mc9wRRJnzOFU+yW4dnO42AlcBhUT4GXtAEYeVUturB04NnIoq1+b6iFZwFGZzxrk77I91KR+RhKdbRv9M3nN9v/2lHkjoCNZqjgAOR25d7vv7iI6bdU9KwZbvsT1pPmwnTSl32oM6hcb4glB6Yabe1MsieqKkpuS5Uret8tUbU7+XKdXvR0zg86HjjFFT6vfjF3+OaUlV9rCOyzrtz0/fzsbV2/wUn+b9oH/qhEPlJNcZQeVaVxMhQFu7FP3tom8Fd25HvKKj6+B0oFIOL1gRxYwqi/IcWUDM+jF6SmggsousatSRHWZM5P24YhKLGUnUV2lU1ombCzT33u5nf/KX/ku/7zA9Rbl/y/Plznt/dAU4/eYmfp3cH9kR1dRDRjFh1WskwuQl2LK4HdjwlX/29rM39BTtNv/5c5vbfMsMmw6MFLhdtFFQ7bZ5SKcXYxEiMkesS2dwv7BhIQ/vMLiMFHlVfI7YIjxj4cJf4ECitZlhHKBc78qNr0tPHXL5Z2MYrtuF1bPIgrwwDwzggZMrNDc8/eIfb5895eVg42gmRTQzo/ZF0eUCmLWMKqLhwek/VSzWWxS2BokUSIzDTVf/7OK2JfdOHbH7lZE5m9Fn62uvrnzR/tr4eAo1HppgERKU5T7ibjHs09w7UglmmWIGKc6yD67rF1RfO516tvo72d7Nq14HsiNnYtCOFXF2hLcnAnGZCcK7d01J4/+KGuxff4urJxBQ25Lgh1+oacY2HHqfERiJlyc0BpT3fGFbumOYCm3avBKq4A0oIHmv0z5Ezak2gehiJ4w6N940r0uZnbWj72XP9vY7v6sCsFnGyfw1og6Zyqd6hk5LztcQQc8ptseaxaItHc8FWgr+38TaD6HDa1AGiRZY4OD9JBp/ZHcqUgIgypiaMrKxmwcQT+T+2dLi3HFda0CbGpqNv0Re8MUxYnDAdMEsYAQthbScvVhoPxv9Y5y5VOann65l6vJlLSvTYzQlq7XOGWW0lBWvncqmBbfbrqqM1yx5X1j8+2jF/9jU2v/mOBxadLH22q5/QsIfSDg8Cjnb0AO386IFtP6+ZofkhekS+RUslBDdJHgleTlRZMzKamjsEgimI2y5ZEDQFQi8/Gi1jUyQOeB7aglw9/b9WXTlJRbyrS614dt+6eYc4eHBXbUVnQ+M69rJISiOEQC6Fw3zPYTmSW4julu1+BDyr7VmspMSwGdkQV+sTTTAMI+MU2A6ezGkLKM8DrKpQZ0fKnpmXMOyVz/TAqT9Io7aSh2IhPvxcf079757lnCFObZM8+56+ab76vbUhqOHBuU7z6GSyrWwivLz9BncXUNMVec5r2W0kQRo5piOGMMaEluwbfy9Rmj+TYkrNCym5VVo2F1nu199LYvnsepe20Z3bEPd7krN79STM72ZPZRRzZffgSYaLVp4sqODkM5r7mEZhU42PDT7Tmhv+4xwwFj5IMJlRWFiauj8WGXtCYoWD+TI1aDM1F9iGQKm91ARvp8iUK+8JXLVoO5mywT0r4dRx3rv6luzlo4lTaWYIvu49KUfSdsuXn73P9WbHxTxwOMzsm1OKmqfHnXNV8WSjnM8pYKOtWUMiEl0xfjCItWJN9kBaoOiP62F618e0P79hCUzHymsLvDsbd6rMobJtzTgxamviiIQ0cJF9fOVQoAtHB4OdER7BozkyPRZ3YQkzh90Fx76kDkbe7lgevc72zRcMzyqXdWQ4bilTQ4CSIfMty0fP+Pi993j/cMuNnOZXfw8mMvPxnnsyVg5urC4DmtLa4DDnzH2eOcxHNDvHbGwBf+9sB1YP126PVtt8zbB2ZXppnhNq1SSeutSSj7snvh2JN3V60DgM5JhO7gZSgAXTPe4aAcUCwkSVTjkRRI1avAHH6UO5eUXX9Z0VIjEOTAmWRqtYTBhHIZoHXJcmvLyofGt5h+9/8Zht2jI8eYP5xYdYLkirtMxyZLt7xDD4xVv1d15iWPd8zKlIIUUIkaoBNW/Wk2RrybPaAqrUpWCmEDYMu8fU40ukrRKllRKMhxSQ3+0Iv/dHPjk+OT45Pjk+OT45Pjk+OT45/ts4/sARs7/9t/82P/dzP/fgZ1/84hf59V//dQCOxyN//a//dX7xF3+ReZ75qZ/6Kf7u3/27vPXWW//O31VqIefYSl2dxFkgKzEIF2nrRF8q1TJqp4KJCMQUSMNArbJmKfGMVyQ48hJCQGpEJDqxO46r9ZEERxaimNOVhEbEF4o85Fclad1iCiUCKKG6zhLAVANLDMQUnJkj0eUlgnPR5nb9244kScdRGqJktiJyTqI0zBRVaZnHCYpO0X+vmttTmYCFJoraQPJGYaSIcJELFgMvQ+HFxcTrn3rE9XaiLstqPrwCZnJCSwS+TeC0I8e9YxPpSF47T0PLGsB1KnU0WGIF+6rzMGhaYqq9LMWDBNqawm0QNz8yza5PI2FFRwGkegdejT5uouq2TZ2jpc6DoCrFFDUlREfKkkXG5IXdlLzVm1waBy+65RORYfByZ4oDWSvLcWafDxxrde9R4+S5ts5Wmtek8zuWfHTJlJbdLQYbFdKYGFNEtJcSvJTdfdNrawp4Fp1bNtZW+rHz8ltDetRRVetf04nAXWSZbz/MWL0XfcD8h71U2QDlU/nTTh/rgqmvnrcPQxEXUzWMy52wGWDZPHcNvePEpglkxuzlknkzYsCUFXmeWc4QUHDEFPVGDm0IMaH5EK6ImVCkIXwNfiz4/OsrSQAIskrG2Nnne+fZHtcIk+rvEhjL2vF8hjiaI4fddzRh7Elcx8L/oHWUzrimVy1wTIF9UbRJpGTqKtHheovOCVQa6mewq7YaZG/E/3sPLtlz9kCCnfSvDOfY9eEL4v9T6onCvQG27XPDs4XLa+Gdj77KZ+PbzPPsHorQXA15wCXM9LXXj4ojcAFHbXKtpBAIMZJSYmh197w4ltnXl44InR9rg0lRhgUuMmzmwvNaOKpQejviUAhpQ10UGRLDXBmGiBzbwwCYC/JiwZ4rvC1smHg7QUyRj2NCm3fTPB7QYYtePiK8/ph4eWC49U7ZqTVd8eyG24+f8d7tMz6YZ24FchypZXnwDkwCWIY5s5QDKQyMyakRS7dRyplDrRy0YBgTYW0iWddY2nqAI7dirHM7n43TunbbGWrG2TvNWVenmlcnBEJMpGEkpUDbWpwD3bT9qh69YqSGhEKn7A+E9vOC1kgphZzzA9SsT7pgiWSCkByd08qOkVx8TZ2PyjHCV9IH/CA/wF3O7K4fs330lPL8fbaNf3nUA1ouXYIjRCoFbdWNteCkyrIsbNIWiSMizVIstMpJQ8yiJBecLYWlGmm3ZdxesZ8PyNosYY0j+51pO9/p+PdSyvyhH/oh/uk//aenL0mnr/lrf+2v8U/+yT/hH//jf8yjR4/4K3/lr/AX/+Jf5J//83/+7/w9QQxToUigT60kI2UxDjIjsjSVe3wj7cMe/aUPRaFWhlbKQNsEbJOxlEZSHFxCITQ+ShVc+ZEGDScjmDCJkSqU7HXrEOLqkVaDt74fvSXFJ1g1ZDgt8s8XuIwTqiMLypAUMQEVogZCbZPDRlyXjFZKjOSlkoKuJP9jzV5WNW+SyKVgtLJaSFiX+g4jieL+gtIMaCUyF2Pat0VkKSwRJo0MKXA7wc0XnnL1dMPlV2ea/ipDFRYRRlPugUggoQ/EITt5upfMlJPm0Hkncg8YFA9ULAQsesDY6QlhvyAzaNozzLl1XA4oup4sNH4grT26AlYLIv6S9y5WoiAhuu6OLQQVb98vM9IcEKRWUCVrZqZATN7dyYjIWSt1WdAanT/UagbD4GT0vg2rZYoWZl2o6gIvSxPaXYV5OZV4QhuvwMgUJvLxQDh8DMDGjmAzw53Pmd3RSxOzeQfq2nElcEMgV0XCiZweZL2sFoT5M0saCKo4zSRAqStBe29+nakVTHvwWM1IbYut1FXRf2gvi9qp3NfnfS9rdZ5W5x31v9PufQC2E+x2xuUjuB7vuNpUPq6vET9uCcltgs2Gy+tL7mKGu4WZGwLqRtdACUrM1d9pmrcmnSLwcPFsMkxr+fXQgrM+pjuBzMgsM08QFlyRPcqp2WAGPlK4iDAU48DAAeVSKot4UA1wZOBoCyLCPcbjGjAKf6IOfKE9oBdtQ5gNZnOPX8z5K1VOHJZujt0f7WGAMYt367bPfK5GolSXxjBzvmwb58DDcnoyuGzzEXNKxgB0S+Yn7Z2NBloL32uRb1jlm/aSJSVicW3IGo6rDyTtO3pXdz8m8/cdUcZYCeLq9K5J5dwl/11bGyhEOJfbWxuIonozjEXjNQIfV+VxPvCivCTU16mD/8asB8ZlQOOGGCvDMLCTERWlFudyaj4gRyXeVOaXM+M8ciFPWXTDpzYZu/bvfKGReTRsuyHtHnN3/ZzdxwupZG4/dNmQw80NH+z3fFCVWwIHlFoWhhYsrcGU+TOd1RsUKjMlzy7b2MZtNp9j/d4PKAeBK/Nk5nS0bs3g81RLC/Rbt3D/vvXtldI4tjOhyEqdwAKFwn0txDSgYyVvNuzYshk33GpLD/QerUIqIzVmZk3+kKyeEVyEUA2tFV0q83JkqZViiVxSaxzA92oxJG7ZIJhdcx1n5jAzWxNV30RKPPLL+j5/5uXHXLLj7viMwQqpDtSND9hFvuYohWgwSWIvQHDOaRhaM5hCsEyehXEQNCl19iYEiQExnxOWC5gQZaA0qlIcR9LFJYeXnQvpnw3h7CH9Hse/l8AspcTbb7/9bT+/ubnh7/29v8c//If/kD/zZ/4MAH//7/99fvAHf5Bf+qVf4id+4if+nb6nVhApvpE1lEBUqFohF2b2ToCPzs+yBh0ohusYuCm52QYzl9SoktcAoYYBbCSXI4gQhkTUoemt+AMMkhhCJRCJoo4ymV+byVlTsOABA41/ElrmkVi7QCMTQSIiAdVILka1jGBMBWrvwqkV1YhpbGTJ4MiOKqWpZ9daqUtlyacul2NZSBYYY6J3lRhKlexES6ER+DNGZN8ywKQBE0WjwKzkpOyf7Di88YjhnZt1gx8wKN5BJjgf6NWa+tqx+cpibJyCEaNfR3sO4HV8czSyE9VzzoTlCJsR1YVqhdi5fx0d1ZbhlHpC7lIkpoikiMXelSmnuYJzAqu53MLaqNGsuEpt3Ty0Nm+chFOb1EUpJ2P6kCISI6aJwIkzURrC9zt1T5//WMVRkkkEayKINs9obhtGuUfKPXezC4aeOH66/j5tfBdsRU86p6ujnMDajRrPOIwV86YWCafrDW3z7wT/sz92Zj9zjsTRUNEe3K2bAR7ES2NO9qDgnGuTAmx3wu4JXDzxwOxiJ4zjQnrzA27rGwAc7mbe0ML+bmB4JNzrM0Z2LNyceI9qVOtCmidBWWn3ul6znURle5IgwIWcNrxggWoz1wE+0sgTCn/0asNUKo/HS//Mdsu83LLZ3wHGs5r5isF75kjT3HZh1QVJeAAC3Abljyt8P5m5RS93wVGfkSZzcNYCXeTEy+zPpI/fNnugpgKvr5uDsSDsMI4NWelB0rnUYdfEGtv1boN3/wLUxpmMJbMf4HqIbI+V41H57JuX/Dfv3THoxSqzIi2AW3lI7Tt7UAdNkDoEkgRHFDVTs5JLpdSyVgX6dZ6jvg/mGye3Ca1gWUkFhqzIXLClkpteWJENRZQYChDdrisqNUY65ByIaM0sdzfkuz2RHSElhpS5fu2al9nXkkMU5hS4HwJ6MTFd7lhsz/2Lj3j+zKV+bg9HntOEgjm5pfR38nx5dKaWvxMGnU68JrTZTolDHxOVU7f1apwOq84kOBBYOHFY+3f1BMpw5ChWJ8Cvr74IUQJhSJRpYL8V6mZH2G6RcWBoHMYuBRXCQjYlVJd1Ep2w3ulJas05yqwzRTNLXXydLYXSOIDdSmxq8lUahTgIG0uMneMclEMIfFMK77z3nD/MBrvZE/dHjEBqFgFFtDXG+T2lGMkE1zltyVSpvsUbzlE+TbmGCFpzsjBvvnOXQ9dF69csrUtfc1s79Pcdl/37Ccy+/OUv8+lPf5rNZsOXvvQlfv7nf57Pfe5z/PIv/zI5Z/7cn/tz62d/4Ad+gM997nP8i3/xL37HwGyeZ+b5RLd9+fIl4CVAEa9vSZ/KZlgtZMtYfdnUmaOr+nbhOy1Ualt4fVMXMdQiSz7Z5giCkpvOkrjtkwlYairFMKUtVQLRPOMWKa40nkCrnja89kJ0iNnUmn1NILSyaGB0jMlGArEFXMG7kTSsPou1qfQXc2KhYaBG1kJtnSc5a1vEoBSlaBfRhFIK0lbwXI1SKypOfk1VyLkSycyjX9dlaUTPFIjVGyxun+64+dxTht/4Os25yZ0LohM8g/om20nHa9DVH1N7lucL0ErYlY6SncPqp41+/Z2SCSX7zI+9A7U60mWn81uzYxBw+YiYICQsJix1+QpxHan1YipSi3cltaC/1sqcszdSKE0Dy8c/xHgqgQdvOAkhMIxTa7cWrC04Pge1lRqiI5X9PtsYreXavkgKKOL6aXd3DHkmFs9MbblD6pFD015YSw3tXEsfVyJLk398NR48l+PQXlNsoqpFehOLUjqy1j4ytyww0mB9exiKr9fRfmeVxbCmiI+/ZwFdPQPTAGPyDHPq2p4b2Dwytk9g+xQuHsF2A8NQeTsEfnN5DsBtHUhfERivscNMKcYNd0xi68Yv5pthR8d6iTW2kPK8meHcqQLcZuneYGy4whEw2XCrR/573/cZ/sKf/U8Ybt6nfP0jZOfCn/HpJbz3DeZv/QYsLxmz8e7H8H8FftUacR9fjLfFS58hwI8K/EkCNegqmPo4B+5a4FxUWTht5OeI0RosC74+NZeTEOBJ7U0JihoMRKJ6AbfglIvKK13RfQzSSC4LtoGhwNhlYCJc7EC1skkRlsqQYbuDly+O60wYOAVmZqcAvEuMrPPGQvMkXlCFJVfmUsl26pqNcuomPO/sPd/8uhyJAjbDdoZtUS6WwqEoJbQuvLptAXoFBjQqS6zkkKgN4osyoCwQC/F2wb4hkG4ZLhbGq5Engws+5zpxM/qG75qZzUVlPnDfBJBvgWM4lROtNguuHiCfPceljVdHzRVWvTbaMzs1x/R32M/X5wd4eViCnN7rdu6ZU/dzpyNE88REVQk1NmCjf8YBhxqF+03gcHkNF1eMFzvqpEzaJGrzkRAzwZz8r/hcD1KI0mUwgndOV6iSKeLaZ5646iqvpBVEAoejEmPTN7Pge3EDNkaLXGjkdpP5rfEF33984jSdJAiR0LFAyyvtxVqVLIWwSu70ezRznbPS1mqVgNaHslRCRLUCC0FcrLoQqSZuPQW+oBWXvrHfZ2T2Bx6Y/fiP/zj/4B/8A774xS/y7rvv8nM/93P8yT/5J/nVX/1V3nvvPcZx5PHjxw9+56233uK99977Hc/58z//89/GWwOHDUNQsEJom2JU78CstZKXo3MSGBGLK8+h1CPgs98wXGm4gg2U0iwiAIIHcNu0WQMabd1CQ+9QGxOYEEohq+9WZoWlNiucfiraxpTaw2krp4uWdrHaEZEJYUAtoAiqgRAitXjbLcCcKzlXagmUCEpFa8VKXRGzJVef7NXN13tZVnFPzNgevdZKFd/K3S8zIotASNRmyVRSWn/XondA3m4HXnzuNXY74eK+LY4DaHQOnfc1nngkHbVZ5/5ZgKYtO+8Lq6hz8TjbxM1OJb4uyWZLRsqRWnctmonUuvgzWrP903LvciiNSSiu8n8qffjnrDbd97xAyc7XakFsb+UudtLb6Yebxsf172LeedlN0ZE2vo0zYQISA2nYss9HtAUKq2RAO29Halx3WMlVub29QfLYljqwkhsfB4bk3LFUfTHv/CI/p7UEw1rX4Gl817HvyIbR8EBxRNS8xNz5SarNlllauVZZFbzPciQ6fxBOfMJ+T5s2B41CEpgGmEbYbAPjTomTm08DDDsYH8HwCKZHMF0k0sYgVDYmfH7nd/CvivL1uys+9dEB0cxzG1yzSArbdmGDuX176WgLba626zq3hSqyUmm81CqRIdRV3XyUwIflyJ+dAj/5v/xfY++/Q/nl/5x6yN5yj7+DpjBMEKbA3oxPP4b/7JlLYfxyQ3mvTBEL7Aj8ISv8aQVi4lld2PR5Jspszn87tGfbeUT1LJBUTuXhLhSkAhcKpc2bY7u3LEa0E1rYS8p9TqzcGInEsjCmxKLGdVFs8oQ5GgzLhhKPhCEwZEU/vuPtNy548fIerf5uTIFVgb4H66FHFOs76zInUd06rtbKIXupP4utvsihBQjnFaIeZL/q6VqAmoUxG7sM41I41krtIVCwFRmLGBqEfRTmBJs1gxBsGAnbHWNW6m2hfnAgvQaPNx8TN46QzhfCfDGQojCqsa1OBbhZDuugJoREJNXi3NH2LHtZvx8mQpXWydrmZ08YchuHpaG/PRjvzz3ThIDbfN6cRb5mHrj171zphWfJ6UlTrDYOcv+XlqBtBvL1hnx1Tbp8xHY3McfjWvpdgCoz6hKsruVoYKbUnowSoe1T2bwnuie8wrDOidokiSrGWM2rB5iDKI1WNJjwuCbe11t+6/E9f+Ibz9kOA3MKDGWPNqBhsF4M987oECPRglcJYgdJlJxzEw0P3lUf3IEiiCsl+Bh58CjmnZpaxDnjMRCbNIosB1DF6rlw1O9+/IEHZn/+z//59e8/8iM/wo//+I/z+c9/nn/0j/4R2+32/61z/s2/+Tf52Z/92fW/X758yWc/+1mqOkE7GiD+9pi6Xre2slM1/xOGE2Jmpq54YVDtDrENKg4U13qgSXRidUBJBFt84jS0Y0iBboiMuRZWkIiWmYRRpBKjeL25b7QGEs9U8qNnADF6GzBAChtgg9VArS6CqWbEaORq5IZMlQlKNkoVcnU1ejdmrqve1FIVzQ77l+ZaUGojogfnpAHMNVNqXQVCRYUSKzEspOjbgQ6+QJhVdPAVNOfK/JmnHN685vrZzXqPS/BSTEdZQnXE5nzhhIYccpIZ0LOFuYE1D6U12iZa7ITuWDkiZUaK22IIAxZduWgVnjSwJg7qTgje0OHlj5OkiLTGANqiQMlIqVALWvoG6yirhEQwZRgGD9ZaIHju6xbEFcMtukuEIcxlWXXWzAwL0WH5FpT1xRUeIl7ag9L290N2LorGHkS5we8YI8MAB8ncH3wgRU7bVm05dUdt+3GulN4S6hW9iNLKyrWVJR77leVn4BbAvqH5ZuGBXCfZp3AKwvp9RVi5eKvRfIBpAxdXwsU2MFwa0xWM14I0BvpwEUnXQryMhF0kjZEavHwsSRhacnNVCt843DEdJsY7LyfUHHlprCXLiVOpZ5VekYbuns3RriNmnILLSmWjwjfbubYo/9PtxP/wz/4p8q/+n4j/xX+BzEJNuo7hGMA2zXLqxggKxxpIG+U/PgQ+ah/8wDwouEiFP1XgIHC0TG2bM/g13psHV32DblqhD0p57ZbaJu2EfxHY1hM3SRr/sZuq1/UeWd0HHpxLXDfyMhcfwxHS1h9mWZQwCWOGeagghizw+Ji52EZuW/KmCmMcWNMEM+cKixF6eSsEJCQX7jXlWCvHhg7Ws5sc+pzlVNLsz+2UaLSfAXMx0iyExRhzYVmOjJtGkJVKrsY2BEKsWIrMMTInQ9vGH0XABixdohvF5lv0diZ81dheCeXCKzmPDWZGRI2LfeUqjFQVNiEwtrl/VH93YrvezqvsMkrn5ceeMPiUCxSUIvLAhWNFfeUU3PmacrLMU2Nd99ScctKV/vvh6GuXZa9UU+fbxhM9wwiUZNjFhuW1K3h0xXZ3SdyMVFNi9yGkcSBRQhWqtsYaKaxUICK1ejDjFoRKVkNXzbTW3NYkW0LKVBOK9epXWB+ySGgCwTt+6/Ked/UjvvBigkOh5IUaHa3cpp17QktwAeJV8/PkNGAhkPMMFrzpLwZCylBaQt7mqhEgOIjiAtyKmpDSxDB6TKLbg0vgSCZpOUXBv8vx713H7PHjx/yRP/JH+K3f+i1+8id/kmVZePHixQPU7P333/+OnLR+TNPENE3f9nNVxTBU5AEPyBoUKo2Pc8xHJ/qmvnF2FCYQCahUzGYwL4H1kmGxipLIOTfOjXdMJkuYNcQpDs0+RQjBNZ9CyN4NJ6eOIdOzBaTNpRAgDIkxtE69MFJqpKhbxtjZpNPiAogAtQSWbCwFhujIRS0enEnbcZdc0eKm2aV4WTUQCO5kfTaGpXHfHA7xTi4lJlktiUwCmhIikRIUVLBjZnnjmpsvPOXRV9ygeLcotQWfNXp20xeb86nYVdI7IrZmbB1Vg7Vbrf84ykmkr59LlxmWpQlOTm1gY0O/2vZaGnIptprNi9U1OOshkGtOmWPmtaI1YzVjeVnLw7VWagvmJKT1ZY6kVa8MnHcYUkTSQJo2ro9WPXjsKFvp3bIm5LOMt76yIfab7eTeKHDUQpyV2Ly8tikyRRiGSBgCxxB4tswsyhr0QQ8y7MGpezPG+c96OcQbEawt08qwDWy/z7fDj+5nytGthmJowZ00y7J+j3J6xmucbJ1T5gggwLSFiydwdW0Mu0rcwfQGTE8NLtq7sRsJFyNsIkwBjYlsylIqRTI3rTV4870buHvBex/MXN4OvD1PHNKBuWcAQGlIyzmpvwfG54FZtyESWnBqsLHIR1R+rPnp/unvecLlf/oz6PP3qf/wf0t4lNhLJimkXiZPimXQLGgx5gpZlcsI15PyHzWWxn8usIjxE0W4H+FOhadFSQFetoVkwZiDJ0Cpsd5ru5++gfuE9g2W9m9T4yWNnObXBV6KHfHSbm/MjDgCFtZTuWtJkcRYKhuJpGu418LQiE4SR7IdiBkn1AdHwnZL5vFu5O7QumargIWGEvnbL97qhnQUQhIZnBuqyqzKgVMjzHmwRXs2QU6I0vkzHDglHsFAjkKajWlWJGdKS2T93W6WO9EDkXmIHFJtzWXexR9CZKkR0ohsEnZ/YHlnjzw1hsnXwctolMPIJkw8OQhUQSVxvbvg5ej0g8NckSBrFUDMr3XldvXnaLaW+MHHrDZ0syPhyqmRpxorGt47LteGEKCq75cdLetjsyZiZ0mZgdtxGa3kF9p5KvM4MlzvWK4vsKsrxt0lNgihHJAVhfSOx4ihHKi6MFuz+pMW7BIxO1VrrO1RvoTImrxprZhVpBjHoMzBUchBZNUL1BDRELmQxAu947efLDz9t3ds72eWYKTFAZe6mdZzS4iolbXKcqLcuDOL/7yxiYMvaLXW9SUKcUKK04HM1Eunfa/vKGQY0BjR0vVH/38gMLu7u+O3f/u3+Ut/6S/xYz/2YwzDwD/7Z/+Mn/7pnwbgN37jN/j617/Ol770pX/nc2ttL7WdtpuiGSNTqjFIoJLRmimW3dQcJzgbLhgXZOsQpO4xTZiOlNJRp0xlJsUNErxGbY0kPgwN/atQ6oKFjDA0xMFFZ+F8Q5KmqN8bASAOgThEknRh0skDwyZyK8GlO2IwJJysSFQjORdyLiyhEIjkApZPk6E05KcLa5oKMYZVSLV3gC9LWUu8DtnGltoHNi0nFQ1o8G7OVbhW4eVOqF98m4v/+psApPf8e0SNWSCUxsE5G4eOhnUaUzj7+Wm++srSX5I1/7KHi+5yVKb5AFogjh6Mr0SIDjV3LqE0QV//Nn8WgdpLmAJYRbRgmglFsVIdfm4XUs0oerLyqI3ZGKP7Yg6NU5DSiIyJMIyElFAKar4QdMuPSqGYsiyVbLoukj34ejWpMhx5HHDC92C1ZaMetA4xEqJSipIXw1p9sdopMOtjtz6L9j2voiP9GZh5l6iTAwLbp5HrH3J4Pj6fWd6HfNceivjvdO5Zf15G2yh6VtsuIAVoFCx2j+H6ddheQZhArmB8E6bXI3bhH5LNFrYbdIzU5IH/jAc5N/nIvgUIN7olf/GS24+/STjOPPuWEc2fwb7d+SIu7TBy2gR9c/MOzXNv0doDHHwc7oLxgxr4n/yx7/EP/dhPwaf/MPl//79jHowJZcS984YzCPhQYTkYZOEgxnYQQgk8T5UfaOP1AfANges4sl8WHhGYg5KUlUs1IwwN7fBp2VGMs838lSPgm/DOHPne9GYH9eeQ1Mtapc2JaD42cf19W/0OL7lg87/5X7DXDRffeJd68y0AxrsP4Gv/hlIDA0oZQQ/CNg48DsK7rQs82tCEY9t8WC3Kzuy+1Dl1Yu7k0bsOu/By37TWZh47lUeRE7Ls/+lJhbUxmJbAsCi7xQjLkWPjQ2VLJG0CwzJgAY4xc4yRpcFXdYAwGqGKc1Q3W2zO6OHA8JU948bFQOpklF3idXnKdDDqHgbdYvJiRfk2BttqbbyDJ8ff4b3vjTCpoWh1TZ1ORy+/r8087Z12NMhWjumiPvdVrAVmztE7P2Nowbzi31WpSHAT79pW6yoCu5Hy6IJyscEePfKXORbM4ipE7ZZzSpBKZUGJmC2YLlgDIwaZCCGCBYYQqGJoUqSCoejaULU0+D1wEKUGZRuSU1LWieqjlVgw2/Ll1yN/eHcgvlzIKTA1xxgNRzQM3oAQwro4SYwruijBzehyzsTWrS8izlcv2igxbW2IiVwzYkIKyYGf2my/gDRdEBoNyuX/f2+i2R94YPY3/sbf4C/8hb/A5z//ed555x3+1t/6W8QY+Zmf+RkePXrEX/7Lf5mf/dmf5enTp1xfX/NX/+pf5Utf+tK/c0cm0Exto5P4ur6Suqd91t5R5R6PqNeHAcxik40IVDmCDZQaGmF7oVsklSJUSQzJGkxZ0OLK8aHB2yRBa3QFeAFwUuK6ZvS1udWHWnJITM5FSsOwImZjGMEqpt4K3/IjkECgrqWyWis5C8tSGKNz3kqp1FJXL0YlOEm91Fa+DNSq7vMZI0sjANcSKOobagoJwojZgDGioaupRypCkEgw8S5DMW7VSF94k8NrrRX+w4wU18oheDa0tKXk98N5XINY7MHSY8BqjRQSoh3BAtQ5GpISpSy+sEhCeqNHO4d3nLYIIfqiIY2c69/hC5JZRay4wXWpWD15lHavyM4T85JLcAKzyKptF0JgGCbCNDh8X4xcC3PJ6+8WK2QTjtmzKA9+4FWCaIt3SGcIIv3vLZpSUYoVYoW7e+Pm3lvo14BjPZfT21ef0rNN4Fwd3D+bUMu+yAQBIulq4OILjlxffSNgi5GPAcuNpBvMjZNfeZ7Ssl9rCGoKMI2Jqwt/jldXwsUVDBeGjRCvYXw8MD2KlK3PrbDZYZtL6hixoVKbz2Gpwke3A7dNwOL9vWLX19z+B2/A7cfITWV64U02SxuJoq10JB0J8oGSFcPxo7YNzqeGP58fUeUn/8Pvh8//Ub+3KWK/+A+5efe3eDwmJCuDmdtirV3UIIsHZweMS0CK8VIrVmBu9a0/mpXXKySZvVQYCkUDH2OUVraOYswFYmkdnd8xqD4dPTA+BrisruY/tc8n9QDVsNXyq+CJwXCGmCG+8QdVLv9HP8XyZ/4C07vfJGy3pPpjPqa3L9l89l/Bv/y/wOEZJYgjO2KM5ZQUQyNdCxi1oR2emKz6dnrqlC16Krl1ZGcNIs7QnfV+X0WcWlFvpbFlRWaIiyFLXr0yq0YKhkoB21AD5FDIROb2UhxCIUYjhNIqAQMyjUQ5oh8WyuULAKbLhevdxJQFilAPnqwbgU3r1LtefJ0XE45EMt7BPsjDzmDh1Awg0mkC9m2gS+fh9te6tudY7dQgcd6BWcy5heV3TMx8falWiY1rpq3bUoYE2xG53MB2g+0uCJsJjQe8Ma4hjALZMsEUkwNqrmagVgjtCYnAYJu2t5WmC3YCNujd9Q08yNVRN7dBSmgcILbrChsSkSVWjIlvjYXDdSF/80i17YnnXGeCuB1eGga3Z4uN4L9ygN2lpTc79X08hEA5MROpmglJvHPVBMMrVRLDmtBZdCTPYsfpO9b5Ox9/4IHZN7/5TX7mZ36GZ8+e8cYbb/An/sSf4Jd+6Zd44w1vZ/87f+fvEELgp3/6px8IzH5yfHJ8cnxyfHJ8cnxyfHL8d/34Aw/MfvEXf/F3/ffNZsMv/MIv8Au/8Av/H39XjEoaHMkq0vS7olJmhZKpTaMmJm9vLS07UqugCS3J5R1qpspIiKCW1mQkDQ6Eo9lJ4yrUfESpRPWYOUpkCkoiEYNwv9+jNMd5cZkM8LxtyU6I3wZhEwxNM5v4hOvhkV9XKJhusDqTc3XEa4hoXVqW1ro3YySEiJrfUyKAjUQJaBOhxZLniSGsyu9DigSJlGzkBuFXc47ZFEbGtEUYMUkkEUJD1YIFpgoaFZVAtASjMOSFcv02L37kswC8/v6/4fodz4WGRVka+L1WujgjKdvDyafyUCKC9rlALzN61py1rOWAYxFC3iPB0T0RQcfG+Wvl6FrmU6ehBNTVhdFxi8VxNdMN6p2ttBZxqTNRXO9qJfU3/Em1ME0TCSMmaR2YgaGhb8OwxdJARkhaKMuBfNyTizK36xISpWYWOXIwXTN804dSBazfbavvXcaR114qSwh5UZ4dE7dVeYEy4+WocDb2Ik3w9axs5fd+eg6ZntM5tTg0PR9M2X7+gvJ607Z7Q9E9qFby+7iPaUM3+nPt3XzKqYMu4o3M11dO9gcYLg2ugAmGAWwLdaOoXBCaevEyDcyjUqMgw8R9jDxfCvel8KImbptR4XhQdJnZPbnm+Y/cI7d7Lv4buLg7gQyd8BzNuIKV5J8xNsHFPMHLxgeBHV4afkuUn/pj3wOf/gLtMbL8l/9Hbn7zy4waWFAOYmwiyAjHVsrQLBzMuBMI0c//XL20bmIcO0QaA5ui3Jhn2UEri+jJ9YLTuyTi5b0HzgGc8cLMEcCKv2tDI/3vqp116ToSS4hM6ijAzrz16RpWpCgrSITv+R//Z8T/1f8c++g5dg9LGRg+/FX/0J2Qdz9M+VNvMf2bf8L49W+SoxLrxBAqu+TrzYyTw6s5d1FxlNp5RX6TR3Ee6CrVQmv6kVbWs9P19zKc0P7dfI3tXZl7ijcNtQ4JRdnlSMjGpiws921gNzNJd1AjIR4ZwwUvtxMf3i9cNYH062kkHRRLG6RWdkVguMAscNzdUZ57rWz3rYnhqaLcUo7KUAJajBqUTZN2qDJz199hzURcvLhiD5pxksBWXMg3dPK9PmyMWm2BWylU8FKps2Ztnat7pCFrxiL+fYKjox2tjA2N612zC5W47EkyUdpGIkNAH+9ge0kdL2BzgQ4jVo9Q8/ocC9LkJu4IYQQqEiowk1fBveB8M4MwVijeXDUEJcd8QsNicIeZVBhLdFSwqjf7zU3AdRBynBhkINjAt6bK808/4VO/fSAeDtA6g7NNDFldlaEqmLhA+zhgrS56LBnTwJDAhsQSnbLCfIeglI4KkqhFSXHD3Dhs3UlnvcWSGceRueYT9PZ7HN/VJua0DjvVE2G/Fi81Va2E2nle1so9rQRl5twG7fwjh1drUUQGkGH9CjNBg1K0AIGQ/PuyegCUwkAYA7qYm103q6WuGXZ+tCpaY3oK22lkGgaGocsG1MaXixxLn5QLQSrIiRe2S6GVb73zsjR9FdGylspO5NHQygLWtMH85ewQbS22Gq6HkBrHzKUfzjsW1+DkvMQB7LcQv9fttJ4/+m22H2TsqA80x867fh4QzRuMbjwkLj8ooeCLbSeSn5fmFON8kKMENAQ4KzeG4Mx0SU7kVIwUJpTQbKta43R1lesgRpCKqaFaWhOGf8dSSzO8D2spPOAlsWGYToKCQ2gabG7hlHMm50zRytzKsFYz92XhoOWBrEifJ6dSxamBpHNoojkn6FBO914MXmY3rl762Hd+VG+M+f2tCWs3n5dLnRtppgzXA+GybSyvQ7qFuvfuzHpjhL2X7brioOHBzSQeqAwTDBcwXgGXhXLZntsOhkuQHegA5QrCbiBdKHPryiyjkIdEleTUA3N+yDAEHsWyvhs3FGwQppuF8vbrvPulO97gY+Zfgavm4DXg/JojNPNzbwE4tg62oU3Ae3PleH8cyp98+xF8+nuwRwn7l24xd/fbX+UuwBM17rA18N1UVqu1G/wdGA2sCh9iEI2hdR408XlaKkUVSGor7+u86tyfDdY4gv05y2keQA/MToG5tKAlhYCtCarPFdHqzwmoRLYtwDy2N/Fp3PL6H/nj6N/4a4TdRLavIf/2HdLLis7NhuxwzzJ/na3Be+Mf5+JNI734JooxbOC6jf2HUtaN2xOuU+JjelpZ1nfCTuvIKrdyNlfXMv05P8tOL1D/URdzrgXIRmpis6V4wLjozrmeWqg6eFArlUM07qK/s4cQ2MDqXLAkIZkQNDAMA4eDb8x37z5js1z4Gjp7h2E4FCyXdS2JuNByPLuZ3oCyJqPtcxOtacNjiAdk/QeHnSV09Hl7fn6jtGB4PhumB7SRnhy33/Xkyrtju9iDBYEhwpTQwffLk43hSDSn5nhHsGIW0QYEqNpZ4xUgMyKpUX8FNKwam3Ca16AQvF9Uoq46kkvek9skEIVhFGqM5HDkWCd+c5v4vliI9/fIcO2fazzvWguWsweCw+Cl7M7bHZwyk+dM1EqUSNXqoIgNlNxkYmRwWkzr6FRV56TJqXHBsjntJU4eI/w+ju/qwMz1ToLXdluvd63KUgpVC6KVgBLUUMnUrl4igkklKE2jxfVszCKYrigXklpHX0DNXxkJgpqSW73dmtdkCUpBqcnQAXI2NJ5QiU7qDuA0tGRM08BuTHTHKlGQppMVFvW2bRZSrKR4gt9ilFVkz8ydDmqt3uXW6+gtMBQzhOblJ0K3flw7PKvrWrn6YONJxUhq/Cm/SYFVWiKsYyjivpPL97TA7FOPefq19xkayTJyWmRe9csMbWOpfdXljGPSPrN2yHBawIyHgV5ZMkOpaHJipWohaEB6u7U1JX+EHFqgJu7eIMS1M9D3CANzbpm0/0f1xAtrBM6UEmFIDHFEnNLqWVK380g+dsGMmjOHw4G7455FK0uDPpaysNe8kpr7Ncgrq+naNWgnnonhyNRe+xj4ON7J/4u9P/uRLcvS/LDf2sM5ZuZ+54jIoTKzMrOqq2siq5vsbrKbarUoNZtqoiFCgEjpRQ8SJECvgl4IPegPEaEXPggQBGh8oABOgihRrR7YqqquYlZnVWVlVmZGRkbcuIO7m9k5e++19LD2PmY3Mmt4oYAA8wCO69fd/Nixc/aw1re+9X2OeI3Netz/dxZefvK4ov9s91zwxcE3f+f3Tc8n9xUC0tNEer+yVCg7o96CPICd2XTm9hNMEfYR4k6YHmV2jwO7x8J0EHLwDrWQYJ1BJ09qlwnmXULmA2+6kFkhUGqXIYkZFUHCAkHZzxPt7V3/LM6FuUsVCZlnX7jhk7/xwHNpxN/u4r4naEEwU8QqgcQOqFSKwdx3t1hBRXiJ8ZcNvvqlxyxPniF/9EPqt/8QgDtrzBFOwUjqQelBAnelcer3QXsX+ANO4j/3IOmE88Ru+r2/TniKKU2E4T7A1Wsu6IZv7tf8onGOgbpKcPu3ZsYM7l4xnjs+9vb4tT0FojR2wGo3ZPPN59Ff/GXC/+J/SXv0Acsf/lPmP/gu65uPKT/6AbsOL95NhV1pvPn+j/mONp7nX+aDZ9/nJhQ4Hyjdy3Q/dckE86uV/hlFun8pnSd8FWiNsTh4ohfpkks38wgwBF/zrhsDtuAMD9CkGHmFqYB03cezNnZUzAqNiWyKBONuVu76+nwXCrchMKkHvi0IJQizCMQJeqd+fX2klkoUIcVArUo7VaQuLjSKr3+TeNBVuMz9a7RsPOOD9eQGugeqf7/xQY1NC208/5HEjQTZx5TfhUoXlLWfXB+ES1DW6J2iGEK3J+x3UnKCKdFSoFnCRmejRkZYYV3bqxnkvm9W8+BMBm4rhaALIo2oM6IuFh8Q99rsSgpRBZrQWsGCYSiWlKZ1C8y0nlkRTiERY0GXPd+6mfkrH+x48uOXqHnj0o3C0iohxS05SGlCrWFdz6q0SkqRUJoLmQcBCUhO1FYZLHLfMxQTIcRMqysEfH8Ze3GtXrkK0cVm/xzH5zowa1pRDVizLTttrW6dmdpqtxMpII1rSyST6LIJqpiEKx2dRqnDsNRHbmtGjG7Y4qIHumm6VCuU5exaL1YJWgnm2mrD4ggcoQqxx1bBEYTdnJmmSAo+smKDYErpkX+eQHVlSkaeMrl72URppDhtZtm1KGrVmxf6vbEegHm6LajF3nHTEbPeX12Lt21rC11rLSAyJtZlytrVQhqCI3YpJex4RG8c+rj75hc4/d5L8hsntIdeShnllHH0ubAhAdv3f0Iy0frmE3i3zGeAFm9wGNeo1VHQ1IPK2hpBhLVVYpqgy3hE6ejYFv1ptwYxbK0Et0zgWki2NvMuoDSCYiOJEGWgixeEMahRlpWHN2+5v7/ntC5UGmtHzM401qv7w0BT7UrQk8tmst07MyT4Aj0QDc+EvWRxrbsVx72+uo8DafmpERo+XpuNBd3b6pt524A+hmXyP8yPJsIXIsRKvGnoE6hnD2a61qtbsgRx+7F9Jjy7IT2dSLeJNIUts4Yz53r2ZhFt3APBJl7XHQ3vylR1+6ycXeImBiMj1KAkW5G+gK8R4lrINM5rZSeZJ196zN2/3ijzp/4ZfxvSnTtd1AaZSurP4SwMVxZuydRYeGPwK4+CS3ecj+hvf49PBuqTvCS0U1fsT/2e3enlFk8WuBcvUTaBHUJRc0K+wb5v6K+DI55bqRnrXaHvUgHGhgpenhybq7LFBv6z5oifiUsC3XTxiIF8pCiUapv6fiUwWSehpwfS458DYPc/+Z9jzzL5e7+P/db/C77zI6Ku8EffIexf+Hiw5ptJDNSX3+cH80yJX+ZR+pS2nDj39XmKeQvctyOG3tzj1yW9dGxoD0D806cuYTBCVekd16NUt9lkcSWfxKVrMRqOiDQjFu8apwdmS/PxYNsaDyqVhxx4nXw/eBKVRzGSRJiLupuCGS0GSkzI7OO5Phx58/INk1amLmoqxUGCsVb59fuzHpvwCLbauHD8dzsu5dvxzK+RtsClhBnl3bkPl8RPuZTxRzl7SGtcxbHv/DtQPN3SMw+ckwiEgIlQtdCq0iqXzkP8prtbTkeOehBeG1eqBcWpDmQS0SsQBCwEUhCmHsi45aKv0SkYqzUk9O76MfdLRdcTk0yQEicpfC8e+KOvPOGvfvv7lNXvxENQpCwc8o1XSYKv5xEuPswWqK15sLy6gHfc72nAUurm/BGiS5AI3iygtfkeAVsThK5nEJcy+vNomMHnPDAbfWfeQec3fS1L9yx01IzgIrQmut10rSA2OvgcBSN42CUqVy2zARIEc4G5GAKiLoaYtuzOKNZYxXlbqxU3f4094BhyDGbEiPtsZZhmYZcTU47uXgDkEDxTUCVPsD8EymKujyZhCBwTE4TgIrmm3rmpqt4YPvgqW8AR2Bw7zZXvbfgdQTdrHUr2F0QMZVPYfufn/ZwbdBsirZPD6l/4Mq9+67vcfv+Nd68OjstnH5u9+237zK8HyrahN/3fzfeuvy4Jjm7RZUUQR1LaJZgyuihs544h4l3VophWz3AAekAXFZf8qBXTiphtorDgQVeMjrhpL4nnGIkIob9OSqGWynL/wJtXr3hYTlRpVIzS0beTtk0eY3RFNWELED57bEhavy8rkEZHWcf9Rhf2eFq+oX82IvZvP4tgXh/j/Q2XCDHzrvKHQ+HY7+ty4yWMaTLSk8J8brSTkNbE1KUris4EMSQL6fGO3YsbdjeZaSdMUSjmAb2cX2OnT1hMeFMjH6KUMzxOmSfVx13OiTRHQgJrK9YiSSIpzLR2ZE5jAzzRaIR6Zpcn7kV5IXvmrwbav9bRyvffUP8z4EfZg7FQKD2QmgyOfdFdo7Kvvsl95cUey4n0re/wg4c3tP5+tZa+oTsC9gzjrLaJvgIcRTl2zmkQ9z1Vcd7QwYT7vnW+pmuN2aWDzsZzu3om4xjPOFzNlaF8E3C0YwLO1vqj97myzeRmWzlth2++BdiTaLVy+Nov+T390hewv/+fsJ4L9o//Y+aHwnJaSC9P3D31s61r43zI3C93tLt7jg8/4m7/AZ+EL0L+kCn0rjoDrJeuhO7B6uXykVYGrkhU4ghKCoEoHWHr642ZIdrQngRfuKzOx4TLPujJn2sGUiEU50+GOtxgujVTRyilGRLcL/Nl58c9ovGUzNTR1qmbqhcBC9HL7IBKpmjA1sp6XsgxsCPSRDee8/CrdQkMNj0zeiA+cuLUP1u4+iyfrVGOtWEEZUOgNthFmJg+toJd6CDR2Mrg10nxddw8vm/o9pZiShw6XqpUXWk10KoHXuOVai5JJah7Rwen1VwH5mqtf37F4uw81BAImggxMI/RGj0oCzl5Sbo7Qtg7DObOeW0LhD1qhTf1xG8/v+FXn0zkH3gpp+4eEyx4Za0aN/tMwGWexsWlGDiLkFMkloqWiuSGpOBlzk4jsbF3qhKjS3CYetlXbABBLuirje6y82cfn+vALCDEEDAZW5xH7KUUzIrjCNF1wK43evV4DguOmESJSBPnW5luSvwmDS1KzF4KlBA9WLuKekWMmAxK9YCwLcRmZOuoVR/Aw9CV4GjmYUpMc3bEZejxBZCmpNjY5UrdKWqNKEpMSuxEyJSNEL1sJ+rND55f69buO2aulwq2q73cvO1mdHueHpAogRy8zDfg2kDs0GzEJF4JxBolR2Qsah884f5rL6j/xR3yqW5Z/Geh+cFt0KuffZY3MW7xgOKH5tb1omGGi8zWwtAVCyEhQWmdiOlWRUrIqSMHEcWIaohdoW3Nr8BUnW+mwxrkomM2yi2DUxDpStGh68OVwWEUtBTW49G9LdvCYo1mxtk6p6V/5iGma3JBPa55IvTFc6BcQ6Oo2IXLVcV+Qu38OpQei3y7+n7c43B90/s9FUZG7XISwYAMn8rCg3aHjcPEPk6kLLSnAS0BeXBOzfADjYvPO90l5MmB9GRH2GeIjjIs3ZYlnoR2C2/1Ld9/WPnOnRDvjW/ulNuBJosQMp5oWCQy0cwoxXgTqqt0A9Iaao1CJaowx8xJhBdq2FefA/DwaMfD7iXnf1JJ33YNwAXX3HvhDxqAs3hj/BelsbvJmAof/+H3eKPwokNTb4G9eSlwBNWxj9PSH8C9Dg0uYTHjhDEZPDIv6LwZeZPBMwssIpylOe2AzwRnHQ2ZGBwll0Oo/dnK1QM1hNWUKIEY3UQ6NTbdJ22wT4JVT6Bqcq/OLJUaH5P7s9ZP/wvWlx8hH31IXe7JxWkfZYJw9xqAVyEQ7gKl3btTykPigZf88fqUQzuA+PMJtV4SVnGOn6JboueHUxEuHFfXV9y4Oxt/LCDFCF17Tq/G/GgkSvSAzbz0qeJIWWo9MOsUGGnNqRXS0TtrYJ7EfTr5qvMsNO6TkKNA8PdMzaNs1YaNbCdNBN2jtVDX6pQIa8QcNl/UYTy+URX6Ix4Uj3Fcaw1ePdifOEZQNhwEtuHSE77rP9X+PsIFebvwdum+vGxrtF19gSOVIuKi303RVmhtctvCrpkJ0Krg7iDekCGM5oMNRANcAFxEqMGt2SRGMK9EbA4ISUgW+t4du5RFHzeD/xw9qH/NmX1VplBgPfGt25k/+NITfuHDTwA4LAttzv0ZNFJxP1aJYVu7WnWrps01aKkspyPT7S3zYU+9c+pEaw2C4PaJF0mrgFzoThpoimuTbg7Bf/rxuQ7MfEJcT0e6+Kn0aN0Dpy0IGIPQemYxAQS0iX+pd1lKGtpWRmuFFszhypBcBTjgsBue8ZlVWlBKdM2qFkFT19kaR+hlIoFdgH12+4iIEYeHJz4QYqzs50qhsbpYFyF4nR5AQgHpisldHRvremE6yP/WSb7hyhOxv7axBRs+IoFexvMFMiEhkbq6+QhEhLiVMcdnEck07fyeeeLtNz/g/uuvOHz6agsQrtFbE0fSjHeDrmvj3neCr6tv9Kf8fl1XDsUJxcJQb7aNYCxyMRQ3SUic+/UEQrcm8fMX11SqhVrOHmy2+o5umQnb588huuijBOiB/yCtWim0UljaSqVy1srR1kvZkm7/07vHssCfse6+cw9bv3dD0Tte3Ztxz2M/57Wa9TYOPvMeny0hjwBRGNm0Z94v65GVJwDsH92SDnvqmllVKCVS93SSm19ZKg0TIc6Z+dEj9o/27OZMSMG1fRYvLbZ54mSZT8+Bu/bA/d0Cx5XTWWk3PfumoeraQG6n5FqFp+XIm1I5Ly7ueYoLscIkO85S2UklhcLrHHmiXhZ99N6B6W8Kb778mvL3F+rvwPzK7+knAu+NJM/g+wn+5hThycz5xyc+US8/rrJuD+sBR5w24egeNJ8HxwwQAov1rlXxwKrhKv5j+XpcfUF+QNkq41fPDnpQ1tG2WZxEfRRvjjFxBB0cPTWMkOBQ1S1tEr2Zpm+cOOJT+/XlqhgRjQ30LfIrXsq0779kuv8RYXnFHHaclwX2T2CpWL/vEg98eveG40Ph2FYKR9blPezmRNpnpLsglEVZ4iXQcCsbD7jCsNlBiKJbWdFwW5wYnWg9tMeq+esCBcU2q6lr14LRCQwXFD42iCuEEi7NUuuJ1hqLJFZbqBbJTKSorL1G+ioXPokekISmiELT7ulJvXByJVMJlOKKmEV9z5nLpYHrmkt2vZ79BJDdg7IRwI3A6qei6niiFft/1OydNQfenfvOWR1B69VrOnI3AryNx9cvLkgfS+L8W62GVW9uMk3UrljQmj9BQdx6q0dZqpfAzKsFXjCNtWBZfK53b2GJA0VNJHE+bcCIIXsHN5G6jJXQUDGsGvfcsZsO3ITAq0X47S8/5wu/7+sNpxPoDUQlhci5FRrCLHlbDLVW54ibc4hbK/1DO5G/XZM5B++aQM4TtbrzT+g/L2pOu6og8b8CHDNRQYt7Yg4+lBBJcXKosqsGqxqtXi0GHSmwXgNvTV2clUycJraekR4wWSiuEjxQtxC32W6tsbQFLBJMnU8TIVlyQdpRXuob5RD+S1GQCFJhE0MV0FCIQZliZbbKQ69PSagXXltdyDF1xWT3sDRV51VsO7ijaSaudzD4ZjICjU6kcVP2Ecz0GSmu8j/EakMvsY6gxAMzXyK0BBdlBCxk7r/6jE+/8ZTwrdfMR2clVLlksNI3g7HpDKRoLABwhfRcfT+ChdEtOF5nWh3ODuJNGEtx1KtnLrV2/tBakH3u1iKpIy+6QdfWHJ0MWtB6RkrrnbWXwMyfo2xdmemKtCxXbWHVGktZOZaF++XE0dZtcRuf0bvh+me8ypKtE8WvBR/NPIgbgRVdhHKssNeOAdIRtms/zLHYj5LHTwv8ro/BcdtMFMw3vZf1xNxL4GmKSDoQSmZfA3ExTsGfyb4vzEuNWAqknNkdDsy3O1KO1GCUJNTogVLRxrEI9VQ5kHgW3vDjcuTD+4VnvaGCnEnJW9O1rCiNY125X+8Ib3STy5HlSOMAaeYwG5UzpMh+uoVzl9QhcPv4PZ782oGH55/y/a/eUX5TOXw3sD40jp3/ctDAq9h4/hS4veHuW29ZNDOnsulerkCUxN4axbrZdF8nNoIwLgRbxIfbjkAMyhGYFJ71AR0IvEY5GTzRyMvQrtCFsb55mW42uEFcZLePn8yFi7SPE++9/5wXX3iPN6eFN28/5knOvPnBx5dkRITS/Hw3prQEoRqLCk+efUD60mMfXz/8PvbjDzmfT+RXr9k9KOd8A0ui9IEbrZLKkaW9pganZZT2CcfzzNNwy6P4FIDT/jXrubm3MMFLq+ZyRoPbKqJEajdbv4xWn3ujpYXN+9ddkFsXiIWpBwLw7lgfYKIoSAFZjdLnfy0Lay0skinaut1fYg51Q/Luo/JJWtmbcFiVvNIb7BNFdCufkhJJ9izHRKOwdomSpD8ZjF0nSts8/UwSOjif13/3047BZxbr6wi8g6QbbHxCsRF4jTS5n78nF9drhXFJBgGnBfV1TzFqMbQ1Ly9apPX5X9VwU/JImHpVpyeKI0ANw2c4gFGcipKyd69adQtBgOCoaSJQFyB6IizJttKBAw8NCzfU+gCnHW0qrKfK775/w7/4xM81HU9eESmVm737e+ZgXmbuNzHnzKKro2GiWBDm7KK0asbUfTDP5UhTR1elN3056nbZg4zVUd5mbOWxP+P4XAdmKg6dVr0Q7eYcEXWroX1InIceU7ggA1V84LZqKK3PXldvsiBOEgekul3SnCZSzg510jOBwe9pFdVKa4WzNXL19ve3qdLqpbW6ilCl+9EFh7LFViTsKa1LENiZmMQDTXEO2ZyFUlZ/nnEQSE6oBWoVTF2FX604kbcHU6VCVCUHV/WvEqnaLTaSsd77uQJCa9URK3NoP5mQYrpEU0EgZS/jmZNwLRho5KAFPfv1v7GCxBvuvnbLo1/aM/1/j4CQrCMD+HsYrrMzMr/H5kDLO9V3uQAGnRayoUTDNbWKsKibzIcgtOkWlnVrh/f36/dnSkhypX+NgukZqm7t66IC2rBSSUWgujJ/k7BlbbaenLMYZwiRGBLuLlqJBrV7ai6lcFyOHI9Hjnb2LiguQRXA3LksER8PGzLSoap3yo92QQECEIerxVV6vS3cV5muI15ejgTf4MfP4bJwW7gEgt796cVlMaEmRxjehgB5R1F/1st6QGMmWvZSd1TmnVukWTcojjpjKqR5wqaJVZwKIGLYuRB2Ps/isnIoEZ0m1kmI0ZBz4Y4jL7tMfXu0I1hjKp4tL0SO64oujTtbOR+7dUsRTM6E5BvCLh6IROLxDN2ZYerm9JZmnvzCN9h/8MCrb37I6bfuuP194fUf+ns+JOXnT5B2YOcHPv34nkyhqXeh0Z9fpvKAB0qlJx4L3qkJHoxp73CZgMnMJWC42EIB3KHb+E7m6MM+V54V4UG8o+zZzQPv7284HJ6R5glyRFXJ056UAuvs23B+8R6y/wB7esP+5uf4Ut7D6SXP//H/kx/8o38IwP1i7GSm2MKKkKp3u80KMU2UH/2Bn+vD76CvP+VQI1ZOmBrx9Y/4uAaY5v6slZYD9ychWsPqnhROVFm5PylffOyfMtY9ry14WUcbOQQWbZjoNk4rKxCc3G3O9YriN1JFN0mDPRGCsIaGmpdzMSh2sSFL9NJnR6etT5aiEFdjt/j9er12SogZ1YJrsIWKYkyzv99pZ7wKJ94LkTUqscBco6Mg2ahbFhlQy8jhCeUY0HZPa/BGLuidwGYaP47B9br+8bCXav1zDBrDNRI2GqO2IK8vnBsqtyXFlwDw4r/p3pljcxzNCCOgu0bwRlJ5OyX2k3BuZyozmo60msAS0VYGGmFBoPkIL9ZozTaqxEDMGqOz2DjHQpJAbKvrJ46NG09aAp6xJgmIZCRErJ03H+1zPSMq7NqRR+0R0d6i88whvsd368o//gsfAPAvfPQ9nh3vuH32HkUCtRmPF7DHM6F3za7NOd26VioRpVGrks+F/e3M63JpEIwKhtC0uhxVaIQBTwPJIjH6XnIqJ/48x5+Pifaz42fHz46fHT87fnb87PjZ8bPjv/Tjc42YDYkCuer3HaWmlBJrde0VNXOOxVUqEsBrxeqK/taq/21Vpk72nHaJadpxezO7NASB1RprXVh6GnLUToqnkk2waFhwRChmtiwqYM5h67W4tjZKbKToGSE4KTUYNBGaNcwqWYwpB6IYtZfKkhSqra7Ab9k7RZtiValbs4GCeNdKw7yTxUZxXzcOVjPdyJNsZUrp97GXKUJ+V1Q1OGIWJXFOSpt6CWI5UafI3Td/jjcf3pN//zvIvV302xgt6x3KNi/7HO3CiYJeERxV1f7Mxo/g3YyyqaNe0iHvphW00rhwzHRol+ClW2tsrWzh6vzqOiK0WnvTgJfBh1Cgl33DNs6kNRDDtLFo2crWixZO9cxDOblOmXTEz9jKSHYFW9ll+L6jKcaf8LM/qRQ5iL4j6w42SvZXfzfu8/V7fuaESuuYWSMWuA9Qb4VyIxyGu0ENJElEyUj3tpPoNALt9yh30uvGIfJ2HZwz6cbi4GKOVhMheXmrmaPSD/cnPt337mGLaIPdDFM1dBKOdeVhPXM6LZS1bc+ILvQIkRBc949irmkBkM2bhpoRKjx9+pT3/tITXn3jx3z67Ze8+C/8Zf/TX/u3+eF/+pvk7/9jyvfuqZx5neB9bdszOHREI+HOAQY8iI/zfX+7u37v5+DyFcMXdcLRsZFDv+VSavokwG2o/Mv/o/8Z9S/MHP4vvwdA+8bXCaug7+0I54Va7pkPE+XuLeS9fwEtGu3+gWl6ipwa9WkkfO8N06PnHH7Buy3rt77LWRdu8BLpjOtbBYN4yKT7NwCcP/0YihDXhbasxBpYbEHDjrtOgp52s9M90G5e3/pYFBqV+y4em1LqvM+wlSZDcwaVjblFpDI6xcedvqxLQwrHzHo5rbsGyLtkdfrneqfLcKBW6uT/UZKuTTlrZa/uj6u1wWTo1rUFMiXWXLiPlftoBG1QItoCIV8EuX1CB0gzpBOtZLSVn+CCGhcawvgaVI1rWscoiG2NAuNzDpT76rWj7DjcLK6lMbgaX9dTXq+KmRv5394tBrA+6AABAABJREFUr4575/chYsn3thqV0hQL7pJiCGFb3BJYpFmANjpv2cqZ0FFC8XVRS8GqgGRUK4l8QRhD7FpiXkkKYYIQqWIEHfqkhnUKT+GECyQqWu7IMvP77/nc+IUXicP3X3K3PGIXAynM3IVAWO95nJ/2m3nkWBRCJCWjhkKrhaiJ1hqpi4+GeaasjqapmgsxhtidLMbmEvr9dLeLP8/xOQ/MelnEbCO9DgNcpdeMR/lG2HZDt1nyHax1/S7TQIqXoA7gcLPjsL/l+e1TYprRBquurPXIfZ+sIXjzQLG3JG2sASTCbYFzhLv+HHLnuCUBaVBK4xgLkstmMWTaWFVYQqMGxZJ3cIp6Oaps5a6FpoGkEHXqg8Cn1+bukrws4XVzdz4wXMPDtPVpDqqVmOLGkwop9iYAI8YL+T9J2HSEwMtdVZQseSux7gVKazy8eJ/jX6rod16T/j+vuijkZSlwbaUOqeulTDmg+cGVGhyH0U5+zcuALjApehVxS+f16FZqDnRdmdFqz7vn3mQ1mstniBpx8A4/O4muZEICEMWbDJb1TK0X9fljXblfHjhb2QyYx2ccx7ZAj7WcKw4ZvNOZNV6gXBolfiJY6384xsjMCMDeFSjd3usz3JKf1i4v/ZprguWZsT7NPM0+Wvc6keKESAYTWhRvx4/zNkZaqT0P2LIm54GYC4iO7mdCo8ZAk8DaGq26DElbC5+89m2llIVlNfY7YdqtWFIWKyxaOC6Ftbe71U4TSdY5J6Igym6X0NA32BaJEiE2mp5YKpD33H7lm+wff4X/4Zf+Tf+MX/7nif/gd7gVWD56oOEdmFltC6bMYAkuKtuauqiv+Wa4UZIDWzt+APC35gYPysa5bvCSVcD41QDv/1v/Fvlf+3Xy8gj+3q/55/vP/+9M90r7+j9P/MIt+vAajkfSNMPXv840NAj/+A8gZuqbhfjV94mnO/S73yUsR57fPAXgxdeFH//o+5TjA3u8KzUhnDEenwrHO+9iC+cHst0QJdCqul0cgbCDc/NPWVc3mM4IBUGkbfW6Kgt3zc3oP8iZaBUJgovMePBsYluCIJIRWS+1PS7zNG4jE1roi23wxq2NFC8Xqq308TxkKbacz9ym6pL4+fiz2PmzzbuyyQnpOkWWI6ekfJJXbhJMNKQaRLfoi4MELoaGjORKmA/QKmttBNHt/Tb7Od7VHgx9uowyZ7NLoDR+3rgEcsDWHDDKl9fc0xI2Gu12O9+hM1ytOeNciqDBtr6wccqtoeKQqckoUSnSqAhiQwxZNioDLaAtgEZXNFBQJw+9s34NHqU1aNII0ohMSJCeyLFxnsMUe8DuJJLYKtLbjGOXQkoGNSy0Eom1kcNCO0e+fePP8S/+wlNe/PAT3q6vYJq53Rlv2onbuqMEL1GaGiqBZg4sNLzBIeaEFSF2l5cQI6UXfQXpHcQjyewB45/VUvtTjs93YPbOdu2HaqVoo3RxUMQ2r8i4zULpN76PVDWiBKYY2U0Tu16zPuxuePzoETf7x6R4wDRQtXBud1edm4nSIisrrSoWGylHclOaGtPYkzp6FmPv3mrCuTYknJn7w4rNKNWcTB8LidpZkd6d1BubOLdGaSdiUzKVSWJ3MZCN+0IxdysI4l0vXVhWuxhvaxfHvBiFEMV7GmJAgjlypiOQvQQ2jhhtDCYSsnWf1GmGtWJNWL72Bcpf+4vE3/n7tKMNmzLvTOn3ZIggNkZn2HZJ4wFvr9tQIN4VTKzmoo2eBYbeWHBB+1whPvXJ2+nYzQVwaWA920IrofVuTvXgwOLI8PyCpJ8r4ItDrZVaV9Z1pbXWpRLhtCwclzMFtq8xN8Nn5uX4XOOe6NXPt2Pw7Qbv7ApdvL5fCltnGrwrQnl1O/uf/MkLxFj0hS7zIIHze4o+SqS+UKYUSMkJuUJyDkzn5JV+octwXxDXoFKB1RrNKiYQN6JJ41xXSqusTakqSMyYVu6OHdmtZ85rZb8LzDuwXKmh0YJSS6FukWz0DFX1ykPQKBbI4udK2mgIhB0mGUxJtvBwavyPf/w32H3yCIBj/Yiv/c3/Lvqrv8yr//W/5z6CAUoQtEfAk3S+UpBtg814sDVm2KwuiluDB4kTzkdTjFcM5B9+3oz3v/5lwr/wG+T3v0n91b+GrBNnuyU9dpcE/dJXKMtEpHC6vWX/JNG+9fuURze0dib/g//cP/M/+zb2xffcM1bPnH/3d7CXH3MqlcNTJ/WftXD79BmfNsGWM1DJ0aAGws2OeHzo46gS2oqVQrSODERjXY4sxb2WHsriZCEXo8Gs0RSaGi1WHjqiGefIlHyjQ7xVXfo6fHleniCGehmj140om16g9BWkJ0zSuWV6FXkM3a4RtI1/o7lkxjhyzkzTRM4ZUu4yFbZZsAHEHFhn4W6q3E3G4+iyJ601rOomVbLJfIREmGek7pGm1HJm2GGN+WxAM0+SRGDuCdP45INTVsdWxYVfNi5/BGUD4RrB6NDCGwK8Ed97ftrUv153VGx7r+uXbhS6/USLRpXGGmJH3Aw1R8tG9QICpgFVIVpwkdhweSbj820NLoYHdpgT6WPuzR50pYBAIBKaUtax34Xt/YIkojRaaASNvg+0RtlFwkOjnP1G/O4XD3z1ufD4x2+Y5/eZd6sDIZo4d1WDRESDi8+v6zLqTt5xWYTYwZvSGtWGbEdH93uDQBnrWxCsNZ8L9m688icdn+vAbBwR2QihxAClIzJRCK1jA1fwaTSDoUWlviDEAFPKzHliPznkeTM/4tH8hJvb58zpBizRtHBqO0InvCuZopmqC2+rIu3sHXFJyBg3fYwW6xx6GXzGQGvGeV09QARSi5wCqFW0rEiohCSElHq3XS+prY1SGtZOzFZ8AIshGjfUxm1YFFPPAIOGHlhCre2yQATbNtkYxVHDGJgS7psJV2XMUZK6bPEtemcXAJPfv9iUunvE6dd/kfDXvwv/8YcXJGhym4sovmBcd63GrVGiBzBXq8KA+K8DM+vZI9UDTwmu5yUWtu6XZO6fKcQuE9LACmb9no5YXQ1rrmvW2trlRlxyZPP2M9s+uarysJzdl7Up2nTrAj7XwqJ1WxwV/yyflaUYx4gp4lV6+s5L7RLADoXv0Xk5/n7oOK3mY8ybBXwVHue6Ruw+Kyr5mbfz50BgCdCqsnwZ0m7aFkFLfsGBjqgSKN0eja6wfd42Vg/GGgWkYDQkCnNfyGpdOZ/PLKVQm+u9m0Z/Dl3I8aEay1I4zYFpL4SdYqn5/NCrTyeO812XnEe3c5J9vw+5d90VsEZtmaMkfm4RHn/3BdY8ILlpL2l/9a+Rv/zf4f1nf5Hdv/fv8gd/9B0e0iVJMusOF1flzZcRZoVDHy0VY7VKtsszvsdoGb7Q4MtfdVmK6V/6F+Ebv0oWR1nqiy8hu0b41rcdAQPyt3+E/fyvEpdC+O1vUe4+hnLCaBz+8UsevusenuF8Yv74I9rPv+DNf/BPCcvC63VFHz8j3fVg6tWPaClAjjyEmZtTxRpEjPLeLfnl9wEox4LFgB3PaFVKEU5r42yhewjDgzb24nrokYRJ25KODQH3J0OKeEdy9G7MbMFt5UYXuARMk7utXKFMDded3FBuvZQ77YrwryNC4dJ0Y7yLMon24Kwr8+93Ow7TTM5emq/BVxtpRpp8LdEpwxQ45srr1HgUjUTC6pmpBUJfU2OMHqoEIUomzTuSOVZUurTD0GLzpbMjqebVnCayiXZvYqy2FXn6vbhC168Ss3FcB2bbcbVObP/v17DduxEocSl5jmVpGlNsjr7uiziqpIpp6SoHVymfRKc1WALz9i75TGAi1u9BD85UGgRztwcJnjgBGiJB4qYv2VrbjNBDf02ME0GUEAqlBuboRJ7z2djbjkdd0O27Ufi9b8z8+odn2v0r7ndfIO0ec14qVhyU2e+7n3ZLrFWJEjEplLYSW7xIOwX/ebMxTqPruUVj7coHc+zNcnb1sP6M43MemLnIJyFsmyfQNY8UCUpEur/l5YYkxCFhiYg0ogg5R3ZT5DDv2E/eAbXLt+zyY/a7x+zyDSF47XuqmTDgbSawGbOFYmAnoK6sUpnnxE0f1qdW0OhBYgg9KzSPpM/dEN1qZrXClNwmhQApCZImN2QfHDMz1taozQihElQZtrGba4HlXsp1zy7VFWuJLmO2lTJjFFIO5BxJ2csKOQkxhi0DSPGiuj0O6/ysGhoy+8QIaqSgJIMoGXv/Bae/9RdJv/kJ9mMPWmZR38y6BIMjMh64Xhv3Xk/foWQ9ArNrIdWquDisNRBHxrwb6HKdbg3SS7qqRFNEKtYnOoA03cQStbYOybd3dcxMNtjczG2QnMUnnMvKuYvaLlbfQctGZ+VPkxa8jj9HyfZ66m5VwPFauZQ9xj0aWbH1fycGGumWPtv5zBfGTU6Dq5LCVVDoyatQUFYipwTtq9Kvrd+LID2D9bGBBKIGqtpFSqJU1ladkxGMnJQYGyFASIFxR0pdOC1LRx5t0xTUBtp3psUaiynrEplrJBUl7CBNxi5fnsvgswFdMNlc8HJNLKn7FMaCSGSyPaEFiiineOTnf/g15EPh1fO3ADwt3onMdz6C3/hXefTvPOe9f+d/xUevf7wlBy16CW4dm6PBbevlaxlzEVJwxLwZzA3e20defO2bTL/+G4Rv/AW/3qbU+zP65BHx+c+T21vkP/1/k/7hP+ShZ/I3Lz8i/PY/4mHakb/5nvOYHhb2n37M6c2HvLn3a7/JM20qlO/9Efd3K+UQOZfC1w8HPvnw+/0+NJZVwTKlB9yPzcdNWY6U+84fI2BWXDSzQmnGfSmsvXwNvWxWSi9/24a2iDjasd+ej5CDXiGcON+PCydnKMRH8e7Nd5TUVUkDoVfvKK+ml8Ciz5RrVH0kRqOAP6RkJAn7fRdM3s1MKfcSd4TQsCCkIJfnGAI2B+os3OWV15MxxYIsDSmR0Mv8iDrvLRgQkJRIeWZpjbWLUNdWN/TLaRHX5U3bAqprvtk1KnaNjo/AaisA9QDMLssgwNYJ3OTiqTsCteug9rOYjtC7j3t3esuBczDWCOumoRFQMaJdwZV4UGYUaBkTReJ1SHl5h4sYu6PdtRlF7CKArR6oteYlxmrFaUtXOmIxZGI0UjNKV01voZHXSONEHIKvNfK7P3/Dl/6osv/hp9w8fs48Z2IQ95UbZ1RXYtAYMZp73pbCnCasr0shTh6YtQVMiTmMR7XxmgdK9q6I8p9+fK4DM5/0ASxcyR6M7Mm6p55r2De5oB3BcKV/fNEIIZFzIk6ZlC8cs5QmYtoRmYgxEWMCEiHmLZLHJmiBGs4UcRh3Ob9hbWDiCBTAQZIHZqESYyTLRI4zSGTtWPOxBtZWkRwgJQ+Mkn8GE/PgAjiI0NqKQwUC4gEEdjFcKXZGe81LFA8CrVDNW/pH5p5yJKVInqKXpqI6ehZ0cF57xaEvueZmtNIx8Vwb1jPKbO5dFkNkr4JNAf7iNwl//UPKf+DkZTkOBzy2Nu8RDNTLWv0uu+tq8RhK3uCZdTPDVvXmjTh5QwfAFky1DWIywKzznkJGQttKmabVJ1tTzLSX5VyTbEysEbC4c0f1SVsKtXXdsuYB9hkXEj37pW/8susFz/rCOVS2LxplvFNztPFaOsdkjF/e1UQrXBbtjHPMBl9lu1/WkUkuC/n19Vy/Z1eHIyosX4DlS8ZhrSx7z/j3euMbz4ZMeSio2rYmlVIKp+VMaYUUjSnBlDuXsxn36uU5dKWUwroWWmkeSBsQE1aGWGUPtMwgeJNLDo6Kk9IWmLngr2zc09ZGItT6nQGzCUNZ6tllM3LG8p6vfbqD5Y5H6ohZuhP4z38PnuxoxYhf+xpP/vZf5dP/43/I0sWTdiao2rbhGVCiMDXbkuNTcC7lF8m8/8UvsP/iB+gv/wLhxZdo4UDrMis2CTE9Juzfg8Nj5Nu/R/vjf8b0o99id/NFP1deSW//mGSZST9l3d+S704c1zcc395t6PVrKs+j8P1m3DLz5m7hcYLjq4941cepqGuqhSlS7458DDzriP7u7o63HQHeh4nTcoZimAiaAnX1xPDYP2MB9y7to6eNMSquJTh1nblFR4BF9zcdiL5sTSMYBBFCdCkQxRHqoq4neWnGgaKudzaI7trH7TWKpPg8GxIR0RzEibvA0xvnvtmcujeykEMi9aTfRBhhn6WOFs2JY4ZP58KchWAVKWkLLHPOEKKvcyFAmCAZITda9Ht/atXLkX0+TlzWiCEPNO7rSMjGPA39/599Df182uf5TxOhDVw4r7mvq+1qjRhJ5HW+NpTj5h503MfKOTSOoXLOs3vESMTNvAPaXTGaBbAIJLRBiIpoI0rdFvgGSLPNLip0jpZ/1gtnWnryJyGRYmWaDFi9ea2rvcboot8LiZQVrYVIYJojD+VM6nyaWZ7wfZn57V9NxDdnfu7hY3S3Zzc/ZaULJpeZJF3JPybWsiA96C+lkNae5E3JkT18rwnmjh1bbAKsa+k8QaX+SWWKzxyf78AMCGJbWRLYOnfc0zCQmiDqEoRcZbCuZB8IqVunBumCsHWDmTch0ZQ9mIqun+JNAv3WmXd0HsOZm7Viq+Mo9axUGvuhrxaDZ33JNYJy2LFPe2LM3JdLyUOC+3sWK+xCYA4TWSIttM2gOEYnXFK6ZVD1jhiaUXW0GfnkdUNfeleq9MCikfokmyYX7sw5M8XkOjJp8M36PY3xKji7uv/BmIgce3qXo2cMwWZUlHMybh+9QP7VX6P8kz/2y/r+iTW4Z51ZJ+VaH4jXkPz1W13B9yPLhLGAOddLVBHxxg9vXtDtb4PgooimtLYiLWDRBS0vIoAQUKo1UHV+DD6ZxkZgneCqqrRakbayrKeetV16mxoXtEx97/FAwy4bhvFuoDY4XaNUydVCPKay9chriNCOc1W6ByCOzAxNug1l639/6fDqc8feLY+M6xjBpAhMJoQvTyxPV6Z1YZk9o2yl0ayhSfs4CR2pF3r1kVLPLMuZ0iote1ArJKx1LuTqQV4MXUG81I5Ohm7rEjC7lCRNvAO3NMOqEhpoi7RmxLF5irxrM2Vexj41ZdftVlJwzSkz575Ig7YsvH934zSD5abfjNfo93+L8LuBuAvw0Y959Atf56YZx7Fy2lVzR99cinpgfts3n28+esTTr32J/PNfRN574h097QktRsoUmZIj9HHaQ56xFy9oj54j/+QPyN//Dqs+ED780F9zf0+LkZpB7t7w8PpjFitEy/y4XXhtzyL8fnPk/Q0Lq8Btnvj49f0WTL3PxOObAx/eve6BvnAyY68gbSF3dOdYTu5Usfhm+VCbI8E/Ecx3I/BeVozSkwm1jU9YuwuHr9HR56B5l+7Ga0A2lXjo5zPbkLNhc6cWN2HZKpex212S/H5Zn8fiSfvgtYYI0y5z6ALGb5NsArUBV/MfqPiAsiwFwpzRKbFm4yEZx9DYq5HWleH1SRiobeprbgDJkCqtrzcjmVIuCvyxJ2vYJUn1yk5HqLt+YRB60uiv2ZT5B1rGJaG7PnyN85+n8Rp5tyw6mif8moTQG1aEQcHxMt1ZKktQltDYV/cvVe20gY3sHjAJBEtOndGIxOi+pn2fks7TGIK23nUb3S8zpAu/L7o37toSaZr6vpyw05F1WMD1AG0FshYkCrkmVlmpceVgPs9KXWmnzLe++ZxfeflDHn7rU3aPn3PMRpaOoJ6Nm91Yv6N3l+Pd3K3VzbqJ4CZYIQRar9SVppunNrDpYIoa2jb88089PteBWROhEam2svYBfxRjbZEY9lDOfYMx71wcyuyiEFqnLAspZbQJK8rSLlILJr7wU4SYA5IiIUZQV/sG2Ocd7J5wLCdsUuQQWKuxSqXVe0oH1vIc2JtnZXFK7OY9U74h2o7YDX5zPfNWAmstSIrssregG4E5T1eEduOwy8R4RttCDcoi6qK1/bl7+RJitwrRmDBNYEISY+7cipwz027uMOsGgIBE8hbsKkQv7QUCEgW1StECpTEN2QMLzt2ILjQaWqDNSvmVr3D7d38dgIf/zT/gUA2/+5WMLyyrXFTLwReQgfBsLe7mwp0bei5GDYHj+obntfizEiVUqMPTMZrbXolgIWMEVxU3JdTqbgnAGkHPjVaUVcKGfnj5sm/6oXfntApVUU0gkaIPnNt6KZ+I+7eJ9r4c7SVYu3CMwmW9f4fzMcoVdvWzEZCNPUL7wj1KGdbP13BC88i+t19efTvOtYlZ2mUhH9cyNgvMkZdXv+JlxWMR4voGgEV2PsA0Qpg9622ObNRuDG8FjtWzxbwGtCoPuxWikFojBM9OcwpoLayteBUpKTF54hRllAMSMYBxBq1eQg25cyNtSyJC7tzRIF5i7Rv6PipldIs2z7ajJswid3nly3cHpo8eUbkjnzry+emPmFPwUpYesD/+MfbFJ3zj7/w3efMf/t/6Td1RORHwhqIHvGP+K/Oer/2Vv+av+bnnrDkhecdDUm6mLxOev6C8d0A4oH2Rl7WgSyTs3qf88A/Y/+Z/hN4tSJnRB7/vpRgLmUNZ+KgUzqrsCLyOjQeM2x56v21eZp17sHZL4JO1cTK2LvA1KR+f7zYh0RnjpcDXM9S3jQePwclEghqrKqteAoFVLg0mEz5PkwknjB1e5qoKuxgZs+OmOBJVaEyAlIJKQ+IleGu4SLFY804+87SnAKJ1Sy6KrNv1wAVV8sHt354RGi6aK3in/KEBs2BfMHa3vk7skrC0TBOjZMgSWa1gBZ7jG/rZAm+ILBFkTryeT8ihUndg5TVPH5776+qBtPNuQJOZFCJm7gIz7/xc99oodfE11+gbv5DNNtTK70X//FdRcLN3E7bW14LhaONrzcVTdOv8FrgSt2c2f0Zi7kABVxUJcZQ34/ZfEVhyDypNWCxTSCSLqJywNRFCpoWeiQJT2LHqHRYi2bxqg0UkFgdUYLMEM9+SnRNu6jSMUFmHtZ4JUSZ23ZbLZCaGxQPLvpAuD0o9C7sCVTMxZW+4a8ZeZ5bNQP6OlHbY64l/+C895299/Cn7H79k/uJTtrradKaRSXFHayfCnGkSCPUBzg/OzQT28y26SzxoZT1D6vto1UaIzldbdSHNE1POrK9e8ec5/nwFz58dPzt+dvzs+Nnxs+Nnx8+Onx3/pR+fa8SsWqO0Sm0rtevpWKsEPM0ahPWAonYpNbn+X3IeQ8qAi02WUjmdFo69XHO7W6j7SlUoTRCLSNcM23hH5pBlShPTNDG3md1uR6kzZxaXrACUQJonUgrs55kpH8jpgDARwsW3siw7h/kj5LQnihtm5xQ2KFnVu0FiSKAzpR4hnLFW0EHTUFziJ4jLO/T6X46JeUrsOlE1Tpl5mkkpEqKji6BIuHDaZHhBbuU1T9lEjaor13zGEJynFkM3JdYz8+2B+Hf+kl/7b38H/rNP8cKJ87BWYLIrAipsvDKxCyeqwtYU4f8EijWHlXWQeh0O2tAo9aw7KJBhlsgQ53Cx2f6uXbdodHgV1Z49Xepi2gxrvaNVGxR1ra7ht9dfVwxWc7RslCYHCjYy4euy4tAmu+Z2jOOz5YhB7K12hRRgW9v5IPZ/9jzjzo5rsI5EvlM+5lIqbgQEZZlAvwzHcybbHQdzhGFtBWTFdKGVRKDS1FiqczAASiu4AG+lqXetLaeGomS7WN2ICVoDWAYJhFC9dE7dxuA4RCIxGNK7n5r63BodxEMXK8Zu+9O1046qSG/7TQi1KqEUaI21NfL9Hq3AbqK+/RiA6f5E3e+Iz26gNcIcOf/wI8KXvriVS0+cmES86aE/15/bf42f+x/8PdZ07PehInpDvQ3cPDqgH3yN86/981iITD96S/qhlynbsrpUyKc/Rr7zm3A+00Rpb88bcviqGgeBNwpNtXtzKqdeej9dQaSqLtthwMnUS7ZX4+5c68ZrNCBK9BJThLvTkeEN3fAu6mskdyuXbWPwMn67dvPWYOJ8xk4ZMINQnbBl4hC/uFTBQIU0NKeIjPcx16Rs6n6242h2RWLv7+9luWvBVHtHKiMYriO3j0z7A6dhKRUjUxJ2c2Y3J+JaCUHIKbKMdRBxj+LdTD1kbBaWBHcZDkWZ7B6A3Ar7JSFhonXekwVH+0fnqe8/pXd9s3FN72ETh4WLOPU12j0Qz2u5jDLWS/X1sl7dg63z/OphCX49qhfR2fEcZVsbHKU84a8Z9+EU4HVqHKVSWmVS1/DCIlZ1KwU4dhZp1lyHLDhf2qVSHJWuepnf/pyEJBnCRCBuXeDS3yNNaeu4xoI3SvSrX1vgtDSs3LuESeuIax8gg4eOKpOurOXMD44T/+y/dcP0v33Di9dH9LGf664kajBM3zAd9siaCO0VBOXchNBv/nk9Ms1CzjNlXVGtnQJ1sTTMTdDTyrIU0jttsn/y8bkOzAztSu+uPQUQtJF7J88UJycBj9JkH5khJGLyOrVbJAbApRHWtXI8+gS739+zPzywKwuSvQVFcNP0tfrAKrZQtRCJ5Dgxp4nDfEDrilhBbdnec5omDvOO3TwzhR0p7jvHwgeaTo2lrX6dMZDTjhSc+zVdK0sDiepKye3EWsBomJZtspbOr0ySXCrDAkhkmpxPtutdpSFFckzkEDdzclfhb5sjgYhz+QzdaqRNiwuyotuM9x7FiEhwrpMYKUQkNNI3XgDw+O/9C7z5nf+Q7JWZjfNhV5wJ6cHHthn0hWRwIQZ/pPZgrF7B/BqEGIITcMEFI9eK1rp1V+bYRQprQ3pQ1Wp17pgpqzZvtggJ0wuXRlWhuZsErSHVjXurVu/Q7Ne19o7T0q97qFsP03L4SfL9Fpxtwe9P/m6UOIaa9zuBbH/dCGR/GhR+/bpxjM31+nd+v91ku/xSpr5onE9Qb4TWPUiLVYJURAsSlq3EOsqZ/nxKL0v0r17etba6p1wnTWrz0pVJIwRDYkaiB1K1JzZmiWDSE43sZPG2MXS2xTqKeJd1cH86zxrcNSN2KkNrlWIFLRBa4Ngau2NC4+zC3e3iI1vOC5wz9nAm65F8OMD9G3Y3vqHXtwuK0WJgbd4d/fP/9r/J+YOJ8MYD1LT7IvPj59Rf+jLE96gvEoeP36B/9M+wP/wutcuLFIukc8HSAfvkn9Hu3lJRzm8Ku24QezRINE7igcKKUQLU5s/wNErl4g+1dhrAar37j8smPARwp84zCj04kBhYSmW5jol7cHHNSRoB0RhHYzP/7NirqltzhuuCgZkT+4M1F78252P5a6qvIwZqDVX3wizOFL4Eg3JJUCoXftY1b3J81kGK3ynYDPIkkXd7Qg/Mpv3EtM/sp8DtLpNzdAkENVov4aXiuo2WI+sc0UPieFuIC6QK0jv6DmWhtZl98s9ZQ0NNWOUiAC0xETWiVTdqQjVfQ9rV68Z83+6z4R2HXPazz+qajc+96R6Oc3F5doZv/snepT6MznDRS8do7c/2OMrkuO9naY28OBggYZzYkB48B5wv5jzP7HsKXe5uMyevvqbGRi3eWFXV5TKEzBCmCBY9UJPhHhJIMdGdigHYNTicV6x4gFRrT0PMCIirFOCNJfftxG2eaPeN3/65zBf/thH//T/kK49/GfDEbgnifeMNqmVqnLF4Rg4Tx04/ONQTEiZPNORMUwUmJEZ0vF9MrGX19SqnqzbbP/n4fAdmHd0IkU0MFZQYfaF3ZXFlkPNGYJOmHTnPbtRbGrU4f6JUpZSFt8PgO3m7bKnwqD7myeEpKc6IXTqDalOWpdKKT2JRYxdn2nwAq/TmDXLywGeaduzzgSiZKe4cCerK+SUlsgRaiEiamKYdSWammJmnvCFTQYRqFcsVq855ULwRoPTlNq6ujiwavAU8ZKJMTF2rZx4q6CEQYrwSkO0rOmybYrIVsQjilFHDoJPkXbtnbJ6RgPnmKZEYYE4gk3Ez+zlv/vqvcP/f+H2O/+c/YNef2JDAiFcLSZTL+JX+u7FIdZF3gihJuixBX/ibNbCRJ/vJRNUX9+Dtym6t5BnU4PdobbQuFqi9Y2CgSuNZl1adCyOxLzLesblaY0E3yYSxkI1r3zTHrjaMLZuVdzjPPyFAO84zFtSxCY2Nhv60hAvHZOqb8HWWDZeNy22/+KnHRqAGHoD4lydeJWN5OBHLzLGLtM5SwCoWCxZSRzUjteqGXldz86EojZxcesDwZoEpRNdJwRGzEN1gXpISYiMmQ7Jh79g+B0yj6x2pksw7uHTAw/1DOsfMEbMQggcGLTDEl7UIRQvFXC/tvjQOZUfSjB0/JfcMQdOM1E+w00KoM/LwKbQFnsPPfeMXAfjot36HUwZWNyD/C3/pb1B+8deIDz8gf/PnfdzcfJEaj9iLD1hOE+nv/wPat38bO35KevMx3L3tAyeT8hNKfkr++A+RcyFr4yMg9nXEAhy7TMCx866KXQSat3HVRStE+sYsjiabXvS+rDeKoGOcuxUXErvGmD/HEAOld8MMA+2ReFyjNuP9QwjupGHdsUQ31iK1ub6X0WhaaTa8Dmxbn4N5NWQ08jTz5pCGslx9Rq6SlIEqafD5NIKSYS1Ug/TOdbAb0MeRsJ8JvStzfzOTDhM3OfFonwgK61qpS93I21GtoyEZnSLnm0h+NlNC4bW6eTvA8iAsa6XoCaQiLYNM2PQuQiTiyYh2fvDKBfUbrxrfj3np/3aAgcvxU0UoRkI7ElmukLUeAGpH2EbTyHgv6Unejktzy6sOE38clHsaoRTmGiH3pFJbb9Lx+5DwDtdobmIeSYgGBNmSJJPQuzQDa1/JmoKpuJRT8AR76lWjy73zRrYgEzn6++3yjv1uh7ZCrSu1Hv1+qldXrtFWtcJ9gyeTcP8m8a1fzszfvePJ9x0x2P3c+5zXM7chgxU0zZg8IZzOTOGedQPfXFNUoXfwNkfXCWi47EGxV2h+ogTyJxyf68AsNPNSiciGhpSm5JhcYZ+MBM/QQtRNTHA3H0jzjiCJZTWWc0M4g1SWdeHU9ah4eIUG41gqp/aU2k7s8i1TPGzdIq7+7rCKVJcXmCVicUecdYN/pyjkKTJNOw+OZGaKE80uOlk1VZeuEAjTCMwmcshMed46KYXGxIzRkJrISTY/yGEpUcNKWRWThOd5mRgTecrkNHEIvU08CATfGFPIxN6paF2hHTxbEateTvEak7spiHqp1i4lohgTSaTnMUqOM3EHlvuG/oVb5v/2r7L+ox9Qf3DekLDEZcwq747fQUR/t6jl5VwRpZj6pmvmGW33Du0Xv51soHvD3kS4BF2td196aRq0e6xW1U1Ec5RwB3Km1jhbc3kMczI0dE02vQrI7CcXTfrnHr8bXUmffd3QJUIuJY7aA8br7tTRIp/wjHSgb9f3cSy2oy0/XL2fXZ2rAROJ+qRy/tUZlkSUheMqPE29lFkr2Iq2jMbYFcUbxZTaZRaKeaJzk10rTyhMW425gvgYjCE6MTgLVYWYIGYlRiV1ZLdqpDanJ3gA0fWCVHupYkBF9ATDvFFF3PpHo6Fnv65a3Tt2qYo1461WHtdb0BXVFcpo/lFyWainT4ntFm1H1odPieXMo69/GYC73/4djutETYUX1Xj2b/z3qI9fYM8ecf/8FoDZhPRGaN/+kPL2j0i/87vo+ho+fcny8AoprhfW7o1UPyHsP0CyoueKtMqdXJGzQ2TtYdB9D1s3JPY6MrNeWrXgncXBicl2oTdfxhjj1vkJmsDSLgFcs0uAs5W67d1gYPDtFU8M1u3kLs/TeqTkaIjjMM3OBJ1wuYGLkLMqTj0xc2HiVhkmcsol0Lj4QXpHpV1d46Vjuf+/mS9dCeLTgD7PxCd7plsfg/Nt5rBPPD/smVLCVq8AZJlYezXGkiexu5g5pZmH2z0LCeKMyZk1+R2LGfSNUtczURpp41LEiyaighH6tQuLGIt5wDsacujXPhLE0awzbvpIWfqw314PvgZsidu4F3KFLF79/TXaZsEXm9V06yyvwF2E72c/+ydRMV15rN6YtgsBggvIEuTSeW6VKcSuQefVK0z9sw95q16IFLyjXJuXn808CM49GBwC6JIcRDD10jtANt/X5zxxu5tprXBeT4T1RNXmRPxeUu1X5l7TYcfCkRsNfHh+xvO/JRz+D67x9+LuEfvDU9b6wL4VcjhT8o76+BGnj16xG+sPzRO/mLanoDTQeLkPElwbr16VTP6M43MdmKlpR0MqS+mKyqpIElKKSKNLQXgJb977JNztDuQ8I73D5sxCIBDEKLqy9Db+V3f3vD2tTLsHXpyeUXThZveEKe2ZgkPgIpEgwyesbdn6FCfQinQXgZSMnCDFTM4zWTI5RpJJh1yhSCXn7Er1afJSpszMaWae85YxpNCFQ3XFukZTDc6vqb12f9Q7ghmt4uVSAjEGL1umROpyH9eBmQv0ObIXEVpfkForvR0YhNQDl9q5H3XrWpReXsFcqUyCEVMitTPauyTXSdn/2lc5/Ru/yPrv/lN2XKD1z1ROtn8HR6X2Msk2ttUIGMuyUFolJ4i1y3D0Mww0qVUP3oKC7PbdPuWiIt5aY20V7cKGvXDr5c3+jiqX92/qCFrRRuEiljve87PB5U+M3YFiXSFa8tkdc5yn/1u5dGldaw+N9824NtGY1JvI9J9yHdtGdvX/AHyKkn7jgD15TLlb2E+PsGYoc78uc/SjFUwDUVzyoCLbBrvUI03PYO6VF4N5VzOg9YKGRRIxTJAKqQamJEw5kFMk5Y6+lUtwEK17LKK00sjpIk5aTZ2DEwIm5jxScduk8UAaQmlCXRtajXMtPLYbihTadIMuzjEzvcfuToT7j2nlx6yHPYfbifPDJzD5WvL4Kx/w+ntvOFbj+Td/AV58iRAWODxiH3pn8PEV+uEd8fiW8OHvIG9fET75Idy/pZVl66QOC6i8paiiS6EdKzHA2bx8BBDVs/Ej1kvXLpq6dS9viGxnSPYATZQueHoViPWJt22MBIzGinC3lm2zj02J8RIYDEFT42oc99cKgOoWI1pH7LZymoQuIdEwc75msEDbQj66K0MPylR7CdPemQPgc6DZBUHm6nc/UeanB7c3IO9F9FlGbidk77Pl5nbi6c3E4xxJEigoVYOjIN1OCnW0L4QAMSDz7N3lUQipUPuDXHZKFVg+hakYE4U4rdRzviDgwe3JzhgnjHMPZttn5uo7yF8f/59FuyOXhAsuSdtIxDZ6y/hb85K/cBkPo8M7dZRy6fc648j5qwm+35GpoxaelEooK0kykxRS9C5713/0kzVT71AnYDZtz9tgWyM8oXBJjciCr7D+FXC5KnB5E4ZVoHNrsF7yHPp3uxRp00yplfOcWdZMqwURo9Wyia+bNU4Z9u2EFUUORjlWfvO9yAf/uu/Xh//Tj4m7Z7R9hFU4nAtBClUETU+xo8/IJRbyNLlVXegC77qiEdC+Es9+p6u4nuqf5/hcB2alVQiVtZw3XZEYhSiuJh2lMU2J/X7Hfr9nv3d9ot18S0gJUyEHI4WuZyXmmj0nJ+2+vj9zOi9YvOOjJ2+4O595/vQZj/ePOGRve87RLZxaMUcRpBuqmyESmVMP4FIlCD1wy6ScAVdKHorKSQIxToiuPeBLTNOe/bzjMO+IcbQhu09ZbZEWFAuFXVTEhNZ1WFpRRyWsOWEaIcXEHBJJ8sYfCzFATB6QDb/M4NQclTGQuwSCgoXqi0urzsuylXyl6WStoRK85C4RpBBJ5C4vsiwrNy8OvPnbv4j9/R9w/u1XZC58hp84ZKjhe9nmWhlfuuJ3XVfaWpjUSxYEQUfpRJr7iKq6Jo+pE/jFeYhD961qo3WOmeHwk6kTrFUvzRlGRM39Rqu2LZN3johf1wiaNiPlq4X0GpmCy8K7lXNhI/hDP09H8MbmpHjGXT9zzmS+kMarxfyadzIQOb36nV69B1zQkFNQ1r/8nPs6YeWITBOS29ZEoq2L72rFAkRphKioBJYul3Fu98Tu92oEJEZiCgR8gxoLuKhfVZRAjuIcyDlQ9wldesBVqgtWaqRqQxGCZKKCaPOxTkdaQoCIi1XSyca9nO3POlFbD/aqcq6KEMkloCnAsZP2y1vimzNSViwogcxyLITjW0rfHfPjR0ReogkeffMvg76F2w98E/++B3jBjuhH36Wc35I+fUN79QPymx/COVJbQ84uG9IKLFJJObM7P/Bxc2unJHDsA+WRhzQ84OrwGTjLFS+oT6KovckBowUHbIaP7EYLH+jpFnT5z4+tcTJj6mhFNFdgN70ExyPo+uwxgoHxL9L1wca4iZDEwAJF1dc6Dx0viLmtNA3UThEo2BaQGfyEKr6Gq8BFLsHHeI3R+VTA/ATkWSI82bEeIr1wQDoEbubITXIiXZTEqg2l0Toio7X52hLdyzPl2aURciZPsHRhYrtRYrpDC5xeryQTQjGwioRLSXftbhaLwZGLfRRcoWFX+/iYyyvvzuuIj5FtzzcfBwOcbp/5+xHIDfeQKu/K68T+2iU68f+TGV7ewOthXVILdVnR1ZCSiFHJaSbSWA2ajsqOkfDGDWmpo/Eu2GujhUHCxkGNUtGuYaidrmBX2a3SeaLiFxrE3QYuJXAPqnNK7PLEMmXKkmhxxZps3Dc1I1WhxcI5CHo2nu7uCXeP+K2vdM7hXz2jv/ldds+fcz48hnjk5nSEKujNDW/7sz6sFaRg4qLzEhOl3UPhYmm4Vi/f1/YOv/dPOz7Xgdmqq3cG1rotunOeu8imayGlLMy7zH6/53DopYXp1rkpCtEK2qDNjWKV3XrL1EseEpSlrhzP8PZ44v74Hd578WM+ePY+zw5PANcxO+wOTOkJrazEZE48bC42F7rBWOicl8F3i1GcW2NsWUHsqscO6weIEynPzNMNu3km9x03Aas0QvI6toRKC4rWtmXfeZqpLXiXiEEKiSlF5piI4tYVgJvFpkDsYn7+5RmKXJH6tUKTBQuRKua2RbVhcuqOCHj3Jj4BgkQMpTYlhMzSkbxpibTwhkdf/oBP/s43kH/2iri4ps7IbsNnBu+YiyuXjiSACSEHf4HVBrVh+4kcLuFblIYlFxjWc6OWiobipZV2FZiZos5+8AyLgMnFlw3oKtNuKdQMt3Aai5u920k1Nq/BvQlcNg64QgjH5siFaHtdtq3979q1Bx2XDWecC/GALJsHaNKJ2tL5FQDeD8lFhNJ+MjBTQCzDVwqnn3+EnVbiHKkoGdnGqpV2EdZtlRyk8wJlS5KanknBzY2bBrJ4x7E1IcRMkJE1u/J7RRFppGzMc6LswA6+G9QaIMC6Nrdl0UC4IhVvC7j5jnMtNB3MyO2yoS+qaHVXADW4r16SJuwQu/ObCKyLcTgXSq3EDPLJJ4TbG2LeY83Lj/nJI9ZgPKpw+MW/wOn1h+x3B+zwCJXO0XpzRzzfY9yzlMruvLBaJpYTtj4QTz1JUkGDofaG5Ww8dP6PJDj1i38G3PVSbdKesPSyvNgF+Rgo2Nof6th4q9i2YY9uvEvZT4niPpiVTtbC59yqPq62wJ530d7r47Pj1LEw/6l2DqtZpHWDdTED2V3N+0ZRo9AuQZlckol29T6Kl0c3xPcqcBzvT0cSpyDMt4n0OKE3e2SXyLtuJxcDMRhzioSQWAqk1V0OXL3enU2igOREzDNTA0JAUyYdvoR18r+qEtMn6P3HlALne59oibZ5RdYuQlzxzvQzl7IhvFumHJqD46hdo2xbL7nQQcDnfdVL5/e2Ltml+UOki673+7VZTfe1NiYh7oX7WXm9g1c7eMjjNd6k9lAquxp40tz60FSdAGzDhSO5/mUxVzPQ6OswgXaVhkdxpYPQr9gUWjVaNWyS7Vl7w5D/bwvGpNsyAn6GSJRACpE5Z07Bkap6xS/DnAbVGpwnYxeFpTaehIUP7x8B8Pt/a2H++J7p5YEbnnKUFYkHqhUWe8u+Q5tVFasLhImUHOBYrNHKQuqgTFMlxoQFofxXwStTa6PJINX5jY/BsztB0JBIecc83XA4PGY3+00PMpEsubL5boepP1prsIQz+9iNbXPkIcPUPHO4vwctD5zuT3zy1NG3J7cveLR7wotpYUqBcDZyisSipHTpUJmnGQu+TBVbCM0zcCGx9Gz/Tldel5dgM9EylEbYewk0xkzo/B7sSFJHAao40rCLMxYDrXPMLFWUe7T55JMQkZDcAiQqx75ATBKYoqfcEjImkSY+AUa5M1ZfgU84/F7WQmuFQmUn60USxFrfDD3AjJIwC+jdmcgbfw3KeTlg6QnP//IX+d7fmLn5TxZKiMw9SIrmC3EWNmHIBR+shcy+U1VvxYgaqdKYy0tKe+ol3SfzhsZoNdK0J8SZZi990YkRa43SLpM14zyAxZRaGyJdvTxerFasVCT7ZwpEVjEKzi8bpQP6v0N9fOPe9AXwp5UVN+5XX6tcAqI/ai7fjyBv7ZvUCPKW4P6Le2BvjpSYuYdjCWzSLsk8SMumW9ZdcA7QeJMKnCi8/Cs72k44nSO0zNQKu/SE1J91jjfetdaLxsM6p1BZu/drUENRWohum9R5GCFFhLDJI0iInFunAcTMjkibjGkHtfakpd0zt9KBsYiuJ+4r7IMRD97DBc7BMqLzInH7MJVGpWFhZOknyg52Z+VeVxLwOr2mhTNtaaTs5Yz88AlnVawqqVYmibTzmZPBdB58oufc5sB6eA/ee0R685KH/RtuWmMaDSkv79DZaC9fsT++IqzeQllDYMo3tG7vZGYsp0pqEGahVmE1Q6ptwchJYDblZI40tCDM3RIqBrgZyBRwDkZWvNyDQ0rRLslNwrs6EVeVlwBTpz8sQd8RHc3mCcdQovcA/lI6HccYryJeFlP1kuooNxUzDsxkPfJghSaRWJUQzmiPDtaWkFBozbak4/r8XCUaxqUL2flng1bSHzV0eygh7AyeFuTZY3SfOe9W4s5PfksgS0BFiVRyyBzyDW/PZ6zTZFprJFwNny5XNMXMze6G3CYO73viHx9WTgLLurKe37DeX+7LIL3LuF4uwatwqRysV9efTTo66J/r+JlFxP8mMdFcTNsuKOHCZf3YX7/elMN4f3NeIPjYkAj1hbHcGPVGsNk4h8tjiMsKbwJpNxGPD65ASyEn86Cs+T7lVBC3azrr0RtwVKghbkLmAVdDcCcDZW2rd3ZLZTJj6h3SE0qiImH1fMECQbL/nQwkfEhVRKa0J1KYU6HkOygrrVuSqDnAHs2pcc2g7RtnW9j1zOaj1zc8+Xtg/7uX5PIe8+2Bur4mGhzaYfNrPeaPyVqRdiJpoJgSdUdjJXdFhrMabe02f5+tVf8Jx+c6MGutOmMQj7qvDxFxzbD9nsPhwG6+YZ7G0HR9lKEsP2sG2dFaYb+buVn9dQ/LA/vphJpzVJYKp4fAclLujw5lHp8s3O9+zPH2fW73O0fQwgG1SrS2WcWIFmYTJOBq82o0K2hrrMWbDbQu0LwcEzoJfcv6A5u6OZoQUSIZ7d2nqhBD8/ZkYAoTS8isUWkNavBAKYYAktkNhNEmssyEEB1WFl/iAlB6V0md8DKmeoNDa42yrGhdqbtKqb0RQnFcJY7n09BWOE1ndv2HPnAL8+kefR7Z/9d+iVff+h2eftio/TXRmreFX20iezyrjDT6WspkMGGkppS3L9k9fh+ZJtK8J2VfJJUzWh+QWYm7TKlu+j46eq/J/+/A5qq0oZ80sqNuomzmrdhFfWsweIeMr1ccz6Et9hPdmFz+Xztqhl6Cu+11XT+u9tLVdSn3nfGOT+Ywyjj27vv003ug1P/vfDljZuLUt4JTgOMN8GtfQNbEPMEaE/uyd5S3q1mbgTYjdN6H4gbwa6uUYQSshRwDKbiSN9oI0YgdPhzEfpOeGJknCK0G9rOgzeUEwEvz92XxElgpGI3QUelhlQU4D04DZhkZka4oIcr2rKMIWf1nk0SyKd+xV2g7sWNC+3Udj0ceTvdIq9xKoNTGsyCE7Bw3AO7vuXky0z69o/3x9+B25uajT7G3R3fGAOx0YqIi92eW05FIQ9eVoA3UqEvZ7mkK3qxQlsqKc32ESxBeER/ziEumDJK/vNvRG5FOp/DRNI2yMRfErJoHTU1ta75xVDd4J+QYiz0YSz1hSNbL5XKFtPa/HRvKJUmRdygK3mDj6EdtboFWTYhmW5m8aqEFTxKWcV1yIfIP9O06aNvKmnYp9wMb/UExpmcwP03ER7csN4k475A0qiNO45AAEiPWIWgLuu0tEiMxR2LxDuI579inHbe7W/Yy0SmFHJ5Gdg+3NC0c7wunduT8kSdP1wlXg3dKtAJEgltbjftF6CVOfWftGF2oALIHk8pS/D12MXIsbZvnowI509eIPg5WwsAxt67MsIPwHqQvwP4W7ifjkIRnBV7f+6uX1Ti1Rq0rcm7UctM9aTuNhVFh8DFWa8A0dIV/8a7nPjdLHlxJsOrdsG5tZBzXi89LwzjojjkHortz9q5el24CMC29cuaezSFASNJ505c1eDRHVoMpOHqZVFhKJaSuRNAqfzg/5ubvZB7/73/E4/gEm2+ZeUN5CCxdZUDPgVNrJJTTJCQR1hygBZZhmh68OFwtuLzXn+P4qbSenx0/O352/Oz42fGz42fHz46fHf//Pz7XiFnssg7S4XjwckCplZDjxi+b5z0p7oixQ6zNXz60jqbJO0bWNTOlzNxJ9lkiM87byOAdOsGzyeODR8ylrjzElePz7/P00TMeH57SzIhUJnOOD4BoJnT/vhiCX7OZa64MxMxWzxxMnADZs13EkGADHHSEwXIvUEz+exVqMkpv2d4FYQnR6+tUF97tXJtM20zYpzQzpRkNniUO4FHE2PeUuYjR1AjqpPdTK5x0xcrCPtQNmci5kbMxM3vjgLrqu1pj6QTaFiJST6znB5Zc2f/Kl3j16z/i9OHH7LociK2NuRN6S+86GxlkRjn0LHlvsIvCrYDev4X7VyxffJ/0ZIaOeiynSAozlBVtwmQTSzsyWrIv1CQnVTuXwRDteJJZ58V49tp6w8Fq5R0OTJF3ScnXxzuSFT9lHA+Qa5QzkAvvp5ltZdHmD8bFb+1dncKMZ8bXfqOfNSgf37oWVBgFLgrG0p/72eD8a5H61UfoKyNOkdtpIthECLbpQlppm/5aCGBirHXhXM8sm1HvSgyRGCaCVbw/S73ckiJBumqqBVIMtObckYoR9hHFG2AAkgaaRlLnr6xUglbnuWlluzDBNetUsdi8i8uMJBf+VRQhqdKSMGnkViu/Wz+Bm4C9Maw3CVlMnOGiBShKWQ29uvPn9TXvv/+c5eVHhD/4A+xXvska3kIq5D6ZVCrtdCS8fkt8ODEtZ0prfg/XuvFCsY6AWYPqz/eIE7sH+rGYcdM5U8Po2jpaFpHNODmad6xVMYI6WrKKOU+s/13pVAG1Cx/Jy+7iJaI+WivON/3sZhG4oHSfHfPbeDNBRS4obXXpjBASBOVcfQQmjRvSovSuZzrRXS4Ik3cWXiHb/aIHSKxXX+MaAhFJjd0XhPTenvUmsuwiLQrz1gSlSFCIDQsBiUroIsVbB6E2NxPvHMaUMjfTLTdxR5z36H7cocTt/gZQjmvh/vxDJi2sryOn86Ub0bp+mwUb1EgS9k4TgHfAux6dayVnVAuHm8CjF/6qmxfGfHA0ajka9dSYjnA+Q1u8EQRAg5ekD617Yko3nQ8QfcgTn4G8D/JC4GBME8xq5PuAdKFGW6CdK+tqLGqcl0qcpLt/NBb1tb5ootpEqUbWjOHjqjA6UXpFASNUd2GxZt7EZsbZCqUvqsWMFeWggTkkJsmOvqoysMGmi/9fO6IeIKXoCg0pdioRXirtCG8DlgapGRJh2hA6Yb03/vgXE/G/fqb8P15y+/zL6DwxTSunDhVPHLDz6t6kas7BTYmYzSt69KpHDAR1rvm7HVk//fhcB2bS1eiHmjR4+QxRkIlpDux2O6a896BMOtEzmHeBBCcKhqyITpSUyTmTewA3hcgUEoTG1LpmjtV3SknerJixlwVbX6OlYuWWKQSYE3HXFfZnJTCRO8k+EbAAhbJNfJFOkjTr4nRd20dbN9vuZVF8sSEEgk7k0LDk3LHN3sl2BFFXX7YuclgzFifv3IyjR9+89RvxhgMzhIALcfXGBb/LXmoaEUERtAVO53VT9Z4mZa3GUoQo5rwtM1orLL2hYkozxZSjNh7uFrAb0r/8Fdbfu2P+XifPpsCpKrsAZ4xs/sn3+Aazv+JDTDSmAGU9sr55OR4Servv1+4TQtcjurwmtBU0dW0m3Vqoq3pAv7ZK64bDGgKt2qYYXVG0FFS4yGTw08U22Z7Uu4XHn1aGtBH4deby4J74e7I1F7hti3WDYi9vwqWEtONKKqP/3Hk3/rNhbRWBM+qCS61wSoWhCNAOcP7rzzgfF1YTHrNnJ9mNgltC82C/NOhdmcEiKkoV174bnMNA7WWDRhh6ZKZEcwHiUQ6UnLxdPtK7lpWkCaWRezE7NKGYcLKuBdUKpv5VSiV0pwcxp7RXU6KKCyiHgIS1azSBRm82WKOREZ7lzA/CHfcv4PHHjTaahJ4+YzrfufWPOqfxzbGyqy5LBNCkIW8fSDdw9zv/lMdf/gJlbtSH1+TB0Twdqe2Na6RpoZWFiFFroa5ti5RMvVtP7KKefxa6QIkfC+NZ9iSi18CcZ2XvrPnJvDsvArM5FzJy1T2My8rI1aj0AKh17tFVydD8vVc612y8dz/G9Y7O4nf18y7piCIuIB0yQmWxBuJ8sevJ0eTSUFPpa2L/fNc0na17WS7j/Z1mFgOjkR5BfA94MrEeAmUvxCkQ+5uK9TJri32t63Z+IXjAdvW+lyapmXk+sIuZw/4ROpIImzncQNxnbu3Em3pknV9ijzPrxz34vAOpvWSpl8AzjiR0lIJD1+kzCCaINOL78IWvB774Zb+gw60y7QwV4/QAD6/hfAw8NF//557wpty4FXiGR2OruVXZWb3BBEBmCI+B55F4CMwJTvcNWRuxE/FFjHVV7h+Uaa/E+ejdtq25UG7nmFWdPZnVTE6RnSWWIJykbsFUUiM0H5u1eJOQO0HgVm6dP+Zbkdsm7uLMPmZPRMyuArN6CcysgnT3mki3aLuMiWEbZipoM9a65fL9M0Y0wtuXhe/8zT3Hj97yK3+wYrcTOlfSsVNzYoJQOalyaMJRXPh6Fbso/2u4aGNeCQz/acfnOjBTdR0zu2oPMjOauJ1LniZizoQUoU8ywEUoTZAYPXhgogUlxswUMyn1CZYmppSIofYsTnwDvupysuCZx3IHd62i7TXWzuynW+BA7oNZw0qw3damLKMV/WrRGosDOjgz6s0DopeMDt+kDEXMic6NSAja1ZD9mIKwhsAUhEUbtNY3ckWjbsr/lkCSdyP6xO9UYbvgAmLevbYarE2QrnDaiKiKyxjgbdJrhUWK6+cE75iRAOfiQZfVRjA4qnB+q9w/vObuWePFv/IN4stv+X0/J46hMPX3VYQddkGFxmekc2uykHZ7LBj7+n3qj39MOD8BIMcD5QE4V5LCUhQNjda6z2DfQUqrHJczpa0QhGZCpGvzbEwMQ9tQXGoUXBhyxTeRTV9IxvP7aWHY5Ri/1c71GQRmg3e6JJ2se9l0mvkGNUKkgAepBxucPPsJtGy8X4VuA1a5pzJF32gGyHX+55TlGzvWV29p0w3FVnYhItGFSkejipiHpabm7fEhIOLjZnRIa/ddNbNuTyVd0iXQcL0/gBhdU3CKCdTQSSg1YqFtYwsLaBMSk6v+H5WTrkirlFWZxjxr0JojvBXpjUCQIpyG6HBytDBJJDThWQl8mle+tzvyG/kRa28KiYcbkkQqvtCelurerivspmEVI9y9vOfFl5/yw+98yOM//pD84jFtLejwo335BrhHyokU3LVDhn5EgNolQZr6eiKI875kWDBduvRW6HZltiGsyS5czGtJhIAhCjn0XLWPgbHkOOJlW7NKkMHb8sEzhk/oPxoBmeq7qNT18VlE2Daf3R6Eo7TqyEIMHkqeca7oxrvs32yyM9hFx+wKTb4WZR48tIH4jaBydCzuv5AI72Xq44kyB2QW76DvPKCmgdaEFkGo3jRi3rwlcnlDdzSRjhIn0m7P7bwjHvYbAjynG3JO1EV4Fn6OIgs/yo12e3J9K6DOIEcoJ5DmVRI1YxW829gGN9mRrqaJfGg8fqF89Tfga99ofPBFv66cA6qB9dQ4vxFu9vCqRvaz8uQgPHvq15VvXBNz0szSGuU+c2yN+1I3lLtqoU0NeSTIXgg1EZY7dtI2cei7DOcV0gnmO0j5noYQdhWTcHE3EEhhIoo76ASEnUWswrlbuzX1z1zxh6fax2Kgy0v5qXQprAo1V1pY0ZjJMfc9r/O9OoImcpEjNiuEnt5uDQfiY0W6G4QENn7j1rypxo0En/vnSPlX9vzxDx74ij6h7HbU2ruyzz22QKkEamvENFFRivkKnSz5c5RhSvVnH5/rwKx2kjxcOs8EIZoSopFS1yezYTc0kCnpwpOBhEPtoVsmhRA2cdKIm9YKkX1uHjnVxLLKRko2a77Rqw9WewBiYdU3WDwxNQ8Q9rrnYBXVRDDPVNXU0bMQ3/lcFqyXPR3ZC12McZCJJ3v39RA2S6qx2YnMBFGf2aZQFRPrn/WKkGvqm0RfaFqXHlCU4T+o2jtmzFi6cXypjSONnTmxE0CbuL2VKaYLNboxel4a2u/9W6C1wHJunM6N43HldHxD/PUvs3v9zwFw/Pe/RQiBVxlSUZIZM8Lshdttk5rnSIpCTIH55hHp9hHr7XOUTGh9aB8X7P4NsZXuO+4oXTWlNts86dbmXpra09M2etnMtm6xqmwCiYXAWby7cVjiXDqhL0hDkAsx/Ro5uG402AK0q69rxGxYKY2M3SUDZCNBT8CNvFvKrLDtxHZ1rsalbV7FXH9tFe4f+6uO/+ojjrYSUwG7Qe3EKUw8khtO4e4y7tuZUctUFSwEGitqjgqNz2UhYgRaR4NiCk4KbrZJb0jMBMlMOW8L45gfI7GpRJ60ALLQWuOkhfNyxOoZM9skOliVmAOKdOeDhunAHUfGL+QckZChCJNWDlR+f/qE33j6FPuoh7y7A2k6UM4PaJRNWDUCpYtCt6MHU4/OjjQ8fPvb3Nz+Elil6kDMXmPHV7R2ZKaQUrpcr4Xur9eRLPPnokA075rM4t+Dl51XPJAZCGgGcu+uHbIHw45H8FLWKPlpv7f0cww0DPr4FR+bY8yN8Rav/j/kVobQ8fasxfXSrg/rIqPbEYSijRgjmYRK2YLN6y1L8ETQJSRGYsQ7Bt8R12Qbbzk6RpvYhgoGIN9A+mCiPJ6xRztsP/vYCwajo1xXVF0mxwOxhkmFYFvyNvwmBspWAcuR3eNbZHdgH4aTxR7LkTw94rY+RT/4AtIqH+8/5V5eArDuztQjtKO6Y8xq6Eq3fdYtqZLk76FT5dHPwQe/CN/4dfjqVyIvHvWOxaRQhfu3wqczWDSe7wppD0+eHHj64j0Abh49Jec9Wo1lXbl/gENVDkvlbR/Pr8sDSzsjsXkzxJ37XcZctxbc1OWCaoG7s5IfOiqlYKH1QAhSbOS9kSIkbWiK3QZKtmd6pFKqXUkzBX9qo6V9GzveLIcoBUFCoXZ5im1d1U5PkeqVpuFKI04DSqNyEKFVtoRbeqAWuuvNONcbrdzuZvS448dfqLRfX7j9R/c8erZniv6s67TSaiTXlQepTE2ou0gocQvCxHRLevRqHv1px+c6MNMGGv0G22i9F/Wur6ik5PasRvO0ZJTmYiTGRCS5qGrtdf5ghCjkvmHscqLtZtaSmMJKyIvrtIhd5BiA6k0Y3pW4Cu2usVRnRBz6xnLYzTyeEk2TbxZcNFnCsFqK7g+mRFIKpNCVXbr34/AgFMlegpDeXdicKJIsbi36p1xpWZEJJhX2q3AADmLsCTwNTwHvyowt0ZrLMNQg/oVw7iOo9IwxqTI3z+ylKbFUmOMmtllW8U6qVGnZKMlVl2tdt46ldU1wF5jOK0+WyG2B01l4cvOGF//KL/k1fftTlj/8kDfamK8QgxvgIHDoopDzJCQzpuBq4toULYkp7xiKlrY2JgtoFY7LmaUaJkfUhGpC7dfeWt26NM10K/MspXDq971YtxoCls4tG1/v0AauyipDLsOftL3zmnFc83Q+G5gpPq4iUPtOPLTTxpHN2/0nvJtW+0Jj+u65xnWt4m3/Kbji+F1sHP+Gl37bNx6T3hbiYUJKo7AQqlFiRtuJYSq8qnZVeS/FqRq1nVFbNmTXUxrZ7quK8+Wivnv9IkIO2WVskviji32D2vn7Hagss3DTBNtXdD2xt5ljO6O60PpOLKl7nqqS+30g9PJgclwpR5+DmYkz/rsnNvG7u4/4u8+/Tviot9+nifzkCWm9807Q/Q5OC1M2zoNfV8BS4OWPXvPo/cD9jz/i8Oo9Qg6ce3fqTAVZsNMDJwp5ngiI88z0Ctk0d4Mai3joY+0MzD2VOqMsGM9NmMW5gbu+2l8HRtLHTujzZzU2G7KxCQ4nYblCoap5Zl+xzWLsbKN06YHiZnPE5e9GIDdwYrsKtFTVS+H9MzWjl7MTFpzjM7onxzVMXObWGMMG7zhsuBD0Za4NRG/8nY9BiF+E8hzqfiLd7AiHHSlPXi3o99W62n9orrTlWobWxYpHQB8QdX/eFIQ1CCVG2CWe396yiz6HSgusSUh5h5x2hMe33IYvEm88gAN4/fye0/2Z0/2ZtlTsXImrEjV61UdH8CewM25ewJe+kfnyN2547xcTX/jiDc93Xtie9YFVF2y/8jZ6cDonePL4GU/f/xJP3vsCAE8ePyYl0Lqg5cSn58yyNNJxoS29K/uknE/VA3ttTifJguwE65oaWQQthga4S8bOMwOmDozG7hAwRWWKlcOsXlaPRpVAyFA7dHheV9Dqe1jsAIplF2mXuHHRVLSrAvQ9aQAIoptageDSG2oKOrQNtQdGRuqvm4Lx8JmEIsau3TbOFYykSj0X9rvA+Sj8+K/NtO+/5hdfNvIjpztw95aafIyspQ54kxgnVJw7vupKbIDkjRbzZx2f78AM2eBL7QM5BDbSpkikNc+mg6ybJUkOmSQTgriCPo2qlaoeaY8m4jkGbMqEmNGaiTojWpB0Zj11+40VztqVlyOYGGXt5NwUsUcdpdNX1LajVkHNxW1bT1FHlC4iSAxOTu7ZQAqQwigV+kMtUrptEhfJBzPX4eno220LoJG9TMwh8DQK7zV4WjO3smPqchLRArooizVOoiypsaZIiUrtxIN7lAcVivoydoxCDS5WmQgbCbxWZcXV9dUUWnORVlvZ93LNzctE/rHxolXm1VCbeLBMsh/x3td+HoBf/7t/nd/79/4juP+Uc4K5RjKNGTjMsN/7/ZhSZG8uKuzJb2M+HdH6sPl8mpYu22GYCWJu36Uhexl4a9F3T7VaK0VX4nRAVamtbEjl2jefFde1wy5B0LXTxpjszVzLa/zqGhH47M/s6kuvfj9QGrggGrVveiOgy8BBAk6XZ+OXXW+S71wbxkJCmnI2Y/3AkL/91H9f9hymgIZAiZXSlF1pHOcjmNLzjG647Bei6pyRWj1cHdZhal2o4FqGpLpps0RxzmL/kBJdES2lhOHIr1nbpCuoRpkz6dyY80TcTezrxNKy613Z5Y4Oe7RmdTPqFonMXepjZ4GmhdYiJQhiytMQ+MP0ijePKx/se3JzB3G3Z384wKu3hPkGGkQ5c+r6ArcJpDfqPGnKpzv45Ht/zPtf+zLx3gOzpmfWeuImQKlQX90TZl8DTC9I6Hhe18doLNlz2UkUmESYgEVs4yVdB0MCHuH1dXE4Agy0FDz4WejNG+q8UoPOYbzw1VbcDq19JhnZ3gc2S6ChyxX6nJCemWzPRxwBmxh6gsFFqK8++9rPO841grLBL6ubpszlnOP+DdRs/CwfYPeFSL0VLGd0mtnlCSG7K0gXk2iMBLNQTd3f0sR9Dkd5QQH1aw2xl+F05WyNZSdb8GkqtCisK+g+Mz15zLKe2D0RvhAfA/D+B09ZzsqrVw+8fn3P8aFQC2ANC6cNzbGWyPvGF7+a+fovPOWLz1/w+P1EfmJodsmmGp9h8pZj/JBPXik/XBPPDpVntzc8evGM9148AeDFoxuawEPZocszDpMSSmHZnTksHuQ9vjPUFtbzStEFjRGsEPrnBrAYIDZ/vhWWM8QZiBAyWxWL3FwySnw9naxTbaJQ+slWdQmls3ggNHDesa8N+pHEBNLQwtaEVFpD1Eg94ZKgtFr8mQGIOwr4UHl3HQqJd2QznL5y+b2ZAzVtnlip7Beww8SrvyF8//9a+OpD319CAknYPCNrZYmRWEFT3Mr3zVZH74ao75/j+FwHZiLCWrX7fvnhxre9NNeUti7UsIMYyFe8g7WdiVqR1jgdC6fjSinufWi9szHPCWUioqhFUhFCiUQLvG1u22JSua2+EQ95lpygqPDpqtz0ksWTdWKtlZ1GVhViEDCXeKQjMXO+IayGxsBtSmQBUmYRQVtj7nW306OZpN5tmVLl9lQ4LIAk7rsYYlAvkaRy5tESeVZmbsueQ9pxE/fknjIYhdYWzrLyECsPyTgGWM3Qjla8iMIxGPch8DZOaIzUNBGCsCQIffAfzVgjnGRlFwTTQDHhmcHtG3/N/o9f8pX6hCfLRDoeydPCGuG4PGNd/4nf99u/yjf/yi8z/4P/N1qM2py0esgwR8g928qqSM60KbrxtbklcNCEdG5SrT4mJDrJs+BKXqU1VhUu1iDO8SgGFiYkZ0q30ChXm9rYPEW8nFl6l6ZypT3TN8gIm5H4T2MWGFzKjZ/ZXHoD1OZv1wzMAim4z2AT28Q9XwC3Vl0wsr9fwjf01S6bqPYxujN4SSW03un2d/fQF+90LEiaoXkxdwLuzvfEEplzpKqPryyJGjzwtqi0VtBWeoDqz3oyJeUnNEksNRND7jtuoAQnC0NHK3VFQ0BbIkT3ud0l4dTnj0VPlHQOnE6QmTjJTJJIsT2VkaW79hRNySSgQlIsZHZ93FhTaAENwqwNyc4Duc9v+E+efYf//vwNf92bN5ACU37CLJXyFo4vbri7/5jDzt9vaSsijuy+fQtPnxc++egl9f29oy1ALI1dDTwkYd9uMFkJ1ROpim7d1oteyh2Gi71qEB7UeNqfYosO/s8WyFLZtyuUNcB+UBoQHnpTlCcIPh6jvEvUD3ggkbpuWR7lU7sgrUeMncE5XDhmI1AaScO4bhlzpP+imRKDbI0e1EgIldAKuxiZLHFm9TW5X9iklwqWbBPF/yN45yI4om0MfT8v7Zc+fzZO3gegz24430yEw8w+7KjnwrmspF1yoW5wTll0CyhRI+dAEkOSl5IBzqYQE0EC0YzcKlobRZXzcWGae6NXEsw6Z9cipVZ0SojNTI+7VmNRHj2KPH96w8PpOW/vz9w9nPn0YeW0zuznXuZvSno88YWvT3zhfeWL7xW+9PgFtzfzlgye1jc8MHO3nPne65f8k08q/6Ld8o0vP+FmuuXx5HNbIiBKqjM1KftlcgecHLDVH9hqkbNmSsnY4mFxjVCHTQTe8NKilzJzhXWCuIAl2Gfn/IKXMncRbiVykwxMKVYhBGrnaB4XgbPzQsmhI3VQakVS3ChKoUVSy67+708MoWL1QjUJvRBqUihaoXnDmqq6DuRYx4MDJyLutxnAETVsM1ePlsm1sbOFwyQ0EnLO2M8/4vf/xg/Z/Qd+rqfPDmjMrCLkmwPTm5dEc16yWlf+77pskxiTZLDBDv6Tj893YNYV2L3Ud4lEvQMjsq5nhARhIkujdeNxc5IXtbnX47I2zuXM+XzyvxlQf04++buFj6VASSspRHbj1kUIpbLYpRtJ1aHQ1uB06gu4GefywMEOqFaso30mtmUFMYr7ZebcOXHeNRoNaq3E/p7FFvar8fz/x96fxVqXbfd92G/MZq3dnOZr6quO9xZZl73oiLREm1YkJFJERKYQQZYUxAxkw5ENKbBDxIYfDBiwDFgwYMDwgyG/GMiLbUBGHvXIQLERGEkEWTLgjpJtkSZp3qbuvVVfc5q991pzzjHyMOZae391L0ndRAlw41qFU+c75+xm7bVmM8Z//Mf/X06Mpwd2r2d2h5GxJWJ1JOxweECOlfFR2RW4MWGIA1FGOAakB0qUCZke2Nkj+zxx2jQOG6VE5a56FqX7K8bdFWOMbOyOZMYuZO7DyDePn9G6WWZV58XNYWKa4Noq7w+FZ6/h9r9/BcBHnzbera+4uhvYHbZYgnLbONqGw6L/y9/muX6J4d0f41uf/rcQYJg8w8/CulAmcT5IMHojSMUqaNWVc7dYR7XW3NtTvWtQNTg3igVZi17/X7w2TWimFHNLIWCV0+iNnyzK40sZZ9mkLjeVBd1avQMvjoXgny5+XsoyCxV6sV6poZeE+gabgdu+IL2DcNtLLAtS1jiXri5qqQguwbAFPpPAN39aGf/gO2zU50bJwkYaLTQnP0uDXso1jcT+IYopapWG0kqlteJq71HW5pnQhBCNWhME90tUyygzQfdrN1XRQqwJs4hFIecN1ktJa8MOboC8CEfG6CKfKQ3EIqsIaAqOOKNuGybZ3QBCiN7J3O+Mdm5kUN+Uxfzvn4wT8blfi/jVxjw3xmHLPI6YzVipPHv+AXa6A2B68ynSvIwmCpTAnCqvf/1b3H7l3T7HGrWdSHNksjvGlJmmeY18VjJ+uBgjtgT75u4NfXCNCBMuHZAFjnYOtNz7sCMtC0/WzsCS86M4I0zLWDO7GLPmJbuL3y3JAnrmrS3cspUX3x+4SHdIbyDo+AeLbpAHb9ZljsKa1ME5OWmwcj/X9+eMJrf19x6IDT46qDins5hhDkwxvBfQ6wG5umG8ekaIW1RPZJmhFMjOi1CUUh39jTH6+LPqDJheORDxLl8XLPWOiqYuqqwXn9EvtnhpLgRHnUKkxBtC8CQ8ZIimZIz8RMjHkeExUl/u2ejVWlIbWkN2kfF2w/554ObFlusnH7C7Tuc17nHgdIwkORI58Hh/5LMxcWyCRS/RAl3uo3qnakuICTkKI3Et80Uxcgqk6E1cps29OLngaOGsIBfjhTJDOEEYPBFeLOxCGBAxxgE2G0MwQmscS1mvQwqNnJSAMU/Hzr+OqBpWZfVYJQVagJwMCR4cI2+LhIduMN+wtaPWB62uQds6pgJdAkb6SPIK0NI0ghnNXEHgZJU07hgkcyqV/Q8/5e/+xmsAfuQ3hKfv7jh9/Z4kB+om89npnmtu6OEGYiOtVE7W1nvxux3hd3/IF8cXxxfHF8cXxxfHF8cXxxfH/y+O72vEzMyRMcxJ/+DZ2dLtNZUZtSPNlEE3DJ1PYLajhYI1oVnjdDhxejgxnU7UMq018pwCWWCORjAnzYtFpEVqJ+zX4KWu2F12V82Z3gzwcPDzup8aV9sjRU9MJZJizxz1HPHXNjGEiPZO0ND/HwgdvfCU5VkTnhzg2eOGq1fC0wflZoJy923a6TUAQztRHu8Jp0LSwbPguAEbqaWSD13U9s0b5pefUu8/pbZHSDObjbEZA7sOu9e9cbpNPNxcs9k9YdzcsAkn3rRPaWGDbv26Po73nB4PxLDl+SbwYWu8d7dj99Vv8+5r/4w/fBjQX7ujfKtQwxVkIWwit5vCphu+T/ENQ8zcbkba7ppyuGcIvYQjspYyXCxEoMt7aG2dU1XRVaxJkVrQ0ijVzeVra66R1Bw5AVASVZt3OhrQGrOaf/XxNkknGfcOOOgdar30dIlMQeeZydlG6RIwWxqO1Kks/jp27vBcXn/VMfPB7ahHCIxqPOvj9IkY+35eIraaPaudydn96Uy43tkdYBtl86ffYbj9kFq8NJ9rYNSEZaGqQSoQE9N84HTB5Zqrm73TJTJw7i75sqs5CaojZhVrIyqRZhMmTwFFe8bfWmO2iRaNFBuKyxKs8AyOQIeoxAR5EKKfLtocEVvEXCPeECKGlyW6OGxY7Mj6lWjRyEmoLXaLrsC17PiN9JLX171cmxIlZiQb+eqatHkk1MapQB6cxJ22WwY1Hh9PpAB3byovnu351rceuf6mz8W6FVIzcoUHq47MVKed2AWi5evYublDl7EjrJY5Ww08hMZR1RtjOlqbcFujtUMRXW2aFkTXJTTO77e87fI+C6orph1Z6/dHbS2H08ej1/P5HY+lnA9gfX1eSvuKEVCSLcjwufz5edOaZe40zuOa/m9vGvCSpi/BBgPEL/WO2HdvaTdXjNc35KsrLLrwp5CgBU69W3zMEYkD0i2R6tQQLWSJq0l2EheEXZCwRqNqYe7WbCva0j+nhAQpYzG5gLceV22rhhFSQFIm0O3yJHOdHI225Y4fJ0oyZonMKVPjllNoiBljv7ohZWK8YjM858mTR4b8NX7zk5mPX534aDqynVzaYRSQVLAWCLpjSIkmgdkK0uUyghpRlBQbQWoH/pTU5x1AykaMS2Wo87Sqz8VaoPa1V1WxNmNWsFaIARKRoIXQebuCd8KXUglaXIrFjGKCtXDR0Q4pxi7HY7270TCxs1Vhl+yho24ii9SJrpUsPy9vQGoCNRoh9lJ/O0/GIo0YKoQNjUxgotpMnTJTMOZ/0NeIr92/YXz1PkNWaszU6UBm5FgeGcxLQHOONCqmlfl3mzT9+L4OzBbTY1sdol0kLqdEa43JJuZWmdvMxgoWllqZcbIZq0Ix5fh4oBwnpDZCMOLaQWQ0lF0IKAEbjBrVneuXJWcKhAoP2TtVogmuoK8IcOy7+uvHmWfXRrWyCuE5WVFWkNWsEfMAIbicxwUsmwy2/XFXr+/54LPIu69gdxDSw0B7dcReFna9dGqnE/XNIxxPYEK112j7GqEFKMbp5RsA6lzQojQTqnm3oYLDxN2MmgB5X3n6/MT+vZdcv3/FZn9F1oBdJ97svgZAvn6Gju+w4dv8cNvxg3fvkf6Hb/P8zcT1Gx+Q8rUT0/9QmU9wCg9nA+/xNTleA5By4iS/RpAbruuG6XhPTmAibjy74NtNISdXEZfgi72Zk8nrwupvaK20Wvx7M2qr1KYUdY0q8HJWUV9cTQSzwtQqE23dFBe+VuvdhQFWc/KFr3N5tGVYXpQ2L6elGav0gOkSkJ2/oG9SPcAalsBMIgOFJ/3FrsVLVnPXN2ty7k5b+Ed9whDNuMc7/aY/kYm/74cYp4FN9jG/yY12su7XNzuBPwaKGA9zY+4npigxQu7xUw59Y+eizTAOKEdaGwnSQM05fVIIZLR0vTPzBbK1Rg0uW5Lz1ktK/eRtCYlD8a5P0d7MEbosTt/O1YBA6AG3qTcZSE9z/Do0QhD3R6T7+zXhOm34mh74lSuXNPiHb694mB6pxxOBSOoyMZL3hNG5icdTRJpTegUPuLalka9G7j99AODpR3uaRpodyRppwbX8PDk4l7T9hJ34vHghZgWN52DlmQlvggfYN8B9f1zAE5Zl3KwVoIuX/vz4VHpSsIzl/vso4qXM5bU6xzF8/rm8nWws82H591LCNzs3LkmMSDOqVd+k5fw6q19j/91l6X/l0V285xKozbiYdQCKKOkFjB964Dw8uWa8vWJzdYWkyKkdmLSQyIiMSO+clejl29oqqq5zmCSwH/AkgV66s7C+v2ijtUrpUkaLL3JmkW2AJgELgRoCqRmtcyY1AMkT7qkVDrV5Q01ohJSoy/qVfV1+rI1Jt8wakVAwgstHACkaOSU3Xt8O7PfGb30y8dVvHfj41T2bYbmPTkwPmtkEQ+vJ1/15pvaJrbUS1bmIgvau5t6U1qdPCEvg5eMiDYEgupLpbUm4qnlD2DwzEwhJgUiwuvKqhQLaaG3uwrABbY3SAuc0A6IEioqX6wUsCCGIUxe6U490gmOShIdA2lUN/FxXbVd1/q2Yzy9vSPE5NSxrVy2UFojhgax7ypSJ24Eoxv2nrwkdRPjGz8L2//6Sd8ueaXjDkEeG6ZGTKfte0q1BV3pK+R8D+T+Lk6EbeM0ZyEMgZ8G0q5DrjGohJiUta7d6tl6bUGZjOp6gFsbs9g0LYqatYUHR7CrDO0ldI63SOo/GrDFLJYpRm3McWifH5nDWmnpzKLx8rFxdHdjmgdIGTJXWzqtyjNklPswXWRFH9GiNK4s86ZPnw5eVdx+V3acTvDzCS4FXSnh9Yr73Tp1y/Cr1zSP5BNTekq7g4aZy7ISBuXmQ0cKZNxLwTHwJBKsqeq/IPdhvKsM7r3j6wSvGJ7DfPOFbT93P437+OjdXI1/e/CC3r+DJZ9/k3Tcnhm9H9Ot+7oevn5hOhuVALbryRuoJBvPMTjryE8Md1lyvyYpANsLiwusPdCkE8YxZug1UUEWXBbDzn1QdUdNavbNLoWlzxXFcp807unzh1TZTPTRZN4xF5f9SIuAyg7+MupaAbWmMW/60PDT0BcYREqF17apq3pW22Ae1jky0vnsJEJtyizdlAFwt45Xz5tU4b7oritI3/Bl4/J8F7E/9HnIaCbGRTp1rExqhTm7xZRBbpC1ae6Gt+lAiHlSAGwFLD8xCCOTOfSvW+SwWWWS3A3THjrp2UnsE23C9wRmtBVP1JoC+KWrv9G1t9jnNwmMJmKT1uqrpGvAu98iIKLLq8gVwPlAM/pN0cRSJbJLwd594QPUH3nnB8NkrF7WcA601whidG9VV/W/efY/D6zsGE7LNHB4Lrx9OXO83fNKv/O1dI47CHJU4Byw6v8rQtat3HRvm/Jclvh3x01vGw2DOD2z4+jeoMoujXJedjbV/TuvXYAly7GI8LKiccPHEfggX/DE7P25BY5fX+vyxjHPtDDIR9WRo2RVjWE3F3ZVE1jVg3RO/83T6vX070ViaD0r/rMEgPoPhoxt2L279ej3fsn16S9rumFukWPJAU3tne3/Taa7dYL25JVtLGMYRJffqSOiTWUPysSOVWmemOjNrI62OFy7B0jgHZ01cUqG2c9Iv6sHvVPo+oN5IZRdi6ENwx4vWAnNLmI3YFNwecElaTEhkchjZjzue3l7xq/WBr39yz1e/uWG/7Wth3JLFTYdsPFBa4FSV+9PEafYkvDVnNopGkiTmHJAMIRnkMz8uJGDqiD8ubq50/lfsMteSMEnUqrRwYtbWZaUqi2Bh05nGhIgxdT5fw43bUA+OwbnCVTIpuwRWjIFxyKQxkPJ51FkrhJBItTLj68NSaVqvl0LSXu2K/VY2CFVpyc/rpJVhEqJUIplchVkmTuXEoMLjfV9/3km8+YMn5P9mXNUtU5whRLTNRPFrKjYSooF4I9vfy/F9HZhJ8s0iqLj1Ai6hEEUQNTTMvvFao2lGSyd6i1sw1KbME7Q6M8ZAHhJB4qqJVM1o2rjSDFFoydmOIonUM/KRABqZh4LOYM3QaJh2HnkPIsoMDwfj4XjHLiUGGRC862TpIIwxE5NQZ9eIEmtkS2xr5fmx8e6dn//TV4WrbynyG6+onx3g0EjHgty9xD79+vp+ofhGf6QjJPgkUsCWz9iv5aJ51RMIjG7b0w8vg7j6u30b9E1gGAMvXrwm3TgROnz0hCdPR57aV7mdhSeHCC8D5auJx0/9MW1uBHHl59IxpIpiEea1hmEkIqKKYL45mRCruaRQH7UhBLAF7u6TU13XTXvUYLXQtPnvVLGuw9asfy0lQ2selJlRRVG8hXtR9YeenduZ4LwQtO182uuxlGyW4GhByy474haooPYdcPHELJytdyT0chSgQchq7Gk8B257YJP1vEEtgreOerrIalnKj9J1qn4qcvgnP+Bm/5SdVhI7jjsP6ENNpOBZsUigSCbXjA3S50wP6Ds72+xclvWfL7xFW0V1wDwX7fIaM2aj2xP1q9FKQ8W31xAiGnyDjCGt8i9WhdrU0U4t3Q+z9i5sVmFVs4Bp78Yy/D0kehbfL74FAQn9ngTEApFIi5GdJb564yhKfb5hm/a0TWNqjZzeQduJ2BKnR6+l3rx7xTs/cMubr32D+vgKS8W9+KZK7kjFJ68O/MCLkSCJ2SakuLSMt+lfBCRKb/bokZSdbbbK+vka17agCcqm32/pod7lYRfjTPqYvSTPXwZD6yEu3xPNWIRelkrRcrg4cT/FBQ3u7/H5iMpFXGpfdejK7LGPBUc+Fy/MyxLlcv4rSidn6ZhzYBmpNGKEh6aknbD98i3jl5+xe/eZX9PbDVzt0CF6CTMGctjSrCB6lmwpRv93dlHoGD0YKjAvXX/mKA29ASAKqDVqaZRWqB3JT72MbL3kph2ReQhuW+YfspFPjSS52wJFzIQxJKaQCR0qthCJ2qjN5TeaRqiBaHkNgmt7pLWGmRAlMYZEKfCNb93zW9/Y8uyJz6FhSGyGTNUj98VFS05T5e7hxOHkK85UG03NXXDyFg2RuUbGTWXc+DsO48yUjRZ9XrZmnYjvAbl00UoJGQmu3TXPikqhiLlbyuI00Aq1lt4wMHYR9LNMxtJsFEQQmnffJ2NIwbU+BYZFqDoYFmT13A0W/Nqql1kX9F0Biz78xr62Fv8IpH4P4zxRicyzJ30SC/X4hlmNEtNKZYoPB6Yv7Tj9wRP8XwtPauYhHNlW4Wh+Tb3pLtDM1nX9dzu+vwOzZRFJAx3NJCWht1ygWntJ0P/W2nnK1+o2PIYRk7dHxxhAWRV7iwpTU8JQGUImijCKK5mXrr5fJTCJkAAbxK1tBFrtg+pCaqGZy3TUdnTrhggh6gr9puTWDZMaVSpJK3sdeXqaeedl4cVnnrE8+axgv/Ga+tXXhLsDvHrk9PqBWttFxnre4CfOKMo56FosUmz97zLIqJw9+taM1rcMf96sxFlp9/DiiY/4q/rA5rpxEwv7OcHhiunlI/opHOazOOvIks17VhkJtMU3Dt9km527gRQY8ejD9Nz8ZM0DWFQIzbVtzLxcsCJmtdJMqbXRuqGlsWR6Z5TQNZKUJt75WIHJnFd2WVZcsvmle3y5Nt6mfR6bCzKxvM9S2rncyJb7tEAaJheimgta0Z+bcR7RJhhPgHcTbMv5vEyEuSNiC5Ig5ndslfuIYC/gm//4FfnpC3S6p+R3uEoXWMs2o6cMFCRuiM0B+I1EDlWQrke3lKmixzjrZ65L+x7QWkDlBC7kAVLBdiCzB2y9nOFm8v11LRDjTIgDISSsI2amgTYLpRRqdYRTVSmtrkiiX1DfbIN5+cMFK50TtIiJhuDG1KIu0hxJblydEtaE13hg9umt8P5mw315TTZgsycdG2WeGJJvPuVgtFQoYSSNN6TpRG1wP1duekTzTYzb+5nNRnxjaZmjFUduL3OK5ctsLUdnurhwf8wJ40aFe3xc7hAeL7rJ1vnZA4LZXIB4GSdwDnzqMiaXsdiDiWVwLvywS4S4Y4zemcl52LD8uwdZrWuiiYW3LKWCGjF4oCydBLnYri0vtsi6LNdk+csyrhfqx4yX/wYDBth+ec/Vx8/Zvv8u6d0n/gLjHo3Fx2EyUkxoFIQRQiIus7uXkJspVY1IJHY7obdKvICIJxA5CjOKNXdDWfYXp6m0dX9aOgenNnAqvWQ4N7JUUmy9u7yX1nXjQUUPNlKKbvyOcJqVuSpzc+P7RQjd7eEKBHONzpCRBq9fw9c/PfDiWz5Wd8OO2+uJg80cSqDVxDwX5qnSukWSqBFiIocNGp06oiXQtoWpmxQftpVjbpwCqHiy3AxyDGCJzlDEJFPVaKqdq9soIjzSOPZBNVdDm4BGRxlVaa13X9viCAuOkmVSCsQkPhJUWGze/NylJ+rVea4W3J6t2WojBp4AbYLPqSKuJGB9eVoRs+D873ByTcy4uSXJSIxHHss9Jl5haAGOFYaf2jB9+inH/zQzbiqtDSRXu2RXZ3dL0HP15Xc7vq8DsxgzQWDIYQ3MYnRNkhgSk07EmEgpdYTjYoXxRn+CuMVEiL4smsR1M6hE5hZJ0xFQ8hDR3O0yFrpK9EhdAuQcwRSLvlOEC0w+IMzFmKpSbOpm6wZS16ww5wwnV8tXDJqyQ7g+VrYvH9l92290++9fYb91B986ML++Q09tXSgvA6tFCHVZ8M8K9bLqPi3H8jy5+H6J7iwE+MUSaQnyrolMrzssfZzJ1zPjMGCzIfNn2MPMnRq5t5xPprxB103BX9tDm6WVWDpSIH0LCK7FTUIIpqsoLM0I6nIZAlh3ZdBSKaXnJk1d9qJY1zSjw5j6FrfncuFfOF4zb+uAVXwimp1lBxYkTI1V8BUuFgFhFX3lvOedPeDMg2PszA271Ida0LJrfNEccVHT59vIpr/haeronnXZhgXNMEdglrFqWfj0n7jFfuwZ+XVAnmeSKo/5yKbruR0Ftvtbryy2GYqSbEOOG1JpSOo6ecUX5HiB6PnHsLXEWc05ZWbFSzAmfhVti3G2GFsCM2vNN2yZSbESY2aZ2NoibRLK1KgtXpCCZYE+/Xa7dP3aLh8lYWYESWv5MQc/tyjOBzVxuQ3zfnze9AH2aid8sN/BEXaSeGgQLVGY2AzOV9WqnMpE3t+AKUM8S2fUPptuI3w6wUfRBXof1CAGrOpaor4ch8tSpcs4sbOW3tRcHuWuj8/r5XlyUVJfrgvOOVxI/0ticFnKXObO5eOXub2+Fuf7u6wLYrKWJJfXWhEzWb8huEdhXozazcDCKgy+8Ol4i1HUE6bL63KB9i2PPPWHNIXt88z+K8/Yv/ec3e0tp2vnmG3SlhYaQib2oIxYkDQjSRjqqtGDSHTiuXppO1pyonwPwoM4QXypvKAuAdFal+NZLcoWq0BbP7OqEusBm7oo7DQT0ojY4DZ5MZGG3NcHWcf9ghyVqhzmmeM0M+0CZarsxo44YxQzmkWaBUwHUgpMk/LyzYlvfNPR3as0oDJxkJlPH4z5we97lETuc2OIkZR62V/dNYZtwGpg9u2Hw2PjuFGOyWj1DJB4IhVWeaEQMybuPTyXB1DhsTUerPG4kOyL044ww5pgoquXtPtldumdEIkpr4HugnS6jdKiT1iZS+2kfz8nVU8W1/HZj4onzQFP6NextYy5phz1DVvbIgaP5Q2bITOfjOt45Z7KwJQ2LsvTdlz94ff49rd+g+f/5RXT4E1CAMGct+cOvX9vx/d3YCZuWyQ2k6NPnhRc7VdN2I1bgiTGuCGlYfV0RMQFLMmYGEMeyDmTJHrE342aM12HpiRmJgz30qqtrJvwIBnNlWZuBpuHQKzCpMZ8cXWziww5pGojKoFsCdFAXmoeVnjcjeRp5s1UuR4r28Mb9q/hg7uJ9FvOfbHf/Az99IH22ZG5GSWCIVR1Ajh41rkgZeW8by25zHcNxN5mpvsmvfx73SA4L9wi8IiuQqdMbvB8FypBinOa1B9/7PyLKTjX7TIQ8qxY3woEu4KaE8vp+jrmfLPL0e1dN5PzkoIgc0C0rRD4icbcLdOC9UBVdQ2U1gCof56FgH+k28/IOShd7IREztzBlc/0uXIP/fUudaQuL/ElTe6x35ejgZLJvcjbT5ErPMMbNbKl8qWfEBqNV/+1P6YQqB2huDa474FSCYFZlXkpqf1zO+5+bEt+bJziA7fzFSE9ko6J2nxk7GNGsjLKDiEytUeSBHLbsOGRm9E/zKN2FfklwIwgLRBIzF1AMeJcHpVNL9F5qQHch3XTeZqlNeZaqarUWnu3l6Nm2sWeWxlpx0CU4H6dmkhAboG7+UgaeheeJYxIacoWv1cpRAz3tAWcZ1QTBMUkEGTo5SkllcYQ/Pz/q5vf4qeuv8L265myEfJj5XEI5JNQu+2MxMTObpFd5vHwQNs8gfaaVtyAHJwMHg3uK+wVklSyhpWAfL7XQhEXOY54s8eiOTdfjHmTRgrw0OBpNJ40+IQlkfEHPgYYmnQu1xlBu9x8Fo9ANThibDs8peYJ2OrMgK3lxpWj9jlRvstxXfCiRV4QbnFxYKDf3yORgAVh6HNyvHiNIs4BWtCyBW1b1oVlPs49ZLQtXP3QMzYvnnJzvSNdjZxCF+S1BwhCI2PBvMMwJidKIaTs1yvGDHjAOOBrhGlAcyB0xwhV5waSI0wzKSnTfGCqh94R3sne0jwglUAOkZRH0rihxIj2caMUmuVVHyyJkYqhTRmSMfX7fepBS6hwnI1Xp4knh5GbvfKoD+tFtxYprTI1RXlD3ijMiVpd0xJgfr3htIOHPPNwd+LxtZGHxJAi+9ED2V3YsQ1bxgFURuY5oTpxnCb2gyPJu1xImxMygDyw6n6Cdk6XJ3lNhTkM3HMkK9RaKL2E31jAD++QproftHNAAnOrDHlHTOcyZUqVrWSv7yTQKO7n2Y3oTUBSpE6RpsVF5WMgYQR9Oym2dtbD1L6Gzw36EsEmBVJVHkNla8omZGiVGIwSCmrdbQAjmcLrxnQLX/qTH/DffeMbPPsMNj2org1/Q4S6whq/8/H5Rp3f9fhP/pP/hD/xJ/4EH374ISLCX/2rf/Wtv5sZ/+q/+q/ywQcfsN1u+fmf/3n+7t/9u2895uXLl/zZP/tnubm54cmTJ/wz/8w/w8PDw/d6KmeRyTisBuRCXIVZxWAcBoZh46WQDmvW5pBpMUW7ZZPzBPzvpoKp93GF4N1+atatSvx91vcOkTEEtpvEuBHSCHkw8uCk6Msv8OyutMJcDxQ9orT1/dQiEiqihV0Qto8ntp/dcf0wkV8e0V//FP31T5k+u6fcu++jW1rgtW87Iy4n841zkV9Y8rdKz+j5rrHE+T5+7o/fLbgAVuG+ZSBVg9KUUrugqMAh+GJbOgLty2lEF1aMnNGzpbtxi6NmC9q0ktgDpOxfS4n6LUNwrV208pKN4pvMikZdfIA1s4K3MvTFd+8yeFyRR/GNqtp5o1uEXS8BELPz+61f/TUWjhvAIImIF/wibuq8oHYb3CN0ViFR+aEP4frdSPy17jzgrRrr+5768wRAA3oD8y9dMf/SFQ8/svFSRTOkqnOxmmeo0WcOKUQSaR3fMUYvA/Ymi+Ve58EdLpzwv3xe7egBtNY/f+f5BjyoMrPVLH6uhbk6X8w5R4Zao9QTj6cHXj+84k3/ur9/zf3hDcfjI60VvG2joxuW1vdsTahl4TF5V6eXMroQbqmUUmhaele3myC31rwsFRO5CrkKv6mfwdMtaXsF2ojZO0AtBiduq5FScIRcjf3tE26evMPu5omjiOrXOJkgCV4Vn//WusROB0bceBsqjv5GfJ4+mgf2Ax6YzbqMCeFqqRBkYRxC32TOLhSyKNpffHnH3Hm8ui1cR23l/Jjv5VjHvl2M9wt0a0HpS1+LTKSTFyK5th6uL9i4/xfM6SQNWb1Cj9/lSzDGDE8/Gnnvo6c8ee8ZcnNL246Mo395YJ4ZcmaTRjZ5w5g3bPKW7bhlu71mu71mM+7ZjVfsN9fsxiu2myvGYc922JHTZv2KMcFinN05TK1Zt/hq69h+y/6nr5gxDgTxL+m8R/OKHCJeddE0oOIAQwrKNmbGlGkKh7lxOFVOoi5War6HNYxZKsdWOJYJCOQcMFHqVJmnxjw1plo4nk48PkzcPx44lcpcGybn/Szm1AV0E4HILgVuhsj1JnC9i1zvIvt9ZL+NDCPEYbmngpJWwr9JohIpqhxn5XGqPEyVh9I4zI3TVDlNlWlunoiJEEMgikvu+Lw0z6ZD95PO4tI84p6YEZ93LvECNF/PjMUn0/FdkfgWWibBK2v2uejHjG4i700QVXtTWJdYWhHRi9faaQEZKJvA4bFyv7vlh/6JK95EeGzCYxNe0vgkQCPw+PeImX3PiNnj4yM//dM/zT/9T//T/Ok//ae/4+//5r/5b/KX//Jf5t//9/99Pv74Y/7iX/yL/LE/9sf423/7b7PZeF32z/7ZP8s3vvEN/tpf+2uUUvhzf+7P8Rf+wl/gP/wP/8Pv9XQIoftJ9nBY8PqyBCEHJ97l4O3Hre+EiyeiivtmzqV7apkv0G1pEujbcjFFgnPLkEiU5Fw2PChYSqpNYS5Ki4oFiI21E0QUOiqLaaBqpFkiEbAepqs9EsjMNrEJiau7E8NnR7b1Gl4eqV/vwevdkTJ7ma30AMEXwwvNLc6BxXqtkE6u5IxyXR6fg9E+b0lnfC6SN4f3L5Xue+zlaFTPvsXe5rcUAaX1gMylQebW1tfeAE8Fxgwl+XuOE1SXBaI3/RCST2BCcBK6GoK3uy91nYX4WbuUxBo0sQSfF0GdnPktFVu5XssEXgLEBTVIbz377R8u3+cykF2m5YIEKG4l5cGUh6xDhE514La/z3OM9zaRJz8uHP67xnyCRTU9SEVCIDUl9et3z4b52Yn2fxz4zR/2bLi9KU56DV4qmFGC+Lador9WSsbGAtrRZTPt98oV/cfFwSr4htl6Jrrw7LQT7+nXygNxWQOFpTSjqitu644KjpbVOjPXmWNplKpoH6jBKklHNwGWHaqFFtyOxSxinSOjLVKLQo0eFJpbcaEN6d5HYdFw6Kuw9ADpaIpZY+yD9df1m9w/LVzvrpGHT4kxkOOIbFbfD1KO1GK0uZA2A+UUuXn2AalU7N7naySjw8x86giYwWwOGwc7z9Em57Jw7WhWNGFDJC0FvwBCZItSUIr5vM52Lu2BrzdLSfE84nyNWBt+Pje/l/GKLWWXfu1hdaVYHvPbJXWyPJY+XxakvTOvk0A15+QGFWqENAhWbG26ms0TvoowSw/oOKPsS6NHHuD2B7e8+PEPefbeM4abkXoVKaNXPsDNxnNMxCFBEmKSrh8mazACjpgJHnC3Xp6s4uMn9wYUdAYaLbYVCDCrPeFvb3HMlsAsBAcKQnCtsqXshiXUjKqNaLGvPYqFRKMS+saRRUjmDTEShBiTI1XiRPjlBrk1lfPchpTZ7QcOryfXFVu4qNo4FeOkM6djQVsiaU8QZQkenfqTc8IssWcgcETHgPr2zW67YRweyaHxACQL3nRj7piycAArwqEK2gqpVGrxzsvJ4LRodwaBrjOo6jI2zTxpMbPV9HuTlixw2ed9jIo1FpKrrcReQ5t3eYmELndyVse7bFJZ9gRP0G0NurQPOFHn9MXOSfGSv6yPKzQSiROFcZjg9Z78la9w9ed+k0//T28AeF8CA8oJeHx7x/htj+85MPuFX/gFfuEXfuG7/s3M+Lf/7X+bf+Vf+Vf4k3/yTwLwH/wH/wHvvfcef/Wv/lV+8Rd/kb/zd/4Ov/zLv8zf/Jt/k5/92Z8F4N/5d/4d/vgf/+P8W//Wv8WHH374vZ7SF8cXxxfHF8cXxxfHF8cXx/9fHH9fOWa//uu/zieffMLP//zPr7+7vb3l537u5/jrf/2v84u/+Iv89b/+13ny5MkalAH8/M//PCEE/sbf+Bv8qT/1p77jdadpYprOjaZ3dy69INa9EuVCYLKLAAUL7tsXImhEFELPonQ6UXtL7SnU3hiQ0NjQWdFesw5x6RbrXTNdLFBCJi2GyOLIWsJlCSQWagSJcCogS9NPx/RV3Ruw6YGihaAbQu+F1yqYOER9E4ztsSB394QWqJ/eo/e9jX86I081OL+o2pkfRf/bipYZnej8vR3fraxx+avPAU79Wlxkyx01czBzIarKipYZi37YyEDjtiNTL26uuXr+HK72aEyEUmiPD8TTPWhZYciQAyGPbmzbgK6PI9ZWjSzVM6K4lFaWsuPlZ1zOuV48buHTrY85J2Qruva7XatLhOx8Fc5/XxDGFgRFUDU26kgZeBmzJrhO8M5PN+KcmL9ppBRh7DynKbLJBZGBUzDqQ2P+kRPtn3vKbz4bOb3uUiWyYd8CY0gMEknm3awhntWzkzlFIOAEoUZzKRFcRXtYOB+pMSvYTNdUcvxLq60SF9QGkqEF6uxIw+IZr3puvnDwqqMUta7fp6qUDgHnMENISO9bbbqI6TbKXIlpaRIAs0gtoEX9qyMKcUHfurbD2osq4lm0NE5zY9sh0oM88Cvjt/hHbq9JLwdsbt69NY7k3uEQEBgCsUEYMzpfEaIyPP8S++m/AeDxNENx5fWXzU3nWy9TCmcE6LKjdhF0beY40bJQe4nTmSoReJiNbbhAbxf0uo/H2hHgVZsRWeVZVnRtuQ8XY3N1I+9H5Fx6/+2OVSx6QbguGgqmxYxaIiaFOQQqgl2DPME1sTq5vM2QD45ae5lV8GIZbAfY7fs5/UDkyccvuP7oBfn2irBNyEaQLCvMnWMkjQNxSMTcS3bJux4lh7UpKca8ynioKk2MWZzqsvAoRQJqJ2KLhOT8phA6yqZK6xIQi4zLippJr5stXyyoGkTNmDmVxUQgekfior2FBiRG9lvhxTuZD59tuR2VqyRs03KPGzHD4ygMObDfjuy2hZQqrVbnUwKnZgylMFf39nTxWK8ELXuFqhPnx5RJYcNOEkpjnisPfe5uc2I7DuRUCOksQ2Sdr7H4utbm60cpjdi8KWhqvSy/Oj4YMYh72aaIxODm79H1yhZNtGEYiNHlmkJwr1VjQeqWtf7cFGDmVJqF2iSf4+68Rc0R3x8vN7TmJ4f292pddDJKcgS+X7EWE1KObEOD4QoLleNne977vT/Emz/uFK7X/5cTqQqvk3ET0qUu1G97/H0NzD755BMA3nvvvbd+/957761/++STT3j33XffPomUePbs2fqYzx//xr/xb/Cv/Wv/2nf5i3a+l09cYC2PqBqmDq96K7TQ+ee02dDidf0p+ESsZSJn53/IokejQCdyBpauLodUpZcfoyiETNJAMUFRNyPuvJF5mV/SJRUaWGs0LdRqDBKx3Ee8FfQUiCGzmWeuHirhfibMB/S3XtNOvkudcALsRG/x1rMoabkYfG8RnjiXnBZ9rcvjLS7v50qanz8uS5pvvd1FWcbLdH5XZhb1NC9rZYSK6964bUzlHeD5Bx8AcPOlj2hPn6Hba8btBmqF168YXr1kfrhjbQ8KvaNGlaaza91I6cLC/bq3M9/FN/LzuX7H57JzB6pyDqyWz/T5oPN3uNRvHSurQM5E5svr3RCCBEJr7IEUYdMD/xQbV0/h+Y9DeA75bxq6A7O23uv0pFAL2NSQx8bjH4X5z32Jr4cb7l89MNuuv+kR0W0vAYCpUudGpJFy/5TJxUArfXPpJTVCNyeXrlGks9sniQe/c/PGCGugupRr3FrJ+sVt1ZhroUlwDaleljATWhO0iZf5m1MCjtOFbEIoSHTtszhY5+R4mdXMzhpF1cAW6oL4phCFmMJaTjXxu7V0aQqBahWh9gZPv7A7jP969wn/yIufIXx9T2sHrFZcy9DfsJSJMGy8RKvCdndNazNsE+PtUwBOp1dQIjE1jsBDgGs9j5eFVhB6C3/rPMy2BG6s8SyzeFPBNX4rTrM3Fwz0AO+ig2bhei0CrNLn3bItvCXJIucxv1R6v9v0Nzn//jJQC7z9HFmyjv6YPmPJKKMYNYDWiDyD4QcUnUE872R7HylTQ2eIR4jFCOJ+4+M7sH/eOwhfPGH33g55Eqi3A+xuXApDGjF3cU8yMQdyvvgeo3fTp8jQieopDQjZ1xJTqnh3X4u9XNavqVhEoqw8Y2TGKL0MvwRmHqyEYLhock/quWgQMEUuOH5mztYN2ghRsF6eazGyGQLvvDvylfcyHz1L3G6UbTDS0qlrjSFE2jbxuB+42Q9shxHiibkKjye/K2+OFRtOHGmoOX3HrL0VRPo1U6IExpRdX1AmTB8w63dRjgypsRmNcYT50Tp5WrHay4g45xPrTT3mTV+n2pvR1gRCnFpAIA55LS2nNCJEF1wHNnlAQsNa8UQQN5kv6sEZOLd5kcoBekDmZWT/97Jn+9hcrMpWKs5iqdjPfeHOuhOEAdFpFxrWkvSogxuy656kG3Sv0A6EV0/40v/iYwD+q89+hZv/FPZm3F128fwOx/dFV+a//C//y/yL/+K/uP58d3fHl7/8ZWCpDUdvucUDM/eeVObqAnKxL2+th/KtCtoCtVZOUWhNKWFmk0Pf3HtwZ4qrgldcp9qhMLPmkBgQQyZGJVtmK4GTFs+KBvV2+X7XW9B18SrNKBUwZUgwdrRPZOZEZCQzTCfSqTCejPhqpn16WNGwKXRyP7iIIWd5jGVgfZ7/ZGaf+/nta3wZbH23Y1ms4cwli31gLy+1kNtdyLIrpNl5AviTjSKBaplgjb01rkLj3S+9Q/pRH8j6Qz8G2ycQIm3MRAL56Qew/Qbx/g576IhpO9HqRJuPiAaEhkpjEa4EMNEz0oW85Wf3+WNpjCjGKmbp3Wjn7G4lUss5aPj8KwrfGfgu3fPL70OX3XDivRF6V6WPB7c5Adi923jxo/D0Q2j/DyhToF35WN8fupYWSn2j3N805j//Po//qx/ms29NvLz7NhoSpfYsXaJ3QUkgBnVCK+Z2ND1ICVpJQZibayaVhttWhYEUhdh6x9UkwAEJUIqtAo7WWHkgzu3wtoIQknMgi+t6p5BJy3U1ccKtQtNAq7JyY9axlaCGxWHDoWcX1GQl8IOPf8/eA1ikVaWUhkpceaEheGNQWO5WMG9uUPfNW8bIUwY+ubrj+C5sdldwnEgp0cwIvaOUudFqQQi0Q2F3dUXcPKE9viK8+Mjv08M95VF40MYmePft7mJOrFNDz+Nj6aIU9aBsQdWO4mM0EThFc2QHt+kSaW9xSqs426eYrweBsxyG/92vwWXnsHxu3EJH7vrfPj/Wl5+Xpp1lri9o5NInunBfK4E9jYISgyDvCvpRxI4V64GZ3RlWIHTWvx284hCvA9sfGBie+9q7vdox3F7B/objdkscvcqR43kNlxhJKRGzy1Hk5FpYbjEUibJIYaSuTya+K8foyWQDsQv1Ka3ORev7RDClNif+F10CknOwE8T6+7j90sKZUlVadX5oaxGLjvbE/n3haUlO3NxmvvxO5ONn8GKj7IfAEDoRHucwjkGR7ZbDdeNb+we2uRCjO9s8PPp5vT4KMiqS+4CwgsUBiRDzgoT7fDBRihaEmVInSpkopQdmvaHFpTUqRXwfarWuDTYAtRoyGCaRuS0Cyo6kruO8c8yQ3kyX3NM2hkgKLnECzk0UCbQY1kFXzaNalfNEEnW+8QLHe2L52yAMyxg21mu5vPaaeCtIaGtMoRpoFsgLQq8wDNcudJtPpPkaHWYe5RHp6/MP/m9/il87/Aryt0GHi8nwOxx/XwOz999/H4BvfvObfNDRj+Xnn/mZn1kf861vfeut59Vaefny5fr8zx9Lh83nD7UKDGjPAADEPIQyWL26VCpIXgmolcBJhbkaj48TdJSrDI0hJlLHrot4mSfEQmzR225xFC70m53C4JY29O7LFr2MinX7Eb/TVaD0zWeuE8cESmCoYSWXpmDUXHg2Hdk+TOSTcvUA9vU7poezZ2PrmW/pwVbDM2mP6T93jS7+/T234F4cazxzGWR9/jHL17KA96Bta3TJzgVBVLLNbARur/bc7PdsvvwT2Ls/5Of8/IdIT94h5uSirw1kqwTLxO1rbPPaX+z4AIcHICN6wOYTGioWK9bJ7KYTWhc0wAOrePEZPt/gcImsLfJY9RLe7p9P+7WQ73It3ur6tIvrwQUKYedqUZYuRpx9A95t4en7vrg9+8HA03eN9mu+AUsuhBPYbsvUd7LTtwP3v0+Z//l/mPlH3ufwm6/5TGa2V094ePwE68KQ+jBg4tpvSRcrEsOKe7j6/QnMG+9UnlVQywQZoXc862IWXpQYT9Sk2Mlb412k0dbAzILv/C7X0LubTEmSXKNoDczsQpsoYCthup1RHfESUZVCqyeSRUKLRHP9v7XJhoBp9DWgOVLeKkh0v1vwwCyI9UaEXjYKLvUR7Fzy3OQtZTjwm7cP/MTNNfbqU5IktBSHNYHcEm5IneAwI2bEvKHlBMMtAMP+Ke30kql5QFYNDuJNGpfjqsK5K9O6HIlfEVK/DofWmy5M3UEAODZlE+LZC9CH16oFVpfr15GK9rn3XSo4y37mXXKwCPK2bnO1Nrrx2yDOvI2iaU9pTC7mv4FJcKP0UCkfDjx8GMlvhNa9U8u1MT54YK6zMD84KhP3kXq9p+782h+fXiNPbkm7PSG6fU7KQo4Za4uxYyIGt7ob0tBLYtHHX4qILjpZGSygIr2RrBKty14sKvymFBJKgjD08qdfWLXaN++OmFnFLK16eguAsPi1NvV9r1CoWWjRy6smbkUUunr+sE08vx346Fb4YAs32VF0YvQmJyCGgUyj5cr1HjbbO662M3kQTo9w7KWzh5LZM7ANMG5OtAIxK2EQj3wAElQaUztRqlc3tLbeheyPEQskCyRx/0wRuoG5EmtbkcPUIGifX51KIwSn1YRl7IWOaGXGIZFzZEijJwwhn5uFulm5lYhw0WQRAm25DhrBzEedeTrh793Lw5djtY/3NSExum5af0DoMiZhmR+2lrkNOZeaU6KIcRMr2ra0obGtIyc1SkcYr+qeL//jP8V//1d+g/1/+/idE+e7HH9fA7OPP/6Y999/n//oP/qP1kDs7u6Ov/E3/gb/7D/7zwLwB/7AH+D169f8Z//Zf8bv//2/H4D/+D/+j1FVfu7nfu57ej+v0y/9YG8fIkKtlcPhgFGJYe8q4cBjKRxOE9NUuLt7ICcXh9XWaENeA6WcIEnAgrrSeM9j3BTZ72iOsYtXehCXa+5cNukTtWfpgHrdzmU3LKAWHD1bhTZPDClxe6fs7x/YHDJXd436zYd1YfPPbW+V1dxYe1nU+0Bee6o4+9zh5QyBtwbpd6Tu3+W49Nu77GpZ/ubnz+rRB+byCyIOpSyboka2UrgOcHWV2X35I9IPfMzp3fcZnz/3199l2iYQNnuiiZP12oTd3gJC0PP5anNeFNXQWTEmkLyWmkUqJu0tvtf6GT53Hb5DzoIemF18vuWzrp1D3yUZ+3xJFzsHcgvMrhd/X0pTV2MkZePpDyjv/rg/8Poa9BvG4Ri4H5Uk3q2qr448dG3M8M9/gPyvfx97jTx+8pqYlHengZc8kPPIp4++GGyC8FhngqZu7NvHpgjW58axFax0FCEEYhgIKRJiBTuhzf9WSyDlI9omCArm7g1UUBYEuCF5Q2DZhDIxJUhKNrvg7vmiJzEgGgg1dnFYY1h0ZqRSq4EVYkvOHYU+985dcDHmTl+IaINV8jucuShmhvTSrAccRmuVkEbQaVVUH+IVZvd8Lb/mJ27ew6InfUt3M5zRpJSdx1JKYbRAHLZo84U5jE/Qzbc9oG6u1n+PdQcMofZAsBpsehnzKI5kTeKL9Nh3soRSu8jmqI6+zUA2O3PmOCPE0VyLT+zME1vG3pJALKjWdyC/fZII37mRfb4gI7BKbihcdOp62akbRriuYBoQZjTC4xPh7jZz1eLKdGo3xnQwymHCJmEalHvgfoTTJjFsfeBvNzu20biVmSeMbFSImp1v2Hc2CZkYMzkO5DSSY14DMwkB66VML00FQl+yNQnRglNj+txwLphfMZFl3e+eySoXXZkepLVuA+PVCu0BSEfCJDh3qbUeYCkiyqyBbLJyOfe7kWdXA883xlUSBw1CgRBWLmeImWiVHArjGMibLfvtkTEH3giUxUatBZS+V6XqHLkcsXRex6o0CoFTOYEZW1PadETrEVv2KT1hMiPMXhIMAUUpTYlVV/UD1W4tGCLgPLKFVrR0lCPiwWgYyMn5a8MQPckT1gWztYaESOg8MukUFoW1Jh/MKRoSe0AsXdamSyJdrvlRhNIrOt8NUFs6tcPFz0EElS7q0nX5kkxA4Tg8Z2iKtJljGNmktNpqnSRwbYmv/JMf8F//n38V/ovvfL/PH99zYPbw8MCv/uqvrj//+q//Ov/5f/6f8+zZMz766CP+hX/hX+Bf/9f/dX70R390lcv48MMP+cf+sX8MgJ/8yZ/kH/1H/1H+/J//8/y7/+6/SymFX/qlX+IXf/EXv/eOzKAYhcZwXjFMKUGINaBWkRSQnN0YNjnq5sNIeXM88GqKhNMD+w1c2Yia0LmSHNVJtqFFBjkQh4x2YmLovANrisiGE4+0rkJ/MqWGCrkHJgAR6mRo8MCqliM7BiQeCdVv4BgjNw+wr0eu7xv5FWymiftv2irqCufGX+NMGA66oGW6/k363xYRS/+y9XlwLk8ssU7gXIaQi99BL3304GQZOPHiuRm30LhDuWLhxXib8L6f1xOUzQjb58+IH/wo4Ud+nPbhO8TtO1jyZ4Ts1lT6eKAWJZugp4KFjOYBXcWERyIH55GFTM0bcohMHJlbWk8whSOll7GXcnLAM6JygZQtiIAuO1X/fJcNAkvAJnL2s/w8etATRL8m6mK1AwYk5n7lI7giugobjG2Da2uMH8DTHw5cBT+B+bVxBB4Hpc3AATTC6Y+/i/6p/wkAw498xHBXuJ/vySlwqo1KJM0b4uHArpcyTkXYDIm5VKacqTohppge17scSkDngTSODOOezbghmTJEwyS7gjcQBmXXMrVMpAz10JjmHoP3qHSThLk2b/EXIRccfTAl5Lhm1oiXH6IF0IBKRHMkBKPZ0go/EpMgCskapoVGQLMwVlmD5ySBISbEPFvu3tBOzrdFryauljfLgh1jZDtXUtywn10pfQ6FvW35b8Zv8Eff+UHi5gl6/ARJmbl3JbRkDMWQo6JXO8I8gRktD2gX9kz7GzbHPZvHxx7sGxm4w5s7tl2aZCln1whDgxNC6vOt9dL82AfcI9Klapwgf2/qaGbfDAo+f5fKsnQEeLH8WsbzkjQkIEjo+o3Nn7skN7it06KPJ5zLlm+hwwrWz31ZU5Yy064/7h54r8wYkXFjfMOMl3ULu8pd13VI0bDpSNmPmFXut3C4gnpfuH2YeW/j67g8zuTtRB03nIJCFratISdBN53LpdXnUjCwRjCXTzJx78y4lKRVfAyKq9eLVtSU1uZVyR7Hz/wcg1KopGwcSqHUidabxpYGlhR9Aw8hkGLG/WK7WK2Ezmcr1Jo5VYWkxGJoDOv6dRMq+zEwxkxIJ0jBz48uiA7MqmgCGSLxULmWHcP2gWEn7lI7+Htej4nrIbLdCmm8xbRgOZAG0K4RU0V5nA8oG6yCpUgprpepfWIXbczFaBoYgrLJylzALEGr3QcXtA3IHMk5EizjltPm5PmlRJkGcoiEATabjTdliBCyEINgS5OdBsq0cASFViqBSAyCdkHrYhOhlzpDEqR1FFkiEtLK7xN8XKZOlZiCz4mBzivvc6Z2BC2IA+RNGyKRaIb0YL2R2MQBKRMmCW0ZrHBqSlhsmwDTkayBn/xf/hR/87/4FX6343sOzP7W3/pb/JE/8kfWnxfu1z/1T/1T/Hv/3r/Hv/Qv/Us8Pj7yF/7CX+D169f8oT/0h/jlX/7lVcMM4K/8lb/CL/3SL/FH/+gfJYTAn/kzf4a//Jf/8vd6KqCeadda1w/iNgxKMFnLDVHBtHUfP7DgAnVZgNA4nuiCmIViR9csAcJgbFImlMJMJQ2Zq2FkMw6UnuVarOQoBB0JcuhE4y1WI43ZU2R8Uxg3XlNLBqllxCIFYe6bcEyBPUc2p8Z4NzE+BKZPiwvIcu64hHPHIHzOcqX/bov07NXWx6yX7eK5b3HEOAd/AVYPRKBvnGdOylL+uAwYW6dNDwQ0KLMZowlPMHKvRMfbPcOLD4nvfwwffoz84Mfw4l3iuGHx8tHTBFPFpiNSFQ0BrRVp3uwRhp7l5gQxdRsth8QNQ/KILJu+FiQ1RE9OiOYcaH4+oFqu4yqUyblLc7k+K8eMM4rw+WMJ/AAeBcSEORhGdV8/HCKv6mHyDfDkKTz/Abj+GFJSjuqf8THgXn+d1tj+0I75j/1e+IkfY9u5Y/Lpa47R5Xqll+iMs+XNgtJFnHMiNKIUYof4gwiLzU2MkSYTMQ6QjSiJKEZIDSyRelA8pD0tKZtkPIYH5/TImeMFoEE9Mek6c4rbo0iAqraWMgkehIk6rDNCN4qvZ3snhdpFcdVsTULEApI7Nwi8HhE9QAnmDRXxkqzdP2MKqQfP0rNmwZIRTCi9jNR61vxGHvnm7o53djvaY+/265NjkAFiYzalhkhKkVIKcZuwvipZmGnba0J4JBu8sXM5/Si2EvvhohsML70vwZP0z2dmnIADxqA+1k4LEmu2jtUqEJefl/KlsN6HZTwvgdYc4aYZTdxbdWysKuhzdJ/WNVO7GPSLb21/+bWRYZk/y7FSBXpZc4tyiMZhG9kp1Dqv603MiTq6Nl1LjTBAGMGOcJwnHo6d+rEZSdOG7bYwJSG1LSU0kOakRCBkv7cBIXWnmCgOhYhdqIwKvcx4XvSkk9KXEqV1oWLVjp4tXZfYW+T/RWTWDds7QitCk+K0GhaU1vXQTAWxhFhktsIYA8Pg93u7GdiOG3IOhO5mYDE5arPQCkJAAsxy6r65kEYf8ykFNh11HsbIJg8MKRGjodlRN5GzzmBTReeKhYlanBNqZeIwn7g7ed3m7nDkMBtFZRXKrgqhecONXei5BXMcdE7R564ENMraDRvEddO2ITtPKznXzE3LY3cK8SB/FcQOmWEYQAzRxmJ0ZCGgrRBjPFt+NX9yEFmrtYuLxQoVX4zRJdFYLO18HPRxGSOBgEgiLVZx0mhUxCLJArmXTFW8PA1wpZkYIk0ym3yOg36n43sOzP7wH/7Db3VwfP4QEf7SX/pL/KW/9Jd+28c8e/bs/y0x2c8f3gGmiDask/GaNqo2v3jDSJTIGBOVTF+XyZuRzayw23OyyH1r1HnmePIyTIjnGjmiDM0IpXKSR9J+JMfEsLreKxICgwQGtm4RAVQRSou0pRUUJZfMSRpNgqsjEwlVuZ78Me8U5ctp4MM58fyQuDmc4C6xp7LlXFJbMl/nP/lNdHLvec2MncS8kiyXa9YHnFwsoOXia+lI/Dx/ag3K+nssiJxcdAW0/pwBb2suAa40sttVwove4/7+R4QPfw/y/o+i73+AvPc+4fo5LRTokiimijwekfsDobi1RkOJkryM2TNFiZEanPfUzBeVKMlPMi+IZoXYsNCQdrY68o8oLp3A24K7S2C2GopfXHfrm5xxDtC+GxS+cG3mAMmM1Cu6l8G19drmfgvXvwe2z0C38JlCXsyOtXe0/Xjm4fd/zPwTP05OO4a7T6k98RjySK5GESN24UvEO660l17AgwltRp2VoxWiLN2Ndkb4cibmDcEiWQayJGIQkjSItpbwNhqgBcoMQ27E6CjTErgCiDoBRbQHPn0hF6AWI+UuVUAkhqVraikRGKJOcgYoakxWKVqRYCzu8AGff2s1XtQRuL7hRGAMzodZ+CopxpX349IOhsSIilGL4j2OfsQQOOiJv3P1Cf/z63dp38594/dPOXS5hJmCivPbpFXQkaWDtdY7ar5ie3uLvnrjgUm/Dm65ZH08djTLHDWjsXrK9tzNnTTEESwzn/fH/rNwFsi8dPtQFr6ZW4lddibXPn3H/vq5d7iKsCYRNOEogvfEeofkd+OrCp4EL4nfQq+4nB5q/plvCNS98foqsG1KEV0DZ62NYZNpVmgRxMX2CRHKaaIcfXObdxPzpNTqhvCFkYx7oS4cLe2dh8AqCutdecG7YC/QsCVINz0HZK01tPVNuCq6Ek17ObNTQy4FZlfvzNZWzpIjs3rRWObd9L0YiYgHjjJEhiExjv7EPIQ1sZCQIAoWvRy4zNkQgisAqDGp0QK9mgQpZMY+zzbZSevbNJJyoI4bN183WYPPqTV0dku7aaokOVIOE4eHiTcPviY9TsZUcckaA1zj2+OC7iri18toQbEgJBLSYA7m1946EgbUFJFxICWnF6Q+DoIGFifpgDuUWDMsivMDpSceoXfWijFN6t2atVFKwWojWnAKwkogc59cvcwwnI2xdncve+SSZJvhXbaI348l6Mr9fKNRI/43Akni6j9qGUIQApmQP88C/+7H/yd88C+OL44vji+OL44vji+OL44vjr+Px/eFXMZvd9SizK2XQBZhuTZjCeeVmWAhYZIZNlt6KdpJj819LivCYI3HQ2CeZ07HyrgkzAonnSljr47YDDyQW0Q6H0osEy1SKAwhE8NItNn982RkUQ0yOSEISRsjcGOJJxJ4asKH1aPo9yTylekpT07KrgiJB8reiO8eSLKoEUFJhiYvlWhV4twI6sbh0iP5SnhLn8btKtzTTQ1KzwJLU1L1TDeac6KcbxXQcH7ukn0vnV4Rh3brhZUS9BIKFTPYEdjeVtLzW8KHP+Cv9aWvYO9/RHvnA3jnBdxewRCx4x30zkChYPMj9nCHTJWwdMoOA+aaCv5aQEgZyxkpGWzm1EoXiOxZbsgdVUuoFKKdkb230DN5++ele21pBPB7eCZPvyUR+NsAyIHlehrSQEhM/ZmzwkaEQWD3wUB6d6YBj0c4PolMT3rG/NELpo/e4zhc84ByfPkNBok83T0npc5hWJorZF67kMzaygmR1V7JT7Rqw2Yvr0TxjuTQIcMYI0P0LrackpOmOweLntUDxJqQDPNQ2QyVMReOodArl37tO/S/XA3T0M/JPJvuY9WtccQ5YFawaN39syAs5OyGJUE66iO148Od1xku7l4IzmZ2WYNKSr3LsyOtC3LSE2CnP5h4uU6N1vkvm9owgUGVvzN+gz909YI47rB2oFeRCUQGgWQFMUcCajmRbCBtnKTepkCbMrurHY+v3jD2sbYgrkv3/NDH5SKCjPhnjR0NW46TOB9mJHjXMj5Wsxd/1rFacNpEW8fo23IZCyq+lEsb3lyQbGls8eswm/La1BsRzLuIUye1v4U3dYRMYbVlWn63HCou9WNAvUrcbSKhFOdH9fvT6ryWl2NOpGpuMJ+N6dgop06qPs6k6cBUE1vAwoTERIgBWySIQgK0bw9Ll6m54K7IOj69bJk6NWM54UDDO/sBrLXuf8r6nCCCdMrAWehU13KmXAighq6dtvygUVAxQhLSEEljYrcZ2IyZ7eiP2yRAXGajSiQP0pUB2ko/aE2ZtPJYCsc6M7VKURfZDuECgQpOjh9jYjMMyC5xPx051BPzIjdV1Wkj1pjmCTgyPRw5PirHDvdP6kKxtXMJJfhnMxWCGbpQUkqFGLxjM7pESV1qLb10miQwqpBDIESIwVFuVaMhFwiWIdYwC9CgzM3XOdG1pKtUqgrzfHSJnNmwZkQL5JC8hL3eVdZ5Ikrnfp/pO8G85yMu5X/t908cMcsLhXlM3YQ+ksUbNxYK0bK+5aQglRgSQ/xOdYnvdnxfB2bTocIWorlWEUASZegkf4oTyRky47hn6OXOea4kEiG5vswYjZiUN3eFqUDpNUPFVs/LMUK1itZHrAVK12ZvRIbqEOUQjBQDwh6zDXOaVg0Va4E2FHYaeNbgA828L1vel5EP1SPBZy2wmW7I5YTaSBk28FEmoh4RLDYC+ZacNr5wVkMOJ2Q6Ues9rXo5cBO83GJ90RB1TkOsDTFj6ETV1lyZ+TQ3alGawVzVtd2W8olfSg/OLhZza20tbQIEAhVfwPcG10Ni92LA3vsK8Us/5A/60se05+/C02vsyS3tak8zIx8j1rvY5HjAHh7Q1y8Jk2GbDWEYaKouAth1csTEV/7o3VEWIrXO1FqYVxVu0KXfy85ddH37+o4xtQRei+bZd3Rlci4jfzevQbjgJHDeaDOgkllVjgMQArEqr4eZKHDawcP7AX74A/KzHwXg1Tjw6fxNeLiHmogEZJcoNlGmvqXHhOUdQdvaEaZWsK5AvpQpWimEkLvqvnaCtxCTEdcuI9cHsr7TxhAYUmRIAhJXD804dDXxMrMdTmwGF1oUYxW+9MCikdRFmQ2hVr8pWhs2DP0uBBLel5664POYjNTCGmBLE7CuUbQEojVACCTiWqZo6hpl3mUnNHMz7PFC6wzwZK53Ny41eyVRqYQePCeUmcxeM18dXvPt2yMfjDvK4cRx5eSJ81OrEkuDmAmhQatYX701RK7SDuwRS4Ghb/Q1+qZwbsRxsv+ErV2Nte9huYdAYrbqgk123uSKeEl2LcmLPzcbq9CswVuelyy/w7s/yxJA9Dn+pnOavhlhqv67BN1Fw9ZgjItxPnAOzM4dmj2b64+v5t2Zp03g0dzTGBPXIMT9iyebiTkQZPSNT5UoimnhsY/7OAU28yNiW8a0YbctXI2BIQzefAVoEEL0kr2K96qLKtr9HReSnJhLYruxu/u2zrVSi66NQ1QXULXWnUUW8dj+Sdd5thqZ63pfvVSZiL146dp+gZADaSOMO2G3j+w3G67GLQsl+3rnHCwVpRIxqqcqUteAcZqVwzxxfzxxKpVSpvO5CSuBfimcphDY55Hdds8wJPQI951GcpgLeqokBJ0rk83MxTVBD325ORYo9UL8uX8PvZNk1eMtCoMSLKIFWhfMCyYe5eDi8mOMEIQY+jXplAtb20xgSSdSFCQYc3H3Agm2Nl203nRwmg60AtLCSre55JiG0Pz1+7AUzmXLRXLG+udLATR3zrVEYkoMw8DQS5I5OsCTSaTghCJRQcKZOkH2ZHAcBzZDF/v+XY7v68DMjoYWJ1Uvk2K5FrUpQ4SYBoZxy3a7W+u7s8zknJEsDNmQoVDjzEylPU5MdUEaOs9CYT/BEAyNMyU+MiwkvhDQmLmWyA1O8ptTpTLzEBIqPuCbzVy3xjMSP8iGL4ctH7ZrnpctV1MPJKtgkmlhi13tsdAImwR542rquSMM2yvfeDTC3AhTgeM9PLzx70Crj2gtzncxX4zcn0hBjc1pYWj7IpSG2W1wWiWWwlQasZz5PbYgTdIpXnrmoK2cIvGqvWlgm5Td8wE+/DL6pR9BfuAjv2fPP4CrKyxvnDxSA8ldxrGHPvNfP8LL14S7hz5hGqqFYDvvi2rn7G6xArGe52qITE2Z+0Kj2mhENHiDwKJ317e081hakIP+t2pnMdkVMeNzDQB2RgbeHpidiyfiqAXeWSkyreylIDA39Qzyx+Czf2jP8fl7hNsfpGWQDu/Op5eMYc+dPpB5JFiiTMYxGSF0vQxmxF6S5ZqqjdYqqhWkeebXr9fcJpKekTMPsq03dMj6O4sj21Y7edlNnIfggZMt/L6xAspURsY8MObEGHF+Sx9aEoyoAcRtTVaVk1UPqAddna+DNSxI58h1pGyRBAj+wr4hajeXD5AioZ77/SUIZs2RK3XNtqk1hssAtTkXK8RFOkZQc726ORi5b9YtGIXAjsB9eOBXnr7ig90eu1f/XABDoEWhToqVQtwPRHGNqbqKSw8Mceb+oSH7PeP9I6duyRPs7A5iakBAoyOsC//zbUq6j7sZN0Re/rbwIZd8PNCDui5XsfC+FuQb3h7/qSPEfYnh0ZzrCPBGfZwPdubBhf4CZ5V1/z4t79FjZ5eYsfXew7lh6FvXwqEYO7xhpS5IC4LkwCZ4whVGqFtls1OCwcvXfQM+CqU+MoYn3G7gyb5xvQlEhGnpatZKTtpFQiu1zljMEAJiaRUmF1FMI2rC3BrTXDjNE1OZab11W6pitaKtUqonLyEE7+5WvQgQKq0j9wvxP4SASeRSSy8NkXEM7Daw3waut5mrceBmt2Hc+HuOo6PJLbhhvWrD2gmlcuqdwcdD5WGqPNSZQ6s0q4QQSVmAwjR3pLwOCJExDezylifjjiGOHAXuZr8Ox9moM2QqViqn5h2XZWZFzGb1QRkS7tgg5/sfYI1uWjNC8+9ZOkW/ZwbS9+KYkt97GhA7x88lMVTP40bUENQ7vAOds6poa8xz75BulZhklb+R5or/qxBvT94WmSORHrjXc0VoQejCUh1J/nFidPQ1pcQ4JMbR16crGbt0ht9fUW8sSQJj7zaQYWAzDtxst4z/YwjMwkODJ746LJrXTSLV1PW8kkfg22FkTPkilYdoiZBhGlx/jCg9y2vULi9Qq8O1zWAyfLFskOvMIlN9a5F3wsgLrnkyb1xtOZxoEvg2aV3pLAofHEe+1DI/0va8r3v20xYOA9TelyWRlkd0iORnO9KtwSYg6QYxg+2NP+56D02QU0OLwnyC6UB+/Yr05g4APX6NWApWjwStmM5INS9bLVZV+KZtAYacGZqyUaWUQmvGqS+UtTROpxNTaavF0TKQfXv2o4MQ7ALcPHtC+sEvYV/+Cumjj2nvuEaZ7m/IsoGq1PtXxPs7QhHaq2+jLz8FIL38FHn9KXp6xCRgWpG6QaRngf28TFsvSyzlV6FacMmTPqFbbTRTlOTlzB7wLJnUQsZdSjlL0HXZXPFWV6aA9zX2ReMiOL0s6ziJ2+g64cxibC78IUPrkPpPw+Mf/TL2pWfolNlzT7bGOPi9fmNbZhIbG71lXB7QeETtGZseQGhsFM1Ym2itd2JKI/XEVNZaknbtH98sqi7+suKIFL3lPJ18A9NGMD1nmyIrkXro2ec4HxlyZjNktpuETnX1yotmpODltaL+e41ex5t1Zlj1+6JLNSAkiWhovdbnPoHg5Umx2gOwymydDN0ConH1gqQXp8RwQUzFSyl1Xs9LNdIQAp6pe4dbz6BjXDs0fKuoTAiDRv7L3Wf8T9+5Jb1KDF3PTQKQxT39qoFVmmbAxaoByvYWnb/JfGjEzZ4hDsyvPyPoOcAHD7YqbtVzGey3CzSr9t7rAtwLbHFUrPUkYdkhrUuXT51MsTQA2Pkh3mndh/BmGcDqXZMvxYMzv6JnhFjxJyzI+ep3yvn72oi0IMwXSc/SDLAFvnYNxxk2a1edv1iNgRc3G65TdA2yGSjK43Vla1C6ZMvxNGFtwxAr10PhdhjYj42YlNwJ4XPDnQBCBS0u1WCOtCBh3ROsgZlSmjCXymE6cTidOE7T6nOX2oxopWkvU9I79QRaKzQ9VyHWrkDzDDZ0Av/agJIDGxJXm8jVGLkZEjcpsh0zV5uBYdP9mtPSXtCvp1ZqLRQ5MVXfg46l8Vgbh9qo/VqmlNhsBUmNw8Hn+eExUJ5ExBLJEhsZiIOwjydyN80KKCKNUo/MbWKehWkWpsmYLuZFyj36CM2DFhWPx4w1AZLqVkZSGzZGr0AEIRPWBCDjFa/YxaBVldqdcSqpN2L4yArB/UkNBVEknK3YoJej1e3XyuRCuWLiwezFAyWcqQTroX591zVcfb5U9Rgg9y7XYdiQc2Lc9BQ7JrJkT9RUepDpSR9bv9fXQ+bJ1Z5n13uyXaZYv/3xfR2YybFRDpU4sk6w0hqpKXmIhNENa/fjwJAT2ieFxAha2dXIMRpZAikkz/50wszRluMRhMQ4VU47KAleTCAopyvf4N8NW35PibzbnvGkGCEqLY5MOnET77ntNzCkDR/Llvcs8H7bMOjoL3gMsGhuDZnEE8qww/Zu6mpbheEJMMHgHqP6ZE9o5unpaaaWI+EwIrbAvxDnN9BOjsm26uWjMiOleTVtFfBq3f8tEvIWJHVRSLjqPI2gRj1OTA9HHk9HTvPEoVZmO/vwAVSJjDSus3L17Bnh/Y9p776H7G6QnmEs+jJMhTp/Rp0rwwnCp1/HXr/2+/PwBn289/MleUDZhNLLbWfp8uqT3mxdCBXBwrkNvqpR1VzQN2QUJfTtbQ3KLg6FM5csnDvMoCNkBk1s7Vxdx+LFaxjnzTYtaIU4P2dekfktN+FI/INb7P0XpAo3uy3D1g2/pY/V602Cx8LrvfDwMDPXgVQzQ6zoyTshi41M0YhasV7CCCEQk6w2J+DQfpsdBdO++5t4h9hZOVfQOrsnXeexhI4MEGQ9L6kuLxFjXtvyt2NmHupamxO6lYoZVtUFHSWhUQlV2SzyArFh0TNniZDomkPazp56CCllR241eHmuRQ/WdVj9DIMFUCPF2DXaHApqvO0H6PezI4dLWdIqmZHQN5Yqxtbgrs3sW+I30j2vn0RuN4l8WrpmC83Ur0OZsGki7fZMtRB7AjHunjCVz7i6esKJEy0IISfSVDnC6iIymzEF2PXbEYiI1Lc4jWq2orRT6M3Hy3g1VgjL8MBp4a8ZvaT5XcashS5ybMKM8S05C9eCj9klWFy/+ia2BmRLYtMDtqWM/9bmx7kcOiB8khul+fiYW0Wuuj7h1Zb3rkZux4HNZsM4uwTDm+ORq5o5Pvh1/dqnB1Ia2AzCflT2ObHfROIo5Ohocm0BkUDTSJl6ib8ZYoLJvJbmrRcU5grTXHg4HXg4PHKaZ2KX3hhsJkmjqXPNYufZ+bpzFjl2hF8x2vo7ESEk5ycBpBQIRLY5sk2RnUR2IZHHwDAkcl6ym6W85+xAa4pE54/WXvotVAqNSQtVfW7HGNjunJP20K/Xw6MxnZQyG610JEgSSRIDXbQ3RSRXjq1xbDDPUOdeaFm4o+raac28aBODVwfMvBS85IGm6gmTBqxVv15RkJRWOzYzY46wGxJKQLWh1SjFqJedp10GKMbAXEtPEGVFJMG7rUsp7osr59d/q/uSc2KyzIw1N7e3y/w1eFCm6qXnGDM5Z4ahy3UAW1K3kcqrJ+fSubtQp8LW2NwErm8zUj4/I7778X0dmL1/f8XrozENJ3LqiNkAUzKuhkyQDeOwd1FYEbbdEFXyAnc2wgbCMWJJOdXMbt7xMHn2MGthrpU5Ozn+iQUOqqSa+cHJN8Wfy0/4B/JX2AwveFYeKYc32DDQhh2zvmbMvoA/LS/40knYtsRQA/YA7WGP1Ix2+6BWhZgePeoO14TdFTJm4mZLS0/R3mKciZRhxHImpEJ4SEgo2DDSRi+x1qun8PolSR+Q1qinQpiUUIzWkTSAeur8iiESniZ0/wTZXEOyVd24RiGUxu40s308Ue/ecLx7Sf3sFS+Bh34/ihk7gye3N4QPXtCePIXNc8r1FgldFLIpdb6D+yOb13fI3QMcH+H0Eh78mtbHB8JUvaVaDBFF7USa89uaO3VGWwErKEoLjh7WOq9xpwaXQUA8u2rqgeREYGPuNQgLWuET4mAuhOmivcKimjVRmWFtoV42pmWqXU4m7T/P+IY1axcw7JP2djzy9OcTN7/wg2zDwDYPZMnuNDEYpb9qbSMHJq5kwpJhZYapYLahLnyIMnuGFt2DzhGzwVVD5OAEYqCpZ6AiQlPX5mtRMJEVTZpb5dB8cTf1oN92I4Tkvn/9ulZmms4orhs05MyQMjlOMPRAh4ZJoBT30YuSqKLU6hvHlDoq0Pk22zETJDAOg89NIrEHSQVXsHel9gm0up5Z3JBDhUU8lkCzGdWZIQxIiAgFWlzdDVTBaoQckWDUao6iza4VpcuHnJS5TkQTHmplqid++b3f4H/3az/A4aHP2cMb9sMNrR7QODM8JGw092PspabKa2R4huwr8X4ixz3H3R2pczmX8CbiRs8H8bJLa9W5Ypw3u4STo49ilAaPuGvEIPDKYOyR14jxKI5GGm11/7jclJYxLOoSE5+acBChiBfcljl0UvdvrXYumwpdGd3efr1ki4Za56UGR+OW5KbIEsQZ39CGPAqveER2O4bub/Kl/Y53NgNX20wKjXE7EDSyy0pS5UUvzX31TcamiZtwx5Nhy2aYudm/x24bmZZGggk0Ju5PhTgOlBo5luYlSTuTuI1ErcZU1LXSTkfeHE+UVsmzn1fNgYFC0kIUFzAeYmKThFM7UwaUhomuzTdrI4oFp28AFhOJwCYru9wYsqFRiTIQc2W/7QKswQiiEDbUfIQQSaZIMUJ/v5iUWLS3fijbIbKVE1e7Pc9vjdNrr6J84zPhnev3eDrObDaVd1ohEtnkxNhv5C77vVMNnGqmHSMnOVKDYiu3tzfQjZ54HdSDMNHApLZKuyQJWFVSGlwDLno5MthM6em8hcxekkts6AHNiSbCrIppW9euIRpRKnFwqe5AoKpxbLoGYdKUZIbZucAvZsTm1IfFFcerI+IJY+v0i+Vz9XE6971iEhhGl18Zh8SYhG2ObBZv4TSyzxsGiU7xSeYy62ZM6vtZHBJJKptQ387gf4fj+zow05cn8nsDx52uojtRnZArQ2KIAyA94g2rU30IvcYfhOMshACNxmnecZyPbI7+uFPz7VHNF6AyQx3hPQq/9+QT//cMP8IH47tIuaZNGXklSH1J/qDxQ7s95cFfa7SZPbeEU4Y7QR8Fm41AW4nXIW2wcU/YPYHtNeHqmjYk991L49pV1mIihBEhgSXEeiZYK9LT3Lw/0Y6P2MHtLayB1UadC3UuxP44q0rRShoCtR0J2ZCbHePNFdaFXMmBUBQpDXk8Ed+8gTd72tWO9sm3kI5xJ1OuEuSbK9g/wcaRMDbXyqq9hHia4PEEbx5pb95g93fodCKdHtBjh9OnyVM0cy80k9YbDo6dc9Yz09qc4N6aawxVc/6QnuH0hX/WtHNXJDJZZUSo6Mofy5yN372EtTQ9GG3prMUn7hH3Kt3zdvmyW/zS1i2vreXR264Cd9W5Uzc/e+Cdf+5HefHiBXHKbNJIFs8jm+pqdjw150RIcKOjk0IpXsJdAjNi11+y7kcXHOYPcel4PAurLlkv4J1PRftuvUR5gVphmg6U7YZaJw/ALBEknImFvfyXU2TMxpCFYciM43bl7URT5tqo1csTJkIT6UhCZi6+4eUUUEuYpVWHzTXIhLnjNqsvvehKpg4BTBrDsMW6IG9rCpKICFjoqJMXqG3NnJ1E7QKuymIp5RfFvPMXwIxJK6V77kaD/+r6wN2zxu7THtDnwctaMZHmwR0+OmywjkEzSIE8jkjZUWtlM26ZdkK8f6At+krN9eQOBrnZ2hFmnPXvFDujYz1Ymjhn+uUiyHOdMhd9XtCySz9NXR6HJ1cV5djRumRvOwSYXXAr+c5DltImrKTqfgnfEpsVgY2BboEf/lE2oVIfJiDxpPNvnm6uef7kmutdJgVjSJnQjLZJpGHg0JHW/+5rn3CvCU1XWCvsxw37PHGTtkzZ7+cjTnMZU3RPxX5NSyt+n5bzpFFmZZobx+PEw+M9h4dHpqZcdcQ51OSd79qQoMTQr2lPhqSPe0fvlwYAv1EBcSHypVMvpi6uilsqxehNNlq8LNhXlhYDLZobqc+O2J2kehWjByTKIvTq75dDZDsK++2G/a6Sxw40TJU3b+642zeeHfeUGaxOYJXFw1MAZCIKZNkwh1NHgtqKPC1cL5JQtRGyItU6p1SwvjC10rBYCalh0e2TROJbvN2gSmkzWQzVjDVvQqtFiJ1+AhByYMhb5lBgCISk1EmZ57pSMqSXYdvcaK2sjUK+VtraENYaTv7v5ctlTATOicvSXT4kGHIkJa8KjDmwGTObHjFuhw1XQw/MkhAHbxoCXd+vBGOTlSHp25D173B8Xwdm8bOZq8PI4VbQ3jpnqTmUHKK3MKsLW5qEtVoTQmJI4vyXuHMJADNKmTmcDuwOjjodp6PDs0Asrjz+HvAPGPw++xIAP7F7D+QFnDLTFEl3Svv0m1BecfPBBitdRDOYpxivB+qds6QDCYt2lt4Y9tjmOeyfOJ/s+hrLAUlufrsYSIs2V7snEFJGx63bT00FDgsR4BrGeyyOiAxIO0ExdK60UtFjt+DQhkYjy8ZF+JKg1wO8+wTZP+3nFVHruO7DgfD6NfHVjrjd82wYib/1DQC205HdTghX19RhgGAIM2E+rar+PE7I3QP6+gF5vIPjEaknOBZC77aSMiPN+UzWmZotRKI23ziXxbR1qFobWr3T5szt6Kgatv7sYyBQLRDdyIuOSHOqviGegGSJAaF03OocfLlS/2Jr1Tpi0Mwn8bKfN9TlRBBOGBucl6YUnvwZf9CP/B9+jvdf3DAUSMm9/CLO/SvNugQGTE1JrRDjFUZgao1H6+Y4fUN3WQg8Mw+G4CVMs9YXFB+DuSXm5Ni8qG8WNMWCrf6QHgFVajtR9EC1gakO5CYEyZzZsQUJlSEmxnTNNhnTGCjbtvL4ko40Kx6YCxBcnNTM7YSGvnnG5K38IZxLr8FAkyHTUr5WVM0V0umNDQhIQGJG4mKtUwnJvTSTRGIXotRWVyKxqlJViSpOChfzLsNOVly8WKt6Z/KhNko1kkRepxP/6Zce+CP/Qz+vUwI1dMhelq+T86vC2/IJanjHdh6YppmqgSLeNVYu/NCEHoTZOQDSy5JkH2eL8LHgArOBnjz2x0Xxnxd0a0kwL/eF0N8j9NLoEUcIepPkdza99L9dJqt2PnU/V85BHHDu4F5+FuEKY/7wOU9/6g+RP/stHscnzJtXfOmdWwA+enHFe0823GyzoyQSCNqwEsmbxLF33714/jV+8xsnji2SojeaRDswkAi9jKQb4TQXspjzJUmEmBxVJxM7GbW2wlRcVPU4H5gOjxyPB4oq47jt16sRpBBwcebl8yy8sfV+VxeXteZr0ZpoSHhL1sUIfYw11CoBJabuZtOWwMwTmSYnjBkLlapK0MKpi1Afa2GuSq2GmZdCt1vhZrPh9sa4vnHk5vXLxsNh5uGQPAg9HmljRvXciRyoRI0kMpvQqONI7Yh3q0vZt4FELChIIA6RMBVs6VBdbnZRNCnWGs0azQLBhNrL1wAhqCeC3RwzLJ3XmlxAtiyVsERII0NoXtpXRXRC2rwmUqZKK/Mq7qt9navVO2trf1zTnsTaeYB7A4N0EVnITqklZcg5MIyBcYiMm8xmM7Dbe4ywi5mrIZNTIOZIzsIgDZGZ0ifQROB6q2yGdjFTfufj+zow425m87Kxez7w0L3RrHvkBVzzKEoiLEHaOhi8xp+DMMQNAkxtYr/dstls2HY+VE7BJRrwhfFZg582+EemH+BH+XI/iStsGqklkEskEmn3kXAqKMa86zew7lGUeg8yRWIeIEU0Qet6R+HqKXL9HNlfw26PXV2512cIPZPv5x8jNCdXYoYFgyEhQyJ0X7QyDTDsCfkKxSU+pDgcrfO0as1YAAkDUbaEkLEg6BiQm5G494WS7YDE5Mjj1YQNWy8DxYwMgeu+QOyObwjbEcYNRiXVA3avaA3ebw/I4QhvHgkPB5fFmE/QJmRSmLv7ZG8yMDKmbhsdvL3Q3QxWPlTnOVijaqVZW6VB1h1jkYzo8PLSC1gCiASmeQluGlfJgbokFWFAVBgIxHTeYFNfd0ZwH0cLLPo5sauNJxqz1k789/e72Td+5i/+MD/7v/kHANjYCXkT2ewqQV35HpIT1aWuUH8MYJuMTo15iIw5MlclSSJnH6dj9O0ydDXrRWFcxFGzsG4ixpAipWgPhAKousdrOs8NCY1qJ0o90HRDqQdqi6Ska/BpLsNNDDDmwJiETU5Mw5Zp1vU+LkGZJEcrFy+iFM4onZhnmKvZMG6xJGIMQ1cIr1C1SwRIQ6x4GTokaksr2duV3fsm2L8jyQPdflqqgllA8aYDVJ0U3nkyq94bwoQxGxAiOWTeKcL/8903/L53n/h9fN0VxGOiEBAabZ7dQqlnzJLyQgAjDZmQvGwKwbvSOgI0mZf6FnX8jS38MztzzDgHPQv3rOENAIGVp+6lRjzAWj76EjTJxfOlv1/qwVSz4Ag8F4hZL6cmOz8P/HmX6JksXLeLx6znvJSL+nk9/uTv5csf/INMbLi/fp+H4Tf48Lm/4wfP97y43nIzJlKXvZBWsBYIGR5nn2dffrHh1796x2998pKf/fg5tJkxXXnHoCy+mwMwe6dkFcxGUjjPy7oE4U1dLb425vlEqRO1zWiFOS6IjHlQFqoDrIskRfTvl5qRy9cCMMs6N1fSFK31zmxzDmoLymNQiimlzyFplZAKUSeqnhA5Qn0gaOGxI86PZeY4C8ciBAYIG642idsr4cmNcXPr+8vD60cOh8rrhyOPD4XD6ZF8HJmlYLbg/T5YIsbQ98kYfX0KsSOC1pw6sgb6vs4sA2sJ6D1478ly6Q1TzL4mLPmdRIrgvpZyJIgQNK3r9jLC5hI4FSMHA3GPbKOAzWifP7W6woA1dcChV05qadTmzghAt188j9zYO0ZFbO0PjL60Mw7CbuNd58MwMAwD4ziy65ShfR7ZdfmMNIgHb1I8re8XYhbjaoTNaLTvYF1+9+PvLXz74vji+OL44vji+OL44vji+OL4//rxfY2YGQrfnLj64JrDpus+ZcgtuH7X2sRnqxozgBKJMZGHSAobsMB2njkNU9dk6h0XeUBbZaye5X0pwj+YAr9n/pitLSK2e+QE4RSJhwLzI0khfpbQa4XnHW15EDQlKMFLJtmwlNGckJ2XDOP1C7h6ju222Dgiww5JwblWvUsNemnCnFsVmnerNPNep0X/ibwhjDskbzBcX8VaQGejlLbW0ht0LzUhSHLi7piRYQNj7+bbDoS8ccbnMLrSczZiFE61El84uTTVDcQdcXsFckTmgfYYsOO3CB2StuMEhyN2OiKnE3KakFLR5jpQfk2bd/L4q7qWlTn07wbCZwSraHNvtFaoNmEUtJO/oYsC92KM0jpilr2rqVWevtsRs5++4eXpEZsbu+eZjw8zL78Od99sTMc+kEJARXnHPMv6NCnzyRW5j9YorXt9AhahjUJM8NJg/5PwB//kT/KkX/g6N/bvXHumOgdHQS3QglGJLioMSAjU1ii1MiRjHAJzdd/LsZsTZwKGdeX0rvHVkWP3wuycLzEkiUuIoEhcsnbQDoWlXq4qZWIuj5xaZtsitUVU44WP9RmVi+Lt+TklxuxcGfAyOYGOKHTtsBSJIgwxsjTZxsUjLyViysTgTR7RgNYRs6DUWKihECx6+bb356ucW8FijqSOci/8HjG66veZ83XOmIN3OjrzzceVLdwdUDFkyARNJBJXjHx2o/zqxz43fuY3M/ObOwjCJiZ/7dqQKqtnYBgcvUcXtNcFNE+nGZNI7PBBEGU22KTIQ20UIH+O0bWctvR/L6hVwVHcJR+v1s3gLxAs6U9cy4q9DB87V+0AzBgj55Kl32tZS0/LoZ1HJutjvrO5YHmGyVmkObRu7f4P/cN85dmXeVPvOYbGHIz3br8OwIvbW56MwtUYyWkRIa5YO2HBeHbygfPuu1c83T3ym1+952uvbnhxc0+pkUkjm+ookURDtdKqYi0QQ3D+llZvpuoNVaUW5tKYS6XMR+b5SCuuVzbFM/8qSCXRHHUNXkv27m5bu3vVKnT03rowiyNrkZgXY3thao17K2xOgXGYQCIiA0M1UtcVIxopN4ZhIqdH0ANZneU69zXuWAoPJ+OxRrJEpmzsdwNPjpGHa+HprXP3Xm6OlMPM3TFwdyo8lMA4TUxyQvv7WfEyXIxCihHj5GVZCytdQNTRKR97iWDe2CPJXL7pAiVsJtAadY40M6ASgq7IVGu2urCUaKTgFP2M8+5C1zGcamO+m0n7SohGs5ljO3JoRx57h/RcOsfOIqU5YlZro1VHyZaGyKqs81H6RHBttPP4DQGGDOOYGcfMdrtlN256fJAZBkfMtsPAbjOyGQLDEJ1rG4zIQOli7yczxqzkpTPm7+H4vg7MGtBeTYzf3jN0j+xDdhi4lIJGpbYZtCyPBqC2E7NCZuMwbRLyEIlBGLJ3qQBshozWzJHC+yf4sTzwE/UFL6oRag/MTgZpgxweqceKTkdinGkWsXth6LuPHAuyjSQEyRFLEUkjYbwijJ3LNT5hHvfEYYvmSCChtRFdI2DtiNPZxf9C6SJrQTo0nFw8EYhpJKQRjRmzoU+g6LB55cxPigHvITBsABkTNu6xfIt1gqNuMpIzGrLzmMqA6Q6pN+S7Azx5BoCVE4QNjAEooBNSNsTpAWqXFzjN6GmC44weT8Sp9tJlw3rgLM25Er6rF4yISYO4RZub1IMr1DdzIdSmXd7hoowAvl+rtb6AdgVuRixNxFZ5/pP+Gff/+y9z+wMbxuv3mW8y9U3h9I2vU77+GXp3Fr6sG+UQK6cGv/ErA3/7b73k6187cAqgTdbznwUeTZmCd/XsH+DxN068/9M9aHEOK6cQGVogJC85OZtQiH3DaK1BhJiMrJExC2XMZCJjOm8Y1rz8p95y5PejU7UXVfwQDWsVCc11nSxgQWjVaMtzoneO1jozlcDp9Mgxe9k0p0TsErkxV7zoGBEGYijE6Av6YmkWNJOTMbtoG2qBIJGQhCCBTU+AxjSSQ+6LsBeZnK0SoAs0I5UQGyk0Umy01NwKJwAqXs4HIokoQtAu7GlGiJEWZi5YW31shc7/MrQHTYad7bwwYhgZJBBjt17bDjwdR/7Wl7188lPPE/ra5QquNgPt3llEWsvanSbmLgu1TH4954mYBy8xBn9tALXCpEqubRVqdQrUmU1/qWzCwuUybzjRCzbXBAwEgulbZVDn03Q+rjjXp3Uu2qN42UWM1boMvIT5+e1Ee+l0DdcuAr5VV5C3+W3gTTb55ordz/wMm6fXRD6AOBP1kau9Bxq3+xu2YWIc+voZAyqu2TDViX3Xh9rfZF68G/i1Xy38l7/6yEcvRq5OR0IObJdmkKEntGogCaPSVD1wKlCs653NE9NUmabGqRyobe4G5nAsS6cUpFi9IUkX1wh6GVPfShibvV3Q9cAsndXgLVBr5b5MRPMGjVKVNCUkH5HurNGikbKx3xb2qZCsEGIhrgknHObG40E5qqHJ0BQZhi03e7g9wbMbD1I/u058dircH+H144nXh8Y2KjWcqGXRFLRV58u131pvnzhzWjUKWqBYQEjdMqwh4gLRS9c8Iog173pukUKFoMTYPGADzKrLZIhSsxCH4CKylny97TVwd3OZGbUQsmJMzOVEqzOt7y3zybfDKj1p7wGZtbPsBfjYDBdl5vPAZWmaJST/d+4lzO1mw263YzeODCmz6ZShzWZkvxnYDpFhhJyEQZRIpfXHyFQYshGT82b/Xo7v68CsAHlujJ80ti/8A78eoc2BQqO2Ru0bNlL9C+eYVK3MbSbg8gqrWjrqvozANida9qj/RTR+errlS/UZqVVoXfn/ZoKjEE6PtHrEavFMIYGUETn2x8XRSc9xA3nEhv8Xe3/2a9uW53dCn99oZrPWbk53u4jICEc403ZmOtLhdCtjO11SYVSGMsIGSiXBC5aQgBcEEn8ADzwiEAIVMgIVohBSIQqjcmGXq0w57Szb5TTpynRkH31zm3PuObtba805R/Pj4TfmXPtGpu14ARTyndLW3Xef1c455hi/8f19mx76S/x4jYzG5Spxh4jHFUXnBVcq5IKoZZfJuuM/ZTQvuJzNGNS1iTvNrKvyhmyIR0IAF430WGWzwQDjEHQhIrFDhw69GOBixI0X6GXjmO07nItIqZQ52fsFk6z43QBXTwBYlgfjC3mhisMvFdIdTJPZPACcLD5K5wWZEqSCSwrVUdsNJlWpElBnfDEthYrgNJnrduPuZMUWVc2msvkB7sD5aMUxlm1XgrKUSlDH6ZVda//1V5SsXPzUJeP1W+Rn75I/+9x2jGrk2WnJ3D+cqG8+IB0+Yvf8s1zoE07TtznUmUOTk9eTecXNahlqQuHjRbi/fUXXfdne7+Ke6D2+j5TFyhFcRRA6J8xt4pqkIg5yy4XzwRFj21G6c2FWVoq2AtjCsHoprePeiLZmCFkB0bIZs8q2hlihZrvNhXk5MS+98W7C2Dhi4MpoM55MiF/wIeO9+SLFYDv04k6GtjhPiB7nIkQxlamwFSSr8ab4iI+9ub2rFU3BXdg51RnPYvxQUYoIzjdbAj2jXBI8LhtCZm7r9t292s5+HV+Ox8hZ49xVT2l+bWCk31Cb8hlD8Wof2fnIR0/tPvvNn+j40ssRjoV5AHUdNS9ULcRGLE8pGc+tuZSLKEtKzGI80ZWQvzS0sogh9LmNcY+ekSg1G4v6qCqS5syXqVvRZnwxOWf9nZ/+O9AvU6OZ0vixmfIPHnn9XNqimeRsl3HmUp2Lsq0MflTVRSD+5B/g7S/8OK/jhOyfgLunk0tc/xYAfZ/Zh8gQxIpqb2IRfGYojsu9FfTPrwOfeW/kG9+Y+NXfuuMnPr/n4vKKsaucvHGmLCNYWgEeyLUyL8JSHDkLubn1z8vMvCTmxa6R5VxawV4aH9dTyVLbAu9sjGoj9Ett6web2KioFf3rbKv1rJ3xYh2KVOFhWhhiZyk06YALjtIW8OwA59gNlYshMXbFOha5ba6AaalMU6G4hrxXu5fGQbkcHc+u7R56ch15/Xrm1cPMd159xPWzkcvc0Y+Z2gzTHYLS2xrG0TinYgjSOSMU84oUpag3bqtTitSmJFmLLsvERZQiyVCspOeiFnC+4P1C9JUgY/NLhGVZWLI2/zYoYp0Rlype21qeCkGFrrkV2JpvQoJcW2xUssH8CY8y15Tu63htInrvhNDc+vs+ILHQD5Hd2LPf77jYj2ZW33cMzbHgYhi5HCO7KISoxACRSvB1M/tdRU9eCok1Huuff/xIF2ZtSkXfLLhDK1p2hTRls72oM+MwsCwTMTikbWH7vifVhM8C6ciSF5aykPKMo25k6RAcMXguFuG9nHhbdlylDPNzuGruXcfPQD3AdMBPN7hpwk3mWYUGltIg9W5v1XIMEHaUOODGC/z+AvYG92k34tMRrc3gUxy6zNSaLZy9hXz75My/K5tsuqqgwZRLNFuKWgqugohHXKA035dUKkWFvt34ffQm2447chhw+z1uPyKXA3nct89uSJmZAhl6J6VY0RQq0jf00GWkPBgkTwcJNN2T04K0HZlfErIs6DJTloU6JzQVYNgQM10jBELdkBD7TomSku14ORdm0GqE3BINVrQN25GpqokGsMliCYmAEurAzfs2gYfvLehl5eblHVO/47p/YFoSk3qmVjA+HD7m9e37fPT6lvtj4bV+l8v972Xv9hyXefNEk94mnpCVi1woA7iD8P1vT9y2KJLrCcbxxFEGSp1R9Thz1zF1rKyf3zYXSzZ5f5VWyPBJE0VaQsJqbrm27nPO24JhJ6kSYmvnpNoUUI+QjyQWR1XXXWdu+aMt9mwVb2pq5rRiaKzrcC4jTgkN2l0Lrq7zdMOI8z3qHWhiwFSOAF0c2O32XF4/Ybe7wrUiJudMPVpRLKmS6ozMGbHRzBqkssZEgRUpeMGrGLFYC1rA+8HUnGBYnATM8b/gpCmfs42Xrceq0WwBqkcJeBeoDQW4arDgr/yE5zO/sWP/6oY5VrjY4w8PRrpvBZ5Zugi5VgqKC9ZqcUMHR90m8KW1k5fWnqzeMZdKxAxZsW9rxdMjKoJv4oAi54IqKefWLE0EsLWhz8dK/h9VGxYqRLX3SXquqOzeOb/HqtB8fKzt1fVYbWcU2aLyQnXs/8gf5Z3nn+WwfJud37GEEy4KrqVddMPHhFVp7MQoFGLilhgDQ/P4erYPfOadHW+/feK7Xz/wS7/2mqthx+CE9MTOROdoZHLzjDvlymFSlhxYilv1Gcx5JuVsP2mxVl21mKhmm4YXpTht4pFPGhbro5V/vQe3v4MVU4+c/6O3cPG5wjwVHvzM2O8IvuI0bQz6iiepsEyFaT8xjplAgjQTm/+lEdzVcqGrUFOm1IBIpgueywsDB5492fH+mHlzf8eb+wdevqmoP3HlzMx8HQ8iCa3mNyjOTHGds9kJrMvig1gXOJsFTxFLzlGRzXmGRimQotQuU6s5+lsExfm8dF3Be6HqYmPFd0gMqGez3slpYVkmqjg6L20DUClayHIuYotYYk+pWO71ct48rGPQi4k/xEBzuzdkFQU2/7i+Y9zDxcWO/cXAfj+y34/0Plj8XOuqddHTxUAXK9ELXiqdNMeHNoeXaFFfaKHk323L8zuPH+nCDKBSkJroPrRd+vWVcDM9MN13TJcn5nmg9ntKtVQAAK8LMStLTRxzZV4yp2XmuJxIaSKuN0+M+NDTxZnPlp4rImEKTMMDXbbB7tI9SCKXW5jvoE6kWHBuRPuIDOsAzAR21NY6DN0Iuz358hJpqkyp5nRfpkxYMnJ3i59vqTpDWXCNB1BcR3Cl7Twi4jw+2va1Nk4O6YSy4IMizlPF47qebjcSUiU21U9RgbAgLuHjgNs9QXd7tO9a1hokJ6QqSKrI6QjHB9ztA3p3xKVEbqolm0wGshaoiVwyUhZs5W8Tld1OllnpPOKdsXuWzNauaaQDi5gBVwveKYt6qpzjNcy/rbbioBkYZksHdC0exdWCk9amEkwhN2fURY4+c3m0a336xTuWJz3z9Qcs1wNpviSlSqqJ23QA4KObO17eFm5uA3fHDN1CeEd5euG5nTvCmpWXlLDA4pxl9i0w+cpv/faH/Jl725mmPnJfIjK9JiXzGIteqGq+cqfWR3qYC3PKTDWTajH0hNI8JtoGAk9dFI2Cpto4NckKKs1btNZSCr3zFM1MWpia7UKRM8yPVnbOaH659GjtKGtboRTjSQAlZbKoZXM2WxrjfvXkhhQd0oGcQPsr3hoHno49yfcgC4cpkeY2/VzvuXjxjM9dvOCi31uOrSiTzjy013q49dQ3wgFPdkc4mKv4gUzAo+siNVYqnqqBuWSKD3SxJ3Oka3iQE2EumbEtlmjHIrYgdno2FLakCI+GiGtnfnBCLN2G0L3/VuLrPxX58jcgzQthDLhxRNK0bRqWIuy9cMTMNutUOajxBLNXclMjG7YKHcrOAbUye5grxPZ+XuzTWYfbI1QOovSqoA6/5QLCopUem+TXgHEF2pCnVmtTKmcErFPdHPzXcncrQBpCJrAFo6+PWVuW63ut/mgrmpvWNAhXePanfo50Udl/2HNzEfDpgacl4xoy7U+CiCP39iIhZ1Qqi+spoTDs7PN87q1ryuJ5+eXK/d0Nv/qrH3AdYNIXfP5t2yy+/fyC3J34eFHe3HlKDVQ8pfYgwm3rlC9LYk6ZlAVqMM/HpNSipDVqSQO4zviw4tAQ8E4JTvGuUBr9wArJTNW0ljKId+yjMDXEefLS7EbsPeZJmU7K/hLE9dveQBr3ZF4Shwzz0rVkhZko7Xw5Jfg9OwZGPFkTRS6JPhOj0rdK6dk4cjm+xstCOsDt7Uw/XoBX4tBORJhJmi0OyXWklC3wvbDFIxV1FE0tbxZr+5Kblxq4lR5XHSqVJOBzpQutEHK6Ua28wlI8Wj3XLjDEgQCU5syjbc2LwXFIakaO3hGiKauzgyW28ZkhZjhlKIlNsrzYUrMmJNFrs8wAfATfOVPyeqFv9J3LC8flPnK9c+z7QO87ggQzjvaOpawxintqSPjo6USJLqGyGFe4fcfeB1K2zs4PxA78M48f6cJshftTruR7G1juKOQR6mmic0IfImPf47xS3Zpv11k7Q8z9f1pmjtNEXiZUddvVOBcIvmPUzB6hz+ByhSWjbh3I5sZMKqxEHfHWDnHBU9odpsFZ4zpE6Eb8sIdxT+5HXKvSS6pIzjCdKIcH5O5jdC3M8mzvjVEmLPQ5gDNvJIk96sPGyXHTjE6Jutjux/5oHj7oGqUNvldcHHFDT73u0BfPqJdv4Xp3Bl1ztlikwxEOR+rhAT0cqKcDkhZqC5ElJ+PTUBAUNFNLaZE87W5NxWwUcoVsZp665M0KBED07DkntdmBVCu+bKe6GjnquTBbZdnbd22fvoVxK267KaJEdszck1hSa3N+I6HfOHF7lan7keXtQs6FKRdeHkzc8L2bG17ePHA8ZJZZiZcnnsWF/fM99aOXhIbIiofSGcqjCrHaAvX6wxM3HxpC599SxrAQi9iuLxVSUmuJqHBo5+MwZx6mmVwsd84MVAX3qNmkWgzQLImU7CfnhVyTIQDFro9FthiHq6gDZ7yHWHXb5QZam5G11bdy0Aq5pI3Q6n3jTyjG58Lah96f80dFwcfAW1dXfP7qmmeXe6TvqMDHhztetqy/FxfXfOHZ27zdX3NJYNx1SFA0BT5o4/mkR6Qmbk4Jd8oU75ladIsb7yjSWp7V+hKViegSUiu19MSiZ4fvKgxYkkH2FZFMrIWULa+PttO11E1rrUgVnArVR2s5tjF4VTy/8hPw+97pSd++p8RLUk62MVojmYqydJ7svfFhoiOqxQylkpnX2wwz4Y1qftk3aiBnETi1XnOnBiIVaW38dr0yVog9trJYnfqlIQL5UUuT9ndaYbblYNqfWgnbriNnPlnlk4/fXGna34/A0B7vHYQqvKlK3+7n689/nud/7A/zIDNTvOPm+AHd/JqnV8myqMB8oHIwFNPbSSgNCY+uJ3b2Yi+uLsiT5/d+duHl52e++vGB3/rtN2if0cXQt2V5QPvA3Vw4zB3OXYLrSLWwFE+Z2pyaMpoqaVKWRSnqyGqeYXNrgUc6K0qgBaIDvlIfI9dwnofqGVUTMfPntU2OM+6cc44qa0akmTFHf87HXM3mgrMcgZoKwUGtHYsYGDFExWOCNy09xI5SYRBP5wXnG0WkC/T9SBceOKXM6ZQ4TgkZwuZjKBobglvJZUK1xQ9pR0sqJGcQH3HejKNdGPDBgU9GP2ntR7PtaWNKH/2Us6UGCt5rM+ANhlp5M7PWGrYN4zIviAjzkvEnB1rIznwea4uoyMnoI42STbJuJyobA6ONVbXotxDwMeI7xUXHfux5cmlr48WFY985+jg0j0mzBBHvENedr2NDR+dq36EqeOcQVaSVV5WZio2b/Gid++cdP9KFmWA8s6kkuGsL7E2gjDCFeZsYnRSyXrKv6+S9p/gFjzJXa9Msi/HM7EQ3VKAotTquFs+ldoyLIiUTUmUbpeEEsbeWIq0oE4dUMSJ+KwYJnaEcIaJdj457pN9D7KlrEaEYmvNwj9y8hrvXsLyx98oL0i6qzwktne3svSCxp4Y96k0ZA+AylGkyoUCtiLNKX33AyYCEpgSLPdJ75OkOeett9Nk76H4PErcCyCU1Y9jbO3hzg7u/Rx/u4HiilgUpZ3RKKKZ4KhmpCUrGVaGuPYOccLnarjNXJOXNRPBcjZ15CnDmAqm4T7QQaq3UYkpLU2zW30H+/+Rr2O9VEydnXEBagfHwOjN8Ew5PF+arDxjJnLznthReN5Tr5vWJ+9vEYYEFx/6olCcT1+88I371+zTLLWapJJcQheScKSx7+OijzIcvm4nuc8/udCASyMnGINWCtHOFQxu7b+bM/WT/lmuhFHBippCrO7iZKGOGliU1nyEb26c8Ma/cFzCkUsxR2/tK0IxW3fLdonN4V5shZntdlWbYmDfPoEJnBXOTBlotZgvSmlpQlkTtLvnssyf8nqcvuNpFy6yNHVeXkb4VjG/1I2/vHU864ToMdFVZ9ETZJT7XFLHf7ZSLzvFMPWkJLEnNnFgW9N4RW/JHFwvRBRwDTvqWjegshi2sQoJ1jDhi9USpZF1IvjO1VqtSV9NJ1FkbpFqL0FO20HdR4eVzx698OfIHv7WQ5oIPpqSdWgzHckoEV4k+Mp3uKTXhcRxz5VDqlhixPFo8gppxbFUj8qe1DVPaWu1M/GJcQXvcBHSrG7zoVpg5DH2Y2muvBdhm/ipnd/5Vjfn47mkCtq0wM/5aQ8ke1Q8r50yB2XkzTsWTqLi2aXnxR/8wx/dGPvr+r/Hh3Tf4+P4jnsRE3tWmsoVQK86baKfUtbAJFpaNiaIAfH9J3Xf83ucLd78n8d1vJj74YKb73szomsJbB9ywZ3GFREFkwAdvKRqL4ttmOiVzjK/VtjxZzQtrzstm2ptdaWN7HR/W4gsMeKazyKLNQ4+NrqG1AFsB5BofzIVoBqnONgVSA84Fa6FjHQacGSZnDSCZKIkikUyL36tpu15GQehMfCbB/B/juuvydF1H9B33S+Z08kyTo1vcBg44b4kyuS5M8wRyifeO0DYTYPSFRQWJTZntIq44UrYCd626PI1vqGwqSEobM4/4agDR+VYcmiG2CXEcaWvXZlCl5EoRRYMgwVOykNfCrJrysjaif66tKPMQOqFrgopupRioXb8YPcMQ2Y8dF3tDwq73njFAH0d611lRJhZ159yaEWPq+iVm28g6xxiUqhAlGEIGzFpZclM+/8uiylxvmnhsg+FWkEs4BSNrl3wPZMtRbDdULY4ULBh2rosNoGrIUq1WkAHUJEj1DNlUPmG2q+6qIquMLS82Epo9gQYH6pHiUR/QNUDaBQjerOa7gRJtZ6PizpbxOcFxhvsD9faGcHiN5Ht8malp2cKVpVS0BlOviRVb6iZDzNaBLBFJiZosVEi8Q2LAaQ/Vby0p+oF6NSDvfgb33heR3TN73SItUAjqYYI3N8jrG9zdHXJ4QKYTJU1Iqbi6ukYXQ8nyYrmEpSB1McHBKsfOxYxxU4FcqdnQM7vnG1QufuP5meTcirLHhZf9W+NToahUVjfnlfgObEXCetRacSileIo4pNHj6knJ33HUzyr3NxMp3jE5x11W7o+2oD+cKqfsKFnoqgdfWbhl99ZbXA2RqTneb073YVWO2u+3d8rLD215vPyJ5xwOLwka0NNiBPFiHKQpF+4XO6fHXJiq0GsrPnE4b/9d22naMnqyFrMPqZWkmSknTnneNhrOt3GoDpUVB2le5G0MWkEGwQW863GuA/VodZ+IujouJ/wjx/NcDV085oW5FeoZR40911cjbz+94HLn2Xso3Z4uZYZ6aecie96Klzy9uKTzEEUJ7KlL4nBh7/H5WtDdxIdDRlg4VMcdHrd01NHjm8HsIhXvFkKp9HhcySQRXLbJGbBF2gu7Cr5GXOgoam3icsYqTEnXbPGl2m5Yc0Kct0IayCoME/zST3ne+SXP1TduOb7zFK/eTGuBQ5npDp7d9ROSWlxMdcKCFV9Te8NjO5cDNq+tOau1FV5gRVqPFUC5rXWC/a4KYbP6UJJYgRfUUanbXLkS9V17fcHW0pUT5hqNoK4Fqjb9n56J/Stv7fGxtkQB5sYvSlQWJzxvf9/9kZ/gV199jen93+TVw/eZ08Lu6cDtdOTqZPeG+EpmsoKwCjjBh64pgK2lDBBcz2XveHt8wpfeTXzucwdefTRzewMvB3utfhB2VxEZfROKTOBBXWRJykq/POXElGHKymkRDnMhLWUVvdv3q2ZEm6vlMZpLht0X4nQrLB9vDjf0zPlNkAI0I2SziJFSkJbnHEIkhoHQyOW+FkulEOOHKhmRxqtqBrO5HEkUgqi1EVNmUSFVU3Wv83gRQ3X6ELl3wvHUMx09+507Q0oaqG4kaWDOnoJlfFap+GbFFL1xjosq6iqVzmybamnm5Y/Gma48RjOVlU9Iee3fog/0XSD6gPcekYj3Dq2OnJftsWsbfdaCV4cvim+JLfb9wFVr/a9pLKuZ/G6MXO0NYey7gMdvBrC+g2GIXOwil6MBKfteicHRxZ4QIsG1612lMXNa3FK2KC+v4L2wlIKXTBGLSQQ45MxchKKB+9MPR/7/Qf7mp8enx6fHp8enx6fHp8enx6fH/5+OH2nELGGy8VyVfqVZ3mbkqaf0hQmraKs7mDle41YtsZoRXHAUipH9tRiXWt3m1dQHR0C4ECUy4uZsqEUuG2Lmq/kj2VMFdcEQBjzOd6hfW5k90vW4rqf0HS5Gqgha8tl89XhCDyfq3QG5u4PpHskPSJ1xadm2zaU6VI84dRhV2KHyAN4hzuDt4hvqVGZUq/XUu0ilotohq+3BlUfffo/y2Z+kPn8b3wcER5XZiJYAd3fo6xv05jU8HPDLhCwLNS84lc1HptZkBIJqxGzNhZLsvOoat7RkSlrMGiMrlLa75DH/xVhqrkWesPIz+J1KKFNc1qY6O+9QVyXiirjZc53ZjogQNXDEbYq4sQbmDzLlpSN9r9LFwl1f+EjLZmBYJsUn4wpoZ22JxMzFpXJ9ecWHN/YdxQsxZxZnnmBk0AWmWnj9rVsAPvdH3yUdPNSFdJpaK90Qr+OSeGjB8EmtBVKblxYu4JrwY83TXPl5c8lMZSGVmSktzGmhVt3USFE8uA5VxTulOCFIR24CDLuGlYDQ+Y7gOnyLd5caqEWoq5cRaeNL1GJcj0ULzgX6zsQsaVc5dRbbNV4OPHvScRUhSQ85WqA9cLnrEc2oLuzCBSUGkia6qly0/NR5F+AUOcWeN+FE1MIX5pk8HQmnHemyxYKFSPQXvBl3TK5jtxyJPnOqHde5BQ+LmkGv6NZi6+kRmRHOLW9jrwsqSqlKtp4Ms+qm5nNViYty18M/+WOXfOWb30VPI8EL4b6pQAWm+cjN0rOIMucJUWvxVZVt3FexOa04a2vWauPTcebknBoK5hRDfhqEVWnRTI+UmNZ2BdkwJkMV1kn/MeD1qOHGGva+UQbav3jOZH/790+2PIGNLycY13JCOIbCW89tXvr25x0ffP2X6O9ecpzvSSlzjHfU/o7Pj208XAR8tgB6gqe6gGpGqqGZKmfUKfaVJ/uet6973vvMyK/85zec7jz3zaXo9YWSSAwakODIzFQiSGQuE0tqSGsWlmpE8SUXNBeccww+Etp9Fn3AS0A35L6Z1NT6iUgmHgWYf9KSxREabSWIeZpFF0FMRlhKxUmHDwNDy+eECmKWN6tpdHLQkTf0bVoqS02GjqlAShzmShBnPNwV3SkF8WYoLs4xL3A6KafZoc3I1xUzES+x4ro9JRUT46hj2a52IPhA9ULOFj0nRBbN+C5QV0+0tf+tZlDssHHrm4IT2MLBO796GAZyhUGCcbK2vruHanzcUi37Umul5soPiM7PI9lDiMI4BK4uBp5cGZVpP444HM731FqZ0gEflN0wshvsPHQxmWl2NzAMPV3X0M7aeNGrsKcXuiIEjG8nLrPI1ARuTc0/FabkWGoy0dgPcfxIF2YZ6ykbh3UtbhR/UPK++QDltWy6pVFaWMYT49CxG7rmmyJEbwHXnki/2hLHQtbCZX9Nf/CQHsjBSOe+tZo0J6S0G7OROYuaQR4uGFEQg87VR0rsIA4mAsDakq4twvVwQu7vcXe3yMMtOt+DHqAs1DSfW3g1mF2BuiZfwcz9xLP25rSL1kYtxVqkYmoW7Yy4WFt6t3/3Be7HPgtvvUveP4Ho8TGTtBJv7+0xb27QN6+p93cwTzA3X7K8WHunnVitCdFknKOquMYj05zOIebr77lQc8tMq9oQ4qaG87RAb2OKrsT9NTvtcSZd1Uwp2toenyzK7FiLt9aTEEHVchBV6va5NHrSknEvK/lV5M31zG0vHCPkYrfJLJBiW8id0HuDxn0/0z+5QF/dtXEzWXGvxu9Lah3m3sEH334NwE/cC6mM5DlxOlbmU+M5auWUEqfV9kQVHxyu6wjVEaOjlGJmso/atKrKVBYryOaJJc+kWhDn6KRNugh+dacHUzV5R1VP1jVvbsFVh3NhM+R14o33Ih1omzKkNEUZFK0457kcBy6HC6TNuvfjgXtv+XJEJXSOrvd4gUsdcLGlM8TERY6MunAMd/RR6HoPTxPzwe7FyxvlRTnx5cM9f+Krr/jxf3zg8v2J0yExjpGHppzLAuPOc3pvz6/8+BN+8XMXfOivuPKR2DZcpVbUO6Sr9MxQYekUX6GKbF5edl6LtdHVGaGYRC2OpZU1u+I41ozkzNd/354XPz3yuV98Rbnes0aGBgfz6Ujxwug9b45HwJlqGfCtJxjQda9iAeVq49o9amUuYi3PSPubPLLO1jOPzPqMyoSVWmvrUUU2QvVagEVtYgIRclU6pRHaH5VrLeuxtmJ1s4Zo/74KATzwAAwoIp4bLTxfoH7FPMq+6t7n9Ftfw+dEEGdKyHmi7xxPe3u1i9gzuh4vapkWomYfI9ZoXj834RLXCRe7E08u9zx7NrAfA8ebytSK/vv7Pb43PlIYPNWJsSgUUk0UzkVeFxXvhCEEuBBqMe7Z1JTB0VkubIhnbpS1H8N2D8Ljeaj9YPOZqGv8JFNhex/xPpLE7vslW2FY1BFaoRSj4J2imB+nc47ZebSc8C0PFFc5JfMczGKVujsmQu9sTK+0jmqFWdfLJny7myphjuS9iSV2w8iwD/Sh0BfbABqZ3THHFpp+Wlhma/tHGUmrdU2tLN2R3JIZarPatwRZK8yMsXI2vfbeE0Jn5tIVcyWoHhWPaj1v/rB5e9JCnytlFkrIJG2bX6y+3Xiapq1gHIQnlx1PLgauW/D4OO6I0uE7U+aelsCSjucePxBDT4w2d3X9SBd928RVtLTMXuBBTwYjxID3QvGVXDNzyczF2KPzBHNW5py4O/JDHT/Shdls+zIqlbwyKJIgR6WchBoUrTCf4HRU5mgnavAwBk/N5gouiKkbJaA14NtpGVxhYeYiDkQBTQ8QHFL8ppCUVMwwJTaisHe4GlAxYictaFpjpMYON+xgGKxoU0FyQRu3oh6P+PuPkYfX+OMNdbmjcsKVBGl6pHZrhNF1t9B8dUTMFgOw6KTYUcVBMe+norVNUAO89cwe99kv4t56jg4dUUypJyXhDoX6wYcAuDd3uLsb9PSApgVNGUkJaURz2qIYyNRaqLlQS8WXxifLBV318ktuRVlBCmg2E8JERc9GM5imxXgbKrVZ39huZV1ZHu9Ka2lu23V1/3/M9zB5wIqAOCLJC0UToaGVp+goBeLHsEue248SD/vK8Wow4QbNssBViq9UX3B9pUep4WP80z3jxpE7Ub0QW8j6EmxSrwE++O4NAMfbI1zuOaQj8xKYZs+8GBo6V0/KrXBvrJ4SQFxDIksmtOu9fcdcyGXZPMfM90wJfbeFq3sF7zojGbciOGhtfKJWfHqPLAvBy8o+ayRnjxOPyOrx1SHNN83hGYeei37PLvTE2gwfnyZOLjJcXdIPDiGTxBED7H3PkG0xuHee/mI0pWI+UH/ze5z+8dfR25lLDDk4fe2GL/zSa774Gyc+erPwXYyDNSKkh7ShOlfAPYXw7Tv+C//wjj/+buRX/+gFf/1nP8P9riHhKlxOkIZK9cpYM32OFO8QLw2JhsZeNG5dW6RSBp0LuRWfBz1SUqZLntNF4D//wxfE37jh+alyXGvYZcaXmXyEi90loo77fGrGsKbGxG5jU5RB4xwpXm2DuRnHKpxo3Npqj83uETn/kSJWsc0Eeka6suojZNqoRaG9Z2qFRW6/r2w7YzcKvhWKj3jcWxErTaW5V7hv/7/kyuIdTz4z8GErzG4+/oDTR69IVEbXW8rCUni133HdCO/Rw7MLIZRKCIawdK75zVE3UjXjBeAZR8fYH+hiZbfP3L9W7puiYjxWxmNH3wniFNeNoBY6FXzPrvlR+c0/0Aot7wLLkjkcTtyunYMKfa/EKDhvRTvFbb5YKwb52EfwE0IkZ4a5AC5GfIgQIi5ESlGct//3obO1A+g7RwxqanpaoFzbJPgm4NLgKA9CWZq9T8kIM6qBUeu2UfLicNHTDeZjuZwcy+KoXLLfvw3As+dXXFw4ggdqJfi+CX/M6w3g9u6Om5s75qMiviO4YMbMVXFdj183XEtpwqBHI67aOPGtY2NpCMGKVjHetHMdVX0rau1xqVRKFiQr4oScTa6i+knkFzUubRDoB3hy3fP0oud67LkYW2E27MyWI0QIjj2Oh6Mjp2W7ht4HPANefZsHBa2eTLI1qEm8y3RPSZUpeB58JMTZkoVyYlnauk4kZ0+pFuH3wxw/0oWZxZBU4iP+vKAwm4+JiF2k6KFznq4N9jEG+hDxGs1lvJnB1WYNsBLvvHh6Hwgu4rABLiJmArpZZFvhUcQ+jbh2i3pzX16jNaTr8cMeN+zM9d8HM0pNGVnOcUXL/WvkeItMByhHYEJLRvPM2lQIZaFmZwqYnNDqQaO1WVcsN2VyWKirGnNlsIaA2+3Iz4yOq9cXZLejXyribkhuhMXBx2/Ir98AEO/ucIcHJM/UaoawsiTzCgu6mQDWmtFqprdSa9MvJ1ttWrtTc0GaXcZalJVUSc7k1fYgxTcSv6uGbKmTR7vQhnJtrQKjMK8RKfqI/G8Fq024liGpVBJdGcjOMYbVrRu09pRXM1eHieXJni4lQgoEbze0OodFiiREM37o6BVUFrh8yujMkPeB12QguoFaZmo0NMZVePjIPNGW2yPj/gkuDcw5sZRErX0TLjhWZaDzuWUpSmtL2HcT/+h8AWUtdqviWjxMrm2n3nbprkIIHVqiufBT8ZwQD7oSe5eexIMVZs2AMTjZdrgbKpAAaYYSDvZ9x1tXV6aqbAiDd5HZCw9xsLQCb61UkWw71lbwXu3e0NdXuL/2W/BXv853/8ENb+7g91935CZSeXmT+KDAvRgxdsTQoQNGlj+styNwiecW4Vtknn2Q+P3//hve+c03/Pxf/jwAX/viEy5qR3dM1AtTuXrtyC4h4j7hsr9uBgydNsuMlEFZFVczvnq62fMy3fO9q5ndz3T87H9yRx127VacCHXhMM1kMVJ+wjZJcykbeTnrIxI/QlLdJug1HknVCtIkMLbiaNbt424tz4AVY2ltXao9d215QluCmgJgUSv2PE1U0LApWAvFunmVlTYGHvO419bUTft8EuG2QgqRD/7U55h+zM6F3h2Zkid1Hl0KdYSwVN5/+UBsyFM/GE1gHwKjc9aqzQWt1nr2q2OovgG5BB/x4igp0TXU7TTZay2pUlOkFmuNejcSZMAH8HHkuluV+0p0juikZTQ6lizMoxIaEL4kIQah77wJHfPa0jKLpfV+fKzIfKzMFPGb2tJ5b9QT7yFEQhXC2DPuL+h3e/pm2h2CEmMldqv1ghKTkOnRdm8khd4Xqq9MpfIwJzqXjaYhZStcnDN1ZdcFxr5jmnYEP3C9e8K717ZRf/HkkosrU28Gv7O5VT1ZK6e5ARu7l7jguXuzMJ/AYXN+9pkQArM7q5rPNhmCNlseq9POSKUTS9eoCKlYykOuhZTyJrKZpoW0ZEJ1qIOqFZcd3itLK1CXbJtPEWtIjWPk8qLnYt+z7yL7Zgo99h19HAixw/WOWZVKZp4E1+ytqnoraJvyv+SAuNJymCt5tVnKhYeycK9qRb2fQRcTEa5pN17JVRF65jUD9V9w/EgXZhOFHW2yWuuO2tpGVZlHuBhht4Pnu8iTC1OBjeMFbughOIJ3eD8afJoM+A/NPC7huaw9u9iT5cRpv+NKOhyZ1JySQ62gB0IWQ8WkVdfRU0OEaH3tPIzw5Cn+4hqGnUnx5wzLjL+3liE3HxIeTmafoIvZTdRGIlFPc69o0UoZwRYMLZXa1I/attYu3YF4NNouTL3HxR437qnXV7i9fS6HQF2s27lA9/Fryt1ryuGB7tUH9oYtj0yq4IvDiUBs3mOtJQlGhdBSrdlvvTKDsNOCayuGVNfak5VUlVkhiUfb7gdst2+7Y4P+VVbXrsonXLVrQZsj/bq30qJbm+XxsXo+IQ4hUr0Sg2eqbQKUzBAWlgT9dyP7z3c8jQPH4lnaDe294mtHDiO9P1reWqx0IeH7jO7aLuqww7sjSqHuld4U7/QXgTdtB/Hx1z7iC1/8DBJHurBQQgYcToUhuM10NOuEl4KvihcLxlYnZoi4LuglUTXh8mJWF81PwXuzFvCtvR2Dx9Ph4mqKAC7swPnNfFm7guqAq2ooQvU4icaZdB1hM5k8krSHWghkdl3kWT/yXEZi6y3Ikz0TjisypbPFpGqx0PaxQy6+DkD37/9D3P/2O3z0HyQO0vEq2PD5xiFtrv4H55ikEu3p3FW4EM9D4/k0wSWzCJMWVME54Y0TXqsy/qby5/7X3wbg3b9y4h/8kbd4WjtidKhfWLodC5coJ9LYYpxmWqv7AucGEouhrLpD5tbKCBXvHN/ZHUkPB0458Z98ceLZV2fe/aB5Cg4drw8TlzjK4TWRSvVQcmLEcXxUAIW18KKSHXTVUR7dG3usCH0QcNKR6vK4A7NN6IplXxYxqugiTXhXzyrDat0nJsysVik4Z2raUuv2WovADuGqcTSbUJXIucg7VEMrJ6yVmRf4LsrFly8pX36P/mj30E1aiL0htEsshti6kcO98P6aRjIK3lWWK1NBXpRCJ4VkBLMtM5jJQ3nDUg+8Ka/w6cQQHD5WckNtTykwqXBZdnS6ZwgDqat0u4Fdt2PAekvR2wakc0rnjNuaqnLoQHpbD25PihTlMgSKOGYtZvnjHcE51ia4SosdWs1HqqVPiNZtE+vE1JGTCn3s6bod1xdXPB92jEPksiF5qc6IKjuJ1GheXzoKpwT1uIZkV+NDxUjKM7FYugElM3W1bexgiJWLsTKPylUv3IbK7jIwDoHLxml79+LK0hN2l+wuB+Zpoao3xXXfPBjFkxZhmV9xmG/pfU8NTRKpZQM2UlUi4E4w7dUYLN72nGcvSrMlya5nx4iXSEGYayYvmeXQHjcXZk3UUokenFOqmqfbsJove2UqNv2LhyEK+y4yDB1DN3Ax2Aa7G3vEC12IRtNQqD4Rh2LqeOB4Cvidw7nIaVZcPiKi1JJxtW4F4+AHal5Iy8mU7r4nq+KkbrskwTcVruJ+OIrZj3ZhtkLpjvNaLGIcJedh11uP+WLwDH3P0BLhh76nH/rW5/fU4i3TzAW86+icDVItBWHGqRVbThrs6tnicM7k9NqQmYAiTQggG2Lm9wO620HfQ/BoqUipVuZPBnnq8URdjtTlhMt5ExmoWATSxjguFUqlZmthaa5osTii3MgomtU+RzbvtBACnSqSZ3xOLAeLlPJBUFcNdZtmysMBff3aDGSz3YhVF+MqqOnosypSZ7Qu1spd0bBiHmaqxSahXBrJL1NbQaK5osmI7qUqpTSPLDnzBF0p+FopFNtpecygr1g7dsveEzvHJl9mK+5+Nx+z9XAATn/HbhYviEpz4U7E6hi7kV3n8cHGzVxnRCwsHO/oo0c0kUshDAu+uWj2p0jRQBGhJGgBCqQMvk1Ir96/5/NzsVa6yOZvBA5XBG1pCiWZfYtl5p3Hmp2zViRVI8IaJ2wdk8ZsEfz22tH3hv42Uy5VbWiZ21AQ0YqLSqgQxDyFQghmsugi/pFjZK0LSmltNsVLpe+FfbRCsF4M9KVwj0BqooWUkYun8J1/iP8f/E0AXv8dkFNk6uFVSdwm2EWYkrZmn7nVp9YuK9Wu41EKxdlXzev9jznji7QWXLGe3oOD1w35eO/fuuVP/3cSv/jnfoxpLugu0smRXCYG7ahq93/InrlkPszw3zw+4/MfRv73V7/Eb44Lz5vbeFwG7o4Z7xN3COlw5BQd/+EfEv7CK1v0d7eVGnpu88yAoWG7Iryh8iCVh0do2LqoLR5cgaXZXKxjfrXPyBVuWczNvz3pMfHZOGpGMVUMSUCE9Ljlhrb4JRodpBVajRf1uOmiqiwUrvCoFnqs5bly2nYucFszNxia+Rqo7/U8+TN/gDS4LfPw+cWe+9OEr46qgRAs6/QwLxzubR559XJh7wpFC1ptsyKds2I2yGbk3EUrXN4cXvPy5jWHZaIfPP1QOT4YWnE87ShF8cHMlEOI9GPPxUXgcnSEFmIeWpTYEMznSiksFbpBcKe2sekVlzsG6TZdVJ6NXC940HNW5hopVopxw8RZh2CN6XFt00oXGPzA09DzpOvZd4GxCwxt4+/VGZdMM955vBSqFIZOSS3g28+AZHx0hOJJKaCSmhebnikwrU07DB3j2DHuPMEL+z6yb8VnFwO7sef6ck8/RNxu1+KNMg+pxRVFpYpFtR2mI0zeQtXFDKV1XRuDt5zR0CjOYiOvegiroW3jNaZSWZK1qylW+JTFeLcAy5JJKRNcG7/VXsw1P0aA4symqEYYvDDGyBADF7Gz79Jc/WMQpOsJnTf6T/HGCU9xg4CnkvEpsahdX+/smtUyG/DQ1uISE87NlDyhRJwvJDK1pg2wWPwRL45OPGX63delHzx+5Asz45hx9r3y4KPQ9cowwH6M7EfbHY1tgR3DwBh6XHBkgp1kJ/ShowvX+BYVU+oRpzPSshfVtYpPhTW3slZzAVfYqkNpMHWNEWloi+xGdByQPlLEG/dqycjxBEcrgORwguVkYeTNiNX4VBVqRtpgkFI3P7Ca7PeaoSAbmpSquefX5CwHLUYCijt1cOpxD40QrhmWI4WKnCb0/g45HOlzRlt579qij3MG4FWLPKol4Za0hY9rSQgVqcYtqCUjuSC1bJ5OtVgaQM3awhKs5aaeLSuT1q5TqTitUIr55dTVq+wxf+xciK1+XZ9QbmIFjFasOdMg5/V561FxxqGiogsM2nEZd+y8gLM2TF7bpU5R580JXjwuOi6fFPpdm0xvPDl7cp1tx9h+5lLZN27F+9+8od4eqaOcSfbO4ZuObg34luyoVcileWpVEGe+YdssUtXYJ45zu7213I3X0tqUoSeG0NouspGJi+g2mdoexNTIXiJd7Iixtx/fEVYPKTpyUVK1qKdjqcxqKE8ZVq5gJOXCk2Xh1ifcEdJbI+Wf/Mcs/+1fgN9e+SPKFCs3s/GtxHmWDB2F3crbQ3F6zpHssd/XIqyeT4UR5htpvqV7UapnbDPGR6eFF/8W/Nnh+/ydP/mMJQ/EfCQPlSx7/MkKqvtcueh/gv9+epsf+7/8x/h3fo7/yYPjf/Hlj/nHLwxx7pcTO+l48yAkPbD3wsV95Tc/Exi/bAv6n/vPEilXdlghM6EMKA/eCqc1h+NBTNUYFcYiLI28nx7x8Jfzr9wD15w5yytHzc6XmAJUrYAKAlHb1medp9pzosARZcZc+xFraz4ORKcVjSdfGAqAY8EWOYBDzRw9HAvc4Xj9tOetP/9Fls8KYdhtbHnfV67cjqU6koIjMaUJrZm5LVp3t8rLXhFXCd4UtL5CDUIpGfHNV3B+YMknXt7f8OrujikX4ujYX3qW1t9dms+U3S6WxjKOPU8uI0/HirSB48VilXa9GSzXKuSaGaKnW4OtBw+5wxN4WAoFz1Qyc648dp4ycc3qO1ZtDqq2xqzcqug8nQ8M+x2D8zzpep7guPSOXqBfr7JUUp6Z9EQfeqoPuJBwTugbMt17Q5FMgW3xT1Uz6grVLbi2yRvGbGru1DPPl3SnE6KVzgmxcXP23cDVfsf1RccYBpwLFCyhop/PRWWtlcN04PXNSw4Zi69a1bzrRkMsbUOArkLy7X5UtlHssDmrVuPN1WrpJsuSyEsil/NcL7WhkUWpruKd/X3lwjpvQhuqCWnEV3zvuOgDu77DjzY/d9HTIXgPhUwumSCBpcYtUeV4OjCfTJkp0hi5WtCc8E7PkVnlhHMLQqbkgguZIrop4gGSVmu1e08+/UtQmCVotgp1g9S9gASli7AbYTf2jN1I9AOCLVAOb+Z+4khFEPGE6Bi6PZ27XDc+eFmIwRExU05xHeoihbptYY2MrlRnNhGiYkaIoUO7gDZem/qIhEDx3ubFXJB5hsMJbegVpyNSF1xNLfOu2sVVQ6KkrupHK3A0GZG+5GqxOVU2Q9W0WU14ChXf3OD16Kl3oA2KrQ8RHx2aZiRndJlxwSHOUZqySRRD/8RTs8HxWhak2GeSldhvSgT7vRQzmM2mzNxIgKWiRsujtMiMUm3Xv97YUitl45mZ4keqb8qcM8q1cq4eqzRXBeZqqqjt79JEIk4tdNerOdvX1XSY3NAmB8URsqcLPZ2HpakaY+cpaTaD0BAsuigEqs90+wXfNQ6gihE+Y7GYjmpC81JBG3x299HM6dU9+XMd5sVrO8GqGIl2bUm70OD+ymrP79Um/HVBNnf6ShVrY1oNbZwaEb+JBMR7nA8t+sThMDWsq5Xcdp1RTXbhRVo2XGc5ns7jHrVPnfREJxRJLLUwlYFJOo7OmWQfCHNGgnJ6KCAnyltP6f/63yH/G/+IkOCjwe6NyS0spfAgcAxQkqnvZh+4aJPbNZVDVVryKr14xmZxk+TcUltRH31UrDiFS+CmPUYDvMqFF//zj/m5/2nkb/+BiJQAGR7cxEFto/TH57f417//U+jf/LvMv/BV+O/+a9SbZ/yP/o/f4B/8tBVv/7uv7Hg1zlw7oS+epDANhaengX/6hwx52+s9v/+XDxyP8EQdy3qPtAVsVWUq2lBBYWr8spM7F5hghZtFzJgFQVITCqgYYX/tlGRRTljxmqGZnjT0rD1mRXJR4ylZWLo9zsHWMnTArn3cXKA6uKuW+3nfFNm3Hsie1xTS54Qv/sWfpP6+57gCWRK73jbFMu4YQiaWSCqZKS2t2BIjpgGZkbtjZLcLXCyOXad04iwGSWd0MbxoZIsAAQAASURBVEahyoHDcsPrh1se5kwRR+gqcVBca4tqTRznheNJuLoSxCv9IIyDZz94YpsvndhCP0QIzubopXqCLzhpG9ToqFlNCKLCHJTS2WY1ibLpYrS0UfrJuakEYdVmeW/qT6+eKMJu7Bio9F7ovSe6dUMiHMvEshwh7iBE+jau11zUfRfIgyW1zM5eW1yx1pnzxJXXJgo5kYYTwy7R9RVR/YT9Ux89Qx8Zg2MMJh4LKgQHK/t6LspFLVxfPeXyyRWH6Z4yV6NIiJgKG0in6axeXQsZadzvT1D2LSknq0G8uRbmeSbPM7khZiW3+U4LIqHNbbaWr1BxcAKdEovNCbNkcEqMnl3vCWvweOcZMIPwWQtKtg5Eqrh2gWq1GKgQrHMwV0tVKXUhOllNFXCqBGfFWl4KmhLFFQ7zsqVWoB2OipfyLwf5f/2K+mhHiROI0Eel6zqG2NGFSAgRv94V6hoXy+HFoc7hg6MLPcF1mwtyDANDnJucdzF1nusRl42Qhe201k9jju/W5pIQ8KGjxmaoE3aWL+ZCy4tM6HSino64yRYCTUc0L2hZ0LpAMTK9qBEQt9ycjLUwiyFRhj4ZwXALEVDLKRNn2ZUVT02VenAUD3JqLvV9oEaPZlOkOECjZxELWYZW2DRY1tVqyF0qUBKqi4kQwPLB1ribWlq7taBr9QVG9i9KtprS3KlpX63dYL7QCO5KrWt7zgjwv9thbeWKRe38wMCXdTdrz9cGTTv1hoSuLSKt1jKpHp2U5WQWEEMXeGhFeDcOLBtXxKSWzhkKMeyU3XO7W4fvmlfY7BKasnEetEXotEX4eFp4+P4t4+feJRe1zDmk8RjqubJQS3dwWhux3xYXat1aFF7sebm292qtUafmLi7tPUXNCy2ESHTeFgsrBwhtUawOPD1BHE4sqmXNwDRnmVZ0pY4qlU5AQkffXdjGxsdH+XgWczO7xGGYePs/+E8J/8YvkRO8FMgNaS01GrleErnF8VCUUykc2o29U0dCObR7XUV5gidoITVeFhgiNa//L4CYQ/lCYd8eUzMkCt8Pji//L1/ymf9Z4BtP9zw5eHJ8HxntXPyxv/E10v/9P2VaEgOgf+P/Sfxv/Vd5uPsuf/IXvgXAV377lr/6s4X/6Pe+4Km7JGTYh4iPmRdt3HzrZz37PvD8F295/1AZxJOktOJbSWuMkmpzN7eiymGbz6jnYmq2r4RXIajywLlbsLUiOZP5CzAIeCdI2YwmtmNVV87YG2oVKkKWurVI+0ZjSCg7F0g1owKv3GLXA2MzvIrCk59+yh/+S3+E8LmeE96Q3l4ZGiHcK6hXgiuoK+YrWApxgNrailkDtY7MszBNyhzhJDa3ec1bXmsMyt1h5u44MWWHj5E4ZPq9JwwNrTgmplPmlBKJai7weyOED31mWOkTHhyFzkHwNo+HCjKDa+uBn4U5V9ICy6L0QViCeQyKK+0KrKiPtTJzKW3adnhV8iqi8Yo6IRK5CJ7LIbJzCbcJBdqJdTCXiel4y+InLroB6kDhPNc7hM4L0VUcjWeqe4IveD8R4rrpEsvCTMJ+gnF8TXrIJpBYVZLBRHJjs5JwrmeVjPSthbyrjtkJT9IDL549M28uzWRA5oq0iKRympmX41bMi7POSAiu8VxNEOSdzam1cVK0FkpNlGLebHZO1yL3vCbROkSyiUEUJ9B7ez/NhWVZeHCJqy5w0cQl0Xk0Cplq6klNTDlxXJJZKGGek5KVUgtVKiVVW1sECoVWqxOqo+scXVByTS11ZeHhdGBe45fihFNwRcg/ZCvT/Ysf8unx6fHp8enx6fHp8enx6fHp8f+L40caMVt3h+rOrsslOFxnocYxxk2t0oVIv0KZobn4SsR7q8a997jgiT5u7aZYe7o00PnOZHW+tx8nmwUBdTE0qGbb2oq3UPIQoOuQoTk473b2vNbTkpTQ04zOB2RpHLM6o8UyOzVba5TSaLnFzEqhccy2H+NWmdGntSTA9jhOgaKGmAmoCJoTOs9IU5+ImqFfIy/hcTAVaye2XU2t1dClZhxba0bSYu3YkDlbL7NppKVqC8BuHIK2vStFSQ1wWxpaVhpHaOPKVGvhSq3mNF8rDr9ZgWy7yUb0riibIzM/sONsZ0ONDb79v6GmsPYWBEfFHNLTKaG3GZKyuwhby6eIN4+jHM0EVA31cFR2Y8/VOzYiPwiFcirNisCQBxFrcS6Nazdp4uPv3fNj6bNnzyMHrhnrruotpbSdWrVwcrUdssoZPyzYLrOKOW/RWpkdjU/GGSmWrb3ZkK1G9F6Vem7lZ3kaUmbiAe9ce611F+iRlPFO6bqRJ+OeZ77jCQ637uTdwG265S4cid/8kPA//kfMqeN9bJzW5sRfXGJ2oEW4YOCOiXtgdMpNe62uSUNWv6yp1qbGFKLqdo0ENi8ws65bcx/Ou9BAR5CFoVQ+ehX4s//2h3zrf/gUnwdq7Tg0+K1/WZEl0feeu1k5fP2rXPyvvs+zL30e3rHPfvzmif/e9y5590+84d/7uec87QpLGpDxhDShx7N6yasvXyJErr96Q3mZmTEeTG3ZpHYvwnHlxmGoX1ZTKDePAWNQKHg1xdvH2CT+gxO5byibAk8VXJXGdjrPm85JU9+a9YbTle9WiXo27oyAb/fabc048UxamJIQm0XMXXzgnX/lKX/sv/EV/OU1OXS82wXufSUsxVrxQPUFSiRg6R670rNfHMusLM3IORHxBEpWTifhIQpQGGLB1UjOzUR7XniYAksdIThi8MSoDIOnb8radEhM08QpLahLDJ1wPXRcjIGxW+jraktj6rngjA9kKEw0AUIb0FULGUfOleAguIj3QnBnXiesaI51DlYvs5IrHX7r6bnoqMHhxTPsdlwMjh2FiODc2eIi18SpzNzMt0QCebxiyWkzgLbHVMsnrsnOqmaCO+GDN6FSQ8NDMELHsEuMS2LoA/dvZm6naXP1ryQTC/WjTcouGo0GR9dQz514Sqic8o7DkydM08StE+7qPWVYOLZelvce1PiNybfOljd7nZVrZ7wwJfjCUiY8Nt8HzP5D3Fnpaj7jbRxr3Vr8jwAzilpoufPglsrx4Z6PTwOXV1c8dRft3sjMtZKVzcVfayanmVNTw9+fJmQRuqDE2DVOjIB6G8ttLTqobB6AWjCuZJo5nmaOS6tJ1G46LZBO/FDHj3RhZgBrpa3LADhXLWLGr14qQnCeXRcYelug+s5tPCopiy10awfMKbEFyHZ1IPhM1w34sOCigOtaUPhatIDkxeJccraL7DzSBdxuMKIbUPqx2RgoNS/4aUanI5zmLYyWPBtEaqx4Uzg2kreoOxcWqlYMrYVSq+EetzLWxUiwMbTyt1xVK8pWYq+raHVNqYjFnTTFy7qgu6LN08VIrdLamVqT8RM2+N0I9vZC9jlFVy7UCkGblUFpRVmGzRxzy9FVgVJJTk1MgLncV83Nhb1u5+GTEU3FlKiyLkErEZeNfLx+EqXxuNrTE62ViKCLUu8T6faBcN3TN4XUXAvBe6IORv4PhRDMOyy4kYvndteFXggH46pkMTKuq0KVuo2zxSkvv3vP9esjS5mYlhPBqxnwVG0aOSiayLqYCEILJTdV6KpIbd97i2dqggiHteiDhCYoMD6G+QmZQ5WJWuwcbIWgetDUOJWyCQVMyelQab58PiFFEafGl4kdu9ARpDwap4nrbsddrVz8H/4Jh6/DgYUL4H1gCrY58NmK8zkoNU9EhCgWtbaO59tWcEZtTvXALVawBc6dX2k/HnvN9ap72PhXlYVLhQs8hcyTvw8/+ycnfv4POa4f9jwrbwDYv0yogzkXDu01vv3mDd/5x2/4/Z9/B4Dr9xKvPn7Nv/b3PL/+5DXf+cp71rqrg5k7A/5YoIfv/4n3OLy15+rvf58nHyRzeg/GiQGIWPzTvbcxGrOdygk9B2lv300p2u45bUax/M4WyMoPi6KUxsVbEwl8MzGLazHoxe51tcVzXRyG2ophmkVGNC7bsxp5E4wf+4X/+hf4Y3/hJ9EXl2af4Hv2Enj7siOcTtxoi9YqgpsrngxSmbVjXgLLbOHhALMKVSPOK6k4Dot9oOwg4Ci52dJgFgolOHwHoSs4l4lR6EZ7TBcq07zwME04lIuuchULF9FU+1KbOEvsHHvnNiUwWiAqvo3C4APBV0oUQhRCFnx2+CA46XCtiBDJVDFDW6VQNVM1mVCojYneB6IvVBVCFwmxcuUiXguLLFuA/JSP3E63fO/1S6Z55snuirefXdPHgaHRZFQLS8nMy4lSLOUlePMsc0HXugwfO3z09BoZFs/TS+HVyzd879UrPvv6BQBfWt4i1dnWn6C2eDiPWzfvwK534BypDOSrS3KaiQXqYeGB+02Ato5XYxetmQ323CBrK1PxoeC8WBJMtq2fl4oXs42x18koYqkISCvu2CgWdh5ss5KqrTkVmNPM6f6em+Fjxvael90OSZAlsBRBl0qstrm7aQ4Jp4eFUyqMXc8lztawbOaypZ7jsI5+sQgvmRDMyH1OJ46T8rAWYblhLMXM7n+Y40e+MKNNZuu6m72YO39wNkCq4kWJXhlj8x4LfpPKNeP3trtJVJ+RbYcRiH1PnKOl3gePc4/c9cEQkdx2VasTfXC4vkfGEW3eKRojXoyoX3NB0oLMM3U+IqtnVVroaqsyi/W0ZYWSWhEGmHqxFTwrf2mVt69r4kr4rawg1tkiQkrdqiBD5jLBe7RmVMRugeDNMAqbn3xLFzCn/wJ1oZaCc3ouetrv58y4hmJVHpnCNk5ZQ8lq+6weNg82uxYmqHBYIaK1NvSjbP4wpRTjIKhu6MTvPKw8taKtTS7q2vu5R0ICc9V2GPdMJmW+eyBMe/rrS3ulms2QWALeO5DF4lI00Lme6xf2+S+uE9NDpZYZ7705s2cFKtpQ2yyZ1y+PvPnwFi5OzMuB4jGiaJHN3dwm9WWLKPFI45T5R+qnSjYTDTtHG6fM0LVNtcSKJEor0Nq4aYU9GMnWaW0o9HlniDq0us2iw+gWHhGlqJiDe4gggdpEKtIpuVOef/8V7v964B5TEk5tXIbcvKZoxXo2rlNBGTQAmdCmqCOZUU2NGV3kVDNHZ6pQI763cd82B+4xm6oVdD2rSWfmog3sSyfcV+Ur/8GBf/xTV3wYb/nT330NQPdr8K3KZnExo4wIdyi//m1LxfjKi0su1PFKlD/9n32Hf+fH9+zeu6QuI765g2ucGQ8Ll8NI+lLPR08cD7/whqdfv0Hns+P+KI6xgm8FabVbiMQ56zLa19mQ5hG7d3qF0LiGcJ4XVsWqBt2QzLBqD2ohtEK3wzZtG19ZHxVm7bz2avfqXh1PU2V6b+FP/pWfAuCdP/cHyOMlHR2DwnjRE/uKlhkdRqQlOGhKuNnhMSPOicJxKsyuMjfYc3GBhyXafIMjV2UqGc3QEam1VRoxc6ozc14Qp8ToiDHT955xtC+Sxonbu8yh+WHt+8ouHBncyBADSxMvmJ2HFXbURKrJooZcJra8Yw2wqCOpEKMnFI8kcEFMJONa+kwj/Ztjvm2XxSklOnxLjIlV6FU4FuMWinPsrkbG6pirI7fN1H2OuOg5zhPvv3zFm3DPlO4Y+z2XO5uXog+UkljyQiq2udbqEekJsacfTVW+H4wDOowJH++5v1W+s3vg+x9/zLc+MM/KP3h8h2l6i9PpxPi0a940HUY4bCIIMtFldk656jueDj2nLnAYeoL3G/l/kQnvzQQ+VEWCQfHqPbGtxd7bmFVXiTHYHFUy6i2Xd+X3qbOIrJwTzntD1uATRSDVONKj2PpSAiQHenjg6B0ftg328TKxpyNrYFqEtCglJbIU5rIaR1duTydOS2JKmbEVk751gUJ7rayJpS6clnucGEcxl0JKm4sUD3PjUBdYPWf/RcePdGEWccxURvX4Nrm5WCkyI84iFQ7THX4U9mHHqS0qbhE0VULIzFqZlxPiBpYZLp5GmngTRRhnpe6D5WZJtQXIR6S5wZd8IKbZvLb0RNSK6kIee8quJ3qDT9dYIXKiSwt1SuiyIPmIy7brjFogTThZQBYzAnFqQa01bd/RO4fiyKUa8vS4S9cO4VzwmNrKcCJD35qTOTYZW9yHWgsWI3i7whbvYg01aRYEQq15EyM4dRtJGO9R31pxS4ZiqiDkHJmjK5TXPuRjc8wV5SxeqFIslka1+QCJtW31HIBrlhCt6KwCEszugjPa1x7RrqaydTTFdpq1taSF3mjwcUFEWaYDOk8s/ppjau1o15sPVNCNvCqL0skTDsuBy51d67fem5i+HZjYMQ8PhFSRUMiiaCu49ovnzfLAy6/f4n7/9/DJcb/c09XnXMXKSvdW55A8oTrYne4C2vX46IihLSrFoeJZKJjJo3nt9J0jxNBc/sH5PdF1OHrLftS5FdsOXSNGlsTDBfTe0w+OEEYG71niiWdzZW6xJh7L6xQnzH7mVXrgY/+U3dBv7tnMidj1pJ//DT5445h7h5vhRGUBBllbROdNhW/oTV79y9p/ozMy9oVTsiYugVfS2n1yRsxRI8O3zgexteuqwK6de9duhWu10TEBu19X/sL/6wP+N3/hGXprL/W6CPcNe71pCtag9vjcFoa/e3fizzzdM72+5wuv4I/85i3/9N23eVqUQ/vsF67j1GUmFp4Ux+17zyh/fuCjX3fIP3mN/8je79gC5Nd7JWHfoX90yyw0I1qsUOqAV+3vb6uysy4fJyx6rh4sW3MnUEfPVJTrVqH5wRaw5x3wAM4cHVisCcHQbmwFJlWWtrV5mjL5x4C/8vv43J/6vfZaLwYu2TFLpZPIEDw7H3FdYMmJ2Cpn9ZF6XQiLtfZzLdzqwqSOKduCfkwQhkSSVng4GERBMyn57R4KZEMIq4dyJISEdjMaMhcXdn3uY2WnHXe3M/Nywpd7nHYMsUeYGFckv3qK2PxS1ExPq3hqdfRr0e8LsYBqh3OBSxEWJ7weInG+I7bCrPeGsNSCKdlX9XgqEO075j7DLNQ8c7dU3tQdl/EJ3W6kk37L8N0dMy+un/HOix/jg/ff8J2XH5GXzOWLmWkxZfCTuLeEh75DFTIdotD3zwhPnrNvXZvLfqATz26cGPuOMhVe3px4dfMd3rxvm5Fff//7fP75nuuhUu/2jHsoi8PpsPkh1mqFX/CWJznuHVcXA8cOfJeYSjNpPR7JVDQBY7O8KbYWbfZWakhuxdG7RG2Ukloy4urWrai5MOVEVFuzzDTbmY3SuogEC25Xb8IBX+0mOvhCnU4sL2/sHjoUTrsd3vcscyYXx7wU7g+nc+chzfSu4mtlOcwk5xnG3rpttXJMdu5nX3BzRqqtoZNkkjrmVDdxRtCmHJczPeBfdPxIF2Y0s4OibGqWIsYz67uIOmFKlXJzT6nC0wtbyIZhMJ8ob+HEWROKhUCfNG/O370bmNeCTwLiIt551JUNJao5gSTE900pOVH9Hhd30O1a29NaZxTQlCmnCZ0mmE+4lAy1gmZouw5YMSf9Wls2V6W0CUk3mwjOP5xRM9p/H7c1LNOw/aUq6s43hu06DHkErD3aHtceBNoWjJopydyPwV5nixypAtStFFpfX2vdBKUr72xtX54VtbJ9Ym1f6jFLbI1WEs47//WbPo5k0tby+mcZzJ5jm6ztKe09Pab+3CbRCfKd4qYZ6Uyen4tSXQ/OtxZxz0TGM7OXEd92/Pt3nnP1+ZHTd75mrdfoyUvdDBYBUwqmypuXD/RvLcRglhg5T2hIdC2uKHQRF4QyF1xrEeGEII6x2k64SMJF8MnGvxMzgvW+I7pIWCOZnCk1ZUNgPKlCyoVTiwqZsnI6Vi5CIKnS7Y4U95QyV45eSfOdXWsRaja0Mp4SD9P3+e1QWd6+4kWLuXoyjjDd8vBXv4a6TDpBe3bTgrYxSMuDfHSdVkx6ncdqtZa0b2Mnt98ThgA9buFXzkHX65wdlS1Wiqp0YkVvroUgtkP+iX9H+K98TvjOVTOiZgIHN40mcFBwKGNTL4L5d33z5oF3rkZub0788V+84xs/uyCXey4aZ2p2Mxp6xiVzCEqXC9p70h/7DA+f79Fffmmf8dcy8Vbp2h1UEQJChs1iQ7FCreN8j+/a38MFXHyujcFBid5QyNO3HKETht4RSqLZcvHWJQxtmppeCelBmQ9CdxB2i2x0gEN7z97BuINf2cEv/1zkv/hjL/hMO8GXFfZXB7q8owtC7x1RKs5Xumj0ETv1QuwcPgI5c8rZ1JYtLNvGqSFTx2wpBMF5EKNQ4CJ9Q1uK7um80HeFpQCa2I+FZczkU+PtDh3aZUqqvL49kOpTJCvz4Z5LP278q+oqKoY6Z5FWLNiGbm7KwJJAssW7uSAQPd5boXGUaNFsYEhzs1Gi2obQ+L4JvwaeS6DzZjkj1VGyp86KXFV2vrBv5yK4F0gMTGXi+6+f8ur+Dd/94IZnFPyz9pguEPuBg2aQjugGuosLri/f5r3LF7w9Gg/wOkaCF+Y00bmeuxfCZz974P0P7+jafX348Ht8591rri+veTE6tNzh6Joqvm1iJVsRGgqXvTDnkcP+RLcPOK8NJTTe7iMhpc3f3hnXegVJfPPZdIp4IQgWn6UBamJZla5Yvz43G6WKrQdewtlAF7Mniv2AZGccCQo5w3yakDWvLJvVVAwjKZvJ+bQU5sOJvLT3W2akVBLKUgp1rizzxK6LOF8prUOTdSGgZhyu2Yp6bbYybVKquQESjWf2wxw/0oWZAjgL5l0nGw1QOo/rI4ij4jlNkMoD02Q3zn7f03UDLnSUKjivKB6vEdFlc+zdj3ueqifyghBC4y4ZR2fd10otsMxI9EgKKCczCPIDIsNWmDkXbSs6JTOUPR5w04Sk2cxmsTYSUk063AoaLVhbs2abmFgLs1bgsOJBK+fu0bkBPlHdwMYHOxs9NaI+nNsg2IK2Gr6auZ8hYbXY51DNBHHUWliDrbW5rOv2wVoygT5qlT0uyh59DGvVnRdTwVn7rLVgxd7A2m/rlyvZPrsqdqs2uxJ1m41J1bZUf6JYWx/NhhyuMIqIWGF6X5g/Bnd7hzafhSpCcWFbQERH/CjEmgi1wztrSX/uS7+HnVzh7l/zndsTkwPxAUrduELJKXOGu5cHLl979NnEkgXmIw8uczHYWN2VAYlGvh3cgLgeAYIE3Nqaqw7nMp1EsgTUexBPDCMh9IQGAa/wv/EVrRBWtYDg1V7nmMzYM/WVOCt37sQ7/VPCLCiJfauY7vOEiGf2jhlFp1s++M4Nx8OeH3tufJX0rif+jX+K/gNLByhiAoKltQvOnK9PFmXrhkLWy0Jb5zhzxxToq1k4PB73gvGv0Ecddqwdt+ZNdrS2XHOj79W69snBn/g/f8jxL9jjbgZ4PTezlQBjbhYUwraBigrfXpRnnbnpjB+d+PFffsXX/kvX7FsI5qSFKEKXlRwroTpqAn+o0HekP/4uAIfPL+hvfMTT7yi7N2wbjUqzs2DlkTkylWP7/659v7dewDu/1y7Q8CSjHaTQIW9n6utCPRXKAt3b7bWejoz7zCyJu6iUj2A/KTsPdVdZO4ZhgPunPd9765J/1C386rKQvln5qQ8f+JnPNzRG4dnus5yWSghKkIKT5nivibBN0ErnM+iRovcsZaFzgRp6VI3fI84h2bKEnYt0YaQLStUZ1CxS7INd08tAqY6UgXJgPwh5fOChcd/C0OG6I/Iw8PHHR17eHvjC80vQSk3D2awtOKoqWQspZ0o1vnBKlu0LWMZsDY2xIXjvDM3DIy5a5iMAi3UlmrHsZvWwIsltVFaxDVquihcYvXDpdlx5j7QrPnbKfrwgli+SvjRxf3vHd3/7FYfXd3zYTml8MfJ8UXrnOHiIT0cud5/jrd0T3h4vedpMznctAWbX74nhgttj5Z0XB569+x30fYOJ99Mdh9uPuT294HKX6Iq1RKFQ6jnwy9VEJ4V9EHIXOV0O3D6JuFhIc/Meq46qzfS18VlFrOMRWosyRPC+oM6822y/Y/5qWj1+JUSSEfEUyWb0XYFqtkB++1TmtajY67vgQGbEtXm3FYw5zdwS8K5SMtTimVPmdJq3uSnVwpxm0EhN1k4/+UquCSdlExxINNGc1eBCabGIj2wnmWs1sZP+8DYYP9KFWcEWemm+OwDSC9pZlEyhWEvLdeSSOJzm9rxEWBI+DKSUkViJYeRyuKSWwsNpLZQmHnwklxdm3OkdVN8KBXf+EGmhhhYfU65QiRRR47mshVHVFsG0wDQj0wGZD7BM1LIqJAtrFqShZNp4H0YwZyN4n8/BVl/9LuenskaxPPoMuhZn54lyJYzbYv2IWF/Or6ptoqJmtGS8rHv6ytZb1LoVYNJ4crUUtFj2HjSRwg8gfFbD1a3QKK2SkxWN07WIagrUVaFWV5K/IX4WtOw/sciLUzPgXNEZOZ80Rc7ChXZKRIyTJQ9KeQP543t41s773hRhLgREVvGI0BHRLralEq53mfDOZzl+7vfwwf37pGLRQVrPyQw4IVEphwr3A/XpgVxAs7KOUoC8TPiu0O127JyjV4fPQomO5REj3Ks3JFDUCMw+EN1AcJ6w1aPZ2hCYyrasLWAslgQgFY9mOMw3PBsv2E/XHOWBMlRCqPhsKN2ojlQWalF89qRJeUgPTHJEqn0D/+yKz/yt75LEzIGrT7isjLQW5OPrdP4qGxK0Fu/rUbC2QMbal1EcC2WLY8JGQUOQrZiXVtQLbOkT0TksXQJ6MU+wGASpBf2250/97XZv+8BTPVGAQxaSM+Ruqmf/wn0RblHev5147gLfcpnP/OL3ef/P/p6GINvnjFlRH+icUsqCX5TZwZIDJbfYmWu4+6MXTD8x8ezDTPgAru/tZ9/4KlkxL8FquZmDOrxA7CvLAHcN7S97O4czhf5K+e4tnDzEsePqqb1fd10JfYLieG+M6NXMPMCrUbnZB17trZX44dXbHK/f5nCIpMP3eOfNb3D4CP7mb36LP/yHnwDwQuE+vea6v0Jktg6DB6RQct4QRh+USIJY0JpwNREy9E6g8S+DA50dU7WeamnxTc4HgnhCaxlKEKLrqSoccyYVR9d74pDIvvmTEYnRU2vh7qbw6ibxcLznWi9J+URc53FvaNlSJ7LO1DJTcyEvCU2P5vraNdV4NV7rOieL3wjh9p/aRDmW9xlUiEm25AxpBergLWpp3EUuLyO7MCO+p7jWk/aVfVW+cP0c98UfY1q+QHlzQ3+buTw0IUG/UCI8vXhC3+0Yhj3vjhe8GEaedD1D43yF4InBNY6s57NP3yMvJ55+/IzjZKjti1jZ5zeU44ekp0rNPaGzDWxdu0frxrQqnfNc94G76ukvoLqZaVkZ7qEFjp8RcFiL2pVjBj5YALwXRxAhaTXnTe8JwW3vuXWWVNs645HqN760a5w0pxW81QCIcYGdsBlfp6zmKVkn45dlU9oucyZXG1s5F9vwpYU8OxDzUczVkhLG1pIudWmJG45ag234mwq3tILeVeOo1qrb+vgvOn6kCzPFCrMgdTOYjKbtxoujYjeG1mph2Otil4SUba+91ITzhd2g7PsBIVAbQ28pmew7jtPHlDoi0rUJxm1u6lRPmTIaZpwusJzQGhC34B81lHWNYFoSMp+M+L9M6DIhzTBRi+0wrUpwjdxfcFie3A96p64RLP9cdPRR+2wtg1TLJ6C1Cmd4WVb58A8UZqobOgVWFMlWyLRCx63QfWuHlnMUylpArohZka1D2hC0Fr8E+GYuu0Z/CGzQN1I3Pvr6Ny+m2tzgcc4RQ485bPY9aCyGxv8o5/Kgtl6jqCeeMt0J7l9m9D1rZbr+kqQngga6MCJloRfowiUaKz430Yi8IbnK9Vuf4+l3rpjv31h8SGVLlcBp41IU9CHCcm4ra4UWukB11dzYB+NOeQQtnpICLqwCgYpqj8Na4eKNN7gWmZvVgmZyddbObC3bImJE4YaqSYUuCR1P0QQ3bxU+W17wfLIFrwYrPj/QRKXHnY649DE5LryeT+iHBRr3Ra7h4hsHtFQe+gE5JcCUU8oj5FPOBdm20TjvG9ZhaorC9pSkpsisjbux7WC3OAS2Fj/tuStoI1aXEVjd75VQIPsdPk64b7Y25XJCI3gdqHnaWqa9uk0RiypRlI9OyourEaf36EdH3vmte978+DN7SE644MlVcCVTCiy7gTm9QYeMn1s82jQhWXh1Gbh/z7Fjz+upMB6F/cGu4W6udFMhToVYCssdEDzD1QXhuYe3G09rn5GauHlzx8vXjq+/sYL/eVx4r937Lx9mvgncuMp0OfPmCby6HDjuRtR1uNoyImtkFOVKEgSYuOY4Trz/q/f8/D/9JgA//Re+zG4u+MGKMnwB1/ioLVXDzn3AacU5Rb0jOkFdwTmPb66dixOWIeKSQ7OCeJwLhBgJrse3lmHwFVygSKUrE6lUJDvckPAtaYAOwrCn9onjofDmXjkuicPdxPCMTWVcs7KUB5ZyoNaFkmc0LaSUqKnxwlTRGnH0aDPPLqVyUkPuV96Ub5uuLEpSQ0t6Veoj65rOZS6kEFMiFOHZLDw9KHEwVTnr55cOX2euxoUv6RPqez+JfOlj7n/5ezyRluE7GcIktyee9m9x3b3gehRG74kEOjFUM7rYDF0rTgrPxwumJxe8+/ZTvvaRXethgBfdkX75CJYdpd8T1Ns2p4158a3ZXhWRyth7xgx9UMRlUptTe9/jJSLVIqRy22SbBU9DzLzgvWX2em0iPm0pLZUtW9SLQ2puc7+Jn0x8FM/0kAaxh0ZL8b51uJzlEa9Tg7qOTjtSLuSlmgo/mRn2alAuKgSUVJVZKk6EUK2LtYhuHNN+7bQApa74NhvVCGzTrKK2aeaHO35YZO3T49Pj0+PT49Pj0+PT49Pj0+P/y8ePNGLWwBhEhBQbzNpZ5INKIafSQuDX8Ov2vGKGnlUzi874YJLoZVroAL+q2ETwokzHW2oR1EXzLaNuLSkhoEXIxyOxTMg8mV8UFjdCXh/nkGWizrPtjOcTpAktMzT41Gky6bRKs8AwuLxWM0PYvJoaCLJRuX4XeNQaVI94Nhu1rEFWsvmEmMhAZUOZrCX5SbTMrRyu5qe1cr+8O1tOSK1mSFEtr3LNX2/aBTv31a7Z6mG2NiLbVdre38xls3HbrCdyDuded9/auHDUpsKzVyqqZ5Zla7mu6Jy9hTXLqtRPWnTYu+BFcAm6BXgAvW8Q+vVM5UjKgVAj0TlC1xHDQJUHHE2xOBwJ3T2Xz9/j7bd+jJvjA6mczCtp9R7DIppKOrHcjfRzh+/NP8zVddxC9mYd4ZsLb3YVasQtbfAD6oTsAl0RoGwZcgK4WthEFSi1OmpxhmbWimKGiXULkBfSDPPuDvrITz18gX/ze19B/tZfwx0OUAzKe/Uzia/+kYFf+1LHxzMMx8D1UTgume80L6CP6vf4whtD/OoyW2vdOXKt7Vyfx2YVNiXlY87Z+hht4+UkNmlZALauQ3hD2NZjNVctnG1juvaYUgyF7W2oIV6oRXnwR1wZcC+uAEgHgcNLXjIRHVSNCIlZlbiONXFEhVkcb0rmIjpulsr46x/w6qeNzBWOlRQcoRgfRcYdJSs+KRG3KVCzm6iSYcoM4hn6Qth1cDlwbK2TZdjRDT0Rhy9K9IG+u+Ld6y/wlttRxc79zfIxD8ePeRO+x8//5td5eV/xCl0P333Z9uPRcVczb3YCFxB9xwt69icIrmzh6gDzoDgfSd2I766R4Y4xKX/rF74PwFd+5h3+/Gef4xelHztcCMbVKmWz0bHztSClIGQ6Dzvv8L7SRWVeleLieVOE4qwl1XeBvu/xQXESaIJoQhBSnQjS0YeB7DMpCOMuc3VtaOXN3Yn55OguK4f7hQ/eP3H8qUsogs6JhJ2vqoUpHzilA1kz5GzZiCXRLNiYslI0EtxAIJCKckqZKRVqKWcOp+gjgVbz5VJwvmyora8LcVm4OFoc2tPFcVXsgSIV3/Itvbuk6lNcvGBP4ieeB/Qz7/Irv/kRscHqb3UDTiIn7fDa88xdchX3dHHEhc7U/sCgEdHR+KfS4dxrdp3jnX7H15rY6OQLu71n1xX84Y6y25GL2hfY0Giluoy6RClHUsqQEoP3jPthdSlGl2I51rryJet2P2/5vQLOmYec1AxyNsUuVTeqifdmzeMqTaRhvF1xYRtbpsj3eBfMINsriGuILCb4A2LoiepYfDYeJMnI+hRSPWNVSzL3gsDKPzTv0FIqNLGU9h1erGujWinY2uj1LDaqTUxieXc/XC/zR7owe7yg10bQTr0ydhb8ejqY9NnJIwI5ABaOXVDmKRN7uBqVNCXEBUJceUe1LYTWaHHO2ovidOu3CwadllRMxr3c43S2Qm6pEKzf7mRAlgSnCT2dYJ4hJ6QkpKzh5DMi0Yqz1tqreu6rP+btrz8/eKzDav0n95hjtr2ObkIClOaKb4albn3dR+7GTh+1Mqlbu1JZi8KVN3Um62vRlkpgC+HKA2p2LptFwrmdJdt3NZ6QBQVbMaHN6DQ0EukP9nTNHLYq1gaQcw9rVWE+Powgbny09XyV9n3ECbUYg9Bl6Bd4uLfHTHc3lGsh5Y5YIm64QKqnlBO9d5RgLYNJbhH3IZdPf5L33vsiH758n9P9kSKyEYRLSeAhpcTyUJHTFaF/RcWyCVeyd/XSJoNiHDQXaW5uZt0CBud7C2S3M7Ge7HYN/XqKBVcLOS+ULJa3p44pC9Ns75hKxunMkhfyvPAXb/4g7j/8VZa//WuYV7sdT/+R8HP/tvKn/5Tjl/6VyC98MfK9MNIVR703j6/Xfs/1bErMUsyUdK6V4QeuoMi5kHKcNxPl0e9mMqycBPZqG5/s4ZFAul3v80XeqHWyOuF/sj/aQiosENoJP/77fhb+3L8OF0/tcd//Dun1Ld/8B/83Pnr1MQ5l8bZZ2zxI2zfptXJ/OLHvO3Ys+K99hN6a0i162yioKCF0TL1jPtxj5hiCX+N0XSDnewqVJUdct9Atds/5ub3P7OnmgBs7aufJdQYcU39BHZ5Qyn37+hM9lWfPYNd/HbztVQ4KXzu21+oqzxWeeaX4wJPsiL3jFD2OjsRqvio4LZz8PZdlQMfIsYyU+cjd+3bP/rW/91V++i/9QeJ0iYsDfYggC1VyU0CfhUtRMq4ITh3VKQRp5ts2nrN6pjlRaiBIpXNKLxnvWztzLVqkUkrGq2f0EelGDkUZi/LOU7tAp4cb3OnESRNlXvje92/42od73rkK+OAIbVdbQuFhOXHKybiXLYO4KrDYxnlZlFQmnHSoKKfZ8TDDtBh/1DX1s8cERFIVVwtdKfQlcXkoXOhq2n2C+3vi7YSEhV2KUD087KArUL7dBkVE/RW1u0DcHbvdwuevK+mtK25/ywQOQzownxJxr5SPPuThcGScPk//hS/i93tiU8SK8xB2RuyqiWH0hBmu1G28sZs5caqBJ31H7xeoCdGEkQian1vTUKtkMomH6cByOBCLcnFxgV8res04xWYNZy6RDrvua34vGM+1ihW0tSYg4JynaNpKOecaBaF1VFVca136Lb9XsfnVuw7nhC64VrglS2hoxefY9YTQsywLoSuEuKDOM9VKPrVWJIr6iKuCL5Vcsil1xThjzUuYAwvBOaI3rmdpgETWs2H6ulZKhVA2dOCfe/zIF2ZV28lorv4aCsVVjsvE6cFTpWxmcKt+Q8QI9VkrbRwQQt8WTUeuq21DsXDphtQ457boppkz4pQ1E71H54U03RLyAVcdrnjyamBIh6YFnWfytOCXGVKLM1rlxSVT5Ae60PKDmrXfWZB9whbjn3GutBVXdt50sxepWhs3zGM8fmnRRistvyFTcBYAPPoAVep2k7mVw1bN5XhF9coj0qPyiE+0fcdHHLL1Ma3AW60vqpZNHLA+9vFzVjTtd4lptsKyLdTyz7gv1tey8KeCx25AKTDd2xMO9wn2J0rZobpjAi58pHMJUVjWSxc8Md5z8sqzF5/l+vKa18ePyCpoXhE6kymkCvOxUKfBznM1hfGaf6soWgpJCrUE41a4QnIVStsBihI0UX1HUYufklU04c68BhWHK0rNiZytsFtq4ThXpoZyZc1cuMiNV/7ib8Gle4fj+/8uUOiHgO9s3N+LWpD2z1d+5udnfvIrM3/3X33gF74E9026OcwH9I3Vhw/RE5JZlBxRemhYxe+omzdjVNGzkhKsUJvE7CHWXXijSG/f8Txqz68rYnwRbbuAxyatzbqP4XKP/rk/b+bSH9uCd5hu2d295ot//L/M8NW/z/e/9VtclI4vhoVX7bVfOSjVCotcm50H8PS7Jz78pj3q9qff4nIqzN6hQZjTQomVEB253jM38c9pcuQ5UHXhQRY8I11yBBfQoSEavUdI7GfoamAGtCZO9w/U2BHF+JB7f8JrosaeLz1/h+/uvk8+BgZJlAYv9l7p3MhJJ4Ylc4qVKkJXvC2ijctVg7DEytV8zYOvdOOAW3pcvxCTzV3/+O9N/Ec//hH/5lcqzpvfXnWFnE+kkjZStVJIHGHxOB0oBHISat1zavPKXY7c3p9I1bEb9vSd0EVHiILzrhUGZmGEE3wM7LvBOIApQyzQm0jlxW4hPgt8cLqn7pSbY+ZXv/mSzzyZ+LH4lJhawdjBQzoyldnUdamS8wJa6NtISlOllEjRE3NdeEieeQ7kHAkSt8LM0KACVfFa8TURy0y3KH1tJud5ort5oH95S9FC2L1ETzPy3MM+U+TUxrXDDyM8ubA5+j5xdfct3nUTHBvad0y4LkI4UJ498NGu8NVv/L/5zB//ObrhT3Dxzlv2nn4guw4fMlLuKeXIMU34WuiCzSWH+cTtg/KOH4i9R50nSID1h9ZdAao0P00ppGUhT9n4Y/0j81gVijPVuHeC836b5+xebzYqNROjp9juE/HY+XSruMLWedeKs1wb71rN9AiMh+bEoxJwXuhiZ6kKzqLjhoY6j+OezvWc/MQYMkcfcThSKqS5qTJzwakni1KDva7UQk1mSp83kRuGcOII3iFENFVqqZS2MhdXqKr4Avmx4umfc/xIF2YDNrkee+jM2xMF5gr1bk/mhFTHshQQ8F0bWNkk66GqFQ2lndwu4qTDrzL6Tug7SD6Tc7YdmlO0doRGvBSxyWxOZoKoakpKlw+UfEDzEwBcvqGeJigLwgKabSdSzk7t4HFa0JKoJbXFtV0iPZcyfv2ian8u27PPx2pOuQowV6K+UkxZuWZ9tkFVOVdO4iwdYM3mRATFgUQMZ2EjSTr4hFng2kUUbe1KtZ+yFRqsiU/2+e0Nmmqz3WA0M1h7VxwRRzAhhLAVX06KtQ1UqZpxWNtTHiUCfiIi6tH7e/sS7R0wY8GWeynRkZIwqHIUaMlBVBUOPvFU7km+J5aBpRZc10GvSDWEhIOjdh+QUyXurnn3Cz/O/c1LXqYTt2KIhhbIk53AnB9I90+4eu8ChgemxNbKFOAhwvPkmfxMCJFBRnyJFikFuGCqYqcdFeG+nlBN7FxslP71Oiaqi2RvLuoue2apVFJLp4S9FKbLgc/fVf7ybzyl/CHIp1uiF05kxmMba6WHcWZ+Cy5OUH8Z/vyvwp/6suff/a/Zu37wwtEvCUS4lMKhA5ISNZDJv6P9+PhYLYdWz7rVRHIq8FIqI0YmFmljfb2+TtCqhNYU9Q6i6iek6kkcUc3wuaswo+yvdnD3En7mD6E/8ZN2Ln7jF3n4P/1V7r76fZ597rPUV694km/Z7RzP2u3z9KbyYVU+bPQFlcoz4CXw5DfeB+CjP/AWVWAZzMP/WA7sDkpdZnISTlNDZE4HQl1sE7MoWhbmYItUf7CT1S2eMPaULrZ8XKXORz66uuMLx2eMXWvD6owuJ7QWLp50XHbK7TGh1TI6AWQRuC+UqCwd7HLAFaFEIyuvaR35WOgud9A5njMy+8z+vqfwwGEFZ+8qf/1vfJuf+gx8ue/g4T10d2RJ36MukaMe20gUSEqfPTEvTMUz1R2JHQ3I40CARen6gnMz6hK+i3ifEPcAjcwOmZ1kMgtTSORQ6bpIzpnYvLuunis53bH/7DNyvcF95zW33yh84xLyDM/eamhS35O7HUcJ3C+35PJA1ERZZpaHZiaeOpSIEplzZUkQdEeQHUtXmJtatORI0UAUz04TT+aZi+R40WEUFmA43dPfvKG8umG4n4kfPXDYveFir5AFf2wbxftqiO+TDp50MArpG0fqLxeeHWxbMg+Jyc/gIR8rz15cIOHEL/8//j2Od/dc/aW/bPfGlW0Y/z/s/Vmsbdt2lgl+rfc+ijnnWmuvXZz63utbuiJsR4QJDAREQhpcZRJCNik5REoUEkhI5iGREiVI+WChFJmCF+ABQqFMZGVCPGRKhAQiDAhkrABzAzuNr8Hlrc6959xz9tn1WmvOUfSi5UPrY8y59znXRQoeHJwhrb32msWYY47RR++t/e1v/1/Yo5uZq8cPKIcbBhfpq37X46HlehjJ08hw++OcdRfExlNcQyMW7Dq9sPvJPaUNCfpzDnrN4/GATjNnwY7rKiRcTrVzP4Azp5Yev5L6NQrFe3Iwa6Vm09AQGG8GHBGtgT9FDKOYBWkKwkTgHO+bVVKnOLHjRPFiAAoCoXOc73r6WkfuQ4Mr1q07zRZM5qzM/bxKfVwPB3PSSFatSaUw52MhaQXmC7TB0xShdy3SFuYQmZLjcLDxUKwyTha48x+Djpn1eIHsBO3qLB9gjAmNGY2lctAqYlLbbUvKFBZVeRNUn+cZ2eQKtVa4WSFU7SU9WUQKR66QQa3W5UYQXPbo+BQdMrq5QmpOX/IAcTbz8tmQM50nNM3rKlwo+CyrlZRWDS6tnK6vt33Q+nb62Cm/TCsaUU6DLpUqIPj86xZkrJx+tlarqOVPOKJhNQZCLasqtZSpJwN5MSzPJ7vMLJIGC1eg6mytJ/2kmxSeOxerebAuOBwVPq/P18V7PR9f5zQ+50BQ4bUSwWVHU9OjzayMxTGnQpsy0oD4xeQ4H8+pJnKeKfIM1Z7b53e59cprPB5+lS7a2Bp9puQFyczkIeLzBu9uaOT5QFYSzAw0PjD7FqcNrff4Ko6pPlCcZXljnJnTZCKIQUjOVzFgK2WoCj5biT+nyBxHrqcTiQvv2esN//XPd+TmJUp5Qnj7EVOGTRfYLymfm/CD3YOjBGIL+5y4/bnMD3/Nzv5//79v7R5VJcxwCyuIPJOECquMx3qtlhi6Xq8PGvKl7m/p4FSqttmyj3KCCmOB3ZLSVLYDs9SOYjw5eNCZZ4cbdh/7BOHjn2EaTIq/fzJTupd559G/4p1HD7h35yV8M6FNIlZO3p1b0A4t7X7mXYFJEjsH1wXufPkZAOHpnv0Z9IMjhEzBk3Qk51SR4CXJWRIcLON2BcnGT4l1gLqUzOvWOXBm11RiZhj3pF1cs3QRwQWPS46z8x3nFxtubobnAPhFTHlJ4HLOJOeQ4mtpsV6fHInTQD4rXFNoSqBvtjzzj6ngCDrAO18q/MOfvOLW773kjdtvE6MyxBt2WRlqqfxKenSfzbzb70zqIs7s0w2r7HFw7LZnZsIdCs5NRhkRMZ0qqZxcJnKJkANt2tjh+gnXKL5KXBw2LXKr42JqaLY9L529zPz4QPOm4yYe8JNl9N2Z0F00tAHcwXEYlLkIKQWmiiaPY2JKHvAEv6H3jq1XVEY2zhuiBUTnEVfY5sjF6LjwgXMHzSC0o42J/sEj2rcfE95+iDzaU65M3nKesLl/McCeKvUjzMxhJimMCm6CTV8rQH2ibcB3DYdhwE97gldeuRR+4Z/+j6se5Xf8kT/E7dfOuEVEDgej1JSMBM9+tPv63Sd7nsYtQ3K86jpCs0XDHbKcVz0zUA2VB30GTmkbQbqOqyny5GpY+XFtV0ipGJ9Pq8akyHMJmZUmPc4pXddUeR+PtB2xyQwVHT2t9qC2fiXNRmN6YQU0r3EPTmjblq6z4KmtlbOmaUwHMiuIN7rNVpnmyH6qXpkxUdIBPNbZXsQ6R6Vyo+s06D0gincKJII25p1KhrY6IEzH9eU/CrmMiAXi3W0HFTGTzlspshpO55zXwGDREpVFF6uYVEBKiWmaDJFxznwQASempeJb821TkWPL8FrXNujaayFrg48j5eked+8x7Dd4fQLYpOfHRSpjhnlGU0JyWjNTI/qrCaeqyVYsx3qievB+flldyNzJY8uCdcrZsc+gkv/d8YUu14BNME/Eqkd21KWwF9YAcaVwVW5QOf1QtfNa1DhfqTwvAFrUbsal5LRmHvZs/cuOXE7KmVI5fvb9jwidq/yIZZHxzpE0Pxdo/RrAzMkRYPZSBdSbgGvJhVZb+mxZ7maGfaliklnJeYLgKaSK/lWOTJkoDJTmPt5/movdllc+9g08enyfMhv/KiIUp9VHtDA/G3HTGf7sCU3O5Bq1RLFy6qSR4Eeca8nOkYIj19JCmhWXlcN8YJ72Zt/VWgY5p6N2mi8NOXnmeeaQCyUVpnnk2XTFrXqSos+89DTzO97sKN/6Buln/hUSoTTm8xbrZ06SaAqkAbJTgnNMIjwo4B7azv7A/zDTXTk2WPC9b5USYadWkjzixKxcr9Uya7kmyzXUY2PAMp6W0WKU/AV7tdclUTq1sqXDtLEW9G3h33ZkuqLMrqG5+zLh7l00zvDI0J344GuEbccDYMvA/PhtbvrANwZBq+GkE2hT4pNboc3K1Wz+lQV47d1Kun77Ic++7SX6FFAP3puVjGqmxNncQ+Ao/lzPx5yM/D7L0ipT74OUcMGC7IItfvv9nsOdxHltQAG3GtBvtz0Xtza8d3+ozTn2ilKoqvtVOyxnUhIEj/qCX1v/Ic4DKTU0JELo6bbndNc9c2cIQ8wN4xj5qc8+46N3PN/1LdAPMHnHk3xjFm3A1RyZbhzJO5rgCHQmSYDQt4bIvNQHzreezgsqGcKAl0yjgtMNcSH41FGgRHAj4kbEGcKx6F8FgVf6wtzv0MuGNL7D7otXbN6G7cNE/3aVY7gH+pon3IJLv4OYeWe65v5hwA825kd1jDhcCGxdx8Z1qDfvxySKLpxPhTbtORsL5y6z85mePX6MtPvKO/zaffjyffJ9KI+gXNugib7qXS1Bcc1RU4IcazOLN2R4yjZO05VdwwalexwJTyHdEeaD8loDb/3jfwzA05v7/Pb/+vvwH71L0yQkWQPQPMwMg13HYQjs0xlRd6TogR7hdeClVT/SzLkUT0Nhxvsduel5NGUePB7wVYOt7RNxAp2Nl6nIEruecMc83gc7p02gCQGHx2XH4TCvDWk5J0OrxJL6nHT1SnYLHwEHqoxZkJyJ2bHzni54Gi+0VSevDd44axJsvm/NV3s7BfrOXtOMnhlHSabmr07BO3IuFS20T7RGvIJrDIkOCL7pcTIhNbueYi2+L0Tm38D2mw7MfvInf5K/8lf+Cj/zMz/DO++8w9/7e3+PP/yH//D6/B//43+cH/uxH3vuPd/7vd/Lj//4j69/P378mD/7Z/8sf//v/32cc/zQD/0Qf+2v/TXOzs5+U8eSO5AzkEvF7eo3bjxZ1UQkNVM1Ck2Jd1yi72N2Os1WUtnvR+JlpN9B29vACg24ziF9ILtCqd6PxR397BRXldSt01MOijx7F64/A83VGloroNOITiMuzUiOUE1u3cK/yLEGRubhIMVKsEv34Qdpcp12XZ4G44G6SNUFcA0ka21ztX6SskZ59stVV4ETWxs721b6OykLrnzd08hHK6Kxmp7bkrJkSamiZYsGW5Fj0Lz6ZVM7miSAmBK9iODqxHdEWOwGM/DgZGEWXYNUV0njyzEWqU1i9hfL0u/W/VrxtJDxydNEf9S/ihCmjD93CJmYDsTGDkCKo9QvOZeJVK4ofBX13wKMXJzd5u4rH2e6fgTA/hCJzrh4WZV0E9HDhuZiA3KgrRp4U3HmBFGUKWWEieiEOQpD5cfYJKPs8w0+R7bOuDglJgYphKoP1fjCWDLTNJggZyrcTCNzHJDK2H3qHP/Fl4XsOtLwgOFf/E9WhsswlbQGSqVYcJkUhpLpNLOr5fKxXqf2VwvJFwLQ49hEx7VmYj31q+XzC8GYLKhZzQfWC7tknbzwevQ51G3h5jcn13pTIC9JVTEU1SnsQ6FPwvYP/1fkW4Xyzjv0X7PguVw9IMXMS+E21+kJhcx7c+H8yvOxVw2tvB5ntttCSfANHt6L1u0LsK1Z9ce+8ph3fts9IolR6iorpTqJxHXgO4eR4p3a4lOU5JRQjPtoYz+TUkKi6dRFzG5IDweGHCm6qefEoXXhaRrP+VlP0yzzSL2jtTbWZE+Dos4WPK9G71gtrERJaWYcHf0moBqRIPTthpuDBRrZR9oGrp9k/tnPPsb3HZ++K5QG9hQq1ZY0KleHTNRM8AeaxhCXtoNNNaUMwWgkZ0FxrpBCxPtCo6CSKLUKYQmkUpgpbk9UU6cXGkJTPWnPhTb2tERmnZBxpBdP87Cw7c7JlbbSJyWMhXjhOGwsHLmThXDwPBkq0tI1bPoWvKMXR+8CXeiQ0JDSSI0FuSwJd3Pg4iZzp2m45ZQwX3N2FXGPDI3VN5+R34L0NJCjIpoJDmJFeNOiaVsrDFoHt5fAlKMhpXW+bLxdq1HhcLCbww/K5ZnQ31Iu6r3xzk/+HP/83Ud85ru/i9e+5SNsJ+UwRJ7cf8J4U7lVQ8twCIg7Q8OG4jxO7iHu4xwpESOiESk3uLDHxYL6LVNxjFNZudyhaoqqk1q1su75Ugpp4TlKwFebrbYR2tAQ1JO9icMvRZNcxcvXdaMUYp6ZNSInnHC0kCXgi6fJiYIhWN671Tg9OFv9nHN0PuAbB2Vk2zm63j5wM1lnaMypal3W9axOPm6dCI16ZyCOvUdoaHxmrgt25z3qM+og/IcyMd/v93zHd3wHf/JP/kl+8Ad/8ANf833f93387b/9t9e/u6577vk/+kf/KO+88w7/5J/8E2KM/Ik/8Sf403/6T/N3/+7f/U0dizuH7g74S2zmBTRVuQOv5AzFGUm0lLwqfy8inyIwJxtu+wH248CmRDOdw947aUIbJTtDvUQa1CmyTFo489DUQiyKP8zI08fI/gna9pSq/ltcwMURTQdyGnEpWv29qAVigKYEsngZLmltXiOyJbg5immerlw1ADs9P/W1Dta6j1kQKWt4oyflXrXJzhC5U7PwusMa5J52fp6id+s+Vqi5HvMLAWXh+cd0DTprxsnRS+0okSHrgr3eE1VA1TmHVyOKqioev3LmFp4aHFGX5ZDs93IenJUQnN2AzgkuKd3g2DZWBNv4PcMwWyeXzOQizEXRDK40KzI1xIl5uiZPb3LeGEF2UxwvvfYGV4+/CsA4vm2SD6FmxDEyXSvb13cUP62TvERqE4UnxcJEwrtMlAmd5/UcOQdRZ3q/ZNjZJjMn5Lrw56RMeWCYJg7TxDwnrlMmTAndWaCRtPDbH50Tb59Tfu7/S5kGihjHQjgG/1GFg1eCmlelOltMZpR9TUZ2zjNJJoo1iSyI1cbChqMwsj5fklTen4R4jmNb5VgKt47LF/we3HFMOqz80Kon1lK4X94nDpcK4VP3yG9+AX99jf7DHye/+l/Y+NpcIg++TAmWXzXO0ZTCm7Nye2/nq+/MxXJKJgXyUis43/JkPzPWI/7Eo4nPzUphQl1HWSgMYohYqDzSXAq+eCTnKjejlg/B6mNr3lSFHGfAMdfkMw8j1/MMuwWJaK0jzXlC4zg739D2Qk5HsrQ6JcWCi55GHElsnrMnTQm//hfVwhATbits8bRO6LY7+n1VQD/MzMAknq98Ufm5swP+W3qaLlJcoVpXsi/CPJlET+Mzux10TcuZb9kuq13INKGwaT1NA9ErkAg5kzShVWZY0mwCzcnmdEmCVwUpq+TRK+4Ntk2DvhKJN1+gkxmap1yOI8SJ3dVtG/epJRKQdsOZdHSxsJsSd7LjbmMoZPItExtSafGlxeeAZo+EhmdNItQmjl5nwrynHRx+ti5nuX5A/zBy82ULZPWr4G8WfDcbH7ccS9YLYlbUuJTLGBAiLYKvHb1gc/ZBEzlUVGkWNoPST4pa8z8AL106rr74Nb5U/gXPHnw7l6+dEXNifz0x1wufk2OYQLXDdRvU7VDugLvFOvNqtDXBXYE+xrtM2+zYbM4JviXpSso1mRnnjBtcrQbNOWEBCuwcOGes0Na3EC0xyDlzKme00ByslLkgZol5beqzbGvSiPiCjw1TCqRiAsVujaYwDprzlGBemykLTau0VZEh+Ix3ZkivqsxzREpFK/Uo/+Q5Ea4W+x6oNaT4quvSVo5ZKsdOzV9v+00HZt///d/P93//9/+ar+m6jldfffUDn/vFX/xFfvzHf5x//a//Nb/9t/92AP7G3/gb/MAP/AB/9a/+VV5//fXf8LH42xDuCW4jq52HaW7V4a7BuiJVSYsPEHZyVKkWDVaaizPc7Ad2cc82243oJDDEzBwK2TtSiTZgnTvWycWbUbUqJdVyyc0MV+/gwiVSFylHh+QZmSdI80rwN2Zgvcq5gDPum6ns53VAnm72t3wg70xe+P38Gys6pHCKrwnYYFdWD05d8bL6GjGVGKnInFDlDE4Cs9Pj+zUocdSPW49l/f8Cb2Mm3c4/X4ZdS5kLcVQM8SoqtvCv3ULGH7LvpoC+D1FkQWnW4yxI5RYuMLukglwnzkYrsVwC47hHx5Gy60kZppRJudC4hlxswRhjYowDeX5AvHhC19wi3bzL2XbDvZc+AcB89YSb68PqnVlKYbweOc87fHOg1Mmt9ZN1AeVFi2+27iZ5vtPVBwFXyOqYsIlLguKCJ9dvPpeZeU7Mc8LNmTwnNBW6JMR6Hl4ahE8/EXznuPmVrxKyZ1/3EDmOq4QSs2X0HgjZoVnIlJVc3nhnyYxas8UeK824AsYaen40FGoDixzR4NNxdJp42EJlYzA4pVnilvqk02MwZ+VMXTltrqKz2RVeeumSITie/b1/wZ0f+J3kNz7B4Zf+FQAXd34bQ7pFM/4KDdZ9Ger4+9JjKyN9+2sd10xseogRSp/ZTPCZ846fvTZeVfs0cRmVq3OHmwoHLZQYq85VWYMuMxWxxMIVcwiIaDV2rl88HM+XydIYOkYu7OdxtT5q2hbnQtV/Es42Le02cNjHdczb+ROa5GlwOC/MJdYEY7E7gyLeOojzY3S6QJtzQmhpQ0O3sbkyxpm0D5Azce9488vCdtPy8j1HGw5UtQ+ucZQp0yn41qMY6nTW9uyWUpMD7xI+CCEo3htionk2SZiqvZNLIuaJmGCee2LV5vNc0KitP7c3GzYCbiiEXc/0jXfYzL9E8+67oJnwzDpnS3DkBw3sL+le21G6DkkXdHOirf3D2Tlm54li/okRcyhpnLDLAR2tm7e5uSZcTXRX0OsztnOhf7zn8OaB6Z06BgdLTpLMTCoVES3M+jxSbFN9ACkUrIw2YQHaCmiKVSIo4IvD1b2lazifoa372rtCm2AfH8PubZ4dbpG7BiZoagC0lUDwLWiLaIsZBW9BG8rK+Q0gTW02ukcTDmybHRebczbbhsNDi8JzrumvKygNaCGIs07Tk4m9FENxnQqijpQy85yYUjqx8ivHbJol+U/EPK+0HBHzsN5PCdWIk55h45hT85w7jnNVm0wU5xKiGecrS2m50dRsB0VMgMjAAf2Atc2S4oVmJBWy966hqfJJKUzmCfob7MgETubGf4/bT/zET/Dyyy/zTd/0TfyZP/NnePTo0frcT/3UT3F5ebkGZQB/4A/8AZxzfPazn/3A/U3TxNXV1XM/H24fbh9uH24fbh9uH24fbv9L2/69k/+/7/u+jx/8wR/kE5/4BF/4whf4i3/xL/L93//9/NRP/RTee959911efvnl5w8iBO7cucO77777gfv8y3/5L/OjP/qj73u8uQ3hHLTxx3KayWaS1Ord+6E+lqBbFCKWelaqRPZs3S/jYWIcB4bq9df7M9N6cglpe5ImSkl4Lccic/AUHEVbhNlKT4drePQMwrsQ7wLg/TVkjxxGJEaD4FWrCoYdvKh1QC6elM/9/BoQ6KlMxtfrxrVIX63V94WUbEkACkdemD11ir9VSLlyeNYy5knH5fpbj79fLLee7s0wk4po4VcMRcR0oZZMZSExBzGPwmNxMgCpwuMBXHqft6ZXM9B2tTz8/tNox2bW3uUImWfrkuXRgP+aneHzi4ZD77meRqS0qGTGeSa4grpMUcsUYzEQNOU9I+9w1n+aSYQmRy4urHRyde8VtvuvMCZrTimaGK9uYLhHe7k1EjqwDYWbHI1gqsaPypoxd9jq86eKqsP7TPZKxIQ7zU3u2M2ac2aeStVoUlIxMqtmeFQMFfiWBz1nc8vw9BExGhdxQcpmjmNtcOYdt3C4RmuBMBS1KmwfcrTypxOSmB9lX2w/HTDWmryUhQdpY2Yt1XPkoQH1+2gVEj6WPZ1CpSeZMK/Y/hedy1aFhrLuq6XFO5BXNsRvfYP5n/07+rsXzD/7q3Sv3qM5N1bO9bufw11+BsItxjRyRgbX4srMTd3Xwyu4vHMB4xWtFPZdIF4nms5xVgfb/QeJlx8/49HlbWQauNKRdk7EYuXmVFEBFcvmHcaBM36RjdqlTG6zm0HbKtYN7pySS+JmHNbXNc5XtXO7fzbbju22Y+/iCgoUHCkWJEHQBrwgaUaXMaNLFSIjEhBNME5IOKN4JUhP09i5St0NjJk8KULm8VN4882CpI5+d8AtChdkNMJsJQ3aZPccXlZNuc4JnljniGIeiqUw60TOhvICDLEwx5EYlTldkaIQ8hk7f87Om3bXq9058dEjzjZbJAy0n/0Fhre+jKYGVwpapTCaIKTHM+XdtwmXirtnTWWHrbAbKhVnA2nnmbae0Qs3U0JzYYujC5F5Me++mfBXI/2TmY0mmkGRdyfSm2KdgNh8E4GhImBg92uuE/KRw2xSQMs8rUDAk6sPtF0flguKUEgCfd1HmGBbl9WQ4ezcIdczg/883fgxyr1LNqPy6mgn/ywHzgm0oSekBpUOlRbE46r4Dss3cFvgDNFAKw27pqHrHcUv5VhfX7dAXYIT+1mpCsVEzU0YVhACJaW1MWURJl+ddk7oMrlE5jyvfqdBnHVDlsg0F3yAq4Nju2/oO09bkbXQO+NPukIhkYvp7KV4rEJIUcRpvdus8a+Irk1Hy3pcxPjUtpQVvCuEaozcOZuBSp+swoCVR38j27/3wOyHf/iH1/9/27d9G9/+7d/Opz71KX7iJ36C7/7u7/7/a59/4S/8Bf7cn/tz699XV1d89KMfxZ2BttbxUGp7cdBinYA5c3NtAvt9F/C+rG3py+LsTVqFXDzkTJoK43hgWG7EEIgayJuCtK7KQdTFZOFAhUB2nkzAyWy8pRHS1Xvk/qWVmGjuAx0yjRBz5ZYtZPoFn6VOijbyVpFVjlDpsn09+YwTTuKvub3Y4Vntro0QXN+/GEIvMajW0u9SCVWF/Fxb5kmwVmtNZkny/mN1WIAoGIy9fP5xP0e7JAvOwmohtDSUigjeNRRfzc6LlfpQXaFyk0WRSgr/AMLbc59Z6wF4Ah4XHCkeyO9aSWr7sQsufMesCfUtCSVp5TqokqodTiFQqCXosKfrGubtlrjPbGrp5/zey1w+SDy6/qo1RORIHAaInqZp18V6I4Eh2YTrnAN1ZF14QrX0qw7VYmPKWVmyqFrJR+XYjohpeJVcmKKyz5lhtijy3cqFvPMAsg/cfPUrTEUYQkETzHXZWBwJUjFTbFfFbL0YJzNkrIkFGLBAKajQFCWLBVAjyuiOAdhydEtZe7lCXy+IPt1yfV1YW1Rs/HYIbTUNbupxLCWdLDMuC7e/+TuI7z5lEHCHA09m5Y04UyppJLSZ4Z1fQsptXN2HKzO5hbNor3n7ZubyQsyk2St9TIQWGCOvVi2JJ4fMS197zM995BZ9mhjjjEvJAn8VxB/nCEHwWRESLhj/JutRPiXlbO/1tbuyqIXfCsN0MPkIsMYYb9ITjWvo+5Z+u0HcHinLvVFIiVVayG2atfTp1C3mBmiy4CC7DQdRmjyz9YFue8Gm1ijPyn1u2kgZoUkQR3i2v+G9R55X/GbVhtw0nszAJBBlsclzldRtF8jnbEETFsD4OuEoMylPa2CWpz2ktpKzZhg72nLOrf6Cu1s797euRqT/ZtR9jenH/y77f3Ofg5zz6iGy1wO91BC76aHAXDL5wUz5cqE5g4stXF/WE7EDd7vFv3JBd/c2rduhCbbSE3mK1FKmv4k0j0c2DxNnMsMV3HxtdfGx6wic8sCNHqJrY8sxMDv+OKj8WRNFSSf76mtDl57cE2tiUn/7B2Z4fvYS6Ncmhvhl3Dd9ggvf8Skscr7RDf2caTUQUiDrXPU0G2TZkyiZbNdFGhRHwNMI+AbrugG8a81uzM2mgKCWyC0JJhyvuXOOprHyuxmUu9W2ycazmZ3bulTPjRZKSeSFLuS9GZ2TiDExDIoPyqZ1tL7gq9RPQGgah0ohqwX3wzBwOIyksZbvi6JqqgySC0UjWQ3Ucd6S7/XaiTUFBnE0HppgFA5Xjz8Go+Ugx47oX2/7Dy6X8clPfpJ79+7x+c9/nu/+7u/m1Vdf5b333nvuNSklHj9+/HV5aV3Xva+BAMA1Qk5Q0kyVKGPKUFLDeIhcT54cM61P1nq98OkVfAu+sZZikUyOMCS4OQw0nS2ebXDcPwt8UreMm8JERNOeFMI6SJuklNChuy0cZtQXihdkmOj290k1Tw9NR3EJciLrRGFEJCGSWcLorC1CQiUjlGovhAUaTlbG83Mek0usIUdyOxxvZHeSYRhp2jy9VjX4esNkOfpJrjZJ6zpYmx30yEFb1NlfXCqL2k9UC/JU9HkhUTG+0cJPg2CkyZNlWEQQb7YcIkC2SVnb1hSYl6BLAalKcmLTmnXFiJ1XqHwLU7v5+iEZQLFgTgtKYCZi4TbEwfYVDzNOlLbzKI7YHNgkZRQ4yESY7bjmmLiK4PA8efLz3Lr923BFGJlohsod29zl1hvX7L94zn6+Zmg9Mgzsrw+cy5Z+a4P1EDO7bmbMgmbzMwwZ5pxqzo2p+/sAPiPFfDUzidkletninY3nUjKHNqMHmNMVVyXTDc584moy/JnHB/IUOTy6RoExGWqVqx3TQiIIYoFOUkuMHIJkZWgg1BXjkVhPzi0tHCz3qZwZs506LKNrzT5tOLt6bVN9LetYsc9dRIsFVumMRcYwL+NPhI1qHf9Ki6PUALXL4F+/A/sb3HvvcvuiR2jZ7694rCNnZ3a+dA5sbxVuP97zmBalGCclbQjOltVnvvDOe5GPvtaAzuQcKCRUhDs1EtyOkJ9eMQ8H9lNgnhJtVhoxvbBUJ6YlGSqloM4CIicJf5raiFByxhWlCYFRC0pL8Zm5TKtI60Za2nCbG3dDcULolctdy/0QmJfgbYTUwE2EW9rixJAPQmBSJUVzEWh8S/Qtbh6R4hlzIYSO1meaxvbl21swXbFrZ8RDaSEhlHjBMDrOdhbSuyYz9Wdsyx71jixb8qz0Gtn1Ngi9OzC2G3oHoxRKM4BmkiRSUYbqETl4b/NoTmgI3A136fKnOb97h7tP6kzo7pJv3mb/d/4K6VfvU3Z3Obt+yqObzKAntjmyX+ezqDArlMc2GDdVLGA+g3Q5o+88xL++5+yNO+j5CNfPaK5HNrUZxz/Zk/aJ9gB6A2nvmescslzHJUaraT6sd7P9edKDcRTkBkpFymzOs8carEHAYZWGVq0xSozZtc71oqB7OG/gFo53NELzBW598tvYT3ZEW/X4+TW2Zx3eNfjyhNjNFAJt5dBGcQxENtLgRZljJM97UkogOzTVBod2gBxxsaPoRHYwS6YNz/BiSKvN1q25kkhr3Ky2IN7hvaCVazvnQs52PZxUnqh6fAn4hbylEaXQ+kwSVwMu4Yn3BCdIZexPMXPTdTRNgyvKMEwc9iM5R0oVCXY+43OBkilqTjCdQJa6mix6icHGSwFc8bSzkf+Ty8y+OmdkheDJklee66+3/QcPzN566y0ePXrEa6+9BsDv+l2/i6dPn/IzP/MzfOd3ficA/+yf/TNKKXzXd33Xb2rfSxWwlKPVwRRBYyRG8MVMY0XBNw5Xu9MUs3bQEwFUqnTBPEYO+6pb1QTaBobNQGnO0d6RJ2t5XaP50OElWDfGInZaMqRIniays33pnEyxOGckm4cYNXNYrGK0JITM0STcgigj2Ov7OtU+aHsRbVhubIftY0Uk1iqlwRRr8MYRGl8zE7F81SDdF0qXJxnAKRlVsJtHTw+KBSkz9EpZypZulbuw97tKorQLY1IZvnZNHl/nMamKEMJ6PFKioUmLaF3NRZfu1FNE5sUzp6pr5p5wOLXMUDv7kq5TQzFaT2wLoTF9O4qJy871OsakxBnSlGnlAXNJ7NQ083Rrr7mYrynnd3l26xn6cCSlaMjOwxvupTu4Cwv828Zzrc8oPqHJEaRUVwd/Yjpv3UDdDHM7MaG4DI1ravZaSdVlyxmJ3IxofpWLYWbPE8L2Fu1gYdIruqO9mdFy7JIUzGVjCaSoY6TUa0yBuY6tfYSmLgVZsvljUoOlejUWVOB0X6fDRF8YL/D1rpntqwFCfVMrJsDZOiFnOF8+0xU2NeHafMNL+K4hfeFrxL7HvfEy8Sv3TcVChVQX0WYjXD+zDlzRhuyr7ZIU9lW086wEHkji1cHRNELJVnIuRUn1ht0AL7+dcM+ecaUtG7VOvIJWwvJxrEqxsqWoofmmk6nr93Ollr4loFjnMGKowX4aOSRbDG6XDi+OjWuZmkBoGzab3jrFKlRT9T4pKcEU2UhPMS8cSkqExfuxmCRQ9s7I0CFT1PwFXRXtDCHgfDEBTm/uC3lUbvqZbmrYTTWZ6jx9DrjQISKkeOBRLmyzYxstAtr6Bo0zSU2wW7LiY8LHQo5KrgnQzWyq8t3YcX4QXDjj4taGi2GDdm/Yd3zvi0z/z/8z8WduKHGDnx8xjI69gpeGeekgXMZgDfwLlQdeID+tz+2hfQq8B/JoID+4D3cfQAdOO3yFxPxVpNyAv4b0DDhkXLJ7aUG5VtmH+uHLNFo+YLCf6lPaLS+cLgZrVQUz+Rax+0GwRKY5ub+kgB8gS+ElBX2W2T08cE/s3EtyXD28gceJ2O8J73Y0r/wr6DzqP1Y/MbBDKTyj8BQhs5+VobQMKVHdnWhih3cefCSpwUVSvK13tcznWwNdurClbVvapsEp9G1HCGFNwpfvvt4pdU3Mmlgyf4eY0GtokDjTOGUa9gze8VgtHgBgHkl9oA0dTiHNiWmcmaajYT3FE/NAttKPLScORNWQ6HpOb2b7KoVM669omoMFYVnxYwVSAgQKXYDutBr8a2y/6cDs5uaGz3/+8+vfX/rSl/g3/+bfcOfOHe7cucOP/uiP8kM/9EO8+uqrfOELX+DP//k/z6c//Wm+93u/F4Bv+ZZv4fu+7/v4U3/qT/G3/tbfIsbIj/zIj/DDP/zDv6mOTLBSmGpZOshtqzeTp3aOmysD3hmsCRXm9d6U19tMwtUOP8x8vF6/nNTEG8tAbBJlI6RrRygeqerG9B1yHXChxbUdrsympp0zLpvSP1g1SYtCLkg0cb9F3V/0iO6s320JxD7AtPu5MvXXWbHeXwaq58jVzsqT6Oo0IDsNzPS5VdG0lRzHOv/X29aW5vo6B0fpDbByjYpxA5wzrszJRGMLhr7vW0jlJ6zdmTgD79UhrnbQuIxmM+c+/RLG1DkxX9dj4PmB5895cy4AclVfLXcDeu7xZy3SOxpRYnZIFDQ5pkqImJOd67nAw+kpbn5A2t6hmR6u0igHTOn69r17XF09NaPk4Ll6PPDqdOQUdl0HskX3BtsvJD9x+egOoAmRCMHRiCAlkcV8XVtp6aqVCmHDYc4020IfHbnz3O5uo+GC2wfT7vqkiVmtXC0cpGwLilTUi3oNCwbjD96CI1WD+GeO6EDEAjM5CcySgnOyljO+3uaU5yVb9AMQWli7sMG4NaMHspV7vELX9pxfdLSfsPkl3+yZHj1EG4+/fYf29iXz2+8xYUKdT2+sJP3q3ZasEe8dbYqMpSLFLq7HEdRQxEc3hVdv1bNSq8zLPXLZtOjbM/ceX/HWrUu2KkhjyuNZE0sHsdNCKgmKIX3RG8IQnCP7yhMsBZFCI1bBM3pFxKVE2u+ZquVPiDtSsC7YEFq8b9iebei6wHio6NVyDhPMc+J2aEjSEdU0pJYooiDm1BAC4pU2KFpmYENT7XeapiEEM3JOtXW7RNhPid3kKYuRrHr67kAKmZAcuUxcjYmHV3BrUYwPA5u5waXEXEtE27EgU6FJ0NY60oU4GD0X2TNNnlul4+6V0m5fhfxlANKP/WVuPntD0JZNHBiv4DpbKHyjcam6GV/xdF5kLUSsg67M4GehuRH0uuCfJOQOtDsI/bwu6rIHngrpKaQbXaUVEjwXgFnCe7J/jnPSi/PrOlUpdc57/vnlWEMdk4uOpeeYMPsFaRtNrLbJEHrYtDdsz4wLHYaB/q1HNF+aCe4a+dqb8C0DfPKXSL6i/XwcK8Q/pI8PGa6t3Lzf79nfTDTB5IVc68l5Tzh/ihs805hAI07cYrtJs/Fsti3bztO1LV1oCOrYbDa0jV8pKQtNrRZuTJ9TjAcmJ12ZOMEFS3bybIb0h8MBV5TrGpi5jSfGBi8TXoSSlHmemYeBGCtZIycIHi25mhAe1yWRE4q58zT7zOEGvngOF+eJ11yh855DhfHnaSaLOQeEF6ywv972mw7Mfvqnf5rf//t///r3wv36Y3/sj/E3/+bf5HOf+xw/9mM/xtOnT3n99df5nu/5Hv7SX/pLz5Ui/87f+Tv8yI/8CN/93d+9Csz+9b/+13+zh2JokhwzP6gK1g6a1uF8qeKbQu+grXdh8ELwHaoOHWe8syJYCHnlWAAUMY2vgyZmmYibQmqMUL6c4dIEnHeINqhrScVUkUtOuDjja8nDsoWIZkWSmhdmimhOHAlklaCuoEbbxqaM54OzU77Zr0Obev58UW94OeqwnBL1132/ELCcvuZFpOMFZtiKln29hRSWxc0CsiUwO/1+8gERk4jgF4+1JTBbuTJHPt7y3lV6Q619vK6Vz7WiL4Hmc5+zio/aiZodzPeq9Ma9AOctzbYjhQI5E4JH1Lw1l3J6XhbwFvQAj5q3eb1/nXMvPK7Bm0s9mYGzszPOX7rH/uHAnAtymChjoKnjy0mkCR1tP6LekgaXweVlfBwz5kn3tDHQqKNpCq3zbOSMtk6UClwmxyO2tH2GIOzOXuf3+U/xe9+y87X7WOTRF/4BufGmtJ4SoWk5pEQ5SR1ikyxrzUurvWckIXIk4h/URm8U47mc5k6c3LPCsbQDR8RV4bmF6oO8NUvdf18/s8V01RArtUobuP3KS5TXetL1MwDCYWIiw60dbd/C3gyYb/DINNNWNPyd64lXLxu+9CBy4eFphoT5ym6qTtZcFC8t74wzr26CRXbOAvNqDsBFG7geZj7ybObnLjNJxIzUxWQScoX7UzIRTV+5l77GeDWrs3NTkRDRBqcmqTM7KCVxFQ88K1WqgIyI2du4ugp2vafrj/MbRfEFNME0Jnt96xnnjDhWlfZFzsM7j2vEEJFYQCG0NTDr2opcz+t9ZtSvwjwK42Q76+eGuHWm10agcQGdlP2jmfu1BDa7wGvF43RDyS0kj45CMyouKWG2c9GX96B5hSZPXObbbImkiye06Yvs/99/GYDhZx/ik4frxPWInW/JzN4huZDrol5KWcviVFRKKryeTyArE3st6AT+MTSDw/eeZhvXAZ4OkAel7A2BLRx5u8sYXvbo6/+XaffrJrzPT7Mf+PzCpfXokWpyWuFQK3s6lJhsbjof4OwmrjzUDk93UNzPX6Ffvk/67E8j/6vfg/6vz/GfqE4DFxPJnRN4yD6PPE0T7w0Pee/hffzc4Ov903UenzNBtowocSpIDlA6xFmy2LU7Nt2Gjd8Qgse7AAHatsX7U47ZSQUGW+MXbbQjF9kRnEdCMtQ3elrvUM24khgPde6dW1JUglhCT8lkZiIJbWwmkqAQZeUKLzQfdbUCV69HP2WGAKU0lGt4epN41hXONoWzyjFVNdpUcLrqU/562286MPt9v+/3fV3iOcA/+kf/6Nfdx507d37TYrIftC01dVdOuQJ28ZpgAqmth04cvmS6Ovs3Xoxc2zS4rkcxtW7fKk2rdO2iXGz2G3MuTG4mN5nUWDqiNbLRRnChQbW1CJtASbNl+/OIik2UzjnUByiCJkVTqjpmp3DfkfgPnJQz8/OZ2wuI1SLiesQDK+L0AZdpfZ++/7EXAYwFxFoRJj0ulss8sZRjnnufCOLtuiylZneyL3uNr8HZUYX59P0ez6KIvgjMruTghehdobvleSPHC8/D31JLm/XmXlDB9fnTI6/XVJdunEJsQW/XQPCiwZ13yKbHudEUqkXJeNLsGJZBWI8ziBBbRacH6A6yNPQLH6Jx7IONyTvnt3DXB57ePERjYnqcufzkue1KntBJW5m/ESlmlF3cMSBTV4w/JsqemSyejes4cxv6/mzV0ynzxHWXcHtHcA5cy+u7V/lvnn4ahgf2mukdfGkYotJuOrwr3OhMbMCVQqglCE09WSMqyjMpnNWFOHPspFxO7SywLW5V/l6v9DKhYn22WhsCVl6NHMfNB3UbL4jDAKsR0YUIrSo+eLqUuXjjHu5sx3D9lO3BOFNjKIRDpt04pjbTvPuMm+uBKyyofqVCDG+PyuvbQOsjczGvzWsMaVjKyHunbNWzx/NoTNw5r4PLGRIFENpEGeD1pwn5lCIzqOa1qzLnhXBcjhWqRU5JrZt2dbYQj3eKutqNvHQl60yc9jyd9uu5Cd7jfEMQj8PTBGG7CYQajETNSL038yFTDpH2wqMl0Yqnxj+I82QF7xvaxuHFISGgc14RsxAa2maL+gMhQ3GOJIWSMuOUuR5tqdkOIG1D2xSKTzjNEAtPoiPW4DOnwEtn0JbCnRG6wdEdhBCtKaHUkqGGgM5PKU5wm4R0ifDmFzn8wi+TfvYxAIdn4K+EPgUGbyK4moVDLtwC4tIRy7HEbr91RaFWZ7r6eMTiGD+Di4Wyh3zt1nknx0yZl2ndrtuLnO9FLBtkTQRPxbiXe+eDuvHDC/N/luW4dU2e2/rR1jBg24KcOcwiKc0QxpbzQZgqGlZCYDMp+m/eJbGH+y1P/7t/wPm+hx+0ecS/+nN0m9eJzYF4GHkcH/ALX/gCT98bufBnpMYSIKcJSm/NI901Y0jk5HAO2tYSm81G6DvoQkJq16fgV4/qNdF2HJvN6ncRZ9/F1ZumEeg8hOCYMrS7hniY6aVBphmtydR+LlAC3ks9a4nsCkkyY+VfppIZZyVHyy5tHhKyKE4syALIHbgETiPqLX8KsyX9Y22oujyzmEObQNf8xsTMfkt7ZSomPij5uMDWxjVT7u2gE7MRlwmaene4rASf2J1teeXiDPUWfheXTUSuprpN1+B8YdDAoDOpiWjfUlphmbUUgbZFcqD4ANm8v4ImNCe0epGUImjOiDorlaoR/MvS3sgS/NTumvWxzFIdWe5urf+sMdxJVrRMACvJcFncTm5wffHv9R9WI/Ii7w+mnnv/B2zHEmddRJxNaqeB3LIUL4a24qwQdRp4HV/ljyVOOQZgqxQGFrh6Kc8HYyfHZPZM1SGg8qC+Xil24fLBIlVQ0B6aTTWibxy66SidA/FIEVoRohR8UPqaDk0hMxfwqrhOmPN7tH1gkg3aWaDus+feJLznOza+Zdhc0Mw3iI7sHx7YZJOUKd0eL4Gz0eOKx8WRUGyBWmZc12SKwjhd8NA94zEFCT1Ne8Em9GupbJIdZxrw+QklHrjxDd+57+HhNbNV72j3DzlPiZ0Ih2FiFuUmCLkorQqHWg6ILtKoci5CLqCaSN6Cs77OKrm6aiw8tdProy8MwjW5qNfFLRlqff609LwEdhl7TdVlBqwZYSOOWArbzZbN7TPS/prdYSCP1QPTFbbbC/P7fOcB+u6BN52BQKMKT2qXVxH40j7x8YvAzz0pXDSFJ1HpgKHWrYIzMr2XhvemzO3zYEkfx3tRqxjo7fcKHREvvfFJnayWYjYojvd5xsqQmu0e0oXSWpdgIeOdENW8/ISIxpnDaOMrlUIgEIOhzM45QgP9LtA0tkANQ/XIBVLxDMNEe2Fdzv6kW1Q9SFK8OIJrEHU0zpMqfQCWUqYlp5JsJXMKlEJKmaHywmLyhLGQXCRrMlHebK4RzWiLYvcEus2WS7nFvbmjOTTIvkbhzlOq76a4hlSeodMjNAmMvwrvDDSPIb1zC4CL8YpnKfHEJaZszhQbgY02HIjHJqj6s/j4niZvy+VZyoNax2eoYzLngp8FWay19Pg6Q8JOLPBO5h1LdI8P2Hz7AnXlAxKT91XDXlgDlmN12DGupcy6Z1WrEDUoYfb40tBWuD/S4orA1YhrWqbmnH6+4s2/+4/5xGuv2I6++x6S30W18Gy6z5e+9Ct87SuPCKnBycRckaKUzFLP944u9oTQUXKL81s2jXHaNmFD51q8NFTXvxNAQo/KAN6+3xKYiVoy7zw09Qt2rWPTNzbW28ycI40G/OCscz2vAQBxnineo1qYciJRmCjM9d6fDPgmp4WbbNQNQyJ1Paex2DEFwFeXFhroW+GingffJtrO7vV+aZH9dbbfILD24fbh9uH24fbh9uH24fbh9uH2H3r7LY2YWb3fSiTuNC1AaYLBua0TXMlIOJbqSlI2W8fleeCle2dIG9BA9SVLLB5x0WWyFFLjGDUy+ZG57Uid0h4WwmEwBiWCFqGoEAhQEmWKlLWUCeK2FlEXsTpHFeg8RQlM5FXWbr9TMVdZ0iZ/RBOK2PdaxDaXf42kv4Dv9v7lFL3IFVvfdYKWHTuGjluuxyjCCnYU3p/BvUisdw4WaVdr/feGNIoB7c5VvbO1HG0Z5qk4rapJdpwiZvaYfXl54WiXv5ViBNETp2zP+7//+j41Hg8Eio9oD1Lb+BMKnZVpGwWcENSyfQmZvupfHcy2k9YJoltUrmmaCXd2C3djJRaXIIdEP3Vot6E9T3TpDB0T/mbgcqxoRb/FzSNhPhDmCTfOdMnjaVd+T5CqyZMmXi/nPGDmCSZ70NDQiKEHt84uuLncEDf3mR4/ZeaKT08BNq/Q/s7/zE7Am98Kf/B3883TGc8++6vMX33A7nxLPjwmHp6yr/DO45//HI/9gacx0zbQRJPdyBTi4kVbz+ksNeOs58wVreP1dMS+f1vQnPXvF35nbOz3PN951njBpcLulQvceGAar/HjwFyP/azfMaOU6ysYhK8MBzON9tboMJyMhXdV+bgGLpqRJ8nGZeKU+2bCkTvneJZhmJWdB1cErdqKvrPHbr0Hr19PvLfbsZmbqpIjqwabR1ZTe1XjEgqVtla1u8QLUgyR8l7JOuM0MHmYU+TxdA3AlDPnSWizIzpH0zSIV7ptwDWLlEQtBwsUdQyHGV/MB02Lrje2eekK4jKeFilGsHausBSu27alDSZ3kEPGq6zl15iVWM/FlJQmNTQ5oJoYZ6N2XM7wsdG+4zeo57Wm57Lpadw5ms4geuPnqkOnWm7iGRwe0D5KzE8KfphwB9jPDr+3ctq1WmPKXBvCOiCro2ANLqcAxjKX2ol+kSLCylNd+WCclN31aJ23TFsnlcnnxu36eTz/miJUz9fnIX1Xj7Ny2ylOkHTE2jzWaLtg0O7kWAPPy2UsWmmrV+1hJO1nFjHxsG0pIZHTASGQDo9p9A6bR1/hrf/PPwfgtbv/JYdveMJhu+Hffv7n+eUvvMN+Etreo3MkFJuXXG6AG8QVkkTEKxoU3zh2nZ35XedpfTI9RGmtVF8bl0LwNItNVyvV4qmuX5WsJ0aqs89rILRC4xt82xKjI2kmxmwk//oe7x1ljuQcySgpRmbNzHrsmkWsI3qhPxfniLWh6JQfG7xQopJne51rBNeYekCsptxNY/ZjbRDa9j+CUmZZJhU4BmaI4f8OLtqeprG/PS1aO/TSNHF+0XHv7gWXZ4F22+P71kypU2RM1pkxkYgOJilMcWZiZvaJHARZyNnqydpYu7hrSMsoKfbPIrYJjlKM06G1/UdKtt/1JhSVek9WjpOe3N35SEXTcCxjLpPDKfT+QdvKF9O6nw/gLizndNme4/WcQOoncVk9+XUB4lgmXcqX3lnw7Bd/S4SMtd9bF5KVM9dgtD5m4nzFOnjked7YKcfRuGVaOWrx5Pm6kJlOSFX+N07CQpp/7hTUssLSRBBKQ/SOVMncAKMmgrd3d7kwOWdcOC2mJVbPg3M2SWxCi089RZSD/wpn/aeZn1aeQzjgxoYmjBxUCb3j4vycFA/4Z89o3rUFtkPZzBPhwYhPM/5mpCuOVjKhtSvU7ho0FOYMrrvgnkYe0fA0dFx3F5x3JlXzDZev8ZHdjq/6j/NPf+VzXH31ho9dfAq+8Xegd20ylc9Hytked+c2Zz/wX+JfegM++hHiw4dsvvxV7mQ7/o/evmD+h/+Un/1//T94rzzgqS+EDOd4burIaU/Gk6LPhc4vBv0ftJ12Zb64nY71TS3LABAgpsx529LfPUcf3qdzhbwVgrMmiJQdMu1x88w7z5Q95khwA2Tc6i0qQZnHxLu+59UAX42mLB45Njgkqb6EmokI11Nm14JmXb3xChA6YfdMef1p5Eu3lEtanBa05BN+pZH5zUOz+pCK8VWOXCMla8GXjAtC1EwolesZI7FK/aSUrAGmbdDocN6DtwXC1cVOvJVLHSCzkg4jOrVWGsrZpA6AWCyVWhqpXOVFhcat2opd09AE04aay0hTvYmTgC+RXMuUh9nhO09XiiXPvRIm+ORj+JYr29frXccbcce231B2W+b2Fq5pLYCdrKkKwEeFJ4I+CzSSmG8CfpjpBvOcBOMFPuXYEQyYxyXQEkwA9WRQLfOo1ODI/l9J3KtoECuZ3y3j9GQ8LuV7ZdGcPFGMXz7qhaBP3bGMB8ckOqj1mbVOCLVDdXYe8YlSG4kmtaOKJ11hydl7i4I74XKefp4vcJ6VEB1lsvOQuSZ1YoFRChALaXrKtg187edNjeHL/6Pn2X96yeGs8IXP32d/SITLHblNbLot00JODBHXwXjwlFjWBgtcS9saK7RvAk0IeGfrZwgB0Rroty1NTYrb3KEpoxLIOZNSsWaVEAg1IXaNQxpoWk9wnl4aDvOA38yksdBUPoDLnuwCUKogc0azlfXd0i0aBOfV6BjF1B+ssVCs9F4j+ByUkOzxuQFxytZD77DYA+ikpXetjXc/8hvZfmsHZgnUFxovpCWjxERJ2x7CpuFeE+g3NdqtnbA3G8+tV8+5ddFy9/ZtQtNDW4ipMMZiEw4guSflmc45hvQMZMZpi4YZ55ebYIdvJua4oTRnkPa4CHiHlEJYyDseNNTgsCwkTTlBtqBowqtUhSNQcdY1pUoRJS0ZbGbtECn5FC077osaYCzyFhgAV88RzxH71/PJ8TE9efL0NcYPYtUF896hC4HWUxWOTX1e0bVRYjEeB2d8AqOgViJaSzl1Qy9qJEtL1e1mQAjijSC9fEdv7W8lKeRCI45ZEuLrwWA8MVtK7FsVUaIcJ9QlbFaWTDPb+RfIJCTBUDvdZnfJUJQ+nRH9NVlmrtRRpMVrYawkbs/ALoOTHc12Q54Do9+Tzh3dw/P6gSM33QafFHdL4Gqm6QKbbcfd+wf43K8AcBFeYvvoKfnpgYurkS5toOnpAvS107mZz+kvHYfkGG4G2vNbNMGxvREOr71Cf/ujAPznb3yG2yHzSbcjt8o/vXnC0PZs371C/vZ/C0B6+Fny9/5vkZ/7e/gxUV66g0rAf+O3UX7ndzEvM9fDh7g/9Lv5rm/9Br763/53/MqbP8uNgJOOUhY5BstCfYYJEzWeiiEYUo7jzWGCnoJl/g02pnuet2TKdTVbUYF6/Z6hvFSJw1OJ7IDt7R1ydYO6iaINvr9cbwDJE0yFZ3vrHL0GxnVM5LX93idl4+Dp/oaNdLziJqQIkyjTYh+UlSyBjCf6mWdT4NUuEQXKrTqcC5y1gXSTuXwQkU8KUzSO3uxlbRLIyXGNM8utyU6E5WZHNLmUUp0vxORTsiOLx+lMZuJptID+kDLZtagKyS3m3gHfjbR9XagzuBmkh0JmjhvyVGh3kDrPXDXRUIdKS87KpFds5BYl93QbpdK98O0lrjvQ9Z4yOGIoEA31i40lvAD74YzNVph9T+pnwgSfuPZ8LGZeq9y3b9rfwflLUniN7O8Q3IawPSenROYpYZkIxyvS4PETlMOIltlUZ8QSIxvQsKkB0IgJD5s+oT35ojzFEv+uSaJdmfVfOfnJJ+P1dCt6DMoSxwT3NNFdmpCWAM/X+2Ih6y8uFa0IfdsT+g2uaSle6GkBZaoAAuOIGyemMh65ckvy7Y4Jr9ZAX2u/mW/MYgo/4muXpDrBkVF1DBzYqJCaFlFPH42j+eCnfpGHX2x5/OoZerGhbTM311fIRcOBGd9Wrp04NHlymBmISGzotEVx5Nr92MiWnXOUxjqSG99SSsK5QNt27PraUa6OlMxyrxTTCyylwQFdDYCcg042hM4hyULqbdczq2cskbxolIWG3s1MOeOcuW1otmBoaeoRPGRoKcwUtKkC2bPxyxb7sCZDceBbpY+W6DRYUrWOJV/3VwJps+c3kpb+lg7MFtToVC5DNeME+t7z8t0z7oRA0zpyKfgavF34Dbdfu81Lt885225wvgOfyFkJMww1E8mTEomMqozOcTVd85I/UNyZ+TlBnd0UgpWO1LWo8+QkVQjy+DIlIuJrRIUFFUURPXYDPt8mWFuBayC3dmpVePVFAvspqrSQ99dM+2S3RVg7jU4fo2Z+5YXJ6sXNs5DjQdeFoi62Qm2vXwj++hzi5U6KjksDQJGl9fn0WKUeshFBnT6vdQYVqFNniJk604/TY9BpX+yIJS72Qafn4vR02+mv5VINBIRZlS5VOF0dczY4HjJu7vDtTOoiiYSfDJ4/E+FSIo1GbuUzShbO5ifkcM0Uqup6nnE1+8spoUG4xHG7BO6mht2Xrbuuz5nukOjemmmuoWWgPRs4211wdlbJ0vccDHfZ7gZwymFsePXiI1xtOyZf2OzsjG+9I77yMjx7wHddnPG1b/12xi978md/gvLv/qUd1ysN7Re/BvuBeP2Q/PhLtJe34cnb6E/9E9rXqtZgG8i3XmVKyht/5Pv5yHv/FZ/7B3+frzz84rpSHRJQifgHTM/MDFxMuvY0KBaslT+upZZAJK1ZgQARZVffM2CLWIMhZqHu7bHAJXDrY6+R98/wviclCHOkVGJvHGemFLnOdkweQxdUbKEd683R1OGzB55ULHzBHBYesVX8MpMYLeGqJIbZVMJX+RSBUDsr71wrIRbmWo5KxYjIAGkcKFnxuarx1NJKkaO4dCh2n3lvgWrRDFqsFFuUXOvIU44klIDgQyX/h0AIgbYunCFYswZqVk+lFEoSpBQyaUXySv2cGAd826NOSXlPSFuappatpHCxveTJ1X0GVy3Qaqk0l7ImzuM4Mc+em8FxHjwvHTKvzoVP5JaPJEtaWtmR2wtcd4aGHvBosY51p6DREqVy84y8v0YPB8o0Q6zlOjnqUiatjRRUJEuPiBe8MFecPH46PZ4wSOy8cSTXnzYPnG5y8rye/Ly4z+V14eSnAbZ9FUPe7Oj6Lb7bIk1DceY6ghTaqrnVNAcmPyD7whznIyqmUFRXxQKo812dB1faRtHVi5JcUIlWMifBOJCKkmLBLYjcs4mY9oYEhw5cRkIhpmimbYuvZYaYM0lHYpzszpbGVAZWSyYBaax7OPS0bUvKnuBbunbDZrNdjz2lRImJSRJuypQciJLXi7ARR2iEzgfEObwYyCFF8C9YjDnxNG5pUJpPJHpq9ScILgsOR3FpFdxWqW429SR7VwOzDHNTp74OGg+uduB5EbwIwTnaKinz623/CwnMrJwFNgk6B/0mcHkRuNvu8F4oUmiXEsXWc/naLV4+vyCEllICxQklC40HXye3hJUnRo0MvuM6F6KbSX2B8zraDwLzORIOqPk8YdCWA/IaADkHWTMoVX9I1snmFCkyKpQ9tnbt6MIPqFt5vpR5PB9H23Hlg8tACxL24kRSgKWz5xQ5e+41ciwvidaWZWXlvhXPGhipO3FzlMAavdXeJvO8dPa4cyDpeLNigdZCMlolE7KVfU4aNa0zR232TYuFQy2V1q+EaM2MFwQRVg7dcn3WyUxssgpM66NyqNmd65hSQ3I9s8yEUChzpB0C5+rXTt187ejlLpf9R9i6HVkH/GHEucxTrBuplKd48TxVx9kM3gVuaeFOgjtRuPXA9rV798B2hOlQr40ThkeK212xuWsRkGt6sm8IN4oLA81Zxt3bcOfuR8m7wj5e1fHxGM2OcStcDGd8592X2f3bL+J/7qcokwnMuvQNuC/9Cno94XWPtJm8H2FWxAfisyd2LnZb2H8Br8JwM7DZXvDtP/B74H/Y8+aN7atRGLUYj0eq4HMBQkPRTKjjZmKxfVLOsZKi1EBoUT9cArFHVNHM5ToSCD4RarvWJmc2H3mFvOlo9olZd/ThCopfLVmYMzGb/lmppcIk4NXQsGXSVThxLMhcFo8nr53Ly2sEmCRzrp4rMk+j8Fpn7gtgyt/bUuha4SMPoZ/N05RsvLKmTsPBNTid1+SoVK0tX47jc6hWJp0KjQpSYHaZUkCyEgdD6K/mCRXF42qZyBGCp2162m1FdluPRHMSEDHXk3Ef2JyDSF7FhBNqWo1U94o2o7lQ8hlda8de8kjftLT9FnEBNB+7TQurY0eaZoZ5h44NrzZbPjJe85GkfJQtd/KFvb45Rzb30NAjboOomB9iLvis6Fz9Lccb3Dzh5hkmpUSrJqQM0wL2UeUk9P0VgSI8Z5GznPf3Jbx1VjWn3ufLmLlyZVO17WEdl8eAa5F12emRwxQBxBxJPEon0ATYtTvatqepAUnbbXD9BtdsoGlQ7xEXEAphssDMhx7vboyGsX9GThFf5013gtyAxUwO8xL2pQZn5WQNytm6SzXjSJBnYspM5ZiMyGhWdc31TDpkQic0fbYu25IoVXKi5ERKE2MaiGmyOV8LWmazbwLmVIjZ0dLTtB1t04FEQmjp25ZtbxIdzgkxTaQ50paGSRNTdrgS8VVuwXeBs7YluIagVfUaJUlDJhEX5FMLSRVF8GrIrlTLw2XMh1wI0lVjdG+K/wXUZ5JkfB1NQWvVqiKUVVu5lv3r+VKrGuH1Bd3Pr7/9lg7MlpLcGu0vjwt0faDtlH7b0IcWHwpNjSD6Wx23b59xa7OhuEDOgaSFnIQcGkiWrYxERp1gHoheiEVM76Sfoa8lqb3CjYPQ2o9rURpUh+ejHGyxERTEV6NvqaTRBTZ39bss2cQR8tYjycQEC/U42XwQsKXL+0+RMj0J1uT9rz2VyigngaCevHCZkJCTBbK+0NX3gQU3cqorVkUHBUHUH2HOtUazhnFrECVLy3318TTUrRwlLbIVcQXjmDnnVt/BZb8LKrdEpAprQ8WLm4GYtkplJkQNodB9DcyaHTvtyakjby9xcaDLynZQmsOEu6kY0METROjjRAgNYRDK/gm7XaSbLNRo55adc4TQcNNk2nni1jhzsd9zZyjs9vVrZFYnikFakmRUM9ONI8Wn9pruKzQ+Mactu90ZGgf0yTPk5hLutIRKkivDM4bPPyRdvop2yre82qAfnRnf/hJN5Qq5wzXD1SOaMeOaQspCexXx/Y7p7hb/0L5jfCcS+o4inrBp4L1HlJc/wjf/776H9H83jcL7ZBpnRPIBuEXDPoDPIw1Hkv0WQ5huffqb+NT/8f/C/f/b/5XPf+GneQrMlZ69IfIqgcuX7+HbhvHqijfHpzyJCVUIddLdEpCXXqabEoN3+MM1xJGJgBzs2MehsE9wo3b7HrAyl8eI/ctiHeuwEQdPaqlpI1b6XJIF8wmtDhEoSTzPouO17liEFa0IRyO8fKXcGjLP7u5giJBLRRCh84EzhCSRJIVc1Hz45IjshFzRM1eDypPJHxWmuljvhwNaEj54JNk9GFww7lBrb/KNQ7xJrVDMmiYOcFbMaWAhSxdVXEgW5OTINCU6HyiaCLWm4z30rjeEIzQQp5pUyUp1oN5f89RwJ7W8NCY+Onk+Ujy3SgCxeyO7Fhd2RGx+dATQSM7JtCFrg4Pe3MAYkVnXsqkU6qJr2+L9e6oRpiKUF7LaJXhagjZZkk9q089y/LBqatlpU7wr9Plot7PIkSCeIqaJlUpkisJYJ0hRk9MJAp13bLuedtMS2jv4tsVtLCCRTY82LamvzjLBU7TDlYRMFoR7sSakNk6kNDNTrL0ql+ekkVYm4wIOwMqplUVLz4Ekc7tQjUxxJhfIsZDWGkahcdDMhf3Nnu7WOZvgKB6KZGK1uZryDTFNjHOygFCMaao5ESvFZ5wnvG+QVmk3ViNcZUck0NTAP+MR16IidKmgxZHiEvDUJLwN9KEh+A6KlW3FOQiFSYVQx3MqatQJzURNeO/JpZDJazJVCiSxpCZ4s95zFYVwEllQL6mgwOgcGguTh11ZjOWPY2v5KSd6nb/W9ls7MJOF3HgSHLhaZfQGXnW7jl3b0wQIYpPW2XnHxSaw6wLFdcwZYvFkB1kbSm3VaRLkEUppSbonU5jKSAwZbltglm8i/sEBKR3Od6jrUGkqInacwK3ikFCvFfES88XTY5nCJoQlQJE1rtMCuZhPH7BaJS2TyOn2voCjPrBmiy+ibHJ82ZIJou9XYv/AD5EjcrHs27l69Mv3cBUdY0Gwqsm4cyfHpCvCBTbRqNV+7YhEEDKqJ6rl9TNWwmxRU+3X5zXN1BmhYjlXGQsoI8cJuH6V9bs7BfWmNbWdPPtH9oV7DZzR4wZBuoCPiW4oNM8C/irgrmoAd5NoYot3mdA8w+fMvrmiu/02w3AbgDvDDcE3+Jh5IkIcI+fjyO0x082JxcJP1MygVawDT7NxaCY817U77ey997gtSnd2j5zO8XnH9PmHSD7Q7d6gNJZ9j+PbhMvb5PEGkQ5HYfpIA+UZfb3gcZ5x1wPjnOjPPFqEKQ5sSfhHE7EKmLablulJJlxeEJ60SHuLKX+ZfrjNN//u3wPAW//ynxsaANw0nrPs2PnCCJzpju/4jt8GwPYz30b55m9GXv8Y6cEv88r/6f/AnX/7BR7/+D8EtfDtlc98FC5bePTUFu7Y8PG+4/rtt9C3HpHe+1W7RijbN94gfemXaecRH2dubgquTMzVu+5ZhIfAlRhal6nG1RznlNNxvhQfnpG5rANp4RWNdUF3uZL1neMqq+kg1SHtFSYB7wphD/f2mZtXA2EqlBLIFTrUpCRiDbwiBV1L/MuIXlwlKGb0XKqGEtkQgiVJPcwTpExpwon+HzjnccHmwdA55Mb2nTP4oujsKFNE+4TUoEtKMcSqeJSBEoVcNszzFbmzAOLy4g45Kpt+R+ha5snGrGJ2PVqzCy2enByvJMfrs/B6bHgtb+iKX+Ek5xtqmzyuZJwIuZjHsE4D5aYiwPsZHRJEWflDav1U6/y12CA91yB1AomtIaOezIXuGAi/yAZSMcS+EaVVpXcWkG12F7Q7Cyz9pqM4S1tVQWfTtDzMaeWF5ZzwmmjF9N+azQa/2YHfIm2HbCvna7NF+h7teqRpkeDx4tEUcWNFzFXxKeHnHU2aUSJNSWsTybJZg5ydnAUxLEWRUpCl/KhUdQFLN3Iu5raQLJAEkKq46ucJbqBJZ2xpKa7q09WLkRhJWk3XPWhwaMyUHEkvBGZNmtnUClIpae24X8rpyxjeND0xFJoSaXNhILM0GTfFU5ynaVpLqtXjtVgCXxRXsxiHJUNlPuAaT8iu6pa5FUAwXqXhYkEC3lny4Sg0Ck09rpQStImspiww1+7ccoJyz6I0FAQh/ccQmK3lvHI8n0s8ICJW6O09bdfSN6zt203jaYMpaItXU0FXYbGlcJUckiQzlZkpZWKa2eM45EwMA5zXlfPlDr7QICUgLlCkATwqDcJ8giDVGz4XCsmCjBf9OVbxiWoYXqMly8afD65OhRBPt+OEImsmv2yZI1S/Bl4nQSEnaNlzh/XC/rUeZakI2srTWIPkuhB4Zwhi1tVvrEZtR0V+EeNALAqawHOeU8BiLi6Vc7dQd1ETCFQtUBKuZOucKZm0yGXkZEKxi4vA1/liqzWnE+uOzQmH0hwKlU9N/2Tm1htbXFFcVLrhNcLTB/j714SnhfapHVf3CJpDwZNBJ1xTeV533yFXPtEwPUOlJarS6ch4GLg9Fc7GQpx0Deh7YEJJukiImPxEcZG5jp+b0XFrHGncFf5Ww/XVSHujdDFBn2j+s7v2mdv3GO9+lEcPHnC5+yjnn75LfvuKtJ8IC3lrPhBvMtsOxoPN4JvGk55G8hBxtd076kzXO8rhEaXfEHfK5iA8e/KYzd2PAPCf3LnDTz9+TC8tG1VEIikW/tPPfCcX3/MHKa+b0TTnW+T1e/Cr7+F+9SHTIdB913fyyn/y7XBtjgS8fI8y75HP/jT+l34F+oCfhcuPfJzhmz+FvvspAA6f/Z+Q6yui7AkpkwhEJmSqnDcMIYtYILUEZAlbuKSWq063RfrhysNZlbA4nV5z/XvSQodjT+ZRglc3FU2KyiTQFrNEffU68bUMvQu4RimLT28BPxdDxHOiOIGULbBxx88qGKrXFXOAaERwrsG1Sqk371QSOUecBsqJl6BzlRCDdY05qZQghZIzaczkQXG9HIMpiiVUWvB+ptEOUianA3N1U/C3X4Uys9tsaPuO/XVF+lioA/V+VCuP3U09d73jUjt2JdR6cUW5S4uWQqAx9fgcrcSWIzIfyFdP6nnNaCqkCFOyuScnmz3XkqEez5me/MARTeP0sROESeVkXmBBuaw0vwHOGuFss6Xbbkh3X6fZVleMtqWIUJZgYE5IipzPmW1ZLPpmPAXnPLiW0m6Qpke8Q0IHC7eq3VLaLa7rcb5BnaAh13mqlqQ3G9w40EwtZWopqUWnaGbbx8N/ruNT5CjQXDThKxqhucoLOesQjlooYnNurKh6csnEhr2SdMDFRFc6ogYLQmo78txkXLCABifG9VOlpMw8VuWDaaBpGsY4s0uJVCKlZIrOqObjGFRFa7d+KBCDJweT6Vn8Tp1rbMw4R+c7OhwlRzKCtMfATDWipeAaTxOciSlH0OKgWEikJRCankasQ9WRERGaIGSEWe34B81o59jGzNzCNlpHLJkqMw3n0bGVQJMdffyPIDBbrH5OpRaWYC2XYoGQV0Jn7dxS63qhM50U9UqSWJGiGfW2yKeq6h3zyJwmDuOBIU08LQ03kpnnGc2PAHD3XoPdjjxfmSm3D+BqkObcSoRUame5Kmi2m1afv0iiBeeOUhlalgzmCIVS91PUIPlcy5pC5UAulUFdSNTCaYt31iNydLotE9MSbJ0+thBplgDM1YecGLl2WciWb7OYyap48AGRsgahRR3OhRVRU2zga5E127dyvVszW6l/W/akVoagonPOmcK+eJwo82idSaF+i/SBuOIRITyxwavdVIpTjxDpcDQ50txYZnrvC0+485E3AKEZJ8K7if7xY9xbj8nvFPwT+47dkxYZMlO+BqdIU+g2Ac7e4qy3KK+XiRxactsw5SvOh4HzpwfapzPz/qivJFQDcXQNKj3gyzHs3o8Rnzd4dhSUTduiUchXiv67m1VL79Z3XfDm439LF7a8/eArvHtxwVnzDh/x5u8HEFxmTqDW8Ecf4HrIDGpk8cq75qKBqAWnHTor/f2npJuZbbeh8V8C4I2P3OMXHj/m4ApaEgfn+b3f+T3k/81/hV68jKtyH9y7C9cz+QtfJORruOnh534B7l4wn1vEGDYdrm8Yv+lTdLm2Drx3QNuOzZffJdYO1fNv+13M432aaEri8+E9mgL7bAEZWNlSMT01xWKCCPRFyV4Y6nhpl6GvFqwNCk/FOuZivc+sjcHG0izgi424+xleW1f9gs+e4jJzFl65mvG9ENR0D1feSVRaJzTFId4TYmYGojvakDWqVKoq6uz/JZjkTuOMKwdmJJ9Kptce52y+c96MyBdOTts1ODdBNM5rzkocI3GCTp35+ILpRMrGoh7AS8CrouLXZgMRod/2nOctu92GZ4+clYbUSkrTek4Ld/LI5dBzL2y5yHtCKtWxvS5HwZI5skMlU1IkJ2fdh/tn+L1dSRkjZcrMkxIjtZOa9ZpS751Yf6/WRfo8T2zdltzxBO1QjuioYInSuYdb247+bEd7eRc5uyRcvmR+QGBBExBUkATMMxojOsz4Gmg4FBcE2hZte7zrwAUT5AgBqo0avkVCg4QGfGNzYYrkFNeo0iErlcM584tUVbyzIvlC2tfKTV4CM8FOgqTCUsMT5xA1R5pS18IxFyZ1jPX6iwrMmeAE3wfyPBDKhqCBLiht9YeesDkEL8Yx9NloIjlzONg17A8bXGgJ7cQwTTRNIKaJOY/MeWKuCOOckjWBOePlFa+oV5pZ1xIlgFNzp2h9oA8eoikEONV1Ui0FNGc635DKjMyKSAYCUl8UQs9lf0HvPa6BQxwoZPqmIascK2HOGo/cNLBJBvROufLO6iiagmkr0gbm9BsLuX5j4duH24fbh9uH24fbh9uH24fbh9t/8O23NGKW1aqPzh2rX6UYilMKlGjQKBRDziopU5xF8EUSh3kmqZI5UIoSc2Y/Wv17GCfmeWQ4RA4uceVHDnkkHnSN+Hd3M+wC7qohBWfkxdBQsjOdtXqsmjmKCBaD9lMp+FPOVK2rw/Nl2ly/16kejmV8R9Pa5XOWC7qiXnJUpF40dV7Ej1bS/8n7XkTUpP6zZAorv+9kf/Y5UksXaqxpBN+ElcAuxa86Uc+TbwtLi6QuqZ0qRRxeTJuupIx4R1MNcNtNTxBHydmQspyrl6BSbSsJxVE4ivjCEVlcyixwUtLQJbMPnLFBGGj2dlwXP/82Fx+9R75b6G+eEd7eE+7fUL4qpKcdeajw9vUVEUdE6Br7LpGOwiPKLevKlCYStoHdbmMem49vSPsD41h1vioOmcj0GOG8r1eiqNBQCCcI6jgeaC4/SpoK7WYH20B0e8LTGfdFG6vpjczFywNfyQecu0S+5nl898Crl9A+qidisPPzGLgssJ+NpN8hTFGPmZwXGJRUBkNhnwF3PJQD5VHtknr9Fr/z05/hX37x8ySBz7zyzYT/5o8QHidIE+m28e04P6e89e8IX/qfyeWC7tBy+FhD96tv4rp7ALjP79Hrh7RuQ3rtNcI+kS9mmi+/S3rtDk3FSIZffoT/yMdhvoW+9y/pvPGnXK4lBirZX4/lrkX6wGGSHMvrGqy6NmdrUEhqSvIviWOstIc22/nJVSxwUtvfowyajghwQ0B9Jkvg1j7RlEIrdo/41tWxGihSyJrJKRODoJFayrSLnZ01ILlcqRkRklO0ZFpk5QHNlPVzF8QsOE/jA20lS7etGTl7UWKV6CgpMw9Kk8qq1RSw79IG4+yYI0mm8f0qcqqa6ftzLjSy2W2NUK3FOLQLkx4QUV6LM3cHx92+5yINuHGCqYN5EYczFxWSUDQimiB5yjhTpj1hrjytFCEVNFUumR7nuFUawQgQK9dsnTuXMuYJf/aUq7zQRepO7Hwh7BrHrV3DxcUWubxNufMS+dZLsNkitaYmIhWp8qZHMo2kOBK6zZH2IWJejW2Lth3ONzgJqLNS3Kp06pZKQ0aYUC34GPFzpsSlLFpOJjB7vW8ckoqJmtenvFTeY7HOVQl1fUkJrSiXJxgXTCOZTCiFkkxYt1u8U6MyRGiTcnm24700Itk4XY1XQr03vLM1rxSjYJhHZybnzFwFh8dxpOk2TNPEPM/Mc0NMIzHNpDwTY202ihGcp9dsw6QooSQm8rq2hBjMN9c5vHO04vGNw6tx15aOSzpHjqDSmPOEZlMR8A2Nt1l2u7nNrXbDdtPgApR9YYqj3S80q8QVnVV78JmzYWbsPG5QJmtFXr9jqx4XmrUb9dfbfssHZuKhU12VQVwj5CA8Pkxc5T13dCC6HUJLqKWyxsFhHnCuJ6Y9U4pMOVFQpjjzbG+B2fXhwM3NDRI9jQpD0zK5kSiR8U3jVuy2M/n1TH6nJ+gNpQmkdgfDHqd78olmiccm9xVyr9pDawkQIcVlDhOkWtdQ+UXuuUDmGGSdBhjL2FNsklnKjGtXkjw3FxmxuAZ37uS9i1YZdd/PfTas8h3KscS6vFBCQLxHpMGJIxV/1E1bSq0IIrWIUzKSImM9AhWPJDVQ2SfmONIGz6a/pD3bEhZirFNSijiNNC7RNJkpR6Y5HbkJTiGXWh6t14BjILraVD13eJGN8zTNREMmLDPbFchP/ALnZ0JiRt7tGfYz+UbReVx1q0yzq5DEIG0T450BIT6yVv/kgTDRhBHXQYkFP1kA7ypPzs6Tkf03WKNIEF07qpaD3gA3w4HztiBnPbl3JDJd7q29+8oCs/DVnrvffsGVO/B0esKhePax8JXPdHz0aZWI6TLewTBbiQSndAUOKMVbVyJAVCENSvYwHoRtG/Ap0s0OObfrOL/7NV569Q2+4+wO/2p4xEe/7w+h27sUP6LNHjc+sX39ytt0n/tV5n1Cmxvm6/v4B+CyIl/7nB1XntFnT5E2wPAM6W6TVXH7a8LHXmP/6CkAu8aTyOitS8o3fhf5F36O8WZkT2GuQcvCu0k1XmgxEn/GunBPtEnxtfQzKPR14R85ZkkZT0PmBqVTR6KwkcC1Jn6p3lDf6mGfZrYOOs3cehrZaaTsevwhc+6sbJUkcx0KrvV47WhKZKxsCLeYOavD1bK+qN0fSRxb70gq+BpVznFkUoVQ0Kh1ceosCK3HFcM1oVGGAWue8aBEynhGHmfcpdWt42xBSZSJLXdw3tFGpYSMZBtbD3Xi9eYeUs54/e5rfGH3NvPVE3YKczF1f4BOCy+1wsbdsBtaxKnxw6aA1uWolYy4ApKNIzdn3DjhD9fIzUSuAZxmkFS1yhxVb80EpI9yJvoc180tE+YL9AytelRuva527Vs5dtftUO62EM62xDsvo5cvE+69RLi8RepbdCGX+xbUkeKE5IQLLYwNKvNa3lLx+KbF+RYNG7TdkJpaviwZt1Bg1q59gTzjciY5D6Lra/KwR6aRMo5onmsTlAXLp3xIkVqmrWMasdzZeyH45biszqlSKLmwF4ejwTGTmoVa0hLjxKYIlxcj13tQMklC7dxcuIIwOyje6N5JoQ0NKUXmqr1xPYDvAzIGNmNjGmT1/XMZuKkcC01CK47YZlwRmuwZVUjFsa8Xe0fAOQi5Ydd0OGdC5Z6OyvgGIOhsdmJFcQnmUqU8aDmvlIiXzuCO62iaHXmTGeVAeZLIxRP9tKr6n7UwpxbxQnA2R9zojDjrxgXonDDFCX8t7Jes79fZfksHZq0uIo2Qq+CRcfgFxsyjBxO3dwOXZ3u63hO93dANkOdEiQM+JWLJzEWJOXE9HHhybRPSzX4ixkzxwq40RJ258ZmbkHh2ZcHb7iv36d2r5D5QBiGLM46BD+gLHLIla4u1c0j1iPqB1cHrPQewaqKoKNXhZN1P1mPmt0xCnuP/Wf4vR0TsNFtca+QcUbLlueWD1ljq+agFVxeoRaNn3dZss4Z5ekL6XzdnN7/WSac+mvHGXaAih1LIokAm+IbNZsvm4hy/6XC1L93EhDNkJTaFmRlPxLlMs2iiORNPTBXJWALQNZA9ObyFe9EqtGQ2KmxxhHqDNU8hDZEbZ9fPX40rhyWKcZfgyGkpcnISK89vyZdKRXF8ysTJxsDi/nR6Sl09/+0SOOgxGG9OrkeJiTI+opx/Cpc8XdeSXSHPM/l+RUjuD+TpFV658zGGx2/y5PHIW35D+HjHN/7PdmQ3GRDHVgo3qvRlGYMV0KxDekqF4jHldc2Ij0gCaRxjtOBTnBKnR7zyyo4/2H8r/I5vg9xQ2OOnkfjVd+y8Urh5613aw0TnFZ0HSv8Y8Y5D7fDaPn5MvHrC6APdowOHT9/Dv/mY/PJLxGfv4O4/BmBuGuJmQ9jf0H304+itM/b/+idJT0fGZVzX67fmE85kak7P+/GqvRi0W7PAIto5kukWpKVemwPWifq0Ep2iFwKKFMhSaPbQT8BFS8ozOdrI2bTKPDkmCqHdkKZEI0JxuqLM4rzdz7nySBN4PM4FclLm2lxyOByIdfFzzpGrKOnaAIDpmDWdJ4RprTJAJs6ZnGWFjIITyIWueBt/ruB9MGumZXzPic43xI1y9/wWd7fnvHfz1OawBkK9Hy8ULvLIbmpoRfAo+QbcKMjJalRKMakYVyOJXMySJ8dVUEv1WA0w3li1PjqZK0+DsA+6tiso9sI0tXCwTu/HzkO73eDPdoTLS7hzCZe34GyLLpwwQMWhxSNNj0sZcYN1PmpZ5zh8MH5ZaCihRi6N6dZp1hWJFFVK1atT8ajLuFysa3I5D7kQYyTNkRwLJWfrMKzf4wMHNjYHiTcUs6xcaENEtfLUnArJienkrTqTptkmUdm0gTYrMSttDvjiKBV2HNT4VqqsaJVWBHWpTJSYmKaJrp1J80hsZ5xYZ+Y0TStiRoYmOFR9Rb/y+nMqCbWMb/stBGcNA4W8WjIVDbhSKENkziNFYdJMG5R+a4HE2dkZt7Wh3ZzxzM241CH9gTkPFsjW79j3QlMybkooHldyDfIyqQbqI8KUMzEf9RZ/ve23dGC2b4QzlBGsgwnYZrPw2avw3qNCu72GprCfD9zqDWnZdg1BMh5rb51LZoiJIc082x94emWB2WEuaBFTaS8b5iA80QNP3cRLrZWkbt4b6M8jrvWUNqCTqf+L71DfQbWnsSHvSGrkdbunDE9fggNV69wSWZoa7PFTo3I4QvWJOgHVzO8UJF0mlueCLuoEdjJxLQvP4vG9HtYC9fM88rYcT9ATROHkuIJAERPHsACwFonckm3JWi/QxYZJzVZGagAUVEguU8g4hLbt2W4uKOct0gR8XVhijowy4NpE0EQXZ4omOsmrd+pqAn9yLuB5K6bl92KF0qvQOsXnRBDHRu028VrQ+ei1OWPaPqMIY/3bzo9NPOaEYK/PKhR3JIQ6tXPosi3yXhfD4ZPImFp+QJBakl40lIB1cpAa4efhES6/geDRLBRn8E8YqwjoLz7A/3JD/ztf5uXXv4F9/ir7d2cef/IVhmJyEzQwj9ZOfqiSJaWBkEwwdl50frAqTaeZ3c6OaTw4um2mr34yiR35nUdszs/Z/NAfJkaYz4RmPlAevoe7es/29exAOx5ovKC5kLXAzcyUI92FkXGnQ6QjoFdXzJeZ7nM33LTC5kmBJzt6Z7hGvDhjUwJyfUN0V4S25+L2J3n45JdYPCdKRaAXv8JUynPOEy8GYqebxwKn87rAJqds9HkF+IyhcI8qWH7TCBc1y2oEyhXcnYV93xJzXu+zkhyTd9UAuRBQW+Trbxs4bh3LJWecC6haUGbXxX6Ph4EhzqbY7oM14ohYx90ybqqkEOJMY0oVLUKaZ9Io+NpB5jw4KTTZBp84j/iWgGcJZ8s4o95x4be8dHaLW5fn3L9vKH/miExfFridE7c0stVqiD40yFy7uOuBibgqo7MEKJaIaS5rgFDykcqxNDad0jGWa3iqUbbOc3JEzdftNJpTrOHGwbY+vt22NLsdabcj73ZwfgbbHvoOwhZxiySIacC5WAOtMoPWtK16LGsIpODBOcS7WjYuCMkipjoX5gKos4anUnBFEDzuRNdFVCk1SNF87GKss/Bz47pwrHYUsYSwCOv4UZ9BastYKbhs6vnAKiQuJePVE28iuxZ6deQitNISylFrJCeItVThqahkNLunpXs4p2INJEXJ2QzFvZcamA3Eaaif2VCcSV8Ibg3s7LvK8XdZEg+HQwne41whqFuFnKMkXHbkYSDlibkIU1F8SGtDkguBW2e3aPuecVTC3NLtGob9gabrqWwatn0hqcdJQVyiLZ6uyXhJjIv/qELKhZJGYnnRwOuDt9/SgdmghTNsseorRFhcRoojq+cwCA8e7AmSyGMmXdjFnDpHHxytD4ySmFPmZhwYYmR/mNkPVbm4dnhsgpIyaPCMJXFVrojdywDMgzA+eoLz57jQ4F1Lcg3qGwgdWtuQS57IWtbJwMaVqQ8vuJFQdYrkhCPG+5GxzLHVOS8vWrYlIVv+1IqYLe/heVRs2d/KMTsJ0JaXrGXMk32bsOYLk2Cx9nTRmmFVuKfoURQW59bJ8rgC2t37QarIbfBsNh3hvKd0wTgKtbvGO4HgkUPCzXvyzY0B6UrlmtVz6E4mYbVPSpTnJEiWMmensEHZsAROZZUzQIynpFhWPDkxEdCT4NpeduyCjeti8X7pkvU7qq9Bg4WQp1l6o8YlWjR6G1fL4Vo7jTiWaHTKhMMVcr5hisV0F30kbCwLnN+C9l+/hXzMc/eTL1HeeIVf/MUv8vi1j/HgwgKzlwfrYBzUlvhRltKHEE94WSNwhnDL2eJ4M8F573CdUqo47jQO7BxMn/400vaUNNN7Qc4D+kvXhJrx37z1gF3TMRz2eFG6s5ZhuCI4ZfpaNaZve8owoG3D/PAZrjtHsufmvUdsbyWugpU82q5n/9Z9mpunNM4TS8JtX6E//xpPr5/auT+Z1Jd7aWmgfg5xhrUM5upiRkWvc118vBpa2qkwOVMHL1qYHRzqBX+aC5eNo/hi8ieTcC8lDq4QvLAJljDOAo0rxAZ0PFi3siq5mMfrOrYqmmHd6BhPKAuuPd6R0zCyn0YrgzbBAjq3WKNVyoBTXOtsIczFusCDULKSJ49Mlfu2sVJqEmicmvaTdHh1K7d3jDNXkvhoOGPcbtmcn5EcyGSB3cL5PFfYFgt2AorTQJYaePrFXL0x9wDxmCC3dbouCvUvdlIul+yUQ3b6fJbnfy+J2iqHAYbil+W+1fU17qS03fQN2vWwPUN358hmC5se7VucptUyTjUjaYRxJN/coM+ewrgnuQ2+6r55H8hZa9JrlR/VGTQbMhjXA6vSQjVBK45Su/ZPB2uullqpFLOnU5P9EY5B8Yui4VLv7cU7087hMTDLxdC35GvAf1JCccB4NSJjpA/CvmQET04FibXCkARRxyR55e8Vanf50qVfDAktJZFjIc4J7RwxJ4ZhYq7BjcNkOpBQk9OjMPvpb+NdKyJLAlK5y571YmtRikaQSMwmoFscECIFA1JSStx0jj4MTHlP6wPS7bgeDqQScbWsG5rWvGzU43xDypmYjVO9JBBZA2NKVikr/xGUMv0ETetAhDDZYLjeYcTVUXBBmG6Ua1cIcyIdrMTypM80rcOFlqIzqWTmKTGVRJwLcx1Y3jeEkGnbwDhmtjimNjH4kZvaxnu7P2eaIm2ZTYE5tJTQkKVD3YZS3eSLn9AIp/eTEfqPN5fUeVX0eNPAMag45Y8pVfpCfm1I/vSxJThDnw8M1ozz5DOeM1OoE9oyiS26Tct7l892Up9MxXzCNFU/y5MamCREfQ1OjwfrNa3lmixCwdEBu27D5mIDZ46m9UxSSDWrafNM2F+R33uPNBgJ3eRrj6K9HkMpsrOgy7wNTbQ1iXJ6ugLGJ+mAbTEidYMtvnZcR+2rWIpl6OX5IBqMuzTrMvHreo6W7PT0vBaM4O/1WDqBY6lssXtHdF3gXIU1V10ijGgrg5DHhN/2JotQElNuKMvO2gsOP/sOu0/cJ28yZ69d8Kmd8i8PT3j0MZOleOnfJjTMXBWlw7PXjDcNTxocudbdE7Arglfh0WDct/NdYj/AfGUfd+flhjQn8jf+52y6LeJ7hgdPKY/u02vD4ctfBWCLkL3Htz1+mBnee0TTeKR3bCrQIL0ybwPlpqEMB678nnQdCdIzbUam+28BsGs64s2e5Avib9NdXDBPb3P73l3evX5m59Q7xmy6RK5eI2DVrPqge0gw4eMitrAsOHiDXeutKGO9rwIWvC2X+n6Bj4tx9JwHjcLFPiNB6F0g1kTDzTNFGyQmGu+N5ySekuUkCK9hexHUwZQnK286QxGW+SSNA4c8r36ygjc5BQl4X0MNp7gOQutJByNRi6tWcZND6oG5zpTk1ZkHYYOjEUcojlwXKI0TN2mi2d1mWzrOdhtc45HZjmhZaC7UhLsbZ8h4cYILAY1iKBBGa3AihhZKXsWnTbT7VBrJrscSQC+Jl754DfWYCFqSVCWEToK6xc7udNpz2FwRFhWPYOVHvz0zsn/bImLnvuS0lnUlRhiv4foR5ckjeHaNm5Xiekpn95nbJprNuYWAzrw9tTg0JchqMhtQaTGLaLZC5X6dLiSlGJJkDgMFzYVGwIlDyGvwLFLPg+M5r+SCUP5/7P3bry/ZltcHfsacMyJ+l3XZO++nzjlVrgsUFIWMoemuArcfLAQPfgOp34AHP6ECycYPCMsPliy7ZP8BqF8s+wkhI4GQQEJcBFhA0UhI7W4oLlXU5dwy82Tmzr0uv98vIuZl9MOYMyLWyn2qDnRb7iNySpl7rd+KX1xmzMsY3/Ed31EPErVkKlyVnMpQxAqMt85RZ5GPcsmk1zPde46cJhOToyyGoGStIejKB6/97GStnUwx+Yw0z6SpMPtEFscYE5cpEqsci8cx+2icPFViLstzL4anlEVSSZzig0elkEqykG1DK2ImlwmRmaKRpEofhH0vSNUnmzXyCSeOyfjCu+MVncJ+vuM0vkab3plmvK+6qH6wxIakjCEuSTa5FLK3MmzxqdvwPdsPtGEmYjpVmguXloUHxIrSOEk4HFPyPJwKsRZ9lW4CV5X0y7SW0KByg1qSQGe8gOSt811KPPaRyyExXaq6scA4O9wlE3AEP5i3FzzFD5Rki1umAxdtLjXaBE/h9GYgtVDmtm09wTUr8xmf4lnbepLb37fGVNn+V9HxJfxZDwrPrtHqxbVMttakVOQBK8PjnGw8u9XKbBk6Fqqwj10VywSbxEGFXd+zPxhalgYofsY7cMl26/TqM/J3PkYeYrXAgCo02agcInaPvTM0MgMRE4nMbDhmYl8PVP2qtrKLIVRgobvUEKvakc3YNWPSLjprWUIGLczcFv0n77UulK2/l75cowEVAbSQiqvX9K5Fg2tICitUnPIZ33m09MYL8oIrBzKvretvhPFVz/D3T1yuRlz4gB+5ivyDb3+Hb7xlyNSP1CSFWSwTM9DGqKGMrS+sqk/hHhgJvNcl4hnuR7iyKD/n84Tonv4rv4XU78lO6McHeP0affURu6oXWNQhdDXZwNH5gTzO5DHS17p7p1cnoBCA0UXi/YVI4fpm4PGSCLWcz/23vknOmTjNvJWFy6cfkzgzT/D1r1oB9m9950MO4riviF3vHOds595m6mk15BccqqI1Xg3dArhSC4svzhT1P2VBgD8SZU6KH0yTLFM4virs+4G+nznVMJcfhDIl4pwI0lP8vCC5q5OmFs6qRRDPKJ3YfItxTXpxFMYUSangBuO84sz48dLK3BixMXTB0IwMpVhZmjILvhJ3XZlwTnG+ZyfBkOTZ8h1d9VpknrmczkwvlIBy2/Xsup7Zm8NU66ZzjJassnfgXEE6M/ZEdCF9ewkU8Samqq4S1T0qRlMpb1B52paoy+srfGJsaUU/21L0pioPUv/Q5p0D2vbgugHtejT05kS2UGvKkCNajSmZRni8h7tXcP8KOU9IFFwZrYwUkE8nuBkpxwOlG6BGW4gzmtyy+KrzqHe1LxRJiohHitJEe1OayTmSUiLWcKYX6xRRc0DB1nGpA1RYbTvNK4exOLWqK8Ut6vWmxr/2lRZQn5EM+bsF974i0wW9v+CuBXY1XCuZkhTXaIGlhoc3qGfba3OOkCuCpoXLNDLPJkht5yqkWJjFSiiVVmeasqy7LevT5kABpxQKWTK5PReQiKgmoxSpRTL6zrFzLCoJicSlnAj+SHe4pdvvSSLsU+ASo01kYIqZvhqbnQ8E7yllYvCeUvUCxzThXcE7o+h8P+0H2jD7kQgfDcq+QK6pMy+AUx9QScRcSEXtpTiPNvLf2bL5pmmmeOWwl8qNgiKOXE1wV1PTVS0FOFOYJTLuMnfFFtOJbNXkg9J5Qbw3Fyt0qOtp9d8QuyerTr+GJltCADwNt7mNcdYMpraGNMMssXqH28hgPQXP1pwlhKnwdPPZnL9lYpbNCZKsoRyhCZ6y1PBr53cYmuSK4jfpTjaBqoBh4weUteC63UervAlBCnsPu14IhwB7M3STjHQlIQ8GycTPXsEUCcGMpvkEfe3fRsnxUj08lbVAPNWG080E0BVdC22Bd+s7gipOXlpI2JEw6Dxh9fkav2dByCqHcfuc2+Z1NdqaQ5zBFrzNNQOmPq8VCXQO/EYGxblaLu444IcdboDkHN4PtgnXkkwlv8Zf3/D4yw/kvxe5TN/i7oOJ+CryryuR5v8s5ul2wCQFV43RSMv0qtes93p24EikGe5UcE7Z11qM+TOl/P7fY0KaE+igSJjI95+RP/42u2jzcVahTw5XhEkKWUb2g8PFSKoZDnr2qA+M5cy5TOQMu92OU7zH3ym7t626QZ7u+fCTT8DtOd703N3d8eKDn+T82b8g1+Sfr3/1h/jVb33LkifEOGZtHDjWvm+o9TZZJrAiLmDjpJVs2gMnbA4ltbI9AKcCDyK8XRR1jkzh+vVYS86UGlSHPE8QLSMvB6zOaL0RXUI/sgqkyooCqCpZE65ymCjKOE+UlPHqajjM16LmfX0+RUKh6we6EKx0TV2E8lQosfZKMQS8p6PD4RFSDXWlCu2UWRnPJyNrl8LB9wx94FRsnLeSX302En1XrLKG4hEZbPBv6+q5CnGICdniBBc6K3XU6u42rio8Qb/aZ7AaaE+2w+9VLPdJ0yUMuJQFwqMu4IIZjV7cSs0YI9TEC73co/ev0NevkNMFNxVKMtHrqtGKzjPECXfZU/oO6Qac31Vel6OtTMW5upcIgsOVgLqMpkjJsY6bmRKTccyq2Hq2lH/L4G7rOGtGcuuzUtoHDX1TSyzwmVRMJFiSkfcbzzGWQumMQzt+XAi/Veli5nKJSN+ZE2CDYgmVbh1lycVCx9SEN2fIU0wTPgckKufzI+M8kponGAqxZFyVm1AMITM6T50/mkmaQBLGUM6o12VuxGo5pzKRy8QYZ0O2PfRO6Vn5lxciThNd2PPi+ILQD1wyVnbrUbm/2P5/8JHUB3ZdR+cc3jucE/o+4Oo4HSehxEiOrpZBaFLX37v9QBtm/5dPPf/LMXOjwqEaUz4IV9qxK445KqIzqbunOEExV15zTygBpANf2Lke0cKUTDJD62ZQspCc4rLBoXNX8FKYSdx19mKiKLdyQF97si9GOu88LvSI66waACBYunTRvJLwpSFnbYGxrKKGmDVDqEHw21Bm2qBbyrqmfQEle7ZCNcOL7XEVYWoctFyPbx5S2KAIrTUzajW5aop5huQLWqomuqu4eQtHFKsZtqSAA5apGRaORhBl10E/CLqz/hQ8ffVOY5Wfz2mCBLF6ZCpQ6uqzyGAYcIQIdFjqfMKSFLqNoQGVaF+96q4tKDxdw5tRbL5YM/o2L8Yex/SvKtyiz09C1c6qxmDrS1dzqdq523sTqV5tfa5mtLdws3MW4pTrD3C6Q3aOyfXsy0CRjB9uAejOZ8T3PDjh9K8mvus+5ddezMwj3B+rp9h3DFOkD8JdKlxjYUyH0qGLdppTGPGMJXMTIKaOkcjtDjjbQX0XcD/9O5juvkG3+1G6y0TMZ0QDrjsQv/MrAMSbW/rdQI4zh67n8XQmiljmX6rOVDcRVHiYhBx7VC/M44l9vyftA2MdSw+f3HNK8NWvvOCz+ws75zl/+M85P54ZK6H6sNvzlffe519992NQCM4RSzFkSmQJbaqs86o1e0+rG5RFauKNcqjop1dI3i0cOjLce+Wd4pDimShcnR+J8wVf1DIggfmSkGiF32c34d1Ap8ZfbM4iWJZeybqS2lWtUgChzjsoOXGZpoqEFJxUPTPXMYRa07FzFSJ2TOIMFXO2GqU5U2Ijs4OTKmMgYgr6omRXiA31TMI8XdCY0FLouo4QwpLQtNugVDbeqxyM9BB6kLjw0PCttqfQnDURAW9VQ3ITWKvtTcGhrRP75PPNsqMbW3AFyGUxWoCldqTdg0dcQHxnjlhJuDxzScp+nNGajczljvz4Gn14JMSCZCHHmah5qYAicyJPI3IZ8IOFREvxxCB4DUgT6eh7tO+hG3DS4ymMTpB5Rmut3DhONSMzIcWSy9S1/WEdqx5blxo14wmfsvGhKqKrKgZqrNTXRUtPWaMI86vI/iEhFEoxrTCd10mjCskJoVrO63ut76dy42KauVwuSOcRp5wuj8xzXpKNnCuUklZjvEZclEyL1aoUM+CdQDXa1KlRUNSSI8D4YzHOXMbZ9tsAPtSKznVojZJJRdjvdrx3e40POx4mz3f7a2Tnuat8VZn3qGb66vR57403OnSUwfb+sVgYxzmrLPL9tO+PifZl+7J92b5sX7Yv25fty/Zl+9+9/UAjZv/hx5m/+DXY9UqqhFbtBa+ZU4ErCcyaOI+OXcjsawHZg+/Am+RnuPKmEJyFg7PsiWmBPEeyFNL8QHADx1KQWTjvzszHzwH4bPqcw82e69eRccwMscNdApPzMAT8ZFk4uYwUnZFK9K7lxiw8tgmBOdaQprARiIUV9dKNRb0Jf25l+FsNv4X3sgnJGAy89uOWr1beYNC3EKiTZ7IZUBVb7KIRC/+G4nHiKASrjameVnhckoIrCxlZyZQSkVyWIrmDgy4IafD4qxuD0N2Mlj2kM6G6UXoH0yjMTulqyE2CEYsXEGtBDKUSX5WSzWvsWGUzHIYMeq3Co84csRYybs/bEMLWZ+2YJRzZXkFFw4q+2UNdjq3IpKkp2Ehoni3t2tVjlVRDyOrofVl4O+KBF2/R3fx7pKsO748cwsAUwOkOV7mVabiiL559rxT1hO+8R7h7xfu7M//sR2xE/doPR376lzwxZ3Y4ZlcIBUQ9BxItWUwEkmb2wKRClMhRICSTkwDo9p7una8xnB3l9S8xhUD47Nu4T7+Je4jEqm6+K6bqHvqO6XJBRHgcLxz2PaXGfg5Dz8Pj5+x7mIKCP5DnmbvxgZ27scw34HNNHK4OjOczaZwZ+j3fenggO1nCKXeffMzh6poX+4HPL5MVzK4cROPrVfSwhqmd2ELZaUUKpFA570RMCHl0AEKnygXoSnkiVPudEvgxMs5HFCE8JPzr15zcwPh4X/v0jDgh+YLTDoqSe4+fQWbr01m81UZMkUedUUyDKhLx0uPqC4p+Ykwj5zTzMg9ohk6FJAntbNSH4ohS4DjjXge6aNRkX2BkIj4aajMcXlDcp+Su41HhSgemeeYiZVVADzOn88h9HNlRyGHHrivsCsyFRZ39WqFP0AfQnAiXkXlfcMWTa4jVi6+IuoUTxQVwfV0HCqGzMTF1nhwElzKKWmF6WhSi3hbr/F3mn9qapYCrg6KI1UPNG96rYCh6qAtvdI7gemSMhF2mdFAk0emIZg/nWkvyNCHnGafOwsu5oDHhvaxInibIhXIpuJgQZziikfUFkVrnYzjidnvoE8WdLRyZOkqaiKPV3c1pJpNQgURBpCNqNDETtwqTS2nPa2irB+P1OigbvRdxzjIygeyNQldgQZzA3uHswT1E8scdvHMizifk0tVnqQEpL2inpMdK3Shq0iwbTCinhHSOqUyE1DE/3nN6vMOLo2tVBEIG2SNksjpUHB4bJ1FqSFcKQrGwpRsIXvFeiHHmPJ6INfQbcyFPI5esjNlznDO9CnpzJBXr91R6XLkhHN6G/oZu8HTuxOB7NO+4r4Mq3V24OewIWkAHohTTQ/M7TlWRoXcnxhA5HHumy28aQ1/G7Q9s+8pF+OkRvnFU3q2y63NnRbMHZ/TuThwyJ8a7E6dKdOgOB479juPQc106C+M5R+x6cknkqj2mkgjFFhnySAmQveNOL7xOrwF41DMxCNO1Qz9JHPJsPCffEb1HKgFQXAd0lJwpsuZmbJJdFvK/feGLVIjy7N9mJCyG1haa3vxcakxUWflhjfS+6Pm0sOj2e8vJWIkJG74N2OYly8/gixKl2OahFpQLFkQEDJKGhGpAiwN15OwIZsbZ+xHFeQ99B85bWReFkmY0zpRzJf+nYuTNUp+n1DIguoYXtwKLTZDSbfrqSfLC5udcVu5d64dFO05WI60Z0V+YbvpF3sv3am+IdK5/07X727GlVFZSvfmu7/DdgPaPuP27aC8UCl6ucE4otdxKDgcu+wnxVzAJx7evOFx5dpcPGaoB9Mu/HX7sVzPHDCeUicrbpFhB7cYVqf+pWBhkSXQo65OU45E4j2hypgX46hX6rW+i6cz580ek7tYBoOhSrmS325FKZoozNZpGSILvBqZpouAoKTHXMhk5Zx5reZqYM3sXGB/PZFE+vaSFwNxaKsrpcmFO89KnlC+GLRuhfyvG3CLzy2bvwJVKvsdKgUl5mqyjCrFl59bwqCtKmiOzkzXZyHkGZw5j1FwTDTKlZMoiIhXsAj4RtJjR42JNCMkLt9I0YSOpJIqYwr7JCLhVYNabwrs4xe0CeRIcSnYQslUPANCUiUnxksl54uzVkgq8p0mP5aJMKXJOI70zXTTvPVTB7BYCD8UouK7kSthXiAWRYGKy2POo9pXe4FEPdD2lPyD7a7Sm6naaIH5OmRL7Ok8W46z1PU+WLvvXKb5SH9q4WIp668pdc6pGmO8qHcXVBARxaCm2SKSMlEyeZspoFAt/NhV+n4qFg1Ox8khTWc9NjcCpoG6yaiciCNHK+amdS7vJDLNhD86j4kjqSXEiTsZVulwuFsrMeZWRUBvQrvZ564uWfYpdehmfayWL6iDW7MaSV5rI1tld9o2ipNcJd8w4SUgI6FD71GGZs4ssSMvo1+XiWudvTIlSThSfmXkklVglL+o1Sy2jmD24jBM1GSE8qLlAsSgxT2RxNSSraJUyyjkv60ucTez1Eiv3z2GF3zPMdU0qquyHPcf9gaurK4YgFDlyOBzY9WtprfMEuy4xhgvlonh1uL4DDynXZEM6eg+4jhz+HSD/71T4v37i+L+/m3i/lSxxpsEi2KLjxEjFc1bOlZwZQqAPMISBQTzFC9E5plJIWcl1tXG+ozBTUgeYJtBcCvMMB2eT4n3/GQ/561zd3HDeJy5jQhK4ucepJ9WYchn2aImUrLbQ5vJGQ+u5nMITQ2ljTLXvLYvQ8519wxHb/mnx2NrvujHunt3Pk7ZZ2RpC1D7eKi/PgKfgRfFVDDOVsmwYWielqlk+ri4kRQvBmUfTOcV3HbLfQ+iqOnUkp4zOmfJgAz5mQzRE7bqhbqRbg7YtxkVXoUVPI3E/bYsEiK5I1RcQM1lLOql7c98v2U6b359v+ttj27m35bNWfSWWbM5G3m0vyzWZhf0evx9QHdEyI3nAi0O1kKczejFhMRd6rvMtF33EuQu7qwMvv/oW0+tr3jl/A4BvvLzn0YNPjkIhNkDEFzOSaielUsWEdSXAo0IpK9k4XN8y5MwcCz7OlA+/Qfrmr8FuYDxd8L098eBbiZ2Md555nnk4nxiGQBs4xQtpLpwuIwVhjpnihPNlJjPyWI2IPgTO40zMEYWlIJrHLehOEuU8T0t5NHTNvJTNu5SqTL/lJTmpSS71mFnsHEs1jCpeOwkLqqYKl5KZlrxDJcygpZjcQzOmimVTemq6f3G4bAp3TUC7bVZF8ibrV/FBgIw0tX/fkdJMLFUWxBtBUVyg1WGVRqLvhe7Qo5cA80zpYK8wN6mC4hB3oFcTUjZekOLFGXcISFqYpgvnOHEzDJZEEwLqrW/2dcE5CuwcuGxCoAWHmxUvnrmulb6J4HoHPiDS21zaT+jxbaglv3w/01Mo+VN2l0LO9ppmnq6pDc3eJq40Z6qtpW6zGGh1Ko1+J4v2mOv3qA91/TNZCT+DywmdZnSqhuz5AlNEYkRipsSEJtsX2kIgyyJR1z9YOFMpFnLLyvQD0h+RfgdhQMQzOaXExDxXnvNoCvmllGrkl5oQJk8c/5atuhihBZ6XGhEFyVRB842ThSzrYBZdk9EU9FOlf5HpmhRtt/a9w9YhxZxnr2asreuhVd1IqVDCRBoTE2eSWoLAsj4mXSpDNEfJOWepKE2MN2fGkpg1WqKKOlwpNg9yWqpiTPOFONe1rdgSE8Tm2rSIL3ccuj1Dt2PXD3QButxxHI4ch91iOH12xpIEip2z0yNOhIMqpfIvB+/Q3oSPUvx3oFbmHYXf/V3l9ifgcrTh12MkfCnQBVtcc7SsklSLCs/nR+Yy8OJGmXcDwXlKyZzGmdeXC9HXUFkXUTfTHW1x9hGmGciej6sb8vb+I77q7njneIN72XN6nOhECSIE15FrUdTYzbjcQ0iUPIOLVTBx3cAXY0DWzbzNna2htRzTvJb68daoayHMevg6wDfnhtWAeILcbT0pthsvi9f//O+tOUwjLGgxDS8nIPOKUEkwr7/qESnZ4HuFvqX6e4E+4Pd7UjDkU7IV5WUulFPNBMtreBZXQ4blizaqGWZAtoXH6Uqgf06ybBv18zJWUHXMWn8JSzp5u8YTm/qZgbiQib+HgdayMrf3/OT+Nz97aibmUEM/+wHpepgEHS9I3kO+oPlTyuNpIeO6XSClA4onzuBfF65/5Irp3zvylVe2mv5v83f47N1Pefktu1CX4c7DC4ULQleh1gKmecWqB6VSw1b1GePrB3YPnxPOHrk8cvrwV9nLRAkHckqcZnNuDi/eQZwjTjMiyowwx4hKZrc33ad5mkgpE/odod9xms6Maeai4EtaUKDedVzIJLHaebNkM3TL6n1HEcZiiSBO5IngrM236uTx9P2p2T4LUghmoJo2kxlrzRjLPK3EcQEe0LrZqKEY0QbjgqJIJSCLWo3X4vD1fltijFNBxYQ2vRZCb/IWITiyFtyyKRbmnIhNCNd7cE3DrJmHdp9D5xj2HTeHax7TZ5RSw/nVyEupsGOPc4WhSjUUzYamrJAGeZo5jRdKcHQqdF1HDkIoyq4e1mPzLmAioIriUrGC3mLjWaWHfkDDDtf1No8cSLlGb17icqWtxBMaHJ33DHefUh4j0ixm1vfQ1qz2LpcsdF3XzAUBrsd4MbRs8A4fbA13vkODX4RwLbM8QUq4ebRMS4A5wpSQYiHMPNlar65qkIHpe1Uegwq1bFypRHhdkOIiM5JmZD5AsCSBRK4aXna3OaVFPqJJRjjEkkSw4208i6kC1MXqezqV2P2UJ2N5FcjeVoRxBfTR4S8ZJ6Ppe+3aOo7JqZS6ltb11rLmm6Mhlt2bkom+ZmViZk4wJ6VWZKMPIEUQlxDpkeKJJJSIVDeplMCcHeN8YciBOYMXC20mzYwVwTrPF6ZoCLaXZqx6vPaUzsbg0O253l+x7wdLZPFKH3qOuwM3uxv21Vif8oXPH6HLmatdsjqv2Dys/iauq46nCmN4vju9uf1AG2avgd9yFn7i4vm1vS2DvZoCceeFgwraeabQUea0TIrHCzxMI5+lif58psOxcwEQUlZmZ+cKnTLsAlOeyOpwCr0UnAbuLvaSv5Ue+C3uIz6/eofblz35w0IUAZ/RYUCSSRV4iWgYKCGiKVTujL2k7ebewomNN+bYIFKyHtc2iidq1pt3vg1LbodCm1Tu2bFbw6x9YdE4xK5TtA7k5waErscVsc1qiplCIkh4EjIoKF5bEW6tN2/IZisx5L1Hh73JPLhgyEJxuARlzJSpfOH+mgH73AAqYAtURXJEzIDL+hTpY9Pn2z7YGkXb+qTbECc8Qzq3p9GN8bi5P6phve3LN7VcF9GG8BWBDqFzgq8hFhd8NT4EiRFX7kmnC9PjK0IsBMy4uVyE4C7IsEdLT/wo4t97jf/JgR9++30APnzd8as/Ebn91iuywM51vJLIexkuG9hBeGp4WraredXN095fJvJ4z/zqzP4yceUz0UXuXz8Qdnv2dSN+PJ3xvqN3jpwjl5RMLPMyLZuPFjgcDtzd3RHTDJ3ncYxMwG5OzBUN853jLs4cDz13Z0vX967aQMu7KcT6jrJWHbI3vDtYOZnQMjB1Ccm0ZiixceuslgdPZFaajMwZQXGEmodt4ZWIqyMpiMN7C406VXywUJJsy35ghlwIPd45kpwIM1Vdol9kaqTYhjdXl8HCgpZR6Cq3L7gO1N7XWy9f8kP7I98Ins9PnzCrMtRrzvPMIR7JPuPxUApBhVEduW6Kog6dC+dphuFAp8rQ7aogbabbwPF98Dg18pIGRVLA7Q+43Qvrr/4lZTfgwtFExJSqjJ/w+RapoauS90i3x4WOPoBzrxgeZs6z8tjeTV0f02IUuIomFVrGMzxDSlE6rVzXrrOi5ACuA9cjYlwx0WzZ0rmQzydKRaZlmijR1MTLnKy+pwKSyBXd0WwLgxaxkj3FshqjmsPXsh4zCmVEU0ZcwLlQnWuTtQBI2UoyFTXpFa0OsVWREXwd+U2YpYXnW+1gLet4XtQCKt8j18LwecMlXjjHavOmzBDOmY4JXzx+rKH5AKHXBapbspyf7T0527vNWU2mQw0xTEk3UZ9gFAY1HbQian1GpvGXiyamPBKTI6ceDZAkkTRSiMxVPHbUmUs0QeVewEuP0JM1WJYwsB+OXO2v6fsdLgjeCSEE9sOBq8Mtt3sbq6GLnB8TDxlIhb0frbbxMLCrNZ2DGygukrUsUjW/WfuBNsyMLqn88L3yqzf1s2CweKAgOdLtHWEI9J0iVZcnhshUlDkr8/1IEBiDgDiK84sQ3VXqGFxAZCCVkTGbYOMuTEx1wHw6wrfCd7jt3ufq+D7cdoznjE7F9Fmad9X1lFRQ11P8DvWZ4qZWEg3YTA6eLuqtnMUW9doiaE/W7Xau556QPP13q+y/hbr1Dd9pmmZeniUetOPqv61Mk6hJQEox1W7VsEweqQEmFWewudom4l1a0uVl6NH9Ee16C7+o4NShcyKf5mXR0nY9KqdGtxtsRQVKK9Ox8ntgRSCe90HryyCrIbYNZZpK9tPySlqhlSd9p2/++XmTzSK33vUXWwsfG8iqOO+oIBFFivV1P1OmQi73pNMD6XHGJY/W7IPZ7dHyyOX+Nem4Z3pIpF8v9IdCvrJjfmLX82s/9i76C6+Zp0LnsnnFQE9Z+sLEWHVZcAGrcSksxcL34ricPqfPM3r/wOX1HbvO8c71NXfjvNQN9H1HyXCZRnsWbzyo0+NEq4PnfDDRyThxOj1wShcugHjQXPB1EbyfZrw3zlPyhqqGArNYJQeolRtqX85YOLzbwmObzVrYhL0w1G3Wp+80i435hpJ5THfuSVJPgZMq3gn9qh9KQJbC1kVWDpgURyZRnOK8rDUQnY35rhp5vXZ0iCEKoVucm5hqhYI0m2EnlXPkvCnYYw6Qc0JKid3bO37rV34LV7dHfuVb/5yPX3934QuOccSnYx3oVlDbVYyp5KbxZYTqFDM5R5wWOmfyAx1WTYPaz93xgCQh3SnFO3S4Qq/ewl+9tIN2N+iuh+4IwRuHMRSQBPoC1BAsr4XSfQ7e9KNC2KHdJ7j7M1xshp6p4exSjYBnvILnTqYKBLUi9L14wtBDqDBkCKtDWTKSPKoZjROMZ6jaVhpnM9aKFV3XnGp0JK0OcYaSlZRNlmKuqFLjyD2hlSiWIEWqFrhU59DmT6o8qlb3oRUJX5zSVlHh2Xh27aFlXXdcM8zUjKKsrTzVBtlt876uj37OyKPQudlqsVeUq5vBu61gh507a6ZpinrnSMXCxrNmtGQzTq2LkbbbaCCVbCUOJaM+k4koFs5s95XLiZQ6K/FUHCJVjljWsoCxJKakSLYwZice8QMzHa4aTjfDFfthIATTAMSBc4He7dgPR45HMzh23ac8ZDjjCFPBT4k+ZJJTcnWA/FBwaJ3fv4knXtsPtGHWsuqGMVMrtyA5oXSoSwxvHeh3HUgmzomWUjZPFke3QWz/JVHGnI3UWTvvVnZc7W7o+xcoM5/ND3z37jPmcVr0wDyeb46f8UOPH3G1v+LmdiC/VuQs9NEvuiV5GIxbFjo095ScUV+gxLWIOSsytQ2DPEfMYIWc2/eev+4nhsPyPdl4IPqFY5+HRLc1MsvWQ3p2re1+1oyzrNX1y5kg665nE1/RYmKWvvINgiT6KrEtux3u6hr1YSHzaxHKPKLny9L3RvJ3OLFV9wnEXidhy9L6nm1j3DbexJNFhI1hprY5t3qi29eiz67zJornFlXbZrc+f73Pqz602zTjwN6nc1bsGmrISTMlT2g6G9dlnvDZkWJZQpmx95T9Fef7R9IAzIp+W4k58/H+DoDH24J76yWPX7ui/9f3FCn0OB4o7DbOwHIvWH/n0jY2Weo6Du++TyhQXn1KeTgRHy9cmOlTz4xnapt6qN6yRoIPlvE3T3jvuT9buPPtt99mTolSLHyhWhi8cJ8tQWHhOonNbVczSL3AZbPBgA3LUrMwS/0OzaiXLxrJ7f2mYsrgcTN7HLbxdzVbJ6st9qHIxrmy4y8FulKz4arO17LpAzElM0qLkYhVlRw6c0ArOhJ9wUtGMD7MkZ5d8HRe0WFgqAMvRfBdR5utIoJ3juB0CdUGZ7SBlJVHnRkOe368/2GiPiBEPr97BcA8PpKnK/w+2PulgPP4IsypJSUULmXmPEemNJGl4IMSPBwcHGpndsEzXF8bj+lhAneF7l6Qr28IewsPlcEj/Z7Sm2HmFFuXg40vakRDSsF5E3nVoUePe9zxwDB8gt7buOnOJ85ROS/G9FpjcYuwt/EsrNU/+s4TdnvYmyGoBsvZd4uSNUKK6Dzip9lCmGAlmXJCU7Si69nqQUqt0ACQU0PGlGnJzrd/t6Lj2/EnKCbU2Crr2lG5InjNEVhR7YqUNcMfnnBWt+v2ghxuHPesz+5F1n5rX/b1XPGk7PaFYR7pxxr5MFoVrqOWcWqaklstMwcV5Uu1JmzMoBji52rtqKSFTsS05DyoJJIUlExfJ2zwiqvnWs5f+8JLWCpxdHiyKwxi5dmCC+A7sgR2FR19sTswDAN93yPBLNfge4bhwPFwy8trE7R+sf8Wr3wkeyGK5dWXAmNSXKVE+WCOkUhGn+zs37v9QBtmtt07Ys48NpLwnEg7kA5OXSZLZudh2IEOdowGsEQ1x5wtno5CECHhiLXuJhp4cf0Wx/CCXbfjpp/g6jWfffxdxpMtWucy8apTPot3HPM9O72huIlxOACF0uCd6MjiKK6juA71sxFJc6E0eLtOlrbxN3Sq6FPErPEl2ka/NSae87+2P2+L87ZpsaAdm3O01ibkc4Nsi+rkzUaWoYpt2vrlBFLJrAVlQYpxzCwjKYEoIpkuQN/bJuV3e3S3s0VYM5TKqYgX0nh+Up2gLRJtg2wlhLbJDdvwYasRt/CinvVRrn3SZCwy67lyey/1d795L7+ZUdYMt+/FMWttXYSffq99thjNG9iyaMLnQpmVzs9GOJ5AcqTksGQfTjIxjDM7t+fTzzPOZ2Ka0Q979jvr+/Mvvyb9aOLbbzl+66/WUFARXgPvCVWiGR5r/2wf5wv25ItrdqXj/N3PIFtG2iXOPNx/iuuumKoh7vsdDk8uhXE8cZ4z8zxT5nnJAn0cL1bFQYqpaydDHBwwerjdbBRghphD6MQZ56waaWD+mXjbJDMmqlkyDM/ezXbjaihGUhOsbEkJUgqzwtEZNyxWKZaALPwurZNnLqAl4wu1sLjUuWMXDiEQCIQkSI6EcE0OAknR5npKNkkaH0jesfOeAzscEe07+qp7oF6QobcQH1LDn67yumwSlJIsy9Qpr+MDD9OFr+6OvPv2WzA+cMkmxzCfItN8wU87Umd1D3AejQWp2bxJMgnhdDkzzh7tHaEThk7YubUQuJRMf3tg0B4+vVB2L5Crl6TjNRwMV9Od2K4egnHMmuVfM0tXZeUZ9NbWkJ0jHQOlH3B49t1HAAw5IjKbIG/aroX1fjYTraGRVjNXGEJniFk1GOl6kq9YYbGSVCVG8jQS5rQIvlqJpmLGZDHx1ZxAPaQms5BqUif2X0tYmOvv2zVE2IbhVw7kmj7mNyt7RbcMoyHj8JsQ5NahXx5dn62DmwVo4ebpunY2A9Y5M6C6orhJ6COEnJbIlMyC9or05tCIt/nqyiob0rK6c5XQaCXFRCyc37hoPgS6wSMuWjJICaR8ITMjvjkaJkCs3oSAnQRz/iXQOU9fja7gPHSRA0DwSBbbrLrAUDlmh67Di8M31XFVnHP03Y7d7rAgZm8dj3x7d7JxKYZZzrmAzkhvb62bO3zowAlanpEgv0d7HpX6sn3Zvmxfti/bl+3L9mX7sv0f1H6gETODbzO/2EM4mY35+U7py0zBs5szijDliHer5+CDwYtxVvZOiQmkeCuZkAthb93y6nwh3FzxteF3wB6+6hI/2j/wr/p3+JUHkxeYzq95jHs+Zcd7ceSRK7q+cJAJH/cMxTCGU7jg+0KiQA6Qg4U9akmX588FzWpusPUGLaFq0TxzfZ6Q19l8xhvCnZvw5sojWDNAgUVo075v6MISLmRFi5oP0NKwPVh5Di2EEIilI4T2NK7yvhJCphPh0A30XaTsK95xc4UOHb6r8FuEME+WYZhl8dzEO1Iuy710NR5h5FS7d618iCbXUYSFgGd8B/s8U7Mu67O07MttkfGGjLl6vSWzCfO2ddM37fhUda209sv2Xa/h8LUtYYPGt0PIxURMlYpeeiMFdzUUqGNG5UzXd5SYoGScs/Tu7IVUHz4n5eIKU3mgoKTsQIVwHjnUm5H+wOf/+jM+esvz4wPki9C5zAUDAhbKgErVI1pDwFFsnDSQWB4e0PFCuNozvfqUWc94BlL/gjLPlCpulctI1g51O2LODEMgz4XXp9cMnYWRTudH3vttP8706o6HX3ugE89eMxdv/XupyE0QR5FSWQtKJK9isJuwVYy6IBMur2Ek0XVR7KQKDrNmk+X6Di4t2UDsb6nCh0lq5uE2ZCZ27t457qTwPvD6JqPHa9w4W8F5ILmAlsQudAyjkkl0s0dcIjY0OZneWXYdRy/csCcpHIcjNxqYWpKNJHKI9DKgOILfUbSz2oCt9qMNJLwqu3Lio/yaH+/eIgflMLzF2/vPAHi8/4zL5YHdrec8T/TSk8bCSFjDZXHm8fzI/PqRH7sZOBLwuSMOiUO3Iq3XMtDv3yHfHoiffEifD8T9nrA/gL+2fggdIeztPkNHxoFGnAjiHLozEjdzoRwd+B45XxFKj9zsUCzpC6AfYX+6J19Gap3sZT0ddZ17WidfK5HmOyUcduT+Ghns7mXYIV2PdF3N6BghnujGCyWdcDWFUBLkltKhDpIiDlzKlLKuS23NaYlXqa6lX4iA6Lo2b5GUBQmryVONuF+0FqrXXFGzxu/djGUMnetZk6ds3NgfW6Z7oygs90KNFFARNGd7xzAqPhaOxfG6RpzcQdn30HmTNupdW2cLuyobUaJSvEeSUhjXRDQF6cLCh+wHR3AdLvSodyCRIRnEqe0ldo5dEHZ9xEmm80e860k6EpmXcw1hz/WQ8UHptCe6nqSeLgf2FTFzO0dykGNBJZN7R/amNXf0O657G6s3777LVx4eeLgf2Q97uo4qAyOMY5Xw6T2dmBD0s2pi37P9QBtmR+A+wEd71srxg3J0nk4KQ+jwXnFS6J1fuVqlIOrwQZlTYr/fI3OGmOnFUUW2uUwzH7+643f/5Ev23cBjvrC76vkteHZqnf6hF1IO3I0nHvrIjRS6rucUFO2UoQrkDHScw5mShCIBlR34TCknEmth1pmV12STRJe/NaOhbTRLewbPP/vY/ibrByvnbF0gtsetHKfNGRajcGPT6dMQqKNpOa0LQKnp26vga5NYMIX1gElrDMNAd7TsQTkcyP1Q1fft2BQn8jiTSl7uV6tB2eywJ9o6rP22hAHUDK4WnlyyX2vffSG8WQ238ry/eRq6fNZTT1rQNTQ9b/rWY4ToTuWN317rb+oT2YY1VK2LMKlZDyYmqimjKUNy5Ay5+FUTKcMcL8zeVSK0UCLkZAWuwYiw715f82195HHf080zWtZQVFtXJvTJJvFcqgUgziN5vuBKJlRjQAUjax9MpsY+DCY9IgURx+Vy4TydcF3gUs3+vgjnDz8lns5IH3icLwhwpcK06T+lGDdGWCRLWihyTVJYw0WNR9jGwravoTl/65jxz/6+jptVNsRCPWUdd0aPIkjBO7P25HpH33dIgdD4YyT6AikHhq5wyRnvHII3w4/mDGWK9/huwEvHrhu4HTw36hmrJBBd5NxbIfFYIilFKzCeC27Jfrbi2Jb5m5nyIw/zyK70pF441vk47O6ZU0SSR1SZ84UUlbEIserHpaKcJ4+MoxVjDw4nQueFXVD2tTNCKngEf7yGlxfceY/rDkh/hLYp+gH1HudNc02KcbRKUlxm0V8jWMY2+wPiAr4vMI/QKa5rzsiIfJbZy2SyJtNKy/CsOnft3e6xfeXgdwz9gbDbU4Z6vT6gwaE4SkkmHBsjxNlC9fUdaa4x02LcT83FwnWVU9bGTTNAGq+sGWWtDjKsDrDfhBvdG5eMFmAExHQh3bIG1rlXE1R0PXJZL7echOpzLnJCnpbg0q5fnbiFoqJ4hZIKoThcHYNezbnpKs8seIzbnfzqleZ1LVO/GoLiBIcs9INQk1bs3ddksMalq4uq88VAF9/Rh0Dng+VKaE+fe/paI7bvd+xzhARBAy53aPE4txqCrlafsCQKtbCqCF03sNvtuDrY3Hjr+iWvro7oXBj6HfvBmUhzLEut2SgRRMiURUvtN2s/0IaZB/7lrfDhUXm/xpmv9ntuhj3uKnP0Pd5nHI7BB6Ru6WlOJi0gQDcQo3ECghcKgtaJf06Jf/mvf53f9WPf4idvfgfvXK74NHW8kkwfXgDwnjsTJTLtPHd95m0VDvOO8/2I2wuhFmH2E6jbU1wyrpkvaCcUdYtWS1ssGirzPDOnLfRR1kxN+KKRAE8RHK3HPCnDtPnOc6Jp+z08+xy+mBhQNscrlVdWN8NI1eoRXbSilFQRJ8uS66QwiMMFj+5t4uShB+8oVcsJUWQa4RIXTobdw2p8LTwsYSk+3u67LT7bDbhtoNs+bhu16tp/jWf2pvab8cXazW0AvDf+rZ1Lnr0vqCjb5lm219Uq/6ICxZlQZ6lFgUuBmCzjqW3quYjJUGSINPFGB1nwdcwXdbzsd9wNPY9v91zffYaqcW7OGHIHlW+3GYPbvaLdu8uKmyPT6QGZI3GKdFcHcueY80yMDVHqEO+Y08w8z8x5IjNyjhekkuO7XCgPZyRn+t0ByRE0EdSI3e09NrSubManB5JbF/IouiCibSwkWEqltX6Our63Lbq6fZVtzLTr1h5drguQRLjCqgI4ZxnB5Z0dfXDIEAilZoszM2Shc4G5K0T1RDyedZPtgp2963qGYSDowM3xirf6wJXCVHlOxQec9wze5paIab05Fwi1fF0IAScm/uG0MOorTiSuwzWX4XN2O0Mrj8c9dw+Ry2WC0BFjIsZIKp7SCki3cT6OpBjJpcM52HWO65CpQQj2UfEa4HCLfxkppcd3VxTfLbp8+B0aBstAVEVSRFvR7pwXxK84T3G98eakQ0JBh50p9DeJCzVkqcuRrpxwmtEqq7hl+8z1fe2AK+fY769whxs47NDKhaTzxgsW8KlAirgYYZ7QZIkAdmPGLTMIy7TaSk00aWNoyy2LAtPGYdxmZTYjSev9PeEaL3ffXAo2n5hGxXYfyZtzqat7iD5NgliMMJ4ahH5jvInUMk8KgsnHaIX8+pgJVZU6JBMT3gfHHAq+CBIx2Yu6iOdSDbMqTrvl0zrnTH8P6Hzl/LnBHJOSQApOWDX+nOJ8ofMd3plx5r0nFujDwBCq4Gs3cygRlUJgYIw9Ev2T67VezUrViMPE6p2j7wcOB5PBevf2LT6/vSaPM84P7Hae4AtlSlT9X0Yic3KkErlM/w4YZgL8P2+VrhfeqkS7427PcX9Ft6+LXDCkbBd6uvoCS7AyKPMlcX+JpGzWOQ4rbdE2FjyvP3/kF/7FP2T307f8xO1PcOOu+Xw8MTsj/43+NWd3z7Hfc3KJB5d48c4N+kqJqcDBuvh0Hw3edh3FtUQAyBKsBhyGLq16X/W/DZTd1v5SH/4JDN/6RJ5+2AyyxThr53iDkdZCkW2Kb/NH2gIBLMrZgAlm1osJqwco9R5ErZ5eq8Um4m3TUvO8egeDB9c5tGU/7QY7V7EJ6ONMPj9S5mjp5C2DWlnqvTW17uewoW7uZUE+6ka69C8r8rEgK+1nni6S2/7YGiYtzfx5a0tm2L5A1gUyYqniGTNq35RNvfQ7LGWCzMu0N+SKh2hJEuSMqoU/cynMOZOqdxpLIRW/eOcFRxATHm1ion5W3Osz7wo8vHPkxYefMT+aIOiDyKJttfSv2o01Y/hp6DcjeUZKZJ4ulq00nnGHF6SYl/B213XM84hDGYY9l9PI43gxoz7aZjdKYRj2tnDGRCeBSCa27NuNk9I2o6ZVZ2HKsrkvniDQy3uuKFvazIf2bwtjP0elG0pbalZmM7aDwLnpTAEHFO+hUOid4/MPDgQvSPZLKHMokHolRbFsxm7PY0qWkVyzH6WryFwdAyKOfbfjpgvcOL9oMMZopas67xFvIWvvPd77Bb30LhiCLYBmJn3gtc58vbvGb8JIx92ex/MDj/HE4G+YMpQSIK/ZorkkfJ6QkknFVOh9gF0QDgLHOv93TghJIOzIVwfyvSP4AfV91c7CtLrspIY+TRNczjBdKJqXxBc3dOAD2oG4ZGtMMeK7SM3w1p48TrgxQ37FXs+oXtD4FB0VVsNst+9xN1fozRXp0C1hUdd5xOoMoVUKw6UJ5pEcI9JqSS5irxlfTSRY6RFtzCV5Gh1J2kKRNVO4jj1fWDQi35SxbWParCytnkXDz3JbqFifN2NyLiuitrZmjDWHsIVPn2eRb53SjrruR0PMailTugJ777gaHHPNqJynQhS3OGWaAS0UyUvWKMqyD7axGpyV/3JaatZ3sem2WY9ErJqEiCXiSC3b5CvCJnWfDeLY+YHsEyodZE8pBZF1duds2eopJUoQXPFoRQiD8+x7AxGO/YEXV3se7u3ZKI5+HxAflnV7mkbSXMyZufw7kJU5efila3jf7dlVaeldf6DrOvbiOHTQ9xDw9M7RbZ82Fy46o9nhRbgbJy4xGlKS2qDocKr883/9TTr9x3z8o5/zQ++85GZwdFcGZY5Tx9UZ3G1HlCpwd63sPjgyn+8XBed49KTHMwlFfW8aRSjZlUXrLKd5gbVbdo4CyFPUpsX4mzp9PcT+3aAurS3I2WbTfFLeaYuCycbA2GxQ2/O1c7jtyXm62bVQZlLFaURyzdQRtZRnzKsTZ3w/H4Qy1BcUOryoFbdNVtqEcSQWrYtYhdLLWpLJsDgrh6NsNtVmEMnqjbYU8FXit26wm8Wm8eme2FObcyo8Eaj9gsCsru8OqkzJG6wulaele2hI1LNj3OZd2P3qEixzGHqWYjbuWBWqnGNmLlBrvjMVyBQiSnaWyRVL1daplrM4T8gwvC58/jKT37+C0yPn+sBz9Vq2Idhm0LbKFc1ISlogz8T5TOc8lERJCRdj7UNbpMbpRIyR0HWknAgh0HUD5ynS1wdOCOc4se8HpjgbGhc8jmzyVu1e3IoCNAkCwaQuttDjalBt5oI+RYybQ9TJOra1cgZba7Vnk5iB3cppZcwYbNe7KcAAZYT9wRE/uDH00fmlSLZ6e6+XnHAY4pJSouRogqRA8I7sClSxWy0m2bOTHfv9wLHW4pvSQI5K3/dmzKZ14/JuRcx85c2ICClf+Gx+zY/t36LvdnSdZSMOfU9wnpgyec7EuVg/SFgFgBU66QgSajWEjO8ch84Ms1p9i73v8OpRFdR5EIcTEw9tm6wvINH0v0gZmUbkMiKzqc6t3e+QnUM6Z4iZCk4C4rqK44C4AR8nm1s+0PtPcDLDY8andUjs7PXwoofhuCPcXMP1Ebcb0N7QN6mGKrFAyWiaSXFE0gQl0UghWkVPRdfyZAsa1JaAVkycOrZkNcq2Y1Jg0WhcoiRba7KOU1v3bPSq6eksNYa3mfdOG69Nl2ttEbPWGgWg2/IBNm1rILYfXQSfdVnPQoZDNxCvAjOCemHqEheUWBelEo0HnoUFVbSlUp8YXWB+p+pIRknZBFuf3HMdS95b9rHRCaQ60Wtxei8Bz0wRK85XCqSUECkLPSTnTMrRkP3i6NUE6C3c2dP3BiIM/Z79fk/XecYxk5IZqN45Wp2v6CKXMTJPme+zItMPtmH2jRv46DrwstvR7Spi5gZC6HB9x7UThkFwZSIAvnqTOI/zgssdXgLlMfE4F5wqLitdhTO9dsyacWfll/7Vv+BbH/06hxdX7PY37I5mMQ/yyNUQmeaZbt8xpZEHOeHeuca/HrjUAtIhduSzJ1eTRbHCxhlPqaGFnDIZ00LZti9wn1hRILBfnntSb/Kutmr1C1K0nWANMXs+8Z+1Rn5vocFndtxyf2kBERS09r0G89ZUjMwrBe8KKlYbDUCqtILLIzJm/GViupyZqNesD9FQxUaqz6ybY1uN1tTy1Xho/zYyd2sNOdny5p4buN8rfPm9EDO7ti4GV2uheaZqPy/2sS7Z2fbd+h6b98py30pLOhc1lXdNhaS2QMdi8gxzWUnv9tzmweeSKOLweHuuphCDol2HT+DvL3x2gHd3cL5An1dhXV36eN1cCiwcQrDFzsrUFHJKhGD6PntVk7GoaNjQKynNpJwJ/Z7DbscUjygzuR7jgxDHC8PQIcGT80xWK30mutEos7dhn4kuhllhtYu3nnn7YWuMtSZUw6tAk+Jr0jXtuFTfY6rvcfuO21jZAXuBLA5SwX91z+XFDo9xyJq+kjoovfFtvBfmeGGcJyu9VIUvpSJf4nsGcaRsSukiHh969m5FIVNg4dqM2RtfC0erPB5CMIRBQVVwOfGQPmdyyrE/cu5NP/+q2zHsejjNVe5AuZQZRybXLSS7ASneBKWrkdEHzy54jhI51DV1cD2enpyUkt1ShkkUqyUJqJtwxVkiS3XMXEpoms1Yqy1qQejwfmfyHc5RfDX0cuWrdYKWF4jL4MH5xOAmJD8QzlClpugEdl3geNwTrl4g+z3Sd8Z7a1pz1blwmiBNaBwpcbRQa0lLzcaVV1aqBltFwzYUi1LHXFuXGhrbHMa2LgmbcONmiG6dwoYQy8JH1cWwMd25dTzW6OqTJICtI7tck0oBYC16vo0OIE/8HCroiktajRjwOXMIO/TYMxW1Ul0+UtLE5VIRYG8ZBo0j3FrjneUKksRc8NGqZGSUnEotG7VdoVtSi9Rkr4JqzWDIDtEGDnhs1nmyemKK5KyEsBqDWQs5R0Oeh56iCaerueQrmNL3O4Zhz67rmTpb6+JUyB5SsSQVjTNlLpSZp/Hz36D9QBtm/9u7wrC75uh7usEMpSt/pPMFho59MRFK5xWvednIkIQLwbSTXKSblUPZU0bHOU6kGiOXQfEotwWiV9zpwun1iW9P30a9CRhe7aHfC4eXtzy+OOBJdL2pah9e7jnXgXWIE/L6QLyMiGSyK2SXiaKLYZYkIzovAqkBFvX5bWv8nsVTgaeoF2u4s3221bB5Yq/JuigsRbTrv+ENmww8NUB08z1lDX9q+1+NA3VtAmUTWhRM+NeJFWx2fj1pmWcLF8Qz5XFG706Mj5dFO6uJiRYsjLmUO2qIoK7GwVbPrIUOWp9u+/Z5H7+Jt7d8vhicT43SJ99pC1c1qkLtp/aYvRgpNmRqksMzI3Dzw5vsvcy6kolkq8mXG9JXicZ5JboDS63PdTMotWaoLNJQWgolCLv+wPHhgc/9zO1bnuHbmYwu64qFDNfEBGkGpOqCmE1TRHIiOIc6GGcLV8yXkXB1ZFfJuPM8UkoxYyKeiDNotuSQhTNWLIwR87w4LqU0wrU8MbADldfDimht58MTnuGGb8bmmHZcYA3vNNS1lbMBFi29raZgEStfs68nvXH2rrN6urmQfvya877HnUa60Le7M+fEebIGdmkgXEZDA4unc1XjS6sJXEAojGViSpFUMjNmRAEMbsANwr7bWTmmxThfk0YMRTC+UC6Frjgu+TPOjHR+R1c3n0PfcRh60mkiVFRqjqPN5Tr7sxjK1SuWFYjgHQyi7MWxC7XUTTdQtK+Qokfol1qm1FJrhckMpmICrTaYIxozmhOuPkwoyUJg3uGHHRIcqkJx3oorAoWA5JulLjEIThxd+QQNZ7qpFmr3QtjtcVcv4foW2e+hH2C3M2SP6nSoQrb6lSbkbGPc/muhTLVyTdUT1rrotmQiWOfgNku0/b19vgzHOrdaKbCWiLN1vLdcsoYQtXfcIgxGQdH1eqxrWGlZ4Lo6oE5WJ7yK3y8XaGvbGrER03ZL0Ndq4y4XOj8QDnvG4si90ncTaS48nOyLLiRKkUYlXvmpLqAqzNUQd3OqToSFieeUSLks8xcgaSJpWkj7Vr2gkJOYoO/Cd7CaQVmVKWYuMRGT0nVpmRtWizSSUiKlGZ+NX2bibQ5p3LehZz8c2fV7Yq/knCkqaEqkfAHgkpQxm3H+vfjKz9sPtGH2q2/3/JDc0FOIlaCpQ2AnE4MGdr0wdC30l1eYstYYck7wvRCGgFwMWtj73hYC4JxmGBwpBLxTE+rsA7fdQCqVxOeVdO75bnzNtx5ekw9nrt4WdnKg8wOXWiTO72EYIEkxb5CeTCaR14LIPpnicps8unpQiZV38IXMsObF88VQ5TJhN8YLrAtE+9MySfU5jbQe8+z45rGVjSu3Da025KehGSvqbLH8XnQpKC6lgBZyxXkzE0KkxDOMGZlnclzr3oWNOVSqUeMU46i0Ekybe2/gTlsE239fSIisq9SWQ7cl2/5GffI9wDLEwU5hr1aiqG98Il87zimT5tVQoL67ZzDkcn5dveM1U09IsRFpq0es9nsLjcM6jgqrgaiqZE1IXQrMg81IVK4pnPHc3xTe+bYwiS4GSCuj0rJ0t1ljrV+CCnGcKCnWEIMZWOfzI2HoCAdTz7b1rjeETSOaLWTgXaafbV6PeSJ4IcZIjoW9C5U1lmpJo7oR1R8bh2fbn1sUsjkpbRzYZlSlY+ozSt0oFyN8Y9i295FhRWmLhdJTDRO1MkTXallpqp4uR/LXjqS+I5xmOh8WczdXFNl7b4WTw2RGe9Y1E1EVJ5U5UxIP04mH8ZFHPxB2HTeVsL/rO25cR+gCl2ylmLy3IubtXI0v01Wun0e4xAdOeeLtLfE6eFMv70yQOPsOpcMRlhC4d5ldS8vGQan8HlV24tn5mhEXBtQFRIWgHUV6SkV4pG6cOSaY47JeS1FcUlLK1fipa6pKRYU8qgEZBA3ODIpG4pYe/Az9Fd1RyNnVRVWstmYto1ScUnY78vURd3UN+wPse/JhtyB5RQoaC+RkVTbibGhInA0tW8pilAWacvUenT6jTjxzlNu8apVftghuHVpvLIfXxmnjVi4cLVou5sbhZf3+lkv7PDGpZWOKsxqRmbIi+qx7jdZXYeupotnmSle5XB2Bodux3x/oiycNivee8ZIJrUSFz1Vk1sj+ue69TZA81QzPGCOx65BsjsWcErEkqoas3Y9aJYBSlJIt4UKKjZuUzECz11MLvJfMnJR5jswZdqU8McyaUZbzbh2LdX6G6miEEEx0ttsxuhEpSimOqHHhhc5zRUuLVOPwOQzwxfYDbZiN1y+4Sgl84dpblsRx6riSPTI9cjh0BAac6yk6IhUCyqqM84mcE5dzxzxFSpwJpVCCw9XcbjcroezRMOE6x9WwN361QplsoXECbifcTDc8fDbzK69PcP9tdu/33BwHLq+rFIYKaZjRzlMu0Wp9hUyZPVYfB7qijPL4BCURVs9lG2JsE5ol22yNocPGUNP1d9lsMipPZ+NiB0gz9JR5811X72OrnN9g9mWxaffaPBiebmL2jMEMBJcpQUhacOooE4STGbt5uLd07+gQes7lBOLoS+EsUCXrGDIckl0/YoTNtjhtPZPG7Wj2TvNCtxIKUvu3oSfz9rubPmqeIlQ0pf67Jch2QK/2795B73v6ENgfj+x2pomk0ZHyhYve4xPM40SKsdYhfIrgZAy9FKnvp7Rr283HmMl59TiVVnj4qfedAfH2gtdnN4V80bYJCj4V9iUylx1yN/L4nmd4N+M+Wfs0AAMmVXHxcJOEKMJBYaoLz2eXmR8iQYS7nKzUTQ/FXRPUMVePstX7ElGG7gohVku+p7uxJer8+sI5ZyQndv2Ox5goWmrouTxBhycMtfOllm2T9Z3DZmPRNSwTZUUZFp6mmHGXqpHl63tuWXRgSJhXYS7KWZSjmiFWgOvaV7sCZYBLybw4wP1vfQnzTAgDjkKsRovGQq+J4nqK3CM5kf1ASZlQUYh9vycK5N0NafqcbnRcYuFhvuPwMHCqa9xOI67rUZ/piqDJFNBdyTSqFD6QnOOcMn3y7DmRS8/HfIOvht9BF+7tGfsz192OzwRKvuDmjt4HUlfwySbj4I1APQzKQKSoMFK4peNQEn0NWXq5AhxCZxUqcsHvd+QpIcEqDXTJxoQUsWPVOF0LqF6tmFQmmEFcnTRyRdAeDUqutBV3yaCCeEfpe+R4QNPbhuT2Hfr6tb1z5wldj+wHdOfIO28E7iJoJSdrHBEKqhMynQlxhJjIKSKy1gMuJYEWqy0slqykapQjtxmnlmxlVlVRlsoi8NQAU1kdg7bGOH16LlgdcKGFIVkyMW1c63JcczJahr/fONSpevOdKpMYW6+XlRJRgJLX0GwzlUMWNDpiBS3cnBl0YHe8IZTCHALid5xPM6Gzd518YqSYoVp0Qec8ihPB1UoAJBjP9wQKMWMJTV6WhCiwPk4IEUOQ42wc5hwLpupS36MqaEfOE+M4Mc8F2FG8LshhVBDNFLsTAzhKYcbjnF8yQQfZ03UD3a7DOW82uYvMaWKsty54TFNupZ/8Zu1NBviX7cv2Zfuyfdm+bF+2L9uX7f+A9gONmHUu4F3hprthl28AuA6CDwnpDrhsqrs+KEmdKUhjIQOJnhSVNM2kKVpsGDGCZ6MQeys03PmAw0jqxTnCIjVsvLBcIn0XuHJ7pnnmu68f+X/Pv45/t+MnXvwQAOeHO+IuIkPCO8hzYNaB3IErFU7PGasXRrWwn0balvCJrl4KZZUB2EbmfIOjK1JGQ1wamvW94nOs6NHqaVW4fQOlN8+mlKffox7rN5B8C3cBzCQ6hCJiIRpnxOXihFITBFR2+OEF0neUSyEcEvt3hHC5p3985JIt9KPAY7vf+vtWnBF4knXaeGYtpNe8vXauhk6VzfO3sOe2/7VeL2PE7iNwwMQpwTKDD/uBYegY+ndxNx2pq9yE+ow5PhDmwOF8y4UTxSdcKkRfiDWrsl3zC6HrUrlTzQPWqh1XGsFfa1bgU8SseaULCriE+CyPDEwFPgvMmvHqeTnC3WPhcjvw4rOJcx0AJbNkdrlk4feEkfFbmHyOiTl2zAWmbIr1XSf44UAaZLl/hSe8EBGhCx3BuyXstht2XKYTzgk5WUgkp/SF8CmsiNh2DPxmheGXMbDt6xrmEVbkYXZwVWQVJi2G0HUC76iNhxdUwn8dMTtMKuN8ihz/vWte/9AtMp0I0mJ/te+94LEMys51TKGwO80kv6O/eQHAS39FUs+l93S+Yz6/Zr+DMXfcTWeGy2sAOneLT2fe2cEUPBOFmK0iQuvTIIGOQCjG81PXIyhTvBD7xFBJ73HoGfo9OkCeQF2m33UMm+BYH4Te7cE5kzioQro9joN27Csq6IcdKt7WjawLf0tzoSUIZWURawVn4fuc0WR8rjWN9gKlg2LisVIiZTdAd4CqY5bLhJNiwrC9R3OPHI5IvEEFfA27qSriHTL0Rvh3jpaO3qqGqCiaE2WK6BwtY7xK1uScaeJc2rIxlUpir+iSrmOtEf7XyqX1Gs/GaeP+tjmb5SnvCxqnbOmtNbNxOed60u14X+gym/mxPRf12t8rqel500o6c3V/9GIh3V23w1FwweG8cL0/cNjZatl3ieyUVAStqGu7UVW1zG5YqnMUUVIWUrEEH9dt1hFVNClzmphzZNZI0cI4ZeJseoIAOQWm6ULMpWZfWvIZRRfZIEn2t1JS1YUsS6UFEdMDBOOSd2FHCL1FlpzDdx3zhrdLUYskuZXL95u1H2jD7JivSf6RYdfh2kTsBw5+YsDhksP10LsC6igtbqWFkmEeBVI29WjWTTjX4yR4Qh8YsLovUjLeF7KPVNk0guuZY0eRRF8GOgJd9jwU5Vc++ZD9S1vcbt7uOD0q3S7Su7kSLo0knFs2ot/bBCeZZkqN7cNT3lMLN7YJ2LIIt4ZX+71NYqgwuDwNz7Xn3rY2udvg2BqIWyh9u2EBS2q2wCI9ILJWAqAdK1DI5CqEKqHH7V5QgoWjXX8DV+8CPSWMeD/Q3b7NYbxj991X3H/yIQAPOlnJoxrCa/D9tspBM0gLW/K7RyUvvAzqd7YyGgFXM2TX5p1JW/RqPLcbtIYrheP1Cw6379iYuH5BOBzQLuD2L4EZmSYzl2qmjkyQxQRTdy7CnOpCosyyLoZdaZw6Xd61UzOMqoSUJR2xvtdUww7bLNT2HpsZ0N7N81FQRJDsiCQ6J+wn4fToeHhHuL6CYjkvxLqgaw2vzGrvYUaZliySgusPzOXCnEGcw4eBMY6Ew3HRwCrF0tpTzcQy48w4Lk0movOB6CyLsFSOaOJ5/nJ9Tl03sO372yaCvOk7W15h65FmhNv4tcLkro4NO6eJfrwQeFcdhULAjHS33J1xWctjof/dX+d8fWCYR3ZDb7IRbYN1Ac2FrKZFJ5opmpHuBYfO5saQD7hwZNd7rlxH1iNX/YTmiTFlTqc7AHrv0NOBfvic/e1XCQipGHGfxldzAUJH8Z7Om8ailMjj9IrLVVoKOt9VNfUdHdlDdI6ZjA8OX2UDXFAkOfZDz845ghMCmR44uo6DlYymY6C4Di0BzbJIdziUXAVatRL1TVbHHDbJGSkZ1bx4g2EOFL1Q3Gtk9uh8gN0L5Ci4vjktq+GbBaTrkN0e0i0iDq1ctJIzlIx0HjqHeKmhxgItnBYTxAniaNyylHBtvKolULTBJGVDodgYZVuKxdbpe9K261cdm0umuK7jczGmNk50o444C7quizLrvtHW7IVz+4axL9WhfU76h6f31p6nOapky9oFLHFlyvShI3hFvCeIYxx2HHY2tobQMTtTIihVtkVVvuBMFWrSSgYt1fnzWrNqpb5DywSfR88UTB4r5sRlilymtDifcz4TJ6rgtjLnRPBV3qTywrQaY5b5mVDC4jSKsFALQgjswhVd2CFOCcFRXERJNXwKOZcloeKNQuNvaD/QhlnfHxnLI9e5oKUqPUsguGx6IyoEEg5v/KGqwl+SkQpTUks/l4CIaQ+fp8TcMlS8gHcM4YC4BDIZfyCH1VAJVvG+Sw439Lhe8NKRNPHZZeZXP/02AD/+1gf0N3v0QSn9hJs78xh1zSrB9VACebFwClpj009LKrlKdGSp79gMoa3X34j8Xp7Mz8oxW/txmxwALEKbbbY21KalW0OdhG1h2CwUzxcI9KmRZ/wG45ZpAbfvYf+Scv0B/toWb65ekq/fRV2PXp3x8yOaMjm+zf7q5ZrE8em3cSWTZRW1XXSANs/nWUnbdl9lyXJamq4bucN4UlL7riqxsNNK5Ad2TrkpMNzcsnv3K/h330dfvGXPeLyi7A64LhCzotMFfzkRphNlrItIEnzncPt7KB2TCJMU4sbIav0cKyIDTxfR6tx9wUhPuiKDrX4erN7xdqCYobxdvAsuZ8R5iiaywpB7HscTr96Gq5N9r2mWiVhSRmYVpmxoUtLCQyxoEVJSUkimtD1dOMQO5+um7hzFlYXwq3WBVJRQc9C8eDrfU8g4Z9lazZDWjUey5TS2xb31zcLN3hheZTM2WtLWlmO2NdhUFVdgJjO1MS9wVHgLzx64dyZR0oxgAB0UcR6/K4Tf83XmeWYfLOsxK6SKtJQi5MoPLAWGECh+oqhw5Y2buO9ecLV7m5dXO8AzOs8xZfL+nvLwObGS2afDREmB16fP2d98gEcIqGWWsZL9QwhI7xAyY0oEMuf4mjMjb9eEKvGOYdix21/zerxHS2IfOnYHk8gAcGR06NldDQze0YmY4r6YtEBf12dxPep7JHso3jiyueAolFRrTVYFaXHGyRFNUJI5qibFbuM/Z8gTUk44CnoZYV/QOCOHitCFPVBqBm+xJAMRQ9TCnnKoTrFmk7xwmDxGsHJVCGiVcNd4QaYJiTULs1hCmRYrM/VcQ7IZZM2oerKG8xu350LhX/i7rkhe0yFb36yh5q1pE1aXNRqSNg6zZ+WrPRGSXTxunpD/26mfZxgqhjS5SrJ3qRBPF4swOWXG46XQh54+bKpPuFgrYizuIeCq9E7ds6ngSi4WbcE61+iHdd2ogrCX+Y5hDvisEJXzKTOPfgFciobq8PlqeFVjTFe5DH32mWqmCY+I+GWtCl5wrjMJmyBoFkpJ5Diu1Q027/BNAuJvaj/Qhtnx0HE3ZUoO3IpBo6qZQcHrnj5o1ZEpaPFWww2Yo+mgSJVRd8EGbSzKeZytJBCGkBRVpB/p3AHRPbnMqD7ixbw7J0p2HWWvzFq47XqOaun7vu94/cpe8C9//gm/ZfcO4RrkpieN4EaxTKTqGkSZ0eIM1sVcLcdaPHyZ8M9c/mUTf+JhrIT9ZnRs53r7uRlYzzMqtxNxK9Yq8jQEJZvrF9aB1ww5USt83Sa8FX02OQfvoTvcwO375Bfv4W/sHbqbl3B8h9IfkHxG4xmmhEyR0u/pJ1sob8oMrz5mVH2SffgEMdt83p7lef8BSzHh1lc9MIiFKK/qCrQTIzkf9weOxyv0xdtwc4u+8xXKOx8gNxZOD0NnE1szISpcDObOJZJqSn1WwZ1G5PPP+Pz+wuM4cqmLeEv2gA2ip6vR0e5+U27Oxhtr9m7afGcrJrx9Ry3svOJxdrDHV+9aUXHsomN3UT67hv5Y7+i+LAiAKzVjuBrurvbXGfi1jz/n6xJAE8UJsevpneBF1oUSQE18wzuHK7WmX8r0tSZdiAFmDEmqcZxmNDVy9PaZ2nNtnY7tW2/G2vKZrGikbo5ZohF1mr5QsMqidtQNcFB4ocpUBSUbwjbUAlZ+P3O+JMKPHXn48Vs4f0bnrZSQdx5fOyzOkPCos/CJDB6NIzGecTXb8uX1V7jZv8M7hx2vxpkXpwt7PzEHz135hLm+7HHOSP4u57u3OL03mZC1s75zC3HZc0DYqWPSTE7KMHjGfObz+DnvdO/UvvHsup6973jVGTl/R2G/H3B1EA4uMNFzOPSEINAJgcLeOQbX07ka6C8m/ipJkOIpzqEx0bliYUqAWFFEp/YGKkompaAlm74ZoOWEpIyLYnIaMlHGR5iukJONG716C+et5JdXXXSvxDs02Of2+gu+joriFfXBNv8ckVQNs2lE5tHU/hvzXev6JwtmhBZZpDIWabPn44117D43vrZOtFSjaCkcLrXCiazzeCH/b4wt4YsE8obgtTUgY+tx+xzWNWV77SadsTUqmrHZPmqPIYBrpdayEB8mJBeLeuDI3uF9h6sSUa7WgkUUkVIR8bV/VmK/Zds/WQO1rnF1zKcipFg4+wvd5ZHiFJ2U+eLIsafrasKe6wiaGdXClWtNTKMmQTNsbfzZclONTbGSaqsTWXDOhJtLRfznOJO2oYrW/47vOyz8b2SY/fzP/zx/6S/9Jf7Fv/gX7Pd7ft/v+3389//9f89P/uRPLseM48h/8V/8F/yFv/AXmKaJP/SH/hB/7s/9Od5///3lmG984xv8iT/xJ/g7f+fvcHV1xR//43+cn//5n19KgHy/7aXb8akOfOonboe6QBAg9/hsnZezQboxF+aqTzbPpWbwZYq3Vb2IhWemWIh1K++BOcykY0fXKVIyEoVO9hRpajIzLsw43VFQU+8OsNsfcQSGOjsuH93zq599h7d1x+0Lj4wZGWfLQKqZTSmbqFUWRxO3KgskvRpe6xJgrS5fy2bTPmuZglr/XfS+2BhhtT3RcqoHbTe7hpq1+pRmdMkCD8PGo2rXr38aWaUDVFuxXqEbOoaX78DL9+D6QL42VECurvFXV/jdFVmvIEUYJ/zDA3NKhPffA+AwnZnGR6bzI05g1Krxo18MEWz7KLmVp7eVklCtxaaBA469Fq6A66ZafjjSX79NeOtduH2L+OKIf/Ee7u0fohxfkKt2knFXIkIhnS9W51Nr1YLZspH6zz7i8uFH3H1+4qH20VyNGy1P+zJt+6/+25Ae+9nGiNZ/F4OMTQiEajTJZkFtSOqzvjI0LAEdV9lxJnKcHZ+PhVc3dmcvH+FS4BJglyy8mwUeWLl2SeCTEd7aVQchdPira/JlouAX/ohUYdBWkqfVUXWipFpI1pAeK52Wshm9Wvs6s8p4bDePrQOy3URKnRxaf9bNQdsQykILkNXIHYCDCId6tlt1BAuOcwErQl5P10KZh53jw7vC7e/9Gp9woVNDrlQSKg6/IZ4UFWaNpslUIpozY3rkUvlXXXB0knG7nrfpcbcvSXniu6df5+50XoRcw3wiXx4gBeZ5pBNBxVXxzcaPsR03iemRldJ01Cbu42t0MDmT3h84yQMHevbe47TKc4pjqELbvS+Idhy7AF1h7hRP5liEvfS4ul5q6MFZtqNU3q9pg6UVAk4FqXL3xfka6k5mDDVdM4Bc0GQZ9T4mtEwwKjqe0KFqQ04zYbfH76/AGzJISVa5YOiQuY251cqRmqJo0YqCq5xW8gx5RlNcC5UDbjHTN2IUmzXoCYdsg9Y2o247/7ZGR/t9uxYsDptuIhWyZpQjZsS1+f8bCV+3yzyvIKOYEdFodlvkrN1746e1sGg7xGVBmg5ogXQeKVNif33FJM6KnLuqCQYU0aW8mHNrqLBswsJ2n1IzGquWnOqyjzX5uBSVuc/ki0NyIroZnwolmWYpNVvUeyF5R54ycTHMLHN5icaoMXXX9cQMyFYlo42XrEomE/PM5TIxjZBqFmgbNtLVELJ+EWX8Xu3fyBL6e3/v7/FzP/dz/N7f+3tJKfFf/pf/JX/wD/5BfvEXf5Hj0TyU//w//8/5a3/tr/EX/+Jf5Pb2lj/5J/8kf/gP/2H+wT/4B/YgOfOf/Cf/CR988AH/8B/+Qz788EP+2B/7Y3Rdx3/33/13/ya3g+/2vJS3OY2v0Ep+Ce4lCcG5SEqeIgXRTNSZXF9M073KkoguMJVE0kLKpv4yz42EWgjiuL4xXTKhw0lGQqTV1SqiiAykDo5+j8uQ0sw1A91+R1drkUSvfPIrrxkvD0x03FwV/DmSo0Nis75tU1ri6bKBVjeogNbddNl4ZROmqcfM2CQOuhLjG4K1XQi2G/3WsFvQhs3vzdBzyxefb+nWtnIPLTxYH7F6bKbXdNhfsXvxNly/hRxvyHsbQ64bQIoJbbvBxIAL0J3wXUD3NWx9c6S/vsWPF/OkWQ2RN8Hs29BVrobEdkMfMIRsELjWwi50XO2PHF6U04MuAAEAAElEQVS8AMC/9R689Q7cvo3ur+heXMH1C7h5G4YDoRpm4rD6eeNMiPfkeIbpNf71R8g3vwVA/PgjXk+Jj7f3put73k7MICxCq08Ms/pzrniXqx7lE42izXGFlcy7NTq2xoxJwZk4Zp8C9OBzpM+Bq6h8vqtI8Y2wv1eiPN0EJre5xwJ3Ap+WxFXVMPFhzyWPFq6vFlCguuTOUfImCQC41DCSE08fOmIV8nxCaNbt3HhqWK17kiyMtAUx0LXvv9eCuUXWpL6Llyo81hPvtSy1TlOBA1bs3QPZVc6U8/AOXP79D7i7u+NGOhIZQkGz2ySrOGLKjPOFOU3kNJO8kNOF89lU+OP1xMgdj1PPO/5AfPuaj18/cv/xJ5zvM7KzWbsLM/OYcP7EeJ44dDtD6XxYDODivVEhvKAzSEyMc0IG5ZIfFqfrZn/L+PhduusD7mxFma/6I1ehZ1ejC85nbuWIH3okzGgodJOyz0qQsHDJRKxY9EJS9Bi3J2c0VSZfBlFHkQKVawdqavq6Ka0TPSUnNI+UPCMpG1dsjpSxrl5zpgxH9HALuz3ShwUddiJIaCPE8DIJ5iiIFCTPlJRXJC8nNEYjiGvju6n9txHI3rYtv3NLUXjzylmNoM0ftzaRSUPok7UYWHTHWrO1wGphOtmgbw2JkroO8EXD8ImRKKtxBixalI03t4RoN/eizVoDfBF0SuQ54p3NnUXeooYKHQUnxSozLHfV7ne9GzMwayKIlopSGiKZ6uYyz4rzmeIyJQrZF3oNBAIOh9awuyalOGXKkaQmdeIb+bpes1RUTBymAVjDlWZUQsv/K9ozp4nzNHO6JPLkyNGhORBCHc91uOfthvqbtH8jw+yv//W//uT3//l//p957733+Cf/5J/wH/1H/xF3d3f8j//j/8if//N/nv/4P/6PAfif/qf/id/+2387/+gf/SN+5md+hr/xN/4Gv/iLv8jf+lt/i/fff5/f9bt+F//Nf/Pf8Gf+zJ/hv/6v/2v6Wpfs+2mPV57DfOTx/jWxGlNyTPTOM5Ri2RYozhVwCXGtvIsDCZRsZWDO40jMjpwzIn6t40VmlJH7O2Xvr/FBCJ3Dy25BzBIzEjzoTJgLcwbdD8wUOilMdTZ0t57DD7/k4Zc/Jr86AXuO+0B8gKB1s2MiV5JhpkDlkS3eVX3uVpi5eQxt49huwtuSG4tsEdDG35b4nZtB8D28q2Xibzwm2Vx3e66mTt02aq1GxVbYECybZXfYE66u0cMRH27IVYtOCeQ5ot2FMthEMElpxR87eLRZMXthf3PFeLfnPD4uHLitgdlaqwrQRBrNrlytCI+yR7kRuHLCy52nv7olvPwKvGVob3n5Lrx8G3d7ixyPlKuXuKsb9LAjefA15OHnGcYT7vEC3/2I+bMP6T/+ddy3foX7z2yD/a7CqYbEZte8QLun4gq5drZsY9jtWfgiOtrCt0uIoXnNG4O+heq2793LF40zW/iFvjguwYj4LipXV1d8Uj4H4NtvCb/1XrmeDSW7r4kRRQ09a736oJmTgx2Cz6DF4YY945wJpS3Ofik+3DZgsAWtiZw68fR9z/w4E0IgpkTnlDmvz2nnqn3yrM+2osPKFm18ujm9MXtze6xY6PK+nbf2e6lktiIOr1aoqK9Q6xll+PH3+OY7ARlPXPqeEEeCs40lt2y+IuScmeJMLKbFVAR0GnmohtncKbcame/v+GyY+TCd+M7dLzF+fmIogRJNGy6e4X4+c3Ge6+mC63e44CnOo8vYEhMBdT2zKKgl5AxdRyoz42zn2nXXOOfonOeKwEUTWRwdjqsavh/DSD8GpLdC6R2W4RsE1LsFpaMEcGIJVwtxvpBzQiv0IQmkK4h6Q7ZgQUmoxpnd/2zirwk0Z6SMSImQFJmt7/v8mnSZKGOG4y3+sEc6t4SiGlhp6JgliJlRQ+U6WrklgJIzUgql5CfOA7jlMzuXrJmZz8bXm9p2yW0Rjq1BV6luX6g2soB8rIZbgUUHLOma0blcaLNob2W4l3VZK1om7bCKDunTefCmSjJtX2jcKqdQYiLFQnCm92eadGnpU4finHG1zKB0T/qxGWrNIFd1xiurXJ2SHamuF3E2Q3JWIUtBQ6Z42PlECHlBubxT5hKZ8sScG/1ACOLoQ8sytnqbS1F053CurVPrU4s4YoxcxpEYBXMzhcE7wt6AoMhsiHCLMX8f7f8rjtnd3R0Ab71lpOd/8k/+CTFG/sAf+APLMb/tt/02fviHf5hf+IVf4Gd+5mf4hV/4BX7n7/ydT0Kbf+gP/SH+xJ/4E/yzf/bP+A/+g//gC9eZpolpmpbf7+9tWbxyhfmqp5yUzyf77C15h1IyxQWGPJN7pTgLD2r1amaF+7lwyjNTLGQtJvBZ4JLsBdt1oQxC1xV894p3bm85+CPOO7pqMfdyIGfwZLS33d9rNvXfOdMq2p9UuH1noJtuOJ0+5aO7E+/0PYd9IT/aMWEWUlC6YhtJpCcxo6z8IVjruwkrD6BlojXvyDXvhqcISiOILptUm+x1xm+h74UIKutm1ngJy+KwQeugGmQb1KL5QE3+YQJeaOHQOw5vfQ36W6SH0skyGiUn/OWEPnxO1w24wxWpd3gGSjnhGutVPMkVgov0wL2y1JlcNmhWNCphnk7JDSUp7OpxN/W/Fx5uX16TX7xPvn1JfvsrhBcfWJ++eJt0c0O+uiLsdsh7X0W9g5zoxhN5tDEaz6+R734XPvqI8OEvs/v2N7l89E1eXwoPbXFT09qfMY5EWbzJUpM21k4tVHX59rushHXauNj83hbORVy4/htkNdjbu3lORrUkkoLSc/KZq0tBDzu8JnKXuK3pyL8eHb/0lcK734YDHZ+XREQZgNftXC5zXeA7E7ykUFxkV0Z8FkrwyyCZSyY4T/GB8+WOfReQXHDOE2o8oDjYH68YYyKlM1CYsj2M6lPAolANTurYL/XvsvbX1iBryGlzTJYNSNdxXur3QnFMlKXcklZv2BcbW5eS2dfrvqjj+TtaePg/vcf57kLIHYWRfRdgglKmxmTgMU7M8cx8yUwxcRlhcoIrhVd3vwrAt84/wtvv//tkLVwur/js9be43J+4v9wRHx84XDeEvqfH4bKS5gvk26WKgLasc+/AO4LrmMqIE49jICSl2yU+VRvPP8w1/XEgnE8MuwP99IroHUkyx2Otd1p6UtgxdOAkMmlmLmdEHS4KQ3U+dRjIrrNSSK4jZGX2kW5SXFXtnV3Exd6kdKibspgrIhukNYWAaIfLPSVFNAdcTmiZcdQ6hQyIy0iZ8eWMpAPsrtB+j3bdUk5FvK+h/WqMYcZDcQVPQwVnig84cRaGLdUAI1a4e+UYtnG3NV6+kIizmXdOViMq83RjFizyETCqhq9rxVbioiHhzYlupP4tRaUl/jiFola5I2EUil07l6w8S3HgvC5I29YBV2UR8na67jG51PIeYHSiWdBzZoqZ6EbO+siZywY+ciC9KfyHWM9blnqra83j6ryKicaG1oe5kKppMInxs32CyTk0FNgHur2ifiZX5YbHLIzTGU0Rl2HSRHYFLwnXYEEXsDqsgmsd4ABXcF7wqYbmdUbJDN4xeE/n9mRN4CBX728Ie87jCcEKmX8/7Tk/8PtupRT+s//sP+P3//7fz0//9E8D8NFHH9H3PS9q6Ke1999/n48++mg5ZmuUtb+3v72p/fzP/zy3t7fLf1//+tf/bW/7y/Zl+7J92b5sX7Yv25ft/2/bvzVi9nM/93P803/6T/n7f//v///yft7Y/uyf/bP86T/9p5ff7+/v+frXv07nhSwmYDeOxjHLZab3OyCSG8GwGEkvVY0S4w2UBW4uCOLDAiFozcwoGaYxMZ8h9Zm4n8g+4N1AqRpMIgVCwRXzQjpv4nMhOJyHVjbHd46zJI7vDuzPtzz+8j2fnxP+4PGVt0Oya6oqI4Us85PyRmlDZX6OeLH53Y6wtmTs6TM0a/MlZRPKfMM5tq15dQt0/uz6AotsRvt7k7Jof+8dHPdHut0OgqelqzS5IGKknEdKGRF3gTihHYgfcGWC6n1L1bzoek9/Nm8yVSg+VpejpJYOLgbTZ2Wqz7EDXtZLvgNc3dzg331JfvkCf/0V/O1b8NZ7yPGFHXR9i786oscj0u/QkInicKngxxPu0cJN+vqEfPgR8mu/yPydX2F8/cDjuXAHTEv6NyjZHPbfKM6xeY/P30mRN3y4fRdvgM395jvfC1U3zkg2CQBRimZcKXQCL24s3Pzq9SP/6h1PuRRuX0VugE8q1tZQyFLglbexNSf4vXNgLjMc96S7CVdZC3GKhGRZ0DHZdb0WxnEkVI5mcVBcwAfhcYxLFm6TF1mQrS1KKCtiuqUCLAgG6/zYco23sjKweq8dUKQwqHCqcPIu2/pxBkK2sG+s3znXAT3/+MDdVz9gOs24bkLmTCo7/NCRtTCOBtHPMXLKM6d0IaaJOUWEQNcF5tF0Sr717X/GbUx0wwHGCw/jiXM8kYkUicbdA4L0eO9QlBit7qSrxOomqupqr5Va+FnEk9XefugcWasQk3pC2NHqBNqXMlrRNgC3O9Llni7MmNLbhiskQlMEVGriRiV6inNoMfHYNa1cQRMqhr5s0f5ta5nPqWmbVfkEybJmX5aEegvbFSxpgSwWNk09ftjV8RIQcRuKuQ0mt0l1byEtk3Fg0S6TjczCMob0i6iY8sUQ+/PWJC22zbPhmG1IaFuB2RY9aTJALay5ZObz9F7a7y3C0P6m2nrg6T7wHC1rEYg2hxqFBTZInoIrik6RMl9QX6yeZvZIK8WujkwkkwmlhToF1UIpa/ZjlhpaZp2/LTS7INypQIEpgfORVIzo33eOPqy0CES4zBdmIAfB5Y6iwqUU9vVk16XUMK6HJZRp/3mKjTmoSQwm0+GlwzuPUBAZCKGuXV5JJVCk0M3fH2L2b2WY/ck/+Sf5q3/1r/K//q//K1/72teWzz/44APmeeb169dPULOPP/6YDz74YDnmH//jf/zkfB9//PHytze1YRgYhuELnzuUvu/Z9XvOVfH6Ekd6fyBrhTzF162wDW3o6AguGxGWTBat4QxHIS0DWS0xktM50XUjw04IrsGqVV/JKd2QbRA4scLo3uOd2MLRQovOE4Njuk70P3zEX4S7b76i83BVRZHiONLJwcQJqeETpXJCtovbm9sTEnT9txGzIzYJg65QePtOO34J/Wy+u/a1tRo9WhaC54tNC602SD3W89Ulnj0163E44PoBug6cx4vg6o6YY4ZpxqcJ1TMyPuK7RNnfwHiiTLZJaancvBAIDkIxQ/i04Xc0DbdoWwIBb9liwFvAW1Vs8/rtl8j771Pe+wC5eUG++YrJX1y9QHaWlMDhSDnucbuB4kz2VeJIPj2SH+5wn3xm9/PhR/DL/5z0jf8Xl7uR+2w8rAlnBFZAKSTbI6xKw/dhnG2N8S/8Tb54YHsPCzfw+7lG7S8hEwi2UdT761Lmtib5vN0/4qPw//jdwu/7e0qJcBDlop5WMzALaLY+/2aXOeSZn3x1T//VA7tz4lxT5rNUfaLLhHOByzyx8x25QJnNaClObYw4Z6KQWBRqrCGeZblr860O6O0Y3oZ6W8JMCzElVumA5/3kqSEk4FZNRDc1brl3uOLM4C2Zvs7T4QinW7uZT/7Dn+B8cAwPNicu6thHM2zGOFNmG8dzmnmcRx7nMzHNxBkkdPheFuJOnh/47O47ONfXzDPHw3RHiWc6yWgdX7lE+lIoDnKaySUagRnBNSaAM/I1pXJpFmmHQN87pGbExqz0biCEgO8HvBfmNJu4c50/+6tbNHlcuSO28zTnV8oi4q2aTK4jRSgOVY/TZLp1jVtIMQpIyWglPMnGMlg4ZrJKp1CFSVFnHK8l+lRqkfuElAuaCxKN0C/DgFaeo/QD2nUU52tCvJgAclmNLmnaWmKmj/GelJJNUqENwhbCbK1xbu2e20M85YYBi+7k89ZClYKFWr2s9Xnb+FwkdqpF1tb6xqsFG9+ZNeTZ1u7ny0JzZNjMhSKbZCNd95WtbMYyv5rAbAadC+k04ayiuIUwCwvFp8aPcb6snIKc10SDJfSrZF1rIS9caVanqlSPSNXoEbkIwzDgu0C36+mrNEqMCSlKVKuy4qKagQ8kafdelus33TLHKjWjLaNcMzAZ/5AOL84qHLhVyLmQiMGRs+C/zxjlv5Fhpqr8qT/1p/jLf/kv83f/7t/lR3/0R5/8/ff8nt9D13X87b/9t/kjf+SPAPAv/+W/5Bvf+AY/+7M/C8DP/uzP8t/+t/8t3/3ud3nvPZM9+Jt/829yc3PDT/3UT/2b3A4eT+g9u37H63tbRC7zxFUQxJkmiqVlQxFZOsrVvxUnSNfjSsKr0InQZQeV2F9KIcXIROGchbtLJHQT3nX0jeflwQWxrA5X7HeCeZdSyBXdKZqskr2M7Ice/0M9afR88skJPZrX1h2LeZNTR7jPOIk80vCxhr3Vosz61Fhqkz9vPlNY9blYUbNmnNkDbCanCFmflgh53hpHbUnV5tnEXjZ2FtXorSjqAehVGPojEswoQz0lZYzJCzrOuDHiNJJjRMsIIUJMlDzRSAWWXCCEviMEoZutLFDerBKdWq5iFFtQQi00/Rbwbr/n8K5JAvivfUB++z38zbv423eIx2s4HinDEYJBO2Hf44eAeqHkhJ884fRAuX9FubtDv2sOhv/Vf878zV/m9HrkUeFRzDB16JLokZq3+1x0rfXts9Xyud7R93pHbUNoXEG3+d6GsvEboqK20JYq3qt4AkGVeUrs6ii8+uDI1x5PfLzv+JU/suerf+HErXaoj1zqzflm8KC8jD2/wsz+8Z637wce+8S+iiYdD3um84w4xzg+IqpcUrQU+kWSQMgzuKFbkS5ZjarlGZsR2sZ1Q82eIWFtvrTyXA25XDzx9i7qeQdgEMGLci6W6GCHFubKickBdsnM0sMeful3Wjm2j3/0fbrzJ/S+IxbHuUAgM02ZGFdtu3PMXMaZcS7kklEGzuPIFY4umGNapkzeO/rjjlxOxPmRcXrE5YmACWGCZYb3NZGolLIY12bgrLuD977KZhRSKjhxqDicz/SdjdXxYSLsLZuz3w1WDmueOOd5EffsjtdI9Oh8Ypqlji+TxHD4VeahKEWUkiMkh/Mep2YoNe6bq46BlFYVQUyDTc1gaoaS88EyOZ1DxRtnLlcjrb2/+rMhXxmnU83oU4TVGCw5I7pHug4q4qfNSq/QtFZoSdUkaVTFvp9B3IrSlcKStfiEHS+bH+t4a0NtW9ll2xYnuP7sGyrW5nb7fPMdwTb2JKuDDCzJo804a4/WqgGAZXi3c8miFahPDM0nSDMrMtf6qEU+SgFJynx5UvflSRM1Y01KVfzXYlnmS3LF1iCv83gzr92mT+srtnlddAFJ9rueQ+9x1fOYp4iQcVlwuY4Np7ii+GbQbxI8DKmrBnxR1OXl98t8YpofmefZSlFJwePwriPXTgtO6Sow4/73MMx+7ud+jj//5/88f+Wv/BWur68XTtjt7S37/Z7b21v+0//0P+VP/+k/zVtvvcXNzQ1/6k/9KX72Z3+Wn/mZnwHgD/7BP8hP/dRP8Uf/6B/lf/gf/gc++ugj/qv/6r/i537u596Iiv1GzfseHwL7sMPXWmz350fe3iklC11QclaSFKLIUitLfI/rFd/tGYLp7sjO4WfFhblWm4ecCvM8E4l0riMmx3lO7PppGYlKIKSOfuhRF1EtqGR8JQ42yzrOlpWRvXAqkf2N0H29Yywdn1f3+/DCI3mm797mED16OaEumh4MZRn8iY1hpOskeY52LenM8kUPjjf8rpimzCIjsGQsrddqkhvbVO0nm3sjirbjn35MBwydJ+x3eN/ZX9QGfKlZFzrNLG53ymicIU62QLkMqcmeVDg5eEII9LMZsjZZq3fUQhPe0LrrUngbuLp9wfD+++i7JqKp77xHd/s23L6DXt/i+wF2B6TbVQ+ZNbMsRiQW8vmCP7/G3b1CP/0I+fYvWX9++5cYX3/GQ4EHJ4w1jCKbwqJb3bStN/0mNMxtjmnZpW+yqJ4jl0+Ms83fF+i/optbhGjN/DIhYS2Cy0KQgIwXQo26h8OOH3pv4PLqFd/6ylv8o//bhZ/6Xwo/nh33oc0fe74HKaCR5OBfOvjRjz4hDB2xq8c56LwDdSSUmC4MvmMqkX0LlRWT8ICC7wI5zk+QiO2m1pIaCqtB1oyz9tyNZN02lqKrsbz0Zz1X+28Q0+0zhbe6WWMq/VHAecejL7zsYPya41u/572lHzKeixP2TpGcGEXQaEKyrYrAwzQxjQVNjjE7pmkiDIFEoVyqHFB55P58R9REcInz5QHNM16UUhKpabQWIbtATObsNRFMt9kZpFimmxepyvjWQd5XQ6qKaE/5HnKhlIT3nmEYeLxcOI8j57p2vdsdyAGSdpQcLMNdHUHVdL6WdaGqpqcCKmjJZnyXzeZv4QcrOVZKDa+5ivBsBqsGECNpq5urcRbQMq8CoE82djVrIVuGl0pZiO0VxgEyaBVYLomihdAkFLIZXKaV4EA8Vts4WwLPxiBp+rMNnc3VoHhumyxOrq6/A8v9N+37VlWjhSa3xthWuZ863oOu43cxinkqC9PGfYFFWLYlwJgKvy73XcFD+56yVJxpWe710uRNjUhNWBLcWIhJKA2NcmUh1IsWSiroLJSOapS1fXODKtTObHptWvu13e/y7M6QM+eEoQsch8BV7xjcms3rMK3SUgpBvSWEie3dbRCqlAouCFrDGaK5JqOk/w97f/ZzTZal92G/tYeIOOe8wzfkVFnzXM3q7upusckuUW23TUqyZUgwQOnCMGTSgAGD/wb/AV41eGXwxjBgXRgwTNAXhiWRsmnZlqhukmaTVdVdXWNmfvkN73TOiYg9LF/svePE+36ZVcULy044d+LLd4oTJ07E3muv9axnPasQ/IEpXXO7v+IwHlCzwZiSOQOLrX1gnTfEPBA14+z84Cl89PjXcsz+7t/9uwD8wR/8wb3f/72/9/f4m3/zbwLwd/7O38EYw1//63/9nsBsG9Za/v7f//v8rb/1t/jud7/Lbrfjb/yNv8Hf/tt/+1/nUsq5/IDmia5zDH1JsVzf3hIvIhhhniLRJKJRgtFShQM4b+mGngtxpNnQp0SfYDMntl1HqCWEIUZC6IlhJBY6A5KVMSWilORcrz0pGp52vpTaWoOm4iSqnkp0nXOIZg6hpE3vbMQ/ho1cMKeCAO2nG8IEb36uByy3P7iBWIVuWXRoT0ZOWPV/vL85LVUynKDek0k7aTqtHbk6/YpgJ6dKyns6V3UhLpV+9857Ggv8vkbk6vHOObKxpx1TajuTpmMUAqRIlljataRcNY7uamTT7kMiUlppGVNYC6amn6Np0XcxzX2Ac+Cpdbx5+Rj9zFvoO+/C44KYmfMncHFJvryA7RaRLdIPZNstxluTkg9zEbScA+Z4Rbq7gWc/Q37yffRHPwDg+OJD9qmgZDHrCva3pKYs3VIOv4x00u57e05631GDExK0Hg+Rs/VzeVhJ+9EOodT/coH1ncMmoauOmekUGTo+962vcXh+xbOvGv7ofx65/N9YNrVK6uDBhMLluzbKJhv2mvkp8HQK5FDEdrOxSMzs+o6ZTDbCbQoYgWPriyhCzhNki7ribKfMvUCifPD7n6JtPu17uB/AAIukSOL+XG/HtjSmEUUqp6y1vNlLqTQ29NhYWgttvgz/4t/6MlfvPALA7gPZbXAmcDAzfhaOAi5bVIRjTRnup4Kqa3aEFMkmMcbEmd3iunLMFF7yap8I6bKmcTI2Baw1GO+w9UGGOXF0qUT3qvdK/hepCGkOQBHJ3LgNKRcnOU+ZoxS5jBhfcfMqcDjeodnivceYkfk4c1PTsBoVGSx4Qx6LxpfVKh6wRulUUEwVjbVkVWwuQp4LL0yqE5VKJCj5/oOR6pIs6n1iEWPBOjRH1LqTzmSudeFy4k5RETqCkBv63vpwpq5kUZwtaap0km8pzmPFr8SWSk7rUAnroswFLSuIlPCQf8aD+bdkGKinXh2x8MvkFOg2u7rmmC3+dj45ccqpOhNOweA6S7IOTIDSt7mmDNeVyimftBTbZ0u0vsKrYF9OBZeqAskyjqGkGasURUF1X8cHl96UKw+wsalPaOX9NarK0poqGzC53AuxwsZ7dj3sOuhtJlTiccwFIAkoobY+xMSSban3NLWMW9Mzq7IhqpGYZ6ZY1sYxvuLF1UvG44QzPc75k2db85bWGqwt4IF194gXHzv+tVOZv2wMw8Af/uEf8od/+Icfe8wXv/hF/sE/+Af/Om/90deDwXZC54RhKKTkF3fPuLl9xe7RtpTT+yKTYTqHrYic7T3OgZcN7ChdAWIgzImwS4SK1swhMmdlvDOlN1/Ir6WU1GidwIK4Ek1q6skhkfJUCOpUYjOZOM/MajjMMx0Tn3nrgm0+B+BZPPLiVeDy7Ah/4Qw7Dcw/OjBWaCPk5vEvN+A1aYTlT6vF1/5RjxVdwdVUB29BGE6pzCZQsmIDnJCy+m+Nxt0jm2tDGooTuOZDlEavUiL0FobGREpNYHJC8kzOgTxPyDxBOiAYkhVsK/cnFp0kU6ISa4oTO3JC0zopulMb4Hwz8Oj8CbzzDjx6hFw+wV4Wx0wuH5PPdrDdYvoNpEJfFpWFHyMpwXyHHA9wnDDHV/DqffLPfgg/+3PSyw/rvMlMFGd6/Vii6LIkNReDkk40xNfv4Wp8VKpjWY7VAberZ7E+17ply2vnlPvvX09HaeAeEDEYrXpTWZC6NrqYOEyB4a03+fy5oXuR+eNvvuQf/S8Sv/O/Lk/7i9eZlwLiDJukzKaQs6+9IcXMef1EdzkgSTncTTgjzKGgUp1QORzgF5J1hmRPiLDeRxyWfoAUB2o9n9cK/220gKRtTsIpza8tCKnnzVkJpqSfW2BzyDBbSJp4ROTiKXz4l97kR7/5BWKdN6OdcdniRRmT5SiOnVqiJKJOjK0jSc5EYwhS8BBTN/9Jlabu2PWeLIGQDuw2W8YMdiprSSx408j5kUkNXfJLY/glPbSgMaamMi3GOKz6Ii6cMmkU5irQcxw/5NXzW8J4Ry9bssk4awkhcBcqzzMreXDYqXR0SDlgNRYytC3rvTzIKjSrBiOurjFTuWirG68tlaZLSm0JHRePpD58a1HjKHkph3E9mirKnRWMFtTD6ELZqJAYpZkbpDkVqY15Rp0t12kLFaZFtpLLtVb+ClksYhxiLTmkZX01x0Xr9a/Thuu59zDAerhe4TSX20s/CnVb6Fr1HEZOtBW7OnxNQWlzXmQdlBeUTOr121ycnaWQYY2YUYtvVte5RrEAclLinJnuZmKYS/9RCoqWajpGxYAI2WjJmuQiHqs1ZXyyZ+UdctNfa85s+yB1NL/ZCHiX2djEYBNOI2N9zxAzFkvMhpAi3pgSGFhZJLWcWgxyzykTCsVgDiOHY+E5f3j1Y56/elUaofdzddj78hyqI7rooVmhLc9fNj7RvTJTyvSDwTpdOBiJxO31FZvdgOl9ccjOetxgMZu68XsDRsk6YSxETYQUSaGQUI9zi2Bn5pS5IzPnAyklYnSFxFeT0YOzeGeqcavesevJJjHHkVjRsJQTBw7YAaZ95NU009mJNy4CT7/wqBxjPR/8q5+yT9e89eVHXN485sWzmat9ZCuWrjVWxRTIV1jQrV/kMjc0YCGBm9I6BsoElpaiWSFvjdgPFeaWmubS+7ylew7hA4NSOBW6GAjqV9v5YtBEFqdMyUhNi9hcIHyTU2nwHAPEmWQN0WZsa4cVAllnMjMqscnMFD5QNVadwiOBx9sdw+M3cG++xfToEf35JVw8Jp8Xp5hNXwibSuGszJGcC7egtcyRNKKHW7h9hRyPpOuX5Pf/FP3ZD+HFNWEq82umOIfHdv8AJS9kVSicKc1Cloz91UCz1/l87bY/MNay+v3a0VujZw9RsvWeCA3hU7xkrGaC6kKcBlB7geeOLt/x9M03uOlHvvj+HX/+ZOb//jfLye7+t/ClD+AQIDplE2GmOF63CF17s3kuhRspI62PCaXlUzNQSfX0fU73Ys6Wliw/rNMyJ7Sy/W315RS0rAIYgYXrtBxVd1Dl5Ly9X6fgRuHtaLkzibcv4fIPHvOPf/83wXdoeg5Ar6WR++0s9HJB2ERk9EwycxyPpDbvjSc7JcyCFUc0irWCNYKrc94kATNgpPRjzQZ6B6qZMCdc1WrqnCc5i0554ZnlnGt/0rZhyOKYiXHEY8JsLHEOZNmitUXeOL3k6uaKw/XE1h3oLzYgBomnSnEvBh08403ZKSUnDOX6jROoqut4txQoiNQUpTaCvSy/bxPSVKespFn1PvorGTVF3R6pLadMB6qlpRMgJqGihdAvuXhXtaLeKEUYGArXKZZG6mIc+ITaSvJvjO3m1IqQjUGkivXKuibx/tAH6+zjxmIX6rFrKtKaj7ZkIOT0okxzupojW4KLyImP1t5DHtiKhoitkbymqdgcpAX9y6vXcHrvJUhqr2mBehZyyMz7mTgeSa6EMykaVBt6lZlyYEqZ3vgSYOSC18fVubS2bsrNaV/dr3v2rl5/54R+cAy9pbOpFJS1u2oMzJmcwYhjg5QuBS1nC/g6Xwu1Y803y8zzyKFWSb///Ge8ur7F0KEaSFqKDJxziG1zsHDoXA2EfpXxyXbMTMAkj2x7+pvSoe8yPOJKbrjgXfwmc3l5znY70PWWbqjk/8GQzMwcR7LOpJRIRBKJZCNdVY/13nAYJzpzwTwJmgqCth9nplV/hb6fUT+Q1DPOytDPeO9QzlFTkLwYRobxjjtxbHLiM77jJgUyWy4vHwHQbRJqjnzw3isuu+e8+e1zzM88+U8ih9VmNJBLyyUFj8EiHEnMQFcnX9L8usO2ijLCvd+VoxqZukV4D/enxttpgaqI4B9AdY2DBqdFP3Pqn6iASTtUDTHP+NxBVNQFpKYyS+W7Q1LERUtrXG1CopsiQlUAzhETM5p7jGR6e8eQa5RYH8+Zt1yeP2L39C3yG28wnZ1hz8+I5xewOytyHYDkiIQRcsYEJZkeyBhJ5HojdJxwhxH2B/LtK8zzn2BvXqLTnkknprro5pRIqouRaw6SsDbSikFra7DTvVtKzVf3Ptcb15wHeB3lor0X3ENDH56ryZ20hubt2taPUWkpaoPJjmAyzmSyWkx1Is5vI4Mm8mFCd2e8uXnKeHbN1W3gJ5tyzH/9Nyx3/2ni6/91potwdI6YDKozBviwTujBTJzVlKcrdXRkBEciLSK00NmBOU2oKq70RacXy0HTa/N8fZ/ax4+nP594fZR5HvQkc9Fe3YzjRT5tux54j4ytC+g3L+B4k3i0g6/+jz/PP/79bxLONrh5ZBM39dojMRfHIeSA2wujh5QNxmyJUt7JGM+gE9rDKEfmBF2KlS9THTOfyTawMQY3OGRKxE4wbOmCIlUAOPuCClmfSVpkXbpsSjVnnSFbB9d+xlqHS4l5N3E+nyNH5Sgere2KMnfk5IjzxMs5s+1GNtozJ4e0xtB65Il9xDOTkTyD7LGaSTzGGMcQih3Mj7dYInP2SHeOhrHIaThwTexV3IkAvsgWFGTFLk4cS1/DbArSIcmiJiP0aFPhdhml8NhESmIvFTiu1Ancmye2OkEGDQnJinH2vlSNWNR0iCnImnhHjhZxjjyf7Bc08dmTY/XQQbuHgNe1v1Rp11+b+sZCQ7bayVYvFVBbguBCMynvbZstXNmUddpyia8frP0iQFszyTUabB0YoKYCc0EEbaWKZOqxOTUhDA6V572/GjnuLebCkMNEViVKseHRRjyWaD0pToWPqLkUHMgpkGwOXwo1bVuNWlxdvAG8KUV8LhenfHNm6dSQdOC6zRsbuZUDKNhkSa4jOOGpM0hfuwOZLZvtOcPmjN6cYbFonggpMqWZ915+D4Af/eynzDO4aBDvSteL3tF3G1zl0Ea5I4hwjBHjfrUo/BPtmE0pgenx2dBXvSPTZfbhhjm+4rJ/m92mZ7sZ6DqHG2qk1Clqha63TMESYySbQEyBGHWpLx5QjHHEnBi9EufiPIScSZWAjjGEaJkZcSoMTnBW6DvHxpuFOGqyJXdbOslkMr16cq66OL5c18Vuh//ml3lmBl6Nz3h8+Yin33mH25c/5OoDlnRGy1InUyKORgRd++JrpOAeVN3y8Sur0BbcQ6dskfCRU2Qi7X0UDK9PsnWbkI8aBSafyTHWyqwEKaBSWp2Ue1XatLRPoSmXnnh5QjXh6vsatLZCKcdZb9jVtJAbyoaxubxk8/QpPHkDc/kYN2xJnaXzPVghhVoJGiNZBOMj2SkqscKJFmme03gg392Qr1+Rbl7i7+7Qwx0yBVLUpbpu1pKKW5gU7f49TEHoyRC2eyN6kvhYfrdCLOFk6B+ONbn/41Kiy+vX51591Xvve2oYnDNkzQvBuRt6VG/YH29h7GHwPDp/wruSmfeFO9YdhLv/CH7wFyxv/yeJR8fIXb0vd5zmSJfgOBimMeNrkYtQ0LW16HZME9q2plzu06hpKXxpo/Fr4LQOAqf1sN4si/SAwaOMlJThcfXaXT2jJbE1liknzjr4tbNyzJME4V3w//Gv8Q//0td5qVuM7ul7z9KG0TiYEyIOmw1WLb21ZLFEDPcelQN1CkZIYSKHghbGeiOcEYw1GCn6f9krPR2iG3IXsLV595ntsNmQ8hHVgDWKcyWIs1WmpPCjPCoecZYYE8cY6IwAI1NtCNhliGlekMnxmDAylWuvhTFqCiKAEbIKQZWkxb7hPLmib7bvS1rIOMRZJBs0lpTWqWpWT3MPUE2Lk3PiebX5WZE34xCTUVNybw2VQxXJ1WGrDrCVyjmT1pekzLBmVXK7BgpMVHneBY0ygjoDzoF14DziYylMkdaLCDSvUuirtbo4Gsv/7o+Hy7bJ17RDT/fk/npXPXHPQBcOmuipMGBtw5XKqazIWuOPWSoXjfq+maohdn+Nab2WdbDY/MWTM1VeFKeZ/fWR/okl5kRKafEOG5qbagd2KbkFqiTZfWc2nxDDonHGvV6eUDi9KookwzGU7iCydXhrscuDTEg0iGaOJmHMiDPnRPH4OufdNmOsYhCsA2PLdU9z5Gr/gp8//zEA7z2fiUe47LoVVSJALhXJAC5mXC77WpxPGYFfNH7F4s1Px6fj0/Hp+HR8Oj4dn45Px/+nxycaMUvc4d3AmBXfFzxp8AOHSQiS8M7grSvVENbiW6WeoWiPaeH52DxV7lLRhm63RURwLjPFCbIniAWZCSksyr+alXlO+JgxfaFa5hRQAsZ3+JYmNJZOPUkKmdwYMOLprF+grm7Xcbnb0HVbfvyTyJUEHn+uw397Q5qOpKty3LEwM2pPzYIGRFPSd0sFTg2LMiwSC3mVw2nRT+MtLVU6eiKpLrHkKmppZFK7+v06bXSPYKr3kZ9laEZTqHh4ghyRLIvKds6RnGZsjEXjJkV0ihgSaKBV6iQtHRxEE9kItvNcmIwRj91eAOCePEWfPiU+eoRuzxDfI05KGBlCIfQDOYXS/aFLqMt439f7VWQNAPJhj7l9ibu5Qu5eobfXcDiQ5okY81Kc0bSxTlHu6aPLR3xt6M1aMbv9fUlFyCkFt/77eiyp0F+Ali3XI/e/P3E5TheulbSxPEOVqrsGknLpDpADUzii3jN0j3jjLDGl0gHhuYkMdx3jb0S+90W4/N/D5//Y1jpamGvS44McOJ8yu03HPAY2qngDx6zLmhUDJp2IyWRwxhTdIznN5zaPW+PnxofxnOQwckvVUP7NUoo1VA17TrI0CTgY6HMCB9teebqDzTm8WxWThy8OfPA/+y7/xRfe5GbuOU8zwfeoJpKt5tUkrFEIBk+Hz5be2yJvkWZsk2PBYrItavhAFs9RlZTmJWWjRulMOcYZTzcrAcvQ7bCS2dVUJrYrckKhpE9zjpXzZXG+q/fUkdVQPp7DhoiOgfEYudYJtyvP5wuPOy5c4kDpZRhmmEwimYDWxIF4xyCCGEsQYUqZLiliDSqO6Mu98ENXOKPGIcZifYfOFio5Gij1QNIQsbo+CyGooGGLuOdicSgcs4JwY4vwaznI1Wq6hCKUjKirSLYsxkofVsZwQnMWwdTGY3Oe7HrocuVFFv6l1gmmsSFJRSKizcmPG8Jp3a6zHPCAQ8XJlq+LgdaUk1ZlK7Kq6FxlMZycimaa3Y9S5F7a+y3mIddr0gfV/Xq6xvXnavZp+XtWcoR5DFy9vOHRmx2zzczzqcAup1T/lQssGnFSXru+R0rZr9v9XN54ZbxWNi0kOB4T1/vIbR+xqhwqd3yaI2I8iUBvy7yy4hi855GtYuKdq2icFOqMCjFGDvMdL25+zk8/KIjZ9a2DGBlkJOq2VnFmEhO53tQ8TsScyCETxkao+MXjE+2Y7TUzWUOPI5vygd2mpAsMe1QzMSdCyjjnqtNVJmdJIRm82Gr1DWoKedTUc82xqF6rdZje4G3EOMgmFT0tammwNcgMXjzeGjSXh2iMLCWzrnO42bCTYqicM/TiS+GAPzmCnROevvsW8zYzvXzGjRywX77EvkyMf1r1u45lcnpqIRJlQQVK5SfcT2sCC+S8TufAqcw5r35+WOlWru1Uien05IA95E20NFR50etORAZCmolpLGnEoOBS2ZTqVWTNaCwLVlKAoGjIiGRE86L4fxIhKjwU4w3JWWy3Q86KY5YuL8nnF7A9Q4ZNEf4jEKcR0r6UzVPTp75DUyTZgNgiCStZoDpm9nAHt9dw9xK5eYUeb5DpWNKyWUj1k2Z0MW7r23NPdoT7ELzy0dWvLfXQTrSu5FqnJOH1Z/FRY33MRzmO+uB5tXcwWIyGBZ43MeI04+dMDJE5g/VbdijnfXHMPhxfMY4zfRR2Z3D4jzv+5Cszn/0/wfYOYtXJOvSGY8psYuDCFF/dmQ5y4XQCSFacgF85imPOUFMxq845p2o0OW0+TWAYVhtZTf94LdWbR3JN29TnaJVLC26Gbgdf+VJm2BnePffMv/5VAP7Vv/1b/FH3BK4jT81zpmHHYDIxh8ppgk4zjuJ4dNrhxZWKS5QY54XzhVhMcmhUEkrKI973JJPJq0RSpuSnrLVYUSLCZtgxuIyvTLrkBtxmxzweSWKYwkyQQMqhpJLqfTEieO/JGkvbMjHEAGEqDbwB3FuP2G5CqdyWk1OQUuL2pjxruuKYGeNQscSkzLk5RIbY1a2mt6Q5gCmisWQpKUF7mniZeG+tiFKI+3XiL4FfLinHwq0q1kkkF05oLX9bUqKN3C3ln9R5sKRF17O9EkJbum65FlsrTIWSwkyZbCO4BHYunwNIEu6n+OC1nz9urNPwHzWWAJoH611l4YOJCNYoKVe+mpzO3YTJM8IiRqGnoKUJ2JpaUbDWOWvr5qHYRarnbo7aUr1ZJabilDhcTWyOgdgVKRetXFWjpryfEWIq4rS5ado9MKCNtwss8iD6wKCZVkWa4XiIvLwe+cAWgeWbGkyNszJrJKbSZi6ZDYP10EN2haOQTYdzhs4VTb84B6Ywc71/n58//xHPXlTh+MmSZ9inyI29RTQyxh7ELpXIfRaSZMZs+BV7mH+yHbN5miFSNtSqUr11HVM3cEgTU5jpUsRmxeaMr1GRybUaSApbyxhTjIeyVHK1ISIY7TBmxpiMs0pnDaEuQmsDrleiGXCdK+0YJENWJGVc5To4YzFDXwX5LM57ArkoE/eFGGvFYM52PNaO7dtf4fmTR8yPXzGNz9h8KRLlFQDjzxN6B8zgUpuh1al5EPksn6P9W4xffdl64q+O/yiHyyJYdOFAPXQ8oFV5rnwm7pdslwqeVPhkubJIY4m+T3yiXCOjBKk4aSlEjFFSPjlmBsXYUuGIKfc1qS2k/kePy8nOL4tD1g2lrJ6MnQJhPKJhXjiAIoJ6B2EAewQ7FKMdc9lIAN3f4e5ekW6vyYdbbJzRGMixfowmRMnJEDeybrH1skTyiwO7jlC575S1n9d8g7j+w0eMxH2EE7hXybZEldyfD/feT9vfhGWmaGrkDgBCKpwpr5YxBMJ0A2Scs+yGcu+7n17xk71ycW7pX0XONpnxu/Djb3eY/3Tm6X9VTz1mZms4ukzwcJ7ATzPZWIZ68VYSaHGuOhFQQyCVTUZf1+VTvY/oNjkYODlu7R40Rxpg0tP93yj0czHOX/+c4XN/7S3y2x0v3/iL/JOvfBaAZwQujkLYTkSzxekZQsDYDkxtJ1UkobHW47IprX6kR4zSWUeqNiIbi8kOCaXYJeSENx2EhGkdRHJhyxX9wsjcW1wa6P0Gb4Cp7D5b2eLtluukHELgMEUuPSROPWlFDd5YvID2hjhHZp3Jx5Exgo7luu4U/FBRrLzaFMncXV2XH/yGLUIvFmtLD9ApQSKQ1ULlfNLZFU9KEWNLhaN1p0BDWXb1MgfTYmvuBXmqS/UgUJy9XOZtmw9WbAn0FvJ6rsYrs1RZ0tDV08KR6mFktGhT1aO0Op/JVdvlPMSIutIZASgip9xHlD5qrbXfw/2A697f6wdvn7+dd101j5pFtmld4d1Q4zWHtVXU62rOr0n2eVXoInW9NUds3cp0PZQVYs19ZE0ViHC4HtnsIynBFE6IWYpNULY1Hly2snKd7br09Jwf8mHb19b30+QSyIVZud3PvKxkzyk3d8ejOeEVZoENjqHbYOSCz2+/CMBFX05mTSbH4shOYeTF7U/48PnPmI7tnha07zjD9eFIjMfS+szu6OqcT9aTUMa5HPerjE+0Y3Ydbzky8kj9Uo2I77FxS55G9uOePp7RKSWCq166OINTA6YoN2eUKUfGeWacwgJdJ4ratVDV331p69B33VIRFBQwib7zJfVpLM4anCsITrNi1haULMZY+tV5h6/Vhq4SY3vXId5y6Xdo2tBtN0y7S3y+wLgNqS/vee33pJ9PmJegORNTac6rKK3oY52SalI8DR1TWCzCQqbktPjvGUhOToSg98qtpf7vAZL8WupzndYr0VckSmmsXEowEzmfGreIQk7F8BUFbUFxpAQ5y9LTL5u6pVYUR1ULybPzmKGka3KtuiRNqMlEFL+/RsYJk+aFeq0AoUN8wlmH2AQxkkOEsWx2etyTDjeld2eY0FjQh+KUnaQDIrqUkNMMjp4MZbunzTlYO9Ptb8v9WyFEunoWHzV09bzv/W79+hWK1MbawW5E43JMKjW/UiF9PV3LHAN0gpHS9idOR2IOdH5YUAg1PXp35MNj5N0nEFyE53DGzPTvCi//Wq1+/kcR//9U8t5zs0nkbWZnYR4TLe4YrIWYSv8/LVuBNbWRsuq9aL6N/GAenjYyIcppc5qlaPaNWjpF3NYHodbwtst87tczf/C//HWmt57wT8+/ws+Gi6Ua+fPpknQGh5xwvkf3d2RxOOexNVVrmWsTcYdJFk2pdB+xijhHR3PoternFaJx52a8RE6S8xTUGGXKE5oTVjuCRu7GwFN/Tt+VzcD1G6IYpuS4m2bGOWGMw2v5ClXh3xg0Fe0yoyDesOkG+s5zW53BF9d7LnOonUxaIUiJwPZ3pdBjzEovgs8WJ7agZTkxpbmg+HU94jKaI9Y4mtK6YKueVX0+plQsqzrIK7J0q3Jc9BzNkuo8zWVLlkRr5dOq+6S2Zrt3rMgiHtvWVmmM3lJYlPR9cyQlcxL+M4vGlYqg4krhAac13RwTXX1/H5rjntMFq7XeHBK5X6XZeryubUTBJFt/0BZoci+dCSwitc2GNIdKODl07Vpb5qSlMdc2Y0HGuO8spwdfc65BXoa7myPD9YGUHXNOhLmcLMYiS5RS1bzU+hm0KO4v59KTDV2nMlfZ6PIyXSOqmRhgnKozXR1nJ1JSp7YU/pwhHEX4vf6r/E/PvgnAf+O/x4sqGA2GOQZuxue8uPqA6+s9plWUb9vytIwxYTJ4AsYGpEncEJk1cJzn1rTml45PtGMWx0hOhmQN6its6ASvgdFO7Pe37C4uGXIiqqmVH4WvEKNhFuUwzezHO+6ON4zzxBRORt4YwCSMJjBKSqVFhwBSN3wNhhAzZ5JqasDgrWC9wTi7QOq267DJ4giEHFCErmuN0MtjcM4VOYKdx+VzzmbYdYn+q5/HbTqiLXhJcC/Y9y/Z90d4BW6f6eca4dR701MXXoOj5bRRvWYf5JSeaCgOPBCT1VMqc3H4HkRO7TzrtKc84L5lisxJkkjU4pwZTRjTn8CZWhKfk0LOaCo1UkktqgYrbdoqmgIiJZWRYsIwIfEIobSw0dkVg0pC4lzaZRzv0GnEaF4MeFIDEkqDY7FkPUJOuJgKqQZIxz0yHsjjkTgeQFPhk6S2AchyXxp/ya2srqjec8ya7MhDh2xtdA2nh9UM8sdVXC7PVe//bj3Wz6w56vfS1sKD1xe3uPVZlPqQvBqMzsQcGMNAEIcqjHPkeqzotQ6YPJEm4ceSeArIJDg7sJmPOFfTbv9Dx+EvJuY/Cjz9F7B9AfQQO7itEOFdSAwCW+NIOS5IYm6VWu3zrT6nWRHkRl1H9XpvgzIquGrpPYqrsidPQs8Tc+BL/9Hv8V/+7tf44PAEZy54kwmG0t+WoLgc8d0lczjAxUB3BLWnOe+9L42uZ1s0scTRuwFrBdF4H+FMhpgnrJ3wncUnh0mOXAVTk42YlAnMpDiznZSDJuI0sD0f2F2UVIz2jomZ3XZL7kp15IJUtRkmHut7rC/VaBElqGIiXPRDqVAHbq8nzs4yYkr/wZyKcyZOOB5uALi6nXHncZFTmVJkTomjjgTNyLYiZpLJMWLFlnWpWmyk5cQdMwal8hlF0Gy416twQZ1PTlVDXJT7jlrOaUlNLgFKjVZUM1or9Yy2FFq9O42zBuR8WpGlb6QgtZfiMkReW2ttrBGdhYZQF+vDpWw4OUXts63H2knK7foorZOMyBKQNxt8z76s1rjoqdqyBebt/Ms1CScdM3Mf0WunUjllZxoNYvmbnr6Ox8B+P6E2M5pErPMkVwkApTQpXz+jtc+9fo92zod2UOW+4ypSuIYpgRGLbfMlT0SFkKGzBrUe2z3mL7/xWT6zLbqj/9ntc/yb75bnoYl9eMUHr37CsxfPOdxGfH3zsw5Sp4yHxDQV9aVsBJMiqdZ3q+8Z08Rhnkn3Y4OPHZ9ox+wY90SdseLwTWXXOmKy2Oy4vbvm7PCIfneOYVNSIEAKhkDmkAIfXt1ys3/F9d0VIScyFlOh61YRvfHFUIRYpBs6V9qSAGQpDkFpda5YMagpPSA1RVx1InqxJOmxnSdOd0RNbL3HGFscOGoqE+E4jWwGh+87bLhj6DruLs949I0vAnAnHvUWlStGd2BykXCA/liMJlSyM0BzyhaviwXBgVNkBvcdgoc/N5j54VgbjrWdMlDTjLDzp/RHb8BvLNaD2LJ4Wom9WZVQFyMipCxEzcQUcWJQlQVZa4vX2NPrlBlNM6Y6UxLnyjPJNV0i6LzHxAiaazIMchREHDprJRp3kGc0RrSmh3QayfOh6J3F4lTmUFC9DEsvuOaUtVu+cMdWqQbDSam/3deHkW27wYuxWznXa5L/vUhc7j8T4eOkL193ztv7tb9JhVOLo5wQPfVT7OaZbpsZRch54Bg6Yiyx+z7UtbFzvHOpfPDimtuj4b27jJkE0SP9O/DkcXnT3bPIlIUfvGX43mXmu296vhEc/HTC/8uiBfb+T/YLauuAwRSUUuBeS6a8uqemIrwiYOXUkiciRD0VEkSjzNWY3wCPKwfrG8OBd/7dz/GDv/hN3PiEIc9s5chtb9g2h14g7zw5J4zZ0uWM+JE5zcvGXURcXdHbyg68o7ce6wTRkwlOlH6LkQ1TnrFqceKwzqON95oiVgvyq6IcGcniuBlf8rw7Y1dRuktncDnT7wY8G5yxkBXjzILQi/WIdRjn0eoc5ZwQK+x2O+ZYnKnb/JxJc5VXMEW0te7oU10bL29u0bd2hauJZa5IeCSSBEzjmImiKWGqFoNSZCsKelbuqTGOrLF2bNPlgRZ9LV0WQfv74pStvp6QI1nQL1jZqJobWzvFSqFXJG1teGrwXYsSNBlESyu5wn899V9kJR2yIDardGEbSxcOTsHwrzJeC7CkzBcoqbuMki1YShuhlqY1Jzf8NRvebJRZB+2r79s1LlkROf3+PvBXHKtCtz6dYHH6MkzHmcPhgHYdk1XGSsQPcyLkU/oTChq7NIpvIEl9n5i0kPLbfVndw3ZvI9DVIKJo4QlWpCVW0DgTs0E1MzvLfrPha5df4N0vXPKf6T8H4Hvvv+Drj95inmcyidvjS17dvs/t3QS552woNiKMhomZWSBFGDN4ByoTpnKTU0qMaVoydr/K+EQ7Ztdp5lYOPGGHkRLBDs6xcY6XByH7yMubl5itx5gzvK2Cj8Fzezdydxj56Ysf8/7Nc27GA9Zs2LgNu+p0nZ33dBtHjoWHZl1B24wZsL7Mgotdz2bquJ5ekfKRfSgkWYzBOyHVPngW2A6CMCDmMVMMjHHG94KpUZuqYG8NGu8w8h5bf4nxbzGFKzYEzjflPZ+8e4lNiniH2b3g8ONb7l5m9p3ypN4bKxYXM/trPfHBtKZ+ODVhbr0kG3ehtVGywNyMWz1nM3cNDbOr36+/LsVRDqwtMlebepKxgzgkum4mdpmJQK9SxWuXJYbiyPkIOqKi4AdymjGyIeXawSEDaslhAhlLR4DRkl1Eh6LMbJzFZAMhFxHGusGpESSZxRm0oqjW6jURJAU0KjlktCIHmqgK0fUe54kkJRU260kLqFW1lvRZqcy6x4XhY9CxZoT0oeGrKeEHN3p9vrx6xrlGjg/7PmZdpQJW514XcTx8lqleVyKXVFauBjU5NPb0/hGPNmeI7bhxQjQdb4aStjoPHfPnr+m3P+eHH35A/PDA8ZjJHQx7ODur86sX9JXyzgQvjzA+FS5/6+u8NUWeXZdU2We/95jv/yc/5fJ9wyOEuxzpKB0rNvU+twtP6NIOywNei0hoqC6qRznUz90phOwIRPbA14H/zufq8/0P3uHPf/87jMfEeXqfYfM2+xzopw6p+fTOdOhR6RQKGSUzm1Qrjsvny5WeJ77DZIOlOEEljVh6RpbrMsw5QU6IAd9t2YXM2AVibZRrszCLYnRCJDPnHo/Hh4nD1U8I9Z764TEX1iNmQL2l6zt8N1B0AAth37pEjDNHAs5Y1I+YKaFmx5vbJxxduffzz0aS73BDIu57rDkQTClq6OpE+ekPv8/t1z5LmGaQTBgTm5gRHYi9pa/6asQjTmc0bQBLtoqxgsnpFGik1lhNkOxRAkZjRcPaagBhKjc2VU5vCyp0tTZMqtSUGriLqel6xehpLaNFnV2zlmbVuSDnIrLoK2IdYmv3gVxoLqoRMYrIcXHSUi7VwmvV/wdANA3EQ0+ouFNQkXvIOmvHY5XZYLWGVcHWtkJJUtFjzBQ9L8uSdms9NK1WNJ8TAtekGkN1wBq/bN1vee1WrJ2hUJ2/EiSdrt1pCZ5cLEU7x7sJ02VGr2jlKBQQI5S9bIz1fmlrk3rKonCyWzmdUphZWFpSiVZ+mS220nWGrjN4k8g2s2+FXlgusudIIuiWx7t3+Ctf/w2efOZN/uGf/R8B+L99/5/zePMWj7oLDiKM4Y7D8RVMkc9sL5nr5ndrruhT0Rh9dQd3ATa1py6HqmvpwuJEP8wyfdz4RDtmYrqCqCD4GtVYb3DOwAz5eOD6w0AmMt6d82Kofdq2PTfzxPvPn/P+e8959uqaQ5w4P+94tHlE2lX1SBIX5gwrAn4mZ8F1PcbB2baU1fZ2SzaFP3YMM8e5thHKAadSSMCAyIauG3D0DM7hYyCoAY3kBdaNZAzzNJMNxC30xpEJJDsTq+J93k24t3oG/5h5sOjQoVczhMBce9dtO0cfM+bZzN37ylT8FAZKmvNlDS9cLj8PrFCHagJr8uGU8qEubNvQrgJ1t44ltmZuF36DL7/beGgBs+/BdAYZDK5zeOMocV1iYf/Xr9YYsjiMiagWqDtrxtRpa9QiOSOpMeyKEKadBTmUY9RYtKc0Nq4aH5pqOjpr4bKVR115g4U7YjJoSuSoaFX0JoykEMgxklJCUjEALfJ86NRAjeLXSOUD1PKjKltfG69ZdU4GnQfE24qYrRvd31P11/un+kUjP0h6r1M3dkyYy4jkPUPaYPszzG6L+p5hQW2V0PWcPd7xZrfjn81/yk/Dkc7CbltSCeUDOAKBI4p0YDUzv7xiK8PCHZ1+fc/lec/N/3ni8E/g0QxBFY9wjWLrp3JQpTjKCMBAWjYiKKjYEcOkmQOgRL4M/HsdnP37Az/9730bgJ+99QU4jjySO9hecBOv2Pqz6jjVi88ZMQYjhlw5Tzln5hSRtqFjSrAWFUtp2WaWMEeX+a5SzuWco6MjBKXXiawe2+SlU5EuzikhEkATmiOD8Vx0PZe+BKiX1rP1G6z1lVzfocYyal54e6qKQ+gwBdGrXnjXOx5fnPOyOuG+g3GayRbEZWws91U5zecP3/+QZ3e32DEW0d9alXkQQQfPcVPe9OymVGNLKvSQJu8AhkTbONfwflldBQ1vk/d15KEheA8XU0l1n9K45FTfrSDjzXYVCRYt0VfKVbi6BoCmq88nljSKKa50OQZyyoSYF95xG6IPbcH973/VdfjaZ22fbRVMRc2YxuK/d19fH2vkLLfrqr9YmktVu9GKZpoEUHlNKa5YuKw1KHSU6s6HmRULaIQ0RXSCKevSOSfMsfKfE+TqJLbPdYpt6qPVhbO3TssuKH/dpDop/dK33tMPHXPv6ETZVmHlfZi46oVeO4buXf7yV/8qv/e1b/De2Q959qy0Ufve9+947P4Vjx8/JpvAz69+zvObAxu3xWeDr6n5OcBmOBB0ZgyBu6wcpjJtG73ZzBFs2Qt/1fEwe/Xp+HR8Oj4dn45Px6fj0/Hp+P/S+EQjZv3wGDUbkghdDX2MdzjfgybCOHF9uOPF8cj2bIfvSuSzJ/PB9TVX1zfc3Sb2B7A97DaKweIbyqUWjQabwfsSKSTNzHmi7x8D8Lh/E5k9m3zJq/kOjkfm6Rak6Ay15urOeHrn8aYH9XTOc4yWMdwSa/PrJKlW800oASGTvSvtVEzC18bpZ/OIGEiDkt7oUXvO5jOJLmfsvqYp4rGkU9/e0J1P6M8ycgt5VPbAbhXctVJtW793vC7Z4BCMUYwH6UBcTZdt3amyKeYFFRIL2ZeTWVeOB8gDzC4z+pngKjE3V8mSfAqVhFyr7+qvFLK1pboql8hH6UFD6c2Xq4YRGcKMO9RUJgohoN6D87Wgw56kFGIVOoyVP9UEVLWkME3Op7ZNYULjhKbaTmouEXmLKtfVkGsEDU4RnqwivQec5PsVRat7/3Fjrcf1kX/nQbqT1/XteHAMnNKd6yGNa1bf85gCMWXCIRA6C85zJjuUgU1tfbbtZrJVwmaDy8rxcMtd+Am3OTPtoa85iPE6oMkhBqyJeAcxZo6SyFqraz+MiJs5fndi/KLl8EPD5vsB+7KkLdvnulh91khBdjylV2vNBjICh5xRHG9j+Df7mb/4CP7kf3DOv/h3fp+f1znx6PASMY5DNszHK574zzBpLOh8LcQpAriupMIp8gohKSkl8pJyV3IUNBhUDJ0flhZqURINi0jk0tTeCT2eJBBlA5KYK3qTk6PgXBFNgewz0zThTceu79m2Cm/jGJwvpfumvG9Q5XivesQgWto0iRak7jhNhcMTZ876cleHriOMc6EC2IzVQvBTJ8sG8vz99/jTZ+/zuQiTJqZUjpusJZ/1hKGm+Z4fkJRKX1yfQDtEFG2l1lDxcwPEQuAml2RZAiQj2nQMH87jxjE78QnbZxXRmu8v1aeLyrA2QnhAYyi2KCc054KuaeEqtffT2IoJBHIkzTNxngkhkV8H8n71sVpvrQDg40bjhd2rklQhaStq+QjY6qNGNT5rXlmo1yKyKg6gImb1nFH1nt7lorkoJ15qO32516ABpv1E2ETGaDjW+TyHordGfRxLNat8FGJWqRUnOuK6J31B0kyhRiSrOA8X0uPdUPeucrJ9NrwVt3wwnPN7f+G/y1//jd+Fd47c/fwVN88L+V9v4I//+fd5580L3n5n4L3n7xFH4bHfsdt59pVXci5nWKfYTkluRgWmBCGcshbJnZLvv2ou8xPtmJ0NjzBuKAKwVavEGk/X9Vj1jNFxN0cO44y+Cot2ys145C6A88LNATQJHYrGwi/oKuTZG48XizGKMx41ChoIISyEoGF3xmb3hCEEXLxC/HVx9EyHcw5X9c5cg+M14a3HWY+KEkJHSKV6I1eOU8gzTiLJCwFbHDzNDDXl8TwH7sKR62liIjCdKd3Qc765xI6FR2dur4oI6GSIJmMuJ/KNEl9BvgFXU5uaKuyqrPTJBINjpumFFcMmDmwHdhCkL/wFUkRaE8+tkE2pthIyzpQ0Z+ihFpWhgyXZxGwg+kSiaTT5RVPoYYVTSWNmxNoq7uuX35d8gUG1uJcqJQXZJC6o5GCSR12HMRa1rhoAXcjsoorVmoqKiaxFQoBEUSoHiDOkGU0Rsp5k2DhxNWDJJpzI/+s9Yu2grdKap8/6+jxvvLFVNXx5LqvvRU7X8LCMfX3Kdk0fxUFrxzbyfDmfVgJvOWurVJuTY55mwi1MW4PNgpnBmw1DVZ+/TD2TUYINXJ4NPHnznCe3A/H6wDzBs7mmM0YYbEKtYDswSYlJ2Vslp2IoTT/D7cRmP3C7DTz/buBz/yPL2+GCs9sd1z8rfKhXP9hz9aPIfHUyjB3g3QCxzAmH44kVPm8C397A7Rvwt3qw/il/fRt587YEN+LOMV4R1yGhZ/Qjdjaoh5ZsSFo6UxhjKqE9FjmMrLTGoiEWuoLRjowl5RFjPUYyRjLi2hysfW/V4tXRO0M2kSwzOVU7khwZS6wl/8bBPEUwMDjBV+9gIaar4r3HisOIsLWeY7VJU20MnitJ2loHTFhvICfOaoeAwVn2lbfoe1scKynps03N19zc3vAvf/JDnnzmXdI0lUpMGUhuIO2GYjuBoAFbCwOIEXEezbI8qzLfU5mbyhKglslvCverzuiW3jqtAUErWWrdU1IMi0it1LYRokWSIzex3SrdoykvTpnJWgsCqo6hUFLH1PWYI2GeCSGUiv8Hi3ed6l2cqF8QTKnc/zyv/f1BKnRdSZmqHfP1HGUdF4Ozti/3UmR6P3W4XDPr9X9yzJqzuPSdXQWWwul5PKz+FAWbIc3KPCVmUaYmSZWrM5Z18ZNVTrzMh45ZlpPtXL42O2rAWoOxmaEXdpset+kxg8ebouAA0OsZ+8u3+d0v/A5/4zt/he6L8OHNB7y8U67H8oaHAPn5Hf/kX/4JX01PmOLIBVse91vM1nB8WfmqXWBjOmZGNn25T9YLdwclNkkNEYyRep3/f+CYnW/PMOIxzlIpZnTW0/ktOz/w0k3c3dxxNSVSFua5IVOlYe44a32wFskRIhhRfHXyBucZrMcyoVHwXhGbiWoZ5+pMOdiev8EmWrpxhzFbXtKTucV3jq669i7PxJwRCaWBqnWIlNL5WMuxU9W0CDFwGCPbjcHbHqUDUXLtf2KdoJLIcWTeT8wEthePuHxsGKRgBjl3xPFIup1QRmQb6A6Ce5qxV5CvyrJOe+Aa8gx93ZAlK0kC25UTYaWQRj3gjOJ8kQQJ9rTsi0GIJfKx1IjfMlHE/ADGnBiVUkmqmSChcGWargdwIuoW589ag9ViYgrhs0oHNO0gW5hFhlgMZsqYtkHJXFZszpioqPVkO4NUkeHFgBfDQMqY2ERwy6ZLNd6kGlXHgMRqtKRWEHEyIvdUy6vl+kiJi5UVXJOEq8TOvUPauIfKfZxjt0Ld7v2tGrO8/qO+TmGD+w5ciVQLSVvqpv4o7LiKz3Bz4KnvYXNJsOfYrWfX1XkqF3QeXLjjaBzdbmDbwxDhkOGqkla8CnlW1CmbzqAzhGPEbjYcpDhThzQzWcjbiHuZkOcd+9xx+Irlq999h99++vVyXa7n7oORH/7JS376ww8xOXGxcxy3HeFn5Vz/5P/yPld7+BML/yvnCSaTp8RfORx5xziuNoWnlfsNczqQJWEkEXWsDlJeCOFWPGK6ogkmRRQ151y5Zg3lmkr7IFXAoDpWTqUvtqbaiCiOZCJWBecEJBPtTIcjpxKMxGQIddKILU5nNjAYy5nr6Vvj5Fw2QycGbxxObCEgG0tugraYQoa3ju12y/52BCMEzWzOOkZbxWo7z3wcwVscJbAh1kCzro3ewE9+8hM+fPsJMc9YzYBj6jxDtyFVO3EkcU7t6EEobWvUF/J+3fpVS8SjpErkL5NVHmiWrZ2YIqdQxL01r+Zz9WAUKc8gF1RdGmLW1lDzQFTRXFrGaVaM5vtNuZfDlKiZlCNJT1WqD69ruQxeX2PrQC2vftdkbNrIq2MXJP7BojWrtS3aCoXut4NaZGX0dM8WP3HlQC7V/Kv3j7LifnE6TwsEF5219TW13+dStEJQYlCyFYJZnVzL0Sq57B3xvkMGp62h8XLL5ylMTalevRWwTrkwhjPnGDYWzpTBKc703E1FyH24+AK/+ZV/k3//27/JW2+d8fJ4wzQ7pinx4lALXibYAlcvb3jvJ5Gnjx0X52ecDR1305FBCsfcbg12EkKInA8BY2acKDHCsWEDqyj846SOHo5PtGN21u+wpkeMK2r7ANbjfU/vB0x8gatCpaVSptyVKZf6P2OEHCCniOQimq2aoLZSEZFSCeOFznn6LteUghArFBtkIpjE+e4SjCWEHSkMzPKKoTdsatrNxEiyV2iOzDmgKTFH5TgHpqlsZHMcUQIhBWYT2UTFpVwqa4xBaiozmpFoZrJkQs7ghGEj9BtlqKRk7w1zdlxtAvHoyWrpzgz20Ux6a8S8qsb5ZWZ6H/ydATwmhUKoNzDVxeM78INjc+4Zzgc2Fx3deY8fDNIJtNRCOjDPR6b5yGEqKsdjSDCV6hyA0cCUYa4E6WwyxghpJfAitGqboiVU9Gia0corZK1uEurJqsRqcFWF0CriQsKZqaTibHEYs0mINRixNIVwyYUMLFnRXInGKZNSXNo2kQoxWKuOU6rpg1bR+hCiahpCTbPsowidD0nA7fu16Gwztuv3WOsTpQeLXVfv/zA9uj7/R6U1753KnPYukVULF+DOZkK3YTSWD6Zrnn34PQ5B8ecXbC/Kmd8833GWd+wSRJuRbsD2G9QcUC0OWnlTS9IIubRBu52UfRDUJQ5TXbPBkI4WFxK9gWDL89KQObwauaY6Uxth86bny+884uy34Ppq5O42EEzHO18p5i7eWP7JP0tsM8gc2HUDmcTd/pbs4bIvxT+H5Ojslmwto00QZ2IqNAlNod4vjzWpaBBWlrNqJKaKrNcHYgSQhBJAAoiUFkjWlwANSM6RZl+dOkOWjMfTpw6tEiTJWWakKLJbEKPYzuO7jm7w+MYwNlIkMMTSaZHhyZoJKFPTcxKDF8/GemY7lRRYErJR5pwWGR9vHbuhx/seFcPhEJfJFesz7Dq4eXXFz25ecOdGhiwcFea+vO7oqiMbZzY5Izpj4lRRKClVy23+5kyrcM25oNlKEdZtiFWbq6XdUl3DNcWoFYE5rQetjpvW9zJojvW1Ld+Ul/WfGxSeYlnbKzhPtbTLKtphSsrFkRO9T72AU+yTHqzxXzRaRfc9h2n14pZ2XKNcwikNGBWkthizvAbivXZ9LbhMK6ezIWZrxyxxv9fsOnhs42H6tWV8pT6P2ha5VLtXz6M9Ak0lJ6k530fKFo91hZpRY+1WXVGvy1rwXrG2I1lBXeTcZTZiCN0lbzwqbdS++dbv8O98498gvznxs/GW7aHDRsvd/gD72kbNwxjhbB+x+w5z0WM3G9RautDjfXHMJjXIbJjJnDmQfEeyE1uTaGhRiBlMrhJSHxUGvz4+0Y7Zpt8W7Rhrls3H+FJ957xwtunobi3xrogdtgeY6wbg/WnFzAFu7hLX+wOPqzjpuZxhnJaeWZ3Be0sIBucyvqZrJCsp3BHdGeKEze6MSGaO0PdCX9NuNkOWjjDvOYZr5jgzRyWEQKzWLcVI0qk4ZjpyGDObTuiTQ8RhXYuGEzEdSCRUEn7rGc483sxL0+duKDm2FCZGMsdOmLaOnXUMuuFwWXKZ+SyQbESvM6RpWQiug21f9Y4Gjzsf8OcbuvMeu+vw275sBEYIFcm7CYbr25HbO9CxVHU+FsccdUGdOjzMiuqWEDviXBeaPZEFtHgC9edSC1q4ZYoxhtyiD8CIA+2RXHgvplZvNscspQzzjFcFJ8SYi+6ZcWDyCZlSSn/O6twJQk4RU9OW9QGRciqbohaF/1gNZXO+2nV93PK755y1VIDed5rWm1TjXCxGsBqiX1bJua7+XI5t8391nvbtcvp64bp6rS7vWdoEAfTdBhvhxY+v+d7La549skhIxK4nVV7lm48y/tEFb+3OeXN4RGmR2JENHCP01ZIfU6QWzJJG5eWdstsd6edxSYHNJPbmiOsFMQ7NQnfuStWu9RhbBUhlJB4sU7TMN2eEg0PTkeFW6J5eAvDbX8r87I//jCQdgRmNE76Hq1eBOG3pPlNSeLs5kthynSL9cSRI29Ahtn6aGks6MnqcKXNLiISciE21PkewDotBTXHM1BakylvFVQOesiVSUpWaSgASNKBphq6kdFMcmPOEmhERw5FEdqakI1detamN0EtVsMWKxeQqhVNRz851DL5QLkKYSyrVGJIpPLHG3fHW8Giz47wfyJ3h/RH243SaS9R02qz8/PlL9o8CKkJ2jrDtCB3cpAIfdPMduyiY3GN1rmtAKFj8CR4RjSUTnCgNxiUVlFFZGo6L0fvVmJVj1tT4gQU510ZcojRG1+bs1YCwNSGR5pxV5Aw9CdoWVL04bglFxVQuIeWZ6um92jpqFdvL+tL73ysf/fP6WZbG5VJT5SeUaj2WNmpAzlolbhYQsJyHcoDKihdLkbhowV1cn6eeu4mTtzgqcrJLRsA3Zwles3mG0sZPcw0uM+T5tBebXN4zt5RDuyi9f+0LYlY91wJ+KrKkCMFZoXegRKwfOO937PrHaP+E84sv8xuf+W0A/q3P/jrTG8LL/YFu6pjtzPXLK+ZooCtBWZQ9muCwj4wBJHWkpBys4dzuiKmh3B5rM7br6clMRnEquDwhpqx/25Az+IWp6vX4RDtmvnLBELsImJaSbsG5hpAKzkvtnVWOMbWkNh5hzmWP1gy3d8rV7cj1/gaA880FvsucW0/MR8JRCGTUhUXHiJTxaSZMB/BCt+nZqMMcBjyyqIibPjDkJxzVMM8jpQ9cKpy4aigNlpQMMSTCcSKMGdkNJZWoedGQ2rgeg5CYUR+RwSJ9sWK+nivHG/ZT5NntnlfjSIog2mHdBvUbnn6+TMD95UR8dMfxdmQMER2VLnVc+gtsbe9iOw+DI20cc2+hc5ihQ/qeaz2UFAhw+3LP4WYubSkc5B5GU9KL7fkkcWh06LRlvnXsj4pPmW5IJ5JtruXepog8CpbSM0GBdELMTEkHG/VIFESmQhReIXA5Vf0ejYjOqGgNkiNZCn+vjdaapkSbBkkJzWmRPcg1gm8B9dqZgpPzZDmhWAuc/zFDH/xr6cNfiHj/iuF344qtXya/4O8PR+J+O5hMabMDcJwtOpxj9IjXmTNzjgxCtgJVNHXzShlf7vnRW0f03Y4v6jnzxQU/efwhw08jY3vcHdjU0m+ecJh59uE1my1cVoH9ZIokjmbFi2HohG0vdLLhkA6MocznnT0DyVgz0/cRP0cwHdHCvn6Sb3zrTd79oz/j//XTmc25oFHpA1ztA/N0zdvbdwCYjUOPAUkQthvCIZFzWBBcKGnzOSVUPL4TOieoiYjkZXONKZUNypbtLulMxhVHA7esDWdL70ptUKUqXfIk15OrsxtcTxcdIZY0vqQAqVyXuoCtCsY9tqREk5TiCFVsFrxaumq6NJc5PaVI1zvyVfnDFALZCrEixUPXc3G+41yUjOdmH8lXRR+rSW+kAH1nuL45oD1M1mCNw2569h3YSv1IcWSawIcNtnb7ULWVNF7pFSnhyIhK5Y4WF0OkIbarWZtPztlHzmXJxZGrRPKiJ9gETPM91Odhf+HSfSQsCvSqiaS5SEVgUEm1ZVM+pTNXozk0H7Vc9SO+zyfzd/8jCCfeWPv3IDhb+zVphbqtnb304Ofm4GVz0khcEDNZXYfe55g1SZ7WnUq5X4ywvKw5U7K+9y1tvLIq927A6UPdQ87q70yljqyvte0bnTP0neHMwlk/4OyG2V3ymae/xm+88dt87bNfAODF+YFwiAx5R/CBeb4pfbVzR1fX2Xky3KbMMcGz45E3wo55PzKbnikXzjjAoAasJXuHoaOGDwTNi3RV6qV0leD+M/tF4xPtmGGUzl4UAnZNP/ZG6c0tXTew7S0XG8/LceQQwPQ1ZRAKb0Bd5VXZmrJJws0L5eVZyTN39s8J+ZxDPmPoFGvA2D0bH9FQlnRMRyYmzHiLBkuQkZAF486xzmPbHZaAM4muO+PMKOP05yS9QbpuYcbbFHiVA8c5kOJMyI45zTibSu89V3apUa9wCCkaXGc595F+nDC9ZdaiwxKnyNXtzHir6J0hHJT9ICSdkD5yMCXfrr3j4q2nPHnH4lC8L2Jk2Vi6XCZpjoqzHt95MIloDAfreTkHbl7dcZhq+6PtQHdxgWTHnDPRZLLLhATxUMOGOBNQ3DiyCzN3k7JRy3B2RjyWZ2hNh+oRq4UPoiZjzIwVQ85mMcC9VZSepDPiE5htiRZTWNIKWTOzFmPsw4RjwvgtYg0YXbSHWkrEkDEi5PlY2m/llXp3FjQpsRJWm6Fb64hBMUCdFgPZiuBMPaAde09mSE7pDqE4Q+uIWICwjnKlRbr1ffU+d6EZ5HaORcxW76cgWmS8fm15XdXyTsV+Zi3RqXMwVmPTuZGtfBF/ZrD8lE1/RpwtypGjFDTl0HdIdHCMmKAMTy54u8u88+o5f6ovsJUrSNWDsh56PyOUgEkP7Sqh84WsLRnEBnYbT0+iT5l86EiuXFfQEecNxjpsGPAp4Z3DG4v0Zf08fiJ88XMX/LMf3CBWOSbY9DDfwp/+6Ibv/Na3AHh2c43tz+nlgMaJ88FwPGbiyrpmhBBBdUYRchYGZ8v8qzfW9r4grSS6DkSOHCNYK1jXLTZCUCylS4WKIAi+s0xRWLoCWYfmorU3E5FOkdvAZBLP0pFHvtZ/5cxlzswo4fYlcfcmB7chjkeGviCC8ejJfcfjfuDlB88JZBJHTBQ+vH4PXwsO3ukcm87QGwe64Xwz4UUxGa7rJBsEbg8HhmDx4wXbNMPmDS78JbnruAsl2J38LWd5y3A4YLcbiILmvoiEtr7ClPlGnnAmVUJ+QaggLxpyEou3tfSy1BKcqgrUat7i7BVvwCzd23OhFqQVUb3OwbJGCx0iIUQ1C/ctk9HWiklKpejSLB1HqsF6NiXp3+gL+SNsBCunrdGsbLMZa6eIkzPVOKL2gQhtCVcr8ZwSOLaeuHCyA5aT3VFWdIi1g7dKGa6PTeb+caaiZcu16uk62nVJ9RaNFF0xQnHmp237I0tHmnLtxUlrtjCvylMlgcVydUhkU9arFVCr1G2KTU5casbvduTNjuH8c3zh6bf41ju/zptP3uBZKFmiu6vA0+GMpAHNI/sRxjQAGVc5rc/J7HzhLE6HPcf5kptZ4GbP5LZsN2WdnWHxVui7jtE4nBq2aSJ5QzorczAE4eZuqvfUnj7ULxifbMdMBGcjvXcLt8IEg/Udxve4rkOcoEbqBlOmjNqSWQsKtYl8mUyqTCO8etGqERMhJDaHPednhWvlvHLIE7lGpt1wie8GBlWSCEGEFAVHjzFbvBQHyPuBabpiHzNTmJglcX08chtfLhGAhMjN9R1p3mPckZgsKW3QLIVL0ngaUkrbe29JYukciCZimhA9FTgcjxPjODPPSohCmorK+GwsRqo3rxknHd4aNoNns+1Lmsp3JaJl5UwYs0Qs8zEwzjA6g70oBQfdbsPZ7hFDf07fbxCnqORSPr8vE/7F8/eYj3vuNvDsmeKC5+6Q2emMtc2YanGcpHBkSh+ClupYuyzFSbPWFgkNpbRMIWNqj1LNuTgwofDmMmBlglQa2y4Kk7D0vyubQoHGsuopaku1ioj7RnV5/cP5+Uuioxb1LudqqcUVkmUpBkrWCF314D7u9A9//9BB+6hj5cHP916vULea5STpjQO67ejDO/QcGM2IP+uYg8Xopl57xhpbDKwx+N6xyZ7trsf0kF+VmRUFjIXsBTsofigcj9L+9nRFxtR706aDL3woKq+w3LuWeirOS9cNCJ6+t/ihrMXduefXf+Ob/F//6J9yvJ2WjcEr/OmfPSPXatFtv2UKASdd6fQxKx5DTHFJpxstacOUlJgSKuBQ1CysZjRXwVKni0MdUuAQRzQ4ugXtL8IQYoqMhRchWkfvO8Yqg2GMWf5hC7dlnjM2HrmZb3jPlMAsEJimI1k7xCV8EOz2CcZ6pLWvMw6LJ2Vh2xv8PjCrwUgHeVrazg3uDS56iBvBT47u3BJ7CBFqZxpcKPYxHCOME7O1XPZbdn0H1nFX07ouBO504kx6tgSMGjTPmCyYthZ1nUIUSrfWcjsL4tIcsWo3RapTVh02Ke2qqL9ZF+OUKsvCEmMVdOVcn5NSm1IXGYe0qiQo6vfFEVz2EpFCKVlFQqXn5mntfNQ23DKrDelqqcOHqMqaa9pGaXy2/kwV4eLkGIX1OerXSA3uHtqOFhGujm+/anYp51Mqs73W1h+UkyMhNVXajlPuUz1q2+LT/ckLOFxkM+SU8lxTNpKBJLk4eXJK3Xp7em9rHX3/hM355/jS4y/x1Xd/jc9+5kv48y0v5j0aqzPld4iz2AwxZKbjhBiL+I5m4IaSEy72eU4c9jNHn5A+MTi7PBjXOQTBOkcv1HZeOyLKWDWinA3ENBdpkHX58S8Yn2jHLKUD2YHwZGkEroxYq4jLDNawM56NQL+Uu5TRUlJqih0wpk6qBPsCmGGMkmVi2B+YZqEbFNN53DGQ7IcAWNch5oytRLIIE8UB9Dg28ZKc3wBgo4/QpExj5Or2huurK17dXXMXXy0aX5J6YrCIejpbnIwcPTl50iwLpF6cMsdm6IgInU2gkRBPiuTH5DnsJ1IEzab0BhszFstgTeFmUYya6QzOG3xnsUYxkrEmEysvRFWqI2MQ5wgx8vL2jpu7IxpmXG3G7sViu46z7WVJAztDzonRveJYicRiHtPlnvM3DHaTEeMxdESNdLY6gqFstGWBm5NuEFqN8MmsiFikvq7oF3W1SuhU4ZViImWYtKSwpRF+VStnobh4qolGODZqaf7ZYkRS5YzoyQiWa6A4hSsjes9hq1auyWjAffJuMidnr/FZG6oVAfLrC3VtWFvKZP2eH7f82wbQAuCH6Y7TUeU3qb532TBOXQz8+ZbdG+D3Hh0fcxuvGXqltxYzlat3NmCk8P1mzaXqsvdc7gZMf9IVy1qi30GEwZY0uPFK4+BRrzEnQY0iVkmSSaIlQEr5pKWXQTFI5SY6Y0kRBj9gWu/D7Pj6N77AZ975Ht8/THhTELozA3/y/Wccbgvi1++2zPOI0a7Mp6wMrmdKEFvVr/EY7whGSRrJORFr82SRlporokYpJUIuCHAWQZMhhyO5OkqIRxGskdJ0ORs62zHb8jmgrFeRqryec0F1ojIycTXdNGUH9mnkDEMMllEitr/h858deGN3UbxgwBiHF4f4joM61G2x/cTkZ/Z2x2Db5mM5vzxHusd0Q+Lx3UjnHdnGZeJMAtnBWXCE8ci87dhtNuz6jmyF6yq944hcMHJOxy6PdAgue0TcAhNp5XmhNRhos7W2XZJ8f7a2UY4qO7os3liBXZbm5y0SSrlQFypKd49HlnNJWebSZzEuucx84pTVVdSqstdFCG1drS+zcbDy6ufmyy0BWvvdAwRbqG3d2jpde00UuxObDarHmNU52nU0LTKFe/18185g+zk9/N3KmUt6QgJVavC4CirXxfUqJ2ROaiSr6XQfFnSuGaWazvRS1npLsTZOr7EgHURbEEankCulqTt/l7ff/Qa//vnf5Cuf/TJvXb4B1rDPI8YY/FAQc6MOMEXuJATmOZaAxfc0x8xoTW0rEJTDGJj3mRgy0vkSGFGuofdS6FNG6VXIdPg8sKnp2t6bUqgUxyVg+GXjF9FfPh2fjk/Hp+PT8en4dHw6Ph3/LY5PNGKW8wFVZWJHkOI1O1OQn84Jxjt8V0i34pTYCOFaYV1bUpkpgtGqHK9FwwTg5qZEwbsBYlLsAN0QMU4X7qLRDzDhks1uROyWTEfUjBPDIe6ZasXBdprwRKbpyPXVHT//4Iq7cM2or9AaWXfuEVYHnPF0ncOIRcVUAT6pKb2Sr3fO4b2DmLC1hjoxESqksT+UKiknrlThZCWEjJPE1EPfytz7jt73dCIlKtfi3YcQiBVV05xLVRelCukwjRzGufQDNCyq3ZGIGrCDoR88nXFYLVFOrjy0x488nTlwngKDTEw3RzR44iwLQK+znqQhTElFi5FKtIXGrlBArMGKI4tUnbdNhcO1Pp+iwt70H4MBUsFUyIpdYjdZolzNkDQtMPuaH9KQqSin6PYhWgX3gr/ltQ32hxpBt3OsOGutWOAeGseKdwH3KrbW46Q5dhqy/O90LQ9f8zCIawmedRBbInpFF2HlS7q3e/yHAc0Dh+mOQzjw5HwofD9AjAWFkCJzSESUs87z9PyS7dbxsiZHQgJnCq/NoQxOML0gVqma0MRUFMedKfyaKURUJ4Yh4qxZIlhjHMY4ZOEr2YrGGPpq7gTH0zd2fPHzj/hXP7jGVrDGCLz68MD3flzSgX/521/g1pbOH46ObDK9iWR7ysWI8YjxdE4IaSbGmZAihliq6CioLTmU4pVoyB6suoq2GEyV1HFai12kfPViSWLqum/afgY1Qk7KHBIaC9l4TpEgiVe3lcuVZzoM+zHxcn+E/hWyfZszv8NUqDLnXHhWxrDpPBvpmUNGukRvHF0t9LiLA9fuKzx+4136rWU7WZz/CV18wVTvgyu9qCEqMSeiWDa+ZzCGIIlbKXawE+VaJi504iIHnFjQngL1NUQTyBHVtCBmTc1fsmKaOOmDCX1CrbRolgFZta5sauqyVlvWxuf3KikRsgoZKWlM5B6XM6eqUr8YhPJerXL4VDm4WueNs8UJiWL1fWaFHskJOb/3uarhaL+27Zg1spZPqFbm/sa+TmW2va8d1xCyBfGqnzfJCUlbH1s+YEX3pVnN0/WdEvj1UGFJSUrLUq1QtXZvW7aqoJTl2bam8u1WGwOTg95CF0pB0OGNN/nS578EwO9//Tf4y9/8Bm8/eYfufAADIQSG3HPWbYq0EjClCVEIOXA3TiiFCuVtd0LVq12Gcl3HEJknJUtm0GlpJK2qqOvZOIurDN4kjoGhqLEDSQ05w3SYHtydjx+faMdMw5HJT6Q54OwTAC410YnDy5bOO3pr6UzhHTyET7PWikwocgyaWemlEgPs9yWFFbPgRmVznnFO6LpygwcTGPQZMQvD8ARkR0wQiEzmwGyKl7d3N5B7oh64Oh55vg9cH18S5XbJ3W93hje7hLGlY0DRPJpQDmA6zFrryxQxvZwFjQY1qToTZWJNkxCDYLMt+l86M88Bg7JLmVY5Z6XDqS1E72wxrkOMK5tB7TIdM6gzhBwZw8x8nNGY6LJlJGFpDcOlOBhWSB0Yb7EGrLf0rkzSresZ1HERJi61Z/5JIB07RpRdE8qRwqUpq7rqllmQVHgoJ4hdiwaUMRg8qCligxGyqanMIgJHJteSbNBY0tSW05wwcuIwtUXZhP/XrY8CxWg1Z6zxw5rTBicDtThlcso+tGrN0mHh9NpKp3mND9bSoGuj19IQetpvThVTp0s9FRh8hC1oTle7vntjlSo+8UVqWqFuQm6I6HZgdzHwOCfO7nqurq95mSKxlv1t1WGS5UhidxgZ5yOXmx0X/QWPt1ueU5yIOcFsYUTZGhis0pnS+P567ZgGSFHRWKo4sy+pCOsdtlVoLx+hpJdSUowYwhjoavDmiwAY3/7GF/iH//hHmOpnqQU9Kv/VP/8RAL//na/gpUd9wkcPPnIQS298rbAENRZrHGIcyThmDHfpSIqBVFNgpgZGc5oIKvgsdGKgt5iUCKk4LWJcqTRXQxFTLjxIa92KemArMb18NollU96HwK4DV2Vp7m5uEWt4NSbee35HMoHPfuGOrz5S+lYYI01qKbPzlmuZIQude0zffxbZvgVAePs75C99hjfe+Sxvv3lBN1zyX/74TxmfvSoPAvC5+mWp6gH6gcu+iH9HZsY6QyeT2GjgiUYmTWwoDljOAdemnTbB6DL5C39sHV00Z+r0vI0WKYtlAS9Vf+019btUvAPR+85OEQWuacy29jNF/7J6D0nLXlBkNVjSpaKUamRO97Vd32InuL8Ul1RisyWrtfyw4rI899OLP6p7QHPqlsMWG3k/GHz4b7mO5ijJyb4tjt5HBJ7L51slm6FVZ64uMNd75lc2Sk+6aa2CNEspDBDKvcxZ0aQn+o6AkeL85BGOw4ZHn/kCv/u17/BXf+23AfjON77Eo7fOuIl7Yp7L/u49nTvDm44x1M4fWemt4+o4czPNOD/QW2U2DuNqVxlbA+ZUHNlxnpg6RyDhnTLZsq8fJZCs4HpH35uyP7kenKera3EOyvEQii3/5bz/co2/2mH/vzmOx1fMmsnuhq5+4q05owe8sXjr6MWWn7WRicu96SiG5N4CSMVRq3S1UoabhP2+8A0GpbQr6WE8lJMdnTIOezbziPUJbw0hZmbNpDxxV0rLmOSamLYcjtfsx8DddOTqZl+ckMYxs3vy0BVSOlR+QyakSGe7xYNvyJkxGWsdmhMxKGpz6R4ATIFCdsylHURzOo7HyP4YuOwbr8WQ1ZJyJmWLmo6UC2+iaYEp5edxHBnHGTB439F5jxNhsy2TebPpceIL8zQLThydMYw6crEthHBvEmey5SwktjbjvtQTdEJ+umWq5fmdtQgGtUU6oPSrM0jO94zOehhjwIHIrqA6VfTW5AwmI2Yka/E1E2U/Wa8Tiy5iiFCOW6LJ1fs0Q7VEtw/nECenaHHWtFYQrdApXb3uowx3kwRZo1fNULbzNifvhGi1Fz+4Rx+DmDVD/lqngtVnOJ2iiFY2aQeTC6Pn/FHPED196Nluz9gfr7k51K4Y3hdY2nk0ZZIEzMZzeXHBu48e87NNqZK6HQt5T2O5KJvNwgla7q0KMUOq4r7WglXDHAJJ9d5HK1yihGpiHKdSTW065sq/8nPmaBNf+/LneevpOc8+uKUzwiRKp/BP/8WPAbg5jFUtvqA0qpneOjyOUJHDcg8tznqyWLyWaskphUX5P6kpPM8EKUVSEqIeyWrIOEwt0S89cQFMRUAKcmaMWcReBVuq0JMs6I0KTDNMErC5VlzORSvtMCfmWIsN5hlM4a6V+5TAgneOUTqC2eK6GTl/B//u7/L4m78HwDtf+xbfuAy843e89egJX/v2jv/iT/8f/KN/8cdoLbYOARgg54gwsOsGzp0H0zGmcRHkzq7j2uy5MpE3JLAxHq8BViiTcdXWqaw0LPLidDU+mFkXfNzzUlYdPRYcvnQSOPHMyiRv4rGaS4uprEKk/MuUyswFoa82IdWgalkeytJveFk8nByapifW1imrn5tTtgR1FTlH7p3qXgs3W39e2yVTfx+ExcFdO2Tta1qdN1W7rtwP8jIndK19rodIWKsttO2zrO9Huw31XGJWAV59jOuK8kRxgqwCRhYB3VWhLjkJc0ocEHYXb/Klr3yLP/jO7/Lvfe03+NpnC49bHzn2ZiaETEoW5wYQixgFq7glKi4yDNM8E7PhfDhnsAem44m/nLTootkK6MQ5cAyZEGeEhK2os2MmGKHfOrrOYY3QlY0IVyvYRT1WQbL8yv1UP9GO2Ycv32OzteAtT7oqeNQPiDPYwWE7j3iD6Sx5zqT6YJIoqXrEkqq3b+IpGqoQltZa4pjhmCAFUM2EDYvitSi4fmK3maGf8TZhvYVo2ccjUyiedbATNkxM+zv2h4k5HEr15hK1Qp4y4zjR95YQFBtLC5Z5Nuw6e4+wn1KpLBJRSIYQE8kkYm3dMh4zVhxqPDCXKEuEeVL2h8T1UNvmdD3YxMZ6IpakhTjrjJBaY3hK9BhyAmuwfgNicd2Wc7+h78qm0blSIYkXkihBMxbBuI5aoY8h0BmHtwa6Cf+5DXEc8bfj0gHBGw9VYLM0uivpKKmVpA3nar3HpH42K7YgGHRIKvNBU8BaT3IJcliUs4VifJoxyfVZ8sBAtmPh5Mi1v91zBh783NKYH5XifDjaudpXI5wseCUXL3IbcrqGxWfh5CB+VLT9cCzGvr3mY67r/jVquxwAnN2gTvAmkN3IISRct2WjE+GuHBvGREoz3oPMkXG+I7nH9L3njYsdu7OK2t4mUtRSEp/B0qOSSTLVrspALNFziCXoKPpziWEOzDGsrq8gTWKUrvO8SgdMsvT9gJquHmMZk+Xs7Iy33nzC+89u8c4SU2Tw8OLPrwH4pz/4Gd/+3Lsc5oRqxBjHzgxkI0vv17lu6AbBGIvtHL0EcopMrSVTVnI2pGzJqaB8RiMhGzbisb5Wbrui5eaywRi7SoOZWmYPWUz5lwuKv6AauUjk5NLIs6SAtbSiGdwEQZljkZGR2oYj5+KIOC0t7TrZckywvfwKv/kbf4Vf++ZvAvD222/yuRCZTaTL53RPvsS3Pv9V/nOrbFuqWcu1iwhWd+z8hp3zBBxTuFkEOWfvwSeuzcwVI2fGsZUeSWbRyBNNIJ4sitXaxHwVJMoJ611V4a6cNj2hNoYT6X9N8G9OXquSy7UCO2kRj81iiOSCTtbdtAlLrxeeaQuppuDgfsC1OGWc1hqs0KiVU7ZOFy6/a8anfl06gayCPGpAmbQ6SvIRr6ch60KkpL+Tlt63uf78cBSdshKQNT+5PJ8T4h45US9OtvQUTGahVJW3oLeiUHrvPQpaVqpgtTi+9YCWJIpiyNbjn7zFb337O/yHf+H3+L1vfB0+u+WZL0ZiEzOSDY4N2Ix1js57LELOcaE7uN6WClyE8+0jdv0ZTkaCjNweC5ASA3iR0k9VgJiIKsQR7nSmJsIwBqJx7M437Pod3iiYHsMp5Z5zJszxXtHaLxufaMfs+YuXXAbwOyHGy/LL4RHoGdgtnevYuo6d6/BEXJswqUwEK0ByYCLG1gVV20JA8QmMQNSeECZSrSbxCcYqVpsF7JA4297g7AYjG8R2BDJjCoyxHGgQtlaxfUecb4l2QroBTaFsQOWtmcJE6gwxjYTsCGFgnmdi9CgnCeEQIzHOiLpSLZiUaFg4OWGObMwGMZYcIiEkRBw5BeZJeTUWVMNtekxt+cRcZEB6K7jeMzSIwVhC1fzxw0DnN6Ce3faSXd9hbIPyEuIdris9/+44MqmrJdTlwna9pXeOnQouGeSpQ24E3svoy9YsOJNJhTtkPVkSiiI2sZAboBgkY2rKxyBKQQztCZKW6DG2bKpZwkciYOX5nHR+1imItD6ovnadvnzonJXr5155+VK12c63Gk3IsWmTaT1Ry+quJTWX89f3z6f9abm2h6jYWqz2oVlYI3jtnqw1zrS+1xK96wnJG33AMeA08WRnsEzMsXCurK1IUcq1khJymJnTgeQyvvc8utjy9kVBUV+8mrkdS5ulZCyKQ2LE6MlAxUTr0lOcsymTgN08Mc7T0j0j5kSMCecM1glJI0okZcOxWnl3mEi9YXAD77z1hH/+/R+hSUkBbA/htpzrj7/3I37n3XfJo2BcwJsi/ZBVSNI4XxmNgZRzbZtmcOpKenPhqxSOY84QY2mcbSWSjUfmyNAMOHqSZMBgjZSWP1oaEgGILRWnDcWeqzPkpba8kRZ8GqwKG98TNx06Fb6NqmLbBuW6kibFwmTQ9Jy9RN58+jn+4le/yre2RbR3RuHyEZ3JEAdwPW9evM2UlW2dMGkwpJDZaMamno31DMaR6DhW3h3AhCV3mTsJ3OaRI55Lu8PZjNSuz1lD5RWYEjnX2bikp1vMoom1TEpBf01ByRbUSCk8jdPrF0ct5wUxi1oqMWNWck3Htn6YJ8dJF7vQ1lSzD/4XpKhOn+Dk3Dw8/B7NZvX79c+L/Wjmb+WssQrK1nZpjehnCk8zUBExhNicrtW1LkGfloBsnfpsF2I5BbONHWioQWWbg/W1qR4nDfTMekrdyklKo7VLtgLGlv65y7MeerYXT/ir3/nv8x/81m/x7W+9y3GXmMIr7Fxs/dzvAENSQ+8ELxNoImiHqqer6gGkSPJFONl1l/TWMKXAft5ze9gvNyJrxmTIFR7MgNbWyUu1qIxEY7m83XHcOrKTQp0RQ6z75+F44G5/S5jzUjX9y8Yn2jG7ygGZHG8OnJr8yo5+W5qU3zqD63zhOSUl1KhtskqIJToXVzYAnevkquKD0DZMwZrppBCdhLjXpT3kLYBY/HDNdnPJ1t+Q1HA0kbs0EesOe9ZPpdhgc8kb0xP8ZNhvRo75QK6RPDkjyXCYSu+9PhtCvGMWW1o/LJt0pM/CsxBRucFbwzTN+HDSr+nEE7JyMx4JWSEl5pAYttDJjG336zgRbcdkE1m0CPYN7h7hmJyYQ2QYBoZug9CxO3tEvzmnxywtXoByDhIpTcxEcvZswkR/Xt7vERYvSuxLU+zLx+dc3xrkzcjhp8VB3aSIR5DsahStpd2SNYW30xxnSjpLnCwolnNCEk/MFUH1CZMDbs5LCXgzkmlF3FicED1JSbyWIuSUomzE//b3xtWAU5TcnK6GZK2h/gUQa30CH6Tj2sYiy/9OaZQlEq3HtBYpsb5nIx8vqv2r1zeDuna61n0816kWW++p0Rqk1FQiwHaeiHLgGMHIhuFyw/Off8B5N+CHEiRN0/v4NFcE6wybB2ye0F7xmy1P33oTgM+9nPmeHJgBTZa7NLPJihlZSLYYMCKlEXIskXw4Ki+uLLdvhNo0uxB6B3o0CFM6suk8Nx/ecXP2iF3lHe3pIQau3YGvf+Hz/Of/+J/Rd6W/6mgyzX7/q3/2iuu/NtF1PcFGsvYIlr7vCNXJk5zIFGdRs+Bw7IwnWcdc27EFE9BcnLmslqiWOXgG0zEER5qqc+AVsRaDxYupyvSluKXNh5AS0RTHS/yWM2fY+xs2nec4D3S+rEU/RZLZcLbZchmFH20TO7V45zC2Sdd4ZDSkNGFsxG0u4ep9xj7zlUePoKJvXXcOsyduFGcDyBWXl+/QuyeEcFWuSzJnBsIMzkSe7i5R12G44SZOTFNZ23dOMabjpZ35bEzscyJeNi3Gel2HSOwtRuei/k+hNFidKb0yy3NMAqIVWVdYOoaUBpvlhiUoLd2AZDFJUYoAcNTafxNQzcScibkEgYhBTSl2yU3GRw2ZiVy7Okg+rZuZ1zfTxQ60NSenYK3ZoiXYMg3RWq1NTo4cnNb7omlWHaCynvVki6RQd5bXrs5xH7Q5BX0LbUHv27Rmx0wVtW3v117nRPDVgSs9Su/zzVjJUWnj8K6OUSm/TxR0TCq44JPixLKvQf+j7QX/9re/y3/4O/8G73ztETfulrt9QsxA16hH05HimCWMejR1kKBzFucMJ8JaKfQ7375J1znSfM14Bdd3L7irncdFKcQ2KfQCpaBosd7DVlySomJikV7Zj3vUb+h9YtKRfSxz/uXdK168CtzdGubpV8tlfqIds6gTU8hM4wlNsi5izYB3A96X6sbedzg/4uZTCqxtWHk9kUzhrlT+fEV6Stqzid5pBA3FqSvXAJoTj3rLz3mP+IbS9z3jOHJ7iCRfELNgHGc2crGduXh8gZgdcgSXDFnKwo/hjl3vETuSUiZESgoklki6sy1KL7o/qhBCJsZMDJAD3LU2ELMSZuVwjOyPmZTA2VJ9apxDa3QaQ2CeR8R4wGKdYo0yY4rQKgBKjpHO91hr6fzApt/ijadzjs4P9bosUWfUK2oTsyhZlN4ZdkvbDMFacFaxJpN6YGPQs4FgKxlcFCcN464kaCyh1hO1PnitubG2nk9Sv6pdyOBiHRhPch4xcykCWZye+xhSSxEsaQc9VV+2+QAnA9pSlS1CXUe9S0TY5paejPR6KLroCy1goJ4MuOMUKa+oMQtqx+r9Xzs5qwMeHHsvKl85eWb1tyby2D6zZcX5mGGeA7fHmWe3N7y8fUViQozBt4rIfgtJCSkgdqbzJTVljWO72fDW+QUA85NbXtnIbZiZ08xOQHoYPbiGTKfivFopDmNOpQggx8wclFBD2FbhKjkTYulZmZxwM+9JNYKNJpMnsJsNl29c8OTNS8LPXpVq6wxSffqf/+wDvv+DZzze9dylyK43XA4XgOBW+lcOJZvWZjsV4n9K5NroPFehSsEi4ouKfSWaa87LflH5BgsqmXMmW0jeoFXrbFDHI+1BdmAn8pAgd6TgGbOlq0ix7YRoOrwd2F32vCGZzbBDrcVUs2+yBWsZXE8fO/LUs2HH4Yc/4AN9yecffaZclyRCf8ttSjySLSYNfGZ4l8vPPeHVy5flmInSBs8r1pc2c66K4uaQmWuaEuu4lsylNVxPR948O2ccMuf97hRVxoyqwUjZyUuRDyCnlCZw0iajoJLUlH/p+pGX51O/oanNr4VpW+u2st5lWQGqpVgo1ecAVPT15IC0KzF1mazFYB+uMVk5Zcv7te8b0vUgCFwf22wT8Lr9WgWIwgohv2/eyq2FBTELnALQh4VLzZ60tKLRVfFcveZSIV+rWTnZjoUXK/WzSS2aakFlXjl3q0A3KPS57K2zgHrDeV+CvG9+9gv81te+Am97XuqeNCWc3eJyUSaAwtXOpmR/kmbE2DoPDc6deizHGMmS6awraU4p+14IgcO+of0n7ndzJhs6m1vapH5JMhPTzDQpViNhhpnMPn4AwI9/fscHzyvKv54Ev2B8wh2zQBbPOGei1nSgHQl5g6pUx6yjH4qyvZFTlZTUBLnYumG2XdPcXwC6+n2DXFFDrjnmNGfyCD9PSs4zpJf0vSeZTMi6pPlsNIgxWJMYtgnhHPFbYg7ldcA4eugnOitFviJBj2dwPYhlrhZcs0GzYJIlTJFYydNThJtDrX46ZKaxIHtzKJ/BW4fU+7JEPzkzz3O5TutwEZIREpHjVDkfRmvxQGmIPPiBTbfBmQ5jLZu+dTfoKZIKgWhmIjNRMltrOauM1MGCdRlnE1DSBzKAXAzMFTgMJpFz6TOarSnQucjCJTsx1FtaojQREjmlNbVuCpI6cAPOjSQ3kTUs/MDXxsqxaTIR64qpFqmuDRncd5Ae/j21S3547Or1S2VUPUA5LcxUX+t4cH5YpVhOjmIz+OvUyZrcvzb8ud5K8xF/b+haa8YMze8tDsKYjty8esXN3YEPXzxnPB7IeWZ0Sq7Vj1vniN3Afj5gO1sCJDVFAsYEzi8fA/DWGyOfMYn57nlJfyRDDgbrM7aun1xhQWugM8oclBxB50zOsiB5rcm0amAMM2OamFGOxzuCVo9r2NBl5S7s6fpHfPaNJ/zwvRtMJbQ1/tjxduSP/+lP+e5vv8MHr/ZcnnuGp+cMfeGwUZ+Lq/dpzpmgGVFTBYvrU0+FqyTiECsYJxgTSkoxC6aJHGepm2B5mccgaolq6Cn39FBbs3WuY8tA7ma87Ar5Phg6XxZRnz1z9njT0xnH1idEbM0U1FSm8bRetPMwwe2Rvp949cF/wz/8l/87/id/qUKHuuH67gXbeI45fxe84YuPdpztHnNdMwfdFu4y7FLGWsu53+LEFX5WLMK6UJTWbwfLMzPxjut4JTOf2ZX1KnNZmBoDOlWuIEUyQWubozUXsolNy8JqWs3ifHK+RDOtRVtxqpSosfDJ2hpSKX10tfTQzBVBSzW1XF5s7onBighKSbd5TutkuT5Oa7Ch5osPvjr8oWD1mie6Lgp4aGde4yY8eN+HozUvb30vldP3a7mMFvi1jEBD/O69bbUzJU1bXLGkp7UA1I4fFegonf7q/roSAK423YrS2bLOrYfcebzZ8a1HnwXg60/e4lpu+bMXf8ZlvGQ3nHO2sXhrWfpWU1Om4kp635e/I1qf4WnPs77DdqY8x0mJeebq6iX7u3KetUAvFMcshQipOqRrm5kT83jN5M6x6Yyor9inA8+urgB4+QEcbkCzWwqCftn4ZZzfT8en49Px6fh0fDo+HZ+OT8d/S+MTjZiFCLJNTDkwTo1An0BS7QlnsU7obKkYbJ1PDCV9nFboWIsS1kMrNiu5ePtmOfKU5ksKc4T3x1LlMac7tjtLv7F0G+jMOQAb1zH05xgxGJsYthNizpjjgFAqQQYr3KRXWGMQtBDgnUOsLZVfp5wHphQ/ojExTQUbngNc1XZS820hA5cgpUSeRgtB2ZqErRG4UUofrwQmCTFmki0pyLlGsM6A7WoJulH84NkOm5IScUo/FIhhM2wKyZ7ErBNBD6hJdDayq7piG1+qz3op6QI1Cl5JW4dc1ibt712zUUNnDNkaslmhZA9GI/JKVowpaQaMlPQmoLbD2A7rSuscm0MhdObXn/eakOs4pSRbdNTQr3UacY12LfIVq+i4pSheSzOuIuCoLD3sjN6vFm2ctxZRrwVt13yVFoHfq9biwff1a+Oosbo+fXAMtMj3JASgylKhdnv1gpc58GI/M04TBoMTVzgUFZjaSSA6S5AdWTqM2dL1TxHn2ZzvObwod/zJ08d8WQNdnpjyTOyVJDPnzpFq+snFUk1tReitgCvE91RFh9s1GlGMZlKOpBSYQmCMSjgmrqteWAodF85zG/c8HpTPn5/z/cEhU0CAChQzOPizHz7jd/7CY2I6cDcduJn3mGSoMoaYrFUiwhZiei3EURVME4VVg+RUeDq2KzB9AmddkcJoDFaVWlKfwSa8GExWHLZWoYLrO9ymx0+e3nfELOX8LhN9aXkF4FWR5OisozM9Q5cwSUkhkltLJinEpqQw58DsRvZ+gv0N//U//D/wpV1Bwt/8zG/w5HiJPXsEB6CfkXxk6D3TyqbmJfWjXAxbnHNESRiEXIslpjlxFHhO4EMXeNvMHF1gvjQMoaLc+wkd0wn51XSSmVlDuI3Ej64W8/1koBjQmuYuFY15KQAoZP5qx9EFGVNVYs4llZxP6dNE6ZFZWLTl/RvVYL3Em414SEFY81CbyCuc1vQCsD5AsNc2iNXrHo62/j/2Zz2dL1FpOHJ/72soWkPkW9Vke3273nW6dn09DZUDClpmalZKV2hhXvFpayZGTaEr+B3kJFzYDV986/N87fHbAIzzkT/68Q/42v4p/ssbdp0vRTgurWp0wTsP6nBWcM5hTKE9xFhknsr7Z3ox2NoqKVoYpxuev3q5tGNso0poghY+WUHkWASsY015j4eRPT1xPjLmF1wdZ95/v5xjfyjyS4q8vg98zPhEO2ZzgmQnksKrQzG6U4qoy7jKcXDO0XnL4BzOFNzdF4H7hY8Cq0msp82scJfqQoxa0koCJdFXj6m/yzO890yZLLxjMqqJTiy7R2WXOj8/57zrcQLez9AdwQh23mAqUX1A6CZhH665TjdYDQxq8cnTTZ6uGmdrFOeEwRmMgRQLBLw/nvp86rE4n1kKr0tQUlac1Zq2LedKhZqBtYFgDZZIZxzRKP9v9v472JYsO+/EftulOea6d5+reuW62qHRDq4JxwaGGAEcmglxOAwaiQRDCpLCNKkgGYrAkAFqaIKEACk0lAmRCimCHAUIzYgjcECCIAlHAiDh0d2o7q7uri5fz7/r7zGZuZ3+2Dvz5L2vGoYTE1KHKiPeO/eckyft3ivX+ta3vmXotZMy2VFEpJFMZhPmWzU6GrxqmNa9Y6ZQuiAGxcoKlA9E5SiUYiI2jlmlApVU2chpfA1NFZhc3U738OVjnPMEAmRnKwiyPpnYlKWTYfH+XgxekNwYJOHTg0gWSG0QvkXZMKToxuZ8bGR6su7YeUsPic37cVojGdq0Uy/jIE7LaHsybgzlYKTiZj8RBsHFsdHuOWeDcRfJqF5wrtj8PdyzfnxeOuY+ZTrm1F0q8LxgbIPIgUyE2A/8o452dUJQEi8lyhSJu9F1qDy2OgmFLpmZitrMiXqGLHappGbfdElMECjUOdvbEl9JDrsTSr8iyBZbWMp8r9vMoYwu4rQkeImTnigkkk3QImRIlAUhMSpJCFtvaQOsVil4aw9bzqShKGG6Y3j62lVE8UVEmwWK84UtdOTg4Iz7D06Y75csuhXnfoHsIjo7GjIkkVchDS4IXOg5b4EwDIBE4FdKpVR/mnAbLtmIVmCDRQaN9gIlJK0QNBoo0sWqWkOQhlCUhLqmbQ3Sd6BaygrKXLlQiABWUGhDFQ1FJVEyEJ3H9xWeQmBiJEhBTZUqC6Mmoji8fchP/tyPA/D0ew/50P5H2NMVM22Y6B0WtCwP7tKrFPkGqMBLT7tYYYRMjllo8d4SXLK9jU1C141w3KfheeU5D2vcltxENocgFp7YRWJUOV0pEPn7Qa0/JAkCEXtJmSx3Opq0A48vBGS2DWmeSwIen2diyLwyF5NQdyRVhrvMK0tzIuaASwyE+4GmMJowg0aX2Mwzz0XHbKwnNjhMY0dpFOT12/hyy9t9N7ivl8EGkcRdfU+H4KKdGv895rRe3MjGmez3I8VGA2zQLNPpORvlhjbRd2AZjl1kWpGC2Rwshm054/n9W2xf2eMwpirJs8US3dWcyjnnVzr2tgXEPJLzPFKqxOgaos/8xtT1JRWHJIkpACkVxhhwPvEGo+dsdcrh0dGgamBGBvECTXG4Nvlc0yOKZg0qdixXJ6wDnC0Vi2XmcXfkZ5b77fplX9mOWfAZaZCwahP5v+k6fBHRmdynpaBQGqPlIDD7OCqwebhFRo5ZvisCCSrxrKQiqX6HzSCUUqB0jpbPBd08IusCY2bUVbrDu/WUWTlFSZs8+NghzYJCFKiQULWkIl7RRYtvGtZ2zflZw5aZYIXH9BG4jCgt0FIghUgETZc0yjJdLfFWBAPK1SsCRp0KCcbKNdGD0yBUQEaLKypcYEi0R+FzRWDShprODLOZpsLgCBRFMg1F0aGVwUVFEVM7KaEVhYoU+eJPCkGhwciAjAIXBLYShNKir0wBWMwK1p1j4nJ5uxS5qlEyirU2IqyZVyQBqWVy5nqijhQIqZGqBKkRUiHlxmsaE1rHg6NvgdQTa/uvemenN1ieDU/D5q24uDF6w+9643rJSRocsVGkKRiR/+PGceqN6LhYADYNjOG3o0f2No6c2Bxrj9j1wpGJx/e44xctSKWoiopSBaQQyJi4O9HluTibY6Jkbkq2JlNKPUWKEuUNld0imFTs8UY8Qc8iN7f22WlmtO0x5yKyDueDyJ8pIr5L+9cEnJZIm8aAChtDKVIJGVqoIWIWyrGIniqkIpvVesGj1lFPSnbnW+xe2WWnrDg8taC6gYMlpGe1WvH6W2d8ZH8fHxoae0ZoumH8qJBI9IWqsEFgI7gQaZ3F2g35XymVkGsRMUrgskClD4EuI3nWW1zwmOCIUeFzOYHOjl06Jg26QKgSpVqMNkQrKHRCmHohWo2kVJlX2iqqUlDk5ssbhz4hkEiFlBEjYBoFKx1ZBc+d27cBWK87Dp5+mScPblBOt4h1zSuf+WVWj16nzgPHykwUl7BaLli1DVf0DNc41q7B9WKbQuCtY43nQHoOhGNZChbKUdXpHPUkIMqI7yISlaZkJvwKIQYyPiEOVct9ZfPl9mKiR2fCRgIjxpjQr7Cx/yHGVCjifVL/z6+91hsk+xezUfC5ej+ILIwqHp8jQxu1PoATDAKjfeAHI37p6Df9xgZ0O27W75e3I/f3vxtv8/KP+kbmsb+sl9bt34+dr3TWo+saRx9eOg7fFwhJCIbUUUOnf4PSUV7HKBJBzwhqVbLLhFs7T7F75QqPmmOOV6cAKGeIznDn+JTZoxOuX73O3rxEeY/SKTowukxzTCmUEpm0H+kFh8XgwAmUlsQQkcHhQ8NidcpivRjOJUaGQrEY+9qyBHDEAD4bXy0MBEtwka7rsFHiBLQthAwhGwKo1Knjy8Kdl5avaMcsRjhfgZzB0qd04KprkvGXAa1TlGq0pNCpPdD4twPBOfYkRpJKeF+VmQeS1gGtwBTJNdBGbB5QWZW0bVKJsooRlmDnhqANfdhsOkcxNSgp6YIgOE+IDUKeI3NZvRKGSKDyEyaNw51b/BmEqSQUkja3PxEhQX6DzEJMaIKz2UgASibSa3ropg890HYCokRkFd1kzCTaghAB7yNl6VBGDwRILZMkhSkU9cQwqQumpWIiJDGWCJMfitJCXCFjgZKRQgiiimjjMUXaWFVopAgpkvICCDgTsAX4aUIFwnaFPV+l8uk8oSL5oTvC5qNg4yBni6d6GYPBAghidsiULPCyyLL/YVPcwcgxGpxyHjfy+d+FVCUbB21Mxu9TFYJNtVL/EOnXCf04vBQlv93k7e2fZ1TdxGb9HhkbHhSjJY6vWYwXUhBjjTXyMfROpLj0MBACZHawnUhwvipUcn5DKr4wuhzQAEOB1BFRaabVNrtii3INZ92SNw8e8PP3PwNAd3bMh3Zv8MzsOlW5w+FS8qo4Zx3P6ZabsRtlMpAyJOesMCm1Z3SNzJXNCJ3QUSIqKLSBspLE445e1aWcKGJraZeWh+sls+3Iszdv8uhghVYQcpVXQjIFtw/WPHuwwpcG25zhfEOTESBCUtI3UuNjcswQGufcUJCgJCgkRkmEkNkOaXxux9S6ZLsab9CuyPQFnWgPRIwPQ8GKkTVVJQguEjuIxYpgA6LsCHIzgIUWaCXRFNR1DTEiqwKhFaJH+xAJTY4C66FTBRZNYQOddFifHdnzFnf/TfT6GBMUenuKOnnIrN7jUB2ly+BJLXUUBNtxcnbMLXYJ0WHxg0aZEIKoJFYLToLnruhYTQoW0bJlMkI/V4gjixUeRVJpFzIh5QqJzZVXfQGTEiLL6OT0ZAjD/Bf0/XUhRkEIPr+Cj2FUaRcuOGQ+5AArio0IbU6bJscud/N4m6BlGK95DA1OVrxI/h8jZJdpEhfI+HFkG0af/VbL+DfjpadE+HiRHjEct8j0itH+xq9DJWU+Dsemetwxolho8ErgVMRrQAmCSqhjLz+lpEBrkIWidBXvuvEURT3nweqY08UxesjaSJz3nC5OOT47Zrlc0G6VaFOjssyKQiCiTcFGTofFkEhBQoihSlgphfcWLQuUBhvXrNdrVs16g3Za0OVItkiMbOfoMVQUIaU7FUShgQLEkig8IrtXQQsQIdnFMST5myxf0Y5ZYaBtoC2hdQnyXLXn+KCQKrXxMSpFzkapTaPjGIbiviGFFZNR0VpgqnTVtQmgYG7AFJJKC2LwKBU3UUQe4G1ftZlrp91C4FYroks8jcatqe0KU1S51N7TujUhnoBOpSAz8yS6U6mJuZmjZINtO9pVxNdQ922UehchDzzvIfpIsOmhBSBJjofSOQrzaax2Ls18rTZKdyJGnPPJSPhI03WoQqFz9G2MpqwrqrqmqgqqQlIoTy0NUpUEmbk70dP6hkjMDxVPxKOVR+RtSRURPuBxhGBx3rHG0UlFkx3PMClwpkkdGGJEsXGEA2w0G7wfIuDetQghJCOezUhfPSWVSa2upM6IGqnFyshojmF7wSbdcEFTjFFEGS8a3gsVV6OVe8QrjreR3wxcDjaGVPTnSTZyIiFnPVJ3OQXRb3dIW+bXt6PlifH6+TguBL3jL/O1EzmEl1IgTQoiSuPT2DKBRnmaLnVdkCPoTy1XxLnGG01d7bATZsSTNS8dvsFPf/HXePXNRMJ4cnfOdHeLWm4xK2YsmyVlI6mBs2zlW591B2MGgCVUhWEyq6irLbRJYrVS9b1eE59Sa0lVaXZ1QdOn8DDMG1i2gcZZhIg8/exTfOrl20S7HlBH60EbxcOTFYcHJdWNOa5dE7xlnRGgmPPsMkps8Fk+J/W2LE1ybApTYpRGiUQ9kMgBnbHBIbKTV9iGwlUYXyCdQyhw9CKn6di3zIwres4iGA5aSdu0SKMRWuO8JQ6wq0RJhYxJsFc7hw2BLvqhs4k0GlkYjCpAJ0HYRQQZPB0tiyaNwu3ZdeaT68yNZFJIbLXD9vaEydYpRw+yYyYTIuICiNBxcHzIan2Txjap8jx7ltYmCRCpFafLllMdWZUSrwQ2p5r0dpI3IjtBQghEjqLSmN3AGpvgNA7BWQo289+IzfeDrSAFbKEPtQBk6vyS7WRqdp4qa/uuEj6mHpuXG5P3gdiF+TZynPq52oMB5Pf9rvtVw8gmjAPEy1Xg/TYuO17j6mqXf3whsBrZmQwkDaDEsE02gZ/PNuRygNqfe28vkzZiWkmKjWOWnLSUNpWFRBiBzFxjnQnfhZGYUmJKxRO7TxMLzVFzQrtaY2ygzc+NqCI7UWB9h22WHBw8pKg814qbVDkok96xNStxztF1NgMQqd+sMQUmR2YRj5RgRIltPYvFKYcnxywWiwHRVHJjo2NGy3peuZS5NSMgtUdLKNQOgYBQa5z3KVDRLh9X2qbOgUvLb718ZTtmErzWyNahc8qw6zqcX6PKOYE1TkusWCPEEtULk8YkCOgi0KUBaAooSijrSJERoEJCVSsqBciIkgJZGIT0makAycxHdJEmtaCHqxesFpqjkwyflhonYXtaYd2K8+aMEC0qQuuSYybqu0Szj3aRWjqOpWMVA46UH/c5/RHFEi86OgXBBlSAhRPYLg6TQvaieD10HrLRAZbrjmkGGLSyGCMRoqB1BV7KpAJeFWz1OflCI4H5RLNfGGZNC6ZkOdNcK2foLIWxXjtKwNqWA+DctdRY1BSmORVTiog1bVIR9yEjSx7PCZ3eSdd9tkVrH3JewBRPFTQRTyv8oGcEyeiGrLgegiMiCc6jhB7KsY1KfUCtdWm2CYlSBuXsEAFBb5wkQ7sn0sQchIW5CGQN+pVxEyX2jpJllAIQPem4Hy35fshNBD1GyvqUZe829xpJfe+65AeIoex/fHC9kbxsSMdHPnDnsifXow5DuxUBPiT6VxCgs5iZlwbhLcUsDZxuUiArWLYt1jqMtPhCUKkSmfUmZtLgjOTqXk2tj1guU+PqX3r1N3jp5UdMsvnpFivuHp1gKKFYcPfkDkuOWOkBcKZbge9AVaCsRHQBObPslYrtMjLJPMeZMBhVEnSHFB6lJbqdIgpBm1H1fTOBaYW4us2VcpdYbHHrXZIbTz3DvVccxqT56AMUQdC5jnvnsL9zhtm9AquAW6eZdubPmNQl2sKyWXIWV4Q4ZV7to2RKsSiZggKNSh0mtIHoabRk1bWELETpWkOrFhRSoH3HEpl69Xmw+aEyn+3w0bhLa+f82p7g9utvIZo1XbQsg6W3AI1T7Mor6UEQW6ZC4x2c+pYrMXX9iJ1j1ULpYauc06hzptUezp6jvGOq9gCYzmfsbhmuz7dwsUIXASaap65c5c3bdwAo1o4upmbmXkeOl6ecrxYE4ym85ERlrUMsJgjOMFB1PNBLrK6YmavIfLO9PUGWBehTRBMQocaJgNQ+NQ/PuTKtYmpK7mzS2cpE8jQHeq8nccYQqX2dJ+KDH1A1P5JilyrddI9NYrPEtF+/mb2OjQ5iGM3tjdVIKbANoiWG9x0bOxDZpC2HdKoYTNsF5+syOtZ3J5GPzfONLekzQBfWERfTkjneSs5tTwsZnZ8crTfmofYoeyQ5YjImm6NJwa7P3FGX09vBJKBgqaAVAuUjPiMIalqiPVyv3k+xVXD38BG2bTFBEIXG23RcMihcWaB14KA5QT+4Q1Ua5pNdJkWiwJRaEUjtliRzUAu8niG1zq0BM/e1KZhOJUqeYpcrls0pi/WKrmHwmlLArRB4ipika9I986gIhU7H7wJEJXHRosoCRY3wgrr2CO2Ha+1cBoL+/0HHTCsISiAbOGr6iqvAWnQ0oUVFmMcZtZgxCSuqTJ5fF2vI7V10lbZT1VCaxLGt9Gb7pYkUIqNHGqQMqQ1Qn+6kb3Cu8NJjpMBFSeM83dJxdpCQvEJKfNOxXBuECQS3xro1xih0TvNFG3GhQMgKpURS6dZuaJkUBu5LasPUWUdrI00LbZty333l6bhZqswpIC8Z2mL0AsQiY+jRB1zsEEpQO4XwIZXCQerL2SlEo8FKnAuEbs0EhQ8LihydahVYhIZFXOOdJLaRlXdUlSbkwgsvZWrTQ0DGDidaFqsFS+voulMAJjIidUhNn4VKj5o+7LiwbHgDfboz4rNTJIZj997jM3IgZEEUDVG4pI2Ur0OKZsOGa5Vf336v+RqP/4mRMe0dsd5wj5C3AeWKF430ZePb779vWdIbcB8his1xD8sosh0M5+VId7TvmP8Ql77sAbPesKcRHgnRURhQk2R1dVScHh1y4D0rA3FaImLS9+pTqY0p2a8LnqwL5o3gxKz4tw+/yKdeeou6EVid9nC6Ctw5fEjlUyPg1eqcZXS4aSpogVRx3ASYthnxrlJvRoBCldRZ5Lg0NUYaLBEpCpQskEJTV46+8V4lpwhZszN7gv1qn6KesVdVfPj9HXfu3CO2yTGLAVzoUF3Jet0hnCHaDhlN6vUKzHWN7Vq6NmCtxXYN0Re40OBVbuYei6TdZoqUZzapPZzxkUIEXD6uVbPGKkGUhmAkQcpUVdhLogOV0Uymc4yCyfEjjmXJOi6xNmDXlqLvFSMLTmvHtAhcsWUSm9UTqlAScrhR6pJ5WXNPeZwSOClpYscqWrSBnWlCR3WpqKRiWhpsMHQmwnTK1evXqOqEVDbqnDJTP4KH9WLJYrGArdx7N0uiaZMct2lQnHuJVhVzWVPWU0zuhVvZJVaUSfRaAiGkueMS505lIxd8S1+ZOQ4+4jDA+wAr5vTkhkPWV2X2aUofE1HchzA02u5/03Paxih1T/rvEXPHJrU6ODskJGnghPL4nCR/NvTWFZeKhi6jbFykRPS/HzYbN8hdj8AN7aniRfRtjLz7PkCDDbUiXkTqxvsYc1kVG5vRieSQQRrqQkBroJEpoBQy2RyTEafCKa7OnqWaKR4eHmDbFh9Tf9IYPSoj9JUyFELgxRp32vKgDRTljMnuPjtZqFpKB6GldWdEeQXlJEZ3zI1EYTht0rO41gUyTvDWsVqec362YtUcsziD8QMgxoyVChAyDN1bhNrYdKMT+qdMbsNGok6le5Kvab1GWg/WD9fmt1q+oh2zSsEs2uHGAxz7gGgFVaVoC8eydphaIyYBlw3l1KYBpCvYqsEYqFIbPEolhsomJQOlSjRoqTJvS4JUAjmUbQh8TjqLjEjpGCmVwvuIXaWbc/pgzfl6jVkoTBUplEMoR5QBR9/o3EE4RcmIVIaiUKAVPrpUNdQLUSKJDjob6YKgaSNds0nzpIPPRyeSTRciNWUNMZXyN9mAtx4a6wm+QWrFpNYUcspMFVS5BYJv4ZyGg7MT9q5OmamKSYyYJiCq7gJPa9EsOT57xCrAykqKKFmva1qVoxoHDkuIS5AtK2s4Ook0rkV2fbgikaGlcIqSCkVKHwMXSJyDwQyBKGIuhEg3oU9bCyGJQiBQRKGJwoDQRGGJ4qJ3M44EvdigSW+XEozxYnXk+NGQfp/+6JG3/vPxteqjYT/67WUx297QCrgg+tg7UP3x9che7E+CUUTfgwd9dH8hd3n5xDYvKaWaK9FihFrisuDr8njB2vskDiz7+5Gdw4x8THTB1nxKYRQKy+3lAZ985WWWp6m/re/SzixwsFpwc9JgRIX0kbaVdIIB8m1DDix8juQrqFFURYFWApOJoZUyGGXwsUXGPA40zGcTjM1RupPEyYzdnSe4Wd9CCMGV+Zyv/8Ccn/jlT+HWDwHQssCLDrm0LJYtkhnapW4TIiNThYDgEw1AC4kOksN2idBT6oHgn8SRtYxIoRBeo3SEaBEisPZp3K+9xYkWE1tKY5ClwWuSM5IHjouBRkOsDWZRsg4T6vWabad5hKDJecpCVwRSOVwx28YGiRIVWhcDchCjSylYKTEBCiSlUKnISESqIh3/JFhm1hFbS1VO0KagKwKTnS10WWzGbdgQv9t1x3q9pp7VmbfY87Q8hRQ4JQhRcoUZu2rCRBdDz1B8mYp09BIX18goIaqB0jCkI3vl/9yWaUh3+pHgbNw4Yv3vQtxw0foncQipkjoEiFHkf9l5k5v508+hfl73qNHYIRojVZ7klHWXnJw+DTgEdmM7Ejcomr+Uxuy3eWEZpUR7H1XJkXM4mv+9vQGGCs3+OPptD0g+Fx3RIXswshHE5ER0/e8jg31uiLS52LZRpIIsBUJDkTMo+/VT7My3OV4fc75YARK0pBMCiaDOSHGpJFJAtIa2c8hKcHJyxPHDA86u7AJQLFfMpWAt59TBMpluUZsCayMrZcjay3hTpqyZXdN1HZ1tePTofnLM+vPN56FlROsUULicBlF9ASAkHqepMbpCm0nKhAiLlAnYgbSu1Rat1nSXnjlfbvmKdsy0htpBqwRVNvLx7BTvW87xTMyMq+qAcyUwomaqEiJjJwk52pOCiUk8EqUEWpL4aDm801IglADtkDqtI3EIkfg2AMiIjBEpUqugEByBpFovo8Jl0uvZylN4RdUJ1oVDzgKzmaHwKR0JIEqNcGuCSFyosvB4IXL1aarEBAg2ywVYgesUnXV4nxyvHjZxOslkJG5L4ttJKTcVjJk/0jlPa8F2nqqOlEZTG0FVCOYkZ8pKi/UtJycn3H5oKKcTyqJmHjSuWXOe8VnnFb4Bt/Y8WJ1wtDhntyxR5U1muYefCJoQ1zjf0Vg4aTwPz9aIVUedh+N55ygClN7iVIEjIqRKzW8vlRwngnaqnI1REHOKYuBrRehZ40FIIioRnjMCMY5gx/IYMXtVMY6QqPF+2RivoQn5yECOOwb0Bu2xNkpxY3j7zy+0+Mv7HPtRl+U78hBMKtsxPcT7B4eIG57ccFw87tiNl+Ghc8l5EyJFiudZjyUIN7RAsfkBWWrT19ABMNtSzGpB2QaOOeUX33qJB/c7dktJ4wMmw/pepLT30jXsFBpNZN1IVsERe/0rF9AaKqVS82oBRquUJqLBh5SmFLJEaQXO4+gIOKJSFFWFqRKqZheeNRFbVmzNrjIVBbJUfPi5PW7dfILXHn4RgFIGug609JwcrtD6KSol8RJi7q3ibaAUilAIOu/Q0tBEMM6yjumYXKiJYYKME0wUSNGxlgVBydy+LPOvWoenQ8QVupCUfkooBEoLVEbouug5lY6iNpTTLa7Pr3FjVbMnJa/4Mx6sTwDQaEpZs1VdQVcTpkoTRI2uocgpgeA6bAMzPWXpHyG1whhFYQxOgM7OrhKSqAoqUzOrtmkrwwJFvdVRTVJq+0yRq9XSOPFtx7ptKWI18LYgoV5lUByJDq3gKVFwQ2kK52CR04qdBgdKlgTREL0FKbKD5nG51ZUSyfamuXSxzdUFyQy/QS+GuRDjBe2y1MCcjNbk9zFlQy60Vxsh4xfsw2gyybiZtz47ZWOqA2wQ8LHqf++QPTZn3y4w5HHbdfmYHou9Ljl+IV87P/p9/93YVuSfbuzZ6Pz6b7NQSerZmx3ZRiXAxEawGoIUWCGQOnClvgbAzvYVztwJhyenSZUgSLwXKG2S3JUqhr00scM6R60NJQHbLXl4fIfdu2leF/oWZmeXWgSUrCjEBK8rvDihaM6RxSQf/wrrDYvFgtY2LNtj3rz7cEg3Qs6myeSEFTq3akz4SyoqzOcoVYVWU3QxoSjrpN6gDdYHbPZJjIZONRgBDR0bosqXX36bwNo7yzvLO8s7yzvLO8s7yzvLO8v/0MtXNGIWAVeAa6DJru4iBFyQzFtHMIqmiJhKMJUGUycEyKgl2wIqH5nPtnJEmsppjcnVeyQIUkqJEiGXvEcEKQUxYNUya+tEi9CBGBwiZI0cLwkxIzM+4lcaFwTMNHqrwGiNFg6RK5ZCKxAhEMSaoARapygxCplLdPuYSwKa4DVdC64FYurBN4Q4BsrasDWZUMisoRRS9WXXdcgcUa6CJchAUCFpvemEEEgcTRZFiz6hIcZqlocd9+bn6KpGTEomncZ2SbfKdRY6x7preev+XV6++xb705KGU4R4Ju2v2EZEsD6yXDU8PG+4d3rKpI0YkSDpJoIoS1rhccHTLjIrRia0q688RUrEiEwXRZYciWGoTkvNiANJrTsSlED4AqE8MgSE72VDgLBJSX45Pkh/9WED+Xs26v0wio5zNNxXS/bb7gevz7fLfZlUKYzSqCM0r09tDr07h/820bsc/fYykvd2y7iKNLLRMvMkmaEk/RZp23S91Fbaz9omDask8Jlgwzrz0Ka1w3enWLPFi/cP+fzrBxQeyjrgGzAZAbYxsuoiD8+O2JGRznY4LzlvUrFIf3yFh7IomUwkrV2gKo3RgqA8bZZQcFgqXYELOOdwwRNReO+ZZMQsNh3RR9rlivW+4lq9SxCOeRF4z80rvPyZrKdFoAVkCSeHK04Xlqef3OVscYrOiLlUCp3TNVVV4SKY1SLJWdiEvHsniMohdJtyItJgCLhMCndZgqTtHF3rsAEQhl01AS+JQiNlkfdXEIoSLQzzcptZUWH2JVcnW0i3z87pcVpPatayZlbs8Z55yXZZgJqwQCAzp7XDJNFN21BomduECKRQRB8wfbPz0uAnNeX8CjEYRDBUs23KdsVkkjhmQaW0bo/YhM5yvlgxbbdwHtTQFkBjY6qge1Jp3i0MT0SJblZwkhBGTiKhseiQZqKXIovEOqQf8Up7LDkjtknwuU9vbsZNooEkKQznfeqR2SNled64mNKdPs9jF+JGFHY0Z3rawSDMnNGlMQYixGhuxowYcRHBGuR2xCadOeadXpinlxC0x7ij4vHPxpqI480MyduY6BqBeEHqx+XfBsnQbLxP275dsUEkpSx7ObJOw7ofXzJi4ybNuQgpTbmlrrN/NT0PFv6Iw7MT2ugQVCgEOioQCiM3aFkXPTZ4RBsQ21PwEtFEDg8XHG4lTujTT2pc1GyZinq+g6xhcf4adbhGVUmiTQhw6yJWr1g0Dctuyb3Dl3nwwA6VqJDsnZAxPQ+NQKtU6R8lg4YogMo+Q1nWmKpOPUGdQ3uBySRYa0ukMglKFmt+O4jZV7RjJgCLRIdImfHThTjjqGvYrffYcpLV6XW28IjJHYp1uurvL67QSEsoYaucUhQaYxRlWaJ1seH5qNQKJQSGVi/EkCsAs9YREKXAtysQEe9TtZ/1jqbzg9CpVwG3AOEFV8sdnprsM6s0wqyJJnHfrD+DmLhnupCpkjAWKJVSpDbzbUTUCCRGFOioMcJhRTZI+dhnU8X21oS96Tzxb6wnukjbuCQ7kJ1Bb/zgVGiRqme8t4RgsZkH0HYdXoDammHrkpOuQ52vaFXBfjll0jsAruWsOebO8R3u3b/Pm2+dcEfB+fqM00VKgd3cvUmhkqDuqjnnwfGa4zPFDpHtMjuehUTVGmcCK++YakX0muhz5WXPoRqNhUAcWuD4KFK7GUhE+R6aFwlOD1KDUEQxaoA7smzJgboouQGPFwSMlf/jaL3eARsqHfN/F1IP+Z8fpRcut1Pp99VXYQ3pTi6u21eG9amTISWbt9kXCvTqACFuqqnSeWxUu/vqq14ktz/nmK9xzBbDxsQncjIZcRcC0XuiEkwmOWWwPGURBPfFGb929wS90GxJRyOgUGowbtp6nI8cdyuWrkbGpKzetgyUgXoWKSwU2rC/vYXtSryB+awkKrD56tvgCVHhvMBZ8Nbjuoi0lhCSE2F1zUTNKZaeg/Mjrsy32Q6arjnlw889wb/I/MSAR6sCLztkgDfeeMTXfegGKshh/ktVkCRhAqqq8MJgzvK1zwUvWgeU0iihkCKgxIrCqXRfBEM1ovQR14V004vUtaQqJhhjmJg0ZyozRemCQpVsF1OeuH6N9dEjfFWyPdmGWbr2ZhXpqJhv3eB9Oy2zSlCVU05Dwd02HdfZeaQ2mrZI6Rkhc5PlEMDJgcMpZWRlHQsfmYiYAlcdETpQVmaYCzJRwSDTM1arBmtT8/SYqx9diKytZRoCX+0K3hsLDBGWS3iYnEoOBUIE4mqN71JaKgqFdy0hN4kf5lDYyGSEsRxGXkJ22kIUF75PTthGxyyGvhX6RvHfES5ISQxzNr8fgqC4qabuF0GWkYibrh0DRYCNLeipDL0dGI57nFp8G4fogu0bfX/ZRl2O+cZOXO+UeTbp00G6J14878DFCvX+8/H2fT4An6v0G5PSmTam4p311DMv9rm18wyNSMH8o8Uhrgl4o5A2pQGVUihhIERsXzWrACWZ6ZI2WlywbHeeGEoeLU4BeHBwyhP7T6O2Je72l2jbFnX0Mkdhm+mHvovtKj1nWxx+uSAEz9qe8dbdN2lbgfdxqAIvct2JMQKjRZLI6K+xhDJTc5SBohToQlIUBUiN0mCy/QGIcokXAo3ARwWMyGxfZvmKdsxiBaILuBK6Jgu5KsUHFrfQb2zzdX/0XZj/26f49Kuf5cqzU66t0pVa+IqtwmBp2Z1NKEtDYRR1XWNMOUwQpVSOIlP7lBg9LoZExg99NU/Ax8gqmvRZ6/DeJ69ZtvguHZcLFqaSybTi6vYut+ZX2KpKQtHS5PL8kzb17EJEpAalIyI6hHF4IQkh97eUEa01k2LGtHCsC08TbRI9zNH37rzmytaMeT1J/fxEwMuY+lsiib5vYRVolEWqLJzoI9ZanO+odap2aaPntF1ho2BXG6y1nB+viWuJrdds5caBPq45XJzx+sMH3Ds54WxB4pyt1pwvXgfg4fUHbNW7aDnHR8nZmcPbgraoOBOpamZSVPjSUxaaphEoKSh90p4JQgxEYiEV0XuQcmScXY5oRzBRumpEMvqoCoK2CO8JudozBDc4KAFwIm4M0yX0agQuDUKNGxrxxV5z42rJC2NXbIi2vcPUt2waE3R7Dy+OjkfEDbGXfAx97zpx4ZwvGvXQ/270EAiQpAZGxyYGpzTt1+cNO0mGRaAjsnCCtYRGRLzwxFKgS43OpPHluqFRildOTjg7hnlhaAO0HcTgNw+fLiEXCxNYRccsJnkW20LMXSXmESYVzKZT9sot9GzOWrTUtUQY8DGPZ9fSdGs62+BtS2gtbg3BC4TLpYHyKtd238V8uUXoAtE2hKoC4fnge57k2v4+AKuDuyAFbZOqtu++dYSzMJ9M6ZrMHxMy9WEVHic0CxUpq4htBPiMmNkCnMQrhQyKoApCiAgcOhpy3TFCRAQeEwr29JybxZRZNUObGj2ZAaT3usCYgtn2FteaPVot2XKSO905r64OALhKzbysQQXMdsHWzKPNOUptsTpIdnB15vEiIXJBa4KWoCRRSZoYWeVrasOa5vSAg919npyXTKYVTVwAnmmdJUGMhG5TuaYirNsW13YE4YhhM8+EhJui4Kt9zRN6Kz2575/B/cN0f841Umdk00ZkrQkxVXdK4tAWLMhUPulHVZm9ntmmt+ZFZ62XmenFZIeeujFJ7/TO2cD/EjC0WsvzZDy3YeSE9XNq5Kz1c2wIcr6MLei/ejvE7O3o4uE3+TsFgglNvLy7sfbhUOU92l/PdxuOI79qNs6iYGRL8iedSKBAGFUsegOtSojZWgmmdcWV7WuoKTw6vgvAqvEIIt6V1LJAmYL+ikWfpE0AjCwpRQETi2sb5vMtIg67XtCu03Ojc5aj0xOev3aD11/8JO86+wD1lX2uv6fh8PjTdFc/kI7LLnG+IwTH0dEbPLp/RHCR4BI4AVCqVAxoCpWqLoVPHHIiUmx02LQxGGOSTpopQRXoKCDqfvpDF7L/kLoD/XaW35Fj9v3f//38yI/8CF/4wheo65pv/uZv5gd+4Ad43/veN6zz7d/+7fzsz/7shd/9uT/35/j7f//vD+/ffPNNvud7vod//a//NbPZjO/+7u/m+7//+wdl3t/uMguSZQzMtOCgSbDn+ybP8K1PP83P3v23PPFewxN/8+t4+b/415yEiJxnIbqmQdcGrWZM64SYFUpSlRV1UQ8Ng4XO7T+EQ0qZhGRDpAtyA2/nlIkRCmlKTHA0sUtoW6WJOXpQtEgCV3a2mE1LfKEQZclUKspcSWXKbe74JT4GpJJo40EITBmSIF5fVpJVjEtTUmlDYeRww8ucWpgUispI6iKxL70WeCmz+KrGdjkV4xtkANkLmbvUvHzdtsyyDsuV6RaOCWVZc63eZ6ssKUyBLgyxczxc5X5mqyXniwV3759y79ByvoZyDWsFd+/kXnmNZWfWUegzjJoTvKE2GofgrE2Rj1IQjaWTFa1MZPCla5ipapCPAFJ/NpITJvrKy+iyUc2aSD7kwj6JjxIpDEGTpA+Mpw9rIj6VzJPMgh1FqmOHaBwh9oY0cLF6KpCds8sOExdRrwvFAHGz/uWqyUSszVA6G+LtQMYNm2O5HCGP99lvV4zeXyYJx5y3kBJEFEMZfanAK/C5Gul8FTn1ka6WeJN1mrqOutKD/m8XBOet4+RYMlWBoCxdB7IBVzKMWeeSmGtbJ8RLUKdWRzERpwGKAJN5zZWdbSa6YrYzpXXndHJFlA0+pvvYOkPbGTq7wrsG71pc03G2dPgsRLk9m7G7/TT7ky1ijGwFQaM7ZLNiWiluPZkaJ3/x8BEei4npupweWB4+POBdN/ZZ9gKBIlJVJSq0LH1EFYYZkqVzqXUT0K48K1OAmOHKAiMktiqRIiDcGpWFn5QokXJNHWt2Jzs8v7XHxEzw0uCzY6brlC7xRqInFbe2b3BXBKq1Q65OeeX2awD88sEBW7OrPP/k+7lx42m2C0VjTrB6zTQTqkuhiLpAMkW5IyoMUzQWxzJ2iIxW6K4lnnseXHnEblVQxxnrbo1UgTqT/5XW0HWpMjek69J1HY1t8Nohhl6mHmMUT6O5QUU5ncEaeLSA80UeECW+c6igkTIhed73zqsYzRlB9GPnayM0u2nRldHP7Ij54C+Q/nvR0I2MRvo+xjA4YH3as59j/dxO20/2yEcukvTFRYRJiJGB4KIzNUbDxzbk7dZNgZsYjv/y99A7gI97gENrqC/zu/4c+5/3advLwZxnY4P7QFKQrkGjwfbViNk5XGqwE8XV6hbbOxWPzh7QrhJ6JaWkDQWFLKi0wQIueArJhSIuJSQzVXDYLlHFnMMjeN+VZ3nfM+/js4/eAuCNgyXve7rGuYqrz1aYh19k6+ueori5hb37v+fR0X8GgN55L75t6JoVt++8wtHBkrZNtrDss6c5SNYy0ZnI45cAImxkhqSoULLC6BplKpTuNyDxWflAxVRpjvMDmvhbLb8jT+hnf/Zn+cQnPsE3fMM34Jzjr/7Vv8p3fud38uKLLzKdTof1/syf+TP8zb/5N4f3fWoD0gT4/b//93Pjxg1+4Rd+gXv37vGn/tSfwhjD3/k7f+d3cjjcst/Ii+EXQMNWTByT26ev8GD/F3jfhyoevvg6ry3eINxccfLonB2RjvFUOaK3lKJAaY9WiqrQlEqiVRxKYYUoUIVi2S2wLtB6IBhU0Ki+VxaRiRHgQ0JrqgJfVljbprRKvp+2KpLkhlFUhaGKgokuKEVJJLnW1im0O8JzilKpLQvGY4pIWUhMrt7y0SSNM+mRRmPKAk/SRCr7npTKoKID75CiQmiJJqBURFLStCm1OBMzrJe08owYPNYGnFe0bsXaJ0ep3NtiT21T1BN26zmTskRPNUbMEMFylGepPl1ytgy8da4ISyhWsEIw6SJZBJ3mWHDWKnTh2dpyzIsJEzWlUhqVJTU8gVZGVq6j2+5Q9y2VnLGmw4jNpPA+GVFBmkCElqg1AgZtKE8fuvpULZYV5IOUGXrPXAGxcZQSisRF8VeyYYqjFOVojo2V9vtouU8t9hH1WBajd8QU2QnkorTFUAIf+9Tk5uFwGRmTYsR5iWl7w2ZGKcsxR2RInYrHUzCO5AgFkgBwocB6TZg67h2ltQ+WsNiBlQ8sbJJBaVxLXSUB5bQTSWM6og9Dmit26RwnLnEJARoBEyTCC07Ckj0TMThUhD6TIRToSlJODVrW7JQ1yggeeEnbdTRFmhvHnBK8p7Et62WD7SSPVhaxtByJNJ7bnY73yYqb85tJBNJ4DtoVx2evoJoFt7KI7hdixIeA70AagQmRz714xle//xm6NtsI1SFMJAqNCRWVO8XIKUVw+CZtp7HbiPpZyu2rFJMZ1XwHFc+YRI/wltnkBIB6fcZpc8bMd9yclUy0Zm++iykdD3MOOcgqaSmqCVGccH3XMI3bHIYTVnbB4WFybu596VXuuZf43Kf+HS8++Fq++w98kA9UnilPcNyle7iO25RFQSimaJVQAVcq1qVDrgJNl/Zpo2QlW6Zn57R7C05XJ1iRnPdJTmX2nT2sSg2gF8DUCmQjEJMCm2VWTIRdIXlWRd5ltpkVNU1zRnH/DLnII7EImBhwUUO5hS8UcXmOjpIYJG2vHh+zFxhJivy98v/IiYiK7GR5pAItBMGG9JyNMsn2ADFLXccoshC3ApG7ofRUABj0CgdnKs/ZPtAaL3E093tH6UKKcvPxgL73gZIYrdfP5tQOimGlofdjf2y5yhIpkCEOtIYx/7TXHOsFqzMtdJNaJTlYiX+2sTtctikkm5mEZXXK2JQgg0RlEvDSwLoSNKVgv7zCja05b57f4eh8gWgymqwFMmjURHPUdEzChLY8pAvbeNMgVJmPvUOvC6bVNbbsUzz5ga/i6b09/tz7FKdfnQKp//cLLYvWUYqa6XO3uHr9U8jJlNv/+C4vqVs89f5HaYyG93CyWHL33qd46fWXOTioaNcN0UDIbky9hplK4zsKRxfD0KIwXYMeK2yALZRIPkGtDFKURCRtTtcGX6BDQNtuCFp/q+V35Jj9y3/5Ly+8/4f/8B9y7do1fv3Xf52Pf/zjw+eTyYQbN2687TZ+4id+ghdffJGf+qmf4vr163z0ox/lb/2tv8X3fu/38tf/+l9Pedrf5vK/+FPfzf/mH53y+vKLNPlBvHNFcuvrFat1w6e/9JAHh4bGbuHPzwjZWzWtQMg5pZJEL4k6aZEFIkJJRE7FOG9Ztyvs2qTvBEBACk8v0aylQCpBCJooTe6zGfAyDag+cnPOQUyaKat2xbIsuIbAlBXHORo8Wwba1SQpRIcWrRsKEZlpQSGhi73khCPGiNKBqjLUdZkeIKFNDboB6xzOVsTCJPkAIQheI2VAG8c0F0K44IEOERVCBJwXNE1D3UlWMRn5+aqgvLbFznyGoKSYbrMtK7QIrKsAJ8lQLnGcdud0ixWLJawjmC7SFRtSZZFa32FdSpsKJTFGE6LGZ76XKztCseZ8LbBigi6gWfaCsJ7YpzJjJCCRMcePI/7Ixp5cLKcHklyGVEQpB8csKkVw7mI6gYsRJmzSEX1qYiD/i02aoG8wPnDTRpDZgML1cD9cLA4YGerhdeQw9ss4ynak1FHI27ocrY8fBr2jNji3PM4dCYAmqZwLAZ0HZBgeHgDdBNoqc1NUorM2FmYjT08biYki6Rb5JBAbpEK1ftAd7E/SERAeOgeNctgYCLlcHUAXMJtOkFEyqTRaaqbVNRpveLg8wGXkU6vIync0ncUj6GwSFT1bd0xF7rF4eoS1p9h6hbVrJhK0XbJcHmFXJ+zu5iiiEhTnIpGbbaSo4NXX3uT4+L3oKuZ7q9BFgXOplD+K1LBc15FJRpOevHmNp/dvcH3rJjtX9oilQoRbaLkCVlxtrgBwumg4XRxRhHOuTA3SLrCcIpVA5Pp8qTyFqIg+oE3NNEra6oRmteDA3ufhSUoRLTpHAdDAC7/8Sf4vR2/yP/mur+FWaZiGJ9J1aCzCNRg6lE69O9erFtf53Mqm1wIUdM5hVCI6a10QbQfIIeguioK1W6eOZx4QkWAdnbNEUQwITxSwIyq2iVSiSD1Wz5eEVYvsC3G8SR1ItEEZnSOdNNviKLyJQuLItJKREGyUYjPfvM89cSPBxdyIPDnczm0oKSEkVM1Fjyfg8ZtekqN51Adv4/nYOzqXF8EIMRdcSGX2Z3HB2WOEao2cvAucWrEJGntjcRmND7lx+7gXMGzs1eXgMISLKc5ASts+1jT98vtsRyyWQgjaGGmmgTBJTleQlkWM7JQ1ezeucz84DhdLmtYTTF8M4phWJYX2NF3DaWyp2i1iAcopnuSDAPy+r/oA//ytf8bBUvFffN0f42N/YIdDecLZg1+iDsnhes/tfYqvusWj1Rk7syXyiSn25IPMnv5mnl29wekX85z98CFnp0e8dP813nh1yeK8paxAt4rJOl2J0xL2qpQ+F1JiUDkVmcdFFns3pLElhEg8UqUGR77PAApTIn1Aq4pS/fak//97ccxOT1MEure3d+Hzf/SP/hE/9EM/xI0bN/iDf/AP8tf+2l8bJvAv/uIv8qEPfYjr168P63/Xd30X3/M938PnPvc5vuZrvuax/bRtS9tuOkydnSXyXLN9m//gG76D0LyHm7MPAbBcfZby2pKTh3d57vk99p9+gjv/asnD9ZKrZfJgsftIuaDVUyYoYlDpUsiKzivCMt0c6x0+CjS9ocp926Lc6F3FSAyKojDpBkmAgBIOlVNsAFEH1lHSuIZld86yczw6OyeeB5xN5ybXLUjBetUiuzO29wQxGEIbEcqjs/6QWzlC7EB2KGMHhC+EQMjcsSBK2s6zlA7KVKVFUEihKQqB7XKLFNHinaTrJMRAkAGjQaJQuQJv4c+YbBVcubKP0hHXOmwBeiqJbUuzSFHnYnHG2ckx7qyjXadqvSqmh7fN12teGHRRIkWFddBaz1o1lGqSG8KTHLlpRJ45FlHhZUBwTgwFEIdoRfbpvRzWqiiIflOdle5PD/uLgegbYnIQotK58SwE3GAcI5cM52gsDrwzuTFoA1es/05sdInSRWawhoMm0ji6FlxA4QIjI8vj6ck+CBYji9w/QERMvqbn8TTqcA6jCF9eNrbDfuPgtFkgyoANEDO3oi0DYVtiRMlWVeOcY+E7ZFng8rgpKoF1ASVSKsK7kPhBXare0pfOOcSUurQxcXwcKSqHpClU6UQKV0ogMJTVNjvR0ISONmtuGRWxnaezjoDC+3Rf8ZqzrOgfts+5d/Qm85VhNpkhTMHZ+pyDh/fw/pzySnKoru5u82h1jmpbvEiE4PMTx1tvnfKBD+wAcHq2JnhFZxWda2lsl3QCVQSTnMV6Inly/wq3dm4wn8+JdaRzFZ0sIdZMJ9tpvXLFpKjx3TFd6ZFRU24pJgZ8kSvKilTp7aNLFWxKMakLgmhY21N2cy7GTktW5y1VgEIKbr95xj/9mV/ho+/+Gr5qL+3vht5hS095pA5xMeRig4Ky7bCxwzbJLimlKIqCsqyQQg8keoGhnqYUa1lXrMQpRUwIk1JgraXrHIryQkrqipfsOZjLGhzogyU07QBNCQlRaYQpkEVBCBalJUHYZCNyANd5T4iBIHrHpZ9kmyBMxI2N8DFXZuZK7ZS6TNc1BJf5wskp652ksR5hX4Ud2VRS90U2nosTLcbkCA0pP0aBExcDtHG7tT6AvZxuHGzQpUCx/05c2nbP//KjwxJiE/z1dqsPPAfbxeZ9iJlbNjqP0WEM+zYkR87UGr03web2iM6dUE1gf/cWPi65fXhMaDuUKHC52mBeTJlUE+LSErRjR9YEs8JFxwcmH+dPfd3HAPj417+Xbz1u+aP/5/+O8Nx91iev4uRDtp96k7syC8fXiqfFy5x98gUm3yIQRQtrmHzzDSarn6T5x2lunK8+xNnygNt3vsjrD1sOzhTBe25UnkU+0ZmBoi7z89ujpUpjpgcPesHkEIc0pxhFrkIITN+GUHpQBd50tOpyGPz2y7+3YxZC4C/+xb/It3zLt/DBD35w+PxP/Ik/wTPPPMMTTzzBCy+8wPd+7/fyxS9+kR/5kR8B4P79+xecMmB4f//+/bfd1/d///fzN/7G33j84Lszrj26z0f+0PvY23sPAHderDi8/TPc3JesJDT6kI//3ieY3n2VF30yIrI8x9nItFrhgkbGEhci1oN1Gte3NRGGsiyJdAylcEKkar78IEieskGaSIwekR+rIsrh77TTSO3mrJWlXTQcHt/jzeUhbefZ6p2WsqTTMBeBaWWJ0rCzs5MEMNGILF+hlCHEJW23YLlacrawrNcJxeslCFLD8g7JCiMqikKiJalKMUJuSMC6i1gnIRqcd4gIrkstVSZkVe+p4+jBEm3e4sa1SLQCP285OQ6crY+5fZIilruHj3h0fMZxG1j79JCNBvZiajYPsNKW4Fp0UdA5QaRDmYYoHTo3fI+FgmlFNC3H60BLQKoG6Q2SgB6zbqNIfLKYibtD2qEXkA2DoxaIIGQyviESYhabBRCSMKrE7A3UY2Xuoz962kif8uy/6413b2hV/nuMcvWdAS5u9GJKdfhqVDU5nDcbI65GBrd3LPuHhWSDiIlLD44+HepH6c5+aYZdJfitM5EzC4Oy4jwwfeI61WQPX02IMXLuWvZmGlZHAHTqlOOmQ7h0LKsGYuNxHqTanKNAggoonUSjk8hzUiUNPWKmknyFKSIuSLoYsVim9R7XjGURU5AoZEcTGtbSI1VAKUehA16GoeBleXCPR5M32NqfMK92WKwFX3jtM9w/eIuyjrgcAF3Zm/Hg3jJdN53S4FrBFz53l+ffnZybtonY0OJETAhdJ1J1aozE3NHjZPGAE/uA69UV1lKghcGZs5Qe8QKZ0XdRRoxVuFhzZD1mYtifQdV1iVgMHArFOkpkEfCdSFXU1lGheHK2y3u2U+GCWXleXj+gWULRRKa6g5Xn3sER+1vpuCphQVQ0WiKUROsipWSERmGxPWGfSERTmYoYJd7HXC0sqTPyXk5KMAm5dSI5/M45rLUUqAF9K4Rgx0XmXlLrAjyI4yXS2s1IFYqoS4QxCK0IXZs4vjnoldkRa4NPvtyA7mSCNZu0phZ6QKA8Ao8Y0LH0L8+lkB62Hp8QcEaO2NugXL3TMw7Gxj1vI2lOKXq0LDdDvzTPegdoLJvB6HX899jx6rldeXc8Zk6kGJCz3l44NjZiQOjY9Prt1+kdSB83lIyeJtFvf3waEpnaZM2vM9uac2hS54xuCTd2niJuz3nt8EWsUzgf6aJnqhMXui4qFl1DoQti8CxZ0zQCYyY8v3+Tb/n6dGSvyH/Bhz/4HH/lf/5xbu/e5xu/9ds4+MzP4dQeV2Ta1vU/NGU6O8R/9l9Rnb8Xrr8Le/YK5aufpalex3wk2YjXXgmcxNc4uv9evppb/O5bgcmz7+ez936ZV89eA2B/WmHKdEwy9ZxAkoCXQMRm4XgrPEUZiM7jVHL8lZSILBWSrpcDYREqpFz+b2P593bMPvGJT/DZz36Wf/tv/+2Fz//sn/2zw98f+tCHuHnzJt/xHd/BK6+8wvPPP//vta+/8lf+Cn/5L//l4f3Z2RlPPfXUv9+Bv7O8s7yzvLO8s7yzvLO8s/z/6PLv5Zj9+T//5/mxH/sxfu7nfo5bt279puv+rt/1uwB4+eWXef7557lx4wa/8iu/cmGdBw8eAHxZXlpZlpRl+djnT99cIZ4R3PrgM4RVrui5+qscdG8yPf8wxfU/wIeq+/BE5Bt/8D/mJ//6HQD+/r3/NU2pKdodpPYIGRDS4WziaEmRkCJjJNEJhNa5lB0gIOUmRSlVRKlUdq2URIiBiknEb5COGBHeMplUbNsJy+UCfeJpGjjICnx3Y8vawl4NuxM47yzvrlbMp5oQNMH1Jeepyav1Het2RdNEbE+SzvtrS4dRgc7CIggqK6kqCAi6NnKySjt9eHLG4ekKa1uij4gCRCUhKGzm7dVxCquK43trOntAVXlKVxJWa87acw4zYnb08BHnxw1hlcjgNkIoMkk1h3u+kzS+pYo1nZe0vkHowN60osxoTBAaVSp8HTlslzgjKIWkiB4VIjqHuUUMmfTbw/kxo5QMIV1fHu9j37g4RcchprWJ/T7lUCreR8x29L7f5IDy5DRg/5sYudBKZRwJ93ITjFIVj0lWjPYx1jwLeexcyHzm7alRJHxBkqMHabkYeWc62KYqc7S9y1XcvcBsICaJkt2CUzpU3ng136N48lmme9cwdU0pNZ2OxLDm0f20h7uLYxZtgubaNtCtQbhcRerH6dsUiRpNagdWFkyaiImOtkcFlUktgnSk8WsKJ1nEmj01Z9vM0Znz4cWaQpSEAAsPhfaUKmB0GPp8yrDk9bde4OThbUxZIDnDyAZEga80vkuFMZMrNdPScGzXVEHTesdUw+uvPeDOWzeBxK0SviOYQOctbSPpXMQgKfJdXawe8cU7n0JUlpu7TzMPuwi1AqkxQRHyoAi+ReLRpFZuSyvwCig1Qm/umkAgVUAojTaGsqiY1RMmU8P1rWQnD85BH4EvEh/QbIHc8zT2AN/0Yk2OUgiu+IpzYyiKAi1TWy0p5TA+fJREBIWZIDC5oCbRLIq+eXxdDRITPnMcow+JrxZTH1GAWmjmsWXmJbpQ+NbB+RqcR/bodRQIY1CFSTxQm7EkIQh9nh4gp3SjT+mkvicmbKgMLgZC9EmeJUZCjFjn6LwjxDDIMfRpzF4cuueN9jystM0NJyyhZ2LIQISRcbgw50TSBExrjwxEn37NNmLMPLrAJe1fR8cwzNVLadEx12zgmY0+HyPxsEH8xsifFylr4Efrju0UXJQBUTkboYpt1PY2TdVyenIMQFFsMdve53Nnr3F+1qGYYvQ2hYhMs4i7iVCVEmsjT/mv5muvPEkT7/ATD1/i/U/uUO8l+Zf3f2SXN19/lT/2e76DX/13n4LXX+TOF3+Zzx+9wbt/dwJ8pktD/f5TzvWLVFNJiF+F2/kMpT7DuikHZ78GwOGDHT5774A/0v1u/sQ/+6Ow8xEIC379H/9P+U9+ICFmYqtK6XMvkCFAtJvq3BjoMh+ylEleyroWoapEY5CJ19hXnbfRJ/klwSBe/1stvyPHLMbIX/gLf4F/8k/+Cf/m3/wbnnvuud/yN5/+9KcBuHkzGbJv+qZv4m//7b/Nw4cPuXYt9cv6yZ/8Sba2tvjABz7wOzkcWgQP36j4pf/mNT747e8F4NrNOS+/5Nh/8kn2n7kJfBRoWb/8Cv/P//J/B8DiG1q2PlCwbE4IxRW60KJdRAqL1gatE9RfYNBRo0KNljJVoUiAEXcsCvAKrRN/In3GUNbd8x6UBmcCZRtQpUYWkkor1sqz7u1RC9jEpT9fpwfL9m7DzgS2imIojBA25sxqNqDCp+IFG+mpeEo0zKcGgcR2DU3nKW1JZyPnp47Xj9LkOT5YE52g0IqIxzvwTuKtHFIZsSvwkxqWBccry7VbgbmsWTYNbx4fcH6cHbPzc86yvkGhGWZ8CBGTG8PXZQX4nI+PNIs1R7Eldo56klMsUWDCjFauaaOnUDucOonBZWJ7T/5PaY3Yw8wxXZcLukWktIUbNTKOQmaib8APFU9Z5wyfofycFhk9B3rD1hu63vGJPC6XMeZpqEsOGVx0hGI23oKNgzU2hoMzKDZOmRhtoycWx5j8zHHKQTBKecQ84ccpzUuv/d9jo74WkerqDOfWnMTE56i3rzC/+QxXrlynrg2TcpKcpvYUkVOGL7evgBUEDd0iEtqUGjUBnItkXwqRizWkTN0AKlPgdESzYjGMZ6hVQfACG5csO5Bum3lcUUZBqXOxDYKq1rhg6dYNpTZo0VDIgJlkjiYerddEe8B6YdGlR00j1c4Nirogq2oQO829/YqDszNsFs+SAZp1x5e+lDS3nn/XNVTp6bzlvHWsV562C5hCp04YgPBrzhYPeO2NSFytuHHlFmY6Z1JkLl3etnWW1iXZFhlaYtdx3jkmk4J15qvYECi1BNshREFwyUE76Ra89PA17izuAbAMh6lrwwSKKVzdU5SlZ7K9GYNGCIRa0mmVKRlJUFsIkZqlZ2NiraAwc6QwSFWBUoiQhGNVlvCp62oY816mcSqix1uH8KBzM+paamo6JlGhhMauLKJNmoIx83f6Y0FJIm7UwUIkTcne+ZKB0IF3iYIRRs7Z0B1AiqRXlguEXO6B6XKxQMjX3oXUicFnx7KXRgujeZ2OgU01Y3Zghu+4uIQ8mTbz9/E0Vp8e9XFzX/LPNn/3DuLouwv7Gh3v5T0INsffpykl2WHM3/XpTEhcsaGaM27Wv8wz67uPpDeKWCnqPcGb/gifxd5vPnOT+/qY05MjQjljsbRUakJZG5xID72liPiouPPglD//Hf85f/xjU5r63/GxV5+kPf/nLLa+CYA57+PpInJqfpmPfP0uj1Zv8J4P3uW/+4nPUB9nuZkgmT5zC721w+HpF5jVX0+532DP7vP6vS/y5r1ULf7W4a8y/ewW3/qHd2H3VeA1ovw8+0+/wre+LwUav9FaoqwgCKIUm04yQPBxKEBzKhUYdV2H0gmIca4jCLA+a515CF5mNYf/ARyzT3ziE/zwD/8wP/qjP8p8Ph84Ydvb29R1zSuvvMIP//AP8/t+3+/jypUrvPDCC/ylv/SX+PjHP86HP/xhAL7zO7+TD3zgA/zJP/kn+cEf/EHu37/P933f9/GJT3zibVGx32y5/zDwnt/ztTz8jTf4zK/9OAC3PrrEeYl7cAf4NbqH1+l4gdP7ryO+NTkQZQ3tYUtbTlmdHeSKSwVSI6UcJrXWEmMMSgtKrSm1SpVmxgzNVUtjKHSJchVGGIwSCClQMiE4ui9pVgITS7xSbFVbrLYsZ6s15/YYnZ51GJse0FGmAXB8DHdvB65eEexMPTpXjBQ5epbZQUzUt2Qkcj9xzhcO5wPTiUSppGN/ugosFo6Ts4bVMq2oY4GRBQKHdR1egXUCFyu8TfwRuTdjur+LaRXF2vGR2Q2ev3KD40eGB/YN2lUagKU11LHDSRABKhJCYiM4kRsPqwpFCRGMSoRfvy65Z08pFmmyzkpHZRpWjWDbTXh6a594uuB8vSbYMKA2goAJqZFxiAEVQ4rkRxWYSeE7DuKRMUaiTNVXLnhCNnEOgRcyfc6IZ8HFCHXsDLl40YnpjWXP2eiNWeJzpSKEyzpFvdEbOCKXouA4+kd28nqPayOiOTrGkCU68no9EXjY9Mj499u/TFoezi9/5g34vZqyvsLrB18C4MkrW1x/5nluzPepK0NZVmgTadaHdLlFl35UIf0ica+cQ/rMz4hpfPSInwTQYFSaJ0ILCiWzllpap9KGqSkISuE7y9KeI9pzrCswphoi0YRiemRGfdJ89gQTCTZdpbkpKYpAZE1RVICmEIaiVtRGD07E3Efu3Jrz+hsPCT7xo7yPRB1483bi0e3uVcy3C5asOV61rNeSruvopMCv0xWsKokqO86OT7jrXwM65uEpQlVgCkkrerHnDusDXmiEhqZdcLzo0LHE5cAmKI0hkdaFMFRC0jjBqoUv3T/k9TeSTZYadrcgKNjahbpMchFHzSNWs1ydGiLEQOuTAGZwHrzDRUdUEpsHa2sjdTXNmk0lQQViTL/pK8+m0ynKCEQbU+cZB8JHfGfx1g0dI5SAIgjKoCAo5LIleocIdkAYFC5XHoasJQUuRqRWhCiGAFhHQdd1+JgEZWNuQB5CHJw5gSXSi2enOd+j5957Qh7lQfhkH0TqEHCBqM/FV9g4Z8O8GTkqYycrjObf2/HLem5Z76C9HTW8R9V656y3BZc3Gy9+PBQAjDlqvU2T8aJTNgSfZBvQc+b6c7l07GM+6wpHIc5ZnlhuN8dMbyTOeFO2vPbgdRAl3WpNVW8x1SVES5EDDe8lbeP4H936Rj6wc0DxewVduMof/7YzPv2rhkV4AYC5ezfr+CTHB2fsPX8Hee8R0x3L5AqsV+m5vrNX8oUvOfauTjn3JzTr2+zNJA/XL/P6l1bcSwXL2PMTnnzC8fQf+Ro4/yz3/+t/xerWLt32FT74jbcB+NLPg48dPqaxSAyQszOpIDlngHTEOUdrLapztG2LUEmM1vrMQ3eBkAMB95jr/PbL78gx+3t/7+8B8O3f/u0XPv8H/+Af8Kf/9J+mKAp+6qd+ir/7d/8uy+WSp556ij/8h/8w3/d93zesq5Tix37sx/ie7/kevumbvonpdMp3f/d3X9A9++0u8s3IB/7YDsenn+SwSTfwpV9z6OeWPPXss0DHz/+f/jmfv/sS7/5P58zem6ZI0wkOHnYctoIYkihiYyNNlwdbnknep3+mhEpDqbMacAV1jr63JzNmkynz+ZzZZEJdVlSlSbpcUtInoYRSTLCcS4cSjpmFnWlBtyqIWYjyMECx1FjnkCo5WfceRPYeWqZFw2SSyurbao1zHUIIyrKkqsG7iO0CdhTCNW1gsVxSViBFQddVLBtB13mES4Nmq9qm0obOLrAaQnYKWwfPuIRyPr29z/NP3EIYwerkjKqQ3Lxe80zxLPfbB6zadPwag/cPeCQ9XZP0jLoAwoihckqaSC1nWNsihEtpxk6w8g3LdXJ2G9WxNV2y7CZM1TV2997LDm/x4OXXkSJQ5SqYIkZEiKma3qcCgB4N26SQJZv2KzGXhztccDlqTus5EfESvM8yEX3kGDcpygtSFvEi4T/Gxwm0FyPoOPp/s4wNobz0+fiX4/cSBvX/C8socu5/9zvlKsTsVPf1L0EAtcZWgsmt6yxOk2O2rgU7166yX1+jqgpMUaG1w1eSk0Wqmi7LmuAlMjh8BBEFykbWBmZsHmRaCtAxKWxnaYiBLJ6fVIUo0QiiCqxbT9etcecHrOsp06pG9L0rhcS5rOoeBZHsQASB7rUWVcTMBab1yWhWE8xsyt7OFtJoquzAyd3ArWtbfGpaEBfdAGsGAcvcReT49ARVbHPu15wsVqxXIJ0j+JaiToHm4iSy7iTV/oRTc447ep2bMRCrGbIq8HWeG65Jkg7MELGmEwWtjXSriMownjIVDoGWOt0s7ymLCdLXHBw4DhKQx9YUqgmUe3BlC1DgPCxby3GdHmSn02e5MtvFNCsIcXC0iqLAFIFuhA3X9QyldPL4oksBoZSDk1RVFWVRE9pVyioEIAS8DVjrCTnmDiEgg0ip5ygRTUhdAZwl5PSQci4X9OTkUU8bERo0qNBX4Epk0+DWYUDRfAwpaOkfnL0EBxEXcq/MEJJcRtgk8XrNf58tdk9PuOzs9MuARI3TnP3fYjSv2VQ+i9H6/W8i2THKSNwGhbq4zti5uly5Ldg4bpePd4ysjwsZhoKDeNFWjR3SsTbi212Dfl9GwWmz5uBeg96eU+4nMv6Xzu5Q+IpOT1PQJSqUN1jjURlN7oTj3vEhn3juOb7hE78LVxwRX/g5HqlfYf9ZqMtUsEfzr5HmFvPtBTRXMRxTuYrtvRIb3gSgbQte+lzHxz5a0nqo6xK7eoPD8yV3HsDDzHZarha8/8Mfh/IB7RuHVE98PWfilMXhm5hMd/DdklCkak2JQPbjy4O1YHJA4mxAm2Rv2rZFa40uAwg1ZJya2BGJWOGxvZL7b7H8jlOZv9ny1FNPPab6/3bLM888w4//+I//Tnb9tsuheZkf+sEv8cRHdtneTZP187/2KtutwHzTA5Yv19x58yW+IF5DvrbP3m6q+nvw6oLJFtx7C5Scs7ArvIsEl3gkPnO5fJcqC32X0iwi1z8rDeU0Tfj5zgnT2Qlb2xN2tqbsbFVMKsW01NRVOfTUklJCWSPlBFM6yrJkanaZTo/YyecTjuFIOaq2Qqkm3R1vuP1my809wcMqWd0dOyF4S8RClJRCUdZgXJd4OMDZyuM7TbcOHCwCbdcxDwZHS1XIFIYCxrQ8OTHM5T732nMeyTNCG7lS1rz7yhyAj169zkdvPIeYLTi9esby/CGvn624eWXOnZdWLDJCsr9fI+U2nJxz1Fo6lYtQ1nEYaZ31zGaOqVQ4u0vAEsolcSWQvVp/hEVY8W3X/xjf+rFn+ch7Ot76+XMe3FOs1+dMckXPqotIpahCRK0sSmuss0gh6GNPH5IhDoiEZPmIEjJVOAeBzw6ji4HgkwHvhR4hG5/893hKDTOhT03Gt3GeRtFwEFxYobfRgswRY4OuDV+MFpP30RvpPhq+sIwi9X4fPQ8EskGPI4P7+G6G322ROro1EeRUYuclu1e3mO09CcArp2d8rGtZmRVuGrlZakpRUqh97m0/ytegpg0W5SSFj3SlwMrIVKTqPVVmXpgPFAJUFShiReEr2jXIKLFVumilC0yooZ1yP7aEpcc2R7yiG5byKnvzq2mfwmP9MT6eEWRLUB5BR1FuWvnszSsKLTBFklEomLE73+bqdI6VfujPp1vPU9ct13bnvH5ySBlAVFAaoEk8tNX5nMVew8FywfLYo5qAE4ZGgzlNUHggwN6cbrulawPzU8mxeIS+4piECYiEhk2VIFaSRlqW7Sl23XHaKGQxZyvfqCkOLUpUSJVfrVKELlDrKUopHp7me53a7LLtweax1yzgxGncdtpYKQPCO1QZQCUhTWEiCk3nGmKuFu2iQtgOiATVYWNAeJ2uk0nzbLsoiZpB0DRo6KSk9R7beUyTU4tRor3ES0OYC2ILymW+UU6B4Sy062RwRQlBErROQsqxZJ2zFSoKJp1lvTqlcQ7pFdGl9nPCZrpDz80LkVQ/KHM1skUKOzTYjj0sNJoQfSqyR9/Gjs0g3Jzfj+kKl1OOQ/A23vbIkYLcClJAZNNto9+uiGm+9Eh4r4t4IVAc7Sud6aZDwTh47L9PvLYUMI9TmWOxWUFqpzSFXAkfh88tIIWgjIrGOwwVWjdcvbHD/VV6Tp02Fm9mBNulrjjS44UF7zgJ6ZkRO7he1/xvf+O/5X+2+Gp49QX8w4oj+VF2pnc5vZ/pLe9u8bv32YqpW/JxdcwZD4mF5E6KFZl8uGN1NuWNgxYbS973zBIpWjqvefDI8jCXSa5WsD2bE5litteU720I8Zy7XzjnzhtpRMQiEEPKRjnrECHJLVmXrlWvF+plQAqDFh4lV0SvcUuIZWTZp0caw3PqCu3iVd7su/f8FstXdK/M+Y05D1523H3jmPqp5A6rfcVnP+N4/Ys/gb/3NdinFnzs3U8S20Nm02Qo1zEROys857Eh5v6KIUZUCOTnBa0BbyFEgbdpAMscDfapYruERA2zLOw5rJasC8GyEpRlQVnl6M4oziaCSblNrSRGOyaFw8Sa/pFvJjA5h0Y0mAjBQCSweASvvN5QmRMAhLE5PWbQKumYCaswKqJVT0qEVkjWa0dUMI0FJsLWZJfoJa1K2mNFZdBFx9X5hNLOkcvAkV+xfWOP9zzzbLrOakZbPeS5p/e47t/Fg6M3OT5a8Qtv3OHuw1dZd+m4tuYTijIynRmitJy3CfZtQsqzA6yJOAlaFJQocBHjNOu4HtJ8xgmumj2+60Nfx7MfvcfhS4/47Cuv4q84FoeOsu8tqGLqQ+agloY2Q80MMS9ZEDAkacoYMvEfulwa3+dPvA80IuBFEiizJIDEsnHIHBvDJtgYMLhIoO2N5oU0Bhcdqctk+yHiHjmF/bZg4/h9OacsXHpQ9PsbO5Z99Kzyw6Yvhx9vK8YkldHk/SDB7U1YVhqmiivvStXQry3usOrOUfNttpykM5F5KXDeD8ffRY+OiXvpSORqFZLkRJTge5HmKtELihK0AakjSsdEcM8XUE8m2HpCdIIdKTmyLWexoF4IluYY4dP8n9UTfBQIYYAmGVc5IfgOk0vVddFQmJC6Z4iSiZbMdwRVVVPGyCJ3xXDCM53M2d2a8IY43Ih/eoYH+unJEjH3rNqWpomozmODw4aAatLIkaYjKs3xdsWk3iEEC90ZR0cRf1VwPW9sXUikCNj1Cc3aseosYWkoTM2kTAMhlDIhgTld7yWUqmCqp1T1dJAXER7cEo401DHZLduBEw6fZXcMmkqVtLLMKV+DwGT+zKbBtzYl2/M6SVdARiNzANPTPsoCrTUdG4kIQUDEVAQV8o10OLro05g0Cso8DoVE5N6c6UAt+EBUkSglwhikkkCFznxC7QWhqFI61Tm8t/jgiTIMHXSMLujbNPkQSMnLiBMxz+8eMcvIVXaAem3Cx4IfMXKAsq/So2P9un0g5vMk7KVpLkzPuPn+4pLbrsXNtvo5GvO2Bu7aGK0bBYZ9CnJA/XqHsM8A0L/GIXAcig/iRftWiPR8i2wI+yG41MkjRFrhIErOZcNsep3VdMmD88TlMqrkPCyo9B6V1jTC4bzFhECvs9plfbDzk0f8i1/8Nb79+QPO9lqevf5HMFd/F597NfHCb/+Tn+Bjf/KrWK0KtudTaj2hsWt2t0p+5VF6rt/sYNUtefgAHhw7PvzcQ566vqZtO87PocndBuzSo6sZAoso7xBuXkUcLzi0B5yvksMovaETFhOSWIYV6XxFzJJGmchrlKCQaZ10L7IEcvS0GXnfd3v8J9/yH9OaFf/3f/oDl2/42y5f0Y5ZJyZ8+FvmrN09Pn+QtFPOO8/kCnzqk4c05z/D+7/ta2jsKW8dBF58NV2oxTHs7gssU1AthVc4It4FNJC7GtGJhJhJnUazklBVMJ0KtuY5FbilqSearWqCUR6pPDI6fONYNR2r3GYkAWYBrQ8wekJdC9Rkymy2w8plg+QbQiUgE6MVIKXH24JHdzve2s3O57ZAC02zTrlrIxWyqJh4RWvP8rbAxQ4CbJuaPTnHNJFtM6OqNfcy0TYWEGaWdnrKzBie261Rd9bMjOXZJ5IxnTaRO+cHxPZ1nt2CLSEIBLZWHZNdySwmQ2kKST0pmCGIocVHjzCwdIKYc2PL8zVLWVJNJ+iyoMXSLBU2FhiVJoVSOxwcH8D+v+Rs+T7uHj+i3X/A0V1PK6Hs3RoNsROo6ChEwTpYjExFDL2OkYvJMQvZMIeQCP+2b0bvN+tZkVKasKnKfFvHbGQQez2hcaQJOSUcYWiNdgnN6g1/Po1he3nVC5pj/Wd99Dv01Oy/zA7d5ZYw/cPgsWoqsQEGxtWh/XHVpJY6Cui2od2rCJMJdqJ417veBcC/+4VXeeH+K3zTE8+hfKQUJl87h8woReuTTpTMulcOQSVAyogVoPOByeyoKQVSbjTboofdnEee1DW7pqYTgoXRBBEofEJ8CAafyzdj6THa0HQRFyxBtXR+ha4CZb7QRgeUaVAFGCkoS0lhpiijER3UZTLgZzagveb6/hzns0EOqUiwzPmp5aJBHsEiOLp1QHQRJxtkLJDdJp12cnjCq9YhvKW6uUcRW6S1nIglbO8CsK32OFcrbHdO10lO7RLLhGm9T8hzTKAzr1QRM0NRq4JpPWNvvofuGycDroHTAEdturZCQijg1CU70sSQ9OAiaKlQFChlUEIiRaTNKqpRKqo6OV5CKGLs8pjZhB11XVPXNR1nqD4QSNEuIThCHuWtjLTRYqOEQhJqSRQSQ3IaAXTX5XyRQ8QAukQaTYwq8dx6xywIvDYopQe6gscRYhxSmzYYYkw9MmOwxDzvfQDHRt2+d07Grz2VYTwVe0R74GH113q0ThAXEbGh9+2lOXq5WCDd34vzcVyt3QdmY+eu387lSs1xX8zL++z/udFvR2bqwr5ltlsCMaTmIBXBaFIfZBHATOaoZ0ruLw5p80VdRUltCroo8G1LUaYLZmUYWuY56ykrzc7VOT/03/4r3vW/2qa7861cOy2Q1wUf+j3/YdrW7S/x5ktforo65aw54HpluXd4QqG3Oc0H/YWXDAfnFjWVvP5A86nPn7O7d8Dtu5qzlcNlJDxGSdg+Avd/5ad/9KeJTz7Jx9//bTjX0eabEpUdnHtIBS0xJh9AS4HMxsRogdEyKTKk7DxCRjwem4ugnpzs8TXv/iomT9zjZ35u/jZ3/fHlK9oxq658iTlfy/L0LR49TFPj4ACC1/xXP+LYqz3/+Xc94I0vnvC5Lyy4fTtXTSrwjSLoJVJUyJjr8bRP3AbVk0sllQoonyJ3baCewGyq2ZknSHJvu2Q6NczNDCE8EUfwFu/TjbFZx8LhCZ3DVJI1a07OO8pyhamus7u9B4AQD1h6j9GwCoYYLIUEWUSaFdxNvieTsGBrUhNceiLHCNEHjFaUmSRcVpZGSXwTUKHh6vwGhdF4luxvdXQ2V4LVFU51rKoFlXHUesKzoqDQJ9z3aYcffv45HiwMd2+37D5zgqmhvrbPldM5N5+8SpvbT9UIShMIxiYycTzHOYcWkpXL1U8RTmNqv7Kjk7Gfmymrieb8PE3oerbmoRX8/Z/9Sf7an9nizfgWTSmJVxXdRHN4nBANSUFRlKxci3IBTHpU+cwjgYQ4+r6Bsfd4BG10ECJeRNbZpHoiIoAXSRa0r77qGHHHxBiLY2OxR5Hz6O0FdO1tSfdwQe172OzldeOI0zLa1oXUyOO7eBwNE5d4aWKT2hxLgqS/DQ0WsW/wpaacTHClodrdBuD5vRv87Cd/lT/01d/IU2WViluQuCiHKkNCzBF//xgXCJV6LCoytyyfoAhQKk2t60Tkj0lt+1ylo5mYSGUCjVS0jxqMV7hCooXH+SVaJ+fG+xaT2yUpaiQRozpacY7JnqBSHl0EtK6BKYgKF2TqdxnF0PcxakN71vH0k/tMZwK3jpu0cLYRnfO0q4glcTyFI4231uKyY1PUkrVtefjgnMXZOVE+z60rO0xsoFq3xCLZiKpcQDB0zmPFCQ0eHQsiBVKmed0XJ8l8cwUBrwSmKLky32aS28kFl7hWoYF1k+6xKUBOoM1yGd6CDDIFdiKipURrjVQJTcsNUHCZcN/vuy/7D64ZEOeiKCgn9UZoNSbHLHpP7EI6HmDtLW30SXxTK0QtU9QaR+IuXUvsGjA1UVREJMGUCQWVRerPBclDlgKFQAmBDSF1jQixbw5AryMUM/k6EnIxUhIEHlcjXuBcyY1jNQQx8cu8cvF1/GZIJYrHg6Che8xo4op40WGTbKgIl+kSF/r4XjoW2Eh+MNpPz4XtJTp6RO0yS0mM1tcwOLqQuH3OB5RUCT0tJNdvPcNnxKscrR0uOy1dJSh8jZWWSqQLGryjC93Qg1PENPbqecHPffKEf/dLt/gPnz3iX/z4P+B//A0v89Iv/jwArl0j312wOG7wsUBfEdx7ILG0g+369Oci7QoWHbx1HPj1F9d87ddP+NTnTzhoCkydHCVfBd46/TQn/9UJn3rLYE5OOTp8iYqPcBZ+Ol3bTiJVSNxDKVLxVgw4CaVWFDl40yoNX6UFQkukjhuOc24TN51Jmukh3ekv8L73fSPwKX6r5SvaMTtbvcR7n9vl4eIRx6fps9/4dcHshufVO7Bt4Oj4EaKULKyizciUnqe0pdFzlChxQqFipFAWp3yu1IISibGWpmkoSkE90Uymmt15zf5uIiXub0+oK0V0PnnKrqXrUoWTUmogx/ro6Hwi6WutKJmho8WYFmXSw66TO0ztIYsW9NoTdYGNHYW3NAq6XLV4UnokDQqFjFOkEAgZkQiK7CRpaXFdQBSpoulRd5v96VNQFhwWZ+ybxLc7WDsKIeC8w9YtauqJlaIrPK8vUvJeriomZU1tr3N4XlGXD5n7Jde2FdfaXZou7bOMgiZEWK4pdYHSJaenR4TODwrbSJjKgkoZhCmIE4UvHfpU8FS2VA+9Y8t4fuWT5/wf/ut/wze/Z49feuUuu3IbW9eEszQVK+8QLmlgaRGRQSDwWZ+oj47SRPE+Yr0ghEAnfUpXxkCXy55tSL0CXBS0xCHafEwV+1IKYWg2PBqXvaZRYDPBLivr98uFikq4IH0x3u4Y4erTFONtXkbL+qbpl6u8YsyfjYxxX10K6ZgbUmq+nYN6cotuu4JJgSoUNj/4P/re9/HLP/Pj/OgXP8X/8uO/Hx86dKiwaGI+Iy0kQkEUqdcrRKIWBBkpBZtejNKjZEKqC9lX3aXG50U+8XVsWLoFq04QnGUdDCLMOVO71K4ZeGGFSFVeuoe9vcIYgWwcMje/jrFPq0kKOUFIjRQGIyoEni4D2K11NI3j6rUdnrx1jVe/8ABpwIVNSxZhI+3a4oXP5G0NdEhR0OR1zo8agiQV89xdcOxeYO/pgmefvMrzOzfYv70FwHZV88S8JnQn+Kix4QzkHkrr4VrBpiUZIhU2xBBRUrA92cIUOSAJG8dCuiR5YZuItnC2yMTrNuDmjiBCKqAZbVuLnlQLPqrkLI8qnmNMmmC9pyB1atsUtEKHVNlMBNc6XOsQuem7dZ61DKxFxMsAZarCJYjNw982iKYlmgap5iAlspoQZQnKjLSg1sNYNwg6UtVsFyIySxVI4bKOYhh6ZvaUBk9q09SP+3FvTB9GTtVobvmLb4flMtI1Tm2KdCC5lRXD/vrA7cL2R7YFNscg4iaIE5fWiWxsUs9D7RE2D4MsTQyb8xujb73jCJuiot4hdUJQxZh6wPSX3Qc0EosmRMn1J25ysrvg0VGLKyEU+RmE4YTIXAgKqVi7deoYYf3gOOu6ImhHWC3Zumr4wR96gd/zf/wY7/9Dv8oLnz/gC6+msXr3UPChLQ96hay2aIojXr1rCcIOKPG9o1T5XS0ExyeOT3/miFff2KEzgU4FsjAA5RZ89lcbbplv44Af53oXeem1l3hytr1JgYtAlNAKCD6iXAJArIRQCKa59ZnSCeXXWiJNKk5yvktcxwx+sOV5ZD7N6YO7TG9cuzx03nZ5rLDrneWd5Z3lneWd5Z3lneWd5Z3l/zvLVzRidv9gyexrWpRwnOfmfvcWkeIePHgAi4nm7p0VagKHJyXCbAh61ZbmiXKf83XLkXUEFylkgdJQF6ms3ngwcclkK/Gy6lnB1rxkd2fC/nZaZ6cuMVpkRCjg7Jq1XNB1DcIl9AYgekelBd47TKGZTXYxZaCoNDGkiJkYadeB1p5yGgLWWYIXBBUpZEo9ALRWsFoHNCF1KpC5V6eTQyWVKkBrQbuOmO1UMHD/7A26U81TT+wxa1K4VTrwiyXsCYKoOPMdsox4YXkkE3Rwdbng+Wf3uXk64eSVbWw02K23WBtHObnOznbKm09lyTpGbLlMEbfSCAHL5Tltj8cXmnK2zXQ6Y24mlASsd6hJw2lmLu+2gqWCm9drfvKn7/Hmq6c8f22L228do0VHXaSbPbUSH1zixSiDdC5VloVI6Pv8jXriWR9yA2PHOno6sZHVcCTYv4f3x5IXQ8XS23A2BiTr0ti8zDnrUxJDtecYseJi+vKxFCQj5Ctu3odL61xuhC4vfX55GXgoAnzf4J1Il7ftr5b4vRnd3pSVCVyTBpVTV9v723z9/nP86E//NP/RN3yMj5rrtNEi+waDANGDzLVw0pNVR5GqJ4j74YTLAgwOFTqCsLShAeWpctaqaDzn3nHqHSXQCs17n/h6Pv7cN/CF00/R2E8CUBZTOqfw4QgXznFOYP0KpUkNsEnFPA6ojKQwBuHBLx0ojY2eVa7e6lpLlAKpIs/cusKrLz5I19anYhHIhQxdIpRbEYhegBes2jW+v6kucbyqQhFl4HAReeOXWu5cvc2bzx7x7isJvT72B6zf+wTbhcYsPSK0FDJxWESe14gNKR9CEnv1gVJEJpMKUWUbd56u8zodDo1LRG88HGd9tTamzIHPeohSglBJ0sioEpWlA1rviPma+BiwLgnjbgpt2Gia5XP2bFThvXVon2ANF6GRgqVxeOHRShELSfR+BPt0xK5BujSXHRLKSdIpEoq+Bwvdiug9JgRKEmcuxoh3cZDJEFIg+16zPvXR9CFiY0pT+RHS5DIi1XPMxqjVsIzmr8jvRT7/Qfg2XrQLCfGKjxmJyEU7MUbP4mgDkcvIfEYuR8hZQvLFIPfRz/vejuXRcsGejSU++qWv1O7/+YyWCTaUC5EhtvPQcfXaE9RP7/Kps8/jpaFVEp81bqJXiDpdzMa2ScxXmMQBzM8DER3eOqQybE/nPHqr5ad/+kW+5z+7yudfusO196bs1eFrr/Lqa5Kr+0/QsOD+mwecPITTA0PMLGClJURFGy3zouLkUcMLL6xQFYTgmW2n43Je8OrpmhvfYmm/9BxvPPoioYOH4YxMC6UtU9FMCP1YDkhEUmQoS6qcTi9KQSENRaERhaILlhgsRIHKUGWcK9blglf8GfZt8dbHl69ox+zRseZgdcids3N+44VskETEN+AcTGaRf/aTcPUJsKcdJktJtDi6yvKNT1/j3ptn/LpvOV00TLSGuqDMULkOEWsUlY7UlWY+LdndmbO/u83uPBnTaSHRJI2cGD1WeGSskup1aBMzEoCADAozrdFVRdASaWq2ZrtMdEplTnVBDAoVHJ5zzk4iXQCrBVWIyXiRNMa6jqTpQ4tHIURBFBGR81ttByZEgsuV5zXs7UZcJ3HrM6xIx2+JiLKicYGwTpVQzTIgLcRs5Guzg1jcYOe6ZqI6Pvuy5f7uKaG8zWK9QpvE75nXu2gZaFSJcgYplsmAYIi5CrQqZ+zWu2wVE0qpUyuXmaM7DzQxteaywaIcKCPYvhn5/KsnPLzruLpbgmoxIavPC8NubTi2DoRjLjQyREQIg/cUvUupzBjpcNiQKrJaPF2AXmfOERNsnQ3h4IxxybCN1hkbsCElwKYgAEaGb7xC/nPgsYwMOWS/5ss4U7Bx7sYpjXFpvxSb47uQp8hn26dHBr4JYih6CCRBYHsF4o0pxWzOZGeLQyOxKrUXAVA68HUf/hC/9pP/lP/m53+G9/1Hf4S67FChoWszby+6TAD3qdpUQKUFhcoCoBmvDxKMgUIlvhNSpMo9JZjUac6qaUlpCiZrwbGQTLd2+IPPfYyvftfX8fCVA46XLwOphF9KcDbiSK193MrQeIHMVZnRuERpkhWVvgJrjW8isZCc+pbzVXZI2hYta86WJ1RVYDY1iMwbGcRQFURL4i4FT+tCKkTwQJH35zJXx3gKLREhMpFw/hBePF9x/6lUxXbWKh4GwVddr7gxDxRdskVGgOzbSUkyAV9k7lFKc5ZGMammlFWuPBNd4iSpxDOL+eZGD13WkDJRUCtDjDXnuRF0wKNVKgIIIh1X55fEzEXoKzJdnlcbjgKYskCLJK1g+55lMBTfANgYWeBYKEcTPTMpEIWgc36wz56G0LUIm6QVYhRIUyTvHZV0z/LRiugxeHRwKDxSBIJwhD7QGJUqiiw85gYnLA4xRJD5IZydszBOG/4mz9LLzcT718vE/ss80nxIF9KOsOGjXigk4qKWWOjJAqO0aAoY42Y74u0dzKGzQV4n5Nu0EeTN8VPeTik29kX2nqBSdN5RTCt23nudV8VdTpYdyy2BlVt0mddnlabsfOJKSpBBYK3HyTCIpAufjlJNZgQh0LuSX375Dr/37h0WZ2eUWQDv2p7CRphNK+z6kNe/EDhZQFgIZPLduHYzsDgMnDYQXMO6hc990aEMbG1F5vNkcO7e9Xzt7z3h//H/+uf8yf/0j/Pzv/IazUHHpNic9yzHGIXL8oUhGVajFaVOovMAlSkoZYFWEo/HOYf1HTFoYuaYajVl0UjWIrBulo8PhLdZvqIds7u3A7/4ydf41c+2HCYbwkyBtCCn4ArPw1NwHQRpaIuEAJUW2lPYe3/gCXeT17tTztdLlFRElXRLAPAeZ2CuDLUpmU+m7G5ts7e9y840GbJKC1T0WL/A+0iLhKDxvqB1BiVzyG8USyvQ1lIKzawy7M+usV1fwWQHqKqqzIk6o7MOu+pweGSIdAIm+UHmo6DziTAtCofSkUJUrGzLuk0XIobURkIDTZt4Q1d3DVIGvGu4d5z0YU47j+oEe2ZCGSNTrehcYLVUrO8ks/WoOuP6ozPe0iVX5iecvHqbn3rhLte/+oSr82tMJ9kxm+7huhVVqZFTg2GSjHiniNlDKc2MuthiVhfUyiQR0rimaXfpTELCbsdTihhpJUxlpN5/kkcPHnK2XHCzmFJP0rU/OelQMeJkRMcOQ4HwFSpu+FfBR8L/h73/CrYlS+87sd9y6bY57npTvqq9b7gGmgRAgCAxBCboRxwwpFBoyJiRYh4mRPmQQtIDJYXcOHEkkWNIilSQHHJICiRhAiDQ3WwA3WjfXabLXVPXHrdtmuX0sNbe59zq5rBfpJiO6Iy4de7dJyszd2aub33r+/4mWgIRFz0DHitUdgKALK+EjWcYjEGcBcJNIgNnATRyDr+Vg/f5uLsJrh4oOMObeM5WwiGvcoPI1aN3BffzxzrvMiBzJqjPJXqb6ps4d67NfJKo/PGJY58XnAwkDOL5RFQC9mpJf2VE3J8i96eIKqbqS8ZyqeC5eGWfTz/9Hn79t3+bP/OpH+dj4hJr26HzOx9kQlxpklivKKAuSozqsTESQq6ieJskE6RGyoIQJdaBF4qpyhpfhWEUIAbDcTRcunSF6Y7ki+/8NpUcc2Ga7NxW7cuglyA1UhmcXIOJSe5mI5OFwnae4Ft2paIRe4z1RWIfOVqdcjhP5BL6BY1WLIdTVutTRlVNv7IIzjFHI3gfCSTV+aGP+CHrzmU7qUKBMhIXAtoFigpmElQJlYJZsvDl6+IhVQW33YT1XsN4MkYrhRJs6fgJUybThKoiMki8EgijaUxNlfE9gxmSZEUAX0BUKfEYljDvUoxYDy2WQAglZHJBUsNPL+gmbQnBIkTMyVVOeDypG3AuMSuKIom+ep/f26QplhT303vTR0snoJU2KfZLldiZZPcVSDidIdk0qZCqcyiNLEoEArEpcwmRyB1CoGJERYfEEaPD5czZbbVZ0su9Gashv/PbvC3kZC2PC39uHH0HTvO7VLefANmzvTzyrdweK7zrGFvCQR7DG3zpNlH6LknfeekcOPv7+QXdpmp+HhN3lsBxJq+T//1uTJMg/d7kayyU2oqXD8FTGsPVF27Sji3ffvwADLhosLonyNQBCr6lROCUAGcT6x2dvHJz8ua1xJQV0ktk6BjrfYajZ1ndW7IIc3weRJ01XKpBFndxpz2ik3ztq4Lnbp7JIFcmCcEfrlPC1trI8drRVJKygsP7KYN7/4cmvHlyj9/9YsF/+88dU5aKuyfw7E2FzVX8AFRSUQ2eAYENCctphEQjtuQyrRRaKCSCwfX0Q0ffW4gl7ZDuahk9cuwphop1cZ6/+y/fvq8Ts9e/nsCkb70J0ubHM4mUU80V4Wj2JD/yY4J73/K8djqw0Xab7gjWneKt+V3+xHs/zfVHF/iav0/vW3bUQWL4AV4WNE4SZMQXgqIo0sQiGzB5H93QeIdkQIQRUhj6csVgxwyiY5UD0nIdwGsaMWY02acpJ4yrMQeTMUVeKYoC2HFY12E7z+HiEXOf4op2sM4z6jTE1MoK0PceoTrWUTBbdqxym2JYpsFlFYQBjh7AtBLsTgbGVcTqtF9xCK6PrNolF/drykoRZaRwI5Y+JUond7/N1/0pfTnlcal5fPqYk9UdwlHBgVA019J1rZWjDQVaacIkoMTAJFzE9o7FegZAH1dQ70JZUxdTKglCFgx+wTqTCEZ9SVg6FCWzqsTIjssXDjh9dIhTji7T5Rd6RuUldJ5KlTx0ngu6R4ZImYOOlJLgJD4H6hLJED1OpFZO7tBsQf4b1W5HrjzFsxXnpjqGeFKcdbuiPPdvkROuTdtzC/A/t8rd7iMkG++1eO64sBGdPDv/VqLj/Ko6f+bzMUU409lS8uzcG6FKkStlfT6mJxEeF93NngABAABJREFUNpv9YEN/cURx5SLywhQvDTvS4E7m6EyMobJ0zZgf+viHeOuf3eO/+urv8/Gf/Hn82tL5NPG3/YCwBZKBKsJKSMy4Y5rv/YNVOqfpoXWAqtAYzNqxEp6mhkkW66/LhigVsYoUIlI6OLSwMHv8kb1PcLdNk8Fvzr/KWPR431Arg1WG3lSMKrYzbMDjezjp5rTuW+zoF9nXGjeHW4eHxD6N2VAqTo4f0y1XiOMl48owBEEzRLJEEV2ERlq0j9ghJSydB+8gi4PTSSh9QAArUtWqqiViCDQGhnxd3QyO1oJrYg/akmJUM24mlLpEiXTfvRPIwiOkREpN7xyFSN62U1VTjtJJOwU7I5iK9A4se/BrOBZsq1eD7JMfplEgPLKoacSIJT0Oh8wTsRsCQ/TI3DRyIibD9RDoNh3WomKiC0RR4fsBoUE4kmpY9KgM/g9F5FR5BtHRFwFOewiSQpTIIUv9xAo1LIjrPRhJdBTE4PFlgZSauDhNu4mCqDXee2TwmChobGCwliEvNZwzECVCKJyMBGGJMbXV5bsg+9tFjeBM8oMnpTA242371834PIdv2IzN80D7dydmLub9AgTObKDOxxTYLAqfLJ9/h63PuQCxkeHZVMO+W7FvsxB0nJPQOfd9C6mQwaNJtno+gnJ+G0cccOH6NcKzNd949DKtg6422EIRvMaFlExJowhEau/xIiJ7i1QDQhfI7DwjYwmioqoqGmmYrRr+0s//JQ6u/WXuvnqHvC5jJ3ooNcvHHdEZLl+3iK9eJB7eRNzPUlkvzJlcloz7U05DhS57Dh8Fdg4CtYNvvpxkYp774cg/+E/h+Zf2qesJetxjKkVZwDSPxUpDcJpCCpbeYftk1+VjIpAJv/G3jYiiRcXIuh2wbaDtW6RyDDknebxaE/pL9NHSh/8fKP//N21brwRvvOKYz84skqQILOcRNySW0WIBs3VBrAZ8jsym9GjnOD484Z0feswDeQSzgDmo8d1A0WRWSYx0RiG0Q48UTMZMypLJBGq1C4AQgVCtEXFK1APrU0NY7eP6R8ToWefItRwC01JQVxWjpmJc1jSmpNElRfbms/3AaFwx7QzjiaFuFGXr08Sr2PrNFWXCj0UXsRaGVYdzlvXSYzemzzqxdbUSiAKWy8idex3muRpp2o1gN80Uhi5F78eu45KRTGrFaN3T2HTCo/4d3rz9OocDXLm4y/21wZvADXMZLVco0km973LiIlGypq5LgltT+T2m2RD50fwxJ+sFBzv7jBpNESW6VHTAjk2WU/NesLIWHwcKJRmZkrJsCN5wuu746Q9cSc//ldscPhhoqwq7PGR3MFRFjYls1bNNjIkqTxKX3Hg1Wp5kXNpzQW0TbCM80VbYKvmfT67y/7/xoIO8/7uTuXdt5+A5bOxgNvtLzlbM51sgkrOE8N2Qlc11ybz63Yh8xvCu1TZnGmUAC5la9puqoHvWUFy5zPTGPuJgD70zphUkQUUj8FkZUlpB5xZMJjUfeeoFfu23Psvrn/gRnnIe36eE3sZAK5LUjBfp/RVFktUYvMfmu1+qxL4UPmBFRGiNKgwjEdgdp7EodZpthAyYKJj3LTN6dosx2Je59eALABSNQOsRLalqo9qAEQvqpmGIKWF0EYRM4+Px48ccrWGmlqgAdC0+2x8t1wPtySltt2DtHdXOBUzb0cd2O6OlVlakNIqid7QRypigA5sHr2V6z4oAEw2LCFZGlIE+q0UAqAEePZ5xcuMSu1VBDBrhk37X5oVIGlmBpKG1MeuWSKEwxlDXqZq8Mz5hfydhTU0F9RIWR5KlC8hccXI2tX59GLa+olprjDFIrbdVwWEYcD63aeymqpaYmRu2qBKaqqrOzMPPvfQhBNxGLkc61gQ6kRhuQwTlRBKN9TnAhQEfPMp2eNeT1q0SpUxi0JpsWB8GRLQIHEoGtAItBFpEdH7rPQFLzOxricoK7k6EbWUc0vvp4lm7b2sEsPl73r7Dhzx+5/h+QmPwyV2fgDactSvzczzXitxsafw/edJ47k/6jk9W0s7/frMA3Ox3Hh+7ZXnC1tM5EvHBp1gjU/wA8DIJxALsXtpl9MxVbvVvcm9p0aXCSehFIEpJlauvxMAQLdY7hA94leVeQtyuWoRKibeLgVgITk+OWVz+ZT5y4zkefgbqvXSoC41AlwODLPFHPXGt2K0e85H3/kleH/4GAIe34Uc/9GP4l1b86j/7LMtlWijNjwVhd5f64ASAr7wSWC3gF//0MwzxIUaXhNiCUlRZXb4wGj9IAgLrHW2XCiEqla9x+al7n+ZcryLduqXrwVkLUhKzk8VqtcDZjYzGd5sNvnP7vk7MujZp0pSlRmUvtr5P0girBZSnkt/5PU9dDZQFVPmmSGPw0tKtLPeOOg5XAVcWzGzPtNFbayBTT7kwKPxU0vgpl4KimO4ngcMiyWVMokYXz7BePmTwR0y0wgeHKR2z0zVuSMe6qA2NKZkoRRUFjTZUpkRqhchAQu8itdGMa0UzVphSofFJDqOJm1NSGCi0JOIJFoKD4CJKCKpcKtIeohAMQqJsRKvIw0MQRcdT186CSWehtclvcrmK+Jnn6g5MRw6Tk6mjPnLYSxZvB169c8Kxl+xeqWmqQKHctmdjM3g/ogkUaFNQVJq6azjIPp/L5Zrj0zkPxoeMjGakCtwgkMoyrtOLPKkL1q0mBstIV+ypPWq5z24x4kPXX+TDO68C8Ldv36M/iOyeOnxfs54otHUUPpLtBzEh+ejFTVJGSsqGc38gBeXzgP8nKmL5p+BJOrrKn8UcwN5tesy5Y2yrXOci7/mVNZwlS+fPvR3GeRLR+QTnzxHPfbY5fErMxFaJenPczXfcfAcZNFY45PXcVry5T/nMPvrqLna3Qo9qailprSV6RZdLRQ2R1fqEemJ474sv8CufeYV/+Puf599978cRWZYmOE/nHcUWvxMQQlLEwFqcCTkrBU6mRE5GiZYSpRRloZiOU6KhTap0CQfSw+PFnEO75kpxkd9d3eczj5N596VLDi07ZOeJNqCVoGgqQj/b6kFZm+UQomJwCnrPSnSMpUZ6Sch2K2HwBC8gCHobaKqCqxducu/x7UROyM/Vx0TOKcRZlUwJcDkmxVIzip66jKgx1DYlc0FA0GfJwShK1ouBR6drbu7u4UNF6ATOet5thyelTHIsIiKIGWdWsjNN2NFwABf3U1LmcvVkWAZGVZKQAHC+Q8YAmQQTSUbnUaTEdZNktX3Hqk+ajL0ISXbIB4LzbFyUlJKowmBMiRPJ7cHHgBYCZy2+yDHOOlYCVsoxyEinBE0UyKLBu4S/UbkjoJxN5X65SRokXkRU3PzbQXAIHEZGCumppKcibMd1Ig94+igQKAzJ0sfFBO/YGOTYeNb68+fG8bsxZvJdP//rptl3e2N+ty2wSZASmN/F79znu+mMnW9lbhKzJ8hA5xaT/7JNwjauyJxxmpDal7VKcIkGRcDTEsmcOMb7+7TTgTcfP6KXAiqNJSCjeqLNP0SH9w4ZAtF5gpGoXMrbaAVKoTBoHJJh0Agl+f/8/lf5A3/wo0QEX38zHWs6BD74YTh8qLkz73k08+yWJR//9AGfv5/mny+87njlv/gVfviF9/EXf+nf4/O/9/f4p1+8zbNXFdNqj3A9Lcze+daaj3zseS7uj4Ce46M24V8rTdjgQkODKiS9H2itpmoEqpN4b3FBbN9nJSTOOYSPDMNA3wes85kck/Y5On5E16akzG9bJ//12/d1Yta7iHbgcawyxizGZO+iVGIzHi8FL1ySFOtImKeAZGLSIloMjtsPVqy95+pBhSgtHY79jZJ9A2FHUoQpo0lF30xA1pjo2WvTrbt0eo1rf7jgQrzG737927y2+AJWrBiGASMkZY66ZSFRjUZWGoxAZCVuJSRVFmKJKtAPIqnXG4HbqHiLiDFQb4K+ThOVEIpgPZ1KK1+D3I5IP/ikT0bA2khVCVofefQwsWx2L2eNohDxXtAPHlnAw3uSdhG4eS2yw9k9LTx0bUri1ngOLgacsxhVMnQJjJ+UtBUhSGK0CGnRRY+uBEUmS0y6SyxO7nP4+ISpVsR6QnACbTxCpOOU2uOEJeCpVaSpFZNyl7K+iCkecGeZ9osHz1KJtzmdr2mmBav5jBBKyhBxOcKVMaAISSBQnNksDeIsGYNcMdskZ/EsaXhCjyj/5TxODM6SrScUtnNAVzzZxvhuopDnddK2djb5d+rc/pvAe6Zotbmws+tT4qy16mMC+28GuRUgokSQgPGWgFAOfWOM/HDqRzc3LmEvjwm7NcFIYqEoZMGQRXo3J/Mygh1YLBZMdqf8xI3n+fXf+Cyf2tnFt6llUPuQ2aiCgkgnQMVAKXKV7Jy+ko0ki5gQsIMnBEdZJQwnQCk1g5BgHd4JHq7nzJcLXDmwigWnR1lYVa+5tlvR+pLB9nRxgYyO0sStvRMBelvivUGJElMcoJhAMDi1JqoU6EcBWiPRK4OyFjOVvO/FD+Jx3D18c/vQlFYoJWgKmAeJrkfUVcPOblpJSeNwdKjo8ayIk5ZwnBIDERUie/NqI4lD4P7tBzw82GV/LHEuYkPcqq4XEQQGGRRxAzjEI4TAmJK9aU5kL8KFvfQuPZgnlyMfYTQCk1nNLi4RIqKlRimFziKygYCUkmqTTHnPaui2lS/vU2brfNIDAzAiVdqKosjqYmnzBIT3+Fz6jQoWeFbGs1IDu0alXqsqQNX5f0qYPdl1hL5FxWQkjRAJt7hlvaSkUuJReAoRcUQqwpmvbfSpZS9gwOGjTI4WPNmijDy5aNm2H8+Pf56skicm5mYB9N1qZO/azo99cS5+xHPYr3f9L+ftlDb7CpH8NLdt0e9yqm3iGBNTE1IceHfyJkWumOVfVEAJVDHfXjyF0qxLTzlOmZncG/Pa6tsctxBqw1rCsJGRjhEXM/g/J2UqSvqNkHcUiaUtUlRSSp1V62mZFJJ/9tmH/LlPvM6j12reOE7/3/HxwMlcEXbWXLqiYagY7054/ubT7NRpvi6iQ8maX//Wy/T/Rc3/+H/2l1nX/3M+/1tvs7MDOe/n2Ao+9Ow+fhD0Q8aTUlI0E1Ab/csAXuGdyI42BgX0bZ9YvSG9q4Pr8B68HxjckOaX4IgEhgwZmM1mOKsoygql1t/laX3n9n2dmEklKIo0+TRNCiLLZWLlCB3pY2S/AjsTrIMnE7zYG0kWq8jSw5v3b9EOHc9ozVArrsdqa3wbVYvaLZgu4ZkX3sef/PkDXv7cfQ51xcVHxwBcObjIs5cH7KnArTXBW3olOW4D1nvqJmFfopE4YREFyFLgZWAIFoffnq9sSmRX4p2it4GoBFSpQqBlAgoDFFpQaIUQYLWn0KCkQZSCDfhlkMkzzghN5zyDjUxHyWdvPYtnLC8t6PvIuodSg7eK07llDVzPKyQdYTEPzAfQusLR4eyaNi6x7hJ9OE3fUZY4URNQoBQ+SoQ0VEVkyLPwzniHft2z6mfMB0dd9UyUIXi1FbcslEMbSecdIi5AlZRjhxaKt46a1PMBiksFI3uV/vpjVsdzmmHEwg8MMRIyeMyJgCY5O5ADX0+apDbtTHgSUBs4s2GKPLk63rQgRQ5o51sK54Ped6j053N8h3DsuVXyRvtcxrOV7mY/KXLFK57DhWxOli9Sn2tLJAHQJzEyguQNmLBogkHB1R+Z4q7t0x6k9vD4g09hiwm+AKM8MTgiSbpBa79V67eDJ6Lpuo5YwUduPs1rv3uPv/P5z/BTzz4DQO2SsKx3ESlFer7qbIW/qWqUQVDYmHBa3mOjAxlpjNgGcGNKBicQUROl4nQ5Zzlf4K9IWh+wbQp4rx0e84Fnr9CG1LCVRlBaRavrrXhkJUtUOcbIEh8bGn0REyqiTQmQlPmOSQkhIvGI6OmGEy5duMgL3fP0MQX0xdFjAgkQXdcFNQXj3cu8cO0aT19KVeKyMXRSMFsveXzyFnP7AGWPaVcgUFvZkDY6lIeTB0u+efsBV/ducNVavDiTdZBIRCBLQChiDEQtEBKMNIxHeSW/C80I+jX0Fgabkt9Cg8jjx8pAlAohFFIppJTbakc4V6Kx1rLsV9h8LdY7hAs4Fwj5fXBabhOzZCcnCDKPE3eGa1NBMpeWhbb0wmLLVKUISkIWx0YNSFURh4CwPcF2KBHxcUBIvfVYVSpN8FGIDNIXFECdFek3mxOJ5DPESId/Alh/xu88YzCer1BFzmE0OYMTeEDE1IKUxLNFEmex4QksGmfJ3vazfJ5NQug4IwHE8//ju7eclJ0H/2+27xAmFXFbiffxyd8Lsqo/iX0JSSLCkCpnSglMY4iNxHpHmWWRHuhj7vYLBi3xpUrJiEjXr2PY4vNMFHgSkztIgQoRHz3S6GyKnlrgDoFGMFhPUxY8Wq75h//8bT557SUW4h4Ad+4+4mu3IjevRz78nOekXtLpMQ8f/xbf/ErKuFyQqGnL1Ah+9Wtfgr/8f+W/+2//ad6+/Z8gvWOVJVu60NM0NXu7DaYqiUFSjDymCUg2wsQlq7VDKY3WgcIIqJKxuR8iLhcjVqGHaLB2wOEIUqYCBX4rXrxarXBDQdM02+LDv2r7gcDsD7YfbD/YfrD9YPvB9oPtB9t/Q7bv64qZ9RHjQRmomrQKXK9T33o8jcgY0N6w7iyDhoPdxGzqvWC+aCkKzenDHucDxUjTyEBRQG9SBt5QULmGnfdd5kOfVTzz05YP/6lP8vf+3ltUL2bmWfUy66N9Pv+Vx9x+eMSCI0KlKOsJUjjKvAp0tNQmaRKJ6InS4aVj7VuK3DptRiNG/YST1QTlSyopKRsIbWpRZXsujEleXcJFvIm4wlPoEi00MessdSqgpGftA64oUV2HloG6KBl6gekzndy7xFBSgn6AGD3tAtw74EYbtkHA9QnPx+BxGm4ETd/XHM9PENmpuygnBOGJogKj6F0kxJJCiK28SB0Nsg3cX1g651gFy6SUqCC3UgxNJTgY1/jQEd1AOyywfsbFK++h0VNO7n0QgHHzGLcjsAOoxZJh1+O8oCdgc+Ww12n1V8Sz9uBmdbphTW0+C+d+dx7ku1lxinOr2U2L43xh7AmT8vx7HTnTPhPfZQEcwcvEpNxU5yJnuKMonrzuzap9c87Ndtb6TCt4Gc+uaVu9O1cxJAbcHrxdzNHa8MyzlwEor96gFgVWWlwc6JYznLMJeB3EtrW68h2gEMFSOoHZ3eWHnnqJX3nl8xzki4muR5BYilLEpFogwOqkM7j5Hj6m1gzBI2JMVmaFRpsek4HEUWREn0gPolu3HM9PiYUgxoLdy6k6deuOYb4eMdHH3NcrYjAsa4mWBxQutSgLWeKHESFqymqMsQ1YgS0F0jas8p1tRaBGM5iIU7CYPWQ5nHL90g2O20MA3HpOux7wMVJVmrGqefE9H+DTLz3LhYyPK5spsjQc9Yc8Wh1w9OgOrxSvcnrvMavZsDVX70jYnujhrbsPePzMCS9Mr2emcL6nMW6rWlLKXIlKVkpSSkZlOqcoQQtY2YQjHRzZhoit3EQQmoiGmCqszoWEHRNJFsVmEkeUkq7vGWzHQKB3Hdgs3JwZ5UqCkhpdFlsJmK2LVIhbQd7gPE7BvBpoVYctC1wtUUaiN0Bh3yHKGuFahG0ThiI4RPDIrCeVTloStQSp0/cXkVJJgjrD8rgADalaNuTKzqYafh4+dr6itW0x5n221cpz4/CMoBO/A+LwL9vOC7pu4sMWY8Z3b2e+u7W5qao/UakXPGHPttnn3TI+6l3Xp+MZUShzralISYFWsH+wT3mt4fFwRONHdLllc6gfMVspaMCKSOdFvkepJbAlQYnkWxpDQGiBjoooJdqU6CyNIoTIMS7S25JKCcZVx2dvVfzSzwo+8YfeC8DHPrrDf/C3b/Pa7cjQDcwdXNyVvHr3TY5yEUo0gcooKiRmx/K517/Ayf/+mzz/3hcYH0y4vbydvqPfYTyWNFXN7s4ldi5EmnGBqSH0KY9oVwakInQCEyXjKtJpS4iK6AMxs+wGB0oqHIEo0thJjqxx20JeLTtWi8DooiF+18bzd27f14nZqIGmEQhKIimIhAGKIqDrSBhg2Wn2ppaJgapKQ/KdE1g5uLATePk00vuBKBX7zZQlUMi030RK7nb3+e/8kZ/mxvBR/ur/8J/yS/+7CfH0Jm2RxIfee13hm7tMi0vsXgq8/OaMa5ev43vQTYkN6UE7bShVEoOURhBl0iES0eKHlOTJUrOzO2Jwl7m+mHM6P6Uf9Sx9IMYkMQCgdWo7SBUpy5IoBLUZozD4/GJpoWlUi3A9nfeM9yuE7SlERJkC1vlg/kwtOhCROmIi0CvmQ6YE64TLU7VAecuogv2yRKwlq7DaMrPcoNCFIOqIc0W6Xi2pJ3vsFWnijJWgQdG5Ba1fo30A71BGEfNUPRIlF0TyHOuWhvXQM1vcYzy6z6WDj9PN0/1azRWj/RHWtbT+AsPhY8zIMhDIl84QoAoJI6fDRp9M4EV8AmO2Bf7mf2+kJ84H062SPjlBOtfyOA8S3hg5bwL8JtieD6bnt037ZMPmUvAEkeB8PN3MdepcMNYker8kMSxl/mwrfnnuAOuY8EBt6YjXDG1T4SQ8vZeO3Iz2kMLSxUjoO6RReO8QuiA6SchYISEUNgYKHOs+IKuBG9cvcvnePl998246VjkwMjDzIExy0tBaoYzHhIQzAxIe0CTspIkefGapGYnIvpZdCIggcSpSSEURHLdnD+m6JcKu8E062PJI8+praz78/JgaiXBrglCgdiiyMHGFxvoSGQyNatDBIKVhGXt8URJiCrouKkoh8VLSaEE/RB6sbvPS7gdpqtTWGY9qCJEWDwp2mpKLk4ZrzzzFMzsXAJhOp3TFwIVVwzOrHd4e7RCl4SR8lbeHh8Tu7CX0QiBlJLaOtu0h4882rTkhRAblJ+mHJCidPlMSCpmua6AhsE4C1S4xMK1PbaotbEkn/4gE9D8zKdc6/f1MsyHSDS1tv8ai6IeeMCR1c5E1iGIsk5F5WSKUJEZ/TqMvblmZkUAUioUY6GWg1wFXC6pRg2w34M01wYlERmiXiPUaOfSIkHTbwkb3SCmkMkQlQQqUTG4HpRJbUoJL7Iwt6SdEWJCS4ESeOBuDmxb7RnoG3iXsKs6gBk8kYeLJifSJpOtcorWBM+RH/QSb0osnQfznVf3hSWbl+Wt74icb4pDA5Gzx/GVuF5h5/4qzduYo71jFs9hxYWdC9cIFlg/muNBwlMlnR8NA1DVr2dH5jiFUSCkxJHD/sNWs8xACRYAYJV5AmefBM8DnhigS0TI5s0y04LVHj/j26Ut86pn7APyJX7jES+/9k/yt3/pthv5r7LgOEXpm7R4xJ2aXpxJRgy8s0Qn2peTle2vW81f5oU99EhvT4tOyZG/3Evs7e9h4j3IcUdrQDoZyA9XQgTCUMEQaoVCTyMK1tIMF45BbZxmyoYnA+bS4CT5r/uV72vc9i3nH5LJEfC9YRL7PE7NKVUg63GCJG3YQKVgWYoc7j+6jhoH5Q9jZh1nGHfVdpJxI2gBHxx2jBtpCcKp3+ZmDj/H55ZcBODUz6jBwYbFm+smSq7P38Mt/41Xe/5PvYX94BoD/4H/1Wf71f/sd3vcRz8MvtAyrgt2dCyitsaHF5PVZf7qGfkE5KtFlSQyewfUMpWa50WBrBTuVZne6x/XdSxye3uW4X2GHAQZPm5+ptRYZA4UwVFVDUVaMq31kMAxtivJaDii1RIklS9kyLsdUfkq7HPA2bKm8pdglDi39ukU1SU+nLCB2SXMJoChA+RT8CxPZaWBkDC5WmAAuX7+zA81YAYrOR6gaypHiwu4VpiIB1rwdqA4k92YN/bBCKYVVgUpV2yhUFJGJ9KyaCuEMg12z7N/h0enb1PUe+1ebfLum1KEnlAWzsaRp9+iG44SFyW+20CAKiFbgXUTl6s15w2I4C4qbwLgRIz2fTW2UuiUprpzHgjyhGXQuGFueDPZPUO83u8cnP9skZ5stvGufFIAFG25IAVt5gA2jU+bE7QyuD0KB1NAWgfUY9m7scu1HXqKbKtoy1xEqi/E1LkCgwxQl3dCDTGKhMldHlS1xoaPRkntti/dLLk12+eFn38dvvPUlAOp+oBKSlQpoIYk+UESBEVAh8PlYwkZqBY2R2OBouzVDsJTI7Y3sg8NEg5VQEKik5H53ynqxoDaOIuO0iqngG1/4Imu1zwefeYHTAmyouFBVFNnTzKCxUYOskD14C83OCBkaVm5BsVncuB5tQPeKSahp9ZrHy/u8tPd+6jq9zzs7E4auZdF7vLdc3tuh3m2pSkmxl5KkYq+iEpLJ+Dr9cJmy3GfwgYftIcvVMasH6bqkk2mhoCIXimTWraXJYqGbhFiAEkQRcNGhRKL0J5IF6Cz1QZgS/XrLGPYu/dTqjDkqdUTIZIsklMQojdaJCCARqHxOi6fruoSVCZp1vyYMFpVJA5AwZIVKGDOpFNGGc1ZBkbipmMmAC4p1DImJKyK+1Ki6TFRygJCSraiTXpqyHcPQUziPl4ENakpqSdQSIWVOXAJKRJTw2wVMpTU+OqoQGZPYv11I9ueWc1Xb82Pt3Njz5xZYIv/ZLNTOMx7/VfPtdzNEf3eVbht7zu23WSDCWUyI5363+fx8JT2SdBw3JKLNKSVZFkOkg5RkPJmADCemBAop6Fyk8IHRjV2KfoQYKk7L5MzyaA5t0bGqSrpVRxM1NoJTqaK7jTchseG9Figf8FpTS5V06TZM2xhRETo8tTY4NyBpEOEh/+kvH/LTP/kSAKv5N3nvtV/mI+/d46/+/SUf+8glptFxevqALCmGEIHFSjESknIkaGee6VXJ4ZHlW79/m//t//onAfiP/vGv0dR77O/XHK0GZkeavd2aINbM50U+VqTrenSYUlU1shzoVYeqNLoTRLcRoZd5UWSwvkeERG6IQqQElCSDtVr2+I1bxvewfV8nZo9XHTemI2bdKqHjgUkTcK3h+ZMXGN3YQ44Dj+/fo1+uYT/fzFhwoYg87izBwmIFot/hD+59mh8Jz3LtfR8G4DfufJ7F+G2GozGPes3Nn/w47Z3AVW2ZfCIF7xfuNvzOP3iRZ/4XMNk7odjtaMZXeb7Z4VicQm4HrB5d4MHDNxl9+JSAx1pL1w304RRjdgAY+hMWkwdoOcaPYTTaY/fkiD62rGSqBgIcd4GJgUIGJqOaXT1hv2yQaBa5dVqZGcNsTEvgxmgKQ1Llbsc98qFlkYd1Z2e0vafvwVhFM3FJaLXxTDcMPJN0gYyKTBu40mhMtYOSA72PjLNtRnCBbt1SaYOuK5rpATcuXOGA0dbCotsteGzXFJOacVtigqMODdGEraClx6GiYdKMcG5GdyoJHZwMd9FDyaWdp9KzNp45HupdpmvHorlPtIoweGImeA0AWmL7iG5hr9HYhcNJEEHhNurmSHqS554AZG77bERnIQU6SQ7UAaIQ2wlTb6I2aRV1NvwEAgExbBmacL7yJhhITNknWpTnJgONRCZKRVLiBorsYUf+bPPHc3aNHnCjMzbvugEqWLtAV8JKCy7slOjGM9YpidCrgbaArm+R1hOlwihD33c0xmzbW00ZaDtN6w7YZY2VK6xccvOpS7zvNAXT12dfZyyXnChNNzh0hA6XFssqojZVDQ1ms0wPhkW3ZtW3WH+Ay+J9/RKsaqG1tFoyEx17ds5h/4D9ap9JlZL1/brm7mTMZ7+4oBoecuP5XVxU6f31mYjjJAJNozWn3YogFftYirJgdrpiNKRjSeOQeo01mrVZMwHuHz+iv2LZzcqXjzBEISk9zJeWpyrBSAhOuhMOuAhAGVLbMZYGWUh2XMmVbp+XVtdYrU/51uptAOYxIAbQA8wbz26RfAWd1BRhw6J2qCy1ISJEnzSgEA4fA0VIlekqvkS10pzKu6wHmHUgO3BT2M33XYaIlB3OXGRkShaiwBiDcREtDLpMDS7bt6yHFYv1koBm2a8JAQojaWwG4kdDVemk9ygNIgSE8TgPEk+Zs5t59CA9sxiZYencCmkGoqnxJg1a5UZIOcEqhxk8Zr3AdXOi6xKbPZdag5YoDE5XuEIjbDIYU8ow0lnklCQm68it4ghTUjt/xZmQrCUviN4FHRhIjOLNkJSwZTOeXywN55M0nvxdJGvnndvnfAvzvG7apiK23VUIHPGJ5HAzaYtzP1WWyxUkJvQgUquyeFc1TxMpYkrmlMgMzFKi+kwGEylZqiWs3AkT/zrVXkF0A/cfpaV6qAQulvRtgKFEaIFyHhNEEjTPGWOUCukEUiqcGZCU2CiSokBOM60NSNFjlGGIc2pRMdczdqcjfvtbt/nmK9cB+OgHLnLPvsGLF09pRk/zf/w/3+JyBT/76etcupDuxGmIqU68jDDW9Aaaw4jYF3zz/h2+8aW3AfhTf+xj9DYwqhsezU9ACmTUCKtBpEm2n9cpMPlAtRvxRUTaglG0aOUYstxHdAZ0es9j0AgESmV3jnzPV8WK2Tue8icuE3md72X7vk7MzAC9W3GjFjzOZfBOw6c/8lHe+LXHFO6Il566zOiFH2XVHvL20RsAHPeW0XjN4lCwWkTUjuSgWPHnfukTfObLM37iky8CsPuU59dfHRii4T1TRzm+xYVLgl/95RMuxCMA/tS/81F+43/wDl/+0stc/9Dr1F9tuba7yzXRoNuBg6d+OJ1TH1A8dRvR1KzmHTYGOtfhuznBpDf5kesZHRfU5W62eWlpyoamGrBdt2XvdQ60DZhCoEaKsq6oRiNk77IfHFhZ0OuI6C/QO4EtBlwoGal9yist+iS9Nm8/6FkMiRDVDw63gvEYmgjrJ5Z3iWIlCk81rbAhMl8NiDJis9K7kOB6i24DB5MJB5OCRksqP6LPrebD/j6Hq9vs7kiMNgzesdSRPeWRG/sqpwlSUhtPW0gmoxF2qHB2oO8WnJIYsZN6hCk03thkB6NLbD3GjyIxt628sNAHnIiJZCcd1y5q7h86BJ5prjutcChEbnnEM8zYuQC8qQBsEqzz2lL+XX3K86vX1MD5TrxIOnxkI0qvOGtHbI8jUgK+9YCXqVWj8koTUpvZCFL1x0CsEmM5KEEYRXyRNcp2ayAQ+w4rLLZU2FIxo2MY0j21tuFCcY2mLLAqtaCiL9MKMYRthURoycg6ejcgZED0MGiLlJ6b11PL4Mg9YBlWTIbAvBSEIWIamFQJU1fl5Ll1WZ7ERVwY6Ie0soxe0m0U43XAW4+LPUezlm7V4xW0yxPC5csUeTLY3a1Yr3eZH93lN756iz+6M+XgwjTNVHZDhZeY0hC6EmUtWi4R9oDB9zg3IEx6AiImtpj2EqKkl7BarDhczriQW5lNPaUsj2nbFumg6wZGoxGC1IoHCCEJr4bBIaRHFordZsqNCxc5mV9kmKUJ77Y85PjIYz04GSiNRitBjB4X0qIyxCxUKQUiCmJsUUKCsEh1hj+zTlLFHaRvMe4o2ZUKEB5GZW4/CgNUDN2a3tpkneQcqMSw1JvS2gBd17HuVkRR0LVrYpTISiVzcTLGTClKodCIZBGcy0GKswRIB2gKTRc9rVPYoPClQo3E2U6qAG1QZUW0K/zQw9AShh5ZlMScDAYiUmmkNihlQFqkTEbyG3ybEQmHV8aUlDkS5mxMSoBm5yjOUoiEMZXxCfHX8zI4wFZV/4nPeLLSFt/1O/muz/5l26b6dSarE7etSnHuD7xbUy0t3DQJEqD0WTVsA8kTOUnTIRUnC6kQ3iND2MoDGSWRMRJcpJ2veevrS4brl3hwMsflLksXA9Z1xKgJumApewpTgJA4a4lZB9BEiMKDtxg5RiuNkhCGmBargNIGqRxdbzHeM5ORSeiJRY0YVvxf/l9vA/DX/6P3MLn/gDv+Fj//Ex/itW/d4pXX4Ru33+GZGyl5+83fv0tzMdIJwb5xVEGzLiTGDly4KPibv/y7APxJ/eN84NOWopxzOjuiXUjMrsOuFHaV+rXWrjEGppMxPrSsV0twgug8pSno17k1bwxRwEBiTnsJISSXALlZeYqBew/uEIcPYcrvLeX6vk7MdJ0MvTGR5TK9vQdP7fHU05f5u4vf4sHrkvLBI56dfo1PvL/g4OkURJrGE92YO37NUBlUb7k3zNm5cMrHP7HLmoQf++FPTtCTp5CDZzn+HR67Vxg99QIf/sQe7/3Y0+kiwhj1oZf5yudm/OhLFaa4xSVVsiNL+voqZvKxtFv4FXb3NadLSfQ91lm6tcVHx2qZJkWNpwsR4hxIE39ldhnXksEd0g050NsEqG4KldoOUqMxCBEYZ4ZA1CMWUaAKhbORhfUczwJ4Q2EkexeTPMLdVYvxa0ybRBldgKGH5kxsPAPTI72PdBZW657BH1OpEq0LyjoFeimS0nOFZmfXYgaP1p5T7nHvOOGO7h6/gQtLpmaH0pgExlUO70GptEKPQCkNviiodQW1wiqHLSy1MYjcFunbFUpBXSpCXREnu/QxJCp/LnN5sSaoDl2AcikYP7AOLaCXgnlub0nOVs2aRLHftCrPtw4UCQe0aS8JIRDEJ6jo59sKcvtHbL9b+hm3LUei3K50VVIa2bbmtYSiFBgV0SrRy3WMiOC3GC0dBW6IFA3EiSSWEqoqiT/WYHPlYxiP6YaerhcMwdEbgdWK3ll6ewKAW8DETKiqiqLQiBiIpsDqnr5vkXmyjsLTCM0gXKr4eU3XBoaiZX8/JS3PL57jlXDCsj9FkISOKwHj2uCUp8zWZ4NIFigmKAiW3jqiEnincRuV1gEG71n4OQ8XCxZLmPqe6HuGwVNPU0Adn9QIBVfLim+edvzaV27xb3zyo2AMIgurlsLgBw1SU+kSnKUXDqcgDD1Gp/tVRok1mh6PjZEYFaH33Dp5wI2nP5T2KRpMUaA0SAuL0wVRBIyS2+AaAmgp80QVsDolR5d2p9y8eIFumcb7ys9ZLls6l99DrdFJqZIht048ERcDOupUuYkqv6PpnSRrSLngkWpCHS0XylPuKc/QJ6P4ypg8Xg3eKVReDwkliVKk5HtLLIAYArbrWa4XhGBYdz1CaApVbzOUlBAGKpP0zLxIFSslMsA7a5XI1NzCek0nKjpV0CsPtThTex1KoiqgrAldi+971GpN7FpEM9767gqpiaZASI2I8oyIIwRKbaAaIEMil/iMV9gkSRpBx6ZKn76K5axFCHyH8r/nSWKQfNfPNDbOwfPIa9pNW3ezD2fx5d1QBcFZjsqmSn++fRrOQPubc28rYCZpXJoixw75rsQsJDKUtiJhbmWKd1lXNRGH8vWtWs/sHRA7Sx6etizz2Hcy6df5KAlaUgtJdJZBQEHSywQwThCrBiUVwkl0kVwZZFRnpCvX4YIlRkUvC5S2WKc4Xa25Oqn5td9Jchn/p/944Bd/6g2+9rrlhRtT3v/cCFkMrFfHxKv7AEwPNHbtqCaROAi08qjS4bxBSotKYYlXXznloz9m8OaLfP2L0Jgd7ApEqJivkgxOt4Z6Ylm3j2nGko41g5OICG3bU+YFSRgkg0tzjhSCJPecOiEqR4DKFLzz8Db0n0JvBtu/Yvu+Tsyu3oR2BuODETczrWRvbPnyK9/kfoSnLgeCAy7C1/uB+p30wty4Yqml5ODt6/zC8z/DuOv4zOk/4e9+9av84kd/hpl5CwDr3+G9z3o6/xS3T76FLd+gUJfZ37/AZ/7pNwH45J9q2Xnx2zw3afi1f2Sxu7DDinV8lsmzz2IfbgTljhmvaxbdkGxWujW0EGSJ61IwFacN63FLzB6VRk0o1IjdpiI4z8w/BmBhA22AhXfsx0h0EteBRm9Vy4UTQENZWApjWbcFLqwgzPFql5P5BhviGNUSQvKji0HRdYFORoo6BfCATaKkMQlVzo4dwjhG9YD0JTL3yqLz+Ci4ZEbIqHD9ikfdwFvd2zw+TAMsuDm7OxUr6yCMERgKeqxvMpA5rb611JRUVMWAJ9nX1KKk0WO0SyPs5PExo1IjJMgCyqKmLJOJtBvOgrerFH7w6CFh54prkmEFzEKq5QNFrwjaY/wm8CZu0waLlg4GIWuEncXNjKk4F4w3gyrXXZBAISJaPhn0k6Bn0lqT8gwPK1LOkO5FAVIlnJwqQMuYV7ls3S4EkmHpoZFMLl4iCklblTgNoYnIzIAyombuDf3SslrOoVujo2anmNCtsrbVYs7cnODcGGOK5AsnQWlNHM4BmQMoGdEG2tajVIH1FuEdIZvRX7p6iW54ltnqqzRDYKWTBQ9KMyoFVZnOuZJgtMBggAGXV5xuUPisni2DwwVD27ash57Ow8nJCStlma7arZ7WWE9oRgWn9Yj9heXw1Tm/vXOLP/qhjxPW6eprUxDNhBgFoQ60qxJdQK9VwihlwJDwgYBkCBGPwKMoRODR6nirUVQVE6qiRikwEuanCxbtmtqoRGQgg4OVRElBiI7gAlF7yrpgMqnZO0ilw71lwemsTQuvrP0lQ8TbAZuBNDFGgk/i2YRNZSi1TpQy24qmJSC1YiLHTJo9muIQt04+lz7j8YRMSZeWEl2UmLJArVWyfyMQwhmDzDnHrFsSgmI9eIwsqUrL4DZiogNDFKiqQJSGoGXCaYZ0zq0cqpagKqaqQNdjgtHEU4EvK+ROvu8tCFmCcYSixLsB0y6J/QJvx8iMo4tAkCJBCnJ2E2MEEbauBYrUd6xFxEq2WnYuj92Qs645kRmZURjPqlKBJ9X4RS5ZbVnT5z7fjo3zjYZcXdvEke9WNdtUwb7b7/S77URiWtRJzhIuIxN2sNBgyoQhVGWyQDXyjI2pAigP2qaWphh8WuzFc2LPPiJCYnd3MVJJycnDDmuhLzbV2LQ4DUJiAlTBEGWg92tidsGAFLeCkLjgqYoSY3S283IUeafBWby3GK1xMSBWHSsspQrMrcWM0xz0v/wrr/MrvxP58R+Dqwe/y8WrgW++UvLRD1xivJcWZYv1Dl/7yhF7QLeOTK4ZTLC0c4/HMG3Swz/kNb7y+2Nudl/js79R8BN/6IBu/Q6+N1tR2MPlnPkJ7JaaSa+JZqDvIrUaIVAMGa9eeJHISmzeCUHM79Gmem0KxdHiMe28QG5ElP8V2/d1YjZWBXFkeZsV9ZX0oNU08OrnXuZAQZjAxR0oaoPCcmmcAs2sVzw8afljPyr4Ey/9OPN7j3jvqkQVGui4lpW/WzHjWN6jNyueffq9tOsWEZeEm547f2cPgP0X7jC5Bu90X+MbvzpgnoZ7H6m4eeN9DHHg+ChV3y5PC5Zzw0QV+Dhh7taURNr1kGwqALV6hj7OKcsFptBoVVJWkrJRwCWCT9n8ql/Te1gPYIPHBU9vwUZFyA1+ERxeaOoGBhRu7ZHyhFB0LIcC36cJr3B7DHHBUC7xQBEDZjBYC0WVrquSChEES+sQhSQMoKJi5RwxDLice0pMkgp8NOfR/im4gdPlkleP7mzFcfcaQd8OyKbExY4YOmJvGUpBGNJ9r3SZKlESpIngPMoodpuLSD9mZBLbrV1Zoh0QyiOlQGpBWZbYtWPI0SgWhig8Q9cm5ecAJzJw/eemrH6jZXGYJhYvFdp5tJT0MaRyPqkaJs/FxpCbk46EK9vEy/Or2hjPKmU6J246V8I2LRYpodABJRMgWWtBELnlqc4SM1GAKCL1KCVs4FFC451jyOj/ED1LBdVUsfPsU3hbwChZCwljEbkqqEKB7NZEZ3AraNcDw9JSHzTE3L5TZdIHX3UrpLUoKdFaoY1Eao3LFZJKKpyKKCdQiJRwSYWMbD0wJ+MRNy88x4PVPR6ePsRG6KxkKZJlVrbB5FiDEAoRU5XG+UhE4Z0kxA3FIa3jTa9wHgYJ3XzJvfaEGzpQ+HQw1UzYH+3zsLAU6pixgC986xYXd/b40evvAaB3cyoTqMKYVYjQ2CS+a3sKrfAZr9rPhlS9yxXadQBjDGt3Sp8rU9N6ysloRHGiGYxj1a55fHSIDxbXp4ERfYOLiugj3q/xQ8cyzmhDhyw90+zze3E6Zn1pTj9Elq7CEog+qe33GWBq/YBUmhAEEkUUXTLzlgIlCwqdFclDwlXtlAXj8S57zQnKeVobsdsVwICWlhBMrryf1X028hvp74lwtO5WWCfobKBQnrKsqYa0Iu7sQF2l4wit8DlhDCISc5UPEiu3LEtGYkwUBUFoLJ5eQWNyahMVUlcEbzFVjV2sCesFsl2ln/osxkVCFnlNCVqUEVw8E7TVAh0khYJKhDS+QuoMAFuCg+LMjqnfVrZExns9CRjLRastwQaypdN50BeZfBPPsKpCiDPhXvFka1JskkGZmXvnKlji3Hk2F6vlWSwx2QnGlIKiiigFxSi1MjUJLws5KRtS1UxkmxHlyQr/bE6Y3AC8IISIkZr5sUWOxSZtpXMQZYnSijIohtiipEb4zLzM5X6LoHAaVZVYLF4ohCiIWAaX3SecQwqNiJEQB9ZS0i01U+1pjafK7/21gxGf+b0TukXJld3AsxcFvy2XPJg9zR/+yPsB8OsFX//KEcYoZPSs55aDy+C7gAuBUKVxNtqzvHn3K7x1b8yyDdy767g4usSyvcfDeTrf0WNYOOiNY7V01BOVSwgeEQMyl3dtTLIYIQpsjNjoiVISUdj8rKNQHK0PeXx/TXXxIt/L9n2dmL19RzLej0yi5PB++iodAxdeDNy+I7nxvODadc/bL3varmB8kLXOgme+hs+9eZt/7998g9XpdW66isWko9h3/P3/Wzq+nDzHx3/+ADHXPLLXGV2UMP4WZtTzzI/PALh/csLN8ZxyXOD9QNQvcO2Fn8MPkvuPj1CrNwFo9mvmxYi6X7FGU+qSMDhU0Mg6JSQf+fGP8vvf+ja6VpSVoixrisomORBfMfQJi/IwvImwEB2IGAjS0QlHQUXMyzsvLYUGqUqik0S3RqrAop9x92hOrVMrc3fnOVaLU7y/jYhLXIwU9YD3FT5jZBpdI5zG9QvcEFBGIYJh3QqcGM7o20qhpMYdtnz9m29SN2mSjmGg3ksThi8lzoFza4RKdjN9BM0Cn20urC8IWKQICNmBSBWZejTBMGF3lJJirOKdW29jkBRFxOqBopwgC4cqsp4bLb2IuAoKBc1UcPo4Mr3a8uyffpZX/056Pu1qoPAQbEDmFkcqSbONWjEmLahNjNwA7AUpSdu0FjYrZ0EKnipAHaDUApGzvFAFGEFU6bqEilsWlRCbJCxh/4o6VepkFZEa1oOjLFVqAQGyLOmWA9KUiJvPIIYK0Qxob+mGGd2QLJKC1oROIIREaYWPgtlyRZw2tHni3Nvb5dJ4n3XvWLaO3qeJL5IA7N6md0IoTVCpeliqgmXXopVAo3Ab+x1amr09bp4+xclwSjjtCYNiEJJaSWqd2SwKOh8Twy4qlPLYIPFBInKGqqUB71G+xtklKM+oUBwvjxnigmmdcG0Pq5qR2ePq7op+XnPKmnIF//x3v4L+dHq/PnH5aViviE2XVv1Y1g5mQ4cUIVkOkSbw1ntab/Eh+XgWRmHFmplN9/RCNaUwDVoXSOMoXOTe7VuczB6zM0rJYmlr4tohlIPQ03YrFsMR6+6UyJpJnjCu7e7h45KVnXE414TMkAzObytmzjkKo7IUhSB4iVCCiEQLTZnhAD4EBr/m8m7DZb/L6cVd7rtjTofIMi+A1kOHDz2FHjHYVbL0Umf4rM2bHkLAW0dvB2wU9C5RTHrb0ef3obcd1pYYZKoEx4wiCskDVeTenMFQVRIdG6BEKIUXLU52qcQDiQYeC4QvELomyB7ftqjVDKoqDSZA+pB0GINP1c0YIEbkE07dEikiWiRs5qbWHbaJz9mY3WBKTwSZAZ92erfy1HmWZMj/3piBn9/OA/zT4c6u67w8B+dxYDE8CYUQKYZImSvnIh1TqdQlgPR3baAoI2UFphDoMnWGTYyojZGFzThWl6tZWeohK71s70vwEGPqFvSrgV5r/OC2AnAWgZEG5X1K1qSnHSxFVPjC0A/5HheKUil8tFRmhOsiPkaEDITo8j3UGNMQfUdcD4TgqFTF4Dqk9vQZyuCHEy5O4Xe/3tPsWP7if+sCf+5PBr506x2efd/PArDsLlH5NykLhb4c+chzCjPR/B6R+f2OnQvpnBeugAkDjx8Iiqnj5HRgWloez3pOFulBnJwmrUVTKZT3BBvYG9cQbMIqZ/cJj8D7mCqBLmylfoKIhPySKVXR+SW33n6HTzxzk+9lk//qXX6w/WD7wfaD7QfbD7YfbD/YfrD9/2P7vq6Y/dCHf4lvfuO3qcxrXMhmX3fuKn7oxUs8fOoUJXuqANYqljpQZqNjTaAYw5fegP/wH/xdfmjnv4cp93nPj32d+dErTItUTfqxH/GcrsY8vDXjqUsND3/vReInCvamF9m5/i8AUItHLE4e88Y3PcMN+B/9mf8DE3ODt4ffJ84kF7IoX2UuY3SPkQ1GzRHGE6TDxsD0YsJMvfDBC7x+/22EKShKgzYNRZXYit7XBJeMpu+7+/TrDkIk2AFCi6qTArXOchkhWHShiJkeVQgHQnCybjk6clgS4PjqrmGyc4Fipjm1DzkZHjDroTQdbV7gFR4aVYA0+NATg8EG6J1IXm+ZreNVoKgD0kYOH6wIIlJXcGFP0K+yCqAEUzTI9YqmiQkMGQuUCBi9lUBExM0qMdOrpcDHwGhcUGVf1MtXr3FyOKOdnVCYImkwUWKKElWkqgBtS4wCKwTORMS4YKcUvPWqR166y0t/+jkA3voXM5bHj1BdZFimlazPJnZPtBHyJjN+ZrNKPq9ptqmeadKKtyxgPIF6JHG5DdNVEJvEopRVui/Rg+2SLMp2VV6DqCFIxeiCYMAihWJpA3qSnvUyWI51pFWOpy5fRq4MwpwgLcRQYnOvOUqyWnwC7hID625BIw2dy0LIEi7vX2BlI/7wlMUqeSSK6JJG2qZ1Yy1eRHwUW6C4zeyuoU/vw6pxlMby/N5NjtcnvL5+nVV0eDQjqVlvwC8y0lmBLyNKlBTG44aCIJKsBYASFTH0DK7EOkFdwMWDKZNSMbhjVJkAwBfMlPVogqw0lRmBXVMUgn4R+ae/+3kAxp+qeLGaMmuPUEHSzlesC8PSOWRwDFnUNQpJrwSdSH6hhVTJD1K0HHWJsHO9PsDoiroaM7drygJOHz7g1jtvUE7SfYiFoDIN2kScHGiHFoY16/6U1i8wdboPF3anCHkZ5zQL7YlKo5RhCGxbc84NgEk4R+EhaJyIiFyNqYpcmUbQditU0FydNoRrN1AuUvctfcYtHXcDi95xsUrt4gSa33hmnqEht/6ZzhNF+l3veqwdzqqLzuOtS5WUDXM3nr0vG8C+EhKGnt4tOOqP2WlhfLhiV2vYtK2LiuA1MlQI4ZBVSbQDLOaI0iThOUAEDTa5ApA1okS+bpVrQCn+Je28QHKfkOkjVHiy+rURaZbAqYDuXWSAbRjI1TGXgfkbYejzFbMtQUBAiKmC8kRHVJw71jmCQZnjXr6FKJWq50InhxuhE5FB5X9DqpZpA4VJsUbpVCEs8vG2mogRZGasRhG31bIYz55PDJGN1JaUSe0pTFzSKXPpSFqZ7XceosMIg4o9QiZPWV/kZ21KQvCENrCad5gqIkyOjLkl4EPyNZXKMgiJ7yVWOoSqqOawFBuTcVAt7E3hV/95YLZ8xEffC8W65+tf/dsAVE2DGimKakBNYTQOLFvLlecTC3VnN79epWHoLMenA0ZeZN2/zd37gaMTydEifflumYgUwXmGAZpxzbqLhK5jZ7fCx3Qvuoz/9N6m+SAAIjFkMwICS09Qjre+/Qqf/IMv8r1s39eJ2S/82Z/l5pWX+I3f+H/g9pI+yFMHkhpBP2rZj5KjJaznlraFRcpr2K0gzCS2CPyjf/E6P/QXX6F//OO8+sWKh8df4Uc/nm7eW1+6xbWfkbz16C7XP3pEiM9y563XmD3zOS5e3gUglJK7b8LnXnP8gR/9Sd5/ZZ+XH36eUk4Z5rcYX0m3OHSn1L5l5ZJY5jD0hNixsDNevPpDAKwfzwliRlMVlEVFoIAYqKoSKQdCSCPxkt1nsTjB+jXd0BKHNaYe8FEybMjTEXzQeL9ExJYiOJxztDagvGTIydTj5WuMLlXs7F+GMjC0jvbRIb6FxYatJVaIWqBMAtpHUdF1LVY4ipDMagFkIZHGge2TU0Cpqb1guXC0mVU2jgWKwM4EtOowOiC9RMgCnQd08ANCFslcOeok3hdafLQo7XHZ/npv7wJXLl/jrdkChKcoClS0KJ2AsABOKIgV1kdObcdC9Dx3CUpd8urn19Q/fQrAB/7oR/jqP3mFxfIOrYrsB4NrLa49KysrkkzFJkGJnLUgAlm8EXJLD1SMNDpS70NzUyIPAv1Z7klhFEoahqZDKYVvI8NhoJ2nc0Fic0oX0IWj87B7RTFfBfaevs6D0yT4GCc1q5MVS2vpJxXlIEBYog7I0qBsjuAxJMNnqZFS42Ok79YIPyFscGgIalOgyoLTVc9itcI5h/AWIQQmW/AMrt0C4XwMKJUEfIXUW7CNF4rAivGli9xcPsXt5dvEzqFtIFSRpsrAftMzhHQc4eV2YpdCbyeJdGWGHgNRMiphMq24Oh5RVoIiswBnukQoz76Z8mA0p3xc09qW3UJxPE9T7K/87m9S/+hPsY/kpLvNSe9pjxX74wmlEHS5TTGQ7KIKqfAy0EWoCoUYLPM+TRhaSsqypmnGhMUjtIi4leWN268hdnJiFiNXp5cwlaAXHauupV0sOJ094jSeslek1dvOaMLFcIB0FUv/EE+yqwkhbBMz7z0iREJwuChROiBEJMokhltlPbcYDfOlpx9Hdi8J1NWraKW4N3vEYbshCGi8LPBDwKgkZBtjttARgiILZG6SNRVTcpOuJ1k0bXqCioQ1FFqBSsbiUSgknhDP7GmigOg93XrOg+U91Knn4EEkHFxM4CeA2iFcgcITbJFkM9Ytcb0krkpsvhdGVClh8y4lZPm5xXiGMcsRBUFM7MP82ab1KPOYXWdQPTEJ6xrgWMCchB/bjPHNJmHL8NzgxDaJV+TscwIJZ5fP94RjwBncbAthUKSObmk2MVWkTDILFUodMRkfsSH4mSInZSYtCmX+noWEIpxhzKJI2MFAemwygB8y5s7F7bVsv4OUOB8IHpyBLksLmSiIweGMQCPpfKCUGh9TO9Nm94nVPRBc4odf/BFufqzgpH2Tr33ry7x194jReArAuJng3Zq+O2UYDJKCIDtCgCGErYTSyg+0AsYRrl6UvH4U2L97md1p5M2X0zX/kR99mo9+6g7zZcLTvnprYDyF63sGc92yzrWBcJwSr4cPPHVdI2XF7HTN8SzQu80cFIllSsqnBw0eyZ13Trl+qWQVPY6NdIjFegs+EzJI7OC0aM2yLi5SlJp7d95mfvK9mZh/Xydm7XzGsx+v+JT+M/yjf/g3APjk3h2+Ee/xzHskp2tBf5gzV2e2IOjlGuZHgR2h+Oo7gqPwJj/7qR+iF/997v+Df4vPPfqHAKzePuU3f/UNLuxJ/sUXJO/58E0mtma++BLXr70PAHu64ktv9ly6sMdzF3+cVw7vcjAZ8+qdNzGPvkR5OemYHa6PWdlI50/p+xXBOY6JtM7w0sWEmXo8W+H6iBxLjCoQssAUnlKOKYwgjtPkc6W7xvxyz92jNaetZ79eYewx0Y/olllSox4Rio5Rr5k5eDQsaddr/NqgdL/FWnRWc7w8pdw7oCluskuB3xUcdsfEVXqxHkvFoFZMKgVqhFVjpA74weOLimqUJyChk5CLlkgxUCiJ0JLSBewyvfCd0BgFhbJooxFSEFWgl5Eis8C0BMeAQKGkREVJ1BXQYx1bb842toz29mmqPTp3nIDFDiqpWW+ZiBLrLDFYpIdVB7cONS/s9rix4gu/9QiA8c+9xfDMHuLkLnsDLJcW2cLywRnItsjLxKZIq+smamJwDBaU1PiMRilfjLRr4BQowY0FXQxU8QwX4ktwlaeYekqT7kt/3Cffu1ZwcpIrn2WH99CMYDzAo84z6xTGHxPW6WBHpyv01KSV2+oYUezj2kBoI2roCKt0XVIkrSqUpFCO1nliq9FxoNswYoeIL0ZcqAyz04q2vkS0d1kNisEPBJEqrZ1P7EWtddKJMhUeS9/3jLIv6sJaxkLi5JrrF65x6egmh+u36PoSMfJs2As69qxsZMXAviiTDVMhUUhkZhB6DJ4a7xStVkzGY65O9jDTMeVkhMnvc3HaIvQIU6w5GI85aQ5ZHMNce+pM918fO37la1/gpz7ycaqTktCesBBLGgoa6syjhVJECi+BhsI4RLQEuUZKwfE6MaSXesW42mVV7LCrJPNeUbjI8vCExw9uAbBfN9RFpDEBJwbmqxW3H73J6w/v0GO5spOfz1ixO5qyVzZc9xZTV5g64UaXeWGzCpEdC0J5ehUoJRROYimQ0Z5Z6xQjelezXibgNNOSp59+GvN4xOgwAa+LuqQrSmKxzhVQj48DUUa0rgjJVRUhDQORUjhEYZhKxdoFnDjDvq0i7MsCrUp0kcZeGRQtUES/regNQ88k7FAUmp1OM8iCE9uyPvXQZD0D6RGyxtWSqCAOHmSHW56gFdguJ3llwKwtRMWQY4+Umig9cWMd5mKyZ9Jii6nSCqoYERuNDFLSZUkM0hE5icmk7VlO3vqYjP9sxnltAPtbX1rOti0+9dw/zlecA2cYtPPWSrUi+ZzmGKfLyFCCLlPi5irw2dN5U3AuZdIrEyHFJRMzIcClSpHKUM44gOhVig2rSFizTdI217pxJ4kROpIsRozZGSyk98arGiFlwl4VkcY5uqCQjUJYwaU++Qj82V/4n/AHPv2LvP9nTnjz6J9wMjvg8PDH+PXf/Aaf/dwX0lhcnNI0Ncgd5LBCElFC0vkBqWrskBKZMCSZj7XS1JccH7l0mZtPjXmzO+L5JlVc/sAn3scXvv7Pea2FvasDVVmyXPW89jVLJeFuTlBfeMnSzgSlm1JNVvheceIS2N/k96asFTp6xhNFiD2H9zyVNChjcUOFHbJwvLPIkHUvc3KvjSTGkHzQAERNzZqH7Vsc3p3zvWzf14lZXewxO3ybp27u8PO/8G8A8M63/y737rxNqQPLQaKTXVfSC0tzMKbULOaOInhmK8M3Xm3513/2Cm/8lb/NPQ7Z80mIjmsHfPvrjtcuW64cBEIx47nnJ4iqogxZBbk94a0HhnryEodO8YGdXVaLiL37gJd2LpNxtsxZ085Oub/sefPxIQ/7Y46Oeq5ceJaiSKvcB0dvMakrRqWklBqEoCxqxs0E4TUit6SOmxlTu0uzOCX0jsPDY4IXFPX+1p7CrXt8dCxiybxbc9IvaWOP8AOVgTJXJWZB07kVnX3A2Fzg0vgyKkZ6NdB2ieDQD44ugBocTdPjdYVBUyuFCRGTI0vQkSgchbeYGBBoglJJEHPj+bfyLKMlBpHK2kpS6GwqniONUmIbKZI3YCREh3eKgMRl4OhIwGhUU47GrB6doFRDIQa06tAmq5arjijASMFEgrAwtI7XIrx4c8Lq9VMAPvvZ13nm6g5+eIYbHz3llS/NCLuBQYDfdEWXmuZA8LC0XFOKITjKEnYFzB86mp1s57Fr6UIk7AoGHanHkXIMoVbozHQVRYQi4pXhdCKJtuDg5gvcu3OP1VNLulka+CczqMcwPwU1lvg+oKeeB4/WfPswJRB7O3CjlhRWMl+3NJVGW4lVEi8CwWex3dCiY01Vj5idJI5VGHq8r6k2E5nysDhm5RvasiD4nlLvsYorahsYcl+n75KYbF02aK2RgCNgrWPIenuVqmgHCKXH7Fd85OKzvN7ep5OAriiyOnvZwDALdO3AMiaAcKEqtNYYlVbWxhhkJ/AYqtEeO5carh1cYjLeoywmxHW6X6Mdw76bEIeO6XrMzt4es8WKICH2WXJCwMN7S35ffItPPPMU/WyNEB637vBl0gUEUEZSmoK6Lola0LuEmlbK0A+pYrboWwpZIY1GVzWy7+iE53g2wL30bo3G30ablnK5Q+8sK3fEN9+5w6P5EmMMO5uFRh2pTMRUFTf2LrIsJdSawgsW2YC971tkPUIGnQXSdEoSSBXtokjvYFVVdDHSDw5TjGlMQyE0oysV10wiLqyEQQnwsUaoLnlROlA+gow4ziQ6TBS5xQlBCUT2A+zyYtcNFhc8Rmu0KhAq9/pJEhabqrOIgt5FjNYIVLJ7ip6u97S5zV/rKWiFEBFRlIiywrc6EYaGASnz+xw90Q0o4RIgO73k6XxyU630acLM1cDkCSqQMqBD3Op3BXEmEi3zdYpcOdvs05HcAtrc4rTxjHEpxTlFi3MZ2ibhkTFVYM8XzDZtUEFupcqkzSkqMFXufKiAqgVeRUwBFFBpQSEFOlcF88fIFFKQIcU5GUAMEIfMrh0koZOEThB6d+blK861nGELXveAVfl7ijMaqBCpOpqI0gIRPLulorceyjEPsg7oL3/rL1Pvv8F73/MpXnj2E3xz/S2a0T3+/J/9BH/0p5I7yF/7m/+M3/r8y4ymhqglEU+0PrmEhHZrdaVFYmI3lefmdMS1m/vISjLqGn7z5W8BMP9PHvBcd4Xh9iGLWlFc7rn3AG69bLh2xbJ6nO7+Nx5rSmcxpaO1ChEjwkHhQG5cC0RANzCeGBYnHTHAzkXLMJQgOmwWOaZLD3JjGO9lxFpHEIJhK2Y3oETADQN37t7he9m+rxOzazcaxs3Hefnb3+Cl930UgG71CrOvvU1BKiH6ssSueoyLuQEGrlOIVcCpwLRWfPELX0Bynaf/yDd59W9+mcfjtOI/CvdwZcFrXw3cmsCXv1Lxvg8u+NGfNpyGhDF5vFxh/QGR5/jES09hg2XmWoZ3XqN/4aPcvptMWA/XS954eJcv3X2ZN+/fp19oymtX+bn3v8Tjk3SsW2+fcKnYQ+YwpqSkUJpCGRAlTZ0qa/vlPWZmzKLaI9gVfvAsVpaS2XZSjH7EuktUerfuiK0j9hb6ZOXjcwCvQ00MgpPlCUOtmeobjJorHLiBk0lKBENrIUraIaALz3pYURtNiBFVKKqMDRMqscGkSCVdpXWyTBEOJXKpyEr8EtZKUVQGJSSiAaElIZfv/cbvRMXk7yiB4JMyue+wYVMOHtBNzfRgwurQEBQo7TCVQHVhe02qLImDZ+4da6DsJQOBXns++akbAHzli4+4e7Tgys5FrDHsPSdZnAZ2RMWdXGGQ3rFTTViWFjv12BmsK7h8oImVo5umez+UELXADpohWNoyVcqUT1pqkHS7CmlodEkQI5r9G1y48kFevOr5m3//b/BgN+0nF4KTdWS/qBD0PCokB8DOKFBlPMStU8GlqxfZK9Lz12OTWL8IWm+oR9kJz0ZkrJmvB6RpiP2SfrlmNLqO3vioMHD0+JDbvuXlu494//4NmqqhFIrYTGiG/N64I2zvkDEZXnsZiUrhnMBlnM/IRawG5wdGRnP1xk0+fPo8d/UDQjQUOt3XsoLZSWTVOZSKDFFQF5rRaMKoShZDRS0x7ZogCya7+7x44wrPXrnOZHefcTlm3T5IX5EVkkAzHrE/2eN+PUMpQXBnTglepUns228fIhi4WVeoPuHmnPHo/D6XVcFo0PS2wOtkUtz1Dqkj6zYtWg7bBS+MLqKNoawbxGLNyENYOLoHaZ87YolyPePRHkvbcTKccOvRCTZoLoxHaJGST0SJN5pRWTPZr7kvSlSj0Z2GvEJv+3VqZ0aV5sUgU0IhErWuqlKlddqMOIqKZe+xQ6CMGiElzWSPfZVerhO/w9iUVHjWNiVKMmuyiwhya4Qd8qdJDiXkG2ltzypbdNlhwDlHrUsKY9LkLiObXvSmZZjcHFaMi4ogCwYEax1YrT0LmQZH3QSEqQkin7UY8EVFbFtc16KyHIN0mug8zvdEEbZ2SjEmEVtIelwxsgWBivxTIpAiUm4WgPk/KoJBUGRcWhUjk013gWSAviLhzzqStVPI92vT6tyMpE1VLN3BzXny4pMk1aE2chgy4cRUJYg1xKzLGTTEJhIV+EIgdaoCGhEx+X02IlWTzADKQrTZHcQLRCdg457Ra0IHvo/ZKzkZyp0nsaaWbNwK3wYBVqZFc8ytbSFkFlBNrOylFqxCYKRK5Fqiy/Q+f6s95S/9Z/8+/+H/+6/w5//UL/CpT3+Ai1crjrvHnPanAPyFf+tf4z0feJq/9rd+hWHlqJoSZRXR+wyqy++P9hS14MaFEc8++xSTKyWHRzNW7cD+TpqvP/PlY74s4LIThK81HL4AJT3RWE5XsHeQ3pvX7lkuSc3BU57jVcuoSuLDYeNAQsI0TycS23d0a8lomvQc+94TBWRLavQGo7cRZI9JEiVJZWwlpkEKoorcuv1tvpft+zoxm81aphcGrq0vE2TStloea8IafA2qivQzR7sCIcHkyiLCs24D49pgRh23b2ksn2Fyac0f/OMf4m6bhtZXbr3B8s4KHNy7DU50vHUbbs8C419MT+bNeyX3Tz/Gz/zQFYI6YbFS+NMTFsslrzx8THv0NgC3ZnN+79aXefvBkngIyz3Ln/3EDd5/Y8R/+ZXfA6DUz+CkpM+Ot0IKvJD03qFVgW7SaN0bXeTBbI12EYcnBEm3iETr6NdplJ34GbFYEXXNxEkKrwheIFyqGDQ6JTcyQvACZwVrjtF1YKymXB5dJAwn+Xa9wzAEooPQe4JaE2WVLHQKexZ8o044XAFRCmRWV4xRofKg1lEShET2Er/QOCmJShHjGY4mhJBB6kkGUmmJHCLereiHU3qdWmBD1yAKQz2tmNQ1x/YUdI1QwxYEPag1QxvYLUpGTcEqWB4s1qyP4BW/4ErGCv7Ip57nq6/eZiE7qnCB6TNrDt9YMpnCIrlv8a/9uY+zfHvB6vGCwwqemzas3cCycpTPNzBK17XsW5gOrJcWeog6Ddyo2S6ZSxHRMWD6gclogllE5jszPvbch/lj7/9D/Pu//usAxD2NXFpmtuOZUcPN61OiWxGnC96To33/uuSrDx/ys8/sY5TAR4HXybarlAVxlAKlWAqWcU1RGMqyZr1e0rUtRT1iL0fno7bljeUDfvPVVzhte575RMNFuUMvFafWUZcZDyUi7XpNsC5NukoxuJjxNOlYffSMdU2HZ1i3DOMpP3L9fcwXPYMYqDb+thUcqiQ+mwQ/A0VVsTeeMh6n8xVjyWK9QpUFk52SmxevcWX3MqIuUBHKrMNm18cMUeBsidKGslBMxjssV6dboLeWAucFqoQ7D+e40YrLOzu0bmDAETOOrlKakSkIVYX3GjeERHQRoHLVZtbNqXYLqqqhKkdo+Zi1lVQ6ojJY+uiRJ8TH7F+YMajAzAf2ioZRvcuV6UWu7qXYtTttKIuCsampphLnK3ShaJxG5QXX2g2s3UARzdYZIMZkYy6UpMiWL9N6REQy7zsWfctBaAhe0GsQOiXXTgV6epJGeaq4RZMMxI1IWnUAnZS46AlRohFZzt9hrUgG96SktveOvaKhLkqQIuGTNopgG02nwROswInIylp6NN2opO+ga3OAVh5hVAorEigaVDUmDKukYRBTtwJfELwn+mQXlhLVAMGfJUJbPzUgnBEZRIzomGIVZJ/ZbXKVSmcRgRbQbd6bmCbMhlT92pADAuckMUgVNHfuM7fRShNn15MQk7nSIkFkElDUEUrocxVVVKlqJasE2Nf5OJqtNjaFg8KCbkEPgugTXC8OkdBJZK6YRZeKFQkDKnMnIkv/nutQxJA+8zKduxfglULkGE5MDhSb/p0hoiixoaUYaUSXzrcbJ1x4/ioPl7f43/ytv8+1f/pf8dHnd/nDP/VBrj2bbJTevP15rl0P/E//nV/gr/7nn+Pu/VN2dwSuJemd5VKrCnBhp+KpZw64efMyc/uA4DxSwSyDx+qDEcujNbqIXFn19Ld2Yarw4R7tXDDJvm0HhSKoHapJBf0yCd8KSRvDtnJYmkBVKU6OoGmgrgR9G/HC4dyZEK3M1VjvUnItdCT6NFeKDMPoo6MwgFTce3SL72X7vk7MFqsVVAJlYL9Jk+L++CXkSiJHCqUs7drTrqCucqkVGJzDFKkds1vBG0eOePKPuXNnzN/6zw3v/ZHUWvzgSzd48PAN5JuGzluGwVDEit/8hwv2YqpCnJ4+w+7V5/jhD17nuLXsVZ6je9/itTfuoR4f82a2IvrarUfsj/f55N4H+Uw85v0vFPzitYv8ta/+Y/S9HQCqaUnQEufzqlVGjI9YPIKOkEUt/UQRi0CIPRFPqUcIKnQscRthRbdAKEuPx4eksVL6yIW8qhxlrZlT1gwRBBplB2r/kFHZI4sLTCZJc2WuBk4XxwwuUoTIlBKDwseApd+W7oco8U5gUGgFiMigHCqmdiSAMAodBSZI1EoQBDhpkI3Fl+n5aK2QhBRIRQRFEtpkwdBXtJl5uo4TSjnG47F+wOMYSUOnNTKbpqu6RLqCvl1TmJIbN67zwl5HXe7zzpuO23dS2f1jLzV84rn38juvvsV4PeXqpRF1vQbh+NBHUkL8yU9f4q9//U3KnTFh5KguOZ69OmZwDcenl7G5ve3EPXox0GuIs9z6kRKPZ2M/GASYGGiHAbVe0+z0lN3A/cWSH/3DP8frL6eS99+/9Rq7RZpATvs1L02mlBef53G/ZFGkQf5eafn6A8+j1oM2yCgT00l4NHJLD1oFj6oU7nSZPE0ldH2PDZorF1Nl6itf/B0W9jFHtx/jLx7QtwN97dFGUww964z5qHWFaGA1XyRyQH6+McYt7ggbGDWeUhUIZ+lcy8HVy3zUP8eseAeVHS52TEdZOuxGYl1GVF0yahqm441wrEHrAtUUTHcFl6YTxvUYZxzYgS63a609Yo1ChssgC3YmNdPpLsGdMsuFVhVStUQh0EFx/9RTVAMTrfEIRMZzBWEotYF6RDf0CLvGSIN3gSZX1VbrEzw2V4oaCmOYGYuMUIb8sHvN6jDSIKl3Kg6MYnJ1zMFoys54h9FurgpWDaUyGKkQpWTSSaJRmLKgyD6fM+9ZDAO7vkEoqIQEIRF4ohBUJiVd06pBmYL1asVsvaRZj/HKEPUaTWrD9naF0BW9vAAqJFKIVvmPQGw9f1J7WUqZhWdBK8lgA32umK27FS5YlBY0VZ2qqOJMZHUTI9zgqdhDeMVxP6P3cFwqFrqkW2ankRCJFyIy6uQJagKyGhH6MaJfIjaOBF1yJBEIYpREn4gLMbJ10QibNgnnmdPp/zEbFgCAzInOubYeMSbdr3PJW0FyAvBCYEkivjFXSjYEByfOWpxbzFY67VlrMwtXSwnWgCtTEtaaVNG1mxxIp8TNZxCb8NlqKUiK3E4zAxQ96FXCk3kHsY9EB3EI23YgGZgeM33QZQaT4OxLbgRwHTCoxNK2CqwUW+KCCFmoV8S8oJYoIpUaITqNyDCZNqwoFwNP6wvML7X0tuCzX1/yO1/8HD/5Eykx+/GffAofpsz9mr/wFz7Of/bXP8cbbybmo111RNICe3wged9TF3jxpWcIWEJrEdbS2eRRC1DMVjQeTiPEqudadwT6vTx77ZDXXnP0bYrjxrQE7ZlOp5wsBjQtMgZMXpwAjHcsznqKAupRYLAifS8v8V5s47jMHrQbXToRwIfkknNWcYYYBUpF5t0R38v2fZ2YnbavI1bPIEXB7n56Oh983/v5B58JDC4QOhhs8l2MraTPA0cSUA6CgTZKlibwe1/7DLV/npsvvsRilibr00eKT77nfXz5C7/H6R4sTyximax1vvCN9JZOxxP++C9cYjLWPHq85qh9yOtvfJMvvXWbQUVef5Qm65PH8OJPFvy7T3+YP/9jl5jv3+OvfeM3efONI14YJzVga47RagdNkSx4VMC5QBUtzneslimgPli/ybJ7wDAsiFZRVzsYvYeqKmy2w7naNahiRnA9ul+B7ZA+ELUg2IjMpdhKp5VJCB0ejRaCZpgRTMuuugTAvt5hVlnW8RTfg+4ChkCgY5CJBQupTWFDpqln8dNkbyO3g1pqRSVM8mmzkX4W8H3L+EBQlOkZGmMQUiFFxOORUiIJ2MHTrhfImKooczdnr7xAsJ6T9RpdC6IcQCURTgBtJAhFlIYH8zmvPzxEaMHNpx7z0z/1MVT5QQDefucO02bNc/ca/PFDrjz9AXoeMm8DT783gUuvminXJobeL/BiQjVdMdrzuPUuR4cqYeMAryt66xh8hysEPsRUOfBsWydSgwzJI9P2Kw6HFQsd6ZxlVhT89M/8HABv/D9v86oZqJynGAr6laS8eMzNqwesVok9/E75Kh/xgXceHTIbWp6ZFEjA65AwE7na4qWjtx1KRIqiwFrJ4B3z2YqhSnIT3aLj9smK9cqyM+4pgyWqNdJBqfS2nUEMVKZiMB1taxEypklUnHks9m3HYeHYr/YodUFwLatq4JkLT/NACLpZbl3pFaawdBZsrpaasqA0BU2d2+SmIESJKTX7U8W0FuhSAh5kS2cTqPZk8YjDOdysr6NNTTvsMhmv6dua0zZjk3xIJAwXESiikdw/XTIZTYlB4DdAOhUxpkjq3jZVdqVP2JyqSAszIXq6sGRSTSlNTVmW1GubqhabibPQ+N4SrWRS1OxdmDDdP+BgPKasK2KTFoJCjyi8QShJKCXaOawWqLKibtO9emRXqQLmpsk8Pr9QMiS31Y3x+KisKOuG+fyY1eGMWVVDM0JpTZl9JGspqGRECI9XGpCYoCiEQhQan1uWgZgqc1IShUfiUCo5c7T5ns5WS7osn1EVGq01dkgJnPDhzG4sOxYIAqt+oO0cj3rB3eA46NO1D/RoZ1GiIkhFRBNMRTAjhHfImIOXY1st2wCtN1ZMuQOW2IcbqA/n23bJnHybewa2BuIStmr7hsRshJRw2ZjA/yGmpGwD4D8zacsirWJzls1/UgXqPCtTiFRJtwZ6A14LoomsBLhNTg9Ipei8pyAB+r2H6AKZiIxyoHqBWUVYRwaXKkw5C8veI5uGaiSIgBMBF3ICIMFtKpq5KhgV9Dq1awctsVJuF/0xRqKMW0mVoAPRr/GqpE+eHQCMMzX1nmtxds1IBkbNhLDb83f++dsAfOYrt/gzf/xpPvnhj/J47vmlf/NT/J3/8tt88xt32dkTKJXerwuX93jx6etMRobHfYfFJOuw3jLdKACbgodhIA6Sw2Vg6eDpeIenrr3Ig4NvMVvnOXsXJnFAxyTB1LZ9Ikr0EZPVeE0Nwwp2DwpW7YALEq0FCoO2IoH4SPE8iJied34/NkXJuGFnOIgevAh4fyb/8V+3fV8nZrPTh1zcv4SUJY+PU/VgfCNhTWZH6YVtomThLXaQjLJ5r6UnFlAKgdeBUsC3Xt7h/R+4RbtzxLVn08Tvujkj/RR/5Oc/yv/9P/4qe5OINQ6tFfs7Sf/q4v5VXnqm4sHh24RgGMsLfO537vC5N5ZoCTrbFRVR8Po7D/gXH/gKP/3xn+ON04Fv/MqC3cWYeUjX7IvIgYpEuYMwRcJLOAs2ooVj2abWol31mOCRCLwuGe0esFs9RWX2sdlFYGgfIINm1C+IwWOVQ7nAOkSsh2qzUBSgo8XZGWsUCIkOyerJVOnlGwdBEwW2KOiHIVU2YjJtXQ0wzgFcekvnAr2xDNEhfUAh6JRD+KyVFBRCm2TTEjzeeoa1JSqBrtKALstIqcUWN2OkQCtB31tWy1OUT6soZyRFMeH/y95/B9uW5Xed4Ge5vfex1z7v8qU3leVNpiwqSVXIG0CoRUsC1EyPQJoGMQyjHoZA0EKMOkBNd4gaJpBo0SAEUksgFYVMFSpVqVS+Kiu9fT7fu+/6e9x2y8wfa51z70vZmomYmIrIFXHz5rv33HP2XnvttX+/7+/7+36NXqVyjl4IONVB554iccwqrxkIRYbAGckgz9iwLU98ZsyTX/gIb/+ymEW9+bE3IKaCt92/wstfmPLIvY/SrVpeee5TnB5EFKLc2WW1alFdxX455pjtcXBlxG49o60dZR4fGHXVQhPLICUxe4oG02JhaWQtBBsIwhMKSzPd43q5xdpgiZPlGvmFyCe8/+7TPHfjEkobcJadzTHXRmPufeOQR4rYGXyQXaU929C91HL79ovIk48hhaYNEhsCMmXDxgecMQQ3A5VHFM+22EnLeC1e61lVM53AyFseHPZhNmYnbLCqT9ExA8pUNtDKEoLDFDmzuqRxbURUpMYkgSVrGybTFuMzVG+AlJKqHmOGy5z0Z9lOThZjv0+eNbSTNvrlSY1RGVmR002l01aDDwplcgaZJddTnJmBc9R2Susin2vSzLh9+4ClE7dZ7d9Ft7PCyuoIb08ymUa/1omrqWvoFRIhJLQtpYNRXXGCKAEQF7QiKIm1kcclRYaWFnSGSkalmXZMw5gTxXE6nR6d7hC/PSEIaIjroWuWqOqWncmErNIsqxXyvKDT7VMUBSRemFCGTHYwsS2PmZ2gEsd0kCd9snrEQTkl2BaXRfG7QAyGvQCTAudeVlB0O+wS2B9ZxLhkaApyV6CTpEaHDj2pkTQ4r1OjjUj2WEeCHCEgBdxCRZsxKyK/cG6jNq4mTMsJdmmJPI+aglJKNAGHv+O9nLBIMlo0plGUBHYyydTE+38qpvTtAGUKJAofJOgMUfQQ3uGb+ZO4JLiInITWQkifEzgi13+IZPkQgyufypqCwwdgAqQW+oMZMUAzR97KBWIzEIdlyrm3pgyHQZdQAqkNQchYGkxRqRMskCk3fxcZqLVnZjyNElQaZiKklQOTEF9WEflP0sGgBuEUpo3vpVuBrkGVDl/7JGtyVN4j6fKlLx98lFMSh4HpouKbzsMnsn0loTWJcpKgN5nE1wJxrfQbD7nEegnWEZI+x1QppHAoX6O8pA2Bwkzw45YLJ+K+u7lb8z+97wqPvfM63/mdb0aHs/zX3/1mfvZfbfLiiw0X74lo8pvuvZfVM6vUfgR1PJ/JZIIqBeMq3huXXi0jXSR4dIjWe2q75W1vei9jtcuNjcj3Xs1XWB0ajAwo2UFrT97fg11Jp5vO0UOeCQINbS3IemC9iIoCJnFTYIHezrl6zsfvQcoFOmoMSDTOt4co9B8z5B//ktfH6+P18fp4fbw+Xh+vj9fH/y/GlzRiFkLA2g2wln/xL6IuitWX6Q4kzaanH6IafWwf9jgTszuTx0whyAjynhjAv/6tMY/NHA+e3WO2FSPy46e6lOMrDFcM7/6aY/xvv7DDXfdAZymDEHlh9731LgSW0iv6+TIbr77M7z5xA2mh46CcZ1EusLMBH3j6Gu+//S85dTzje97wrfzyhz/D9Z3IQzvZnVEUU2gzVOhE4+OyoSbQyR1NiOUMFSzGKEw+JBOnOXH8QU6ZCyz7AUFGVK3sDZnut5jKMw0H1G2g10DWROPaLKVkrQICKB8YBIvz0eC3EQbRRCRP2BwlLV4F+lZT0uLaFu/iAipUEr8Usc1Yeou00AqH1xobHNLOM0VPKaNKt5YSjUYRqCY15TSmbb1BLDOjImQupYxt/c4xqcdkIvmn1S0d0+Pc6iqikHhbRS6hcMikqSp7BtdkNHVN6xTOWoYqUNwFvpa8+NtxInY/+Bne+OASd3/1GVaUpxzt8K6v/m6y+gCTLBD2bmyzZqDQhk7VsmxLxkFSz1qGXrAznpfKS6RqmCLRRI6HR+Ccp0lprBJQZw4hoF8PqYuKnRu3ue3XON7dg2txrX700hV0VzMVLa7R1Ftjhncd4+sufC/lNJbTnrr6K9x3Bu6xisn+NbyyCJNBJXFIQhGzU72zT3DQK3pszWps8CghsdOSSer6q2yLCuCw7Nkpvf49PDfd5kB1eajbpd9JxPGqpWmbKD6qFcEmLB/QiVwulEDXlklTIbuG4z5nZjyVLRl2j7O6cgqAA3GTTjGh1i3BSaRQaCHJCoNOyBQqRMFSkZPjCGJGK0qUU9S2pGoSYlaX7G7vc8W9yNrdZ8nyPrkRDJdOsLIXywhlXeMlzBqP0p6Blkwaz7W9Xc6fPE2Wx/kKSJwPSBnRn9DEzj8nFG7OTbINe7NdTg5OopRCm4zMKKaNo5zFbHqpp+joAdXYsXvds0JF3pnSLQYYLemk8r/KoMgVBTmZKMhzi9WKXGYM88Nev3EzpWkqVC9qyCEEMkiEUIu5jzIfHZCC662jmViWhx2WVI8ilT9HtoyoSraCno7jOSqNIyBdWJhRa6loiH6AShuMdFRNiD6J6dgbZyMPTwZyk0U6AioKvwZoE1IcvEPoDlJqFIqOMXSNQGQFrow37bR29JyPivcIvIxmkCrvgW3xLpYhbL2PCA7pHQSbvDI9/oio6kLYNpWX5kgRQkTj7CNdmfPOSi9IYq+RPCSOvCYI0v3BoVZZgEIelq5EliF0kdosDU5IXJINmXMxPY4QYjdpJi1KB0rlyVVAa7cwwG6EoKodtY77duFBBoH2GuUTrtJ6bGmRpU/dnrE5Yc4nmxdZ5wjO/LiFig0Bh0XYQ76UFXGeWiWxSmKVWHDVFCEpZ0gEEldkuNAghEY5hUguIk41lM5hZJcgPcJZbFOAsYwnSd5GK5aPdfn858ZcfeFzfOt3Xubh+x/iv/lL38PP/4ffop8QpvvuWqJW+7R1TXCSZv+A0jZs1C1LdaQC/S//55/kx3/h/8n40h43x7dY7QZeDTPEuuKxU1/Jp/TnADh2TOHHNZN6irMSm2lGASoZGKykeXDQH0p29h3SCIIwZLkntILMWPDJxFw6fBvwQhKCQgpQWkXkdq6RJyweF9Vt+JONL+nAbHdyDXVD8Ib7BnzVl10E4Jf/94a14UW2d19hOo0tsMpEMn2ZFntPaDyWAnDSkHVa7Mlj/PLP7fLeb2l44NGo81O/aml9w8H2mAv3Ce67J3DpJrzp7HFO3xUfKl/56EmcHVH7XdbyAR9++hJZD4anDLbyiAS72klc7FdvTdi6DH/2T51k+QHB3fef41MvR0Jgb9QwWJqgXBd8xdK4xqWFPXOOZpK0wHpD7r37HpYo2ToYcWr5IsfCeQauZiXJYFSmxygTHOhXkLcsPasQ5TayPuROQNQMDMl3xIUoCWABJ9uFgKGVs3iTpu4TRKQtWRO7jJRLG3jhCcpj8qgSXgaYuhZbgRPJ8scd0JQBKRWZ6TLIe6z1MqrWIOpYTjMh4DEYXaC1xYsanRdIOYUpbJaxJCX6y1iRsd67SCE6uGzGKHOI0MfOjXSZIJPLb7eWBOuxAvoKusLTie5biNKw98KYT19+lrMXYf+Zz/HAA29D6hy/GyHwQmccNw7hMjp6Sr6jaIIhH49YN51Fr/UBQ0q5jw82ko8FhODIAqRqDV0JRSvwBBqzz2wbptzmadND91vuLuL6ojNgx+9zVhpK0XJ8UPDmc2e47x3H+fiLsXw/+WDgssroLNecXWuoxhXWaYQJDCmY6BjA2ZUu8mCELCXr3QF7s4JpU2KnjmoUOVouDNFZiao1tVqmyYfcL1YYu8CGnXI8iyKge3JGYz11dYCSAWE9wdUU2tMmrlBwIZKGnYWmpVaabqtxomTPbNElTv6aOMdyZ4+RiNYmXtZoaejqLAaYgHMzGmHxcoyTLdZK/Lik8i2t22GWSmpuUtKiuHbzNhfPzVjtLDEbnqLTbRGpeWZvOsY3dSxFeZgIj9fQjuHJay/wVfe8HYg2UKaQBN9SVg19KahcwIWavowlVl1pRAPeO3IzpFCBXLvIl7NJLEEr/CBwX9njTz/4rTxdPsfmS1s4r2nOCNaPBFPDWpEph5AtXZ8R6gLXF3QSUaqQkv1ZyXhWs9rrUZsWI1QsF/oSncpWutujWyyRZRmzWcmk7JJnZ+hXjuFyLP2EvqOc1SwFQdCRg+RjOxlBS9JtjfWB1jkGyqBl7LjGZCAtOnFtprMRE2sp69gAgJZkLTRaIaRdPNQDkp43dGqFVQW6iCrtrmgZ61hiLa9n1M0E5TUq7xGcjdIQpqDRGYQYrEuRo3yDbWsgcp5C2yKBeh5tyEPS/2FgEXlzqaK5GIvyUaIeqBSyzJX/AzEoExzy0VQK4qQClZxLhNbRbF11UTrHhaitpoSMzUxAIMn/BAuyoVYlUlp6SEpX0U8c3W0X2POSsgyRGyqi9mTPCbpz6Z1WUjSRh3s0DIvl2bCQKon3d7SlCgGMjQEcIjqazE+yRTHTDm+g7njKjqIVCp+aiHSKYp3ydKxEyIB0Gc47bCFwqdVVB41E4rVFCYVRBoulbYAUVDZ4RDNleVkxnjl+9md3+dr3fo4/9RU3+XNf/RDXd2JkVucVsvJMq8D+bIv9cp/yoMWXgs9fjQ4oz2x8gk/9h3/ONASe//SL/Oi/+Ek++cGX+J2P/Co/8Be+me2PfACAFbFEo1p2tnt87Tu/nNVzL/G+n9pldZnEWwVZt1Qi0IRIeZKhRVlNluU01hOSxEPe5GgBbQjULspisKAEzJUIIs3DSo+d16j/mPElHZidPjagLPd48eYrvPnxSNA+d/4r+eBHBrz4yk1CKNEaXKsoCkuWtGGCE0wmkC1BZ7llZww//K1fzgPv/Zv85D/5YaT+LABGz7Aqp9e1hN1AdgK+7NQaY87w+LsiMb5nHLemFV09oB2NmY4Fb3jTfSBqyrFlMo2Z/GxnSpDQXYELleJDL26wevYpVobnOX/+AgB9M0H7LrmRLHlPbTQ9M8B7h/MV/SI+FPVuh/d+7b2snV/h1371STI95UTRpTM6Ta+IQUSzbFBuTNgS6FBhpcf7jKIOKNtGKQti+7SwEILEO2h8oCVuQp10Q1sCFoHzMfNtgFqBawSzOpB3U7ThQebRjkKpuAk1LtC2kird+M6DMh2G/R7dbj/y2bBoCcIlErQbRlXuUOPaBjKPMhavHELlzFITxMHOdfb7M86uPMBdd1/kmSemFKcNQVp02ti6uaYVjmUt8ZmKzgEemj1JXvmFl+nymkXlEDbBPq+YTl6kevhD3P/wI2w+m3hhuzt0gierMnZbyHQTs8dK0UWwltCDAstMW3IFrkhcFButVrqJm9DXilwJHBZlBYW2XG8Dk+ker1yT6E4KdruBQR096TJTsLdTceXVlyhnlo/+3ofiOfYC2kv298CIGnt6D1EXBJsjpIc50mpBGg3a05aO1kdNnqqxTKq40ZSNx1lQeeCui2d4631vZLXo8cTOVbbrKZPE3fMIct0hhEDTTmMAUHu89QuytXWB1ntqW2OUpJvlEFqkFjTOIhIvrDAnGWTXyYsxoVYE68j6kk7eRaU3E2HeDRgi8VwIHA7nS6q2oqwiX61pA00d8JOKva1bHLuwzFJ/jdF0l+PLscHh+PIqV2/fItMSaT2ujR6DFIrrN1ueGl4G4KvOPwR4TNsgTIGuG5TsUAhPmJPZradpS5wOFEWXvLuEzwvcrFoQzUXuWJ4O+eav/Are8t/+PcS/+av8wu+9gDxznKXNGnu8TfPgIavYrvdpdY+gK1SoWSGjl7pAl4Nhw86YtDV935J7gxee1lsyxMLOqpcVDLo9OkWPujpgZ/8W0/pehidP45Nicts0ZFkH65uIbkpFoQyZVFHuIkUyLoTkK5k4U4gFH23Oq7KNoyxrqgBCKoJWOCnBRQR23ukopMYIBYUhlzl9KbA6Q5kOapy8X6sYJHvXIq0FJXBWoIRAmAKpE6JpMnzrkDL6LXrvo5RDYIFy+YSCJfCLkDhiC3FVDsc8MPNH/nau8A+H7ynT+0txKAwrJJgk4yPzHJV3QfdQppirhRHwi67H4D3ex8qBxOCkwvuSmfdI1KIBpVZQeU/t4mdMgSZ4PGYOVsZAKwUELslfxE+bNx8cYmJzLpTnEEET3Pl7JzxWxa7MWsXOeUUM1uP8BHRw6ERo99ZjvMQTr0NIiyL4OU9R4EPUN7REYeK5bEmQUZhaeOhn0LqM3/ovJS+9dJnH3vUqD977AAC5uZtr0wO2Zwc0e9uMtxq2DgKTEvqDOBH/5H/8d/zar/4b/toP/WX+3Ld+Lx/9uX/G3/uffoJ/+D98mO/59g69LO6Dl64fcPykopN1WTuxz9Nf2KLTd2gDbdrfBl3YG8dqh1AeKRUiRJGTTLmo0UfUd5MShI/+pCCPiBunwN6p5EUqjyja/dHjSzowe/qZ53jXV97LK69cWXS7nT9X821/5iKz6h384n/4CEU3Q4iWqkyRPqBMi9ZgejlF0dDTmubWCzz2F1r+7NYP8+//zfsAOH7/J3n21ZrTJ3O6subaFbjwZsfg5HEeOhcRjdlEYIVDh8Bkd5/VpVUeebRPqEuqqqG28eGz8fIm1zducPbYPQzY57+8sMO1KyUPnwxsjFJAUirUbsHKccnS6oDM98DNMN0uSncQqVwzDDM2njvg1NtO0T2/xZUbn+PuN72VvM1hP5ZPZtsfR4x36e1PyasewgeqdkRTWcQkltggdosEC/hov+GY3+iH89wGgQ2xxGlCBIaUhEYFChUf+AD4qKnTtgFvAkanLNV72vR+Timk0mgjMMahgyBTAtUO0Ec0a4TNwIEXEwSOpnYIOqBHix746a5na2OLteETfP3b7+fC7Ye4PX6FXkcvNJh8kKyYDjZUuLJGtxCmEEpP30E/SSjk2wFdCfJaUYiM6pWSSz/7mzzyV74RkaK36c1rNDNNrkqMXUaPK4Kw+EqggmduSVkJwcwLykJSp/72YBRG5ORpNzU+NgcgA1YGCm3JdM44ROX8Z9ttAM6cy9m/ArYGJxxdDSfPHGd3f4/dMmaKhQLRNDRtxkbVMPUHDFWBb1s87aJ8opoOrR0htEFnsQTprWcymVFMIqoWMEhZs5Ivc8+xcxxbW2Gtv8IDUpLvbbKRuoqE8xgdtaeqeoIkxE4voF+kuRcFVV0ipWZW1ZR1jSpygg+Epllc70yvsCxPYeQtWlHRyQo6vS4qKNRcz8kLjJZ0hYqWWEoRQovzFS60CzPt1nm804S64crllzhx7DRGDHFaLh64y6srXB9tUjVRYdy7iIp0rKOjDM9ciyX8k90Obz91L/sdg5zU9LrLKGEJtNi0BpVoceWYxrcUokCrHqbbpzOqmM2lGtoeF5cM93/1t3LMf5TB9V/j2NpDlLOWWjnqWUoidMuGnLB/sMem0ix1M/qFpMijeCfAcdPlCbfJXjnhWD1AdXrxISFF7DJMN27HZAy7PTqdPmHa5fbuPs9dfZblZc/K8rwMA13RR7WCTEa1fivBS4FBUiQHBKMUsxAN65UDdIgoiTpsEAghUFUN3gYKFF1VMJMa4S1WtAuWuZegdIaRkk7eoZANopMhhKFNOoAuV7Q3An7P420NqiCI1IBhDN4kxMx3cdKC0lhbRVHSEJs0jz7+QjjslExc+kNy/NE6HkcCOnFIlneLN+IIHhXRMy1END8PsXMVIFMGpQwyy2M5c+7pJtziYe2cQziPsx4VWlofhXCnto6WYGmxti5QiejlqTzMRES56uBp0mu09+jgCEdQwMXXIVi2GPOfq1TTDPMTJZYvGwWlhqkGL3JKrbFKUiQkLEiBVQKRKgClB4tAChl9VtPF9nikUOBDxPJaTxNaXGsXsaJIc+YcCKuRecOxgeHCuXM8+fwuH/nsUwB81Tu2uf/4Ei/vXuPylZrZvsDWAj+T2CTP0zsFz16Hv/WjP8uHPvoF/i/f+1f5sb/+P+D0X+FDv/Nx7lq5F4CP3niW42vLyLDLb//68/hizPIq2GlMnuPiKqin1YISM79uEoFUetHqWye9vSwkN4QQEUScxbt51yqxccKHO56rf9T4kg7MfulXnuGFy5f55m8+zytX4wNjb9Lj/osP8u3f+hgvv/Iizzy7STGIPIEqcRiMt5gs0DQN/VqSFy23bpf87qc/ytd899tYX/8RAP7Hf/iPmYYv8MztOprEanj2FcW3fcVJljrxwtwsN9BhGSVv8/kXn+D4qYc4kT2Ms5uU42nsugJ++/rHKS4ccOr0vcj9mxTZDp964QoPH7/AoB8vw+PvfIyl3R4XMs0GW7RrEulj+/xMCQYuvlcdfpdKjvmd37jB9hMv02w/x2jtP9JpL7D7ud8DwM+eoF+skE2X0OM9xF6F2gwUBxAqv7CecC4FZ0duzOTGiku6UhmxM2ne1dSEuICtBJEpZLvYuuhYaDOwxmO0RProTadi7MlUGWzwTMc71FWgV3RQgw55N0MncUxEi3ctroHgOmTGo/U+Sigy5TEpws5Nl43bM5575nneePfz3PPoo1RPjXH1Dl0VyzWZsZSuxjuBCgpRQZ5azgunKFJ52IwgtwFjHUqWdI1Af96z9W8+xPo3PxyPq1hhaXSJPF/Bti2+7aKFJa8C0kvqRIeqNFhvsFLTygapugSdI2SOSbeccw6vBGiBMxV9kzNsZ+x0OlijU3kGKjemVtAVGVNRsxRAWNjY32CagreOhCY4cuMY7cPuuGK936OtoLZh4S0qaHGtQ5mcLMuQSVqhnMyoUut9p7tEa2eYToeVTp8gAwd54NT6aTqhQ03UV9vcaelKCyInz4YEN6NTaA7GzaKNLcsKTOKhNeWM6XRGJ88RLkRfzYQCZSZwfOk8K1sb3AivUEhBhw5BHRabolE6dDUUuUYplbgbJdY1BJ+sgVyN94o8CLa2N7m1fYNzx+7HGIPP4v2/cmyd5d0Nbm/tzrU2aRvoGOgrz14KqD535TL39o8jdIdsqLEHgr5R2NoxTvUtZQVtcMxmM5ZEF+UlPQr2e6BS2WLSTFi/uILvdXjyA/+IZ5o+Z46tsiOmyELSJgPVSe2xLaj8JGsdw6rv0A09ykZQpxs00xolJXt1TRZyvPfRdUEk1DuFEUYK+nkPnXUxQlMUkpvb13jxquAeG9H+zvIAm6+Qa0PjS7wj9vpbFwOTFDnH7sr58JFHmhTrF/yrAHVdU9YzOqbAKBElFVzsSl6UMoPFEhiYDsZolgqNkjm97pAm8ehE19A2HtF6OEgWSyiQHhckKtmtOd8BXcevRmGJ8Z/zhwHUYl/jkGfmiPvcUcX7xZijbeFOBf/0q8X3EJIj0UJAl4WzhCap+ksRddG0jvPmIq85vmYuDNyivMAEgQmK3Eu0kNEvFpDB4cShEK6VitJ6Gh8WEhdNcBgfr8n8jgmCQ8QsDUE65xSMzX+vjgScDrAy0BhoNEyNohEiCpHPO0pVIPNRdqYWEIIgSBXlcrxYyGhEvl+UpBAi4L3F+VjWnseLUsqkixaYSotCcuzkKg/cf4KvXbnAz/9mtEf83z9wi3e/YcRb3noPB9uXufxSSacIqE6A1JXpbMn6qgSp+NVf/zxXLv/f+ev/zX/PX/lz/x3vX/sUm89GAln96WcZ7yqWTyqwOZtTRz3LMdQk0JPb2zHIV0pEPqlwGB1AtChpCKnyobUm+OiNIUMS1/Y2XZvU4SnjXP2+LOCPGF/SgZlX8J9/eUZHbfB13xgJgFs3HdfUDU6fWeG/+u538dP/7P3sTsAUgjo9fOpaMOhEj0cjHMsDePb6iO8YrHDz5dt8+bdGJe5i8Hf5sf/ux9kdfCo+0BTIE6s8evEYO9PIyZG6RbKBLaf821/6LP/Hv3aa8+tDZq2he0zQTcjNk6uvYoaBe88vs9dMOd3vcH2v5ObuHr3+MgBveCTnLeosT9zeoTduuVWXnFs9j2hqCuHopbLbin4DOwcvIMptThS32b0FV3/9/Zy0OfWVfQBED0KvQc0OEHsT3JZA7AWKNhIb25QZaH+4eS3IsSQOQprnOdF1nmk2sNgsfB1JuhAJt57ohdeqqMOGBNMDkyyZFIYRnllZ0k5gUlTUvqRQM/pt3HBF3jLEoEJkgyAVUikU04jWza1B8hZtJFvXpjz/7DMMv2KFC2dOsXF1H5E4ZqJ1iBCJp4UL0AayJp5E5iTZ3G+uDeQ+tqNLD1kuUT1H9bs1ZfE8AMv3rGLFADOpEaxifYvLDT3vCa2jmhtuB0krFIQCsi4iK2i0AZUh053vRAxsrRZIOlgC+50cpGNmJKGO6OhzsuDU2ZP4a3uM6pq6AjuueP7qC9RVTEZWjcRqTwfFCMe43EOttgQqhOzgUxQuZE7HFFgRW7mdiyKrtnGEOZKnFTkFu6HEZY6OKGidxnYFS2s97hfRwmpvvyGIA/JOF2Uz2kYiaKhMoGmSzEq/T24KvLeoogDncc6Rm2iBINIKC6ZluLzOhdV72D24hROC3Au8FAtNNGstCEumPbkyCBmwvsHRoh3MTWJDcPjGIYOiaSxXr1/h1PpZVnTBpJe09bTkRH+ZydYudbJSMW2UB8A5hgkpmo4tv3ftZd55/CIqk2gBVglae2iS7VEIoTgoS9aHQ4pel16njxiDSNSJg3YLs/J2mmrCC9ub7Bx/GF0qljoy+nH2IkVBhYxe1sGcWKOXW07IIa72XHINdSq06SxjIAs2DqZMJw15z2BFCnxEbLgAyPOcQdGl6PSZdAQ9oQmNwI0tNvmwyt4A2YVKlrjGJnJ6KlGGsJj7EMLidypEpQDhQSWdwfm93zQVM9mwrDN0rnAqoKyMIq1pj2mb6PBhjabXUeSmgFyT9fr0egmVaAR26nCVRVcOV7dkpsA7S8Ah5hZwdMDUCFsTlMGLBh9cDATT3iW5E/kK4TAocxwGlkf9LY8iY3AYoM3NveN1Z6F5RkKz5qitdw7hWpSzBGT0SRIpaJmLjvoQXQuCRxD5YyoEjJAYacjTQz0XAR19pnBETlbtAw0em9ZEmNdVCYs9esGru/O0FucaYCGOmyhji3NtdXw+NFpRZmYhlTFH3oFFOdIHDyiEDHgfz8e3KTNTMq4bH6KOXXJnAJBJQiSg4nwHRx4k+bHA2eNDetKwObM8/rZ7AOjop/jcqyXXD17ma995L3ef2eCX/uOUdlIzWJ57ASvy4GnrlrVTcPXGJn/zJ36Y737PN/J9f/n7eaL4DADdX4VyMmWpXaJ1t7F1i/MtRQ5VSrhGk5bl5GTnXJzeoCxBVISgEMTnWaYkXkk8ApzDWkvrHS5Y/DxD9ZLgPATNnBz0x40v6cBsqVCUq47f+s0xZ87H8sPp04Ybm9sgTtNfhm//rvv5D7/8PAdj6PbjYi9nHpMDbazhn+jA0y9usb3zDGtLD3PtpfjAeNe71/ip/+XH+Yv/p/8Dne7LzFp47E33cG5thSsHkXhdNGOctly5cp1bN+DWrZs8+kiJ3x/SlYEzx2I031nus3r8BPevPcAzo5zVfcELtz/K0+NXeSyLGewnn7/OQ191N53ZNlLf5nNXrrBy8jgns3Wy4JGDPQCWByW9Ucu47XG9rVibQud2yWxSUsx5BzNFuGXRBxOYgi0VOLvoyFGpS3KR8RFh5bmH21EIX6aNbJ5JzjdaE2KXkFhYXcUSaK1hpogiuRkUXkU8F2idpBGSmTVUNllu1DXKVgsBwFlbcvLUOsOlnKBKfBUNjWs7Jfjm0BBZejITGE3h5vVtbm1ucbHTp5s57H4qNVUl0tXgKmQ5pWhBW4OatZimJVHRyAUYqVHOYdpAM3OI3NBXLeXvRv7Sis9Y8QI/MWSFp60sQRsqaSC0uPmcBqiFxKORWQfR6eLyDKeiThzEINhqQRs81mqcLZFFj7puybKGgyLO14U3vovu1pSnL9+KULkKNGVFvb9HdZDKd0vghQIZlaprO0LqLsEvkxtD7eLxy8JDKXDCRQ7L/AK3bm4OEJX7haNsAqX3CCy5tVQCOp2Ck8txga0MhuyPxuhM09qGIAXBtnSNY5LQPkePPM8pS0un08G3DXVdU2R5XFNp7VkahGk5vX6cW/sX2TjYQIqozzXnq7S+RciGzDi0VAQqHFMCFiNY8KGk1NjGJpcwwa3r19k5d5N7zz9ImbhVwmrWV49x6/YGk+kMbSLBdy8EuhpECvJyn/PSrR1WdM4jw7sIGUxtFGmWKaiUSJogOCinuGVJ0euie5oVATtpd3Vjjx6sI/s9GinoT69TrXwZZncG57us5DGT76oOajBE5gotx4iBRkqHHyvauTm5z1lXmpfHr3J5vyBfuh8jimgaruQC1VBGMej3WV06xs5Wj0L7yK9y7QK+miuTB+vIpEIphTAaoRW51OQLlStPExySiMp4IdBCIMShPpkSAWcbGu/ItKTQahEkyhQ0zPeZsqkJStN2BD7v0l3uo4zBJrfwzsoQ24ypq4psCuG2XWxEwsFCdj0UCFPg6hwrNRZJCO6w9Dg/enEkGDkSlN1R7uQPH/MmiDlHLYS0X6Y/VDImu8bGT27blqAtoq0RCKSN6KEO2YLPZeecOIhm7anjQKWSaCcJXnVQaGGRiWbShqQv5qFKKroGT5vshBalzNcEkXfMB3eWaOWR/4+8MqgMVLnGSUUrYmHUzAMz73Eh0CKQqYXVh4DzHuHn7LZ5+0H0SXXWRl6djM4FLN7KE1zkuhYdyT1rQ5Z6jtuzA+pwwG4ZFQvoatYOBGXd8G9//Vm+7OE1/sZ/O+Dfvz9w7Uo8+uUlx1hCNwcaKApL6wr++c/9Jza2LH/xz38dAPdf6HP10gT8kJycZrqNSNJku7EIgTEpiHch2WbFINO5FqFapOikuZfEBkwfGzqCjQm7d4cnaS14gUIhlQEmf8Rqi+NLOjCrWke3pyiWHJ/9ZEQPwluPMzx+wNWb1+mvaS7e3+e973mAX/uPz+Pm/lZZNCEdFuBUcq7fVzx57Qn+1BsewKQ78fKzO1x455D/23//I7zvn/5TxoOX+Oq3vBmtSmwdEbOm3abon+TqixP6Pc3NW7vM6hF5VpCpjMrHGt7O7FXeupTRu+sUZ2zO983u4jt6b+Lf3f445Yl4XK0Yc+3qPqMKitUlHqFLTx8wmK7iVItN3W5l73lO9DbpLQ2ZSks5gl4bSyfDtBakdbg9YKoRC/o+i8BsHtzMl0/8eVh0HHHkNYvsMsH3c/FFCRQLRgH4JkAZYfC+homLAoVjK9CJ7OAaT+kDso4clejVGY2scfHhMxpXmO42lczRylC1hjxvGE09vp5ibXyd9knA0gZ2d0dMNvcR53fJcQs7HCOjVIaqPF1v8LMWU7dkDWStQNp4lnFDiXNkgBJYqS1CCEzi99SfM/Q6I4TMaMSEXqnwOXR0hhAq3YzgGkdlG8gkcpijZQGmQ6PkQmAWJbFaULqWSXqo5VszuoMuVlnaczFTPN+/i5sbX0AaiawCUwvTg4Zy4xoiiUjbvqaeWkxfkueeWX0bKxxKKZRskDZuIl5CsLtYHVBCRQHVELBNi0t8iCahEsqqmElqT65imQUhkb04Xyf6y+yP9rB+BkERZAcVHIhAG3u9KMuSLJt3VTqU0bTOUzU1xhhINkNeJ9PyYYfzq+eZVZZSBTLkQpzTCoswkVIQ7b5qXJjhXE1dTZgmHZLRNDYvOBwSST0JvHr1CvddfDjKRxCrdesnTrC6s8HGlau0SSVUBigb6KWStAwNysHnt29yoTjNUOdoqciyjBBisOCEJzSOsixxLiCFwSAJvQI9i+c3aRzC53RXNGL1Prae/hQn7lmn3d6DWtNNXUkr+Sp22KEnKoTu4r3DqoiEztLN2AdOdvuA5dXpBhfKc+jMYESUlJkTwo2EQafg2Mo6l+WAwAxjBE2wzHyy1bI1uYOOLGizWFbKfOSXYVTkbhB/bnzi2bhYuooP1MNdQgkFzmMbj/RRINgLgXSBJoTFg8aK+D61FfSFY7m/hs88vdWMIq2VWijEumQ89fT3FdxuooBuiILVfq7SqTTC5ARpCDrDSYMLbUSD53vSvGwnDknv8zKfOwIpvRZdOjq8P/KaVCmQR148V8uf8/ucD+gQkwqZdkch4r02T4RdhInAB4RUCC9RXqCJXq5Feq8uihyPkhEJtCkoq4U/LG8TRdIXweL8uAS/L0hdIIfcWea16S8bGagyqHJJaRRlJtDeI4KnTWRi4WOlxQloRcB4cSgH4gNzadUw7wsNAevmga3Ai4CYN5a4mJ4JA/1lQWcIdQOb25eo/JjN/XiMe3vQVYo8wGpH8skXd3hxE979FctcvhgTz4990rGSGWrZkoXo8amoWD4Jv/iLv8EpGV1e3v1l38D/duk3KEyGDAfkDpqQcTBpmEySJVPX4mM1k0xFjqtrRURgg0fMJaKswweLdY7Wt7TOYm2TGlHmE+wBSaYLlM75kwRmrwvMvj5eH6+P18fr4/Xx+nh9/P/J+JJGzDZ24fwZh2lJkSh8+EOSx7+yT3e4w54LZPo4d52XfO1XPcT7f/NFAKR0BKdwwtGTgr1JTndQ8bFPf5pv/fLvYns/6mTlWZ9LVzd4/E+tked/lX//wQ9wZrVhZ3qLrtoFIOsaJvUNsuwEq8uSm6/e5vbWDY4vr9JgubUbp/j5l5/j3Y+9ldLP+PI3nmXKiKaRnGzu59cvbQKwLO7i5OkpD6mLPP85Q/+s4O0XllA3Kj74/Ec5dzLG0Zkb0eQZTu8jDhrMFHIPZszC0DV3MJ2Bxy6ywTmzJxpNHKZ8R5DlO7qOxJHfew5RsrlViYnKRyT5nugx56HvYot1IQJNA8FadBLkm9qWTuWQdZSnkBmoNtA46PVTRiYc49JzUFcgBb2OpNcVWCSugpAEX0Nj8TYQpGZvd0S5dx1//jhLvT4HZeIATh3dRqIagZp5igZUqZG1BXfo99ZNjcxNOvsizVEeoDON2VG3PcDXCjpdemEA7CNchhYdhBbUSSercQ3SNZg6I/gcZTO81ahw2BkodIHPFNQlU1eirSAPXVZEw6hzksGxyOXqN3BzYOgsLZGNZ9hCsbM7o38y466TsQT+yt4mw2Xw2pNlmvG0xHkBUtK2Bwg7TFeyRhqNkh6lBUqZmM16d2jA7B0dnyG1Z1KNsT7yirQO0JT45Ot4cqXDlZ2ccTVGq5bGeowxSCcQeXyvclIhdZRwqJqSXr+D1oq6baOnYyrFTKsZ3gnyfIkTK8cYtRab5QhfH/rz0SKVQ5uAkp5Ag3Ul1lmqekLqcufgYI7kAN6TBdi6ucnoYJ/VlbX5NCAzz/H1E1y/eZ2yjKbMhYOpWNDVyFQgKGgqw0d3X+Lbi7dQSCDrYGwyTTexFAOQSUXHZ+RZHy2WsUWkV/gdcA0UynHszP38+su/wpnHWurVYzQ391FviIhZp7uGNBZvSnpNhyAtMxHoBM9mElUNONZ7K+Ryme29BnemxVsbu9uEZi4+qJSik0fJDD0oaGeQ6QIZCoJNZuFBgQLrS0KbpTJ2iObWrUPYI0WwxOuK5ZyAaz22dbj0Gukl3kMlFEFnaJMjtEGIGuUh9Z9E1L11iEGPzE659YVXuO8b34LUB+TpcTQbGrISyl5A9DVCRYs4RMAIRZvK21pqtMqwSiN1hjSGVkaD+jkw7V7zfSGZEe7sWjzKMRNHYacjvzuKMvmEms07O2PjwZzKINLy85EQ7j1CRmmcefk+yo6k93MB6RUSgxEOHRxZQgWzENDEMjIEWhEbQergmDf9epkO5ggq6I+UNY+OOXI455aRXjtvLmhV9Omsk1SGEFFrMQhHAoliN6mM6KewMgJ/PiRuIsjEq26SfptrLcGbVO4MHKGq4UiNq7miu9TBt5qDgw32bcXmAaTGc1QwbNuWvoFOX7LU80z24Nd+fZ83PBqfi1/5ODz9XEs2NtCJAuihieXfbKj4wJMfBeDvf8V7eNObr7K8mnNz63P0jWaqGnYOWEChwoLIJXhPUBCcjxi8misFJm9WG6sNjWuwvo3NOKk7eIGqBoOSGZkZUKg+sMkfN76kA7NMKpR15MuSqU1tf8UWH/6o5+v+tKY5mHHb7NHJJcvnJ7zhDZHU//nP36aqQa/AwU7ghK5Yy+DVmy0f+ty/5rEHvxWA3dkuw27OjYOac+e7nDlvuGttnQ9+5hf46GcvAdAtBMvmPi5fyhGZ5vqVGa+8cBv94CVUFti4HetN0/09SldTbV/iwmoBb/Z87qVXOXb8Nt86iBvzNeG5+lTGU2ef4m1fdZqnPlVysvcSg9PL6Oevc/3F+MTQ9w4pMkUYvcD4JvQtFJVG1/awZDCBnDu7jywRQvfcyZs4qvUTRyT/zm/8+a8E8+Aufu9E6uYCdlUKVCaZmai+vRxg5OBsI7hdxYMYZCU9V1MAUwvWxRbwPMDevLvTSlzV8tIBdMlZPjbj2DIc031CBuMy3hRhDE0JVbDMbsP4lduMzm+wvrTCsZV4VNPNfYQtkd6jgmMwldSVRfg7yxGz1GUWVbMDbdrYFGBVnIlKGLKQYWxAlCW2NyB3IExJQKPS7MugyUVBNmsw+5Fe1wqLDH6hf1O6iipohBYMrWK/7WPbKf0HTsLJVbpZ1OXbLF9CDBTSBEwBonbUFvxmxfEH4no+EWa8Wk4YmIxR27DtRlhh6YTI9VOpkaCUUPS7tNMKdInSGS4IpG/wqTsVWdPmAdVKfGtxSmDyftSaKwImxLXa72vOHxvyXD3CTXUs0+Q9Ct+ylhbXrXyXJtR0swwlCtrKY3KN1rGskyVPTekUVowYa8lweYXzYR0dFDZ35MmpWakBIhPkUtPLl1HCk7shrT7A0dLU6T6rgTauRQBy0BPYun6V1cFyfI2t2Z1O6Zouq0urXG230Ro6NiZr84eZt6ClpDQtW3sHPDm4ytvWzlO6hkEWE8EqOErZ0GlbpjqwIvrc6EnWvWRrXgJrQWpB6xXDwSq7AS5//gs89u3fzRO/+guMRvsAnFs9RfAtgha6FlkZjCqpwhif9rdesUKmHOeXhzy9dY3NvU2yroLhkBx9KD8goW8yji2tsNRd4Va4FbkwoqFOUh+ZEwhrsMoipKIJbeRxobAiLHh7rW+REmzbMG0NjBXWz2ht8g4lBgcOR90EcqMpih4qqGj6HQ7LgY1U6Hab7gvbaNXhmd98hVOP3MX97xmiqnj/dGuY9QO99YaNl2cclyU5A6hbQJAR5z6IEqcUujNA1DVSTAlK0NjDwEwL7mgGWCSeiddxdA9Y8LL+kJqmmP/nCF+NRO/QQJn2L5O1WNeQiR4WTyY1wVmC0Ni0ulpn8T7u1yLEBiUVQMpAETyZnHfXWjIfGwRaASJ4Sg01kjAXaa0t00j1oggCmV5LOHLMf8DQsIjcmhRx1jlMcphmmlmumKhE3g8BnxJPGwIuSIRP3DIRxW2DD6BVggLie0sXEErS+CY2p6QK3zx5EgEw0O2AcCP2qyk2OFyVIyYtzdyN3rZ0haaaWQ6s5ZiGtQF0CnjldpyHlVXNO97c8NJzLbcPoNfX1M6CFZTDwMb1+Mxuyh7nz3QRjeCgv8TSygFbY8tsAiqRjmsPAY9RqZuVEGV52hYpM0JdplMMtN7RWr/o1J3P6Tzob7xivbvMIDtOX60Bz//BF+S11+ZPOt73vvfxvve9jytXrgDwyCOP8Hf/7t/lG77hGwCoqoq/+Tf/Jr/wC79AXde8973v5Z/9s3/GiRMnFu9x7do1fvAHf5Df/u3fpt/v8/3f//38xE/8RMygv9jROsYz6K9n2LQxO6GZlY7f+eCEL/t62N0bM+xFUujJc1Fg8sJ+y4svjGjrqOWT511kDlk246Mff5J3XozK3+NmQKDllDrOx7/wDOfOCW7dvMGP/eMnqQ6SLRA13/SuEY/c+zgvP/8MtoYvPPUka+c9fmb41GdejcdqHPvVDv7Vy1wbHHDs+Dvpns9o/ZgT98caubpl8bOaX/1/fZKPfXXNwa1dbr9oOX5+yL3DDpNpJAnr5+6HRw6w1y/RedXhZxazZ7HycMHLObeC+GXTlydxDNLiuYNjFubk1oRIpWmWR74LIscsl5H8r1Ukv0JSvxYenUXOgW+hJ6ASgcLFh0G/lizVsGpjk8DUQNOJG2h6XiCEpG08eQCaloMtcDMY9SYYoEiiaOVEse9aihLWNiFc3aLe2EafFQgbN5GsychmFj1pCGMINooZRpPmw415fhPZcHi+c05em+ZKzzmKUiKVQlYxS1ZOIDIWQZfygTYEFB49nhCEpj0wBFsvjK3DQJIdG1ALj/YFWpYMjg+ZDXfodk9jZRQKtqJCSugUKRs3Ghs87qBkvBkRmXe9/c185OlnuHGwF+1ibINvPXtoer5PJVMjhCuxXkQeSJAUeU4nN1jfUqeNRuYBpTvYqmZS1bStQ7mAVIEgWpo2Bnm1kKx3VhiYCXtB0CvGTOoZuT6BzyNh149qrDIEk8VF4tvFM6JtW3wiRGm6ECxOBypfU/SXGQ4H2MYhk7CQlDFoyLsaG2KW2u3mzGbgW0+dAv+mitfUunidvPV43bJ7e4fJ2bhH5HnOcDikVYoTa8fZ2NvGOqhxERE+AjNIH+gpgRWBV/Y3OZn1WS4EtYzrq6aDkR0qW+OmI2xnCWNzZLegiCBXapzok+V9emaJ5eMFL3zo43zH9/wFXlg9w7NPPAPA3SfuYir36fod+rJLbRx75S7jumbq497Vrh7DIFgtuowne9zc2+b4sTU61hNMRGcgdkcHKciyjF6vR5YVQFQg93MB0BARY+cc+IAMKnaQBYURHbrJ5aHQFaNqhhUhyiIIj0djW7m4hnPxQ4dAGkluIu8y+EhKn2tIBhGYNJJOJfia/+qtrL1Z8fQvfZDdjy1x/1dEbcj1tz9IfqHH9CN7mH2YrRT4nT2MCmhXxGYeoDUSoSRBanRmqLVBKAOyPgwEj/DK5gHVfJ87yr8K6T/zSy+ORDT2KJyW3mMhYjtP5sJhF3vrHdp7GmfR3tN6hxCB4Bx+YckUFtfA+9ipGnzk0CkhFiKlGo2W0WoJ4ZJQa6DyfpE4tyIG4i697xwpW3RoHjnPOVo2/30k/gva1HjRJvJ/aRROa0JIWmRY5uGCEC4xi2MTiPUhBiUh4EU4fH6kcyJ4tI38RxvS56YmFRFScm8dvokC5jZAO60JLXTS8moctNYSQV7F3kagdyJK6LR1TA6qieTWBO59KwyuaJ57TtEbWso9QU8E9tLetXNbkBuNbzKCsyz1lynqKa6FIl1E4cG3MaEIOvZsBQGthRDsQkqq9TGhn3fCSpIYb+qiBdDa0Mn6nFhdpyMH/EnGFxUNnT17ln/0j/4R9913HyEEfu7nfo5v+7Zv4/Of/zyPPPIIf+Nv/A3+03/6T/ziL/4iS0tL/NAP/RDf+Z3fycc+9jEgEoC/6Zu+iZMnT/J7v/d73Lp1i+/7vu/DGMM//If/8Is5FABkDS2SyayiYDt9xjK9fsa1Gy0rTyvuvi+agXW7GbJInpRnVrl9q8JVlrwPVszo9Vc5GMPedccHPxq1U9782EPc2q5ZyRU77QFvuPsMP/mvf5PxgeHCeiL/WsVHPvYywow5ffYUm8+MuHJ5yo3bt1garvPys1cA6AB7k8AWz/PUzpAHOoqbB4q7TpyhV18F4Ddu3OQTL85obgd2P3KJN55yjHcDertmcmaJu85dBOB8725e+J//HcgCvX1Ab5wxqzwyt/iknTRfLD4cBmVHO5KO3qzhyNdRbZ/FJsVcmyeicLmM3nCFglyBmXs3mRAlRQyQGcppi1QRDauSkvUUx1oL0xaqOiF6BRgdIXQAp6LWTW7ABIFVimriULOch46v8cBdbwTA7k/4+BeeYKOc0AlQXbKUl67BwKCSXEZWVWTjQH7Q4sdRqFUw7y4/1GGOZYX4N447k8yFSvicAB0cIji60kQ0REjwCj8XMvKx3NNOPaLaR80U3XxAVVmcS4q2K5qgMvKOonQHYLuEUzsoucawOMbWbAeAttV4sY/saQqT0Y4arFJoAfv7SewVx3d8zVfxnz/1QZ69NmW39DS1Y0UaNo2na1NpwSrapokEdqOQUtK2LY6As/FhpzuaPLT0jEO4CtU2+HJKmwcwniDnpN3YXHBsRTGZWdqZpmMkNkxZ9zERm+YHjKYjQgj0O/2onRYgWBdLZiFecGUE1nrsrAGjMMZTeMPx7glEQt+EjB2A1jZoU+CFoqr3oWmw1jKOuQ22gkzEOEGqDB1qhAowran2Ynm7c/Y0bfBkwNrKGv2iz3g6IZMySjmki1+p2G3YCRJBYHM05YXsJm9eXUEm+yBNhrUW5SwyWNA5y/kqO9n1WPYk3isn1s/SUJPrnNVj6/z20zd48bMf5763vZXP/vpvA/DUC89RLd/ET19kZXUZ7buMyj0OtiV1FptBdpcucEIYTg+WkLbl2sE291dn6TZdglTopF0TvAdd0C9yBr0+MjO0tqaTzpG03mMA4lEyooNaBnIj0CossjctFZnWBB+w7by0JXA2IFIAIRV44SEojDH0ej3yrKAVSXgz1RKF9ywt59z91ecYuxvI0S3EMc2l210+9rdiqelr/upNLtx/EfGZA1RygFgJGbYKYHr4OXk6CIJQCKNQeYEuClSbR2mVVF5uw5FyZeJkzBOu18plHO3ePPKfO0RaFzIUR3/GndQP61Kw6xyEWOqWUuK9O1I6TUF0CHgX51qEgHQCFQQ6Nb0o7ciEIROxOcsBpQ+U8tBiME/n2YjDxq35ubjD04g/587gLBAJ/NVcg9HAtNDMCkWtJcFG4VpxlI4e4tkKZAw4F3MjEMHfGXSSSpdGJvu+EKsVc81dE58XTgmaGhqbLKMs5IioHQa0dQzOMimiVE3w7G3B8nEWnfXTXU13peHKVTi7arn3Ic8znzcsd9vYPJOacW5PZpxbW2ZUek6uH+P6rRGFVnjnol0UqdM2QPAiIoHpeFsHTevnywv84VoQYr6uYqA6/5uBkaggyVQXk3ye/7jxRQVm3/It33LHv3/8x3+c973vfXziE5/g7Nmz/MzP/Aw///M/z7vf/W4A/uW//Jc89NBDfOITn+Cxxx7jN3/zN3n22Wf54Ac/yIkTJ3jzm9/MP/gH/4C//bf/Nn/v7/29RQfXn3TUbezyWytEbKcCXNky8fsUXcHzT7ccO95FLk0RZcC6GLXUdo/H3/Uwv/1bn6WzrLl11XLXRUe1aUE3fPCjEWq8sH6alXOnOJjU2GbGeFfyqU9cYkkIDlKg4QTkxQof+Z0dHnnzhOVBj/FBw8aVMebigPE4biJdA209hVmfT10ecfb8Ve46dZZy0vLzH4th+vs/2uH8ccf991xlb9dxeqXLQVmipiXVTsWlj38OgJU/XcO1a4yed3R7MJk1yACqkomLEPfB+eYzz5zmQZrncEM6euOG12xAR4ckKv5nQVCEqPjfM2CMIJsLmCbdJqkDqhebpb0TdF1gkiKgboCBg2UrGNu4wbgOFDm4pBvTqBiolU20gxJSkZeOb3vgLN/43r+DWooHurn9aYpBzXOfeonJE/ssT6B6eZvpWo9hkuGX1QR7UKFKT6gOoX2fNt7DTeWQNzLn0aWXzi0wYxbrA8FFkU0rS0KQSGtioLfYdUG7Dt5qbDlB+hrdUciyQi+MbT1NUSLXDe1on96JgqkZ0Vfn0W4DW+/Hz/SB3CzRyScoHTuNgvboRtDzsctw49UNTt99jL/03q/hf/7oJ9n6whY5ATqaonQ0iRyitMAYTWsr6rqMD2Zl8K6hrhOq1pdMsMxEzc54h1k9Ymm4zKxtkS2o1BHbtJuIvENPDegM9qgYE0pNqxpCKv2eWD1J2zqmzYwss/TzHiQrmiChTlZReVBoFahdizMB50s2d6Ysmw6DpRjkqdBSZDkCaJoGkWukbsCWVL5hkqRWbB2lTzygsQgfCBpU6xlvxeStf/I4WglUntPv9lhfWWU8neAEGHHIuLQi0HowLomFhsDz4x3WjOZiJwZm0lbYpsEJSeUtS1iEiJ2bJk8emL5haXUNQ4XJJdoqdobw4d/8IN/2N95A0Z1TLD7LhrnC1euXsIOWC70hK12NafqYE6m7+/SU491VhkXBWrHEtb1Ndkb7DDsDZNE9tIMh4HEoAUu9PnnRo5lOk2VT6h4k8rGUVIggo1yGCMlRocYmY7+ARQi3uCcCUcHeC4+apzYyluGk10it6He6ZEUHlwIEuwhIoOotMbt5g+c+u8v7P3+dr/uehzi/us7gB94FwOXNPZ689Dz3XFxh92CbFfFG9o0hywO6OqBVMQoXoRulK4JE6gxlCnTepWmaCJkSy9FHAxFBSlrTHnCU73TUHeC1umaL/xV3JraOyJ+bVxMAaucxPhCsjd3e1oHyh0EfkZPlQpKX8A7njvRrao9Igax2kKtAJiRaCryLZcoyfQEUQEdE5Gx+dY9KghwNPufH7okInyPJGyWkqMyg6hjKPKMMAR+ayEoWIdkKzZ8ZEiVEIsbc+eAQRyYvpJNWwUcUNx2nMfE1JhMoHTAEZnPh75yYxPlUAoXoIiBj4KZEoJt30FlFOZUsrcb3quuGjQ04dRfc2Ck4c7wivNHz5GegtxQdJwBevHqdt7/hApee/T2y9Q6ETVybSpFzio+AIGVEAMXRORQJaT6yblKQppVAImi9Bwl6zqewDdZWTKodZuzwJxn/H3PMnHP84i/+ItPplMcff5zPfvaztG3L133d1y1e8+CDD3L+/Hk+/vGP89hjj/Hxj3+cRx999I7S5nvf+15+8Ad/kGeeeYa3vOUtf+Bn1XVNPVeHBUZzw2UZF67zCpkePrLJqWyFdjWNtDz/jOft7+gwbmYRkyTayJy+t8uD99/Dy5uvkEkY2QOOHesgsh4yoRqf/O0X+c7vOcuL4+usrhhevbKL25P0TwoO0o3/Td/wDm5/Zp0rNz/LU5/Y5MyJFcbNlOuv7HDmwkmm8/JcLgltTdZd59Kre7x4bYc/+8a7+X98NPDBT8TjOgl8/TsU96kl/tW/nbHxas3xxtPtSnqtx1/eA+C5f/EJsqyhaAv2bjp6QmAIlAQK5jdPuBMJO5Ih+XAnMja/UV/LrTiKGkkRyf4Z8YbOJBSdeGPpOVLUjYsWCVlX4kWUcZAuWjUBdC10IdohNdCT0E5jWVTMs5A8kpcz62iVRll417klvuW7/joPftk3EUZPAmAujxnNIMdw019DP38DfbllvHqd/vFY+hF1C5Mp7SS1eB85uUVGyWGJQ4iYC2YAIfqfGTW3ptFoGYNfTbQZEenvpAsLQUEVAriAaWJ2VbsJjZpQW4dPd7gQXbKVnPHuDN0oXHGZcVXSKXdoZgUhcStyEbAdRbej0FrSyWQSmYRBQpM8cOPGNU4//mb+zrd8Fz+1/3Ncrzd5VJ1ibAvQiZPnSqQWCC0QSiC1QGqDcA02BWY0Cqqajm+obl7h9vXnkHLGgW/pmg59osqx84LAPh0R6LWBie3jbYNvK6qkGL+c91hdXsfv79LUnkZYjIwWJo212DmfwxWYTGHdlNHM0u0u4+ualy4/jzwbj10fM/QyRe4jn801MzJladspu1XN7n5KlNJzQhogWJSSSGEiV2xvH4D2YES+toT1FmU06+vrvLp9i1Hd0heCnLmECguuVUCiRCBUcM3WDNKDhbpBBcVQaKblhPqYQGeKTqeHTgr1g55nsLIcNY+0Y7Y7ZmkIn/z8Tb7u2iZL52Kjx8bzG9wY7/DZ5yfYELgyrDi/AsYYOpP4XhfveTMP5kvkec7xlRO88OJlbu3tcHJlHdlUaDmXZ4hzoJVgmHfpDoa4ep8sK5BzVG1BWZAIkSFkjguaxoIPApUM5JVSERU0IGSGFB7vFUK6xZ4KHiFScIcmNwWq0FFpX4TFg00GQV1usnn6Hu5dHfGGcw3bL19iNr3C1/ylKGegNy10B1Srlr1XR0wKR7EqKXYi1G6I97aTFSLEYxVolC4weR+dSvAA2HaRlC68S9MNP3cKSEe/KANCIsPP/zHfCNNmOqeGzH/1WoJ96zxt2yK0Q3mfGqcECHmo3ecs3nlkkgGZS2kEomzL3EDdBE/m435rvKAVUUeuCTBNUWVPQhVIqFriFh85fPeavX6+5znic6HKYZowkVmuqIqMNtO41kYfyIQJinmN/8gJi8BCGkIksrJnXk6PAUxIqJIQccKVBp06xrSOumZiHjGHgEaSGUPwgjZBWD5ESRadBTQC60syqRDakeyoWTsm8FaxecsyOFVxYwvuOiG4dS5n53qI5EXg+RdeZPl7/hy6a+kVa2Tdyyx1uwg1PdRzVKCFj/IzmkWmLlLzwkJKaj6vId5P7uhDM408O02vu0zZjhlPbv3+F/wB44sOzJ566ikef/xxqqqi3+/zK7/yKzz88MM88cQTZFnG8vLyHa8/ceIEGxuRC7OxsXFHUDb//fx3f9j4iZ/4CX7sx37s9/3cyUj+ttaSHHho7B4y5ChRJ5JhhysvCs7f28H5SPowmWZr6xLf8J4/zT//X2/QOHDbGd3TAxo1YjnZLe2/epMnf/d3yd+0iuloPvvMcwRtGdfQSejO1719jcmFe/hX/+qA02cKXr1+nTyTbLzs4fGGpBNKWXpaIzF2TCjh3338Kjd2e/zqp+C0XALghnqJbK9PONnltIKXfxfUGhSrntW7gRPxIbX32WgMXKW6ZeZjj4giMNd6nXMI5lnT3FJpfsP+Pn4Fv38fOvpdBtAiLIIyk0GWg+6IRZbhiwh5e+/BxE1Yl0mROZVPOm2g52DgYWhjF1xbEcVo08EXCqRyLHVASUdoHV/96MPc9bZH6U/ez3QvturIjqZ7/ixLZUl3/4CmmZC9MMK+XONFDN7z2sEoQHN4fna+QYs7M2Mh5kTegA4JLQuHHLpMCrSMZs9aCQgaLRQiBIKtET6pz9uWYC2hdYS2xk4DTXAIKSPBGhCipNf2CXWNL7YYVy1i/YEokCrGkB7qKrRkqsHJClMY8gNofA0mMKvjei5MB0LG9uYWb3zoXn7w2/4MH/vClHY9p+O6NLPYQexEjgujqFeVZchZi0Cl8FIurrUVlrra56UrO/zax34F/6zm6sEGvU6XU8uxs3G9s8Rg3XD3uTfR9SfQ+/v4oUNVgirh/MJ06PeGVHXNrJwutNKs9TRNizPxM2ehiYLHUuFsS3AVWVYwqUueufwsAGfFScSwhGDRRoOCcjxhNJuysVczTptz8IKQhRiQBY/CAD28VIREFhxv7bK+vkpQgk53yPp6y/G1E4xv3MCrsOhQ9URR1oikBnIbtflujg8411uJa8J3qVxA4mCyy5mxJe/mFKqLThl6f1nRP7aGcuDsjNnBGCkk18aOy5/4PU5+83fEg3+h4OLqCS6v3uBglLG/09COAdkixhHFf9fbtwkr99DRBetLq7S15cbeFvfV58nqhjoFXSozBEBrRcdkdHsDypFBKUWed9K1VhFRFIFA7CzzOBARnQiJgG59TRsajM/iE8uLaEkTRPw3pM7DIlrayIJeN6fTKZK59hGkqM5YLe+jujzg+cmE/uBeynCTqQl84YmnAJhOZ9yujhPO9JHHzrNzYMmrBjNqyVROmxwvtNKLbkAvQOocbRwq76FTYNY2DrxfVAriwaZjgTtKi3MUDQ6Rs6OvP/rPeWAmf/+vkUQ+lAoe6+JqClJFZGy+tpzDOxdNx4PHeheDtiAQae1CTCaNB5M4q20IOCVphSeBxJQilsvnCJqABb/4aKBwFDELQA3YDKoCZqnCNs4NU53RIAlSIJGI4BbuD4efQOLGRY6cSB8QOMIxS0FmEOC1jNxWHxuHzZyPm4IyKQVSpeRfGxSKVhLLiAAKvPXIoKlFS1GAzAK5PgyutzYlS8csbWW4fbXl1Bm4sad4+L6Kj80KzEaEBUfldSYzA7qmUwxROsr2CRWFZQHqEOdrbjsWZETGXus9Gnzkn80DdJGONwQWyUHLNjM/w40njEZj/iTjiw7MHnjgAZ544gkODg74pV/6Jb7/+7+f3/md3/li3+aLGj/6oz/Kj/zIjyz+PRqNOHfuHL1CUTUON4M6Rfwur/GTQC0sRT5AaUtwS1x7ZYd7HooPu+GyZXdvg+GjLd/3XT/AeGQYb1/jcvkxVDZlt4xvZlzFJ373Zd6zchfliXW+8Pxtsh7ICu57VwwoJ+2U7to1/tR7TvOff22LleOr7G7tMduqme7vcPHsMgDPvbiPqMbUMiPrSK483fLzz+xwcfUC23UssdxTGC6Y41zffYGTJxy961BUUG9DuFejVuN2MADKW6BKMEJTYpkXgeeEUHfka+FzyZEI/zVzfNRuZB6RicN7grS/RgTJREFYWYAs/KIDzhpioNIGvGsRLspNlM6TpaCrU0PmQPtAx8HQg5+Ay0E3hx/fakUrHBmO3TG00qN3P8luVXMzdRBOx2NWJoawehF/n8buPY3deQq5MSWkcgcHRLJQIna/VnBRLjL5+GVEampI550rgVmom4MWEkVAeIcIAqFit5Jz7cK01jlH29aE4PBt5JAEBK3wi6Cw3WsQ23uIjiUMJL6NXXTtaJPKO/Tam+NnZhXtdEwmAqqb4VWNFDFrzKqEvjWK2kPZTNnausbafSd46yyjul2S6ciTAlAhp1P0KWe7tK1F6ywq81djfMpMNYZRPSJIwXQCH/v0JrMsbvAqTPhMG0vzQWmq3HLXhc/y7kfeyHr2KL4ZIjUon5T/JWiv6WY5WBu5bMHi4//RzOJxCd2AC2RFn0LnNNUYFyyF6bA7i5/XXr3CyYsguyVVYyk6FmcF+7OKzS2o0+XOsoAqoDCD6KfpNLrtMpOSPGUQ070x622gNxgipEEoxfmT++zv3KJtHfO+zDrym9E+BusyiVHPZoEbs30ATnUzmlGLaz1OSGrb0rUCZcxC5PjEibMUw3WkPGC0d4XN3RbR0ajC8sH/8hH+66/78vh5+ZBVeYJH7z/O51/apK1gNIvNCC41z1zb3qa6G7pCsdbv0VGG61sb7Iz3We6vIrvz1F7QekcmJT2T0e322BYqGk/rGJjNA0eQ+FBjXQx855ZMcxkQ27bRDcTLROgPeB9x9rBwZY6ixSqXaNOlrw3dfgelDIjyUGaizPjLf/H/igmf5J/841/lbRcukL9F8erBmM2oZsT5U0NMO+Hg1T1ujHqcLSaUjaEsltirAysJybNujvVHRDMIg9A5Ou+h5uu5rWMS4P0CBVt03Yo7m3+O8sdeS+WYj9hIIHAhlbjCoZzQfGigci3SO5xrkanEFRwLfp/zjuBj04XzNv47RDRG2ENyvFJx/SkvkwxFlNyojpQyqxSUaUlybnlNdeQ1xz//nZAwKWIX5iyPbzbtaiZZbOLwSiBdRFRDsEe4iTLOY4jB0jxxR8xJ7yK9Lhx289sUtCVITyXaSyYhD4JCGXQmcUohtMCQI62dg/1oGbvlkS1SgHMSaz1agJv7VuaOjRuwsupZWoK9bcnywLKl4Z0P1nxsNx7X5qvXmI0V6yfWqSdjFDnrSxqTTUmmJRSD9CwQsbw5D2p9SLZa8+fGEd6Lf+1Ep7Gzd8BkeoAM4Fr+ROOLDsyyLOPee6NL+9ve9jY+/elP80//6T/lz//5P0/TNOzv79+Bmt2+fZuTJ08CcPLkST71qU/d8X63b99e/O4PG3mek+d/AGnORz2yas+xth43m9J6QgNmSSPUFKkyss6Ig60+N67Fm/XifXBrx/Dci5/kq7/iezmuz/K//vKnENNNvDGLqLgUcLAFv/nrW9z/FYG9g8DyWqQvPHA2vtdo2jAdX+au8wXveOspnnxuhlASGwJPf34H1UuGuxL2bUHXKOzeCHJBh8Ctcpdx6ig92N7l5pvPwqikOwSzCuYGdLtQNJZZBA+o9qMEAGTI0GAjDRNLoF1k+4dI2TwgCykj/CODsiNjnhxIUlbjQUiB1hJlAkJ7gg7Mo8IgiUfiJK5xtHV6qNUBkxakSJ0u3keErNtG/kmZQzd5gbYT8EJR5I6sUFT9wI1py6XdnNoMqPYjAjS2Pex4l97BDUJPEC6eZ7p5A7NdE64nzmGIKvCCQ0T6aNYz35jnLe8mxFLtIijTGpPKQ0pIlBTRTsg5lPCE4HE2Wg0tNJNsiARkFWjd/BoE6gBzn4S6cbRbO5iLPbbaGlkNCdMJV588YHO2y9mvj+t5eOpRXCvYyUt6A89e2I0lWSUpUx9BaAKDvMPsYMLOsR0m+1uI4YPktwTBC3q5WVxnSo2WObnxNG2LJPLm5krc2BrlWqoqdq+2FpSBXi5RwS+M2pupRVbw6qfhZ64+yXvebvnytXdSDpbxaTGVVYUmoELkMs2aCi88RqqoH5cCuLZRtFhqu4vp9BFkVFVJK6dU846QaUUxmpB1DEHCdDpmPCk5mDXs7y7eCt2Boq8ZZqt0ewPaIDBNl4729NN7lWXJbDRlZWUNrR2drMv60jG6gy6T/fEiITHpQa58zKbbLN4oqoJXk8TFsOjQlxmjFkZtxUGYsa7XMZ1soSl43z1vw8hlWja49OJnuLkBnY7FCnh6E7Y+H7syVy48wPYXxjx0/4OMxwc8M6tp9iGrYJoyikubO+y7wJKRnOz1WF9a5ebeDbZ3dzg1PEGnNwTAK48MEmMM3aKgyLuJgE4MljjUm/IItNQYmaFVB+kzZOiQ6dhBlpua1iiEt3GfETbuF6JFiHkgaFEalJEonZORRXeHNI8LVN4VtPV/wZzZ4Tt/5A1cuXKZ3e2K08uKF67Hk9wdWcqVGXmpqKTnZTtGZKcpVIbBLCy/dDAIAlLGLnCUxnuDznJk4isbYxCJhI8QOBHi/gN3BGkBQBwpAf4RwZkNYbG/zrll8/+HeL/XzmKcxbUWoaLPqnCH5WPvfeSYJW0v5xw+BNT8fkzIiwgydvoJAdIRksq+JTAn9zQy8rZaEVEwzSEq9vu6T0NMsgKgckWTO0oNbRbvjaaT4bRK5XsgxKAsRq5zApZMXa6xg1d4eyQoO4KYJdRJBWhl7L7sGMGgl9PrJ2khFVBBUaiMgkCpJVZG31hItBCik4U3QAtBK7x2sXPSygWy21Zxn9m87lg9B73csz/K0aEm68B3fWN8z//0WxUvv3SdM3cd46VXL7G2fIp271V63bjfQeS69Y3AyBARUOZcRRkfhOm4tDgEPubIGT6BGSkRVDIiyjZA+1pk4A8Z8o9/yR89vPfUdc3b3vY2jDF86EMfWvzuhRde4Nq1azz++OMAPP744zz11FNsbh4KrP3Wb/0Ww+GQhx9++P/bQ3l9vD5eH6+P18fr4/Xx+viSHl8UYvajP/qjfMM3fAPnz59nPB7z8z//83z4wx/mN37jN1haWuIHfuAH+JEf+RFWV1cZDof88A//MI8//jiPPfYYAO95z3t4+OGH+d7v/V5+8id/ko2NDf7O3/k7/LW/9tf+YETsjxmNBuMdO7vwzlOR9HX30l382m+/gBYNFy+cpucnhNBDrdfcfiXmR0NdIMyI7fKA25de5aq9QW1KXDiFcLt0ZMxFZipn2m3ZHkkOnrmB7wl2p4GTJ+DkmXgM7cEBM2MQe44D9SQXH7zAs884VG+DjVGLS6aoug/9oksWYJT1qH1A5gJZ1SyndsbdXfjw73yOr3kU6rLDqV7J1EkKrynahvpaqhluxrZrFRqCiNyyihARmyNE1Tlnao4AzHkHYa6ueORnR5n+ijtHIBKhJaBkQCkXNctULAXOy3NagFdRjyKUoBpwTjALYZHx160g1FD4QEHM8IbOIFW78OYLfYEaNKzmMBGCfuFoTcFN78mbA6aprNNOJ/QmU2bjCrmzSbcMLJ1Yoj03w1/aS69xaJUEc/1hW/uiapvOOwuRNGuIJcuulEgpI0qWUAEhQtJwihMiRI61Hte2+OAWpcxARAGtBS/BJg6gA5qUv1oFcgriqketgziXc+33bmBHLU0jefbTWwC8+/sVVdthMOgxWVYsL61Rb23T6kOqcrUzRk2WqbRnf3ePbqaYVZdZzd+Er3aZ2tSUkHdojEHommAdeSEXhGcx95q0HpdLlpYL1qxjy7Xs73jaVDZoEjq6kkn2pWd4UlJNPS9ce4VHjz3KqjzOJIvXp5k2THxD61rKuqXygmA9Wa6w3kViB5EHJW3MtCvvcK6hriZoedj9dMPuUYxPcs+pLlJVdITilpgymSm2blv8fCdTim7nFOvrp+nnSxjdjby9yuGnST3fB/a3djh58W583mUymxGU5vhglf3xeNEynwUQGloNjY0dskUBUsM0IRrXDna4e3AWXznk1DKqdllavcCtfEhnKXJH77rvjQQ/o529zJMf/hTTYUYQDT0Le334wsciFeQb/taXMXrxVbL6OI/e9xaubX2Cas5QT5n81Y3L7NUzzmaBbr7MqRPrbF+7wuW9LS6enrGePDzzfLBQ8lemYKVQNEJRNg3BpnmoDFWvRdEigkzyGQGho5gxk8SHlAlzl5K2HWONxkpLbaHQcT1nvkCYnF7o0cs7ZNJTrAyZBk+/FuymTejU+S0uv/wf+dy/LxiuBc6/qcvw7oZbmy0mlWFtOcMAYgXW7jLc+MKEN51oaMWUWvfJ+sm5oHRkSfNCobB4Cm1ofIHIEvxeOKT1KGshhEUFARI/7DUcLDhEwOZfczPu8Jpawxxtd+mPFqhgKm9iW7S2KKsxylP5KJkB4IMHF02v52A1zlPrgAjx5xCRMukdGZJOgFIKjIslzznyX4WEmhEf6KmPMu7j4fAh36RjDSGq7c+WHKWEWSEZJ4PYGQKsI5MB6V3UIgsCH8Ti2CNKqFI5tsHKaOSt2kAIYsGjk8S1Ky3oIJAmsLLa5dSJZQYJobPWYkWgU/TwfoZ2lsZ6PA3O1QvR7hAsTSoBCe3Iykh/8R0fXU4AUwasEzQhcPN5uPsh6Oia0QRuW003j8/1e98C2VqXotvDqCEZU8gLuh042I/zkOV15ElnmtZaXAHNJFIa2joQRFqDooVwqBsqhMSF+KCZe3PbxFNzljs76v6I8UUFZpubm3zf930ft27dYmlpiTe+8Y38xm/8Bl//9V8PwE/91E8hpeTP/Jk/c4fA7HwopXj/+9/PD/7gD/L444/T6/X4/u//fv7+3//7X8xhHL5fMHhjCT7wwksxAvqh734Hn3/mFps7Y3bHU06dz9irKoaui12OhPAXbsG97ziG2R9za+tJ7jn7Tu4+do6t68/RqJpmDvGHmr7WKD1mNHagAlkDJ07k0CwDMBI76HKG7dzNpz62z+BkyaneOqWDXjBsm0jQ7xoYyh5tJqnliMLn+FKQmx4ybZQndODGc3t03rFGseTZ/F0YTjy+aRi9As3W4cYgUklScFie/H0l7iNBlz/ysy92CDhUyuZIwCcjr0zqeft9QDoQLqAV+CTMqQIkZQRcHVAeciHIQwzOgmtZnQnKzeRs0AO9IvA2IIynVBDqEa1XZMIQmsj0bqZ76OkIN5viR2OmOyW9nSndEMg7MYJoy5LGHnIdBIe8kHn5ElKnKTEoy5REq8gfUCq2hUMSlkWlSZQ4XxNcFDt1zi06epw/LNe2HJaTw5FuA+egHAWmsxn5qmDzhT1mNy1uDaxTrCaNjjyDsq3p5esMig32V5bRuyOqpmHu5VxNt6maZTKnGe1PsIWhdV2EENSVQ8lUprA1zjYIoZDSYetmUdayKXjLTZ+VE11OnjjHeVZ5ZTTmwzyNGlswYFKA3QqP9XHz6SjB7pbl2VtP8WXZCjKJRVspkW3cxINUmDb2sU6aNt678w4167F15LDY1uNDS11VBFfTSQr706ri009d5dSZ4zzYyRn5wGxs2d6yeBs17wDWhmusr5xgdWmdYWeZXncJoQx15SF1c4+VZlROGI/HdFeGdHJNnhsG/RXy/DpzqyjhFDh3uP6JZc0sV3TS07SyNdvNFqeX1mlmU+qDA5qzNUpmLKfOn6WzZxBWcfnqB/nEJ2aIgUBXULUw6MEXXo7r+fGrr3L84mm2Ll3ixBvWePSuNZ4OO9RWUieVd/fqDjcnG7xh6W6WdMXaYAknYWdnh1E5pp0bw1sXxVdDQGaCXrfLoNdnOr3NpIqEvKWiR5EM3zu5RMqwEDx1zkUbJIgWTC5gjEpBSkAEiRbykKPlowiqyBylnnAiX+OsWsIUgWoW0OneqBrFbneflYcHrGXHML6HDFMGKrCV1uB+HW+xQfCMVyfIbpcb7jZ1r0HXHcwsBl0yKHwQFEgsAec91lsQYJIxvGsrdJYh2wbn7CL4mnNvFwy5+Xlw+H3e+LEgoAOHUjvpdWlfdOHwvSI/VURpHO9QzuKbSM+Yr3kXPDa4yNFKBHoIuDY2CCyM0x0IKaOcidDItNPXMjZQx/eKx1urlFSG2AQ2jwHmtKZFN6YE35FUmafRgqroUpskmK5EEnAWtC6QSY+zUSQ3HMn6vbWE1NggRTJnF/EpM2/SX3DdBExl4FzP8MjZE5w5u06VGqV2RvvYsqWQnka3BBuwwuGdwOOw86OWieuVQTePkmTCAg302/iB+5nBVBaZZVRNw5WrGefubdAHoMuWnd04YUZ4yG6wemwJhaE3PEbf32Q4yHk1PYOGY5gWAYIl7wBeo4UgNC2FlrRJwDyowyaByFH0hxxtcfQaptfxJxtfVGD2Mz/zM3/k74ui4Kd/+qf56Z/+6T/0NRcuXOADH/jAF/Oxf+iY7LXkS9DLDTuX40T9zAc+DF2JOiiw5ZSt8ZCT3S6ln6E78XSLgeJCvk73ZIMdWSa7+3RcD9H08M5SJ90GqeIicLlhb1bTsxLd9SwNYW8r8pzIO7TeIvMN3vGGu/mN/3KJwcMVYgmmdcPasRggLOkuHdmnFiM6wx7BSUwekMLGbicgExXLVrG5MeZtbwkUDxpMLRDPNNguhPTknwdJgbQYjgRe801jQYt9DVdiruwfxJEscAEfxW8ODn3g0t8IIRCJnCpTPR3u3JCEEEk0EYIUBB2wGoQViASrGT8XRAwMUut0RwBeUiTEqZxAux1Tvf4SzCT46R46dNBWIMcxkPWjA+q9ffR+hdkZ42/tIfZK3GZFs5cQQRs3KynFom9cckiSnat1d0g6bRKMVigdUELGbqEUAUkhIvroYmu993UKyizOsgjMWiJKNg/IWg43qHk0pZLej7OB6nagvmHprcKslbiJ4/jpiLaE6RiTCbJuh2tesr7UY7OjUJVAJ8/QrYOSuyqBFjnT8QgokLIl6EjUVak7Q4eIUOJD5JY50CoRqeeBmRC8/ZGznLvwJs7bs5zaGfNiOeH2cy8jFAvNHp8WimtAGo0dw9PXr/DAykWGw/vT/SMI0iK8pd/p4lXLpJwwq0oyrRYBHN7iXKBpWkJoCMFRNy1tUzFJKu+NrLm6N+YDHxxz/L3nsXSYTgz748hhMcnLcXlwjLPrZzm9doal/iq9/hIoTTlzhG4MzLpZzuT6FXZ3d1k+tkKnm7PkejTTE/RuD5imwF9LibRu4SRgHZH/gqNIN8DYBTabESeXlxgOC+qyoXQtWnj6JuqTrZ07ScgmfOrXP8CNBrI1hR1HTlbHCLZG8Tq+9JmP87Zv+C6uPXuDckdy75n7qErLJlOmu/H6TLe2eHnrOl914iEK0XJq+Tid/oCD/X12Z2PaJBTsbYMQGQ5H1tEMij6D3oBtd4MyMZAlCp1UV0OwqfMuEbyFQqS1KqRCJF9VHzy+jVmHfA22rrWma3oU5iQ9DHmRoYTBBuill44njlevGY6fdpjsgFqPcHWL85YmNbM0QtOOWtqx5KBxZMGx07UEWZKrvYXoqG4GiBg54UUUM9VS4FEIHe9skXdQTYNumojShrDoSrewCCLgkGs7v1fnGo9zVC2E3//veeIrj+yDLaBloHZ20dHovAMpIlJGTGja1ADgvUckfSzvA8EFlEsT1sbqhLBRBFiKqD8XRKBJH1ipyIeWMu5nKa5dHE/SHF8ko62C0JXMtGfUMYw7mlkW56uVHkREqoWMVYJ5orKYJ6KTThBRekenQHA+lQvwOsRqileQdTXnj69w/6klsr5nYxJ51ZkJdESORuBCQytdCn49tW1o512OCXWyQDWFogBrJTkGkpRW61ukihFQ1gFXNWxvFtx9V8vOLcfMxXM8vSyYHhgGyxcx6pPY/RznD8i1Q07ivH/Vuy+ydGrIE59+nv3xDJ8HnLUICbVNMhrEPfW11ZcouXRYeXLp2OcVpj/J+JL2yvyh7/hKfu5XP0HTaekkFY7ZBsxqS7bqEKrLysoaxreMQ0Gnihdw2eSoqWflYoeQZ5S7FaboYOQa04MRejXOnu8UDI8VbE3GSCC3nuVlzdccfyeXbkctLd2fkndbpjtj3vbWglsv9Nk7OGBpacDmVsNaqrGoUFE12wzVMvc9eIbbVza4tTXjZM8hyli20r5COU97IDG2ZeWYQd1vKS+DfVmg2nm2dTgHR0n+/simsUDRwiH533NkMzkSrc1vu4VvXMoUFyXKELWIIHJo59o0EGHquRq8kLF0F2R8LxdiW7dzh8a1SsbgLCN2lxZGUguPNW6BeogGqhGoQWyL6Vno1iW6BekUKnVlmlGNGtWwO0bvTlB7E+ReQzhgkSZKEYNEXMz3DIc3zVyTDaIGkJFgtIxCgckWJaJkKaCTES2Lej0e69oFmuD94XVZXA8Oy8aBaN/hUwbo5hONpN70mCVQraIqPV1hOH0qBmazUNLB0Tn2EGsbq9zaeYI8D4RJWGTMtvTs3LzB8Mz9NFWNDQ29fh8vPJnpMrHJmj0ZioNEeIvWGilNkjuJB9+UDaErObF2kp64wN255Z5rL/P00y+Tb0f1bUioYwNlgCktmZbcvu55+vhzvLM4H+fUGGqtofZopaiMxVVRxHVa13SPdLvOVf0b57G2oa1LbNNSNbHDoaxKpBB84fMVvyQu8ZUPXWRzv6VqYps9iQoxWFrn2OoZjq2eZnmwQr83JCjNKJ/hs87i2JdmY2azGd5ZMm3o5l1Wlo9xbPUY25ODdA8FjFIY4sOiFTAL4KrDkq6X0DRwqzng9OpJ6jZQWge25vSpqNbfXz3P/ta/5bMfO6BeAmkt3kc7s7oOdFfjKvzM5z/H277mG1k+v8b+zQ2OP3SK+89XBHmJ/ST/cnPbc+32LuOHKzIjWM37DJeXuXntgFv7O0yrOF8r1pIZnTqzLUZAN+vQisDkSPDWVhUoTV3XWBtRyxBERC9TsG6tpbU1XkVplSj1ANKIxX2tpSTPMwYqQ8gKkWewVNAvNPvisOzW6wpu3mjJhxKXj8jJCb4guIpylHw3l1smM5hahy41e1std/XBOMum2EMm6seyNwivF8kCUiBkRNHmSRLCoLMCnXVRbYtMtmDz4Ozofkk4Qv4/spfeoX0ICzHqo/CHPPLPmsgPJ9Roa1AmhyBobX0n+d8HnPMpMAuLTlgnWPhgBudRXiSvX4UCKikwQSzQt1LBVEPRCLwKlPbwwS4WZxtHELE0HzqSymTMCs2kCJSpU8Wi8UEjZYuSDViPEAGFwC7I/3Ev9N7F44aF3AhH5kEIhSMGlsc7gdPH+xQDSeWrhZSEcgFEy347wbYTaudpraBpJXXbkpZg7N3wMbgxmUaEJATlBbNkk/0mfRefG19DhZZcKIRx7N9q2OkLBifh8vUYog67S3z7138tq8evYY4pOv5e7u08wOTgY7zxdLRQPP/AHq/crPj2P38PH//kizz5JGSFp5oGvCsIqRIW/J3rAOL+MpcBAQ7N38UiL/9jx5d0YNanJHQ1ciCYJYHMR04Znt/Q5G6PYd8QhKURXXJpaEzctJqJ55Z8hZe2NF/++F342+uYnqbbD8wm2aK8M9E1J86ssv+URYymnD7heec9HR7ceg904wyPxWfpjdaZ+E1Utst3/em7+ZV//wwdUWF1iy/jYj5oK2Zjz4k1A7N9jJly973HWXeSG5di9lDMPF0PftuRG01VW9ytQH+YEzb8QrhvvlXM/d/mLdFBHPInFuXLcGegcBSqf+1Y7GV/AKw/10BzRDTIJfheqcPb3kkIGrQUBAuhjgdkCo108ciMiMhS10ItY9CnETQEUsc2GsGskoQp6Dxmir1Rg5jsovMlsja+sNMo/EwQxjViXKFHDfIgcRoSomFDlKuYlyplmhsVUhdmOvZMCbSKOmVKRCRJpOBMzEuZIeo9eTw+2FSuDARP5GAsrs5hVh55f4IE9t9xBZ2InpqS2L1aWoedQG/JkvdSOTBYdGgYlTkP3n8/W6++FCdZqkU3UjmD/Rsj7KMlyjn2Dg5wXsKKQ6scMRcBDQ1aawrTpaptNJ3XGUobXAr6bePYcgKjB3Q6q8i64e7l87RFTnWzppuyzsq1ESlQAmsDIfO0U8FTlza461jc3E4NzyO0wWpN3c4ofUtQYHKDK2fMiPdjp8gQIpZzmqaiLEvKssTbhlmT+J7TuNa1gN/7jKfZfQmzXNCWSUg1KfEPVtZYXzvFyvJxlvpLDPtLCGPQ+QhrkvRGM2O9PcbN7Ws05QxT5Lig0Jnh+MopXn71CgB1aw8lFFRs2Q+RgkOTdmPjMwrbsHswonNsiW7RoxGxk/L0gw/G1+SaJ37rZ3lmC/QwLgJVQF3F954Dhzd3Pdeff5r1+x/mpZcvo8MqJ04dZ78q2Z9FlHh3DP9v9v48SLYsv+/DPme7Sy6VWVWvXr291+npnn0GM5wZQMJCAAQpkWAQpEhaDgdF0Q6Kkk1b4eUPkqGg5XCEJZtWhEjJlmSYlMwFNAmRBEGRIAGQAAgMOBjM0tMzvXe/fal6teV67z2b/zgnM+s1BstfcnREn47Xr17Vrcyb557ld36/73Lv0UPm3ZTdsmK7HLC3s8edd9/hwaMHHD19AsD2YAuhBV4JlIz0TMl4uEXdH9D5lWinxXULZK8ixGR5pZTODgAiMWdJBy4pIwKV9NqApPAu1nuSVMmXU28N2EIihyWDwRaWAlmUiGl6jpM6YoTi3sGAF55pCaGlsy39QrGVtSFff9vQ27csOiic4+BQ8OpDzXbdIc8EVd7pdFwQg2bgiyTgGgXOuZzZW92mQKgSU9ZYa1FhBsGv73u1Xua4bKNtFjfVh43W2TkcGU+28/tzB9mcPNK6FuVKNAU+C8lCkvpw3iWbplw6TnpgHusk3q+ylQEZkkSPkqCQmUEt14dlJ1Np06qENVtlyjygtUDmE7ZC0OlI7KdSZmcKulLQaYU7J3ERRZL5iIS1oCqwhnSstMvSKZ0sJaGQ0T+xxgkP0YAq4PlLFdsjxfHijJPGMsn4xXbWIZRjoeb4uUv7TATrAsGdSwCIFOykhJhDy8TADqHlUcZx/5d/5j/jP//Vv8CX/uW3KHYFVsFgILn9tuOlMTx9Kd3/cTPnH371L3JJj9jaE9x/+5f5aP/7cRcLXsuaj//wa4KtfsQ/usuFp4aoNyyTaaSSoISlW6kM6A2caDVG1hJMq8GRbdAivxG//Zu193Vg9pf/xleYa/h9n7/Ei/1PA/DP7/8rTH1GGwcsp0dc2/04z/Su87Vbv8w0l3SmzTHNPc+/esvxkU8KrhjF2aMJW70+B5PA9CwFSsPLir6A3Z0hH/3sACUf86FmxOHpv8LM0gl2ZytgRjN2Btu8+2tzjLvJJ0YDvv5wQk9AtzoZjOHxwyOee3GHo8dweHDKsx+u+MxLH+X+Lya6fM8rtomUZ4rJoYUDkN/QzJctmmS9sWrn0+4rMcTzTNxw7ucrkdl1OY3vHJhBynRtFp90cVKdSouYJ1GKO5fwRl6AWHll6pjLXcnAnOgpFHR9j9lUrdBesWwDsYvoGCglLO3mhFEFQU1gMY+YPtQ1uMee4uQB1d4WcUVpnnv8rMO3FtF65AJim8qXKy0qT5oghUgp90AKygqyGXt+TyMjSkqMUiilUCIFZFGIc5nEZMqbSpcd3m+C1JDQBekzxk1fJcHFcwv5RmZqTarwURBnEdlLnVz1IiKX07eAu51getSiLzxgf/cqd+48wPt312bbwcDkMcyOJtS7BYdLx2x+BHsW7wR6hR/xlhAEZd1DzKdAyFkzic2YDxUEl+orjIoR/bJGDA0v7l/nE0+/xK/d/zpumg8FHSyBQakocCxsInTcvdnwjUuvALD94gUMBlOULEKLsgIXBUvnkJG1DZSWag1mThpwnra1uLaladLAWXQOEwuUblFR8/KDjsuTBl1XaOGQ2f6oGgwZDkaMRtuMh2NG1QCKgoijzQ/BdmPmbkY9L1guZgxGW2hVIntw+cI1LgxTCfLB2UPkOS0ZEdJBgoq15Q8IigB0cLg85NlhH286tkLJxY9/Lr3f8mf4Z//D6xxXMIzgdQJsr9T5yzwmuj68/e2v8n3f9d044YkTzejSBS6NHU22nJtO3uV0eofj+YTrdZ9+0ePC1jbKaE6PTzg8SxvLpZ0LlNrgCwWFQhtJVRUYXdF1Kcjr4hKsopkrZCHzYSMFCDEHCnmEpWeeBUSJ6VoZz3llZkeMEBbUoyFvfeWrmG7O/rWneee1+4RxntseJBIj9pFeUQ6ndLplvghsXUzXxLctD+7BjasF01lH1Yu8+sDx9GdAny4ZxrT26rKH84roLLWsUFHmMQ5SndvahEKZGl05lLfIbrEG6Lu4uew7HWDjBgGxDt5WJc709Dd/rysPMmVJJNA5i7RLTAQv9bpQYYPPwVmyKEsbd0ASsCFis4dVUIHgUrQSvUNIQREUTjis3Nx8BBZaIo1Hu6zoAIgQzgVqkc6A72t8LWk0NKWiE2YdMErhErIremQUCZLiAwKxftbEmIhOgSyfkqoIkpQhW+8tMgkVD3qSUa+HNoJFd8asbZhmzPHJ2ZJlY5l5GOiEqZUyZzyjQOfPGIQElcSrC0inmeiQRjOOnwTg87/nw/zEv/t/5k/8e3+Wn/vl25idEqmSP+07X2/57u9LrzV57Piv/uHX+N3PXuOjF8d84+hVfur+t+ka1gHX9jCyfxWiVxRlx/bFltk0wZs679F6tRs36/V9JUYrxSqQzLjQVZUhz/ffSXtfB2aXXoL564Kv/+oj9Iu/CsDcTRg1nlOmWF1zfOctBlIzL+fUTRoy9+ae9kCjhON67/uSkrg4RoYtzqYNZU6NikmAyvLFz9+gKm7y5i9b9p65wmRngn4zGaKbCRizpHm4wH8Djh55GHkuJE1TFnkExgD1hQmTR6fcekcwkGPO3l0y2zlkq5cyB1X0DOoCP+twrxi4ZWHp1pN/9UwbNun0ldXI+ZMKnMOccS7QynvMKiiAJyqa+UOnv87N+YSVysFGExMrqHRpb/IOCp1LUl6mBd2lxUwUIAsBXaRY9WmErvPJbH2RmSpA7Q1dnhW2DVR9qK1kMg1ED34iqe/fpayu4jLGrHd8Rpgs8IsWuQxU080kWJsAk/AKIm5KmZqUeSliKidBKsUoKVBCJBamCCBEtlJZ9dUGGO1d8lEMIWUAIxuV7VVmMZAJEnGlp8Ya1Lva1tPZNz1AcQJyCMOdPmW+YNnMODkwjMuCxdF9Qm+LuthlaB5x2qUMkIow8R1HD864NrpIoTTz5RmdW6LCk6bVAQkhlTSDTeNOCLERvmw9u8UlhuU2Wmv6WzVPXXmKH3rpszy695BH7ySHjhBSABy9y6KyaXyFSeDVd24D8OH9x1zc3gOpKESNMRJ8YBk7UJqYN5/W2uS7FwXW5jJZlLQtLOf5cCDA2pZeZTCqo7GaeRPo+RaTszUAdV1TVxW9XsVg0KPSPaIx9LoeOltYuS3Htp8yX24hZNK66g36dFay2+1zZecKAI/OHhIzELsN6TNWIuPM1sG6WwtHtfOGWXdKG6cMdy6x/+x1AF75xT/HKw+g2AI3ywytkDbwNC7SuK9qeOf2Pb5/OmX7+lVOD+e88KlrhEuRLht3H89ucbM74d7ZES/uXEcLyXZ/SD3o003OuPM46ULeuHSVUdVDUqNkiSks/V5NrzdgnpXzg2zpoqBbavqiIuIy2zqND5N3ES0NMjZAwnGtMj9hRUcDYvRYa9GhJMwbdp/9CP/8//k36UV4+nOfZXI/jZv57IDlLDJbHHJl9/O8+vbP4LYiVV3QhXSA2Ls+4Y0vw9OXLdu7krNJ4PCx4N1JpB6DeJhKzb4IWFtCE7Ad1MogYg4s1wbsCWhvEURTossKGRwiU2/PlyvXsI94LjiLG9Hu9Zq6WkdhvTaouPEDDSFvzhFiiPjOUgq9LvlBCsysTwQAH13OqIfEqozQdWlh9EoQXNIvw8tkkZSdDlav5XOZvZWRUiXBWZ/hIiGk8QogDNgS/KhkWsFESuYKWuHW/pZSkgRUY4Ao13ZKqbybD2XJSDJ/UyCwECQKkb692rBipCxhd2iQRtLEFisd1s3o2pDnPhxPYLKErgfDgaI2guQ+IZCr06cIKSsroQrQCkkRNceHnj/w+T8AwJVPHfPo4JT/6D/6D9j///w4/+CnbjJfwHjcMj8TvPpKtj581mGswqkpU9dhbWS4C4tT6Of9SCnQXY8YFwTp2d6teOvVBqlSX7vsutMTPBGYwXk8dl4kfMrgSpPm1m/Mt/7G9r4OzOI12HkgeTDx/PRXUgr/Ux+qWYws8SyBRU9o+Jmf/Fk+/yM71P0UHZw9knjh8BE+//xnefO05ds3/wFCNtT93polWS3AR8+lvuLOrz3kWXeRQknKkyO2sxrq2aTFvB1Z3oELLexue048XPKamVQssvSG76B9p6IdP4BJwY6qqeKYb//cV+n5DP43FhUlkh7zdxfIU0BCHZIYrSUtXCtpwTXW4dzf6w0//sZA7Twj6QlCQP5bsAncgthcJDmfdRP5pJeCIBlIuxYJlLqqo6sSfCWwRMoOTM6qNTYSGpCloO5J7DzQNREd7NpHsnURv4R+MaBqlzhvsccg799Djk9RJ2lhriYLxHxJu2yJM4dp0uKYaOJh/ZmMyEyp3Cdq9T0pMGK1IKUgTAmRgKh4vnPiOYFLYhZUXAW+5/v0fDDs4ga3Ete/n3pdsjESTq8qMW2gVyhE9oF6+M6btItPsTU45kHX4oNn5/JFjh7fZbFMm3Wr0nuency5uHRoU+LdETZYeuWYhTtbv6cQgmXXrXFlLoYnEKntouVYlxT9ETEUKFmh6wEfu/Ycn/7Qc7x8kuoG73pHvUilBS9gYGDmBb0aslg/B8dH7F66iuwivWigUonZFyVHAXy2p3K2IwiHd5GudTjniQG6BmbL7JUpEoh40TlMAdXScQxol0Vlsxp8v+5RFyW6SPZDWmtQJXVZr71YK1sxGg9ZdjtYu0TISFmWeBkY9PpcupACs+re14kh4rxCaofK7Cr8ZvNRKvW/t8ml43B6ymx6yKXPfh+RtwH4hZ/7ElMD1kKlEy5NlUALRkpcDoqH/QEP7s95dPOb7N14iVd++V+yqz/JcvuY3izV+fRwm/Ke4870EQvRMCgqdgZbDLaGPD58zP2DFACdzM7Y721hjKKKA5RSVFVFXfRY5pN8YxfE2GGFYRpOEQq8d7mkqdbEEK0NUuocwAuEjGvC0QoYHrynbZec+sdMF3NevHaFj/3J7+cv/i//D3z2U1e48WLaanoXPszidMBbd44YDms++cz38w9+5Z/Ru9AyO8tED6vYvui5/yjy3T9geOftDiMi77xWMn6xRWbZA92C1IJaaWLT4Z2jkCIHCDnTEiRd8NgQiFIhjEE5jQop431+vVxlyzbZ7vSnPVfKDKRgejW/VwdYxSb7tsrKr/jbBJfA8l3cVCwC2OjpgiNE/8RW7Yh0uTbmfTIuT9ZXEikUQXqk2/iPdiopz9Q2VTYalQ5NK1+HmE/0rpKEGkKlmBeeOZKldNjoEokCAIOXioBDCktnHQFJJOLzKTp4sZYQsSGZ2Uef2JkCVi5dCJnY0lu1IYoJk65NzNlGELNtQbDQ+kgnYLoEU0qMTKVxYC29kQSNk8UTCqzsGCiFPdP8wT/8cQDmPObxtGW0NeAP/b5P8JEPKX7i793jG9+e0Osrbt5La8nOGIYXNM3yjMOLBaedYdtYYg/iJI3TunB0YoFwiq3aUPc6vAPpJUUQqBU9PW7A/lKKTbDsIfrNip8Y/r/zwEz+tld80D5oH7QP2gftg/ZB+6B90P5Hae/rjNmFqWDrM4GzXwSTGPpM7jVsb0d8T1I2AdkPfO/vfombv/YmH/k3UyhfOaCAk0cVP/7P/hY/+OIPcnA2Sd5+0TLOAe31xrA7DoTJAvMILn/oAoN+zaVK4cfpGDKeaJq5JNoTzB7IMdSHII4dl0LgLAvpnVaO2aJh8gqMrhgKL5BEgqoRMWU0TLmFcYre4gTfpWyE1tB1q8g7nxTZlC5Xp7dIAt/bcye4eO7nqwh8dRrk3Clw3TZVO+AciDP/3QkoSfpV3RLsAFiyFvcsdMQq0D3QwwQYNgJcnwSqyX1fFBCKXBY0AjePWD3AZzB4YRTIisVcUJQ7EC2T42OKt2/Rv3aEXaR+kE0kLCzRL4miwyIpspXIGvuxKimwkcdQ5IO+2pxglZBorRFSYglIneUBAqicOYsh5FKmI4SICom2Hr5DX64xJxnLEtYduTlFrcqbHUlCRORnfvGFqzw8TP5b4u7TXHjqCq65y/b4OY4fPsCVC/Z2xhw/vp+uSX680Ix47MfsIDiIgo5jYr0N85R2N1FhwxxTwLwxKNlDx0OM7kGRyRm24/ARzJaSOjRY5VDK0+sJPvTUHg/vpczN8fyMk6Q+wJZKDM3gI/WNHf7ix38AgEERuNd2zOoBpSpp3QJHR9GrKdqWJtdro/W4Ntk36UIxWcxou5YQoVjpTgcotUgnThGJEooAThVE3yXFTFIJtAtLpOiSXIg2lChmURBzyqyoIqUt6ZkLtPGAQnt0qKmaDlc2XNxKGLOoh3TxjM47XMgn4xJKBzJjUZY2leeFAuUFxycNi60TLr3wYR68/g8AuPlYoLZLhqcNjQHjDdJ5QhVQLhBFym1ENKKOvPHyt/ieP/Hd6FBx59EJvSt9dJH6fbt/idvyJncmc9puxrDe4drgApcvXuPh/XssjxLG7NbRA64N9rhQj4nR44qC7arPVq/mQQZFdELR9y1zjpnJkn4YgOwIQqeiWpVWALco6PUqurZliSB0qdwfQyDYLCVjCtxsSTu3dF5y8+gtXix6vPDFz/F3/vqvsZu6lF4Jn/jYNjvXr3LbWT7/3PP88Urz3/zTf0y1PUxjcGq5cgHO7nuO3w289Jzm5m3HvG155TV4kIGhn/INWta0oeUGAeFKBr4H0eNyVnBZRZQLFD7QRk+kQiooZMTHJb08Z1uRsu0W1pZNKw3CtQ5Yho6sIAsry0RBwloVeW5nyCFzNqQA0bRQ1OtkSTIuD4QY8UiiCMSYGJ02y2EAWJe8J2MUmBipvGdJICJoclatEImh2eqIyAwn2abqRoEguPQclyog+j1cUTA1liUKGwLGF2tJHSc8xnYU3tOtdoCQMmIiL5ZKCKJMZDQfPKEDpUOynBMCYTPuUHn6FSzDjONGUvoOQaRzkTbXfTubdC2rDH0R3ibGtyyJQaByTrPQufqQRVqrJTSlJ/QiRZmchE4fP8vUvYXtCho3wAzH/N4/cMDJQ7j9UKw3wfuPav7YFz7CeKfPy+9+iSvXLNVI4LvILLNTF7Ykti1We4rgWU57GBZ4lTxoTc74ySISjUCESGcjQkKdbb+6Kl0jZpFoJBrw+rfPlsH7PDC7+46kf0HQf9pRpzmNP61xBx3qkuPoDC5MJ1x57ipf+QWNnaVyp6hSWaE/bvjH/+wX+dD2NWQ0dLMJIxHZyqCCoY9c1j2a1x6w3RNcuzZkZ3eXnUJyOk0AVDERHL0tKLZPCHVyI4jXwe/B6aOAyWW+nTZN1KYB0VgW/SEuKpj1QKf7GmtJr7MURXowpWPt07XSoIGNivwqOFuXKeM5Ro44xxQ5d935kuaqrcXx4rmF5FyEFs69hiMtAj4kZfvYblic2oAsQJYQC/AKdAlKFegq59P7JTFovBMIp4ilhArcRK5xFUEKZC3ZGtTMCo1zHbpqMI87ikfHDBfputgE3MKhWvDzZO/LORzd+vOd67tVoCbz1zqLCQsh1n+UlGu/RwmZ+gAh+syeirmv4hNs1VXblC3zszj373jumazuR0VohEDHQFHULLdusfxa+gDDj36EY3dG23WUdUN/q0It5qjekF4vbdZnYY6PAtwxdfgcVtZY/ybdrEGb3toQuW0s0feI8RTBAu+WFKaPYo7KGCYbWpSTVNT44BDKUJiaXm/IVn/MaC/p0lx5dEaMSSoiIJElNG3gmZeu8EM/9FEAri+GfMlP+Kenj7BBIU06ObhgkdozzAHJmZpjpUiipSiEl9gu0r1nnK58AgWpTOI9tK1lUEG/TGBcUUBfa5AC6SPBWAIJR+azV6YLCl1U9Pt9jBmhCk0RSs78GX0Bo51tAC7t7nPz5Ayj0+FDtmnMC7l53qKC0OUDkFGEzlGYPXau1PzsT/xk+nwxwkJjY4XvGsrCoo1Kc0qAWsnN+A5TSW7dv8sXZ1OK3ZrT+6eMrm7TM0nqY2ewxcWdMc3ZMcfdkp06si1LRqMRtTIssrzI6eNDTi9N2LJjilCijMEozag/WAMrk3xLYNHOcMsjgm4wZUXwls61NMt0X10zwzYdwSdQuPc+RS8qErN4cQieYC2Pz45o3JJxGDGbzXnuxiW2dxPrGaBpI7/ypRP01054+KHHvFZW/P4/+Fl+4Pte5Je+9BoAeqjQC09/CF1T0b80oz6Dyali5gvKOn3G24WkvnPCfneRBzqyP9BMj+eYTiLFMr9fD03C+wSV6q9akMyy/ab0FM/9Wc3n8+LQsDlIxbghFa1kMgJxfZ1d/a5Iv+NioCSizzlYxxgTvoyAy+vIiqjliKtzLDaAdzGBxoJIDixaIrLwLEDTBXRmYyqRDgvbRbphESMqmyA0BVRbI3zpmdszWrmF98nc3Wfoh4whrWsiruXLpASBwrkNVlUgiSK5zWiVtfBI5TtxLpiKXrCcR2IIqJwtaBuYL9IHXGSvZHTW+ZIJ5GF9pPP+CWNwKQROxWQYEgps6FDA8ckbADw+VbTTUx5Pj1FlSzOviM1T/J4/1PATPzEnS6cxXTbMJmfsXS9YTisWc8ueESxON1jo0LXECL0qHbpmc0vTJaeh4Ege0aQAXNqUgCi0xMrA3EElQczzNRrKIlCUsNVT3Hpit/jO7X0dmJ22nrMHQAFb+QH2di1y6ijNgC9+4nNMJ2+zvy0wZsjxg6QXFhwsl7BlNHdOLfeP7qAwhOkZAxnYyTXyS0bQe9RhTx+z//Q+F/fGlNJSuIKtvLUWOjA58AyODXpsmUcQI1B96I8l8TTXymeRuS45dZbl/Ybx0wuEvIgKgdaPAdgzW0T/gOFAs91XFIctwmgWnSNxCVbhRUg4MPJiEn9jMLDCnQV4YtKvNoMVNuJ8AMP513gPBs3lLwoiFknnAm2XQd8beRtUCaIHsa4QVYWsK4QaJRV1QHhN9BJpA6IDWUZoBNFAtcgaX9bjpKUeDKiGZQoCo8KenVDePlwLq9qZxy0txTwmOYXVbYtNALQKysS5P4qsp6YESm1YZSJTaqSUG5R+jGtg/ErteqXqv1qkV+bw5zF9sMH9rRWgz//sXOdGJDoGGhTjTy558C9g9FyyMQvG4o4OGWz1WTy6yXDrAuahRfQMZmVhNYMoBc2ko/KOnWuf5K3X3sZ2DbrWiMznLfoQ5i0m1Gi5BTEgxZTADJvB0L6DAsmo6HNi54BEKUO/HLLd22Z8YS/1/cXbFMOOg4nAt4aq8YxioH+hxF3OW1QJgxOwiynLuaAv+3RO49qCIozosjaXEkk3yVtHQCFFQfBLfNxYMhHB5ZO7imkBD0AbI1orejnAHktBU6bsoCJmRnFivq1kTxASYwrqqo+SFmkUqJqT01OmdkavSv319Pgqb+s3EI5k/aNTxtVGyPEINoOspUqAbtPCU889z+niZd56K2U0G12i6EBYogPnS7RxuDbpQakcOEsRKKuKB4dzJnfvMNjf5uDmIR/+/HWMTqnD3XqbD+0HlkLz0FqelYJ+WXF96wJfqUvahyljdvfwPg9OD9nb2qFnhzgDSiiGvT4iB2ZTN2dLg2vmnHkoSs9ImoTxDgK1EkMWMmUrfJKKiVl7a0XSA5LuWddxNvGctRP2rMESKITnwsUeB3cTbre/JamHJYtDz+vffsilp3f4a3/zNXrFSWJykzBSKqYg+Pt/8Hfx+q1fYHvsCT5iW1hmO5ymDw8vhjROwhy9W7MMc3YfjJFteoa2skQhcdERSbprMUaEMqALyIbonDuYnheGXtkYwTltwvzv88GZF5sAzuW115Ot2SJ0RJRrN9I7rNjucR0Arg7Ojo0obBsz83sVFUrywrKR9xUyEpRMB1ME0gc6pTDB4yS0o3TlYqsPOzVHHHHKAEtABE9QIZlzA8LFROSIkd8sueNjsu5a1yVECsqkFtgQ0atvF0lIPDbQuHXdAO8FTQbudS7LSKwqG1HgfcQ5T5exprAKDpPWZOchhC4dmq3mlbf+DgBv2x2++6UP07SPOFkEdGkZDsc4U/HFH5jzs38/vdZ8FvnGN97m6nMLGq+Zzkq2BpbFBK6v9FBPgSgpZBJjf/jIoqscMEnWGGApU3rVRsDlAFkKnNPIvG/omaC4qOjvGFSxrln9lu19HZgVNehOMe/gLNPX551lJxqm78z40U8f8c1aMB/3uHjDcnIz/+J1TXvmGJYOoRSnswW16tO1Cy6YATtd2jAG1lJ0DZUUXNy6gGoUuusQZ1PCw6TV1J0u2DmDo4NA5VOqvluC7MG2qWjyST4aw3LRUpcDgrSU0hMl6J0hVz6T6rB3vvYaZlbRKyWjEbS6RUS3ZhKuJr7KafdVZWxF6V6d9vK3U7AgN9muJ2QbVtfF3yQ4OxfcrKQyyKl+F3PGzKYAZ7VCiAJE2UPUfRjU6MEY1RuDMHi3Sosr8CCtTzPMBNABZxTKZKbUpINOIUJJr1dCZfBFSTdxxMMTCpHrW60nWAiNQLa/uUbM+ZPtKlsmc9pdnqP7p/5Ip0AZBTEGIj7rBqXMWWJ8pU56IuA9l6EM+d/rfuQ9JI11J+fyWAwY+ui9OScHMFbbyM+kCye3TgjhId6MqVTJzXfusD26xsODg7WThZkCOjCZweLgdW688BmeufG7OZk8oO3N0axUvW3K9glw0aeFtIloVaJW8gLSYkKgX/Q4WjZEIjJ6iqJgf7DDhdFuGhNPXWU4PUGaGSfLSHFmGTioehXL7ZQJm5cVQ6m4VPV5053QlzWlNjRMaeIclym0WimEKpkuktK3KA3KaHTb0YVNf62yusJDyB0ZY+To2HN4fJRuX3mKzuOEoBUwypmMED0xpwCiFESlKeoaJQPKaApZsD3a5c7xEVXecJ7ZvsZwPGRxMkV5UDVUUdDOI1Wx2n1A+ogwkkIYKtvysc98im994+9wlsk6mkAXPCpKlAqYItuNIROYPlM8o/OoXoU3C+6//QaXv/A93P7qbSZnS1SdMmb7F64z0juc+gJURQcMipL93jbFYMA0j6uzoxNuPn7Is3vX2LIOE0qUMmzXg7Uv8XR5RtApjXE4eURtBJXU2GAgKkzeYUvVp5PLJ1iFQiRHjDWr0Xusd8xmnmnb4u2CprMI27Gz3WNxkgKz2SKwWywRu9DMFY/vHbP30Uh//Glk9WkAHt56lZm7hRfw8z//NW48L9jpGdzMouuakPvroD2h2i5pOoFxgdor2r5n2T9he54Pgk4RpMTlTLeIKfsklCKqgriy1hFhXcIMJOZdF7NTRm7nZYfyo09A/3NrJaS1UiDSGhA2RkYhBlT+l8jahqsA0ImNDEfKuIn8WhHrySLWkRATY1xISchgfG00rXdIGbEhMixAlYFep1CDiL2QUma93TEMF5xOFixFlQ6a0eVT/WqNi9kTM5VPESnoDp41OQvI62BEKIHNjHehUqeEFQu00ATpsTImQldORcYgcFm7xtvMYgVimT5nB9iQIDO/AQTfpTXURE3bSpTqqMfpR//d3/kGyx+Y8PRTiqX2LJs+SrRMj/Z4/mqf6fcktviv/iqcdpGhHzNZHkDhmE8AJdnNtW2BomscBEUUnqNDMDUEm3QBV9I7ZV7fo0jqBNKm+29ay4cvPQXAtaeu8MrrX+JSr2Db9Pn1ddj9m7f3dWAmKompPJUXZCIlzOAwekbjih//6W/Sf27AJz93mYt7V3nn/hkAo57mduvpb0nCLPLo0SFPOU3dOkZRs5UjoIExlE3HVn9IISz1TFAfN7g330RkFWE3c4yUB1HjJ3NUJRFGIwOUwwssTKqxxkFDKPsM5AAl5gRf0h+MOTucY+8m6Y09ETBqhLCPGfUNix2BfFwk0bAQqPOItzHQ5QksYD0ZQvyNpbJV2n1NAT8fKOR2XhoDNqXNVUsDT+DFSuU5lQBiAGOgrDJ7q+4h6yGyN4JBH7E1gnJA8O2aGRSCBDTCSVQLQQS8SEECNvWpKj2VSnX62AqE0RSVQdRj7HSJyebXvoUuCHAKiUQT1p/zfKy5oq6vyIcqpnKUOPcZV2XMSAqUZBSEGLNuWT7K+7ixZVllys718+rrVbC2+v75EskTPnv52ekIS+YUjSA+EnTbE9y9tIicGQjtjEJ19IY7tIcPaTyUqmXlIlApzRLHUsDpvbtMH7/B9asvchIVk9kZRqV0ou80pujRtacIbZEGhCxQaggyMzw7i2s7pC6JokDGQAyKuhyyP77CtZ2rABzP7zNQju3BmDvThm7RYu+cUFUVOluM1WrEaCvy1HifbxycMnMdfVkiJTRFpO/SoUWj6WJDFAuEEBhjMEYjVbd+Zqt+W+0NLuM5tE6lm9e+khiQv/TCN/jEjY9CTHLMbWepizothTkwc9njTxtDIdP81Eox2h7z4FAiijTPRtt77PV3uT2dokwKCJsY6Ze5vESqMFWlYkmgnbd87Pp1Rvvwi7/wZU4zGKnQgJREmdiAQqdyCNIQQ9aEAiwREQK6V/Hw3l0+OtrjZlVzfPsI9Xzq07La5srwKc6aDq16OCmIomBcD9keXeCheTd9nlnHzcMHnF6fcGHXUoWA1oa6qhj1tgC4e3aXudY43zFdTDiWJaNBjdSGIgvNAkhjkt6dtekgFHLWSbBW3RdZQmPZPmbReew80vhIITXEOZcupgXg9j3LyUIy6qdguQnw+usnXLn8JV748McA+NALX+CtN/Z5dPBNHp+eUjyI9GvFaKhYLhvGWx8C4GZzwqmC0a7iYBEopoELvQpXNdjc9/1lQLmEjxIiaXJJn3ChXRSIfCCJPpkKBcHa9WSdITt34Do/Bp9cP55cM20+4EXOYUwBkVdmmUVwV/mTjs21FmjzdS4mWaLgszNATLZTjo0kCD5l3oJI7MxCKQQBMbhA8WwP3ctZOq154CYcIen8AmXBSlB+swd4PDFCEQROpqxYzNF3PHcYWv0RQmBkxDqIUVLqgl4v1QP7fUNkQRQN0kWW1tJacDaQ42FChutECdqTMsoBupgOkOrcQh5CTKVOD9Y7OiS9ASzaxEQ+O4Wz+ZzJbESsA0Ee4q1ByQsMy22+8LvSxv7qq/d5PCk5fHhCTzmOG0njZUrR5vuaNw7RwI2iQGioTEMXkli5hPWBpAPKmKswMakQWJF8Rv0ivdj/4n/yx/iLf+nrhC1J3/zOJGbf14HZrAmIPghZUm2lTV13cOwCjw4bXngaPvTCJU6XkeJiySLXfK/0Gn70e7+PN7/5dU7VgsnxIVL2KeaenrWbsqir6IsFI7nNCMXwdIZ7+RbtmweI+RgA0Qo6O6cwI5pFh+oUg/FljKoRokevGAHgwxmhLKmdQqseVilkKxmFmrN3U/lhbHosVUu30JS2pNq+gNgRxKbFWUuVQZVNbJIi82q655KPF5vgYxWQfccy53vaKkiT58tu56p5gjwZWb1QQGZ8li4EpkqYHFmNoC6hHCCqbSgqXAGx7a1FaKVMFr+xU+n1y1RmKq0j5PqQ6AtCtAS6JMpDQMoOYWrUokPl2mnpJF3QhKApKVA4OuGeCH4gLaBSpMmzwpalP+9JFeZrV6KwIQSC85tMWMaWPdG3v0kZEzYki1X2MsrNg1g/k/y7NZJuFujXkamF7u2UkRVXxuhKonxk0SzYLivunx3TK2pkVkAXShM7h5NwdNxx+vBtdrd32Nu/hHvg0HnHUEQaP6HrOqSokTJSFBYhO4r1BgXLkxlRlxhREWOHw0AU9OptLg0uAXB3tIMvLENVsLUL88mS2vSoKsEi+zbNxBw9qHlhf8ylRxWHB0tqNWRoLuBETb9IC9fZfEanWqIx6EYTvU84v3NrWMyVE5nTFCFsHqYXEZdL2f/8v/85rt14hn//e/4wBEtRD+loscFis+K9JeBE1rJTBikEFIKRG+CFpMn+nIPegO3eDu/Gm/TMRq/OtaDKnJGxnoX3dFpQNPDx7/oujqdf5vWbC+ZZ0l92NVLM6EyL7CK2U8QYkdog3abEGoJbe9Q+On6MDJ7q0pjTewdcev5pAI67Gdf3LjOkSOQNrQnaMCz67GxtE3MwFeYtp4+PeXB2wrVuydD1cTqitWZUJ2zi7Sg5WSxpfaBdtDxyDxkMttgaJpxZ06aNbNlOcT4SZbrv4JJMjAhxg00SMomNtjPadsncLPHCsFhAWAa0Ts/60r7m3iPHolUMd6A9FWAN9+8tmbZfBeDqpfsMt/pcL57lzu0HvPLyMS99xlOPJA/vLdjeTdnDq/sfZnp8xu7ODe42R8yHhwyu9NBtw8ki4+PqPmruqLxEB5ns1MQq+yQI+QiXJCF8hhY8GWSt2irAWuN9RSYAnM+EszmUidUviJSByzaUmzHNxkfX5eeu2WDbyD+LPiW01mVYnyKBVZbfu4AuJSYEikKAkqjqAtUzN9DPlVR5Ps6nUx5NPMcIZPQQEk5MCNZSMiHfZ8LAxeQ0IAQhCmI+QPgQ12M2+oCSiuA8harYGe6xvZ32g6oWOHtG56Y426BtQDrPUmz6iAxJCH5TFYoyaa9JKZ7osBCgCUkEVosU4Gm3TaXSeBjpIy4MP0PRP8TFMxp3gm379OseOzsjFi4FcM+9AK+8bJk1Iz51vcfPvfw2UxW4OKqSowlQGM98Cr2eYNYpruxL3rrlKcv0WdUqc+xTBp4QaXN638iILAW3DhOU4a//X/437EhDuXUVVUjg4DcOrve093VgtqUVduLRlV9veG2E7VrRNp5Xj+D6g1uYjxTsbLu1/9jjo8APPz/j9jtDmJzhraYqOozz7JjAVs6+abFgZ6CovWXYlbhbhyxfPaBY9HErYF9zzFD0CW1DzyiqYpeq3CWaHlEIimKVKroAUjIwloDGqRYpCjoxZ2t8GQA7n6KVorNDxJGidoFOOPr9AhzZHwywJdp7VHBoF1AhGUX7GNaZonU2jfcEZt8hUJNi8401Nus9C1Nkk3VSMYF567JEmxpl0kQUZgBFSTRDkH2QFRBRpndO8EcjfEZPq5isbgoPLBGZpScqCC4gnUF6iehySVEoDBUhn4aNV1SURBWIpSQES4zLFFCtRFXJIP+0XqVFVUq0zOWE9XE4CRiGXCIOkTWWJmTjvWyRuV5EvxOR4okOE0BMJQtgLda4/nn+fQE0hGSIu0hd1yQ4JIOtSHdZ0u/3cUND1ziG05amEhiTyRKlIOaay/EcHtx/l8HFq7htx0jtbVh/fomIyd3AOYciAe6jT0EBpD46e/wQm2keIni8DNguEqNktz8G4PrWPvf8kiharowGqO1t9kZjmq0Sn4MDZEWJ4fmdHf61vUv81NGb2NjQLyo6UaLz4r6yYDG6IpqAbTq01Ggl0ZnV3Dmb4MUZ5CPEKkOcomSThbjtDP7+X/2rXLy4x7/10g/Q2SnClISQgOvp/QIuxJQZ1QYjFaYCQ0k9HNFMT4AkVntt/ylevvtVVADVg9ACEmwuZQaXREyFj4wrzYuf+zRvfftvc7h0kDPJZycdrgn4UiRtMOVQMemXKRVZRaBRC3T0GGM4XUxpj48YXd3j/tduc8VeB2DZnnHcHHFltMewLohaIo2iVxj2tneoMo21jVOK0zlvHd7jpRvPsjMcQZlwiftbqRz9RtHnZDIh2Aac53TRcTifoHSVxkUe3EGJZFPnLDGI9CdGggeXbcE0abx3y8B8eoLr9TmendI5R6V6uJCC4ko6nr5keHRsaecCYSpms46qD+08zbP79x8hHziKCnZ2rtB0Y17+9VM+/bk+n/nEJa5fSev420dnLKNgZhT3haLxDvfSFR6cfAv1KBfBaolrG4adpAoKQURKiXcBpyRxHZgJvtNUPneWWv+b+CQ8QpCZlKtsGJtsehAbfG7uovSsxWYdeW+J9EkUUhIrlR5kiOgQCFHRCBKwMd99jAlyH7WESrP73DW2X9yiG1RU2c7rsX3EfNIRfINMVbqk+YZYM00JPh8oV44lIpUxz62pZD2/FbY4EBBKoo2hN6gYDNOep3TAmBLVwTJEjPJo3VIEtzn0+8xriCmhHfLX5+8hv8kab2aUwHYQYmTpTpikQhhXdgrKrSVFscszu1e5ddfS9gS1kZxOXufxSUre1KVA6cCsNXzfx65Q9Ab8/K9+E1026GwLVpwpJtHj44JmUXDtIty+l7TnFOlZrJ61y8GtMjLhytpIoRXLYbrmbd2jnEgG3tEfrk0Tf8v2vg7MmoWHHnTCUueH1qCJS0fZgwvAP/1Gx8PpTT73A9uoOi2U4qzh7/6TlzmrLb0etMslSrf0DZSdo5+pvAPtMXMPwxJ3NEHdlFTtDrPljNik0VBHCNZTa4/UI6riMsKUUBiC7mcQFkRTEN2cGA1SFxTtFmiP7u3jJ3kjKGpsM6MsxjDzKBmRYslw2EP5yDQDVSvqtLFal7IfbUt0Au8tOq8AQcS14OF6zv1mHRmf+GtTBtxwDRJoXqbATAcwhaIqDFqMURmYHHRBFFsI0QcqJH1ErIlmkc28QSRnSGKQyGA2720GhJD6VCuPWCaGHr5JAFFRYeQxqIImezKVMWKEgTJghWLZCXTwaRPOp7tIWPvM6RWwH5lNZjfH3CSNsTkJ+rBK14tNeTgvGqvANsLafWFVquTcz863FcbsO7XVIr/ibBVzQ7/K2Ypasd+r6Q1q4q7BN45quqQzjp1Rmr73jieYvCg4rTk8OeHi4R12hiOaZkFv5boAECUiSGzbJkAxUJUlOmc0ezWcHT/A+hm4hphB6423KOEZ5jl0pX+JdjnBuwMW4YxqMOJC0YetITk5Qic0bbD0tzRP719g8drXiaeWeusSZa2wboPvK2VBWRcsoqRznn5UBC9xJo3apm3pnMd1Nqlqr3ZEUsyfnc/YHgATz1/9f/0VLvz5y/zoxU+ztA6hJDZnDtq2ZekDtYpElcdu9PR6Q/bH+9xapvkYq4IbO1fo92u62ZIKqJWgITLP7sqlUagOuqXnIzeeYffyPj/709/G1xDb9NrNfE6QyS/WUlCKjqqoWfhkw7M6aAglkN4z2upzfDZleu8u+5/6NO/88i8xy6LKZa9jvjyk61W0gwF17GG8pyk0l/sjTD9lD9xjSWg67jx+xMnklEvjfXQlKbRmWKUBURQVB7M5oZtilx2TBRzNHlNR0zcVyBWGqSZqS1huvB7jahPNY9ZHT4gilcdax+l0hvOeZXiMiA0iy50WwmKdZXcEx9NIw5JylPxeQx6PzjoGQ5ifwezohGsvOMp6wBvfWPJHf2zAOAcki7Ml6F28f0RZLpm1NXHvCo97LxMuZXyfh0675NQQ1FpE2noLXiGyzMpqDr+XvZ2nTPo7bsbbas6uMvDvhX6sMmubg3BiFJ/HWPhVcHYu654NYzYQiTzMZcg2TyGigkwBQT5MFdrgQ0CpZL823NtFXttC7I8wcRdVpszNO/fu8mjSEbKdWhlCIlmcqxfqfHg8h9XPQdlmbVyRPtIhNzElpZZgBEF3LEnzR7guCRNrhZdJuFZIhVKBMs/rQokkAhw9ComMMXuApgVzvfbmSkXsIPqId4lcPF/KtU5yMP8AADx/SURBVH3VxUswtY5nRheplOTG/id42LzBweNfZz4fcjxJh4PBlkIpz2xxwOGjyCdvPEc5aHjl3uuUWVR5HjuU0TjvENKzu23YGsLhxNPXoHOKNKiIRCJsihw7CYUB33p0L0OP+jNGp0Pc4yXy4kr297du7+vAzMhUxrQ+gkydXjhHVILpPLJTweCi5pt35+hfm9PbTZNVHCl2L0KUEkoobEBHS+16DFAUcpFfK6W+ywDmUcPieIlqNN41SJFVy4PGlZZSjKhHO9ieAdNH6gGyqAh54kcRQPeQ0eN8Q1Qu0Z1dm2jcQPQKxHZKSXU+2bb0CoSD1npMTgvIoLDWgmwy9guqTkAHbS7DZLmXdZCwwjCs7ETWhBpYLx6r6bnCZOVEUfKUBPoRygBFoTBVH1GPMXWRUP+QsEq6TxB9oi/w3YoZV6JlHpChwLUenECoRL8KbYeoDDGmTSW0LVppghTgJcot8O2CKAq8DhTZWiuUReY1BaQI1Lpg2QDBZQsZksaVALRBSAVCoUibiA5iI1UQQyoJe0EMEiEDPrgnFtz3SmLAZkE+H4ytA7eYAmR//rpzryfy/1bU+HQ6FwjtCGX6Zi0Fpn+ZwbjE15qDIiLqY4QM1HXaoHpCsvSBaEC0DjmBo+Wr9Nx1tgbPEdqUfitUSVEqls0cU/awfo6QnqI/wB+nuWGZcfp4ymTaYHUkug7lI/1CEvBIlbItW7NDtge73D87o5kvWLYnaFNyYWE4cQmIfzlsM6q2cfMGr0pO7k958/CMwUcDqKcY+YTvelyfsRV6yFlk2gPlIoMo6JuWZQ6mfFviTI9uPudocQLW0noQQdAr5FrvaB6hX9eM3mn5L/+z/xt7/7s/xydG1xG2THoXQIeiFYFCghIFhdRI7fFxztX+JV5tk3ftvDhmd7TDxfoGD7rXk4WSTsDfwtV5fC3pdI+2WfDS57/A4f1f4+bDAl8J2mUG+XqZbJxETU+AMCXeC8qywNqWmDcpETS+rsB2eCN5fNhwsQYVLdOEdsD1BMfH76L7A67Pa+phRVuAtiV79UW26pS9PrVvo2tDe3DK/bNjriyX7BUDUIoyS9fsDLe5K0rsEpqocNZzdtQwMKdEuY3x6TPWfsZcGGTokCl9ss48i0zdC64AqbHWc7ac0tV7dHaBFGUKXnMZeeEEXkqMDlzfH3Dz9gwqjS5d0tcBCJ7TGSgNIS55/TXYvyLoDwtefe01WKaA5Jt3LS8965H+IicngvHYUA563EGxezUB3u+fJBsdHwwLLDXJk9gDCk+TSWMBQZcg+5SsmJQkF5AVYoQNkH8Fi4AUVA2I6/Vhbet0bl0NRDTnMvQ5w6bymrDiUGXt/HOBWUyuJQ5E59EmBbBlm7TEADqZ+fqyh60HsDum2hujdJ8L+xe5+da/BODObYebl8TYUgmSTZILyZpunTkEIiifNSjx66yfzMxNHUUiHsiU5S0AoyNb/YJSWuwszX+tCxyKxs9oXXKGkVLTrw1Nl+ZG1zkKk3wEogs4mfYoIWKq5OQhYb0kuoDNfDFRSLQIKKFo89x46vnnuVw+hWsNrr9Hv3eP9p7l4Ljj5HHD/m4Kdx6841icCaTZQQ12OTg4YHRVcPMX4dKF1Ke9YcnBouXgFDo8tvJs15LlHPp9yXyarmvbFKh6I6gWEd+DZc4o9nNacLEEUU+5eSuy/dLK4Ou3br+B9PBB+6B90D5oH7QP2gftg/ZB+/9Pe19nzK5depY7R+9SVgrX5ZKEbCmLiLUw9WBw7PYF796OGJVOzM8NFU1jKS8UVCeBgQzULmC8A+FY+aZql6L2MG2waPyiJTaaYN1aTVnEiEcRyxJRloiqIFYVoigJZYHQKxmCiJQFInhUhOgbRAgQPHJlLhY9sagI1qXjIol55kXMeim5ru0cUUYqVSU9FZlOIkiFbFNmKvg5grDGMBgSJTtDdDYYMvHEX+sm2HitlflrI6BSgl5ZUVUVpijxuocwqTAvVUmUFVGUCFFCNMSgQTjcCsTgSN8LAt8JlAMjeiCythAgVE78hwQ4TrIWOl0jNayvCwTlUSpkV2mPMQlIGny6xkUPMpEOZC6BSKUSToaEb4KUvpcxpKOOTECzJEMSNsbj4UlpjDX8QWxOmufb+X+uyh2r6qmP+fXf2++CXG5Jv10UmqqqKHo9bCXQlWHWLDBbNbkiRdkrCG2LdRElBGc2Up8o5qfvwtYLmPymroDQeKRWaCTGGOQyn9bzabiUcDI94Gz6gP5gh/liiRQahyBGhckkgbp3ka1myqQ5pPPH2DBHdpGuszRuJToacH6JDQ1TZrjFKa/evsVTe1tc37uE2ErswKu9C7TlBNUXtG1kqiPROx5M5vgqPbPhsKIyFWZ3xHDR5+TolLhYQAfT0LKV+08JOOyW9C6MWbxxj//0v/1/8Bf+xJ/mefMMKutfCT+DNuBFQ5BHyPIiynkWNrI1qBjWieBw2JzynL7A3vgp7py8Tg3MNPgKZk0SMO1FaMKCG7XhmU+8wNd+/f/OvLOgSkSugTkCPkIRk9ZTUILYeogqz9uUffcuYoNHCkkhFUdnRwxMjR4OeXQ/ZT13LwyZNUvMo0MGakgpBmxLRaUL+mXJcJhZ4MYgo8TOF9w7ecyLbctF6zHSsJXVuEfDEbKqaX0SjfUOlsslk9mUSg/WJStrLU3T4JDYAlorUD4ioyUjBlDCEbuaaAMhSJa2xUXHbLFk2QlEzpjLYolxiUk73LvOp7ZP+OpXHlKPJCJTPM8cbEuNrhy2TSXTyXHHVs/jyj637qa+1xEOT2fsXbI8uhe5dBH87Ij53DDYT5nDBw/fYNSH+SJQW+h1sJPzUpoN5tNnsGcTJT4qIpZKQec32atVhluQ1k+VZFbzzwRqhSVlRRQQOXsec+YsrtcEn78+X6mIJILOed7e6rUcaR32Eay3CKXX4tieQHARZyKi30NtjxADhS8uMpG3ePmNxFietAErW0pt8M7S5VKmALq83ggBJpcNnUo6fSLDV1YrmpdJe01L0DGivWbY304lcj9D5yxqYZZILQmzBbYRBBtYWDC6Qufqj/ILOutwQJHlNs6Xjs8zQVes9jZCYRVoiMryxp30C3/oC7sIcYKwzzIej5nPHrK79wwP55ZyXxJiyuQ9PgUtIo/Ojin9NmKgePPwHvNjaJqUjRU9R60TRKKVoJSkpxWXByVedbASog1AC8tWE0tLuQCvNSo6bH6w+8PP88WXnuPml/4Ve7svAv+I3669rwOz7/+hj/I3/9a7dAtH0BlTFGG+SLRhJ6CWEiMCVhp8RvF4I1HaUw3G+MWEXecYdJ5SBLR3FKvArItIKRAzj+sa4tzhGg/Br/WVpAClKnTVQ5Q1sqiIVUksCqLSyBXbTQrajG8QQhPRyOAgio1Cv0gb6wodFaVAS0XEgdqAcYXWaCeQToEUCKGQsgOh1rgQsXRI15HkNVNrcjBwPmBYg/1XAVrMAp5sBkd97k+lFb2qpq77mKKH1QVRrsq1KRgjlhDKJCYrVNI7yjiA6EwGV0iwkegiMkoiFpH1toL0SW+IkADzGc0hlAblEfn90AHlQxYh1wjRYZB4J/HnqJlCepTUIGXChyiVTI5DElNM/SDw53BwEp8xZhtz4tXC4MWabLVhqp6LbFegX1a4iHxJfC8WJW42hs3NppVyFUzVpqCu+8iyIupI1Sux0lMasYpPETI9YynSuPcBFieK2dlNluUBJqTA2ZoWJQTKlCgcSqk17V2dE75c+sBbb3+LT3z8czhnEcrjosSoCpnxfabU9MshQ7PNKUcsmiWBJUpOma60azx4G5g6y8F0Smg6jh82vHn7gItbl7E2lcpkYdBR0GrHoNPoYR8RFQtd4Lv0WnVt6Oig8+wOh1zsb/PwdML9g0eUvmWlaFLIkqExTMual64/w6tfu8V/Yv4af/7f+EMU5ikgYSXd8oC2rFk0np5Y0JaC4CVExcVeCiLeWUaEhmv7u7x6z9AIi2rBNTCSSWh3FASvHB7w/Z/6BEX/kHe+cYYTaYyHJgfEdmXbld0llMFLjwgkOQm/IqpIZMyaW1pzePQQWijGIw4Oknn8vhux6DpODw7oDXcZ969gracykUFVsr+T7usbpkBYj3Id9w8fcbSY8tz2ZYxU1JmVOe6PqHp9jvOAlhKaDk4nZ2hRU1XpOuccMiZWY2gjrktjVLuaIgdcod2i2AI/9zy48xrD6W2qeg/tNb1+n65Npat2sURlvN1Zc8bnv/fDXNvf4+/+1Dcp82ZX14JF6ygt1D4zGrWgcZ57jyf0B+lC5xfMZ1AeWboltHT86q9/naI/YplB3DdL2B4KBsvIcAn7AmjFGh+2ElEVrARhk9VQjOmAh9joI67WSBnT93QOooRIgd2q/LSCZ7lc3lyxwM3KXiy/l3dJCsPmNS5Zg6eyZJEPSiEGOtLBuBOKUgKFwYqYoDiQiTkp2BODCj8uOW0s9f4Jv/K1f8HhUfoEpoJSg4yW0MKiSQQAL8R6DAoBXooUcDuIQiU9xxXmJX+udNjNJTsDvUogxZwYWvSq5IlDi0AtK5CBJja0HcwWDWXeFwu1AfybYkOm2Oh752A3l4YL0uFItYkRXBl44+00+beQvDs948J4hKCksQJEQb/qY6pj7t5PkkCPHkNRak4mHtqIl5JHd2dJsiM/48alrTS0eatqAtXIQFkwX0rqzMrulhOcD+xEkcqsEnzlGGA4zXIZ373zEf6jv/yn+XP/5h/kh7av8Zf47dv7OjC7ckPzqc+8wK9++XVMnmHWZ5aMSyC8zhpQLaW2xPxpbx5ZnnkmBRpWRsYWithgcFQhUq0mawBJRFqJtzZZMYSQxf7SNdFIjK6QqkRInTM8CcskolzzqCUSIVyiyYsUdEUfs+xFblEiQsiyFQqETMrgSiTW3IolJUQC0wuPplzrivkYEsAUkK5K1kku4kXYMF3yZzvPDQlik9HRGV9mIIdJaTIU5GyZMZRliTI1QpVIKddekgSIDqQVrAECwSCbCD6TIKJaK3yL6BPl3ncI6xLFjTwZRWLdpL7REN1at2utCSJSsCaVx5uIEAUiiqQRpDLDc4VXUNnfL3qiSODZBOpYHc/8OkgjboD/Pp7De8QnmZhrTbj4JKvqfFsHaOd+P3fVE9iz9RAQJCuV/IC0Emhj0KpHWQh6g5ayPwApqEcpsNHVaSIvSHA+YjzYaWAxm7PcPqUi4W2kFwghU0DW2U1fr+8onchFENy69S4f+8gnCCEghMMLgRESMnZPqkBd99ka7jPuZnTNEu9mTJsZ0+YUgOWyZaiHONFHdH2KYgsm8K03bvLMtR12hymr0bqaKPt07SlGlRRG4/Yqrlc9Tk8eAKBUpKcqrPG0QdCr+lztFVy+MOL+4QNizh6M6pphXbP93HU+NLrEd528yF9785f4P/7kf8uf/b4fAWBr/AKtkyzaE2QIFErTkzVCO5aLUy5WiSV97/QxcVRwffs626Mhd46O2Y4Ff/rFP81nxp8G4CvHv87yV/4Lnv3QsxwcfZ3704Cue7Sdpc2eUsFLlBZEkQ4MKkp6ClyQBCXW4OuQUiaIAKUQLJsZ3jrKXp9HN+8A0M081jrmkyNO3beg2EOJJJBbK8P+dgrMTH+IPzuhFJHJo0PuHt3nxb2ryErhs57buDdg2N9KTgAykQ8IkdkiIDhhaLMUhpQoKemEgiBQJPyVaJY82/tBAF742FPcbH6Vh/fu89rXHnPLzYnxFXxhqHTHIGc+e1swnbs0/wJM2j5/6A9+DmGG/NW//ysAXJARVYN1BT0tslxDgyKx7TUJp/NdH3mex/OH/PpXZ1gjkdrz4OYZ29tDDrNlXoPgrIbJIDCooZvDzlxQLqGCteCrzkBzLdMc0EKnYEt4VtrYIme8n2Cs5+CkLsWaJBllIg+lgCtfIwShiOvgrohJNqtrgBb8yl2ETCjI81KSqx0KtPQYKbDO5QhCrZ9P8Gn9b2zDyeyMrWHFW2//Ku++fYBr0vw3eklVaoJ3tC4HmlLgQ1hXiVCbgFKFNChdTEK4q8ObCGmNlRLQAl04ojxjMm8pK0Phs6h6kPgGVOjh1JLFssG10LaRNm8CvTpZ+Rmd/hYxER0smaW5Wmsz2cRakLXAKUfw6fdPDtLhrV04nn7qw+zubFGyxbAa0ZSHDPoFb9+Z8fB+fo4eRjsO2wh0T9OGyJXhU1x67hZFBuwvFoGOiO0ALWnbgJcKqj4XezfW9m66uMeM25xpT68BKwxBWVxrmaQ4kPJjilf/q/8VF75nj5/+yV/kd9Le14HZN799m/0bQ6pXWFv+aBWZt4GgSfRa3WJtChhEmUKNufc8ngQuXD6jE56tFrxfooSjsBGz1tyKCegYJNal7I2PPnmmrQDbWqFNL40qpVJ2Z6WELERmIQIIyqDBW0KU6KgJzhHagMxGwNon8UslEoVZSpUyM6ZIZc/MFxbeEaJHqEAMiT2i8ZiocDmjIX2FI2CCx8SYgjQ22bI1PTq3EDdJH0UqW2b8eQrKgKowqYRpSqRWoDVKmY2mTRDEKBFOErUiWo0MGufC2kBEBAHOI3EQHdF2hLZB+Y31EdmIPLGaBCJIfFDnMlT5FJXLjUIlpmFUMX0dDTIDqmPmqgeRq5Sk4GRVulyxRVm9fvREKbKQ5kblH/KinF5uHZBFNsHXSstoDfzNwe57y5WrlnKCm+zb+vtBrAeYFhKhDWUxwJSGYd9S9Ed41zHY3gFgvHfK6YMOEwSOSBHAupb5qSDszWhSHYI6JFXvZMTu1wKRQrAuI0cBwcLpySGRLklsaJ2CVa+QeW6oAFIb6t6A3W4b1x3zeNGxtA2T0zMAllcaogjUAYa6pKwrkLB4AG/cvcvzF5NYbaGv4lUfpZcoWVEpyZIOr44R8RSArhMM+jsU/QrVNITY0qs19bjPxQsjumzCfm00ohj22d/e49qlfYq5548PIj/+C/+Yv3KSrFt+7Lt/L+PxNU5lR+E6Fs4zG4/o+QGnXYvNJ7hhoWnDjH5R0y+u8MmLz2JO5/y+lyTX/nAiQWz/3Gf44r/+v+fN4gGvf/XbeAkiCkJUuNV4DgIVBUFlQU8X8YUkhIgKHpPrgSKkjG5UglgXONekEmgpmS6zcv6kQZaR2WLGwcFDTtii/Phn2a4GaJXcGQDG2xc4nZzhsfjZGW/ffocPX3oaVSl6OaOxU9XsbA0RdUGckLW9IHqYLxuMzlZe2mRf3JZCBOYdqZRWGR4sfgaAT+7+O+zLZ6DQFKce3DGz2YyzCUwmpzw6yuQsCRcu1gizhNhSF9vY+TZ/5Md+mDcP0mZ989arFJMZS6M48wnSUFsQZUAaOMtr3J3Hh7z0UsHbb8JkESiUYjKH3ctbzE4epf5yAVVq7CCilOBRCDgNRQXDQqIyP0PNI8aCCgqJosQhZKAJrJ0gIK4ZmYos5loKtIn4flyXfqPMmZ+MG0kHsIiszlUrAhgHQguUjnTLFKhtNuRViTQjN2qBrSDoiBIKrWV2bYDGdQQZU8l5PuPs0PNITXn0+IhmXtE2CR2ve4BIjjMupuKG9JFgU+AHIKRA5ZOmV8kAPeSFaq3oj4ToE9xFCoKItL5l2UEXHEKksRqUJFqFbabcn3c8fgzSKUwRiNm9IURJCIF+lbwyV4dXnddTmRdVFwTBB1ogLiPCp7Eky1TOBPjyy7f4kRsXsPIutgUdQhK/XUS+/Y0Ft2+m+5cO+mM4vDtj2jg6a7ncv8h4fJOHR2ku9nTKlEUjmM8DfjFgIF/AyGtc6r9IZdKTutV8hTfUHdQDmEhQRlD4mm6xzX/4xT8GwA9+QvBTf+sBn/niR6k+MYb/76v8du39HZi9fJtqKNi/XHF4L8+wKNEmdaqL0FciqfVqQ9tmpfQqcHoEs6llqDUmJEsRKZIK8mryrPS9nPdoJ3HBpXJW2AQ2Mjt3R6MhMwmTvk0gyoQRy7eF6ESy1EiFMmJI2ACRmUHeBWJsE1szRoLQiYKuTTpJ5yBPdCYx+GJ6/xgjOnq8dxQ5QHWFQIYCbS2FD3g8JSTWH3F9/36zd6TFJkIhUiB2HmNWaUldlJSmSirgUiOEwEtDzJ56Ap3S3x6ElYgcOUWSUCuASMZvxNCAbfFNQ7SJQ7qxe0m4MhEiMkRC/M6hTcpqrTJj+TSnUt+EkAINubomPRQQGhEjMWlcr416U1QmgZQ5C94T/JMZrfNSGQE2WLNVJuw92a/z3zpv2fTEz2WKt1dBZwCil6nGASilUGVF2evTSUFVD+gPx0xnj9cCs+PdHqNRR3u0pBUkyrZwNGeaOJ/T1ZkJ5CwmiFQylylzppQC79beb0KAEZGDw3s8ntxNuDJbAxKvW1x+1j40uNgAgbqsGfbGTN2CNpwwmyY81EFzyMj3KTqPlAs6O00rzhm8/u0DXrpxC4CPmjFdUWJ0idMJZ1N0DYfNQ9osoRJCwbIx9PojdsoapyAqyTJ2bFcDev30vKuhYOfCgF4JNi6oL4z5+KUbfHT/Br/29lvpGvfTfOEjH2Nr9wWWZcdd69kLA9r+CW3b0WVJkGraQ0TJeLfm6u7TfM/OFd7+56/zl//zv8Fnvvk/APA9P/QSy6OO3Rd2eeetxwSzTXSOGCWr4lYIAekDoTCgUsZJOgdCEYk4k0tSCEIMmBjRIeG9rLXIuBHHnc1btozGxoDwljv33uLB9ad56dLTDE3Bdp2yI+PBFo+ipzZJRuf+4X0eHD/gyvaIvl45Mxi2B1uoYQ8OjxLuMR/QnIsscvkxRg9B4aTkcROIlUK6AGrJWYJ78U9/+W/wic98jtmkxOCYdw5VGvojydZoTLtM998sNEeTU6SvUeoIIx6yvf0jVJXj088lbJ+6PuZHdvf4a3/r79IsBEWvpesEcyUSM12nN333+IzlV+DqZSgrgV9aimEfKy3tLGXMVCC5YvQEthRUEqazwJWqoJCC+cO0qZsAAwXReQYmYpcBHcAJgVofzAQqH7pDKgYQqkisJLEOSbYRUEWaYzYRWFOGCXBmswDImOyIZBuROlV3lAWRWX7rrJwEW0LoRaq+IZaSTlXUSmLjilEasjdnpFks6A6WHB0fYOqKx0dzRJ1fykCUkRAUQXtquTZbWQcCKza+lwneYQUQYt4XVxWbVOKvCsFWKYiVYekskWRSL3O1wlqPd56uCxycQtNA4QW1iGRXMITWyNhhVFaWygHgau21edH0uS+JiVWPKRChY77YlJi/+foR3/vD3+Lw7m3GHxIUSrOcGK7vP8u/9q/P+Dt3302fsQaF4uysZT4FMxTIIHhmuMvLBwmHtne1ZLtrCcvIbAq75VX23Ge4NPoY+zuX6XwSiZ3dCVQyIPsCtSgIA4FvnuK//1//WT7+w+m1Dopvsf8DfR5/ec6n/73fC3/hb/Pbtfd1YNb5Bcd3lgz7mt29tDCfPLapHCfBWkGwCm0cuvJIk56gbUAheXQXLu9vwIVVlMQYkjo0CfDoRbJhSMaqm7LVeoopmRSks3eY9I7oLNEqhBTrjEzCKgmi75BuiWjmxOUM1TXErInkXYeKDi9YG2qjNDGahFdbCwqmwCXKVWpZQkgbrc5kgxgDhZZ4LfFBJhC7iJicsl8FXefFDVeZsipCJTf4i1JJCl1S6BKl1FpxetVvq/7CtWBl+rl3qbwkFTp2+FwWwXpk8ARvCV2L6FLwFZXc4NxyPSdlpQIiCET02b9y894ykhEZYY2VigpCFOtAA0wWy/Q5gxVwQiRcm/DrU2B6qIEYPESH9yFpDIUnVbgdrCVH1qVN8eSYWAVp/j3X/lYtmagkQr5fecEAWmtMVaLLgiWBgKQebdMszgj5zkYjw+X9MffPGmIXsRG0FwQrOZ09Zr9eCat6jDZIoZBObcZL59bgfwkoJXh48Ii37r7OznCXvhqjlSFisPlo3YUlIbSE0BBFxJQVvaLC+xJn04n5/sN76FCxowsmzQLn3Lo2s3wMr91Ki9vFrceMzQWkKimkROuAdAVClMzXWeIG4ydYo+j3d+lixCAZ1iUxgApp3DdB4KYL2qJAWIvpOkbXdvn00x/htWmqZRxMJ7zxytuMr8PTTz3FFTWglUuQksePb9PmgKqqJU3rKeUOn33xwzz8xS9z9O4j9n/gBl978DIApz/b8v2//4/z7Vu/zNFxym4tQ0B4wVoPOqbnW+S6licSpEQKA9ZTrg4kWZrBi5jA6MHTLhsCS0JMAcR81tAfS2QUDAoNbsZicULbtvT7fbZ7KTDbH+/wpjaAp5CRxdmUg5MjFosZW7mEHJVkd7jNaDjmsbqzzvCucJFNswnMSl0iXUelBIvWo4Sk6STFdrr3yTLw69/4KrsX9tnr36COz3PSPaKLdwl2s8bJUlGoPq5rmZ/CT/zNn8NFzX/4J/8sH3s6WTL9g3/+9/jMj/4+/o3f+3v4G3/35/mH//i/43AZOXWwtfTrXasIMIs1nVgy2okEB+PBkNnsjJNl6q9GCwoB1scEnagkjYjsXR3yeD6hnaXxNdACt4RSR0ajgDoGewSImLPuaT0sVDq8GtL+UEASSjasVIOgjKlik9cBkTORgszpytM7OogL0qG7zBpi85Q5WhVagkyBhBxo9LCirAoKYegLjcvCscsm2UtFH5nOl/hFRA8G2MUZnUs4a4BgQ4KbBInWnsILvE77msqZKUEC9kcAlzwyVzi61ZokTKAuBFsDxU4tmajI7OQcxCOfAzsh8V7QZTd4ATjhaTMWFsDogJEgVdp7kCkQ7kRM5u25772L+cDv8UWSOvE5mdDPZfLDwzk3b9+ljVfZrd9lNLrMzs5VXGcZ9T1Xnkqvdf8daLpEdrFWMR462onlQzvPcb9JwdS08wxKmM0BC0IXFGZIb6ukjY9YuHSonLanFCEy82DHLb0Onrne5ws/Nudelw6CRS/y+c9e5yv/6C26d36F30l7Xwdmp2dzKgPzaUdZpo/S25bEaaBt0iBDpSxXIWUSgwGEAaLALgKqs3QhY3PWQMM0SMuQJph3ufzgV4tW2kQhbbw+hpT1cQHpAsI6hMyqOHmwRgnSBWK7hPkZYXJMmJ+io2VlreR9h9BbyBgRWhBkTpMT05NaZd9M4vskkc3VhFIJj5YzGiF6NGCkxIkEbHciqRTLzdq2Br0qUhxYZJyZiZtrSqUplMZIgxYalTTjEURksOtJHR2JTek7IjJjsCS6a9eWH7jMtgwhlSmjRomEe1gB4aPwqfAaAkLG7GG5goWKNdttlZHYAEQjYdXhOTBb6QTFECGGFLQR06mP8/1HOq0Fhw8WYhKEPC/2uFLkXgVh783jnZMo2nwW8tdx8+/fqq3YRyuVayklqijopMd6mDYLvJDY4KlXorB9xfa24WQ8oDyYEWQkdJJA4NH8kCsX8rgPHlFWxGjXZVetFCqzACGxM9s2smgcD48e4VzHXk8w6NdY6xEyC5g6TXQe76e0cYKNc1T0FKLEZrbE/duPuH+nQfYUy/uB6aRBuZxtDvDat5JFysWLt/hCb4CQFaUySOnwukdt9qj0FIBJd4Z1x6i5Y9TfpTZ97HKBlR5lIjIjImUjaHSHWizRqk5f9wQff+oav377aQB+6d63GY0dl5YLulPLSTykqyzDoyEnB6cEm8Ah/adfAFnQH46489o3eff4Tb73f/spTo49bn8MwOx0wv5zIw6+/AqTAFot8F4QrUsOF6uxqFPU462jE4FBVUJQtNFh88DxQgEhwQNCSNgvZ4nC40NilFtr0dUWVbEkuDkD39DOz+jaJV1V0qtSKuLy7gXqfg9/ekZpBG7RcPj4MdPFkt1xKsNKVbLdG7Hb2+LtfKveAZmF5laOF8sOUebDzlIirCfqjFtdpGhkaCK2m3LnQURfXFCYq9hFxLmCbu5o2mxN0feI2CFEzW5vi7kM/KX/+me49eq3+JP/9v88jWen+Jdf+XX+Z3/w9/Af/Nn/KT/8g5/gz//H/wmvPTjE2W4td6aDpBksWZ6U6J2Arzzt8oxa72CzRVeooD8e0RyfIpRk0gS2x0PM3oCFPeLkQrquawRyHtjdHtJtDbhwxXH6ziEnh9DL2Z1Kw0CptCbFmMzUexpfFwzqQNBpnlnlCEVCt6AUwXpclwOclSmGAJHXdGkEtCkAlFXyilyTgiqB2SpR4z6i3ydqxcjU9DGER+n5TCfQdTGvS4ogHc1yRhsUovaEzFSr6pLFosUFSzVMODidYTlrLc3sMhBFwnlFB+hke7ZaBw0wKCRbQ8NgIDg5XeAsa+/k9f5Z9WmaRWZ6ChYuInXEpAR8eobSUVbJ81aSDswr/EhIxZV8X3l8ku6tsh6loRGbzGFfKnrhaZ56+rP0a0NvuMvp7E3ivEeY7WPb2+m1dArIaANNZ3B2ho01u2rMU2lq8PVbir09x8wqpBRMlyfca19H1BPKsGR2lrJvj3kLqW7wR57/3bS7R/zK9Cu46ZQT+1VOVSLs3H294PJLF3krvoJ4O7/Bb9Pe14FZkOAsSA0zmyZFT0mGA4nUKTjrWtCVTurfOfouSolUHtXB/G7kbAKX6sSLwUI+2tKqVG8XOY1qSFTdVsY8q6DoNKEMWCWwvkN5Q87/QnQIm9LuClIueXZKODlEzE+QyymdbdagdxEDvlggtMlljwLrEjVahA6vcyASaqRShCgRWRoweewF4upDqpztEwK0RliPDCFlxUineEgByQozUcSULSvz93qrfUUJhBYgAlJKbExq+sHaZNeSNyBpA/iGsOwguHRCjQbpLH4Nspd4RDLiDZ6oSrwWxOjWUYs4lwXDB6JI1kAxJL+2ENORLOJy6j0muj8J3LXKXqabyoKxQuVgMSDxaYYLCTkTEZ0jBpewFV4SRVgvjms2VnrTtUCviOfsVs6BxFbMTRs3Qd3qMnfu6xWRIC2ocV2eRqcxnQZ0ZNFZmE1x3Yz55JDZ2UOClsQssyJ0zeiSZe/EcPY4gaW98YS2xJ955ovT9FquAKYUpk/XdUy7FoSikCLpZJAX1SKgXcFkOqWqDKP+CU3wyCgJLgUts64lth3LZs7SLXDOYpSmlAWjnLUZdCOOzxreudMhZy2lMwy3YWHTFFlO0v1/7dvv8PTONpd2CmaypIehVw/YHl+i6dKGrpzi2B0SXcOsOWUw0HgTklhzKIgxBXBR1GivqURHDDOsM7Q2Mry0xw9/7NMA3Dm9z83Dx7y0vQ3thDAt6cQxEzElzE95mDN+/XbGaGtA1JHbv/D3ufGFIcJbQvMG1SCNG61HHDaPuXl3SmdLvB2hwxFtqBG51KyCRYeCKFO2WhWSSIfzEaQjrPAOokNg8L5D6QopJbGIqM5jqpwRbEA3Ad07xYQ5qvXMT25xe3Gffr9PTyU64jOj6/R2LrKYLQGPb+bcfnSHe8slT+c+3R7uYPsDBjs7FGUP1V8gLMggQYYskQBWajoXkGFzJI0BSgcqO1R0oUCbEZaCo86yq05hcYJUu/T6NdMuRUpjD9ZKhFkyj0uUEHz8uSE//6W73D/7L9K4CfDarbss7h/z5dOX0UXFn/n3f4wf/6//Hu/cPaDJE7ILkf5c0AXHcunZUjWqiNh2wiwTgz/Rf4E/8qMv8H/6v/4M+mKHinB5sEfljrjv0hoPpFJkAYNByf6lCzzz/Ij2QyNe/fJbHGfQuNNQKZ/c5sYlg7JCxAptLlFs72DGGQ4Qb3Pv5pt4A+OxxM081ina4BHZ7SKKkBZaBUYrXOXoIsQmVXtW2KpQGcrnt1hoTa+3i+yVyPoyC3/EweldAOYIBAUytnjpUZ1EmIiKnuigy+XmrmkJKILyidAWAl6CL8BlhoPKuLfYRUyUSAGdC4SgKIoMBTKwtaXZGitsN2PZgPPJ5i56jS7yuiQlpqeYeI/xUNYKLTw9oMjVK9OLGJMOBDpjGkPscDGpCLS5H7ouBZFCgvFpjURD6QKrmH//mQEvvHQJqpZj+4hwtKBX7zGpHnH5wtPc2E+4w9PD2zjroYBT3/C8AoFhYk/Zv5DSahePb1G2cKw9agD+sIXTI2Kcc2dxTJtxdFKVXFSf4T/+Uz9G/G7Jy69/iL/84/+Iqfg2bpbWpN29H+bS7g2O93+Zf/LXf2fgf/nbX/JB+6B90D5oH7QP2gftg/ZB+x+jva8zZoU2aO/pQthouriAFJKqEggZcS344FKmKH9agaAoJD0daM8SScCFjJcS52rlIQHZz4O7VxIHMdearHdo2yJtibMa1ypU9CmTFdQTfoy0iSEVzo6RixmhW+Jck3BNpCg5th1CgyxKVFEjdYlUJjEzcxqlUy1CF6zKetGTSoFhI/kQ/arcmfwhhRDrzNjqwJE/0jprZhJUIpU1I2sxUaP0GlsmZDo1CxKXOVqL9xmLYlMeKHiZDNejh1jQuZaQ31HIBKDwKum3perNhkYOEGXOeEWfrZJ8zoCpnNXKZeQQiN7ivV9nyeJ5s93V62VWqljXGtn8+5yA4eq6dWmUJ0H9UbDxxjyXDVuLVJ5/T1Zj7dx1fOe2KouuTklJUiV9bdsldn7GooxMlhNOzw7ouknCZGSghouAidSDkrLWLLsu4eF8oGsjbT4yF0bl8oBPJVKjsW0E5BqbKIVCBJ9OlDHSOctsPiF6kIXMCBQ4WxwSO0fTNCzsksZ2KBuShMnK5kp7TCWJp6korIRAiQ2pRuVxf3LP8823b1GpHnUoYWtA6zzjC7uoPGmNrugODa2zhGiQUmOMwruI1usENiBxHrrW0+iANp6CAmscH3o2sUC/8OgT/Myv/wu+fucew8EuFQX4PqKU6DISJwk4PmtvMVLfw8nBLea9U07VgNdvfpXlsqHcSqf7/evP841vfYsHDwPeaGQ8SSQgkbKzQMIQyZiN4tMsjDHjT5Um+4DjHBAjkoCSEkR6TjZsRpaLARcsUkWUERhhkMHTLOY0XUO/SuKx48GQ3a0xU3GLIFIldXJ2wsODh8z3U/1uK46pTMXOYIu6rpmIBeQxLt6DixQ52+dDwv4okeAZq/UtkHC4RRlYLALSSD780jVeee0NtNzhwvgaAMvmId46BCmDEgKcnU65cnXAu+8k0khnQVRv8M7jE95+9z63br7D/jPb/K4vXoFfOOXtZerXxdTRLMGKVOZvm8CoLphOF2THH/7dP/PHufno/81zn5AMd4e88mXL9p7i7DRnelbIfkmCwwwV5eUhcn/AqHeVZ5Th5FcSi+5oCr6E3gi2dnpMm469p8ZUuxfQxZCilxixhpqweMTZcsJg8CHMcMFjbtJNSSbXpP6tysSAboJHFNAFKIuE0bLZqLPcr/GDknI4hsHzjLZ3KYct//JLb/DocJHXC4FQAY1kuggMS5F8cGUynPfdRoojCo8o0/5lJOhK0c79WurDdyBtxAfoQqAwCroEjVkNh60+XL84ZjiAewczGpfGZWUEw6JGlivtzhpFn+16hh1MYe6REnTFGu8tVUTJJCUjpUSrxGhWymcYymYQivWYy/u108QQMFk+YLEQOHnM7Xt3Ef0BplIMt0rODh6g7JDnn7kCwIMHR7x9Oqco4eRsynzhkjaZUmuS2tUrA968M2OvSBIrrx8d0l+0fHj7R/jkFz7Dz9/6BQDeeeuEf/sPPMt89ys0xzOe2lL8O//W93Pkf4ndfspeh/qIgh7f9bkd/sZPHvA7ae/rwKydW3S/wC46oloNeIHrAv2+ZNCXeAPzNlEt14uI9ZQDTV9GTCcooyRGhwtJqHOtWyOTH1fXRrJFXArKZMYsASEk1ezQLXEy4ogIb1BOgxabDTp4xHIJ8yliPsW3C7y1SUwyvOeDSWDZIbVFGo3WaROVWf1O6JqoFEJKgjDpFzyZ3ZcnYQ5gJJkFhkgemCQB2SI++XZrRiY5QBMiMfZIzEAt02IsYkASsmcea+FYACGSYrv0AmLIQZUjkLTb0otloHNU2SRYIBFZS+1chJNJBQlrtvpwLuHDVhtVcNlgNzyJ43oP0j5kRHMC10dUzASBVU0mXxP9So073ZPPbKe1lg4b8sfqka3KkeeFYr9TIPbe762kPojnBCzXrxdWEDma+RxOH2J0ZDI7Zj45geAo6j7ZhhXlFW3sqPoaU2kWXYcQKXC1FhbLFJiZXp/WWVxokl6QZB3UrgkdUqCjwFlLhQEpcKHFhgbbdHS5DHa6PMQ3HU3XsuharHf0dEVP99ZkkMZNmHaWtg1UHowxaC3RRSJ8ZJY70wm88sY9ru/u8kL/AnZpaGKLlSBFAviMti7iZcV0OUFGaJ0lSoWSkkhYK4QHFDEqnFc4F7De47xHCMvWIPX0Z597gZffeZdXDt/g4s7bmEqwEwTOT1iEikVeAJrZFh8K79K7/jy9F6/xt1+5y40bgtHWhRxFQTnc5Z13X2MpJb7zqNgSYz8fDtLrSJkkSYgBIQRGycR+FBIpA15uRkoIIZWSRML7aa3x3rPaOV3wOBugckgT0NIgo8d3CzrfJX9ZoF/3uDje5V0pcT75RS5nU24/uMPxC88AsOsdxhgub11gOBpz/8ERQkAXAyaSQ3AQImQoQDrGredXzDtl+pREIYlWY0Xk1AqeuvEC43LML3/1WwzriwAsFwEjD5Da0czTeltqQ7PsqHppQC+OLK++/ICfffFnGA++yNnjr3Ha3qY2gevPj3nr19LmtrVlcNIyP4PlHHaGmhACR8eeG9dTEP79P3qJP/Wn7qALxY3nKu6/VWLDIWcTv9YShLTEBOWwRUPcFjR1pOoN2O1rLjZJS+/RW6ccBuj3QNeR43tz2mrGU5da3JUbkH1ktdhmeKPl8N5rmGv71MWMrrvJ42UiVUHSSyvHhm5imXcJd7V00NikDCC38n0Nt7B6j+HONcbXn2d4QfHtb/4iL3/zdews4yqFISoHHuoiE9FWmmoxfTZIZUBdgq5TcFRI6JaegYFFxuShTe4USw9N4xzKCKKM9DK788b+Fld3aub2hMal4FIKAVIj5YChSQLAylQE6emKiIsdot/guyQkm6uW+JiYnEoqZMY2ClPQ6XT4WDWxYq4ntAq6AIlJ2Mvcp7fvnHDz5lssDDSt5P7kNV60H2bZnmEnkeDSiK5qz2CgOTMOh2C0tcOyaTCqwGRXjMpPGBaS25PA/ZsQTiqCUOhLF/i+P/oi1w/TGPybPzVlf7dgcGPK0ektKrnFjQtbdHwvsfxy6tL6iLNZx4c+Jtjd34e79/jt2vs6MIsemmWHkAK3El+1EaWgaRI4tS40faPo2oSRAbDB44PD+kQNlloRMmMsCjYWPBmL5CN5M2etnB9zoOG9x7oW3SVckY8R6QzCqUx7zpRmH4jNBLdYEJo5rrM4HxPx4NwOrlF4H9PrdwGpuoQ5MKx1crRcIrUiKpmU9aUionHItfzFStU+EnJGLAnXJqOnTVuB/nVMRrZGpAlrpHhCDV5EiInSQ/SO4CySQOhyZiy9afq8IS/e3hGDTuzRvGFEpSFKoimSREeMhBATXel82mxlyRRcekgxpFOb2ARe0ecsWQ5wwmrm8mRwFmPk/H+CJMaYheny8/HEuNItC5usWNws3qss2Qr8v3qHNfPyibvPv3MuYHxv/L26xRUBI2ms5cA1d8VickZ3cAcTJRM7o2tbdGkoSoFaaQ9FaJzFyw5VbDTbiBLrLE2bMVPOoWN2kJAJaJvERdVax8wYQ7toaZdd0myS0LoO2U5paWnbFOQ1y462bWnbli64xCIsUiAR8wI4tXMOZ0tmbUkpeigtkDJnK2G9eYgIk8PIa7dvsz++yM5Oger3UuCYBW2N0dS9AhsLYmtz1i8ilUxsz5wy094nUeAQcT7ivcNaS2EkTT6NXL10kc89/xHe/dJbfOWdu5TjIR9TmlMMPWGobMKGfP75izzz6Y+h5e/m6h/f5af/wh9l0khGI48q2/x8C+7evctMBEotIFTYGJLayTpmiYTsJCGlRKmkYRZDRMjkLgJ5fguFEIEQXM4KmrzopM6y1tK5iFARhCNKTXQNXbugc0t8nos9U7E/3kP3+3THp4kU1HU8PDzg/lnKTF3avsy2LrjQ32V7vIPU7xLV+QPKZqDG6JMgLjFluVeswdVVUiCFRBQC5z3VuCC6JS9ev8iyOODtLFXy3M4erhnSuCnOK7quw1qLs9BkrFA9UDycw8/+5Ff43t9f07v8cdqzX6Od3+fRAgZZl6ILDl3BFhrnHMu2YTAYIAUsm5T1/Cv/zX/KrXcV21cDZ8ctly+NCOEBrUvTX60/B2gtqesItWPhF1RFjRwEXvz0DQDu+zMencFyOKQbV5zeF7zxxgPs1S24MqIszgCYxJbHW0ccPH5Ee9rSzRYc3CPJPK9laQT9akC7aJjFjqHewnaLpCcmWmTIFE9n2O1fYnD5aaoLgncefo2f/Re/zGzuKdXKmWGRnG58WO9fQqdsVFSRDPmiqKEYCGQdCSrNO6HBt5u52Lae2hZ01mIzgWypIqanuLqXsKPPX9lnPIqcHTuihJmHYdhHtgsWWMbb6cWGvZiIVH7AoJaMXctiOUus/Jyhny+XuCYzUYVDqQItJLYsKJp2vQ6GdNZP1YWQGKiEJTEkTTqAdqlYzofcObrLWaPYG40BWNoTvNOU9Yo2G3DeYSqYzZdM5gpdl+kgN83VBR2IKhCcZD4JGNNgpeKkuMvdyZDjszRYd68f86j4JxzNv4vCPyKaY7b7z9AUgXlMGbpB4yi044UXBiDv8jtp78vAbLXp2gaUkfh2rQmPyYQAoWE2BTFwlD2Nd6z1sGxIju8FMFlEZrTMPcwi9Fe7JIAAmQMzmTflJmZxvnyJdC26FZTOUTpL3bQURmGMQopIyF4xwnuib3DLJb5xScU4JgXtJ9l9KxpKvgWXgqfQsdZEK4RDSIfQbARW0QRpWPsISEPrPNYHXPDYkIpQlgRAX4VAgfTZUtkuZg9HgVcSn3M5HWBCpHUeI9yaRIAXeOtYA+hzVougkhZYCMSQ7mnF+kGmQmjp6kxqCKAs0qwmTSo1E1NQJkJMpV4fCL7Lv5/vPZC+H5NgalKt93i/Af8HNmSAuNocc5AnQliXpEOIBOcRWU5ErIPzJ7Njqz8+92GuPq0Zm+TvuVWpM26+d17gd5V5W13iAEPExSRCyjows8SDhyhXMaFlGR2u34PSUWRdvs7BdGnxywYr3JrJFDL4dzJNgH1Tz6Gu0UoSfaDL1/pAysIAEYmz0C1amEfmS0vz/2vvXn7bqKI4jn/n5bHjR9LQNiVNimhBlHaJ2LAFlqz5D5D4E1j1T+FPKCzKBlZIiAVIoEqFRmoozatJm1cdj+153HtZ3AktlVDVVT3V77O1F2Mde3zmnnvPyTNGYYRNnpbKJ5klz/2KXFEZCB1FVFKFBVU9R/JRdsTG7hPyyVm6SUxRGD8OK3HEhWNyenI18JuO764fsbSwznViqmKRbhIR1qtJVZFT5BlFnuOKAGsgcRZC4/sllafJug+IMSllVWGcpSgLbNyjrEse5xhwffUtLj94mz/vrPPr3Q0Wr6bM2yXycp+sXn3rdN/j+ORjkiBh76AicwHuEDrxIUuL/k9xNKrY3TliNIH50GJch9KNsVX4dBXPGH/cn4oojKhK3yPPGL9B2/7b78w/wPkGy4YyTyimlnxqqeo2JZNJwXgMSdsymVpSoByP2X90wNneEZ3AX1dQQS/qkLTaTEq/kSCsYP/ggK2H/mn/Uu+YpNsnyB3tVp8gSLCmpHSOyPr7JPjffmF8ecvWAzocoR/2dvpUGeLbA1UZoeuS2JjAOqpRSdo+j8v/BuDocAuikGluGY8r5rrQbkVE7QWK4rSEXHIuilk7Csi++5lPPrrC4ycZCQOOHu0z1/Wrb4tnOuzs7NBKSt/WpgjJs4r2HBSZPyD04/f3GReQDqEctpmMDhkeF1QGAuMnxIC/B7biFv12h7zMCU4sadsRBRVF3fSxu9AjKCccVpZ+mtK+cpGHj7f5bWOXM1dWWKhHSk3tkM18j4Mo5e69A0Z7htYUVi50GdbH0yfGMLYdTBCRRT2qZIHMHBPkOWXhiPFxTHtg2o6iGrN9f4Obt27x8D70WimTyv/OOi5kXJT+UFfdpNbUJyufGS1MO/UrTTbw45iYQBkEnEwi3NBf1+XOO3z56edsj3YZjQ0/rH3LX+MD5ro9Lp33febOzPUxQcbUGMoJLE/f5avPvuDN5R2+/vEbbm/771e0dI43FjqEzpeaiQM6cYA1FaP64S2OYkxofVIURoTWEbuQyNVTb8Jn/g/rxCwIHEkEYexLtqefLwwNu9sx3f4KtqpwbsjG5hqFfUJQdkmTesWsHxCE/nR+YQLKErJySqtIiObqMqyJyIEys8Shvz/n1nJn/Rfe/2OKqZc0L7U/5Ke1e9xMb/PBVYiSMf3e77iTCWXuy/d7qSOvHsDCMe6Cvxc/X9V5XuBe9I4ZtLW1xerq6qu+DBEREZGXsrm5ycrKyv++3sjEzFrL2toa165dY3Nzk8Fg8KovSV7CcDhkdXVVsWsoxa/ZFL9mU/yayznHyckJy8vL/2nU/rxGljLDMOTiRb+5czAY6MvZUIpdsyl+zab4NZvi10zz8/MvfI/6mImIiIjMCCVmIiIiIjOisYlZmqbcuHGD9HRMvTSGYtdsil+zKX7Npvi9/hq5+V9ERETkddTYFTMRERGR140SMxEREZEZocRMREREZEYoMRMRERGZEUrMRERERGaEEjMRERGRGaHETERERGRGKDETERERmRH/ABiCotvykpd8AAAAAElFTkSuQmCC\n"
},
"metadata": {}
}
],
"source": [
"img_orig = next(iter(dataset_builder.as_dataset('train')))['image']\n",
"plt.matshow(img_orig.numpy() / 255.);"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 454
},
"id": "VJGKHo1lQYAX",
"outputId": "a31862ab-2919-4e04-eeee-6d43c7babf5d"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"([224, 224, 3], 0.0, 1.0)"
]
},
"metadata": {},
"execution_count": 25
},
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAawAAAGkCAYAAABtmxHBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9abB1yVkeCj6ZuYY9nOk73/zVXKXSPIEE5WJoXyNZQ9MEGNzGajpCprkQxpI7jMARJqLN8Eu2bwR22FeG7g4b2dFhZHAE5hpsxQUJCQOSbIQwIKRCJdVc9c1n2sOaMrN/5PPmyr3PUalkC5UOd70RVft8a6+9Vq7MXJnv8LzPq7z3HoMMMsgggwzyNS76xW7AIIMMMsggg7wQGTasQQYZZJBBToUMG9YggwwyyCCnQoYNa5BBBhlkkFMhw4Y1yCCDDDLIqZBhwxpkkEEGGeRUyLBhDTLIIIMMcipk2LAGGWSQQQY5FTJsWIMMMsggg5wKGTasQQYZZJBBToWcyg3rfe97H+69916MRiM89NBD+C//5b+82E36mpef+qmfglJq5b+Xv/zl8fuqqvCud70LZ8+excbGBr7ne74H165dexFb/LUjv/Vbv4Xv+I7vwJUrV6CUwr//9/9+5XvvPX7iJ34Cly9fxng8xpvf/GZ87nOfWznn9u3b+L7v+z5sbW1hZ2cHP/ADP4DZbPZVfIqvDflSffk3/sbfODZP3/a2t62cM/RlkPe+9734hm/4BmxubuLChQv4ru/6LjzyyCMr57yQ9/rJJ5/Et3/7t2MymeDChQv4u3/376Lruq/mo7xgOXUb1r/9t/8W73nPe/CTP/mT+P3f/3287nWvw1vf+lZcv379xW7a17y86lWvwnPPPRf/++3f/u343Y/8yI/gP/yH/4Bf+qVfwkc/+lE8++yz+O7v/u4XsbVfOzKfz/G6170O73vf+078/h/9o3+Ef/pP/yl+7ud+Dp/4xCcwnU7x1re+FVVVxXO+7/u+D5/+9Kfx67/+6/jVX/1V/NZv/RZ+6Id+6Kv1CF8z8qX6EgDe9ra3rczTX/iFX1j5fujLIB/96Efxrne9Cx//+Mfx67/+62jbFm95y1swn8/jOV/qvbbW4tu//dvRNA1+93d/F//qX/0rvP/978dP/MRPvBiP9KXFnzL5xm/8Rv+ud70r/tta669cueLf+973voit+tqXn/zJn/Sve93rTvxuf3/f53nuf+mXfike+8xnPuMB+I997GNfpRaeDgHgf/mXfzn+2znnL1265P+X/+V/icf29/d9WZb+F37hF7z33v/Jn/yJB+D/63/9r/Gc//Sf/pNXSvlnnnnmq9b2rzVZ70vvvX/nO9/pv/M7v/OL/mboyy8u169f9wD8Rz/6Ue/9C3uv/+N//I9ea+2vXr0az/nZn/1Zv7W15eu6/uo+wAuQU2VhNU2DT37yk3jzm98cj2mt8eY3vxkf+9jHXsSWnQ753Oc+hytXruD+++/H933f9+HJJ58EAHzyk59E27Yr/fryl78cd99999CvX0Iee+wxXL16daXvtre38dBDD8W++9jHPoadnR288Y1vjOe8+c1vhtYan/jEJ77qbf5al4985CO4cOECXvayl+GHf/iHcevWrfjd0JdfXA4ODgAAu7u7AF7Ye/2xj30Mr3nNa3Dx4sV4zlvf+lYcHh7i05/+9Fex9S9MTtWGdfPmTVhrVzoXAC5evIirV6++SK06HfLQQw/h/e9/Pz74wQ/iZ3/2Z/HYY4/hW7/1W3F0dISrV6+iKArs7Oys/Gbo1y8t0j/PNyevXr2KCxcurHyfZRl2d3eH/l2Tt73tbfjX//pf40Mf+hD+4T/8h/joRz+Kt7/97bDWAhj68ouJcw5/5+/8HXzzN38zXv3qVwPAC3qvr169euLcle++1iR7sRswyFdH3v72t8e/X/va1+Khhx7CPffcg1/8xV/EeDx+EVs2yCC9/PW//tfj3695zWvw2te+Fg888AA+8pGP4E1vetOL2LKvbXnXu96FP/7jP16JS/95lFNlYZ07dw7GmGMol2vXruHSpUsvUqtOp+zs7OClL30pHn30UVy6dAlN02B/f3/lnKFfv7RI/zzfnLx06dIxUFDXdbh9+/bQv19C7r//fpw7dw6PPvoogKEvT5J3v/vd+NVf/VX85m/+Ju688854/IW815cuXTpx7sp3X2tyqjasoijwhje8AR/60IfiMeccPvShD+Hhhx9+EVt2+mQ2m+Hzn/88Ll++jDe84Q3I83ylXx955BE8+eSTQ79+Cbnvvvtw6dKllb47PDzEJz7xidh3Dz/8MPb39/HJT34ynvPhD38Yzjk89NBDX/U2nyZ5+umncevWLVy+fBnA0JepeO/x7ne/G7/8y7+MD3/4w7jvvvtWvn8h7/XDDz+MP/qjP1pRAn79138dW1tbeOUrX/nVeZAvR15s1MeXKx/4wAd8WZb+/e9/v/+TP/kT/0M/9EN+Z2dnBeUyyHH50R/9Uf+Rj3zEP/bYY/53fud3/Jvf/GZ/7tw5f/36de+993/zb/5Nf/fdd/sPf/jD/vd+7/f8ww8/7B9++OEXudVfG3J0dOQ/9alP+U996lMegP+Zn/kZ/6lPfco/8cQT3nvv/8E/+Ad+Z2fH/8qv/Ir/wz/8Q/+d3/md/r777vPL5TJe421ve5v/uq/7Ov+JT3zC//Zv/7Z/8MEH/Tve8Y4X65FeNHm+vjw6OvI/9mM/5j/2sY/5xx57zP/Gb/yG//qv/3r/4IMP+qqq4jWGvgzywz/8w357e9t/5CMf8c8991z8b7FYxHO+1HvddZ1/9atf7d/ylrf4P/iDP/Af/OAH/fnz5/2P//iPvxiP9CXl1G1Y3nv/z/7ZP/N33323L4rCf+M3fqP/+Mc//mI36Wtevvd7v9dfvnzZF0Xh77jjDv+93/u9/tFHH43fL5dL/7f+1t/yZ86c8ZPJxP+Vv/JX/HPPPfcitvhrR37zN3/TAzj23zvf+U7vfYC2//2///f9xYsXfVmW/k1vepN/5JFHVq5x69Yt/453vMNvbGz4ra0t//3f//3+6OjoRXiaF1eery8Xi4V/y1ve4s+fP+/zPPf33HOP/8Ef/MFjyujQl0FO6kcA/ud//ufjOS/kvX788cf929/+dj8ej/25c+f8j/7oj/q2bb/KT/PCRHnv/VfbqhtkkEEGGWSQL1dOVQxrkEEGGWSQ/+PKsGENMsgggwxyKmTYsAYZZJBBBjkVMmxYgwwyyCCDnAoZNqxBBhlkkEFOhQwb1iCDDDLIIKdCTu2GVdc1fuqnfgp1Xb/YTTn1MvTlV06GvvzKydCXXzn589KXpzYP6/DwENvb2zg4OMDW1taL3ZxTLUNffuVk6MuvnAx9+ZWTPy99+aJaWEOp+0EGGWSQQV6ovGgb1lDqfpBBBhlkkC9HXrR6WD/zMz+DH/zBH8T3f//3AwB+7ud+Dr/2a7+Gf/kv/yX+3t/7e8/7W+ccnnnmGQDB1B3kf0ykD4e+/B+XoS+/cjL05VdOvpb70nuPo6MjXLlyBVo/vw31osSwmqbBZDLBv/t3/w7f9V3fFY+/853vxP7+Pn7lV35l5fy6rleChc8888zXJvX9IIMMMsgg/13y1FNPrdTzOkleFAvr+Urdf/aznz12/nvf+1789E//9LHj/79//UqocYt5XQEAjuYHAIC9ww7P3AznzI9GAABVKGye2QEAnD9zNwDgwtkL2NoIx4oiVN09nB/g2u3HAAC3D8LnfPEs6rYFANTLcN3lLEPXTAEAnQ9awfbGGWxNtgEA5Sjc9y3dNl7z//4oACA/WEKFSt/QJnxmqvfLmiJ8eg/wklCmP6bUWgeMC5hRaENDxSnXFvlfeQcAwL7mVeG303NQ+7Nw38efDMcOrgI2PBNm4Tt3+zrMYhHu+9JQW8fnHvbZoCyYKy8J1zUGeC6Uz/YqdIi5sIv2IDTCPPdE+O7wGrTtwrUX4by27dDw4T5zfS/0LwB5NOkLh35y7vLzHACpjTziiQcO2OMxDg1scp2xAkpeqGDfewfUa79p0B9r+NmhF/n7DW/6Xkz4rxEL4/luAfjwBF01D31gNByfs+nCjTtnkVE/XCyPwnPWc9SOfVSFcTisu9guaZMHMOXfbfKc61NCAzDJb+Q8l/wNhD6X/nXJp/zt185P/+6Sdt3m5yUAE/5d4ngfynXTYwusji2S68qzgM8jbd3k51ZmcOWuewEA2fQsAOBg0eHmPBR6vNGGufjbdwJ/9GC4Urkb3s2z2QS5Cu/n7DA81XPP7eHmtbB+bIfLYXohg9kKPeyz0ILR5jZ2NkOrJ0V4YtMWqOfhOvu3w8jVzgA+jFRn+f6MNYouXO/i9TBPikeeBh4PvX1PHe5x58RgvAzXrn14jiVsnBMy/g79mKR/y/ca/ThW/NzD6nsChPGS/jdneP79wLP3hRnXXgi9vhhNoNoytMeFN7FrMhgVRrdQYWR9d4i6Dc/XmfDdneMcF4swDj/29OvDdS++EdnuBg7rCvf8o/8XNjdldL+4vGguwS9HfvzHfxzvec974r8PDw9x1113YVTkaNHggIvs3iJMjP2ZQt3ytTV8dYyCZcc6Fc73egmVh473JnSmKVoU4zDkE8vXzACKb5Ljoa5TUNx1xioM4nRjgq1pmGjlOAx25kbYmISdqDhaIuMskRfYqPAf0A9GavIqn5zPN9jw5VFFBsU2dln4NH/xLyF7+avD9xt3hOt5B9WEyWRtEx/ELbm4cjPuTIZMhZuoWpYcDYzCRLKT8JKjcdB5Hv6ec8kqC+CBB8N5XEZ1U8HuB82hmofXZGIyaPZHPQ199OR8HvtDFql00+E+jk0AE/bVWPffaf54ziGvFdByVSwB5PxNkXSs9LUs7jUAPlFcUGv04ySLwP5Tn8c59u90wWXAeLhleD6dhbngfAuXhSv6LNylbluoJvRrMQrfdWqKipt5x3G1db9VZmx77YEptZfGh0nUoJ8r8jwqeSYRj9VNDnwunXyP5HgqNvneJm2RjXKcfMqGVSTtSTcg+Y0snhb92EZFBAqed5TznerfA3m2LW2wwf4dT3cAABvjDBslr7f3pwCAV85mWIRpjmvUj0d5hrzIeY/wg6PDGkdF2Bxc6+OnzC3rqLC2Cyy5aGcmtGZkMhRl+HuD78jYGijOhTnXIjsCchXuy1sg7y6hap8DABw8F8Z9xwETE+bERhcakDrK4jqh+uuk4ySKzRj9nF6wN+fwcS5s8POMASYEDs5ZA7K5WGJ5PjzLaDeMzrY36Mbh7rkLV17OcyyqMKKeCpqCiQMlirkyCls6/HZrtAMAcGfOArmHt2FWqmMa+XF5UTasL7fUfVmWKMvy2PFCKzRO4WAWNqCb+2ExmM8LtJbDx00KXmNZBw122YQBWNYbmPBvSy23cwsomkE5eyfLAJOqugCU8ci4muQ6DGxucmhO4oyLVGsKqNL0v5Xff4mxUX7t3zqdqOHHzmkoH9qdvf7rwuerH4Ybhw1GJgIOK2D/kMfCdDXWRssqWjfKACb0s78RLACMSnhROfnaqNbDOk4yPpBZLOEvhQ0yu/uBcPqzz2Bx9BQAoOLi7scFJtzgL54Nb8mt5QIz52MXyef6ZtEl38sLnGsgd+EJjFyjRNx1KgCe1o94v1PNU0ThZASSX/u8fu1JvOS13wAAsNy0ja1jYx23BGs9tA3PLBtXrjQatkE5av2jMZbc7BRfaCggk4WIn1oBt7hRySJkknbJRpROK/k77cvUkl3fnFTyfXp+3OSSTmuT79fvK/dMxSTXSTdUqkCJNXs8QqF9/xvZwDfGBorWrOWnPnMB27gAAKjmwe6+d7nArbA84KgNP16OAc2dyPAlz3KFLAtj0rVUCBYWXjQkflTLBkYHS0zJDp7n0Jxj59nSrFawjMcsuHEd+Q4NB29xdsznKaI1cuhDQ289Z3EXV/qc64l3FhXnjvRjnfzdKqBk18k6sYHeIq05ufdc39dTdn65C7Rcdqs7wv1u707RboZ2j6nMj5RCx+crqCir0qGuQyvmVctnUnBxQMNvq07hPFUaR/XEFQWUncO71I5/fnlRUIJDqftBBhlkkEG+XHnRXILvec978M53vhNvfOMb8Y3f+I34J//kn2A+n0fU4AsSbdA0HosqbOc0tNBZHVRShHgFADhn0TTBaljWYYev2yVa+peVpnbuLDL+Vqwkbfp9XWJKWWYAFbrP0MzPtImmkaYV1IwLYFzEtqRab7ymHEusr2MWgO41XEO121kPRfeaej03+gsX4KkF6SpoLq6qgSpYn6oTd4cD6Ep1mTxfFp9J3CKqLOEnob8U76tsC9+0K23H7VtQl6+EYxuM7TmNiq7FJV0bme8flN2MrdEIC7rFRNdKjNGoEVrEUBEETKQsUPLMmfSPXo3LVOw4uY7B8XFI4zep5bFuYc1ne2gZkypGwamSLWexQY79q5SGbYKFZTlptPfR7WF0b2do6QgbrlGYDE0XWpuMzLHYw/iE9in01rvMp9SilOsp1R9cd8fG78HpvHa95KerGq+MzQkwrtQCkyd36K1FeSaNk6211IUOAKMyh6F3obkdXGp5PoIu6e3Iwxy8vBjjQS4MV+mLvH7GIhqzfICyVMiKcLDme2NrwNX0IPC81jUo8/D9kk9i/Ah3+PAevnzBWXawgLPhvSm5JtwuFJ5ix17bCO1sLihkPpg3ugnudZvdwuhqeL9GjGuNrEelw30X7N8m6TfpO6B3r24BuDgOPVxHC7HFAc0uxcDV8gqgzoXzZruh36rzU4y5Lmgl3oAcRRfa09IL1RqPEX3tB/Pwrre+heM7l/G8WaNxznMdcQX7VAGLCroWJ/GXlhdtw/re7/1e3LhxAz/xEz+Bq1ev4vWvfz0++MEPHgNiDDLIIIMMMgjwIoMu3v3ud+Pd7373f/fvrx7expGbY9kKCisc98iiNdK5YEF1XQdL7bfeCDt609XoGNPJc7GmgJyxiczSgtImavR5DCQZQAdNQQKpRhsomgCilSzLHH6LvtvEcorpBh5Y1/fTRIMYI/DhlunRvMzRNbQeGHiGH8NSo1M1Iw2LOVAR/cf0ADs7hKbvH3zeAEBhnIoNVBvbEXQhaq5qZxH9B1q3rj6CfjYkfbuSoIpmgVZ83aJ9w6CjtdfQ0Z0ji5p2ilKTZ7fJZ7SYRdv3CiOJo/G8UQe07F/vgIa/kTBk/5SrgAKJF6X3c2L8yHyyLa5eDUjL+y/cw0YYdBxvzdiU0RaeT2Vi4FLD8TxHBJn2GkXB/mpDCzOdwdMe1KmlQ0lRYqJNrwAnEksoPsfaec4ft85W5HmSXU6KifnkN+75fx4RhE3y++d5HQD0YyvWm7MODefvgjGgos1Qnj8PAMgZX5rYER6Yhx57ktbPTHnk4lVgUGwyGSEf0cKac82oFXQVvle0IrxScF1ojROrCw0u0btz7kZoy/YBYG+Eg/kkzIPtiYLm3zoLa8JyksOeD6jDpQ3XvbezmB7uh3YROlzZCtucjEdcIGbore4t34OLpuyjbQXsEiLrCTKZFy1yxqvaEHJGc0FhuRksvv2zRD6Oyn5B5fpWFCUM18mG8cC2cygJ+Mp16OejxRKWM2TUhWOHncY5AtFUET6td0BTwTUvnN/wVKAEv5g8u3eEpa9QN+KOCcedV+jYYS3hYk3n4mLXcJOyru0XBi3IJAXHQcna0D066SYBWmhfwnPDypxAYfq2yQvYGY1uI0yCtLNlUdEecH5thfE9uiZ1x9AjgGbC9ikHww1LABTmfBb9Zt5HDD5QhxfJzYjqm92AbsIxr8MEcmYEZfjMRKz5soybr6Wbyi3nAWgAxEWjcB3ck18ITWa/+LaD54OMNDdFrdERDNIIEtECIwaXG3s8ANsmn9FtJ+4/r5Cb8I+JYExarGC7iSTuXYI+cY3Jtf2q65E/7TeqpD2PPf4ZAMB9BJe0SvVuOG7Q3rcwvEvNzT3Pe+BQBDAoICMoY8mD2mQo2OcNJ226SaQoRplTqWtuPTCd/jY9lroRsfa3P+E8EYvnB3mkiLWTNkVRAlp/XHFw6DfnFNAR2x8VBw9HFUSC/VX1HHJBwXIeu9ri7CLc5cHD8ONnO6Dmy2TK0PejkcHGhJDtA7rKOw90bKGVjcvEl1OxobmdYed2GJWLT4U2TW94NAQj5JuhLeW0QH2BKgZ/u395F7fG4b6j3R0AwJXRETY83026nx0cRlwnBAdy3fcb1g56lKYALc6VGhsMR5hJcF9v6xaji+FXh1fCcyx2Rqg2w69lc1XIox/UcZTyfISCGpRS/YZULcPfvBUOZh0sX5yW72bmPba4+XpBG0NBewvtv8ZBF4MMMsgggwzy5cqptrBu7FvUrkZbU+Ph4yil4WkNtOIGbBOrh4ACa3u/SNS4MwPnxRVIbUP1VkZG8ximhJK/xW2jfA+6oCbitEYnLkGXaJr8w/vjx1bg7+LmNIChauoE1l61yEaEiC6ChdWiRk6t0BE0opyJfjE1J5S9qqOa7yQPAipx9fT6t+oEsEFtv67hlvSBtNRoW4928RS7ZpP9oZATmGAJI66tjWNzuGj5PMA4C9rXgm0Rqwrotfk0yBxBAeiRLKJZzn2f3zPXvRtRXFE5jucqtcm10yTbdSCGA7C3F1yfSybljfIShkm/ltqisi6mEFgvGrmBk0Q+AjG6rosADMU5Y1SGgses7W2jdTBFi97NmVpaJ7njBKySQtPXIewnSYKRWekXuZ5YQ14l1pI/7oL0CVIjTSY+yaI7SeK8FPCA7eCYPzir5dsZxgd03TG/clI3MEQp3EVL69zS4+oZAQUQ6DLKsEmw0H4ZwFn1ooUXEI2MIXIoJh0L2GrSdSholXVfCCPy9KGCkvlxk4nDZoHNq0zS5/uoc4PqSpi5FzkpN/MGhi7jJwhkGHkFJ/mNHJAMq/mE4h6WPKzNSRktq2wruB3PjjJMzod0ov2AkUI9niCjReotXZYTg4w5ap7rTpaXKOj+k3HrOoeaHqkRfZIKHRwtrI4TxLoWBZEffpf3UAbt3j7aRt7MLy2DhTXIIIMMMsipkFNtYd0+bKAKoOcx4heqA6ipdwxgNI3vkzFbUTdN7y8XFgmte8vKS6yoiMnBmpqvs1nvv4+R8d6jLxZWa4DmTLCwJom2n6qUcm9RvlMtYkUhzwkusAJ5LmEJcFD7t8Jn28IzC93TkvSug2IMyzDm5JyDNzxPLEqn0YkWR9yvy0YANSAlmqzt4AlFlf7r2gqeVsaScUO1cQbFNGiPlQ33bZsWFf3yAjc3ugcmSOC4OSGmlFpYUes3ffePqKEWNg4/5hmiep5aWOusIh1WwQzynV07bwHAkJ/r+o1nAQBXsgKOz6QysYwAoeYRBoCmWcJLPFAsZ++QU7vNC8ZQqqqfFBSFOPzxORRWmSSA1QTdlOYqWvTJ86yDWr5Y8vRJVqaMjUTlUhCHS3HvJ11vLUaVtuukmFj6d/RcQKFlgwQU7eGBhrFcMrs0cMhm4e+tWejnS8sWt/kiKk64fFRgSmqgyTRYWF19GNeWTIcn1ZlGJiQ6/JxahYzggj3mVsytj1b+jB3cWY/sWjjvzILUZsaiUQEZfTcHc1o1UIy33eA1cu1xliiPnL1VKmCXHTdRPZPLJuPsRV5AS9L/dgCjTLMt7DDDuC1IDpAraAK0KlLUbSiFLAtrnlihWuegIwSaI2LbDkumA5SkXsp1hznBboo4gAYdtPTIKHy6roVqm+jBeSFyqjespikA36IkQMAwoxwuIdXhytbZ3o3RSSTd+/gCmMSfIS+Hkax1XcLE6dc7XxQH0sfL9W/pykK4GSaBLhDfdGFfQCSiAST26NQqjRAQNjp5QQ1pYlw2ink96nrIRVHzFnZM5wBfXtVVsMxBAxFVFh1Mzs1ONnynBPgItc3AaJ7DE1nouFArm7LYcWPzHZgWA8sltSwLgIuwIWLOwmPJ/KQ0kG6c8JERXemPO6pa9PRA8rUu+36LOTW2B0sUAJbsI8uVMqU0ks+TXIJIjqWLusyAzzzy3wAAd3z9X4zAGcPVrGvbwCYCwHMbsN5D5cJ6IpqIji6XOB5N3bNWJ643ua9s6jbZ1AVlmfzkRFkBP/DzJIome8J56Zxexwmle9SKm/uEe8ecquQCMgdd0vupsiBvn5H3ubNo1miLHFKlg0oWNHICk6b7obfumjl8noCIJbnSivEE4yldfJvBqVYvlhFtmlGBUyaLmFbDTUwZh6agIrsRrpft9c8tY+MRFB4AAF/HySdv4RwHcUdy9J5e4ClhUGOnXfd934iCNk2evUgALIWsCSaHFtabcovfbWKHimVJurSDrAaYXwXDHuzGIa8UgIkPoKPCXnBTdEWOMg8PUHLjGhUK3REZdcR9nSk4L0hsrhnzBfRyCd0OLsFBBhlkkEH+nMmptrBap2Gsi6S2hrkVPlExRZHt2mBlAUAnzAe2hSOPlZfcJevgOwkeikuwRKYInIhWlIKniS6WllIejm3x0RJwWJxhGDRVW5PEGnEPnaStphpFRH8y3d+5NlomOAjuKV9XMMzJ8vHhO6hoBvIeOoNXq6BoqxX8JtWp+4MrQd9WsRN1Q8b1enHses4qtAKnl+exHTxzjDAJml69uIaWv03BJmJQlOxz5R3WbrFCRhuHwfceYcnRyZLfjBywh9X7efRQd7lOCrBIIezrlkJQXsOPb9wObp2ldyiY5yLAmNp3GEnOGFmsndOwK1cnQMHJvKWnwJiYm7Vo6XpFb/kLZ1yHVXcpELj5/Nqzpc+RAi3Wj1msslCs/za1eNzaMa16a6BTq/0KBE173WYeI/0NwUr+5LZG61KOWhvThNJ2xrbyPZx5xHetPAx9f2HeYuswzOV9utQ3R5uY0MLaZBpKfVRgQbZxST0wRY4RLSshoV4YjeW5ML83Lu+HBuzbGIIQSdMBZLz8ocPZPwrufHrjcGvPYSbnyVgCkNK2kmc1Rj9eqeUfmVPyAl6SNxWt93qJkj7U8Uboj1sbHRSBUFkWQBoOHo4vlJZVSGkoT15MLhNlXqIogzVV0E1VFn3eas3wgG9zzDIhyZWHr6CqGqpNIVbPL4OFNcgggwwyyKmQU21hdc7DdzqCHXIpH6J7fTTGNzqgo+rZkBmj7ho0ZBdoa4IqfB4hmWBCsFE5DBnGhcldeZ1Egmmdmb4tnqpjaz0WW+Ha1hhk1LFS62E9hpAnFmIq8ixWeP5UB08Aht/bD8fmR8D0TDhGDjXVtlCSlCpkiAAcNfo+XqJgJauT9M6+qgFmoutGoOwuWm/Cmdh2Do3A5BmANvtHcOOgrQqzeWVdz/St+u6LbZB2nvD8Dn05hcgpqBDVTJcEZeTPUvVxMcc2u0QVl/NSQEequMdIozCdeKDD6nM+c/NZXCG4JGONLKsz1JolXYQrLcvQ0LQrWH2gqZveSmEU32RF9PcL/F1Y+VfahL6fBHhQyBdrchJ0fZ2tIrW6pE0pQ75YpY3vOf1MYgGIRdEmVpIsMBPEsE0EityRfO8Z7PBd8lrxry6xLqWkR1+EZDVmF58pjq/HUsAqrFm1ebPC7oXQn48yUbYrOpScozsbBAptHEEzDpnl4R0u8ixCwHNaXTYvMbsUiPnO3kkY+mNHyJar7UtTKWJM1AP2ILTaMu61VCqy1fhkkNJ4KxDSOLJkMkQ2kHhtBU/AlGYpofbaTYwzglCmjN+ddbLURXIACx8Tho2QCagMRth9JGVm5DEahdlX5uH8yajAZBwadptrqW8a7Jecw3zXddOh8Rbdl5E4fKo3LNt2sLmCot3sDR+8biJjgo0bR/87RzO165rItlAtWKfImD56T/M3UwVyblgNAQ/hzZD8JboEkypO8jq1CmiYQe9HBSxJXtdfrPB7Ppfr3VudjLHrF3hpgtEOaFmHyYcXyjcLWG4wgtrLXAPPO8ZuUDq2VZhAlc6Q8e32V3mNTEOzoFBXhw0ra5roGpW+bNsWtTBhEPXj53PkBIDU7OdW+egUk2c0QKRwikwXJyy6qUtlHQgA9H1ZKCAXRcUARkAXtv9tCiAAVkEXJ6LmEgSngEylLMJnHv8Mrrw2kA8vm+AmzLMJrCa4RIiVbQsr7lfHOeE9KoJZJhPWUMsK1AKSYUu11tBubQyTv2UzaZFsHHh+Wd+wVoAT6A924t7z0gcKtQAEXA8ESMEU0q9pIce5Wj227XXc/HPexAIg2Cy6mBQUJrx6yTpiy6ZN0IGrzxPaGCSHj+0XXtrt2y3OE4XnN0NrFhOHXb4Hk3F4lybTCUCSZ08FY2QMIKg4HjPFBDMCZtylg3DfzaNIQyFISmGlSNusVN8fMuLa+9jXUQdLNiRhtFhHu/bMbezLtgEO9gEA9Sz01myvhiLc9MJG+PWN/QWu7lDp5xoaPogYppKgVYZMk7pJkJK5RcHNa8R+KQwwGoXfjtn5TQvc4FoFurs776HqFqobXIKDDDLIIIP8OZNTbWEt6iVGSqGinil0+NY6OKnEy3Od77UC8R25rkG9DJZEJfjL0vXmLsTCyqGF4FasKd9GLUlgpJ238MI5Ri3HeYu5FJ+cToA9WlgJXDm6LxIfh/AGRnYDB+GgjPBxnzk40QCZIGGWM2gv1YUJR++qxDcqN0asmCyFAx1U7CSplmo6D1sH3U+zr2y3hGKuWsVqo4u6QiXkrQJqcS1AK6+hlbbs2mOavQWwEEstavEJM0ly3rGS61mv+WfJpyBxGyhYlVrAq5p4G887nvelcbzcvE7Ok2M39m5iTtfsiHOnUDoS4jr2c7NcxnHqYpA7QysVn4kKytBzE7oE/i4i7c+Q5CCxMTVWXXzAquWUSnod+fe6Bas8jrFk1CBxKRDzypR3kXbP+dW5zBMTxgy+G0bBu9X3tITCvbt0r10IwB/VVujIFziWVJLDQ6A5BIAI4jkpsSu1nOc8NrntcWY/XG9jl8UT6wrdKMz5kp/T6Qasogcjli3xAN9nKfg4KsdwBDVkZ7fD+WduwF+ntcL7ZwrHeBIz30O/IxAI/XsgotCPg8zzMVbBLXExl2NtjXr/NgBgFv3wO5FlY2cvXPDcgcf1ET1SdBca6+IkEAahLMthyAcouA7oDgXBJzn9w1mmkNGKm7JRLgeeZS6mSl5y4yzM13oBx0EGGWSQQQb5cuVUW1jzpgWg47YrGdkdHFrhC5QYkFLQEsgU66ursWDsCkpUEAMtcG8GvgtdIOMxgXNrpWPw0FDd6JxDR3aDjoGmERxm9O22ZzZhngoQVh9ZMfyxuIH3fcxNgqpeAbZZPVE79Ox/pKpXB3tAZKEgOMO5xMCSynUqWiZRk1IKKmYJUpO1S9hIKcCk46aJpcWXtLCqeglLeyUTtvauxeG1wFt2eyFcez6ZdIwtJXGGdesrFYveIopKtQYyyREXvAgUFuykJXxMSjWJNoo1qyFluhBtP8FzrLR43XiwtsGTtwLo+AHGoZaHB7C5ACtCH3W2i+ASsbB0kUW1WywsZVSEuHuCLlrvjsXYNFYtdSBYiuXaeSdZWAk6f4UBX54ptTbX+yWbZvj6aUhE/dMbt+IVhX8yhdNHGLf2kamDJOG4ryhwjX2zpGfiyvmzuPcV3xCueC6Ah/T8CH4WEAmO3HpTVWK8YIFMidn649ZvOoZiYW0dAmNSSOycC3HZa9sLLGg5bWZke5hMsOQ4jRh38ZmFJtiiYOmOIjdQAns/E9qMcznwiMQcw41N0hb5LJK25slclDEUgIpPQBfCGThBH//S6JOr5cVuqgozej32mTqxM5pG4NhGML5w+bbBM0x4vjmlpwMOI0lAjuCsrE/hkSKxZoRCyAEYc8yzDIbWVs6HmoyAx+b7AADHWLiaH8Et5nCySL8AOdUblm2AIziMOE0X7DDtHMTKVKwxYxDKcQCIL1ZT11jkJChVZGfIs5jFLRPEKB2RQZokrsYgEuJGgIVr0TjZvOiW62pUJuQ2dOc340QVN0aadyKSor9ijpZCXE0iqtD1C5WgBXFwC75i9V6p8tsld5CNV4XSHGn7baYjU4OIg0bmQt84ugG75gCWm1LdEWXpOhg2ULxX7XIZ3V1S4dWghGLftAKMwWqQXzqh5KQXZhKPZHFVyelSAYLfHTqPvdj+FBTTo85iXaXkWaUNduW3QdK6VLLx+eSHjzzzJADgrpe8KvRLWyEvmLvHqxhjYOWZOzrz9Ci6/XzMCQzVmsP3RKJ19fOS1MqcqJP2RzdmsmNFpSjRCKKykLis0nvJpl6SzeE7X/JyTARdeyuADG57izr5UcpiAgAPQOEKJ8ZGGW6+PHcH7r28AwD4k0/+YTjv9a+HfyCUbTGGbCuH1+E5Lz2hd2MA24vgErS3wu7TOn/MpZmjH0PJaZp3QH4z/D3hoj07s8B0HN7TCRfZUV5EJOCYyocrHVQu9aGYk5RpZCyZYadBYdHbJayqV/ogBU6MOGDb2vQbFce8zA3oZcMBKaZudR1YvS7O5zwZpQ4rHDzheZsOR1Q2jziPSrSwVKqLo9Bblw4NdjkdrzGc0CnfkzFrIQNX8d5C2KyRQwt9FZVdbTRyrsVdSbBMp/CsIqxEKMsO9+FmR5F8+4XI4BIcZJBBBhnkVMiptrC0A3wNNLScdC78VQ6WLgbh91MZkJcsJSD5Wo0SXldkRDm01sLZoC2NCuYp6AJGyouISye38FIKQ6j0bQXrV5ERWlm09OV1FwSQmoApdM/AEbUHhWMWgE5cAnJephABJFLCIjt8OmrqRtx/TRVLeogGbZ2DEjcHn8N0Sygw013ybOBgCWEQpV+pAjUBHdWcep/RyDRLnXTCV+jh2F9SLmNy5hzaw6CVLxdB4/JYZUQAgO3pBNu7Iej+uadC2ZIaCZdgtDZVdLl0LF9wGz10PjMZvFT3ld/641ZICmEXLdgiIZrlifnab0T29gKX4173CgDAubJAxg6raZ0XSqOLPIqE8TeIACFDSLHWJn7vaI06uGPWgwdWyoWk3wF97ljm++dIyYPXXdGdwjErroHCJV7n4TsvhOtdu44/pVXzjbuXAQAfvvVM774CwKHAyzhQl+Cx5Ly8yimzWHwer778F8J5994b+ujBB+G3wjX9LRLELmr4A2rnAkzRGlNWwm5G4YJVNY/sFxEofYJ1eQRgl57M7TAVUVctav54bkJvjXMV875ysluM8wzIw3kVrT5XZND8fj4N7/jZnU3oIlzcJgzF23wf7tgNLtVLk020nNQVz9vY2kRGzsk5SasnN2/g6Xnogyxhikkd+D3TRWhf7YFDAZ/wc6FS93v4Y/fI4yLH5HOcbx59jmUerapRXFOU5HP6LpZZie9P5jFlyKBrhE/T44YhAGsZLOPMttDGfFlW02BhDTLIIIMMcirkVFtYlpaIaA95IxoqItOEJCZmRqPQq6qWswYVVcE869XvjBqKHFOwfRCRfuvOBi5CAPCiBdsusF2gZ3/3yqGltTXfGWNHVASBACeBEpto8euJsdb3//BZ/10sMEfrspsfICO7uhRcdPUycoW5WNtBxWCTimXYNaR2gnQVHGLRwb6AoEVDaGxNEzWfmNg3XvD3XQUJJuYsNDmebqA6OFh5tvRvmZCXLl/CmZe/GgBw9UbQ5uuqin3Ux5s8OFw4oubZAj2gwDm4NGAj/dY/3hdtSwrzTr9fZyZx6OHnj954AgAw3ToDN+c4sF+qJdByHHwMZiZtU0J26dEyJtnxmWySwJtC+9eh+A7AIf/eFO06ec6UL3Ed/l673roUiPWdyuObWUZ+cTUEfJ5bNqg5QW7OwrEHofFZsRqhsMt2X+Z5Bw5oeUfR7BU8Pv0HnwQAvPGu+9kHBRxx2+ZasLDqx57Ack7QRYxH9vNyxKA/ujqWtmmSebJuNbYAljQ5J/vhszxaYnE+9OIWGdyVKpETuj4hpOHlbhzLDn2OTCYuz6AZ3+smwXKyZ8+g2Ho6tJ8Aj7EyuGM3WKkXLoeSItbkGFvegwCmcmsDdhQstYJJ+DlGWLafD+0SUBVWC3fGrB1J0LZ92kOcJ4G8kv8In5PaY4f9IVUirHJwwuojVqbJI2u92EVeZbE/DGNYRhmMS/YHMQS2sVhkYY16dh4CiPeaDBY+gZ19aTndGxaC2SqsCxFskqEnYOVnVgRTFegRep3L0NXh2HgqAcaeiFUCn9r5aB4b0o/kPoMVUlPfL+h55KqRmlUWDadLdX4Limn8GpLHcgLowvcvWUrhYpLv43k80Qi7xNEtgIgq74TwsobiSh9fXpNFEIIgfnxWxA1LGCyM97GEiWxibdOgkkCpAFnMCDnzVwQhaZs5DL8vuajYqon5WvF5k79LvhzbFy+imIZN7txWcFPeqKq+fITsuwAso+lCctuqdDF2oeIygEwYL7w/toDLtdLPVKLLMvlNmkcmT/TI1bCoPHDmG6FZysUzcN96xHIVQiPWdj1yUO5iXYea2f+d79kt1tucAizSsPUhG7bDPqqS71LKpbjxqt4JKsqf9O83Z8BsIcgxH+8rbuzbS7o4M49z/M0T3uOlnFPPUGEpYdDJ8/FuygM586seezJs9Hf+t09j/MavC+2iorL33DO4ObvJZ+f7ozXGmQAJwo3zvMBYSsiwgc73zynzt0E/XhOyzE6WDRYEKHiWy8jKTXAfwgY78cGsxDZd31MfrvK4AmrmYS3G4W7t7hYmZxhGuBlufOXiReycvxTuQTYNNZ2gvh42fcuXOYePSo4AmKZb29g52g73oFs5T9YEhyQ3Ssbdp/XSVOy3zIo7j+CSDjjDF6sUQAQaWGHyEeo53QEEYEUUq+8VQkEQZlmGsgzgEydlY7oFlhyTJ+qwPt159j54c1yheD4ZXIKDDDLIIIOcCjnVFpZWCGraqoULrxHdRALxVMpHFSujRqltEeHvOnK2eRia+mIdaPS5Vlq0Dblmcsw4QGdyjO4R1SGjhXX78gaaTVbxPGIugjoeOE+15VSbj6SW4m5M/B1e0slmMyixsMwmL9hEIEYsOqk0HF2GKpY11jFnSZ1gZlgWWqubus8toUU0KicY0e23ELJXpSPPmOTItfMZrBD2JiAIeTbJ6UBRwB8GDXuU9LlIVCZbYI9WslgWne/56Hzab/w8qXzISS4/k/wtnw7HNUKbXFvy/x6fH+BleegPRddxl2WRNUDcya13sUyN1jJGKqYDpBaRtDsloT1JO7XJ90BwgaVAErmePHPDJL0cQMeBfzP73LUaFd8DsdQqn3gpqKUfdcCYvXAHPPbEvcmbLGGPkcAqBTR8h4SE+pk//gReMgn9dnAj5LZdm+3jxlqF6wJAy3dyiywU47zEmMlISxvg7wt/3OGUgmlygg12lh1mtLDE0nG2Q04X2DaJc+9vCxRbdIHR7eVHBk9PQxs66bedbeQXg2dg93a4xpkzl4FJOOYV88+e/AKuXQ8IkJovxN3uIjZpuXZ0SXatjyVOSrrcM9+uWFiyLojzo3NJ7lz0X9v4cvsE2cWVIoZNOt3b9LFMkfMRyCNAIWdtdM2K56fISoy8vOPCTJ2hcsGcfeoouHrN+LVhDr9wj+BgYQ0yyCCDDHI65FRbWIUJ2mZUFJJk0vVgulM+OnmNZkKiLaAkfiClJ+Ai51w85lU8z1MbddZFC0uJOQcHrVbvrKyDIlx5MS5RnQsa1pSlP5AkBKe8YOvB/rQooYhz/W+I+0DWAZ5QWM/ArUpRBlr81wpKElUFrq51NFNVTJ72sTSJpQ+6s22MZ0h8q8zH0JL9Tk1Q6yIGxKVM/NHyMPLQ2WS8ooUlDNjWA9S6GfMPimGvFAIAmkbhqsDakw7q3fn9bIglZ9I+POHvkwAYKcBC4mhp5MmtHXv69uM4dz4kwG4uaHUXYzQ6jI2hJdY5G8ekoeVpjEEtjCnybEgY0NmYlYKWSfskzjOT59b9OyKglRbHLbbaA99KjXhX4q5JikDHibJULgbK4/xUKrbrDgDMx42sDC4BjaR9Kq9OJmzsqHDtsT8Obd0P4z+rF6EQI1a5HQ3n5YT9psZbKBn4z8ixqRLuSiS/FQurWIZnOl85PMUGdvQ82LyLVs2GJOHvN1Cz8Pc5XvilhUIdXmvc4GRtdnagL58HAJwh6AKjERRTXG4/9hgA4LH9G1hyUKQ45SE89EaIASmWNVl2DtU8WCiOC51TfRpIhh7oJSw5Dfq5auIE8fFEFV0NOs6PMhcmHAfPGJ0nOYDzLRzTdsQK7do2ekyELEFrgwnfe83x8E2OzXH4/mrNeOR8CTPZgulSn8fzy6nesLQGlAnlJID+hUjZI8TlkCd8PJIXY10WERa1mLA9wUlEGrrOxXyuWKHYu7hYGy7uxuqeW5bnN3CYckD3SoX2cgic2ieCu8Ohd42lCLj1hVIjWYR5sOv6SSruUG171GJEHSqdLNbSQNWzbojrU/erio2MqC7a7J4bVutd3PxzyfrP83ieUFtlWYmyDEvWfBk26FmzjCSpERiDXiTbKHd9BeNi1LNzRPcen23eedwWIIO4RwEo12sBPd+vjt8f2/yThqSKw/p5HseBDh6J64W/KLHEbBLQkOYozKmp11hshAVoa48uP9dCc8NqCVZpbRddrm38PO7q65I2S0M71bvBuRbDqB6ccxJQROZfqRTu5nswk3wm5SMQo5ESNUmnUO1C6X1kb2hc74KS8Rr5vr/kNy5RNiVDcXxk8awOi/lkSZCBd3HTkU+Dnoooi0TTPjIwyLxMN6zUnSQLecGH26mBjkAXqZenpojMMCO65lA54GZwN5oq9NXFyQaunwtPsE/wRTfdgWGNrPzzYaPRbY29JwMjyp/eDu//ns7RsTUCHrm1OEDHUhxSh2vRtDiUunSCfPb9Am7Q97nkwNXolXifILUkLNAjgfsBLQqWKcoAo2SOhhHrXBFzXqWAVts6NEQKd9zYvPPIhQyYjWrGHqMubGLXCjqX2wbZ7gVkQ8XhQQYZZJBB/rzJqbawsiz8R7qtHqZtfFQapJCfMcCiobsA4rJyEfauJB3dmei+klwjdG2sLivlPKqmQ86gq/Gk3IeKcPWWDchmDSzdBN0I2Lt3FwCw8Tvh0lolMHW2v4VfAV4Aqzk3o8THIZaVsFB4BWipVGekZJyJRQSVVH+0ri/IJ2a7Q8wj0zzP2zay7gojRu07aE9whBAKdw1aqY5Lra4oMzhq7DXhsMqrJA+nfzb5W4hzu6M9ZHsBPKKZzzRygdgWCKS2ALCPvgNXNWmxiBW8FfdlDxEXWQFi+FVzJUNCtstPi1XiUvDfdDJjm9HrC3cA57cD8e/BYdC0y/0RzJng6mn26d6BRWOpcfo+T6mLbszkmfgPsVBaAJ6tEZea0QobQgLNx9loLa6LbiosL97Esjyeavi3QEU3t8y1Gfr+qhILNPabpIA44KzrXZGysAgTQ5H014Y8kwPGnCxb4vp0FlM+0z6E5UOv8OYBwWqMQAJa+4V26CKNA/sDPYdgahlK+5d8L8oOmLBcSafDILq2RUYPQcmyGs7V8ALl/9MwhvmkxsUzYQY8Swu6Gefw2zsAgDoLlvbs8cfw2O3gLL3Osa5sewzYs3RAwYKnitbL3B8vyHmEft4Z9NNX1sMavRUlK4pzFrpbmVVYokVLtMqUQC2rOkAsKx+et+5aaP6tnIBBGlRtxfsShNIpgOk/I64tWxkApgvcHPN6t/ehl3uovgyX4GBhDTLIIIMMcirkVFtYZV7AZ10sfxE1ad1DPCMHnDHQLBsABg4biwgokE2+sw6dXYVzuqZBRU3H8ru6a6OK5yFJng1yBqYtYctOO4zFCmocDu7YAQCcpbpZNr1mJHEvj+PJqQ6Jls8/lEqC7rY/zzGx0NO6UcpHNouYQNx2sVijYYa0y0wPeRUzybYREi+sFt77nqOR5znXxYTrnByF8C0aWqbe97rRSfEjOSZMEM3hEdQyWFizZZ/Z75OYDhBAAX7tgjq9h++5BqM1iKSESIy1HLf8gOMviEJvNdDARmEAIrFx5q7wee5eYLwVWnlVBaBFdrOBOSTnnKQZwAf2AazGxlKLDliF00vpiQwKS5546Y4rAIBvvnwBIyIKfEvOi+c+j8cPw68/HONfvZ25RTX8AedjnEwa0CaxpzTGKqcVbFSpGNsEUEBFsMia0QoAGDHONNUWBXtYFSzQ2ClcYpHLq8Joj+PjYNAnymrq3c62qGvGWaV8S3LvNOYYS7RI+2ugrCSGyJgMPArxehA4ZU2GjA9lydiS/WmHyT0hhnXmQrBQbhQaehxGatEFT8fVgz3cYmPE6ut8byVJF1Xoj6VlY9KyN0B4B6LVpZO4ffK5zj/prINjx8W1pcgw2wztzzdD3HBUzpApiVkHDsOmy5F7stG7MCKNa/rYaydVGPoosVbC/WgwKcJLMmOM/Rm3j7urNrLwvBA51RvWpCzRaYUuZyVcAR6oGCuNi5lRBkaFDvPELllnYokQ6bOmsRFlI+UvuqpBQ9qnuhWKI4fOCgBDGBR0NJWNRD6Niy4X6z1uX94CAFzmml5YRKLeiMYDjrkE04Fa2bASoAnANItGXEziFnORwt/IRuMRECsAvJbPHoanYjJHG12BjRBjOh9zOHKiirI8Q8YSBoYLjne2R75JeQZ1GNucAkvk+RasTHyUK2gGngXoUmjEEhYCGEiwNCsIwrjhOxddYzFejB7AIJInf8tXKd2NHM3gMeXYbW6FY+Ntj8lOODa9O3xuXDHIJqHdl86G8fiTxwzuuBrcLLfZCVsJ0ksWn5M2rC55JmHscMpH5Mpf/b/938Nz/G//H7RHrBPFiZxvKlzkeD2+H67xaQ+M+HR/mf3bQCFnu2Z83irZadJaU37tWOl7V+WqwhA+UtJgqTDbZBoTKXUhpNG5gz8M7RdWGd3188Mn8z0uuDIn6gqLmgoPeytFoEaWlKTd8t24BqbcsPYEZBTxlsCSwKq69LFmXEbFzD27xPSp4Oq7eDEgA2dnplGxyKlNdp2LgAdxcYrTPukqVCcgKmus1r5CclwOSt/IeS2SfEpRhl0PYBGlsxnnaM4EounJ5k5oq2/Q0iUo6F/vZn19M6KyGm9hRVsWZd31tfaUKP9KQ2dhU5xwwn9+sof78gJm7V18PhlcgoMMMsggg5wKOdUWVpkXUPAwLAegE5eP1CGMWl2Wo+AOb9ug+7RNF/kFG7IlNI1HQ5XcU0O1nUJFSGkj2fCdR55T64qmcAdXkYuPRI/jkQ+aMABvgIoZ8fW5YFrbp+bRJRQLEbpV7jpgVcOOZn4Shy7TwHJDvS2iEDIoQ+tI8shUXz5AwB5Kqb44IV15aFvYVtx6vJzO+4KWkaPMxBwqnbHSblNHnkdPf+FJ1mPqAq1pkc2WFSZ8+DF51bZch9uLfoxF1sloU548hV7LFGSyzTWIwI15c13q50q0W+FELgSIkQNnLxAoEJRpbNwJFGdDf+SXqH1vF6ikrAmf/eDVc+TXBazCY66HdKc5VaI5x9yrpHniOu6Uwve89CXhfh/7YLjG0/vRVU2jFvoAqEmy+xY+2yPwOEtmjWkEc/hohewnbsqUTUEk9q/0j+95En1SSFGtfQJ9ysRO52A2VHyWcJ6Dop9zZzMcWyxcJF0V+0DDxXQAR09BZRvMvViLQfIT2pqCL6Sd2VKhrFcJolvnIoiq5uRpM40px86Vwf3nZzeRPRbcZtt37QMAzsJhm46OEclhS90DRUzyysc28NP5/p1s0rQGfh/XtOQ503Ix4ki16bPzW+ttvLZjhzQ7G8DuDgBgOg4eoHl9OwHlLNkfBpkUeLQMq7QqlleK4DFn4Jnk1caiuSpwlQIoWJjzs5N9fNvB9MuqODxYWIMMMsggg5wKOdUWltZF0LRiqXpqBOgJHQQUUBQFChVUzmXTM6tFzaMTYEGGlpmZxE3AdyoWd+toYWmvovrbdZLp3UAxGiKEyAoaUubcwaOl6Xfr/gBv3356HtVPicGlMYzUkkjjGfLdsWRiAK5a8t4S+NaxQwSIAQV44UKMrAW6L9Mu7O9t2zNcRIp0He8r5bMzU0BLdjs1KQcVmByACDduvD8WPwB6i0LasqxrbE+pmZIye9T1AJf1YHJ6bKVon5IDQEZtXl1QqD/PuFgSkD9Je5OgtyTF7pwBzl+mBRMqRWDzQQ19dosnMD46zjBnv91gLM6+xuDGpwIUevc64fm+j0MK7Nuij7ekn7HP+df33H837nrT/wUAUP1//9fwZelRMqjeMCnWLvtYFIcGr6iBS9T853zIkXc4kHjaOvkgVvs3cvrxMwu3jtKPp7Q5SVBmFYHi4TcA10PhS7f3FG/s43m7jGFd0w5aUiYSm03ST2q6GtrOxTiaiEFqmYZPh95qFAslWypMlnIdme/9m1jTm1IZD8O/PaHuPjdonwwW1vSZPQDAFe8xJRO8okUxyTRyYcxgWyTOBfTvegZEAFDqVVkHxKTJ361LEsuT0+Q+0fvgOyie2Ekn7G4CZ8J6lPEdVk0GqWngEB6kti20CZ6hPMJCsljUMcZdbRsLxh4S1TIyGTLGuDS5EK+OGxzUeziy6z6XLy6nesMKZCZtZDBAsmEJA4ThH0VRYMTsa9mQoFwkqZWIfF23qJnIYKV8aefRRuSgBC9VfAGi96xzsUKpK6QKZxY3Cw2Niq/P/kvDatf+56f6fJIkDyfdqETk2DrVTPqdB2CJ5iuMiufLpt6Xw1LRRSIuK296aqZYRqDteheUEF46xNorwiyQZTmyUZjEgtpsuhYtd/0lQSsdjr94KYuHuGus95D6UAJuOZr3rqaT0HPSB8qvXo9J+ci3w5hsvhI4epYKyJwgE/SbflzgVA9mmXLHOnfWYTO828joEpxcKuB3wobVbQU3UTuyqLlJP3cQ+ur2HQX8q8OSWny4z7Najzmn7t9Y+yrJ13sdCZTvfu3r0f6X3wMAGNJ8ZI55MEDgWgJwuwU2+D4secFvUMA8MleE8w4St550sMFxVGcqMpZOARMBPaFfWOTZcvQb1vZrvyWc91e/B+azYcPKHv8v4fzf+xDmfBfHklumdF+CJcm5E3qgiu9p7Vx0h2Vrn6mEDYt1mmTNqIGyovstEhP76D6r6VJfaoOayZ0CFnBFCTCvLn8iENmegYee8e4uzLup0dhmbmfH+1c9hnWlr9bfcQ+slNQJbU8QgX6VdkskRRnL88qG5ckG4s9sxJwxL7XAjIFiKEBxw+p8jYaNkHpYWqm4xkoOn7M1LDesOfNbvSlgpLQScy0bk+NocoiZPWlFO1kGl+AggwwyyCCnQk61hVUaA+/NCpwdCEaCFHItqQ1tFtNYIqSuqOVYCxjJ1xDgQRndWKIZua6B9cJnR03WqFiaRKrDKmWQmTHbQlCCMzHXwyMLfH0A5vcH/HMz/X1k9GNI7orCcfdfcBPw+SDX6wcwZWTwUiBRC1Q4i9agF6iqUoSxA5b5EcYpOKmiLM+EnjNR3Il5blAyCJ3TxVnkY1hqwa4O96ibFpYqe2UDXL1BbylEvUolrAVyKCkbc3UvXGPvBE1Modc8RYPX/U9XcrzKK+E5i7squDupDT7K86w/nuvjAdIYYvsMr3EGMKRqUGS1QD5GM6W7mVrkopziOhkAbgoW/6jDwdcFl0r5R/sAgK3n1iD6fKZouSTun02Owze/+n4AQLO/h9uf/M8AgCkf3mdAzcGecVLUOrAnAEisaoVKcvLEIENwLQGr5VTWLSzjE/eg9H1yLFcauRdLgsc04Fngc/Tjfydc+8nr8Atyau58XTjxwevAp/4AAFCwEm+WdTA0k60kWDqFzkueUBfbH9lgkvb3rlR+KmAhUHNervQOGZOj8jnXhB0HRU/DYhzmzl7e4gr9qiYLdxllE3Tkx3NPhXk+LkvYJdcUqeybZShZumjKxcOhH38ZmymS/k86f80xgQL9PD+Ein6Z6OZMPATy28Y3KNmvGHN9mG7Bs1KyxwFv0vt5dExqbdH6YG0ZpgaZTEVKFc21VBkXc2Jz3qtuW3gWNBVu0C432Lu4gVn7fxSXYObgnY/IIXnxoPoBz9k5o3GBTKhIyG2kZjWMbvjbMPmM1jF/xYsb0Lmkyq/kTPUoG8nhMgYRPSfJuEr1XvfUnJ2fDdiw27sjZNfSbAzWKlpbsByOo4RSSY9JzpU0WkFFIt8YfVI+xgMkqbfrWkjZZs1FwNoO1gqtCzfrMoeRfBLJrypMnzNSh1ewbhvMuGjPpYJu0s70BZTjsqgsHbBHtNi+kO6iR0M+r58qubbmswJAvsUX9HyLLOTZYkFlwd3qc7wEnVhsKIwv0N1xkfGgHcBy86qYh6U3NI4Yb2nodrbIkVNZGjHbZmE65Dtht7v5LSRQ/dUamsMvFaNGAGqOXZ48732bjA2eCzkzzW//USQ6RZJfJxuRxHO8A+YJhRIAbNrevdolc2x9rrrk79QNu+7yS3iT0XkX6ZdKcakpg3u/612hPVuhA/2NTwOHdLFXoZz8c80d2J58Ljx7FmZFARdddyJKqZgllX4T8+aS9+fYe5PMHSFg7jyQ04XqqPDV3sa4bcaN5iB3aDj3R+KCy3TMQVzuhQXfPOGi6w1HYYC1dTFOKbGrFNEqm4tstkC/+aw8+9pnOM+vvFvA6nqTIhGd5E1RYfXjHJ5xpT7PrYjnaYkbuz5JP9M9nZhiEnFslwaUWA9G3KwNKtbTyxnT81rjs5s5lm3a0ueXwSU4yCCDDDLIqZBTbWF1tobXXY9260s9RZqgrKQ2NFbIqHkWhDNleYciD7u+yYUOX8dsdKEc0arPedLievEqamcC4lCFhhKAhY56ExDdjToGKFveY/+eXZz/zDPxTCC4g0QJTHNgUjcX1v5ODY6uFUYKWp7W9qUE0hN1HzgFAOVdX21VrtH2FpYEoHWWxarNUnHYZwqgJXR0GCiB5s0Sc9LcCEqtS54pttcfd3dUCFVsAWCRaMvrrsOTDK0VBhAkOW07ZCE5k0PdS3AMv+3OBgJUoK+/VZ4pMb4ruPCKbbZB7eOIPt5DJmnN1AaWLS1ruoLH2kCbVf+adw0sXUH61eG6t692yD5O+hqi+rahegYDzrdaezy4TdDQQbCdDvb3outQzi+9QrVG9bSPnuUjWk4JkvIkyz0FAkRNPlHpZT4JYjEtG21cQIgCwIhXGu1chPqugGj0v/+74XLPPAp17QbbFVpx8+ZTmJfBPXV3c4v30LBSe06ojZWC8JGJp1j5vqaUiiAj27vIpO3Jc4pXxgHIxZNOgusGLqH1CpPjdmFxyAm85fqJaYXebBGuPnvqOkpJBq3pm+26FTJguX9a0woINcTW8zBT1pX0mHhiat+HQcT9r5IBTQFbQsslKEDkWSTrjkAzr/uSJFwfrPex9lVlwxzMYSKZeMa1ryhMXDMUx8gbG6tx1xWpypDhT7Z2IsjthchgYQ0yyCCDDHIq5FRbWI2r0aCCp8Yr/IGqT33CaMRcnlLHUiMFVa7JVEE7ARKEY8b0xK7CYSfEsQBi7ge8itpexGhYD9EBVGKJpfnmYnlZxo1uv/IK7P8eLKyo3MT/9eKR8OMlx0TS3CzXrfn7gWhyCKFv+De1VsLHXdsC1IykBIF1LTpaXSpaZD0Ho2jYqu1QHYRclIPDfQBAjRYL9u9JPHSppbhuAFa9wh4LCKbw91ROsrJSq0E08GoajlYTg+zeYOFMp4TdHmjkSwJlqCmWlzcxuhyiMWbEeOUcWC6DBfl5atgbTYld6sxjydGxDjnzdArRvn2HgkXxijJc1367QnMrWBL2T8J5N9BFDbyhdps7j8u7ZEl59PHwqQzayAOIeJ7IXPpC9ZDnszx2G/3LP0cv62S1aR5Teky0fMn100hys5Ljcub4Ja+F//2PAgC63/1PoV03j2CbgFzp2EfN/i3MiRffJajBYx5bK94UlcRg++B1ILkGEGOsgI9egxX2k9VHQuMBhrOjheW97wmR+d7czhvs8X3ZZWeNW4eOnhMnsaxFi6YL82SqhXfPxUkv9089BmJ9GRyHtSsct7DS8jdp7l76nOtwBoueLSRd11rCzzuaac77WNnYCkbA99WMHasRO13CM26r2QejTKPLxCoL5y99n9PWSqHHTuMx5dGqk97gk+VUb1hGWcDV/UPIwqYBYh8wGQkdCKJfpMxJmzS26DgCWUZ6J+PgpGZUJK3senbyuLImS6cs6M7AxeQ4cREZKCGZVTpOREEnLe8+h+YMk29vcXH3PeIx5nr1j7dSJjzNyQDCRG8Z3ESCXpRy1/IqaJPDsl2yovuuS/4mGMVaOJnYUitL93kxsRfqFhWJaxdk2T6yTaxbJQtq4olaBYpgVerkb3k2jeM0TOkGmFIqSbvaADkBANxSJIUdb2OLrNr1FjeTfYNuwdwzdn526QzMbgDHdFxNmgVwdRnO++xz4fMVjcWOlkUu/NagALEnqJrgPnGug+PW3XGl2tndwuw76UptwobffaHvkAknwK7SyMrQ1qevhtwl5UxPiJt0wVLQsgKmSDo6BvE9UETQgiSI9v0mgf8JEqofAeglbOwT/iC3a+hPniskr/logvbaE+HW9ZxNyIHZHgDgGmd12ywwp5JznVRfxh9GQEp09SkXcwBVpGTzMMzxkXxI7xQ0adXSZGbZHCp5vxRANjUoKmsODlaqHHBcZ4XHLfoOz4v73/r+vZEk2qzEciHksX2ye6q4AezvEwBWJ4mMTUrUnDLopwqF/HM9ZJCGFmT2aACWoCirRVnXMb9U0ItexyUlUryZrIWJ6EBBTxrkueS50V1oGmhSgUmope4aGFSRZPyFyOASHGSQQQYZ5FTIqbawOnSBWFbMT6obOrH5JV9IextNZgFajEYt5tTfM2KLvW+jKeyo6cHaHnXhheLIR9i7uA59UrNCqvgak/XuM2NiHpdcrjozxf59IaJ/hq6hDkgqhR6Xk46lMGTRlnw0/X10BfrEDSAUKYpWoXIWniAJSwJdb7tIzSRWVWGyCNSIAdmmwYJlTebMXVmiO0biWqA3TtM6Vus61on5PzjuEkwtrPS7lTwinnCddYnU+BLGZehzsyTjQKlghOSX7qnJ2V0UW8HWkHpIizEwuylzJQAGbs1anC/oWiaza20rHFYB4ly1UrnVx5IUmha9V2NsvyxQZhR/JYzN1f/tCP5R0XTZ+DEgtWkX7MAMFi2fOhd3C3oXapcQqE6T4DwAbEJhTzRsdpZ1/bwTQt5N9Dlh0o+josSZszvh2vPg9tKHi5WcsZx/t7RWzUsvQj37mXBwKc++ja5jH1VhpszqeQ8eGIe+nI4y2FasHhllAxllSajoYOO7ljFXqu08FN/8RnIo/fF5BACGk7XjfFdtg0oqb0s9t0zhGtePK+yYsvMwEgKg+8/qLFq6giJPcylT975as6y+GMBKr30GAAXvgZNTEuK7kYxhrKsl1pTysGTgaI0QeWcRgJH1/t/oXhcLy3cNnBa3P9cJkyMrCHXnD7KiBJifKVbrsp2hNKbPe30BMlhYgwwyyCCDnAo51RZW29VQ2sd4laKD3iXBDq96fkEJFIIaQV54GAZYhY7C+iWcE2AEYe1AtKYyqkO11/BaLDBxHvvIaiHlEJRWMaBmlI7tGjMpdm419l8WtP3t3wsWVhosTUk602CqyDoQQwNoaCEattVmBbQlE6fveQMVY1JiVcE6eMLZdSPJwqoP12nxVWcRwirtWyxrHM3DPZYJ84C0WUAEKXGqSJo8mfrfU+aH9edO8p+jpH76LNHFIqGujEM7hdZMFGcBynLLwrlwzJEkWU0m8IwbWRfmjipLjPIA2BipYGFd35vhCmNhuUB06xL1MmiUS1oPbdegIN9iTgtcW4ucfv7Ja+8KbZpex9GvhmvP/jhcbif3WM7Eskv7za/0m0Ufw+p4Yobe4pH4R53Yr1J09Kx2yEk0fM+ZHQDAeGMDXkrJ0OLJL98BfeW+0JY8QNAPPvTLeO5PQ4zKwcdxidV5j56GfjbQiqgZCzQ2FRZS7JMejsY3sYLwgnDwe6YjHLVrc9q7+G6rPrDcxzjFGlQmkvumBRzT+QiE+JFglfKKseu2jcwfli92livczEP7Jd2irhRGJswZK/VoyimaLFifHQENYxyPKaUxXRGLfu6n4KFj5yXpG8BxUEZqdUXig6QNEmdyvkanOLektJLz8GT3aYUswfuVkkYA0KCF8I5mLDmSqwyGWP2cIJRCO2QsA2W43mxaB28a+GP+lS8up3rD8gimbiHYAaaPd8moNFLjZuSgBSkXczospkR1SXnvxi1jxwvaznaAUqtMDdZn8HQZSsVNpbKIlBGXmTFZzFUyWkf/S8uFEo3D3qvvBQBc2H4cAJAduuiaiRM3QXqdtICnk9mmsxMIoAm2VdwASuuYp5Uu/sLW7kjf4byLoAtBRSqloHheRWTBom1x1PT0S9J2md8pWGIdoaWSt/Ykd2f6ksdtKHnj1dpvw/D3b1YkCN0Ov97SOXIim8TfYfISYMC+69IAvwT2SfJb5NCZbGLhEgcHc1zbG7M9XNhch31uWNUyzB3nXATvGNP3go5u63Ds7lfdi5s7YTK/5fOvAQBc+M9/jObJqwCAJReDDds/syzGDXoGdeFPGaHvQ1mob6DfvLZ3ggPwtX/vbyH/g0CVhK2gRHnVQjHnBkXYqF1TQZU74RgXvY0Ld2LyePitbhZ9zhMfU+9d60EtrCbQ1g77NV3QfIcdEAe0E5e8LpOaVnQDetfTV53AnyLvn9U6uhGFaqjF8VxAG2+O+JJXroOLbn+6TzODhrmd+1mY+9utg+LirgtRCAt4ssAI80v6vqZL9Lr7r0GyziTfye/lu873465wfDFPXYKpYivHNBHUlfYR+GCF1NZpeDJY2Pje296NzwFp2w6KdG5Ckl1kOoJecjLAlF5DL2dsA9dN1DBN1vOBvQAZXIKDDDLIIIOcCjnVFlZejNH5GcYC401UFbHMK7q2llkF00kQmq4+46Gp6TY0+Wf1HCMh3XRiHuuoFcTqwVAxOBvBGWYUrSkJchqVIVOioeiYO1JI0pjWWF6+GNr/dZfDb37rmWOlLlLC0dQ1uKbwwABoRZ+K11Dw1PwldqyUTghx+Z21MQ9LRbYE36cLxJpEgCXsfbkMVtWsrbFggFpg0CfqTX7VEgJWwRexVlbycKmW+UI1rNSd6IVw4BxZKPICI0Jw4YOd0WmD1okfWSzxvp6XjTdWaGK5lXCkrTs8eyNYIU0XrptvWizYE7VAhmF6F45YYpmCpgXreKzxDg9cCpbVX8JfBQBUVxVmnw41o8QyskhSkBJIs+RVlZx3Gg5jfn/I73YAbD5wLwDg/Lf9n0Nb7no9uuvBB6kF0nzlAvzTT4ZbjIMlph7/Q7gttnmP5LW39nHuYiiZs/fMM+g4FyRPrzm4gZLu5pbBfGs7LGnJN7afyWIdd3QdOT+KrkBJtlSq6z0NAm+33bH3wSkT6sGht5JSF/RKKoSwgcitTA6Vrf7WGIOmCHe+UQTL6aIHMr43Ba0SmBx6FCzSlpa2TyrrppbR+jGPky2sdWvJAissHicBkk4Cb0SZhplUGxsBOvFd9AbOrnoXtLPHIPhwiYdJPAU6h+H7lYmR3PpQagmIJOTOd1Cug1r3Mz6PDBbWIIMMMsggp0JOtYU1nWzhsKqgpeBbLrGWHoopVte8qZCx0mkFgdV2MAyAidZ6tKjIWAFMjMQqNFQkp2OXWY1GfK+8f54bZDTtcn4araPVpVWf/y/OfTMeIafZU33LK8N3//kZWGorWSwcCRRr2nSaSJsGZyXYnga9RAsVUd7EypOeVqZyLlLPWykzgh7CKkwXzlrMCSRoqDXP6wptwhcYHxOrkvriT4pXRSsyscTSDP40oPx80iuAKkLD/S5h/EYhY/8aoURxBnM+Sy391jVo+DSaqAVnK9QE6tjIcu9xaz9YU5K0XTYeviTgJ6JWzLEApFYqmkmOiZWVtXjw2rnwi6t7AIDNv/Yu5EfhvOv/+6+F66reGhAdtdJALkAB38OSBTlMwg7c/a0PY/L6bwrtf9k3hGvs19DPBeCPPker/7Gn4PcZ13rqC+HzscdgCWtfXgvf1cUEfhHsNzcaIVuEPso4Z/LbR7Bk7vd1iGHVTYeacHWBv2vo2O7e/PESKoOVpG7fVxw2Rire9u9X9BpAH4uPdjiecG/R95vh4jEZj6AZa+6TbRXA2M9eGcb6SAMFx11+6zVgCLBRLE1SzY6OWTwafZxP0hHatXbJZwrUkPOeLyE/hcen30UPDcuLtMpHT4KDpO2ongRBKqAmS0gs6eMQU3Xi2mjyaJW5GOfv3zUtgDSr4NCsxJu/lJzqDevM5jk0foGuDi/KhIgu5Tpwj4iIo7ZtsCBJraGdmhkVS0LLQu1shdqTBt/KwqQgZWtlbNA16ITVQMpRFMmGlcvGZaIrDUDv5uKhTBdxEWlfcW+4x9kCxfXQBpmcdfLbmALljyMHPZKy2EItlZUQ4lqk37lVJ6O3Npa2FsCJ04hEljpSDHnU/O2CLALLZhlfnvVSIcltVzYsnHBshbGDf8vnRnJeDCInfZC+nH0A22Mh3teJ5Mr0OWWCduy8RksFZJ8MBXNtUZIRtyhFUVmi4jM7Rv1tBzRVePqW9T6mAIoAoMNIcq+MgZfNiXOnbR08KRZKG1xutW5w/gn+3Ya5PbIZJt/3NwEA27/7MQDAwdHtWH5EdJTCJRt8RHdqnNsIi+fFN7099Ns9L4EjPVRE2f7yB2CfCe6/7ibn34N3ILsVcqX2n34knK+A5RduAgAOOWLn8hyHLUupu35ZKTkS9cEhcqL6JKfqsGpizlUE6yLriYsJYMqgoIU+jJM/MzpSLpnIet1v0sKmYJ1CJzmIiKcdc1unG8dkQoaNooz5lJEoWmugDPc74tzYy1tM51R8BXSBDIoLc1aGsTxaLNCs5Rzl6Mlv0/dnffNpk7amm5gAT4zvyW9TFOlJ+DuZMy014CpL3890W+SzE4SGrIbC8b6U8Wqp6HWZinlrQh7snIt5n5kO/dsoCweb3PNLy+ASHGSQQQYZ5FTIqbawdnbOo3IL3GZwVpGNYDPT0Lm4CYM5rvQYt2g1RK64rERO9dLRehh3Dk2Es/e5SFIUUSw25VTUZEu6DYoyixaWuCQzo2LFXuf7nJFccmR0n7NlWT5Cf9vL4T/whwBiMU8kBXhjoTyD45pY+J73o4tL5eih6cLK4Vp4ugQVNR54G32owt7RAZH8Vko3NLYNxR4BNAK+8DbCqNM2R0lMwPX8qpOCwxq9FiwW224CYRewgU7u88WuZy9R2yOSoLVWDGYglqFQaMh0MWM+WZbZmIvSdVKBuUIr5SIEpKGzaG2J5akrHy065Gmqg2icvITzsER0LFiV2asSZ44IEfb7AADze/8VfiOYbNtvCK7jw4/+dizIKH3eoIeub5wJ0PQ7H7gX2UvDb+yFe0JbyhxmJ4B83JOPhx88899gjzg/Dm+H8/70MOaTHbCsiVHAUWR2CO28vXcrglCcsthjG85yJGzn4ARswTa3HlhGYI1Az21vVVgBZzSRjSV6K1QCbIq0MB4t52Ur77APZUKAVSCDzA+xSsYe4BKAMxvhfbhukOQbCoehjy7BehSO3Swb7HQsmdHQOit1zF8DSXxbnWEWiyIGGfmkaCmk/xLwg4DHTnCRB3c94vWkLyW80al+rZB3IwOQabGsQo/MTBfRWN4LB2AW575jKCX3WSx4m+K6jLgCJaXAayimiJRMAWg7C1MJiEPCJRbG51CSvPoCZLCwBhlkkEEGORVyqi2s7fEZVNM5jshn5lz43BhlyKjVbkyCnz5TG7D0sc9s0N2LrEROv6qhP9dPDJZ1OK8Rtvbax2KNuRJNO4/+3JKM8FmexyCz+N8LE5V4KPiY5KipbercReivpTPbvenVaH8pWFjG9tBk8XGnWkaakAuQEVqgpwxya78ZYyc+KdroRUUl3FbDxyCpaKPBEmD7BH7d9omGjcQKcNxf7k/4hz/h+5M82B2OB4xzfzz+Bb9qWQGrwWYDYPFA+KZim+vmCI3hvNAs82H7wpdLlkgfW4WZMBwwyJKbOlqXvbKvoSAlxlk6wdros4/Js0ZFC13oQzrn4MiioandmsqinAkLQYgf+b3n4IoQC9m8fCmcDxWBLhFnY4D7XhUg8eX9V8I1zm9DnXkgfH8u8BYa6+F3wnXwewHE4eYL+GtkZ2D7Dm/eiHNBIPEafRKuPM9B1WLKeE9t2z4SQvO3qTwMrdB53YNz+phkOGbRW05arC5n48FoTasYlY1WV2BX57WtzOM+rViwL84jcjBKxzUKmNI0nWwUvJeG5X0lDja1WWRvqEuxsDqc5bowWoYqAKU2UC5feQ5VlKiWYazFCk7fmTQhOEoCxFiHwq+XPZRrxbgm+gVePgsAJb8/4LFG6aRCBftcaWixaiP5p4lFctNEZBPHhO1QGiW9NmU5iWfO+F75Ix7xIc1HHX+rv6ic6g2rzEpMxxsYM99hwUxqUxTYnARX4OYkuFEUJrGypZVJk09Q8CUTkIRXOQqiKPYR3CJd1fWmt+5RNBOigKa8V5HnUHH0wqfS/UDCI6kCzHImmYahK1M8G8WDl3D0ut3wLL8f2pCW20gJL1NKHiDcS+rLCFtFhj4IKog5o1x80WUj8s7GzH5BGqbGeiduUWth254JAwiuqAiYWEOuASczXZwk8q42OL6hmRNOTPNTRFzSFgugujeccVSFDbwaV5jXQkVDxFoLNER6SV6JazzmfJEFMDAuAUdXieHcUaqLDAwq5qrpCKyQtzxFjArjRQMLggOh6vDH1Cl4QeUchPF3Ix1rvIlrDfAx35B4Itxx4RUovunrAQBdF+bn+NIGlv+nbwMAZLe5aHzuUXRH4X3Jn/18eO6jBrNWXMLhwgvnQRxJX54j7XDOZwtgbtt4ONaj4oa/qIAi2ZTSz/Ak7IPkWNxTnIvsNKI5egX4yFzDr3zPWNORHixF3EVkGxA3+oTXFXpK4MQ0vM9lZlCOBMglFG8KBV1gjuCM2cjjNn2zW8v98NjLCTRpq1wsz5zHh5INv0PvCpT+bZP+iO49AFUCsJBjIk3yt4CuCgQ6KCC4cUO/9OsDdSIcwUKTXinRwmAg9eFIyaVHEP67LqG883Ebkd1MIcuF/Dt8jlRfe64oyYhTt4FE3K+/wV9cBpfgIIMMMsggp0JOtYUFB5SmwOaYmkwXbM08z7A1DQHnaRFcPh45loSwV9Q8y2KKcR5M1oL5WIXvYOugFVjHDPXxLBYzM0KboDJMJ8GymxRSJDLr3T+Cf1cOOgGkOuHdon7koaLVVlBL05Mxtv+vDwEADn4/VGfN0GuKqWa17jar0Qe1I/BABw40oM+p6Jom5lVFV1rTRSuqlfIBOkPnxGIjQ4hzsLQ8JKBdo9fydGJRxvYlx06CtYvIs6Wa8bq2mR4DgDXvzopLcA6guiNMcwlK102Lhq01hvlT1qCLnGhSTsXFitRiOXuX98wKjNIro2Ej4bDkrgBqnYdQGRRmNUDtGhddjB3NpGlVRnet8F92izkwC065bBJ6YmNrA4uDYCW1bNPlv/ZdcCWtwrOBTNfd9wqMPvu50P7f+91wvVkN5ckNeP0aAGBxWGEZc7jCfWsApE1eZSGBnBc+c/TjX6gEHk3XeF27YzmDbu1vuXYEF0kuldfxDMvx8BqxVI6cZ30X84kigCFpVwrEiW611GrZpeUs3pmNEXbIBiFlYxx6N31H66vZMbi1Ha6+QeYXV3kUXF4dx7wFooVt6f2Y43j6RupdEO+M832e1klu9fR9iaCb5DnFyMsAzHnRm/xu0baYEDylaLJr1SEXS9JLGSIgo+szpz3Yqg4dJ8GYz5shi94bcYvnmcFkFNbpjUmYx9V8BpsUx30hMlhYgwwyyCCDnAo51RZW1zho67FBK8mWYQcvcodRSetHvnNAIfDzPOggo3KKkqUixgROeFh4ptW3PuhprrOoF9S0qWmZvETJGNY4Z6kKkycQT/GlN7FEdggUC/ZXwBsahloLyReQjzwm3/RSAMCNrV8P7To8zpPmk79j4m1yMGq53sPGIpN9vEqvWSa26+J5vU2YZMGLNuocalohyxiLSUgcTlCY1i3BtM2pRSRaZqd67V3iDEVyXsowv36/FNZelcB8GqzshqXqq7rDksnhyIKF5VwWASRgWZm8ADJayiopwpnTP5/ltDhzBa+kJLvA1l2ECveM7yomLEcgS9eircL4Lxg/0nUOxeRkeTZd3YIgijuyv1+45y584Y9CMu/Gdoh5unteDneGc1S4M288juy3g6WOZ0KJD+wdQucB1o5Z6JeZdSFJHkDNcZ2jBwHEmKLvNf8UyJAWCRTOCafEMnErTA7yuW5FG6ho3cW5YAGfEQAlyfAwfRxTUjBsF8ERaRLuekK7XfsbAEoNuHMc162wJuxslTjLVJOu4nVrCycsNFxHDnfHqK+QwZ3FKXGrxqgJHh+dBYvNOhdjRDXf+ZlPnlPalMz9lCPypJiu/F0lzyLj0Po+dlUl78htmltPs1xQ0dXIpKRO9FJ4lLQMa9cXehVmdsMx0qovDhJ5UjODjJa1khQiZCi5Tm5Owpp8WGaYVfO+usQLkFO9YcEHpNCINXU6blJl3qIsQueM6Aa0FhFgUZREsBQTjOgynI5pSGsv5V1Q+7AIoAMO2zD5nJaAu0FJV6AMhMoMMrOKDLK2jQALr1QEKXRcFMcqgxb3VBzcGv5caGP21nvDtX/p0WPklx4nI+xk0gmIYOJ9ZOiIG1FnE3+OUFD1OWituKLg0PlkkUDIj6lcX/OKXXRsQ11pa9LQ9TanC0ibHIsVovmZVmw9aVNcvycA+JeX6AyD6AIocQbLRvJJwoaldIHWCoMF6zVlBkbIT4lwy3QZ+6tgrl/4ZJCf7Rojj3lwPmUKVQJ6CX25dB7kZ0bNjSurC3SivNDFpG/fQseyHHqHitJoB4ou3nP3vyR8N8mhpqTYeOTZ8PmFj0A9+YXw7LfpCJpbWP14aPNhuO7c9QuRoOjaZLRWuvwEJSGSxqKvRya5dK2P6WgRcZsiQY8VbkK6qQR8LYCoUCll4AiskGOt62IWV+peS+cUsOqKjHlKE8BfJNp3K7x7G2fGOMMaYDMvte8c6pYbVsbQwfYU9jI3IpZL2V9aTGZkzJH2wURFQBB6R2vtks91VguD47RkefJ3CroQsejno7ynhwCe4At1lWvQPW2FMevDSS5rC42GSpon4s+7Ft7L2xgunMHHSth9tWcNLa5FrmnKqrgBbpBEeTQuMDs6XEPxPL8MLsFBBhlkkEFOhZxqC2vZVFDGRg02F0p702dfC7mtbn1f5VU4rbIRNqiNboyDhgHdoWZlzCmD0mgc6jk1D6ospckx0nI/ahN5DiOktzFfwaJlPpR1NgIXJPcGzkWXhqL7RKOJfGv5d74KADD7d49GOGsq64FsjV5h6erEsqMbS9g7gnXFtvAire2SQLc0r8/NEhJXZ110vQhTQOdXLRuRdVegStssv0Wv6UbNMrmegC1SCyt1jxzjWFO9Zrl4467QRMIQrOAs0MhBwtpN5rHshGmCkHeVI8vCHPCCKfcu5tqVZRj/yRSYlQGaUC3FBaaghVMvAgVMzM9TsXSNRStlSgi66FqFTIrnzZi0crCEN6F99pD5dVMfOfa2Hgi5V6os4Z4J1YrdMwFoYa7dgL4dSpOoI7ZltoSiZ6JmkVPle6tmmZhQkblEDiRurOieTllIdG9VzjjfatXn/6RuwPUprZL7xPH0gOLcayQdxOtoWUnOVeXdMWuqQz/PIqdx8r0sgKMzgD4bNH89ZQ7n2GDK9aMbcWw6j6bhBRk6yMZTuPNhnlQNLfLlHpaPsBKvlGpRGeb0VpDIBIdJH8S0Dd+3TyxUjf49kP5pE0aMDkApP/f997EQLAdqvwCe2uS16YHp6iUylsURUEhrdc852JEk13lY8XcIWwUsOqnQLh3s+7QNSXvQSkdAkgCPJkWJLDNYKZv8JWSwsAYZZJBBBjkVcqotrNtHt7A51qHwICKCFlmW9zBqCQRrjVZ83uK3VhlGzMQe89P5JeCDrjKyIfZRZx1UFo5pxj6KLEMuScTCH1gYqHV2Z6iYwFtZi4bZ3rkm8FY5aKH0p63QtksYWmXjK6HMxDP3GUy+wJgJdTKlAvsDsJqEWUpmPa0HazzM2aBW+XbEzxaWnHkrScBSQqTrEzStFJ6LhddUtIhEo3UqSRhOABEnBYqjrz7RDiOzRvIpmrFojpnvr50qZetw/8BkEJ5jfu8EismOBdVMk+UxqbeLpcErdB2TavlUXmcRbKGFc9K5mDYwIoP7eOox2gjjLuVWbGsjX2BUPFWa1c+gtVewEM7K8Dm3dYRsR+uxapF7xtsItdb5LJZJye8MrBX29n60utAxLWM5h6XFpnhMw6EjCGXZK8Ynsnun1nH675VjfpU1XwA/S7EaPXr+RHkmJL9J2rDOvt+ht7BavktatTGZWFIwOu+PW9voLY50Lq5bNcUZDb0V1gAzYYzSKJQ0M0b0UMyVi8nahjGssiiRj0MsvLDb4b6NxuJpWro0krXvYqxJLKcG/byNEHQct7A8VmO54XqrlREK4fITwItPrGKGNee7wMFUbs4SQdUCinyBGXvE2AxaS3oP52LnYbV4gaRdZXyHGuGKTPweUpoISvd9riSNp0SR5SzdJIkIzy+nesN68sbjuHxuFxOSSxqdEtSuugvaxqNquYDHt8PD0J1nOCGV7SsTC71Ippt+8DgQTuXwNJ8zMh/kJo8LW0TAZRqeK5ZtOlQMoqsyTBbrEclnJbA/XxzGCZHxBcC334vmnwVGgohD9L2bQM4vAQgZCuaBJUG3c6jtUA1Wc5L6jREsmR98dPn1MzxWZ/W+N/WN9CkgmMWURWDdsD/J5bOySMnBBBW1EnhmW0auf8bUBbJ+jzR4fTjmyzMu4ws8ji5XHV0WQt7atFUCtugVgj6vTvJT+ldmxDIT00mG6Qb7ZsHPbh4Jc63vNzvZsLxsKlpBkw1CQAl75ghK6gjxtL26wlTqtNHHpLXGxoTo1tl+OPbsKKLYjGy8sz04UktlglLzdaTeaeLiklawJVoQPrr6UlaIyG0g/YPEjewR55GgDn0yE3T6uYZUTQ71rmOfEC8L4tbphC2CY4jV+SOf6oRjx3L8ziu4LVIyjcK3mbZQhv2WCV1T7zbT/ByXU2xOzgAANuhCbDRw88l9AMDsEUHW9fM2zQ3rN9deJZC2CljC5b2be1OQnLanpxoDMCeoFj7so9D38ZsdYMwbHi353ldVXCdLLWWNxuj4AgoJcdcpGAFecSO3zsMSuLRkuOOwWsCwv4rkZbcENcmEynQWELex/tmXlsElOMgggwwyyKmQU21hPXH1aegSODcN7q4NassKkxgYVwIjbyt4L5YENV7lIzmnUO5bFMhosRWk38/zCiUtrIo6T+aBnBpvSZh8YcaRXFZRgzYOqGnRNV2HIyn+R2N/0XrscBhEm26rFoqOsHwR3AqTb7gP+yZYWMI9VyTgAtE8pgrYFK2X3IpdO0O5cUe4h1AZKB3dLC6WFPExJ0IsD+98X/qBLqvK2aiVN1GTRi/PE0NN3U7R5ZNonit5PTyWAmlPAl2Inikad62A+WvILuA1piO6cznu3tsIJNFaXFdzdLSwykLcdR2MEug6LXCM0NICFzDC5maBdsHg/Dz89mBZwwo9Cuedc6qv3sx5p3VfasZSHb1mjyCEKp7pGXOHiI7JlmLFVTi3Qwv8T0M+ln1ZjqwMv3GLMP76aA4lhTnJo6ma3tXTJuMlGj0V8xUQRLT6cdxd69Fz3aUoZZk7LfyJsfVIEYjVz/B3H8zP854tApC0EL/S5grH0yN86qqObUrSI+iiKC6PcTiVlBWOjbHwhqCByBHq4hwVq3tSTrExDmtQthH8bZvlGM/d/0w48XZgEtm/4aMr2iUPHi2/xNoU/r7xNIz15Zd4jDgo873w2/ltwC7Yfg3sCJ6K1/YFoO/mNR+g12jqUdxkG2bhvKpusGBunx7JO67Q0oWe013YKd0z4Mg60XoIeXbFsju31RF5E4EtWmLKNWgJNJP3DMohz03/0r8AGSysQQYZZJBBToV8xS2sn/qpn8JP//RPrxx72ctehs9+9rMAgKqq8KM/+qP4wAc+gLqu8da3vhX//J//c1y8ePHLvtfRbY8b4+vIzjJBTwXtRmMSA3uZ6p2oShjIxWfqbdR+FTVpA41MUaNgUFXpDEYHAEZGNgTfGrREeYyofSmje/443tX7PunRWcAzI3TJ69SNxpJWV0xERANFrPOcnwc10L4hlIbAJ0NsyjsbNUlhZR57oIhZtdQ3bzwHL1ZDRuaPxQFcRwuM6nLnXGxLXzak16oc2e4bb6NWKyzhJwXrgeePa6VgiXUgRoOedTrG5HByjCKW/GOn1x5YfMsOAGDWLXGWYyd936CGpbanRItHhwjzj8UJASFGLzmfTF5GJmqxiJ0HbBVsjeWCVks3j1yUlggV7/ty7QXnjjeI9SAM29eWLRYMNGTkxMzzcQR77DNOdmYM+HmIyx585tMAgK3z5/uOOAiavfJd5EfspLCh7TknZbr0PbDKprLCkr/27ywZ/5MM6xjzTY6lyeYnXVNECv0FC1oKC4ZjC+VjmkfKnLHOGuH8cas8Tcwtz/MeZ0foprRCIH1ke8ARo8Va+x7SzTHMswm2JsHCyssdAMBoexc7r74XALB/tB+eW1Vwt7gGSR4yEoYYDoT1KlpW939T+Pa+BwHNzt57Lpx/c66wMQn9srWRYZsehBk7+KDr4Gkx4QqBDhrImRYhFt3+YYfbB4HfryWjSGVLOHlmsgWVKscROQcl271pFayV9yX0al1XOJyF1GhfhBEZmbxn0BciBXSAcoisBi9A/kxcgq961avwG7/xG/1Nsv42P/IjP4Jf+7Vfwy/90i9he3sb7373u/Hd3/3d+J3f+Z0v+z5tA+zt18iLfQDAmKu29dtEnvRmu9GpMSnfuYiokgwOAxUXMan1o5VCriXoyhenadEVgp6j20D76J4Sqk/vPRpxvXQODX1nOb9vO4UlzXFVhvMy32HJQV0swvm3bi2x/ZcCman/ZJhcM+Mx4YoqG1YBoGD0vpiGDTybZOgWghgLv1WLORoBpLCURW0tKm5YnbzyJo/un5ZvWYMe05PmvayHfBWOL2JpEDzW/1nLJ5HrMjMuPlu62fnkvKiScNVb5BpzUuWg7tAKSpCozs5W8LIh083m0ESUnueL2lqgIIBBwANFNkJJl7FLUtq6zdDyHULumvkhDuZ74ftW3NIOHZ8+kyVTOZhc5grzU1SFm9thHC49LUpTjgV7XVxuW42HI73OUgcGi52DG+iktM3TIfcqsw1aaWysGu1jx0n/1egXcgGqKPQu2f47HBtYjx4AlLpuU2hhLAMi1zvBHZa6E6OnSOv4I0FtLp2Nm1yKJuySv0Xs2mf63eguEmLvjqFLQetybtgmsrwI6tHDxQ0rUpvlOSY7AYY3HtFFmxmcvy/44w7oErxVP42aa4a44/wyaRAfyJbAnW8I37/mTeEmd13yqMhCLPzJZe5x5c5wv7PnL8FR4dk/CD1SLmpca8L7LoTI5sihIOGz3G++APb3wvrQEXbsFTAdk7pJwiaZRsUX/kiqt3vEUjRSN8t7h4Yo6JkAj4oysl44K6EZGbHniSGsyZ/JhpVlGS5dunTs+MHBAf7Fv/gX+Df/5t/g274t1Of5+Z//ebziFa/Axz/+cfyFv/AX/iyaM8gggwwyyJ8D+TPZsD73uc/hypUrGI1GePjhh/He974Xd999Nz75yU+ibVu8+c1vjue+/OUvx913342PfexjX3TDqusadd2XMDw8DFoDdCgMd3M/fLc12Q/nN7sxqA4pFWEUtEDYCefNtIqukpqQyxHyPktbQBU6g5Xg7Cj89kzjcJYaxUUf3AG61VhSa6lo0c1gI2zTWA9IEUlqc20FLLUUDqRLwNXoeMPiVjj/fOUx2Q6Q+M2XBvfp43/6ZAyOy+e0BMalWJVkdlAjGDrWVB3ciXVVY0moc0erse16OLuN/aej+1J45WY4TigKJLxnJ+RKiSgc13Q732vYKePFmH/HwH4Cf0/ztoRtY8nPo28qomZf5g6NI5uFEl7ANkL5paCldUsYWsqi2Fvn4v3EfaZ1hoxkpgJRbzMHTxh9Owmfs+kUByxT0zGdouu6GGz3dIPozGAsUHO+jlPn8OSZfQDAZXK85ZubWLRBCybSGpUDFElZR1sckds3UNPpVpL/spofRZi6PK/3IT0B6C2sJXoOTIFIZwooff+9iFg36TjIYpIWIJSxrgEQH7DiRo7w8lWvbrhmkosmblopibOwdhUej1WXoMzP1MpPXYdCENtdootxe4qCABfJ22thUYs1IMAek0VGmpwWQ5MZTEmYuz0Ka8FMexTbwfq58NL7AQBza3Bt51b4+zZZVY7amM7S8BknVyzu/Zbw23tfGyy3OyeHuH0U1r1nyWR7YTTFlZc+CAC4eOkiHCuqjw75xHtz7O3TxVeHuTBaeKgxvQWSDrAEjki9oZmjlRUdyo3w7Fv0vTaZRk12j8NaGHt8X22bg+nhYRnKEBYXrTxygooEvAXXMSzzwi2srzjo4qGHHsL73/9+fPCDH8TP/uzP4rHHHsO3fuu34ujoCFevXkVRFNjZ2Vn5zcWLF3H16tUves33vve92N7ejv/dddddX+lmDzLIIIMM8jUuX3EL6+1vf3v8+7WvfS0eeugh3HPPPfjFX/xFjCXI9GXKj//4j+M973lP/Pfh4SHuuusuGB3YLUgKgaqWBLcjdDZoJtb1oGhJ6s0FkGFM1NiccJ7lPRdDSfV6CoPLDJDkjKjcawvc0QRtanwUNLN63mAvC9rtUcEk4EzjtlgtRuG6sH+zDVXnsaAaUlFfzVyFs08GffSumwzstgrLMgQyX/m2bw7nPfYMFK89obY0HWlMCz6nJJBWSzhaGQIptd7DMk5Wu54/rKVGKZqlgUHt+9hVaGcCG0YvMW5xgracyrr23SJJHE5iHmI1SmwkTU5OIcpLWsxLQUi89SzGhcDaLSwtyTqT2KSDVWJJ9vEK4YSUYL/zumf8KAS/3xciFJDMqHBwOf39jNdmoxI52fwdy9R4eDhRQ6m1lqaEjs4AHX/7+TzEvx7KQzpCMd3A+PZ+OJGZ1E29jNbPmB14+/EnsHMuaOfCYZk5A8+xTtkmJEYkoIs6sabEIsvhv+g4AqtQbBknq/tcUIn9dfBgOLZncE+ukyaTrwMwvFeR+cNosUb7X+fJb9fnB0441gEY7Ya/mx0m+hdjmBhI49zpFOacO5IAq7RaKTYJAK1tsZDkb3ZgC0BNSWjA5O47HziH8xfCjZ99dh8AsLdXoWLhWU9Px90vz/HAS0Lpl41z4S1o8yPMdCgN86ckH3jl7gbOnw9MOOfObOCoCl6UDc7tDV1ipw3n3lyGtaPSbV/6RSg7Go8lzeeCnxvGI+cbOqaLY+w9KmH1oaW+cK5Pj+HkUtrE+G6Ef9kOOSStRKxWBwUf2/NC5M88D2tnZwcvfelL8eijj+Iv/+W/jKZpsL+/v2JlXbt27cSYl0hZljHXKZVREVJTBAh4UIdeWjZtLI/hfD9dC9IrCYVJpjVafi+bk1cWG3Uwrc/dojuxHqEhaejOYTh20W5g6sOmaPbp0qtvY2ccNpo9onNmhcIGczOMB2q+DNe4UC46BccNa0rY2YWZxYOf3wcAPLAf7jHqNA4n4e9s63EAwCsuXMLetZDrMY3us971KaS22tVAw4AyAQDWdxG1I+AL6/PeRcM8ita76HKTfB2HhJKJnwrHGSyOw1zIiLDm/mvRI8dkscoVcHHNJbhAv0BGVx16XoD9sLYj292Gl+rIdolZIy45bkjQsEpcFiSSVR6F5NCReaSzGRz7UtxnrW1ihWBTCGFo/7AskYRcFZGMueFTaQ1oun+47iJTWSSKLSJhcoYny+A6UsXd/G2JMdGOC5bMyQxgl6zxJUrbch/+AudbJ5RPPvpmZEyUWh0TIIzBSFAP7OgiURJk43JQEZQhSD2N5Dx3MtjmpBIy6yCaFaBOzFVSUHw3hA2kRZ8Hl7qe18EePml/Cshwd9GdS7fdNJuilfItI7pyu4boUcDl7D+VoYvzRBrQom7FvS6+YxPzyDzHq9jawpQv6vRMUHafu7GHZ/dYFonPe+6eKS7fGcb63JngTarrGSqier9wK2xc912YosxZa0v56KIeM7lsIy8wlrlne/q3CJwWVKIFKr6AI86jjS0PRi2wS+Lf1rXYpyvdCC2WtWCqWkTUOqegJW9RwG62iRR6fUVyB+/911bF4dlshs9//vO4fPky3vCGNyDPc3zoQx+K3z/yyCN48skn8fDDD/9ZN2WQQQYZZJBTLF9xC+vHfuzH8B3f8R2455578Oyzz+Inf/InYYzBO97xDmxvb+MHfuAH8J73vAe7u7vY2trC3/7bfxsPP/zwfxdC0JRUErhjL8jiMG+XqEn82lq6ZZyDorZUsKCfgYaRYC4vcmFmceezNOFv0fVy83Gow8A4UZDJIDNnortI3wjmdvvUF1DaAC8+txm+2zpvMLlCjffslZhzIzyEh1jC0RV0N6uQ3vvoAq9jSM99KlhQ42IDmzvhOkfjfQDAdPMSzEEo0lcwkl2oflDFjWLbDs5L/hgBFHUVLYWOz9FYi05y1XiNxvsYJBdXTpV8f5LrJXXBqLVjUL1llWq8oi2LBrXlgXP8W6wui+NAjrRqrfu+M+H8sUHGMg+da9HRxTenK6XtFNwaOWeuFfKMsHe6GG1n4BLGDyAUxVwopghQk1VKQ/E8gRxnRsPzOjkPZlA9UTMdWbnJ4KUKScbxUg4HRFZ0DIJrU6DYDu4hS1efMpvQHMOG47qxlcPdOmBf0oJul1BSzSZJmFt3zXoF5LSsFvy2wKrrLvyudxNG5DlWARiiRUvuo1K9hZ66GE9yN65f2ysXYeVZ5Cbsrey0fRHQk1jg8j5EgNAIyO7cAQCYHfoG8wxSIVoKauos7wmaWUm4dGUkv3WmT3GRtIFoLWgNT5exE5JsbaDoWh5PwxozVQZnN8NvF1WYV67MoKahfZ1Ut/YbGOWX2NQnAQBP3WwwI4GxaQzAauhFFuZJrnpLKDNkdoGV+qOQbKMui2QVsRRP2zp4VrKViumlMTBe8hf5vG0VGTEaTmSjNAp2knfiLuz/Fni7Vpp5WC+iS/Dpp5/GO97xDty6dQvnz5/Ht3zLt+DjH/84zp8PSa//+B//Y2it8T3f8z0ricODDDLIIIMM8nzyFd+wPvCBDzzv96PRCO973/vwvve973/4XnkeNAIvfm1qRvuLCo0TNvSgMnROR9+tBG6V1hHGu9UE7fbeZytcfpp2w5NBs5g+U6G+HoLgah6uW3V7UNTia2q0TWtRJXEeAIBvsLlLFvCXzIA7gs+8PhNU51uXJrgwDW18463g17705A3oR8JvZoyjLdQBFO8zKsJ5nX4CmwvJuqeG7ywUh1VKVsM7aCkhwvQA29ZoG9HOCfrwDg1jLJZWydL1rBYCS041WtFa05IIqYYc41pJgF/AGyuMB+w3w/tuwUdOxAhlV6sMBiLzENpD/s0PAAAmjUHjxXrUULScFlQf51WYOwBQSpxP+ZhxL0AMjTIWrQSvZzuFSohShAWjnPT82oREQ9mYZK6UEAP27gCxvhRMtM4Y3oT1OlJIXN0ltFjpmMguEPt8MkFXBaizxOxgHQ5uhiD+DgsS5tbA+lWYsU+StWMc0gMjSSym1rvpV+OUcv5JsYSUUULGR6zj1AJLz1+/tkeyKMVx9+h4pjDYpHHUk4A/K4AeuQfnVn6nQXkhWONjxrAq26Ijy0MEwRigjf1KS6Uw0BKXEXCO7VDH1A9yRcJHyyoC8JsWnizmS76PrVYwYzJKcE7MOofG0TrzZPHRHiVj8FuboYe+8NwM1/eCR8cUE4xUsKw6Egw2VRWDr5JqAgNoBgINg0/OAXl8FvEoaHSMd9f0TBS5B5xUeBAP1hK2k7QYKYzr0DKJ2dHKLLSJzCAxhIbVckMvRE41+W0xAnylohuLIDDszRrMq/DSbhZ0o1gV69dIeREFj5J5UVf2w4S79EwF/eng/lPPMtfgiSfR7oUgOEsm4cj1L4Lkp6SZ9jJBMgDqNifBpyqcuRou8MAdwS15wW7gwY2QV3X3tbAhlY9rHD1HxgkBPDiA+wsMcyoypSJxahz1DMf9LN5HW9+yyqjrugj2iIAC36KT8gGCHIM9VsMnvd3KIsW/11068QQE0lQbr43+GQWdxuc5r3qmi5mc53vARsy3UcDR39gBAEwzKQejIm+ShUIprhK6Y2e1j8884oqqlILgvyIxrnGRssvJgtRYWLpuve3Pt3xBBYXpfNdTC3Hj8s71KSdx4e1BMrK6K29Q8Ni1C+EHF4sxGoI8ciXt08jPhbIx9fXgGp4vm0hCe5azse58RAeKcuf8KiITCLXV1BqYokQ/v2N+Go4Da9LNJxXpA+39CqJQrnfSXIlgBSGudi6SMcv9VnK4TrhvF89TkfnDMD9x8tLzKC7uhGOTsKlXrkLHdUHcaMuqg6fSp4ikbWx47wBAGUG91ahJWSQ0Zpn3kSGmkXwtZaOSaK0gbFSsji6gn65TsC7MZdVIlewWhpwjO2Sw+eynb+PJq2Fd2pgaLDlnhPpof3GE1sozsd7VyCMvw4hmLHGkVA9Oi+OaTQCy+7TCfmMtLN2EHYt8dQ6wwiAs9bichxeiZoK3RmWBUczDImhJaazCdb60DOS3gwwyyCCDnAo5/RaW8zHDWveWNxa0Qpr2LL/TkQ9QKmNq2+HcUdCI73wiaB35H92C/0xgl2yfDOa2Wy4iZFO0/Qq9XlAn5S1S9gZgVetULaCfCl/s3g5m+z1W4dIkBFEnt4O2t3isxaJdAz+gd0GJZWITOKiUHElVXeGU9HDo6IKIQIvW9lx4AktF79IQq2qB4+wBJnn29NOtHUvzaXzyx3oRO4V+Im5QK72oe4CAPG/KoCBlTdorgPumu3i/YLUWxQJZI5WfS4BulTFZStS8L5XXF6zsxy5WSYWNNxSrq3MWYOVfcRc1toGjC6eiBt06i46/6VG8acVhIUnWER4dXdbGoKBmeo01I4qyBEqS+PJ5YD3UKGjbigSl7eIALJiL5YxBc+MgtfPEcrb+OMdeDqDhA0+SY2tdBSQAinQBWXH5RRCNj0+7XuajwHE3skefphALRnoPFU1DF9u83n6dXFvaUsPHi2/cHZ5q+747gDNXwnkZc9VMBhiSZydmo42uyLB2WG9QaHHDkxBZebRSaVpANXl/X6aHYr9VyOtwvyKGJQw00ygKoiGW3qHmAtJ00n8WTkrdMF3i9u0Ojz0T1rk7zu+g0mGQl1W49tHhMjKblCPJ0cqwSSDPaMI3W7veO8KH9yqHsDVWws6jHObyHrSCYVcx7CLufEDHStgCQjJQsXioPIeGhrVJgdgXIIOFNcgggwwyyKmQU21haQ1kuYY1iYmAEL6o6ccVjQfaRtZ4gXZnzuLMQdBKNp4IcSv7mafQff46AMARiFGh9+PLZ1qOu/0SPlixJAr0CM7JjMnLfzRHuUno/SxYXbNli4bq2RK9pSV2gZQEz5METgGPeOsjBFsgxd46dOKHZiyrbXy02FIrSbTVNG4VLboEOJEG4EXWra4Oq2UqkBxPf6s9QKUPm/zB7lSjPiLcXmIeKtHSmaB74/+xi3xGOO0GodhZBkyDNdLVLXJqhQWTz41ZJNp73xYJOPeQ9z64JtyKnW1hBftLS6coOlgfbJG66oPW0SKR052KVQOEHUDrDFTUY+xEa0CTB/L2iFbaeIpCWOTZ+txkqBiT1IWwyBzAc6Bu0C1w7wRoY3ykH7H1OKRGX4Rx6wQIehzfJA5pkmMp0GIdMp/Gq1IOwHXG9ZXfSGjPe3gBAzC4lnJ8p/dIPRJA8BCMGNvZfEWI922fP4d6usH7EiSRjaFJ81BqAQcYtJEaRAA0ZQ9mYlRXeRc9FzYy/vdJHcJb2VXzPp5Ja6ocFdHqFjBH23WYsdDmEQekLF1Mt/Cca20HPP1cAERcu2MfR2SCP7gdWldmBTY558cluUR1jhkHd2szPO+NvO4Z6BlPc17HmGPVhXssG4sjJ6VX+I5YEy1ixViWNgYmE6AR+J2OAeiOsXOlFJxdTbX4UnKqN6yNcYY5rHhoUMmLD8BaKQtBMxoFOqlbxRyB0i5w5hZN10dC8LJ97jYqmrsEAWLp+o1KgAcnIaVU/B/iQDXJ9+kCLq6XfAkckLM/VsxFj9ISaps2Ka7Rbwz9NpBJOQiFWFfHix/Ia2ipEJyQn56Esoout+RT/hYShHSxiL9VSQA9ueB6UL1Gv/HKZjsBMOaxK5fYB5VDFZF0Lt7fcbFYPBwG++DuEhsqvFDaSqXdEYQYZZRN0Onwfcnen47Cy86uCb+xJtImyZzxALxsIhJoV4g5NzXdO76p0HEH1Y3QMeXQRMwseP7EFRF1KIueQwfXifsykj7F8byRhXmZTy6jzoJS5TMBbHiMTECMtgyGm/EtHC3Dwwlopa77qsYpaGEdtVeiVw6Eisihp2s6SuZ0LCXCge2AOMitDwAO6UMgvJPd2ruRUimJpLl7Mo8y5fu/+Zm6h09Smubo2zq5FMZk60rYsPLNEZaCeTWs7G08yiz0ZUZwTqYLjLhhCUgnQ5HszJKT1KJtqRBG2iEf3IwANJkuvPaxNtqID5l3gMr79z5cBLi1oHuvEwJrh1rc+ayzVmQay0VYmRb7SxwU4Tf7B+F6ZzY2UJKMd7Oku1E57JP8dmMc3ImmqOHa1bUFqkAl2wO/WzYtGoIpOgGjKBdZNKy4tjOFQgvjB6+oARU3c85fr0O+olmfBV9cBpfgIIMMMsggp0JOt4U1KuB0DZAXTkwBo3voekuX4Ag+unAkS3tzr8bGPtkNHg2abH3UxFyqJtEo16ukKiTsDCeZKGva5LqkgASB5cupLXolrq+mqyLIQjTnXYPIkmDYBXnRV0KVNijX9doNFR/rj4MkUklLgKyX/rA4QdPxJ19HjqVuILEeZPJtAtgSV+A5WlB/5Pr2R/eUQns5/OPqXwuao+s8LNXuRiDF2keizVrraCVJv4yKGLtPQBW2h35LRn7efy+uZetdrEJbNUGPnzcWwng4csHVpN1GhFNbkjF2TQ9nFk1bKRtZIXqry8f5VlnCh3dH0I/TKqfGrrs2dk65EZLRcncJLQs3EieAZYueiUPcPIlbVwAspddYMiA+FbeXciicwPO9HDoGYXcJiKP3BfSSugSPWeeJnOQdCmwVMoMI6c96VgZ5J6vkb+m/0UTh8qsDQ8TWxW0AQDspkSEBswAwIw3DpLy8EBaSMoIWGtJfaG8iu4gxhGB5hU7mh+0ZL2Q8NcEGWuW9K028JdpHYMdYC/uNToBQdGdroCIcXFIFshxoiOhYNjWO6mB51S0tOqVjdeycJJcjpTCbhHVwcyMcKwpg2crco8tSZVjQDVuzlMii8jjk83UssaKhYKQwo7yvGZIFgpZ9Quro+aI560Ie4gBrH2SQQQYZ5M+bnG4La1yiy220sGIipAvF9wCgjRnZGXwd/nYMfE4OKxgm9dpb5JmzfaJfRe2qQii9ACBqCUEvWPP7IglSv0ClweNkXr6OuoRAQFv4OFh3kVNu8567I6tbeyuQD6p2CUMNXDLt0S2jVtNFze144mUKukih7Ovl609KEE0fNw2Ci0aUBufFUjvDY1sauHgHtXcyd7hcRd440RIXrsPy74Vf7ZPJZOxHKAqhiBC4bCgdEy7kY1xBjIbxKMR10v6wFhGsoDk/utaCSPk4n5zzMcDeErwzr9qYYKxp1mjXoOO8FJ42aw1s02uXAOCMim0VTdMrFZlapLTLM9sN7pLzaFYXxsOxoJ5Qd+R+FxtFsLAWHIibDtiRkAL73CXAiWjEo09sFrCB971VljoP1tM3LPo5o9RqcrBIygn5xSRlTEk9DooxEcvYpNoFugP+fcS2oI+pbXHctl45wu5LA41/tkmwQqFRsCdKpgrkZYaccZ6CCdpZNoolMwzH0lsPKDLXsO+1UXE8rWU82vsANAAi84z3Gl38nvPcI84dYUaZbJS4cjbAkLbIlF7mDjO6Vqas91GWOTpafoumRSPxM8VYbsJ9PykC6GJcaOyXYa3bGDF+lynMJbYtTDdOYcFrC8HAogVmMjiZQPvz/jmZilEURfw+5tbAwjlJHAY/PZTK4vO/EDnVG9aoKDHJOji+3EJaaVtEYtqGGejednBMViq4Ck0PGuirYfI1zDWYA6j9ahC06sc9ER9fynUv4PPJOgjB4Xjwu0UPNJAzC3hcHjHT/Q3fEJ7z4l2x5lL2xBPhOa4/C0eCSsuyGlrZWJMpVtjFyUwB6yhBd8J5J4IuEkmn33ptI0AF9yyACSf19jmPMy/j+X/Ie0x8rFklKR+H//M5PLG5AwCoDwMUapysjm3NfJGsQ1Nws3AOji4UQ7dI6Uto5qw0gktpEf2rQqVkLVA1BDNIDolXEdDDKYPFEsg06yaR3qfUDl56TlZRZ3qSUW5ceWYirUSkDHMAuPgUvMaT5+a4mzlXmCfqBN9+YWLAaAvjs4E5pX025BFWPTlQHBuDfn6nx8TNHXOcDJBJXl+SCrWu2LRI8vQSd7MsMCe9GypRXp7P1VMDKMTtT74u/VIg26NiuUcX80HP0LFzX+iPnQd2gbNhYW6mwW2qjYusN3kZ5sR4XCIrhBGFVZ7NGF3cuIk67mzMthRUn9aIteXaTipZhyIs8pxAQCR2srHFRVvH2mMdmVruvyvDg+dDWy5NwgM51cGxrtbuVmhflmeYLUO/3JxV6EqOgJaNq/f7ZmyrRQ1LN7PixjsZAQeCwuQiats2Ioolj2zZ9eOdKcmzKlAI/RIrcY/KCQzbEFGbrot0ZGkNGKVNX6blBcjgEhxkkEEGGeRUyOm2sLISDdpYmK/lZ60cWqowVRssi6acIKemu8Mg+GjewD8TAqdp3pGALUQD9X7VIgJO1hj9FzmeyrqG4JL8FflthtQFGf7YHeU4/9JghrhXvB4AYM5eijBvrQWSuwF3EODPOAgattcOXtPColaV8sGlsq45p66e1CV4IsBizW3q04fisRI+ljXYCjFwXHo1MD6U+8qJGZYsF3PwN18CALj5xrNojsIzdbTObKuiK03yo+q2g6bGW3UOjiDsLA8WSmYtFCHJkqZnOxzLh/E+QHkBYCI1GbyHiyUTeheHuKNbYbewdXy5WlpOzmpIVU1hCnBeIRNSUDFhtItEqDm/u35mBkzYfrK4aGRwBNM4pkaUZ8/Bn7sz/Pa5m+xzRCBRWr05hbgDwcJKi3TKd5FExfXfrTNKdEAs9Kn8cQtcrZZmDN8l/3w+i70GsC35iLvhJs39OextMixwHmUHQDenNX1HcJurizvoCEgRq3uSm0jBIB6Kspwgp4WVEepuVNlbKJm4bRt0VgiOhWRaxUrdAr5wzkUwgrjMrEdf6JFz1Wc+LgAFSXBfdnEULasp+Q9bb3CG/CPbW+HZRqXBAX10B7XB5lTyeggQyXsS5Zqktca2kU9UkOeFMVDkp+waydNsYwpBJ+7L8DDhmZivlec5JqMRryOuVIVYoZTzs3M2vqfaST0dtYrGeQEyWFiDDDLIIIOcCjnVFlaWBc61SRseo2MZa+d7nr2Y5Nm1sQTHDkvaT45atM8GbTW1KERrFI2x8CqyS6QW1gtVDJ4P7p3Cy/MkXrVkAPYS424X7rsP9pVfF87b3Q0/OH8RWhzM5xn0X1Z9DYmKfvfFEqD2A8ZuXNKqNK62zs+Was491P54bMrjOPuBSx40jaFsTMK/Lr0ytHlnG9i/Rui0+Oz3OzT/z2BR+m8NFhZu74lSG4PERddg0zMaTQv7aBnTQqG0QpYTak7LyVkgL8K4C29g2/q+YF0u2f515I0bi7rpXEzwFS0z0y00Az0Nk7WNK6FFAxcfvddwUqrc9hZWLHuRS2n2Fpqv5ogWwIE6ADYDL6Y5EEusTwhXwo+pcyhCtf2IjAJLgLgEcOagVX1yr0iaEBzTN+wq5yOAldL2qYWV0lmuW0xaqcgI7pPz1iW13uS8hfcxUbXe4RifL2E4j9wGrYI9hz0WPvUbIZ5SbWzgEh90k/G+zBRQfB+KLMydMh/HAp7CUq5UEbEDGf9wtmcijxaWziPkX1IebBKzkfFV0MfAGVAWYEx0azO05a7NHNNcgAy0lkyGku0fk6VjMspxgwPQuAIZf4NCspKBluvHnCkYZdvAS8kJxuW06uDZxpoZ0l3nYnqHjLzRpn/vlcT7MmxIeRQpf6Nc75GKr42L8XMtPJpKB6IDvHA51RtWS0KiRnJMuCYrDyzZCzOav6N2gYkOE2Kb2eGjWxaLm9zYeE2PpERIfKF9XHDlULpgpyi/dVDFOnn+ukmbQWPBK+wkZ17ms2zfc28476H/Cfa+B8O1GeT0s2X04XgBChRjGNL6QCrilltoGeTXMTs/5tWfiN4SGib4vh/S56iwekyhDy7HjUvpyBoRc64UcP4Cn/d8+MG89lhwQ1DMDWn+51eh+J9eFf4+3OOvHbKK+SR0hVilsYyIUFYC7iqAruDJ9i42uWHl9EW2doFSXHx0hVRdn8sm8ymDRs5+82W/6ORUJsZ0K7UTEytby2TwvoXPVsua5CaLbpFW3I/eQBNdpRNS1aklKICjcmRquF1S0j5DgmIPZJysbsJNygNM10E5DW7ixfIoIv1q8bh6oO49kLHpskAveV6hcKwKdbphyTTp0Lv4AguM9JeABlzcIOX9yrBK9xWeN1Ee+V3jAU90xiFdpPvZCJZMyUKhtqwcbhFVt5izjtS8Qk1ak4qURllt0Y6ZiymEw1AwWhgpepeVUGTFUife90wtsZqvQ8d1Rj6ddbEOm1zDoYgbWsv3tvUu+tILKj2T3AAZwWKKiF9vAKISJ0QBjqYmViTeKnNsjqm45XzXC4/KhpCHrvleWYdK8gjpnnROocxFwRfAUQ3bhE1fsayJylxEMgrQYjwqYTJxpfJ5O4OOg1eKnuw7eILBOqpAOsvgl6qvOfcCZHAJDjLIIIMMcirkVFtY8AalLjBiaQUw52pqFSpq3Yqaqm88tgi1PMsE9cl+t+LGAIK1wVSPldIKoqHKeRq9RpkCFVI4OBCUJ2HCOMkq6+BiFn9LGPfZUiG/n7kjr/u28Byv+nro8+fDb/eYgHLtZnQFSZ6VybKQAg/ASS6XzqCEHDXnZ9309nryTNGSTLTh1O0jnyeVheiJZAmIgEMp3jB+d+4ScD4YitFKW0KhuS+0q3nTa0IzX3YX/CK4MaTvldYRZKCFY7F1mB+FXs/o/lAayMhaYJyJQfScJMmjfAuTgtc21Ii9iykjKqr9qmdTECaRLId4jka5WHmI5SAq29vqwpUomjGMiS68gt8ZkyEXzrnILuGhOWk6Bqg70+LaVpjfu9RyM9j4G7GWTNtCZm43CWiEEY7AGRM59grgGPmxRT9O4kLc9j2PZgq0iL9JrGqZH+k8Eph86Xqr0grziE9ZTHqJaR6JxSb3PpoSjGI7eHG1ShXOoo1elrlw7C2XaMjz17gx799Aa7EegmTaIOf1JAcSKoNb0/6d9bDyziXfec5LgbVba6ElTUG4PVUPaxd3oYaBpgtvTNh6nue92StlOrRB00j/sXlGYVTQ2i9zjPlu25EQZQI1wTgNJ7BqFtifhVmwT77CNmEucYJC6mzvvowuco9MqpmL98AYlHRbi8Ow9QqW4A1JJdHaRM5AK6VYOhvIuofyIoMMMsggg/x5k1NuYXmUZoKSRdg8tVHlqqgSn6Hf+t5O4XWLsNu/ZBk0z81uBr8rWhCD5oWJMPmikQJ4NvqtXXTUe3TRFxwOLdrA7A70fvg0sBxg3gQXSJAWgIEkkRJuf8c21KtfHX7/8vBprlyJCYugpeAPDqAkW1qcxc7DS+lufjYO6OiXt0Z4xgzUWspyCiSRdrc4bmGlDPTrvw+/Dc+W+z4+t0vVaHq/QlMwzjBlX3zdHbh9OSS7HhGae2FZw4wYgxMmAOX7chDUUJ0HlizFkEksM/MoyFJdZAVGfGaB0yN3WLJM+KiQJMpFLIfeJzga0M2PTuJCeRZLJwhvXOFdDxGnWe6tisG8yBWoDVQsDcK2ahM1cRNLr3sUrcwP3ss7fGEzJEufJVCgaxuMeGdJ2XBtDVOKJSFFWxCtlbQIZ+oFCPfoT5+J8eKPxys71VtGK5yT/DtYRDIH+Oww2OfZkiw8SpDukQkluY4MV5q8frTVF02M5UdogptcIxPwy1EYsOV8ibpjKgcBR8rkMXYZ+SqVis8k4BylTGKF0vrqfCQlSJn3paSPWCid7ZBnMn95D6NhpYwHLaeizLG5GcZrXCTwcfIZRvCKd5hLLI6pE867GEvKtMIG2Swc36vDeoGZFF+sGfes9rF/GOa+5J+3FmAIPNKy+rZDx4FyUh1AKSgCMQznXZYb5AJmkonkVRxEJzFWE1gtgFDwFACWVRMmQ/fCLaxTvmEZTDMNZcMGpEg5AgecY8+/1oZBfJ3bxn1V+HvU0XTd3kH5dauQAzM6B3Q01/f3AQD14iYmDBjKhPRdB8fJUJIuIVs2KKpwbMFAcO2SfC6F6HPJ4rEegbi9wfyP+14J/UBwjdnLAdfVbW/G9mARfJr25vW4SWMjLE6uaWPehwTzO98HWIXzwCa1b49tqsmxNA/rJJdQGnxfB29oZPCKmz43gdmGw+w+BpK//pUAgIPJGIt9Vnwmn1A9mqPo2B8Q9gAbK/9GaiMH1JGQk+AGrdDJSqMUSgamR3TN5VBYjAJAQ3JItFpE0IPU9/FO9ZRMQk0AE8/L6H8awaPlZikVYjuLiP6SDS7UwJL+p2tT674cS+J96gP78saXeGwjgE++pTwHADisDqPbWssiWliATA3TLMyJw9zAxBpE4fwG/Qblks1MmiDgjCX68Y+bsuqBOO3axiV/ryML95TFDQFyyIl+lRQZ4Ea6hkqF6tlnDghgcSoB+QiSc3MCvxtacntJwtblEopouI1p+JyORqhlgdYCKHFohe2Bs1qhr4bbtL1rzVIZljmhoKHYS6LERNca+g3QwPQb1SS0eXMrx8XtkCe2M6VSqfryQhnn/qJqsT8P61vViAqRvnUaU7rfppsBZDTzLW4dkO2GLDCLZY0jDu6iL1y9EhIBAsglPguJca0OQCoAKMWNrbP4fELy7H0gIAcQXfht16GzpMkjSrGq6uDa9y98wxpcgoMMMsggg5wKOdUW1o7KcIffQE2X4DPU57Yd8BoW1HujDbkrl2dnkFFj7+hqMncWwCS4jjzdKHq0Ab+gvsHSnfrqU7DzUIVYES6NroGmmT1aUkvPG+Tk88qq0KZF5YBGrJte+lwTjxFdQqN7QhkE//LXwF4MbAWSL4JZBRzSgrx2K9xvfx9eINFRpfGxcm4k2FQ6QntrsUyUgY+68/F2pflY3QnH3Nr5aQ6XSI1GuGPRbZC14i+eg737XgCAIYln03W9JqaC9bhoM5Q5K0PTOrO2hVbCxcbM/M5HjbiRIpW1Qs7iitZZ5OzfXitUmI6CtT1hwHiUAx1V+giwcV0ErnROXJGqd/EZAT9kUNScRzQ5FsZFt46u6aqGi+wYTczH6WBdH1gHghZJpRaZ7c2MZ0Zh/O00FCLMDk0srSI+y6ws4WK+Edu5uQW1t8f2h2M1Eo67BGCTuoIBYIYEiCOf/ribOB371EqS69xIwBsnVRmW81O3ZPq9tOGG8Bx3Hg3NvK1peP8mmwW6klbBMrh6F8sFpnlo5XnOwc0xMBOPihGYeQUveUyRrNaj4njOlwRxVA20sEFIFVCtoAT2TndXZzuUUjRTWB+0wYg5S1vMITszKXB2k3ORqHSnu8ic0dow5ofzOQ64prTCM1ka1LS2uq7EyIS37RzBNs9UC1Rt8CQ4sj3PK48ZDbRIVpwpGN2HKICQJyjuPFmCOoSUC6DnJgR6F6lwbDqXWFiccM67yNcqI2tdC+tytG6wsAYZZJBBBvlzJqfawnpFt4kHFlPUEktgwtw4H+ON86C13NEEbUMfjeCz4NvFKFhd2VYOxeRKTMIxNRrDHdK5u0V+PgsoS7XEM9nOLqGp6Ug8wuQFNkoG+3mrbddieRCshsOjGY7oC5e4loLC9oQW3wOvCMcu3QFLrVEvQ1uavc8jf45p/E8/Fpqy3INT4TljMjF0rGwXgR1ORUuhEW1I5fAqtCLCr5FkpvPfnUq06LXEz3C/XlJYc3g2oCMDevb6cKy5/w5sM8lVys6bzKClObBg/5Rdg8Yx1kX2CKV6YIqhRVPVTbyhJCm3rULDUjLettEiEoZuBYOSEOCCVtBkZLCgbz+X5NmuRcvf1JIQ6iw0YwUC8YXvQkVB9MwUoZ+l1AgtJy9xLAgZO6yz0SKS9HQFHwEYMRTnO8wIdT46yzZfM/CM1eqqj6eI0dWMGMscbUDngazREaTTord0hFXcKbdSLgQAFiokGQOJxeP7GFW0tJL0jTYBZdyQ66CXdesc6OeWRn+daL37PtXkllgClYXaDkfvPBtiQGdGE/htwqnpjPiTxw4xKRy/D8+5sa1i8rdYmV1ToyNjjjPC/QgsyIZ+MKPFNq8xspIcLHFBFT0EYvk7Z2MsVJt+zk5G4R5b5C3cznJMx2TdyHtwkRduQkJelm6JOWM/seDjxKDhu76YK3jG3gsydYzVGAWJFg/I4rKoFJaSqiEWtvaQajw9b2TqqaFXBgq5Eq+NvPk6Qvlrya1QGnkmXpG+1kPKNCNinYvFdl+InOoN6zXuHtx1tEQ3CruDRXCVXWl3cfn/z96fBtu2ZGeh2JeZs1ndbk9/z+2rbjWi1HeoIZAeskHY79EZIxvbBDgQz1g4MI7AJgKCCBmbMPBsED8gwOEAHOAXjzACwTPiCYEQTVFSlVQqVdWt5lbd/p5+t6uZTTb+kd/ImXvtVVenHPKPLVZG3Lv2mWuu2eTMmTnGN77xjT4+tHBMRlW7M4h8cgC5vWswu3GwW7rTYbIPLaoLnPTcdIbAAHZh4ysYFhZ+wQG0oMKCLlE8E0kSZjcugLo2qCiRtPPoCOdvxcXmHn+zCMD+9QMAgL/xTDzO7gxKKpiexAya6u17wHGsc2QfxWvQ1qecK9VQEkgZWOad+CBsoj6JAVshFmiVWD0DmSKkvyVfpwkDU01glDwzRYaawbBQCfRTANgXOOb3R4jz+t4OJkbq9RBuaS36Mi7qDRdov2rR8uU31QAEFFTvGBO+CYWGo7RDQ6mKDgZLKwtMl6BPrYVJ6dNEXzHHZFQWcNWwKAFAsB6ePdJSPXbZdhjVQuIgpKlUglpF1kdjCS8B9lRl2CW4MQEqzqU8FMV7si5AceUQiTHV9lhxwvrlmycAgO96VWPkBL5iNe2gUs0URSOswBOYKZmD3TydX9QspJSNQya/lKlMCKwri47GQAoJmcmSxkUAHvFf55zY2rBZzkmMg5x1uE4AyJU1TmlUNG2Hm3wOz7D41c64QEmllHYVr/bzbxsccGzNRpGscnt/ilUXj35KA2npSywJ3UsKlLUFFoThjs8pkr1s4CWfj8arUQaCJiYxYrihJhgpeEoXqLjjjIuoLoCyjn/viZqKHqFlNWNDkWSlh1m+otG2M64w5nHuPWpxdByv5wVew7gw2KlpCHKB9H4MyzdUKq9XI6CTl5aq260HPMklhSVWqR0cTe3AcR76Fr0oyNAYnlYFyvRg2ZduUAsahoyCChZqS7rYtm3btm3btt9o7Up7WNf668DREqWKXstHb0Y4buQNzGn0iNycHogOMKPoyejd6Mlg/1bSYFO0+pUZQYt22n7cXx2dQY9JG39Cy75z6OnVhIYW+cinIG7xTAyMh4P9pEZRXD/GdBy7/PoXvwwAqFZLFAf78XpY8VS5Huo4Ej7C/ehN2SfvQZ2dxO9pPaJX8IHXr6JP5AMS3V4qhbreJ3q2CFpGZQLm8IjnlOXFILOgJfdJrJs5gCzDh9uGMh9iQR/A4Nb/IO750ne+AADYUTuoaZ31ksdU9Ohs3O+0ZfmFkFnxQSzPAEPrUnQBtbHJU3RC3dYBji5C28/ROXrAAiwph4Lez5iQ4Hg8QU/Ud7WUUhEBnhBUIKznfQfF/CYRq9VBwQaBLQXWMcltKCuOLd9BEyI13M8Hl/rXS8aTDpmMCnXyrE3B7c9di/fzfUWVdAMLwqfO+6H4JzXs6vEUHTXlDMdJH0LyhIUaH8IAVefQoJCoB3AHsCmYPuyfQ33niQgTW4/L5UxC+t9FyvygmMKmAD2O/T97KcLmbvE27uzE9/OZm/Fzd2RQSD4UPyd7j2GpajEjXHtQOAgG1hLuWnU9OvF0rQgiK8ypmHFOPUvvAqrq4ntT6IyQktPaJYVFoG+jB0TEiFKEghICEbUMG9/CMD1G0jP6EFIVaNHs253WmM7i2Do7X+D8PI6Llu8Qgk/JKxWRiaIokmhzymMrNEwlIrvsvz4ktQpJC/CqSBW6EzVdG3gKcIeOCEZRpffUUUnGdhYNYclA9MO5HsGH9Kyepm09rG3btm3btm27Eu1Ke1ihm0GvFPy70WocSeIoKtgjBvZ9tEBCXQKziGFjl97P/nWA2G6gRa48kuZVmEZLWo8n8OMY64Iiot8FuJaWAhUxqvE00cvDtfhbdfc5eFLP1d4JDM3HfWLj9b23U2E+xSJr4ewIOKYCw4PoYanTJwA1wMJKVMkNQhjiMnEj4CX+Ibpm3qfvfRBrHqmkgIMkTKqUSCnKCSaEpLEm5nQPDHRvCSwrlZUXiX88+5/t4tv/dFRcvzGL5JBaVdAkIwjVXquApo/9+liSv51BKf0Wt0Brl8o8GCOZ9joVVBST3BhAUTeyc3O0fexLy7GgVEgxhwnVAabjMRpqq7VMV9AFYBgoThRxPQSPRcmg1CaqbgNYJRF8k8xBzRhGVZYoUkmS+F3wbUpOTZ5iCBCbXZRVVt4DpPc/mLAo6e4YFd2kXko2OIcgHhtVXkxdoz8jAUNLnMxdKMgJnrFbo5dbNcQk88Kc679dV3CXbsiJOhdJ43F/t7bNheF5586+f+YOAOCDL38vAODhya/iA8/Gq3juVowb71QBykpxxXjkm4cBj89ILqK3WmiXysWcMQ4NH9K7JHp/q9Zi1cTxKJ/wJt3bqODVa5XKvHsvlO3Bw5LyIlopdHzGncRJtUfD9ypw3BVqBRcYZyIt/Wy1wIqVZceMtx/sjHB4Lcbv733lFGfn8cqWnCe8ahMxrJKE37pCwXvvNNNtlBPOUEpUVwGpBI61kjzfo+c9dbwWFRQ0x5m2QuxymOAikgDfwDG23pNi31sL2/ukqPE07UovWFgWMG0L9YjisYfC0CnSAoQJg42TXRT7hAJJiMBkd6AkiRvt7FBPSspzmBqKUj9etlkN20p+VTxGrYukcaIkqWJnBj0hZbAcQUtexHEkiMyqAD0irMfcEfvQAad8Qc4jNKjOTxFI3giNSFHVUEbcf052Dok51BLj8qGDJ3bkOIAD3CA3w6lk/yWD/nr8/kMs3fDgNeDkHm+Fo+U2gAccxCKSGrxHTxZWxxfi+/74h/Hi9bgYS76TskWC1wqBEEOJSRlf0HEtdKUCNVcJgVRc6BKMoRmErkuFjoF4kb1xzqeAfdudo7WRuGJ9fIZGqQTdVTzHuKpSPSSZfJRWKAgZlqWUUFAoBI4k5GaUwogrUKkloG2BIBNa4LmKQaYnK0STYs4cOz4M5BeZ8J3WKEjyUISk3n2uxN13T2I/SFXjrk85NYYLeVAmVVZOM5N3KZem4cSkkalLZOeXCVoUKkK2X55/JT/1uEjQkP3WpyWPi+opWNsnwWwBUB/+CADgw3ejOsr1a8Dzd74KALi5G42O2UgjeBKhCGM9e3uCL78Z4TxHw6UPO4mAYynbpEOACsL0IzO0a9G2EUrtSWRx3oDDHJZGUQVAdIudPHPvE4VJsZ91WWBFuPGEem6nqxWaUbyGgsLJY7tEZeKYLSUPq1uBikrQXLAOd8a4yQXr3a8c42ger/FsRYKIWQDsj8II3TTTYTLy/g1CzQKV970bbGAOAGe6VAKlF5UMq1DwOBOWQumtwzENAU0i08otcLZc8jdc1PuArnNJ9uxp2hYS3LZt27Zt27Yr0a60h6WWPfrVctAhO+LtlEDBfIfAgLeZ3EAYx+BsYDkS5TQUTQVDy8ct26RgIcHrYEoY6rM50HMKJpWcSGndpQJIqgjj/fg5HUFNSJm3DuqAuV2H0cvzykLxWiXqb5YL+HO69YvzYRsDsaLIGoJJKhohUUstWnpTknXvg0+5FLlAmxFvkTzem98/xfU/Fqv7mueeAwAsX38P3TvRy1uQUPKwmeO/+/vx7wdvnMRzNQHnkq9BbsP5/RbTbyeER8/JYBDGlMqjc28lFp08LOUKFGrIS4onGXT+RDwAqoemlynam60DGENG0y2xXMUcpJZVV8fVOBUYlGKXRaExIjwsBemgAc3guJA86qKGTviJUIUVFD0rKWFSFoPVKoilKQwKgVolZSUDxoLk+HmfSrR4gROLElpEfJmj95mXAm59ghTmUfQYXNelvC5N2Gm+OE95eknEWeuU6pDr/iWvRnRMA4Zr4bebvKqAAcrzuCieK229SnXuYUnT2baU6ROA3Y9Fbc1wK0KD1/s5buzFd2PGk0xGBp7e8bSJZ7t2o8InPhk9jveoV3nz9gKBXo+QJAIKeJIC+i4+11WzTFCgJdTYOoNWQgaCK2fskZSPFVy6gaTtqXTSJHx4Et/1aT1GKWVD6LrtNBb7JQtQUrB32VsslvGAMwpEX5tWuHkQEYzRyODRWTzmo7M43kO9gjyt9L4EB8/rkb7qg4ERwedEGkJCCBwRjLb0qHm8JqWcFKhKEWiOh1h1Do6KQOWY6hxdA8vQicDm3gFtC7ith7Vt27Zt27Ztv9Halfaw9HIB1SxTXAAtPZl6BFSkoY/3AADFziEcyzIo6sypdoWQtAG57K9aQEp2JLJBMyQB0lK1XiVlgppJckU5RtghUYMqGW73IGnOOcwBwbMlkdOOUxKrosK0Xy2gmaiIJa2l1QqayY5eal4UdSJTCFW077qkzSXB/GgoicksvacGyrYkeb7eYPzuAwDA4YvRwxq/fBvL5yPhZP7grfj57hL73/RhAEDz8HMAgDPdpuTDEU/16MEpitHLcRsd067xMAzsinI1tEcf3IVtRhVJgQHJU3RZ5jy9KgzemVoOgRXBxa3t0XQRO+/oTY+KSUp4lSRxYwyKQkrVxx9PZuNEyhFjuiwrTGZMn2Bc03UWwQuBJI6jUpmBPCBlKEyBUsqFSMHNUKUw6rCfhnEDqSSevwbo3UvxvIfPKLiduK1phC0RKTEA4BnL7GyPjtat5Dc3IQwKJpLcmREe8piSeGVSeNGGIYUhp6gncoxswOYY1qZ4VU608GsbCyjc+Ni3xn6QPh9NMKVKzXgUB56pilT4ckrl87s3apgixos/+YX4Ln3gmUNMVPTOBHho+oCG71fTiYrKAk1STIl32vYenQiceKklowfvUhKHQxhKEaW4pYZmiY0zCgfMG4eCKQfy/pwvApoDakOK+xEsnOWJiWCUkwKHe3G+mc1qvHcU417vPYzXcHgTCUkQ71hrn4hEim6+0gpeXjVNf1trBCYtCzW96TwIsqCizue0KqCMxLWErp4lGJdC8/dJ2cZyXmo6oGkibeBp25VesFzzGKo7G+o/idioLhEYJFeTONn63b30kgUGJ/XxY4SWpAYyV7RWCBwYuhYGIRAIsxVGKsAWKRCvpaJoEWCmcYFUXLhUXcFJDYbVEooCtrJIoSwSswitsP8aIOUsSPVQhcCJOyTG3xAolkM4b+FE1iXBfyGJc8r+AYNck7Cjmjca2HfITnx4xJ9qPGrii/7aW5Gx+NaDBuPbzCMJ8tK2qQqwKDd85StHsIQfXCfMIAdFuYeOk8Cq69Dw3qQvVYEEUyTR3cYO+R9W1CsUek5xSUHBD2lM1hawTl4U9pW1sAmGowGBMnFtQOjt7v4BCq60K+bcVQf7eOFmhKV2ObZWfolHT+I1PCEhRq0adLyIWuCikUHPaxGoGqpPE9tQcLhEyQVZEbL0poYihDvipOfHGg9YquX2Z1jfq64Te1WEc7UpsOBYkHh76AblAVKC8EAB4xzjQ4RwW6mSLPaACkn+KRfGFXJGgcvQjcbFBQqIzzUncsjnOkxoEDD7SCxTvaJyivF9um5FdpovQ8IyNcWNX769h1c+GkkX/+EX3gEAfOCFHXzsg/E9PWEZoOPzHj0VHc4pMbTs2sRgEwOo63q4Il61vF7K6CTKLOMupPrKA5GhLjTEjFkyd2m5dJjtEurlWrFYrmA5LoWYVOsWYxrhYhxBjzDhHHRtp8S778Yn8PiY7+asAGqK4/pBmFbGfKImaZcIJImw5xPREor4unEDfCzSYmOoxCKWUABajwWNb9Wnw6FncTTpo76N0Q3/dSxYW0hw27Zt27Zt265Eu9IeFmwXI3eEbRSDf6oo4GnBFlN6WKMxJE07nEcvwj1+D76NcIGiZQxvkh5cwdyrUGpgRYq4QFKmhKmY6yWilZMR/E1S5qfR/lPeAwv6+menCGfRbYfAkq5PuRICfQXnECTLnMyOKPdP6DAMdoZUOg1i4XmbrtEn8Vuf0tpDpttlaKkb2k3NEwX9RYqkvvJu/O3ODt4mBf+r9+N3p2cOtw+yyCliXpHQ+0X27NH9BU7OSRseMVjbd2jpylAgBKfLBr3oHwpUpkuIPeUS3Nmik+A3PS3be/TsD0Ozv3RBkFeEoBLkYGkmu8ykE1UFo3Xqmx3mxX30mbuY7pE2PD8BAOzuHuKF/ZjPd21MnUoUGJPg/9Yj0v1NgZ7lp8sx7dIwg9ZCj5a8HSQhVnmsRls4BuKN5FcpDU0r2dHLL5zHZ39T9LoOP/E4Hq8u4S3zEqmC0kHB0PNYStkKDC9/l+VKSc+k4ohhII0IUGDDYOl+rbIgueoF1v7OSRVrDl08Lo8pCO/+nVso7kbC1OnxZwEAs3CGciTkKL7/OsAF8WbiM3zu+gG+6ZU4bj/38ajF+YufeQvjMr6n57zRlb2WvO5uRY+87VLKhOW4bPsOPeFal8reZO9XUpTxaTwl0oVWSWFDbrK3HlrKmbBTC1PCtuLesIhp4dL4DmEQcRaR2Xo8SqVrzuekzDdlKrjYiySKqjFidfJ0igIwZZHOAwBO2cGbyR6YAD6FwLWmSISkwPtwnR9ytyhY2RVh8FIT+hEziHIx3F+rbT2sbdu2bdu2bbsS7Up7WMo5wGioimQEI9HIEqipCD6e8bsyleoAdfrUkwdQbbRMNQOG6HSKifmanlGlY4QQQ1BVFQU0VRLMiFbwjUPg1vPxNwWvpW2Bx/TiHh8hnEcPK4i6urcIQvjgNajeIgixQgIrXZci5j6xTFSy4iSoGkJIMSxR+g7eD3EythBCUkToxL51BfBFJh9+W8wW7g9XeO84WuynJ/G4q86gZ9B6yj4/dz2UknIrsT0+cnhAM7ljAqRdNKnE+Dnv7eGigxVP0gtBpU/3K7EnFxyWLPPSilmmNSQSUpVDzW8hSRTaRPo/hlIHXW9TcrB4qL236HkNO3sxmP+BOzewT6222WG8z7EZ49ZePN+u0IIVsGgYS+B5j3uNkuQdCW4XdUDFiHnLWKIaKVh6xzWD+IUaCjiKJW50oqCglzFoFd67zWuoOF50jYbJov0p6fxtB01L+5x9eq6AUnLm5bjpbCmnFD1CitUI+7hX2feZd5bHsy7oBbLlqhjym0209i5dVzzi3sc+hLfP3gAA3Hs3knzu7jcJhQi8EesAQ3UUKYa5P97DN9yOHvHhjViE9fU3j/DCTXo9JC1Y1SAwcTyw1Hvf9smbaqwUVHQpHSD3oAxTHIIVenvuYfEZFsWQdC5x96BQyLzF8aSrAl6mZkXdzVClMkY9Y3Y2GAQpXVIOogWLOePsTZFK4XQ2Ejs8ahQkiVVSkqgOCJ7XU7PywWIlDPeBnu8yT7iKx6iLcuhzxlhXWKX3yvZyEJ2mMkkHcoF0+/Xg5vu0K71gIfg4yBiYRlICqKFGZAzygcE5YMHcpqMTAIA5fYKii3+HTmrcAJokBK8XPN4oRcQD4TtTl1CBJUdmHOjPfxjq4Eb8DUkO7nyF8PABz3eaSpYo+sfBdimiG7oU2UWgiobndYXepUEgC5bUVgIw1FQKIcESUh/Ke5exl2QAqYyIwcA+POwb8RwLBnDPcY5jSr64nnWdVIGWENgh2ZAPjldQBZVGOAMu5h6np1S/qLjQnM/R8p6OCLPOe4+xVAMWJic6eBIUhJDR2A4rKZ3BuyjKCkZmQKppFMZD8+Uu9QiK8q5yvM626R2R+mCd99AUsy1ZCfnmjR3c2mW5EM7MRo+xQ/hkXMUxNl05NMyN8iNhAVrsLmmAEAasVYO2jgZUKSQTa1ASoqwk4m0qFGQbpgldq0HCifs3nU5Ely98Q7zOFz5vYVjwqPOEvouQqjG37LkGwCoNH8KTISShWyE0ODWQC3K5JVm8cuZfmW1bZ/r5oBIzs872Cxd3Q48hB2zJVdF9+ACvvvUZAMDpE1bYKle41cUFuUrkgTIluMnyXhYlntmL0P7zz8f39XO/1OHdB/EODqWCsVnBUyR3wU5Y2oAljQ6BuKpiYK8m0daQsXz8sGCJkZgkqbRJC1uRFiygJJmiZrmUoDsA0RjuSYY4XwQsQ7ywipJmy17BciwEhbQYLhZkOy4LeNYK7Dm2vQooGUIZsQSTCRqd3MuEC9yZghfLQQg2frgZgQELUyCw/0X5JagmzVVSIqjofVqoZCEMAOqRgvs6pJm2kOC2bdu2bdu2XYl2tT0sRGgOjjp1JEv4okaoo7Ukrje6Huo0Wpz6KEJ04fwY6CNE55ijo1xAsEK2FdhpAk+1ChHRUoWGEtHaO9GrwgsfhJ1Ey0jTMtKPjxAe0ypczGPiAYAgJAPbQDHXQtOa802bziPQYHAhWdisB4gi4FLxsxBC0jMT6M26wcNKEGIICUKQPKwQAmxkAKN5L257Mmqx4n6OkdZaq6Q4Mb0e6cE4fZLIIFIGoQvAycN4/R2xstPTHgtqip020du0wSOwf8dKdPpssuyW1Hubt0t0hDtHVB4pTQVPS1ys1s6HRKPXuki5LyK6G4KC5XEESr22u4tdlokO4+g1lhOD2U78zXVP6j8cSkcveUQ/Y7fH3jvxnr/xLEIv/5OffAvFVwjDcDiZscbnvus2AOCfflOkxu+Px6kvA8dMMCrpJwrk5uHgGFRvyBXufQFFK/iz3x6v+e6/eStZy47alLPdHTxZMq+PrQiD3p+ke2gMXpR8BgyVaXPoZoWLzauBEl1iIGOkKilRWpmHGQ60bluvMAi1VhxHbzzT4PGX/n28Vno6Z3qFl25QfUQT6TBIuXRiiqtyB4e7cfzcuh3fzU/Nj3B0FMf0aI8KLBOFlaijiIpDobE3pfoElUQWqw4TajlKJWGoQVTYp4rDQyVdw4spTImClHhwTHYWCf6bMb9P6w6KbKAmiAboCqdLyd2LF3ra+CTUG5RPRRMXZFMc9xNcZ9HK6STuN3aA4jhSfslzAAX7uiGRaG4MhIIjxCSFTNSHyJUKCs4LfEkSildJDLpoGSaoLAjADNXWDbAzNUz7eTrmxdbD2rZt27Zt27Yr0a60hxWcgyoVghQkI8081GNoej8pU3LVwkkBxDOWuV+eAJaJsm4gXfiG1HMSAIJaAC2tbyV4v0cx4zle+EA8xnQHWoLpJ8z0fvddqNNI8gjNHLqR8vWke7sOXmJXTPhEa1MMKzBSGXqPTmQKaNkpeCjRLhPr1bshsVhIGs4lLcGkKcgjAINenfIOoGLClAb56tgh7FOvTIq7lT3GNLvMdVqoX23ArAI0kpxqgFNS4dVe9KDOmwJL7tASQ1eqS9aoFHVUqksqBDZR2HtYmu4jkmpKXQOMwZVMw6+VRkvTPiqzC3EhuzD20YR4/jN7NzCmpbhiSkRVKXj29V4Rn3+lOlRv3AcAzH/q1XjeLz9C92Z8nh/+eHzWby1c8h4k3nMIj+d/8j0AwB/+YnRl//nvfw5+FGMskqA7tQqe3qwkb3rv4Olai5JJ6wbdy4fUYHzjZoODh6T5WyYTdyOc8+9O6NSZRzRQeIbYk/hjOsR9kX2ns3sayT2GwUY2GIgTst8m/UGPgUZvUqwLWPHBmxdjvzwsV3jEMjtScmZlFN6+FZ/JhKSLqXYomFis5MijEmMWX60KSaj3WCyIPqzo/YwnSSPykMUiK6VRlSJrES/wwZM+FTEUjyY4lxFOGD/O3jlppigTCUI+TVWjqqPnN6rj/DUaFamPSqIvPfawZDLuKb3lumow4UM0hcGYEjPHp/GaJ6ObeO5OVKyZTaXcR4mOMbD36kisevDuWYo/es47uioQllLZQb5E8qI0n4NWRfIWhfLeWYcgCcPintkhNinjbVQD169NYLuAoe7D+7crvWCpvgFG48FP5SQWdvagmH+VRtJqBXDhAKES9CsEJwsDM7M7JEFZ30uVWQUl5AeeQ1cV1C7PsUMmYu+hnxBifBJhR/3wHkC5/xh15AtFSCp0DopPOqla9B20sI1EEsYF9FIRVSbjkEnahGxxkjwsyTcKPi1i8ougBjDGJIkFjYr4T/2Et3Z3iqaIi42wi3TZoOQMo3dkoayT5FIlArClwjGFcw8/HEV1A+oktluzwmrAEloCsYQxvTbwKYFqUMEIXFwNf1uYcRLRlSquTum0IBkVg90AoPhiwQYE4nSiFPDszj72OIksmadSVAYuiOwWF/Kf/0Wc/9nXAABv//JJvDdtcMRZ7InIYeHiBA4AJ1B4wH7f+UJcQL7vH76Jf/+/jvJVKgjLtUjnTeUq+oC0PLBbVO9guaIdz+OY/vlvdPid/5w1lAhdurlNFYJl9mnDRcIEEF8VudYFB1YZ1MA2zfaXRSw/bM7+k2P2Wf6fiAEnGaPsmAIdzgGccdvou+/G43kFp0R8mobNucZXHsQzXtujEkOtMRLGrbS+R69ifyhCy0oDK6pZeKpbjM0M9ZQVicfxu3EBTJjnlBaQosIxO2dSiuCwS4axwKveR+Mx/pjGh1YJti5J2NnZ2cOMteKMCPdWNTStv4b3s+jKBIOvlvFZn897FJQ+GtfADoW3503c7/rOHl46jLX/Dg8PITe/7EUGLZ7j/Oz1JJitaShppdJYkKnDYggpiIxYQJXyv5Y0uLtVn1i/iu+NCUNupDz/slbYm43Qtx5Pu2BtIcFt27Zt27ZtuxLtantYwUIpgyBVfiUQuLsDjEmSoOWrlg1ASNBTDFX1fYLUFN3k0Lvk6VDOL2Zji1gBrf1xmA1F25bUB1ycAlTRCO9E3bKwPBoUNpyCI/Qo2oWh6wYeuKSCuw6e1+NS/kc6TFK/KPzgOaVKpwhDaYjsu7BGzogH4n4ilgqdqgW7s/i5V89wxnyznpT+YuSgKGdRjEXsdYRWRU9S4J2gNB6+Gz3OV1IZjxIVRelKnnfVten6pZSFcy7lkyXLXAFaAtgMXtfFKFluUu6j8wGmJq3ZaxSSJ8J8F6WaRAcWGGinMjgU6JNJXNbq5OmU/+LfAADO/xe/jCPCSY/oRvjgwIK+GUkCSVBULNS5CslqPeO2/tUW3/LfRVWRX/ifRvimMG3SY9Q6eu9zF/Cfn9wCAPw3+vM8R4k51TTmLj6Hhzc9vstHH0Uo1LANRuzXB+zLTqW6khc8IxkTch89wlB0Mru3fm045d6Syo8ZBq9KHtOIf1XwGws9nu6Tdv3hSGYalyP01PVMHpZzuP8wekzvXeNYqF0qZCn3NlIej+cR+uqoLluNgHOeWCqEj6oaOyzVcTiJVz8qPMi5gRVx5rJCzXJAJdNfTs46hFY8Sd6HdQPCwXFugIQGzYjO3JrNsEtCh1QwLoJPhUUNPZ6yAEwp5I34XWc79HwPq0pjNqXWZBP3253U2Gdo5JmDSL4oqwJzFo/0JO88ePwY95cn8frTu6sG75djVWPIKRTlD2uRlGs6wontqku5k0U55KIxXRYLhhZGhcGsrlLKxdO0K71gheBjNVJJ+GPlXkwnCMw7CE0cmfp8Hll6ABQZetp6pFKsGTQoOK5rhWUX0Cfl8PhdBR9ZhgDU4zihY3kMnJ4AAEpZuHyXYijeISo+AikfK3QNtDCLhIHYWziRESI02NuBrSV1jqz30GtJwj6ElB8i0kze+8TM8shgikvsL4UgNaioWL1bzTAy8f4cIaag+6Q6v3uNi5gyWMpLzeN2zuPRO2e8jyT6kyA6L1BuGJIKlUhNoU8zfYIS1FCK3BiZpEYopFJwwss72MBJBQVqwriVMLR8SLGhpTCXFGClqrATCacW/kuMV/0vPxk/z4Ajvl9SV6iHT30odznGIFeUEmYDUp/Li6cDUPxUNHhe+K44nh5+xKAXYWXWY/rBxTfhYz/1ZQDA3n48y//1N63ghBXXxcl77IB/8R3xfn/Lv12m88vzlzwri+GxyyLa4CLEJ9uqtYWtVcA5txXZYib35NUgkpqVX0vHWRH/NV4l6+aYs+ITo1D9dsLH+1LLrsIB4bDWCQR6gjmR9vtH8cCHux47RiDqeAVnzWO8cxzjhg3H02xP43zOeBaxTV0YXNuP4/zGWGpRWeyI3cuHabRGTZHtPlN3l/Cz9FyAS++aJOgXRmMssSvGq66NSuwmgyv+drk6gWFuVknZt3ERUFGaKVU5CBaqiE90NgP2DuP13z+jULMK2BvHbddm8g6MMWVOZGMjTHj9+gHuPzgBgFTBQRkN5kyjErg+DONIMZZsHeD48i4Zn7fOp3dXDC9VqFQJWy7faI96ZNL8+DRtCwlu27Zt27Zt25VoV9zDkuAhg59kd6lyBC81X7jqh/NzgNaqciLcahHWAvve9gj0bnpxdS1gIeoShN4M4I4p6yQlO0xIFYxTNVLnE7MG1iXRW6kurGw3MPiEEdg6BOJ/otDUWaCX/Cs3eFBOFCzkHCFcgv9C8MnbGnJghv3EI3JQKVcGtB7rcgRDeRhDMoUNXSIhjHfiBY5MgYV4tVYEXgPmpwzEUvXB2pBcDsmG17qA82L7C5vRDuXIIPCdSvkuSsgXRYkR2aGSU1X4gJIeWKGLFOguxPpGjVKqrlJc1BY1lpLTIh7s4jF2/vjPAQCOn8TfnkElD0GURJZKY599ueA1lErjGr+XnKUFBo8jp8CsOH4++F9Hgsrbf2YGRw/37uPYb9/19z4Ndy8y5Z79zm8EAPyXP/UJ/K3vp7XPHJ62Dnj8kXhPD9+Lx93/cotyzZv2GatPaAq51yVGbx8Gr1GciGX2G8lYVBgULHJvSz5D9r2ovC7gccozPqbtvP+Dz+Hge1+IuxFDqkd1ErPxUg289fAkoZyv4uei0Viw+nAfSDjxCzymHBroiddTi7OTeDcCqepCYWcWT7I/jmO6NApTQloidBsQkic0JzQ8qhQWavCsgAgTC8wteY6qAGoqREz4rkynGqURxCH29JOTxzAVa7eRuToyGiOpuye6WChQcu4bjw1mM5aVMSLU7ZN6xqSS6uJV8mp3p/Farx9eRzmJwsCjsRxvjEZHFzZhI3rIN5NPF0IKk7Sc2/p+yKFM8wmGEiw1PcVeBTg4uK8DEtx6WNu2bdu2bdt2JdqV9rDgLaA0PK0W0Q9EVSeatJQNCatzBOaiJPq4c8nDEjp16B2cHQKKABBrrRGLtUJ5V/DUJjSCVdd6oIi3UgoCWQTaQVHdIYi96t3gEQkl2oYkGtkJ0x2A6Ehqya/yLhVhSzR1pbK8GvHE/BCvSiK5KkVTB69LJV2w/nyIM5T8cUnSQu+7dByxSsu6SXGvkMU1lowH9mf0WuAGj1Qs0OCRPCt+qkyRTkQ/lR4EO3UQenuJQo+G+wRQaIeK1qoxCoYelmHMzEOhYFXhvUn0TA7NCDPmcz30J/F8/+xzWPxCjMGdsIc67ZNSgBSO9MpjX6xo3nsbPHb4G/EsPAZ1CbEpK8SKugDgPh/HxK0vP8Z7z8ZfT96OVm745ddxSit/8sU4zl953OCP/bN4n//n/1kkKEwwwkjHMfb2b4m0cPP4TRTHRCMwtDyeFZtCMo2F/o6hInGymsOgNTjlw+4x0JaNjzE8OSIQnRL5W/KsegCNVFH+5thL3/A//kbUB5GQ0GXeimbMpJoQZqimWHZULnHxt/MFcFYztYIBmK4PWPYsaMicu8kMCERZ5vQKRrXC3iweZ0bR60Ir1PRmJNezg0YnpUYanqswKAQOyPQ7RWlGqgcZreFJV9+ZxhjVzrhDyYDqkk/i8fIJAvPJZjbu5wKgmdykIWPbp2rm49JgtivpHfEaFrZPxVJNLRlzIxSM7+4yJeXW4T7u3NqPX1NI8bQwKbYp2SCqAMpUAVs8zzZpA0ruW4BLDztlI4aQ4tPSVQYBZ4sz9O0GQtjXaFd7wepbwPlUadgwkddVVVJy1mTwYbnM5I6k1lSfMBrJbQreQokEUpZ/IE1qPQFIuVkiDYSgk/svU4MyOh0geAswFyQwpVKrLCFG6GRhYBsJM7BHBifx5e1cgE4Qmrzc+XFkcfIJqhiSGWU4DvceVxpO6mQj9GcL1Adx+uloGJgwQcnyqDWht2p6jmLBiVwESH1ISvBP3oskFHttiYKdqEXROXTprXZklqhSJ2bWkPsRUnlvIzWEgkZIOkDc5k3apgud4IsEzRqdSqmPGAQfmxJG6mFRab//B1/BKfttzlNYD7S0hgohXwTghPeZExQkaVaaRmY3sAUAFX8749j6TT+zwKt/KAbLD+7H8duHgIZQ5buvfgkA8Mqdfdx47wQA8KHPRdjrnW+5CVD13TIv7t3f8Tye+Sdvxm2ssFtZldhaIh7sstEroyMgY9xl1y1Q4IRfdmGAvgoVQEHx9L4YNSi8S2tKoNqPO3z//+Y7AQA7z+9hJGKwVAQ/6luo5qKU1t5Y45TqvT3ZeqtO4ZRG01gxGdvO4Yo9AEDJxaweARUXhNNlXJx2a4/99L1ULg8wWmBwLhKVSzWoSj7sotIwosIuqu0K8EGuWRYYg1LysMgMPJgAjsbrkY3mzJtP7uOdR1FZ/vnbkRk6rkZYMqQhAtxFqVKliLqaYMaE/OkkvmtvPHiIZUeDRwQNSsCwNPAOoc9be1O8eCvWB1vcP+H+Pk1LklQMpVHwXSpKYfW2ySAvtMCiNhGOKqk2YFTKMytJnNEhYLlcoO/WBsb7tC0kuG3btm3btm1Xol1xD8sDwQG0VsKEQEVhANa+CixhgfY8BQdVSC5PcqOUfIaQ8g7yUGDG/I7/Dj4F5wPNSWMGSEsknILTiUIdnIUj2UJLkLZUyZtK1YO9uxScD9k1iNyJdRalSPr7wQtx4mF5gddUuvBBgNSnv+XTD7eXrrk7O4HGdd4T77OoYUR7hTbPtdu7ePCA4sKEKYK18ISM7r8VCQWT6RkULbExrVGPbsgZo7WsjBm8XsFmg0sWntYDJBjoWSUCilMpt8MYjSB1hlJ5lpDyb1Y0D21VJ6+xYlXo5WfsAAWy8zsMXq+QEAwGeO2AnxYDMUGaQkZn5+cICjWf7FRUMH6lQ03Lf++NeOSHfjiftK8+OsPLDN5/zy9F+v0/+c13oXgxleTCPHuIe7+V+YP/LuZ8VecBY3nGPF4un5TT8+U+pol+DdB5Q5mxKmRMzxSgJaWQX5cD1yYd+1YF3Po/fDMA4MZHYw5apSvsz0QgljlGdgI9Ea2feJSRAUb0ak4tL8IU6ClfpShg7bBET/m1qop9Wo88qnG81hOSL6Zliwmff1kM8HRgmZeepKBCBVQkYtS8+brUKCiIK+MzVhyWuWCA66W7WuZjmr0ZdljWZM4S3FqX+MJXvgIA6Hje23uHiZwg91jrGUY70QO7trOHuore1O1bEcb+7GsPcP8kljZ6pXkm9ltRJ0UVJVJgxuDaNG7bI+lCFybBifLelNAoRX+N91QUGTolHnuh0YlodxiQEIHsK8kjCxa9c6nsyNO0rYe1bdu2bdu2bVeiXWkPSzkLuBZgDCtVGYaHboVswVIQ3SqV8Uh2XwhJdy/pfvkh4TN5NLgYJAcYj8i8mngMwCdrXn7sBg/K9kO5zSQHkdkM6Xh+o5eXdhsEBIeYVEZ0ELJFuq7Mw0qxtcxnk/iWUmXaTzQM3XyBngnDPWkE2hj0Vii28XjXnzvAzmcjzd+FaD32oU8R+0fvRavv4PnTQUBYFAOqIeYUIEK2VSpX0sq9oUVNj6IQ2joqiPito5fcdgFn7DkdCjhkVYwR4yCaD2h+Fq/5jW4PL/N6wiejSslZ75PnJJ/AMC5Mtk2e04TX2oTBMxmEUTMdNX5WKqTKv0bYKucB3/tzpF3zLCVcEqRN49IFLOv4r7uP4vOqW5WK63VUPgjOAR/bBwA8nMQ7Mf/yCcanF6899jBje3J9yDwsuo+ja0BD6bcxSyNb5XEjho0wmwBNFJeAPYn3VNuQ6P3MAca/+z6D3/vRSBa5vsfnpQrslCKwStp4r1GwYz29EG/98O4Q4ehDjUJimxLDVAcoTYz9VPVxvJZxD02B2+UTSdBvgX7J65PiiUXyrFMSuA+J/FDRMxoXBoZxLUmOdS6LTYugc1Wn8icNPZBlKHFjPx7n1jJe+8svvoxPfuZXAQBv3I/jcxwKGClxpCiWO7uF29eiDuXtyRTLOj7QF1+MaMZrr93HW+++DgBYfCB6sPWkBEJEooQkMa409nbiQ5nSje5sl1RKZAAHpRJhQhCKohjuU8oeGY0hA11QHF2ld040qH0AvLPwX4eHdbUXLNsCbjVUGjaiOg6EOV8PpsPrrkuTcGLyOJ/YgbKQOH85V8bhcrA8zvey4/CpEvOOu1mXJm3f28SQEzqZt3pg6SVIcCB85IumMPgE8dNBD2W4hbHmA1TK/xgARX1J1sIjCMEisUL0MNBISWwetgj2JPaDY62vUEKNqF5Ol/+5V27hyS/FWWxxT/I3NBwpRKeP42RQz3tYDtyW+j77sxqlIaeOorYlygQn1pL3VGiUJRUAzFCPR1ZwgR8WdoBFVeMRZvy9yFooi57w5rKLL/nnXvtllB+OsMnBP/4S7zfApkrTwlgcelBln0noNhs7dm0/n5HwZBFTCqgTUYMwiwK+4ediH37lI4SuMHBGhsca8BZJFIfs55dePcKjb3+OX5PI4jsUCxIxnon9fPw7DWY/FwP7s/vD8dZV2OdAYjveeTneyfWPBJyQeXf8qpBlgJ0PUMy4aJPyw+GSpIwdYHEnTrj/7E68lk8vVvjdPr6n18n+8zqgYl0wRVyxLjUqLndLG5/XpCzRk5Rh+V6sUKEU1RN5IuU1jFi3bFWRGTptUEw4Z9yL53pwvIDvJX8pjnOvgJ7zQi+q6V2HYON9iEJDURhULA8vkrHKD9JMabwEYKgZFZ/XRCtcYy7VbDfu+R0vfBi/8Fysm/bmV6JKx9HuIe6oeF7LBe7mzrO4SyvicDzGrI5P6s6tuDBfu1bCH0dI8HQVF74DN4bQ+iQUMNYO11kHa/eQFYfbNk1rwmL1SieYsCoHyF3oGWU1qNtKVXRZ6HUwKGW5EQUNraDN4jIb533aFhLctm3btm3btivRrrSH5ZcLoFul1XzQkvMp/0qxlEjolhDfUzLQ4UIiYOTqEGHNw1r/O+7kMyq8eDc+5eGkHCfrB5fI98PB07UO/rDAk84pWCEIZOdPGeGpvEiuasH9vYMXCqvkjoVMwDRztHLVCwBQ3iNIeQHxCh+5KOALoC+j9th4WsGIsitLceyrFe6++CEAwGv3o+Zdp1WysHpCOaEx6FkSoZWCTL7F4WGkHkv+TB80tJT2oNJGYQpUUlZEAtDepaBw8kpdCdsTdBt52Dk9uil1HINCx6B34Di535/hq4/jcT72+RMAgNUmUXBzGVOV/S1NelIswB6D15Vb2LlXJgdZp5KXSqE/i1uf/UJ8Rc+VxfnamNBQOCcD5Bap2Ld/8R6Ovo3wD+u5WZRwUqZGqkJPgbd/gNWU78Vj3LgP7MxpLTM3pupCEgNuDlmOYidgcRp75Cu8ub2Jhp0wFUJr2OfiMX/5w/EcD3b3cXItjg/3KCp2HL72ZXz6YRxbv6mPnsLB7AC6EBINSVJao+IT6EhHr00A06awDNEjWnYKoBdSCaxY9piMSGGvo9daThbQOkLUJX3Ke096NCwrNCU5y6kOrYvIgG3ib7vGwtO7lD6Ft8MzoQfSe5/qlgk8WTqFiWEZI247GBUpr0AV8T5eOjjEb//WbwAA/Pzr0UM6OLOJlDO+vQ8AuD2ZYMYyJVU9wojfv3Q75t8dPLuP60X0TP0qYrTe7SdlIMX8RRMCDiYUzp3FY6za5VB3L6TRCkMPthAhXj3MQUIuKYyBFMSSqU+FIqnTJDKHKaKCSQiIOjC/dtt6WNu2bdu2bdt2JdqV9rDccgksz1P8Jn12HdSSJUSYbOf6FRRX/UTtDqTFIzlECCFTmM5aijmkLSEl4yUpgF4lCnUQ/DpGFtPBxbsTRQwFP1DSpWijV6kCr1j2DlnWOI9XeAvPcyeNr+waxQPUKvPEkuZZ/v+BTi8EBU0Pqpw7+FOSLliGeKJmqMXSokdW4QQHt789fk8l6pUdqgY70Wqcl1DRmUrOZWcyAgstT98qeAayFWNTygyeovR95zpAAv9SVM4rKCaVtrXHdy2j+vftn/k0AOD08Ayf+o5oZc5JEOmOzvHqV+LFvvCE27y+9Nw1sOaXxiRR6V4hFrQYPCbxtNSG5Fmd/VaIGCoE9IxMu44EAJwl91jnGpD8bSvaig+XUCSmiPq/6TUcA/Y6xPgGGocFL7D+aPwufMMIE2pUzBiTKTtgfxap04fXoxc0N4/wxV/9IgDgF45iTOn6jse7j+P1vXsjwD5Hi53lLWo1ShnIZUULf6Lxz/7NVwEAv/87o1egdlYIHMteigpqBc2/RxxvU+OHcc6xqpxGRT3Rkn1QlT066gpWVWSFjGZzlLN43fUOS5Tcb9EsSMCYLdlFi5R469q4rW8cmo5CBeL19QodUy9S5WT4VNEZKTYZcMD3YLSMx7jWBIjPXJSxfytd4nufjV7y8ZQEixbwK8Z8TUzy3alHKXZWq2lSpDicxj442N/FSEcPbWTjc3f2DCjoWUn6ibbQzFOYUK/QaQ+zVpQyJu7zvRddUzXEwluZ27ROqMdQvLJEKnLJY5SFjpU2nMfTelhXesFCZxFWZ6ncRird0RfAgiI4Ag3aFspLboBAZW4QhZX+9VnOFVvOscu3yUKjnBu2Sh6G4FNwQz5UJsMksvtAVtOK23oXLixUca+h+aRqYQcChh5kSFPaRwYJpvywDPocyABy88N5jMgxtYA9ZimGmxFGsW4PWu0DGAgDbdFi73pcia7vxiqnx8szeEJQViStzndh9uIEIgoL3g9lVBxhLNcrGE7hsqBquAH3k3QQZ9GxtEbHl6NpW/QU4D3vOnzzz1EA+R99BgBwSwd84Bvjvv/oj8SX99XRCP27ETYxqQ6aH8pk8FMj76OL3wGDIoYPF9UigDjG9Nq2ApfHGxTw3A/9nvj3cx8GADz5xX+DL3/638W+Sfce0vHmq3iPE4MkGSawqSlLLPWS10wVh5GCS4LEhAmLApaCs34vwr9treGvfQAAUO8/CwBYnb8OdRAn/GMXP49WwHusoVbfAnY5kbZp0qsS4WhcchKe1njji/Fav3Ac+/7aoUryP47czD44VJI7KUSHuoTiWGlXlEjSASy6i3osFak1lqwuvUNSQjc7xM3rHMtHcSy++cYxzuZxYZ7uxms6xxxnVLNxHeH6tseKOXJgJeOTZZkIRDLJw3sYzgtTGg771kIfx/NOWjJkbQ00zH1ScXEJo9u4sRfv6ZtuxUV2/qUlGrIS7VfjIu9vX0d1I+ZIajNKC8FoFPtoR5egwhqogYvglhjMqng8rzqcLeN1Sd00PQI0maAC9WmloBJlkMa1HpjKonSiY+E6HptGpDap4rdm7mhZVCjLAsY7AA/xNG0LCW7btm3btm3blWhX2sPSUHDNHIUdqgUDgGodLEVmDa1vuKyUyAWiwkUP62lbhA6ZVyBwonLpQJLXpXy+baA5eCV5Jyp5W+I8uMzLy5UuxLpM3mCsqJbvEam2a+6g2kSnHjat7a4ufBgPNKf0/CzhVXQo1jjWwRgECs7evBHp4W8+fhsrWqjiMfaLIllxBS03H0JSj5Aqlb3SqThgKMTiblGIgGYxeKriTa0IDS1Wc9TME/nBL1bo34okEE08zpcB+Fw84e/5P8Z7qv/AAp/6ML1johOtUTGfBoO3ZDf1ZRjub5nlw0nLocEELWbkl+R1Cd13PIL+0Mfi+QiH7k33cftm7Ff15D77xeF+Jo4MAKUNmL0Wv29Zsbct1ZC/RK9lufIJqW7Z97NeoR+zr/ncNEoszmOH6N3424nqcXMWvWkzplXtQyrZ4pxHq6is4ZK2A2zN8j5F9HSqeoyCUND/++eiAsfHblXYYZrKqj2Jv3QNVoKe2OidNf0UxzZa6gKm7EyBuiYyYOR9MKhFcJZ5mk01wZ39eP1ns3iOd56c4/NvRgq5GkVorik7nFEwu+N8onuLfknPzlEYeTWCDxH6VNQKROhRMSwwZhHZae8wefgk3vujeF47+VWYW3wAzM1SNw+g+B7cJuR6760W1YPY/8dn0fV5dfUYN2//IQBAPauginP2TfwsEHBMEo1UYy9MBSO1Wpi/6HWBTlJNqBBUjYqhorqMaa2H/NIgSVchebqa739Qfqi8nKX5aAmDKKk+XmE6rVAWT5+ItfWwtm3btm3btu1KtCvtYVkXKZeKAVHXMMi8amNSMZCKCQbnkmUcRPEiYAgK8pg6XPa2DAZPJ/dGJDHT09rQWcmOVCLaD0FJYIiwC3nA+yC8jxSPsv5ySYeAIf6UZP8BaCkvkuJWbtApDEJG0QPlP7u3RMXmj12+lR5K6YCS3dUwMOuwgnWk+dZUFlh5WB+t5Vt3XwQAzL78K1gqKeoWj9EvPPZIW25ILfcqJmQCwHIUP2d+FJXuMXgry+4cmtbySJK/lUoEjIYK/b0/h5/GDvyt796BP7nPfdkvSwU/IQHgJJ7jh/+fHuoPSPCeSZ0hpHIgfQqmX6a158NFvJWdbL+c3i4jwaTnlSmX8NtggHD7TtxyEOOBj//p/wuGg+IG1QhK46HO47b7vOZDANc/G+NBn/1I9BSMAtBISfi4n20XSUndMwbUVh2qBWM2ZA+43uOU8Y3Aop4qVKioEj6jR7acO4BUeG8B1UvwXtxkD6ykIGN8/rNyikLFhNZ//28i1f1L3z/DS2bKa43P7ay3AEk0JV+S077Agu/fSAp0Fi2KQooYDoouEylnQpd9XBnsXo+xoX1SxO1bDR59PnpR98YkjNzYxRkJG6es+tDMeywXkhQbvb0SE4yJAgSOz9IEHFBt58DHUTRbnWD2ZlRRKb8Y7+38k+eoBVkhH8aPNXqWCll+iVR7Y9CPSbHngHnw+Ffwy8/8KwDAb/6Dh5jRW22XJ/F43uH+uSQ+s9LCaB/ecFxQ+dIYDT2K7/MjktWm5QjnmlADU0188DE+BcDT8zdlgTHLDtklU4hcl2XSMwkcKgk7iOumjUFRFvCCfD1Fu9ILlvEBynqE8/ik9Ty6z6FxUJRkAisOw3bxTQJiWj4AeD8w5BLMNhw/Z4blf3PPS7hdgE1BSSFQBGpUAHHxWidvxDIEfAGyPBvhZGx2lrnIeqSZT2UYVGIMJggxZESM7BAbYFDJuQjEp4w1mBBqrVj9VPUeasLFhpOG9Q1cEV/C/Z04ye5fO8AxXx5ZsOyiRREiNFOR3dB6oOni81oIlDuaQPWEhnoqZ3QNHMklLVUJgi+wIvNqJc+8W2FGHaOd8V10Dz4ff88Jf6lCKkxlBW489/i2/5ZLdkfIEiG9IEnqN2TPZM2oyP/M4b9Nxo78XWY/qsSYqSZQ1yOc178XJ/KFMXh4FBeichav6mYZcFvyfuSAFrj71Xhzv0gR1z14BKpHePaVzgZ6JzlazqMkvF716Y7RM5exEwNIG9TMv9vZic+oXSxFgQhFABzfNdWLsouHYy6Tl5Iuk70E4Z0cxev5l//xDD9ckJFn4v7LZYfzeZzsKjF2/AJT5n09e8AFq+xQIF5rqcSgATSvu6TBNSo0RlxYnh9HxZY7N0fY+wrDBxyL+MgUO9fiWH1E2aT7x0vMeU8ieHtzXCRiguH7P+uWOKCm2FTg0YcPUX72dQBA+EK8ltXpwI+TencWHq1Ii/E7bxxqGqczPo+D54DXf+r/E8/x4iG+7bs+xAPFYzdNjyP2W+dYd0wfImjmwwlrWhmUZTRKHi7jtsLUCGXsfyNzmVaQGUyqeE+qAhNW/O5qEqvCYpga+Vvrs7rCHHs2BFRVkWz8p2lbSHDbtm3btm3brkS70h4WQoBfttDHMVsepydxswVAmFBL/o/r4QUepPWng7/seWCA63T23eU8rAH9k0hkpKfTZSYBIFpeOa7HP7MDCT1evCoLXKK1O+SBevHYhtwzodN6pRJtVGvJMfMwhM2ShZI5iIOtHdKFiW5h0QEVz1E1oqDRo/PkvPoB3mtDLImwqyMl+tYzL+O9d94EALS0ePtlD2P3AQB1ySCxBzomZYmHZe0c50xJWPlo6e1pn6BDTxq8hcfpMn5/uqJeofP42H0K4i6OsGK5kCb3YHnHLT2KHQ2UMe6Pnv27G3SWDsD9cVEtJPbHZWdrk0qKwQD/pbGlBsjTcuP0mz4Ad/I2AKB8PRIBRr5Az2f4ziIe8eZOgWocr/9lOkQPPLC3kLSHuN9cu3THXgqH6qFCsBBLrEOqtt11Us5Dw0HykuIxJrpIXrx4WCcPlvDsX+N0Bnmz/zTQ0ysWvT1T6QTdFQy8/8tfPMZLd+kNTOO9LRcex+RnV1X0R3ZmY+ywmtCMmrDjykKXJIioeP2d14keb4OU9imxtxPHxzNlhAZtb7HzkGK7oiihPBZ34rVeIwS2Oi1RTXh9LC8yNjVg4vEm9EzHp2e4xjlguoh09fJL76D7PGFrQtFG+SQuTMcezkcRYABYEcFQHikFp4uHw54dQh6/8n//h3D4z+N25mGdPDrHihqSMLt8DjMo/WGekcLgeAiN6MmvmDNhUAyVhkXQGwGW0POYc8yoqjGlh3VGmr8PIeWyytzWuTZ5eVJtvbBFHPw5Q+nXaFsPa9u2bdu2bduuRPu6Payf//mfx1/+y38Zn/rUp3Dv3j385E/+JH737/7d6fsQAv78n//z+Nt/+2/j5OQE3/d934e/8Tf+Bl555ZW0z9HREf7En/gT+Kf/9J9Ca43f9/t+H/7aX/trmM1mX9e1KMSAsaIisT+NsaygSqCLq70UbYRzAJNXkwZgyOJLv8Z5clJG+sEA1Mb9VBaHSpaluuDKJApzdsJM2vB9r2Gw6BlQ1kiK0aLGrpRKJQ4kKz2EAC0JnEln8HIyNOCTO0CmKso2YFfF5zIV9RC0WElZljbuuOgaLFfR09013wsAuH7tBg74TFd9xMidc0DLpE4Wi1NdJ48GK02ViX6RXBnFwO2k0onVIBT1vm/QMFF2zkTYohzhux4xtvKrvwrmdKY4lEWK4WOX25pw2av18ANJIvvc5EWtJxG77O9U4zDzpkTVwoTBY6sZn7GP7qP4F387XvMD7jk/TeeVWMfJQmHGrwu6brcrg3uMseySBNGUAS1p4XJDptDQVAjp7ZC4LkX4JAXEOg3HkdKINzquUpHOnV2JjQxEoaI3KBhbs5KyEWyCJOZCQjAFJuPoHi1JnX/4LvDpL8b97tyJ93HaKXQNvV6WGbm1Px6ShKkcPhu7IdnYMcbZ97BM+vW831JVeL56CQAwvs5ePeyx++UY6xw9pJfpV9gZ78fjULN+4krMWZ6jcxQztBqeWo7KxX6ezheoVvECq3txfmo+s4JhuRXLp9mGYVymWDYCHISsMBCPFqKyzh9MHiWRF1wrz/H2P/9FAMDJt8Z761YewQrRgWomah+BSf8hRGq/1jWqIsbWdqbxO+8dChn8Ug8kk3RJMW4YaJLOBCXxWfqOzHPW9uj4kjv2UVh5rLoKXff0tPave8FaLBb45m/+ZvyRP/JH8Ht/7++99P1f+kt/CT/xEz+Bv/t3/y5eeukl/Lk/9+fw23/7b8fnP/95jEZxcP/BP/gHce/ePfzMz/wM+r7HH/7Dfxg/+qM/in/wD/7B13UtPhigB8KT6M4qQoOqqABCS1KtNnifVoR8wQrrlMCs5TDgJbKEV9nqM3xklUa4XwYhYViUVBaJf5ocMJNdkSxORg0LlhF8Rxt4gVzSAPKJlCGDS8NtIA8EyDQrLw96iwkn/70X4jGW3mIl9SMKmcxWWBLyMawnNNYahzcj2+3RaVywvHOwlIMor1O2x9k0aVpCKg5tYleWLBjVqkxxQjP3atXDtbxPlto4NR2eO2XdoXsP0LKDpaZVnxFOpA/6MJBUhJDgswUmvbvZ33k9LLu2LWCYwElIRBeGemqyYGkAFV/D8gORrGLffIQg8kqPYt5OsNcGVidZVg8WDns3OWlyEkDjUHHw3T6Ni8Abswotn5eUtfAqssOAATp0XiXYSbaFoBLsLDl1xaxMsj47rKPkFeA4Zow1Q00uyS2CT9cdaPhgeoB6FE0GVTDHb+Hx+S+J8UWYe2xBkRp0Kc+qwKgQAVZWCtYqldZpLPPN+h5LXveKLMed8EE8N4nSR6Oavz3qYZeEzRexD+pHR7A0gu1zZJVONA5JYJlX8fN82WDCmnyB5YzqJ3NMT0lg+BLZe4+QSnEI0UIhW7BEwBq5PNwAScs8kYwnC+wRHizqgEefjjByRzmnnaWC5UgrIHDcFDIKBZoNSqMo4jjaH0tpFY800rX0sxoYynItqoCX3EnONzYjl8ltdK5DIYLTUuV7aXE6H6Fv//+4YP3wD/8wfviHf3jjdyEE/NW/+lfxZ//sn8Xv+l2/CwDw9/7e38OtW7fwj//xP8aP/MiP4NVXX8VP//RP4xd/8RfxHd/xHQCAv/7X/zp+5+/8nfgrf+Wv4Jlnnvl6L2nbtm3btm3b/hNov66ki9dffx3379/HD/3QD6Vte3t7+O7v/m58/OMfx4/8yI/g4x//OPb399NiBQA/9EM/BK01PvGJT+D3/J7fc+m4bduibdv077MzwkuqgsIC7iR6VuE45nToagwldPaUD5W7qdm2Ne9mk7eT52FleyaCwuBNqUSTF+HLCF0JVXw4UMqz8gN5Q4RfVQZQ6gt/5RlYhPrEs0o/UQkKNEbKNLikt5igQ6gL50nnTQUhxXMrYN6OfXnwndEK86ZCw5yrXvJ6QpdgUMnRqZXB3s1YiG7y+lsAgK5fwdGSnbD07MivIMloQqv3QGIoKF5763XCw0wxQCorelZLWmqnpkVBTHPe2KQHmIgWYfhbXgCtgJ4PQvavwuAJzTOT8aLGx2a6esDgTcsxHC5W8gXkiTJtgPDTqrmPJVUXKnqX7vTddL5Oxf1XViXR4AT16YA71OK781ZEHl7/4GEiUaR8vUInLUdVSH6dgmVfi5grvEs4uKQeGD0USpxOWcxQKzhBLjoHzdycopQikiFB50s+673gMamjFl5VRcq+mgNPjuN5jk9o9ZsCQQlxgv2rVKKrqyCqHDrqTSKSdgCgbxZwwmZYxuNdGx3iJh9oqFj082c/jumC3gVzr6wZwd6PY79+LR535xbQ3aVaxUf53ew26OeiO4nvxeT+EpMT0vITmWdQ8ZM3T+Fi6Rr5NGujrMMwdmT/GinNCepdYFnGo5+9GnO9nj24hl2mgdSBxTWDg5bRl0r/WogwrSjEhEpBceSKHqAJGMabjO2ihkqT2YA5yLSUECffwxFm1izMuWoCjs7O0HdPz2v/dSVd3L8f83Bu3bp1YfutW7fSd/fv38fNmzcvfF8UBQ4PD9M+6+0v/sW/iL29vfTfc8899+t52du2bdu2bdt2BdqVoLX/mT/zZ/Cn/tSfSv8+OzuLi9Z4DOAUnolyxWlU/FWTw5htjah7B8QSGz4pRw+aV9Jk5e5wueWrevLSFFKBNlEtcCErSp/HsMQJcsMBQnZQsZjEO9sUOwEKqKTFxVhWUQ0epLBXtU5WjUklR3SKH+ikeLE5cJYsVNJ4beiThapZlt7owdpbkJTbthbntLROdfR4u2CTMvfuQaS6P3z8Hua0PK9V+wCAsRqjJ5JfBCYGWw+RYlC08L1zWHHbxDDwXVp0EuBoBsZIMSde7lyyaqWfV7jo9QDxuTc0FMdC0sjii/KMe1z2rOyG59Vl38tLtgmpL4zG5FZUHMDjiBBM9mY4PZdig7E/plUH1TBxW1Q+dMCDk3hXz5A9EgDsV/FiivuRzv/w2cOkV1nRql66Hj6VveE9qj5drZQm0bpMhUMXJDAFPQEYwxpPs4J+THtYdcAei246xjhX3QqhjDFLSbFoVYmyYvqJeAJ1mzzT5TKOnYObLSpSp4VrbW2HqojPvWVMaYkOYJKuE/UTBGgmyF4rXozH27kJ9Th6hmd/7X8fj2cmWJG8I6QWrZYpvtQyi0M/BhTp78VJ9Ap3P9YDTIaeHpNWf+Zh6VlZAXtwufXZ37k4gcwkOXEnoQLcUWUxUd0CN4/5/RscR+NrmLMoZM3nZdon8GW8IKuEMGNTQdOWsWTlFcyI6AlVUjqUKIXQRQ/amBq1xBLpdVvvhvdGrtUbILTyD37ncX6+gO2entb+67pg3b4d4Z8HDx7gzp07afuDBw/wLd/yLWmfhw8vSslba3F0dJR+v97qukZd15e2G1PBmEHaIzQyNc0H6EiytL1PsFNIGJxKPuvTit8mt90PeVG5SsalBQtD/sSF3KvE+MtIGZztVBiOrVM9GQ2lZcEa8EQJg6pMYkoWqqKI03LX2ySvMhA3hukzv/e0IPP1cd7A73IQTwkRjBVCz4A4x2DbA0sOvK6Lb47WBjXhpoPrEfo5fnIf7WmEdwpey+50Cuvii55KjGV9KQy4TrUYGVZY1SR2KIMgk9mMMjbOY7TiwovsmUhuUBjyocRAacOQB5PLJ20q77LeVzps/n4ddsxzrgpOILO712Eq9rVQvvZ20JEwYSWQbX2CcOW41gNPOMHczcBIkaB69iux78++tcGeEYyGzz/oocRNuhGflEt67meCSlDvYkWqgA9pbI3I8itLA01I1nUOExo3Qu2xroCjlaY56QU0KFj9VuqqwQCcOzFvCGPaEqOJTKTxy3urU7zQRObelJZG5RzGXDRHS+YLtcCUUOCUArV77XW4f/G34vV9MYYXcOKwErge6TYvP/8G0BGtRUXZozB/iPI6+4tsRjwG7Aku9G8+Fv3Fr+K9rX3m34fs7yKDmmVsFQDA890k2XrvsYOiaoh+mzlXZ1+B+dbPxb/1N8frCg8APrtzvjdVNUqroasZgrE+iftOmIs2Ho0w5rbRKPazhkoLlbxzAX4QC+fkZkxA2yzh+rwX3r/9ukKCL730Em7fvo2f/dmfTdvOzs7wiU98At/zPd8DAPie7/kenJyc4FOf+lTa51/9q38F7z2++7u/+9fzcrZt27Zt27btN1D7uj2s+XyO1157Lf379ddfx6c//WkcHh7i+eefx5/8k38Sf+Ev/AW88soridb+zDPPpFytj370o/gdv+N34I/+0T+Kv/k3/yb6vseP/diP4Ud+5Ee+boagMiW8Ngi0unRHfr9apuhskAQfb7PKxLRoQ3S+YxtW+U3r/Tqt3YfBM0qB4Ix6mv9AIDqFi57V+sE3/DQpWARtkrWi1PrVDBWMjTHJ25IjKaUS2UJnnljG+xi2CbQlwVIotHcIRR1Ei7aoiqEycRPP21mAKCJOFIO+5cs49dESn+7E39azMZaN6AASBqo8JhQcFYp6gE4WWU+Fgspp1Mywn1KhQPUegWoPu/svAAD+1MMP4Lz9mXj9xqR7X7Lze+2T11vRk7AYVE/E6xohg2zWmRZZ/yHznDYpXMhzLTMSx961CJGa3QruNEJ3ikK34XSJBT0nzVyqSa0xpsrAKc88UkZQOKRkM+3AtCTMjuT8NuHSAos7O0gyD7mBHj4IOUPKZFhYjrfTdp7uzSjR7xNBWwNNZ8V3gwcmBIZghpyEgh6l1hZlHb2ekrRwpZqkSbha0mNvK4xJXRcyzeq4w9tTQl9F9PJMP8P4PH4/odrHrf4YNWJf79J7NO3rePJTPx/Px7yodqXQrKWpuBAuwbhBYXiwvN/yywHjJ/FHQgRpHwNB9DO5uwfSoPm1xB0uv+HDtjwlQjaWAWng6jiccHjco6TeYvHvIv/d3vsy8L8jB+CVCY97hJNVfMcePIw8glqPUdUcK6I4ctxCEeGYspL0qKpRES2o0zPMUKM0Z3npgQQdaqOB1dN7V8D/DwvWJz/5SfzgD/5g+rfElv7QH/pD+Dt/5+/gT//pP43FYoEf/dEfxcnJCb7/+78fP/3TP51ysADg7//9v48f+7Efw2/7bb8tJQ7/xE/8xNd7Kdu2bdu2bdv2n1D7uhesH/iBH3jfZFulFH78x38cP/7jP/419zk8PPy6k4Q3tqJGKEqgj1acp35g5BQzxiKBYp/zxyVuFYYy9+/DRchS6NbKi1zcL8eZ03f+4vfrCXW5NSLncBk5Q4tahVYDVzQrRb3ubZmMrJ5UMPSwj2xTGT07Y8SnC05BX+WA6wzO7tGSNR4lLeu+kJgYUma871nqu/gIJizGt2Cc6WB2HU8eR4q7ps5ZfeBwyPLfmum9QYd4bgBn7MPTELBTRk9tNIoW3sI2qFhi/JXxXQDAhxbXsGSMzekCKwZ7E/vWlVF5AVRuR/SS1x+9zqj/F2KOeXex/zYRKuR84vyMAIz4DKtr8T7C+aNUPUDIQPMnJzgWxRRa6bPSoOL4FYu9U8CIcaEnTAy+FoY4IMXJoQuHQE8tdUEmkS33bZ2HJanFkMhgg4flzZ81dCmUgqZ2XsH7mYxLKM8ihy7AU/U7UCOwCCl8hlFFtf6ihIFY5yRk4Dwpz7f0tpdLhTGlSRwJA27Z4r134z18lASbW4sakyPGrhkP9F7DueiJmeozAIDVf/xHaB+RPEA1ilNtIezqKkvgXaec6+y9kRBgvQAc+zfwXQvNMJouxKnWBpnG5XGXt5ywk5epASKtXeJZRmUlaxqmHCx8QneK/xADb+3DJ/D/zS/F7/9YfOfa8QhvHb0BADh7yPI+NWAdv2d8eaGb9OwmY9Ec1FAiRkDvXWtAiRsoc5tWYLYFJjJ3aAXV20E96CnalWAJfs1mSnhlUqVT34m4pkqirIoTUwhZpD2tKoMUkYyGgMuBUY/LHZXDe3l/r69TKpsIVfrf2jYZaOnhDtBdkAVGGwSVayusQYOyn1LQPKDOv+ZnTuaQ/S5MtmsIqS08zIT1biYU1S36VMNnUTLIXUUlh3i8mKdiRjOYIuJSpaZQ6HQP5VlcYEquRPuqRkVWWsUXvQoVirFAefHzDdVAVTGiPKZc1OHsDh6fvw4A+J4F5Wde+AjG3xfz/HZefYjyLEJZrolYyZOjI8xXpxdut8MAtdiMTLMOmwYMz/3Sgp81gwFOlLFTao3pXrxuxzwhtWqghFD0OJJVHi6bBEsuOePMw/DCi8hsB5sm/Ef8wbUaaZYd8eKvty2sJsJBoWZdKDB9aRizfiAISc4M9KB+sZC6cz4kQ0ogwdGsHEgtUFgsaXgcSCfolMNTCpyoDUCjpOZEqLSCkfITJJw0nYKUsFtJDbUG2D+K319/GPt03+9BnzHfyAmRBSgexRBGeByVIPCOA4vz4oxPyfmBBGSzBWv9ffYYxkl6h32EQYGLKig5A1D2Xx8rKvvNhW1r0KHFAFXn7FPBrfKyN4oWkm48lNAI59R5CQWOfvLfAgCq/97LAICTwwqf/VSUpSr4nuq6gicLRSA/ZQpMCeFOZFsYTDgRyS1LlZiWKU1QK5Rl/H48ihc1KQMaq9CbAGDIs32/thW/3bZt27Zt27Yr0a60h6V1ESvuiqcjAp/KDEs7I7jKD66Oz3nobCJAG8JlqyqH8vLmkxs1bFuncLjsHwGZ5cTfKj38XhwmozUKeo1CBQ7QQ0KXnFap9Bt1YXv81Bd2J+FEDdDQJWgrIOWCpTIjWsGScuyY3zPyIWnJwTCQapCIE6LnVk/rlMNTEmoqxxVmYwZsH0Uz98bZDJNHEXaoT1nxVAeM9qL92FFH4HpZ4a0yelGjnecBAL/11gv4h/QaPnISt+mPfgvsja8CAG795h+CefGjAID+q7H8yYtW4ct/LcZM33rnV+J5g4LDRdjMZ335fqjFpgB6bk1LQoYPHjs392K3tdGbcqMqsXbOjyPk1vqQBGctD37UATvkCocsFcPTrj4X89oAhsb0mOb3S08avHU7ntdTc1CrqIASrwvDNiGmiGh0qYfSKizf4q2DoiVuGEAvR0WaTXwAOlr0xT73Uyp5YEmzUQGBz66m6kkoFQp6K6K6sWwsektSEUuJXGuA3/Q4brvLnLywdwBLYVfTSm5mi/BePB8LGMMchwSliRekEIk3wEXE4TL9XCXNvwtQ39r+OYS3icKezzGCm4jnNi6Amv0rkGTb+wTjpvc1I/vk73MthJ2lhWU5FkV0yawaGFZPfv2f/AcAwDvfUOP+56Mqij+gt1TUKJgj5zm4vAZGI75/lajplCgLIU/xc1KlyUc0QrUqYOiNFyPCyGUN0wB9lg/7a7Wth7Vt27Zt27ZtV6JdaQ/LlyOEYpQ09jUp7EENSt9iPmoAIVy0eQJUUm7vM0t6PZ8y/zuxm/2GL3HZw/LZdwq5ZzWA1HrNbAjKQzFqKfThoAw8Kd3ByfWHpJghytYhhOTGSWKoyUDxwARiBQUn3lZmbQ5xrZbX79EwO1go1NpoLNkThnh0MXEYuUiTNT5a8+3IQUn2OymvmJTYY0xs9uUYCL62e4Dx6zHutbuKlvZ0prDDQo8tgzE7U42qj+eoXvgAAODbn38Zc6o82F8iYP9/+z/BfyTuV3zuX8MfkkL+3/8fxnvaGeOj/9v/VeyP/+r/AgB45713oclwEC/DIiTCRO7BrlvOBTbELjGkO1SSAhAcChmj/EVR1cAixrPmjN8tMDyH9Gxsj1JKpPMKlhooSD9vlEvH7Zg4ukMr/da9Dq/dppVMc97aQRgkv+iwlnfhM73NJeNujVfQ1UX/vBxb6CRZH9BSlWOP43MVekBUTLroWVd+AsPAXDmW2KRBS69L+m+xAlpq4dVFvIYPLxU+0kfiSqVu8d5uoBjtx7+76DHodgnMSSDqpXDr8B5LDOgcm4kzl5GVgeouHtGgDzI0h1zBZvi8gLwg5uZKmGmHqMV09wBKWNV8/88XZ3BnJ/E4orgfkKW6JLADmu+L1TYxcBxTf8oAGAa27v+3sRzJG28ewtyI2/opKweMWwTGnIS0VPkSgdqWE85F0GrwsOrofe1Md2B09LAbeurOFqh4vJpELKM1ZqMJOuUBHONp2pVesFBUgKnT4pR0GZ2DEooUXWHnfVqcLlD5ZAG5SCC81NKEtAH+S1Bjtk0YYvnxVP4bkXXKKIipsmfQiXGThPqVzhQCuAh7NcCNiQFpknCmCOOqgJR7lhZMDLlW8mLl1ye5SwWAmmUZJJDdao3CXmQgXrdj7LFm0MzFBavGOVqe5FQC7lC4SQbnzQfxZZr+6ycYvcsM+8CKstc1rtn4Athr8T72oXBz76XYLzXFUA9n+M5mHwCwCHGSsp//tyhVrL/mnrwHrOJ2/Nes/Prsc7CjeI0f/c9+IO73s/8a79x7A+x0AHEiMuwHgYEshslO+i1vLvuU35Y8YD0ZAYRNQ8dXr+mwoDAtOa7QGKAqqdvVBeBsDTpxYTDC5LnNO4Daw9BUt737OAwEADF2ul74F6J7eiGPMGQLVmLKdfGvxvfY00PeHxDVaJh6hS4AthdVmWFh80oIUCLhs8CoiAoou8QvVVFAC8OCF9M2Lc6X8ftneb5XrMY1MjrUKBokSo8RGALQ/PRnRwgL1nETkklQiViRGx8bEP5LzLyAAcLL89jWf6swTK6bjiuL1EQDMxJO9vavAQBGu9cQaiHJEMqragSSUObz0+F4ifkxwNdCbgnOQ1HEWLGIll8t0Ev9qSexQ9S9GnpPJK/i+9c1K1hqSrUUPQ4YpWvwNJRG1TSqYgAYVfF93Z3tJPZoyblj1QYUXOzGrNQ8MhW88yg2mgqb2xYS3LZt27Zt27Yr0a60h6WKGijqVLAweUtwUMQnvCR/eD+kYWUsiE3lRTaFAFMehvw7XPS2ZB85nM2OmxlBl46TEzrEC9LKJA8riFagVsN1JRFcnwX8BxdRoEAJboes7GCeh7XeFDJYhJa7CYBe0ENwkSyhSoOaKgS7pywvcdxip45ey9gxzyY0MF38vqN5q6BwnVn114/jMaq3HSw9tqUIBTdL7OyfxH+0N3hci7FmXpeNsNLqnS/BsGDd7j7Fj8+P0b/xZQBA0TXwq2jPhkNa1W/10AwQe0Iur3zXN+LoJ98EEAVTgZhSYEU0mP0xCsCNb/72+Nv3Yn7P64/exIjP7u5+tPpd1+G1JQVs2aXV3WfgpNggy+W4xmPFHB7xsJYYvLecHHDEv8sUdUciiggl+7jTuDOS3Kf43YuPAHF/NGGx2pSoC6oy0OvyLkMIcs+IY8vyt8uuweFMgu7xs6pKGAbiQ+vQsvqz9izw6C0Cy0o4sdJNi5I6dDU9T1ONoJhXaegpOOvgm/jcX2D+14vOoGBqQ6ioeqKKpGyjJCdzfg5NAoYoT/T2cupK/ndePXqT55RIFEr6angm40KQBwPHd62hxzjvh/6dUNtxd2eK8SwqnJQHzAHY2YGhRqMST0YXGK9iv7QNBXadvUi6kL8THOTh6WEZjqi+b1NJIEUvtJr3GBE9KfmMzl2DZX/M/pI8QY+OqUNLettjU8BwbIm6yXgyQiBELZ6W8xa6ittmJNgUuoAeK7Rq62Ft27Zt27Zt22+wdsU9rArKjBMmnkupeyMaggPXNsVxUlwCyevqNyQO5y0Pe6X91kNiuJwZH+IJ43nDsG8qcx0yTFwSoLUZsse1aMApIJgLv4XyKSiVilIiDFoXScUjK1md9r98b7n1mPQIfQEVHQncDkKqUJgsGUC9T43A9yqM2Im1OQEA9IfvoJ9HK7hnImrZWdxYMojLqH+wg6cgZJa+AxbvREWMXUViuN6B+8yr8TjPxKTHrt6Bq6O3pF9mn/Ue7nHE+Z3xsFQZryTeo+7BM2ZiDAsIeocbz0Va/NvvRE9rZQx2GYO7vR/11178L34fwjd/CzsuXuzh3/p/4MYHoxeY6lDYMV48jknT3a98EgAwvXsXeC8maDZncXz2S4cHjK2wm7FC5r3z4WgAZ/zHfraNDPb0Ih8H4A7NUEnFHJ8CBT1KzVIRdVEn+rNjBrFHSP2v6KWrCijJ7BGSTtu2gI5eTaKoFwpFLSd2g5ZnK/IRHo6ehg5MQO4mcLTY92ckUIxHWFGyUJQbgnfY7aP1/lIb7/SGG0Hx3RbdQ+1dqtzgl/E5uLMVVMs4Nq/J5XG5i918YZtHRmq48K4zrsT9pjqkGNzsWvSSQj1KSIddxScxWi2gmfg8prJHfXAdioU7cRhjWNjbhyJ9XB6I9g4lSRcVE+TNao5gBU0Zrl9IY947KM/+6Ond9haGaQMS5y+aFUbskCk93hPfoCPZSrJpvHLoVpgufAAA/7hJREFU2+jlLagHuudcUvNPaTlGoyqlYCTfya4Z9E5F6b8YoSwCzNcRw7rSCxbKEkFXCKylk6NjwuURokAkHlz0mYO/zPhy6iKzD7gon5Lvb9e2bWIBhezCcndWvtdhgOnAwRJgksKFTpFUlUXHJY08xySFkOEgpUG05ECpgE76ww9ssrU58cLfKWPfeRRHrLlEEkRxbjF9jzDAF+JiMPuqQzFnCYMivpSr2+/Bkoix0CKWe4Jr5/HlEZvigiqArJNwWBB2POBL0pwXGP0qX6IPvgEAOPq2FZp7cXLaUQ/StfdzMsIKoOTb3JzGmVAbj2qXL/AhjZe3T/A8J83XOUGX8Lj74jcAAJ77kf95vLAXnwO+EBc0W8X9bvyXfxw4jIu5+9QvxHN89Q3sTPYBAO1u/AyLewikHVo+w3k/LDp5va51KbACw+Pus/GU8nR4vFMMQsgNJ5K2i1AmAJiarMK+REX4bAVhz7l0Yptm7RAVKQDUxL1a1yeikxCFTKFQsT+UH0gXls8wFBYKUn4i3rG2NTrWoCoPYzmi6XSKoye850QocjggBHWTEl9Vr6Cs5P2xr1yLwIlZzU/iZ2vREnKlmAqsx6Uq1Dmcf+FzzSiNZIr4rx3e+7W9HYzuMAfwelx0OmOScHGxjOP3sF3BcJFQNeHM0Q4USQuYERIc70GReedZ2duMxyhZvqNi7S3dxtp0sY+yv+X6nYOSB0kCThcselqrUkrGGA9N4s+IhfWqqgNJpmna8fBoltHYWBJybbsWFUkUlguc8y6tcon4ZRRKHqggA3o6msDYPi2qT9O2kOC2bdu2bdu2XYl2pT2sUBbRKylYJZXbvR/01gIhAo1BOy9VsPUDBJWK4gVcCsiWuEj9Bv+9Kc8ipzXLNvnNJg+rwOBhBS2B7AKB1G+pLhygkkqFY5BcBSTFATmL0ho1czgKlltx3QoVf7va5FZlbZ3WvoMC9eNofd164wQAUN07w/gzsdqq/Qx76b0WDZPZCgqe+q+eYrofvbI9qlbUfQPzJFrYpbt8MQK9lABaod8yuD4yU3gG89UvRdkC9dGHOGmigsXZrWjJfqQbIM/eAkuBG3nwsQeMQFXvRnUJNdaYXouW37WdCNe0LuDuf/EH4hUeRkKJn0ygvvjLsY8Ono3bVnPoM4rZkisenn0eiqSMehK/691qsLo5PtsweFibPN02eZwDRCbkjJg3Ff8Uj8EAibpeku7RBY+KPy6Y96QrjYoVKwt5BxQSrhPkQTgMaiXc1sGh4NQhhSir0qCq6fGENsFh3Yoexa5LnpD2Qgpx8PTExhPm8OzNEr4l9GylPO6wsOBN0r1N0wKECekoIOgOvqX3Tuo3mg49+7zJ+ko8evn0aoDSsu69hEJoANPkWVH14bm7wLMx3SLMiEIoDcUx5ucRAdDtHkBvJEzjmNDFGJ7eFEqhsuvkNuoV8/Z8GMSshXSlNUSLcROBzPc9NOE3J6k/vk9ixgVd9ZHScPSYSqYK1FU/ADrJc3NYUSNyQa9xuWpQckx1Nn7Xdm0ii8hvNTy0wMwieFwUGGkN4y5qpL5f23pY27Zt27Zt23Yl2pX2sPyogikKWJZLz0jng4VN28j4MJQqz4DrHJuWtr4tJ2HkdPR1JyX/d/7b3Eq75FMoJDUIrcUaKgevS2IFADQT+TreSeV1Km44IrV3vH89eZDGU72+tCno2qrBsl+/v9x6ERrsuFKYCBX6n38aAFCOO/RRqg+WhevOu6jQACCpgOulRUttwJL6YaoAwiIPXcv5Ln4qAEoUw2m5+2kJQ6WO/r0TAMBNP8P9LqpwP2aQpy0NxIc96wYtP/EuKyisSPhwpNNPjIVaREv4G25Eb+q10XWoFz8Uz60jgcJ/+VcQHkf6vKfCymi1C/e56HUpJlnqD34Ufhmt/JbeQ9XUcM/FmFj3qU8AAJpgn8qjz78XMkVx4RkOfvxXOcCfoetRBYUZlQ6sFsWIAgvRg6OFG6yDkbQHCYg4P0Td2fet69MmUYsvlQaqVbrOgmPQLqjMX7VQdCVH9JirAATE30wYm3r22h38iv5svD/GNqZtwPUq3vWkpSJDo1N8V2LTymqYBeOook5uQ0oxEW+qxeWE4LzcTiJiDI5OaiMA1ziRlIxX+WdegL57O/7G0FsKBqFZ8eAkfYwsDNUgwBQQN9mHTjCPJFt7oCMxRchUqyVCw/itkByyoqNRyYd/S0g8Y3lJ7LoLGorjQ+abCYBaEY2RNBoFdCkBPX4WhUHbU5WFY3u+OsWEhJ7OxT5fdQ1GUi6I9+490KTYeuzpsa6gCg/tnt5vutILlq4LOGMAc6EGJzx8krcX9p91g8tvLuSx8Ht+9rg8mDN92hT4tmEoQ5FXm12HCfPxrrLvdbYx7ZtkcVQGBfJ4IcByMKX8FG1R8t7HB/EFqGYjeMvMdOY9KN2hZF7aiJDV0l8uV+AAMI0EYgJMg8Mud+xflQUTYOmjtEjNkclbpUB1SP1fLPniKaSKuHkphk3qAVpqQlEoVrl9OKlk+y6D/acjzMhI+uX34pT0PRro+aAWGBhhouzQ+gDH7yeEL+EMPNlrkyJCMx/7vt8GJ0HjR5EJoL76DjRZjoWKC5zXDvphVNNYLSLEqLWHccwtOo/X3x/soRK23nORZNC+/valXJ9Nwf7c2Mn3lxdYnmWLkEqNvEBlgaACdvi8VofxFxMXcM4V3EiNLg0ozk6e9DjvQpIHEyh6sVoNlLCs5poQOorCJai9Y02r0ilIJmEhQqxFkYzHitD2zf1DTEo+Y+b/7ATgoGXV6VNOqHMNM+VCkKiNiC86AE9Cifcp6/ICXL/J2Ly0LduQk1/Gkzg+Skp+6RvX4HeiwRhkLnIGhuQCLwuXNomdGkYsfzQyCIS+RYknOIWg8jchEij6VmSOmCMJtU5fjOdz8ulTrpU3A0N6vUd07zGiPIr3UilaoVljlQUf0vkasgT7ZoV+zHmBjM+ub5MQbmAerHMOgfigzqDNqjT4OjgXW0hw27Zt27Zt265Gu9IeVggOqqjgaRGLO+6tzXIRZN+Brp7TFNY9rJxMIc3holcm+8k5cg8lWXGZt5FgBzX8PunRhaivF48jP0KitWc3Cy+0UKk8qxzGoxi8NbsM2NaAJ4XVMEBqXIOa2IZc3yS7Twn6d4gwDQBQPxWl86hobVux3AA0tLTmGdzis7/j9V2EXIAovin9n5dVSDBGRtMVMkAg/KDcEl4KB1KLT//Hd3H3fxQD3vhszHF6MgLqc0KfAFh8GGM33OeOWLPSzQrQvcCWtAQ/+AoC+eDqUYQB3dvvJfKOXKs9PYMXbUV6Avq1e2hYlbWcRugonDew80gWGU2eidv0u4OYKYZrWTf3NxEyZNd8m0UAOTIo6vjtSAXcIfHgXVrS9bJEzXIPq1RJRMd8v+zmnNIJNhXobbFYpnSRIB6oBlQhpCadvB5LEkRoDQoz5AoCUdVCE3dqOD5v7uxD87ooK4kdAAd01avE7S8gwrAqsQMwuEVCuvoaHtZ6vwVsyMkKl+HyGkA1i6Qcv7vHC9xBoKBzKkezahHaOG4DKzWragYnuUjSb7ZP6hxeYCEYaH8RYVEhwDrRYuQ5grpwb7kKD0CEkV6qjPPgBjFjxXmkP2sxJs2+o8qh6nxCIeQsPviUnylCws569Lz+hp5k17awhATlGQEhy1vlcZWC0epCRfRfq209rG3btm3btm27Eu1Ke1idb7BX1HAmWje+iKu675aZFRI/LQYrfpPKQw7Xritd5IbbJotMmkNGk10HxCHJy7x2fhYBMJKtDlFW1onDLFqCAQGaQXRPK62AxmwvRptKarH1tUN5Fi0dd3zMi/FJmV2qQow10PEGVtk9CUFBYlgjP1hYDc2bJkRF7rw/XHZPQ9xtsOxyGn8eGwTPmQK7klUfMkWPhbh2GoZubSAFuPnXb2H0Q5GK+0Fq6N17YYRnP7vkOVTSyhNPtlMBN3lhS1J7p1MHR764fjHS1ctqhPYJvbvT2EtF28Mx3tKfR8u5GBkUJM6IhXpml/AsMNiQYj1ePQNfsaDhDu+3qlNwPneqN5Fz5Bn6rE9lnImavMLwHKSDSw1cX8Sx88RIafsq6QAadrrxJhX6U/xxl3kZck2rZpksbfFutCqSh1WUGo6Dywvjoa1gJhLQjL+tTAHNZONTBuxvznZRsGSNIBi7AMZWbomenSkQ+Bwk7hZy3zMhK0PMJo9Nr3tTmxKH1/8GYl9qJvCqnehhqXokIjQAyU1YPIZ7+Hr8/iQOrDC9jlILhV2UIBqgFVUeeulmUKuR9BzvPSzjUb0b3iBP9ZZc3ECaC0gBLS9VK5y/NDfZ0w61ZjkhPjetBlUFeYQKIfW/xNH6xqHl9S/oqnetRVcynUgP5w2FjC1evQa88kl952na1V6wlIcyVRL29CxH6pVOnSDP1oVhUcoD/JsWrHXihOwLXBzo6wPc4nIeVv693vCbPkSYDBgy+1XwMccC0Q2P2wDNiZfIFWZlhWo/Li2+ioOm8C26996Jv1kQjtEZwYGjpdZAzZuRmkDAsGBJ+YN8AuyzvlxnrLXZPan8Uxa2bJJNkISW+x2MCfkssmu2NAMq1FBUWxAR2dW9Ev2/ewMA8NK3xCf38zc73MVwkrzeFxCPu0o5eVy4zgIqYoaaMky2GKF8EqFAHMfPEFzKgxFooz9boSXZwrBnzk/PUO9EmEWYjYtH72DKassNRUYPD/YxpzbTkpOLwnDvidCa9bnkY1lc7vN8XLacaXoD7D4i446My6N6ELPVRqoHu3RCLQk0GU9RICnr+hTYV/JbbeClqnRZQHgGgiNrO4YhTXCsRIg3klMAYHke2Wdm/yamdRyFIoQ8bYGpPK9MAUYEop0QRbwb6sJlq//6O7cJEszh+g22ZnoeZYGkTBGkngoAJQsVhZ1x/CiNGTXnojFvkoyR6+MiZsopghSaE3JLadIkJIxA23foKdXRiaBtNqflNb4SKuoCAnF1J4ZeCAMpQy6+BwyVgQMXV9wo4KWWlVxeCAkOl7pp3jp0lNxa8qH3vUMnFZ+ZoxWCH4g8MsiUh1d++PdTtC0kuG3btm3btm1Xol1pD8vWgDUGiu41FC0fVcDTL8hp6GJR5J7WuofVZ39LC3i6IK3HQHVfz5qX7/Xab8oQxTiBga4O5ZAwBgkeA5BMsore43ikgAmz/Wnd6rMzBOaiiOHSdwDRn0QUqMPgWYk3BQylMPISCyl/JVn2KuW3ddl+68FtYOg38QouqH1kKiNCOZfvCzNcs54RehkpwFCAl4+8HB9i/i9ijtQbxPRem3j8Vl5Eo1y6CKkkXIWBji8oVaeBGXn56rkICeplB8+8k3AcYT3nFApRcaCQaaUDAi37ec+id85CMZdmNInX/+DhO7hxLZaSqEgQOjs5wf5OxAdPTk9SH6zDcHmTZ5MXk2SRWSyzcXeSDmKwcyx3LFBOC0PrvSRU48PwdJIYtEaSrtR0z7uuS+NSvJxSVUl7rq4NFpJHKLKXrYYhHCalUHp49MwPC8yB89ZhLNR/QQA8UKXkSXpQuhrccSFd6JAGuCK9XCmV4NJNkF96hze5VRuahoKSarscoAoBQWC98zgWw9FDaJKeAqn9yjWpwKRanvAYkyR0HRjSwGgMwwrLllWwbdPA9UK6EIhTpffqgo+ShUMkFcFyD+tDmpsE6gs2wCyYh1UzLNEWCP7iHKp1JFnIPQNA267Qdiz5Q53B3jlYctU1BYq98oPANZ+lUw4wSFWwn6ZtPaxt27Zt27ZtuxLtSntYfhfwp2qwpgqa3aGGou2fky7Wrdbc49kU9ksWWWbmhg0WjfzWYfCsEo00jz1cSBKOnz2Q6KM2ZJngYsmkYo0h0Z8LelhqZKKeIoZkPLQtPGFoSchz6qKCBBCDoZrBHYmheQweVqouEi7H5SzC5f5SG/owrMWzcJGxnfpNAVryd/mlUTExHAD0/vNxYz1BT7JF1TJb3hxizvjM449Hb+jkTg9lBldXdPFy7zHFefi5q5Ao3dUuFQzuvwn1MKpo6IYYP/RAXaaCuDVAydiQGNqmqtAsJWGACiZFifN59HRCHy3xed/hOhNlpdhdH4bAuPSbVdErRnbteawrxfwxeM7v0WM61B4VFUeWVOJw/TKxOJIXbAp4cXsldoKhSGDBs7SuSWruKdZlon4eAJiJT+QZSf7uFxaBHsJKCDbBYiVCgCuqf9v2guo7EJXmJd1C0fVzxqf4mfSIhoa4eUlhvNApYTlHTjY5VOuIyYbsAmgoBHpYWoglrU3p6Uae73yZ3FRPTU8FhSDl7XvxXs4GhYqaicg7u/CaRA2eo29WUQUdgOdTtwgp0X/TDYVsvpHfWnXZm1Q+wD/keW5xzLZ6EDjhu2nCEEN0jGc2/QqagE4v95mRKNKcG4B+CMzFa9EB2gDBbHoam9uVXrCaPYX2LYuag9RV8vLUcIoDB0P+0XpoT+PygnUhl0p2zAOa8qkuLobrv81ZhWmyzibwRLoAkty/sMBCMFDGXthReQOVpHaYZzUqU+BfRpVbrGClFpEIAJjsBpOQ5TBIVPbVpVwO5AvVsG0jr2dtkt3UNrGxXBgm3Hy+LBl8VxNODKWBJiToqVCxmtTQPqoMzOIHKvUWzkigKOfDIE+QsBrywmzW964UiRnCIsfH0F99HQDQnkVSRbWzD8e+rnl958tTBK7wCW5xHh3hnJb1uEamwKpZpu+B2JdLVpLNSmJfXuizdzpnuyZyRjbeZNc5D3fNhFTKRRQKChUSO1CqKSPYdCKl00gGFOtYyZbQJ8gnDdosn0ZVGaLNa+2aFXrJc3NxYrPBJxix4WTXuy4xLuXmqzAYad4Oy4rA4FILTukSoYqwWpjEwVBNVhidPkrHAeI4SO929kKuzwWb8iaNMUmkOiQ5DwuwtIadx+erVm1aoCGQYAgIHBO+oYizNgNsV7Cf+z4pZrRc9JaLOSzJFin/M2S5TWGNQAKBBHlP7FLnh2cSshfRPuI1UtYmdBhy8WQODQNL05JM07QLdOeU5ApCGlIIiW1K1iFUMsh7jh2rAe09rN84m2xsW0hw27Zt27Zt265Eu9Ielp0AiwlQnUjQlZ5WPYXvo1XrFXX1MmgrhwbN2jaP9y8bkgzdcPE3wEXYMbcZ1inKeVNqEOgtRCgy2MGalRIL3sPQrS+Fwj6dJkp/YMDTn3WJfi6tQkYlFyszs+IFKsvpvvl99Bu2rfdHbtmne8v+3hTwzr3VdQhMacBMmQ1G0V/VOSgfg/P2LEIr0zDDCnHbjWdeAQDcPevxZPYuAGA8DxcKHkq7lEPngZ5Qz4gej398H+5BLKOyIiQ42j1IGncryUXxPnH0GxEZ7h3OiA9qUotHaoDcWlF7ADAXpYMNHuomKDWPUW8iusj4XWSVYEuyY0SvUivAJChtyL0S4eQghQb1kAoheYIBFlbGLGE5rcvUB8VYwzBXyVOA2XUWRl/juUlDtx4Vz70SSrTvUbFckIgkzACUqSMImzsFT9KD5v6oxgDLWmA/ajXqVmO0iuNjcsJrCWteVN6RWV/GNJO4Q8kdqmoERaFeEaZGb6GllA/fQzTdgA/z+QYXEtQq73UfHDqBlsECo00DlHHst+KlLxbJs/dSRRjqAhok40Ic9Zy9I4SHgA2FZwOAh4TNb3Kg7ASYidDQs/3YhJ7f2gZ2yesXjVM7kDJSlXelEUh/X/GzCx3gQkI0nqZtPaxt27Zt27ZtuxLtSntYJ+US+9MxrAR4i2iVhHKOIEXRxMPCYFnkwdWLYduL3kO+33pANv8+90rWvSgPbIzpJGs5AD0tk0JIFTrA88qE8qrgUIjaNSOtejyDo7WqJFl+YZOlnli/Gb6dMG83UM3zROr1+8zvPbfmNnmS697ZpvvetM2Hy95ZWRjomorPLISnzs/QHseiiKaN/dMVE1iKBU7O4hV+4PYH8fbL0Vr94KOTpJW4R/07FQb3MhVADAo9yzdMTx4AAPST++gDaRn03perFQy9pBXVqRftEgUp25696YwC5Qwx5f7nqk9B8txqFQ82J7rkgX9gTclA4otqeIZyjPzvZdpdJwq7eIeFHlIcJAylVJEC5poxlLoCCl6/6N8FhORhleJhFWUiCJmqwPM3bwEA3njwWvzeeRhLwgy5UTY4dHYgLgBRe7BmLFGurwpR5zCeXLzGAqpk0HIcE7RVPQMMY0PXY/pAUAYln9MOa+KEub30jrtcxTx7b6SJQkxRjqJYQbwI/tgC1JAU8gg6C0d1ddCjsK4H83NTrLZzFrx1dIGxTH8EXTDWKYSRPgxJxylWpGHZ514NRJ3c07qUehMGOnsigwEI83gvBVVy9L5DUV18K+P4ZKyWJ2vaNmkJtinHRQ99KbFdOASOzIapH33fQMEP8dCnaFd6wXrQ3MPz+x9Aw3oVmgMdZhDE9YZ5NM6lB5QvXOtQlcfmyXh9IQrI2IPZAHi/HC6V/S3nsyCkhKH+D1SdgpVKAt4OMFIOgFVXQz1JL41iiQLb+QuTnFxXWkyGuP5QISIL2G9asGQCzF/yNJTVxf2/1t+b+iDfZ73PTaEREouNjKqzd9A/OeFv4iTVo0E/ipOr/dIxAGBvd4wvfTTWJ3r+4yfpORW8GpddmFxfj2GV6OdRoaA8foSaV9tIyQsF9HzhpILxYr7EmPBJ08YX/mw5T4aU1DvrrJWSUjCZ2kBSL+K15KQQaTlvJl/sctHmeI+DmLHkUrVaDbBeDgkaKV0Rn7AyIY07yc2qQ4WSUFRFI1Apg57bxuxcY8oED7Yq4Fs/+i0AgJWLz+T09AhG8tw4+xsENGQQOEoz2b5HzcVJYMBKAQXhYXfO+5gdQM9uxOufHsTPagwUhI8P4oKli0kSgJ31lMB6+x10LD8s74Md6EsX3lGBTcdpMa4HSaisrHUgsQaEBEPfI5AJKCoYyru0YHX8bH323Pnp21X6jafV2TmdFippDgEuLQjD+5lDxkLYSwtzDofyUwfAsKJzzUrC1SqgmFwEzoMaFqokfdU7tDQ6EptQFXAijhyEKOKSYdTZBffv4MzAYHyatoUEt23btm3btu1KtCvtYd1b3MfZ+AUYpvmPSNN1ZQ1fiiAuPa1+MVCO+ftNVPdN1r7CZc8DyOC1bP/0W5ovZgM5I/+Nyq5BrFat3JBXJfpyIaAi1V3ICKEsoSVrnLCH9S4FOk0YPAqBB5Pw5ToGh82wKdSgIZh/t368vN+wYb9LN75h3/xTaYUg5RRawizLOTSt5ZYqEt1klKji7q34/J/4E4TnWDyx1FB0YcQjUplnkpNuSipOaObS2NNjnC/j3/04EgbKsoanmbxshSwREOiteAbJQ/BJGSQpE3gkDyv3xKUUjYzPHkMhzQTXZX3TZb9Nf8t4y44jXb3ywLSK/VGJ7qbvkpKIlJnQqoBW8YgCPytjMHPxXSoI0eowSUF1TW+o0CFBeH1wmFJH8QMvvAgA+MwXjmBJZimkiKFW8KR8L+nqLLsVNMuQTHh9tVEoD/fj9awIve5eh54y42w0TZ+BOn+gDp4upzAsaNqTBj9yHofvRTKNiL1an2lNZk6EaGuOpCr4dAqw/Il4WK5toSQ1QXCxvk8uh5fcMTd4Vix4jQaDRyzzgAKgiLYIgcLBJMUOpP108gvzuSdc3Cl+nzxJXEJgFIaiqmM+j2qpUo0hIb8YhawSOo/nQwppCMyJIiSNTik/Y+EwMYIayaeCCuESmvB+bethbdu2bdu2bduVaFfaw3rogHO1woRSDWMpT21KeCk5Qlw9KJ0SDH3mHazTs3P68CaPKP9uPT7jsNmByGNn6zGs3KspaLH1PqBUF5MTC+VRE/vX9ARQKICxq7CKsTrv/cbrT54QN6osip+TKhKxIotNuQ3eVLr+dc8tO57ObtSve1rZMXJNvDzAFWiOGi1Wq8vK3Mde6xbnyZOR5MjpvQLmJJIzXBVVwYGhJEqNy8/TYlCaKBoJDq/QEJ/vmXAbxs1QqoHPZrWaw1GVuprGxNVirpKytSSG5t6PNIfLcSgbLlvBGpsTuNfjFjojbMh3cziYKVXVhaZdlqi4QyHFAr2Co7krIhK7usKU6uoTvl8uVCilNEmeiB6G3jyiBuPewV78nE2wJL3c8UFoV6YE6hU9k5PlIqELrJiDcVAoSeLQJzzH7j7CmEnCVLhR1Xgo5ijfBY9wsM/r4nNYtdg5j9eyOo0J4U03xJLyWK1MkDW9SzWZIQhrRKQgulVGtqDX7UN6KSSJ1roA8oeSN9dhOK/0nsblxPF8ZsmRn5TUC1zeM38kl44yHNuoIdY15rMp2kFFQ+KfysREXwBJd9KHoSCk9EdZFTClvrBNwUEVEh9lwro2KAsDW/wnQro4KUucG4s9uuidkvo+JYIsVJqQoFrA8+Gmiri4/HJv6pALJIMN2/Jj5IFM2S9npK3nZBk1TGKSI9OHkCYWueZCAaNRvBe9F9lRoTJIxCJm2OcMM1locpZg3lx+jVhbsLKFZn2SVRheBFmwlLr8crhs9a83XMD7kTN8CEMlVr5Evu8ROLkKUWFpVykw3VN5uFw43KBk1fm4wP6iu3AOv/Z3uk+SKeBkfweIDA+P532AI3NT2FFeh0GWhsdrVUg1xRoRisWw2Ejum8WwkOYL0Xp/5JBgzsZMYzDrzPXJzhZAf0BIUIhJwWMkQXLBp7RCx8lVnuu4HmOfLLx9LouNCyio9tCnKrghlSlBWOHUngIArnOxO9zZw+sPIwxXdHEct22LM07wZ6t4DYvVEpqG2ZiXOoWC2Y3ECr3P8413gJJQYCKDFPBU0VCJaVMhEJ4sRH7D9XBnUf1i1scFq+0ukx80hrpw41H8S40nCJThkhImqm+hCMkLwSNaevGeghsIHgn+52eXnU8eYYFNsF1IUP8wdsOFd3d9TOtsTOTwnxw7r7wurSJEWnQBIxEzzqjUWkgyUqcr+DQ/aFovZWFQpGdCIxsqQYsV2TR1YVBXo8RcfZq2hQS3bdu2bdu27Uq0K+1hYaRxPvNY7dBim8alvlxUqSqslYJ/RkNYoSrzLNaJAptgOwCX9Lc2id/m9PHcismPt+7RAYO1lfKsENBTaUCgnloFFKxWayfxflXmy4dlm46VdPnE8sm8n8GDuUx08Nl15Z7Wuv2jsEEhQF38Pv8q37YBGbxgNSUdNB+S5aWpKOGsRUs9up6c8c7aAVIJAlM5XGtjHy2u7WLy6PGF+8zPd6HonVjChHds26GgXmDDHgm9R8eA/ooU9qVtMIUE4qWOS4GSWnLn2bnWiTq9Uugycox8pu9Tx2RjJjvGJUgwu6eJEDcKBXcnegg1UwC0ByaiichZoHUOhaYHQ9imqmtcZ2XdQ3bgvLMoE+RDgo8uUWQFDRfhCQDg5eoD8Vqm01Q9N/AZNu086V6CnnPftqjpju/ycFMdoGsyAPZ4EeUUEIKF5Fxai0DoVkRyvSlj+geAsFtJVwIvxGc3PY1jwzYnaEh1z/t+Vyj2O6TOj+uhcKOo0HQdQA87eVjODkUOhfAQLnpW8rmehxfCZe9HXUiQ4em/hoeVe2eDSklsRcjytLLvZFvbE7ZrAyb8EdFVeKNgO4GPZcwOkKBAuWVhkifvhTSmfCrWKXmkZVFhVIwQijyL8P3b1sPatm3btm3btivRrrSHVYxGOBt1cLeuAwDsE1rGRYAjnm6JoTtdwtKuEY+hw2XcN8d4kW1bJyjk3llO3MiD5HKu3Lu4dOzM6irEC3IWnp6hWC2VAQzV6D1jWcFZKFp2lioNNosbJWb8hvhGnuSc09XTPXHHvPBlfu3rHlPuPUorsw1yj5t08DYdzwefYgCiv+ecR0dvq+sYt/IKncQSJDJkAUMF6eX1UfKyRTy7DuFCUjUg6h1iCkcr3VsHq2K/BrohIXgouuiSou28QkO9wxFjHWUwaOitCN03j2HJvXdrVrJcU25tp/7aEEBfj6MqAD2vaw9ynQHN89FLqvheFFqjMfE+RyRTrPo+UbVTifmgMGNhwessrFg2TaLH97zqwhQpcRgqYOliDGtCNYqqqlHQa5P0jd4bKHq1kjTvnIUhhXxG92BqdHq2jk6VNnVS1ghW0h86KJIpAp+hHo8BeslSFkTBoLjJvrn1XrzO5WvY7WI8S57XCMB0Qu/tMPZfGNdQ4l2SKBL6FcAYVhA6evCJ+p/PGevvXI/Lyjpf24sgmpJiWUNxyhwdycfEekzdYEiZSPFnDHOPZkC1csCMyfL7M3IE+oAz8UJ7iWHlcVQ+S62hwkXdwxCGGJxh/FMpDw2fyElP0670gtXBokGL9nrsiNUu3fdaw80lOE9BXF2mF1nEZnOxWnnYm+CiTS2f3DfV2Vn/BOLA2HTM9UGsg4UhBml4/Up7qFQ3hiwb10Kx5lKXlatIwds1+Cm/J4eL5BPZlhiUG67/aVvIFrv1BdxsgCfzhV6aQ0gThyUTsgtASwhJ2FZdGAgYCaZ0CkpYTPMWGPOLM5/ubR22dEAqcxB6Tnbw8FwgFWVqrG3Q84TjUZzIz1ejtOw4gWhDgJFcFD9coBBr0iQWLvd1wGBkJGMhZGVFNvw2byKLs8P9Fw5Y3YnMUlGR8NpjVFFGipN71/UwXLwSgcIHaE4T0zHz1MwIFRcvSzWHyPiSBUHBWio+EJof1ROM+BtHkkdrO4QgwXmyeqHSwjbmDU+LCooUUKkFZaABPhudBGcbaEokecK1NjgYQoLgBKx0gZTf+PwHAQDF6gT7S7IY57H/RkZjuhcXKjMjwaMeJZKPIjnHdksoGiyi7OC9S8y7lBcXLo63eD+ZwcgvN03KubGbL4T5/LWeQ5kbh/n7l5id2XuYjs1JYeQNdso4vvtZfEYr77A443MlScb5bDyKdJQL6DkuhJTjfUgvqBgaWoVoID1tuWdsIcFt27Zt27ZtuyLtSntYX333bUzv9nhmFAOiNXM+dmYd+icMCjK7vdMaLt0u83tw2brNVQTEYteZJZtDSeu/9dn3YrGU2Q4Km6GedYsoANASECekqVVIVVTRMsDbz+GOo1V4Jt5BICyIAQYIaiCaJG8qXPSs5DrywpPrLYedLsFTgwGVzlUDGAk9Ov/t2vE2tT4AmtaZeDQ2xFp58r3cT/KOad0qp+EZ2J+dAav9eBGjiPigxeWgtgYS3JhyTTxgeQ0VA/aLtgPovbUswlcag575Wp0TCggQ7DqZIlyAh+T61738PN0ib3llaLn2dQ/LY6Bi1xwvXQsc78UnULHjnOowIgGjqHi/K4VKtB04UK3v0VFrUMgN1+opNHOfVt0i3Zs05QHF8bukckZdlpiO42+WZ/HYnfWpj4KoQgSFQjwr+SxHCI4wkielXAeA+XKKtJvQ90DHN5hwnfY9PBO69IQqH4UCppHEoW9EzUk1/yBGhBOvhRMAQDmuoQ9iuEGNWTi0riBPR7wqNA2UFVVbem7epYeTE7XWvSC/Ydum/LqLHpa0kGDpfD5KY0Jl41sNn2r9QNkgonAOalVgSg8LTAtYosXjR9Q2lYTI/IXm7LfqepRMNZKik731qUq1aCKGEOCcH6pcP0Xbeljbtm3btm3bdiXalfaw5sc9ftXfwwu39wEAu1Qxr8d+8B4YaHUwcJJgKhZP5mXknlVquWfETZtW+DzYub5fHi/JjymxIofBI8kLJZZCsaYlo3xI0WC3oDJ0v0pqEDklev2hBly2xDclGgKXrbjccsvjPjmNWj4lZiaxh4OiwpTUY88S4quuS97IJrJK8r5cSAm1ojpu/RCvSgrymRWZihN6D0XLbroqcHxIGq0IE/iL3iIQPWEvSZGMf8TCkuJtxe/q0QgNCS4SLC6KAivGThoG10qY5OnmhJ2kIJKTKTZ8boohrsfd8u9yHbqZ/E2+wCQAlkSGSspCBJ1KgxRGrHSFUl8kKHTNAieM7SzpUe4fTKCTWoGQOAwU1S+UVhBZ8lMfyRdlUWPEtIwTN+e1VpCRNKJ8vQGSHyqpAqNikvYz4mEFJEXzwDSD4AEx44PEHq1DOOf7IsmslYEicgHS5fXhs/A3I8V9xEnBVyXCbozbhYlUSKjhqW2pqFGo2iYps0tsLGc6bUrzSM9QXY5XbkqPyePfOYqTj4V18kb+XiXSRZGlvcj+IUuf4L0bAGOWa6p2InGmQIdqzAKqosuphqsQL871PdqWBBsynRw8QlKWj1donYfzHq5/eg/rSi9YIzXC6szhc/ZNAMBz11nRtOpgOHt2fGRWFYNsS1YXJK8ZBVyEaPIBZta25YvQJnH8fCDJb032uwQTKKwBlfH8w6IqUImDF8l+kgKM2kFL1crdUXyxjpt5mrjTIrwBfrwAIWSL53qeWc4+yl8eueZdfu6YCnsUI53u8DmMa2hef7eIA93PTwESRfyGlzpdk0PK0ZGyBV12TxeYjWv4mYdPN1L0BSwZhZ4KCrYNF1QA4vEUek5ykkrThICaMJYE2n3Gdiq1pO5XWHKR6yX4rvUlhZBcckn6/AJzNINw1+8pz69b2x3AALleQ2ZUcU6uro8BsrVSBVhtkhSYY86f1iX2p/HZ2ToecWlP0RIiPaUQ8N50it0dwrS8yQDAsD966+EZvT92UVHi2XIfszoSF5w/AgDUxTTd/NjQ2NQBlrjUjo/HK4tZunuROVIeUFw4RElCBZ0Ek1NlX7eEiusjVCWjdjaQabiQh+kE6nqEB0HFk0IpKMpthEIW44wBx2eu+zbBWpJ3hHAZ/s1bTrpZb/mmRJbAME5yuDBnna7nUCKH6eV46nKeps8Md0H6vPco2P/TKQknGGMyjrXighFYVCUR5WR0WoumESiQZy6zOZYvdNv3KEKBVnD+p2hbSHDbtm3btm3brkS70h6Wqg1mGOGYAdY3z+Lq/9zOLuqKxdNoYlgNeMIdUgJCwV6A/YCLQfzcc5Ic/nUoDNgQ7My+L3DR21rP/8nFSvPfN7ho/QYXoEb78TcVxW9nd2F8/Psac7PUF7+Ix3Y+nJDn3AQJWgxeA8AcHqG1b6C8ihU/AXDIv6/tRqLL7jMfQLEf/S09iZ++XcK3xzxgvIJRu0IXoocl56qyPkhWn70oBgtEL2Nd781usFABnw5UaI+aQf5eeN5tSL8naoY2z4ciXNR5l3KLxBJXNiQRZemZoihSldw2+24952oT4cTjokUNROdwE5kilb/IUAH57W1eyzg7mpFCid/7ImqiCxXvwymHxuU+PWBMjVkZn6cikWm3OoRC9IjOCA0en01Q7zTpGmIroLRU4tUJUjxro+LFeHoLVUkvSq7BACVLjRhSnSeFgaFa8YzPYVxNEALfwCB6hYAXdQnxumAG95jvuGkCXE+vgMKbKtxNJUKkM31dw+zGUR2kQGffQo0II3I/DQ9F+M+1mX6g5K+JrIXHUCw1Qy02oTHrbRPpJocEL8KFQU53ibzj8t8I6eLXOLdcq/EBoaGnK+VgQsDOJM4CpYkMJqt9StuQ33p4WCH3WKGwIwlvdyxh0rQBwXVoVlsPa9u2bdu2bdt+g7Ur7WGNSoXJaIKSfs+b59GaO6jGCDNq8T3mSh8Ay6RI64ayfRJU3xS8zC1ksV5yTysPnMo5pKVyGuFi7EGOmQdi1z00l/0txdFMPYE/fB4AUFx/Jh5r/1mo3XjPbhnptztNj9VXPgsAWGBIlF0/byRdMFjKbRq4VExNA9jjDvu80cOqxC7pvmOWQg937kJRdMyuouVZHD+AO4lnLsoYfPc+oFlLns2JIuIxBVyM6QHRu+rWtl1QyIf8NkD7IdO+stFKPjqIv957PNzfKiNBWPbHnInKjfNJ+aFmbETpMpmrYll6a1EKSYE07s7ZS0QeD1xKzPbhYgxRvlu3OfNA/EhiAQq4zr/FU4zKCfSseJ2z3/ISJiXjVDxIY5vkRZcFVWG8w840lp2/dfACAOCsX8BRXd2dvQEAWLZLzJenPF5ELbTWKJmUbEqFlp7OvIveWblfpkTrng97YgwmjBFKMumkKlL8S1Q5jBkB4mHxKStnEYTCHgblBInVCfnC2RaKniFYZNE7BbNHb0oo6t5Dkarvp/vc1kFRFkIKdLq+AVgwVEsZe+8Gz0rGb7hMpvhaTb3Pv+XvPAUnvSsY5h4PbCxQu+5hKTWMARlwm3RRTefRn8WYpSihLH1AzcK4olZhNIYS9wmdUSk+nYod+ZD4KG0Xn8PZ6hRVW6NtnsbvvHjvV7KNygK6AmZkoklg9N0Hj/AslTMr5pgYAD2DuTYbEk8VGP0a36cFK2zYL5t48zyrdcZQDrnJw2hwcUIDgPrgJvy1m3G/G7E2kDm8g9BHCEefcAJ58RyzB28BAFbnJwA2S1AFZBNlNtBlPyrgYA/ATfl7P77k07svwD//Utz48sfifexMoWWKfhwnKXd+hGBZd+jBfQDA0dkC1CW4AI9umqDXJ/wcrs2fx/q95dCbBVATnmByPsZVQoxSpdUesTo0ADycxzOPtIZiwFnEnDwUkDL1Bf5xKDjpq2Z4wuuklk00sU0Gy/rfwEVW57XsPgWa7aSOVQBqTiZzqeZ7p0KxImyjJVgO9CIxxRvvVnNMyAi7TYNkcj7GqY8GyD0yUkerE8zPIyRU78YFTmuVzqsR1Q4AoPHn6WJLlmgpKDFWu4ApiTqakfuyNqky8FQLq28EEceVxSlYm1iComodlBlYeiLX5BooTojBxmv2/TKx+ooZIVDlklKsrsa85BqBNZwST6tvE9lDkZ2oAqA4PoQBF3xIEFg+PjcRuqTlhtc6/Jdv89k29T7jKN+GbMFS6n32k2N3Hv1pXFhKPocSCkaEf6XOlXKp6rTPoMEk8hyGbUKiWrFS9/H5HHXXoWs3za6b2xYS3LZt27Zt27Yr0a60h6VKjbIsUEjFYRY2fHLyBJrh+Vu7cfWujj28ZO/z9yGES9Z5/nf+uR5A32AsI2Az9CZWQYnL7j9w2XKyyMgdtFonN+4g7MVAuGdBPTOu4BgE1aSG+r0d1AfR7lb0sHLrPL/fdQ1Dpwa4aZ/brpcVbtyKdN/ipY/Ejbefh37uuXi8mxGmVKMCfkGyBwv1qflj4LUvAwDuncbv5hiUOIqsE9f7N9c6zAkx+ffAZkJJvD961i5A09qetMwD2rEoT+J+kgKi1HA9D2gePgcNxWx/oemaUiFoEV0VVYOQnpN8hkzx4oLnt2Za57T8gfSBSyZ4nl4gJOM6DH932TEMRXftfryW0/Nj7BlW4E2dpdF0QsWnlqBfoiPhoKEW4PX9GR6ex7/PzqNHUaHBeBK9lZ29O/HedDHAZhoAvR9NFyAoQNODqcxArd9jfpOQVorKJKt8zOlJmxIq1fcRWMAlL0pKuniFRDmXTjVd9KjiTYkyxRk8r88dxOvHbDpA5EJ1RwElItQyGptF8s4ShT74jS7Tuje1kR+Ey5B2TrDYlBKT09pzdZn181w4n5AuMhclJ4Nceod6wLNkUUlvyigFxT2Vykc1G2FRFZCek88GtyC4ixVJb26OXcy2Hta2bdu2bdu2/cZrV9rDKqsKcAE9zQZL7Lu8O8Oj92LZgBEVmqcGKJh9LxacRbgkyX8hoZafFhmlmJ8XEj7ZLtBWM8sn7+R1RQe1YVuP4TcVr7/YO4Sv9uN1ibJ1cw5QQVuE5nRVotyJdndJa997d8nqcrhIXQeAKijsc49nWB97/9ZtmJe/OR7nxVfiOW7fBm5Hz8qNGRhfnECdRAJI//ZX43Ff/TTeOYmWeCPnzbwHUX1w4XIfbKIAb8Lpc2ckDxUlFQqv4NiZMxW9jLd253jxKH5PDQToAIx5hFP+9k4YkoMb+oAjhUukC20AI9mT0qcqU+leu758W8DluNymZNL8N4WoXm9ACEoA9YwFEr8xesbNqkUpFHdep3UBK9LChSDh+w7nLvaIbeLnu90C947ejQenF79ctni0jB7ztaT+rqHpD1S6Qk/ygxA6Fv0qlQ3ZoQTHmTLY3Zuxv+JxitJABX9hm1JlusEUE/Eu1YvREpfTwzNJ1r63UOIIOdLa+w6FZ+FG5lYYfwuqZpJwUv4ApJSHlvM6m1HX/XBNWaxGPt/Ps8qp6TmJAtjsRWhc9sQ0hrjr14rpSnekVAiFjeIA6+PIW8C3RBVSoMzCMxVCKhFojeT9Sl/lHqd49N4OHhbzz+FWDsXIou+e3sO60gtWrQxMqZLbnNzPqUH9bJy0Hx5FMsLtsUZ9xoGdlCB0Co3nENN6jDxks2JeG3OTBJI0uZT1bPT1wRhwWcHAYiAcSN6DqneH8qF88uH0AfQ0wqChkKvxKYhe80V2HpdEd3NIUIbLIQKeIURz8Nzd+N1zH0F4IUKB4e6z3PEa3MF+vL/TuEj5oyOoz38yfv/5XwIA3H94gmWSYwmpX9blYtb7I+5/GabIXyyVbdvEspKpvNMeY5ammJh4xKLSeGM/PvEbR/HXK4RUdoZxdpzC47oYOXxBV30/TDa8gWAdvFSm5aQ3Xy7SgnzhWhOTKn4aXGZDOpUFrdfuHxgC7YUaFFFSbaMAVMQ2n3xHfF7dUqOl9FFJUsK8XWFJNueqEQahxptPIoR7dz8udqdHD3A+j+9Q30YCxWxaZ5V1h7wjpWWBUVB8O4yLBtexbzCbStXjOKZd5zGZcFwUceEqzRIrynhpL4K3CoGsP03yhQs2Qb3UIoYqQuoxqVnmjIFKMK3k0g3VgBUepd+GXdJZWNNMF0VSrkjwHwyU5OaJio5HIn7IDB18tnghddGlxUtn24QHWWCYW/KcO/lbFqkLuaC5IchtFhAp4+F8GkngNt8/XaOczwFadISZ59qo5WA4MK/PhQGslndcQcNzW16ZvE+lwePGXnmMlf+6FqwtJLht27Zt27ZtV6JdaQ9Lq4CyUIkuKxXTbGFQPR89hcXbEQaYLzu4I5aKcAIJ+suF8jacJ3fvcy07yX3JSRPr8M56sHQd+sp/n2eoC618Z0LIryqhRZmCrro6nwMr+tczipZ2i2SClwyCl9ZjsWadxfyluHGXV/NsWWHnxZfjcV7+ULy2ux8CbkRvC9cihTlMZ/A+9oQ/O477vfE2wq/+RwDA8b1otZ54lbQcJaVgPc9L2ia67/tBKZssrU2QW4XBvSzpDdw81PjZvbjH5BeSyZhyRhru/6oO+B5WklUNoUETEqFAPF1tXXKJNBUU2jDAoBegyjWrO+CiBQ5cVO/Ix6VYyxVk/AKG/rlO1G7g7CBuO7sRvW/tTtG38TctvcymaXHaRlivoQVdVBXmZxFKf/O9XwUAOOew6KNnVZLMYVBAsZSL8wIbqSG/x/sMzmOlZrVCbSLqUbACsF+eQvP7yc5e3M81CM1FMM3DwUh+lTgyXT+ITMreziRFimxr8raTd+AVJF8riCd2fgbBOlIF3PF0SGEQEkEIyZOU72KO0eBpArigVpL7D2rD5/pYNtm2RLpQw9/53CLbckLSprEl750PuJQ3ulGBwyOqRAPJNQraJ89Z3pUoys3zSe6Vyt7nDNlJKXJ0PVvVY8cFqcP5VO1KL1gGCqbwKUFSxHB6Z3FaczJ5Pi4rj07OcZ1zv0jI6EWPdu2YOWyXu9apOq+wbbJFbADjMliHnzmEaDC49fnElbPg1q9hTFFbmCKx0kS+XDctPGEadFSxbuaDAjkryurWJuRGBningCmv4TYhjt2XXgRe+XD8zd24YLkbtxAmEa4pJpxogkPBOlzuSVyczKu/jJN3HwIATqTyKACVRnP8yKeTTcaBtE0LVx43zOsEvV+sICYQx74xTAjeuzFGqeJC/9aH4r0/8yWfSVTFz3Nv8PZJXJDLUeyDvXoXLeuRSVVV7R0KeaBMrOxxmTGqNkB9uZEj+/fIDJvM0BBGoMq2dfpi0mYF4J1vj1Dgso1xmlpXCcIT9ex512NBKNA6QltoY00pAD0HTN/N0XCMFSL6msVxBF5XSg/yPwZwpC0Ko7IctVB2P14PE8zdgwewxE3HO5EBu5w/gjSB3LT3qfS8rOba2cFIoDEU8hEi7E9TIBASTCqtGJKRNGmiCssoZQ6khSgoDUVZoqSz5LK4TIrTDJO2z79bH5BqWDg2vQcXYlP8W5irRmfzjEDC2X4mXIwJ81KHxS2DotfnqDy/MS2KDlCCN69j0bID4kLkBAZNJYzVJVmqkF2XDHSrQnQ2Qr6kv3/bQoLbtm3btm3bdiXalfawlDIodIAWa0kk7zuPhqt2dZPB99sVjs8iSFNfi16L+8IqMZLywPcmjyeXO5H91uG93MNKOU6ZVWVwOWCae1g5+ULgHzMScNBH9xtAWInoZkiVVYXvplyfzJqiGJQHxKKUfJ2xCniGlufhCy/EjR/8EIo78e9wM8J/ZrqHUFBJRDy81QrhMXOt3vhiPO7rn8cprfJFMh9VqpuzDoV9rfZrfb8OqVyoGr1GaJD9tNT44c2XVYmP3oge08dvxydfP2wwO+V+POASDl+iWsRdMRl1AXYrlk5AP2BM89YTztoEs+Qkkxz+3ZT3l+4ps8glKG+zbY1cFj2Z+jbw1rdFJZRkcRuf4FBxtU+bFo5VmVuRGDIF2iayOo9O+XxhoSEyVywL4R0M4Tg5r9Y6jfNcgaGgV+Ncg/kqdnB6Xz1wRvWMO1LGpTFQZFEI3Ilgknkunp3OK9WKaxdcKiGUxkAoochKDGR8Bq03sHd8gr4UpZdQGki9DVEIgfeJXyGqG1DFEFrIuCjrzzWsjUv5TGM585yTtyXbsr8L+QxrSMOaQ5SPwXSbuYfF/XOVmcQgdIAnSUXYfRhn3qzU5gt5Xw8HFmJH8rDUQOiQVhqN3XGBXm89rG3btm3btm37DdautIdlvcJoXAGOhANSkE2p0BGfnxfRarrxgSkeP44EjOtUAOjvG+A4ru4Sy8pjIrmHlRkPl7blxR3XPbEuXPQGNpUnkZY/DMm1cYK7ew8lFmUqf9AlAVDN/BkYBSM0U5psZWZ+ied2AI3rN6hg8dyL8ae3n0e4Hj0rNYseSCjqpCTgTqP1rU9PgfuvxwN96VcAAKvjY6x47IHIohJVfF0bbb1tIqNkdTZTW8f785SDTftFz5RxCrocZfCYPRPv/XmWv/jc71riu/6+/ILHCMAJj1lJddtuhZKephieWgFegvjMU9lEYc6FejflwCRPXQ3j7ELh0DUP0gI45TjZo0tz9puvYb5DQVmSc2xwGLuKv437LToLK7p8SogxBSYs7XHeRO3HUTVJeTiGJJ7OWtTuIm1ZZSJ1SmuULA0ixvNiscT996KH5Vdx/JpC44QIgWK81cEnQofRRBe0yQbD0CGpf1NCYW7usykNSKFNiriqcoQgRQPlAjUSguEZo1QqAI7vl+Ro9X5wFdJxCwR5XxNR5OsnXYgTpzftry/Hv3MPK2/5eS95WBgyEfI4+/p76X26dViOI9R598pgdPCSj5ZpCeb5aAA9bnGEeYNFobBTOXRfRwzrSi9YhTGUZooDe0XWU+daOAajH7Iq6J07U5QfinkWDZl1ex/Yw7ufjIFpqfWUEyw2dWNauDDUc6oztzeXyJHjpQcVLk9im44NDG6/kpfCuwTJJTjG9XBM/vSKy0WpEQJHmrC6AjDh8UiOw7P7hygJBRouXOHgAIqkAQlQu66Bkczbs6P4+egePBXhw7vvAABaH1IS7jAZ+/RyXWBNvs9ClPdFDoe9X1s/XL6/hceIVyQBfhXqxHi782wUeX0zHOFXvv//y96fBluSXOeB4Oce693enpkvszIra99QBaCwECjuJEgCoNQUlxaH0zSJlDhSDwfgmASZSaJMJtNmQlOStUwmo6RWSw2RPWJrmRGHIsUmCYIEQRIbUWCB2KqAWnN9mfnWu8bqPj/O5x7x7ruVqOJIZkrpulnVfRk3boRHhHv4Od/5znfkCG/6TfmsdBMPdgtXVOQISVzw57FyHnd9wHE4sP33ou9d8y8au3ANPgH57IfAJllAm2fl8w++440woZAkNCGdulrxL9RR4WDMAIWH8ygyG2kEhosdXw02sAiYmEYRdQRh2EBCrtKuVV6WKgxDZAzYVxQ6rbXCzi1hIGJCaZ4kRc6VOek4zq3xjEcdOTZeI2rmID87vzCBrNf5bHhtPHPQQYLQsX/5e6hP163xxqdT5F7CSTF3TIgZLXYJAKigIWrAEUAWvz/mx3IbElxExFg0DtowYdsAbocm3Gd7TAGyEB0jhrxKP40Fakomlcz0NWkHNXPjChoVVlgXxzppbANHu20awLyaUxRrdBOF4GvO8KYtIcFlW7ZlW7ZluyPaHe1h2bDGrLBY7cjSnVC6KJ/s+wz7DRIGomSA80/I37eefQEAsPXGHupn5Fgjer0S3JYV38FZtgVttS0eF1huw4nzFGXT2q9EgyZELZNpHg6rWludNYS6hKHn5CoooA4kpwSAcvWa8tJL1igCgF2d+zySTSpUpPfdj/K0FA5Ra6s8v4WiJI+DAaEjGKoLBCPxYOvdHQQHQmHPSTzI1EmLrW03uYHWzkVrfz9v5R37rgWFeAiK3y2SwGnDtUBT+TfgPejODDARj7S/vQYAONPp4IvvkG0bY/n19tMKOc/kFFF0lWPFpwjQQ4E5UR4lCICQdPHCCYbipIXdphS3YZv5HK5w7jcAcC4ATpPrbv/UEwCAeqOL7phlRVxulq1RBjH7JT8IQoOVlHmLdCMjUyFO5bmn9KriEKiJP0S0roMghCJ7qFaNl5ZyUAc6RJEIcaEzlfOOEoOKJI+jiRx7Pah9DaoNevTXbOkhQaWp8hL1YByORDRFBdOG0MH7Ya2BdvWyHLxXWy+1pBy8Hqaw/NufS9UwnnNOUgi0p9b7OaWU91Y9xBhHAJVmTNZ4WK61x+XtxvmiVA20uuSmpKey14uh9PnzAo0TZNBActW8G4cGeY3QFGWm0AnqlRI1gXCXE1qbGsrlqLXh0Dk1jTbpInHz2VgkaTtT9Wu3pYe1bMu2bMu2bHdEu6M9rFleIikAS0psJ2HV1ypBh0txzIJjYRhg64wQCiqWGSjHU4T3U8HgOapgoAmcomXZz+PD7RiEt5rscbqy+87d5ADHS2q4Uyy2L2itlC7yWQEVA//OfDGVTwg1ORUZbOH3cz3rJwFWWF6ic44agWdOQw1YpsQJok6nPuCsK1qqOvIZ73b/Fj9vUhkAqJh4mdvj1+yud74FwDH68/x+t4vx/eGa8sEfF/cIdYDKl9aQ+7u9uoWD4hIA4PC9TtDVYONTchQXn7MGQCjfp9S3s80mT4lHDV/ipC2c3I6Bus/5+2bVibCAqFpwIK3IY8OjPYXhH38MAPDJJyXRO0GBoC9aeIbeQ1gm6NHDKhnHqWuDinG8nB7WLDtC5YRkyd0PVYiIOn9pV37bS7qerOCKWIZRCLDqskUAwyTtaSb3uaiByJGBeFGzrPDBjoSeWg2Fwt0Qzl0bpl5pRDtPRmvY+aCIrRs90RaH3KluuJiT1QpegdQRBbRqvHcGiaQcET1sd7yq8h6YpTA1otQnjCPIXVdOjOU2QWhR3GoRQtB+77j3iE/QVg0Boz2H2mQKP/bcbTEtj7R1bO+ptT0x5k8c3JK5npxa9fPdMI5etQSYnfBIG31qhxpd7M2lZWSlwbg0KMp5v/PV2x29YIVKIUKI2glxxnI5cRJixVPzWG8nSZB2ZICde+wBAMDeS88hfGIFADC7LISCOmuq0Lp7XSx4jR7LlVmwrT2V2rDO/A1XaB5kezA7NldVkUxRGqjKRUudunYJRdUL62a5yeGHpMvX6MaI+uuyH1mAdmUVmpWaqxnzTsoCOmY+DEdVoGLYjMP+QBYsfXATYO2rypUix8nWZkAtCijruc/WLTh23/y2144ctJ6Y9RiIgwRtZRC5Fyr7P+hsoZfcAAAMR7L4D/+o9rV6tp6Ro92KFHLOQlf5N4BCxm2pe+nZZgy063o5Mk2b6Tc/jrRtGIgORjEaWBF+CB68l/l13/ZGPP1NbwcAdItDOV7YQ0ZiQp/PMIkSpFqea13JZ9VasGY0cHQ9RWkde44LidZQZOutshJvkPZQGVnCndqHbb32Ah0gYMdLVm+e1TXOd2X+Xd+XbWVtMWGtrdSrvofIamcI8HmlEZA7i8DhYhFsOCfAawE3Em3rje/vqxtpWnuShB9lyvi3rJcdQrtxYasNlKu66xapOPVkJSeCbPxRbm+EtdXT2+1EDtfcb9xlUNN4Yb2s9oIVO2Zg26huGUXzsHRtm5yy4a7M9dWsj5rzRvm8OHhppUXX0W6eDMhbXxQWN/ZnXjbstbQlJLhsy7Zsy7Zsd0S7oz2sThohTJImz4lLfC9JkBA6CEKBQpI0QUTv4XRnDQAQdgJcu/UV+fGj4vZOv1KhN1LHjtfOn1mUG9S2ghbZCu0M9leDydA6Rzt/ovYaaga1ywmBoxLXMM6dJo24rkoE2qkQ8HhhBLUmHhYG4lEGUQwwZ6gaHgIAwqoCeI+cgoZFjJoEhWBfiBZmdAjFXLCqbu7RvEV5LMek9d0iD2veOmybt/N07mN/tyxU1dqm/TbltQHhvdbKEwjqTKxHFYXo1AKlvXDE3D1jMfp+OeqMpZg7n7O46U7jUKBCoe8scJ4qgoZR7jk1fZ63Qtv3rX1t89WWu1rjDd9Mn+7bpdzL7z7yNsQuR6Yn6i2xCWAIfSliNHGZQFNOOWRl325de4WBSSkeZadOYZnPV7PveWDRJ5wcKoEGO0Ef16gvOKXFHUcKmp5dqJWn+Vf0zmcIYDdJn29d8HD/UPrD30ZBhMyNKbdfJwKOnLhvQ3hQoYPrmryppmyINGta52sVhdL0sIyfkAbKyzOc9GsciUNpLTlngC/pY8MY1sGXnOyLnmu7tVMi5jdqe9LDglUnUtG0srdFKdr5fJ4MZlqU8wW/8b81zW0Y7cn8D8cFco43Y5oX4bygdY3mt16ExDbzwKn92NLgYJijKm53p463pYe1bMu2bMu2bHdEe90e1sc+9jH8vb/39/D000/j+vXr+Pmf/3l87/d+r//+R3/0R/EzP/Mzx37z7ne/G7/yK7/i/72/v4+f+ImfwC/+4i9Ca40f+IEfwD/8h/8QfaorvObOhxqh0ohiEgTIw4zDBBET/jSDx3EUewLGap9eRtpF+HbZdrQrHtbN6XWMXxZ7JGH4qL3+t0kX87GVRVphbZLBohCMwskMd9HHozeg6OmYurEeHWW3Mt6SsS65t9ae/u7ophrGJ+V7j62YwrjYxfjQH9cwETnInXIDEIzF6rYjifOpfIK6ckFXOdwidfK26nS7NMK8Vdj2OhdZULcLRit7ctsxmi5MQ012MRFT+4dW0kMoJyPUFQk4B7L3UWDRyXk13yqf03sA87uK+7kAtEFNi9Mp4Fctf9AHpRd4g4v03iq0qs/SWzr3QIrgT74NAPD5Fan8nAQJNGnhmaNzFxU6nA+BU0GxMQJWWw7oVaVVgdxSW5NedVIlqEioLzmG8plGpcQLPbUiaRA2TmBm4i35KgAqQOBo3mhoz32mlRgDZMZVq4X/7ZhjL7Iu8TrAhJ5h5uJpsYLi/bRuwOvIe0zKiUW2kwuckjsWeDqqIcS0Z6WjDzhSk7LWuwrKB1BVq+Kw26S8WojzABeiLQsciTaFvd2jeQ+rhvUpLN4TU8cTh/Xcb7Agjlov8N4McALhqNG8l0ZHRFiGGQoSsMwx2ICnW0C0aF/zIgp+2eaIvYb2uhesyWSCN73pTfjTf/pP4/u///sX7vOe97wHH/rQh/y/k+R43csf/uEfxvXr1/HhD38YZVniT/2pP4U/+2f/LH7u537udfXFwsKqEiqQ47tJEoUd1NZl6nPxiROEToKGv++HCcJzEsm+8TZhWR0dlDjKhFyQ75ABNWtqX7UZNu2XJnAcAnOtPZAWMYLacJhnmmlPkPLlby2MDxTbFkZW+TLd/K0KfQ6EO54xBVBKcNvlWUErKKe9QoUQYxSsc/ld2N8YaBIszEQW9TqboebAXVRTxz3tRZBgOzi8KO/Evxbsyfm9yHD4Wk2heclpJ4hcVtAkYsxKeZFPjEJuJEdpJZTEk5u3cl91dWVbzr56N/A8E+/WWctpcMlg/2n3YmN1YzQ0Mc8CbAW825N3fr6Wrd+cIzQ3+M6H8OX1+wGwXAiAIIhQOVpiRcjPzlAal0PlCDQROlw4nHBvWRfIuGBFmQy2NIpRcN6gku/yeoayEPmqUf8cAGA7TrDJ3D1X2kfrAMrleoWxV2hxcHg/6SO3u7xHbhwbzKbM7XPsVKswZU5h6arWRgpOcdYvXCpo0U2bReIYpsV/+vlyrIw9FzS3wFg0b1pfNbhRl3EivzYEbOUMH7fAtV75qr0AHm+L3g/Hmj35pzd2bGPEunO9GqnJLdshTjJQYTwHaREaeTy8wR1mE3k24XiKGQ0ZFwqoWvPUvXeUauWPtm7LvIkQukn+2hHB179gvfe978V73/ve2+6TJAm2t7cXfvflL38Zv/Irv4Lf+73fw9veJhbjP/pH/wjf/d3fjb//9/8+zp0793q7tGzLtmzLtmz/DbT/LKSLj370ozh9+jTW19fx7d/+7fjbf/tvY3NTAsaf+MQnsLa25hcrAPiO7/gOaK3xqU99Ct/3fd934nh5nvuieQAwHIqlHyQDzMwUMWntDv7rhgXiSGC/KSGuUltMKXgZjgXa2lq5C9qKx7G2xYz7x86imIh1uVvKeTYzi4Tu+PSAVhUaS8Y5Q234wW2bp3vPEywsGkky8kNQd4GCup9xKhZNgcrrdLpQq7W2CezTAlR1CEVliroWr0rXIeqp/B3MXBkSCxM4YgXPb2pY4zQWeI7CoKYenHYUdmu9Mero120Fi7b3M+9htWn8DTGiuR/tgprzhlelFmxcsGkeenVIivM4y6pCpDcAAJvpmvQ/TXEhFtLFyr2ybfjsl3DrQHo0GPBgMdDjGfMn5SG96Y+dx5VPi1e298/EO6/L5lm7AjEWABMI0HFKJjDew3LXkUNhk5Hpb/wuGccvf/M9WKOmW+DUKGyKmDBmVDtBXuO9As8ijlIoujqObBCr0FfWjRLxMpM6gnHqHCTkWGuhmas2HV2V+7MZY4tFGJOuXF2hAmh6e4U2ACnnQUQyyGYPE6dFqVwu1QwRB9J+RhUXa3yB0oowg0YFWzsFCY7zwDYUdyfZYAMoEjAa9KP0v21cAQtX/8JBuUopX9bEq7yYuuF2c+yrKgRiQvIcacZUUOo4GaE9B1Trc55wtIiEVeG4MgwgDqZ7hi5XTquGWBWak8dsk3faIsrHsgBa55zvq8uiCelqzcZTVOFxIpE2OKE3bOvGO/Pki5Z4r+9fqBBF6lWrkC9q/8lJF+95z3vwsz/7s/jIRz6Cn/qpn8Jv/dZv4b3vfS9qDpqdnR2cpiSQa2EYYmNjAzs7OwuP+cEPfhCrq6v+vwsXLvyn7vayLduyLduy/Rfe/pN7WD/0Qz/k/37iiSfwxje+Effffz8++tGP4l3vetcf6pg/+ZM/iQ984AP+38PhEBcuXMDZ1bM4qqcwVii2MbH9OOqiF4m1HIZO/WKCwogVt0crLIwTBLRBLL2N9EyM9FEpgJewpAAmY0T0ZIIv0eO5btGl+eO8ghQnMeUEzU2OQkC7YBitpUBZ72HpJg8RlHyD6bjfBqgd7ZZ2l1YNlbhkwNvWEQJa3c68MrZCMJF7pDpkFKSF75jz3OqylvIJQEtDrfbqF5b057oqvfW1KKbn2qvR2r9WgqH7rWsez1+A8X+tXOK2petiGbaoYSvxdfq86ba3jd6q9HL7jIydYjrBb48k7cERBWYzjcJZuIy1YFrBPiR/H7yHJT0+qdC7JfuNeP4EjUfq9AVDAEO/Ta7midjiwfcIIvGl73uTfJllSAbyQJ03Faja6+lZ3ZBICld+htZ+UMfoB+3oBGBUjZAkpU5NFZQqhSkY96IWZ61yX8Bxg57YepSiE/d4XtlvrxV0jKF90m+SUh2jk2L3FgekU2EvGw/hpVuStJ1XlVdwLwdy7ErlUE7Lz7skCnYuKqKUaSUMnxxkPpYFBTuXn2JN7Yu5Ku5n6iYrtqaXCVX7bZrSI7aqUJUt9XIcj8se68OCba/lOwNBNgB41fP5RPr5jJsax0skuW1l62937BNxZdU6Hi99Ns0wpqqCSxauWx5WO1Ro2hMPOFa81DnG3UShk6YotQV8caLbt//seVj33Xcftra28Pzzz+Nd73oXtre3cfPmzWP7VFWF/f39V417JUlygrgBANtr9yCo90G0q5Fh0hpxJJcW8kW+N5xhRjkeQ/hhJQ3Q56RNvUs/hF2RF/PqGwgR1V1EZDN1eMdUNwdLBiFxPIa6gQJDB6nBgKQtpF1AcyFyE89q5RMTQvc0OoDpkYmWcCKE1svgWHUyeGy56tU6hmGFVVcxNFQV1IwDYigLlq0rycUCpAIrAFTGV+e1Tj2krHzJUcMFy9R1M2DZ5UWyM8fagvy1drsdiWI+AP1qh1EL/m6RmDz5IjcaBQViS96jnuliPeUYI5x18dxpfP7KiwCA8ZDjozJQvF8BjzyuTaNWe0FGwOW1Co9k8rAVq2rc+MIEmSCHfgFPlPJj5Y2UI6i2gc99/d3SL46NKAJyVj+e0UTSiFHxbeHYf3Vdo3L5dYTWIhN6lqB/ywU1AmKknST196DMKeFE1mkdZqhZV6Tn5JqgkfBlHfBejHWT2wSl/LwLaJitJnFTP4xGmyqAiH344ssiSD04E8FwFlU96XNpCyS8JuXmwLH8Oh4YVYtRwNfyosFomy+Uw6NsUzfNsQRtVXrpM1U1r3cHuVpeUV3lKHIHoc51qdVebYzfbmq0WcnOSPTGn7q9wdZenBwsXtlmm1dTwcmFQNsWZM9XwWSSY0J828GKxhz/229zkKE7oEXDbuTJkq5GL+mgVLeb/XP9es17/iHblStXsLe3h7NnzwIAnnrqKRweHuLpp5/2+/zGb/wGjDF4xzve8Z+7O8u2bMu2bMt2h7bX7WGNx2M8//zz/t8vvfQSnnnmGWxsbGBjYwN/42/8DfzAD/wAtre38cILL+Av/sW/iAceeADvfve7AQCPPvoo3vOe9+DP/Jk/g3/6T/8pyrLE+9//fvzQD/3Q62YIbm3dg7gc4DpXbofgqTpHTfhPKzEJLEKUhLuKWryNvAzQC2kF094oTYbZ9BAA0NmSvLDtrbMop3KrikC8w6RTQa3J8cqrpHjeBEIn6UdrvmNZ8RdAEgBph4HTuCFOOLFdl0pVBEBGy2/GXJRCFYhJpnDCnUpZT7cNHNFCBagI0xjdCt3S2jbOHVUaiGkpBo6m28roL5zrXwKVgwSpKFAet/x4g08Ieh6z/hbAeZ4U0Nq5baHOW55fy8Na1NoQJG8RuvUAKqMXncgzrtMueik9Tt6PXi/FCp/TK8TtqhIICYuYjNCVAaZObSF2QekQ9Ttl7H3nO98AAEiPND7xazJ3EoI1s1zhs//ndQDA/04veV8Z/N/Z582+eH25gjd1Z6WMXw2FIBAUIAocjFXD0BtxSibi98tzj8jsiRF4erwhLDBGgngmY2ZKOYQwADQVJXocT6GxiAgRxOyzDhQcoBREcSMgSzM9jZT3tkLChPmk8Gb5iy+/BAB4aPMCnB2tOzI3M1UjZQ6Xq7YNG8NR3b1pb41/2H7sWNsahw71wEnI0FgYl3/lc4ysV3S1LlnIwsOdTu+zrEtUzitrVc91zmyblDDvEVmc9Lza29q6gK4kkfusW5fxauiCr4Bum23zBWrb52sfx1c25pdlblFErTQAUE1jTktw0fUqKE/ySTmeNtIIwUDB5LfzE4+31+1hfeYzn8GTTz6JJ598EgDwgQ98AE8++ST+2l/7awiCAH/wB3+A7/me78FDDz2EH/uxH8Nb3/pW/PZv//YxSO9f/at/hUceeQTvete78N3f/d34xm/8Rvyzf/bPXm9Xlm3Zlm3Zlu2/ofa6Paxv/dZvbQUvT7Zf/dVf/ZrH2NjYeN1Jwota2ImxGl/AjJ5Tj9ZtaGeo6kMAjTU6zXKM6V2UTJgcFRbdlNn5NCNKnfnEy82+rOfdXg2syq26vi+fql5BtCbHqRjzKrWFKRw1lsdLFfqbgsWvnVvBgBX3Itasz/I9HB5JXGniYmEzIKd5ljM4XBsDWx23xaxWXnTaU8q18nUorFMZqCqUxhV4FAs10iEcZ0C5IDgUTDuaCsay6GE5U6o2xwtUwv/6eF/aSdG+z62/zat9geNe16spBLiv3NeLKMLQDbXWWcZ5AEzpzj579VkAwPDWS9g4IxTyu9bXAQC1tlBUgVAu8lxrr7O3P5Uzb0xzTBkipIwjohRe8X50KDZtut7BQ39UCD1Xr0owawMhLn9OvPbRDflxN9KYFfKbAUuFxFZhxrgMQ4mY1jVCmsHeYzclCiqoh0593OaAlt/G9LDSKELFRF/nenaqFFks28YstlhrIGD5+jh2LJ0A2nlTvPklamh6YGkQeeu8cgm6YYC1rrsWmQ/X92c+AXV8JC7sXj5C6ejljCmOsgx9Q6IU0y6UChrPigNZnrP7u6EDuQRfr1ah0LgBnjGg/N+udIZp1ZP3HpaxjWfFgVUZ2/Sh9W68XQn6r9XaJAn36ZP0W+7Q14p7zauotD2s+e/m++oVfJwjWwGGnWicWuWv2W+zLW/LH6shl8XOS4s0VuMIxTxD4zbtjha/rcoZ0v46Vos1AEBMPC7WfUwmMoAmpWTpW9MoHZR8ac9mJSxlmhyhwZgcNpRHmgyYsxIAGeGYQ74tVBxifVV+mzIXZRKNUIxlvw6zb9J+H3qTTIuNLtSa/J3zBXLz5j7GhGRSlo9IoKCprWGnAlnN9ixSV0vA+9kaTV0fV6U18OUPNEur6HLaSCgR6ot05geYqvjy0YGXXPLyM1VTc6uu3QQ9qdgQ4KRg67HWgvzmF5g2/IfWttu11/O9hz78JIqQT5hmcVnguCyNML4lOVQ31+X+3Xf2bvQ7hJRr4frV1iDlhe4fyX25nNyCIk5TEi5MdYCAZTk8HJMBlqQGR4LQSPDoXfKMf29fXtqhMhiS5DFYo6TSpEJNkdDckSps6Rl1Ue2qAVeonKSRdYQMDaOdLJL0LwkjlGRIajICO0UHGXPRkplsy9GS4eG5QmgfxXeVrGMEiLlgWdWMn8K6nDCD7XVhPvZICrp6bd8z3pzoynA4hSFhKusE3HaIDUpohYqWgYpga8fqc0aWgnvazcLREsRtsQSdYoZnI1jdklxqvW3dIuY+ywK1cbJkJLpAwRi3YDlSyMm6U3bub3eK+bG8yAg8tmC1vjvW1bnfV80a7JtFM0/buVeLjD1P7miTKub1gVsn9oZh+0Jb8z4kJJyy5HA/6SDpnQOCCsArC3rw6n1atmVbtmVbtmX7L7rd0R5WkQ+BJIAm9OUDykEB2xEv6nAk3PMgVkAm5sGUHlZRhghd5VRH07UKoVOZIBRWV4fYpwDsjGUOTAEkoXhb/XWxkDfftILEQXTUcYtNByFhpSpROCDd99YtgQHzqINo6x45NjnlhTU4mgh8uTOV853Ncmx0SWGnRayt9aUCtHaWZUPvN6yIam0HzjQqmS9SlBViB68wvlgb6/GGwLj98xZUIudaVDohwHHFD9caSrl8ti3BNjljnohh2j9utfnfmtb5Fh3PmsaIbgLxBeKupFB04n35TqeoSBfPh/RG7tI4f1oKXn7p2Vv+2hwy5jyeg6MCK1TCcFDYoAt0CLUVY3rvgUJEhkOXxTOjoIMHHloFAHzq9wUmrDRw44aMt9WQnlhUed3II8KTtVUoOJYrwmjdJPACrM7T6kY1ZlQ96QUDHg+IuJ9hYkwUh413TlKFDWrUhMMO6EWcNRba5XqlJCpUgObY7gQBKldktHBoxhHO9uVaUpYpiTQwchb7TPZbmdZIeG+QyuehGmFGGHQQuazHxg9x3k2gygaW4neqlShknSfWqiR8DHxzVOyW2zIPfVfGoqaHaxxcqHRDsXfahNo23mPrTPNeUFsdZ34OAMdRiHkChcbx+TDn1BzzptrfVXOXrlue2CKilMtBVDWaEieeUNKo3nhntQbKgmOQsHQSN4VwY8LXUX8D955+EtmsAPApvJa29LCWbdmWbdmW7Y5od7SHNZzeQhzkPis/CiQIFKoQI3oSw1wsyxuTmyinYp1lVCyv6g58+XRGj+Mo8omUmpZUUVYYj+mVZUw+zTV6tBQ07ZhuHGFtXWJU/b5YzSpI4cwRrTVGh9KfGc30/pkLOLUuUlNJVzzEmcmwc+Vl6f81MS3H4wgVLevQKdFr7dXBnTKFgvHbvOR7FfmAeEk7LS+tj3u5BFgVhr7on3Gki7r2npXjY7Sg7GPtVXM0W/sv+u1Cj+w2x5r/e95WVnNbPY7vEq7XCsQd8ZyCgLqUHcDWjC+5AphRiJWBeAXultYZoCTMA8rpIQmak/u4fmJhqS7h7p9QrJmoTjp6v9fDI489AACIUtHqq4saN3cYz7JMwIVBRJJERE+wrJu0htKVe9EWVjuP2CnvK+REC8aljP04Vl5tJeJrII1iBPS23HErBZRMOr++L17mIOkgpiZ/wsKh6dopaKZTKBV45KImKygMFTY5JzRjZyZViBiSci+iYpIh7cuzccUmD+ocY7roG5wDus78mD4WqHEuQCs/ojUS5A9lvRfqY1m28uwCF8syxniSgRv7pbWCRKCJa4n34+aL8+ZOxrDaf7c/b+d1HSNQtJJ/3Tm854fGw2kffH7+2dbvfTGWRXHlFqHDk6jqFsrivMf2xbVCiYZzrV3+JODzWlmV8jhvuedb8cT9b8Vk8tpULoA7fMHaOXwJidmApVhlSfmUtc4Wjo6kbMH1HSlpsF9chYZAESGD4crGqGsnOeECgjHS2L1hKMNU1ZiOXSa77J9lBhlzqjoMEsdRiNCpRhCOqaoMmguMgcLuSF5EBV8mPa0xGKyy3wLXTHAIxZy0tQ15wZi9CsUeJXkcK0ep1huSH1b58g1h5FiCJUCWleI9KnODwGX5czaGdZPHUjs2UI0TMkwVTk689qLThuvmmUivFghe5OrPT+QFsdxXXSQXBbpdeYa430X3lDzH6WGX32WIw/DYfgVqrBPGClKyNccWCe95l0QLnbaYWy6XztSolCMFtF8HhFw5lRMVY+Muef59ShGND2rs78r4nU1dPqFCaOUF3mE5HdSVh5sdEaOoSw/nunysvK7hqoYElYynXp0g4HFCOIMv8vW3/Iuytsho/O0cyoJlrEE+I7lkXV4273h8y8OICGME1BbLOlSo0ECwejcAYGNNDMs0/QMUSmDQ0omt5AXSRJ6JY6/uo8QhjcLtmtWRTdksTj4HCvDiuPN6QZg3bFT7H7DGNizcVmkSJ4dUeWagReWNumaUuTFj23geW3scz5f7aEPki4gY7anilUL4GbXmUomTpIz2YtiGNt3C5/ZXrWP6fllvy6MlBnKMCeh+6ysJuwXVeOKpL2dThxE2z0oJp69/43sAAE889EaEUeSlxl5LW0KCy7Zsy7Zsy3ZHtDvaw7p1cAlrgUZFX7jIxIuYjKYYToTUcHn3EAAwrPYw6Il5cLpDTEcZWMIrvjJuqBC6wCk9jxy1T7B3pRGyfIqsdCKkYqkGNoSCIzoQgqssFOGRWZajJj23ctny2sIxNZIeyQ9hD6dJ6d3Ipf/V1i1kVFHtuGR/BTRVV52+XWMZOehF1TUsLVOlxSKurEHhkvdblqILsLZhQBekbedvLPJ+FilTzNtOpgVjtL2kNvThjrvI8lzkUbWhlKY1cI37vYNC4k6A/qaMgXTKnKCjfUT0mFUlR5plOS6wLM5qR+7lbp2hpMnpENdOAgz5TJwEYzUzWOmJTRyEDirT/gpd2gJKi2RVjn2ayirTgwwlCULXrh8CAM5d2ERsZEwlTug4CKDpJbltR/kRDD0FS8h3klWgc4Ywkj7FdYkwIKFDORgw9H87z6IuDVjT01cMvr57AztDwtNHcq+efOitnj4eaiDmPSpZ5idIH0R291MAgLseE0t78OlP4ejm78vBXRZFXaNPEpBhKslQVzhg7mROBZvEFAjc68tBedZ6r8sL2LY51t49aNW/gNvUyrmiy1Ab68uPeJShhTi4yiMK1g++toM1702186JuN87blJBjfWyRLdy/28een0OLztemx/vj4uQcUmg8q3ZaiC9N0oJRPOzoENcaHoWo6WrddfEJ/Mlv/n4AwNveLBB4ripAhdDla1+Glh7Wsi3bsi3bst0R7Y72sPLZCHFlvDDzaColCqZmgv2huCP7I/mcFgZhJHEBRRVoY6wvVBe0bBpXssEl2c5Kg7o8How2FhhPGNdap1VqtFdKr1gIr0YNU8p+WVYgSZmYGa0BALrJwHs4TuE6DiOsMrluoyOWZfBQB9Mbsm3NyQEGYYMl+1hWE1B28QgbdxE4Dyt3Py6RzyUBGjQ4tGsFjuuQARLcXWSl2TlsHAonirO9Gq3de2etY8x7cYta+/Dt4x6Lcc3F4bXS6A3kOXXp1WKssX8o8cWYHnNVFegMJBZzdk08hRdM5i3rwGs7NmUt6ORDFUA2cAUy6bHD+sTSLBdPN1UJ8ly23Xu3qGC89Pyuv5jnXpTE5vPnt3zsshs6j155j9+5xrNqhhmJFdYVY5wWnoKvqJ3Z6Rg0hUAVDxFC8XgVAx2mBgon3O8SZvMKo4xFUw3Hkw1gaccHYYSa3lsyuF/u39f9GXzjGyWG9ejmXQCAN77t7bj8+/SwvNsQY42kp5xx4JnW2NUsIumU4+sCgXVyb83YbzJbeb2mSRz2Y6WlzqDa3pl3wOSPyhiULszrFGesOkE4aHsj/hz2VbybuW0GrRIg7ctgaxdyDB3S0ZqPbtcSDemiPV/bWoSuzaMer5Y43BYBAeQ94QkbLdjDhxJb0MqMnbnr/scBAP/je/+v+M63PAYAOIxZ9WGmEIUpguB2M/x4u6MXrOF4BltlCKnocEQ2U4ACGf8uCWOoWqHkzCsp11SUFgUncocLRFVVfsTkhGVyqxC4LCOHlSmFCdUSDllaIO0kXnkgYEBeBQo1MYQoTRHE8uJbJWMqSZQPLjvpHR1FiFLZltp1AEB4bwz1JYE5zT730wqKzCytmRdj64aH4eDCMBKtIAABc9VUEHgJmnbNHDM38doECz/ZWhOqDf/NkynaDK12xv4iWSf3vqpb518UeL5dmydkzH/hvw9DRKwa60pnxHEHAceMM2KqfOqry57ZYO5QeLMFzVAFw0z9G4FpRygqjREFdsuGJog4lldDXsqLvuqEyErp8ZnTQkbQofbVnb/yrCxY3/WOx6E4lgfMzYqVxswFrAnlpXGKLCcc7fLASosyY/XpRPq0MrC+XpZ7ExooX6bG8C1UtOrG1JRAs5VCJ3RlSFpEBUf2sQFC/n7l7BMAgD/6+Jtw99oG+yr37Q0XH8FH5tC6wHQx4II8ocTUTAfYD+XvIYWrt8IKlvPZcpHVpmhkgjxXtFmw3IOztoEMbfvTNgsVIMounnTB45XGHlsw5BytUhytsbaQNDT3cj8GIbYGepu4BBzPuXK7ha1tFZpFpL0ozs+hRTC9WrCtTc5Qrc44ERUPgTZpbr7PpQqwefY+AMCf/a7vAQB83RsvYj9hWRzjDHOFONYoo9cO9C0hwWVbtmVbtmW7I9od7WFd3xvi7rNHsFScGFdifUklS4HeVmOxjMeoAVqoGXNSqlojrwUmTAkvqKpGRlWAunB5WLGHTxx1vq4sIpobU6bhz7opwoxeXkw6tGksj0Gnh35fgvgrHemXCuD98aNaoJwks1in4GjYk0e0emYFs9MsQPkVwoQqgGa/jMvHsjW0cdYlzxxoIKJuHD0trY+azHl+tr2atoXn8zUW7NcO9M5vW+QZLfKc2lbmogB1W/Ns3sLSrf0WBapVC5b0sKMpvRXd6ThvpIZinpPKxtwvgKau5Pq6eMSnBhEmJANkPEY6aeBGZ4HmucXB0Hkk1GI0ta8aXeXyDKezEoru2akNetOB8tb29cuHcl+U9bl2EYkWWiufm+eC250gxlBR6QQONdCoCnpgjhxSW8/OUWjGSwP7ksSjU8ShU4Yh/R0Wm10Z30e5060MfEFQmBqKupfJplzT3XEPiCkHEsrFnd44j9r1gd6Ntgop586EMPbEWuxTJ3TE+We7gPKYNvxnQ2unGgXMMQ/B7ziHdykD73HalqdVezULQoI6QA1XyocqH7bxsNrKFPPjvFoApbfnV3seuN1069/z6EcbEmzPl0XzqT2vfNkTd2x7cn+NFnnLfbfAPWvP+9ARyZIOfvDrvg0A8NQbxdMaYoqo6By7ukgpmMr4nM/X0pYe1rIt27It27LdEe2O9rD2dwyuntlBhzT1o6l4KPVqhFMrYuucztYAAFGuYC154dbFsoCydjpkTj07RJbTas3FEpjMSmRj+Xs0ke+UAkIGuivquWV5jpgWbJaTpKGUJ3SkSYpBV6zMXirB7yCMUUVM8KSWYKQNEpcpTmtU9UNYUrFzJdcZIvC09rY156xVX1YhUlAuUM/Ygw4COHPJW1rqJB7dDty6tsiaW2TVtY/jWoWTHlEb719I953ny7eaXfC3wVxp7vm+lAYjjpWDsSSu1ipHyARwZ/EGgUXE57C+Kh7Wxa0Bvny0zwPSS0oAxz0IWp3IZvKPkvEoA6AgAadksP9gOELJ/foD8bq7cYwJYzUHQyGCXLp2E5pJwmfXpC9RGCJinMd6DnLtS2s0McWoUWdgoMNW8CXtXbFRGwWIeOdWKecxSzoIuu634k2ldYBel31gAUwdNISNREdQtYzv8dWX5FatDVCEh3KPSOm/q38RJUlMPd4jHQaI6anN3L3SwD5v7Iiq8/VaADXhBeYNqcWVvG/U2G2LCOFclIYk4wkOthmpXpm95bE576u28AVX23Nl3olbNC7blQoWeUntbYuONx9LnqfQz5cNaR+7Hety3uDtkBCD5pa4tBZjmxC+a6UFXBGJLJR7dc/2XXjTG4RsMwpkHId135NyXMy0QoA4CL0Cxmtpd/SCde2yxdr6AVbW5W2hUpeDkSJO5HFsbKwBAJKsi2kmlxvF8rKKjPb1gXJPhVEoyIAqKLi5d2QwPpAXgysNFWnt6wO5hSEvMhTM4amZvDKrDHqUXIrCCH1m8SeJLJCD/hpqMqAyJVBUJ6q8ioIrmVLHBmpdfpNxYesHsahdwK9bsEo1ZRIczKIUFyggiOQYUZQgZyC7arEuFhEX5uKsxxawRZCE/15hPt0FBRazBN1AXHS89jHm+3cMenEIqD15jnbLyin2rl4GAAx3pBZVoSvEZLY59YgojhGS9Rkx+e3us2dwrZZFxL3Y4sCipKpFFLg+WCcu0qiGGIOMgypjPtF0fISQdZ+CWMbxyqCPKQkgLtfvq89fw+opMXbWKNzbDzVCXn3OZ62s9qw/d0OCMEUUN98DgDbaEw9CPtnEBoho2DgB2n6UgsRCFIUsWD2dIuFGy4WtqoCQYstGKdhA5tiNS78DANgd/i6slX23eo8AAC6spFCEMqccXCtKIWUfZrWDIg2GkYzCG2S1lGsBLI1Ie1M+NRRql+fmFhOl4PO0wGYbmNC2Vi4nFlx6JRmgdiWJ+KaurfVQa9vwcuOsnce0aLGYh/8W5U8tggbb7QTCieNkpvY8XET8WLRtfo4rANaJALX0lVR70kIEjN2PVntixDy2eRbP770oP1mTnKtT6yv+Ny7XLwwjhKFGGC6apYvbEhJctmVbtmVbtjui3dEe1vAQeOFSiYu0KbbPigW60usjoohVtOLUBrpItASAx9RGm9rMQzMuuSIMlefYkpWMgwPAsBpw5KjiMAgYfFcUKC2LAgUL9FWJC7g3bkuURlhdoZ4hhXPXVrsw1Dhk7UekUY4e9Qw7jmYcWlQ9+bui1wVtYWj9uDysFrLR0rCzgNM4JLFARzHodIG341WFaReRKeYJEbVq8kPaEEhThqDZfz6rvm01Va3PReKh/tpa37Vj6a6fbZttHnLJDsc4vOIK7rn7ZoGU5ThY1kKHA6Sd0wCA9Y708v71GnuFwIgzdSD7JVlTlJmfaahQOjIDx1jQClFnJA8cjCvUcjicZn7g6moPN2/eOna9V67sYmVLnt0+z2/Cjte/q2kGF6XxXot23lQYehjc5RFaA6+J57xCoxTihHqFzIVKk67XIQxTVz4kQqJJsGBZnroyCFy5EjTpAsVYJtEv/Ma/xNc/LEoH8fkLsp+dwrLSt5uGVVEiISRoSKrIiwp77P+tQG7wtFdjnfmPuNVU+1Ut6rpcZ93+Bz8FPgTgS4QoNGPUeVXG1qhdAUp+VtY0Y7+VNzVv+S+CuV/Nw1qUm9U+TvvTd5b9XQTdL4LVj5Eu5o65CJa0gNcDbB/D30InGxnCpxU9ekaea4YcL74kCMaFbSFd2A3jUyZCEsCiIBC2y3yy5m3a0sNatmVbtmVbtjui3dEeVpQCk0Og3BaboUPcfWOwgg69kIIU9a7towPBWCeHEivaHx9hlokluEIrIYoUtLMUXUnyKRDSNNJhY8FZn2Ao5ypzjZzHK4j7h0r5oHa3l2B9jQm8zO7udQ0qUuqr3Fm3Fv1EzpNGxM5VBdPleajqXU8MrPP4VONTeMuJlqU28Hxb5cpHBIm3tl2ludqe9GbatNU2xXbewypafx/D2J3V2tq2SAHAWXGlC/Tak95egJMxqXb/XGv3r51c6e/LzCIsSNtOqAZSzZDTit7oisfb76wi1JLo3ac3fTXaxb0XxJI8HDs18Sug/B0qVy4jUihdCQsXX9LwivCWHogpges3pKxIeJbe3FrfB/tdjOra9UM88Rb5/igX0sesihHVTmGecZ+qRkmyQuASarWF0U6NRfpSVAUqJkjXhp3XCqErYcL0hyTuoaYXVNPLTMMQkRMnhNNJVM2z0QESR6nnxt97+hMYkWp+7+gZAMCly88hoWwI+U2YTsZ+rGZVxv5VGLOv10m6GCY1zq7QO2Z8C0XgH7abD8aYpl9tKXWfues8LduMUaeDWNeo6iYlARCqe+WTk3mu5qi3jRXVwAml9FcjMC0a04v+0d4+H8NaRLpoe1i34TKh1oDhI3YedjuvxOloJmmEu3qiXNLpy5i5OdxDnz86OBBo6t67akRaYpgRSRZhqGDq0nu5r6Xd0QvW+hoQA6hnfAlwewwgohrArHT0rUMkqSgJdKk2Mdo7QjVjRn/HwQbaIwhEbRDUJ5k1NYDCsZPI8tMIMJ0x14svjU4SIUzk6fb7CVa78sRjx4zRGSxHccJNcVyjQ2jG1TZCXSPne8VuyLVVt2rPQGykmVTj1vuaVsaTLpwigg5iKCeiSpUMW5/MuWq39uCfhzvq1gKzKPDcXuBMC0px+7nJ087Wvx1Q4PJirLUn4Iz28drH8IH2QCFwOBiZTYGNoDgdNMWI1+NNVEdybz526UsAgJ2jV/Cdd4l6wx4XnYP6Cli1wyOvkTIIyF4NNPNPggQB73DacXXYapCciD1WmT61dc5LLVlCYQeTEqMD+d5VST4wFgFrwVm+5SsboOZN7MbMmwoVakcoqEn2qCbIKzHgYr7wdWUQU4VipSc1qaaTGUoqFLhaaYEKkFDlI/HXpj3rENCoCRmGVI2eZAWuXpNAfDSSkj/VZALN+69YZ86WOfKCKjVOc81YVAzMXyMrchQDRZ9jP2rqfs0nK1ljENCoqzkfjDEtAVnrv2u4Stzf1KhdnTEniGttIwzMc4TN6Y6N94Wki7n3yKsZcMDx/Ra1NpMPtjEOFkkztdVs5lubdOEWhFIBlSsNwsXJho10WxLJH/2ojy0KRN8Y7vHEGlOycPcPZVuWn0PakfGmafzZWsHUJ5mHt2tLSHDZlm3Zlm3Z7oh2R3tYGwMIcYC87NlILLPh9MgXkJsVQkGG2UcUOQ0r55HFoOAAdCv47nJVXPkN2FZOAk2eum6qjA6Y8W6CCDNiGyMKmqaJQrdLqzUOEJKSvtITgkheTr1FVFmxRqOwRhS4DHtSz+sChRYLO6YgaxntI7VOZNfXFPFWucvYN3XlA/GuDIkKAg8PtjPf540diyZ7/3bB3DaZog0h2tb37hi25UUBAifquf3acEYbwmiQT+v/vyhgfNzbOg4khp0OKub1FISpwiBAwpSEiK5uVytcJv39Nz/7DADgQr+L8gyLPxZOp7IRvXXlWQILrA/kuXdYmjgMUgR8himp84inWCGUYmJ5rqfPnkbSFRSgmog3VRQ1btLD2l4X72U8K3zeUcYO5DbCRk9o7xFTNiKtkRDTdrBcnk+RF1P2mWPIGqSReIX39uX8w8NdTJRYyQWfToIGio4clFeXWOFYTHWMDqnpIVM16grYXJWipJuDAa+3RNCR+2pcPlld4WDocuNYBLW2cAyhXaYDVHEHFU3/OCLkauqWbiBTOgBYeoZuPijYZhA68gVMM378H8pDqO7TWOOh9vYc8HC4mjsGjntOi+C/RTRzzP09TyTy2+zxf7Pb7Otro8rXOCl+W6omr0s7GdXQPwaksYyts1t3YZelnAw1VfMgxQpXlp29HQDAzf270FsRT6xPQlyR10AQwbyOPKylh7Vsy7Zsy7Zsd0S7oz2sflcouQUtqMNdwU1vru1hVolll5ey+iexRhpIkmgYyEofhMov2TXjSJWpUNBLmrlSEYvy2lSDibMmI4K6RknPauCKI1YGIcs8JBVgjXiBLsZmbYVhLhZlxrLjpqOgnFoBg9KVmmFayX4BI/xTXaHvSkS4uFArcuSiWcbU0L7yGi1GaIBxBqXIFGhy0duXudDrmsfG2x5WW0/NtKw94HjQt33cRQHj+aaAE4nI7XO0rc3j3ha9MccxSbrYP5JxUVOlPxykvqzINj1YVYzwkZc+DwCY7YkrvlsZvLIjVuNeLuMpSwHm+XqMPwqADv/RZ2JwGnSQMRjtlFU6cYEeVTTWBg/L+U+dxelzdwMArj2/zxtUYkiVla3axZwscupPTiei4jIpYqzEVEV31xuG6ND1yznOZ7MphlN6P+4GwqLLMir3d8RL+9hOB7tEMPLZlPcxREoV+4SEDVMDITUMaw0fHy3Zh7jbxypp8QOqZBThDCnvdcZrgwUOjw6l+6v0ksIAAUdVTs+4m6wizjj2XYFMW6J2ydK+zIjxaSrtKiSqRV0HhOLvgv9tKrv/3s1DU52IB7XnQzsu6wgW7bjVfOzXKpws34OT498u2CbX2dpn7jiLftOec14jcMFxy9aB3TOsAkAzdrWWyPgojMFoynHBC451AHJ8sLMjhKJLO/fi/rtJYKF+a40UnUgjfh1q7Xf0ghWmAXRQI3W1pzJ5YtevHOEwkzvWoUKBWgtQKYEHdeDUJgKfGe8CraYEMkYvs6kcLwT8yHAPItGNQGnOETkd5XDFjHvcUVehr811WE1xV7km50NT/mT36AoA4IDwZD9ZQR4IU62whwCAUWlxwGA0HHRYZViFg154U1rqEm14xPi6AA76jLyvb1XzwvLkh9be8+u1RQMXtCftPGHjWG4Ljn82PTm+rZ175YZxe1KeYPyhud52Pxcyt7g4ZZMJCs3n7pguFohIAFhlPtQ4P8Lvv/CKHI+dOZiWXtbJVaSeVQoZB0Pf9bkLdAibxMxV6oYJhqRBKp5ro9uF1fKr86ekEu92fwv33StlRa4+/wdyX4raGzShIw8o43OpAr6oj2YTnGJCnHusoQ4BkiN6hG32p1OMxvKiSUDiRKw83FW56tdBHyuFLDA3+ITrIG5ElEOWXYkSKPe91oiceC4/0zTAgPer5xaYQCHpy4Sx9sA9BozHcn+7q04stUZI2LLP0hTrYRcxB1dB+L+yu1D83nMvjPFGmq9CXNVwiV9thpoDBZ3ZVtW2JYjL+9Iy6ha96NtGW5tB667tBBHDHofV54/bsiVO5DRGqjW27cnfLyJ0VGhUcdqlUNwcc4XVoQHaA5i6VSIAUhoiWytSu+3y0XUvvOxKsKyEITQXL6catLOzg4OxLFRRKJBwFFt0wxh12K7Wdfu2hASXbdmWbdmW7Y5od7SHFUcpwqBs1B7ou2ZHgCOumnNipa1UgLM9jBaPJ4ktjCFEQ/jG1CEKmhauwGGkgbAn2waE4wIdoCJ1/WAsn0dTg5WBmC0pcxiSQKOayveXb+zj3ClxpRUpzNk0x/BILI8vXHmBfXkAg0AC1GUlv71yUGD/iEHygnpvqFD4PBjmE0F5S8yrONgGvvRinkrBME9HaUdCKW4bAG57LSeULlp/t62geeux7U3Nkyra+y9CYevWsduB5fkgctuSPdb447KYoU55TJqriQrQYcpBSv/xc9dfRskinUxFQgGLo4zjh/lrRxPrYeEuUxiULhslESPPN4jWUBVibxunB9lJMSIyEHQEjlvrncEb7r0HAPDbnqpsMTyU36ZUCjDlFB1avDnz9QpbeQ0+SzmCSMcoHYmHuYqFGeFgLNfR0WLxhohR887POnLczZXTOHNafnu5EKp7VgGbfRnH2701+e1KCKtlv46KEDCw7rQ6Yx3BUEdzJRWPzYYWfQe/omkTQkxxRTJK1eQyXiAqsGkscETSi+NfwzTECo53peC1Pn1xR2P8QGtvc+dweVaVAVizEqWHC28PW7vv2h5We+zPQ4JtwkaL63GCAGJbf7tWzf17/thtmN74d0JzojaZqcbxbXmoMGNKjYMHSx1ge1VyEIeVoFV5bRAQ5k6oEAOlMWVOUJ9lno6O9nBzRyrCb62tAQDStItSGVTqtedhLT2sZVu2ZVu2Zbsj2h3tYXU7qwiiUsrao8Gey8ICtM7OpYK1DpIMdSmkBesk70NAEduvKa2tEMLSpFA8ntLAyrpYCqcHYh2qwmA8EkvQkzSiCpGrL0H67axS6NIi3p+WeGEkVm0VyHHM7ABX9iRe8ZkvflX6Xx3CPvgm9kEs1Jd29lAdSn8GoVxb3lfI9sU6iWhlWhU0SbXuRintS547K7JWgAlc8rLz0gr/o0VZ94s8oYVqFa39FtHa56m/Vp20HrU6KTFWtbbNB5iBxcXzjsUDuLXWTVzBWdOJseiRo72biY7f5y7toevwfsZQcmNwYyzxlnMdRy8HYsYundrKSi9Fv+cqAVDVQimUBe8/zeNIBQj5PMckNYRnUlzcFK+nUTLRONij58HxVOocAeNkKy5JeL9qAv+h+6xdlAosWIDa1phNWUZnIOfoBj0kLHgadcTzObO6gfG27PdAQA9raHDXxlkAwBs25byHuo8r+zK/Ql17fr/vi1WYcVvJJxV1E3RJ6Q9bD2/CebVSCznKGg1NosmbtFx7L8+BHTkfpoy8mObJuxIrysIn0Hviejv51yUEG9Pa1swVr9buyUq3p6a7pl13cDy2uygdxDUfv11EVUczpp2XUWLxnFykx+ni7YuIGO34MxnsyGJg5kgjPEgU9dEdrAMAdvYl3aM2IRK+RwIuJ0Vd+goAJc9WTjLc2hevLGH6hh7votx7BeUkw2ttd/SC1Tu1DqNqWOagaMhAN1WJtZU1AMAjK8K20v0MN4qXAAClZT2psPSaSzVZQFEYo88behTxtRYB2xuEaziRq0kNy7yd6YywUVR5+KGkGmocDHBAeagOKlzfkYcWZC43ZIivXJOHf32HeQyjaz5HZq1zHgBwcGTQU/ISGyeyX38QYEa2RcexEm1DnGhgBeNzkRqShIblQLOEbYyaNblb3K/E4gVrftK2J4dri5QuDBbU47HtfjW/PcFwWrAfcJLtpFrbjjG4HCSoASJ9yIn59KFhOG1f4ot3NCybl4SrGVUAh8wZOpey3MsM2KAU12pfntE9p1aBnjyUisvjrMhQufIYnKR1FmJ97R4AQJd5RbUqcP6sGDTdLmugTafIuKAFZOMlQQhFoVhnJ6XpEVDT4KodlhMJyQYAyX3Q6CCgkbOdSJ9PdTex0l2T/Sh+u711GvuEpZ8/vMFzxOiv8vtzJFLAYueANyuIYLkCGS64U5Nj70AMgfIuKTlhQ4secxS9EQNgxvlsXKXmusIaF7k3QvqqDqbAdemPI7naWkG7+nAub8o0i46D/6yxLbKFIyE1BAyn4GSMhYVbqAgd4iTUJ/ez6b876jypaBFU/rXysOa/a//dnl/HGIi3ycMyC/raXgxnjjUdWFBACBMeb2v9LA4mwoydVa6OYBeK7xHHbLZKIU5d6SV5OHUVYEgma0H5uvKzf4B7t3roTm+nr3O8LSHBZVu2ZVu2Zbsj2h3tYXVVjCDWmJF1kQxc+YMcW+sCbSgGwdeCCFkkXtIu6eFxEiJOxfaIqYIRRdrTkROXo9PtoMP8g5SR9CoNUZOwcTQhxDidwlAKY8baJKdXQ2z2RbT07t4ZrLLScEEP69ruEC9flaA802IQ18ALz4uVv7FKuCbeRI/lRTItx57FCjPSs13RuaDlBind+FpNPglhD9uQLRQ9LOjAZ/Q7m+dreVhtGvp84LZtAS6C6BZZeG1r1KFEt9NBW3SO9t8aCpXLSyPecTgGRvy7ouBEWZaoKUJ8NHVjwsChFU7BoigBVr2A5pVqCzgp2K3NNQDAmY3TGBp6FFY8o0mxjzwX735Ganw+tDh/WuC1e7viTStlkBG+3mB14avjETI6/Pt78tvNQReamoSmIIkDIaqpdHCaS1+SwTaCrlxbh5WMe0WEZCx/378iGpun+uvIWRHbUGvx9NoGyqmgArvPfwEA8InPP4N3PP4WAMDFi/cCAEJMEGgKSKseOr76MHOlqgr28BAAMC7kE/kKeqTPOxfLWIuS1+IIR1rVOE9R3lOJzGGzO4WeML2AZCsVrsJ4sgXvi4UvtGnbuVc8X+3zHWtP6XfbalM3pUakl8eIDH5cqpOogUJDH/cIwKsgBK/Fw2pDea82vxYdZ35bm7h03PPj/SCGWwXeUYftMyWiH2LvurykKqYPJFGIgsLEAVN5giCEZapREAqM/I7zTyJPnPcrxzvzSI2V8xoYvXa/aelhLduyLduyLdsd0e5oD2stGCCONSoGrV0pgKpbeEqxs9LjaIB6Il5XVYg2WhzXcAaeM5GCAIjJSU+oUBBH1kctAxIUdKhhWPCvx2zh8HCMilbhzClO6xHOr5P40V3FBqnLE8bRZkWFXZIpXPE/EwIFgeMRdQrXtyPEPHdFjvVRlWPGLPEpbfyuNd7sasrOW5/s75MjlWkShr2+oPZm4yLLrR0bm/eY2laht9xaWfy+ars9GZtq4+/+uPa49+Y+20rw89satY92kNl6rzOnV7U/A6Z0iXJe7zSvffDCkGYbavjsfRcJrwEU9KIrNfXnC2PZL6FHv5J0oCGkgUMrntZBfuDjM0N62LPRGFktbtyA8ZysnOHSrecAAKt9Od7lCq4KDF54aR8AcPbt98CyDI0iUSAOugghXvypjQcBABfPP4QwdakQ4rn1e6cxGV6T86YyS1YHCnt0JTXd0ZVeggPS/W8MReHj6osv4T/syG/7q28DALzr3reiroTqnoYhLD36DE4hIsSUCe/GlVapDPpM0nakIGjrUQrUTqpD4S6iARtM+q/2jhAxl8CpaiDqAPQG3fEs0Gj/VS7GAinYCXjkAQBqDhpDwpS1tSdlLErG9dsWjFWgiSu2X7Lz+7VJTbfb9mpIQjtR2W9reXvz6SeLSBdt5MSVElGpxogajadJsLkxuolZ6d6TTCGqi4bBxP2fOP0ors2ESPb/ePv/BQDw6EMVsoEIJIyYDtS9N0E9HPhqG6+l3dELVhTGSKMIlnfZs2fiArtTeUk8NxVpkBcrhYoMvSnfUhunSgRka6kOg4R25GHCpjprjpySS2XCYCNSxIQOHcurLAMY5h84Ad1sNIYKZJtSGiHlkEwmL53D0QHyQ5nIhc/1gS9JollDpqiMK4qMPOXEwwwjZo13OclC2GbwmgYmNMffuyQ6cFIzaGpUcAJCaA9616zCQgWLeUgQrYncnmXzi92iLIz2ed3+wQI24bHAsyMUtPJYajSLl5u0eR9Qa3J/05gs0bh1JFdHzCpXqBe5ab4qeaLc1UhSAEl6XoxWI8YWGaWTzMmE1aioD1XxYQ7HOW4dykTeTQVe0z2DG7eEiLN2iooSNvB35OrlQwBA/C0BMuI2uSujoyyinvThwXNSp+iRsxdRkFlYsEpyL13FPokVOiJ7q2fQ6xCao4UTBhF04iA6CkkHAaqJjNlf/MgzcrxvOYU3rpzj95WfOyGhubooEYQ07ALHngR6/eNsSG2AunSLnHU3FWeYOJewurc+2IMbOZZwfhCHMCz5Y5SDyo0XqTY++crfyqaUUFX7unamVV6k4o6Lxnl74Vo0lhuDsWnzcGK9YNui+bAI7m73oZ1LtYi5u4j40W5Oxi0iWaJY6UKn8rxjGizD62OUVp7DoCOfk2KGhPVxTiUXAQB/+dv+O/yzz/wbAMAb75Uzr556Dtdm4iic2ec4uG8K3T0PTVLaa2lLSHDZlm3Zlm3Z7oh2R3tYQRQiiBS8Q0z3PtYJUIsV99kvCbRy62iEQeTgGlpSYYQNSt4r5spUJsdoIpbAdCZeSyc1yHOxKGdUOe13Em+djRiFryqN2lnkzHHJMoPr127xOJuomKW0cyCQyuWdGzik0gFj/Vg1wJTqGc7y2R1m6KRiHdtQYMU8CbGrHK1YftyFbeUquSCzbRQueO9qa5qqp04vTYcnoLkax7XQ5F6dJEnULa+r7UEtgkrmaeivZlGepOc329zA1Qu8rnZwu4Kv0YhDd6JBgHMPi27fWeaVVLqEyS4BAKYjwlkFYEi68DlNUAhZdTqKudEUiIg7RuSN57bGelfG1ulYoLKsHGISygGTaMa+Fti5/hUAwFc7EqCu9vZxY18glbDLTgfal8k4YhXXyaRGznQMp7RhqhrWyrF3c7meiToHxRwzVywvSC2CRKDDPQbLH4hLVCy9M3GXBiDmXX9gTa7nuX4Hu7fEawxL6culvV3c0yMFPAh8rlhEGLuoK1hm+TgFjhpAlyQPHfMhZTVqUqYdLNDRGpskVoQkWOnh1BfcBD1FhBo6cOQBIiKtkiMeFbAG2pXl4e2tLLyH5XLkSmM8Yed23lSbmFTPD9pWa4/z9vEWQd/zP19EVmr/XbxKv+a9tna6yHGonfd3nQo73UPctXEPAOCr+/KuKmqFkM+z5EDSiKD4/j2/JuN886Fn8WP3Cikn35RnNJrUOHePkM/C9FPSlyxFffhJ1GOvYPg129LDWrZlW7ZlW7Y7ot3RHpaCI0Ew1sTtVlsMuhKbSAuqBg+Bl6npl7rSz0GBu88xSGu6/LQoWNa7ZICxKoB+In7GEZxlHGJCNfdXrolHNh4XXjXdMrZg6wizQwbOrxxgS4xUPE/Z/Zs396GcgdHiySoT89xiyewVE/QZX6hXGGxOgQNa+5u0GPvWQLexegAatiEUuFOZJqDsyx9Y7bXknGVZolGuWJSw2M7On491tX/TturmExfbQWZnoGrVSsy0zXcnKLktc7N9jvaxGdbAdECmRX+A+x5+HACwviEPxIQlPvuceLD5jng3WdZ4JJ5KrKyPV/Vo2Ud2itAlYbND43zkyQVrqXhxRRxgTMXqJOBnDAzHrwAAPvcFxpKSCbqkx/d5vKQXYXZIb4rxo1dePsDaFmMJTGbOS4seyQXPXXkaALC5OsCpU0L86TCFoSozr16eOxp5bFHRYw/ooWgVIiAyUbH4ZJCUoEA6Yo5no8cImGLRVQliFqh01P+irhGQMKFJf4dWiBhDdMyYSgOxYwiRfNHREVboSSomettp7mOvmv2z2rbeAXwnWI2ydLGpip/Ga296Cru1yAmPlK6UiDXHYk3uc1Gct03AwNxv5klL7e8qnCRTtMdvO6V23utq79eeewuVLvjZLqp67JP3K9mSBzsaX0F0j3hbo13nTa0i5eAPCFu8of8O/Pdvkuf6VfIGth7uoUe4aHpT4rNfffk3sarEwzp736H0zwKDjSls9NrV2u/oBSuvZ0ihPeSjGeDVQYBeKTdxdUUC38GtIdx9cdDb5csWhyPWGOqSfIEW2YKPOZsBB0OBcqqe3LJxpnDjJiscH7Dch4lQE1osK5Y5QA/pQOoTnar6eDISt/nZo9+X7zPjc3xcLk9lLRISOjRfKkVZ4NIN6evGjNBgMcDZ3j0AgKNQoJleUSEkNBpy4gcwTZVf/lHWtWcMFpy0JeArF7uyKwYnhWvVAtJF1dqvvSAdL2pyfHFyrW5t8wo9dgFU2Fop27kori2CSmoAlVOguigv7Xw9xL33PwYAOL0ui4kKMrxw64vs5PNyTYXx7ErXryAEGGNGHLvaVkDqSmbwRTicHSJnTbZO4krApL7USMiFSEcW3YjSYiXL39gA3R7zB0ljXd1MMaX4MZErPP/iLTzalRyqvYmcK8tLTJjQNxwJsefFlz4HHQmbb4NVkGfVFIWDuwoZO7t5BavlfiSE48IgwoxsoGdeeUnuC3KsCHqJFdZTmekRAt6sAs0AqfkUs9IijFZ5TOpY6SlCp9TBvK8qa4Rray6kaUejSyhQZW4xK2E1YXNfL8P4AaFJ7NBRhZJ5aQ4DL6saOnRwuYMOK794VS48ALMQrltkhM0vAouMsPb37Tk1D3O1iR2LFrFFx61b/3AG6CL1mQVIJQAg49XkpSw6yWAdXz18EQBQkHHU7UQNO5BQ+Xc+fBHv+lPyXO95USDoAhPkhbzz4nNyjH6+i30mNYYHlOOqdrG+nSAsl0oXy7Zsy7Zsy/ZfWbujPawXb72Mu6JTWB+wICNN3zjQCHqyaq8PxELd7AQYj+lxEOqbjICXrot3dJFU4KoqkNAiThL5bjKusbsntsokkzwWU9UYjkl4cNTyMIAz5spa+nKhcw5vf+h+AICaTfGGu8QifsNVcbf39/dxmAukGDnzJwR8agkty0rXOGBV1pws0G60itMXxVOorn8WADCeZug5t6B2eSe20VM7Rtl11iW9QlhUTvDSwYRYUP7AnrQAxTs73trWUNs6PEGmwGLihZr7d5tcsUjo1u/fImIYALYje6d3i3c7jA9wZltUJU6TSBDpKdb5t8NAjQGIDoGOAMJA1FAA8eQB8Tg7FGWN6TkdjEYYjsVj6lLppDYGLaa29ClNkQxYbJQQWNrfwPr6KrfJXd0+M8AN5l+5693fH2M4Fghn90DGZTWpAUJzNcVh1zu7GNwUYWVTyvhDXKBmteCCcPJkYtBliQjrWCZWI1ByjsvXmZ6RA6ySAzoy2B+OML1LLipQyl9fTAhpkpfoMgfRiy1DISThI2UZimIy895ZQQ+rEweIWCRQlc59KGFBM9/BhUHgBSMdiSNCB9XBHq+F5As0uV62hTg4lRevfoHF3pRr8+OzvV/bS/LzoDUuWccTtT1JfzfAsbxF9137e3f+hkS1gAjV+r5d/HERoUMTAn6J4Y3k68/DjKi5Sg3TyCbQJStvsyjtN/7g4zDDzwEANk4JOUNbi7h7kecVFGrQi7EzEm/ry5+Xzjzy0CnMptcxmy2a/Yvb0sNatmVbtmVbtjui3dEe1nMvXcWsGmNrTcy8tRUJ4Pa7KTpUXO8l8t2gr9Gj+5NNxKoLA4PLL4kXtXN+FwAQ1cZTPDdSUnK7FgcHTJ6jsnBc1ohjFqnryHHfsLaKV0iSCBL57smzZ/GN9wjgP61fweXJswCAK6TOn9pIsT+V89yiSkadATm9o1WWsLDVAIel9NEpVPzRR74P3/KtYpl+7EXZNt6rMSCmnxDYNjCNdhopw6W1vihdVTvyRe09q0WJkovKhrQtUNX6jds2H3CusCD+9CpabPP7LaL71gv2a3uFCoAZ0PM+JQGX3YN9TEnHRles/tVggIg1QupWwrXL/HdxRh0BrCCCiB6KVgoJfb6EsdPZZB8v3hKLMuwxqTsoYbXjydPbDyzWqAeYUOF8vbeCU33p1+xI+nnhTBef5o1lnjqKydTHrib7TuFcY5azhA3kGMW5GY4ORaWiyzG9GqawsbjqRyM5xs1iBXfTEw9o4leBQkKSRE4iUT0G1ngZI95naIXEqUfE2j9Q5RQnqgyxUz5nIj0K7euKBI7WbkXHDgBKxjZUrX06vEldgrFGQDJFnYkVb5O4yeAO5V0QxwoqkMB/zrlZm8rr3lnT+OfGxZ9dQjIWkxrmx2CbDNSUkjyZtHuMmOQJUSdjv7r1m/bYnydxHDt262+XhtKmtbc7vTBm5lIIOnLfLg9vYcp75N5zlS5QUItyhWjWp7/yMbzrbhlHyfb3AAB2XvgdnHlCYmGWXnIRjrGzwzhZwaT9nQpn1vMmofs1tDt6wRqEHahZiTFvYjY6BAAkHaDfkxdHTMyiG/cxYB5TwXIgOgCGe/IwrtwSuGW7O/BqBUkkEzXWGRz/KCWZYsuk2GDl1KOOvFSi9RnuU3LeEReBi5sDHEaiWnDhrgIHFLWtlHxGicXqqvQxo4jktARGM3kBnek7eBJQXGgjXtND9+/hxi05zv7KEa+3QpdDkvFwEgQZEDeOCVWjZHA7c6yoqkbmxC85wHP+BywWoV00YYIF35vWd4uCx/PQi0IDm7VPqOb2WwSVtPtXApidkmeiKI6MUYyq5jOj/I+Nak8QcBJIJYDIQSmc5UkCRE7VghV2NSy6fdaRInZo6gIVoY7xUIyTtdV+w0rjuaIEiBKZ8CklwQarMTqJvDjKSO7++mDQEF14U/O8xt5NWWzGY9a2qitEXCQyI6UgdrfXsE1q7D4XxzQ9j0kl8+GIL/zOJIUxriJSk7fXY75Wh6VOhjcOsSPrHxx/YrBawhBaVCqEVknrKEBlFFZpUDrYrg19uRfl0DbGi6tpVegaBR+AdWK/SkM5LbOyKQOkeN+0kucRVEDocssogVXZAkal/n7JttozY9sM2flSIovID+2x78kUXyMvcVFrz4f6Nr9dNP/aRlr7c95g1ACME7rlkUIAJY3c9K41AMBedgPoO6Ylz2tK1I6lycn5//3Ib+OxH5awxMPdPyLfrb6EwwMxzHssx2RNgUu35DeKIYiDIscb7p1inL0aFeRkW0KCy7Zsy7Zsy3ZHtDvaw7p4Zg3druSUAMCMbITxbIacOVcqFotgsHIWg6l4IYcjQhfWesv5xWsUZDyXAbQUnXnSSSIEFLqNWHzswmAT6yusEEv4b9YZY8NVb81I1z2d4TqVCc5sX8XGGaF73nVBYMLhYYYZrdqM9SOSzGDGYPuhgwaTCAmL62kj1/Ghz/w8/tg73woAyE/LeYemwBpVBoasopyqwAt/Vs5qLUtkplEDAICxrbxV6NS9ShyHGABadXPej0VjLZsFBtNtg9atjbej8QKLLawTlOLWDzIAJT3YsCdW9fmtM/jKDfF633jvwzyuaoLuLXYHUZHG20Obwu5o7RqDlN4FRXAjEyBE5b8HAB0WvoJ0QDHXOMoRxdwvluC21TEC7qcIcfXSjoci2/Dp9EieXU61lbosUZHwU5QC2H36976AjYEoD2wSRltbrZBxv7EVD7CoL0CzX64QooJFyJyqDXqRBwByASRQU14wDAycQEWgNCIeJ3DKE5VGRE/Hlf7QOkBNmM55qO3q05aQYFZrZNQQVB3nv2tY6nYiJySYWtgOywqFYtmrrGz0DF36hqlR1o3ALQAYU6P0HlbjgXiliAXe/m3HapuY1JoX815SG11oruwkmtGGGNvbytY293c7b9Lt644tUDshb/eMrULN0MNVj0IpKNtjv4k8VNanwCScGB97+jre/14RWb72wv8BABgfXEJ9v3j+I7JLbu6VuCwOP6bEkeu0wjfshZiMlx7Wsi3bsi3bsv1X1u5oD2tzfR1xXGE6OR4azcoaivZIj5ZW0guxlkvM6fqRWBF5Hngpg2xIUsXKDNpIsJpGIiKdoq7ELAhYjuQoOgIrmqObMXB8WCKPaLklcrwvjF/Cm3pigRwclOgmhwCAc9viYa0kBiYRM7WgRbl365a3ynvE5ONu4mnP06Hs9+xXhnjDOaGS3pqI1xXECrtU4Oi7cuG1gTJNuXEAKCuLgt7bzFL30DbaaYypL/Sw2vbQMWo6/+GTbBfs1yZJtDUF59Uv6tY/jmH7ePXm9qvQxN2yNYX8jFjdCQkH999zNz78WUkD+N6nvlWObWJRBAHg5EoibT3pwqmjhBpIXNKvu87AIquYRJ4fAgAmJsa5wZMAgKJ2CZj+yqBd/Cso4HQeVeCU2UNoFpOcsQxJGAc4dUqSeo9uyvgt69ojCe652VqBm1DN5I/xuMS/+fjHAQDf+MS9ci+2FLY3ZZwbqmWHUeDjt/5uqqYQYZ9JzO1nndFaPto3ntSQpBYBj+O8KegQGq4ignxXmyYxOyY5Sqkmflpm9DzzAGOmELgQW6AVFElKlvqeegCYnlyTojI8qgOfLuKe16y0CDQJHY4IYoxXh29iWa3YVHM7bpvAi7nv5redQAPm/gZk3ixKuJ8ncdRo6PGVPZng3y4s2T6Hi1K6h5jrCJv3SAmRzw0Feaj7KY74DNfpIs5sCeUU7UnE6K+F+MwVeQ7rF38DAPCJl3Nsx/IGUatyjBdeyEFwCkeudNJ+judfUZhNX7uHdUcvWBuDVQRhCdC9LyrmDRiLkC+nTn8NALC5fgowFPvcl/2u3Cx8hnjlcrOmFknISqdc7KANYpcXwVFfBXuYTfkiJFusGmuMCduUJFBkQ4tz94jCQnm9hxvr4hcHZKSdPnUK00gWJVfCZFaU3uXeIFtsrdOF5WJzgxM02kjwCx+V7PK7WRU00SXWQunXHpO5VGkQumA/J3lhak+2GJMdVcBiNke6WBR4XpSdDyyG/ea/M1hMpgjm9l/0YjiW19XKU3HHbldJ9kSRc12o0/KiryiJEnQjjK8J4/KVfWEPrGyd94u5659SQOTqqjm1jBCIycqoeUYVWUR8MVeEXlWnj//h8R8GAPzyK/9Afqs1SleF2AmtKqDmWyUmucFMgdxJPM0oJ2QNNtdJTNiRBcvUFobKDznroOQzi8r1izeumypc/aps+/BQFs8sn+Ebv+4+AEDI3yaBhtbHX8NKaTgCX4eQKurmRenqiU0mDU80ULHjVSDy1aybatwVr72sKp8fGDllFyhf9bauHIRnMQyoBuLQemWhuNrZ3OVjWYBBfk+lHO4hJuwX+KrbFQrbwGGAEJMK2yxU7vOEygsWL1C3W7TaEPU8uajNrl00b9ritfMyS/NEixMi1a1jt41E3y8+JNuNcLguR5o5lFVZRHwOOUkStjK+yJdlTau038cnn5FF7ru+5QYAYPOuDNlMnsOVm/JsvvyiltoxAGyrlPjnv1ggX5Iulm3Zlm3Zlu2/tnZHe1irvRRQFmUuHs6IlStNnaB2YrDMrj+1dha9WDyTvUOxBG4dHKAoG2oyAOQlEGkxM9J4i9sOvVlDBjB65wOESjy1qzeo4xfEWGd0djqS84+ulTi4R34cmyP8n78n0ep77pdjr6ycwipLoVSrJE4Mp77/Gz0JxK92ByjoDl7Th3KdqkAnFpjzK8xQvztNcDAigUSJddM1KRJXTZWR8cIWmFmnIUjvEkBBY6ct+D9v2VVoqqm61obrotb+i8Q3501JNf892zwlt/3TtgLAPKU4b+03vThA57QQXTIqVIxR4smzFwEAH3v2GQDAI+9Y9dBcyIPbEOiFrrglPSJrEdFzdbqLVofodah0UdPbPruFW7OXAQCnBo8CAGbmGe+qGYoWVzWQ0Z01zE/pm22MRvLs9g4FLg5rQKvi2P1oK3EUPMZs6lnjSLuORmy9VTth6ZRPPnsVW5vS13suCBwUaeU0UGE41lSgAWom9tOevy90jEDkErO8RkHvvWc7HvZrKOyVh99c2Zu6sjBMrXDFUCOtvMZlzfGZmxLjgL9xUGMAP1jqgnluxkCRcITYKXNaxI6+TSTB2NJD3qgIx9rjdHbguIflPkOc9KaUarzZRUShY/BfS+HCbZuHyNvQ4SKtzvY52nmT87lbsCc9rAhAwIec8xnd9cBZ/Mb0BQBAxvyqSneQkNRSeVHgFpGIyiSqrHC0Lwoxewcv8ZpKDLryTsyvyjG++qLB1mn2hV78ZGJx+ZbxWQmvpS09rGVbtmVbtmW7I9od7WGlcQfKJrgVStB4SovhaGpxpiNxi5WOWNdn+msoGTfa4VL/3JU90MHyJdA1gFnmYjsSKzgclz6ZlHAudo+Ae+8Sm4Yi1JhNZ7iLyZ8Bo8MHe1fw6Rc+AgDY7EbY2RdP6IGL/I22XmEhXRW7b/VwFdeHjHURkx8kfVQsK7KyIttmwwqzFbnm3qEcw/ZiDPfEuunTS+uoHCuMnEeufEjd0FudgbMoSbhNu0Vrm2ttz2g+qbcdKHatXU7B7d+OX7WPMW9lajQD1sUeA9tYxHnrs3qc+md3n4VlUqpmWQ7YCg88KA/go1+U4onVm5/CrKD7TEWHEDU6A3qkDNpkOaCo/BCQEDHoaXTTDvtF1YoZcMRY4pt7Usrk6dmn0WHSd8rs48ICDAfgpV3Gl6LTyI5k45B9Hh8dIGAR0ZBqD1XRkJk9w7uCjzk5EsEMQMLSNB3e1dmhxbSQcRRWTAyOOrDGBescGcF6L6TPsuiqB5wmp2HM/IejmUXOapc2SBFyTDtKua6NLzZZOQ+rzlHS1RiwVItB4624ciCqjjCiPmLFC00QIeAzqYxQqJVRMI4hktLT0pGP/SY8bqcqMOZIsxxRlapgVKNwARzX3WsrWJxQobAnx3l7vzZZou1ZuW2tkM6J1vaq5mO184QNPfd9O17lzhED0K6KQ58iB+cMxlJRBwXJFHVdedKZ84w1CmjLWKOW8T7orOB73/7d3E/Gb4prqOj1Ur8BYfUgiucEXdp4XM6xc+0AN2+WqF57dZE7e8EK4g3U4RFwyElIVlSsNQYUsB0kcsc6ceoXnfVVEjLSABlVL1xZnjhWyJgXMB5Tvsc0+ThpIsP0lasl1jYpE+OqzBqNm3z7nF9ljtPsFj7+rNS+GvQ7SNZkdYtYW8hUY7hytp2uLLKD1TGu7EtwfmplPx0N0OWLcn0gx5gVQ6yH8vfGhXsAAO843cHTh88AAPYnVDDIlA8up44tiNrDIu5FX6BZvPykbLEf2gvRIlbfQnLEgv3mF6o2LNL+nF+w2uoXbuK3F9mZ+/VAQT9AmOvcFhQnpiv9YkqLOiUpZlcgt8v7O6iITWQcKEYBIaUuCifXFALGUdsInww6CYLYwXryeXM8xCoXp9+79NtyX7YCWMLNCWvdpErD0GoajeV55UcvYaV2ckjy3fBghClzmvrpGgDg0BzANQeAJQEA8g1IrMNKt+k/C/eiWwPXj+RF/wbWXqsrX0kN1q0aykKzZEqPShfbp4FVsQNxS7gryK/bhnCCysOSQdQwA7OC95wLb21qD+0mFN2FUv7BuwrApigx4gEzPuKejmEVwWeyHK0pPStRufpkJodWlFMjjp3AIOF8OCBDNoQ+Vj0bOE44akN8iwgW89+9Ggljnl5gF5yjvU97kZpfABcxB+f74uaiU2yJrELJcbTKXNAvja7DpFy42fM4CJDXjmwhPdSBFoFhACFnb6U0vlLI+H6wEgPt+h5w6rR8f0jD65H71/HM516RYx8K2ectbziNp5/9AurXsWAtIcFlW7ZlW7ZluyPaHe1hvfP+x/DZFz4FYwnl0PIdpEkjTEtechRGqAKxQ9LEVeBTTnsTPXpYSaxgaIm5wpCDnsZEHbegi1IKQALwntYkM8hvEnphUb6NoMLBoRznlRtTvOmtYoW4woxlMYKxroidWJmdbopBR0gZV24KQeR0kkDR8u8REjKBwRq9h82VRwAA2yu70KcFDpswX+tWPoMhfDKwLv8ECz0s97fXBbSLrcJ5a7Bt7bXLKcxrolU4macV4KRlushjs2iO1ygGqEaUl8Hc4K0XoB+W8iFmvY8ecYl8Ih5zbUpMZ5JA9MApsTI/+vmnoZ1iAp+NtkDEXjCtDrlqIC1Dr6Xb1UhoeVb0li4f7APMq/vSjngAT2waBPScnPfT0xbjnDlGTLsAujBOEUNJP5UNoXjeh+8V7bbfe/Z3fbHRbkzILOlhk/laIFSugwm6U3myGe9mCoVrl8WLnz0i5y9L49ks2lFnlIaiW7s6kP6dP9ewxnf3eB09g9oyrUSvI+BvHIEijiKvWdjLXUXvwtP7Ha1daQ3rKP91AwkekXQxizjKghiWeWtgjqSpMgROpYLn17ZGQB8mJczZg4FL3dz1FbYXk4vqOY++TRVvK1Pczutqt/Z8cedalA4yP78KnJyHbQ8qxMn5EgJwI8p9JgBKwsN2Q1Coq5MZRn3nozuYtUTN8Wv5XlVBgNApodCTLU2N3/qs0Nrfckbu3K8/rfDog1TU6cpx77vvLD757O8DAD72KSFn/PD3/DFsn7uOsjAQ/ZSv3ZYe1rIt27It27LdEe2O9rBOr9ao8gpDJv3GFFuziYKmF1ISv661QeIC46xfYCPlasmhyzuRxCFqqlUMaB2oyEAR67aZ5ncWsyGJDNQrHE0BRcWML1Kz7YFTAGvrodAhCpa1mDHZsa73YbTYPzVp6J1UYWtViCGXbkggc1hP0WdwImViM1QFFcjBu6fEWvqDfYveFgsVUsNwOM2hmeFp6aWFtsHO20oWJ1SncTKxcVFS7yKts3pBMLrEcQzefQYLsifdn44mH7Z+01CPrbemV94p3mvx2D3oPygagUVqAMauQpd8CuN1Gy+Q8v7RZ57Bg9tOaYLn102SI/OuYWv4Eu4Z71Iv1ojoHbs46a2jISqOqTwjcaKsvcq1TzRGih41BINUPKM+TiNnGZuypIoDLAzVNO67IGoVL934EkZH8vx7VKEYbN+Lb31CPDDTk/H7wvXPQ136AgDg+hHJI6bC7Lp4JldJnX+zbZ62duJ5Snmztkciw8YGMBzz2vgg4sigpAejdAw9nx1uLYb0ak+5EjdlgZJeashYF0INNefqqBo4Shj/ip3Sgoamh6UYY0M+g3I/IklDB9rrALpx1NMafXoNPSInt2ztB7qLq1U4iRAsSpqf93zmW3teLEpyX6TCPv/bdoyqfX69YN60CRauYkOXzyPZSHCUys5XQiFBzMIAOX/lYoRx3XjbnshSG4Rk9LiUjrqqUXJclkcyjjthD19+UQbIW94i86wsr+HaZXmGATv1v/+HX8VP/Pi3IZuV+DD+44KrP9nu6AXrqy+/jN3DHXTJAgvW1gAAKiwR8Q1oKcGSVVNsDCRfYNCVz7U0Qs7gN/VMEScRwMUuIctLVxaBlRdDNqVkUlUg4AKiZg5IM15Idkao5Pm8yStBVAPM09k9kGh12q1gGBHN+Di66SruPiUL1uGhsAUzU2LNEUm4uq71upjmsijFqRP7fQj5BsVsV27xc4KaigjUO0W3bNhYbVFbNzHctjbs0G7zEF7DV2v2b8N6aG2bb22IsS1PE7b+dp96br8CQOGq/HblOTxw3yZWz4i6yMSMMD6UhxFoeYHn1RSVEz0lxJSaGK/sXPf9AYQs6Ehz7g1joKAJzYXE9eq4mcBOTLmaTlEwb6p3eg0AcHVnB6tbMpFrwtIdu424EsgyiNe4LcWRLwMjYzvVGiXFjEsaNqc3tzEbybiMOjKAH33DE/imt4rQbbwu4/fB3Q189DPS56tPy8JVlE0gfoe4njGmyZ9yTTeLsCu3k4QaGaFPljuCquBJEMo2cKnlPCyLHJNMIMqcn5PZBIb3LXFViAPtCRuKx6jLGuOU84rGZNkNEbHmlZ+8sxEc5cznCYaJz30LOU56ocaYRsca9zsEPEzoWhuuW7ToeAOulYd1u7boeO051zbu2tD4/Lb2p8+valmR5EciRbNgDbqydeXh0xjN5J3xPGG4MrWY8V47+LeyxquQuG1BECF06jlcxMo6QJfhjwkdgp/8sw/if/5/yzm+8tLzAICz22O4YtGDLhe77Aj/9v/4ddTVa7h5bEtIcNmWbdmWbdnuiHZHe1hbGzvI6wzbZwUCK1bFcivHR+ikTqiTGmv1FLUVe+OurbvkczDAwaG4s87CMHWFblcgujQS4kM9NejQdD6YiVXS769hPGFweCqWtsqmQEhtNOIQw6lFh6bsSseiy8Jx+dRl+08QsEhjlUj/1je3kbIQ4JUdcbNDZB42ibTstzXo4wY9tRsHXwUA3HPqKcxG9Aw35Try/S5GiVjnmskoxgCsanEss/8ErLfAegxa1lwbGpyHL9ow4QLk4rgCAP/hBmQMoOPOh+Y7T4XnAasImAqzH6eeEOJJuQZ0uvLr6XTqywU7a16pCKGW575Pa/OBc+fxya8wj8SdxEqKBACflxQE1gk/ICPU1w00Zs4jotfVgcVoeCi/ofrq5z75PL7pXRd4HHmua4gRkVauKBuRRh3UxMWGhJbjQKPP9If9mRBxVgarSFjaxqFim5sad90rkGF/XTZubnVxNBbP7vdfeQ4AMLtaeuq6qh3RQrVkFzhOVCNDENNDhemgMnLfXA5NouDh9SC0COnVRGQ1aWswncpvpjOWARqPEFIxoU8dxSCKkM+Ok1/qssaU1n7uXKd+AhWzQrdLHcvGqEVhGIoDREUBdODKmZBAExj02K++K7eipHgk8CowdwsuXASRO9SsLTY7jxq8mh7g/DaFBiJ3OoQajXfhzp+imZtt9MHNm74CqD2LDt3W9QdOQV+XM93g+2vYjaCIHDmmmTUlDFGIyLjKwxEU54PmDYm1RcR340f+QCDfH/kfBvih/07G98/++mV2sHSvRu9pxT2Fm9kEZulhLduyLduyLdt/be2O9rBm4x7StRke2X5I/l0Lnr/zKYtwS6zlvBDL8uaoQG3FAggZ5N7or6OrxVol0xbDrMTWQNbxs701AMAk0pjQLDh/TszIoE5x86ZYijdqiQFkBwYpZd1tR/aPQiChhX1qoBAlrjQ3LZVpjaAvluLmuniKFzprmJKC310jIdVM4PIknZuxNujjkPGZ3RsvyzWFdyOkTl2PKh+TlZswLrm6YP9qi3VanjWp+pW1noDhPRnbKHP7uIBtrEE3gI6rWtC69ep8x63NeSw+QBMQd+TaFI2lmLaO4RJfwxVuWwMm/PG9d4tH3El7KAyvpKw97u6YE2EEjEp3ACEcnDm7gv4rkmw8DYR2OyksagZwfCHHCJ5KfjiR8ZTUK5iNSXXPxIsY6Rw3J1cAACtMit2HxRe+KDHJex+XGGXHdAAWJ/TlAnUtrBgAobv6OIBmvO3GocQm7+6c9eSGCdVZOmGFWZTzb/G0u1vruO9eofk/et9FAMDHd78KTeJEj557rQKfGuJHQh34uKy28kR6xf0w9g8AAMx1RncAaNf/qIuYdXgiDpQoTjDK5F4PxzJPj7IpVrhfuiKelo4i7y16RQxTY8IRNzI8YWJhXDUFeqtBXgKzIftN1Y0whIlcUjK1IoMQXU3dQ/kGfduMMx4BJZp54BCAGs2Y163vFilduNaOQ81rCVqcVLgo3Rdo5oXEkhlL4pcGAOUiEQLgK8eF4DEIFCKeKOB8iDq7yFaZ4lDIFcxyhS6vRtPbya2kUgBAHRJJgoblRCh5PK0tco7B33uenvPwLM5sfQkA8LuflitYqw5BKUpUrsJAoBEW1hOaXku7oxesF784xsr5XVzcFEyoimXh6laXsFvIy2J8JAP86OAW9iiL1IvW5ACqwoDSRyVfNNOxQbjCgDiroMbIEJ3icJ7KpD2a1DhzSl56Nwm3ICoxoVQ+UxxQWMBSIkklGodjmUh5yUBwAHRmMggudvg4LHDp6FkAwOqq9OtorAFOstBKX3pxjtWBTNbpRL4bjQ7QC6SPDhYNewNMVliVmVVoa1jErv7PLflMhYsm/eY9NmhgiTYhwrW2hFPjrttj+7d/E6MZdC7PLQoAlxqXEgtJYZE6QQkO6HpFIWBG/nSd6gzdSMSJAeTE8g6ym+jV8mLu9ztQZIqOCrkHQRRDk7jgxGMnRYG7twWu2x9JRn5Q1+i5Nxr7d5QBioSDjPWf6jLANJdrnrDEzfgox2wsLKykI+NzY7CKT33xZQDAxXtkAVGhguaCasheDXSNgqobrmJvYDUqp85AZuAbNh9FQiLOZMxFqtuBcdIBvq5Fjc2e9OGR8wKHX3rlMi5N5R54hp6tfUXqyudjlVCKhB5X9sN0kBg5nmX161gFcOH+Ip+hJqlBE3oLggCzqbzQRmMZg9NZjhVKqAU0nmIdNDJpjqhoLUouNjPCp7anoZljqZwRWA4REOJ1K4LRCpolTtzLVkEh4eDrEhLsWWDdETBamF+u/J/Sz5ax1oYE51sb5GpDiG3RZtfaRh+wOC9RSEicp1zF0hhInQGnABZFR8rksdhYP58MWcKXvnId1/kcM5dPWFnMyDx2laKr0ngh5BBiTMRag2RpKIY5yqyGZo5czLy5X/7UBG98VN6J3/EOGSe/8Ks72OqLAb13KPOin1hMIw2j2wGF27clJLhsy7Zsy7Zsd0S7oz2sIrqGg0mOPvOYclb27QRfBOjuliy/Os0q0PjFkRb6slZdbPTP8PsGBjLOtKDFm8QRSiXn2KAlsj/NYbV4N4qeTDeZYkroMKfFHUUKNWGW8dji2vU97ivemwqA7Y78bRmU/vzhp3D5ptCPVzs8R9BDSVHWgNhgJ+6gTw8xilxJkRSG3lbaIQGg18GExI68ooZhZHBICzw65P2sjDf9XPD4WNmCdiDYW7/yGeIkVNJVjd0UuurH2joOBCi1hzgFyIhF6qr9GovQ5QQ5tZKNVQyZW2T6cg9m0JhQSsQy6l9kI4xGAputrmwiIhHC8uIslE97KEv5blqUOHtKIMUXd8TqPyx2EbKzSZ/9OqoRMTnLldMoixglx9v+VDyO4RTILMWMOwL/pWkEFHKXvvCckGW2H7oHHVqwJnFwbeCFWL0XrBUq3llHWoiijvewMJRtRV0g8nRkN15qxDTF7zotKR3bZzrYucaqsF6gsUZlnIoH/HcOIXAej1EBNhK5RxHzAHWgPPMjDBQCMlM0PRljKhQzuZb9oVjYealQVU4QlxZ+msBqp5IhzSpAk6xUpjJXqsTCEpFQkw77NUI4JqBXEePXASz75Sn7uhHCHdD1X61EJBgAiJfgFgQam2/zii6LUj8W0d9fjVrgqen8DBd8FysgdSkpFJRNu8222AAppWsidjqtgZBwv3O29/cqHHSIooTuegJ0jVOu4JiwpqmATbKP1gqacQH3HrFBipxEF02i0P/8M8/iB/+ojIu7t+Wqzp6KsbFJrcFb8vzDGgiiuino+hraHb1g9S6+gK/87gTVNwn8UzAfaj0KcGCpFs0XkplpmD2BYeoVeVnEnR5OrcvdOhqLQO3Vae5Lx2ccAKpWqMngijruTV1i4oRpcy4+ySFCN+FzeVAJLGoHs2TA5IjwFOShQUfQlA56eUsWzd+/+hWc6nHQUROoNhVywhwh80/CMPbMrM1V1tcKTuHWVF7WmotrmsY+4a9gDaRZXII5qdh+ixzv8FMZOo6ZxXvcdtTbSbt+cWqxp7zQJj+TCB4L7FAc1kZoqsa6YrRdIFnj37qZYBP2JV6R+zs4/RDqPqEqLc/aTgtkh/LbaiL3tttNUPGFdTgeIqYyuhNiNXXloTbLV4KxFmlfvr9n6x4AwI3hLqaMNa65voYKihdQEu+vigCW8lqK5b5nCriayzPeVPfL58omuqHAjZ/+wssAgK+79zGskukXM0aU1RVU6eKKhueqUVr3InfVdzV6PTJCd2VM7+7vwRQyLh3TsK5mmDAyQxF1nN3s48UN5uHIJtRliYLVu2tW0daB9WxCl/dUAFgnxLjBhXxmLLRyR7IIQy/1LR9KIaeC98HogOeIMcnJ7CU2q+MQhm8wL5EVAP1E4Hfr+hDUAGPDOmSlgigCxodyjyZyvaqqcayENAAY48dZh+OgZyxYDAHn+HLPAEz9FUmrcDIXMGh932bFzifhL4rf6rm/3fFcc5Wuk6gppkz7Ft1es/DGecP6DfgYggoe1nMaU6rQsAkh7YpwcxTAsAyF4j1PEAAc+25/qBS1zY8dD2EBEq0REBosRgV+7v8jv/0zPzLhtXfwhodE9Pb3f1+Yg8ZarGxgKX67bMu2bMu2bP/1tTvaw7KqRLryOOI18ZwOv/o5AEA3XUGHrkfMgHYQaDz4pjcCAF7Zlf0GK11owkW7h+cAANemL/lIfEaVjLiMEUROKlbMHKUtLu1KFnfcE7JHp6yR4RoAwDjGgAqgWBvLVhlKel6Tiq53EODKJbHEj8YCF0ZxgdXzhEBKKiOouvFwaP0qlUOxjx1a2hsrmxjdEisoZ72mOOkiTMQyHeVyvGIAdFiB15wV7yC+BsxuirkU0pAqLU6U9HDbgRaBAidLGaxGCrZLj2mNtyM5fhxAGH+WdKcqdYklGrnrH0kLK/c8jJAQ1NGIqhSz2j/D/SOy9tIUj5w5zW0lSlf2guamrVo5Yw77ssaXF9nekvHUv/Z5X8k3JlRZqwoWjkhA79cGCAlLa6pW6HCIw4l4Ej1X06yzhT69xVvX5Jl/+Nkv4I8/+mYAQEpSTVbOkJUyPly110ld+Uq9ET2LmcmQuPwleoCXXvwqjjLxtoIRcVY9xXAm7MSKrMi7Nzdx6YIwZG3C8VlbFKw35dRZrFUtNQqSKqoMF7bk73vPrMl598cY07MLdQzlGG3eJLZeTHVG8ouxFTL+XdD7inTg2aiunlcE5euHVYRjczuFcnAomYYq6sA4SHAkc0kjhCHCYR3b0VivwOFqh/VCC5LXfNXtbQATjuXDBaw+5xkde4m2GBbzjEDgZC2tAM38cvMmiOQ/oGGnRh14lh2LkKOTAB3HwgQQ0tNxCIYG/EB3JJrZtELF21YxZBDWFpZSVoqeU640OvT8NfmTeW68clASyDgvpiMETtWEjIxZWeNoX0785Wflu6976xTn72PdNd77ex5VOLUVoswtnvuV1+ZmLT2sZVu2ZVu2Zbsj2uvysD74wQ/i3//7f49nn30WnU4HX//1X4+f+qmfwsMPP+z3ybIMf+Ev/AX863/9r5HnOd797nfjH//jf4wz1HYDgEuXLuHHf/zH8Zu/+Zvo9/v4kR/5EXzwgx9scO/X2D7/pSm+/zt/GMPiEADQc4XadAJNK9gocRWCToSNLSoFFGJ5JmkfKQO3pzfEIv/KrcsoWGYiYv6BRQLFLH/LEgqwNYakzFstcYnzq49iVoqVecjAcqAsuqxCbGyMWcE8FxoUVWJQTMUyOaRpt7ZhcXRIHTqC1N1egIC6coop41qFiEJn5cu1Rx2N02eEunyFBfqiJEZEKzqYyvXOjEbAsiiv7Emf3/Rdd+OVj+wAAMb7cp06g+e4e00/dZLqHumWuCgx9sH9wMRVHO0x8N2vwXQcH3uKYsAyFpJIKA511ME+n11IT+vM2ipUTUyc16G0hmFQeELCAypga1VIAZXNcWv3Bg8qva2q0v3ZKHrUBhlN4dWBdPrMYBtXa0mP6HhqtPIWc+zo4CqEVvKbWSXWaD8FzgwoxhyIF5F0OkhJogmUeF/PfeUVPMfndYbx0VuHU2hHL6YHWGpAM5YUMNdvWIzQceQdmp77167jKzekivITLNlh0hwlc5/2Z+J9ra/38PhdQuMfutielXsDNN6IUYEIBQKIebysqNElkeHxeyUuUYev4LBwUgZBS/yWnlarTIkXD64z1O58jgYfRDgZhVcopuI9v3JwSa7tFoBICCSg6LVKu7AsYYIjiePasAfLuJzzUKVRD5L3tKMBZuahcioqFpg5b4/fHeCkl9QumeMGlEJDYTfq+Hft3yqeGwASJlCpxCJwSYguhtVrqvfydYKOAuk65Lu4eHLV+nSizTxHUQAFd4w4Y2tlUfFEIeecjWKUY/GIHnngXgDAtcOv4HDi7ps8t9yUXsA4YGXqsq7AKYtf/E15Ht/yth5u9iQ369Td8l0QAVDF4ryAV2mva4X4rd/6Lbzvfe/D29/+dlRVhb/yV/4Kvuu7vgtf+tKX0GPCyp//838e//E//kf8u3/377C6uor3v//9+P7v/3787u/+rtycusYf+SN/BNvb2/j4xz+O69ev40/+yT+JKIrwd/7O33k93cGtSYJudxVHI4E7ThXyEsirHBmhtCkhjk6nj6NcBn1MscwkStFjReJT6/KCW9/s43Aqb9SsoprxaIzuKl8+uRxvOp16Jt1hxno8mxobGxJgP6DA596sgE7khZWGqwDZNaojfnmSxD7QSQ8ddlJhvC//mPRZy6dT++B4l4PFQCHky9olehpTY3VTajxdJU4URwEix3ZzgfsCuEl5qLvJGnrxYA+4V14Cui8LQ7EL0B7woprdUBhIAFBTtiU5XyOT24/OmnyONi2spF4gWOcLqRuguskk0APpezauMCArriJcUK8p7PKlskrBU0SAZZ8dfas2BjGDvRUnU44aOpLnut6LMDqUl/SE0GBRZ1C8bxGj2lVdobQud0g+L2xexOUrQsZRvH8WFhmfhKvEG0DDcPKXNRMl1ztYX5cbkRC30VGMtRWSB/gWC43Cx74sjNDvvF+g5RsHOzi9LhC1q6AbmAhxRAkvRtUP8j2c68qz7hDuLEcVXnzlywCA8xty80OV4ZXrkkj/7E1ZgJ84fREXzgr0+fKMJB4LTHmPDD/L0CJ2CanEqazpAFPp/8ULMl4yq31SvNEZjFdNlx8rHaNi7ptLpA9MjbJ26rl8YcYRAsJ+DgQ3dYWIb9zJTI5xdKDljQ34cuG2GgCcu8E+SyF3bKM8z7FvlcCfci/BvgDUZAVTs5AB4PD1teMqADSLfI6WsseJEidaaxGblyoLADhks0P43A4A43Kd+ZlGDXzpOBCJBhKvVq1QcnWtXXnAVkKur75QA+WCCncBiWYVT3K63MBf+r/9LADg7Nd/HADw5eefwP/yz38ZAJAx383mxl97xasrc4uaV3h2k2MiXcHZgRhI998l5LJqoJHNgNJph7+G9roWrF/5lV859u9/+S//JU6fPo2nn34a3/zN34yjoyP8i3/xL/BzP/dz+PZv/3YAwIc+9CE8+uij+OQnP4l3vvOd+LVf+zV86Utfwq//+q/jzJkzePOb34y/9bf+Fv7SX/pL+Ot//a97K27Zlm3Zlm3Zlq3d/v8iXRwx436DltzTTz+NsizxHd/xHX6fRx55BHfffTc+8YlP4J3vfCc+8YlP4IknnjgGEb773e/Gj//4j+OLX/winnzyyRPnyfMced4sw8OhWFHj4h5EnQGOLokVnCpZ7C6PbuC5GwIJPHtV4KB3PHkRh2PZ1o/p3YQhVnviWRnSazdX13HzilzXzi2xRgO1hulYLIaMkdlRNUHC6HCXt3Fa3sBmT9znzTWRWbp+/WVMRvKbfliiS9gzYIBdJRYxFRO0chI9AYoj+Q3FOdAd6Mbl55qutG0IEcQxKlth0Hf5E3KdcdDx9ZoczzXVQEaL7IilKezNIc5tib106qJAVy9/aQSWUMKAsGKn12S3m30Ga88pjF1eB6nnVVcjccpCTBipVrsYhGKVX7omltb+dWCvcjRz+e3N8Rg1c4cej5xFGHoevasNFSGE086smdaQmhpmLO7eSHWQUMeppGc6LCsPOyUUGbZ1ATAnaEYT++yZMwhuifWuCZvFyQRjKkQ4ryqOUkTO+6Efet+Fu7G+LjCzJhV4fWWADUKVHcoVZFmOnesynl/ZFI/eZhkq1ktLSOPupily5iFU9A5G2QHSrqR0JJR/Go9nePklIaTcc5eopeBWB595ScSRD1kH7L41g/4p8UJXSQuP4wBTzoPA4Z4G/l45YzKKEtS0yLc68izfdLGH0Uyu3ajYY8baeTKoEdLDcYFzrYEZERBfeThKPDmmwdQAQ0/T0PuaFBUKVhXX7Jc2XVQud485jcaGCFyBLXc4a33+oIOGtVK+CnTKPvdxvBq3O4QzqR0RY4qT8kqtVMVjTc99xgEQzZUFrruAchR2MjFS0yi/ePq6RcMKGYUwU6Yx1E5eq+mDczILDdS8l65ESKgsAuKIIUkXu/o6nrv8/wIAfPP2dwMA8mqCv/GTPwgA+It//WcAiLemCFtXJKsp23iG998t4yNKYvybj34KAPBgJfNxHzNcvmZRl/Me36u3PzTpwhiDP/fn/hy+4Ru+AY8//jgAYGdnB3EcY411qVw7c+YMdnZ2/D7txcp9775b1D74wQ9idXXV/3fhwoU/bLeXbdmWbdmW7Q5tf2gP633vex++8IUv4Hd+53f+U/ZnYfvJn/xJfOADH/D/Hg6HuHDhAlT6AJLEYnzzZQDA8xRr/OzLz+JXP/t5AIDtUuh2rYNnnxertbtBi0wHCBjc7vfES9yM1zCiZ3J0RC9I7yEf0wuqXCJyhZCWTp/ew2w6wXjAqqpdCaSPetdQzsQ+y/KZL1QXO2sO1le1dSUZDICgogV+RCWGI43EFVdjQDZSxluFJRNl82oo1RkBDBhX3M/2EJGi6pJngzrx4PHRSK7trW84gytXJPa3OiBdenuM6YhiwKtCLx8nu6hpifXOkFK+YpARi5/RlD1QxuPtfcZ7VuNzeNtDXw8A+NznJPZ4aMZY4/dOwWK9M8OXbkhfk7vFXAvCCAnjj72B9KUupwhJPJgN5d5f6F3EFRZj/MgLL+M7HnwLAEAx1tHvb3ldO0e7zrWCoWVa0LIfDFI8vC6kAihJV+gkwOSQ1ij7vLa6jj774wgAj997H05tiCEW0xs0GGG1L/t1OkxbyHKfg/m5lySh8pHNLnIfe6UX1ElQG/k7Z6xukh2hRxJHSvmD6HCGbEc8tueeFQ8r7Me4dSDPdWsgsbE0Wfe/PUtKfieNMKaHVRhXWiKCcdRqV3Qy6WI4E2/wPE33aLCOMJFrS1TtSSPOio8UULgYoWMAKIMxVTuKsuSxE2rLAdY5RjVQMbF1Sv9mpICM8eQuvWTUXf+Ma3pYgbGomELiSnZU1jTej89vsL5kBgVH0EdbpJbkEVgvkut4EUMcL9ED/tv97fyHAAoJXOK7bIs6jTdVk6xkBkDoVGDoqSYW6DCQFjKbOaoVMOS8GWsYL8tB0WAcL6YKCHmnpr6ji48qGMychiefW4AUP/1r/x4A8OmXnwYAvPfd92HznHT2R37k6wAA//yffxyjI6peODcusLjrrFzMWQpS39i7iaOZ9OIzezI+z+72EAy0sFu8D3v79odasN7//vfjl37pl/Cxj30M58+f99u3t7dRFAUODw+PeVk3btzA9va23+fTn/70sePduHHDf7eoJUnSSNC02qMPP4isHOPKpZcBAJ/aFQ/tUy++jCqTBehtT8gN+4MrX0JqpXR4TvgptkBOup7mCz0Mw6bMNhW96yLESEsQ15U2DyuLVa+9Ig8sMymSicCTq6m8GPTgDG7mAi32S4OuuwzmLNRVAk3JJRO4AHWAMJYXc1QwAH2gUbNMuFmhy69EPBUAqswJiu4hT+XvgHCCDgJEhE1C5rMMTILti1vsN+EsleEMCRsFh3qwdhP3PE7pK7L7ym6Ex+6T34xmonC+V1zHJGHOGImUJmwCxWPm2YRlAct7+G2PPgUA+N+ufBgF80AeuluOdxBMsElYb+KIJVYhgCNGcEaHYyi+aKZchC5sbeNjz30RAPDCzhG+7d43y33lrY/TLirmBJnKMcisf0kPAqdkUeHRs0KEuF7Ixa+mexiSIOKYUBsrq1ihCHHUk9fZvafPYsBK2CWrAk/NLSjWMltfE2hwNDpCwIle5HL+nVGOwar0r+ZYXE27XhD3cCbjoKpzhHz5OLZgjl10qblz9UUZB2cu9vEQa8Dde1ooWtsb59EnFtVxJKQkRk5YdUp4dFVpf98iGju9tIuDoYzzwsnBBBUylmeugw0EPKZi/6JQ+6C8g7EDbTEjqy+jwdVLOx6+ctOrKg3AXMa9gnl4aYTikBUNNshAjToIuiS1HHIQzjKP+1lPbbWo5xTCrZHqAkBTlSBVzWu0cILOqlV3ivut4/hCBQiU6P5216Ot9blWtSsg0AUyDuWc7wYb+fJriDm209wi5fzTDgacWhgqqxgDD6G2Ja38QuWuIwAqZ+XSgKiU9fXDXFhC6RCrPXmen3vpBQDAs//ri/iOp+4BADz8ZnlXf9/3vhH/8mc/xd/wGJ0Yjz0s60Lk8herApZQu0t9HOc5etFZeb64itfSXhckaK3F+9//fvz8z/88fuM3fgP3slCca29961sRRRE+8pGP+G3PPfccLl26hKeekpfTU089hc9//vO4efOm3+fDH/4wVlZW8Nhjj72e7izbsi3bsi3bf0PtdXlY73vf+/BzP/dz+IVf+AUMBgMfc1pdXUWn08Hq6ip+7Md+DB/4wAewsbGBlZUV/MRP/ASeeuopvPOd7wQAfNd3fRcee+wx/Ik/8Sfwd//u38XOzg7+6l/9q3jf+9630Iu6XXvonjUcHu7jU18QWvDTz0t/oijF//KePwIA+PiafPexz13CG85L7KtIxPsqUwNDgdIhqfGTYsfF3rHSFysh7ypcqF0wV6AQXVVNpVDWQ6iQ+fyqbikWw2a8gvWumFDFrERcUOyTZmZQaVROFIMWkoJCSPHQDq3WbM9gSnr8+obTWFO+mmpFAcrJ+ACzDss4sHKr1jUi523xuJNRhZe/JBDU298pgfv7Hn4bji6LNTW6cggAWNMGdz0p1tQzn2bdr6SLM/fJ1V9+joQGHGLGEi156Cjn8PpsMbGVw3yEIa3pN79NxsSpD38UJT0e50Hdd/d5lJX04douxTIfjWBJObf02EqTI6I17/KwrArx5RckNy5QHRgl/UpIkrA68GNtVJADrDUKUqadvt1a0EOfMOj6RDyUm8l13OCYGVCUt9vpoEv13oTivFuDEDHN5CHLn9w8uoQzq3LNG0QgrlzXUMxfoogAbk0z3EfYzLiKw3EHlpCx88iCMELNvJkOFS9iDYQux481ulaDHi7efQ8A4BwFftXKOhKCWy7mrdME1ViuvcqdthygnbfN8/e7PVximZxDEq/SsI+EJ7a6Ed6NA+d1WQQtWjkAhKHCEc83Ze7jWhz5fK3QeVq2qXB7wNIkl+oQo5k8mzXT5HDVCWUgQvHOVZn5sdJWoXVwo/Mea9twPBxNO7UnqwFr26R3OKp7gUZaz9fNUidFb+uw8S6dNzWJgR2HQrhcqRrYcP0jzBdOgNjp+jq2R92qIAODymlHtvgyrj8V3/QzDRReq9EhNQaGJCqneFGhQpc/zlj7KsAqPvRLkkv1g0YQp3d+3dvw1NdLHu5v/668azfPBLjvbiEc7daHcty6wipv0g7vyKXDGo+FOQL3AnwN7XUtWP/kn/wTAMC3fuu3Htv+oQ99CD/6oz8KAPgH/+AfQGuNH/iBHziWOOxaEAT4pV/6Jfz4j/84nnrqKfR6PfzIj/wI/ubf/JuvpyvLtmzLtmzL9t9Ye10Llpfnv01L0xQ//dM/jZ/+6Z9+1X0uXryIX/7lX349p17YsnwXcbCCjz4j3tH4QEzU3mCMz3Rltb8yYcHEWwo3V8UDi0lGCOoOAlYonU4P5TtYhKzUe8+5JwAASRFjNn4GABCRWmzsDBPnPfhCsWNMWRHVGLHwgmiGDq3lqYEvkOeS/KKwxIiJoK58SK6BlPGzwtl7lcHRLisTn2WicWgQuQrHTMAcD29hP6IaMgkeVoWISQbocltfaxR9OfZv/aZ4Iwe3DvGWeyT2tkKa9j33vxtVKOoCZ+kFTe0Ik5dokc3Ews6qzFumzmK3RqFwquOlSyoc4wrjH/eeEq8lHSTIqHv44gtiub3r7qegu0Ia6IzEQ1KxRj0mlTl0MccIhlpyhpb2YZ5jOhGrcOP0Km6WQsA41xNLsCwyX7TQqaErHfhSIiNa/b2kg1ks42d7IPD3le6z0LStFYXeOt0OotDFCOlpRUNYI+MsK8UzvXrtKrYvyj1cX5fnu7myjis35HunYGEKiyG9glMMgptA+zIfLglcBSFmjJ92SC6CBXKWiAgqOf/Ngxnu1Y4wIbGuKI3Qj8RDKejFpWmCkNb3jF6msaswrIisnYfV6WHMVIy9kfTz4lYHHaZOaN0O6DPmq5QvP2J8xcIalsjAkOrqp1e3vCfnlfRVUwIlqOVZj1ONkfPoOPa1SqGomWl5nXVxBFCv0NKtsralrs4/ytbfbhyr1t+OaKHRvDQdvb1CU2ZHkbygk46XWq/Zd6MBQ6L8LJI+74cGN5wn7mjypZTeAACVMwaYa2i+cJxHkgO+4KqFadRnWnx6d52uUvcs1Ki0I5A4D0sjNk5ZhVdc1xiz6nnstqkZTq/I3//+P0iKUFF+Ft/0zjcAAL76gsyzRx7YRtjje+mWHGM6ymBqiU+PqKg/KybYz1ZQVwbALl5Lu6PFb3Ob45Url31+TUyWTVYC/9uvC7Ple75JBG9NvIn9CRlwVJsIJxGskzYiVHfX+ftR8GV3/+abAADb2Qy7fHHtsSS9nR2iTyjHCcXaAui70iQO9wj2UbiE/NLACTW4THDVqzxEMnQyPJnFmEodrgrtamcAPSFM5NQ5bQKt2rWBgcP966iVvAxXwkcBAEVYAy4hO20WsZSLyD1M5y++MMTO87z2e2TiP/zYD+LZT8rCcRclD4raIDiSF8fpiSwmM60x5ax1AWNjLLTrHmGqYprj+cuyAN4XCYvuepZ72aE1lj1/4+PfhH/z74SldG6bi3xZQjvSC5VA1WgPKRfXgDV9xqOhZ/+NrIaqhUE3Jisy1hpjvsTc3Na29qoXWe0MiBIJjQjdkXu60TmFIGC+lJP1SRL/Yna4Xm1K5FRWmeSyqA9HMxxO5G/HFrzr7Hnc2N/neUnmUMDlPRlnD54WGDvQkS93HtJwDOsQGRVVEr9Qwi/+65uy7c3RW3HjFSE2rW0Iqeb0Sog+F1xLBYsojJDyZXbE53qmrryaiivN3ukNMHPpPzMqimAFJUvhhKbjWXCecKA0KkdqcHOuCqCZ/zNi3lkYBgi5MHt1BmORsLxPzEUxXAsxmZIsxGcZxwlCGi8V70dtAXi5KfcJn7tnW58OmGqrUbi/3Yuybv3txk6Khm3qmJToDaBIaql53UoZVIYLtBbDdhZk6PD9NaPxdCsX1REAqGuyLQsFxarBjnhg0OQttl0JtWDh4rqHPFEotVNP4VgNAU2GtfL5nBbG5YdGkT+gE9QeUJbnwx++hooMz+95t9QjzEyCrHTVpcVwH05KbA0ktPD3/8f3AwA+8pu/jV/9jU95du5raUvx22VbtmVbtmW7I9od7WFppbB7Y4zHnhAr9HCfluzNI4QU7Pz8TcJZdz2MohboMNKOJGE9VNWlV/K93/AGfPyaQGoXVyVwuJXOkKZUEHtR1viijGAnYi35unWF8uUoMld8UFlftXRaN1VBZ4TKNCw0rdqC+SdHJTzcse7o0mkHYU0hXyZsKGWgeJ0OQxgNM9w4Ekr3tz32DrkfRwcIWJE4YSR7IwpBIwj1gWy70LcYkIRQf57Cs9/5RQx4vhJi9U9yoN4X12nNUHMwsKDQCLpOiaMENh0URRuwKGvsHIn7/5UbL8s5UgVLzCIbiVeyP7sJssdxkxn8pR0J5xdNTk0Qa+91uTYcz2DobZ09fxe+4+FvAAB86hY1zGYGmoBOyP5llfbaa0VBizfL0CH0OInkvqzGFxDFIi7rIMhe2oWzxYPA0eQV8pKKLDO5piy3uHlLYJMH7xbvvTxT4NmvUpGipeZy62DG+yDHuDDY8J5OzDIfgdIo6TX2SOdGEqOcyZjYsLLtj//Vn8Zf+8vfDAA4c5cQibZWLEZGPLucRJwV9LFCyPgW6ep1XfuUg5DPcKU7QMzfHPB5DdJ1HFAr0OoKIa34iF5caQ1U7XKQiC4EAZwLPibhxWjtk5Scg5CowJcJcqkCOkmQWfG2iqtyDyJTeU9XUUtShQlcDWFXSPUYJMhPi5NFSxUW76db3wNCDiH/BnHCcdUZQKcr/K2jc9eoasddl/2GFogJ8Tlh7GEpSjQAUDiiCpQXFHZagHWrfxaNR9WO3Lg/WdUIWdzAta6yura1H8uKSaFWK4S8Usvz1qr2qQEl2WVxCPzyr8t74TtrGQtPveMpfOUGRcBvyXe7+xYvXpXc2DNn5H78vb/9P2GS/SWURY1f/NIBXktbeljLtmzLtmzLdke0O9rDCgCkUYi3v/XtAIDJSCzGz3769/HARQnMPHdJsPtvfyLEdF+sy4tdSdpMV9cQsnxHzITE/Vf2kKyJV9aN5XjFpSOow08CALZGcsuKKVDsM3jJmJIpraeouzhlZYGI2yIL0KBHn16ZDUQnDAA0U+xn1uJg4ApP0nKPCh8w5ybYMoRlMrHyZNsCl18Qq7x4XKyWfnIKVcaEZlrnnbKCZm5lSALIVgX0GHDrEvTe+ZnfwLn3ClmhEkMVq+UqppSETugGrSYapxgDctacTrvoEoN3BemmcQ3FGMYVxnPW1zVu3BCvYNATC3pn/xpSHvuQah+TeojEJXrTdiyryscNDPffH2YIaD0+eOYits8JkeRh3udnd16Cto56L5Z9CTRpFaTdT6YzrHSlPzkLKnbj00gCsexTElkCBK6aOPrdJgG9MGJxVpVLEg9w6SUp+nn/+Sd48wP01uS5jnZY2sHAx5KeYVL8A2/cQsLzrTHJvSjHKGfyUDqpxAfSTh+TodzXh0/LcVenv4jtLYld5VO5z+WwxGUmQ2epXPfF7bNYo2fylaEoe6iqdk6tr2Gy0ulikIrnf5N6m88fXMLqKYnLRQi8l2R5HWkQog4cQUgOp7XxXkPGhOXAagQs5eP0KIyGJ7UMqBDS7fVQUjOzIh/cHDUF7BX1QnXcQ8WinxVjJcY0BAtHAa/RSvRF0zzt3atkLFZcd3KXCYklSRRDu/HkSBe2hinkOlQp+w1MAU0v2XktEwVMeLwZx2luDNK5xGDT6hdaFHzfbEO2mPAZ5kEC6z1YEkQsQHF+hK5iZNX022kTVkUO424cz1tWGqsbRCkyiUn/4od/D4/eK/G7ly7JtRX7Cpbksn/7S8/IuaK/j7/6gfdjPJ7hFz/0/5zv/cJ2Ry9YGiUOpzfwyAPfBAAI+aK5cv0WLmwLnPfMSy8DAIzN8d+/VxTkYwrL7nT20WPAPuJkGo+eQfWi/Obqi/8rACD4wgH6HZncyQ4Dx8+X6BzwpdeSkHEDyH0a0+Rt5GiUH0zmYBHriRgBF8+yA9SEZoaEz4oiRLAuOTSzkkHyLELCiRww1ygONUqO0meffQYA8HWPfjd2nxdWj7auuqxCTKwy4iI1mFj0ufjGnDzh5yt0HhcodWNCSSLTQc5seRcoHluLU67MQ0f6VKcDKF5HyTywKiqhCV8eUQC4SmKEZCRZwiKXbr3kYZGI5zicTnHGJ/G4yaYRMSjs5m6eKyQMvnc7CWbc977Twkp8Zf8IJqJQbiKwTVZPEHBipgzs12XWiKPyj7ib4Exf8tbGrOxrA3gxWApdIAg0ajJGrV+wKhwdystzwmJg3U4H21QXuXmThJ4aUIStDw7kuV7b30PC+6u0g1eBioH4mC/qXtzDPmttrZ8Vo+2TH/kXWFsREdKQNzW3MbprIjt1oUvI1yTQgSNBMCBfWFjWknFwTC/pIEpZWTmQsXH1yiX0V0QVxHaMX5TcDTRaeSNBexJCQ1zJHQSJyjMOPBxX1QgcJEiDZjBYgeqRYHGLD2lkfR28yrE+kj6UFybmC9jYE1BfjRbJAyebbS1s8xSB0HrRCC8EG1jTQIdOsNloT7ZyuXSx1YhcqRvPirQ+VypzRCxrfRXo9oJ1DP5r5V8BspCWnBQZF6xJFLYEcfkcKgMLJwvHc5iydZ1uoa9bkBzHEQwubIuh8uRjoqLya5/YxSuX5Hnef34NAPDx60fQNFRXYzneL/7mb6HXX0GRv7Zqw81Zl23Zlm3Zlm3Z/gtvd7SHtbt3GS+8eBVvfZus8FsssXDxwt2475wQMbodCZDvTA5RkvGw1mVxxZ0XcWFNvLOgL6t+f3+IoTkEAGQfkwD5YGS99ZLIJtgDeFJAiz16jP4K0LNufU9HAonj1YbwHpZT07RJgDGhtAOSM0ZVhtIITFNp6f99yTZSeillSbFPBUSkn7/8VYEBn3rzFF3Sh2PCD0ExQ5cV6JIx6coF0KUNE+cOx9TIflnOu3WWAqtF5Es7ZLS+NvIKM1Jjg65Y9kWvC0tPLHO02jJDQejTankeFx98HF+8KvqSJdMLru9c96ZiRI8iK4dAILkcrpJ0iBCGcJ1TvTO18oHimaoQ8V6rvtzgM6un8AoVSwJCNGk4RE4vyZUcMYF4RQCQOu8xqnDftngmn3tlyL5oFKxqm7qkPJXD+KKa9FZyA0PX++ZNIfbce88jOEXPOeS9sqiRuVQIYo2fu/wKvu2C5LvUEUu6BCFyByfTy+z2Eo9PRRsChe5kV2G6QjkOWN52s3cayZZ4XVsDOe/1wxqa8OqUhRB3x3vY7ApaUZO73et20BvIM+4fytioZ6XX4Ktt6Qteuny5yCqUdLsCTiarledgV3yGKgQUvfHAKTKg8RTcM0w2VxCxonNxnmSPyzMod+PcR9xFzRItJeG6CpXPr2pr7TkPq02wmG+1aijxrmk0+oMRKd5RmSOkZ62psKIsYDjGFF8e2lrE7E2P/dO2Qs575ISkC2VPeIAGx709O/c9AORO4YLQTh6GsGSJKe+9NXPHlQixpgI8PNiQVQIPE8rZkkjjwmnxYK8PBXEYbJa48rKgC6c3BBp86q0dfPIzVFFJieIEJf7tL/wyTL3oTi9uSw9r2ZZt2ZZt2e6Idkd7WDeuzLC7ewjFJDwbMOgbjXHuwUcAAH/5LUKw+I97v4N9qmZvbIincN85jQ3S2nOWgNhevYY9ameNmHq+NoRPxKZzA4XGc/KYcOtvd2PbFkE7S77DWIyaWQwcJdb9tlI4pIdzmVJ3GACgjt7eoQTak4MbGBRMkKQ1nxcFIh57f1/2Kya30HPF86hqEGUWMSO7XVLt41xDEYx3cfa4rhHtyjlW6EVU/RlS1vBeYZC+mBbImIkfadlWxQOUxO+nZIrsZzX6pF3nF8VyH3RO41nuN2YJgurmLeS8e7on/ZzkNxGwzLauXFl0IT0AjRJLVVWwtDOLrPAlWJxK/Om1Vbx0LeE2FiVUCY4qxikdQ0Qp5FT3iOhpjYsxNjfFuxjc3OK9UhgxgZusZlg1Q0ViyogJmEXRPP+rL4u6yGOPvBmbW3If0p7ct6PhyJdpj+mSXxsOMSsclVj6l4YRhi4p9v/X3pfEWJJd150XEf/HnzLz51RT1tDFbnaTTbbczZkUSciyRJkLGza9EEDAa69seOE9F1545ZVhwDtvbIA0aMCSZREgZVoS6SapJkX2RLK6WVPXkJWV459jfM+Le+6LyKwk0TQaEFOKC1T9n//Hj3jx4g13OPdcJtu2oxZCvg+MnC/cfB7dhcTJHGHyG0vnEPcJEGG8rG0COGr+XVpIu+N9bK6Rf5CfdTttrK7I+UYPCLSxGVK2JTYxQk3wVsaLdqTD1yMdbOmEeBDwKAhjQ5hQWRcYIzZAzphPzHhUb7UDsyL3t7jIdIBlg3Ku8TZl1YjhNJmY8VSL4gmOwKL2/iS8vS7WHY97ybVqib4KiLLGVxHQVcG60o/RgNZU6IxnZB/QQxGFRTVv+DwSb69Xa0wdiq/3oJ8DYjUmHHAzxi6zSGDs9ftzxqCqcUnLDwalApP4nTDVV5B6ABgOQ28l32SR0DsPCrSYgPzgsaxBV6+28L73y2c792k5OwuY5PRKl79EzvSGNRoB+3sjHBwI5dKM5LCT8S4urUiXDj4mAeXgcAXtSDa0c0NZ7JLFJroDQQQevPNDAEB2oYtiR9xFPdL4R0cOjJ/7hxzgSRP8tAFe4HgQtMf3bf46NtXvYx1xmcVjkoG2uICbLsCP4LoydF+7dYTlJVLaDGWlzGcWOasij5nLMz/axqX2UNrK0hSt1KHP+1OCTeOsd5WoZAD6rOdlQla/dSse7RiSUWBWhujM5XphKm1JE3j2A8u8nShvYZnuiXDjOgDg0M7QYd2V+UgG+GoE7M85klkv6HA6Q7DMCc++anVjhHmF/gKEBkprVS2SGUrP1EBqq5UuWmyPYXmGuB0LXAqA4+x11lWFi/iwk8UMPdaounL+Go9PfSXkWEmGjfV5PwuWBSlTIOIiPN0V5SksSoRs3waJdkfjiVdeVJwxeONQEHmfuvCMfNiOkZM5IeSuHHcHKAlLdXQ/XrjyAdz+jrjGI4Ibup9bQshyMF2taeUWcGzLZl82qYPxHAXBTCHdwO1WG8OVobTBl36JfA5P4RYoazWZpC0lDJOVdDPO08KzHHiKoaiNMNJcJXkJHBDSVbz/irggrz9/CSE33HSVbrO4QMepW42vQYSQ7tyAG31uFh4lWF/kTzJdnIYWdFWzjs1/hQ3kPLF1FlbJrJ0Sylq/CejOFrkWYq3DpQqag3cJL7hh5aYCcumifbIt9bVJ26+lS+ashVMGAp4AAGd146pvrnqSwLsvdRO2qFCVfOzo9AMcjAQstEMwWzI2mGvV5lU57727GTYv0x29Id/tHgG2gM9/fDfSuAQbaaSRRho5E3KmLax01sZ0ZPGLd14FAGSsijbNRhjtSFb1RZZVuH7uAIfbgpj4+p+LVVV2HmFBzrz2Y7oLBteA23JcZ58ad35cwwJEMztZZbSuKNSPU60gghSFA4CeQrYjoCQUWvO18sJhSNaDVfoBxklVaThnUcd70xIxi9jNyPHVKR1GtEzaJJ6c3b6P3paWOCGf3rxAlPD+ajHPk26REEDGv5TENywtwgWtFsKgOzBwdMe09qVnykmJtMMbOE/AhjVoM28DS6JJ7xwcoK3lFhjUnR/O8dwHxQL7qztS8mQWzZEr5xmtoFYrhjFi2SmhbTZLELAAZpJmCFmULiSc3pQGawNxFc/oJnYuQBCK37dk4DkMQp+rUlDbjFzbl1s4vy55J4WzCJmb1SX4wVkpnAkACVkcUFQkroZq63j/EI65XufXxTV46+F9z03ngScAHjDPMNsgwCYyaNGSzGgFtcMY5NBFzJInw/4mXv/GfwIAvPj7v8d+GaPd2mebpd+m+Q560dPSloHA/V9553WkuXzWsgqrD7DKCssmlHHljEWglpMtvHWhzuVW2IEjgGS2YFFS44ATJUfCVoiWuu7UvWYMwj0CjV4Sl/Ct//kjrP4j4cpMmfs2KjPE6tJSMtfAIIiV9aLNzyqt3luAOJ3V4iQcwJ5yXI5qjmccJ4UtENC9rhD2wjkU7rh1Y0qLNi17TXkJTYCSv1XOxgzwnKSnWVj1dlWwdWBOC2vWUevS+dy3TPlMXehz3rTTSzhvWulYdaaysLT0S1BkOGKYpJjqs3T+2TGLA1c2gF2WRbl8nhZxEeKoLGF+DbOpsbAaaaSRRho5E3KmLay37/4QUauDt25LADtf0AcclPjxIynbvDSUY9f7PXzt7Z8BAG5+XzTVT7z/Lt7ZlgDJFz7yDwAAe//xNcQPqFFMGLh1pdda6qWwT2bGuxPv9TsFMHQA9Njjy1rhPa6SOWc+FuOwRkaCi9ReHofwsPeUsOZOBE+pMX4kr7/7oU/j5q6wKWwvBNY+/+lDhCtiUXQntIgOShCBe8x6PC2L3weSnQZkLdqExAcKfQ1yzBTQERwBAOJogJRxkja1eGNDhAzkdQluKG2KdpfviRleLBK8/6ok+t5hOZLDUYEO4xA2pwad5YgZG1Gfe54t0OP5XLEAFmKR5IRY59ZhfVksogfbTJiMLFbb0kcPDsQCHy4NYRiTSHlvQTvCbE5mcWYJd+IVJE7LyhAEEXYALZ0xYbK5q+IoAYNwi70j9J8VEMc5gi/isAVDAEOmJPBwWKRi2d0eiQfgqeVLCAuCGsgwvxyvgQYHLpyTRM5ee4Dv/F+xHr/0b14EALz65psolv8aALDBeNTOzhQXmHC/yXIlj0f7SFiYM2aCcRxGWBmQz1KTto1BPR1XLT/lrgxN4PkAc46d0gRwSoBH1EKr1UJHk5I5Cl1hsfUBian1ViQp+vXX5rj/774JAPidz34YANC1JS6R91LT9R1Sb21H5CFEFCDVhGYeXbek6jGsk9+fZokZ48OfyGlaFEUB42NFvF9nPWef1VQL5xARoBB5YEkkCdQApoxhLQCkJyyserJzHYChczdpATOmkMzIjuKsqyoL0FsRIPCxNU3VsbAVyzxJFcLC+ftUXFKW1bwxCnUPgGkd+QFgOgECupfGTLu5crXEo5dD8hOehgB4Us70hnXrF8BwxeHRLbE7W1xA+r0Mdx/QzH5RMvG/9arDG2/I9x+7Jq6Q9U4bnalMxjtf/R4AYOX2IVg9wnPIGFSbU91ldlrexmkZBfUNq88/CAhD1AeCjiZ56QR1WKInbZ2vKxMgkzUEZpXnALDgivaHL3wMAPClL/1LfP/lrwEAXt7+FgAgvDNG8ows+vGYdEGzqqXHMudPtDmAlOMAqrIFKEuUIYl/dcGEQ5TKglA8JmozyhAw78edk8/aSyUyotfSTMAycRii2yFohAkt/SzCwaEssl/+4t8HAHz9f/wEkQbOCTlrtUOkE+Zh0eWTJClczIXh4CGmmbB85GMO99yC1RHAFB1MJwE6rJPWHvM+SoeIiDVP9mmBki7BKVk5dvceoUeXZ5hrPksOxxIxh+Oqrz1Zsbo+9w4wfFZcn32COZYGfYxZydcvTtZ5N+GNkQzQy53zcEqcSmLkKIoQk2pp7ZxsPq3OHBOOGS158tYbP8Nf3JMx32MO4rWlAT76uZcAABc3xQ04TzMcsi3La3I+axx6LFfTZptDU1ZB/yDytbNKbuClzeGc3rvmm1X/K7YlCFvo8jko64IzwDiVefrKD2SzvvZbW7j4PqFau8NK45uhwbg/lLbu0RfVMkJFAiBqy3mDdgeW9b6OAaa40tcZIzxjzSkT228WrsayYtXdVSJQ0IK/hkPuNzH5teQgEaXJAzutwLu+FHc0N9WGpYCcOvltWbuO/nbWBmbUXmZtZbhJavNdAUJ1ZUPEWOdpmDTHzAUV5VysG5atcku7zIOd5Yk/m64dFhYDMlyQOQxBHOL6MyXKAnj9Ft6VNC7BRhpppJFGzoScaQvrfVefw97+G9hlstKFZxhotxZz5vp8+1WBAv/5D9ex1hON95lNcRv84vuP8THxOqFDGvxkZpGI1wHqXKi7Bk5zF9S1tJOBzxBVJ3cMQCUEnSVFXzjRAlG5TxZHFkskpB1SjRsWwIixe80rSkPAEWzxxX/6ZQDA1Qtz3PmguEg+/Knb0pY//Qmyd8TK6I2ZyZ4dDzhr++sAEUC0uTZV+7YvGQGECsHVPKAiRZCI5ZGM5HlkxqIY0MI6pDuu9xhZKUwRuSOfX2RB+j606WLMnPXZ9JtkZPidj38IkWH+Vel9ZR6iHoTqLizh6F59dPtt/PHL/wUAMGImy9WVizh/Qchiuws5X2ZLXyF2lbx748nIFx1MCU1H3gK0hAif8u3JLTy9JNp+pClcyDBlztPuHrVqAwShulXJ37dIEJC3rxVLJ2xunsPu0REAoOdLO1QW7phje98lsAQXHLKK68XzF9AmyGB4ngwV+DlIwYe7f/UT6SPXweGB9P/bN+X1wWCG3lMCcPncRQE0GBNiZyxz40IigIdOv48uLawewRntbFEBKGA8Z57mUpnIIGceX8swT8GG0Jq+yjkZBG10CUKxOhjLEPNU3Atbl2VyPrp9G701+ewgltd9s4xpKJbTEt1rxoZwWvW4TeBPZ4BgKmM0qRPKqiWhl0VlWZ3mOjw5f4DKM5HlOaLiuDVdOoui0FIjDDtYh0Bz6HixGM7zLmaajxU4D8BQblKL42uPtjFlv017wLgrR2vRxgB4wlx0znkuzEBh+db68aYVio11vg4svesIQ6DNBSnkMyyDBGojaopI1DZaRxMtprocjkqcWw1Q5L/MN/WkNBZWI4000kgjZ0LOtIX1ic+ew//+syUcHUicathlUG+RQZm3/tsfSXmR1d4QW+uEGSs/3G0LUpihuy76yeQvgPgE01gdYFF/fTcw2ACVVtAOgZiabkj2hjI2PntcS5AHJdDXdvGCawEAfuZZNLpAxmS9YUsSKh/u7qFfiJZ5+UOfAgCkP3oE85bEsIJR1diTCZABKg0xrr22qGG1fME365nNC8YW8ixFyeKPBa2+zBgsmJTsDiV2kp9PERLtYefSGGMMImKxFeIalg6mkNaMpqLhf/zTL2Lne4X/HgDiuAtL7sc2y28sigJTPti98Rz/+evfkXYFqj22sCzGAv7ZR74gv42uqKGLPi2UZD7z3GrKcVjmOeIeGRamRMTEXdy8KVn+156T/giCAPsTeT+hxd7vAN3OUNqfiYVVIEAxE8tj+aJYRNcuXcPdO3K+VFkhTFVWRgEDD9MRegXbOhZr7tr5y+gFjOVuiBV57+Y3fEXAb/3pNwAAn/wn/xy/9X653rd/ILyGDx853Lwv8aDokzIHluMu7u3JHPrw0+wD5zyIpksgRlTkCDWb3DnkHING40K2ziVBoIsFnDKwaOpBu4ellWXfhwBQphH+4Qt/AAD4k//6HwAAg49m+OktGRdRSzwn226BCVk7liMBsnTKVmWaMvUg6i4jah9J/xPIchoLe/2z+ryup6wAx60bnT9JmSPQ9AhNqM4Lb8kUBMmUZen7SJlAotL4RUOtqlkg/wCPvfJt1DZ4QBitoHEXGPWkRQUnVgjjIfPa96Wtxaf1vLW/Q1tFGzWnW2PAgzgGyD+plSCiUFjc9TcAYJ1DyeslBCFFPWC2sCjfPVn72d6wVlfuYWNzGY8eSHB+e5tMC+stZCndMOzY3cUhbtyXRfP9PRmkG314yqLiTW4aUyCAgi3ku9wT/RxH5ZykQjmWGa95Cq7asMLQICRCxkdOW9WJSi7uQQqpRYKqmvGgBBZHPM9YTr7Uc1helvdvHyoaq40JM887rLAcXT+P8NuyEOVEHxrU0IGaA+OqZulrNwq8m06D/rbM4bh1Jwk3otz6HBStm7twDgnbH80Y5E5i7L4m7Xt0R57b1uc/g96SvG8baXNmjRajxWgkN97p30HLCYlrh5MkTdpot6oKvABgigyG7qd2YJCSQJhYBKQ2x94Nef+/IGjSP/zYFaRaWqGWS5NZrSBM1F4BlKzj1SFJ7mi2h8lcvr8Czb1KsT+R32ZUNDbXO7hEdo/AiRurXwAzsnssXbsGAFhdWkeXrpxyRmJXV7nI1CX8cHqA57g5HU2kfwetAboRXV+BLPyvv/JtdBgZf108w/iDS6uI94W+7LUNATLsTAvc3ZZnk9MltbW6gZsP32Jf80aWl33tsA7dmLmrKtmWMGixppVWdm5HA8QRnexcMK0rfV5ayNIT7XYXHaL5NNgfBgEmuTyna78v1x0fHWGX8yC+IP33TpjhgO1ag9x7WJiKvYUu43bc8+5XZKQbQ0VqW0fN/ioWm7x2nH7PsyItckSk0mrxgZWl9QpQWUcTUpEKOBFbJgJI9VXU8rHU1afzK6y1r0QFEEmpuM/jAGmsdFQKoDCeAaVCmRiPNFHqKOuc38R0fRt0A2xtyLgNqcQuRS3M2cG68bRMTVEJKiCZAlJS1tUbtIDxKCRi8t2hBBuXYCONNNJII2dCzrSFNUsXmOEAg55ocfuHons8fWXocy804zq2IQ7fEd9MVwroolgAEdPH3R0NGFuvUaSnBAJ9FjmedKmdpONS8RZW4CoiVr4aU2FiSRqBzFZaeY+fdQGs0bU035cfbG45zFj0cZ/32bcJ3EgOTB6K5twrZ2hTY3f7hW+/NldzK1qouA47zL1oh4Gv3luRVFoUzEsqS9UYHWjUepLOElV2fnqPbsBLEcbviBZ/dEGC79e/GGKVuUCDwQoAYOfoAMmB3Mec5VIPDg6wGrGkBEEpNjBV8ohWuc0KDNblfp8fnsfLd8Udms8V+g0Q4YyDHV4jGyHsyLVnJLxNC+fzwqy6riL4e5+yBHORTjApxSX3Ur7Frgqwf6AmgrysrlzB1gVC2PsC7CgO55geilXZIugjiCIMCO2ezrV0SsXBR9pF2DzHdMDeZipGJ2pjleS8Lher6//80asw/E0ohgdaJsV8Llj3Z67K8ZODxzii+29Wym8vrK3jlZ+KRTknFDwMIji63pZ68jzupXOU5Jws+zWHeaDMCTmsuqJIClwa48tsGFo/rXaAtWV5DmrlBYMFfvQTsbCGQ7mB7iBAcUiXGzX2YslinvF5hizZYiuGFrUewihG0JXzRFP5cWLtE3lMdb3/NGBV/VW/97mZzlUmR6n5hhZFqXD2CnzhFPDA9rXgvLdA+yxFBWs/zYNWArBKdMtww6TfQaZsG0pga60HwnirqrSwauUVlStPp71WUb9+aQ1PXSPp8UQ8K622gyErz1RzEFF5n1o+jxGeYUPPt7cNnL9a4teoLtJYWI000kgjjZwNOdsWltvD9u0FOp3j/s/NeBkBg8wRHf9tE2IpU91ftvTlYQR8T0uByzcGeCImVYeP1rWqJ2JXDl4980oMjlsy6tu13loxMHomhUQXlbWlAda1FtBnkvDhAf3fY2CoBxSEB89TJAcSD+rsirbp7m0jINRcS3nXaucdS2xWrsMu2SPaUeALCyqjs3MOJUENBeNDma1iAOpjL0wVB1xMpQfnb5ag8oslxhuMTdHpiEWxMpQ4x+7BAY5GosVt5cLZN5nOMfRxHLKEZyVCJqJGLCOR5wU+/YKgKn77+u9h73vfBQC89bM35PtUmNMBwFJtvbX9Fp596jPSN1pQL+7iiKau03IVRYE5rZ6SGvR8NkZiBWDxze8KWOIff/Yp7OwSGML4zLUL13FtSyD9y0MBPCyWp7iV0OIg7nd5pYe1FbHAtg/lGZrUev47LT/fM8CeZRJxWwAWicuwRqb1wkn//fidwrPrL9OyfOuVV7AyfB4A8IEtSRLefXiE24/F+t2bky9xMMSCGvR4MfF9AFrgw8EQAPCzPIEpdXxUcAXn54HxTPCaQGycq2K9oSb3LmE1lr7U+EeWONg+reOBWERpYTFNpF8nOgb7Je7HYg2uB9IvkT0HU2psjYM/DBEw4TlSEstk8URMum5hncbqflrsmo4RpM4iVjZ0grwKB6Q+YVgLJTrPeuF0DSok3g3UStwHACkYvQejfm2HKnY10ddejJznsRw0gakY4y00qdfp4/T+pVbt/pQ65SPXz8MuKRxfvFqlm2LOyZRrkjjgTSFd5kobeheBIya+yErkeYv3rz38q+VMb1jp4QSffHYZP78rA9yQrscWE1zfkOD8WxGRRIspYhKElmPpzY0rgP1LunoIeKgjg+rUSycH7i+jbTkZqKyXF3G2msBafwahg+WOUZCixVmHliKC+ONzbYM5M8VZUQLpyGBFiTOJzIsmFuEBXTc7ZADZGcORK4VrAIw7jvsAxB3Y4aRoEWERhqEHWwS8OVsUKDgJtVZOgSdBKLmDz2PSnCWMHNp9ujRXuXq6OTbWXpDrte7KeR0w3Vdfj9z3aHSIy+oOKzVYH3mPYBhqWQqLcEnu6sL1D+LD948AAN9/WTasMq3ALEqi+8bt23hu6xNynwSZBK3QaxgLAg467RApF/AZKYsWkwUO6SK5IWlMeHFjBAURGibfbZ2/jisXxSW4PJRN5Sge4YiMHlkq1+h3ezi/KdWC33xHNsBWaPxiMtc+NwCY+xasyb0v8gJXr8pGtH3nq9LOZcCxK/V5vfLjH+FLX5YN2kCqOF+/voFbpH3aJinwc6sDRHTX7Rwd8By5R370mZiTwKJkNeuizDxZsPrJi6JAwYVNqX5sbhFxoHdiBWc4lNTCQq1a3AJ2duWBdbhxmbKFBV28Ez7/R7sGz56XPzaMtLVrlxByXTA6+AEgkP5qdVh+KFlU45avdeRdfXM6mZPpasqfblhzW6LLisMgOCe3JXJF12oVX1t6YIXjxhWWztfXUoLaReg86EKRg3WGjQBARlaZKYlux7FDqfRLykxcWoRKecUbMMb42Ikq0oGplPRLZNhZXo2wM5/w4tJL0+QIKRG0SaagpWqd0WfoHLDWF3fi/ZkoYXEY4GC7hGsqDjfSSCONNPK3Tc60hXV+EWJrPcLhm7LDP6bGMDpKENG+HgxET7jc7QFvyveGwXBzZH21Ws08xzEKT5HS1N7zyzqPV90loJ95Uxi1gKyrAo5qQdmo0kZ8KYlBgA7bs0SVLUCFLQiptY4mgKH202IuWteuImOeQzAWV1Nw5HyOh7a/A48F8JZWbMQFCFTsFmHQ8hBWhXaXNkWpGfG1fjkt8181tkID6DlQsmjioEeN0haIWgIA6MUEDJg7mJCVIyck/qjcg9mia4O8gWEZeTh1i9q+cw4p2RQGg3V84Jy4vNKS3x/knhPNUDW+e2eE/b8n0P8VMl1Mk9K7vnJClKPQeGjylDyUk/EUB8xvU2v0j//yLgLmwLSYt3VucwsbG2LNrK2KthmZEIcXxOW5NxFWlvZgFeur8lmb5LLIM180T8dB7gy6LNa5zwS7wlhceeElAMA3v/qvAUiaTEHjU13f+wvrgSmzB9IvVy8+hfNMNXh8JOCLj5zbQL8nfflgR3L9ZospYjIYL3nXWgc5gRpJtoCOArVqrC19EN8zSlhvd/u8ruU4xozegJCQ/XmZ4WgmA/2pQExF6zIYeiRoZCKdW7zOfL5roUycdTNBm3V5tGxJWZY+B6lNV3QUjWCKCuIOyNg+SXB9kl1C70fvQxfUFECSy/wzTF7KnUNGz0ROd3JRFsgU9VRWBRyVScZbqIFDQiaclGPWOE8/ijIEyiX5TUJUThaE3v2q4KjIOe+mV3Jha50HFXnXNwoPInua4Kij5BA7R3JPR/R+LIrMr186tgIDtJQTm4M1aBl86Lrk0n3qgsyzr/3JtxH9miZTY2E10kgjjTRyJuRMW1iDcY6739nHoNBETwYMD6bYfURW8pFoMsMNC0vty4qbHvYhUDIgoAl/x2JO+uoqS+E0eKt+Bhz3Ket3qqUlJZDVsw0BtGBQ5voHNacVByq1KOkyxtShQ7NgkbGMx8KgzThJ/75o593lLooDBq1n5KibV9B1lRBPslq0wyp2FSmrRS0FXq3QIitA5LcHl5zGXl/nYvNnKYCAbe4xYJZOcuRkveivCCw8wuvIqLKNHzPYPyyQO9Gc1aqyCBForCNQa7nEhVi4/fpxD89eEaDD+64LEOOtvZtQ4nBNM1gclrhx7zYA4OMfEAurFXQQYMr2U7vNiqq8Op/bfFYiYVyxw4Dy3UOHq4zjtLakh5eWVrC2JpbkClnF8zTB+ia5CxOB33e7PWzQwlrlcYeLxx7W7tm6rfXB8hHjAkW0wPkPikX5Z18RayksgVQfirLUtwE3X7D/aZVcu4anL/0UAPCYsazQvIQhYeY7e2LFTxdTtGk1dljGJY67yNzMt6UXstijxjJM4OM8GouzzvnEYQUhhDC4/4ZYcpubVwEAu9ktn1RvcgGXTMwDdJbluY9uyKRaXQfe2pPjtlkdYG0yQZseh5Wy66+lLBqWBShbnR4wVY5IkXqR1jpsvah9DxxPwtdu7gKYkdHFtZRNI0DCGFbqIeAlEi0/4ivBOg9wqlt2Crpg7VSEWY3NPwaSZcauiJzK4WAYu3JahBGmlp6iJzf+DpWVowwqRvY+LbbEzjEj0eq9HZIEFMDGsIpTAUAQWD+vAk0mz4DPfFJixJ//gvz21bdfxc9/vutjae9GzvSGZf66xKUMWKIpPaJtOt0OMP+FTP71UoLbj9+4gy1uWOURF/6DzOcdlVy+HfInqu4WeIIv8lS0kEO1Uak5bV2NMsUa5AystjkTrHUwXIGsrGXo5ADHMxzbHEcOdiQtWij7RdZHb8RaRXeEXie6uI6lIyKpxrKKhtlxMl5AHrynjAr01fj8j0jdBbUNq152ob5R1V+BJzcuvR4gEz6i66DTltZsPxxjeUCACPN6eu0+jpjXc7Qn99hdi1DQLWltdUV1WfpqrgAmoez4JuqivyLff/YFIQW+c+MWyhGfg1YccQa7LElhiUBcinowfXkoYzJOpIs5CpYVybmDLOYAU4swZ42IgRFkFwCsEok26HbRptsybsvi2e/2sbo2BACM5qQkitsYEH23OZTN7HB/D4Y5TfrcrAUStj8mvVOSj1FYqYe2w2cUhgCJJ9BiYlyn18XDm+IjX5De6aPr17DEHK7dkVBiuCjA+lD64N49QZQcTscYrsq8UsaLbtzHvOBvZiMEMV23SpgcxZ4c1fIZOgOfz5Ukc/bzGE/97ovSH9/77wCAD//2ZWzvyHWunbsGAHj9B/cwoztU0amb50PcuC9/3KbreC0YYcgkNMMK210Tese/0lwFcYxooZunfFenZNMxnZ7yGVDNq/qG1dYwA91/1hkkdCenrhq/Gdui+X3OhgiN1o8jw0pYAZiSsLqG5nPmvQBzljHXKjqZWyCkeqNcPVlRoGReWqHxAVfle/nQQQR0VZMNRWGc5C0UMzmAIFJMEmCJ1w2JFHGuCnPo83VphE9+XpTIVkdcgv/qX7yEf/vvvwtbONy6s8C7kTO5YWkC4KIQpKTyraWKrCkdHBErOVm7TeH8cQsuwjNbbTZ+0KAakAofzVENxLT23RNoITyZjBug0oincB7RpCCqUC8KeKvFFXJvADBnmwtbbRILjWGUDrPC37ScN81Q0uJUKpSo1i59LVBNMh9TcA4t9m2s0NfarqNxqMTJP+0HQArM6fv6q1caa8cx5xdzqv2LpESLDOhpqr595/so47MME4vFgkg0ZYa3OdJU+f5oaRcOCamSJuMZphP5TUqeKFs4Dw1XuiPnnN+IEm6UQZYhpZac6zXSEjmPU8vYlhXUXG/UFpXGXPA+Z9MFJmNyDdJ0noxnmBJOOJ9JOwOb+PdKMVbmzhfS9JeyFTrNL2ZJjslYYfI80FTvFYlW5s73W1YoPD9DpptAqXG6OdKF9kHp2zmdkAaNyNt0kWPBopQuytGmdZzwt1laeotUx1GZu+oz9ul8liAiBt+nTCSlf7YJn2GRO8/+rbpLkTm/2agnY1E4zDiIAy3/birPwVS9BtZirusKuy3B8SR44N1vWAsAdCTU6I6AlO/T2m8VWahzKrfOVyrQWHHpAC5l4HRACl/3Eql1yPhsC3JalXB+gGhtK1c4SWpG1W8oXcVKr7FpI2MYAHJesAyr56XfuRI+uV4XElvUzs0229JhprDZUHp4MS9gixqsXzXiXyHGvZujfsPk/v37uHLlyt90MxpppJFGGnmP5N69e7h8+fKvPOZMbljWWty4cQPPP/887t27h+Xl5b/pJp1pGY/HuHLlStOX74E0ffneSdOX7538Jvelcw6TyQSXLl3yLuRfJmfSJRgEAba2JDi/vLz8G/cAzqo0ffneSdOX7500ffneyW9qX66srLyr4xpYeyONNNJII2dCmg2rkUYaaaSRMyFndsOK4xhf+cpXPKy2kf9/afryvZOmL987afryvZO/LX15JkEXjTTSSCON/N2TM2thNdJII4008ndLmg2rkUYaaaSRMyHNhtVII4000siZkGbDaqSRRhpp5ExIs2E10kgjjTRyJqTZsBpppJFGGjkT0mxYjTTSSCONnAlpNqxGGmmkkUbOhPw/b/DHoTl0SGQAAAAASUVORK5CYII=\n"
},
"metadata": {}
}
],
"source": [
"# Simple example of image feature transormation:\n",
"img = tf.cast(img_orig, float) / 255.0\n",
"img = tf.image.resize(img, [256, 256])\n",
"img = tf.image.central_crop(img, 224 / 256)\n",
"plt.matshow(img.numpy())\n",
"\n",
"img.shape.as_list(), img.numpy().min(), img.numpy().max()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"id": "Y2y2pikmV1ei",
"outputId": "a6cfff70-2227-4857-b670-4c7c35544759"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
""
],
"application/javascript": [
"\n",
" ((filepath) => {{\n",
" if (!google.colab.kernel.accessAllowed) {{\n",
" return;\n",
" }}\n",
" google.colab.files.view(filepath);\n",
" }})(\"/usr/local/lib/python3.9/dist-packages/clu/preprocess_spec.py\")"
]
},
"metadata": {}
}
],
"source": [
"from clu import metrics\n",
"from google.colab import files\n",
"files.view(metrics.__file__.replace('metrics', 'preprocess_spec'))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d8szCvldN-F1",
"outputId": "5bf83cc3-7506-41b3-9421-c95cc7053efa"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[('central_crop', __main__.CentralCrop),\n",
" ('onehot', __main__.Onehot),\n",
" ('to_float', __main__.ToFloat)]"
]
},
"metadata": {},
"execution_count": 27
}
],
"source": [
"# Preprocessing specs are defined as dataclasses.\n",
"\n",
"# This is a lot more verbose the first time it's defined, but it allows to\n",
"# create a preprocessing library over time and then construct complicated\n",
"# preprocessing pipelines with a single string...\n",
"\n",
"import dataclasses\n",
"from typing import Optional\n",
"\n",
"from clu import preprocess_spec\n",
"\n",
"Features = preprocess_spec.Features\n",
"\n",
"@dataclasses.dataclass\n",
"class ToFloat:\n",
" name: str = \"image\"\n",
" def __call__(self, features: Features) -> Features:\n",
" return {\n",
" k: tf.cast(v, tf.float32) / 255.0 if k == self.name else v\n",
" for k, v in features.items()\n",
" }\n",
"\n",
"@dataclasses.dataclass\n",
"class CentralCrop:\n",
" size: int\n",
" crop_factor: float = 0.875\n",
" name: str = \"image\"\n",
"\n",
" def resize(self, img):\n",
" resize_sz = int(self.size / self.crop_factor)\n",
" img = tf.image.resize(img, [resize_sz] * 2)\n",
" return tf.image.central_crop(img, self.crop_factor)\n",
"\n",
" def __call__(self, features: Features) -> Features:\n",
" return {\n",
" k: self.resize(v) if k == self.name else v\n",
" for k, v in features.items()\n",
" }\n",
"\n",
"@dataclasses.dataclass\n",
"class Onehot:\n",
" num_classes: int\n",
" name: str\n",
" def __call__(self, features: Features) -> Features:\n",
" return {\n",
" k: tf.one_hot(v, self.num_classes) if k == self.name else v\n",
" for k, v in features.items()\n",
" }\n",
"\n",
"# Handy helper to auto-discover preprocess ops in namespace.\n",
"preprocess_spec.get_all_ops(__name__)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "GEfLg-z7Z3m2"
},
"outputs": [],
"source": [
"# A pre-processing function is a chain of operations:\n",
"preprocess_fn = preprocess_spec.PreprocessFn([\n",
" ToFloat(), CentralCrop(224), Onehot(5, 'label')],\n",
" # (TensorFlow has types that do not exist in JAX. In our case we don't have\n",
" # any, so this flag does not change anything.)\n",
" only_jax_types=True)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"id": "pFulvbfzJsSC"
},
"outputs": [],
"source": [
"# It's often useful to specify the entire `PreproccessFn` with a string config:\n",
"preprocess_fn = preprocess_spec.parse(\n",
" spec=\"to_float()|central_crop(224)|onehot(5, name='label')\",\n",
" available_ops=dict(preprocess_spec.get_all_ops(__name__)),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6xLNluL4KGM6",
"outputId": "3ba41a15-bab9-4247-891d-b4d548c21abd"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset tf_flowers for split train, from /root/tensorflow_datasets/tf_flowers/3.0.1\n",
"INFO:absl:Features before preprocessing: {'image': 'uint8[None, None, 3]', 'label': 'int64[]'}\n",
"INFO:absl:Features after op ToFloat(name='image'):\n",
"{'image': 'float32[None, None, 3]', 'label': 'int64[]'}\n",
"INFO:absl:Features after op CentralCrop(size=224, crop_factor=0.875, name='image'):\n",
"{'image': 'float32[224, 224, 3]', 'label': 'int64[]'}\n",
"INFO:absl:Features after op Onehot(num_classes=5, name='label'):\n",
"{'image': 'float32[224, 224, 3]', 'label': 'float32[5]'}\n",
"INFO:absl:Features after preprocessing: {'image': 'float32[224, 224, 3]', 'label': 'float32[5]'}\n"
]
}
],
"source": [
"# While we can directly `.map()` the preprocessing function, we would usually\n",
"# use it as an argument to `deterministic_data.create_dataset()`\n",
"batch = next(iter(dataset_builder.as_dataset('train').map(preprocess_fn)))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "S05K4qPzKLNt",
"outputId": "b95f53ad-783c-4bbc-c9c2-c3d3b5c75f6a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"([224, 224, 3], 1.0)"
]
},
"metadata": {},
"execution_count": 31
}
],
"source": [
"batch['image'].shape.as_list(), batch['image'].numpy().max()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-8pqb5PtMyYX",
"outputId": "497982f8-3b30-4159-d95d-3cd436d4d8bd"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 32
}
],
"source": [
"batch['label']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TMtQ9j9YOvt4"
},
"source": [
"### Complete evaluation\n",
"\n",
"When evaluating on multiple devices, every device requires the same batch size.\n",
"Thus, it's not straightforward to evaluate a complete dataset if the number of\n",
"examples in that dataset is not divisible by the number of devices.\n",
"\n",
"`clu` provides a simple recipe for this:\n",
"\n",
"1. Pad incomplete batches, adding a boolean \"mask\" feature that indicates\n",
" whether an example should be considered when computing metrics.\n",
"2. Take \"mask\" into consideration when computing metrics."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VTSQCXaGSFra",
"outputId": "01a77470-6d16-4270-e141-c7257a5315be"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Load dataset info from /root/tensorflow_datasets/mnist/3.0.1\n",
"INFO:absl:Fields info.[citation, splits, supervised_keys, module_name] from disk and from code do not match. Keeping the one from code.\n",
"INFO:absl:Reusing dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)\n"
]
}
],
"source": [
"dataset_builder = tfds.builder('mnist')\n",
"dataset_builder.download_and_prepare()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "erStmde9Qm6e",
"outputId": "d4649627-643a-454b-f8ed-8e6ae2fb936a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"10000"
]
},
"metadata": {},
"execution_count": 34
}
],
"source": [
"num_test_examples = dataset_builder.info.splits['test'].num_examples\n",
"num_test_examples"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bKAMpo_JUOnm",
"outputId": "2e8f5230-ab4b-4be4-b66a-56a38ccc199d"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset mnist for split test, from /root/tensorflow_datasets/mnist/3.0.1\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Counter({2: 1032,\n",
" 0: 980,\n",
" 4: 982,\n",
" 8: 974,\n",
" 7: 1028,\n",
" 6: 958,\n",
" 3: 1010,\n",
" 1: 1135,\n",
" 9: 1009,\n",
" 5: 892})"
]
},
"metadata": {},
"execution_count": 35
}
],
"source": [
"# For reference: Label distribution in complete test set.\n",
"import collections\n",
"collections.Counter(next(iter(dataset_builder.as_dataset('test').batch(num_test_examples)))['label'].numpy())"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pevkYjCASXJK",
"outputId": "b3dd2424-bca5-436a-f71e-24685159c098"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset mnist for split test, from /root/tensorflow_datasets/mnist/3.0.1\n"
]
}
],
"source": [
"# Naive approach without padding.\n",
"batch_size = 128\n",
"local_batch_size = batch_size // jax.device_count()\n",
"test_ds = deterministic_data.create_dataset(\n",
" dataset_builder,\n",
" split='test',\n",
" batch_dims=[jax.local_device_count(), local_batch_size],\n",
" num_epochs=1,\n",
" shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "akwTsGHmSjtt",
"outputId": "b55d3914-1c3d-4755-841f-a69b3e34cd6a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'accuracy': Array(0.09795673, dtype=float32)}"
]
},
"metadata": {},
"execution_count": 37
}
],
"source": [
"def get_logits(images):\n",
" # Always predict label=0\n",
" return jnp.tile(jax.nn.one_hot(0, 10)[None], [len(images), 1])\n",
"\n",
"def eval_step(batch):\n",
" logits = get_logits(batch['image'])\n",
" return metrics.Collection.create(\n",
" accuracy=metrics.Accuracy\n",
" ).gather_from_model_output(logits=logits, labels=batch['label'])\n",
"\n",
"eval_step_p = jax.pmap(eval_step, axis_name='batch')\n",
"\n",
"my_metrics = None\n",
"for batch in test_ds:\n",
" batch = jax.tree_util.tree_map(np.asarray, batch)\n",
" update = eval_step_p(batch).unreplicate()\n",
" my_metrics = update if my_metrics is None else my_metrics.merge(update)\n",
"\n",
"# Note that accuracy should be exactly 0.098\n",
"my_metrics.compute()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tD2cdCpSXJuI",
"outputId": "2e639b5f-00b1-4548-8190-68750df4fd99"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"78.125"
]
},
"metadata": {},
"execution_count": 38
}
],
"source": [
"# Difference is due to fact that number of examples in test set is not divisible\n",
"# by batch.\n",
"num_test_examples / batch_size\n",
"# Note that in this case we could have chosen a batch that both divides the\n",
"# number of examples in the batch size and is divisible by the number of local\n",
"# devices (e.g. 200), but that is not possible in the general case, and for\n",
"# performance/memory reasons we might be constrained for choice of batch size."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "gWZ4lwZQYPl3",
"outputId": "7f52a24f-3fd2-484d-fb0f-83a000010765"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"\"_EvenSplit(split='test[:10000]', index=0, count=1, drop_remainder=False)\""
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 39
}
],
"source": [
"# For completeness of the example, let's pretend we evaluate on multiple hosts.\n",
"test_split = tfds.split_for_jax_process(f'test[:{num_test_examples}]')\n",
"str(test_split)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aj7lIAlaY1gS",
"outputId": "61df4079-366b-4147-d4a4-434ed2ba4ebd"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"112"
]
},
"metadata": {},
"execution_count": 40
}
],
"source": [
"# Compute how many batches we need to contain the entire test set.\n",
"global_batch_size = batch_size * jax.process_count()\n",
"pad_up_to_batches = int(np.ceil(num_test_examples / global_batch_size))\n",
"# This should result in 112 padded examples at the very end.\n",
"global_batch_size * pad_up_to_batches - num_test_examples"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IVjomDnuRQL1",
"outputId": "c7d94c57-aed0-4471-c3b1-141afbe5e42a"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"INFO:absl:Constructing tf.data.Dataset mnist for split _EvenSplit(split='test[:10000]', index=0, count=1, drop_remainder=False), from /root/tensorflow_datasets/mnist/3.0.1\n"
]
}
],
"source": [
"test_ds = deterministic_data.create_dataset(\n",
" dataset_builder,\n",
" split=test_split,\n",
" batch_dims=[jax.local_device_count(), local_batch_size],\n",
" num_epochs=1,\n",
" # Pad with masked examples instead of dropping incomplete final batch.\n",
" pad_up_to_batches=pad_up_to_batches,\n",
" shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jb6uwPHlQbmI",
"outputId": "da1fdfdb-dcf2-4666-d7e2-51bb91ac2580"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'accuracy': Array(0.098, dtype=float32)}"
]
},
"metadata": {},
"execution_count": 42
}
],
"source": [
"# Same cell as above, but this time with masking & correct final accuracy.\n",
"\n",
"def eval_step(batch):\n",
" logits = get_logits(batch['image'])\n",
" return metrics.Collection.create(\n",
" accuracy=metrics.Accuracy\n",
" ).gather_from_model_output(\n",
" logits=logits,\n",
" labels=batch['label'],\n",
" # IMPORTANT: You must pass in the \"mask\" feature as an additional model\n",
" # output so `metrics.Accuracy` is aware of it. Otherwise the masked\n",
" # examples would be used for the metric computation... This is due to the\n",
" # fact that `metrics.Accuracy` can handle both with and without \"mask\".\n",
" # In the usual case you would implement the metrics yourself and the code\n",
" # would fail if you forget to specify the mask here.\n",
" mask=batch['mask'],\n",
" )\n",
"\n",
"eval_step_p = jax.pmap(eval_step, axis_name='batch')\n",
"\n",
"my_metrics = None\n",
"for batch in test_ds:\n",
" batch = jax.tree_util.tree_map(np.asarray, batch)\n",
" update = eval_step_p(batch).unreplicate()\n",
" my_metrics = update if my_metrics is None else my_metrics.merge(update)\n",
"\n",
"my_metrics.compute()"
]
}
],
"metadata": {
"colab": {
"name": "clu synopsis",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: setup.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""setup.py for Common Loop Utils.
Install for development:
pip intall -e . .[tests]
"""
import sys
from setuptools import find_packages
from setuptools import setup
if sys.version_info < (3, 10):
sys.exit("Python < 3.10 not supported anymore!")
tests_require = [
"pytest",
"tensorflow",
"tensorflow_datasets",
"torch>=2.0.0",
]
setup(
name="clu",
version="0.0.12",
description="Set of libraries for ML training loops in JAX.",
author="Common Loop Utils Authors",
author_email="no-reply@google.com",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="http://github.com/google/CommonLoopUtils",
license="Apache 2.0",
packages=find_packages(),
include_package_data=True,
install_requires=[
"absl-py",
"etils[epath,epy]",
"flax",
"jax",
"jaxlib",
"ml_collections",
"numpy",
"packaging",
"typing_extensions",
"wrapt",
],
tests_require=tests_require,
extras_require=dict(test=tests_require),
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="JAX machine learning",
)