[
  {
    "path": "README.md",
    "content": "# 2D Wave Simulation on the GPU\n\nThis repository contains a lightweight 2D wave simulator running on the GPU using CuPy library (probably requires a NVIDIA GPU). It can be used for 2D light and sound simulations.\nA simple visualizer shows the field and its intensity on the screen and writes a movie file for each to disks. The goal is to provide a fast, easy to use but still felxible wave simulator.\n\n<div style=\"display: flex;\">\n    <img src=\"images/simulation_1.jpg\" alt=\"Example Image 1\" width=\"49%\">\n    <img src=\"images/simulation_2.jpg\" alt=\"Example Image 2\" width=\"49%\">\n</div>\n\n### Update 06.04.2025\n\n* Scene objects can now draw to a visualization layer (most of them do not yet, feel free to contribute) !\n* Example 4 now shows a two-mirror optical cavity and how standing waves emerge.\n* Added new Line Sources\n* Added Refractive index Polygon object (StaticRefractiveIndexPolygon)\n* Added Refractive index Box object (StaticRefractiveIndexBox)\n* Fixed some issues with the examples\n\n<div style=\"display: flex;\">\n    <img src=\"images/optical_cavity.jpg\" alt=\"Example 4 - Optical Cavity with Standing Waves\" width=\"50%\">\n</div>\n\n### Update 01.04.2024\n\n* Refactored the code to support a more flexible scene description. A simulation scene now consists of a list of objects that add their contribution to the fields.\nThey can be combined to build complex and time dependent simulations. The refactoring also made the core simulation code even simpler.\n* Added a few new custom colormaps that work well for wave simulations.\n* Added new examples, which should make it easier to understand the usage of the program and how you can setup your own simulations: [examples](source/examples).\n\n<div style=\"display: flex;\">\n    <img src=\"images/simulation_3.jpg\" alt=\"Example Image 3\" width=\"45%\">\n    <img src=\"images/simulation_4.jpg\" alt=\"Example Image 4\" width=\"45%\">\n</div>\n\nThe old image based scene description is still available as a scene object. You can continue to use the convenience of an image editing software and create simulations\nwithout much programming.\n\n###  Image Scene Decsription Usage ###\n\nWhen using the 'StaticImageScene' class the simulation scenes can given as an 8Bit RGB image with the following channel semantics:\n* Red:   The Refractive index times 100 (for refractive index 1.5 you would use value 150)\n* Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value defines its frequency.\n* Blue:  Absorbtion field. Larger values correspond to higher dampening of the waves, use graduated transitions to avoid reflections\n\nWARNING: Do not use anti-aliasing for the green channel ! The shades produced are interpreted as different source frequencies, which yields weird results.\n\n<div style=\"display: flex;\">\n    <img src=\"images/source_antialiasing.png\" alt=\"Example Image 5\" width=\"50%\">\n</div>\n\n### Recommended Installation ###\n\n1. Install Python and PyCharm IDE\n2. Clone the Project to you hard disk\n3. Open the folder as a Project using PyCharm\n4. If prompted to install requirements, accept (or install requirements using pip -r requirements.txt)\n5. Right click on one of the examples in wave_sim2d/examples and select run\n\nNOTE: If you have issues installing the `cupy` library\n1. Make sure you have the `nvidia-cuda-toolkit` installed. \nYou can check it by running `nvcc --version`.\n1. In the *requirements.txt* file, replace `cupy` by `cupy-cuda[version-number]x`. \n   Where the version number displayed when running `nvcc --version` (example: `cupy-cuda11x`).\n\n\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\r\nopencv-python\r\nmatplotlib\r\ncupy"
  },
  {
    "path": "wave_sim2d/__init__.py",
    "content": ""
  },
  {
    "path": "wave_sim2d/develop_tests.py",
    "content": "import wave_visualizer\r\nimport wave_visualizer as vis\r\nimport wave_simulation as sim\r\nimport numpy as np\r\nimport cv2\r\nimport math\r\nimport json\r\nfrom scene_objects.static_dampening import StaticDampening\r\nfrom scene_objects.static_refractive_index import StaticRefractiveIndex\r\nfrom scene_objects.static_image_scene import StaticImageScene\r\nfrom scene_objects.source import PointSource, ModulatorSmoothSquare, ModulatorDiscreteSignal\r\n\r\n\r\ndef build_example_scene1(scene_image):\r\n    \"\"\"\r\n    This example uses the old image scene description. See 'StaticImageScene' for more information.\r\n    \"\"\"\r\n    scene_objects = [StaticImageScene(scene_image)]\r\n    return scene_objects\r\n\r\n\r\ndef build_example_scene2(width, height):\r\n    \"\"\"\r\n    In this example, a new scene is created from scratch and a few emitters are places manually.\r\n    One of the emitters uses an amplitude modulation object to change brightness over time\r\n    \"\"\"\r\n    objects = []\r\n\r\n    # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)\r\n    # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')\r\n    objects.append(StaticDampening(np.ones((height, width)), 48))\r\n\r\n    # add a constant refractive index field\r\n    objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))\r\n\r\n    # add a simple point source\r\n    objects.append(PointSource(200, 250, 0.19, 5))\r\n\r\n    # add a point source with an amplitude modulator\r\n    amplitude_modulator = ModulatorDiscreteSignal(np.random.randint(2, size=64), 0.0006)\r\n    objects.append(PointSource(200, 350, 0.19, 5, amp_modulator=amplitude_modulator))\r\n\r\n    return objects\r\n\r\n\r\ndef simulate(scene_image_fn, num_iterations,\r\n             simulation_steps_per_frame, write_videos,\r\n             field_colormap, intensity_colormap,\r\n             background_image_fn=None):\r\n    # reset random number generator\r\n    np.random.seed(0)\r\n\r\n    # load scene image\r\n    scene_image = cv2.cvtColor(cv2.imread(scene_image_fn), cv2.COLOR_BGR2RGB)\r\n\r\n    background_image = None\r\n    if background_image_fn is not None:\r\n        background_image = cv2.imread(background_image_fn)\r\n        background_image = cv2.resize(background_image, (scene_image.shape[1], scene_image.shape[0]))\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(scene_image.shape[1], scene_image.shape[0])\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # build simulation scene\r\n    simulator.scene_objects = build_example_scene2(scene_image.shape[1], scene_image.shape[0])\r\n\r\n    # create video writers\r\n    if write_videos:\r\n        video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'),\r\n                                       60, (scene_image.shape[1], scene_image.shape[0]))\r\n        video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'),\r\n                                       60, (scene_image.shape[1], scene_image.shape[0]))\r\n\r\n    # run simulation\r\n    for i in range(num_iterations):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n        visualizer.update(simulator)\r\n\r\n        if i % simulation_steps_per_frame == 0:\r\n            frame_int = visualizer.render_intensity(1.0)\r\n            frame_field = visualizer.render_field(1.0)\r\n\r\n            if background_image is not None:\r\n                frame_int = cv2.add(background_image, frame_int)\r\n                frame_field = cv2.add(background_image, frame_field)\r\n\r\n           # frame_int = cv2.pyrDown(frame_int)\r\n           # frame_field = cv2.pyrDown(frame_field)\r\n            cv2.imshow(\"Wave Simulation\", frame_field) #cv2.resize(frame_int, dsize=(1024, 1024)))\r\n            cv2.waitKey(1)\r\n\r\n            if write_videos:\r\n                video_writer1.write(frame_field)\r\n                video_writer2.write(frame_int)\r\n\r\n        if i % 128 == 0:\r\n            print(f'{int((i+1)/num_iterations*100)}%')\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    print('This file contains tests for development and you may not bve able to run it without errors')\r\n    print('Please take a look at the previded examples')\r\n\r\n    # increase simulation_steps_per_frame to better utilize GPU\r\n    # good colormaps for field: RdBu[invert=True], colormap_wave1, colormap_wave2, colormap_wave4, icefire\r\n    simulate('../exxample_data/scene_lens_doubleslit.png',\r\n             20000,\r\n             simulation_steps_per_frame=16,\r\n             write_videos=True,\r\n             field_colormap=vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05),\r\n#             field_colormap=vis.get_colormap_lut('RdBu', invert=True, make_symmetric=True),\r\n             intensity_colormap=vis.get_colormap_lut('afmhot', invert=False, black_level=0.0),\r\n             background_image_fn=None)\r\n\r\n"
  },
  {
    "path": "wave_sim2d/examples/example0.py",
    "content": "import sys\r\nimport os\r\nsys.path.append(os.path.join(os.path.dirname(__file__), '../'))  # noqa\r\n\r\nimport cv2\r\nimport wave_sim2d.wave_visualizer as vis\r\nimport wave_sim2d.wave_simulation as sim\r\nfrom wave_sim2d.scene_objects.source import *\r\nfrom wave_sim2d.scene_objects.static_refractive_index import *\r\n\r\ndef build_scene():\r\n    \"\"\"\r\n    This example creates the simplest possible simulation using a single emitter.\r\n    \"\"\"\r\n    width = 512\r\n    height = 512\r\n    objects = [PointSource(200, 256, 0.1, 5)]\r\n    # objects.append(StaticRefractiveIndexPolygon([[400, 255], [300, 200], [300, 300]], 1.5))\r\n    # objects = [LineSource((200, 265), (250, 105), 0.2, 0.5)]\r\n\r\n    return objects, width, height\r\n\r\n\r\ndef main():\r\n    # create colormaps\r\n    field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)\r\n    intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)\r\n\r\n    # build simulation scene\r\n    scene_objects, w, h = build_scene()\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(w, h, scene_objects)\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # run simulation\r\n    for i in range(1000):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n        visualizer.update(simulator)\r\n\r\n        # show field\r\n        frame_field = visualizer.render_field(1.0)\r\n        cv2.imshow(\"Wave Simulation Field\", frame_field)\r\n\r\n        # show intensity\r\n        # frame_int = visualizer.render_intensity(1.0)\r\n        # cv2.imshow(\"Wave Simulation Intensity\", frame_int)\r\n\r\n        cv2.waitKey(1)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "wave_sim2d/examples/example1.py",
    "content": "import sys\r\nimport os\r\nsys.path.append(os.path.join(os.path.dirname(__file__), '../'))  # noqa\r\n\r\nimport numpy as np\r\nimport cv2\r\n\r\nimport wave_sim2d.wave_visualizer as vis\r\nimport wave_sim2d.wave_simulation as sim\r\nfrom wave_sim2d.scene_objects.static_image_scene import StaticImageScene\r\n\r\n\r\ndef build_scene(scene_image_path):\r\n    \"\"\"\r\n    This example uses the 'old' image scene description. See 'StaticImageScene' for more information.\r\n    \"\"\"\r\n    # load scene image\r\n    scene_image = cv2.cvtColor(cv2.imread(scene_image_path), cv2.COLOR_BGR2RGB)\r\n\r\n    # create the scene object list with an 'StaticImageScene' entry as the only scene object\r\n    # more scene objects can be added to the list to build more complex scenes\r\n    scene_objects = [StaticImageScene(scene_image, source_fequency_scale=2.0)]\r\n\r\n    return scene_objects, scene_image.shape[1], scene_image.shape[0]\r\n\r\n\r\ndef main():\r\n    # Set scene image path. The image encodes refractive index, dampening and emitters in its color channels\r\n    # see 'static_image_scene.StaticImageScene' class for a more detailed description.\r\n    # please take a look at the image to understand what is happening in the simulation\r\n    scene_image_path = '../../example_data/scene_lens_doubleslit.png'\r\n\r\n    # create colormaps\r\n    field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)\r\n    intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)\r\n\r\n    # reset random number generator\r\n    np.random.seed(0)\r\n\r\n    # build simulation scene\r\n    scene_objects, w, h = build_scene(scene_image_path)\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(w, h, scene_objects)\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # run simulation\r\n    for i in range(2000):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n        visualizer.update(simulator)\r\n\r\n        # visualize very N frames\r\n        if (i % 4) == 0:\r\n            # show field\r\n            frame_field = visualizer.render_field(1.0)\r\n            cv2.imshow(\"Wave Simulation Field\", frame_field)\r\n\r\n            # show intensity\r\n            # frame_int = visualizer.render_intensity(1.0)\r\n            # cv2.imshow(\"Wave Simulation Intensity\", frame_int)\r\n\r\n        cv2.waitKey(1)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "wave_sim2d/examples/example2.py",
    "content": "import sys\r\nimport os\r\nsys.path.append(os.path.join(os.path.dirname(__file__), '../'))  # noqa\r\n\r\nimport numpy as np\r\nimport cv2\r\nimport wave_sim2d.wave_visualizer as vis\r\nimport wave_sim2d.wave_simulation as sim\r\nfrom wave_sim2d.scene_objects.static_dampening import StaticDampening\r\nfrom wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex\r\nfrom wave_sim2d.scene_objects.source import PointSource, ModulatorSmoothSquare\r\n\r\n\r\ndef build_scene():\r\n    \"\"\"\r\n    In this example, a new scene is created from scratch and a few emitters are places manually.\r\n    One of the emitters uses an amplitude modulation object to change brightness over time\r\n    \"\"\"\r\n    width = 600\r\n    height = 600\r\n    objects = []\r\n\r\n    # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)\r\n    # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')\r\n    objects.append(StaticDampening(np.ones((height, width)), 32))\r\n\r\n    # add a constant refractive index field\r\n    objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))\r\n\r\n    # add a simple point source\r\n    objects.append(PointSource(200, 220, 0.2, 8))\r\n\r\n    # add a point source with an amplitude modulator\r\n    amplitude_modulator = ModulatorSmoothSquare(0.025, 0.0, smoothness=0.5)\r\n    objects.append(PointSource(200, 380, 0.2, 8, amp_modulator=amplitude_modulator))\r\n\r\n    return objects, width, height\r\n\r\n\r\ndef main():\r\n    # create colormaps\r\n    field_colormap = vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05)\r\n    intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)\r\n\r\n    # reset random number generator\r\n    np.random.seed(0)\r\n\r\n    # build simulation scene\r\n    scene_objects, w, h = build_scene()\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(w, h, scene_objects)\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # run simulation\r\n    for i in range(2000):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n        visualizer.update(simulator)\r\n\r\n        # visualize very N frames\r\n        if (i % 2) == 0:\r\n            # show field\r\n            frame_field = visualizer.render_field(1.0)\r\n            cv2.imshow(\"Wave Simulation Field\", frame_field)\r\n\r\n            # show intensity\r\n            # frame_int = visualizer.render_intensity(1.0)\r\n            # cv2.imshow(\"Wave Simulation Intensity\", frame_int)\r\n\r\n        cv2.waitKey(1)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "wave_sim2d/examples/example3.py",
    "content": "import sys\r\nimport os\r\nsys.path.append(os.path.join(os.path.dirname(__file__), '../'))  # noqa\r\n\r\nimport numpy as np\r\nimport cupy as cp\r\nimport math\r\nimport cv2\r\n\r\nimport wave_sim2d.wave_visualizer as vis\r\nimport wave_sim2d.wave_simulation as sim\r\nfrom wave_sim2d.scene_objects.static_dampening import StaticDampening\r\nfrom wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex\r\n\r\n\r\ndef gaussian_kernel(size, sigma):\r\n    \"\"\"\r\n    creates gaussian kernel with side length `l` and a sigma of `sig`\r\n    \"\"\"\r\n    ax = np.linspace(-(size - 1) / 2., (size - 1) / 2., size)\r\n    gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma))\r\n    kernel = np.outer(gauss, gauss)\r\n    return kernel / np.sum(kernel)\r\n\r\n\r\nclass MovingCharge(sim.SceneObject):\r\n    \"\"\"\r\n    Implements a point source scene object. The amplitude can be optionally modulated using a modulator object.\r\n    :param x: center position x.\r\n    :param y: center position y.\r\n    :param frequency: motion frequency\r\n    :param amplitude: motion amplitude\r\n    \"\"\"\r\n    def __init__(self, x, y, frequency, amplitude):\r\n        self.x = x\r\n        self.y = y\r\n        self.frequency = frequency\r\n        self.amplitude = amplitude\r\n        self.size = 11\r\n\r\n        # create a smooth source shape\r\n        self.source_array = cp.array(gaussian_kernel(self.size, self.size/3))\r\n\r\n    def render(self, field, wave_speed_field, dampening_field):\r\n        # no changes to the refractive index or dampening field required for this class\r\n        pass\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        pass\r\n\r\n    def update_field(self, field, t):\r\n        fade_in = math.sin(min(t*0.1, math.pi/2))\r\n\r\n        # write the moving charge to the field\r\n        x = self.x + math.sin(self.frequency * t*0.05)*200\r\n        y = self.y + math.sin(self.frequency * t)*self.amplitude\r\n\r\n        # copy source shape to current position into field\r\n        wh = self.source_array.shape[1]//2\r\n        hh = self.source_array.shape[0]//2\r\n        field[y-hh:y+hh+1, x-wh:x+wh+1] += self.source_array * fade_in * 0.25\r\n\r\n\r\ndef build_scene():\r\n    \"\"\"\r\n    In this example, a custom scene object is implemented and used to simulate a moving field disturbance.\r\n    \"\"\"\r\n    width = 600\r\n    height = 600\r\n    objects = []\r\n\r\n    # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)\r\n    # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')\r\n    objects.append(StaticDampening(np.ones((height, width)), 64))\r\n\r\n    # add a constant refractive index field\r\n    objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))\r\n\r\n    # add a simple point source\r\n    objects.append(MovingCharge(300, 300, 0.1, 10))\r\n\r\n    return objects, width, height\r\n\r\n\r\ndef main():\r\n    # create colormaps\r\n    field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)\r\n    intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)\r\n\r\n    # reset random number generator\r\n    np.random.seed(0)\r\n\r\n    # build simulation scene\r\n    scene_objects, w, h = build_scene()\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(w, h, scene_objects)\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # run simulation\r\n    for i in range(8000):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n        visualizer.update(simulator)\r\n\r\n        # visualize very N frames\r\n        if (i % 2) == 0:\r\n            # show field\r\n            frame_field = visualizer.render_field(1.0)\r\n            cv2.imshow(\"Wave Simulation Field\", frame_field)\r\n\r\n            # show intensity\r\n            # frame_int = visualizer.render_intensity(1.0)\r\n            # cv2.imshow(\"Wave Simulation Intensity\", frame_int)\r\n\r\n        cv2.waitKey(1)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "wave_sim2d/examples/example4.py",
    "content": "import sys\r\nimport os\r\nsys.path.append(os.path.join(os.path.dirname(__file__), '../'))  # noqa\r\n\r\nimport cv2\r\nimport numpy as np\r\nimport cupy as cp\r\nimport wave_sim2d.wave_visualizer as vis\r\nimport wave_sim2d.wave_simulation as sim\r\nfrom wave_sim2d.scene_objects.source import *\r\nfrom wave_sim2d.scene_objects.static_refractive_index import *\r\nfrom wave_sim2d.scene_objects.static_dampening import *\r\n\r\n\r\ndef build_scene():\r\n    \"\"\"\r\n    This example creates fabry pirot cavity and shows the standing waves\r\n    \"\"\"\r\n    width = 768\r\n    height = 512\r\n    objects = []\r\n\r\n    # Add a static dampening field without any dampening in the interior (value 1.0 means no dampening)\r\n    # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')\r\n    objects.append(StaticDampening(np.ones((height, width)), 48))\r\n\r\n    # add nonlinear refractive index field\r\n    objects.append(StaticRefractiveIndexBox((50, height//2), (50, int(height*0.8)), 0.0, 100.0))\r\n    objects.append(StaticRefractiveIndexBox((width-180, height//2), (40, int(height*0.8)), 0.0, 10.0))\r\n\r\n    # add a point source with an amplitude modulator\r\n    # objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.0215, amplitude=0.5))\r\n    objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.1003, amplitude=0.3))\r\n\r\n    return objects, width, height\r\n\r\n\r\ndef show_field(field, brightness_scale):\r\n    gray = (cp.clip(field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8)\r\n    img = gray.get()\r\n    cv2.imshow(\"Strain Simulation Field\", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))\r\n\r\n\r\ndef main():\r\n    write_videos = False\r\n    write_video_frame_every = 2\r\n\r\n    # create colormaps\r\n    field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)\r\n    intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)\r\n\r\n    # build simulation scene\r\n    scene_objects, w, h = build_scene()\r\n\r\n    # create simulator and visualizer objects\r\n    simulator = sim.WaveSimulator2D(w, h, scene_objects)\r\n    visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)\r\n\r\n    # optional create video writers\r\n    if write_videos:\r\n        video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h))\r\n        video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h))\r\n\r\n    # run simulation\r\n    for i in range(100000):\r\n        simulator.update_scene()\r\n        simulator.update_field()\r\n\r\n        visualizer.update(simulator)\r\n        # show field\r\n        frame_field = visualizer.render_field(1.0)\r\n        cv2.imshow(\"Wave Simulation Field\", frame_field)\r\n\r\n        # show intensity\r\n        frame_int = visualizer.render_intensity(1.0)\r\n        # cv2.imshow(\"Wave Simulation Intensity\", frame_int)\r\n\r\n        if write_videos and (i % write_video_frame_every) == 0:\r\n            video_writer1.write(frame_field)\r\n            video_writer2.write(frame_int)\r\n\r\n        cv2.waitKey(1)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n\r\n"
  },
  {
    "path": "wave_sim2d/main.py",
    "content": "if __name__ == \"__main__\":\r\n  print('please run one of the examples from the source/example folder...')\r\n"
  },
  {
    "path": "wave_sim2d/scene_objects/source.py",
    "content": "from wave_sim2d.wave_simulation import SceneObject\r\nimport cupy as cp\r\nimport numpy as np\r\nimport math\r\n\r\n\r\nclass PointSource(SceneObject):\r\n    \"\"\"\r\n    Implements a point source scene object. The amplitude can be optionally modulated using a modulator object.\r\n    :param x: source position x.\r\n    :param y: source position y.\r\n    :param frequency: emitting frequency.\r\n    :param amplitude: emitting amplitude, not used when an amplitude modulator is given\r\n    :param phase: emitter phase\r\n    :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source\r\n                          over time.\r\n    \"\"\"\r\n    def __init__(self, x, y, frequency, amplitude=1.0, phase=0, amp_modulator=None):\r\n        self.x = x\r\n        self.y = y\r\n        self.frequency = frequency\r\n        self.amplitude = amplitude\r\n        self.phase = phase\r\n        self.amplitude_modulator = amp_modulator\r\n\r\n    def set_amplitude_modulator(self, func):\r\n        self.amplitude_modulator = func\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        pass\r\n\r\n    def update_field(self, field, t):\r\n        if self.amplitude_modulator is not None:\r\n            amplitude = self.amplitude_modulator(t) * self.amplitude\r\n        else:\r\n            amplitude = self.amplitude\r\n\r\n        v = cp.sin(self.phase + self.frequency * t) * amplitude\r\n        field[self.y, self.x] = v\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n\r\n\r\nclass LineSource(SceneObject):\r\n    \"\"\"\r\n    Implements a line source scene object. The amplitude can be optionally modulated using a modulator object.\r\n    The source emits along a line defined by a start and end point.\r\n    :param start: starting (x, y) coordinates of the line as a tuple.\r\n    :param end: ending (x, y) coordinates of the line as a tuple.\r\n    :param frequency: emitting frequency.\r\n    :param amplitude: emitting amplitude, not used when an amplitude modulator is given\r\n    :param phase: emitter phase\r\n    :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source\r\n                          over time.\r\n    \"\"\"\r\n    def __init__(self, start, end, frequency, amplitude=1.0, phase=0, amp_modulator=None):\r\n        self.start = start\r\n        self.end = end\r\n        self.frequency = frequency\r\n        self.amplitude = amplitude\r\n        self.phase = phase\r\n        self.amplitude_modulator = amp_modulator\r\n\r\n    def set_amplitude_modulator(self, func):\r\n        self.amplitude_modulator = func\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        pass\r\n\r\n    def update_field(self, field, t):\r\n        if self.amplitude_modulator is not None:\r\n            amplitude = self.amplitude_modulator(t) * self.amplitude\r\n        else:\r\n            amplitude = self.amplitude\r\n\r\n        v = cp.sin(self.phase + self.frequency * t) * amplitude\r\n\r\n        # Determine the points along the line using NumPy\r\n        x1, y1 = self.start\r\n        x2, y2 = self.end\r\n\r\n        distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)\r\n        num_points = int(distance) + 1\r\n\r\n        if num_points > 0:\r\n            x_coords = cp.linspace(x1, x2, num_points).round().astype(int)\r\n            y_coords = cp.linspace(y1, y2, num_points).round().astype(int)\r\n\r\n            # Create boolean masks for valid indices\r\n            valid_x = (x_coords >= 0) & (x_coords < field.shape[1])\r\n            valid_y = (y_coords >= 0) & (y_coords < field.shape[0])\r\n            valid_indices = valid_x & valid_y\r\n\r\n            # Use these valid indices to update the field directly\r\n            valid_y_coords = y_coords[valid_indices]\r\n            valid_x_coords = x_coords[valid_indices]\r\n            field[valid_y_coords, valid_x_coords] = v\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n\r\n# --- Modulators -------------------------------------------------------------------------------------------------------\r\n\r\n\r\nclass ModulatorSmoothSquare:\r\n    \"\"\"\r\n    A modulator that creates a smoothed square wave\r\n    \"\"\"\r\n    def __init__(self, frequency, phase, smoothness=0.5):\r\n        self.frequency = frequency\r\n        self.phase = phase\r\n        self.smoothness = min(max(smoothness, 1e-4), 1.0)\r\n\r\n    def __call__(self, t):\r\n        s = math.pow(self.smoothness, 4.0)\r\n        a = (0.5 / math.atan(1.0/s)) * math.atan(math.sin(t * self.frequency + self.phase) / s)+0.5\r\n        return a\r\n\r\n\r\nclass ModulatorDiscreteSignal:\r\n    \"\"\"\r\n    A modulator that creates a smoothed binary signal\r\n    \"\"\"\r\n    def __init__(self, signal_array, time_factor, transition_slope=8.0):\r\n        self.signal_array = signal_array\r\n        self.time_factor = time_factor\r\n        self.transition_slope = transition_slope\r\n\r\n    def __call__(self, t):\r\n        def smooth_step(t):\r\n            return t * t * (3 - 2 * t)\r\n\r\n        # Wrap around the position if it's outside the array range\r\n        sl = len(self.signal_array)\r\n        t = math.fmod(t*self.time_factor, sl)\r\n\r\n        # Find the indices of the neighboring values\r\n        index_low = int(t)\r\n        index_high = (index_low + 1) % sl\r\n\r\n        # Calculate the interpolation factor\r\n        tf = (t - index_low)\r\n        tf = max(0.0, min(1.0, (tf-0.5)*self.transition_slope+0.5))\r\n\r\n        # Use smooth step to interpolate between neighboring values\r\n        l = smooth_step(tf)\r\n        interpolated_value = (1 - l) * self.signal_array[index_low] + l * self.signal_array[index_high]\r\n\r\n        return interpolated_value\r\n"
  },
  {
    "path": "wave_sim2d/scene_objects/static_dampening.py",
    "content": "from wave_sim2d.wave_simulation import SceneObject\r\nimport cupy as cp\r\nimport numpy as np\r\n\r\n\r\nclass StaticDampening(SceneObject):\r\n    \"\"\"\r\n    Implements a static dampening field that overwrites the entire domain.\r\n    Therefore, us this as base layer in your scene.\r\n    \"\"\"\r\n\r\n    def __init__(self, dampening_field, border_thickness):\r\n        \"\"\"\r\n        Creates a static dampening field object\r\n        @param dampening_field: A NxM array with dampening factors (1.0 equals no dampening) of the same size as the simulation domain.\r\n        @param pml_thickness: Thickness of the Perfectly Matched Layer (PML) at the boundaries to prevent reflections.\r\n        \"\"\"\r\n        w = dampening_field.shape[1]\r\n        h = dampening_field.shape[0]\r\n        self.d = cp.ones((h, w), dtype=cp.float32)\r\n        self.d = cp.clip(cp.array(dampening_field), 0.0, 1.0)\r\n\r\n        # apply border dampening\r\n        for i in range(border_thickness):\r\n            v = (i / border_thickness) ** 0.5\r\n            self.d[i, i:w - i] = v\r\n            self.d[-(1 + i), i:w - i] = v\r\n            self.d[i:h - i, i] = v\r\n            self.d[i:h - i, -(1 + i)] = v\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        assert (dampening_field.shape == self.d.shape)\r\n\r\n        # overwrite existing dampening field\r\n        dampening_field[:] = self.d\r\n\r\n    def update_field(self, field: cp.ndarray, t):\r\n        pass\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n"
  },
  {
    "path": "wave_sim2d/scene_objects/static_image_scene.py",
    "content": "from wave_sim2d.wave_simulation import SceneObject\r\n\r\nimport numpy as np\r\nimport cupy as cp\r\nfrom wave_sim2d.scene_objects.static_dampening import StaticDampening\r\nfrom wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex\r\n\r\n\r\nclass StaticImageScene(SceneObject):\r\n    \"\"\"\r\n    Implements static scene, where the RGB channels of the input image encode the refractive index, the dampening and sources.\r\n    This class allows to use an image editor to create scenes.\r\n    \"\"\"\r\n    def __init__(self, scene_image, source_amplitude=1.0, source_fequency_scale=1.0):\r\n        \"\"\"\r\n        load source from an image description\r\n        The simulation scenes are given as an 8Bit RGB image with the following channel semantics:\r\n            * Red:   The Refractive index times 100 (for refractive index 1.5 you would use value 150)\r\n            * Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value\r\n                     defines its frequency. WARNING: Do not use antialiasing for the green channel !\r\n            * Blue:  Absorbtion field. Larger values correspond to higher dampening of the waves,\r\n                     use graduated transitions to avoid reflections\r\n        \"\"\"\r\n        # Set the opacity of source pixels to incoming waves. If the opacity is 0.0\r\n        # the field will be completely overwritten by the source term\r\n        # a nonzero value (e.g 0.5) allows for antialiasing of sources to work\r\n        self.source_opacity = 0.9\r\n\r\n        # set refractive index field\r\n        self.refractive_index = StaticRefractiveIndex(scene_image[:, :, 0] / 100)\r\n\r\n        # set absorber field\r\n        self.dampening = StaticDampening(1.0 - scene_image[:, :, 2] / 255, border_thickness=48)\r\n\r\n        # set sources, each entry describes a source with the following parameters:\r\n        # (x, y, phase, amplitude, frequency)\r\n        sources_pos = np.flip(np.argwhere(scene_image[:, :, 1] > 0), axis=1)\r\n        phase_amplitude_freq = np.tile(np.array([0, source_amplitude, 0.3]), (sources_pos.shape[0], 1))\r\n        self.sources = np.concatenate((sources_pos, phase_amplitude_freq), axis=1)\r\n\r\n        # set source frequency to channel value\r\n        self.sources[:, 4] = scene_image[sources_pos[:, 1], sources_pos[:, 0], 1] / 255 * 0.5 * source_fequency_scale\r\n        self.sources = cp.array(self.sources).astype(cp.float32)\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        \"\"\"\r\n        render the stat\r\n        \"\"\"\r\n        self.dampening.render(field, wave_speed_field, dampening_field)\r\n        self.refractive_index.render(field, wave_speed_field, dampening_field)\r\n\r\n    def update_field(self, field: cp.ndarray, t):\r\n        # Update the sources in the simulation field based on their properties.\r\n        v = cp.sin(self.sources[:, 2]+self.sources[:, 4]*t)*self.sources[:, 3]\r\n        coords = self.sources[:, 0:2].astype(cp.int32)\r\n\r\n        o = self.source_opacity\r\n        field[coords[:, 1], coords[:, 0]] = field[coords[:, 1], coords[:, 0]]*o + v*(1.0-o)\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n"
  },
  {
    "path": "wave_sim2d/scene_objects/static_refractive_index.py",
    "content": "from wave_sim2d.wave_simulation import SceneObject\r\nimport cupy as cp\r\nimport numpy as np\r\nimport cv2\r\n\r\n\r\nclass StaticRefractiveIndex(SceneObject):\r\n    \"\"\"\r\n    Implements a static refractive index field that overwrites the entire domain with a constant IOR value.\r\n    Use this as base layer in your scene.\r\n    \"\"\"\r\n\r\n    def __init__(self, refractive_index_field):\r\n        \"\"\"\r\n        Creates a static refractive index field object\r\n        :param refractive_index_field: The refractive index field, same size as the source.\r\n                                       Note that values below 0.9 are clipped to prevent the simulation\r\n                                       from becoming instable\r\n        \"\"\"\r\n        shape = refractive_index_field.shape\r\n        self.c = cp.ones((shape[0], shape[1]), dtype=cp.float32)\r\n        self.c = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0)\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        assert (wave_speed_field.shape == self.c.shape)\r\n        wave_speed_field[:] = self.c\r\n\r\n    def update_field(self, field: cp.ndarray, t):\r\n        pass\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n\r\n\r\nclass StaticRefractiveIndexPolygon(SceneObject):\r\n    \"\"\"\r\n    Draws a static polygon with a given refractive index into the wave_speed_field using an\r\n    anti-aliased mask and indexing. Caches the pixel coordinates and mask values.\r\n    \"\"\"\r\n\r\n    def __init__(self, vertices, refractive_index):\r\n        \"\"\"\r\n        Initializes the StaticRefractiveIndexPolygon.\r\n\r\n        Args:\r\n            vertices (list or np.ndarray): A list or array of (x, y) coordinates defining the polygon.\r\n            refractive_index (float): The refractive index of the polygon. Values are clamped to [0.9, 10.0].\r\n        \"\"\"\r\n        self.vertices = np.array(vertices, dtype=np.float32)\r\n        self.refractive_index = min(max(refractive_index, 0.9), 10.0)\r\n        self._cached_coords = None\r\n        self._cached_mask_values = None\r\n        self._cached_field_shape = (0, 0)\r\n\r\n    def _create_polygon_data(self, field_shape):\r\n        \"\"\"\r\n        Creates and caches the pixel coordinates and anti-aliased mask values for the polygon.\r\n\r\n        Args:\r\n            field_shape (tuple): The shape (rows, cols) of the simulation field.\r\n\r\n        Returns:\r\n            tuple: A tuple containing:\r\n                - coords (tuple of cp.ndarray): (y_coordinates, x_coordinates) of the polygon pixels within the field.\r\n                - mask_values (cp.ndarray): Corresponding anti-aliased mask values (0.0 to 1.0).\r\n        \"\"\"\r\n        if self._cached_coords is not None and self._cached_field_shape == field_shape:\r\n            return self._cached_coords, self._cached_mask_values\r\n\r\n        rows, cols = field_shape\r\n\r\n        # Find the bounding box of the polygon\r\n        min_x = np.min(self.vertices[:, 0])\r\n        max_x = np.max(self.vertices[:, 0])\r\n        min_y = np.min(self.vertices[:, 1])\r\n        max_y = np.max(self.vertices[:, 1])\r\n\r\n        mask_width = int(np.ceil(max_x - min_x)) + 1\r\n        mask_height = int(np.ceil(max_y - min_y)) + 1\r\n        offset_x = int(np.floor(min_x))\r\n        offset_y = int(np.floor(min_y))\r\n\r\n        # Create the mask\r\n        mask = np.zeros((mask_height, mask_width), dtype=np.float32)\r\n        translated_vertices = self.vertices - [offset_x, offset_y]\r\n        translated_vertices_cv = np.round(translated_vertices).astype(np.int32)\r\n        cv2.fillPoly(mask, [translated_vertices_cv], 1.0, lineType=cv2.LINE_AA)\r\n\r\n        # Get coordinates and mask values of non-black pixels\r\n        coords_y, coords_x = np.where(mask > 0)\r\n        mask_values = mask[coords_y, coords_x]\r\n\r\n        # Adjust coordinates to the position in the main field\r\n        global_coords_y = coords_y + offset_y\r\n        global_coords_x = coords_x + offset_x\r\n\r\n        # Perform out-of-bounds check here\r\n        in_bounds = (global_coords_y >= 0) & (global_coords_y < rows) & \\\r\n                    (global_coords_x >= 0) & (global_coords_x < cols)\r\n\r\n        valid_global_y = global_coords_y[in_bounds]\r\n        valid_global_x = global_coords_x[in_bounds]\r\n        valid_mask_values = mask_values[in_bounds]\r\n\r\n        self._cached_coords = (cp.array(valid_global_y), cp.array(valid_global_x))\r\n        self._cached_mask_values = cp.array(valid_mask_values, dtype=cp.float32)\r\n        self._cached_field_shape = field_shape\r\n        return self._cached_coords, self._cached_mask_values\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        coords, mask_values = self._create_polygon_data(wave_speed_field.shape)\r\n\r\n        # Use advanced indexing to update the field and perform alpha blending\r\n        bg_wave_speed = wave_speed_field[coords[0], coords[1]]\r\n        wave_speed_field[coords[0], coords[1]] = (bg_wave_speed * (1.0 - mask_values) +\r\n                                                  mask_values / self.refractive_index)\r\n\r\n    def update_field(self, field: cp.ndarray, t):\r\n        pass\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        vertices = np.round(self.vertices).astype(np.int32)\r\n        cv2.fillPoly(image, [vertices], (60, 60, 60), lineType=cv2.LINE_AA)\r\n\r\n\r\nclass StaticRefractiveIndexBox(StaticRefractiveIndexPolygon):\r\n    \"\"\"\r\n    Draws a static rotated box with a given refractive index into the wave_speed_field by\r\n    inheriting from StaticRefractiveIndexPolygon.\r\n    \"\"\"\r\n\r\n    def __init__(self, center: tuple, box_size: tuple, box_angle_rad: float, refractive_index: float):\r\n        \"\"\"\r\n        Initializes the StaticRefractiveIndexBox.\r\n\r\n        Args:\r\n            center (tuple): A tuple (center_x, center_y) representing the box's center.\r\n            box_size (tuple): A tuple (width, height) representing the box's dimensions.\r\n            box_angle_rad (float): The rotation angle of the box in radians (counter-clockwise).\r\n            refractive_index (float): The refractive index of the box. Values are clamped to [0.9, 10.0].\r\n        \"\"\"\r\n        self.center = center\r\n        self.box_size = box_size\r\n        self.box_angle_rad = box_angle_rad\r\n        refractive_index = min(max(refractive_index, 0.9), 10.0)\r\n\r\n        # Unpack center and box size\r\n        center_x, center_y = self.center\r\n        width, height = self.box_size\r\n\r\n        # Calculate the vertices of the rotated box\r\n        half_width = width / 2\r\n        half_height = height / 2\r\n        local_vertices = np.array([[-half_width, -half_height],\r\n                                   [half_width, -half_height],\r\n                                   [half_width, half_height],\r\n                                   [-half_width, half_height]], dtype=np.float32)\r\n\r\n        # Create the rotation matrix\r\n        rotation_matrix = cv2.getRotationMatrix2D((0, 0), np.rad2deg(self.box_angle_rad), 1)\r\n\r\n        # Rotate the local vertices\r\n        rotated_vertices = cv2.transform(np.array([local_vertices]), rotation_matrix)[0]\r\n\r\n        # Translate the rotated vertices to the center\r\n        translated_vertices = rotated_vertices + [center_x, center_y]\r\n\r\n        # Initialize the parent class (StaticRefractiveIndexPolygon) with the vertices\r\n        super().__init__(translated_vertices, refractive_index)\r\n"
  },
  {
    "path": "wave_sim2d/scene_objects/strain_refractive_index.py",
    "content": "from wave_sim2d.wave_simulation import SceneObject\r\nimport cupy as cp\r\nimport cupyx.scipy.signal\r\nimport numpy as np\r\n\r\nclass StrainRefractiveIndex(SceneObject):\r\n    \"\"\"\r\n    Implements a dynamic refractive index field that linearly depends on the strain of the current field.\r\n    The refractive index within the entire domain is overwritten\r\n    \"\"\"\r\n\r\n    def __init__(self, refractive_index_offset, coupling_constant):\r\n        \"\"\"\r\n        Creates a strain refractive index field object\r\n        :param coupling_constant: coupling constant between the strain and the refractive index\r\n        \"\"\"\r\n        self.coupling_constant = coupling_constant\r\n        self.refractive_index_offset = refractive_index_offset\r\n\r\n        self.du_dx_kernel = cp.array([[-1, 0.0, 1]])\r\n        self.du_dy_kernel = cp.array([[-1], [0.0], [1]])\r\n\r\n        self.strain_field = None\r\n\r\n    def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):\r\n        # compute strain\r\n        du_dx = cupyx.scipy.signal.convolve2d(field, self.du_dx_kernel, mode='same', boundary='fill')\r\n        du_dy = cupyx.scipy.signal.convolve2d(field, self.du_dy_kernel, mode='same', boundary='fill')\r\n\r\n        self.strain_field = cp.sqrt(du_dx**2 + du_dy**2)\r\n\r\n        # compute refractive index from strain\r\n        refractive_index_field = self.refractive_index_offset + self.strain_field*self.coupling_constant\r\n\r\n        # assign wave speed using refractive index from above\r\n        wave_speed_field[:] = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0)\r\n\r\n    def update_field(self, field: cp.ndarray, t):\r\n        pass\r\n\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n"
  },
  {
    "path": "wave_sim2d/wave_simulation.py",
    "content": "import cupy\r\nimport numpy as np\r\nimport cupy as cp\r\nimport cupyx.scipy.signal\r\nfrom abc import ABC, abstractmethod\r\n\r\n\r\nclass SceneObject(ABC):\r\n    \"\"\"\r\n    Interface for simulation scene objects. A scene object is anything defining or modifying the simulation scene.\r\n    For example: Light sources, Absorbers or regions with specific refractive index. Scene objects can change the\r\n    simulated field and draw their contribution to the wave speed field and dampening field each frame \"\"\"\r\n\r\n    @abstractmethod\r\n    def render(self, field: cupy.ndarray, wave_speed_field: cupy.ndarray, dampening_field: cupy.ndarray):\r\n        \"\"\" renders the scene objects contribution to the wave speed field and dampening field \"\"\"\r\n        pass\r\n\r\n    @abstractmethod\r\n    def update_field(self, field: cupy.ndarray, t):\r\n        \"\"\" performs updates to the field itself, e.g. for adding sources \"\"\"\r\n        pass\r\n\r\n    @abstractmethod\r\n    def render_visualization(self, image: np.ndarray):\r\n        \"\"\" renders a visualization of the scene object to the image \"\"\"\r\n        pass\r\n\r\n\r\nclass WaveSimulator2D:\r\n    \"\"\"\r\n    Simulates the 2D wave equation\r\n    The system assumes units, where the wave speed is 1.0 pixel/timestep\r\n    source frequency should be adjusted accordingly\r\n    \"\"\"\r\n    def __init__(self, w, h, scene_objects, initial_field=None):\r\n        \"\"\"\r\n        Initialize the 2D wave simulator.\r\n        @param w: Width of the simulation grid.\r\n        @param h: Height of the simulation grid.\r\n        \"\"\"\r\n        self.global_dampening = 1.0\r\n        self.c = cp.ones((h, w), dtype=cp.float32)                      # wave speed field (from refractive indices)\r\n        self.d = cp.ones((h, w), dtype=cp.float32)                      # dampening field\r\n        self.u = cp.zeros((h, w), dtype=cp.float32)                     # field values\r\n        self.u_prev = cp.zeros((h, w), dtype=cp.float32)                # field values of prev frame\r\n\r\n        if initial_field is not None:\r\n            assert w == initial_field.shape[1] and h == initial_field.shape[2], 'width/height of initial field invalid'\r\n            self.u[:] = initial_field\r\n            self.u_prev[:] = initial_field\r\n\r\n        # Define Laplacian kernel\r\n        self.laplacian_kernel = cp.array([[0.066, 0.184, 0.066],\r\n                                          [0.184, -1.0, 0.184],\r\n                                          [0.066, 0.184, 0.066]])\r\n\r\n        # self.laplacian_kernel = cp.array([[0.05, 0.2, 0.05],\r\n        #                           [0.2, -1.0, 0.2],\r\n        #                           [0.05, 0.2, 0.05]])\r\n\r\n        # self.laplacian_kernel = cp.array([[0.103, 0.147, 0.103],\r\n        #                                   [0.147, -1.0, 0.147],\r\n        #                                   [0.103, 0.147, 0.103]])\r\n\r\n        self.t = 0\r\n        self.dt = 1.0\r\n\r\n        self.scene_objects = scene_objects if scene_objects is not None else []\r\n\r\n    def reset_time(self):\r\n        \"\"\"\r\n        Reset the simulation time to zero.\r\n        \"\"\"\r\n        self.t = 0.0\r\n\r\n    def update_field(self):\r\n        \"\"\"\r\n        Update the simulation field based on the wave equation.\r\n        \"\"\"\r\n        # calculate laplacian using convolution\r\n        laplacian = cupyx.scipy.signal.convolve2d(self.u, self.laplacian_kernel, mode='same', boundary='fill')\r\n\r\n        # update field\r\n        v = (self.u - self.u_prev) * self.d * self.global_dampening\r\n        r = (self.u + v + laplacian * (self.c * self.dt)**2)\r\n\r\n        self.u_prev[:] = self.u\r\n        self.u[:] = r\r\n\r\n        self.t += self.dt\r\n\r\n    def update_scene(self):\r\n        # clear wave speed field and dampening field\r\n        self.c.fill(1.0)\r\n        self.d.fill(1.0)\r\n\r\n        for obj in self.scene_objects:\r\n            obj.render(self.u, self.c, self.d)\r\n\r\n        for obj in self.scene_objects:\r\n            obj.update_field(self.u, self.t)\r\n\r\n    def get_field(self):\r\n        \"\"\"\r\n        Get the current state of the simulation field.\r\n        @return: A 2D array representing the simulation field.\r\n        \"\"\"\r\n        return self.u\r\n\r\n    def render_visualization(self, image=None):\r\n        # clear wave speed field and dampening field\r\n        if image is None:\r\n            image = np.zeros((self.c.shape[0], self.c.shape[1], 3), dtype=np.uint8)\r\n\r\n        for obj in self.scene_objects:\r\n            obj.render_visualization(image)\r\n\r\n        return image\r\n\r\n\r\n"
  },
  {
    "path": "wave_sim2d/wave_visualizer.py",
    "content": "import numpy as np\r\nimport cupy as cp\r\nimport cv2\r\nimport matplotlib.pyplot\r\n\r\ncolormap_icefire = [[179, 224, 216], [178, 223, 216], [176, 222, 215], [175, 221, 215], [173, 219, 214], [171, 218, 214], [169, 217, 214], [167, 215, 213], [165, 214, 213], [162, 212, 212], [160, 210, 212], [157, 209, 211], [154, 207, 211], [151, 205, 210], [148, 203, 210], [146, 201, 209], [143, 199, 209], [140, 198, 208], [137, 196, 208], [134, 194, 208], [131, 192, 207], [128, 190, 207], [125, 188, 207], [122, 187, 207], [119, 185, 206], [116, 183, 206], [113, 181, 206], [110, 179, 206], [108, 177, 206], [105, 176, 205], [102, 174, 205], [99, 172, 205], [97, 170, 205], [94, 168, 205], [91, 166, 205], [89, 164, 205], [86, 162, 205], [84, 161, 205], [82, 159, 205], [79, 157, 205], [77, 155, 205], [75, 153, 206], [73, 151, 206], [71, 149, 206], [69, 147, 206], [68, 145, 206], [66, 143, 206], [65, 140, 206], [64, 138, 206], [63, 136, 206], [62, 134, 206], [61, 132, 206], [61, 130, 205], [61, 127, 205], [60, 125, 205], [60, 123, 204], [60, 121, 203], [60, 118, 203], [61, 116, 202], [61, 114, 201], [61, 112, 200], [62, 109, 198], [62, 107, 197], [63, 105, 195], [64, 103, 194], [65, 100, 192], [65, 98, 190], [66, 96, 187], [67, 94, 185], [67, 92, 183], [68, 90, 180], [68, 88, 177], [69, 86, 174], [69, 85, 171], [69, 83, 168], [70, 81, 165], [70, 79, 162], [70, 78, 158], [69, 76, 155], [69, 75, 151], [69, 73, 148], [68, 72, 144], [68, 70, 141], [67, 69, 137], [66, 67, 134], [66, 66, 130], [65, 65, 127], [64, 63, 123], [63, 62, 120], [62, 61, 116], [61, 60, 113], [60, 59, 109], [59, 57, 106], [58, 56, 103], [57, 55, 99], [55, 54, 96], [54, 53, 93], [53, 52, 90], [52, 50, 87], [51, 49, 84], [50, 48, 81], [48, 47, 78], [47, 46, 75], [46, 45, 72], [45, 44, 70], [44, 43, 67], [43, 42, 65], [42, 41, 62], [41, 40, 60], [40, 39, 57], [39, 38, 55], [38, 37, 53], [37, 37, 51], [37, 36, 49], [36, 35, 47], [35, 35, 45], [35, 34, 44], [34, 33, 42], [34, 33, 41], [33, 32, 39], [33, 32, 38], [33, 32, 37], [33, 31, 36], [33, 31, 35], [33, 31, 35], [34, 30, 34], [34, 30, 33], [34, 30, 33], [35, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]]\r\ncolormap_wave1 = [[255, 255, 255], [254, 254, 253], [254, 253, 252], [253, 252, 250], [253, 250, 248], [252, 249, 246], [252, 248, 244], [251, 246, 242], [251, 245, 240], [250, 243, 237], [250, 242, 235], [249, 240, 232], [248, 238, 230], [248, 237, 227], [247, 235, 224], [247, 233, 221], [246, 231, 218], [245, 229, 215], [245, 227, 212], [244, 225, 209], [243, 223, 206], [242, 221, 203], [242, 219, 200], [241, 217, 196], [240, 215, 193], [239, 213, 190], [239, 211, 186], [238, 208, 183], [237, 206, 179], [236, 204, 176], [235, 202, 172], [234, 199, 169], [233, 197, 165], [232, 195, 162], [231, 192, 158], [230, 190, 155], [230, 188, 151], [228, 185, 148], [227, 183, 144], [226, 181, 141], [225, 178, 137], [224, 176, 134], [223, 174, 130], [222, 171, 127], [221, 169, 124], [219, 167, 120], [218, 164, 117], [217, 162, 114], [216, 160, 111], [214, 157, 108], [213, 155, 105], [212, 153, 102], [210, 151, 99], [209, 149, 96], [207, 146, 93], [206, 144, 91], [204, 142, 88], [203, 140, 85], [201, 138, 83], [199, 136, 80], [198, 134, 78], [196, 132, 76], [194, 130, 74], [193, 128, 72], [191, 127, 70], [189, 125, 69], [187, 123, 67], [185, 121, 65], [183, 119, 63], [180, 117, 62], [178, 115, 60], [176, 113, 58], [173, 111, 56], [170, 109, 55], [168, 107, 53], [165, 105, 52], [162, 102, 50], [159, 100, 49], [156, 98, 47], [154, 96, 46], [151, 94, 44], [148, 92, 43], [144, 90, 41], [141, 87, 40], [138, 85, 39], [135, 83, 37], [132, 81, 36], [129, 79, 35], [125, 76, 34], [122, 74, 33], [119, 72, 31], [115, 70, 30], [112, 68, 29], [109, 66, 28], [105, 64, 27], [102, 62, 27], [99, 60, 26], [96, 58, 25], [93, 56, 25], [89, 54, 25], [86, 52, 25], [83, 51, 25], [80, 49, 25], [77, 47, 25], [74, 45, 25], [71, 44, 25], [68, 42, 25], [65, 41, 25], [62, 39, 25], [60, 38, 25], [57, 37, 25], [54, 35, 25], [52, 34, 25], [49, 33, 25], [47, 32, 25], [45, 31, 25], [43, 30, 25], [40, 29, 25], [39, 28, 25], [37, 28, 25], [35, 27, 25], [33, 27, 25], [32, 26, 25], [30, 26, 25], [29, 25, 25], [28, 25, 25], [27, 25, 25], [26, 25, 25], [26, 25, 26], [26, 26, 27], [26, 26, 28], [26, 26, 30], [26, 27, 31], [26, 27, 33], [26, 28, 34], [26, 29, 36], [26, 30, 38], [26, 31, 40], [26, 32, 42], [26, 33, 44], [26, 34, 47], [26, 35, 49], [26, 37, 51], [26, 38, 54], [26, 40, 56], [26, 41, 59], [26, 43, 62], [26, 44, 64], [26, 46, 67], [26, 48, 70], [26, 50, 73], [27, 51, 76], [28, 53, 79], [28, 55, 82], [29, 57, 85], [30, 59, 88], [31, 61, 91], [32, 64, 94], [33, 66, 97], [35, 68, 101], [36, 70, 104], [37, 72, 107], [38, 74, 110], [40, 77, 113], [41, 79, 117], [42, 81, 120], [44, 84, 123], [45, 86, 126], [47, 88, 130], [48, 91, 133], [50, 93, 136], [51, 95, 139], [53, 98, 142], [54, 100, 145], [56, 102, 148], [58, 104, 151], [59, 107, 154], [61, 109, 157], [63, 111, 160], [64, 114, 163], [66, 116, 165], [68, 118, 168], [70, 120, 171], [71, 122, 173], [73, 125, 176], [75, 127, 178], [77, 129, 181], [78, 131, 183], [80, 133, 185], [82, 135, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]]\r\ncolormap_wave2 = [[255, 255, 255], [253, 254, 254], [252, 254, 253], [250, 253, 252], [248, 253, 252], [246, 252, 251], [244, 252, 250], [242, 251, 249], [240, 251, 247], [237, 250, 246], [235, 250, 245], [232, 249, 244], [230, 248, 243], [227, 248, 242], [224, 247, 240], [221, 247, 239], [218, 246, 238], [215, 245, 236], [212, 245, 235], [209, 244, 234], [206, 243, 232], [203, 242, 231], [200, 242, 229], [196, 241, 228], [193, 240, 226], [190, 239, 225], [186, 239, 223], [183, 238, 222], [179, 237, 220], [176, 236, 218], [172, 235, 217], [169, 234, 215], [165, 233, 213], [162, 232, 212], [158, 231, 210], [155, 231, 208], [151, 230, 206], [148, 228, 205], [144, 227, 203], [141, 226, 201], [137, 225, 199], [134, 224, 198], [130, 223, 196], [127, 222, 194], [124, 221, 192], [120, 219, 190], [117, 218, 188], [114, 217, 187], [111, 216, 185], [108, 214, 183], [105, 213, 181], [102, 212, 179], [99, 210, 177], [96, 209, 176], [93, 207, 174], [91, 206, 172], [88, 204, 170], [85, 203, 168], [83, 201, 166], [80, 199, 164], [78, 198, 163], [76, 196, 161], [74, 194, 159], [72, 193, 157], [70, 191, 155], [69, 189, 153], [67, 187, 152], [65, 185, 149], [63, 183, 147], [62, 180, 145], [60, 178, 143], [59, 175, 141], [57, 173, 138], [56, 170, 136], [54, 167, 133], [53, 164, 131], [52, 161, 128], [50, 158, 125], [49, 155, 123], [48, 152, 120], [47, 149, 117], [46, 146, 115], [45, 143, 112], [43, 140, 109], [42, 136, 106], [41, 133, 103], [40, 130, 101], [39, 126, 98], [39, 123, 95], [38, 119, 92], [37, 116, 89], [36, 112, 86], [35, 109, 84], [35, 106, 81], [34, 102, 78], [33, 99, 75], [33, 95, 73], [32, 92, 70], [31, 89, 67], [31, 86, 65], [30, 82, 62], [30, 79, 60], [29, 76, 57], [29, 73, 55], [28, 70, 52], [28, 67, 50], [28, 64, 48], [27, 61, 46], [27, 58, 44], [27, 55, 42], [26, 53, 40], [26, 50, 38], [26, 48, 37], [26, 46, 35], [26, 43, 34], [26, 41, 32], [26, 39, 31], [26, 37, 30], [26, 35, 29], [26, 34, 28], [26, 32, 27], [26, 31, 26], [26, 29, 26], [26, 28, 25], [26, 27, 25], [26, 27, 25], [26, 26, 25], [26, 25, 25], [26, 25, 26], [26, 25, 26], [26, 25, 27], [26, 25, 27], [26, 25, 28], [27, 25, 30], [27, 25, 31], [27, 26, 32], [28, 27, 34], [28, 27, 35], [28, 28, 37], [29, 29, 39], [29, 30, 41], [30, 31, 43], [30, 32, 45], [31, 34, 48], [31, 35, 50], [32, 36, 53], [32, 38, 55], [33, 40, 58], [33, 41, 61], [34, 43, 64], [35, 45, 66], [36, 47, 69], [36, 49, 73], [37, 51, 76], [38, 53, 79], [39, 55, 82], [39, 57, 85], [40, 59, 89], [41, 61, 92], [42, 64, 95], [43, 66, 99], [44, 68, 102], [45, 71, 105], [46, 73, 109], [47, 76, 112], [48, 78, 116], [49, 80, 119], [50, 83, 122], [52, 86, 126], [53, 88, 129], [54, 91, 133], [55, 93, 136], [57, 96, 139], [58, 98, 143], [59, 100, 146], [60, 103, 149], [62, 105, 152], [63, 108, 155], [65, 110, 158], [66, 113, 161], [68, 115, 164], [69, 117, 167], [71, 120, 170], [72, 122, 173], [74, 124, 175], [75, 126, 178], [77, 128, 180], [79, 131, 183], [80, 133, 185], [82, 134, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]]\r\ncolormap_wave3 = [[253, 203, 160], [252, 201, 158], [252, 200, 156], [252, 198, 154], [251, 196, 152], [251, 194, 149], [251, 192, 147], [250, 190, 145], [250, 189, 142], [249, 187, 139], [249, 185, 135], [248, 182, 132], [248, 179, 129], [247, 177, 125], [247, 175, 122], [247, 172, 119], [246, 168, 115], [246, 166, 112], [245, 164, 108], [245, 161, 105], [244, 158, 102], [243, 155, 98], [243, 152, 95], [242, 150, 92], [242, 147, 88], [241, 145, 86], [240, 142, 83], [240, 139, 80], [239, 137, 78], [238, 134, 75], [237, 131, 73], [237, 129, 70], [236, 126, 68], [235, 124, 67], [234, 121, 65], [233, 118, 63], [232, 115, 61], [231, 112, 61], [230, 110, 59], [229, 107, 57], [228, 104, 56], [228, 102, 55], [226, 100, 55], [224, 97, 55], [224, 94, 54], [222, 92, 54], [221, 89, 53], [219, 87, 53], [218, 84, 53], [217, 83, 53], [215, 80, 54], [213, 78, 54], [211, 76, 54], [209, 74, 55], [208, 72, 55], [206, 70, 56], [203, 68, 57], [201, 66, 57], [199, 65, 58], [198, 64, 59], [196, 62, 59], [193, 60, 60], [191, 59, 61], [188, 57, 61], [186, 56, 62], [183, 55, 62], [181, 55, 63], [179, 54, 63], [176, 53, 63], [173, 52, 64], [171, 51, 64], [168, 50, 65], [165, 49, 65], [162, 49, 65], [159, 48, 66], [157, 48, 66], [155, 47, 66], [152, 47, 66], [149, 46, 66], [146, 46, 65], [143, 46, 65], [140, 45, 65], [138, 45, 65], [135, 45, 64], [132, 44, 64], [130, 44, 63], [127, 44, 63], [124, 44, 62], [121, 43, 61], [118, 43, 60], [115, 43, 60], [112, 42, 60], [110, 42, 59], [108, 42, 58], [105, 42, 57], [102, 41, 56], [99, 41, 55], [97, 40, 54], [94, 40, 53], [91, 40, 52], [89, 39, 51], [86, 39, 50], [84, 38, 49], [82, 38, 48], [79, 37, 47], [77, 37, 46], [74, 36, 45], [72, 36, 44], [70, 35, 43], [68, 35, 42], [65, 34, 41], [64, 34, 40], [62, 33, 39], [60, 33, 38], [58, 33, 38], [56, 32, 38], [54, 32, 36], [52, 32, 35], [50, 32, 35], [48, 31, 35], [47, 31, 34], [45, 31, 34], [43, 31, 33], [42, 30, 33], [41, 30, 32], [40, 30, 32], [39, 30, 32], [38, 30, 32], [37, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]]\r\ncolormap_wave4 = [[246, 230, 183], [246, 229, 182], [246, 227, 180], [246, 226, 178], [246, 224, 176], [245, 222, 173], [245, 219, 170], [244, 217, 167], [244, 214, 163], [244, 211, 160], [243, 209, 156], [243, 206, 152], [242, 203, 148], [242, 200, 144], [241, 196, 140], [241, 193, 136], [241, 190, 132], [240, 186, 128], [240, 183, 124], [239, 180, 120], [239, 176, 116], [238, 173, 112], [238, 170, 108], [237, 166, 104], [237, 163, 100], [236, 160, 97], [236, 156, 93], [236, 153, 90], [235, 150, 87], [235, 147, 84], [235, 144, 81], [234, 140, 78], [234, 137, 76], [234, 134, 74], [234, 131, 71], [233, 127, 69], [233, 124, 67], [233, 121, 65], [233, 118, 64], [232, 115, 62], [232, 112, 61], [232, 109, 60], [232, 106, 59], [232, 103, 58], [232, 101, 58], [232, 98, 57], [232, 95, 57], [231, 93, 57], [230, 90, 57], [230, 88, 57], [229, 85, 57], [227, 83, 57], [226, 81, 57], [224, 78, 57], [222, 76, 58], [220, 74, 58], [217, 72, 59], [215, 70, 59], [212, 67, 60], [210, 65, 60], [207, 63, 60], [204, 62, 61], [201, 60, 61], [199, 58, 61], [196, 56, 61], [193, 55, 61], [189, 53, 61], [186, 52, 62], [183, 50, 61], [180, 49, 61], [176, 48, 61], [173, 46, 61], [170, 45, 61], [166, 44, 61], [163, 43, 61], [159, 42, 60], [156, 40, 60], [152, 39, 60], [149, 39, 59], [146, 38, 58], [142, 37, 58], [139, 36, 57], [135, 35, 56], [132, 34, 55], [128, 34, 54], [125, 33, 53], [122, 32, 52], [118, 31, 51], [115, 31, 50], [111, 30, 49], [108, 29, 48], [105, 28, 46], [101, 28, 45], [98, 27, 44], [94, 26, 43], [91, 25, 41], [88, 24, 40], [85, 24, 38], [82, 23, 37], [78, 22, 35], [75, 21, 34], [72, 20, 32], [69, 19, 31], [66, 18, 29], [63, 18, 28], [60, 16, 26], [57, 16, 25], [55, 15, 24], [52, 14, 22], [49, 13, 21], [46, 12, 19], [44, 11, 18], [41, 11, 17], [39, 10, 16], [36, 9, 14], [34, 8, 13], [31, 8, 12], [29, 7, 11], [27, 7, 10], [25, 6, 10], [23, 5, 9], [21, 5, 8], [19, 4, 7], [17, 4, 7], [16, 4, 6], [14, 3, 6], [13, 3, 5], [13, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [13, 3, 5], [14, 3, 5], [15, 3, 6], [16, 4, 6], [18, 4, 7], [20, 4, 7], [21, 5, 8], [23, 5, 9], [26, 6, 10], [28, 7, 11], [30, 7, 12], [33, 8, 13], [35, 9, 14], [38, 9, 15], [40, 10, 17], [43, 11, 18], [46, 12, 19], [48, 13, 21], [51, 14, 22], [54, 15, 24], [57, 16, 25], [60, 17, 27], [63, 18, 28], [67, 19, 30], [70, 20, 31], [73, 20, 33], [76, 21, 34], [80, 22, 36], [83, 23, 37], [86, 24, 39], [90, 25, 40], [93, 26, 42], [96, 27, 43], [100, 27, 44], [103, 28, 46], [107, 29, 47], [110, 30, 48], [114, 30, 50], [117, 31, 51], [121, 32, 52], [124, 33, 53], [128, 34, 54], [131, 34, 55], [135, 35, 56], [138, 36, 57], [142, 37, 57], [146, 38, 58], [149, 39, 59], [153, 39, 59], [156, 40, 60], [160, 42, 60], [164, 43, 61], [167, 44, 61], [171, 45, 61], [174, 46, 61], [178, 48, 62], [181, 49, 62], [184, 51, 62], [188, 52, 62], [191, 54, 62], [194, 56, 61], [197, 57, 61], [200, 59, 61], [203, 61, 61], [206, 63, 60], [209, 65, 60], [212, 67, 59], [215, 69, 59], [217, 72, 59], [220, 74, 58], [222, 76, 58], [224, 78, 57], [226, 81, 57], [227, 83, 57], [229, 86, 57], [230, 88, 57], [230, 91, 57], [231, 94, 57], [232, 97, 57], [232, 99, 57], [232, 102, 58], [232, 105, 59], [232, 108, 60], [232, 111, 61], [232, 114, 62], [232, 117, 63], [233, 120, 65], [233, 123, 66], [233, 127, 68], [233, 130, 71], [234, 133, 73], [234, 137, 75], [234, 140, 78], [235, 143, 81], [235, 147, 84], [235, 150, 87], [236, 153, 90], [236, 157, 94], [236, 160, 97], [237, 164, 101], [237, 167, 105], [238, 170, 109], [238, 174, 113], [239, 178, 117], [239, 181, 121], [240, 184, 126], [240, 188, 130], [241, 192, 134], [241, 195, 138], [242, 198, 142], [242, 202, 147], [243, 205, 151], [243, 208, 155], [244, 211, 159], [244, 214, 162], [244, 216, 166], [245, 219, 169], [245, 221, 173], [246, 224, 175], [246, 226, 178], [246, 227, 180], [246, 229, 182], [246, 230, 183]]\r\n\r\n\r\nclass WaveVisualizer:\r\n    def __init__(self, field_colormap, intensity_colormap):\r\n        self.field_colormap = field_colormap\r\n        self.intensity_colormap = intensity_colormap\r\n        self.intensity = None\r\n        self.intensity_exp_average_factor = 0.98\r\n        self.field = None\r\n        self.visualization_image = None\r\n\r\n    def update(self, wave_sim):\r\n        self.field = wave_sim.get_field()\r\n\r\n        if self.intensity is None:\r\n            self.intensity = cp.zeros_like(self.field)\r\n\r\n        t = self.intensity_exp_average_factor\r\n        self.intensity = self.intensity*t + (self.field**2)*(1.0-t)\r\n        self.visualization_image = wave_sim.render_visualization()\r\n\r\n    def render_intensity(self, brightness_scale=1.0, exp=0.5, overlay_visualization=True):\r\n        gray = (cp.clip((self.intensity**exp)*brightness_scale, 0.0, 1.0) * 254.0).astype(np.uint8)\r\n        img = self.intensity_colormap[gray].get() if self.intensity_colormap is not None else gray.get()\r\n        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\r\n        if overlay_visualization:\r\n            img = cv2.add(img, self.visualization_image)\r\n        return img\r\n\r\n    def render_field(self, brightness_scale=1.0, overlay_visualization=True):\r\n        gray = (cp.clip(self.field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8)\r\n        img = self.field_colormap[gray].get() if self.field_colormap is not None else gray.get()\r\n        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\r\n        if overlay_visualization:\r\n            img = cv2.add(img, self.visualization_image)\r\n        return img\r\n\r\n\r\ndef get_colormap_lut(name, invert, black_level=0.0, make_symmetric=False):\r\n    if name == 'icefire': color_values = np.array(colormap_icefire)/255\r\n    elif name == 'colormap_wave1': color_values = np.array(colormap_wave1)/255\r\n    elif name == 'colormap_wave2':color_values = np.array(colormap_wave2) / 255\r\n    elif name == 'colormap_wave3': color_values = np.array(colormap_wave3) / 255\r\n    elif name == 'colormap_wave4': color_values = np.array(colormap_wave4) / 255\r\n    else:\r\n        colormap = matplotlib.pyplot.get_cmap(name)\r\n        color_values = colormap(np.linspace(0, 1, 255))\r\n\r\n    if invert:\r\n        color_values = 1.0-color_values\r\n\r\n    if make_symmetric:\r\n        src = color_values.copy()\r\n        color_values[255:126:-1, :] = src[0:255:2, :]\r\n        color_values[0:128, :] = src[0:255:2, :]\r\n\r\n    color_values = np.clip(color_values*(1.0-black_level)+black_level, 0, 255)\r\n\r\n    return cp.asarray((color_values*255).astype(np.uint8))"
  }
]