Repository: ikarth/wfc_2019f Branch: master Commit: 3a937fed1393 Files: 45 Total size: 180.0 KB Directory structure: gitextract_a9o9q6ic/ ├── .github/ │ └── workflows/ │ └── python-package.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── doc/ │ ├── conf.py │ ├── dot/ │ │ ├── chain.dot │ │ ├── dependency.dot │ │ └── design.dot │ └── index.rst ├── images/ │ └── samples/ │ ├── Castle/ │ │ └── data.xml │ ├── Circles/ │ │ └── data.xml │ ├── Circuit/ │ │ └── data.xml │ ├── Knots/ │ │ └── data.xml │ ├── Rooms/ │ │ └── data.xml │ └── Summer/ │ └── data.xml ├── pyproject.toml ├── requirements.txt ├── samples.xml ├── samples_cats.xml ├── samples_original.xml ├── samples_reference.xml ├── samples_reference_continue.xml ├── samples_reference_nohogs.xml ├── samples_test.xml ├── samples_test_ground.xml ├── samples_test_vis.xml ├── setup.cfg ├── setup.py ├── tests/ │ ├── __init__.py │ ├── conftest.py │ ├── test_wfc_adjacency.py │ ├── test_wfc_patterns.py │ ├── test_wfc_solver.py │ └── test_wfc_tiles.py ├── wfc/ │ ├── __init__.py │ ├── py.typed │ ├── wfc_adjacency.py │ ├── wfc_control.py │ ├── wfc_patterns.py │ ├── wfc_solver.py │ ├── wfc_tiles.py │ ├── wfc_utilities.py │ └── wfc_visualize.py └── wfc_run.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/python-package.yml ================================================ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions name: Python package on: push: pull_request: defaults: run: shell: bash jobs: mypy: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: Install Python dependencies run: pip install mypy -r requirements.txt - name: Mypy uses: liskin/gh-problem-matcher-wrap@v1 with: linters: mypy run: mypy --show-column-numbers . pytest: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: Install Python dependencies run: pip install -r requirements.txt - name: Test with pytest run: pytest wfc_run: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: 3.x - name: Build package run: pip install --editable . - name: Run all experiments run: | python ./wfc_run.py -e simple -s samples.xml python ./wfc_run.py -e choice -s samples.xml python ./wfc_run.py -e choices -s samples.xml python ./wfc_run.py -e heuristic -s samples.xml python ./wfc_run.py -e backtracking -s samples.xml python ./wfc_run.py -e backtracking_heuristic -s samples.xml - name: Package output folder run: tar -cf wfc-output.tar --format=ustar output/ - uses: actions/upload-artifact@v2 with: name: wfc-output path: wfc-output.tar retention-days: 7 ================================================ FILE: .gitignore ================================================ __pycache__ /output/* /build logs/ *.egg-info/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Isaac Karth Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include wfc/py.typed ================================================ FILE: README.md ================================================ # wfc_2019f This is my research implementation of WaveFunctionCollapse in Python. It has two goals: * Make it easier to understand how the algorithm operates * Provide a testbed for experimenting with alternate heuristics and features For more general-purpose WFC information, the original reference repository remains the best resource: https://github.com/mxgmn/WaveFunctionCollapse ## Installing ``` git clone https://github.com/ikarth/wfc_2019f.git cd wfc_2019f conda create -n wfc2019 python=3.10 conda activate wfc2019 pip install -r requirements.txt python wfc_run.py -s samples_reference.xml ``` ## Running WFC If you want direct control over running WFC, call `wfc_control.execute_wfc()`. The arguments it accepts are: - `filename=None`: path to the input image file, this is mostly for internal use and should be left as `None`, set `image` instead. - `tile_size=1`: size of the tiles it uses (1 is fine for pixel images, larger is for things like a Super Metroid map) - `pattern_width=2`: size of the patterns; usually 2 or 3 because bigger gets slower and - `rotations=8`: how many reflections and/or rotations to use with the patterns - `output_size=[48,48]`: how big the output image is - `ground=None`: which patterns should be placed along the bottom-most line - `attempt_limit=10`: stop after this many tries - `output_periodic=True`: the output wraps at the edges - `input_periodic=True`: the input wraps at the edges - `loc_heuristic="entropy"`: what location heuristic to use; `entropy` is the original WFC behavior. The heuristics that are implemented are `lexical`, `hilbert`, `spiral`, `entropy`, `anti-entropy`, `simple`, `random`, but when in doubt stick with `entropy`. - `choice_heuristic="weighted"`: what choice heuristic to use; `weighted` is the original WFC behavior, other options are `random`, `rarest`, and `lexical`. - `visualize=False`: write intermediate images to disk? requires `filename`. - `global_constraint=False`: what global constraint to use. Currently the only one implemented is `allpatterns` - `backtracking=False`: do we use backtracking if we run into a contradiction? - `log_filename="log"`: what should the log file be named? - `logging=False`: should we write to a log file? requires `filename`. - `log_stats_to_output=None` - `image`: an array of pixel data, typically in the shape: (height, width, rgb) ## Test ``` pytest ``` ## Documentation ``` python setup.py build_sphinx ``` With linux the documentation can be displayed with: ``` xdg-open build/sphinx/index.html ``` ================================================ FILE: doc/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) # -- Project information ----------------------------------------------------- project = 'wfc_python' copyright = '2020, Isaac Karth' author = 'Isaac Karth' # The full version, including alpha/beta/rc tags release = '0.1' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.graphviz', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] # type: ignore[var-annotated] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['static'] ================================================ FILE: doc/dot/chain.dot ================================================ digraph { read_xml_command -> import_image -> make_tile_catalog -> make_pattern_catalog -> make_adjacency_matrix -> solve_constraint_problem -> output_solution_image make_tile_catalog -> output_solution_image make_tile_catalog -> instrumentation [color=gray] make_pattern_catalog -> instrumentation [color=gray] make_adjacency_matrix -> instrumentation [color=gray] solve_constraint_problem -> instrumentation [color=gray] output_solution_image -> visualization [color=gray] make_tile_catalog -> visualization [color=gray] make_pattern_catalog -> visualization [color=gray] make_adjacency_matrix -> visualization [color=gray] solve_constraint_problem -> visualization [color=gray] visualization [color=gray, fontcolor=gray] instrumentation [color=gray, fontcolor=gray] visualization -> make_tile_catalog [color=magenta] } ================================================ FILE: doc/dot/dependency.dot ================================================ digraph { wfc_run -> wfc_control wfc_control -> wfc_utilities wfc_control -> wfc_solver wfc_solver -> numpy wfc_tiles -> numpy wfc_patterns -> numpy wfc_tiles -> wfc_utilities wfc_control -> wfc_tiles wfc_control -> wfc_patterns wfc_patterns -> wfc_utilities wfc_tiles -> imageio wfc_control -> wfc_adjacency wfc_control -> wfc_visualize wfc_visualize -> matplotlib wfc_visualize -> wfc_utilities wfc_adjacency -> wfc_utilities wfc_adjacency -> numpy wfc_control -> wfc_instrumentation implemented [style=filled, fillcolor=gray] partial [style=filled, fillcolor=cyan] unimplemented [style=filled, fillcolor=firebrick] wfc_run wfc_control [] wfc_solver numpy [color=gray, fontcolor=gray] wfc_tiles wfc_patterns [style=filled, fillcolor=cyan] wfc_utilities imageio [color=gray, fontcolor=gray] wfc_adjacency wfc_visualize [style=filled, fillcolor=cyan] matplotlib [color=gray, fontcolor=gray] wfc_instrumentation [style=filled, fillcolor=firebrick] label="Modules in WFC 19f" } ================================================ FILE: doc/dot/design.dot ================================================ digraph { things_to_implement [label="{Things that aren't implemented yet|Intermediate visualization|timing and profiling|performance statistics|outputting images|most heuristics|removing ground patterns|rotated patterns}", shape=record, fillcolor="cyan", style=filled] read_data [label="Read data from XML", fillcolor="cyan", shape=box, style=filled] read_data -> input_data input_data [shape=record, label="XML"] input_data -> execute_wfc solver [label="Solver", shape=house] solver -> make_wave make_wave -> remove_patterns remove_patterns [label="Remove ground patterns", fillcolor="cyan", style=filled] input_data -> remove_patterns input_data -> solver remove_patterns -> solver_run [headport=n] subgraph cluster_solver_run { label="wfc_solver.py" solver_run [label="solver.run()"] solver_observe [label="solver.observe()"] solver_propagate [label="solver.propagate()"] solver_on_backtrack [label="solver.onBacktrack()", shape=invhouse] solver_on_choice [label="solver.onChoice()", shape=invhouse] on_choice [label="onChoice()", shape=note] on_backtrack [label="onBacktrack()", shape=note, fillcolor="cyan", style=filled] solver_if_backtracking [label="if backtracking", shape=diamond] pattern_heuristic [label="pattern heuristic", shape=note] location_heuristic [label="location heuristic", shape=note] {rank=same pattern_heuristic location_heuristic} solver_run -> solver_check_feasible solver_check_feasible -> solver_propagate solver_propagate -> solver_observe solver_observe -> pattern_heuristic solver_observe -> location_heuristic solver_observe -> solver_on_choice solver_on_choice -> on_choice solver_recurse -> except_contradictions [color=red] solver_on_choice -> solver_if_finished solver_recurse -> solver_run [headport=n, tailport=w] solver_if_finished [shape=diamond] solver_if_finished -> solver_recurse [splines=polyline, dir=both, arrowhead=dotvee, arrowtail=dot, tailport=s, headport=n, color="black:green:black"] except_contradictions -> solver_if_backtracking solver_if_backtracking -> solver_on_backtrack [label="Yes"] solver_on_backtrack -> on_backtrack on_backtrack -> solver_run [headport=n] solver_if_backtracking -> cant_solve [splines=curved, label="No", dir=both, arrowhead=dotvee, arrowtail=dot, tailport=e, headport=ne, color="grey"] } solver_if_finished -> solver_solution [tailport=w, color="black:blue:black"] execute_wfc [shape=invhouse, fillcolor="cyan", style=filled] execute_wfc -> import_image import_image [shape=box] import_image -> make_tile_catalog subgraph cluster_tile_py { label="wfc_tiles.py" make_tile_catalog -> image_to_tiles } image_to_tiles -> tile_catalog tile_catalog [label="Tile Catalog|{dictionary of tiles|image in tile IDs|set of tiles|frequency of tile occurance}", shape=record] subgraph cluster_patterns { label="wfc_patterns.py" tile_catalog -> make_pattern_catalog {rank=same unique_patterns_2d rotate_or_reflect} make_pattern_catalog -> unique_patterns_2d -> rotate_or_reflect -> unique_patterns_2d make_pattern_catalog [fillcolor="cyan", style=filled] rotate_or_reflect [fillcolor="cyan", style=filled] } unique_patterns_2d -> pattern_catalog pattern_catalog [label="Pattern Catalog|{dictionary of patterns|ordered list of pattern weights|ordered list of pattern contents}", shape=record] pattern_catalog -> extract_adjacency subgraph cluster_adjacency { extract_adjacency -> is_valid_overlap } extract_adjacency -> adjacency_relations adjacency_relations [label="{Adjacency Relations|tuples of (edge,pattern,pattern)}", shape=record] adjacency_relations -> combine_inputs combine_inputs -> adjacency_matrix adjacency_matrix [label="{Adjacency Matrix|boolean matrix of pattern x pattern x direction}", shape=record] adjacency_matrix -> solver pattern_catalog -> solver cant_solve [label="Can't Solve", shape=box] solver_solution [shape=record, label="Solution|grid of pattern IDs"] solver_solution -> visualizer visualizer -> output_image output_image [shape=box, label="Output Image", style=filled, fillcolor=cyan] pattern_catalog -> visualizer tile_catalog -> visualizer visualizer [fillcolor=cyan, style=filled] } ================================================ FILE: doc/index.rst ================================================ Documentation ============= Module dependencies ------------------- .. graphviz:: dot/dependency.dot Design ------ .. graphviz:: dot/design.dot Chain ----- .. graphviz:: dot/chain.dot ================================================ FILE: images/samples/Castle/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: images/samples/Circles/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: images/samples/Circuit/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: images/samples/Knots/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: images/samples/Rooms/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: images/samples/Summer/data.xml ================================================ <<<<<<< HEAD ======= >>>>>>> 1fe4f1f60ebd57b99ca7148fb003edefa7979d94 ================================================ FILE: pyproject.toml ================================================ [project] name = "wfc_python" version = "0.0.0" description = "Implementation of wave function collapse in Python." readme = "README.md" requires-python = ">=3.5" license = {file = "LICENSE"} keywords = ["sample", "wfc", "wave function collapse"] authors = [{name = "Isaac Karth", email = "isaac@isaackarth.com"}] classifiers = [ "Development Status :: 3 - Alpha", "Topic :: Utilities", "License :: OSI Approved :: MIT License", ] dependencies = [ "hilbertcurve", "imageio", "matplotlib", "numpy", "scipy" ] [project.optional-dependencies] tests = ["pytest"] docs = ["sphinx"] [project.urls] homepage = "https://github.com/ikarth/wfc_python" [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" ================================================ FILE: requirements.txt ================================================ numpy hilbertcurve>=2 imageio matplotlib scipy # Testing pytest types-setuptools # Documentation. sphinx ================================================ FILE: samples.xml ================================================ ================================================ FILE: samples_cats.xml ================================================ ================================================ FILE: samples_original.xml ================================================ ================================================ FILE: samples_reference.xml ================================================ ================================================ FILE: samples_reference_continue.xml ================================================ ================================================ FILE: samples_reference_nohogs.xml ================================================ ================================================ FILE: samples_test.xml ================================================ ================================================ FILE: samples_test_ground.xml ================================================ ================================================ FILE: samples_test_vis.xml ================================================ ================================================ FILE: setup.cfg ================================================ [metadata] name = wfc_python version = 0.0.0 [options] packages = wfc include_package_data = True install_requires = hilbertcurve>=2 imageio matplotlib numpy scipy [options.package_data] wfc = py.typed ================================================ FILE: setup.py ================================================ #!/usr/bin/env python import setuptools if __name__ == "__main__": setuptools.setup() ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/conftest.py ================================================ from __future__ import annotations import os.path import pytest PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) class Resources: def get_image(self, image: str) -> str: return os.path.join(PROJECT_ROOT, "images", image) @pytest.fixture(scope="session") def resources() -> Resources: return Resources() ================================================ FILE: tests/test_wfc_adjacency.py ================================================ """Convert input data to adjacency information""" from __future__ import annotations import imageio # type: ignore from tests.conftest import Resources from wfc import wfc_tiles from wfc import wfc_patterns from wfc import wfc_adjacency def test_adjacency_extraction(resources: Resources) -> None: # TODO: generalize this to more than the four cardinal directions direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)])) filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) tile_size = 1 pattern_width = 2 periodic = False _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(img, tile_size) pattern_catalog, _pattern_weights, _pattern_list, pattern_grid = wfc_patterns.make_pattern_catalog( tile_grid, pattern_width, periodic ) adjacency_relations = wfc_adjacency.adjacency_extraction( pattern_grid, pattern_catalog, direction_offsets ) assert ((0, -1), -6150964001204120324, -4042134092912931260) in adjacency_relations assert ((-1, 0), -4042134092912931260, 3069048847358774683) in adjacency_relations assert ((1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations assert ((-1, 0), -3950451988873469076, -3950451988873469076) in adjacency_relations assert ((0, 1), -3950451988873469076, 3336256675067683735) in adjacency_relations assert ( not ((0, -1), -3950451988873469076, -3950451988873469076) in adjacency_relations ) assert ( not ((0, 1), -3950451988873469076, -3950451988873469076) in adjacency_relations ) ================================================ FILE: tests/test_wfc_patterns.py ================================================ from __future__ import annotations import imageio # type: ignore import numpy as np from tests.conftest import Resources from wfc import wfc_patterns from wfc import wfc_tiles def test_unique_patterns_2d(resources: Resources) -> None: filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) tile_size = 1 pattern_width = 2 _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(img, tile_size) _patterns_in_grid, pattern_contents_list, patch_codes = wfc_patterns.unique_patterns_2d( tile_grid, pattern_width, True ) assert patch_codes[1][2] == 4867810695119132864 assert pattern_contents_list[7][1][1] == 8253868773529191888 def test_make_pattern_catalog(resources: Resources) -> None: filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) tile_size = 1 pattern_width = 2 _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(img, tile_size) pattern_catalog, pattern_weights, pattern_list, _pattern_grid = wfc_patterns.make_pattern_catalog( tile_grid, pattern_width ) assert pattern_weights[-6150964001204120324] == 1 assert pattern_list[3] == 2800765426490226432 assert pattern_catalog[5177878755649963747][0][1] == -8754995591521426669 def test_pattern_to_tile(resources: Resources) -> None: filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) tile_size = 1 pattern_width = 2 _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(img, tile_size) pattern_catalog, _pattern_weights, _pattern_list, pattern_grid = wfc_patterns.make_pattern_catalog( tile_grid, pattern_width ) new_tile_grid = wfc_patterns.pattern_grid_to_tiles(pattern_grid, pattern_catalog) assert np.array_equal(tile_grid, new_tile_grid) ================================================ FILE: tests/test_wfc_solver.py ================================================ from __future__ import annotations from typing import Any, Dict, List, Set, Tuple from numpy.typing import NDArray import imageio # type: ignore import numpy import numpy as np from tests.conftest import Resources from wfc import wfc_solver from wfc import wfc_tiles from wfc import wfc_patterns from wfc import wfc_adjacency def test_makeWave() -> None: wave = wfc_solver.makeWave(3, 10, 20, ground=[-1]) # print(wave) # print(wave.sum()) # print((2*10*19) + (1*10*1)) assert wave.sum() == (2 * 10 * 19) + (1 * 10 * 1) assert wave[2, 5, 19] == True assert wave[1, 5, 19] == False def test_entropyLocationHeuristic() -> None: wave = numpy.ones((5, 3, 4), dtype=bool) # everthing is possible wave[1:, 0, 0] = False # first cell is fully observed wave[4, :, 2] = False preferences: NDArray[np.float_] = numpy.ones((3, 4), dtype=np.float_) * 0.5 preferences[1, 2] = 0.3 preferences[1, 1] = 0.1 heu = wfc_solver.makeEntropyLocationHeuristic(preferences) result = heu(wave) assert (1, 2) == result def test_observe() -> None: my_wave = numpy.ones((5, 3, 4), dtype=np.bool_) my_wave[0, 1, 2] = False def locHeu(wave: NDArray[np.bool_]) -> Tuple[int, int]: assert numpy.array_equal(wave, my_wave) return 1, 2 def patHeu(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int: assert numpy.array_equal(weights, my_wave[:, 1, 2]) return 3 assert wfc_solver.observe(my_wave, locationHeuristic=locHeu, patternHeuristic=patHeu) == ( 3, 1, 2, ) def test_propagate() -> None: wave = numpy.ones((3, 3, 4), dtype=bool) adjLists = {} # checkerboard #0/#1 or solid fill #2 adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [ [1], [0], [2], ] wave[:, 0, 0] = False wave[0, 0, 0] = True adj = wfc_solver.makeAdj(adjLists) wfc_solver.propagate(wave, adj, periodic=False) expected_result = numpy.array( [ [ [True, False, True, False], [False, True, False, True], [True, False, True, False], ], [ [False, True, False, True], [True, False, True, False], [False, True, False, True], ], [ [False, False, False, False], [False, False, False, False], [False, False, False, False], ], ] ) assert numpy.array_equal(wave, expected_result) def test_run() -> None: wave = wfc_solver.makeWave(3, 3, 4) adjLists = {} adjLists[(+1, 0)] = adjLists[(-1, 0)] = adjLists[(0, +1)] = adjLists[(0, -1)] = [ [1], [0], [2], ] adj = wfc_solver.makeAdj(adjLists) first_result = wfc_solver.run( wave.copy(), adj, locationHeuristic=wfc_solver.lexicalLocationHeuristic, patternHeuristic=wfc_solver.lexicalPatternHeuristic, periodic=False, ) expected_first_result = numpy.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]]) assert numpy.array_equal(first_result, expected_first_result) event_log: List[Any] = [] def onChoice(pattern: int, i: int, j: int) -> None: event_log.append((pattern, i, j)) def onBacktrack() -> None: event_log.append("backtrack") second_result = wfc_solver.run( wave.copy(), adj, locationHeuristic=wfc_solver.lexicalLocationHeuristic, patternHeuristic=wfc_solver.lexicalPatternHeuristic, periodic=True, backtracking=True, onChoice=onChoice, onBacktrack=onBacktrack, ) expected_second_result = numpy.array([[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]) assert numpy.array_equal(second_result, expected_second_result) print(event_log) assert event_log == [(0, 0, 0), "backtrack", (2, 0, 0)] class Infeasible(Exception): pass def explode(wave: NDArray[np.bool_]) -> bool: if wave.sum() < 20: raise Infeasible return False try: result = wfc_solver.run( wave.copy(), adj, locationHeuristic=wfc_solver.lexicalLocationHeuristic, patternHeuristic=wfc_solver.lexicalPatternHeuristic, periodic=True, backtracking=True, checkFeasible=explode, ) print(result) happy = False except wfc_solver.Contradiction: happy = True assert happy ================================================ FILE: tests/test_wfc_tiles.py ================================================ """Breaks an image into consituant tiles.""" from __future__ import annotations import imageio # type: ignore from tests.conftest import Resources from wfc import wfc_tiles def test_image_to_tile(resources: Resources) -> None: filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) tiles = wfc_tiles.image_to_tiles(img, 1) assert tiles[2][2][0][0][0] == 255 assert tiles[2][2][0][0][1] == 0 def test_make_tile_catalog(resources: Resources) -> None: filename = resources.get_image("samples/Red Maze.png") img = imageio.imread(filename) print(img) tc, tg, cl, ut = wfc_tiles.make_tile_catalog(img, 1) print("tile catalog") print(tc) print("tile grid") print(tg) print("code list") print(cl) print("unique tiles") print(ut) assert ut[1][0] == 7 ================================================ FILE: wfc/__init__.py ================================================ ================================================ FILE: wfc/py.typed ================================================ ================================================ FILE: wfc/wfc_adjacency.py ================================================ """Convert input data to adjacency information""" from __future__ import annotations from typing import Dict, List, Tuple import numpy as np from numpy.typing import NDArray def adjacency_extraction( pattern_grid: NDArray[np.int64], pattern_catalog: Dict[int, NDArray[np.int64]], direction_offsets: List[Tuple[int, Tuple[int, int]]], pattern_size: Tuple[int, int] = (2, 2), ) -> List[Tuple[Tuple[int, int], int, int]]: """Takes a pattern grid and returns a list of all of the legal adjacencies found in it.""" def is_valid_overlap_xy(adjacency_direction: Tuple[int, int], pattern_1: int, pattern_2: int) -> bool: """Given a direction and two patterns, find the overlap of the two patterns and return True if the intersection matches.""" dimensions = (1, 0) not_a_number = -1 # TODO: can probably speed this up by using the right slices, rather than rolling the whole pattern... shifted = np.roll( np.pad( pattern_catalog[pattern_2], max(pattern_size), mode="constant", constant_values=not_a_number, ), adjacency_direction, dimensions, ) compare = shifted[ pattern_size[0] : pattern_size[0] + pattern_size[0], pattern_size[1] : pattern_size[1] + pattern_size[1], ] left = max(0, 0, +adjacency_direction[0]) right = min(pattern_size[0], pattern_size[0] + adjacency_direction[0]) top = max(0, 0 + adjacency_direction[1]) bottom = min(pattern_size[1], pattern_size[1] + adjacency_direction[1]) a = pattern_catalog[pattern_1][top:bottom, left:right] b = compare[top:bottom, left:right] res = np.array_equal(a, b) return res pattern_list = list(pattern_catalog.keys()) legal = [] for pattern_1 in pattern_list: for pattern_2 in pattern_list: for _direction_index, direction in direction_offsets: if is_valid_overlap_xy(direction, pattern_1, pattern_2): legal.append((direction, pattern_1, pattern_2)) return legal ================================================ FILE: wfc/wfc_control.py ================================================ from __future__ import annotations import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple from .wfc_tiles import make_tile_catalog from .wfc_patterns import ( pattern_grid_to_tiles, make_pattern_catalog_with_rotations, ) from .wfc_adjacency import adjacency_extraction from .wfc_solver import ( run, makeWave, makeAdj, lexicalLocationHeuristic, lexicalPatternHeuristic, makeWeightedPatternHeuristic, Contradiction, StopEarly, makeEntropyLocationHeuristic, make_global_use_all_patterns, makeRandomLocationHeuristic, makeRandomPatternHeuristic, TimedOut, simpleLocationHeuristic, makeSpiralLocationHeuristic, makeHilbertLocationHeuristic, makeAntiEntropyLocationHeuristic, makeRarestPatternHeuristic, ) from .wfc_visualize import ( figure_list_of_tiles, figure_false_color_tile_grid, figure_pattern_catalog, render_tiles_to_output, figure_adjacencies, make_solver_visualizers, make_solver_loggers, tile_grid_to_image, ) import imageio # type: ignore import numpy as np import time import logging from numpy.typing import NDArray logger = logging.getLogger(__name__) def visualize_tiles(unique_tiles, tile_catalog, tile_grid): if False: figure_list_of_tiles(unique_tiles, tile_catalog) figure_false_color_tile_grid(tile_grid) def visualize_patterns(pattern_catalog, tile_catalog, pattern_weights, pattern_width): if False: figure_pattern_catalog( pattern_catalog, tile_catalog, pattern_weights, pattern_width ) def make_log_stats() -> Callable[[Dict[str, Any], str], None]: log_line = 0 def log_stats(stats: Dict[str, Any], filename: str) -> None: nonlocal log_line if stats: log_line += 1 with open(filename, "a", encoding="utf_8") as logf: if log_line < 2: for s in stats.keys(): print(str(s), end="\t", file=logf) print("", file=logf) for s in stats.keys(): print(str(stats[s]), end="\t", file=logf) print("", file=logf) return log_stats def execute_wfc( filename: Optional[str] = None, tile_size: int = 1, pattern_width: int = 2, rotations: int = 8, output_size: Tuple[int, int] = (48, 48), ground: Optional[int] = None, attempt_limit: int = 10, output_periodic: bool = True, input_periodic: bool = True, loc_heuristic: Literal["lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"] = "entropy", choice_heuristic: Literal["lexical", "rarest", "weighted", "random"] = "weighted", visualize: bool = False, global_constraint: Literal[False, "allpatterns"] = False, backtracking: bool = False, log_filename: str = "log", logging: bool = False, global_constraints: None = None, log_stats_to_output: Optional[Callable[[Dict[str, Any], str], None]] = None, *, image: Optional[NDArray[np.integer]] = None, ) -> NDArray[np.integer]: timecode = datetime.datetime.now().isoformat().replace(":", ".") time_begin = time.perf_counter() output_destination = r"./output/" input_folder = r"./images/samples/" rotations -= 1 # change to zero-based input_stats = { "filename": str(filename), "tile_size": tile_size, "pattern_width": pattern_width, "rotations": rotations, "output_size": output_size, "ground": ground, "attempt_limit": attempt_limit, "output_periodic": output_periodic, "input_periodic": input_periodic, "location heuristic": loc_heuristic, "choice heuristic": choice_heuristic, "global constraint": global_constraint, "backtracking": backtracking, } # Load the image if filename: if image is not None: raise TypeError("Only filename or image can be provided, not both.") image = imageio.imread(input_folder + filename + ".png")[:, :, :3] # TODO: handle alpha channels if image is None: raise TypeError("An image must be given.") # TODO: generalize this to more than the four cardinal directions direction_offsets = list(enumerate([(0, -1), (1, 0), (0, 1), (-1, 0)])) tile_catalog, tile_grid, _code_list, _unique_tiles = make_tile_catalog(image, tile_size) ( pattern_catalog, pattern_weights, pattern_list, pattern_grid, ) = make_pattern_catalog_with_rotations( tile_grid, pattern_width, input_is_periodic=input_periodic, rotations=rotations ) logger.debug("pattern catalog") # visualize_tiles(unique_tiles, tile_catalog, tile_grid) # visualize_patterns(pattern_catalog, tile_catalog, pattern_weights, pattern_width) # figure_list_of_tiles(unique_tiles, tile_catalog, output_filename=f"visualization/tilelist_{filename}_{timecode}") # figure_false_color_tile_grid(tile_grid, output_filename=f"visualization/tile_falsecolor_{filename}_{timecode}") if visualize and filename: figure_pattern_catalog( pattern_catalog, tile_catalog, pattern_weights, pattern_width, output_filename=f"visualization/pattern_catalog_{filename}_{timecode}", ) logger.debug("profiling adjacency relations") if False: import pprofile # type: ignore profiler = pprofile.Profile() with profiler: adjacency_relations = adjacency_extraction( pattern_grid, pattern_catalog, direction_offsets, [pattern_width, pattern_width], ) profiler.dump_stats(f"logs/profile_adj_{filename}_{timecode}.txt") else: adjacency_relations = adjacency_extraction( pattern_grid, pattern_catalog, direction_offsets, (pattern_width, pattern_width), ) logger.debug("adjacency_relations") if visualize: figure_adjacencies( adjacency_relations, direction_offsets, tile_catalog, pattern_catalog, pattern_width, [tile_size, tile_size], output_filename=f"visualization/adjacency_{filename}_{timecode}_A", ) # figure_adjacencies(adjacency_relations, direction_offsets, tile_catalog, pattern_catalog, pattern_width, [tile_size, tile_size], output_filename=f"visualization/adjacency_{filename}_{timecode}_B", render_b_first=True) logger.debug(f"output size: {output_size}\noutput periodic: {output_periodic}") number_of_patterns = len(pattern_weights) logger.debug(f"# patterns: {number_of_patterns}") decode_patterns = dict(enumerate(pattern_list)) encode_patterns = {x: i for i, x in enumerate(pattern_list)} _encode_directions = {j: i for i, j in direction_offsets} adjacency_list: Dict[Tuple[int, int], List[Set[int]]] = {} for _, adjacency in direction_offsets: adjacency_list[adjacency] = [set() for _ in pattern_weights] # logger.debug(adjacency_list) for adjacency, pattern1, pattern2 in adjacency_relations: # logger.debug(adjacency) # logger.debug(decode_patterns[pattern1]) adjacency_list[adjacency][encode_patterns[pattern1]].add(encode_patterns[pattern2]) logger.debug(f"adjacency: {len(adjacency_list)}") time_adjacency = time.perf_counter() ### Ground ### ground_list: Optional[NDArray[np.int64]] = None if ground: ground_list = np.vectorize(lambda x: encode_patterns[x])( pattern_grid.flat[(ground - 1) :] ) if ground_list is None or ground_list.size == 0: ground_list = None if ground_list is not None: ground_catalog = { encode_patterns[k]: v for k, v in pattern_catalog.items() if encode_patterns[k] in ground_list } if visualize: figure_pattern_catalog( ground_catalog, tile_catalog, pattern_weights, pattern_width, output_filename=f"visualization/patterns_ground_{filename}_{timecode}", ) wave = makeWave( number_of_patterns, output_size[0], output_size[1], ground=ground_list ) adjacency_matrix = makeAdj(adjacency_list) ### Heuristics ### encoded_weights: NDArray[np.float64] = np.zeros((number_of_patterns), dtype=np.float64) for w_id, w_val in pattern_weights.items(): encoded_weights[encode_patterns[w_id]] = w_val choice_random_weighting: NDArray[np.float64] = np.random.random_sample(wave.shape[1:]) * 0.1 pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int] = lexicalPatternHeuristic if choice_heuristic == "rarest": pattern_heuristic = makeRarestPatternHeuristic(encoded_weights) if choice_heuristic == "weighted": pattern_heuristic = makeWeightedPatternHeuristic(encoded_weights) if choice_heuristic == "random": pattern_heuristic = makeRandomPatternHeuristic(encoded_weights) logger.debug(loc_heuristic) location_heuristic: Callable[[NDArray[np.bool_]], Tuple[int, int]] = lexicalLocationHeuristic if loc_heuristic == "anti-entropy": location_heuristic = makeAntiEntropyLocationHeuristic(choice_random_weighting) if loc_heuristic == "entropy": location_heuristic = makeEntropyLocationHeuristic(choice_random_weighting) if loc_heuristic == "random": location_heuristic = makeRandomLocationHeuristic(choice_random_weighting) if loc_heuristic == "simple": location_heuristic = simpleLocationHeuristic if loc_heuristic == "spiral": location_heuristic = makeSpiralLocationHeuristic(choice_random_weighting) if loc_heuristic == "hilbert": location_heuristic = makeHilbertLocationHeuristic(choice_random_weighting) ### Visualization ### ( visualize_choice, visualize_wave, visualize_backtracking, visualize_propagate, visualize_final, visualize_after, ) = (None, None, None, None, None, None) if filename and visualize: ( visualize_choice, visualize_wave, visualize_backtracking, visualize_propagate, visualize_final, visualize_after, ) = make_solver_visualizers( f"{filename}_{timecode}", wave, decode_patterns=decode_patterns, pattern_catalog=pattern_catalog, tile_catalog=tile_catalog, tile_size=[tile_size, tile_size], ) if filename and logging: ( visualize_choice, visualize_wave, visualize_backtracking, visualize_propagate, visualize_final, visualize_after, ) = make_solver_loggers(f"{filename}_{timecode}", input_stats.copy()) if filename and logging and visualize: vis = make_solver_visualizers( f"{filename}_{timecode}", wave, decode_patterns=decode_patterns, pattern_catalog=pattern_catalog, tile_catalog=tile_catalog, tile_size=[tile_size, tile_size], ) log = make_solver_loggers(f"{filename}_{timecode}", input_stats.copy()) def visfunc(idx: int): def vf(*args, **kwargs): if vis[idx]: vis[idx](*args, **kwargs) if log[idx]: return log[idx](*args, **kwargs) return vf ( visualize_choice, visualize_wave, visualize_backtracking, visualize_propagate, visualize_final, visualize_after, ) = [visfunc(x) for x in range(len(vis))] ### Global Constraints ### active_global_constraint = lambda wave: True if global_constraint == "allpatterns": active_global_constraint = make_global_use_all_patterns() logger.debug(active_global_constraint) combined_constraints = [active_global_constraint] def combinedConstraints(wave: NDArray[np.bool_]) -> bool: return all(fn(wave) for fn in combined_constraints) ### Solving ### time_solve_start = None time_solve_end = None solution_tile_grid = None logger.debug("solving...") attempts = 0 while attempts < attempt_limit: attempts += 1 time_solve_start = time.perf_counter() stats = {} # profiler = pprofile.Profile() # with profiler: # with PyCallGraph(output=GraphvizOutput(output_file=f"visualization/pycallgraph_{filename}_{timecode}.png")): try: solution = run( wave.copy(), adjacency_matrix, locationHeuristic=location_heuristic, patternHeuristic=pattern_heuristic, periodic=output_periodic, backtracking=backtracking, onChoice=visualize_choice, onBacktrack=visualize_backtracking, onObserve=visualize_wave, onPropagate=visualize_propagate, onFinal=visualize_final, checkFeasible=combinedConstraints, ) if visualize_after: stats = visualize_after() # logger.debug(solution) # logger.debug(stats) solution_as_ids = np.vectorize(lambda x: decode_patterns[x])(solution) solution_tile_grid = pattern_grid_to_tiles( solution_as_ids, pattern_catalog ) logger.debug("Solution:") # logger.debug(solution_tile_grid) if filename: render_tiles_to_output( solution_tile_grid, tile_catalog, (tile_size, tile_size), output_destination + filename + "_" + timecode + ".png", ) time_solve_end = time.perf_counter() stats.update({"outcome": "success"}) except StopEarly: logger.debug("Skipping...") stats.update({"outcome": "skipped"}) raise except TimedOut: logger.debug("Timed Out") if visualize_after: stats = visualize_after() stats.update({"outcome": "timed_out"}) except Contradiction as exc: logger.warning(f"Contradiction: {exc}") if visualize_after: stats = visualize_after() stats.update({"outcome": "contradiction"}) finally: # profiler.dump_stats(f"logs/profile_{filename}_{timecode}.txt") outstats = {} outstats.update(input_stats) solve_duration = time.perf_counter() - time_solve_start if time_solve_end is not None: solve_duration = time_solve_end - time_solve_start adjacency_duration = time_solve_start - time_adjacency outstats.update( { "attempts": attempts, "time_start": time_begin, "time_adjacency": time_adjacency, "adjacency_duration": adjacency_duration, "time solve start": time_solve_start, "time solve end": time_solve_end, "solve duration": solve_duration, "pattern count": number_of_patterns, } ) outstats.update(stats) if log_stats_to_output is not None: log_stats_to_output(outstats, output_destination + log_filename + ".tsv") if solution_tile_grid is not None: return tile_grid_to_image(solution_tile_grid, tile_catalog, (tile_size, tile_size)) raise TimedOut("Attempt limit exceeded.") ================================================ FILE: wfc/wfc_patterns.py ================================================ "Extract patterns from grids of tiles." from __future__ import annotations import logging from typing import Any, Dict, Mapping, Optional, Tuple from .wfc_utilities import hash_downto from collections import Counter import numpy as np from numpy.typing import NDArray logger = logging.getLogger(__name__) def unique_patterns_2d(agrid: NDArray[np.int64], ksize: int, periodic_input: bool) -> Tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]: assert ksize >= 1 if periodic_input: agrid = np.pad( agrid, ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))), mode="wrap", ) else: # TODO: implement non-wrapped image handling # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None) agrid = np.pad( agrid, ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))), mode="wrap", ) patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided( agrid, ( agrid.shape[0] - ksize + 1, agrid.shape[1] - ksize + 1, ksize, ksize, *agrid.shape[2:], ), agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:], writeable=False, ) patch_codes = hash_downto(patches, 2) uc, ui = np.unique(patch_codes, return_index=True) locs = np.unravel_index(ui, patch_codes.shape) up: NDArray[np.int64] = patches[locs[0], locs[1]] ids: NDArray[np.int64] = np.vectorize({code: ind for ind, code in enumerate(uc)}.get)(patch_codes) return ids, up, patch_codes def unique_patterns_brute_force(grid, size, periodic_input): padded_grid = np.pad( grid, ((0, size - 1), (0, size - 1), *(((0, 0),) * (len(grid.shape) - 2))), mode="wrap", ) patches = [] for x in range(grid.shape[0]): row_patches = [] for y in range(grid.shape[1]): row_patches.append( np.ndarray.tolist(padded_grid[x : x + size, y : y + size]) ) patches.append(row_patches) patches = np.array(patches) patch_codes = hash_downto(patches, 2) uc, ui = np.unique(patch_codes, return_index=True) locs = np.unravel_index(ui, patch_codes.shape) up = patches[locs[0], locs[1]] ids = np.vectorize({c: i for i, c in enumerate(uc)}.get)(patch_codes) return ids, up def make_pattern_catalog( tile_grid: NDArray[np.int64], pattern_width: int, input_is_periodic: bool = True ) -> Tuple[Dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]: """Returns a pattern catalog (dictionary of pattern hashes to consituent tiles), an ordered list of pattern weights, and an ordered list of pattern contents.""" _patterns_in_grid, pattern_contents_list, patch_codes = unique_patterns_2d( tile_grid, pattern_width, input_is_periodic ) dict_of_pattern_contents: Dict[int, NDArray[np.int64]] = {} for pat_idx in range(pattern_contents_list.shape[0]): p_hash = hash_downto(pattern_contents_list[pat_idx], 0) dict_of_pattern_contents.update( {p_hash.item(): pattern_contents_list[pat_idx]} ) pattern_frequency = Counter(hash_downto(pattern_contents_list, 1)) return ( dict_of_pattern_contents, pattern_frequency, hash_downto(pattern_contents_list, 1), patch_codes, ) def identity_grid(grid): """Do nothing to the grid""" # return np.array([[7,5,5,5],[5,0,0,0],[5,0,1,0],[5,0,0,0]]) return grid def reflect_grid(grid): """Reflect the grid left/right""" return np.fliplr(grid) def rotate_grid(grid): """Rotate the grid""" return np.rot90(grid, axes=(1, 0)) def make_pattern_catalog_with_rotations( tile_grid: NDArray[np.int64], pattern_width: int, rotations: int = 7, input_is_periodic: bool = True ) -> Tuple[Dict[int, NDArray[np.int64]], Counter, NDArray[np.int64], NDArray[np.int64]]: rotated_tile_grid = tile_grid.copy() merged_dict_of_pattern_contents: Dict[int, NDArray[np.int64]] = {} merged_pattern_frequency: Counter = Counter() merged_pattern_contents_list: Optional[NDArray[np.int64]] = None merged_patch_codes: Optional[NDArray[np.int64]] = None def _make_catalog() -> None: nonlocal rotated_tile_grid, merged_dict_of_pattern_contents, merged_pattern_contents_list, merged_pattern_frequency, merged_patch_codes ( dict_of_pattern_contents, pattern_frequency, pattern_contents_list, patch_codes, ) = make_pattern_catalog(rotated_tile_grid, pattern_width, input_is_periodic) merged_dict_of_pattern_contents.update(dict_of_pattern_contents) merged_pattern_frequency.update(pattern_frequency) if merged_pattern_contents_list is None: merged_pattern_contents_list = pattern_contents_list.copy() else: merged_pattern_contents_list = np.unique( np.concatenate((merged_pattern_contents_list, pattern_contents_list)) ) if merged_patch_codes is None: merged_patch_codes = patch_codes.copy() counter = 0 grid_ops = [ identity_grid, reflect_grid, rotate_grid, reflect_grid, rotate_grid, reflect_grid, rotate_grid, reflect_grid, ] while counter <= (rotations): # logger.debug(rotated_tile_grid.shape) # logger.debug(np.array_equiv(reflect_grid(rotated_tile_grid.copy()), rotate_grid(rotated_tile_grid.copy()))) # logger.debug(counter) # logger.debug(grid_ops[counter].__name__) rotated_tile_grid = grid_ops[counter](rotated_tile_grid.copy()) # logger.debug(rotated_tile_grid) # logger.debug("---") _make_catalog() counter += 1 # assert False assert merged_pattern_contents_list is not None assert merged_patch_codes is not None return ( merged_dict_of_pattern_contents, merged_pattern_frequency, merged_pattern_contents_list, merged_patch_codes, ) def pattern_grid_to_tiles( pattern_grid: NDArray[np.int64], pattern_catalog: Mapping[int, NDArray[np.int64]] ) -> NDArray[np.int64]: anchor_x = 0 anchor_y = 0 def pattern_to_tile(pattern: int) -> Any: # if isinstance(pattern, list): # ptrns = [] # for p in pattern: # logger.debug(p) # ptrns.push(pattern_to_tile(p)) # logger.debug(ptrns) # assert False # return ptrns return pattern_catalog[pattern][anchor_x][anchor_y] return np.vectorize(pattern_to_tile)(pattern_grid) ================================================ FILE: wfc/wfc_solver.py ================================================ from __future__ import annotations import logging from typing import Any, Callable, Collection, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, TypeVar from scipy import sparse # type: ignore import numpy import numpy as np import sys import math import itertools from numpy.typing import NBitBase, NDArray from hilbertcurve.hilbertcurve import HilbertCurve # type: ignore logger = logging.getLogger(__name__) T = TypeVar("T", bound=NBitBase) class Contradiction(Exception): """Solving could not proceed without backtracking/restarting.""" pass class TimedOut(Exception): """Solve timed out.""" pass class StopEarly(Exception): """Aborting solve early.""" pass class Solver: """WFC Solver which can hold wave and backtracking state.""" def __init__( self, *, wave: NDArray[np.bool_], adj: Mapping[Tuple[int, int], NDArray[numpy.bool_]], periodic: bool = False, backtracking: bool = False, on_backtrack: Optional[Callable[[], None]] = None, on_choice: Optional[Callable[[int, int, int], None]] = None, on_observe: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, on_propagate: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, check_feasible: Optional[Callable[[NDArray[numpy.bool_]], bool]] = None ) -> None: self.wave = wave self.adj = adj self.periodic = periodic self.backtracking = backtracking self.history: List[NDArray[np.bool_]] = [] # An undo history for backtracking. self.on_backtrack = on_backtrack self.on_choice = on_choice self.on_observe = on_observe self.on_propagate = on_propagate self.check_feasible = check_feasible @property def is_solved(self) -> bool: """Is True if the wave has been fully resolved.""" return self.wave.sum() == self.wave.shape[1] * self.wave.shape[2] and (self.wave.sum(axis=0) == 1).all() def solve_next( self, location_heuristic: Callable[[NDArray[numpy.bool_]], Tuple[int, int]], pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], ) -> bool: """Attempt to collapse one wave. Returns True if no more steps remain.""" if self.is_solved: return True if self.check_feasible and not self.check_feasible(self.wave): raise Contradiction("Not feasible.") if self.backtracking: self.history.append(self.wave.copy()) propagate(self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate) try: pattern, i, j = observe(self.wave, location_heuristic, pattern_heuristic) if self.on_choice: self.on_choice(pattern, i, j) self.wave[:, i, j] = False self.wave[pattern, i, j] = True if self.on_observe: self.on_observe(self.wave) propagate(self.wave, self.adj, periodic=self.periodic, onPropagate=self.on_propagate) return False # Assume there is remaining steps, if not then the next call will return True. except Contradiction: if not self.backtracking: raise if not self.history: raise Contradiction("Every permutation has been attempted.") if self.on_backtrack: self.on_backtrack() self.wave = self.history.pop() self.wave[pattern, i, j] = False return False def solve( self, location_heuristic: Callable[[NDArray[numpy.bool_]], Tuple[int, int]], pattern_heuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], ) -> NDArray[np.int64]: """Attempts to solve all waves and returns the solution.""" while not self.solve_next(location_heuristic=location_heuristic, pattern_heuristic=pattern_heuristic): pass return numpy.argmax(self.wave, axis=0) def makeWave(n: int, w: int, h: int, ground: Optional[Iterable[int]] = None) -> NDArray[numpy.bool_]: wave: NDArray[numpy.bool_] = numpy.ones((n, w, h), dtype=numpy.bool_) if ground is not None: wave[:, :, h - 1] = False for g in ground: wave[g, :,] = False wave[g, :, h - 1] = True # logger.debug(wave) # for i in range(wave.shape[0]): # logger.debug(wave[i]) return wave def makeAdj( adjLists: Mapping[Tuple[int, int], Collection[Iterable[int]]] ) -> Dict[Tuple[int, int], NDArray[numpy.bool_]]: adjMatrices = {} # logger.debug(adjLists) num_patterns = len(list(adjLists.values())[0]) for d in adjLists: m = numpy.zeros((num_patterns, num_patterns), dtype=bool) for i, js in enumerate(adjLists[d]): # logger.debug(js) for j in js: m[i, j] = 1 adjMatrices[d] = sparse.csr_matrix(m) return adjMatrices ###################################### # Location Heuristics def makeRandomLocationHeuristic(preferences: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_]], Tuple[int, int]]: def randomLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where(unresolved_cell_mask, preferences, numpy.inf) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() return randomLocationHeuristic def makeEntropyLocationHeuristic(preferences: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_]], Tuple[int, int]]: def entropyLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where( unresolved_cell_mask, preferences + numpy.count_nonzero(wave, axis=0), numpy.inf, ) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() return entropyLocationHeuristic def makeAntiEntropyLocationHeuristic( preferences: NDArray[np.floating[Any]] ) -> Callable[[NDArray[np.bool_]], Tuple[int, int]]: def antiEntropyLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where( unresolved_cell_mask, preferences + numpy.count_nonzero(wave, axis=0), -numpy.inf, ) row, col = numpy.unravel_index(numpy.argmax(cell_weights), cell_weights.shape) return row.item(), col.item() return antiEntropyLocationHeuristic def spiral_transforms() -> Iterator[Tuple[int, int]]: for N in itertools.count(start=1): if N % 2 == 0: yield (0, 1) # right for _ in range(N): yield (1, 0) # down for _ in range(N): yield (0, -1) # left else: yield (0, -1) # left for _ in range(N): yield (-1, 0) # up for _ in range(N): yield (0, 1) # right def spiral_coords(x: int, y: int) -> Iterator[Tuple[int, int]]: yield x, y for transform in spiral_transforms(): x += transform[0] y += transform[1] yield x, y def fill_with_curve(arr: NDArray[np.floating[T]], curve_gen: Iterable[Iterable[int]]) -> NDArray[np.floating[T]]: arr_len = numpy.prod(arr.shape) fill = 0 for coord in curve_gen: # logger.debug(fill, idx, coord) if fill < arr_len: try: arr[tuple(coord)] = fill / arr_len fill += 1 except IndexError: pass else: break # logger.debug(arr) return arr def makeSpiralLocationHeuristic(preferences: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_]], Tuple[int, int]]: # https://stackoverflow.com/a/23707273/5562922 spiral_gen = ( sc for sc in spiral_coords(preferences.shape[0] // 2, preferences.shape[1] // 2) ) cell_order = fill_with_curve(preferences, spiral_gen) def spiralLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() return spiralLocationHeuristic def makeHilbertLocationHeuristic(preferences: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_]], Tuple[int, int]]: curve_size = math.ceil(math.sqrt(max(preferences.shape[0], preferences.shape[1]))) logger.debug(curve_size) curve_size = 4 h_curve = HilbertCurve(curve_size, 2) h_coords = (h_curve.point_from_distance(i) for i in itertools.count()) cell_order = fill_with_curve(preferences, h_coords) # logger.debug(cell_order) def hilbertLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where(unresolved_cell_mask, cell_order, numpy.inf) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() return hilbertLocationHeuristic def simpleLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where( unresolved_cell_mask, numpy.count_nonzero(wave, axis=0), numpy.inf ) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() def lexicalLocationHeuristic(wave: NDArray[np.bool_]) -> Tuple[int, int]: unresolved_cell_mask = numpy.count_nonzero(wave, axis=0) > 1 cell_weights = numpy.where(unresolved_cell_mask, 1.0, numpy.inf) row, col = numpy.unravel_index(numpy.argmin(cell_weights), cell_weights.shape) return row.item(), col.item() ##################################### # Pattern Heuristics def lexicalPatternHeuristic(weights: NDArray[np.bool_], wave: NDArray[np.bool_]) -> int: return numpy.nonzero(weights)[0][0].item() def makeWeightedPatternHeuristic(weights: NDArray[np.floating[Any]]): num_of_patterns = len(weights) def weightedPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int: # TODO: there's maybe a faster, more controlled way to do this sampling... weighted_wave: NDArray[np.floating[Any]] = weights * wave weighted_wave /= weighted_wave.sum() result = numpy.random.choice(num_of_patterns, p=weighted_wave) return result return weightedPatternHeuristic def makeRarestPatternHeuristic(weights: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: """Return a function that chooses the rarest (currently least-used) pattern.""" def weightedPatternHeuristic(wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]) -> int: logger.debug(total_wave.shape) # [logger.debug(e) for e in wave] wave_sums = numpy.sum(total_wave, (1, 2)) # logger.debug(wave_sums) selected_pattern = numpy.random.choice( numpy.where(wave_sums == wave_sums.max())[0] ) return selected_pattern return weightedPatternHeuristic def makeMostCommonPatternHeuristic( weights: NDArray[np.floating[Any]] ) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: """Return a function that chooses the most common (currently most-used) pattern.""" def weightedPatternHeuristic(wave: NDArray[np.bool_], total_wave: NDArray[np.bool_]) -> int: logger.debug(total_wave.shape) # [logger.debug(e) for e in wave] wave_sums = numpy.sum(total_wave, (1, 2)) selected_pattern = numpy.random.choice( numpy.where(wave_sums == wave_sums.min())[0] ) return selected_pattern return weightedPatternHeuristic def makeRandomPatternHeuristic(weights: NDArray[np.floating[Any]]) -> Callable[[NDArray[np.bool_], NDArray[np.bool_]], int]: num_of_patterns = len(weights) def randomPatternHeuristic(wave: NDArray[np.bool_], _: NDArray[np.bool_]) -> int: # TODO: there's maybe a faster, more controlled way to do this sampling... weighted_wave = 1.0 * wave weighted_wave /= weighted_wave.sum() result = numpy.random.choice(num_of_patterns, p=weighted_wave) return result return randomPatternHeuristic ###################################### # Global Constraints def make_global_use_all_patterns() -> Callable[[NDArray[np.bool_]], bool]: def global_use_all_patterns(wave: NDArray[np.bool_]) -> bool: """Returns true if at least one instance of each pattern is still possible.""" return numpy.all(numpy.any(wave, axis=(1, 2))).item() return global_use_all_patterns ##################################### # Solver def propagate( wave: NDArray[np.bool_], adj: Mapping[Tuple[int, int], NDArray[numpy.bool_]], periodic: bool = False, onPropagate: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, ) -> None: """Completely probagate any newly collapsed waves to all areas.""" last_count = wave.sum() while True: supports = {} if periodic: padded = numpy.pad(wave, ((0, 0), (1, 1), (1, 1)), mode="wrap") else: padded = numpy.pad( wave, ((0, 0), (1, 1), (1, 1)), mode="constant", constant_values=True ) # adj is the list of adjacencies. For each direction d in adjacency, # check which patterns are still valid... for d in adj: dx, dy = d # padded[] is a version of the adjacency matrix with the values wrapped around # shifted[] is the padded version with the values shifted over in one direction # because my code stores the directions as relative (x,y) coordinates, we can find # the adjacent cell for each direction by simply shifting the matrix in that direction, # which allows for arbitrary adjacency directions. This is somewhat excessive, but elegant. shifted = padded[ :, 1 + dx : 1 + wave.shape[1] + dx, 1 + dy : 1 + wave.shape[2] + dy ] # logger.debug(f"shifted: {shifted.shape} | adj[d]: {adj[d].shape} | d: {d}") # raise StopEarly # supports[d] = numpy.einsum('pwh,pq->qwh', shifted, adj[d]) > 0 # The adjacency matrix is a boolean matrix, indexed by the direction and the two patterns. # If the value for (direction, pattern1, pattern2) is True, then this is a valid adjacency. # This gives us a rapid way to compare: True is 1, False is 0, so multiplying the matrices # gives us the adjacency compatibility. supports[d] = (adj[d] @ shifted.reshape(shifted.shape[0], -1)).reshape( shifted.shape ) > 0 # supports[d] = ( <- for each cell in the matrix # adj[d] <- the adjacency matrix [sliced by the direction d] # @ <- Matrix multiplication # shifted.reshape(shifted.shape[0], -1)) <- change the shape of the shifted matrix to 2-dimensions, to make the matrix multiplication easier # .reshape( <- reshape our matrix-multiplied result... # shifted.shape) <- ...to match the original shape of the shifted matrix # > 0 <- is not false # multiply the wave matrix by the support matrix to find which patterns are still in the domain for d in adj: wave *= supports[d] if wave.sum() == last_count: break # No changes since the last loop, changed waves have been fully propagated. last_count = wave.sum() if onPropagate: onPropagate(wave) if (wave.sum(axis=0) == 0).any(): raise Contradiction("Wave is in a contradictory state and can not be solved.") def observe( wave: NDArray[np.bool_], locationHeuristic: Callable[[NDArray[np.bool_]], Tuple[int, int]], patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], ) -> Tuple[int, int, int]: """Return the next best wave to collapse based on the provided heuristics.""" i, j = locationHeuristic(wave) pattern = patternHeuristic(wave[:, i, j], wave) return pattern, i, j def run( wave: NDArray[np.bool_], adj: Mapping[Tuple[int, int], NDArray[numpy.bool_]], locationHeuristic: Callable[[NDArray[numpy.bool_]], Tuple[int, int]], patternHeuristic: Callable[[NDArray[np.bool_], NDArray[np.bool_]], int], periodic: bool = False, backtracking: bool = False, onBacktrack: Optional[Callable[[], None]] = None, onChoice: Optional[Callable[[int, int, int], None]] = None, onObserve: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, onPropagate: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, checkFeasible: Optional[Callable[[NDArray[numpy.bool_]], bool]] = None, onFinal: Optional[Callable[[NDArray[numpy.bool_]], None]] = None, depth: int = 0, depth_limit: Optional[int] = None, ) -> NDArray[numpy.int64]: solver = Solver( wave=wave, adj=adj, periodic=periodic, backtracking=backtracking, on_backtrack=onBacktrack, on_choice=onChoice, on_observe=onObserve, on_propagate=onPropagate, check_feasible=checkFeasible ) while not solver.solve_next(location_heuristic=locationHeuristic, pattern_heuristic=patternHeuristic): pass if onFinal: onFinal(solver.wave) return numpy.argmax(solver.wave, axis=0) ================================================ FILE: wfc/wfc_tiles.py ================================================ """Breaks an image into consituant tiles.""" from __future__ import annotations from typing import Dict, Tuple import numpy as np from numpy.typing import NDArray from .wfc_utilities import hash_downto def image_to_tiles(img: NDArray[np.integer], tile_size: int) -> NDArray[np.integer]: """ Takes an images, divides it into tiles, return an array of tiles. """ padding_argument = [(0, 0), (0, 0), (0, 0)] for input_dim in [0, 1]: padding_argument[input_dim] = ( 0, (tile_size - img.shape[input_dim]) % tile_size, ) img = np.pad(img, padding_argument, mode="constant") tiles = img.reshape( ( img.shape[0] // tile_size, tile_size, img.shape[1] // tile_size, tile_size, img.shape[2], ) ).swapaxes(1, 2) return tiles def make_tile_catalog( image_data: NDArray[np.integer], tile_size: int ) -> Tuple[Dict[int, NDArray[np.integer]], NDArray[np.int64], NDArray[np.int64], Tuple[NDArray[np.int64], NDArray[np.int64]]]: """ Takes an image and tile size and returns the following: tile_catalog is a dictionary tiles, with the hashed ID as the key tile_grid is the original image, expressed in terms of hashed tile IDs code_list is the original image, expressed in terms of hashed tile IDs and reduced to one dimension unique_tiles is the set of tiles, plus the frequency of their occurrence """ channels = image_data.shape[2] # Number of color channels in the image tiles = image_to_tiles(image_data, tile_size) tile_list: NDArray[np.integer] = tiles.reshape((tiles.shape[0] * tiles.shape[1], tile_size, tile_size, channels)) code_list: NDArray[np.int64] = hash_downto(tiles, 2).reshape((tiles.shape[0] * tiles.shape[1])) tile_grid: NDArray[np.int64] = hash_downto(tiles, 2) unique_tiles: Tuple[NDArray[np.int64], NDArray[np.int64]] = np.unique(tile_grid, return_counts=True) tile_catalog: Dict[int, NDArray[np.integer]] = {} for i, j in enumerate(code_list): tile_catalog[j] = tile_list[i] return tile_catalog, tile_grid, code_list, unique_tiles def tiles_to_images(tile_grid, tile_catalog): return ================================================ FILE: wfc/wfc_utilities.py ================================================ """Utility data and functions for WFC""" from __future__ import annotations import collections import logging from typing import Any import numpy as np from numpy.typing import NDArray logger = logging.getLogger(__name__) CoordXY = collections.namedtuple("CoordXY", ["x", "y"]) CoordRC = collections.namedtuple("CoordRC", ["row", "column"]) def hash_downto(a: NDArray[np.integer], rank: int, seed: Any=0) -> NDArray[np.int64]: state = np.random.RandomState(seed) assert rank < len(a.shape) # logger.debug((np.prod(a.shape[:rank]),-1)) # logger.debug(np.array([np.prod(a.shape[:rank]),-1], dtype=np.int64).dtype) u: NDArray[np.integer] = a.reshape((np.prod(a.shape[:rank], dtype=np.int64), -1)) # u = a.reshape((np.prod(a.shape[:rank]),-1)) v = state.randint(1 - (1 << 63), 1 << 63, np.prod(a.shape[rank:]), dtype=np.int64) return np.asarray(np.inner(u, v).reshape(a.shape[:rank]), dtype=np.int64) try: import google.colab # type: ignore IN_COLAB = True except: IN_COLAB = False def load_visualizer(wfc_ns): if IN_COLAB: from google.colab import files # type: ignore uploaded = files.upload() for fn in uploaded.keys(): logger.debug( 'User uploaded file "{name}" with length {length} bytes'.format( name=fn, length=len(uploaded[fn]) ) ) else: import matplotlib # type: ignore import matplotlib.pylab # type: ignore from matplotlib.pyplot import figure, subplot, title, matshow # type: ignore wfc_ns.img_filename = f"images/{wfc_ns.img_filename}" return wfc_ns def find_pattern_center(wfc_ns): # wfc_ns.pattern_center = (math.floor((wfc_ns.pattern_width - 1) / 2), math.floor((wfc_ns.pattern_width - 1) / 2)) wfc_ns.pattern_center = (0, 0) return wfc_ns ================================================ FILE: wfc/wfc_visualize.py ================================================ "Visualize the patterns into tiles and so on." from __future__ import annotations import logging import math import pathlib import itertools from typing import Dict, Tuple import imageio # type: ignore import matplotlib # type: ignore import struct import matplotlib.pyplot as plt # type: ignore import numpy as np from numpy.typing import NDArray from .wfc_patterns import pattern_grid_to_tiles logger = logging.getLogger(__name__) ## Helper functions RGB_CHANNELS = 3 def rgb_to_int(rgb_in): """"Takes RGB triple, returns integer representation.""" return struct.unpack( "I", struct.pack("<" + "B" * 4, *(rgb_in + [0] * (4 - len(rgb_in)))) )[0] def int_to_rgb(val): """Convert hashed int to RGB values""" return [x for x in val.to_bytes(RGB_CHANNELS, "little")] WFC_PARTIAL_BLANK = np.nan def tile_to_image(tile, tile_catalog, tile_size, visualize=False): """ Takes a single tile and returns the pixel image representation. """ new_img = np.zeros((tile_size[0], tile_size[1], 3), dtype=np.int64) for u in range(tile_size[0]): for v in range(tile_size[1]): ## If we want to display a partial pattern, it is helpful to ## be able to show empty cells. Therefore, in visualize mode, ## we use -1 as a magic number for a non-existant tile. pixel = [200, 0, 200] if (visualize) and ((-1 == tile) or (WFC_PARTIAL_BLANK == tile)): if 0 == (u + v) % 2: pixel = [255, 0, 255] else: if (visualize) and -2 == tile: pixel = [0, 255, 255] else: pixel = tile_catalog[tile][u, v] new_img[u, v] = pixel return new_img def argmax_unique(arr, axis): """Return a mask so that we can exclude the nonunique maximums, i.e. the nodes that aren't completely resolved""" arrm = np.argmax(arr, axis) arrs = np.sum(arr, axis) nonunique_mask = np.ma.make_mask((arrs == 1) is False) uni_argmax = np.ma.masked_array(arrm, mask=nonunique_mask, fill_value=-1) return uni_argmax, nonunique_mask def make_solver_loggers(filename, stats={}): counter_choices = 0 counter_wave = 0 counter_backtracks = 0 counter_propagate = 0 def choice_count(pattern, i, j, wave=None): nonlocal counter_choices counter_choices += 1 def wave_count(wave): nonlocal counter_wave counter_wave += 1 def backtrack_count() -> None: nonlocal counter_backtracks counter_backtracks += 1 def propagate_count(wave): nonlocal counter_propagate counter_propagate += 1 def final_count(wave): logger.info( f"{filename}: choices: {counter_choices}, wave:{counter_wave}, backtracks: {counter_backtracks}, propagations: {counter_propagate}" ) stats.update( { "choices": counter_choices, "wave": counter_wave, "backtracks": counter_backtracks, "propagations": counter_propagate, } ) return stats def report_count(): stats.update( { "choices": counter_choices, "wave": counter_wave, "backtracks": counter_backtracks, "propagations": counter_propagate, } ) return stats return ( choice_count, wave_count, backtrack_count, propagate_count, final_count, report_count, ) def make_solver_visualizers( filename: str, wave: NDArray[np.bool_], decode_patterns=None, pattern_catalog=None, tile_catalog=None, tile_size=[1, 1], ): """Construct visualizers for displaying the intermediate solver status""" logger.debug(wave.shape) pattern_total_count = wave.shape[0] resolution_order = np.full( wave.shape[1:], np.nan ) # pattern_wave = when was this resolved? backtracking_order = np.full( wave.shape[1:], np.nan ) # on which iternation was this resolved? pattern_solution = np.full(wave.shape[1:], np.nan) # what is the resolved result? resolution_method = np.zeros( wave.shape[1:] ) # did we set this via observation or propagation? choice_count = 0 vis_count = 0 backtracking_count = 0 max_choices = math.floor((wave.shape[1] * wave.shape[2]) / 3) output_individual_visualizations = False tile_wave = np.zeros(wave.shape, dtype=np.int64) for i in range(wave.shape[0]): local_solution_as_ids = np.full(wave.shape[1:], decode_patterns[i]) local_solution_tile_grid = pattern_grid_to_tiles( local_solution_as_ids, pattern_catalog ) tile_wave[i] = local_solution_tile_grid def choice_vis(pattern, i, j, wave=None): nonlocal choice_count nonlocal resolution_order nonlocal resolution_method choice_count += 1 resolution_order[i][j] = choice_count pattern_solution[i][j] = pattern resolution_method[i][j] = 2 if output_individual_visualizations: figure_solver_data( f"visualization/{filename}_choice_{choice_count}.png", "order of resolution", resolution_order, 0, max_choices, "gist_ncar", ) figure_solver_data( f"visualization/{filename}_solution_{choice_count}.png", "chosen pattern", pattern_solution, 0, pattern_total_count, "viridis", ) figure_solver_data( f"visualization/{filename}_resolution_{choice_count}.png", "resolution method", resolution_method, 0, 2, "inferno", ) if wave: _assigned_patterns, nonunique_mask = argmax_unique(wave, 0) resolved_by_propagation = ( np.ma.mask_or(nonunique_mask, resolution_method != 0) == 0 ) resolution_method[resolved_by_propagation] = 1 resolution_order[resolved_by_propagation] = choice_count if output_individual_visualizations: figure_solver_data( f"visualization/{filename}_wave_{choice_count}.png", "patterns remaining", np.count_nonzero(wave > 0, axis=0), 0, wave.shape[0], "plasma", ) def wave_vis(wave): nonlocal vis_count nonlocal resolution_method nonlocal resolution_order vis_count += 1 pattern_left_count = np.count_nonzero(wave > 0, axis=0) # assigned_patterns, nonunique_mask = argmax_unique(wave, 0) resolved_by_propagation = ( np.ma.mask_or(pattern_left_count > 1, resolution_method != 0) != 1 ) # logger.debug(resolved_by_propagation) resolution_method[resolved_by_propagation] = 1 resolution_order[resolved_by_propagation] = choice_count backtracking_order[resolved_by_propagation] = backtracking_count if output_individual_visualizations: figure_wave_patterns(filename, pattern_left_count, pattern_total_count) figure_solver_data( f"visualization/{filename}_wave_patterns_{choice_count}.png", "patterns remaining", pattern_left_count, 0, pattern_total_count, "magma", ) if decode_patterns and pattern_catalog and tile_catalog: solution_as_ids = np.vectorize(lambda x: decode_patterns[x])( np.argmax(wave, 0) ) solution_tile_grid = pattern_grid_to_tiles(solution_as_ids, pattern_catalog) if output_individual_visualizations: figure_solver_data( f"visualization/{filename}_tiles_assigned_{choice_count}.png", "tiles assigned", solution_tile_grid, 0, pattern_total_count, "plasma", ) img = tile_grid_to_image(solution_tile_grid.T, tile_catalog, tile_size) masked_tile_wave: np.ma.MaskedArray = np.ma.MaskedArray( data=tile_wave, mask=(wave == False), dtype=np.int64 ) masked_img = tile_grid_to_average( np.transpose(masked_tile_wave, (0, 2, 1)), tile_catalog, tile_size ) if output_individual_visualizations: figure_solver_image( f"visualization/{filename}_solution_partial_{choice_count}.png", "solved_tiles", img.astype(np.uint8), ) imageio.imwrite( f"visualization/{filename}_solution_partial_img_{choice_count}.png", img.astype(np.uint8), ) fig_list = [ # {"title": "resolved by propagation", "data": resolved_by_propagation.T, "vmin": 0, "vmax": 2, "cmap": "inferno", "datatype":"figure"}, { "title": "order of resolution", "data": resolution_order.T, "vmin": 0, "vmax": max_choices / 4, "cmap": "hsv", "datatype": "figure", }, { "title": "chosen pattern", "data": pattern_solution.T, "vmin": 0, "vmax": pattern_total_count, "cmap": "viridis", "datatype": "figure", }, { "title": "resolution method", "data": resolution_method.T, "vmin": 0, "vmax": 2, "cmap": "magma", "datatype": "figure", }, { "title": "patterns remaining", "data": pattern_left_count.T, "vmin": 0, "vmax": pattern_total_count, "cmap": "viridis", "datatype": "figure", }, { "title": "tiles assigned", "data": solution_tile_grid.T, "vmin": None, "vmax": None, "cmap": "prism", "datatype": "figure", }, { "title": "solved tiles", "data": masked_img.astype(np.uint8), "datatype": "image", }, ] figure_unified( "Solver Readout", f"visualization/{filename}_readout_{choice_count:03}_{vis_count:03}.png", fig_list, ) def backtrack_vis() -> None: nonlocal vis_count nonlocal pattern_solution nonlocal backtracking_count backtracking_count += 1 vis_count += 1 pattern_solution = np.full(wave.shape[1:], -1) return choice_vis, wave_vis, backtrack_vis, None, wave_vis, None def figure_unified(figure_name_overall, filename, data): matfig, axs = plt.subplots( 1, len(data), sharey="row", gridspec_kw={"hspace": 0, "wspace": 0} ) for idx, _data_obj in enumerate(data): if "image" == data[idx]["datatype"]: axs[idx].imshow(data[idx]["data"], interpolation="nearest") else: axs[idx].matshow( data[idx]["data"], vmin=data[idx]["vmin"], vmax=data[idx]["vmax"], cmap=data[idx]["cmap"], ) axs[idx].get_xaxis().set_visible(False) axs[idx].get_yaxis().set_visible(False) axs[idx].label_outer() plt.savefig(filename, bbox_inches="tight", pad_inches=0, dpi=600) plt.close(fig=matfig) plt.close("all") vis_count = 0 def visualize_solver(wave): pattern_left_count = np.count_nonzero(wave > 0, axis=0) pattern_total_count = wave.shape[0] figure_wave_patterns(pattern_left_count, pattern_total_count) def make_figure_solver_image(plot_title, img): visfig = plt.figure(figsize=(4, 4), edgecolor="k", frameon=True) plt.imshow(img, interpolation="nearest") plt.title(plot_title) plt.grid(None) plt.grid(None) an_ax = plt.gca() an_ax.get_xaxis().set_visible(False) an_ax.get_yaxis().set_visible(False) return visfig def figure_solver_image(filename, plot_title, img): visfig = make_figure_solver_image(plot_title, img) plt.savefig(filename, bbox_inches="tight", pad_inches=0) plt.close(fig=visfig) plt.close("all") def make_figure_solver_data(plot_title, data, min_count, max_count, cmap_name): visfig = plt.figure(figsize=(4, 4), edgecolor="k", frameon=True) plt.title(plot_title) plt.matshow(data, vmin=min_count, vmax=max_count, cmap=cmap_name) plt.grid(None) plt.grid(None) ax = plt.gca() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) return visfig def figure_solver_data(filename, plot_title, data, min_count, max_count, cmap_name): visfig = make_figure_solver_data(plot_title, data, min_count, max_count, cmap_name) plt.savefig(filename, bbox_inches="tight", pad_inches=0) plt.close(fig=visfig) plt.close("all") def figure_wave_patterns(filename, pattern_left_count, max_count): global vis_count vis_count += 1 visfig = plt.figure(figsize=(4, 4), edgecolor="k", frameon=True) plt.title("wave") plt.matshow(pattern_left_count, vmin=0, vmax=max_count, cmap="plasma") plt.grid(None) plt.grid(None) plt.savefig(f"{filename}_wave_patterns_{vis_count}.png") plt.close(fig=visfig) def tile_grid_to_average( tile_grid: np.ma.MaskedArray, tile_catalog: Dict[int, NDArray[np.int64]], tile_size: Tuple[int, int], color_channels: int = 3, ) -> NDArray[np.int64]: """ Takes a masked array of tile grid stacks and transforms it into an image, taking the average colors of the tiles in tile_catalog. """ new_img = np.zeros( ( tile_grid.shape[1] * tile_size[0], tile_grid.shape[2] * tile_size[1], color_channels, ), dtype=np.int64, ) for i in range(tile_grid.shape[1]): for j in range(tile_grid.shape[2]): tile_stack = tile_grid[:, i, j] for u in range(tile_size[0]): for v in range(tile_size[1]): pixel = [200, 0, 200] pixel_list = np.array( [ tile_catalog[t][u, v] for t in tile_stack[tile_stack.mask == False] ], dtype=np.int64, ) pixel = np.mean(pixel_list, axis=0) # TODO: will need to change if using an image with more than 3 channels new_img[(i * tile_size[0]) + u, (j * tile_size[1]) + v] = np.resize( pixel, new_img[(i * tile_size[0]) + u, (j * tile_size[1]) + v].shape, ) return new_img def tile_grid_to_image( tile_grid: NDArray[np.int64], tile_catalog: Dict[int, NDArray[np.integer]], tile_size: Tuple[int, int], visualize: bool = False, partial: bool = False, color_channels: int = 3, ) -> NDArray[np.integer]: """ Takes a tile_grid and transforms it into an image, using the information in tile_catalog. We use tile_size to figure out the size the new image should be, and visualize for displaying partial tile patterns. """ tile_dtype = next(iter(tile_catalog.values())).dtype new_img = np.zeros( ( tile_grid.shape[0] * tile_size[0], tile_grid.shape[1] * tile_size[1], color_channels, ), dtype=tile_dtype, ) if partial and (len(tile_grid.shape)) > 2: # TODO: implement rendering partially completed solution # Call tile_grid_to_average() instead. assert False else: for i in range(tile_grid.shape[0]): for j in range(tile_grid.shape[1]): tile = tile_grid[i, j] for u in range(tile_size[0]): for v in range(tile_size[1]): pixel = [200, 0, 200] ## If we want to display a partial pattern, it is helpful to ## be able to show empty cells. Therefore, in visualize mode, ## we use -1 as a magic number for a non-existant tile. if visualize and ((-1 == tile) or (-2 == tile)): if -1 == tile: if 0 == (i + j) % 2: pixel = [255, 0, 255] if -2 == tile: pixel = [0, 255, 255] else: pixel = tile_catalog[tile][u, v] # TODO: will need to change if using an image with more than 3 channels new_img[ (i * tile_size[0]) + u, (j * tile_size[1]) + v ] = np.resize( pixel, new_img[ (i * tile_size[0]) + u, (j * tile_size[1]) + v ].shape, ) return new_img def figure_list_of_tiles(unique_tiles, tile_catalog, output_filename="list_of_tiles"): plt.figure(figsize=(4, 4), edgecolor="k", frameon=True) plt.title("Extracted Tiles") s = math.ceil(math.sqrt(len(unique_tiles))) + 1 for i, tcode in enumerate(unique_tiles[0]): sp = plt.subplot(s, s, i + 1).imshow(tile_catalog[tcode]) sp.axes.tick_params(labelleft=False, labelbottom=False, length=0) plt.title(f"{i}\n{tcode}", fontsize=10) sp.axes.grid(False) fp = pathlib.Path(output_filename + ".pdf") plt.savefig(fp, bbox_inches="tight") plt.close() def figure_false_color_tile_grid(tile_grid, output_filename="./false_color_tiles"): figure_plot = plt.matshow( tile_grid, cmap="gist_ncar", extent=(0, tile_grid.shape[1], tile_grid.shape[0], 0), ) plt.title("False Color Map of Tiles in Input Image") figure_plot.axes.grid(None) plt.savefig(output_filename + ".png", bbox_inches="tight") plt.close() def figure_tile_grid(tile_grid, tile_catalog, tile_size): tile_grid_to_image(tile_grid, tile_catalog, tile_size) def render_pattern(render_pattern, tile_catalog): """Turn a pattern into an image""" rp_iter = np.nditer(render_pattern, flags=["multi_index"]) output = np.zeros(render_pattern.shape + (3,), dtype=np.uint32) while not rp_iter.finished: # Note that this truncates images with more than 3 channels down to just the channels in the output. # If we want to have alpha channels, we'll need a different way to handle this. output[rp_iter.multi_index] = np.resize( tile_catalog[render_pattern[rp_iter.multi_index]], output[rp_iter.multi_index].shape, ) rp_iter.iternext() return output def figure_pattern_catalog( pattern_catalog, tile_catalog, pattern_weights, pattern_width, output_filename="pattern_catalog", ): s_columns = 24 // min(24, pattern_width) s_rows = 1 + (int(len(pattern_catalog)) // s_columns) _fig = plt.figure(figsize=(s_columns, s_rows * 1.5)) plt.title("Extracted Patterns") counter = 0 for i, _tcode in pattern_catalog.items(): pat_cat = pattern_catalog[i] ptr = render_pattern(pat_cat, tile_catalog).astype(np.uint8) sp = plt.subplot(s_rows, s_columns, counter + 1) spi = sp.imshow(ptr) spi.axes.xaxis.set_label_text(f"({pattern_weights[i]})") sp.set_title(f"{counter}\n{i}", fontsize=3) spi.axes.tick_params( labelleft=False, labelbottom=False, left=False, bottom=False ) spi.axes.grid(False) counter += 1 plt.savefig(output_filename + "_patterns.pdf", bbox_inches="tight") plt.close() def render_tiles_to_output( tile_grid: NDArray[np.int64], tile_catalog: Dict[int, NDArray[np.integer]], tile_size: Tuple[int, int], output_filename: str, ) -> None: img = tile_grid_to_image(tile_grid.T, tile_catalog, tile_size) imageio.imwrite(output_filename, img.astype(np.uint8)) def blit(destination, sprite, upper_left, layer=False, check=False): """ Blits one multidimensional array into another numpy array. """ lower_right = [ ((a + b) if ((a + b) < c) else c) for a, b, c in zip(upper_left, sprite.shape, destination.shape) ] if min(lower_right) < 0: return for i_index, i in enumerate(range(upper_left[0], lower_right[0])): for j_index, j in enumerate(range(upper_left[1], lower_right[1])): if (i >= 0) and (j >= 0): if len(destination.shape) > 2: destination[i, j, layer] = sprite[i_index, j_index] else: if check: if ( (destination[i, j] == sprite[i_index, j_index]) or (destination[i, j] == -1) or {sprite[i_index, j_index] == -1} ): destination[i, j] = sprite[i_index, j_index] else: logger.error( "mismatch: destination[{i},{j}] = {destination[i, j]}, sprite[{i_index}, {j_index}] = {sprite[i_index, j_index]}" ) else: destination[i, j] = sprite[i_index, j_index] return destination class InvalidAdjacency(Exception): """The combination of patterns and offsets results in pattern combinations that don't match.""" pass def validate_adjacency( pattern_a, pattern_b, preview_size, upper_left_of_center, adj_rel ): preview_adj_a_first = np.full((preview_size, preview_size), -1, dtype=np.int64) preview_adj_b_first = np.full((preview_size, preview_size), -1, dtype=np.int64) blit( preview_adj_b_first, pattern_b, ( upper_left_of_center[1] + adj_rel[0][1], upper_left_of_center[0] + adj_rel[0][0], ), check=True, ) blit(preview_adj_b_first, pattern_a, upper_left_of_center, check=True) blit(preview_adj_a_first, pattern_a, upper_left_of_center, check=True) blit( preview_adj_a_first, pattern_b, ( upper_left_of_center[1] + adj_rel[0][1], upper_left_of_center[0] + adj_rel[0][0], ), check=True, ) if not np.array_equiv(preview_adj_a_first, preview_adj_b_first): logger.debug(adj_rel) logger.debug(pattern_a) logger.debug(pattern_b) logger.debug(preview_adj_a_first) logger.debug(preview_adj_b_first) raise InvalidAdjacency def figure_adjacencies( adjacency_relations_list, adjacency_directions, tile_catalog, patterns, pattern_width, tile_size, output_filename="adjacency", render_b_first=False, ): # try: adjacency_directions_list = list(dict(adjacency_directions).values()) _figadj = plt.figure( figsize=(12, 1 + len(adjacency_relations_list[:64])), edgecolor="b" ) plt.title("Adjacencies") max_offset = max( [abs(x) for x in list(itertools.chain.from_iterable(adjacency_directions_list))] ) for i, adj_rel in enumerate(adjacency_relations_list[:64]): preview_size = pattern_width + max_offset * 2 preview_adj = np.full((preview_size, preview_size), -1, dtype=np.int64) upper_left_of_center = [max_offset, max_offset] pattern_a = patterns[adj_rel[1]] pattern_b = patterns[adj_rel[2]] validate_adjacency( pattern_a, pattern_b, preview_size, upper_left_of_center, adj_rel ) if render_b_first: blit( preview_adj, pattern_b, ( upper_left_of_center[1] + adj_rel[0][1], upper_left_of_center[0] + adj_rel[0][0], ), check=True, ) blit(preview_adj, pattern_a, upper_left_of_center, check=True) else: blit(preview_adj, pattern_a, upper_left_of_center, check=True) blit( preview_adj, pattern_b, ( upper_left_of_center[1] + adj_rel[0][1], upper_left_of_center[0] + adj_rel[0][0], ), check=True, ) ptr = tile_grid_to_image( preview_adj, tile_catalog, tile_size, visualize=True ).astype(np.uint8) subp = plt.subplot(math.ceil(len(adjacency_relations_list[:64]) / 4), 4, i + 1) spi = subp.imshow(ptr) spi.axes.tick_params( left=False, bottom=False, labelleft=False, labelbottom=False ) plt.title( f"{i}:\n({adj_rel[1]} +\n{adj_rel[2]})\n by {adj_rel[0]}", fontsize=10 ) indicator_rect = matplotlib.patches.Rectangle( (upper_left_of_center[1] - 0.51, upper_left_of_center[0] - 0.51), pattern_width, pattern_width, Fill=False, edgecolor="b", linewidth=3.0, linestyle=":", ) spi.axes.add_artist(indicator_rect) spi.axes.grid(False) plt.savefig(output_filename + "_adjacency.pdf", bbox_inches="tight") plt.close() # except ValueError as e: # logger.exception(e) ================================================ FILE: wfc_run.py ================================================ # -*- coding: utf-8 -*- """Base code to load commands from xml and run them.""" from __future__ import annotations import argparse import datetime import logging from typing import List, Literal, TypedDict, Union import wfc.wfc_control as wfc_control import xml.etree.ElementTree as ET import os class RunInstructions(TypedDict): loc: Literal["lexical", "hilbert", "spiral", "entropy", "anti-entropy", "simple", "random"] choice: Literal["lexical", "rarest", "weighted", "random"] backtracking: bool global_constraint: Literal[False, "allpatterns"] def string2bool(strn: Union[bool, str]) -> bool: if isinstance(strn, bool): return strn return strn.lower() in ["true"] def run_default(run_experiment: str = "simple", samples: str = "samples_reference.xml") -> None: log_filename = f"log_{datetime.datetime.now().isoformat()}".replace(":", ".") xdoc = ET.ElementTree(file=samples) default_allowed_attempts = 10 default_backtracking = str(False) log_stats_to_output = wfc_control.make_log_stats() for xnode in xdoc.getroot(): name = xnode.get("name", "NAME") if "overlapping" == xnode.tag: # seed = 3262 tile_size = int(xnode.get("tile_size", 1)) # seed for random generation, can be any number tile_size = int(xnode.get("tile_size", 1)) # size of tile, in pixels pattern_width = int(xnode.get("N", 2)) # Size of the patterns we want. # 2x2 is the minimum, larger scales get slower fast. symmetry = int(xnode.get("symmetry", 8)) ground = int(xnode.get("ground", 0)) periodic_input = string2bool( xnode.get("periodic", "False") ) # Does the input wrap? periodic_output = string2bool( xnode.get("periodic", "False") ) # Do we want the output to wrap? generated_size = (int(xnode.get("width", 48)), int(xnode.get("height", 48))) screenshots = int( xnode.get("screenshots", 1) ) # Number of times to run the algorithm, will produce this many distinct outputs iteration_limit = int( xnode.get("iteration_limit", 0) ) # After this many iterations, time out. 0 = never time out. allowed_attempts = int( xnode.get("allowed_attempts", default_allowed_attempts) ) # Give up after this many contradictions backtracking = string2bool(xnode.get("backtracking", default_backtracking)) visualize_experiment = False run_instructions: List[RunInstructions] = [ # simple { "loc": "entropy", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, } ] # run_instructions = [{"loc": "entropy", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns"}] if run_experiment == "choice": run_instructions = [ { "loc": "lexical", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "entropy", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "random", "choice": "weighted", "backtracking": False, "global_constraint": False, }, { "loc": "lexical", "choice": "random", "backtracking": backtracking, "global_constraint": False, }, { "loc": "entropy", "choice": "random", "backtracking": backtracking, "global_constraint": False, }, { "loc": "random", "choice": "random", "backtracking": False, "global_constraint": False, }, { "loc": "lexical", "choice": "weighted", "backtracking": True, "global_constraint": False, }, { "loc": "entropy", "choice": "weighted", "backtracking": True, "global_constraint": False, }, { "loc": "lexical", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns", }, { "loc": "entropy", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns", }, { "loc": "lexical", "choice": "weighted", "backtracking": False, "global_constraint": "allpatterns", }, { "loc": "entropy", "choice": "weighted", "backtracking": False, "global_constraint": "allpatterns", }, ] if run_experiment == "heuristic": run_instructions = [ { "loc": "hilbert", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "spiral", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "entropy", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "anti-entropy", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "lexical", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "simple", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, { "loc": "random", "choice": "weighted", "backtracking": backtracking, "global_constraint": False, }, ] if run_experiment == "backtracking": run_instructions = [ { "loc": "entropy", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns", }, { "loc": "entropy", "choice": "weighted", "backtracking": False, "global_constraint": "allpatterns", }, { "loc": "entropy", "choice": "weighted", "backtracking": True, "global_constraint": False, }, { "loc": "entropy", "choice": "weighted", "backtracking": False, "global_constraint": False, }, ] if run_experiment == "backtracking_heuristic": run_instructions = [ { "loc": "lexical", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns", }, { "loc": "lexical", "choice": "weighted", "backtracking": False, "global_constraint": "allpatterns", }, { "loc": "lexical", "choice": "weighted", "backtracking": True, "global_constraint": False, }, { "loc": "lexical", "choice": "weighted", "backtracking": False, "global_constraint": False, }, { "loc": "random", "choice": "weighted", "backtracking": True, "global_constraint": "allpatterns", }, { "loc": "random", "choice": "weighted", "backtracking": False, "global_constraint": "allpatterns", }, { "loc": "random", "choice": "weighted", "backtracking": True, "global_constraint": False, }, { "loc": "random", "choice": "weighted", "backtracking": False, "global_constraint": False, }, ] if run_experiment == "choices": run_instructions = [ { "loc": "entropy", "choice": "rarest", "backtracking": False, "global_constraint": False, }, { "loc": "entropy", "choice": "weighted", "backtracking": False, "global_constraint": False, }, { "loc": "entropy", "choice": "random", "backtracking": False, "global_constraint": False, }, ] for experiment in run_instructions: for x in range(screenshots): print(f"-: {name} > {x}") try: solution = wfc_control.execute_wfc( name, tile_size=tile_size, pattern_width=pattern_width, rotations=symmetry, output_size=generated_size, ground=ground, attempt_limit=allowed_attempts, output_periodic=periodic_output, input_periodic=periodic_input, loc_heuristic=experiment["loc"], choice_heuristic=experiment["choice"], backtracking=experiment["backtracking"], global_constraint=experiment["global_constraint"], log_filename=log_filename, log_stats_to_output=log_stats_to_output, visualize=visualize_experiment, logging=True, ) print(solution) except Exception as exc: print(f"Skipped because: {exc}") if False: # These are included for my colab experiments, remove them if you're not me os.system( 'cp -rf "/content/wfc/output/*.tsv" "/content/drive/My Drive/wfc_exper/2"' ) os.system( 'cp -r "/content/wfc/output" "/content/drive/My Drive/wfc_exper/2"' ) def main() -> None: logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser( description="Geneates examples from bundled samples which will be saved to the output/ directory.", ) parser.add_argument( "-e", "--experiment", type=str, default="simple", choices=["simple", "choice", "choices", "heuristic", "backtracking", "backtracking_heuristic"], help="Which experiment to run, defaults to simple.", ) parser.add_argument( "-s", "--samples", type=str, required=True, metavar="XML_FILE", default="samples_reference.xml", help="An XML file with input data. If unsure then use '-s samples_reference.xml'", ) args = parser.parse_args() run_default(run_experiment=args.experiment, samples=args.samples) if __name__ == "__main__": main()