[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nPlease provide us with the following to receive timely help:\n1. A minimum example to reproduce the bug. Keep your code as short as possible but still directly runnable.\n2. Model files, especially when the bug is only triggered on specific models.\n3. **Complete** outputs of the program when the bug is triggered. Please do **not** just include the last few lines. If it's very long, you can use [PasteBin](https://pastebin.com/) or upload to a file-sharing service.\n4. Detailed instructions to reproduce the problem. If you changed part of our tool, please rebase your changes to main branch and push your changes to a fork so we can investigate easier.\n\nWithout the above information, you might not be able to receive timely help from us.\n\n\n**System configuration:**\n - OS: [e.g. Ubuntu 22.04. Windows and MacOS are not supported.]\n - Python version: [e.g., Python 3.8]\n - Pytorch Version: [e.g., PyTorch 1.12]\n - Hardware: [e.g., RTX 4090]\n - Have you tried to reproduce the problem in a cleanly created conda/virtualenv environment using official installation instructions and the latest code on the main branch?: [Yes/No]\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".gitignore",
    "content": "tmp\nbuild\n__pycache__\n*.egg-info\ndist\n*.swp\n*.swo\n*.log\n.trace_graph\nVerified_ret*.npy\nVerified-acc*.npy\nvnn-comp_*.npz\n*.tar.gz\nverifier_log_*\n.vscode/\n*.pt\n.idea\n*.so\nrelease\n*.compiled\n.DS_Store\n*.out\n*.txt\nrelease\nrelease_abcrown\ncachier\nout.csv\nresults.csv\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Set the version of Python and other tools you might need\nbuild:\n  os: ubuntu-20.04\n  tools:\n    python: \"3.11\"\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n   configuration: doc/conf.py\n\n# Optionally declare the Python requirements required to build your docs\npython:\n   install:\n    - method: pip\n      path: .\n    - requirements: doc/requirements.txt"
  },
  {
    "path": "CONTRIBUTORS",
    "content": "Team leaders:\n* Faculty: Huan Zhang (huan@huan-zhang.com), UIUC\n* Student: Xiangru Zhong (xiangru4@illinois.edu), UIUC\n\nCurrent developers (* indicates members of VNN-COMP 2025 team):\n* \\*Duo Zhou (duozhou2@illinois.edu), UIUC\n* \\*Keyi Shen (keyis2@illinois.edu), UIUC (graduated, now at Georgia Tech)\n* \\*Hesun Chen (hesunc2@illinois.edu), UIUC\n* \\*Haoyu Li (haoyuli5@illinois.edu), UIUC\n* \\*Ruize Gao (ruizeg2@illinois.edu), UIUC\n* \\*Hao Cheng (haoc539@illinois.edu), UIUC\n* Zhouxing Shi (zhouxingshichn@gmail.com), UCLA/UC Riverside\n* Lei Huang (leih5@illinois.edu), UIUC\n* Taobo Liao (taobol2@illinois.edu), UIUC\n* Jorge Chavez (jorgejc2@illinois.edu), UIUC\n\nPast developers:\n* Hongji Xu (hx84@duke.edu), Duke University (intern with Prof. Huan Zhang)\n* Christopher Brix (brix@cs.rwth-aachen.de), RWTH Aachen University\n* Hao Chen (haoc8@illinois.edu), UIUC\n* Keyu Lu (keyulu2@illinois.edu), UIUC\n* Kaidi Xu (kx46@drexel.edu), Drexel University\n* Sanil Chawla (schawla7@illinois.edu), UIUC\n* Linyi Li (linyi2@illinois.edu), UIUC\n* Zhuolin Yang (zhuolin5@illinois.edu), UIUC\n* Zhuowen Yuan (realzhuowen@gmail.com), UIUC\n* Qirui Jin (qiruijin@umich.edu), University of Michigan\n* Shiqi Wang (sw3215@columbia.edu), Columbia University\n* Yihan Wang (yihanwang@ucla.edu), UCLA\n* Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU\n\nauto_LiRPA is currently supported in part by the National Science Foundation (NSF; award 2331967, 2525287), the AI2050 program at Schmidt Science, the Virtual Institute for Scientific Software (VISS) at Georgia Tech, the University Research Program at Toyota Research Institute (TRI), and a Mathworks research award.\n\nThe team acknowledges the financial and advisory support from Prof. Zico Kolter (zkolter@cs.cmu.edu), Prof. Cho-Jui Hsieh (chohsieh@cs.ucla.edu), Prof. Suman Jana (suman@cs.columbia.edu), Prof. Bo Li (lbo@illinois.edu), and Prof. Xue Lin (xue.lin@northeastern.edu) during 2021 - 2023.\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (C) 2021-2025 The α,β-CROWN Team\nSee CONTRIBUTORS for the list of all contributors and their affiliations.\n    Team leaders: \n         Faculty: Huan Zhang <huan@huan-zhang.com> (UIUC)\n         Student: Xiangru Zhong <xiangru4@illinois.edu> (UIUC)\n    Current developers:\n                  Duo Zhou <duozhou2@illinois.edu> (UIUC)\n                  Keyi Shen <keyis2@illinois.edu> (UIUC/Georgia Tech)\n                  Hesun Chen <hesunc2@illinois.edu> (UIUC)\n                  Haoyu Li <haoyuli5@illinois.edu> (UIUC)\n                  Ruize Gao <ruizeg2@illinois.edu> (UIUC)\n                  Hao Cheng <haoc539@illinois.edu> (UIUC)\n                  Zhouxing Shi <zhouxingshichn@gmail.com> (UCLA/UC Riverside)\n                  Lei Huang <leih5@illinois.edu> (UIUC)\n                  Taobo Liao <taobol2@illinois.edu> (UIUC)\n                  Jorge Chavez <jorgejc2@illinois.edu> (UIUC)\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "# auto_LiRPA: Automatic Linear Relaxation based Perturbation Analysis for Neural Networks\n\n[![Documentation Status](https://readthedocs.org/projects/auto-lirpa/badge/?version=latest)](https://auto-lirpa.readthedocs.io/en/latest/?badge=latest)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://PaperCode.cc/AutoLiRPA-Demo)\n[![Video Introduction](https://img.shields.io/badge/play-video-red.svg)](http://PaperCode.cc/AutoLiRPA-Video)\n[![BSD license](https://img.shields.io/badge/License-BSD-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)\n\n<p align=\"center\">\n<a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_lirpa_2.png\" width=\"45%\" height=\"45%\" float=\"left\"></a>\n<a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_lirpa_1.png\" width=\"45%\" height=\"45%\" float=\"right\"></a>\n</p>\n\n## What's New?\n- [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) (using `auto_LiRPA` as its core library) is the winner of [VNN-COMP 2025](https://sites.google.com/view/vnn2025) and is **ranked top-1** in all [scored benchmarks](https://github.com/VNN-COMP/vnncomp2025_results/blob/main/SCORING-SMALL-TOL/latex/main.pdf). (08/2025)\n- Bounding of computation graphs containing Jacobian operators now supports more nonlinear operators (e.g., ```tanh```, ```sigmoid```), enabling verification of [continuous-time Lyapunov stability](https://github.com/Verified-Intelligence/Two-Stage_Neural_Controller_Training). (12/2025)\n- [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) (using `auto_LiRPA` as its core library) is the winner of [VNN-COMP 2024](https://sites.google.com/view/vnn2024). Our tool is **ranked top-1** in all benchmarks (including 12 [regular track](https://github.com/ChristopherBrix/vnncomp2024_results/blob/main/SCORING/latex/results_regular_track.pdf) and 9 [extended track](https://github.com/ChristopherBrix/vnncomp2024_results/blob/main/SCORING/latex/results_extended_track.pdf) benchmarks). (08/2024)\n- The [INVPROP algorithm](https://arxiv.org/pdf/2302.01404.pdf) allows to compute overapproximationsw of preimages (the set of inputs of an NN generating a given output set) and tighten bounds using output constraints. (03/2024)\n- Branch-and-bound support for non-ReLU and general nonlinearities ([GenBaB](https://arxiv.org/pdf/2405.21063)) with optimizable bounds (α-CROWN) for new nonlinear functions (sin, cos, GeLU). We achieve significant improvements on verifying neural networks with non-ReLU nonlinearities such as Transformers, LSTM, and [ML4ACOPF](https://github.com/AI4OPT/ml4acopf_benchmark). (09/2023)\n- [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2023](https://sites.google.com/view/vnn2023). (08/2023)\n- Bound computation for higher-order computational graphs to support bounding Jacobian, Jacobian-vector products, and [local Lipschitz constants](https://arxiv.org/abs/2210.07394). (11/2022)\n- Our neural network verification tool [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2022](https://sites.google.com/view/vnn2022). Our library supports the large CIFAR100, TinyImageNet and ImageNet models in VNN-COMP 2022. (09/2022)\n- Implementation of **general cutting planes** ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)), support of more activation functions and improved performance and scalability. (09/2022)\n- Our neural network verification tool [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. (09/2021)\n- [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021)\n- Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021)\n- A memory efficient GPU implementation of backward (CROWN) bounds for\nconvolutional layers. (10/31/2020)\n- Certified defense models for downscaled ImageNet, TinyImageNet, CIFAR-10, LSTM/Transformer. (08/20/2020)\n- Adding support to **complex vision models** including DenseNet, ResNeXt and WideResNet. (06/30/2020)\n- **Loss fusion**, a technique that reduces training cost of tight LiRPA bounds\n(e.g. CROWN-IBP) to the same asymptotic complexity of IBP, making LiRPA based certified\ndefense scalable to large datasets (e.g., TinyImageNet, downscaled ImageNet). (06/30/2020)\n- **Multi-GPU** support to scale LiRPA based training to large models and datasets. (06/30/2020)\n- Initial release. (02/28/2020)\n\n## Introduction\n\n`auto_LiRPA` is a library for automatically deriving and computing bounds with\nlinear relaxation based perturbation analysis (LiRPA) (e.g.\n[CROWN](https://arxiv.org/pdf/1811.00866.pdf) and\n[DeepPoly](https://files.sri.inf.ethz.ch/website/papers/DeepPoly.pdf)) for\nneural networks, which is a useful tool for formal robustness verification. We\ngeneralize existing LiRPA algorithms for feed-forward neural networks to a\ngraph algorithm on general computational graphs, defined by PyTorch.\nAdditionally, our implementation is also automatically **differentiable**,\nallowing optimizing network parameters to shape the bounds into certain\nspecifications (e.g., certified defense). You can find [a video ▶️ introduction\nhere](http://PaperCode.cc/AutoLiRPA-Video).\n\nOur library supports the following algorithms:\n\n* Backward mode LiRPA bound propagation ([CROWN](https://arxiv.org/pdf/1811.00866.pdf)/[DeepPoly](https://files.sri.inf.ethz.ch/website/papers/DeepPoly.pdf))\n* Backward mode LiRPA bound propagation with optimized bounds ([α-CROWN](https://arxiv.org/pdf/2011.13824.pdf))\n* Backward mode LiRPA bound propagation with split constraints ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf) for ReLU, and [GenBaB](https://arxiv.org/pdf/2405.21063) for general nonlinear functions)\n* Generalized backward mode LiRPA bound propagation with general cutting plane constraints ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf))\n* Backward mode LiRPA bound propagation with bounds tightened using output constraints ([INVPROP](https://arxiv.org/pdf/2302.01404.pdf))\n* Generalized backward mode LiRPA bound propagation for higher-order computational graphs  ([Shi et al., 2022](https://arxiv.org/abs/2210.07394))\n* Forward mode LiRPA bound propagation ([Xu et al., 2020](https://arxiv.org/pdf/2002.12920))\n* Forward mode LiRPA bound propagation with optimized bounds (similar to [α-CROWN](https://arxiv.org/pdf/2011.13824.pdf))\n* Interval bound propagation ([IBP](https://arxiv.org/pdf/1810.12715.pdf))\n* Hybrid approaches, e.g., Forward+Backward, IBP+Backward ([CROWN-IBP](https://arxiv.org/pdf/1906.06316.pdf)), [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git))\n* MIP/LP formulation of neural networks\n\nOur library allows automatic bound derivation and computation for general\ncomputational graphs, in a similar manner that gradients are obtained in modern\ndeep learning frameworks -- users only define the computation in a forward\npass, and `auto_LiRPA` traverses through the computational graph and derives\nbounds for any nodes on the graph.  With `auto_LiRPA` we free users from\nderiving and implementing LiPRA for most common tasks, and they can simply\napply LiPRA as a tool for their own applications.  This is especially useful\nfor users who are not experts of LiRPA and cannot derive these bounds manually\n(LiRPA is significantly more complicated than backpropagation).\n\n## Technical Background in 1 Minute\n\nDeep learning frameworks such as PyTorch represent neural networks (NN) as\na computational graph, where each mathematical operation is a node and edges\ndefine the flow of computation:\n\n<p align=\"center\">\n<a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_LiRPA_background_1.png\" width=\"80%\"></a>\n</p>\n\nNormally, the inputs of a computation graph (which defines a NN) are data and\nmodel weights, and PyTorch goes through the graph and produces model prediction\n(a bunch of numbers):\n\n<p align=\"center\">\n<a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_LiRPA_background_2.png\" width=\"80%\"></a>\n</p>\n\nOur `auto_LiRPA` library conducts perturbation analysis on a computational\ngraph, where the input data and model weights are defined within some\nuser-defined ranges.  We get guaranteed output ranges (bounds):\n\n<p align=\"center\">\n<a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_LiRPA_background_3.png\" width=\"80%\"></a>\n</p>\n\n## Installation\n\nPython 3.11+ and PyTorch 2.0+ are required.\nIt is highly recommended to have a pre-installed PyTorch\nthat matches your system and our version requirement\n(see [PyTorch Get Started](https://pytorch.org/get-started)).\nThen you can install `auto_LiRPA` via:\n\n```bash\ngit clone https://github.com/Verified-Intelligence/auto_LiRPA\ncd auto_LiRPA\npip install .\n```\n\nIf you intend to modify this library, use `pip install -e .` instead.\n\n## Quick Start\n\nFirst define your computation as a `nn.Module` and wrap it using\n`auto_LiRPA.BoundedModule()`. Then, you can call the `compute_bounds` function\nto obtain certified lower and upper bounds under input perturbations:\n\n```python\nfrom auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm\n\n# Define computation as a nn.Module.\nclass MyModel(nn.Module):\n    def forward(self, x):\n        # Define your computation here.\n\nmodel = MyModel()\nmy_input = load_a_batch_of_data()\n# Wrap the model with auto_LiRPA.\nmodel = BoundedModule(model, my_input)\n# Define perturbation. Here we add Linf perturbation to input data.\nptb = PerturbationLpNorm(norm=np.inf, eps=0.1)\n# Make the input a BoundedTensor with the pre-defined perturbation.\nmy_input = BoundedTensor(my_input, ptb)\n# Regular forward propagation using BoundedTensor works as usual.\nprediction = model(my_input)\n# Compute LiRPA bounds using the backward mode bound propagation (CROWN).\nlb, ub = model.compute_bounds(x=(my_input,), method=\"backward\")\n```\n\nCheckout\n[examples/vision/simple_verification.py](examples/vision/simple_verification.py)\nfor a complete but very basic example.\n\n<a href=\"http://PaperCode.cc/AutoLiRPA-Demo\"><img align=\"left\" width=64 height=64 src=\"https://colab.research.google.com/img/colab_favicon_256px.png\"></a>\nWe also provide a [Google Colab Demo](http://PaperCode.cc/AutoLiRPA-Demo) including an example of computing verification\nbounds for a 18-layer ResNet model on CIFAR-10 dataset. Once the ResNet model\nis defined as usual in Pytorch, obtaining provable output bounds is as easy as\nobtaining gradients through autodiff. Bounds are efficiently computed on GPUs.\n\n## More Working Examples\n\nWe provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA`:\n\n* [Basic Bound Computation on a Toy Neural Network (simplest example)](examples/simple/toy.py)\n* [Basic Bound Computation with **Robustness Verification** of Neural Networks as an example](doc/src/examples.md#basic-bound-computation-and-robustness-verification-of-neural-networks)\n* [MIP/LP Formulation of Neural Networks](examples/simple/mip_lp_solver.py)\n* [Basic **Certified Adversarial Defense** Training](doc/src/examples.md#basic-certified-adversarial-defense-training)\n* [Large-scale Certified Defense Training on **ImageNet**](doc/src/examples.md#certified-adversarial-defense-on-downscaled-imagenet-and-tinyimagenet-with-loss-fusion)\n* [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/src/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist)\n* [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm)\n* [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense)\n* [Bounding **Jacobian** and **local Lipschitz constants**](examples/vision/jacobian.py)\n* [Compute an Overapproximate of Neural Network **Preimage**](examples/simple/invprop.py)\n\n`auto_LiRPA` has also been used in the following works:\n* [**α,β-CROWN for complete neural network verification**](https://github.com/Verified-Intelligence/alpha-beta-CROWN)\n* [**Fast certified robust training**](https://github.com/shizhouxing/Fast-Certified-Robust-Training)\n* [**Computing local Lipschitz constants**](https://github.com/shizhouxing/Local-Lipschitz-Constants)\n\n## Full Documentations\n\nFor more documentations, please refer to:\n\n* [Documentation homepage](https://auto-lirpa.readthedocs.io)\n* [API documentation](https://auto-lirpa.readthedocs.io/en/latest/api.html)\n* [Adding custom operators](https://auto-lirpa.readthedocs.io/en/latest/custom_op.html)\n* [Guide](https://auto-lirpa.readthedocs.io/en/latest/paper.html) for reproducing [our NeurIPS 2020 paper](https://arxiv.org/abs/2002.12920)\n\n## Publications\n\nPlease kindly cite our papers if you use the `auto_LiRPA` library. Full [BibTeX entries](doc/src/examples.md#bibtex-entries) can be found [here](doc/src/examples.md#bibtex-entries).\n\nThe general LiRPA based bound propagation algorithm was originally proposed in our paper:\n\n* [Automatic Perturbation Analysis for Scalable Certified Robustness and Beyond](https://arxiv.org/pdf/2002.12920).\nNeurIPS 2020.\nKaidi Xu\\*, Zhouxing Shi\\*, Huan Zhang\\*, Yihan Wang, Kai-Wei Chang, Minlie Huang, Bhavya Kailkhura, Xue Lin, Cho-Jui Hsieh (\\* Equal contribution)\n\nThe `auto_LiRPA` library is further extended to support:\n\n* Optimized bounds (α-CROWN):\n\n  [Fast and Complete: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers](https://arxiv.org/pdf/2011.13824.pdf). ICLR 2021. Kaidi Xu\\*, Huan Zhang\\*, Shiqi Wang, Yihan Wang, Suman Jana, Xue Lin and Cho-Jui Hsieh (\\* Equal contribution).\n\n* Split constraints (β-CROWN):\n\n  [Beta-CROWN: Efficient Bound Propagation with Per-neuron Split Constraints for Complete and Incomplete Neural Network Verification](https://arxiv.org/pdf/2103.06624.pdf). NeurIPS 2021. Shiqi Wang\\*, Huan Zhang\\*, Kaidi Xu\\*, Suman Jana, Xue Lin, Cho-Jui Hsieh and Zico Kolter (\\* Equal contribution).\n\n* General constraints (GCP-CROWN):\n\n  [GCP-CROWN: General Cutting Planes for Bound-Propagation-Based Neural Network Verification](https://arxiv.org/abs/2208.05740). Huan Zhang\\*, Shiqi Wang\\*, Kaidi Xu\\*, Linyi Li, Bo Li, Suman Jana, Cho-Jui Hsieh and Zico Kolter (\\* Equal contribution).\n\n* Higher-order computational graphs (Lipschitz constants and Jacobian):\n\n  [Efficiently Computing Local Lipschitz Constants of Neural Networks via Bound Propagation](https://arxiv.org/abs/2210.07394). NeurIPS 2022. Zhouxing Shi, Yihan Wang, Huan Zhang, Zico Kolter, Cho-Jui Hsieh.\n\n* Branch-and-bound for non-ReLU and general nonlinear functions (GenBaB):\n\n  [Neural Network Verification with Branch-and-Bound for General Nonlinearities](https://arxiv.org/pdf/2405.21063). TACAS 2025. Zhouxing Shi\\*, Qirui Jin\\*, Zico Kolter, Suman Jana, Cho-Jui Hsieh, Huan Zhang (\\* Equal contribution).\n\n* Tightening of bounds and preimage computation using the INVPROP algorithm:\n\n  [Provably Bounding Neural Network Preimages](https://arxiv.org/pdf/2302.01404.pdf). NeurIPS 2023. Suhas Kotha\\*, Christopher Brix\\*, Zico Kolter, Krishnamurthy (Dj) Dvijotham\\*\\*, Huan Zhang\\*\\* (\\* Equal contribution; \\*\\* Equal advising).\n\nCertified training (verification-aware training by optimizing bounds) using `auto_LiRPA` is improved with:\n\n* Much shorter warmup schedule and faster training:\n\n  [Fast Certified Robust Training with Short Warmup](https://arxiv.org/pdf/2103.17268.pdf). NeurIPS 2021. Zhouxing Shi\\*, Yihan Wang\\*, Huan Zhang, Jinfeng Yi and Cho-Jui Hsieh (\\* Equal contribution).\n\n* Training-time branch-and-bound:\n\n  [Certified Training with Branch-and-Bound: A Case Study on Lyapunov-stable Neural Control](https://arxiv.org/abs/2411.18235). Zhouxing Shi, Cho-Jui Hsieh, and Huan Zhang.\n\n\n## Developers and Copyright\n\nTeam leaders:\n* Faculty: Huan Zhang (huan@huan-zhang.com), UIUC\n* Student: Xiangru Zhong (xiangru4@illinois.edu), UIUC\n\nCurrent developers (* indicates members of VNN-COMP 2025 team):\n* \\*Duo Zhou (duozhou2@illinois.edu), UIUC\n* \\*Keyi Shen (keyis2@illinois.edu), UIUC (graduated, now at Georgia Tech)\n* \\*Hesun Chen (hesunc2@illinois.edu), UIUC\n* \\*Haoyu Li (haoyuli5@illinois.edu), UIUC\n* \\*Ruize Gao (ruizeg2@illinois.edu), UIUC\n* \\*Hao Cheng (haoc539@illinois.edu), UIUC\n* Zhouxing Shi (zhouxingshichn@gmail.com), UCLA/UC Riverside\n* Lei Huang (leih5@illinois.edu), UIUC\n* Taobo Liao (taobol2@illinois.edu), UIUC\n* Jorge Chavez (jorgejc2@illinois.edu), UIUC\n\nPast developers:\n* Hongji Xu (hx84@duke.edu), Duke University (intern with Prof. Huan Zhang)\n* Christopher Brix (brix@cs.rwth-aachen.de), RWTH Aachen University\n* Hao Chen (haoc8@illinois.edu), UIUC\n* Keyu Lu (keyulu2@illinois.edu), UIUC\n* Kaidi Xu (kx46@drexel.edu), Drexel University\n* Sanil Chawla (schawla7@illinois.edu), UIUC\n* Linyi Li (linyi2@illinois.edu), UIUC\n* Zhuolin Yang (zhuolin5@illinois.edu), UIUC\n* Zhuowen Yuan (realzhuowen@gmail.com), UIUC\n* Qirui Jin (qiruijin@umich.edu), University of Michigan\n* Shiqi Wang (sw3215@columbia.edu), Columbia University\n* Yihan Wang (yihanwang@ucla.edu), UCLA\n* Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU\n\n`auto_LiRPA` is currently supported in part by the National Science Foundation (NSF; award 2331967, 2525287), the AI2050 program at Schmidt Science, the Virtual Institute for Scientific Software (VISS) at Georgia Tech, the University Research Program at Toyota Research Institute (TRI), and a Mathworks research award.\n\nWe thank the [commits](https://github.com/Verified-Intelligence/auto_LiRPA/commits) and [pull requests](https://github.com/Verified-Intelligence/auto_LiRPA/pulls) from community contributors.\n\nOur library is released under the BSD 3-Clause license.\n"
  },
  {
    "path": "auto_LiRPA/__init__.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .bound_general import BoundedModule\nfrom .bound_multi_gpu import BoundDataParallel\nfrom .bounded_tensor import BoundedTensor, BoundedParameter\nfrom .perturbations import PerturbationLpNorm, PerturbationSynonym, PerturbationLinear\nfrom .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput\nfrom .bound_op_map import register_custom_op, unregister_custom_op\n\n__version__ = '0.7.0'\n"
  },
  {
    "path": "auto_LiRPA/backward_bound.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport os\nimport torch\nfrom torch import Tensor\nfrom collections import deque\nfrom tqdm import tqdm\nfrom .patches import Patches\nfrom .utils import *\nfrom .bound_ops import *\nimport warnings\n\nfrom typing import TYPE_CHECKING, List\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef batched_backward(self: 'BoundedModule', node, C, unstable_idx, batch_size,\n                     bound_lower=True, bound_upper=True, return_A=None):\n    if return_A is None: return_A = self.return_A\n    output_shape = node.output_shape[1:]\n    dim = int(prod(output_shape))\n    if unstable_idx is None:\n        unstable_idx = torch.arange(dim, device=self.device)\n        dense = True\n    else:\n        dense = False\n    unstable_size = get_unstable_size(unstable_idx)\n    print(f'Batched CROWN: node {node}, unstable size {unstable_size}')\n    crown_batch_size = self.bound_opts['crown_batch_size']\n    auto_batch_size = AutoBatchSize(self.bound_opts['crown_batch_size'], self.device, vram_ratio=self.bound_opts['batched_crown_max_vram_ratio'])\n\n    ret = []\n    ret_A = {} # if return_A, we will store A here\n    i = 0\n    torch.cuda.empty_cache()\n    with tqdm(total=unstable_size) as pbar:\n        while i < unstable_size:\n            crown_batch_size = auto_batch_size.batch_size\n            if isinstance(unstable_idx, tuple):\n                unstable_idx_batch = tuple(\n                    u[i : i + crown_batch_size]\n                    for u in unstable_idx\n                )\n                unstable_size_batch = len(unstable_idx_batch[0])\n            else:\n                unstable_idx_batch = unstable_idx[i : i + crown_batch_size]\n                unstable_size_batch = len(unstable_idx_batch)\n            auto_batch_size.record_actual_batch_size(unstable_size_batch)\n\n            if node.patches_start and node.mode == \"patches\":\n                assert C is None or C.type == 'Patches'\n                C_batch = Patches(shape=[\n                    unstable_size_batch, batch_size, *node.output_shape[1:-2], 1, 1],\n                    identity=1, unstable_idx=unstable_idx_batch,\n                    output_shape=[batch_size, *node.output_shape[1:]])\n            elif C.type == 'OneHot':\n                assert isinstance(node, (BoundLinear, BoundMatMul))\n                C_batch = OneHotC(\n                    [batch_size, unstable_size_batch, *node.output_shape[1:]],\n                    self.device, unstable_idx_batch, None)\n            else:\n                assert C is None or C.type == 'eye'\n                C_batch = torch.zeros([1, unstable_size_batch, dim], device=self.device)\n                C_batch[0, torch.arange(unstable_size_batch), unstable_idx_batch] = 1.0\n                C_batch = C_batch.expand(batch_size, -1, -1).view(\n                    batch_size, unstable_size_batch, *output_shape)\n            # overwrite return_A options to run backward general\n            ori_return_A_option = self.return_A\n            self.return_A = return_A\n\n            batch_ret = self.backward_general(\n                node, C_batch,\n                bound_lower=bound_lower, bound_upper=bound_upper,\n                average_A=False, need_A_only=False, unstable_idx=unstable_idx_batch)\n            ret.append(batch_ret[:2])\n\n            if len(batch_ret) > 2:\n                # A found, we merge A\n                batch_A = batch_ret[2]\n                ret_A = merge_A(node, batch_A, ret_A)\n\n            # restore return_A options\n            self.return_A = ori_return_A_option\n\n            pbar.update(unstable_size_batch)\n            i += unstable_size_batch\n            auto_batch_size.update()\n\n    if bound_lower:\n        lb = torch.cat([item[0].view(batch_size, -1) for item in ret], dim=1)\n        if dense:\n            # In this case, restore_sparse_bounds will not be called.\n            # And thus we restore the shape here.\n            lb = lb.reshape(batch_size, *output_shape)\n    else:\n        lb = None\n    if bound_upper:\n        ub = torch.cat([item[1].view(batch_size, -1) for item in ret], dim=1)\n        if dense:\n            # In this case, restore_sparse_bounds will not be called.\n            # And thus we restore the shape here.\n            ub = ub.reshape(batch_size, *output_shape)\n    else:\n        ub = None\n\n    if return_A:\n        return lb, ub, ret_A\n    else:\n        return lb, ub\n\n\ndef backward_general(\n    self: 'BoundedModule',\n    bound_node,\n    C,\n    start_backpropagation_at_node = None,\n    bound_lower=True,\n    bound_upper=True,\n    average_A=False,\n    need_A_only=False,\n    unstable_idx=None,\n    update_mask=None,\n    apply_output_constraints_to: Optional[List[str]] = None,\n    initial_As: Optional[dict] = None,\n    initial_lb: Optional[torch.tensor] = None,\n    initial_ub: Optional[torch.tensor] = None,\n):\n    use_beta_crown = self.bound_opts['optimize_bound_args']['enable_beta_crown']\n    tighten_input_bounds = (\n        self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n    )\n\n    if self.invprop_enabled():\n        self.invprop_init_infeasible_bounds(bound_node, C)\n    if bound_node.are_output_constraints_activated_for_layer(apply_output_constraints_to):\n        return self.backward_general_invprop(\n            initial_As=initial_As, initial_lb=initial_lb, initial_ub=initial_ub,\n            bound_node=bound_node, C=C,\n            start_backpropagation_at_node=start_backpropagation_at_node,\n            bound_lower=bound_lower, bound_upper=bound_upper,\n            average_A=average_A, need_A_only=need_A_only,\n            unstable_idx=unstable_idx, update_mask=update_mask\n        )\n\n    roots = self.roots()\n\n    if start_backpropagation_at_node is None:\n        # When output constraints are used, backward_general_with_output_constraint()\n        # adds additional layers at the end, performs the backpropagation through these,\n        # and then calls backward_general() on the output layer.\n        # In this case, the layer we start from (start_backpropagation_at_node) differs\n        # from the layer that should be bounded (bound_node)\n\n        # When output constraints are not used, the bounded node is the one where\n        # backpropagation starts.\n        start_backpropagation_at_node = bound_node\n\n    if self.verbose:\n        logger.debug(f'Bound backward from {start_backpropagation_at_node.__class__.__name__}({start_backpropagation_at_node.name}) '\n                     f'to bound {bound_node.__class__.__name__}({bound_node.name})')\n        if isinstance(C, BatchedCrownC):\n            logger.debug(f'  C: {C}')\n        elif C is not None:\n            logger.debug(f'  C: shape {C.shape}, type {type(C)}')\n    _print_time = bool(os.environ.get('AUTOLIRPA_PRINT_TIME', 0))\n\n    if isinstance(C, BatchedCrownC):\n        # If C is a str, use batched CROWN. If batched CROWN is not intended to\n        # be enabled, C must be a explicitly provided non-str object for this function.\n        if need_A_only or average_A:\n            raise ValueError(\n                'Batched CROWN is not compatible with '\n                f'need_A_only={need_A_only}, average_A={average_A}')\n        ret = self.batched_backward(\n            bound_node, C, unstable_idx,\n            batch_size=roots[0].value.shape[0],\n            bound_lower=bound_lower, bound_upper=bound_upper,\n        )\n        bound_node.lower, bound_node.upper = ret[:2]\n        return ret\n\n    for n in self.nodes():\n        n.lA = n.uA = None\n\n    degree_out = get_degrees(start_backpropagation_at_node)\n    C, batch_size, output_dim, output_shape = self._preprocess_C(C, bound_node)\n\n    if initial_As is None:\n        start_backpropagation_at_node.lA = C if bound_lower else None\n        start_backpropagation_at_node.uA = C if bound_upper else None\n    else:\n        for layer_name, (lA, uA) in initial_As.items():\n            self[layer_name].lA = lA\n            self[layer_name].uA = uA\n        assert start_backpropagation_at_node.lA is not None or start_backpropagation_at_node.uA is not None\n    if initial_lb is None:\n        lb = torch.tensor(0., device=self.device)\n    else:\n        lb = initial_lb\n    if initial_ub is None:\n        ub = torch.tensor(0., device=self.device)\n    else:\n        ub = initial_ub\n\n    # Save intermediate layer A matrices when required.\n    A_record = {}\n\n    queue = deque([start_backpropagation_at_node])\n    while len(queue) > 0:\n        l = queue.popleft()  # backward from l\n\n        if l.name in self.root_names:\n            continue\n\n        # if all the succeeds are done, then we can turn to this node in the\n        # next iteration.\n        for l_pre in l.inputs:\n            degree_out[l_pre.name] -= 1\n            if degree_out[l_pre.name] == 0:\n                queue.append(l_pre)\n\n        # Initially, l.lA or l.uA will be set to C for this node.\n        if l.lA is not None or l.uA is not None:\n            if self.verbose:\n                logger.debug(f'  Bound backward to {l} (out shape {l.output_shape})')\n                if l.lA is not None:\n                    logger.debug('    lA type %s shape %s',\n                                 type(l.lA), list(l.lA.shape))\n                if l.uA is not None:\n                    logger.debug('    uA type %s shape %s',\n                                 type(l.uA), list(l.uA.shape))\n\n            if _print_time:\n                start_time = time.time()\n\n            self.backward_from[l.name].append(bound_node)\n\n            if not l.perturbed:\n                if not hasattr(l, 'forward_value'):\n                    self.get_forward_value(l)\n                lb, ub = add_constant_node(lb, ub, l)\n                continue\n\n            if l.zero_uA_mtx and l.zero_lA_mtx:\n                # A matrices are all zero, no need to propagate.\n                continue\n\n            lA, uA = l.lA, l.uA\n            if (l.name != start_backpropagation_at_node.name and use_beta_crown\n                    and getattr(l, 'sparse_betas', None)):\n                lA, uA, lbias, ubias = self.beta_crown_backward_bound(\n                    l, lA, uA, start_node=start_backpropagation_at_node)\n                lb = lb + lbias\n                ub = ub + ubias\n\n            if isinstance(l, BoundOptimizableActivation):\n                # For other optimizable activation functions (TODO: unify with ReLU).\n                if bound_node.name != self.final_node_name:\n                    start_shape = bound_node.output_shape[1:]\n                else:\n                    start_shape = C.shape[0]\n                l.preserve_mask = update_mask\n            else:\n                start_shape = None\n            A, lower_b, upper_b = l.bound_backward(\n                lA, uA, *l.inputs,\n                start_node=bound_node, unstable_idx=unstable_idx,\n                start_shape=start_shape)\n\n            # After propagation through this node, we delete its lA, uA variables.\n            if bound_node.name != self.final_name:\n                del l.lA, l.uA\n            if _print_time:\n                torch.cuda.synchronize()\n                time_elapsed = time.time() - start_time\n                if time_elapsed > 5e-3:\n                    print(l, time_elapsed)\n            if lb.ndim > 0 and type(lower_b) == Tensor and self.conv_mode == 'patches':\n                lb, ub, lower_b, upper_b = check_patch_biases(lb, ub, lower_b, upper_b)\n            lb = lb + lower_b\n            ub = ub + upper_b\n            if self.return_A and self.needed_A_dict and bound_node.name in self.needed_A_dict:\n                # FIXME remove [0][0] and [0][1]?\n                if len(self.needed_A_dict[bound_node.name]) == 0 or l.name in self.needed_A_dict[bound_node.name]:\n                    # A could be either patches (in this case we cannot transpose so directly return)\n                    # or matrix (in this case we transpose)\n                    A_record.update({\n                        l.name: {\n                            \"lA\": (\n                                A[0][0].detach() if isinstance(A[0][0], Patches)\n                                else A[0][0].transpose(0, 1).detach()\n                            ) if A[0][0] is not None else None,\n                            \"uA\": (\n                                A[0][1].detach() if isinstance(A[0][1], Patches)\n                                else A[0][1].transpose(0, 1).detach()\n                            ) if A[0][1] is not None else None,\n                            # When not used, lb or ub is tensor(0).\n                            \"lbias\": lb.transpose(0, 1).detach() if lb.ndim > 1 else None,\n                            \"ubias\": ub.transpose(0, 1).detach() if ub.ndim > 1 else None,\n                            \"unstable_idx\": unstable_idx\n                    }})\n                # FIXME: solve conflict with the following case\n                self.A_dict.update({bound_node.name: A_record})\n                if need_A_only and set(self.needed_A_dict[bound_node.name]) == set(A_record.keys()):\n                    # We have collected all A matrices we need. We can return now!\n                    self.A_dict.update({bound_node.name: A_record})\n                    # Do not concretize to save time. We just need the A matrices.\n                    # return A matrix as a dict: {node_start.name: [A_lower, A_upper]}\n                    return None, None, self.A_dict\n\n            for i, l_pre in enumerate(l.inputs):\n                add_bound(l, l_pre, lA=A[i][0], uA=A[i][1])\n\n    if lb.ndim >= 2:\n        lb = lb.transpose(0, 1)\n    if ub.ndim >= 2:\n        ub = ub.transpose(0, 1)\n\n    # TODO merge into `concretize`\n    if (self.cut_used and getattr(self, 'cut_module', None) is not None\n            and self.cut_module.x_coeffs is not None):\n        # propagate input neuron in cut constraints\n        roots[0].lA, roots[0].uA = self.cut_module.input_cut(\n            bound_node, roots[0].lA, roots[0].uA, roots[0].lower.size()[1:], unstable_idx,\n            batch_mask=update_mask)\n\n    lb, ub = self.concretize_bounds(\n        node=bound_node,\n        lower=lb,\n        upper=ub,\n        concretize_mode='backward',\n        batch_size=batch_size,\n        output_dim=output_dim,\n        average_A=average_A,\n        clip_neuron_selection_value=self.clip_neuron_selection_value,\n        clip_neuron_selection_type=self.clip_neuron_selection_type\n    )\n\n    if self.return_A and self.needed_A_dict and bound_node.name in self.needed_A_dict:\n        save_root_A(\n            bound_node, A_record, self.A_dict, roots,\n            self.needed_A_dict[bound_node.name],\n            lb=lb, ub=ub, unstable_idx=unstable_idx)\n    for root in self.roots():\n        # These are saved for `save_root_A`. We do not need them afterwards.\n        root.lb = root.ub = None\n\n    if tighten_input_bounds and isinstance(bound_node, BoundInput):\n        shape = bound_node.perturbation.x_L.shape\n        lb_reshaped = lb.reshape(shape)\n        bound_node.perturbation.x_L = lb_reshaped - lb_reshaped.detach() + torch.max(bound_node.perturbation.x_L.detach(), lb_reshaped.detach())\n        ub_reshaped = ub.reshape(shape)\n        bound_node.perturbation.x_U = ub_reshaped - ub_reshaped.detach() + torch.min(bound_node.perturbation.x_U.detach(), ub_reshaped.detach())\n\n    lb = lb.view(batch_size, *output_shape) if bound_lower else None\n    ub = ub.view(batch_size, *output_shape) if bound_upper else None\n\n    # TODO merge into `concretize`\n    if (self.cut_used and getattr(self, \"cut_module\", None) is not None\n            and self.cut_module.cut_bias is not None):\n        # propagate cut bias in cut constraints\n        lb, ub = self.cut_module.bias_cut(bound_node, lb, ub, unstable_idx, batch_mask=update_mask)\n        if lb is not None and ub is not None and ((lb-ub)>0).sum().item() > 0:\n            # make sure there is no bug for cut constraints propagation\n            print(f\"Warning: lb is larger than ub with diff: {(lb-ub)[(lb-ub)>0].max().item()}\")\n\n    if self.verbose:\n        logger.debug('')\n\n    if self.invprop_enabled():\n        lb, ub = self.invprop_check_infeasible_bounds(lb, ub)\n\n    if self.return_A:\n        if self.bound_opts['clip_in_alpha_crown'] and self.final_name in self.A_dict.keys():\n            for v in self.A_dict[self.final_name].values():\n                if v[\"lA\"] is not None:\n                    self.constraints_optimized = (v[\"lA\"], v[\"lbias\"])\n        return lb, ub, self.A_dict\n    else:\n        return lb, ub\n\n\ndef get_unstable_size(unstable_idx):\n    if isinstance(unstable_idx, tuple):\n        return unstable_idx[0].numel()\n    else:\n        return unstable_idx.numel()\n\n\ndef check_optimized_variable_sparsity(self: 'BoundedModule', node):\n    alpha_sparsity = None  # unknown, optimizable variables are not created for this node.\n    for relu in self.relus:\n        # FIXME: this hardcoded for ReLUs. Need to support other optimized nonlinear functions.\n        # alpha_lookup_idx is only created for sparse-spec alphas.\n        if relu.alpha_lookup_idx is not None and node.name in relu.alpha_lookup_idx:\n            if relu.alpha_lookup_idx[node.name] is not None:\n                # This node was created with sparse alpha\n                alpha_sparsity = True\n            elif self.bound_opts['optimize_bound_args']['use_shared_alpha']:\n                # Shared alpha, the spec dimension is 1, and sparsity can be supported.\n                alpha_sparsity = True\n            else:\n                alpha_sparsity = False\n            break\n    return alpha_sparsity\n\n\ndef get_sparse_C(self: 'BoundedModule', node, ref_intermediate):\n    (sparse_intermediate_bounds,\n     ref_intermediate_lb, ref_intermediate_ub) = ref_intermediate\n    sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False)\n    minimum_sparsity = self.bound_opts.get('minimum_sparsity', 0.9)\n    crown_batch_size = self.bound_opts.get('crown_batch_size', 1e9)\n    dim = int(prod(node.output_shape[1:]))\n    batch_size = self.batch_size\n\n    reduced_dim = False  # Only partial neurons (unstable neurons) are bounded.\n    unstable_idx = None\n    unstable_size = np.inf\n    newC = None\n\n    alpha_is_sparse = self.check_optimized_variable_sparsity(node)\n\n    # NOTE: batched CROWN is so far only supported for some of the cases below\n\n    # FIXME: C matrix shape incorrect for BoundParams.\n    if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int(\n            os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0:\n        if sparse_intermediate_bounds:\n            # If we are doing bound refinement and reference bounds are given,\n            # we only refine unstable neurons.\n            # Also, if we are checking against LP solver we will refine all\n            # neurons and do not use this optimization.\n            # For each batch element, we find the unstable neurons.\n            unstable_idx, unstable_size = self.get_unstable_locations(\n                ref_intermediate_lb, ref_intermediate_ub)\n            if unstable_size == 0:\n                # Do nothing, no bounds will be computed.\n                reduced_dim = True\n                unstable_idx = []\n            elif unstable_size > crown_batch_size:\n                # Create C in batched CROWN\n                newC = BatchedCrownC('OneHot')\n                reduced_dim = True\n            elif (((0 < unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or\n                   alpha_is_sparse) and\n                   len(node.output_shape) <= 2):\n                # When we already have sparse alpha for this layer, we always\n                # use sparse C. Otherwise we determine it by sparsity.\n                # Create an abstract C matrix, the unstable_idx are the non-zero\n                # elements in specifications for all batches.\n                # Shouldn't use OneHotC if the output is not a 1-d tensor.\n                newC = OneHotC(\n                    [batch_size, unstable_size, *node.output_shape[1:]],\n                    self.device, unstable_idx, None)\n                reduced_dim = True\n            else:\n                unstable_idx = None\n                del ref_intermediate_lb, ref_intermediate_ub\n        if not reduced_dim:\n            if dim > crown_batch_size:\n                newC = BatchedCrownC('eye')\n            else:\n                newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device)\n    elif node.patches_start and node.mode == \"patches\":\n        if sparse_intermediate_bounds:\n            unstable_idx, unstable_size = self.get_unstable_locations(\n                ref_intermediate_lb, ref_intermediate_ub, conv=True)\n            if unstable_size == 0:\n                # Do nothing, no bounds will be computed.\n                reduced_dim = True\n                unstable_idx = []\n            elif unstable_size > crown_batch_size:\n                # Create C in batched CROWN\n                newC = BatchedCrownC('Patches')\n                reduced_dim = True\n            # We sum over the channel direction, so need to multiply that.\n            elif (sparse_conv_intermediate_bounds\n                  and unstable_size <= minimum_sparsity * dim\n                  and alpha_is_sparse is None) or alpha_is_sparse:\n                # When we already have sparse alpha for this layer, we always\n                # use sparse C. Otherwise we determine it by sparsity.\n                # Create an abstract C matrix, the unstable_idx are the non-zero\n                # elements in specifications for all batches.\n                # The shape of patches is [unstable_size, batch, C, H, W].\n                newC = Patches(\n                    shape=[unstable_size, batch_size, *node.output_shape[1:-2],\n                           1, 1],\n                    identity=1, unstable_idx=unstable_idx,\n                    output_shape=[batch_size, *node.output_shape[1:]])\n                reduced_dim = True\n            else:\n                unstable_idx = None\n                del ref_intermediate_lb, ref_intermediate_ub\n        # Here we create an Identity Patches object\n        if not reduced_dim:\n            newC = Patches(\n                None, 1, 0, [node.output_shape[1], batch_size, *node.output_shape[2:],\n                *node.output_shape[1:-2], 1, 1], 1,\n                output_shape=[batch_size, *node.output_shape[1:]])\n    elif (isinstance(node, (BoundAdd, BoundSub)) and node.mode == \"patches\"\n        and len(node.output_shape) >= 4):\n        # FIXME: BoundAdd does not always have patches. Need to use a better way\n        # to determine patches mode.\n        # FIXME: We should not hardcode BoundAdd here!\n        if sparse_intermediate_bounds:\n            if crown_batch_size < 1e9:\n                warnings.warn('Batched CROWN is not supported in this case')\n            unstable_idx, unstable_size = self.get_unstable_locations(\n                ref_intermediate_lb, ref_intermediate_ub, conv=True)\n            if unstable_size == 0:\n                # Do nothing, no bounds will be computed.\n                reduced_dim = True\n                unstable_idx = []\n            elif (sparse_conv_intermediate_bounds\n                  and unstable_size <= minimum_sparsity * dim\n                  and alpha_is_sparse is None) or alpha_is_sparse:\n                # When we already have sparse alpha for this layer, we always\n                # use sparse C. Otherwise we determine it by sparsity.\n                num_channel = node.output_shape[-3]\n                # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).\n                patches = (\n                    torch.eye(num_channel, device=self.device,\n                    dtype=list(self.parameters())[0].dtype)).view(\n                        num_channel, 1, 1, 1, num_channel, 1, 1)\n                # Expand to (out_c, 1, unstable_size, out_c, 1, 1).\n                patches = patches.expand(-1, 1, node.output_shape[-2],\n                                         node.output_shape[-1], -1, 1, 1)\n                patches = patches[unstable_idx[0], :,\n                                  unstable_idx[1], unstable_idx[2]]\n                # Expand with the batch dimension. Final shape\n                # (unstable_size, batch_size, out_c, 1, 1).\n                patches = patches.expand(-1, batch_size, -1, -1, -1)\n                newC = Patches(\n                    patches, 1, 0, patches.shape, unstable_idx=unstable_idx,\n                    output_shape=[batch_size, *node.output_shape[1:]])\n                reduced_dim = True\n            else:\n                unstable_idx = None\n                del ref_intermediate_lb, ref_intermediate_ub\n        if not reduced_dim:\n            num_channel = node.output_shape[-3]\n            # Identity patch size: (ouc_c, 1, 1, 1, out_c, 1, 1).\n            patches = (\n                torch.eye(num_channel, device=self.device,\n                dtype=list(self.parameters())[0].dtype)).view(\n                    num_channel, 1, 1, 1, num_channel, 1, 1)\n            # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1).\n            patches = patches.expand(-1, batch_size, node.output_shape[-2],\n                                     node.output_shape[-1], -1, 1, 1)\n            newC = Patches(patches, 1, 0, patches.shape, output_shape=[\n                batch_size, *node.output_shape[1:]])\n    else:\n        if sparse_intermediate_bounds:\n            unstable_idx, unstable_size = self.get_unstable_locations(\n                ref_intermediate_lb, ref_intermediate_ub)\n            if unstable_size == 0:\n                # Do nothing, no bounds will be computed.\n                reduced_dim = True\n                unstable_idx = []\n            elif unstable_size > crown_batch_size:\n                # Create in C in batched CROWN\n                newC = BatchedCrownC('eye')\n                reduced_dim = True\n            elif (unstable_size <= minimum_sparsity * dim\n                  and alpha_is_sparse is None) or alpha_is_sparse:\n                newC = torch.zeros([1, unstable_size, dim], device=self.device)\n                # Fill the corresponding elements to 1.0\n                newC[0, torch.arange(unstable_size), unstable_idx] = 1.0\n                newC = newC.expand(batch_size, -1, -1).view(\n                    batch_size, unstable_size, *node.output_shape[1:])\n                reduced_dim = True\n            else:\n                unstable_idx = None\n                del ref_intermediate_lb, ref_intermediate_ub\n        if not reduced_dim:\n            if dim > 1000:\n                warnings.warn(\n                    f\"Creating an identity matrix with size {dim}x{dim} for node {node}. \"\n                    \"This may indicate poor performance for bound computation. \"\n                    \"If you see this message on a small network please submit \"\n                    \"a bug report.\", stacklevel=2)\n            if dim > crown_batch_size:\n                newC = BatchedCrownC('eye')\n            else:\n                newC = torch.eye(dim, device=self.device).unsqueeze(0).expand(\n                    batch_size, -1, -1\n                ).view(batch_size, dim, *node.output_shape[1:])\n\n    return newC, reduced_dim, unstable_idx, unstable_size\n\n\ndef restore_sparse_bounds(self: 'BoundedModule', node, unstable_idx,\n                          unstable_size, ref_intermediate,\n                          new_lower=None, new_upper=None):\n    ref_intermediate_lb, ref_intermediate_ub = ref_intermediate[1:]\n    batch_size = self.batch_size\n    if unstable_size == 0:\n        # No unstable neurons. Skip the update.\n        node.lower = ref_intermediate_lb.detach().clone()\n        node.upper = ref_intermediate_ub.detach().clone()\n    else:\n        if new_lower is None:\n            new_lower = node.lower\n        if new_upper is None:\n            new_upper = node.upper\n        # If we only calculated unstable neurons, we need to scatter the results back based on reference bounds.\n        if isinstance(unstable_idx, tuple):\n            lower = ref_intermediate_lb.detach().clone()\n            upper = ref_intermediate_ub.detach().clone()\n            # Conv layer with patches, the unstable_idx is a 3-element tuple for 3 indices (C, H,W) of unstable neurons.\n            if len(unstable_idx) == 3:\n                lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_lower\n                upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_upper\n            elif len(unstable_idx) == 4:\n                lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_lower\n                upper[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], unstable_idx[3]] = new_upper\n        else:\n            # Other layers.\n            lower = ref_intermediate_lb.detach().clone().reshape(batch_size, -1)\n            upper = ref_intermediate_ub.detach().clone().reshape(batch_size, -1)\n            lower[:, unstable_idx] = new_lower.view(batch_size, -1)\n            upper[:, unstable_idx] = new_upper.view(batch_size, -1)\n        node.lower = lower.view(batch_size, *node.output_shape[1:])\n        node.upper = upper.view(batch_size, *node.output_shape[1:])\n\n\ndef get_degrees(node_start):\n    if not isinstance(node_start, list):\n        node_start = [node_start]\n    degrees = {}\n    added = {}\n    queue = deque()\n    for node in node_start:\n        queue.append(node)\n        added[node.name] = True\n    while len(queue) > 0:\n        l = queue.popleft()\n        for l_pre in l.inputs:\n            degrees[l_pre.name] = degrees.get(l_pre.name, 0) + 1\n            if not added.get(l_pre.name, False):\n                queue.append(l_pre)\n                added[l_pre.name] = True\n    return degrees\n\n\ndef _preprocess_C(self: 'BoundedModule', C, node):\n    if isinstance(C, Patches):\n        if C.unstable_idx is None:\n            # Patches have size (out_c, batch, out_h, out_w, c, h, w).\n            if len(C.shape) == 7:\n                out_c, batch_size, out_h, out_w = C.shape[:4]\n                output_dim = out_c * out_h * out_w\n            else:\n                out_dim, batch_size, out_c, out_h, out_w = C.shape[:5]\n                output_dim = out_dim * out_c * out_h * out_w\n        else:\n            # Patches have size (unstable_size, batch, c, h, w).\n            output_dim, batch_size = C.shape[:2]\n    else:\n        batch_size, output_dim = C.shape[:2]\n\n    # The C matrix specified by the user has shape (batch, spec)\n    # but internally we have (spec, batch) format.\n    if not isinstance(C, (eyeC, Patches, OneHotC)):\n        C = C.transpose(0, 1).reshape(\n            output_dim, batch_size, *node.output_shape[1:])\n    elif isinstance(C, eyeC):\n        C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:]))\n    elif isinstance(C, OneHotC):\n        C = C._replace(\n            shape=(C.shape[1], C.shape[0], *C.shape[2:]),\n            index=C.index.transpose(0,-1),\n            coeffs=None if C.coeffs is None else C.coeffs.transpose(0,-1))\n\n    if isinstance(C, Patches) and C.unstable_idx is not None:\n        # Sparse patches; the output shape is (unstable_size, ).\n        output_shape = [C.shape[0]]\n    elif prod(node.output_shape[1:]) != output_dim and not isinstance(C, Patches):\n        # For the output node, the shape of the bound follows C\n        # instead of the original output shape\n        #\n        # TODO Maybe don't set node.lower and node.upper in this case?\n        # Currently some codes still depend on node.lower and node.upper\n        output_shape = [-1]\n    else:\n        # Generally, the shape of the bounds match the output shape of the node\n        output_shape = node.output_shape[1:]\n\n    return C, batch_size, output_dim, output_shape\n\n\ndef addA(A1, A2):\n    \"\"\" Add two A (each of them is either Tensor or Patches) \"\"\"\n    if type(A1) == type(A2):\n        return A1 + A2\n    elif type(A1) == Patches:\n        return A1 + A2\n    elif type(A2) == Patches:\n        return A2 + A1\n    else:\n        raise NotImplementedError(f'Unsupported types for A1 ({type(A1)}) and A2 ({type(A2)}')\n\n\ndef add_bound(node, node_pre, lA=None, uA=None):\n    \"\"\"\n    Propagate lA and uA to a preceding node.\n    @param node:        The current bounded node\n    @param node_pre:    An input of the current bounded node that needs lA, lbias ,etc. back propagated to it\n    @param lA:          lA matrix associated with the current bounded node\n    @param uA:          uA matrix associated with the current bounded node\n    @return:\n    \"\"\"\n\n    if lA is not None:\n        if node_pre.lA is None:\n            # First A added to this node.\n            node_pre.zero_lA_mtx = node.zero_backward_coeffs_l\n            node_pre.lA = lA\n        else:\n            node_pre.zero_lA_mtx = node_pre.zero_lA_mtx and node.zero_backward_coeffs_l\n            new_node_lA = addA(node_pre.lA, lA)\n            node_pre.lA = new_node_lA\n    if uA is not None:\n        if node_pre.uA is None:\n            # First A added to this node.\n            node_pre.zero_uA_mtx = node_pre.zero_backward_coeffs_u\n            node_pre.uA = uA\n        else:\n            node_pre.zero_uA_mtx = node_pre.zero_uA_mtx and node.zero_backward_coeffs_u\n            node_pre.uA = addA(node_pre.uA, uA)\n\n\ndef add_constant_node(lb, ub, node):\n    new_lb = node.get_bias(node.lA, node.forward_value)\n    new_ub = node.get_bias(node.uA, node.forward_value)\n    if isinstance(lb, Tensor) and isinstance(new_lb, Tensor) and lb.ndim > 0 and lb.ndim != new_lb.ndim:\n        new_lb = new_lb.reshape(lb.shape)\n    if isinstance(ub, Tensor) and isinstance(new_ub, Tensor) and ub.ndim > 0 and ub.ndim != new_ub.ndim:\n        new_ub = new_ub.reshape(ub.shape)\n    lb = lb + new_lb # FIXME (09/16): shape for the bias of BoundConstant.\n    ub = ub + new_ub\n    return lb, ub\n\n\ndef save_root_A(node, A_record, A_dict, roots, needed_A_dict, lb, ub,\n                unstable_idx):\n    root_A_record = {}\n    for i in range(len(roots)):\n        if roots[i].lA is None and roots[i].uA is None:\n            continue\n        if roots[i].name in needed_A_dict:\n            if roots[i].lA is not None:\n                if isinstance(roots[i].lA, Patches):\n                    _lA = roots[i].lA.detach()\n                else:\n                    _lA = roots[i].lA.transpose(0, 1).detach()\n            else:\n                _lA = None\n            if roots[i].uA is not None:\n                if isinstance(roots[i].uA, Patches):\n                    _uA = roots[i].uA.detach()\n                else:\n                    _uA = roots[i].uA.transpose(0, 1).detach()\n            else:\n                _uA = None\n\n            # Include all the bias terms except the one concretized from the\n            # current root node.\n            lb_ = lb - roots[i].lb if (roots[i].lb is not None) else lb\n            ub_ = ub - roots[i].ub if (roots[i].ub is not None) else ub\n\n            root_A_record.update({roots[i].name: {\n                \"lA\": _lA,\n                \"uA\": _uA,\n                # When not used, lb or ub is tensor(0). They have been transposed above.\n                \"lbias\": lb_.detach() if lb_.ndim > 1 else None,\n                \"ubias\": ub_.detach() if ub_.ndim > 1 else None,\n                \"unstable_idx\": unstable_idx\n            }})\n\n    root_A_record.update(A_record)  # merge to existing A_record\n    A_dict.update({node.name: root_A_record})\n\n\ndef select_unstable_idx(ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size):\n    \"\"\"When there are too many unstable neurons, only bound those\n    with the loosest reference bounds.\"\"\"\n    gap = (\n        ref_intermediate_ub[:, unstable_locs]\n        - ref_intermediate_lb[:, unstable_locs]).sum(dim=0)\n    indices = torch.argsort(gap, descending=True)\n    indices_selected = indices[:max_crown_size]\n    indices_selected, _ = torch.sort(indices_selected)\n    print(f'{len(indices_selected)}/{len(indices)} unstable neurons selected for CROWN')\n    return indices_selected\n\n\ndef get_unstable_locations(self: 'BoundedModule', ref_intermediate_lb,\n                           ref_intermediate_ub, conv=False, channel_only=False):\n    # FIXME (2023): This function should be a member class of the Bound object, since the\n    # definition of unstable neurons depends on the activation function.\n    max_crown_size = self.bound_opts.get('max_crown_size', int(1e9))\n    # For conv layer we only check the case where all neurons are active/inactive.\n    unstable_masks = torch.logical_and(ref_intermediate_lb < 0, ref_intermediate_ub > 0)\n    # For simplicity, merge unstable locations for all elements in this batch. TODO: use individual unstable mask.\n    # It has shape (H, W) indicating if a neuron is unstable/stable.\n    # TODO: so far we merge over the batch dimension to allow easier implementation.\n    if channel_only:\n        # Only keep channels with unstable neurons. Used for initializing alpha.\n        unstable_locs = unstable_masks.sum(dim=(0,2,3)).bool()\n        # Shape is consistent with linear layers: a list of unstable neuron channels (no batch dim).\n        unstable_idx = unstable_locs.nonzero().squeeze(1)\n    else:\n        if not conv and unstable_masks.ndim > 2:\n            # Flatten the conv layer shape.\n            unstable_masks = unstable_masks.reshape(unstable_masks.size(0), -1)\n            ref_intermediate_lb = ref_intermediate_lb.reshape(ref_intermediate_lb.size(0), -1)\n            ref_intermediate_ub = ref_intermediate_ub.reshape(ref_intermediate_ub.size(0), -1)\n        unstable_locs = unstable_masks.sum(dim=0).bool()\n        if conv:\n            # Now converting it to indices for these unstable nuerons.\n            # These are locations (i,j) of unstable neurons.\n            unstable_idx = unstable_locs.nonzero(as_tuple=True)\n        else:\n            unstable_idx = unstable_locs.nonzero().squeeze(1)\n\n    unstable_size = get_unstable_size(unstable_idx)\n    if unstable_size > max_crown_size:\n        indices_seleted = select_unstable_idx(\n            ref_intermediate_lb, ref_intermediate_ub, unstable_locs, max_crown_size)\n        if isinstance(unstable_idx, tuple):\n            unstable_idx = tuple(u[indices_seleted] for u in unstable_idx)\n        else:\n            unstable_idx = unstable_idx[indices_seleted]\n    unstable_size = get_unstable_size(unstable_idx)\n\n    return unstable_idx, unstable_size\n\n\ndef get_alpha_crown_start_nodes(\n        self: 'BoundedModule',\n        node,\n        c=None,\n        share_alphas=False,\n        final_node_name=None,\n    ):\n    \"\"\"\n    Given a layer \"node\", return a list of following nodes after this node whose bounds\n    will propagate through this node. Each element in the list is a tuple with 3 elements:\n    (following_node_name, following_node_shape, unstable_idx)\n    \"\"\"\n    # When use_full_conv_alpha is True, conv layers do not share alpha.\n    sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False)\n    use_full_conv_alpha_thresh = self.bound_opts.get('use_full_conv_alpha_thresh', 512)\n\n    start_nodes = []\n\n    for nj in self.backward_from[node.name]:  # Pre-activation layers.\n        unstable_idx = None\n        use_sparse_conv = None  # Whether a sparse-spec alpha is used for a conv output node. None for non-conv output node.\n        use_full_conv_alpha = self.bound_opts.get('use_full_conv_alpha', False)\n\n        # Find the indices of unstable neuron, used for create sparse-feature alpha.\n        if (sparse_intermediate_bounds\n                and isinstance(node, BoundOptimizableActivation)\n                and nj.name != final_node_name and not share_alphas):\n            # Create sparse optimization variables for intermediate neurons.\n            # These are called \"sparse-spec\" alpha because we only create alpha only for\n            # the intermediate of final output nodes whose bounds are needed.\n            # \"sparse-spec\" alpha makes sense only for piece-wise linear functions.\n            # For other intermediate nodes, there is no \"unstable\" or \"stable\" neuron.\n            # FIXME: whether an layer has unstable/stable neurons should be in Bound obj.\n            # FIXME: get_unstable_locations should be a member class of ReLU.\n            if len(nj.output_name) == 1 and isinstance(self[nj.output_name[0]], (BoundRelu, BoundSignMerge, BoundMaxPool)):\n                if ((isinstance(nj, (BoundLinear, BoundMatMul)))\n                        and int(os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0):\n                    # unstable_idx has shape [neuron_size_of_nj]. Batch dimension is reduced.\n                    unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper)\n                elif isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches':\n                    if nj.name in node.patch_size:\n                        # unstable_idx has shape [channel_size_of_nj]. Batch and spatial dimensions are reduced.\n                        unstable_idx, _ = self.get_unstable_locations(\n                            nj.lower, nj.upper, channel_only=not use_full_conv_alpha, conv=True)\n                        use_sparse_conv = False  # alpha is shared among channels. Sparse-spec alpha in hw dimension not used.\n                        if use_full_conv_alpha and unstable_idx[0].size(0) > use_full_conv_alpha_thresh:\n                            # Too many unstable neurons. Using shared alpha per channel.\n                            unstable_idx, _ = self.get_unstable_locations(\n                                nj.lower, nj.upper, channel_only=True, conv=True)\n                            use_full_conv_alpha = False\n                    else:\n                        # Matrix mode for conv layers. Although the bound propagation started with patches mode,\n                        # when A matrix is propagated to this layer, it might become a dense matrix since patches\n                        # can be come very large after many layers. In this case,\n                        # unstable_idx has shape [c_out * h_out * w_out]. Batch dimension is reduced.\n                        unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper)\n                        use_sparse_conv = True  # alpha is not shared among channels, and is sparse in spec dimension.\n            else:\n                # FIXME: we should not check for fixed names here. Need to enable patches mode more generally.\n                if isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches':\n                    use_sparse_conv = False  # Sparse-spec alpha can never be used, because it is not a ReLU activation.\n\n        if nj.name == final_node_name:\n            # Final layer, always the number of specs as the shape.\n            size_final = self[final_node_name].output_shape[1:] if c is None else c.size(1)\n            # The 4-th element indicates that this start node is the final node,\n            # which may be utilized by operators that do not know the name of\n            # the final node.\n            start_nodes.append((final_node_name, size_final, None, True))\n            continue\n\n        if share_alphas:\n            # all intermediate neurons from the same layer share the same set of alphas.\n            output_shape = 1\n        elif isinstance(node, BoundOptimizableActivation) and node.patch_size and nj.name in node.patch_size:\n            # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing.\n            if use_full_conv_alpha:\n                # alphas not shared among channels, so the spec dim shape is c,h,w\n                # The patch size is [out_ch, batch, out_h, out_w, in_ch, H, W]. We use out_ch as the output shape.\n                output_shape = node.patch_size[nj.name][0], node.patch_size[nj.name][2], node.patch_size[nj.name][3]\n            else:\n                # The spec dim is c only, and is shared among h, w.\n                output_shape = node.patch_size[nj.name][0]\n            assert not sparse_intermediate_bounds or use_sparse_conv is False  # Double check our assumption holds. If this fails, then we created wrong shapes for alpha.\n        else:\n            # Output is linear layer (use_sparse_conv = None), or patch converted to matrix (use_sparse_conv = True).\n            assert not sparse_intermediate_bounds or use_sparse_conv is not False  # Double check our assumption holds. If this fails, then we created wrong shapes for alpha.\n            output_shape = nj.lower.shape[1:]  # FIXME: for non-relu activations it's still expecting a prod.\n        start_nodes.append((nj.name, output_shape, unstable_idx, False))\n\n    return start_nodes\n\n\ndef merge_A(node, batch_A, ret_A):\n    for key0 in batch_A:\n        if key0 not in ret_A: ret_A[key0] = {}\n        for key1 in batch_A[key0]:\n            value = batch_A[key0][key1]\n            if key1 not in ret_A[key0]:\n                # create:\n                ret_A[key0].update({\n                    key1: {\n                        \"lA\": value[\"lA\"],\n                        \"uA\": value[\"uA\"],\n                        \"lbias\": value[\"lbias\"],\n                        \"ubias\": value[\"ubias\"],\n                        \"unstable_idx\": value[\"unstable_idx\"]\n                    }\n                })\n            elif key0 == node.name:\n                # merge:\n                # the batch splitting only happens for current node, i.e.,\n                # for other nodes the returned lA should be the same across different batches\n                # so no need to repeatly merge them\n                exist = ret_A[key0][key1]\n\n                if exist[\"unstable_idx\"] is not None:\n                    if isinstance(exist[\"unstable_idx\"], torch.Tensor):\n                        merged_unstable = torch.cat([\n                            exist[\"unstable_idx\"],\n                            value['unstable_idx']], dim=0)\n                    elif isinstance(exist[\"unstable_idx\"], tuple):\n                        if exist[\"unstable_idx\"]:\n                            merged_unstable = tuple([\n                                torch.cat([exist[\"unstable_idx\"][idx],\n                                           value['unstable_idx'][idx]], dim=0)\n                                for idx in range(len(exist['unstable_idx']))]\n                            )\n                        else:\n                            merged_unstable = None\n                    else:\n                        raise NotImplementedError(\n                            f'Unsupported type {type(exist[\"unstable_idx\"])}')\n                else:\n                    merged_unstable = None\n                merge_dict = {\"unstable_idx\": merged_unstable}\n                for name in [\"lA\", \"uA\"]:\n                    if exist[name] is not None:\n                        if isinstance(exist[name], torch.Tensor):\n                            # for matrix the spec dim is 1\n                            merge_dict[name] = torch.cat([exist[name], value[name]], dim=1)\n                        else:\n                            assert isinstance(exist[name], Patches)\n                            # for patches the spec dim`is 0\n                            merge_dict[name] = exist[name].create_similar(\n                                torch.cat([exist[name].patches, value[name].patches], dim=0),\n                                unstable_idx=merged_unstable\n                            )\n                    else:\n                        merge_dict[name] = None\n                for name in [\"lbias\", \"ubias\"]:\n                    if exist[name] is not None:\n                        # for bias the spec dim in 1\n                        merge_dict[name] = torch.cat([exist[name], value[name]], dim=1)\n                    else:\n                        merge_dict[name] = None\n                ret_A[key0][key1] = merge_dict\n    return ret_A\n"
  },
  {
    "path": "auto_LiRPA/beta_crown.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom collections import OrderedDict\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom .patches import Patches, inplace_unfold\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\nclass SparseBeta:\n    def __init__(self, shape, bias=False, betas=None, device='cpu'):\n        self.device = device\n        self.val = torch.zeros(shape)\n        self.loc = torch.zeros(shape, dtype=torch.long, device=device)\n        self.sign = torch.zeros(shape, device=device)\n        self.bias = torch.zeros(shape, device=device) if bias else None\n        if betas:\n            for bi in range(len(betas)):\n                if betas[bi] is not None:\n                    self.val[bi, :len(betas[bi])] = betas[bi]\n        self.val = self.val.detach().to(\n            device, non_blocking=True).requires_grad_()\n\n    def apply_splits(self, history, key):\n        loc_numpy = np.zeros(self.loc.shape, dtype=np.int32)\n        sign_numpy = np.zeros(self.sign.shape)\n        if self.bias is not None:\n            bias_numpy = np.zeros(self.bias.shape)\n        for bi in range(len(history)):\n            # Add history splits. (layer, neuron) is the current decision.\n            split_locs, split_coeffs = history[bi][key][:2]\n            split_len = len(split_locs)\n            if split_len > 0:\n                sign_numpy[bi, :split_len] = split_coeffs\n                loc_numpy[bi, :split_len] = split_locs\n                if self.bias is not None:\n                    split_bias = history[bi][key][2]\n                    bias_numpy[bi, :split_len] = split_bias\n        self.loc.copy_(torch.from_numpy(loc_numpy), non_blocking=True)\n        self.sign.copy_(torch.from_numpy(sign_numpy), non_blocking=True)\n        if self.bias is not None:\n            self.bias.copy_(torch.from_numpy(bias_numpy), non_blocking=True)\n\ndef get_split_nodes(self: 'BoundedModule'):\n    self.split_nodes = []\n    self.split_activations = {}\n    splittable_activations = self.get_splittable_activations()\n    self._set_used_nodes(self[self.final_name])\n    for layer in self.layers_requiring_bounds:\n        split_activations_ = []\n        for activation_name in layer.output_name:\n            activation = self[activation_name]\n            if activation in splittable_activations:\n                split_activations_.append(\n                    (activation, activation.inputs.index(layer)))\n        if split_activations_:\n            if layer.lower is None and layer.upper is None:\n                continue\n            self.split_nodes.append(layer)\n            self.split_activations[layer.name] = split_activations_\n    return self.split_nodes, self.split_activations\n\n\ndef set_beta(self: 'BoundedModule', enable_opt_interm_bounds, parameters,\n             lr_beta, lr_cut_beta, cutter, dense_coeffs_mask):\n    \"\"\"\n    Set betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases\n    and best_biases.\n    \"\"\"\n    coeffs = None\n    betas = []\n    best_betas = OrderedDict()\n\n    # TODO compute only once\n    self.nodes_with_beta = []\n    for node in self.split_nodes:\n        if not hasattr(node, 'sparse_betas'):\n            continue\n        self.nodes_with_beta.append(node)\n        if enable_opt_interm_bounds:\n            for sparse_beta in node.sparse_betas.values():\n                if sparse_beta is not None:\n                    betas.append(sparse_beta.val)\n            best_betas[node.name] = {\n                beta_m: sparse_beta.val.detach().clone()\n                for beta_m, sparse_beta in node.sparse_betas.items()\n            }\n        else:\n            betas.append(node.sparse_betas[0].val)\n            best_betas[node.name] = node.sparse_betas[0].val.detach().clone()\n\n    # Beta has shape (batch, max_splits_per_layer)\n    parameters.append({\n        'params': [item for item in betas if item.numel() > 0],\n        'lr': lr_beta, 'batch_dim': 0})\n\n    if self.cut_used:\n        self.set_beta_cuts(parameters, lr_cut_beta, betas, best_betas, cutter)\n\n    return betas, best_betas, coeffs, dense_coeffs_mask\n\n\ndef set_beta_cuts(self: 'BoundedModule', parameters, lr_cut_beta, betas,\n                  best_betas, cutter):\n    # also need to optimize cut betas\n    parameters.append({'params': self.cut_beta_params,\n                        'lr': lr_cut_beta, 'batch_dim': 0})\n    betas += self.cut_beta_params\n    best_betas['cut'] = [beta.detach().clone() for beta in self.cut_beta_params]\n    if getattr(cutter, 'opt', False):\n        parameters.append(cutter.get_parameters())\n\n\ndef reset_beta(self: 'BoundedModule', node, shape, betas, bias=False,\n               start_nodes=None):\n    # Create only the non-zero beta. For each layer, it is padded to maximal length.\n    # We create tensors on CPU first, and they will be transferred to GPU after initialized.\n    if self.bound_opts.get('enable_opt_interm_bounds', False):\n        node.sparse_betas = {\n            key: SparseBeta(\n                shape,\n                betas=[(betas[j][i] if betas[j] is not None else None)\n                        for j in range(len(betas))],\n                device=self.device, bias=bias,\n            ) for i, key in enumerate(start_nodes)\n        }\n    else:\n        node.sparse_betas = [SparseBeta(\n            shape, betas=betas, device=self.device, bias=bias)]\n\n\ndef beta_crown_backward_bound(self: 'BoundedModule', node, lA, uA, start_node=None):\n    \"\"\"Update A and bias with Beta-CROWN.\n\n    Must be explicitly called at the end of \"bound_backward\".\n    \"\"\"\n    # Regular Beta CROWN with single neuron split\n    # Each split constraint only has single neuron (e.g., second ReLU neuron > 0).\n    A = lA if lA is not None else uA\n    lbias = ubias = 0\n\n    def _bias_unsupported():\n        raise NotImplementedError('Bias for beta not supported in this case.')\n\n    if type(A) is Patches:\n        if not self.bound_opts.get('enable_opt_interm_bounds', False):\n            raise NotImplementedError('Sparse beta not supported in the patches mode')\n        if node.sparse_betas[start_node.name].bias is not None:\n            _bias_unsupported()\n        # expand sparse_beta to full beta\n        beta_values = (node.sparse_betas[start_node.name].val\n                       * node.sparse_betas[start_node.name].sign)\n        beta_indices = node.sparse_betas[start_node.name].loc\n        node.masked_beta = torch.zeros(2, *node.shape).reshape(2, -1).to(A.patches.dtype)\n        node.non_deter_scatter_add(\n            node.masked_beta, dim=1, index=beta_indices,\n            src=beta_values.to(node.masked_beta.dtype))\n        node.masked_beta = node.masked_beta.reshape(2, *node.shape)\n        # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W)\n        A_patches = A.patches\n        masked_beta_unfolded = inplace_unfold(\n            node.masked_beta, kernel_size=A_patches.shape[-2:],\n            padding=A.padding, stride=A.stride,\n            inserted_zeros=A.inserted_zeros, output_padding=A.output_padding)\n        if A.unstable_idx is not None:\n            masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4, 5)\n            # After selection, the shape is (unstable_size, batch, in_c, H, W).\n            masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]]\n        else:\n            # Add the spec (out_c) dimension.\n            masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0)\n        if node.alpha_beta_update_mask is not None:\n            masked_beta_unfolded = masked_beta_unfolded[node.alpha_beta_update_mask]\n        if uA is not None:\n            uA = uA.create_similar(uA.patches + masked_beta_unfolded)\n        if lA is not None:\n            lA = lA.create_similar(lA.patches - masked_beta_unfolded)\n    elif type(A) is Tensor:\n        if self.bound_opts.get('enable_opt_interm_bounds', False):\n            if node.sparse_betas[start_node.name].bias is not None:\n                _bias_unsupported()\n            # For matrix mode, beta is sparse.\n            beta_values = (\n                node.sparse_betas[start_node.name].val\n                * node.sparse_betas[start_node.name].sign\n            ).expand(A.size(0), -1, -1)\n            # node.single_beta_loc has shape [batch, max_single_split].\n            # Need to expand at the specs dimension.\n            beta_indices = (node.sparse_betas[start_node.name].loc\n                            .unsqueeze(0).expand(A.size(0), -1, -1))\n            beta_bias = node.sparse_betas[start_node.name].bias\n        else:\n            # For matrix mode, beta is sparse.\n            beta_values = (\n                node.sparse_betas[0].val * node.sparse_betas[0].sign\n            ).expand(A.size(0), -1, -1)\n            # self.single_beta_loc has shape [batch, max_single_split].\n            # Need to expand at the specs dimension.\n            beta_indices = node.sparse_betas[0].loc.unsqueeze(0).expand(A.size(0), -1, -1)\n            beta_bias = node.sparse_betas[0].bias\n        # For conv layer, the last dimension is flattened in indices.\n        beta_values = beta_values.to(A.dtype)\n        if beta_bias is not None:\n            beta_bias = beta_bias.expand(A.size(0), -1, -1)\n        if node.alpha_beta_update_mask is not None:\n            beta_indices = beta_indices[:, node.alpha_beta_update_mask]\n            beta_values = beta_values[:, node.alpha_beta_update_mask]\n            if beta_bias is not None:\n                beta_bias = beta_bias[:, node.alpha_beta_update_mask]\n        if uA is not None:\n            uA = node.non_deter_scatter_add(\n                uA.reshape(uA.size(0), uA.size(1), -1), dim=2,\n                index=beta_indices, src=beta_values).view(uA.size())\n        if lA is not None:\n            lA = node.non_deter_scatter_add(\n                lA.reshape(lA.size(0), lA.size(1), -1), dim=2,\n                index=beta_indices, src=beta_values.neg()).view(lA.size())\n        if beta_bias is not None:\n            bias = (beta_values * beta_bias).sum(dim=-1)\n            lbias = bias\n            ubias = -bias\n    else:\n        raise RuntimeError(f\"Unknown type {type(A)} for A\")\n\n    return lA, uA, lbias, ubias\n\n\ndef print_optimized_beta(acts):\n    masked_betas = []\n    for model in acts:\n        masked_betas.append(model.masked_beta)\n        if model.history_beta_used:\n            print(f'{model.name} history beta', model.new_history_beta.squeeze())\n        if model.split_beta_used:\n            print(f'{model.name} split beta:', model.split_beta.view(-1))\n            print(f'{model.name} bias:', model.split_bias)\n"
  },
  {
    "path": "auto_LiRPA/bound_general.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\nimport copy\nfrom typing import List\nimport numpy as np\nimport warnings\nfrom collections import OrderedDict, deque\n\nimport torch\nfrom torch.nn import Parameter\n\nfrom .bound_op_map import bound_op_map\nfrom .bound_ops import *\nfrom .bounded_tensor import BoundedTensor, BoundedParameter\nfrom .parse_graph import parse_module\nfrom .perturbations import *\nfrom .utils import *\nfrom .patches import Patches\nfrom .optimized_bounds import default_optimize_bound_args\n\n\nwarnings.simplefilter('once')\n\n\nclass BoundedModule(nn.Module):\n    \"\"\"Bounded module with support for automatically computing bounds.\n\n    Args:\n        model (nn.Module): The original model to be wrapped by BoundedModule.\n\n        global_input (tuple): A dummy input to the original model. The shape of\n        the dummy input should be consistent with the actual input to the model\n        except for the batch dimension.\n\n        bound_opts (dict): Options for bounds. See\n        `Bound Options <bound_opts.html>`_.\n\n        device (str or torch.device): Device of the bounded module.\n        If 'auto', the device will be automatically inferred from the device of\n        parameters in the original model or the dummy input.\n\n        custom_ops (dict): A dictionary of custom operators.\n        The dictionary maps operator names to their corresponding bound classes\n        (subclasses of `Bound`).\n\n    \"\"\"\n    def __init__(self, model, global_input, bound_opts=None,\n                device='auto', verbose=False, custom_ops=None):\n        super().__init__()\n        if isinstance(model, BoundedModule):\n            for key in model.__dict__.keys():\n                setattr(self, key, getattr(model, key))\n            return\n\n        self.ori_training = model.training\n\n        if bound_opts is None:\n            bound_opts = {}\n        # Default options.\n        default_bound_opts = {\n            'conv_mode': 'patches',\n            'sparse_intermediate_bounds': True,\n            'sparse_conv_intermediate_bounds': True,\n            'sparse_intermediate_bounds_with_ibp': True,\n            'sparse_features_alpha': True,\n            'sparse_spec_alpha': True,\n            'minimum_sparsity': 0.9,\n            'enable_opt_interm_bounds': False,\n            'crown_batch_size': np.inf,\n            'forward_refinement': False,\n            'forward_max_dim': int(1e9),\n            # Do not share alpha for conv layers.\n            'use_full_conv_alpha': True,\n            'disabled_optimization': [],\n            # Threshold for number of unstable neurons for each layer to disable\n            #  use_full_conv_alpha.\n            'use_full_conv_alpha_thresh': 512,\n            'verbosity': 1 if verbose else 0,\n            'optimize_graph': {'optimizer': None},\n            'compare_crown_with_ibp': False,\n            # Whether run an additional forward pass before computing bounds.\n            'forward_before_compute_bounds': False,\n            'clip_in_alpha_crown': False,\n            # Whether to compute bounds for every node in the graph.\n            # (rather than only the nodes whose intermediate bounds are needed.)\n            'bound_every_node': False,\n        }\n        default_bound_opts.update(bound_opts)\n        self.bound_opts = default_bound_opts\n        optimize_bound_args = copy.deepcopy(default_optimize_bound_args)\n        optimize_bound_args.update(\n            self.bound_opts.get('optimize_bound_args', {}))\n        self.bound_opts.update({'optimize_bound_args': optimize_bound_args})\n\n        self.verbose = verbose\n        self.custom_ops = custom_ops if custom_ops is not None else {}\n        if device == 'auto':\n            try:\n                self.device = next(model.parameters()).device\n            except StopIteration:\n                # Model has no parameters. We use the device of input tensor.\n                if isinstance(global_input, torch.Tensor):\n                    self.device = global_input.device\n                elif isinstance(global_input, tuple):\n                    self.device = global_input[0].device\n                else:\n                    raise NotImplementedError( # pylint: disable=raise-missing-from\n                        'Unable to decide the device. Consider providing a '\n                        '`device` argument to `BoundedModule` explicitly.')\n        else:\n            self.device = device\n\n        self.global_input = tuple(unpack_inputs(global_input, device=self.device))\n        self.check_incompatible_nodes(model)\n\n        self.conv_mode = self.bound_opts.get('conv_mode', 'patches')\n        # Cached IBP results which may be reused\n        self.ibp_lower, self.ibp_upper = None, None\n\n        self.optimizable_activations = []\n        self.relus = []  # save relu layers for convenience\n        self.layers_with_constraint = []\n\n        state_dict_copy = copy.deepcopy(model.state_dict())\n        object.__setattr__(self, 'ori_state_dict', state_dict_copy)\n        model.to(self.device)\n        output = model(*self.global_input)\n        if not isinstance(output, torch.Tensor):\n            raise TypeError(\n                'Output of the model is expected to be a single torch.Tensor. '\n                f'Actual type: {type(output)}')\n        self.final_shape = output.shape\n        self.bound_opts.update({'final_shape': self.final_shape})\n        self._convert(model, self.global_input)\n        self._optimize_graph()\n        # Compute forward_value and mark perturbed nodes\n        self.forward(*self.global_input)\n        self._expand_jacobian()\n        self._check_patches_mode()\n\n        self.next_split_hint = []  # Split hints, used in beta optimization.\n        # Beta values for all intermediate bounds.\n        # Set to None (not used) by default.\n        self.best_intermediate_betas = None\n        # Initialization value for intermediate betas.\n        self.init_intermediate_betas = None\n        # whether using cut\n        self.cut_used = False\n        # a placeholder for cut timestamp, which would be a non-positive int\n        self.cut_timestamp = -1\n        # a placeholder to save the latest samplewise mask for\n        # pruning-in-iteration optimization\n        self.last_update_preserve_mask = None\n        # If output constraints are used, it is possible that none of the possible\n        # inputs satisfy them. In this case, the lower bounds will be set to +inf,\n        # and the upper bounds to -inf.\n        self.infeasible_bounds = None\n        self.solver_model = None\n        # Needed for output constraints - the output layer should not use them\n        self.final_node().is_final_node = True\n        self.dynamic = False\n        # This is the topk ratio for half-naive, half-constrained concretization.\n        # Please check for concretize_bounds.py for more details.\n        self.clip_neuron_selection_type = 'ratio'\n        self.clip_neuron_selection_value = -1.0\n        # A boolean tensor with shape (batchsize, ). It indicates if a batch is\n        # infeasible when concretizing with constraints.\n        # Always call `init_infeasible_bounds_constraints` function to initialize it.\n        self.infeasible_bounds_constraints = None\n\n        # This is designed for clipping during alpha-CROWN.\n        # For each alpha-CROWN optimization iteration, the lA and lbias of the final layer\n        #   will be set as `constraints_optimized` for the next iteration.\n        # Please check backward_bound.py and optimized_bound for more info.\n        self.constraints_optimized = None\n\n    def nodes(self) -> List[Bound]:\n        return self._modules.values()\n\n    def get_enabled_opt_act(self):\n        # Optimizable activations that are actually used and perturbed\n        return [\n            n for n in self.optimizable_activations\n            if n.used and n.perturbed and not getattr(n, 'is_linear_op', False)\n        ]\n\n    def get_optimizable_activations(self):\n        for node in self.nodes():\n            if (isinstance(node, BoundOptimizableActivation)\n                    and node.optimizable\n                    and len(getattr(node, 'requires_input_bounds', [])) > 0\n                    and node not in self.optimizable_activations):\n                disabled = False\n                for item in self.bound_opts.get('disable_optimization', []):\n                    if item.lower() in str(type(node)).lower():\n                        disabled = True\n                if disabled:\n                    logging.debug('Disabled optimization for %s', node)\n                    continue\n                if node not in self.optimizable_activations:\n                    self.optimizable_activations.append(node)\n\n    def get_perturbed_optimizable_activations(self):\n        return [n for n in self.optimizable_activations if n.perturbed]\n\n    def get_splittable_activations(self):\n        \"\"\"Activation functions that can be split during branch and bound.\"\"\"\n        return [n for n in self.nodes() if n.perturbed and n.splittable and n.used]\n\n    def get_layers_requiring_bounds(self):\n        \"\"\"Layer names whose intermediate layer bounds are required.\"\"\"\n        intermediate_layers = []\n        tighten_input_bounds = (\n            self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n        )\n        directly_optimize_layer_names = (\n            self.bound_opts['optimize_bound_args']['directly_optimize']\n        )\n        for node in self.nodes():\n            if node.name in directly_optimize_layer_names:\n                intermediate_layers.append(node)\n            if not node.used or not node.perturbed:\n                continue\n            for i in getattr(node, 'requires_input_bounds', []):\n                input_node = node.inputs[i]\n                if (input_node not in intermediate_layers\n                        and input_node.perturbed):\n                    # If not perturbed, it may not have the batch dimension.\n                    # So we do not include it, and it is unnecessary.\n                    intermediate_layers.append(input_node)\n            if (\n                node.name in self.layers_with_constraint\n                or (isinstance(node, BoundInput) and tighten_input_bounds)\n            ):\n                if node not in intermediate_layers:\n                    intermediate_layers.append(node)\n        return intermediate_layers\n\n    def check_incompatible_nodes(self, model):\n        \"\"\"Check whether the model has incompatible nodes that the conversion\n        may be inaccurate\"\"\"\n        node_types = [type(m) for m in list(model.modules())]\n\n        if (torch.nn.Dropout in node_types\n                and torch.nn.BatchNorm1d in node_types\n                and any(input.shape[0] == 1 for input in self.global_input)):\n            # In fact, we just need the input that is involved in the\n            # dropout layer to have batch size larger than 1, but we don't know\n            # which of them is, so we just check all of them.\n            print('We cannot support torch.nn.Dropout and torch.nn.BatchNorm1d '\n                  'at the same time!')\n            print('Suggest to use another dummy input which has batch size '\n                  'larger than 1 and set model to train() mode.')\n            return\n\n        if not self.ori_training and torch.nn.Dropout in node_types:\n            print('Dropout operation CANNOT be parsed during conversion when '\n                  'the model is in eval() mode!')\n            print('Set model to train() mode!')\n            self.ori_training = True\n\n        if self.ori_training and torch.nn.BatchNorm1d in node_types:\n            print('BatchNorm1d may raise error during conversion when the model'\n                  ' is in train() mode!')\n            print('Set model to eval() mode!')\n            self.ori_training = False\n\n    def non_deter_wrapper(self, op, *args, **kwargs):\n        \"\"\"Some operations are non-deterministic and deterministic mode will\n        fail. So we temporary disable it.\"\"\"\n        if self.bound_opts.get('deterministic', False):\n            torch.use_deterministic_algorithms(False)\n        ret = op(*args, **kwargs)\n        if self.bound_opts.get('deterministic', False):\n            torch.use_deterministic_algorithms(True)\n        return ret\n\n    def non_deter_scatter_add(self, *args, **kwargs):\n        return self.non_deter_wrapper(torch.scatter_add, *args, **kwargs)\n\n    def non_deter_index_select(self, *args, **kwargs):\n        return self.non_deter_wrapper(torch.index_select, *args, **kwargs)\n\n    def set_bound_opts(self, new_opts):\n        for k, v in new_opts.items():\n            # assert v is not dict, 'only support change optimize_bound_args'\n            if type(v) == dict:\n                self.bound_opts[k].update(v)\n            else:\n                self.bound_opts[k] = v\n\n    def set_gcp_relu_indicators(self, relu_layer_name, relu_indicators):\n        \"\"\"\n        Sets the GCP (Generalized Cutting Plane) relu indicators for\n        the specified ReLU layer by name.\n        Args:\n            relu_layer_name (str):\n                The name of the ReLU layer to update.\n            relu_indicators (torch.Tensor):\n                A tensor containing unstable relu indices or masks.\n        \"\"\"\n        # Search for the layer by name\n        for m in self.relus:\n            if m.name == relu_layer_name:\n                # Set the indicators for the found ReLU layer\n                m.gcp_unstable_relu_indicators = relu_indicators\n                return\n        # If not found, raise an error\n        raise ValueError(f'No ReLU layer found with name {relu_layer_name}')\n\n    @staticmethod\n    def _get_A_norm(A):\n        if not isinstance(A, (list, tuple)):\n            A = (A, )\n        norms = []\n        for aa in A:\n            if aa is not None:\n                if isinstance(aa, Patches):\n                    aa = aa.patches\n                norms.append(aa.abs().sum().item())\n            else:\n                norms.append(None)\n        return norms\n\n    def __call__(self, *input, **kwargs):\n        if 'method_opt' in kwargs:\n            opt = kwargs['method_opt']\n            kwargs.pop('method_opt')\n        else:\n            opt = 'forward'\n        for kwarg in [\n            'disable_multi_gpu', 'no_replicas', 'get_property',\n            'node_class', 'att_name']:\n            if kwarg in kwargs:\n                kwargs.pop(kwarg)\n        if opt == 'compute_bounds':\n            return self.compute_bounds(**kwargs)\n        else:\n            return self.forward(*input, **kwargs)\n\n    def register_parameter(self, name, param):\n        r\"\"\"Adds a parameter to the module.\n\n        The parameter can be accessed as an attribute using given name.\n\n        Args:\n            name (string): name of the parameter. The parameter can be accessed\n                from this module using the given name\n            param (Parameter): parameter to be added to the module.\n        \"\"\"\n        if '_parameters' not in self.__dict__:\n            raise AttributeError(\n                'cannot assign parameter before Module.__init__() call')\n        elif not isinstance(name, str):\n            raise TypeError('parameter name should be a string. '\n                            f'Got {torch.typename(name)}')\n        elif name == '':\n            raise KeyError('parameter name can\\'t be empty string')\n        elif hasattr(self, name) and name not in self._parameters:\n            raise KeyError(f'attribute \"{name}\" already exists')\n\n        if param is None:\n            self._parameters[name] = None\n        elif not isinstance(param, Parameter):\n            raise TypeError(\n                f'cannot assign \"{torch.typename(param)}\" object to '\n                f'parameter \"{name}\" '\n                '(torch.nn.Parameter or None required)')\n        elif param.grad_fn:\n            raise ValueError(\n                f'Cannot assign non-leaf Tensor to parameter \"{name}\". Model '\n                'parameters must be created explicitly. To express \"{name}\" '\n                'as a function of another Tensor, compute the value in '\n                'the forward() method.')\n        else:\n            self._parameters[name] = param\n\n    def _named_members(self,\n                       get_members_fn,\n                       prefix='',\n                       recurse=True,\n                       remove_duplicate: bool = True,\n                       **kwargs):  # pylint: disable=unused-argument\n        r\"\"\"Helper method for yielding various names + members of modules.\"\"\"\n        memo = set()\n        modules = self.named_modules(prefix=prefix) if recurse else [\n                                     (prefix, self)]\n        for module_prefix, module in modules:\n            members = get_members_fn(module)\n            for k, v in members:\n                if v is None or v in memo:\n                    continue\n                if remove_duplicate:\n                    memo.add(v)\n                name = module_prefix + ('.' if module_prefix else '') + k\n                # translate name to ori_name\n                if name in self.node_name_map:\n                    name = self.node_name_map[name]\n                yield name, v\n\n    def train(self, mode=True):\n        super().train(mode)\n        for node in self.nodes():\n            node.train(mode=mode)\n\n    def eval(self):\n        super().eval()\n        for node in self.nodes():\n            node.eval()\n\n    def to(self, *args, **kwargs):\n        # Moves and/or casts some attributes except pytorch will do by default.\n        for node in self.nodes():\n            for attr in ['lower', 'upper', 'forward_value', 'd', 'lA',]:\n                if hasattr(node, attr):\n                    this_attr = getattr(node, attr)\n                    if isinstance(this_attr, torch.Tensor):\n                        this_attr = this_attr.to(*args, **kwargs)\n                        setattr(node, attr, this_attr)\n\n            if hasattr(node, 'interval'):\n                # construct new interval\n                this_attr = getattr(node, 'interval')\n                setattr(node, 'interval', (this_attr[0].to(\n                    *args, **kwargs), this_attr[1].to(*args, **kwargs)))\n\n        return super().to(*args, **kwargs)\n\n    def __getitem__(self, name):\n        module = self._modules[name]\n        # We never create modules that are None, the assert fixes type hints\n        assert module is not None\n        return module\n\n    def roots(self):\n        return [self[name] for name in self.root_names]\n\n    def final_node(self):\n        return self[self.final_name]\n\n    def get_forward_value(self, node):\n        \"\"\" Recursively get `forward_value` for `node` and its parent nodes\"\"\"\n        if getattr(node, 'forward_value', None) is not None:\n            return node.forward_value\n        inputs = [self.get_forward_value(inp) for inp in node.inputs]\n        for inp in node.inputs:\n            node.from_input = node.from_input or inp.from_input\n        node.input_shape = inputs[0].shape if len(inputs) > 0 else None\n        fv = node.forward(*inputs)\n        if isinstance(fv, (torch.Size, tuple)):\n            fv = torch.tensor(fv, device=self.device)\n        node.forward_value = fv\n        node.output_shape = fv.shape\n        # In most cases, the batch dimension is just the first dimension\n        # if the node depends on input. Otherwise if the node doesn't\n        # depend on input, there is no batch dimension (default is -1).\n        node.batch_dim = 0 if node.from_input else node.batch_dim\n        # Unperturbed node but it is not a root node.\n        # Save forward_value to value. (Can be used in forward bounds.)\n        if not node.from_input and len(node.inputs) > 0:\n            node.value = node.forward_value\n        return fv\n\n    def forward(self, *x, final_node_name=None,\n                interm_bounds=None,\n                clear_forward_only=False,\n                reset_perturbed_nodes=True,\n                cache_bounds=False):\n        r\"\"\"Standard forward computation for the network.\n\n        Args:\n            x (tuple or None): Input to the model.\n\n            final_node_name (str, optional): The name of the final node in the\n            model. The value on the corresponding node will be returned.\n\n            clear_forward_only (bool, default `False`): Whether only standard\n            forward values stored on the nodes should be cleared. If `True`,\n            only standard forward values stored on the nodes will be cleared.\n            Otherwise, bound information on the nodes will also be cleared.\n\n            reset_perturbed_nodes (bool, default `True`): Mark all perturbed\n            nodes with input perturbations. When set to `True`, it may\n            accidentally clear all .perturbed properties for intermediate\n            nodes.\n\n        Returns:\n            output: The output of the model, or if `final_node_name` is not\n            `None`, return the value on the corresponding node instead.\n        \"\"\"\n        self.set_input(*x,\n                       interm_bounds=interm_bounds,\n                       clear_forward_only=clear_forward_only,\n                       reset_perturbed_nodes=reset_perturbed_nodes,\n                       cache_bounds=cache_bounds)\n        if final_node_name is None:\n            final_node_name = self.output_name[0]\n        return self.get_forward_value(self[final_node_name])\n\n    def _mark_perturbed_nodes(self, input):\n        \"\"\"Mark the graph nodes and determine which nodes need perturbation.\"\"\"\n        # Set some of the input as perturbed if they are bounded objects\n        any_perturbed = False\n        for name, index in zip(self.input_name, self.input_index):\n            if index is None:\n                continue\n            if isinstance(input[index], (BoundedTensor, BoundedParameter)):\n                self[name].perturbed = True\n                any_perturbed = True\n        # If none of the inputs is a bounded object, set all of them as perturbed\n        if not any_perturbed:\n            for name, index in zip(self.input_name, self.input_index):\n                if index is not None:\n                    self[name].perturbed = True\n\n        degree_in = {}\n        queue = deque()\n        relus = []\n        # Initially the queue contains all \"root\" nodes.\n        for key in self._modules.keys():\n            l = self[key]\n            degree_in[l.name] = len(l.inputs)\n            if degree_in[l.name] == 0:\n                queue.append(l)  # in_degree ==0 -> root node\n\n        while len(queue) > 0:\n            node = queue.popleft()\n            # We set the relu here to ensure the list is sorted according to topological order.\n            if isinstance(node, BoundRelu):\n                relus.append(node)\n            # Obtain all output node, and add the output nodes to the queue if\n            # all its input nodes have been visited.\n            # The initial \"perturbed\" property is set in BoundInput or\n            # BoundParams object, depending on ptb.\n            for name_next in node.output_name:\n                node_next = self[name_next]\n                if not node_next.never_perturbed:\n                    # The next node is perturbed if it is already perturbed,\n                    # or this node is perturbed.\n                    node_next.perturbed = node_next.perturbed or node.perturbed\n                degree_in[name_next] -= 1\n                # all inputs of this node have been visited,\n                # now put it in queue.\n                if degree_in[name_next] == 0:\n                    queue.append(node_next)\n            node.update_requires_input_bounds()\n\n        self.relus = relus\n        self.get_optimizable_activations()\n        self.splittable_activations = self.get_splittable_activations()\n        self.perturbed_optimizable_activations = (\n            self.get_perturbed_optimizable_activations())\n        return\n\n    def _check_patches_mode(self):\n        \"\"\"Disable patches mode if there is no Conv node.\n\n        This is a workaround (before a more general patches mode is implemented)\n        to avoid issues relevant to the patches node,\n        for complicated models without any Conv.\n        \"\"\"\n        has_conv = False\n        for node in self.nodes():\n            if isinstance(node, (BoundConv, BoundConvTranspose, BoundConv2dGrad)):\n                has_conv = True\n        if not has_conv and self.conv_mode == 'patches':\n            self.conv_mode = 'matrix'\n            for node in self.nodes():\n                if getattr(node, 'mode', None) == 'patches':\n                    node.mode = 'matrix'\n\n    def _clear_and_set_new(\n        self,\n        interm_bounds,\n        clear_forward_only=False,\n        reset_perturbed_nodes=True,\n        cache_bounds=False,\n    ):\n        for l in self.nodes():\n            if hasattr(l, 'linear'):\n                if isinstance(l.linear, tuple):\n                    for item in l.linear:\n                        del item\n                delattr(l, 'linear')\n\n            if hasattr(l, 'patch_size'):\n                l.patch_size = {}\n\n            if clear_forward_only:\n                if hasattr(l, 'forward_value'):\n                    delattr(l, 'forward_value')\n            else:\n                for attr in ['interval', 'forward_value', 'd',\n                             'lA', 'lower_d', 'upper_k']:\n                    if hasattr(l, attr):\n                        delattr(l, attr)\n                if cache_bounds:\n                    l.move_lower_and_upper_bounds_to_cache()\n                else:\n                    l.delete_lower_and_upper_bounds()\n\n            for attr in ['zero_backward_coeffs_l', 'zero_backward_coeffs_u',\n                         'zero_lA_mtx', 'zero_uA_mtx']:\n                setattr(l, attr, False)\n            # Given an interval here to make IBP/CROWN start from this node\n            if interm_bounds is not None and l.name in interm_bounds.keys():\n                l.interval = tuple(interm_bounds[l.name][:2])\n                l.lower = interm_bounds[l.name][0]\n                l.upper = interm_bounds[l.name][1]\n                if l.lower is not None:\n                    l.lower = l.lower.detach().requires_grad_(False)\n                if l.upper is not None:\n                    l.upper = l.upper.detach().requires_grad_(False)\n            # Mark all nodes as non-perturbed except for weights.\n            if reset_perturbed_nodes:\n                if not hasattr(l, 'perturbation') or l.perturbation is None:\n                    l.perturbed = False\n\n            # Clear operator-specific attributes\n            l.clear()\n\n    def set_input(\n        self,\n        *x,\n        interm_bounds=None,\n        clear_forward_only=False,\n        reset_perturbed_nodes=True,\n        cache_bounds=False,\n    ):\n        self._clear_and_set_new(\n            interm_bounds=interm_bounds,\n            clear_forward_only=clear_forward_only,\n            reset_perturbed_nodes=reset_perturbed_nodes,\n            cache_bounds=cache_bounds,\n        )\n        inputs_unpacked = unpack_inputs(x)\n        for name, index in zip(self.input_name, self.input_index):\n            if index is None:\n                continue\n            node = self[name]\n            node.value = inputs_unpacked[index]\n            if isinstance(node.value, (BoundedTensor, BoundedParameter)):\n                node.perturbation = node.value.ptb\n            else:\n                node.perturbation = None\n        # Mark all perturbed nodes.\n        if reset_perturbed_nodes:\n            self._mark_perturbed_nodes(inputs_unpacked)\n\n    def _get_node_input(self, nodesOP, nodesIn, node):\n        ret = []\n        for i in range(len(node.inputs)):\n            for op in nodesOP:\n                if op.name == node.inputs[i]:\n                    ret.append(op.bound_node)\n                    break\n            if len(ret) == i + 1:\n                continue\n            for io in nodesIn:\n                if io.name == node.inputs[i]:\n                    ret.append(io.bound_node)\n                    break\n            if len(ret) <= i:\n                raise ValueError(f'cannot find inputs of node: {node.name}')\n        return ret\n\n    def _to(self, obj, dest, inplace=False):\n        \"\"\" Move all tensors in the object to a specified dest\n        (device or dtype). The inplace=True option is available for dict.\"\"\"\n        if obj is None:\n            return obj\n        elif isinstance(obj, torch.Tensor):\n            return obj.to(dest)\n        elif isinstance(obj, Patches):\n            return obj.patches.to(dest)\n        elif isinstance(obj, tuple):\n            return tuple([self._to(item, dest) for item in obj])\n        elif isinstance(obj, list):\n            return list([self._to(item, dest) for item in obj])\n        elif isinstance(obj, dict):\n            if inplace:\n                for k, v in obj.items():\n                    obj[k] = self._to(v, dest, inplace=True)\n                return obj\n            else:\n                return {k: self._to(v, dest) for k, v in obj.items()}\n        else:\n            raise NotImplementedError(type(obj))\n\n    def _convert_nodes(self, model, global_input):\n        r\"\"\"\n        Returns:\n            nodesOP (list): List of operator nodes\n            nodesIn (list): List of input nodes\n            nodesOut (list): List of output nodes\n            template (object): Template to specify the output format\n        \"\"\"\n        global_input_cpu = self._to(global_input, 'cpu')\n        if self.ori_training:\n            model.train()\n        else:\n            model.eval()\n        model.to('cpu')\n        nodesOP, nodesIn, nodesOut, template = parse_module(\n            model, global_input_cpu)\n        model.to(self.device)\n        for i in range(0, len(nodesIn)):\n            if nodesIn[i].param is not None:\n                nodesIn[i] = nodesIn[i]._replace(\n                    param=nodesIn[i].param.to(self.device))\n\n        # Convert input nodes and parameters.\n        attr = {'device': self.device}\n        for i, n in enumerate(nodesIn):\n            if n.input_index is not None:\n                nodesIn[i] = nodesIn[i]._replace(bound_node=BoundInput(\n                    ori_name=nodesIn[i].ori_name,\n                    value=global_input[nodesIn[i].input_index],\n                    perturbation=nodesIn[i].perturbation,\n                    input_index=n.input_index, options=self.bound_opts,\n                    attr=attr))\n            else:\n                bound_class = BoundParams if isinstance(\n                    nodesIn[i].param, nn.Parameter) else BoundBuffers\n                nodesIn[i] = nodesIn[i]._replace(bound_node=bound_class(\n                    ori_name=nodesIn[i].ori_name, value=nodesIn[i].param,\n                    perturbation=nodesIn[i].perturbation, options=self.bound_opts,\n                    attr=attr))\n\n        unsupported_ops = []\n\n        # Convert other operation nodes.\n        for n in range(len(nodesOP)):\n            attr = nodesOP[n].attr\n            inputs = self._get_node_input(nodesOP, nodesIn, nodesOP[n])\n            try:\n                if nodesOP[n].op in self.custom_ops:\n                    op = self.custom_ops[nodesOP[n].op]\n                elif nodesOP[n].op in bound_op_map:\n                    op = bound_op_map[nodesOP[n].op]\n                elif nodesOP[n].op.startswith('aten::ATen'):\n                    op = globals()[f'BoundATen{attr[\"operator\"].capitalize()}']\n                elif nodesOP[n].op.startswith('onnx::'):\n                    op = globals()[f'Bound{nodesOP[n].op[6:]}']\n                else:\n                    raise KeyError\n            except (NameError, KeyError):\n                unsupported_ops.append(nodesOP[n])\n                logger.error('The node has an unsupported operation: %s',\n                             nodesOP[n])\n                continue\n            attr['device'] = self.device\n\n            # FIXME generalize\n            if (nodesOP[n].op == 'onnx::BatchNormalization'\n                    or getattr(op, 'TRAINING_FLAG', False)):\n                # BatchNormalization node needs model.training flag to set\n                # running mean and vars set training=False to avoid wrongly\n                # updating running mean/vars during bound wrapper\n                nodesOP[n] = nodesOP[n]._replace(bound_node=op(\n                    attr, inputs, nodesOP[n].output_index, self.bound_opts,\n                    False))\n            else:\n                nodesOP[n] = nodesOP[n]._replace(bound_node=op(\n                    attr, inputs, nodesOP[n].output_index, self.bound_opts))\n\n        if unsupported_ops:\n            logger.error('Unsupported operations:')\n            for n in unsupported_ops:\n                logger.error(f'Name: {n.op}, Attr: {n.attr}')\n            raise NotImplementedError('There are unsupported operations')\n\n        for node in nodesIn + nodesOP:\n            node.bound_node.name = node.name\n\n        nodes_dict = {}\n        for node in nodesOP + nodesIn:\n            nodes_dict[node.name] = node.bound_node\n        nodesOP = [n.bound_node for n in nodesOP]\n        nodesIn = [n.bound_node for n in nodesIn]\n        nodesOut = [nodes_dict[n] for n in nodesOut]\n\n        return nodesOP, nodesIn, nodesOut, template\n\n    def _build_graph(self, nodesOP, nodesIn, nodesOut, template):\n        # We were assuming that the original model had only one output node.\n        assert len(nodesOut) == 1\n        self.final_name = nodesOut[0].name\n        self.input_name, self.input_index, self.root_names = [], [], []\n        self.output_name = [n.name for n in nodesOut]\n        self.output_template = template\n        self._modules.clear()\n        for node in nodesIn:\n            self.add_input_node(node, index=node.input_index)\n        self.add_nodes(nodesOP)\n        if self.conv_mode == 'patches':\n            self.root_names: List[str] = [node.name for node in nodesIn]\n\n    def rename_nodes(self, nodesOP, nodesIn, rename_dict):\n        def rename(node):\n            node.name = rename_dict[node.name]\n            return node\n        for i in range(len(nodesOP)):\n            nodesOP[i] = rename(nodesOP[i])\n        for i in range(len(nodesIn)):\n            nodesIn[i] = rename(nodesIn[i])\n\n    def _split_complex(self, nodesOP, nodesIn):\n        finished = True\n        for n in range(len(nodesOP)):\n            if hasattr(nodesOP[n], 'complex') and nodesOP[n].complex:\n                complex_node = nodesOP[n]\n\n                finished = False\n                _nodesOP, _nodesIn, _nodesOut, _ = self._convert_nodes(\n                    nodesOP[n].model, nodesOP[n].input)\n                # assuming each supported complex operation only has one output\n                assert len(_nodesOut) == 1\n\n                name_base = nodesOP[n].name + '/split'\n                rename_dict = {}\n                for node in _nodesOP + _nodesIn:\n                    rename_dict[node.name] = name_base + node.name\n                num_inputs = len(nodesOP[n].inputs)\n                for i in range(num_inputs):\n                    rename_dict[_nodesIn[i].name] = nodesOP[n].input_name[i]\n                rename_dict[_nodesOP[-1].name] = nodesOP[n].name\n\n                self.rename_nodes(_nodesOP, _nodesIn, rename_dict)\n\n                output_name = _nodesOP[-1].name\n                # Any input node of some node within the complex node should be\n                # replaced with the complex node's corresponding input node.\n                for node in _nodesOP:\n                    for i in range(len(node.inputs)):\n                        if node.input_name[i] in nodesOP[n].input_name:\n                            index = nodesOP[n].input_name.index(\n                                node.input_name[i])\n                            node.inputs[i] = nodesOP[n].inputs[index]\n                # For any output node of this complex node,\n                # modify its input node.\n                for node in nodesOP:\n                    if output_name in node.input_name:\n                        index = node.input_name.index(output_name)\n                        node.inputs[index] = _nodesOP[-1]\n                # Mark where the nodes come from\n                for node in _nodesOP:\n                    node.from_complex_node = type(complex_node).__name__\n\n                nodesOP = nodesOP[:n] + _nodesOP + nodesOP[(n + 1):]\n                nodesIn = nodesIn + _nodesIn[num_inputs:]\n\n                break\n\n        return nodesOP, nodesIn, finished\n\n    def _get_node_name_map(self):\n        \"\"\"Build a dict with {ori_name: name, name: ori_name}\"\"\"\n        self.node_name_map = {}\n        for node in self.nodes():\n            if isinstance(node, (BoundInput, BoundParams)):\n                for p in list(node.named_parameters()):\n                    if node.ori_name not in self.node_name_map:\n                        name = f'{node.name}.{p[0]}'\n                        self.node_name_map[node.ori_name] = name\n                        self.node_name_map[name] = node.ori_name\n                for p in list(node.named_buffers()):\n                    if node.ori_name not in self.node_name_map:\n                        name = f'{node.name}.{p[0]}'\n                        self.node_name_map[node.ori_name] = name\n                        self.node_name_map[name] = node.ori_name\n\n    # convert a Pytorch model to a model with bounds\n    def _convert(self, model, global_input):\n        if self.verbose:\n            logger.info('Converting the model...')\n\n        self.num_global_inputs = len(global_input)\n\n        nodesOP, nodesIn, nodesOut, template = self._convert_nodes(\n            model, global_input)\n        global_input = self._to(global_input, self.device)\n\n        while True:\n            self._build_graph(nodesOP, nodesIn, nodesOut, template)\n            self.forward(*global_input)  # running means/vars changed\n            nodesOP, nodesIn, finished = self._split_complex(nodesOP, nodesIn)\n            if finished:\n                break\n\n        self._get_node_name_map()\n\n        ori_state_dict_mapped = OrderedDict()\n        for k, v in self.ori_state_dict.items():\n            if k in self.node_name_map:\n                ori_state_dict_mapped[self.node_name_map[k]] = v\n        self.load_state_dict(ori_state_dict_mapped)\n        if self.ori_training:\n            model.load_state_dict(self.ori_state_dict)\n        delattr(self, 'ori_state_dict')\n\n        # The name of the final node used in the last call to `compute_bounds`\n        self.last_final_node_name = None\n\n        if self.verbose:\n            logger.info('Model converted to support bounds')\n\n    def check_prior_bounds(self, node, C=None):\n        if node.prior_checked or not (node.used and node.perturbed):\n            return\n        if C is not None and isinstance(node, BoundConcat):\n            # If the last node is a BoundConcat, it's possible that only some of\n            # the input nodes of the BoundConcat are needed in the specification.\n            # In this case, we only check the bounds of the input nodes that are\n            # actually used in the specification. All other branches are\n            # considered as not used, and their bounds are not checked.\n            # FIXME: In this case, node.used of some nodes may be incorrect.\n            offset = 0\n            assert isinstance(C, torch.Tensor) and C.ndim == 3\n            C = C.abs().sum(dim=[0, 1])\n            for node_input in node.inputs:\n                size = prod(node_input.output_shape[1:])\n                C_s = C[offset:offset+size].sum()\n                if (C_s != 0).any():\n                    self.check_prior_bounds(node_input)\n                offset += size\n        else:\n            for n in node.inputs:\n                self.check_prior_bounds(n)\n        tighten_input_bounds = (\n            self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n        )\n        directly_optimize_layer_names = (\n            self.bound_opts['optimize_bound_args']['directly_optimize']\n        )\n        bound_every_node = (\n            self.bound_opts['bound_every_node']\n        )\n        for i in range(len(node.inputs)):\n            if (\n                i in node.requires_input_bounds\n                or not node.inputs[i].perturbed\n                or node.inputs[i].name in self.layers_with_constraint\n                # allows to tighten input bounds\n                or (isinstance(node.inputs[i], BoundInput) and tighten_input_bounds)\n                # layers whos optimization is forced\n                # (for consecutive layers introduced as part of invprop)\n                or node.inputs[i].name in directly_optimize_layer_names\n                or bound_every_node\n            ):\n                self.compute_intermediate_bounds(\n                    node.inputs[i], prior_checked=True)\n        node.prior_checked = True\n\n    def compute_intermediate_bounds(self, node: Bound, prior_checked=False):\n        tighten_input_bounds = (\n            self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n        )\n        directly_optimize_layer_names = (\n            self.bound_opts['optimize_bound_args']['directly_optimize']\n        )\n        best_of_oc_and_no_oc = (\n            self.bound_opts['optimize_bound_args']['best_of_oc_and_no_oc']\n        )\n        if (\n            node.is_lower_bound_current()\n            and not (\n                isinstance(node, BoundInput) and tighten_input_bounds\n                or node.name in directly_optimize_layer_names\n            )\n        ):\n            if node.name in self.layers_with_constraint:\n                node.clamp_interim_bounds()\n            return\n\n        logger.debug(f'Getting the bounds of {node}')\n\n        if not prior_checked:\n            self.check_prior_bounds(node)\n\n        if not node.perturbed:\n            fv = self.get_forward_value(node)\n            node.interval = node.lower, node.upper = fv, fv\n            return\n\n        # FIXME check that weight perturbation is not affected\n        #      (from_input=True should be set for weights)\n        if not node.from_input and hasattr(node, 'forward_value'):\n            node.lower = node.upper = self.get_forward_value(node)\n            return\n\n        reference_bounds = self.reference_bounds\n\n        if self.use_forward:\n            # forward\n            node.lower, node.upper = self.forward_general(\n                node=node, concretize=True)\n        else:\n            # backward\n            if self.check_IBP_intermediate(node):\n                # Intermediate bounds for some operators are directly\n                # computed from their input nodes by IBP\n                # (such as BoundRelu, BoundNeg)\n                logger.debug('IBP propagation for intermediate bounds on %s', node)\n            # For the first linear layer, IBP can give the same tightness as CROWN.\n            elif not self.check_IBP_first_linear(node):\n                ref_intermediate = self.get_ref_intermediate_bounds(node)\n                sparse_C = self.get_sparse_C(node, ref_intermediate)\n                newC, reduced_dim, unstable_idx, unstable_size = sparse_C\n\n                # Special case for BoundRelu when sparse intermediate bounds are disabled\n                # Currently sparse intermediate bounds are restricted to ReLU models only\n                skip = False\n                if unstable_idx is None:\n                    if (len(node.output_name) == 1\n                            and isinstance(self[node.output_name[0]], BoundTwoPieceLinear)\n                            and node.name in self.reference_bounds):\n                        lower, upper = self.reference_bounds[node.name]\n                        fully_stable = torch.logical_or(lower>=0, upper<=0).all()\n                        if fully_stable:\n                            node.lower, node.upper = lower, upper\n                            skip = True\n                elif unstable_size == 0:\n                    skip = True\n\n                if not skip:\n                    apply_output_constraints_to = self.bound_opts[\n                        'optimize_bound_args']['apply_output_constraints_to']\n                    if self.return_A:\n                        node.lower, node.upper, _ = self.backward_general(\n                            node, newC, unstable_idx=unstable_idx,\n                            apply_output_constraints_to=apply_output_constraints_to)\n                    else:\n                        # Compute backward bounds only when there are unstable\n                        # neurons, or when we don't know which neurons are unstable.\n                        node.lower, node.upper = self.backward_general(\n                            node, newC, unstable_idx=unstable_idx,\n                            apply_output_constraints_to=apply_output_constraints_to)\n                    if torch.any((node.upper - node.lower).abs() > 1e10):\n                        if len(apply_output_constraints_to) > 0 and not best_of_oc_and_no_oc:\n                            warnings.warn('Very weak bounds detected. This can potentially be '\n                                'fixed by setting best_of_oc_and_no_oc=True.')\n\n                if reduced_dim:\n                    self.restore_sparse_bounds(\n                        node, unstable_idx, unstable_size, ref_intermediate)\n\n                if self.bound_opts['compare_crown_with_ibp']:\n                    node.lower, node.upper = self.compare_with_IBP(node, node.lower, node.upper)\n\n        # node.lower and node.upper (intermediate bounds) are computed in\n        # the above function. If we have bound references, we set them here\n        # to always obtain a better set of bounds.\n        if node.name in reference_bounds:\n            ref_bounds = reference_bounds[node.name]\n            # Initially, the reference bound and the computed bound can be\n            # exactly the same when intermediate layer beta is 0. This will\n            # prevent gradients flow. So we need a small guard here.\n            # Set the intermediate layer bounds using reference bounds,\n            # always choosing the tighter one.\n            # Assert no NaNs in reference bounds before using them\n            assert not torch.isnan(ref_bounds[0]).any(), (\n                f'NaN detected in reference lower bound of layer {node.name}')\n            node.lower = (torch.max(ref_bounds[0], node.lower).detach()\n                          - node.lower.detach() + node.lower)\n            assert not torch.isnan(ref_bounds[1]).any(), (\n                f'NaN detected in reference upper bound of layer {node.name}')\n            node.upper = (node.upper - (node.upper.detach()\n                          - torch.min(ref_bounds[1], node.upper).detach()))\n            # Also update bounds in node.linear (if exist)\n            if hasattr(node, 'linear'):\n                node.linear.lower = node.lower\n                node.linear.upper = node.upper\n            # Otherwise, we only use reference bounds to check which neurons\n            # are unstable.\n\n        # prior constraint bounds\n        if node.name in self.layers_with_constraint:\n            node.clamp_interim_bounds()\n        # FIXME (12/28): we should be consistent, and only use\n        # node.interval, do not use node.lower or node.upper!\n        node.interval = (node.lower, node.upper)\n\n    def get_ref_intermediate_bounds(self, node):\n        sparse_intermediate_bounds_with_ibp = self.bound_opts.get(\n            'sparse_intermediate_bounds_with_ibp', True)\n        # Sparse intermediate bounds can be enabled\n        # if aux_reference_bounds are given.\n        # (this is enabled for ReLU only, and not for other activations.)\n        sparse_intermediate_bounds = (self.bound_opts.get(\n            'sparse_intermediate_bounds', False)\n            and isinstance(self[node.output_name[0]], BoundRelu))\n\n        ref_intermediate_lb, ref_intermediate_ub = None, None\n        if sparse_intermediate_bounds:\n            if node.name not in self.aux_reference_bounds:\n                # If aux_reference_bounds are not available,\n                # we can use IBP to compute these bounds.\n                if sparse_intermediate_bounds_with_ibp:\n                    with torch.no_grad():\n                        # Get IBP bounds for this layer;\n                        # we set delete_bounds_after_use=True which does\n                        # not save extra intermediate bound tensors.\n                        ret_ibp = self.IBP_general(\n                            node=node, delete_bounds_after_use=True)\n                        ref_intermediate_lb = ret_ibp[0]\n                        ref_intermediate_ub = ret_ibp[1]\n                else:\n                    sparse_intermediate_bounds = False\n            else:\n                aux_bounds = self.aux_reference_bounds[node.name]\n                ref_intermediate_lb, ref_intermediate_ub = aux_bounds\n\n        return sparse_intermediate_bounds, ref_intermediate_lb, ref_intermediate_ub\n\n    def merge_A_dict(self, lA_dict, uA_dict):\n        merged_A = {}\n        for output_node_name in lA_dict:\n            merged_A[output_node_name] = {}\n            lA_dict_ = lA_dict[output_node_name]\n            uA_dict_ = uA_dict[output_node_name]\n            for input_node_name in lA_dict_:\n                merged_A[output_node_name][input_node_name] = {\n                    'lA': lA_dict_[input_node_name]['lA'],\n                    'uA': uA_dict_[input_node_name]['uA'],\n                    'lbias': lA_dict_[input_node_name]['lbias'],\n                    'ubias': uA_dict_[input_node_name]['ubias'],\n                }\n        return merged_A\n\n    def compute_bounds(\n            self, x=None, aux=None, C=None, method='backward', IBP=False,\n            forward=False, bound_lower=True, bound_upper=True, reuse_ibp=False,\n            reuse_alpha=False, return_A=False, needed_A_dict=None,\n            final_node_name=None, average_A=False,\n            interm_bounds=None, reference_bounds=None,\n            intermediate_constr=None, alpha_idx=None,\n            aux_reference_bounds=None, need_A_only=False,\n            cutter=None, decision_thresh=None,\n            update_mask=None, ibp_nodes=None, cache_bounds=False):\n        r\"\"\"Main function for computing bounds.\n\n        Args:\n            x (tuple or None): Input to the model. If it is None, the input\n            from the last `forward` or `compute_bounds` call is reused.\n            Otherwise: the number of elements in the tuple should be\n            equal to the number of input nodes in the model, and each element in\n            the tuple corresponds to the value for each input node respectively.\n            It should look similar as the `global_input` argument when used for\n            creating a `BoundedModule`.\n\n            aux (object, optional): Auxliary information that can be passed to\n            `Perturbation` classes for initializing and concretizing bounds,\n            e.g., additional information for supporting synonym word subsitution\n            perturbaiton.\n\n            C (Tensor): The specification matrix that can map the output of the\n            model with an additional linear layer. This is usually used for\n            maping the logits output of the model to classification margins.\n\n            method (str): The main method for bound computation. Choices:\n                * `IBP`: purely use Interval Bound Propagation (IBP) bounds.\n                * `CROWN-IBP`: use IBP to compute intermediate bounds,\n                but use CROWN (backward mode LiRPA) to compute the bounds of the\n                final node.\n                * `CROWN`: purely use CROWN to compute bounds for intermediate\n                nodes and the final node.\n                * `Forward`: purely use forward mode LiRPA.\n                * `Forward+Backward`: use forward mode LiRPA for intermediate\n                nodes, but further use CROWN for the final node.\n                * `CROWN-Optimized` or `alpha-CROWN`: use CROWN, and also\n                optimize the linear relaxation parameters for activations.\n                * `forward-optimized`: use forward bounds with optimized linear\n                relaxation.\n                * `dynamic-forward`: use dynamic forward bound propagation where\n                new input variables may be dynamically introduced for\n                nonlinearities.\n                * `dynamic-forward+backward`: use dynamic forward mode for\n                intermediate nodes, but use CROWN for the final node.\n\n            IBP (bool, optional): If `True`, use IBP to compute the bounds of\n            intermediate nodes. It can be automatically set according to\n            `method`.\n\n            forward (bool, optional): If `True`, use the forward mode bound\n            propagation to compute the bounds of intermediate nodes. It can be\n            automatically set according to `method`.\n\n            bound_lower (bool, default `True`): If `True`, the lower bounds of\n            the output needs to be computed.\n\n            bound_upper (bool, default `True`): If `True`, the upper bounds of\n            the output needs to be computed.\n\n            reuse_ibp (bool, optional): If `True` and `method` is None, reuse\n            the previously saved IBP bounds.\n\n            final_node_name (str, optional): Set the final node in the\n            computational graph for bound computation. By default, the final\n            node of the originally built computational graph is used.\n\n            return_A (bool, optional): If `True`, return linear coefficients\n            in bound propagation (`A` tensors) with `needed_A_dict` set.\n\n            needed_A_dict (dict, optional): A dictionary specifying linear\n            coefficients (`A` tensors) that are needed and should be returned.\n            Each key in the dictionary is the name of a starting node in\n            backward bound propagation, with a list as the value for the key,\n            which specifies the names of the ending nodes in backward bound\n            propagation, and the linear coefficients of the starting node w.r.t.\n            the specified ending nodes are returned. By default, it is empty.\n\n            reuse_alpha (bool, optional): If `True`, reuse previously saved\n            alpha values when they are not being optimized.\n\n            decision_thresh (float, optional): In CROWN-optimized mode, we will\n            use this decision_thresh to dynamically optimize those domains that\n            <= the threshold.\n\n            interm_bounds: A dictionary of 2-element tuple/list\n            containing lower and upper bounds for intermediate layers.\n            The dictionary keys should include the names of the layers whose\n            bounds should be set without recomputation. The layer names can be\n            viewed by setting environment variable AUTOLIRPA_DEBUG=1.\n            The values of each dictionary elements are (lower_bounds,\n            upper_bounds) where \"lower_bounds\" and \"upper_bounds\" are two\n            tensors with the same shape as the output shape of this layer. If\n            you only need to set intermediate layer bounds for certain layers,\n            then just include these layers' names in the dictionary.\n\n            reference_bounds: Format is similar to \"interm_bounds\".\n            However, these bounds are only used as a reference, and the bounds\n            for intermediate layers will still be computed (e.g., using CROWN,\n            IBP or other specified methods). The computed bounds will be\n            compared to \"reference_bounds\" and the tighter one between the two\n            will be used.\n\n            aux_reference_bounds: Format is similar to intermediate layer\n            bounds. However, these bounds are only used for determine which\n            neurons are stable and which neurons are unstable for ReLU networks.\n            Unstable neurons' intermediate layer bounds will be recomputed.\n\n            cache_bounds: If `True`, the currently set lower and upper bounds will not\n            be deleted, but cached for use by the INVPROP algorithm. This should not be\n            set by the user, but only in `_get_optimized_bounds`.\n\n        Returns:\n            bound (tuple): When `return_A` is `False`, return a tuple of\n            the computed lower bound and upper bound. When `return_A`\n            is `True`, return a tuple of lower bound, upper bound, and\n            `A` dictionary.\n        \"\"\"\n        # This method only prepares everything by setting all required parameters.\n        # The main logic is located in `_compute_bounds_main`. It may be called\n        # repeatedly for CROWN optimizations.\n        logger.debug(f'Compute bounds with {method}')\n        if needed_A_dict is None: needed_A_dict = {}\n        if not bound_lower and not bound_upper:\n            raise ValueError(\n                'At least one of bound_lower and bound_upper must be True')\n\n        # Several shortcuts.\n        compute_optimized = False\n        method = method.lower() if method is not None else method\n        if method == 'ibp':\n            # Pure IBP bounds.\n            method, IBP = None, True\n        elif method in ['ibp+backward', 'ibp+crown', 'crown-ibp']:\n            method, IBP = 'backward', True\n        elif method == 'crown':\n            method = 'backward'\n        elif method == 'forward':\n            forward = True\n            self.dynamic = False\n        elif method == 'dynamic-forward':\n            forward = True\n            self.dynamic = True\n        elif method == 'forward+backward' or method == 'forward+crown':\n            method, forward = 'backward', True\n        elif method == 'dynamic-forward+backward' or method == 'dynamic-forward+crown':\n            self.dynamic = True\n            method, forward = 'backward', True\n        elif method in ['crown-optimized', 'alpha-crown', 'forward-optimized']:\n            # Lower and upper bounds need two separate rounds of optimization.\n            if method == 'forward-optimized':\n                method = 'forward'\n            else:\n                method = 'backward'\n            compute_optimized = True\n\n        if reference_bounds is None:\n            reference_bounds = {}\n        if aux_reference_bounds is None:\n            aux_reference_bounds = {}\n\n        # If y in self.backward_node_pairs[x], then node y is visited when\n        # doing backward bound propagation starting from node x.\n        self.backward_from = dict([(node, []) for node in self._modules])\n\n        if not bound_lower and not bound_upper:\n            raise ValueError(\n                'At least one of bound_lower and bound_upper in compute_bounds '\n                'should be True')\n        A_dict = {} if return_A else None\n\n        if x is not None:\n            if isinstance(x, torch.Tensor):\n                x = (x,)\n            if self.bound_opts['forward_before_compute_bounds']:\n                self.forward(*x, interm_bounds=interm_bounds, cache_bounds=cache_bounds)\n            else:\n                self.set_input(*x, interm_bounds=interm_bounds, cache_bounds=cache_bounds)\n\n        roots = self.roots()\n        batch_size = roots[0].value.shape[0]\n        dim_in = 0\n\n        for i in range(len(roots)):\n            value = roots[i].forward()\n            if getattr(roots[i], 'perturbation', None) is not None:\n                ret_init = roots[i].perturbation.init(\n                    value, aux=aux, forward=forward)\n                roots[i].linear, roots[i].center, roots[i].aux = ret_init\n                # This input/parameter has perturbation.\n                # Create an interval object.\n                roots[i].interval = Interval(\n                    roots[i].linear.lower, roots[i].linear.upper,\n                    ptb=roots[i].perturbation)\n                if forward:\n                    roots[i].dim = roots[i].linear.lw.shape[1]\n                    dim_in += roots[i].dim\n\n            else:\n                # This input/parameter does not has perturbation.\n                # Use plain tuple defaulting to Linf perturbation.\n                roots[i].interval = (value, value)\n                roots[i].forward_value = roots[i].value = value\n                roots[i].center = roots[i].lower = roots[i].upper = value\n\n            roots[i].lower, roots[i].upper = roots[i].interval\n\n        if forward:\n            self.init_forward(roots, dim_in)\n\n        for n in self.nodes():\n            if isinstance(n, BoundRelu):\n                for node in n.inputs:\n                    if hasattr(node, 'relu_followed'):\n                        node.relu_followed = True\n\n            # Inject update mask inside the activations\n            # update_mask: None or bool tensor([batch_size])\n            # If set to a tensor, only update the alpha and beta of selected\n            # element (with element=1).\n            n.alpha_beta_update_mask = update_mask\n\n        final = (self.final_node() if final_node_name is None\n                 else self[final_node_name])\n        # BFS to find out whether each node is used given the current final node\n        self._set_used_nodes(final)\n\n        self.use_forward = forward\n        self.batch_size = batch_size\n        self.dim_in = dim_in\n        self.return_A = return_A\n        self.A_dict = A_dict\n        self.needed_A_dict = needed_A_dict\n        self.intermediate_constr = intermediate_constr\n        self.reference_bounds = reference_bounds\n        self.aux_reference_bounds = aux_reference_bounds\n        self.final_node_name = final.name\n        self.ibp_nodes = ibp_nodes\n\n        if compute_optimized:\n            kwargs = dict(x=x, C=C, method=method, interm_bounds=interm_bounds,\n                reference_bounds=reference_bounds, return_A=return_A,\n                aux_reference_bounds=aux_reference_bounds,\n                needed_A_dict=needed_A_dict,\n                final_node_name=final_node_name,\n                cutter=cutter, decision_thresh=decision_thresh)\n            if bound_upper:\n                ret2 = self._get_optimized_bounds(bound_side='upper', **kwargs)\n            else:\n                ret2 = None\n            if bound_lower:\n                ret1 = self._get_optimized_bounds(bound_side='lower', **kwargs)\n            else:\n                ret1 = None\n            if bound_lower and bound_upper:\n                if return_A:\n                    # Needs to merge the A dictionary.\n                    return ret1[0], ret2[1], self.merge_A_dict(ret1[2], ret2[2])\n                else:\n                    return ret1[0], ret2[1]\n            elif bound_lower:\n                return ret1  # ret1[1] is None.\n            elif bound_upper:\n                return ret2  # ret2[0] is None.\n\n        return self._compute_bounds_main(C=C,\n                                         method=method,\n                                         IBP=IBP,\n                                         bound_lower=bound_lower,\n                                         bound_upper=bound_upper,\n                                         reuse_ibp=reuse_ibp,\n                                         reuse_alpha=reuse_alpha,\n                                         average_A=average_A,\n                                         alpha_idx=alpha_idx,\n                                         need_A_only=need_A_only,\n                                         update_mask=update_mask)\n\n    def save_intermediate(self, save_path=None):\n        r\"\"\"A function for saving intermediate bounds.\n\n        Please call this function after `compute_bounds`, or it will output\n        IBP bounds by default.\n\n        Args:\n            save_path (str, default `None`): If `None`, the intermediate bounds\n            will not be saved, or it will be saved at the designated path.\n\n        Returns:\n            save_dict (dict): Return a dictionary of lower and upper bounds, with\n            the key being the name of the layer.\n        \"\"\"\n        save_dict = OrderedDict()\n        for node in self.nodes():\n            if node.used and node.perturbed:\n                if not hasattr(node, 'interval'):\n                    ibp_lower, ibp_upper = self.IBP_general(node,\n                        delete_bounds_after_use=True)\n                    dim_output = int(prod(node.output_shape[1:]))\n                    C = torch.eye(dim_output, device=self.device).expand(\n                        self.batch_size, dim_output, dim_output)\n                    crown_lower, crown_upper = self.backward_general(node, C=C)\n                    save_dict[node.name] = (\n                        torch.max(crown_lower, ibp_lower),\n                        torch.min(crown_upper, ibp_upper))\n                else:\n                    save_dict[node.name] = (node.lower, node.upper)\n\n        if save_path is not None:\n            torch.save(save_dict, save_path)\n        return save_dict\n\n    def _compute_bounds_main(self, C=None, method='backward', IBP=False,\n            bound_lower=True, bound_upper=True, reuse_ibp=False,\n            reuse_alpha=False, average_A=False, alpha_idx=None,\n            need_A_only=False, update_mask=None):\n        \"\"\"The core implementation of compute_bounds.\n\n        Seperated because compute_bounds may call _get_optimized_bounds which\n        repeatedly calls this method. Otherwise, the preprocessing done in\n        compute_bounds would be executed for each iteration.\n        \"\"\"\n\n        final = (self.final_node() if self.final_node_name is None\n                 else self[self.final_node_name])\n        logger.debug(f'Final node {final.__class__.__name__}({final.name})')\n\n        if IBP and method is None and reuse_ibp:\n            # directly return the previously saved ibp bounds\n            return self.ibp_lower, self.ibp_upper\n\n        if IBP:\n            self.ibp_lower, self.ibp_upper = self.IBP_general(node=final, C=C)\n\n        if method is None:\n            return self.ibp_lower, self.ibp_upper\n\n        # TODO: if compute_bounds is called with a method that causes alphas to be\n        # optimized, C will be allocated in each iteration. We could allocate it once\n        # in compute_bounds, but e.g. `IBP_general` and code in `_get_optimized_bounds`\n        # relies on the fact that it can be None\n        if C is None:\n            # C is an identity matrix by default\n            if final.output_shape is None:\n                raise ValueError(\n                    f'C is not missing while node {final} has no default shape')\n            dim_output = int(prod(final.output_shape[1:]))\n            # TODO: use an eyeC object here.\n            C = torch.eye(dim_output, device=self.device).expand(\n                self.batch_size, dim_output, dim_output)\n\n        # Reuse previously saved alpha values,\n        # even if they are not optimized now\n        # This must be done here instead of `compute_bounds`, as other code might change\n        # it (e.g. `_get_optimized_bounds`)\n        if reuse_alpha:\n            self.opt_reuse()\n        else:\n            self.opt_no_reuse()\n\n        for node in self.nodes():\n            # All nodes may need to be recomputed\n            node.prior_checked = False\n\n        self.check_prior_bounds(final, C=C)\n\n        if method == 'backward':\n            apply_output_constraints_to = (\n                self.bound_opts['optimize_bound_args']['apply_output_constraints_to']\n            )\n            # This is for the final output bound.\n            # No need to pass in intermediate layer beta constraints.\n            ret = self.backward_general(\n                final, C,\n                bound_lower=bound_lower, bound_upper=bound_upper,\n                average_A=average_A, need_A_only=need_A_only,\n                unstable_idx=alpha_idx, update_mask=update_mask,\n                apply_output_constraints_to=apply_output_constraints_to)\n\n            if self.bound_opts['compare_crown_with_ibp']:\n                new_lower, new_upper = self.compare_with_IBP(final, lower=ret[0], upper=ret[1], C=C)\n                ret = (new_lower, new_upper) + ret[2:]\n\n            # FIXME when C is specified, lower and upper should not be saved to\n            # final.lower and final.upper, because they are not the bounds for\n            # the node.\n            final.lower, final.upper = ret[0], ret[1]\n\n            return ret\n        elif method == 'forward' or method == 'dynamic-forward':\n            return self.forward_general(C=C, node=final, concretize=True)\n        else:\n            raise NotImplementedError\n\n    def _set_used_nodes(self, final):\n        # By default, all *.used are initialized to False.\n        # We set the used nodes by BFS from the final node.\n        if final.name != self.last_final_node_name:\n            self.last_final_node_name = final.name\n            final.used = True\n            queue = deque([final])\n            while len(queue) > 0:\n                n = queue.popleft()\n                for n_pre in n.inputs:\n                    if not n_pre.used:\n                        n_pre.used = True\n                        queue.append(n_pre)\n        # Based on \"used\" and \"perturbed\" properties, find out which\n        # layer requires intermediate layer bounds.\n        self.layers_requiring_bounds = self.get_layers_requiring_bounds()\n\n    def init_infeasible_bounds_constraints(self, batchsize, device):\n        '''Simply initialize the infeasible bound record.'''\n        self.infeasible_bounds_constraints = torch.full((batchsize,), False, device=device)\n\n    from .interval_bound import (\n        IBP_general, _IBP_loss_fusion, check_IBP_intermediate,\n        check_IBP_first_linear, compare_with_IBP)\n    from .forward_bound import (\n        forward_general, forward_general_dynamic, forward_refinement, init_forward)\n    from .backward_bound import (\n        backward_general, get_sparse_C,\n        check_optimized_variable_sparsity, restore_sparse_bounds,\n        get_alpha_crown_start_nodes, get_unstable_locations, batched_backward,\n        _preprocess_C)\n    from .output_constraints import (\n        backward_general_with_output_constraint, invprop_enabled,\n        backward_general_invprop, invprop_init_infeasible_bounds,\n        invprop_check_infeasible_bounds)\n    from .optimized_bounds import (\n        _get_optimized_bounds, init_alpha, update_best_beta,\n        opt_reuse, opt_no_reuse, _to_float64, _to_default_dtype)\n    from .beta_crown import (beta_crown_backward_bound, reset_beta, set_beta,\n                             set_beta_cuts, get_split_nodes)\n    from .jacobian import (compute_jacobian_bounds, _expand_jacobian)\n    from .optimize_graph import _optimize_graph\n    from .edit_graph import add_nodes, add_input_node, delete_node, replace_node\n    from .tools import visualize\n    from .concretize_bounds import (\n        concretize_bounds, concretize_root, backward_concretize, forward_concretize)\n\n\n    from .solver_module import (\n        build_solver_module, _build_solver_input, _build_solver_general,\n        _reset_solver_vars, _reset_solver_model)\n"
  },
  {
    "path": "auto_LiRPA/bound_multi_gpu.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom torch.nn import DataParallel\nfrom .perturbations import *\nfrom .bounded_tensor import BoundedTensor\nfrom itertools import chain\n\nclass BoundDataParallel(DataParallel):\n    # https://github.com/huanzhang12/CROWN-IBP/blob/master/bound_layers.py\n    # This is a customized DataParallel class for our project\n    def __init__(self, *inputs, **kwargs):\n        super(BoundDataParallel, self).__init__(*inputs, **kwargs)\n        self._replicas = None\n\n    # Overide the forward method\n    def forward(self, *inputs, **kwargs):\n        disable_multi_gpu = False  # forward by single GPU\n        no_replicas = False  # forward by multi GPUs but without replicate\n        if \"disable_multi_gpu\" in kwargs:\n            disable_multi_gpu = kwargs[\"disable_multi_gpu\"]\n            kwargs.pop(\"disable_multi_gpu\")\n\n        if \"no_replicas\" in kwargs:\n            no_replicas = kwargs[\"no_replicas\"]\n            kwargs.pop(\"no_replicas\")\n\n        if not self.device_ids or disable_multi_gpu:\n            if kwargs.pop(\"get_property\", False):\n                return self.get_property(self, *inputs, **kwargs)\n            return self.module(*inputs, **kwargs)\n\n        if kwargs.pop(\"get_property\", False):\n            if self._replicas is None:\n                assert 0, 'please call IBP/CROWN before get_property'\n            if len(self.device_ids) == 1:\n                return self.get_property(self.module, **kwargs)\n            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)\n            kwargs = list(kwargs)\n            for i in range(len(kwargs)):\n                kwargs[i]['model'] = self._replicas[i]\n            outputs = self.parallel_apply([self.get_property] * len(kwargs), inputs, kwargs)\n            return self.gather(outputs, self.output_device)\n\n        # Only replicate during forward/IBP propagation. Not during interval bounds\n        # and CROWN-IBP bounds, since weights have not been updated. This saves 2/3\n        # of communication cost.\n        if not no_replicas:\n            if self._replicas is None:  # first time\n                self._replicas = self.replicate(self.module, self.device_ids)\n            elif kwargs.get(\"method_opt\", \"forward\") == \"forward\":\n                self._replicas = self.replicate(self.module, self.device_ids)\n            elif kwargs.get(\"x\") is not None and kwargs.get(\"IBP\") is True:  #\n                self._replicas = self.replicate(self.module, self.device_ids)\n            # Update the input nodes to the ones within each replica respectively\n            for bounded_module in self._replicas:\n                for node in bounded_module._modules.values():\n                    node.inputs = [bounded_module[name] for name in node.input_name]\n\n        for t in chain(self.module.parameters(), self.module.buffers()):\n            if t.device != self.src_device_obj:\n                raise RuntimeError(\"module must have its parameters and buffers \"\n                                   \"on device {} (device_ids[0]) but found one of \"\n                                   \"them on device: {}\".format(self.src_device_obj, t.device))\n\n        # TODO: can be done in parallel, only support same ptb for all inputs per forward/IBP propagation\n        if len(inputs) > 0 and hasattr(inputs[0], 'ptb') and inputs[0].ptb is not None:\n            # compute bounds without x\n            # inputs_scatter is a normal tensor, we need to assign ptb to it if inputs is a BoundedTensor\n            inputs_scatter, kwargs = self.scatter((inputs, inputs[0].ptb.x_L, inputs[0].ptb.x_U), kwargs,\n                                                  self.device_ids)\n            # inputs_scatter = inputs_scatter[0]\n            bounded_inputs = []\n            for input_s in inputs_scatter:  # GPU numbers\n                # FIXME other perturbations are not supported yet\n                assert isinstance(inputs[0].ptb, PerturbationLpNorm)\n                ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2])\n                input_s = list(input_s[0])\n                input_s[0] = BoundedTensor(input_s[0], ptb)\n                input_s = tuple(input_s)\n                bounded_inputs.append(input_s)\n\n            # bounded_inputs = tuple(bounded_inputs)\n        elif kwargs.get(\"x\") is not None and hasattr(kwargs.get(\"x\")[0], 'ptb') and kwargs.get(\"x\")[0].ptb is not None:\n            # compute bounds with x\n            # kwargs['x'] is a normal tensor, we need to assign ptb to it\n            x = kwargs.get(\"x\")[0]\n            bounded_inputs = []\n            inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids)\n            for input_s, kw_s in zip(inputs_scatter, kwargs):  # GPU numbers\n                # FIXME other perturbations are not supported yet\n                assert isinstance(x.ptb, PerturbationLpNorm)\n                ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2])\n                kw_s['x'] = list(kw_s['x'])\n                kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb)\n                kw_s['x'] = (kw_s['x'])\n                bounded_inputs.append(tuple(input_s[0], ))\n        else:\n            # normal forward\n            inputs_scatter, kwargs = self.scatter(inputs, kwargs, self.device_ids)\n            bounded_inputs = inputs_scatter\n\n        if len(self.device_ids) == 1:\n            return self.module(*bounded_inputs[0], **kwargs[0])\n        outputs = self.parallel_apply(self._replicas[:len(bounded_inputs)], bounded_inputs, kwargs)\n        return self.gather(outputs, self.output_device)\n\n    @staticmethod\n    def get_property(model, node_class=None, att_name=None, node_name=None):\n        if node_name:\n            # Find node by name\n            # FIXME If we use `model.named_modules()`, the nodes have the\n            # `BoundedModule` type rather than bound nodes.\n            for node in model._modules.values():\n                if node.name == node_name:\n                    return getattr(node, att_name)\n        else:\n            # Find node by class\n            for _, node in model.named_modules():\n                # Find the Exp neuron in computational graph\n                if isinstance(node, node_class):\n                    return getattr(node, att_name)\n\n    def state_dict(self, destination=None, prefix='', keep_vars=False):\n        # add 'module.' here before each keys in self.module.state_dict() if needed\n        return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)\n\n    def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):\n        return self.module._named_members(get_members_fn, prefix, recurse, remove_duplicate)\n\n    def __getitem__(self, name):\n        return self.module[name]\n"
  },
  {
    "path": "auto_LiRPA/bound_op_map.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .bound_ops import *\n\nbound_op_map = {\n    'onnx::Gemm': BoundLinear,\n    'prim::Constant': BoundPrimConstant,\n    'grad::Concat': BoundConcatGrad,\n    'grad::Relu': BoundReluGrad,\n    'grad::Conv2d': BoundConv2dGrad,\n    'grad::Slice': BoundSliceGrad,\n    'grad::Sqr': BoundSqr,\n    'grad::jacobian': BoundJacobianOP,\n    'grad::Tanh': BoundTanhGrad,\n    'grad::Sigmoid': BoundSigmoidGrad,\n    'custom::Gelu': BoundGelu,\n    'onnx::Clip': BoundHardTanh\n}\n\ndef register_custom_op(op_name: str, bound_obj: Bound) -> None:\n    bound_op_map[op_name] = bound_obj\n\ndef unregister_custom_op(op_name: str) -> None:\n    bound_op_map.pop(op_name)\n"
  },
  {
    "path": "auto_LiRPA/bound_ops.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .operators import *\n"
  },
  {
    "path": "auto_LiRPA/bounded_tensor.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport copy\nimport torch.nn as nn\nfrom torch import Tensor\nimport torch._C as _C\n\n\nclass BoundedTensor(Tensor):\n    @staticmethod\n    # We need to override the __new__ method since Tensor is a C class\n    def __new__(cls, x, ptb=None, *args, **kwargs):\n        if isinstance(x, Tensor):\n            tensor = super().__new__(cls, [], *args, **kwargs)\n            tensor.data = x.data\n            tensor.requires_grad = x.requires_grad\n            return tensor\n        else:\n            return super().__new__(cls, x, *args, **kwargs)\n\n    def __init__(self, x, ptb=None):\n        self.ptb = ptb\n\n    def __repr__(self):\n        if hasattr(self, 'ptb') and self.ptb is not None:\n            return '<BoundedTensor: {}, {}>'.format(super().__repr__(), self.ptb.__repr__())\n        else:\n            return '<BoundedTensor: {}, no ptb>'.format(super().__repr__())\n\n    def clone(self, *args, **kwargs):\n        tensor = BoundedTensor(super().clone(*args, **kwargs), copy.deepcopy(self.ptb))\n        return tensor\n\n    def _func(self, func, *args, **kwargs):\n        temp = func(*args, **kwargs)\n        new_obj = BoundedTensor([], self.ptb)\n        new_obj.data = temp.data\n        new_obj.requires_grad = temp.requires_grad\n        return new_obj\n\n    # Copy to other devices with perturbation\n    def to(self, *args, **kwargs):\n        # FIXME add a general \"to\" function in perturbation class, not here.\n        if hasattr(self.ptb, 'x_L') and isinstance(self.ptb.x_L, Tensor):\n            self.ptb.x_L = self.ptb.x_L.to(*args, **kwargs)\n        if hasattr(self.ptb, 'x_U') and isinstance(self.ptb.x_U, Tensor):\n            self.ptb.x_U = self.ptb.x_U.to(*args, **kwargs)\n        if hasattr(self.ptb, 'eps') and isinstance(self.ptb.eps, Tensor):\n            self.ptb.eps = self.ptb.eps.to(*args, **kwargs)\n        return self._func(super().to, *args, **kwargs)\n\n    @classmethod\n    def _convert(cls, ret):\n        if cls is Tensor:\n            return ret\n\n        if isinstance(ret, Tensor):\n            if True:\n                # The current implementation does not seem to need non-leaf BoundedTensor\n                return ret\n            else:\n                # Enable this branch if non-leaf BoundedTensor should be kept\n                ret = ret.as_subclass(cls)\n\n        if isinstance(ret, tuple):\n            ret = tuple(cls._convert(r) for r in ret)\n\n        return ret\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        if kwargs is None:\n            kwargs = {}\n\n        if not all(issubclass(cls, t) for t in types):\n            return NotImplemented\n\n        with _C.DisableTorchFunction():\n            ret = func(*args, **kwargs)\n            return cls._convert(ret)\n\n\nclass BoundedParameter(nn.Parameter):\n    def __new__(cls, data, ptb, requires_grad=True):\n        return BoundedTensor._make_subclass(cls, data, requires_grad)\n\n    def __init__(self, data, ptb, requires_grad=True):\n        self.ptb = ptb\n        self.requires_grad = requires_grad\n\n    def __deepcopy__(self, memo):\n        if id(self) in memo:\n            return memo[id(self)]\n        else:\n            result = type(self)(self.data.clone(), self.ptb, self.requires_grad)\n            memo[id(self)] = result\n            return result\n\n    def __repr__(self):\n        return 'BoundedParameter containing:\\n{}\\n{}'.format(\n            self.data.__repr__(), self.ptb.__repr__())\n\n    def __reduce_ex__(self, proto):\n        raise NotImplementedError\n"
  },
  {
    "path": "auto_LiRPA/concretize_bounds.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\n\nfrom .utils import eyeC\nfrom .bound_ops import *\nfrom .patches import Patches\nfrom .perturbations import PerturbationLpNorm\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef concretize_bounds(\n    self: 'BoundedModule',\n    node,\n    lower,\n    upper,\n    concretize_mode='backward',\n    # for `backward_concretize`\n    batch_size=None,\n    output_dim=None,\n    average_A=None,\n    # for `forward_concretize`\n    lw=None,\n    uw=None,\n    # common\n    clip_neuron_selection_value=-1.0,\n    clip_neuron_selection_type=\"ratio\"\n):\n    \"\"\"\n    If neuron_selection_value >= 0, run an unconstrained/bounds-saving pass\n    then a top-K constrained pass; otherwise just one pass.\n    \"\"\"\n    # decide which underlying call to use\n    def _call_concretize(use_constraints, save_bounds=False, heuristic_indices=None):\n        if concretize_mode == 'backward':\n            # backward concretize signature\n            return backward_concretize(\n                self, batch_size, output_dim, lower, upper,\n                average_A=average_A,\n                node_start=node,\n                use_constraints=use_constraints,\n                save_bounds=save_bounds,\n                heuristic_indices=heuristic_indices,\n            )\n        elif concretize_mode == 'forward':\n            # forward_concretize signature\n            return forward_concretize(\n                self, lower, upper, lw, uw,\n                use_constraints=use_constraints,\n                save_bounds=save_bounds,\n                heuristic_indices=heuristic_indices,\n            )\n        else:\n            raise ValueError(f\"Unknown concretize mode: {concretize_mode}. \"\n                             \"Please use 'backward' or 'forward'.\")\n\n    use_constraints = True\n    save_bounds = False\n\n    # If clip_neuron_selection_value >= 0, heuristic score-based topk selection is enabled.\n    # And we will only apply constrained concretization on topk neurons based on their heuristics.\n    # In this case, we'll need to 1) concretize all neurons without any constraints to get a looser bound \n    #                                           --> This is for computing the heuristics\n    #                             2) concretize topk neurons with constraints.            \n    #                                           --> This is for getting tighter bounds for topk neurons.                \n    # In conclusion, if neuron_selection_value >= 0, use_consrtaints will be disabled first.\n    # But for the output node in the computational graph we will directly concretize all neurons..\n    if clip_neuron_selection_value >= 0 and node.name not in self.output_name:\n        use_constraints = False\n        # `output_activations` is the list of output activations from current pre-activation node.\n        # This output_activations is manually assigned outside of auto_lirpa. Please check \n        #       complete_verifier/input_split/batch_branch_and_bound.py for more info.\n        # If a node: \n        #       a) does not have any output_activation, and \n        #       b) heuristic topk selection is enabled, and\n        #       c) is not the output node in the computational graph\n        #  we will only compute naive bounds on it.\n        # Otherwise, we'll need to do both step 1) and 2). And to accelarate step 2), we will save the bounds in 1).\n        \n        # If 1) this node has at least one output activation node\n        #    2) at least one neuron will be selected\n        # We will need to concretize with constraints, \n        if node.output_activations is not None and clip_neuron_selection_value > 0:\n            save_bounds = True\n\n    # If heuristic topk selection is enabled, this would be the step 1).\n    new_lower, new_upper, has_constraints = _call_concretize(\n        use_constraints=use_constraints,\n        save_bounds=save_bounds,\n    )\n\n    # If heuristic topk selection is enabled, this if-branch would be the step 2).\n    if (has_constraints\n        and node.output_activations is not None\n        and clip_neuron_selection_value > 0\n        and node.name not in self.output_name):\n\n        score = 0.0\n        unstable_masks = False\n\n        # loop through all the output activations to get a comprehensive unstable mask and heuristic score.\n        # This output_activations is manually assigned outside of auto_lirpa.\n        # Please check complete_verifier/input_split/batch_branch_and_bound.py\n        for o_act_node in node.output_activations:\n            score = score + o_act_node.compute_bound_improvement_heuristics(new_lower, new_upper)\n            unstable_masks = unstable_masks | o_act_node.get_unstable_mask(new_lower, new_upper)\n        score = score.flatten(1)                        # shape: (Batchsize, Hidden_dim)\n        unstable_masks = unstable_masks.flatten(1)      # shape: (Batchsize, Hidden_dim)\n\n        # Only do second concretize if there exists unstable neurons.\n        if unstable_masks.any():\n            max_unstable_size = unstable_masks.sum(dim=1).max()\n            heuristic_indices = None\n            # The K value in topk should be at least 1.\n            if clip_neuron_selection_type == \"ratio\":\n                K = max(int(max_unstable_size * clip_neuron_selection_value + 0.5), 1)\n            else:\n                K = min(clip_neuron_selection_value, max_unstable_size)\n            _, heuristic_indices = torch.topk(score, k=K, dim=1, largest=True, sorted=False)\n            new_lower, new_upper, _ = _call_concretize(\n                use_constraints=True,\n                heuristic_indices=heuristic_indices\n            )\n        else:\n            # Previously we've stored to aux bounds, now it should be cleared to avoid any confusion.\n            for root in self.roots():\n                if (hasattr(root, 'perturbation')\n                    and root.perturbation is not None\n                    and isinstance(root.perturbation, PerturbationLpNorm)):\n                    root.perturbation.clear_aux_bounds()\n\n    return new_lower, new_upper\n\n\ndef concretize_root(self, root, batch_size, output_dim,\n                    average_A=False, node_start=None, input_shape=None,\n                    use_constraints=False, heuristic_indices=None, save_bounds=False): \n    # The last three optional argument are designed for heuristic-driven constrained concretization.\n    # use_constraints:      A flag controling whether to enable constraints solving or not.\n    # heuristic_indices:    A index tensor, it select EUQAL number of hidden neurons from each batch. \n    #                           Constrained solving will be further applied on these neurons. Shape (batchsize, n_h_neurons)\n    # save_bounds:          A flag determining whether to save naive bounds (to avoid redundant computation)\n\n    if average_A and isinstance(root, BoundParams):\n        lA = root.lA.mean(\n            node_start.batch_dim + 1, keepdim=True\n        ).expand(root.lA.shape) if (root.lA is not None) else None\n        uA = root.uA.mean(\n            node_start.batch_dim + 1, keepdim=True\n        ).expand(root.uA.shape) if (root.uA is not None) else None\n    else:\n        lA, uA = root.lA, root.uA\n    if not isinstance(root.lA, eyeC) and not isinstance(root.lA, Patches):\n        lA = root.lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if (lA is not None) else None\n    if not isinstance(root.uA, eyeC) and not isinstance(root.uA, Patches):\n        uA = root.uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if (uA is not None) else None\n    \n    has_constraints = False\n    if hasattr(root, 'perturbation') and root.perturbation is not None:\n\n        if isinstance(root.perturbation, PerturbationLpNorm):\n            # Enable / Disable constraints solving according to `use_constraints`\n            root.perturbation.constraints_enable = use_constraints\n            if root.perturbation.constraints is not None:\n                if self.infeasible_bounds_constraints is not None:\n                    root.perturbation.add_infeasible_batches(self.infeasible_bounds_constraints)\n                root.perturbation.add_objective_indices(heuristic_indices)\n                has_constraints = True\n\n        if isinstance(root, BoundParams):\n            # add batch_size dim for weights node\n            lb = root.perturbation.concretize(\n                root.center.unsqueeze(0), lA, sign=-1, aux=root.aux\n            ) if (lA is not None) else None\n            ub = root.perturbation.concretize(\n                root.center.unsqueeze(0), uA, sign=+1, aux=root.aux\n            ) if (uA is not None) else None\n\n        else:\n            lb = root.perturbation.concretize(\n                root.center, lA, sign=-1, aux=root.aux\n            ) if lA is not None else None\n            ub = root.perturbation.concretize(\n                root.center, uA, sign=+1, aux=root.aux\n            ) if uA is not None else None\n\n        if (isinstance(root.perturbation, PerturbationLpNorm) \n            and root.perturbation.constraints is not None\n            and root.perturbation.sorted_out_batches[\"infeasible_batches\"] is not None):\n            if self.infeasible_bounds_constraints is not None:\n                self.infeasible_bounds_constraints = self.infeasible_bounds_constraints | root.perturbation.sorted_out_batches[\"infeasible_batches\"]\n            # else:\n            #     self.infeasible_bounds_constraints = root.perturbation.sorted_out_batches[\"infeasible_batches\"]\n\n        # If required, save current (naive) bounds to prevent redundant computation next time concretize on the same node\n        if isinstance(root.perturbation, PerturbationLpNorm) and root.perturbation.constraints is not None and save_bounds:\n            root.perturbation.add_aux_bounds(lb, ub)\n        elif isinstance(root.perturbation, PerturbationLpNorm):\n        # Otherwise, always clear_aux_bounds to prevent confusion\n            root.perturbation.clear_aux_bounds()\n\n    else:\n        fv = root.forward_value\n        if type(root) == BoundInput:\n            # Input node with a batch dimension\n            batch_size_ = batch_size\n        else:\n            # Parameter node without a batch dimension\n            batch_size_ = 1\n\n        def concretize_constant(A):\n            if isinstance(A, eyeC):\n                return fv.view(batch_size_, -1)\n            elif isinstance(A, Patches):\n                return A.matmul(fv, input_shape=input_shape)\n            elif type(root) == BoundInput:\n                return A.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1)\n            else:\n                return A.matmul(fv.view(-1, 1)).squeeze(-1)\n\n        lb = concretize_constant(lA) if (lA is not None) else None\n        ub = concretize_constant(uA) if (uA is not None) else None\n\n    return lb, ub, has_constraints\n\n\ndef backward_concretize(self, batch_size, output_dim, lb=None, ub=None,\n               average_A=False, node_start=None, \n               use_constraints=False, heuristic_indices=None, save_bounds=False):\n    # The last three optional argument are designed for heuristic-driven constrained concretization.\n    # use_constraints:      A flag controling whether to enable constraints solving or not.\n    # heuristic_indices:    A index tensor, it select EUQAL number of hidden neurons from each batch. \n    #                           Constrained solving will be further applied on these neurons. Shape (batchsize, n_h_neurons)\n    # save_bounds:          A flag determining whether to save naive bounds (to avoid redundant computation)\n    roots = self.roots()\n    if isinstance(lb, torch.Tensor) and lb.ndim > 2:\n        lb = lb.reshape(lb.shape[0], -1)\n    if isinstance(ub, torch.Tensor) and ub.ndim > 2:\n        ub = ub.reshape(ub.shape[0], -1)\n\n    def add_b(b1, b2):\n        if b2 is None:\n            return b1\n        elif b1 is None:\n            return b2\n        # Check if b1 is a tensor and if all its elements are infinity\n        if torch.is_tensor(b1) and torch.isinf(b1).all():\n            return b1\n        # Check if b2 is a tensor and if all its elements are infinity\n        if torch.is_tensor(b2) and torch.isinf(b2).all():\n            return b2\n        else:\n            return b1 + b2\n\n    has_constraints = False\n    for root in roots:\n        root.lb = root.ub = None\n        if root.lA is None and root.uA is None:\n            continue\n        root.lb, root.ub, has_constraints_this_root = self.concretize_root(\n            root, batch_size, output_dim, average_A=average_A,\n            node_start=node_start, input_shape=roots[0].center.shape,\n            use_constraints=use_constraints, heuristic_indices=heuristic_indices, save_bounds=save_bounds)\n\n        has_constraints = has_constraints | has_constraints_this_root\n\n        lb = add_b(lb, root.lb)\n        ub = add_b(ub, root.ub)\n\n    return lb, ub, has_constraints\n\n\ndef forward_concretize(self, lower, upper, lw, uw, use_constraints=False, heuristic_indices=None, save_bounds=False):\n    \"\"\"\n    Concretize function for forward bound. \n\n    :param lower:                   Tensor. Intermediate layer lower bounds.\n    :param upper:                   Tensor. Intermediate layer upper bounds.\n    :param lw:                      Tensor. Intermediate layer lower A matrix.\n    :param uw:                      Tensor. Intermediate layer upper A matrix.\n    :param use_constraints:         bool. A flag controling whether to enable constraints solving or not.\n        If heuristic ratio is set, the first concretization run should disbale constraints solving.\n    :param heuristic_indices:       Index Tensor. A index tensor, it select **equal** number of hidden neurons from each batch.\n        Constrained solving will be further applied on these neurons. Shape (batchsize, n_h_neurons)\n    :param save_bounds:             bool. A flag controling whether to save naive bounds.\n    \n    :return res_lower:              Tensor. The lower bound tensor.\n    :return res_upper:              Tensor. The upper bound tensor.\n    :return has_constraints:        bool. Whether constraints has been stored.\n    \"\"\"\n    res_lower = 0.0\n    res_upper = 0.0\n    prev_dim_in = 0\n    has_constraints = False\n    roots = self.roots()\n    assert (lw.ndim > 1)\n    lA = lw.reshape(self.batch_size, self.dim_in, -1).transpose(1, 2)\n    uA = uw.reshape(self.batch_size, self.dim_in, -1).transpose(1, 2)\n    for root in roots:\n        if hasattr(root, 'perturbation') and root.perturbation is not None:\n            _lA = lA[:, :, prev_dim_in : (prev_dim_in + root.dim)]\n            _uA = uA[:, :, prev_dim_in : (prev_dim_in + root.dim)]\n\n            if isinstance(root.perturbation, PerturbationLpNorm):\n                root.perturbation.constraints_enable = use_constraints\n                if root.perturbation.constraints is not None:\n                    if self.infeasible_bounds_constraints is not None:\n                        root.perturbation.add_infeasible_batches(self.infeasible_bounds_constraints)\n                    root.perturbation.add_objective_indices(heuristic_indices)\n                    has_constraints = True                 \n\n            # Previously added concretized bounds directly to lower/upper.\n            # Now extract them first for reuse (e.g., in aux_bounds).\n            temp_lower = root.perturbation.concretize(\n                root.center, _lA, sign=-1, aux=root.aux\n                ).view(lower.shape)\n            temp_upper = root.perturbation.concretize(\n                root.center, _uA, sign=+1, aux=root.aux\n                ).view(upper.shape)\n            \n            # Update infeasible_batches\n            if (isinstance(root.perturbation, PerturbationLpNorm)\n                and root.perturbation.constraints is not None \n                and root.perturbation.sorted_out_batches[\"infeasible_batches\"] is not None):\n                if self.infeasible_bounds_constraints is not None:\n                    self.infeasible_bounds_constraints = self.infeasible_bounds_constraints | root.perturbation.sorted_out_batches[\"infeasible_batches\"]\n                # else:\n                #     self.infeasible_bounds_constraints = root.perturbation.sorted_out_batches[\"infeasible_batches\"]\n\n            # If required, save current (naive) bounds to prevent redundant computation next time concretize on the same node\n            if isinstance(root.perturbation, PerturbationLpNorm) and root.perturbation.constraints is not None and save_bounds:\n                root.perturbation.add_aux_bounds(temp_lower, temp_upper)\n            elif isinstance(root.perturbation, PerturbationLpNorm):\n            # Otherwise, always clear_aux_bounds to prevent confusion\n                root.perturbation.clear_aux_bounds()\n\n            # Now the concretization result from this root will be accumulated into final bounds.\n            # Here we add temp_lower onto res_lower, instead of lower. \n            # It's because the lower value will be used twice, any modification to it should be avoided.\n            res_lower = res_lower + temp_lower\n            res_upper = res_upper + temp_upper                        \n    \n    res_lower = res_lower + lower\n    res_upper = res_upper + upper\n    return res_lower, res_upper, has_constraints\n"
  },
  {
    "path": "auto_LiRPA/concretize_func.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\n\nfrom math import floor, ceil\nfrom .utils import eyeC\n\n# Declaration of the shape naming:\n\n# B / batchsize  : The number of batches. In this `concretize_func.py`, if a tensor has batch dimension, we assume\n#                    it will only be the first dimention of this tensor . That is: B = tensor.shape[0]\n# \n# B_act          : The number of active batches. We will only apply constraints to a subset of batches, and these\n#                   batches are called active batches. B_act <= B. There are two cases:\n#                       -- When `no_return_inf` mode is disabled, we will keep B_act static throughout the entire \n#                           BaB iteration. In this case, B_act equals the number of batches not fully covered by \n#                           constraints, as determined by `sort_out_constr_batches` function.\n#                       -- When `no_return_inf` mode is enabled, then B_act decreases over iterations, since more\n#                           batches will be marked as infeasible. See `PerturbationLpNorm.add_infeasible_batches`.\n# \n# X / x_dim      : The number of input neurons (batch dimension excluded). It stands for the input shape of the\n#                   neural network. For tensors such as x0, epsilon, x_U, x_L, X = prod(*tensor.shape[1:])\n# \n# H / hidden_dim : The number of hidden neurons (batch and input dimension excluded). It stands for the output\n#                   shape of this hidden layer. For the objective A tensor, there are two cases:\n#                       -- The tensor has batch dimention: H = tensor.view(B, -1, X).shape[1]\n#                       -- The tensor does not have batch dimention: H = tensor.view(-1, X).shape[0]\n# \n# H_act          : The number of active batches. We may only apply constraints to a subset of hidden neurons,\n#                   and these neurons are called active neurons. H_act <= H.\n#\n# N_constr       : The number of constraints. For constraints_A matrix:\n#                       -- In `sort_out_constr_batches` function, its shape is (B, N_constr, X)\n#                       -- In `constraints_solving` function, its shape is (B_act, N_constt, X)\n\ndef construct_constraints(constr_A: torch.Tensor, constr_b: torch.Tensor, constr_rhs: torch.Tensor,\n                            batchsize, x_dim, sign=1):\n    r\"\"\"\n    Construct the constraints tuple. This function provides a unified interface to generate this tuple.\n    All the users should carefully read this function to fully understand the standard form of constraints.\n\n    The first three argument expresses the non-standard form of the constraints:\n                                    A @ x + b <= rhs\n    We will first convert it into the standard form:\n                                    A @ x + b' <= 0\n    The the standard expression of constraints should be (constr_A, constr_b')\n\n    Args:\n        constr_A:   The coefficient A matrix of constraints.\n                        It should be able to be reshaped into: (B, N_constr, X)\n        constr_b:   The bias term of constraints.        \n                        It should be able to be reshaped into: (B, N_constr)\n        constr_rhs: The right-hand-side term of constraints.        \n                        It should be able to be reshaped into: (B, N_constr)\n        batchsize:  The batchsize B.\n        x_dim:      The input dimension X (batchsize dimension excluded)\n    \"\"\"\n    constr_A = sign * constr_A.reshape((batchsize, -1, x_dim))\n    if constr_rhs is not None and not torch.all(constr_rhs == 0):\n        constr_b = sign * (constr_b - constr_rhs).reshape((batchsize, -1))\n    else:\n        constr_b = sign * constr_b.reshape((batchsize, -1))\n    return (constr_A, constr_b)\n\ndef _sort_out_constraints(A, b, x0, epsilon):\n    r\"\"\"\n    Filter out some batches with constraints not intersecting with input region\n\n    Args:\n        A (Tensor): A matrix of constraints with shape of (batchsize, n_constraints, x_dim)\n        b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)\n        x0 (Tensor): Centroid of the input space with shape of (batchsize, x_dim, 1)\n        epsilon (Tensor): Offset from the centroid to the input space boundary with shape of (batchsize, x_dim, 1)\n    Return:\n        no_intersection (Tensor): A boolean tensor with shape (batchsize, ), indicating if certain batch is infeasible\n            because a constraint does not intersect with input space\n        fully_covered (Tensor): A boolean tensor with shape (batchsize, ), indicating if all the constraints in a certain \n            batch fully covers the corresponding input region. In this case, we can simply treat the batch as if it has no constraints\n    \"\"\"\n    # minimal and maximal value of A*x + b\n    x0_term = A.bmm(x0).squeeze(-1) + b        # shape: (B, N_constr)\n    eps_term = A.abs().bmm(epsilon).squeeze(-1) # shape: (B, N_constr)\n    minimal_val = x0_term - eps_term            # shape: (B, N_constr)\n    maximal_val = x0_term + eps_term            # shape: (B, N_constr)\n\n    # for any constrains: A * x + b <= 0,\n    # if its min(A * x + b) > 0, it has no intersection with x0 +- epsilon\n    # if its max(A * x + b) <= 0, it fully covers x0 +- epsilon \n    no_intersection = (minimal_val > 0).any(1)  # shape: (B, )\n    if not no_intersection.any():\n        no_intersection = None\n    fully_covered = (maximal_val <= 0).all(1)   # shape: (B, )\n    return no_intersection, fully_covered\n\n@torch.jit.script\ndef _dist_rearrange(constraints_A, constraints_b, x0):\n    r\"\"\"\n    Reorder the constraints according to their distance to x_prime\n\n    Args:\n        constraints_A (Tensor): A matrix of constraints with shape of (batchsize, n_constraints, x_dim)\n        constraints_b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)\n        x0 (Tensor): x0 tensor with shape of (batchsize, x_dim, 1). Based on the heuristic,\n        this can be the input space centroid x0, or the original optimal point x_prime\n    Return:\n        rearranged_A (Tensor): Rearranged matrix of constraints with shape of (batchsize, n_constraints, x_dim)\n        rearranged_b (Tensor): Bias term of constraints with shape of (batchsize, n_constraints)\n    \"\"\"\n    # Compute the normalized, directional distance from x_prime to constraints hyper-plane.\n    distance = (constraints_A.bmm(x0).squeeze(-1) + constraints_b) # shape: (B, N_constr)\n    l2_norm  = constraints_A.norm(p=2, dim=-1)                     # shape: (B, N_constr)\n    normed_dist = distance / l2_norm                               # shape: (B, N_constr)\n\n    # Sort the constraints according to this distance.\n    order = torch.sort(normed_dist, descending=True, dim=1)[1]\n    order_expand = order.unsqueeze(-1).expand(-1, -1, constraints_A.size(-1))\n    rearranged_A = constraints_A.gather(index=order_expand, dim=1)\n    rearranged_b = constraints_b.gather(index=order, dim=1)\n    return rearranged_A, rearranged_b\n\n@torch.jit.script\ndef _solve_dual_var(constr_a, object_a, constr_d, epsilon, a_mul_e=None):\n    r\"\"\"\n    Solve the following optimization problem:\n\n    Primal:         min_x   object_a^T x\n                    s.t.    constr_a^T x + constr_d <= 0,\n                            x0-epsilon <= x <= x0+epsilon\n\n    Dual:           min_x max_beta  object_a^T x + beta * (constr_a^T x + constr_d)\n                    s.t.            x0 - epsilon <= x <= x0 + epsilon\n                                    beta >= 0\n\n    Strong duality:\n                    max_{beta >= 0} min_{x \\in X} object_a^T x + beta * (constr_a^T x + constr_d)\n\n    Dual norm:\n                    max_{beta >= 0} - |object_a + beta * constr_a|^T epsilon + beta * (constr_a^T x0 + constr_d) + object_a^T x0\n\n    Now the sole optimize problem is piece-wise linear, we just have to check each \n    turning point and the end points of beta (0 and +inf)\n\n    Args:\n        constr_a (Tensor): Constraint A matrix with shape of (batchsize, x_dim)\n        object_a (Tensor): Objective A matrix with shape of (batchsize, h_dim, x_dim)\n        constr_d (Tensor): Pre-computed bias term of constraint with shape of (batchsize, )\n                    constr_d = constr_a^T x0 + constr_b\n        epsilon (Tensor): Offset from the centroid to the input space boundary with shape of (batchsize, x_dim, 1)\n    Return:\n        optimal_beta (Tensor): The optimal beta value with shape of (batchsize, h_dim)\n    \"\"\"\n\n    B_act = constr_a.size(0)\n    H_act = object_a.size(1)\n    device = constr_a.device\n    dtype = constr_a.dtype\n\n    # --- prepare fill-in tensors \n    zeros = torch.zeros((1, 1, 1), device=device, dtype=dtype).expand(B_act, H_act, 1)\n    infs = torch.full((1, 1, 1), fill_value=torch.inf, dtype=dtype, device=device).expand(B_act, H_act, 1)\n\n    a_reshape = constr_a.unsqueeze(1)                   # shape: (B_act, 1, X)\n    epsilon_reshape = epsilon.view((B_act, 1, -1))      # shape: (B_act, 1, X)\n    b_reshape = constr_d.view((-1, 1, 1))               # shape: (B_act, 1, 1)\n\n    # q is the turning points of the piece-wise linear function.\n    q = - object_a/a_reshape                            # shape: (B_act, H_act, X)\n    # idx indicates the ascending order of these turning points.\n    q_sort, idx = q.sort(dim=-1)                        # shape: (B_act, H_act, X) \n\n    # --- calculating the gradient w.r.t. beta within each interval ---\n    a_mul_e = (a_reshape * epsilon_reshape).expand(-1, H_act, -1)   # (B_act, H_act, X)\n    # a_mul_e = a_mul_e.expand(-1, H_act, -1)\n\n    #               (B_act, H_act, X)       (B_act, H_act, X)\n    a_sort = torch.gather(a_mul_e, dim=-1, index=idx)               # (B_act, H_act, X)\n\n    a_neg_cumsum = a_sort.abs().cumsum(dim=-1)              # shape: (B_act, H, x_dim)\n    a_neg_cumsum = torch.cat((zeros, a_neg_cumsum), dim=-1) # shape: (B_act, H_act, 1+X)\n    a_pos_cumsum = a_neg_cumsum - a_neg_cumsum[:, :, -1:]   # shape: (B_act, H_act, 1+X)\n    grad_beta = a_pos_cumsum + a_neg_cumsum - b_reshape     # shape: (B_act, H_act, 1+X)\n\n    # Due to the non-increasing trait of grad_beta, if there is a turning point\n    # then the gradient must change from positive to negative, and this turning point is the optimal beta.\n    sign_change = torch.searchsorted(grad_beta, zeros, right=False)\n\n    # It might be the case that grad_beta is always positive when beta > 0. \n    # This means the maximization object is ever-increasing, hence it is unbounded.\n    # For this case, a inf value would be returned.\n\n    # Following comes a case of sign_change where all the turning points q are positive:\n    # (g stands for grad_beta, q stands for turing points)\n    #    g[0] = 2       g[1] = 1       g[2] = -1       g[3] = -3   \n    # 0 --------- q[0] --------- q[1] ----------- q[2] ----------- ... --------> +inf\n    #                             ^\n    #                      sign_change=2\n    #\n    # q should represent the interval endpoints, hence, need to pad the left and right end with 0 and inf separately.\n\n    # cat shape: (B_act, H_act, 1+X+1)                   \n    q_new = torch.cat((zeros, q_sort, infs), dim=-1)                                       # shape: (B_act, H_act, X+2)\n    optimal_beta = torch.gather(q_new, dim=-1, index=sign_change).clamp(min=0).squeeze(-1) # shape: (B_act, H_act)\n\n    return optimal_beta\n\ndef sort_out_constr_batches(x_L, x_U, constraints, rearrange_constraints=False, no_return_inf=False):\n    r\"\"\"\n    Filter and preprocess input batches based on constraint feasibility.\n\n    This function examines which input regions \n        1) has no intersection with one of the constraints.\n        2) is fully covered by the all the constraints.\n\n    It also optionally rearranges constraint order for better numerical behavior,\n    and converts the constraint form from `(A, b)` to `(A, d)` where `d = A @ x0 + b`.\n    Here x0 means the centroid of the input region, that is x0 = (x_L + x_U) / 2.\n    \n    Args:\n        x_L (Tensor): Lower bound of input box, shape (B, *).\n        x_U (Tensor): Upper bound of input box, shape (B, *).\n        constraints (Tuple[Tensor, Tensor] or None): \n            A tuple `(A, b)` representing per-batch linear constraints.\n            - `A`: shape (B, N_constr, X)\n            - `b`: shape (B, N_constr)\n            If None or empty, the function returns early.\n        rearrange_constraints (bool): \n            Whether to rearrange constraints for better solver performance. Default: False.\n        no_return_inf (bool): \n            If True, infeasible batches will be excluded from `active_indices`.\n            Otherwise, infeasible batches are still marked active. Default: False.\n\n    Returns:\n        constraints (Optional[Tuple[Tensor, Tensor]]): \n            Filtered and reshaped constraint tuple `(A, d)` for active batches only.\n            - `A`: shape (B_active, N_constr, X)\n            - `d`: shape (B_active, N_constr)\n            If all batches are fully covered, returns None.\n\n        sorted_out_batches (dict): Diagnostic and filtering info with keys:\n            - 'infeasible_batches' (BoolTensor): Shape (B,), True if batch has no feasible region.\n                                                 If all the elements are False, it would be None. This would save space and time.\n            - 'fully_covered' (BoolTensor): Shape (B,), True if batch is completely covered by constraints.\n            - 'active_indices' (LongTensor): Indices of batches that are neither fully covered nor infeasible.\n    \"\"\"\n    sorted_out_batches = None\n    if constraints is None or constraints[0] is None or constraints[0].numel() == 0:\n        return None, sorted_out_batches\n\n    # Read argument and some necessary reshape\n    assert x_L is not None and x_U is not None, \"If constrained concretize is enabled, x_L and x_U cannot be None!\"\n    x0 = (x_L + x_U) / 2\n    epsilon = (x_U - x_L) / 2\n    constraints_A, constraints_b = constraints\n    batch_size = x0.shape[0]\n    x_dim = x0[0].numel()\n    x0 = x0.view((batch_size, x_dim, 1))                        # shape: (B, X, 1)\n    epsilon = epsilon.view((batch_size, x_dim, 1))              # shape: (B, X, 1)\n\n    no_intersection, fully_covered = _sort_out_constraints(constraints_A, constraints_b, x0, epsilon)\n    if fully_covered.all():\n        print(\"All the added constraints fully cover the input space. No need to apply constraints .\")\n        return None, sorted_out_batches\n    sorted_out_batches = {}\n    sorted_out_batches[\"infeasible_batches\"] = no_intersection\n    # If there's no infeasible batch, simply set it to be None. \n    # This will provide a shortcut when update the infeasible_batches vector.\n    # When batchsize is large and NN model has a lot of perturbed roots, this can save us some time.\n    sorted_out_batches[\"fully_covered\"] = fully_covered\n    active_mask = ~fully_covered\n    if no_intersection is not None and no_return_inf:\n        active_mask = ~no_intersection & active_mask\n    active_indices = torch.nonzero(active_mask, as_tuple=True)[0]\n    sorted_out_batches[\"active_indices\"] = active_indices\n\n    # Now constraints tuple only contains active constraints, shape change: (B, N_Constr, X) -> (B_act, N_constr, X)\n    constraints_A = constraints_A[active_indices]   # shape: (B_act, N_Constr, X)\n    constraints_b = constraints_b[active_indices]   # shape: (B_act, N_Constr)\n    active_x0 = x0[active_indices]\n    if rearrange_constraints:\n        constraints_A, constraints_b = _dist_rearrange(constraints_A, constraints_b, active_x0)\n    # Also, we will replace the constraint_b term with constraints_d term.\n    # For the usage of constraints_d, please check _solve function and constraints_solving function.\n    constraints_d = torch.einsum('bkx, bxo->bk', constraints_A, active_x0) + constraints_b    # shape: (B_act, N_Constr)\n    # Only store the constraints for active batches.\n    constraints = (constraints_A, constraints_d)\n\n    return constraints, sorted_out_batches\n\ndef constraints_solving(\n    x_L, x_U, objective, constraints, sign=-1.0,\n    sorted_out_batches={}, objective_indices=None, \n    constraints_enable=True, no_return_inf=False,\n    max_chunk_size=None, safety_factor=0.8, solver_memory_factor=2.0,\n    timer=None, \n    aux_bounds=None, \n    x0=None, epsilon=None, \n    act_x0=None, act_eps=None,\n    use_grad=True\n    ):\n    r\"\"\"\n    Combined constraint solving function with conditional logic based on objective shape.\n\n    - If objective is eyeC or broadcastable (shape[0]=1), uses a vectorized,\n        auto-chunked approach.\n    - If objective has batch dim matching input (shape[0]=N_batch), uses the\n        original approach (repeating inputs, no chunking).\n\n    Solves LP: max / min A_t * x, s.t. A_c * x + b_c <= 0, x_L <= x <= x_U\n\n    Args:\n        x_L, x_U (Tensor)               : Input bounds tensors.\n        objective (Tensor)              : Target coefficients (Tensor or eyeC).\n            - Tensor shape: (H, X), (1, H, X), or (N_batch, H, X).\n            - eyeC: Represents identity matrix.W\n        constraints (tuple, optional)   : Tuple (A_c, d_c) or None.\n        sign (float, optional)          : -1.0 for lower bound, +1.0 for upper bound.\n        sorted_out_batches (dict, optional): Dict with pre-filtered batch masks. Please check `sort_out_constr_batches` for more info.\n        constraints_enable (bool, optional): Flag for enabling constraints solving, this is set for heuristic hybrid solving, should be True by default.\n        no_return_inf (bool, optional)  :  Flag for returning inf value. If true, this function will return inf for all the infeasible subproblems.\n                        Otherwise, return naive bounds for infeasible ones.\n        max_chunk_size, safety_factor, solver_memory_factor: Params for chunking memory.\n                max_chunk_size:\n                        A hard upper limit on the number of problems to be processed in a single\n                        chunk, regardless of available memory. If set to an integer, the\n                        auto-calculated chunk size will not exceed this value.\n                        Use Case: Prevents the solver from creating a single, massive chunk that\n                        could cause system unresponsiveness, even if memory is technically\n                        available. Set to None to allow the function to use its own dynamic\n                        calculation.\n                safety_factor:\n                        A float between 0.0 and 1.0 that specifies what fraction of the free\n                        GPU memory should be considered \"usable\" for the calculation. For example,\n                        a value of 0.8 means the function will only use 80% of the available\n                        free memory as its budget.\n                        Use Case: This buffer helps prevent \"Out of Memory\" (OOM) errors by\n                        accounting for memory fragmentation, memory used by other processes, or\n                        overhead from the CUDA driver itself. A lower value is safer but may\n                        result in smaller chunks and thus slower overall processing.\n                solver_memory_factor:\n                        A heuristic multiplier used to estimate the memory consumed by the\n                        iterative solver loop. The theoretical memory usage is multiplied by\n                        this factor to create a more realistic estimate.\n                        Use Case: The exact memory allocated for intermediate tensors and\n                        computations within the solver can be complex to predict perfectly. This\n                        factor provides a \"fudge factor\" to pad the memory estimation, ensuring\n                        that the dynamically created tensors inside the solver loop do not cause\n                        an OOM error. Adjust this if you consistently face memory issues during\n                        the solver phase.\n        objective_indices (Tensor, optional): Indices tensor of shape (N_batch, H_active) indicating\n                            which objectives to compute. If None, all are computed.\n        timer: Optional Timer object.\n        aux_bounds (Tensor, optional)   : When hybrid constraint solving is enbaled, constrains_solving function will be called twice.\n                                       For its second run, we will load the result from the first run to save time computing naive results.\n        x0, eps (Tensor, optional)      : x0 and epsilon to solve on. \n                                    Without these two, we can still compute x0 and eps out of x_L and x_U.\n        act_x0, act_eps (Tensor, optional): Active x0 and epsilon to solve on.\n        use_grad (bool, optional): If False, the main computation is wrapped in\n                                    `torch.no_grad()` for better performance and lower\n                                    memory usage. Set to True only when gradients are\n                                    required (e.g., for clip during alpha crown). Defaults to True.\n\n    Returns:\n        bound (Tensor): Computed bounds (N_batch, H, 1).\n        infeasible_batches (boolTensor, optional) : If no_return_inf is True, `infeasible_batches` will be returned.\n                                                    It is a boolean tensor with shape (batch_size, ), with True indictating the batch is infeasible. \n    \"\"\"\n    if timer: timer.start('init')\n    if timer: timer.start(\"concretize\")\n\n    device = x_L.device\n    N_batch = x_L.size(0)\n\n    epsilon = (x_U - x_L) / 2.0 if epsilon is None else epsilon\n    x0 = (x_U + x_L) / 2.0 if x0 is None else x0\n    epsilon = epsilon.reshape((N_batch, -1, 1))\n    x0 = x0.reshape((N_batch, -1, 1))\n\n    is_eyeC = isinstance(objective, eyeC)\n\n    # --- Naive Case (No Constraints) ---\n    no_constraints_condition = (constraints is None) or (constraints[0].numel() == 0)\n    if no_constraints_condition or (not constraints_enable):\n        if is_eyeC:\n            solved_obj = x0 + sign * epsilon                                    # Shape: (N_batch, X, 1)\n        else:\n            base_term = torch.einsum('bhx,bxo->bho', objective, x0)             # Shape: (N_batch, X, 1)\n            eps_term = torch.einsum('bhx,bxo->bho', objective.abs(), epsilon)   # Shape: (N_batch, X, 1)\n            solved_obj = base_term + sign * eps_term # Shape: (N_batch, H, 1)\n        if timer: timer.add(\"init\")\n        if timer: timer.add(\"concretize\")\n        if no_return_inf:\n            return solved_obj, None\n        else:\n            return solved_obj\n\n    with torch.set_grad_enabled(use_grad):\n        is_broadcastable = False\n        is_batch_specific = False\n        H = -1 # Hidden dimension\n        X = x0.size(1) # Input X dimension\n        if is_eyeC:\n            is_broadcastable = True\n            H = X\n            # Internally represent eyeC as identity matrix for broadcastable path.\n            objective_tensor = torch.eye(X, device=device).unsqueeze(0) # Shape (1, X, X)\n        else:\n            if objective.shape[0] != N_batch:\n                # objective comes in shape of (H, X) or (1, H, X).\n                # It will be broadcasted to (B, H, X) later.\n                # Currently, is_broadcastable is designed for relu-bab, which usually takes much gpu memory,\n                # so is_broadcastable is also a control flag for objective chunking.\n                is_broadcastable = True\n            else:\n                # objective comes in shape of (B, H, X).\n                is_batch_specific = True\n            H = objective.shape[1]\n            objective_tensor = objective\n            if objective.shape[2] != X: raise ValueError(\"Objective shape mismatch\")\n\n        # --- Constrained Case ---\n        # --- Calculate Naive Bounds (used as default/fallback) ---\n        naive_bounds = torch.zeros(N_batch, H, 1, device=device)\n        if aux_bounds is not None:\n            naive_bounds_all = aux_bounds.flatten(1).unsqueeze(-1)\n        elif is_eyeC:\n            naive_bounds_all = x0 + sign * epsilon # Shape (N_batch, X, 1) -> (N_batch, H, 1)\n        elif is_broadcastable:\n            # obj_tensor is (1, H, X)\n            base_term_naive = torch.einsum('shx,bxo->bho', objective_tensor, x0)\n            eps_term_naive = torch.einsum('shx,bxo->bho', objective_tensor.abs(), epsilon)\n            naive_bounds_all = base_term_naive + sign * eps_term_naive # Shape (N_batch, H, 1)\n        elif is_batch_specific:\n            # obj_tensor is (N, H, X)\n            base_term_naive = torch.einsum('bhx,bxo->bho', objective_tensor, x0)\n            eps_term_naive = torch.einsum('bhx,bxo->bho', objective_tensor.abs(), epsilon)\n            naive_bounds_all = base_term_naive + sign * eps_term_naive # Shape (N_batch, H, 1)\n        else:\n            raise RuntimeError(\"Internal logic error in naive bound calculation\")\n        naive_bounds = naive_bounds_all # Assign calculated bounds\n\n        # Final bounds tensor initialized as naive bounds\n        final_bounds = naive_bounds\n        fill_value_inf = torch.tensor(torch.inf if sign == -1.0 else -torch.inf, device=device)\n\n        # --- Initial Batch Filtering (Common Logic) ---\n        active_indices = sorted_out_batches.get(\"active_indices\", None)\n        if active_indices is None:\n            fully_covered = sorted_out_batches.get(\"fully_covered\", torch.zeros(N_batch, dtype=torch.bool, device=device))\n            active_batches_mask = ~fully_covered # Batches requiring solver\n            if no_return_inf:\n                infeasible_batches = sorted_out_batches.get(\"infeasible_batches\", torch.zeros(N_batch, dtype=torch.bool, device=device))\n                active_batches_mask = ~infeasible_batches & active_batches_mask\n            active_indices = torch.nonzero(active_batches_mask, as_tuple=True)[0]\n        B_act = active_indices.numel() # Number of batches needing the solver.\n        if timer: timer.add('init') # Combined timing for setup.\n\n        # --- Early Exit if No Active Batches ---\n        if B_act == 0:\n            print(f\"Constrained concretize: No active batches after filtering.\")\n            # Ensure non-active parts have naive bounds before returning.\n            # (already done above by initializing with naive/inf)\n            if timer: timer.add(\"concretize\")\n            final_bounds = naive_bounds\n            if no_return_inf:\n                return final_bounds, None\n            else:\n                return final_bounds\n\n        constraints_A, constraints_d = constraints\n        n_constraints = constraints_A.size(1)\n\n        # --- Dynamic Chunk Size Calculation ---\n        if is_batch_specific:\n            # If objective is batch-specific, we do not chunk it.\n            num_chunks = 1\n            final_chunk_size = B_act\n        else:\n            # This block dynamically estimates the optimal chunk size to maximize GPU\n            # utilization while preventing out-of-memory (OOM) errors.\n            calculated_chunk_size = B_act\n            free_mem, total_mem = torch.cuda.mem_get_info()\n            usable_mem = free_mem * safety_factor\n            obj_dtype = objective.dtype\n            dtype_size = torch.finfo(obj_dtype).bits // 8\n            mem_constraints_per_item = (n_constraints * X + n_constraints) * dtype_size\n            mem_x0eps_per_item = 2 * X * dtype_size\n            mem_ori_c_per_item = H * X * dtype_size\n            mem_dual_obj_per_item = H * dtype_size\n            mem_solver_per_item_bh = H * (X + X + 1 + X + 1) * dtype_size * solver_memory_factor\n            mem_masks_temps_per_item = H * 2 # approx\n            mem_per_item_est = (mem_constraints_per_item + mem_x0eps_per_item +\n                                mem_ori_c_per_item + mem_dual_obj_per_item +\n                                mem_solver_per_item_bh + mem_masks_temps_per_item) * 5\n            if mem_per_item_est > 0:\n                estimated_max_chunk = max(1, floor(usable_mem / mem_per_item_est))\n                calculated_chunk_size = min(B_act, estimated_max_chunk)\n            if max_chunk_size is not None and max_chunk_size > 0:\n                final_chunk_size = min(calculated_chunk_size, max_chunk_size)\n            else:\n                final_chunk_size = calculated_chunk_size\n            final_chunk_size = max(1, final_chunk_size) # Ensure chunk size is at least 1.\n            num_chunks = ceil(B_act / final_chunk_size)\n\n        if no_return_inf:\n            # Initialize infeasible_batches boolean mask to be None at first.\n            # If an infeasible batch does occur later, we will then initialize it to be a actual vector.\n            infeasible_batches = None\n\n        for i_chunk in range(num_chunks):\n            # --- Handle size and idx for this chunk ---\n            chunk_start_idx_rel = i_chunk * final_chunk_size\n            chunk_end_idx_rel = min(chunk_start_idx_rel + final_chunk_size, B_act)\n            current_chunk_size = chunk_end_idx_rel - chunk_start_idx_rel\n            if current_chunk_size == 0: continue\n            chunk_indices_abs = active_indices[chunk_start_idx_rel:chunk_end_idx_rel]\n\n            # --- Get matrices for this chunk ---\n            constr_A_mat = constraints_A[chunk_start_idx_rel:chunk_end_idx_rel]             # shape (B_act, n_constraints, X)\n            constr_d_mat = constraints_d[chunk_start_idx_rel:chunk_end_idx_rel]             # shape (B_act, n_constraints)\n            if act_x0 is not None:\n                x0_mat = act_x0[chunk_start_idx_rel:chunk_end_idx_rel]\n            else:\n                x0_mat = x0[chunk_indices_abs]                                              # shape (B_act, X, 1)\n            if act_eps is not None:\n                eps_mat = act_eps[chunk_start_idx_rel:chunk_end_idx_rel]\n            else:\n                eps_mat = epsilon[chunk_indices_abs]                                        # shape (B_act, X, 1)\n\n            if is_broadcastable:\n                ori_c_mat = objective_tensor.expand(current_chunk_size, H, X).clone()\n            else:\n                ori_c_mat = objective_tensor[chunk_indices_abs].clone()             # shape: (B_act, H, X)\n\n            if objective_indices is not None:                                       # shape: (B, H_act) \n                # Select the mask rows corresponding to the active batches in this chunk\n                current_objective_indices = objective_indices[chunk_indices_abs]    # shape: (B_act, H_act)\n                idx_unsqueeze = current_objective_indices.unsqueeze(-1)             # shape: (B_act, H_act, 1)\n                idx_expand = idx_unsqueeze.expand(-1, -1, X)                        # shape: (B_act, H_act, X)\n                ori_c_mat = ori_c_mat.gather(index=idx_expand, dim=1)               # shape: (B_act, H_act, X)\n\n            obj_mat = ori_c_mat                                                             # shape (B_act, H_act, X)\n            # Initialize dual part and base part\n            # Note that the final minimal value is:\n            #    objective^T x0 +                                                                base_part\n            #    constr_d_0 * beta_0 + constr_d_1 * beta_1 + ... +                               dual_part 1\n            #    - ( objective+ constr_a_0 * beta_0 + constr_a_1 * beta_1)^T epsilon             dual_part 2\n            base_objective_term = torch.einsum('bhx,bxo->bh', obj_mat, x0_mat) # shape: (B_act, H_act)\n            dual_objective_part = torch.zeros_like(base_objective_term)        # shape: (B_act, H_act)\n\n            # --- Initialize State for Vectorized Loop (Chunk) ---\n            if sign == 1.0: # Adjust for minimization problem solved by _solve\n                obj_mat *= -1.0                                                # shape (B_act, H_act, X)\n                base_objective_term *= -1.0\n\n            # --- Vectorized Constraint Loop (Operating on Chunk) ---\n            for k in range(n_constraints):\n                constr_a_solve = constr_A_mat[:, k, :] # constraint A matrix shape (B_act, X)\n                constr_d_solve = constr_d_mat[:, k]    # related bias term   shape (B_act,)\n                epsilon_solve = eps_mat                # epsilon             shape (B_act, X)\n                object_a_solve = obj_mat               # objective matrix    shape (B_act, H_act, X)\n\n                with torch.no_grad():   # Otherwise, the gradients will mess up the alpha crown optimization.\n                    optimal_beta = _solve_dual_var(constr_a_solve, object_a_solve, constr_d_solve, epsilon_solve) # shape (B_act, H_act)\n\n                # Accumulation for the parentheses term in dual part 2\n                obj_mat += optimal_beta.unsqueeze(-1) * constr_a_solve.unsqueeze(1)    # shape (B_act, H_act, X)        \n                #            (B_act, H_act, 1)          (B_act, 1, X)\n                # Accumulation of dual part 1\n                dual_objective_part += optimal_beta * constr_d_solve.unsqueeze(1)      # shape (B_act, H_act)\n                #                     (B_act, H_act)    (B_act, 1)\n\n            # --- End of k loop ---\n            # --- Final Objective Calculation for Unfinished Items in Chunk ---\n            final_obj_abs = obj_mat.abs() # shape: (B_act, H_act, X)\n            final_eps_mat = eps_mat       # shape: (B_act, X, 1)\n            final_eps_term = torch.einsum('nhx,nxo->nh', final_obj_abs, final_eps_mat) # shape: (B_act, H_act)\n            dual_objective_part -= final_eps_term   \n\n            # --- Combine terms and handle mask ---\n            final_obj_minimized = base_objective_term + dual_objective_part # shape: (B_act, H_act)\n            if sign == 1.0: final_obj_optimal = -final_obj_minimized        # Flip sign back if maximizing.\n            else: final_obj_optimal = final_obj_minimized\n\n            # Previously we will handle infeasible batches after running through all the chunks, during processing final_bounds.\n            # But that would require to create a copy of naive bounds\n            # To save space and time, we will process final_obj_optimal\n            final_obj_optimal = torch.nan_to_num(final_obj_optimal, nan=fill_value_inf.item(), posinf=fill_value_inf.item(), neginf=-fill_value_inf.item())\n            if no_return_inf:\n                infeasible_batches_chunk = final_obj_optimal.isinf().any(1)\n                if infeasible_batches_chunk.any():\n                    # Note that infeasible_batches was initialized as None\n                    infeasible_batches = torch.full((N_batch, ), fill_value=False, device=device, dtype=torch.bool) if infeasible_batches is None else infeasible_batches\n                    infeasible_batches[chunk_indices_abs] = infeasible_batches_chunk\n                    # Set the bounds of infeasible batches to be naive bounds\n                    infeasible_batches_chunk_indices_abs = chunk_indices_abs[infeasible_batches_chunk]\n                    if objective_indices is not None:\n                        naive_bounds_chunk = naive_bounds[infeasible_batches_chunk_indices_abs].squeeze(-1)\n                        # Get the infeasible objective indices for this chunk.\n                        current_infeasible_objective_indices = current_objective_indices[infeasible_batches_chunk]\n                        final_obj_optimal[infeasible_batches_chunk] = torch.gather(naive_bounds_chunk, dim=1, index=current_infeasible_objective_indices)\n                    else:\n                        final_obj_optimal[infeasible_batches_chunk] = naive_bounds[infeasible_batches_chunk_indices_abs].squeeze(-1)\n\n            # Put the result of this chunk back into the overall final bounds\n            if objective_indices is not None:\n                final_bounds_active_chunk = final_bounds[chunk_indices_abs]  \n                final_bounds_active_chunk.scatter_(dim=1, index=idx_unsqueeze, src=final_obj_optimal.unsqueeze(-1))\n                final_bounds[chunk_indices_abs] = final_bounds_active_chunk\n            else:\n                final_bounds[chunk_indices_abs] = final_obj_optimal.unsqueeze(-1)\n\n        if no_return_inf:\n            if timer: timer.add(\"concretize\")\n            return final_bounds, infeasible_batches\n        else:\n            if timer: timer.add(\"concretize\")\n            return final_bounds\n\n"
  },
  {
    "path": "auto_LiRPA/cuda/cuda_kernels.cu",
    "content": "#include <torch/extension.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <vector>\n\n__global__ void cuda_double2float_rd_kernel(const double* __restrict__ inputs,\n    float* __restrict__ outputs, const size_t tensor_size) {\n  const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < tensor_size) {\n    outputs[idx] = __double2float_rd(inputs[idx]);\n  }\n}\n\n__global__ void cuda_double2float_ru_kernel(const double* __restrict__ inputs,\n    float* __restrict__ outputs, const size_t tensor_size) {\n  const int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx < tensor_size) {\n    outputs[idx] = __double2float_ru(inputs[idx]);\n  }\n}\n\ntorch::Tensor cuda_double2float_forward(torch::Tensor input,\n    const std::string direction) {\n  auto total_elem = input.numel();\n  auto output = torch::empty_like(input, torch::ScalarType::Float);\n\n  const int threads = 1024;\n  const int blocks = (total_elem + threads - 1) / threads;\n  \n  if (direction == \"down\") {\n    cuda_double2float_rd_kernel<<<blocks, threads>>>(input.data<double>(), output.data<float>(), total_elem);\n  }\n  else {\n    cuda_double2float_ru_kernel<<<blocks, threads>>>(input.data<double>(), output.data<float>(), total_elem);\n  }\n  return output;\n}\n\n"
  },
  {
    "path": "auto_LiRPA/cuda/cuda_utils.cpp",
    "content": "#include <torch/extension.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n\ntorch::Tensor cuda_double2float_forward(\n    torch::Tensor input, const std::string direction);\n\ntorch::Tensor double2float_foward(\n    torch::Tensor input, const std::string direction) {\n  TORCH_CHECK((direction == \"down\") || (direction == \"up\"), \"Unsupported direction, must be down or up.\");\n  TORCH_CHECK(input.type().scalarType() == torch::ScalarType::Double, \"This function only supports DoubleTensor as inputs.\");\n  CHECK_CUDA(input);\n  return cuda_double2float_forward(input, direction);\n}\n\n/* \n * Usage: double2float(tensor, direction)\n * \"tensor\" must be a DoubleTensor on GPU.\n * \"direction\" is a string, can be \"up\" (round up) or \"down\" (round down).\n */\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"double2float\", &double2float_foward, \"Convert double to float with rounding direction control (direction = 'up' or 'down').\");\n}\n"
  },
  {
    "path": "auto_LiRPA/cuda_utils.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport os\nimport sys\nimport torch\nfrom torch.utils.cpp_extension import load, BuildExtension, CUDAExtension\nfrom setuptools import setup\n\nclass DummyCudaClass:\n    \"\"\"A dummy class with error message when a CUDA function is called.\"\"\"\n    def __getattr__(self, attr):\n        if attr == \"double2float\":\n            # When CUDA module is not built successfully, use a workaround.\n            def _f(x, d):\n                print('WARNING: Missing CUDA kernels. Please enable CUDA build by setting environment variable AUTOLIRPA_ENABLE_CUDA_BUILD=1 for the correct behavior!')\n                return x.float()\n            return _f\n        def _f(*args, **kwargs):\n            raise RuntimeError(f\"method {attr} not available because CUDA module was not built.\")\n        return _f\n\nif __name__ == \"__main__\" and len(sys.argv) > 1:\n    # Build and install native CUDA modules that can be directly imported later\n    print('Building and installing native CUDA modules...')\n    setup(\n        name='auto_LiRPA_cuda_utils',\n        ext_modules=[CUDAExtension('auto_LiRPA_cuda_utils', [\n            'auto_LiRPA/cuda/cuda_utils.cpp',\n            'auto_LiRPA/cuda/cuda_kernels.cu'\n        ])],\n        cmdclass={'build_ext': BuildExtension.with_options()},\n    )\n    exit(0)\n\nif torch.cuda.is_available() and os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False):\n    try:\n        import auto_LiRPA_cuda_utils as _cuda_utils\n    except:\n        print('CUDA modules have not been installed')\n        try:\n            print('Building native CUDA modules...')\n            code_dir = os.path.dirname(os.path.abspath(__file__))\n            verbose = os.environ.get('AUTOLIRPA_DEBUG_CUDA_BUILD', None) is not None\n            _cuda_utils = load(\n                'cuda_utils', [os.path.join(code_dir, 'cuda', 'cuda_utils.cpp'), os.path.join(code_dir, 'cuda', 'cuda_kernels.cu')], verbose=verbose)\n            print('CUDA modules have been built.')\n        except:\n            print('CUDA module build failure. Some features will be unavailable.')\n            print('Please make sure the latest CUDA toolkit is installed in your system.')\n            if verbose:\n                print(sys.exc_info()[2])\n            else:\n                print('Set environment variable AUTOLIRPA_DEBUG_CUDA_BUILD=1 to view build log.')\n            _cuda_utils = DummyCudaClass()\nelse:\n    if os.environ.get('AUTOLIRPA_ENABLE_CUDA_BUILD', False):\n        print('CUDA unavailable. Some features are disabled.')\n    _cuda_utils = DummyCudaClass()\n\ndouble2float = _cuda_utils.double2float\n\ndef test_double2float():\n    # Test the double2float function.\n    import time\n    shape = (3,4,5)\n\n    a = torch.randn(size=shape, dtype=torch.float64, device='cuda')\n    a = a.transpose(0,1)\n\n    au = _cuda_utils.double2float(a, \"up\")\n    ad = _cuda_utils.double2float(a, \"down\")\n\n    print(a.size(), au.size(), ad.size())\n\n    a_flatten = a.reshape(-1)\n    au_flatten = au.reshape(-1)\n    ad_flatten = ad.reshape(-1)\n\n    for i in range(a_flatten.numel()):\n        ai = a_flatten[i].item()\n        aui = au_flatten[i].item()\n        adi = ad_flatten[i].item()\n        print(adi, ai, aui)\n        assert adi <= ai\n        assert aui >= ai\n    del a, au, ad, a_flatten, au_flatten, ad_flatten\n\n    # Performance benchmark.\n    for j in [1, 4, 16, 64, 256, 1024]:\n        shape = (j, 512, 1024)\n        print(f'shape: {shape}')\n        t = torch.randn(size=shape, dtype=torch.float64, device='cuda')\n\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for i in range(10):\n            tt = t.float()\n        torch.cuda.synchronize()\n        del tt\n        pytorch_time = time.time() - start_time\n        print(f'pytorch rounding time: {pytorch_time:.4f}')\n\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for i in range(10):\n            tu = _cuda_utils.double2float(t, \"up\")\n        torch.cuda.synchronize()\n        del tu\n        roundup_time = time.time() - start_time\n        print(f'cuda round up time: {roundup_time:.4f}')\n\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for i in range(10):\n            td = _cuda_utils.double2float(t, \"down\")\n        torch.cuda.synchronize()\n        del td\n        rounddown_time = time.time() - start_time\n        print(f'cuda round down time: {rounddown_time:.4f}')\n\n        del t\n\n\nif __name__ == \"__main__\":\n    if len(sys.argv) == 1:\n        # Some tests. It's not possible to test them automatically because travis does not have CUDA.\n        test_double2float()\n"
  },
  {
    "path": "auto_LiRPA/edit_graph.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Edit the computational graph in BoundedModule.\"\"\"\n\nfrom auto_LiRPA.bound_ops import Bound\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\n# Make sure the nodes already have `name` and `input_name`\ndef add_nodes(self: 'BoundedModule', nodes):\n    # TODO check duplicate names\n    nodes = [(node if isinstance(node, Bound) else node.bound_node)\n                for node in nodes]\n    for node in nodes:\n        if node.name in self._modules:\n            raise NameError(f'Node with name {node.name} already exists')\n        self._modules[node.name] = node\n        node.output_name = []\n        if len(node.inputs) == 0:\n            self.root_names.append(node.name)\n    for node in nodes:\n        for l_pre in node.inputs:\n            l_pre.output_name.append(node.name)\n        if (getattr(node, 'has_constraint', False) and\n                node.name not in self.layers_with_constraint):\n            self.layers_with_constraint.append(node.name)\n\n\ndef add_input_node(self: 'BoundedModule', node, index=None):\n    self.add_nodes([node])\n    self.input_name.append(node.name)\n    # default value for input_index\n    if index == 'auto':\n        index = max([0] + [(i + 1)\n                    for i in self.input_index if i is not None])\n    self.input_index.append(index)\n\n\ndef delete_node(self: 'BoundedModule', node):\n    for node_inp in node.inputs:\n        node_inp.output_name.pop(node_inp.output_name.index(node.name))\n    self._modules.pop(node.name)\n    # TODO Create a list to contain all such lists such as\n    # \"relus\" and \"optimizable_activations\"\n    self.relus = [\n        item for item in self.relus if item != node]\n    self.optimizable_activations = [\n        item for item in self.optimizable_activations if item != node]\n\n\ndef replace_node(self: 'BoundedModule', node_old, node_new):\n    assert node_old != node_new\n    for node in self.nodes():\n        for i in range(len(node.inputs)):\n            if node.inputs[i] == node_old:\n                node.inputs[i] = node_new\n    node_new.output_name += node_old.output_name\n    if self.final_name == node_old.name:\n        self.final_name = node_new.name\n    for i in range(len(self.output_name)):\n        if self.output_name[i] == node_old.name:\n            self.output_name[i] = node_new.name\n    self.delete_node(node_old)\n"
  },
  {
    "path": "auto_LiRPA/eps_scheduler.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport random\nfrom .utils import logger\n\nclass BaseScheduler(object):\n    def __init__(self, max_eps, opt_str):\n        self.parse_opts(opt_str)\n        self.prev_loss = self.loss = self.max_eps = self.epoch_length = float(\"nan\")\n        self.eps = 0.0\n        self.max_eps = max_eps\n        self.is_training = True\n        self.epoch = 0\n        self.batch = 0\n\n    def __repr__(self):\n        return '<BaseScheduler: eps {}, max_eps {}>'.format(self.eps, self.max_eps)\n\n    def parse_opts(self, s):\n        opts = s.split(',')\n        self.params = {}\n        for o in opts:\n            if o.strip():\n                key, val = o.split('=')\n                self.params[key] = val\n\n    def get_max_eps(self):\n        return self.max_eps\n\n    def get_eps(self):\n        return self.eps\n\n    def reached_max_eps(self):\n        return abs(self.eps - self.max_eps) < 1e-3\n\n    def step_batch(self, verbose=False):\n        if self.is_training:\n            self.batch += 1\n        return\n\n    def step_epoch(self, verbose=False):\n        if self.is_training:\n            self.epoch += 1\n        return\n\n    def update_loss(self, new_loss):\n        self.prev_loss = self.loss\n        self.loss = new_loss\n\n    def train(self):\n        self.is_training = True\n\n    def eval(self):\n        self.is_training = False\n\n    # Set how many batches in an epoch\n    def set_epoch_length(self, epoch_length):\n        self.epoch_length = epoch_length\n\n\nclass FixedScheduler(BaseScheduler):\n    def __init__(self, max_eps, opt_str=\"\"):\n        super(FixedScheduler, self).__init__(max_eps, opt_str)\n        self.eps = self.max_eps\n\n\nclass LinearScheduler(BaseScheduler):\n\n    def __init__(self, max_eps, opt_str):\n        super(LinearScheduler, self).__init__(max_eps, opt_str)\n        self.schedule_start = int(self.params['start'])\n        self.schedule_length = int(self.params['length'])\n        self.epoch_start_eps = self.epoch_end_eps = 0\n\n    def __repr__(self):\n        return '<LinearScheduler: start_eps {:.3f}, end_eps {:.3f}>'.format(\n            self.epoch_start_eps, self.epoch_end_eps)\n\n    def step_epoch(self, verbose = True):\n        self.epoch += 1\n        self.batch = 0\n        if self.epoch < self.schedule_start:\n            self.epoch_start_eps = 0\n            self.epoch_end_eps = 0\n        else:\n            eps_epoch = self.epoch - self.schedule_start\n            if self.schedule_length == 0:\n                self.epoch_start_eps = self.epoch_end_eps = self.max_eps\n            else:\n                eps_epoch_step = self.max_eps / self.schedule_length\n                self.epoch_start_eps = min(eps_epoch * eps_epoch_step, self.max_eps)\n                self.epoch_end_eps = min((eps_epoch + 1) * eps_epoch_step, self.max_eps)\n        self.eps = self.epoch_start_eps\n        if verbose:\n            logger.info(\"Epoch {:3d} eps start {:7.5f} end {:7.5f}\".format(self.epoch, self.epoch_start_eps, self.epoch_end_eps))\n\n    def step_batch(self):\n        if self.is_training:\n            self.batch += 1\n            eps_batch_step = (self.epoch_end_eps - self.epoch_start_eps) / self.epoch_length\n            self.eps = self.epoch_start_eps + eps_batch_step * (self.batch - 1)\n            if self.batch > self.epoch_length:\n                logger.warning('Warning: we expect {} batches in this epoch but this is batch {}'.format(self.epoch_length, self.batch))\n                self.eps = self.epoch_end_eps\n\nclass RangeScheduler(BaseScheduler):\n\n    def __init__(self, max_eps, opt_str):\n        super(RangeScheduler, self).__init__(max_eps, opt_str)\n        self.schedule_start = int(self.params['start'])\n        self.schedule_length = int(self.params['length'])\n\n    def __repr__(self):\n        return '<RangeScheduler: epoch [{}, {}]>'.format(\n            self.schedule_start, self.schedule_start + self.schedule_length)\n\n    def step_epoch(self, verbose = True):\n        self.epoch += 1\n        if self.epoch >= self.schedule_start and self.epoch < self.schedule_start + self.schedule_length:\n            self.eps = self.max_eps\n        else:\n            self.eps = 0\n\n    def step_batch(self):\n        pass\n\nclass BiLinearScheduler(LinearScheduler):\n\n    def __init__(self, max_eps, opt_str):\n        super(BiLinearScheduler, self).__init__(max_eps, opt_str)\n        self.schedule_start = int(self.params['start'])\n        self.schedule_length = int(self.params['length'])\n        self.schedule_length_half = self.schedule_length / 2\n        self.epoch_start_eps = self.epoch_end_eps = 0\n\n    def __repr__(self):\n        return '<BiLinearScheduler: start_eps {:.5f}, end_eps {:.5f}>'.format(\n            self.epoch_start_eps, self.epoch_end_eps)\n\n    def step_epoch(self, verbose = True):\n        self.epoch += 1\n        self.batch = 0\n        if self.epoch < self.schedule_start:\n            self.epoch_start_eps = 0\n            self.epoch_end_eps = 0\n        else:\n            eps_epoch = self.epoch - self.schedule_start\n            eps_epoch_step = self.max_eps / self.schedule_length_half\n            if eps_epoch < self.schedule_length_half:\n                self.epoch_start_eps = min(eps_epoch * eps_epoch_step, self.max_eps)\n                self.epoch_end_eps = min((eps_epoch + 1) * eps_epoch_step, self.max_eps)\n            else:\n                self.epoch_start_eps = max(0,\n                    self.max_eps - ((eps_epoch - self.schedule_length_half) * eps_epoch_step))\n                self.epoch_end_eps = max(0, self.epoch_start_eps - eps_epoch_step)\n        self.eps = self.epoch_start_eps\n        if verbose:\n            logger.info(\"Epoch {:3d} eps start {:7.5f} end {:7.5f}\".format(self.epoch, self.epoch_start_eps, self.epoch_end_eps))\n\n\nclass SmoothedScheduler(BaseScheduler):\n\n    def __init__(self, max_eps, opt_str):\n        super(SmoothedScheduler, self).__init__(max_eps, opt_str)\n        # Epoch number to start schedule\n        self.schedule_start = int(self.params['start'])\n        # Epoch length for completing the schedule\n        self.schedule_length = int(self.params['length'])\n        # Mid point to change exponential to linear schedule\n        self.mid_point = float(self.params.get('mid', 0.25))\n        # Exponential\n        self.beta = float(self.params.get('beta', 4.0))\n        assert self.beta >= 2.\n        assert self.mid_point >= 0. and self.mid_point <= 1.\n        self.batch = 0\n\n\n    # Set how many batches in an epoch\n    def set_epoch_length(self, epoch_length):\n        if self.epoch_length != self.epoch_length:\n            self.epoch_length = epoch_length\n        else:\n            if self.epoch_length != epoch_length:\n                raise ValueError(\"epoch_length must stay the same for SmoothedScheduler\")\n\n    def step_epoch(self, verbose = True):\n        super(SmoothedScheduler, self).step_epoch()\n        # FIXME\n        if verbose == False:\n            for i in range(self.epoch_length):\n                self.step_batch()\n\n    # Smooth schedule that slowly morphs into a linear schedule.\n    # Code is based on DeepMind's IBP implementation:\n    # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L84\n    def step_batch(self, verbose=False):\n        if self.is_training:\n            self.batch += 1\n            init_value = 0.0\n            final_value = self.max_eps\n            beta = self.beta\n            step = self.batch - 1\n            # Batch number for schedule start\n            init_step = (self.schedule_start - 1) * self.epoch_length\n            # Batch number for schedule end\n            final_step = (self.schedule_start + self.schedule_length - 1) * self.epoch_length\n            # Batch number for switching from exponential to linear schedule\n            mid_step = int((final_step - init_step) * self.mid_point) + init_step\n            t = (mid_step - init_step) ** (beta - 1.)\n            # find coefficient for exponential growth, such that at mid point the gradient is the same as a linear ramp to final value\n            alpha = (final_value - init_value) / ((final_step - mid_step) * beta * t + (mid_step - init_step) * t)\n            # value at switching point\n            mid_value = init_value + alpha * (mid_step - init_step) ** beta\n            # return init_value when we have not started\n            is_ramp = float(step > init_step)\n            # linear schedule after mid step\n            is_linear = float(step >= mid_step)\n            exp_value = init_value + alpha * float(step - init_step) ** beta\n            linear_value = min(mid_value + (final_value - mid_value) * (step - mid_step) / (final_step - mid_step), final_value)\n            self.eps = is_ramp * ((1.0 - is_linear) * exp_value + is_linear * linear_value) + (1.0 - is_ramp) * init_value\n\nclass AdaptiveScheduler(BaseScheduler):\n    def __init__(self, max_eps, opt_str):\n        super(AdaptiveScheduler, self).__init__(max_eps, opt_str)\n        self.schedule_start = int(self.params['start'])\n        self.min_eps_step = float(self.params.get('min_step', 1e-9))\n        self.max_eps_step = float(self.params.get('max_step', 1e-4))\n        self.eps_increase_thresh = float(self.params.get('increase_thresh', 1.0))\n        self.eps_increase_factor = float(self.params.get('increase_factor', 1.5))\n        self.eps_decrease_thresh = float(self.params.get('decrease_thresh', 1.5))\n        self.eps_decrease_factor = float(self.params.get('decrease_factor', 2.0))\n        self.small_loss_thresh = float(self.params.get('small_loss_thresh', 0.05))\n        self.epoch = 0\n        self.eps_step = self.min_eps_step\n\n    def step_batch(self):\n        if self.eps < self.max_eps and self.epoch >= self.schedule_start and self.is_training:\n            if self.loss != self.loss or self.prev_loss != self.prev_loss:\n                # First 2 steps. Use min eps step\n                self.eps += self.min_eps_step\n            else:\n                # loss decreasing or loss very small. Increase eps step\n                if self.loss < self.eps_increase_thresh * self.prev_loss or self.loss < self.small_loss_thresh:\n                    self.eps_step = min(self.eps_step * self.eps_increase_factor, self.max_eps_step)\n                # loss increasing. Decrease eps step\n                elif self.loss > self.eps_decrease_thresh * self.prev_loss:\n                    self.eps_step = max(self.eps_step / self.eps_decrease_factor, self.min_eps_step)\n                # print(\"loss {:7.5f} prev_loss {:7.5f} eps_step {:7.5g}\".format(self.loss, self.prev_loss, self.eps_step))\n                # increase eps according to loss\n                self.eps = min(self.eps + self.eps_step, self.max_eps)\n            # print(\"eps step size {:7.5f}, eps {:7.5f}\".format(self.eps_step, self.eps))\n\n\nif __name__ == \"__main__\":\n    s = SmoothedScheduler(0.1, \"start=2,length=10,mid=0.3\")\n    epochs = 20\n    batches = 10\n    loss = 1.0\n    eps = []\n    s.set_epoch_length(batches)\n    for epoch in range(1,epochs+1):\n        s.step_epoch()\n        for batch in range(1,batches+1):\n            s.step_batch()\n            loss = loss * (0.975 + random.random() / 20)\n            eps.append(s.get_eps())\n            print('epoch {:5d} batch {:5d} eps {:7.5f} loss {:7.5f}'.format(epoch, batch, s.get_eps(), loss))\n            # update_loss is only necessary for adaptive eps scheduler\n            s.update_loss(loss)\n    # plot epsilon values\n    import matplotlib\n    matplotlib.use('Agg')\n    from matplotlib import pyplot as plt\n    plt.figure(figsize=(10,8))\n    plt.plot(eps)\n    plt.xticks(range(0, epochs*batches+batches, batches))\n    plt.grid()\n    plt.tight_layout()\n    plt.savefig('epsilon.pdf')\n"
  },
  {
    "path": "auto_LiRPA/forward_bound.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nimport warnings\nfrom .bound_ops import *\nfrom .utils import *\nfrom .linear_bound import LinearBound\nfrom .perturbations import PerturbationLpNorm\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\nimport sys\nsys.setrecursionlimit(1000000)\n\n\ndef forward_general(self: 'BoundedModule', C=None, node:'Bound'=None, concretize=False,\n                    offset=0, from_node=False):\n\n    if self.dynamic:\n        return self.forward_general_dynamic(C=C, node=node, concretize=concretize, offset=offset)\n    if C is None:\n        if (hasattr(node, 'linear') and\n            node.linear.lower is not None and node.linear.upper is not None):\n            return node.linear.lower, node.linear.upper\n        if not node.from_input:\n            node.linear = LinearBound(None, node.value, None, node.value, node.value, node.value)\n            return node.value, node.value\n        if not node.perturbed:\n            node.lower = node.upper = self.get_forward_value(node)\n        if node.is_lower_bound_current():\n            node.linear = LinearBound(None, node.lower, None, node.upper, node.lower, node.upper)\n            return node.lower, node.upper\n\n    for l_pre in node.inputs:\n        if not hasattr(l_pre, 'linear'):\n            self.forward_general(node=l_pre, offset=offset, from_node=from_node)\n    inp = [l_pre.linear for l_pre in node.inputs]\n    node._start = '_forward'\n    if (C is not None and type(node) is BoundLinear and\n            not node.is_input_perturbed(1) and not node.is_input_perturbed(2)):\n        linear = node.bound_forward(self.dim_in, *inp, C=C)\n        C_merged = True\n    else:\n        linear = node.linear = node.bound_forward(self.dim_in, *inp)\n        C_merged = False\n\n    lw, uw = linear.lw, linear.uw\n    lower, upper = linear.lb, linear.ub\n\n    # Combine linear bounds with C matrix\n    if C is not None and not C_merged:\n        # FIXME use bound_forward of BoundLinear\n        C_pos, C_neg = C.clamp(min=0), C.clamp(max=0)\n        # Flatten lw, uw for matrix multiplication\n        lw = lw.reshape(self.batch_size, self.dim_in, -1)\n        uw = uw.reshape(self.batch_size, self.dim_in, -1)\n        _lw = torch.matmul(lw, C_pos.transpose(-1, -2)) + torch.matmul(uw, C_neg.transpose(-1, -2))\n        _uw = torch.matmul(uw, C_pos.transpose(-1, -2)) + torch.matmul(lw, C_neg.transpose(-1, -2))\n        lw, uw = _lw, _uw\n        # Flatten lower, upper for matrix multiplication\n        lower = lower.reshape(self.batch_size, -1)\n        upper = upper.reshape(self.batch_size, -1)\n        _lower = ( torch.matmul(lower.unsqueeze(1), C_pos.transpose(-1, -2)) \n                    + torch.matmul(upper.unsqueeze(1), C_neg.transpose(-1, -2)) )\n        _upper = ( torch.matmul(upper.unsqueeze(1), C_pos.transpose(-1, -2))\n                    + torch.matmul(lower.unsqueeze(1), C_neg.transpose(-1, -2)) )\n        lower, upper = _lower.squeeze(1), _upper.squeeze(1)\n\n    logger.debug(f'Forward bounds to {node}')\n\n    if concretize:\n        if lw is not None or uw is not None:\n            lower, upper = self.concretize_bounds(\n                node=node,\n                lower=lower,\n                upper=upper,\n                concretize_mode='forward',\n                lw=lw,\n                uw=uw,\n                clip_neuron_selection_value=self.clip_neuron_selection_value,\n                clip_neuron_selection_type=self.clip_neuron_selection_type\n            )\n\n        linear.lower, linear.upper = lower, upper\n\n        if C is None:\n            node.linear = linear\n            node.lower, node.upper = lower, upper\n\n        if self.bound_opts['forward_refinement']:\n            need_refinement = False\n            for out in node.output_name:\n                out_node = self[out]\n                for i in getattr(out_node, 'requires_input_bounds', []):\n                    if out_node.inputs[i] == node:\n                        need_refinement = True\n                        break\n            if need_refinement:\n                self.forward_refinement(node)\n        return lower, upper\n\n\ndef forward_general_dynamic(self: 'BoundedModule', C=None, node:'Bound'=None,\n                            concretize=False, offset=0):\n    max_dim = self.bound_opts['forward_max_dim']\n\n    if C is None:\n        if hasattr(node, 'linear'):\n            assert not concretize\n\n            linear = node.linear\n            if offset == 0:\n                if linear.lw is None:\n                    return linear\n                elif linear.lw.shape[1] <= max_dim:\n                    return linear\n            if linear.lw is not None:\n                lw = linear.lw[:, offset:offset+max_dim]\n                x_L = linear.x_L[:, offset:offset+max_dim]\n                x_U = linear.x_U[:, offset:offset+max_dim]\n                tot_dim = linear.tot_dim\n                if offset == 0:\n                    lb = linear.lb\n                else:\n                    lb = torch.zeros_like(linear.lb)\n            else:\n                lw = x_L = x_U = None\n                tot_dim = 0\n                lb = linear.lb\n            return LinearBound(\n                lw, lb, lw, lb, x_L=x_L, x_U=x_U,\n                offset=offset, tot_dim=tot_dim,\n            )\n\n        # These cases have no coefficient tensor\n        if not node.from_input:\n            if concretize:\n                return node.value, node.value\n            else:\n                node.linear = LinearBound(\n                    None, node.value, None, node.value, node.value, node.value)\n                return node.linear\n        if not node.perturbed:\n            if not node.is_lower_bound_current():\n                node.lower = node.upper = self.get_forward_value(node)\n            if concretize:\n                return node.lower, node.upper\n            else:\n                if offset > 0:\n                    lb = torch.zeros_like(node.lower)\n                else:\n                    lb = node.lower\n                node.linear = LinearBound(None, lb, None, lb, node.lower, node.upper)\n                return node.linear\n\n    if offset == 0:\n        logger.debug(f'forward_general_dynamic: node={node}')\n\n    inp = []\n    for l_pre in node.inputs:\n        linear_inp = self.forward_general_dynamic(node=l_pre, offset=offset)\n        linear_inp.lower = l_pre.lower\n        linear_inp.upper = l_pre.upper\n        inp.append(linear_inp)\n    node._start = '_forward'\n    if (C is not None and isinstance(node, BoundLinear) and\n            not node.is_input_perturbed(1) and not node.is_input_perturbed(2)):\n        linear = node.bound_dynamic_forward(\n            *inp, C=C, max_dim=max_dim, offset=offset)\n        C_merged = True\n    else:\n        linear = node.bound_dynamic_forward(\n            *inp, max_dim=max_dim, offset=offset)\n        C_merged = False\n    if offset > 0:\n        linear.lb = linear.ub = torch.zeros_like(linear.lb)\n\n    lw, lb, tot_dim = linear.lw, linear.lb, linear.tot_dim\n    #logger.debug(f'forward_general_dynamic: node={node}, w_size={lw.shape[1]}, tot_dim={tot_dim}')\n\n    if C is not None and not C_merged:\n        # FIXME use bound_forward of BoundLinear\n        lw = torch.matmul(lw, C.transpose(-1, -2))\n        lb = torch.matmul(lb.unsqueeze(1), C.transpose(-1, -2)).squeeze(1)\n\n    if concretize:\n        lower = upper = lb\n        if lw is not None:\n            batch_size = lw.shape[0]\n            assert (lw.ndim > 1)\n            if lw.shape[1] > 0:\n                A = lw.reshape(batch_size, lw.shape[1], -1).transpose(1, 2)\n                ptb = PerturbationLpNorm(x_L=linear.x_L, x_U=linear.x_U)\n                lower = lower + ptb.concretize(x=None, A=A, sign=-1).view(lb.shape)\n                upper = upper + ptb.concretize(x=None, A=A, sign=1).view(lb.shape)\n            offset_next = offset + max_dim\n            more = offset_next < tot_dim\n        else:\n            more = False\n\n        if C is None and offset == 0 and not more:\n            node.linear = linear\n\n        if more:\n            if lw is not None and lw.shape[1] > 0:\n                del A\n                del ptb\n                del lw\n                del linear\n            del inp\n            # TODO make it non-recursive\n            lower_next, upper_next = self.forward_general_dynamic(\n                C, node, concretize=True, offset=offset_next)\n            lower = lower + lower_next\n            upper = upper + upper_next\n\n        if C is None:\n            node.lower, node.upper = lower, upper\n\n        return lower, upper\n    else:\n        return linear\n\n\ndef clean_memory(self: 'BoundedModule', node):\n    \"\"\" Remove linear bounds that are no longer needed. \"\"\"\n    # TODO add an option to retain these bounds\n    for inp in node.inputs:\n        if hasattr(inp, 'linear') and inp.linear is not None:\n            clean = True\n            for out in inp.output_name:\n                out_node = self[out]\n                if not (hasattr(out_node, 'linear') and out_node.linear is not None):\n                    clean = False\n            if clean:\n                if isinstance(inp.linear, tuple):\n                    for item in inp.linear:\n                        del item\n                delattr(inp, 'linear')\n\n\ndef forward_refinement(self: 'BoundedModule', node):\n    \"\"\" Refine forward bounds with backward bound propagation\n    (only refine unstable positions). \"\"\"\n    unstable_size_before = torch.logical_and(node.lower < 0, node.upper > 0).sum()\n    if unstable_size_before == 0:\n        return\n    unstable_idx, unstable_size = self.get_unstable_locations(\n        node.lower, node.upper, conv=isinstance(node, BoundConv))\n    logger.debug(f'Forward refinement for {node}')\n    batch_size = node.lower.shape[0]\n    ret = self.batched_backward(\n        node, C=None, unstable_idx=unstable_idx, batch_size=batch_size)\n    self.restore_sparse_bounds(\n        node, unstable_idx, unstable_size, node.lower, node.upper,\n        new_lower=ret[0], new_upper=ret[1])\n    unstable_size_after = torch.logical_and(node.lower < 0, node.upper > 0).sum()\n    logger.debug(f'  Unstable neurons: {unstable_size_before} -> {unstable_size_after}')\n    # TODO also update linear bounds?\n\n\ndef init_forward(self: 'BoundedModule', roots, dim_in):\n    if dim_in == 0:\n        raise ValueError(\"At least one node should have a specified perturbation\")\n    prev_dim_in = 0\n    # Assumption: roots[0] is the input node which implies batch_size\n    batch_size = roots[0].value.shape[0]\n    for i in range(len(roots)):\n        if hasattr(roots[i], 'perturbation') and roots[i].perturbation is not None:\n            shape = roots[i].linear.lw.shape\n            if self.dynamic:\n                if shape[1] != dim_in:\n                    raise NotImplementedError('Dynamic forward bound is not supported yet when there are multiple perturbed inputs.')\n                ptb = roots[i].perturbation\n                if (type(ptb) != PerturbationLpNorm or ptb.norm < np.inf\n                        or ptb.x_L is None or ptb.x_U is None):\n                    raise NotImplementedError(\n                        'For dynamic forward bounds, only Linf (box) perturbations are supported, and x_L and x_U must be explicitly provided.')\n                roots[i].linear.x_L = (\n                    ptb.x_L_sparse.view(batch_size, -1) if ptb.sparse\n                    else ptb.x_L.view(batch_size, -1))\n                roots[i].linear.x_U = (\n                    ptb.x_U_sparse.view(batch_size, -1) if ptb.sparse\n                    else ptb.x_U.view(batch_size, -1))\n            else:\n                lw = torch.zeros(shape[0], dim_in, *shape[2:]).to(roots[i].linear.lw)\n                lw[:, prev_dim_in:(prev_dim_in+shape[1])] = roots[i].linear.lw\n                if roots[i].linear.lw.data_ptr() == roots[i].linear.uw.data_ptr():\n                    uw = lw\n                else:\n                    uw = torch.zeros(shape[0], dim_in, *shape[2:]).to(roots[i].linear.uw)\n                    uw[:, prev_dim_in:(prev_dim_in+shape[1])] = roots[i].linear.uw\n                roots[i].linear.lw = lw\n                roots[i].linear.uw = uw\n            if i >= self.num_global_inputs:\n                roots[i].forward_value = roots[i].forward_value.unsqueeze(0).repeat(\n                    *([batch_size] + [1] * self.forward_value.ndim))\n            prev_dim_in += shape[1]\n        else:\n            b = fv = roots[i].forward_value\n            shape = fv.shape\n            if roots[i].from_input:\n                w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device)\n                warnings.warn(f'Creating a LinearBound with zero weights with shape {w.shape}')\n            else:\n                w = None\n            roots[i].linear = LinearBound(w, b, w, b, b, b)\n            roots[i].lower = roots[i].upper = b"
  },
  {
    "path": "auto_LiRPA/interval_bound.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nfrom .bound_ops import *\nfrom .utils import logger\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef IBP_general(self: 'BoundedModule', node=None, C=None,\n                delete_bounds_after_use=False):\n\n    logger.debug('IBP for %s', node)\n\n    def _delete_unused_bounds(node_list: List[Bound]):\n        \"\"\"Delete bounds from input layers after use to save memory. Used when\n        sparse_intermediate_bounds_with_ibp is true.\"\"\"\n        if delete_bounds_after_use:\n            for n in node_list:\n                del n.interval\n                n.delete_lower_and_upper_bounds()\n\n    if self.bound_opts.get('loss_fusion', False):\n        res = self._IBP_loss_fusion(node, C)\n        if res is not None:\n            return res\n\n    if not node.perturbed:\n        fv = self.get_forward_value(node)\n        node.lower, node.upper = node.interval = (fv, fv)\n\n    to_be_deleted_bounds = []\n    if not hasattr(node, 'interval'):\n        for n in node.inputs:\n            if not hasattr(n, 'interval'):\n                # Node n does not have interval bounds; we must compute it.\n                self.IBP_general(\n                    n, delete_bounds_after_use=delete_bounds_after_use)\n                to_be_deleted_bounds.append(n)\n        inp = [n_pre.interval for n_pre in node.inputs]\n        if (C is not None and isinstance(node, BoundLinear)\n                and not node.is_input_perturbed(1)):\n            # merge the last BoundLinear node with the specification, available\n            # when weights of this layer are not perturbed\n            ret = node.interval_propagate(*inp, C=C)\n            _delete_unused_bounds(to_be_deleted_bounds)\n            return ret\n        else:\n            node.interval = node.interval_propagate(*inp)\n\n        node.lower, node.upper = node.interval\n        if isinstance(node.lower, torch.Size):\n            node.lower = torch.tensor(node.lower)\n        if isinstance(node.upper, torch.Size):\n            node.upper = torch.tensor(node.upper)\n\n        # Handle NaNs in lower and upper bounds\n        if torch.isnan(node.lower).any():\n            print(\n                f'[Interval Warning] NaN detected in lower bounds of node {node}. '\n                f'Replacing with -inf.'\n            )\n            node.lower = torch.where(\n                torch.isnan(node.lower),\n                torch.full_like(node.lower, float('-inf')),\n                node.lower\n            )\n        if torch.isnan(node.upper).any():\n            print(\n                f'[Interval Warning] NaN detected in upper bounds of node {node}. '\n                f'Replacing with +inf.'\n            )\n            node.upper = torch.where(\n                torch.isnan(node.upper),\n                torch.full_like(node.upper, float('inf')),\n                node.upper\n            )\n        node.interval = Interval.make_interval(node.lower, node.upper, other=node.interval)\n\n    if C is not None:\n        _delete_unused_bounds(to_be_deleted_bounds)\n        return BoundLinear.interval_propagate(None, node.interval, C=C)\n    else:\n        _delete_unused_bounds(to_be_deleted_bounds)\n        return node.interval\n\n\ndef _IBP_loss_fusion(self: 'BoundedModule', node, C):\n    \"\"\"Merge BoundLinear, BoundGatherElements and BoundSub.\n\n    Improvement when loss fusion is used in training.\n    \"\"\"\n    # not using loss fusion\n    if not self.bound_opts.get('loss_fusion', False):\n        return None\n\n    # Currently this function has issues in more complicated networks.\n    if self.bound_opts.get('no_ibp_loss_fusion', False):\n        return None\n\n    if (C is None and isinstance(node, BoundSub)\n            and isinstance(node.inputs[1], BoundGatherElements)\n            and isinstance(node.inputs[0], BoundLinear)):\n        node_gather = node.inputs[1]\n        node_linear = node.inputs[0]\n        node_start = node_linear.inputs[0]\n        w = node_linear.inputs[1].param\n        b = node_linear.inputs[2].param\n        labels = node_gather.inputs[1]\n        if not hasattr(node_start, 'interval'):\n            self.IBP_general(node_start)\n        for n in node_gather.inputs:\n            if not hasattr(n, 'interval'):\n                self.IBP_general(n)\n        if torch.isclose(labels.lower, labels.upper, 1e-8).all():\n            labels = labels.lower\n            batch_size = labels.shape[0]\n            w = w.expand(batch_size, *w.shape)\n            w = w - torch.gather(\n                w, dim=1,\n                index=labels.unsqueeze(-1).repeat(1, w.shape[1], w.shape[2]))\n            b = b.expand(batch_size, *b.shape)\n            b = b - torch.gather(b, dim=1,\n                                    index=labels.repeat(1, b.shape[1]))\n            lower, upper = node_start.interval\n            lower, upper = lower.unsqueeze(1), upper.unsqueeze(1)\n            node.lower, node.upper = node_linear.interval_propagate(\n                (lower, upper), (w, w), (b.unsqueeze(1), b.unsqueeze(1)))\n            node.interval = node.lower, node.upper = (\n                node.lower.squeeze(1), node.upper.squeeze(1))\n            return node.interval\n\n    return None\n\n\ndef check_IBP_intermediate(self: 'BoundedModule', node):\n    \"\"\" Check if we use IBP bounds to compute intermediate bounds on this node.\n\n    Currently, assume all eligible operators have exactly one input.\n    \"\"\"\n    tighten_input_bounds = (\n        self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n    )\n    directly_optimize_layer_names = (\n        self.bound_opts['optimize_bound_args']['directly_optimize']\n    )\n    if isinstance(node, BoundInput) and tighten_input_bounds:\n        return False\n    if node.name in directly_optimize_layer_names:\n        return False\n\n    if self.ibp_nodes is not None and node.name in self.ibp_nodes:\n        self.IBP_general(node)\n        return True\n\n    if (isinstance(node, BoundReshape)\n            and node.inputs[0].is_lower_bound_current()\n            and hasattr(node.inputs[1], 'value')):\n        # Node for input value.\n        val_input = node.inputs[0]\n        # Node for input parameter (e.g., shape, permute)\n        arg_input = node.inputs[1]\n        node.lower = node.forward(val_input.lower, arg_input.value)\n        node.upper = node.forward(val_input.upper, arg_input.value)\n        node.interval = (node.lower, node.upper)\n        return True\n\n    # Use IBP if node.ibp_intermediate == True (for nodes such as ReLU)\n    nodes = []\n    while (not node.is_lower_bound_current() or not node.is_upper_bound_current()):\n        if not node.ibp_intermediate:\n            return False\n        nodes.append(node)\n        node = node.inputs[0]\n    nodes.reverse()\n    for n in nodes:\n        self.IBP_general(n)\n\n    return True\n\n\ndef check_IBP_first_linear(self: 'BoundedModule', node):\n    \"\"\"Here we avoid creating a big C matrix in the first linear layer.\n    Disable this optimization when we have beta for intermediate layer bounds.\n    Disable this optimization when we need the A matrix of the first nonlinear\n    layer, forcibly use CROWN to record A matrix.\n    \"\"\"\n    tighten_input_bounds = (\n        self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n    )\n    directly_optimize_layer_names = (\n        self.bound_opts['optimize_bound_args']['directly_optimize']\n    )\n    if isinstance(node, BoundInput) and tighten_input_bounds:\n        return False\n    if node.name in directly_optimize_layer_names:\n        return False\n\n    # This is the list of all intermediate layers where we need to refine.\n    if self.intermediate_constr is not None:\n        intermediate_beta_enabled_layers = [\n            k for v in self.intermediate_constr.values() for k in v]\n    else:\n        intermediate_beta_enabled_layers = []\n\n    if (node.name not in self.needed_A_dict.keys()\n            and (type(node) == BoundLinear\n                or type(node) == BoundConv\n                and node.name not in intermediate_beta_enabled_layers)):\n        if type(node.inputs[0]) == BoundInput:\n            node.lower, node.upper = self.IBP_general(node)\n            return True\n\n    return False\n\n\ndef compare_with_IBP(self, node, lower, upper, C=None):\n    \"\"\"Re-compute the bounds by IBP given the existing intermediate bounds.\n    Update the bounds if IBP gives tighter bounds.\"\"\"\n\n    lower_ibp, upper_ibp = self.IBP_general(node, C=C, delete_bounds_after_use=True)\n    if lower is not None:\n        lower = torch.max(lower, lower_ibp)\n    if upper is not None:\n        upper = torch.min(upper, upper_ibp)\n    return lower, upper\n"
  },
  {
    "path": "auto_LiRPA/jacobian.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Handle Jacobian bounds.\"\"\"\n\nimport torch\nfrom auto_LiRPA.bound_ops import JacobianOP, GradNorm  # pylint: disable=unused-import\nfrom auto_LiRPA.bound_ops import (\n    BoundInput, BoundAdd, BoundRelu, BoundJacobianInit,\n    BoundJacobianOP)\nfrom auto_LiRPA.utils import logger, prod\nfrom collections import deque\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef _expand_jacobian(self):\n    self.jacobian_start_nodes = []\n    for node in list(self.nodes()):\n        if isinstance(node, BoundJacobianOP):\n            self.jacobian_start_nodes.append(node.inputs[0])\n            expand_jacobian_node(self, node)\n    if self.jacobian_start_nodes:\n        # Disable unstable options\n        self.bound_opts.update({\n            'sparse_intermediate_bounds': False,\n            'sparse_conv_intermediate_bounds': False,\n            'sparse_intermediate_bounds_with_ibp': False,\n            'sparse_features_alpha': False,\n            'sparse_spec_alpha': False,\n        })\n        # Optimize new nodes if possible\n        self._optimize_graph()\n        for node in self.nodes():\n            if isinstance(node, BoundRelu):\n                node.use_sparse_spec_alpha = node.use_sparse_features_alpha = False\n        # If Jacobian nodes are added, we need to redo the forward pass to update the\n        # properties of newly added nodes (e.g., output shape, forward value, etc.)\n        self.forward(*self.global_input)\n\n\ndef expand_jacobian_node(self, jacobian_node):\n    logger.info(f'Expanding Jacobian node {jacobian_node}')\n\n    output_node = jacobian_node.inputs[0]\n    input_node = jacobian_node.inputs[1]\n    batch_size = output_node.output_shape[0]\n    output_dim = prod(output_node.output_shape[1:])\n\n    # Gradient values in `grad` may not be accurate. We do not consider gradient\n    # accumulation from multiple succeeding nodes. We only want the shapes but\n    # not the accurate values.\n    grad = {}\n    # Dummy values in grad_start\n    grad_start = torch.ones(batch_size, output_dim,\n                            *output_node.output_shape[1:], device=self.device)\n    grad[output_node.name] = grad_start\n    input_node_found = False\n\n    # First BFS pass: traverse the graph, count degrees, and build gradient\n    # layers.\n    # Degrees of nodes.\n    degree = {}\n    # Original layer for gradient computation.\n    node_grad_ori = {}\n\n    degree[output_node.name] = 0\n    queue = deque([output_node])\n    while len(queue) > 0:\n        node = queue.popleft()\n\n        if node == input_node:\n            input_node_found = True\n            continue\n        elif node.no_jacobian or not node.from_input:\n            continue\n        else:\n            node_grad_ori[node.name] = node.build_gradient_node(grad[node.name])\n            node_grad_ori[node.name] += [None] * (\n                len(node.inputs) - len(node_grad_ori[node.name]))\n\n        logger.debug(f'Building gradient node for {node}')\n        if not isinstance(node, BoundInput):\n            for i in range(len(node.inputs)):\n                if node_grad_ori[node.name][i] is None:\n                    continue\n                grad[node.inputs[i].name] = node_grad_ori[\n                    node.name][i][0](*node_grad_ori[node.name][i][1])\n                if not node.inputs[i].name in degree:\n                    degree[node.inputs[i].name] = 0\n                    queue.append(node.inputs[i])\n                degree[node.inputs[i].name] += 1\n\n    if not input_node_found:\n        raise RuntimeError('Input node not found')\n\n    # Second BFS pass: build the backward computational graph\n    grad_node = {}\n    initial_name = f'/jacobian{output_node.name}{output_node.name}'\n    grad_node[output_node.name] = BoundJacobianInit(inputs=[output_node])\n    grad_node[output_node.name].name = initial_name\n    self.add_nodes([grad_node[output_node.name]])\n    queue = deque([output_node])\n    while len(queue) > 0:\n        node = queue.popleft()\n\n        if node == input_node:\n            self.replace_node(jacobian_node, grad_node[node.name])\n            continue\n        if node.no_jacobian or not node.from_input:\n            continue\n\n        logger.debug(f'Converting gradient node for {node}')\n        for k in range(len(node.inputs)):\n            if node_grad_ori[node.name][k] is None:\n                continue\n            nodes_op, nodes_in, nodes_out, _ = self._convert_nodes(\n                node_grad_ori[node.name][k][0],\n                tuple(item.detach()\n                      for item in node_grad_ori[node.name][k][1]))\n            rename_dict = {}\n            assert isinstance(nodes_in[0], BoundInput)\n            rename_dict[nodes_in[0].name] = grad_node[node.name].name\n            for i in range(1, len(nodes_in)):\n                # Assume it's a parameter here\n                new_name = f'/jacobian{output_node.name}{node.name}/{k}/params{nodes_in[i].name}'\n                rename_dict[nodes_in[i].name] = new_name\n            for i in range(len(nodes_op)):\n                # intermediate nodes\n                if not nodes_op[i].name in rename_dict:\n                    new_name = f'/jacobian{output_node.name}{node.name}/{k}/tmp{nodes_op[i].name}'\n                    rename_dict[nodes_op[i].name] = new_name\n            assert len(nodes_out) == 1\n            nodes_out = nodes_out[0]\n            rename_dict[nodes_out.name] = f'/jacobian{output_node.name}{node.name}/{k}/output'\n\n            self.rename_nodes(nodes_op, nodes_in, rename_dict)\n            input_nodes_replace = (\n                [self._modules[nodes_in[0].name]] + node_grad_ori[node.name][k][2])\n            for i in range(len(input_nodes_replace)):\n                for n in nodes_op:\n                    for j in range(len(n.inputs)):\n                        if n.inputs[j].name == nodes_in[i].name:\n                            n.inputs[j] = input_nodes_replace[i]\n            self.add_nodes(nodes_op + nodes_in[len(input_nodes_replace):])\n\n            if node.inputs[k].name in grad_node:\n                node_cur = grad_node[node.inputs[k].name]\n                node_add = BoundAdd(\n                    attr=None, inputs=[node_cur, nodes_out],\n                    output_index=0, options={})\n                node_add.name = f'{nodes_out.name}/add'\n                grad_node[node.inputs[k].name] = node_add\n                self.add_nodes([node_add])\n            else:\n                grad_node[node.inputs[k].name] = nodes_out\n            degree[node.inputs[k].name] -= 1\n            if degree[node.inputs[k].name] == 0:\n                queue.append(node.inputs[k])\n\n\ndef compute_jacobian_bounds(self: 'BoundedModule', x, optimize=True,\n                            optimize_output_node=None,\n                            bound_lower=True, bound_upper=True):\n    \"\"\"Compute jacobian bounds on the pre-augmented graph (new API).\"\"\"\n\n    if isinstance(x, torch.Tensor):\n        x = (x,)\n\n    if optimize:\n        if optimize_output_node is None:\n            if len(self.jacobian_start_nodes) == 1:\n                optimize_output_node = self.jacobian_start_nodes[0]\n            else:\n                raise NotImplementedError(\n                    'Multiple Jacobian nodes found.'\n                    'An output node for optimizable bounds (optimize_output_node) '\n                    'must be specified explicitly')\n        self.compute_bounds(\n            method='CROWN-Optimized',\n            C=None, x=x, bound_upper=False,\n            final_node_name=optimize_output_node.name)\n        intermediate_bounds = {}\n        for node in self._modules.values():\n            if node.is_lower_bound_current():\n                intermediate_bounds[node.name] = (node.lower, node.upper)\n    else:\n        intermediate_bounds = None\n    lb, ub = self.compute_bounds(\n        method='CROWN', x=x,\n        bound_lower=bound_lower, bound_upper=bound_upper,\n        interm_bounds=intermediate_bounds)\n    return lb, ub\n"
  },
  {
    "path": "auto_LiRPA/linear_bound.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nclass LinearBound:\n    def __init__(\n            self, lw=None, lb=None, uw=None, ub=None, lower=None, upper=None,\n            from_input=None, x_L=None, x_U=None, offset=0, tot_dim=None):\n        self.lw = lw\n        self.lb = lb\n        self.uw = uw\n        self.ub = ub\n        self.lower = lower\n        self.upper = upper\n        self.from_input = from_input\n        self.x_L = x_L\n        self.x_U = x_U\n        # Offset for input variables. Used for batched forward bound\n        # propagation.\n        self.offset = offset\n        if tot_dim is not None:\n            self.tot_dim = tot_dim\n        elif lw is not None:\n            self.tot_dim = lw.shape[1]\n        else:\n            self.tot_dim = 0\n\n    def is_single_bound(self):\n        \"\"\"Check whether the linear lower bound and the linear upper bound are\n        the same.\"\"\"\n        if (self.lw is not None and self.uw is not None\n                and self.lb is not None and self.ub is not None):\n            return (self.lw.data_ptr() == self.uw.data_ptr()\n                and self.lb.data_ptr() == self.ub.data_ptr()\n                and self.x_L is not None and self.x_U is not None)\n        else:\n            return True\n"
  },
  {
    "path": "auto_LiRPA/operators/__init__.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .base import *\nfrom .linear import *\nfrom .convolution import *\nfrom .pooling import *\nfrom .activation_base import *\nfrom .activations import *\nfrom .s_shaped import *\nfrom .relu import *\nfrom .bivariate import *\nfrom .add_sub import *\nfrom .normalization import *\nfrom .shape import *\nfrom .reduce import *\nfrom .rnn import *\nfrom .softmax import *\nfrom .constant import *\nfrom .leaf import *\nfrom .logical import *\nfrom .dropout import *\nfrom .dtype import *\nfrom .trigonometric import *\nfrom .cut_ops import *\nfrom .solver_utils import grb\nfrom .resize import *\nfrom .jacobian import *\nfrom .indexing import *\nfrom .slice_concat import *\nfrom .reshape import *\nfrom .minmax import *\nfrom .convex_concave import *\nfrom .gelu import *\nfrom .tile import *\n"
  },
  {
    "path": "auto_LiRPA/operators/activation_base.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Activation operators or other unary nonlinear operators\"\"\"\nimport torch\nfrom torch import Tensor\nfrom collections import OrderedDict\nfrom .base import *\nfrom .clampmult import multiply_by_A_signs\n\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_set_profiling_mode(False)\n\n\nclass BoundActivation(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.requires_input_bounds = [0]\n        self.use_default_ibp = True\n        self.splittable = False\n        # \"core\" region of input where precomputation can be done\n        self.range_l = -10\n        self.range_u = 10\n\n    def _init_masks(self, x):\n        self.mask_pos = x.lower >= 0\n        self.mask_neg = x.upper <= 0\n        self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg))\n\n    def init_linear_relaxation(self, x):\n        self._init_masks(x)\n        self.lw = torch.zeros_like(x.lower)\n        self.lb = self.lw.clone()\n        self.uw = self.lw.clone()\n        self.ub = self.lw.clone()\n\n    def add_linear_relaxation(self, mask, type, k, x0, y0=None):\n        if y0 is None:\n            y0 = self.forward(x0)\n\n        if type == 'lower':\n            w_out, b_out = self.lw, self.lb\n        else:\n            w_out, b_out = self.uw, self.ub\n\n        if mask is None:\n            if isinstance(k, Tensor) and k.ndim > 0:\n                w_out[:] = k\n            else:\n                w_out.fill_(k)\n        else:\n            w_out[..., mask] = (k[..., mask].to(w_out) if isinstance(k, Tensor)\n                                else k)\n\n        if (not isinstance(x0, Tensor) and x0 == 0\n                and not isinstance(y0, Tensor) and y0 == 0):\n            pass\n        else:\n            b = -x0 * k + y0\n            if mask is None:\n                if b.ndim > 0:\n                    b_out[:] = b\n                else:\n                    b_out.fill_(b)\n            else:\n                b_out[..., mask] = b[..., mask]\n\n    def bound_relax(self, x, init=False):\n        return not_implemented_op(self, 'bound_relax')\n\n    def bound_backward(self, last_lA, last_uA, x, reduce_bias=True, **kwargs):\n        self.bound_relax(x, init=True)\n\n        def _bound_oneside(last_A, sign=-1):\n            if last_A is None:\n                return None, 0\n            if sign == -1:\n                w_pos, b_pos, w_neg, b_neg = (\n                    self.lw.unsqueeze(0), self.lb.unsqueeze(0),\n                    self.uw.unsqueeze(0), self.ub.unsqueeze(0))\n            else:\n                w_pos, b_pos, w_neg, b_neg = (\n                    self.uw.unsqueeze(0), self.ub.unsqueeze(0),\n                    self.lw.unsqueeze(0), self.lb.unsqueeze(0))\n            w_pos = maybe_unfold_patches(w_pos, last_A)\n            w_neg = maybe_unfold_patches(w_neg, last_A)\n            b_pos = maybe_unfold_patches(b_pos, last_A)\n            b_neg = maybe_unfold_patches(b_neg, last_A)\n            if self.batch_dim == 0:\n                _A, _bias = multiply_by_A_signs(\n                    last_A, w_pos, w_neg, b_pos, b_neg, reduce_bias=reduce_bias)\n            elif self.batch_dim == -1:\n                # FIXME: why this is different from above?\n                assert reduce_bias\n                mask = torch.gt(last_A, 0.).to(torch.float)\n                _A = last_A * (mask * w_pos.unsqueeze(1) +\n                               (1 - mask) * w_neg.unsqueeze(1))\n                _bias = last_A * (mask * b_pos.unsqueeze(1) +\n                                  (1 - mask) * b_neg.unsqueeze(1))\n                if _bias.ndim > 2:\n                    _bias = torch.sum(_bias, dim=list(range(2, _bias.ndim)))\n            else:\n                raise NotImplementedError\n\n            return _A, _bias\n\n        lA, lbias = _bound_oneside(last_lA, sign=-1)\n        uA, ubias = _bound_oneside(last_uA, sign=+1)\n\n        return [(lA, uA)], lbias, ubias\n\n    @staticmethod\n    @torch.jit.script\n    def bound_forward_w(\n            relax_lw: Tensor, relax_uw: Tensor, x_lw: Tensor, x_uw: Tensor, dim: int):\n        lw = (relax_lw.unsqueeze(dim).clamp(min=0) * x_lw +\n              relax_lw.unsqueeze(dim).clamp(max=0) * x_uw)\n        uw = (relax_uw.unsqueeze(dim).clamp(max=0) * x_lw +\n              relax_uw.unsqueeze(dim).clamp(min=0) * x_uw)\n        return lw, uw\n\n    @staticmethod\n    @torch.jit.script\n    def bound_forward_b(\n            relax_lw: Tensor, relax_uw: Tensor, relax_lb: Tensor,\n            relax_ub: Tensor, x_lb: Tensor, x_ub: Tensor):\n        lb = relax_lw.clamp(min=0) * x_lb + relax_lw.clamp(max=0) * x_ub + relax_lb\n        ub = relax_uw.clamp(max=0) * x_lb + relax_uw.clamp(min=0) * x_ub + relax_ub\n        return lb, ub\n\n    def bound_forward(self, dim_in, x):\n        self.bound_relax(x, init=True)\n\n        assert (x.lw is None) == (x.uw is None)\n\n        dim = 1 if self.lw.ndim > 0 else 0\n\n        if x.lw is not None:\n            lw, uw = BoundActivation.bound_forward_w(\n                self.lw, self.uw, x.lw, x.uw, dim)\n        else:\n            lw = uw = None\n        lb, ub = BoundActivation.bound_forward_b(\n            self.lw, self.uw, self.lb, self.ub, x.lb, x.ub)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        return self.forward(h_L), self.forward(h_U)\n\n    def get_split_mask(self, lower, upper, input_index):\n        \"\"\"Return a mask to indicate if each neuron potentially needs a split.\n\n        0: Stable (linear) neuron; 1: unstable (nonlinear) neuron.\n        \"\"\"\n        return torch.ones_like(lower, dtype=torch.bool)\n\n    # Return heuristic to select which neuron should use constraints_solving concretization\n    def compute_bound_improvement_heuristics(self, lower, upper):\n        \"\"\"Return a heuristic score for each lower-upper bound pair.\n        It indicates the possible bound improvement for each neuron.\n        We will then choose if a neuron's bound needs further tightened based on the heuristic.\n        \"\"\"\n        return (-lower * upper).clamp(min=0)\n\n    def get_unstable_mask(self, lower, upper):\n        \"\"\"Return a mask to indicate if each neuron is unstable.\n            Here we mark all the neurons as stable by default.\n\n        0: Stable (linear) neuron; 1: unstable (nonlinear) neuron.\n        \"\"\"\n        return torch.ones_like(lower, dtype=torch.bool)\n\nclass BoundOptimizableActivation(BoundActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if 'optimize_bound_args' not in self.options:\n            self.options['optimize_bound_args'] = {}\n        self.optimizable = True\n        # Stages:\n        #   * `init`: initializing parameters\n        #   * `opt`: optimizing parameters\n        #   * `reuse`: not optimizing parameters but reuse saved values\n        # If `None`, it means activation optimization is currently not used.\n        self.opt_stage = None\n        self.alpha = OrderedDict()\n        # Save patch sizes during bound_backward() for each output_node.\n        self.patch_size = {}\n        # A torch.bool mask of shape Tensor([batch_size]) that conditions the\n        # sample of alpha and beta to update\n        # If set to None, update all samples\n        # If not None, select those corresponding to 1 to update\n\n    def opt_init(self):\n        \"\"\"Enter the stage for initializing bound optimization. Optimized bounds\n        are not used in this stage.\"\"\"\n        self.opt_stage = 'init'\n\n    def opt_start(self):\n        \"\"\"Start optimizing bounds.\"\"\"\n        self.opt_stage = 'opt'\n\n    def opt_reuse(self):\n        \"\"\" Reuse optimizing bounds \"\"\"\n        self.opt_stage = 'reuse'\n\n    def opt_no_reuse(self):\n        \"\"\" Finish reusing optimized bounds \"\"\"\n        if self.opt_stage == 'reuse':\n            self.opt_stage = None\n\n    def opt_end(self):\n        \"\"\" End optimizing bounds \"\"\"\n        self.opt_stage = None\n\n    def clip_alpha(self):\n        pass\n\n    def init_opt_parameters(self, start_nodes):\n        \"\"\" start_nodes: a list of starting nodes [(node, size)] during\n        CROWN backward bound propagation\"\"\"\n        self.alpha = OrderedDict()\n        for start_node in start_nodes:\n            ns, size_s = start_node[:2]\n            # TODO do not give torch.Size\n            if isinstance(size_s, (torch.Size, list, tuple)):\n                size_s = prod(size_s)\n            self.alpha[ns] = self._init_opt_parameters_impl(size_s, name_start=ns)\n\n    def _init_opt_parameters_impl(self, size_spec, name_start=None):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        raise NotImplementedError\n\n    def init_linear_relaxation(self, x, dim_opt=None):\n        self._init_masks(x)\n        # The first dimension of size 2 is used for lA and uA respectively,\n        # when computing intermediate bounds.\n        if self.opt_stage in ['opt', 'reuse'] and dim_opt is not None:\n            # For optimized bounds, we have independent lw for each output\n            # dimension for bound optimization.\n            # If the output layer is a fully connected layer, len(dim_opt) = 1.\n            # If the output layer is a conv layer, len(dim_opt) = 3 but we only\n            # use the out_c dimension to create slopes/bias.\n            # Variables are shared among out_h, out_w dimensions so far.\n            if isinstance(dim_opt, int):\n                dim = dim_opt\n            elif isinstance(dim_opt, torch.Size):\n                dim = prod(dim_opt)\n            else:\n                dim = dim_opt[0]\n            self.lw = torch.zeros(2, dim, *x.lower.shape).to(x.lower)\n        else:\n            # Without optimized bounds, the lw, lb (slope, biase) etc only\n            # depend on intermediate layer bounds,\n            # and are shared among different output dimensions.\n            self.lw = torch.zeros_like(x.lower)\n        self.lb = self.lw.clone()\n        self.uw = self.lw.clone()\n        self.ub = self.lw.clone()\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        return not_implemented_op(self, 'bound_relax')\n\n    def bound_backward(self, last_lA, last_uA, x, start_node=None,\n                       start_shape=None, reduce_bias=True, **kwargs):\n        self._start = start_node.name\n        if self.opt_stage not in ['opt', 'reuse']:\n            last_A = last_lA if last_lA is not None else last_uA\n            # Returned [(lA, uA)], lbias, ubias\n            As, lbias, ubias = super().bound_backward(\n                last_lA, last_uA, x, reduce_bias=reduce_bias)\n            if isinstance(last_A, Patches):\n                A_prod = As[0][1].patches if As[0][0] is None else As[0][1].patches\n                # FIXME: Unify this function with BoundReLU\n                # Save the patch size, which will be used in init_slope() to\n                # determine the number of optimizable parameters.\n                if start_node is not None:\n                    if last_A.unstable_idx is not None:\n                        # Sparse patches, we need to construct the full patch size:\n                        # (out_c, batch, out_h, out_w, c, h, w).\n                        self.patch_size[start_node.name] = [\n                            last_A.output_shape[1], A_prod.size(1),\n                            last_A.output_shape[2], last_A.output_shape[3],\n                            A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]\n                    else:\n                        # Regular patches.\n                        self.patch_size[start_node.name] = A_prod.size()\n            return As, lbias, ubias\n        assert self.batch_dim == 0\n\n        self.bound_relax(x, init=True, dim_opt=start_shape)\n\n        def _bound_oneside(last_A, sign=-1):\n            if last_A is None:\n                return None, 0\n            if sign == -1:\n                w_pos, b_pos, w_neg, b_neg = self.lw[0], self.lb[0], self.uw[0], self.ub[0]\n            else:\n                w_pos, b_pos, w_neg, b_neg = self.uw[1], self.ub[1], self.lw[1], self.lb[1]\n            w_pos = maybe_unfold_patches(w_pos, last_A)\n            w_neg = maybe_unfold_patches(w_neg, last_A)\n            b_pos = maybe_unfold_patches(b_pos, last_A)\n            b_neg = maybe_unfold_patches(b_neg, last_A)\n            unstable_idx = kwargs.get('unstable_idx', None)\n            if unstable_idx is not None:\n                assert isinstance(unstable_idx, Tensor) and unstable_idx.ndim == 1\n                # Shape is (spec, batch, neurons).\n                # FIXME: Sigmoid and other activation functions should also support\n                # sparse-spec alpha, so alpha will be created with a smaller shape.\n                w_pos = self.non_deter_index_select(w_pos, index=unstable_idx, dim=0)\n                w_neg = self.non_deter_index_select(w_neg, index=unstable_idx, dim=0)\n                b_pos = self.non_deter_index_select(b_pos, index=unstable_idx, dim=0)\n                b_neg = self.non_deter_index_select(b_neg, index=unstable_idx, dim=0)\n            A_prod, _bias = multiply_by_A_signs(\n                last_A, w_pos, w_neg, b_pos, b_neg, reduce_bias)\n            return A_prod, _bias\n\n        lA, lbias = _bound_oneside(last_lA, sign=-1)\n        uA, ubias = _bound_oneside(last_uA, sign=+1)\n\n        return [(lA, uA)], lbias, ubias\n\n    def _no_bound_parameters(self):\n        raise AttributeError('Bound parameters have not been initialized.'\n                             'Please call `compute_bounds` with `method=CROWN-optimized`'\n                             ' at least once.')\n\n    def _transfer_alpha(self, alpha, device=None, dtype=None, non_blocking=False, require_grad=False):\n        alpha = {spec_name: transfer(alpha_value, device=device, dtype=dtype, non_blocking=non_blocking).detach().requires_grad_(require_grad)\n                    for spec_name, alpha_value in alpha.items()}\n        return alpha\n\n    def dump_alpha(self, device=None, dtype=None, non_blocking=False):\n        \"\"\"\n        Dump alpha parameters to a dictionary.\n        \"\"\"\n        return {'alpha': self._transfer_alpha(self.alpha, device=device, dtype=dtype, non_blocking=non_blocking, require_grad=False)}\n\n    def restore_alpha(self, alpha, device=None, dtype=None, non_blocking=False):\n        \"\"\"\n        Restore alpha parameters from a dictionary.\n        \"\"\"\n        self.alpha = self._transfer_alpha(alpha['alpha'], device=device, dtype=dtype, non_blocking=non_blocking, require_grad=True)\n\n    def drop_unused_alpha(self, keep_nodes):\n        \"\"\"\n        Drop unused alpha parameters based on the keep_nodes.\n        This function is not used in auto_LiRPA for now, but is used in alpha-beta-CROWN.\n        \"\"\"\n        for spec_name in list(self.alpha.keys()):\n            if spec_name not in keep_nodes:\n                del self.alpha[spec_name]"
  },
  {
    "path": "auto_LiRPA/operators/activations.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Activation operators or other unary nonlinear operators, not including\nthose placed in separate files.\"\"\"\nimport torch\nfrom torch.nn import Module\nfrom .base import *\nfrom .activation_base import BoundActivation, BoundOptimizableActivation\nfrom .clampmult import multiply_by_A_signs\n\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_set_profiling_mode(False)\n\n\nclass BoundSoftplus(BoundActivation):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.softplus = nn.Softplus()\n\n    def forward(self, x):\n        return self.softplus(x)\n\n\nclass BoundAbs(BoundActivation):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.ibp_intermediate = True\n\n    def forward(self, x):\n        return x.abs()\n\n    def bound_relax(self, x, init=False):\n        if init:\n            self.init_linear_relaxation(x)\n        x_L = x.lower.clamp(max=0)\n        x_U = torch.max(x.upper.clamp(min=0), x_L + 1e-8)\n        # upper_k: connect (x_L, |x_L|) and (x_U, |x_U|)\n        upper_k = (x_U.abs() - x_L.abs()) / (x_U - x_L)\n        # lower_k: choose between -1 and 1 depending on which is closer to zero\n        lower_k = (x_U > -x_L).to(x_L) * 2 - 1\n        self.add_linear_relaxation(mask=None, type='upper', k=upper_k, x0=x_L)\n        self.add_linear_relaxation(mask=None, type='lower', k=lower_k, x0=0, y0=0)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        x_L = x.lower.clamp(max=0)\n        x_U = torch.max(x.upper.clamp(min=0), x_L + 1e-8)\n        mask_neg = x_U <= 0\n        mask_pos = x_L >= 0\n        y_L = x_L.abs()\n        y_U = x_U.abs()\n        upper_k = (y_U - y_L) / (x_U - x_L)\n        upper_b = y_L - upper_k * x_L\n        # TODO: Here for the \"mask_both\" case lower_k = 0, but not sure if it's optimal.\n        # lower_b should just be 0?\n        lower_k = (mask_neg * (-1.0) + mask_pos * 1.0)\n        lower_b = (mask_neg + mask_pos) * (y_L - lower_k * x_L)\n        if last_uA is not None:\n            # Special case if we only want the upper bound with non-negative coefficients\n            if last_uA.min() >= 0:\n                uA = last_uA * upper_k\n                ubias = self.get_bias(last_uA, upper_b)\n            else:\n                last_uA_pos = last_uA.clamp(min=0)\n                last_uA_neg = last_uA.clamp(max=0)\n                uA = last_uA_pos * upper_k + last_uA_neg * lower_k\n                ubias = (self.get_bias(last_uA_pos, upper_b)\n                         + self.get_bias(last_uA_neg, lower_b))\n        else:\n            uA, ubias = None, 0\n        if last_lA is not None:\n            if last_lA.max() <= 0:\n                lA = last_lA * upper_k\n                lbias = self.get_bias(last_lA, upper_b)\n            else:\n                last_lA_pos = last_lA.clamp(min=0)\n                last_lA_neg = last_lA.clamp(max=0)\n                lA = last_lA_pos * lower_k + last_lA_neg * upper_k\n                lbias = (self.get_bias(last_lA_pos, lower_b)\n                         + self.get_bias(last_lA_neg, upper_b))\n        else:\n            lA, lbias = None, 0\n        return [(lA, uA)], lbias, ubias\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        lower = ((h_U < 0) * h_U.abs() + (h_L > 0) * h_L.abs())\n        upper = torch.max(h_L.abs(), h_U.abs())\n        return lower, upper\n\n\nclass BoundATenHeaviside(BoundOptimizableActivation):\n    def forward(self, *x):\n        self.input_shape = x[0].shape\n        # x[0]: input; x[1]: value when x == 0\n        return torch.heaviside(x[0], x[1])\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(1)\n        return self.forward(v[0][0], v[1][0]), self.forward(v[0][1], v[1][0])\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l = self.inputs[0].lower\n        return torch.zeros_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim)\n\n    def clip_alpha(self):\n        for v in self.alpha.values():\n            v.data = torch.clamp(v.data, 0., 1.)\n\n    def bound_backward(self, last_lA, last_uA, *x, start_node=None,\n                       start_shape=None, **kwargs):\n        x = x[0]\n        if x is not None:\n            lb_r = x.lower\n            ub_r = x.upper\n        else:\n            lb_r = self.lower\n            ub_r = self.upper\n\n        if self.opt_stage not in ['opt', 'reuse']:\n            # zero slope:\n            upper_d = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)\n            lower_d = torch.zeros_like(ub_r, device=ub_r.device, dtype=ub_r.dtype)\n        else:\n            upper_d = self.alpha[start_node.name][0].clamp(0, 1) * (1. / (-lb_r).clamp(min=1e-3))\n            lower_d = self.alpha[start_node.name][1].clamp(0, 1) * (1. / (ub_r.clamp(min=1e-3)))\n\n        upper_b = torch.ones_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)\n        lower_b = torch.zeros_like(lb_r, device=lb_r.device, dtype=lb_r.dtype)\n        # For stable neurons, set fixed slope and bias.\n        ub_mask = (ub_r <= 0).to(dtype=ub_r.dtype)\n        lb_mask = (lb_r >= 0).to(dtype=lb_r.dtype)\n        upper_b = upper_b - upper_b * ub_mask\n        lower_b = lower_b * (1. - lb_mask) + lb_mask\n        upper_d = upper_d - upper_d * ub_mask - upper_d * lb_mask\n        lower_d = lower_d - lower_d * lb_mask - lower_d * ub_mask\n        upper_d = upper_d.unsqueeze(0)\n        lower_d = lower_d.unsqueeze(0)\n        # Choose upper or lower bounds based on the sign of last_A\n        uA = lA = None\n        ubias = lbias = 0\n        if last_uA is not None:\n            neg_uA = last_uA.clamp(max=0)\n            pos_uA = last_uA.clamp(min=0)\n            uA = upper_d * pos_uA + lower_d * neg_uA\n            ubias = (pos_uA * upper_b + neg_uA * lower_b).flatten(2).sum(-1)\n        if last_lA is not None:\n            neg_lA = last_lA.clamp(max=0)\n            pos_lA = last_lA.clamp(min=0)\n            lA = upper_d * neg_lA + lower_d * pos_lA\n            lbias = (pos_lA * lower_b + neg_lA * upper_b).flatten(2).sum(-1)\n\n        return [(lA, uA), (None, None)], lbias, ubias\n\n\nclass BoundSqr(BoundOptimizableActivation):\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n\n    def forward(self, x):\n        return x ** 2\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        upper_k = x.lower + x.upper\n        # Upper bound: connect the two points (x_l, x_l^2) and (x_u, x_u^2).\n        # The upper bound should always be better than IBP.\n        self.add_linear_relaxation(\n            mask=None, type='upper', k=upper_k, x0=x.lower)\n\n        if self.opt_stage in ['opt', 'reuse']:\n            mid = self.alpha[self._start]\n        else:\n            # Lower bound is a z=0 line if x_l and x_u have different signs.\n            # Otherwise, the lower bound is a tangent line at x_l.\n            # The lower bound should always be better than IBP.\n            # If both x_l and x_u < 0, select x_u. If both > 0, select x_l.\n            # If x_l < 0 and x_u > 0, we use the z=0 line as the lower bound.\n            mid = F.relu(x.lower) - F.relu(-x.upper)\n\n        self.add_linear_relaxation(mask=None, type='lower', k=2 * mid, x0=mid)\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        alpha = torch.empty(2, size_spec, *l.shape, device=l.device)\n        alpha.data[:2] = F.relu(l) - F.relu(-u)\n        return alpha\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        lower = ((h_U < 0) * (h_U ** 2) + (h_L > 0) * (h_L ** 2))\n        upper = torch.max(h_L ** 2, h_U ** 2)\n        return lower, upper\n\n    def build_gradient_node(self, grad_upstream):\n        return [(SqrGrad(), (grad_upstream, self.inputs[0].forward_value), [self.inputs[0]])]\n\n\nclass SqrGrad(Module):\n    def forward(self, grad_last, preact):\n        # (x^2)' = 2*x\n        return grad_last * 2 * preact.unsqueeze(1)\n\n\nclass BoundHardTanh(BoundActivation):\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n        self.activation_name = \"HardTanh\"\n        self.patch_size = {}\n        self.hardtanh_options = options.get('hardtanh', 'same-slope')\n\n    def forward(self, x, min_val, max_val):\n        return F.hardtanh(x, min_val, max_val)\n\n    def bound_backward(self, last_lA, last_uA, x, min_val, max_val, start_node=None,\n                       unstable_idx=None, reduce_bias=True, **kwargs):\n        if self.is_input_perturbed(1) or self.is_input_perturbed(\n                2):  # Checking if min_value and max_value are not perturbed\n            raise NotImplementedError(\n                    f'{self.activation_name} is unsupported with perturbed min_val and max_val')\n\n        self.bound_relax(x, min_val, max_val, init=True)\n\n        def _bound_oneside(last_A, sign=-1):\n            if last_A is None:\n                return None, 0\n            if sign == -1:\n                w_pos, b_pos, w_neg, b_neg = (\n                    self.lw.unsqueeze(0), self.lb.unsqueeze(0),\n                    self.uw.unsqueeze(0), self.ub.unsqueeze(0))\n            else:\n                w_pos, b_pos, w_neg, b_neg = (\n                    self.uw.unsqueeze(0), self.ub.unsqueeze(0),\n                    self.lw.unsqueeze(0), self.lb.unsqueeze(0))\n            w_pos = maybe_unfold_patches(w_pos, last_A)\n            w_neg = maybe_unfold_patches(w_neg, last_A)\n            b_pos = maybe_unfold_patches(b_pos, last_A)\n            b_neg = maybe_unfold_patches(b_neg, last_A)\n\n            # Shapes of w_pos, w_neg, b_pos, b_neg\n            # For toy.py - Final Shape - torch.Size([1, 1, 2]) torch.Size([1, 1, 2]) torch.Size([1, 1, 2]) torch.Size([1, 1, 2])\n            # For simple_verification.py - Final Shape -  torch.Size([1, 2, 16, 14, 14]) torch.Size([1, 2, 16, 14, 14]) torch.Size([1, 2, 16, 14, 14]) torch.Size([1, 2, 16, 14, 14])\n\n            # For all tensors having batch as the first dimension (batch,.....)\n            _A, _bias = multiply_by_A_signs(\n                last_A, w_pos, w_neg, b_pos, b_neg)\n\n            return _A, _bias\n\n        lA, lbias = _bound_oneside(last_lA, sign=-1)\n        uA, ubias = _bound_oneside(last_uA, sign=+1)\n\n        return [(lA, uA), (None, None), (None, None)], lbias, ubias\n\n    def bound_relax(self, x, min_val, max_val, init=False, dim_opt=None):\n        epsilon = 1e-8\n        preact_lb = x.lower.clamp(max=max_val.value)\n        preact_ub = torch.max(x.upper.clamp(min=min_val.value), preact_lb + epsilon)\n\n        min_val = min_val.value\n        max_val = max_val.value\n\n        uw = torch.zeros_like(preact_ub)\n        ub = torch.zeros_like(preact_ub)\n        lw = torch.zeros_like(preact_lb)\n        lb = torch.zeros_like(preact_lb)\n\n        # Case 1:\n        # When upper bound is smaller than min value,\n        # the activated value will always be min value,\n        # so the upper bound and lower bound are both\n        # min value.\n        case1 = (preact_ub <= min_val).to(preact_ub.dtype)\n\n        # Computing intermediate values only once for Case 1\n        value = case1 * min_val\n        ub += value\n        lb += value\n\n        # Case 2:\n        # When lower bound is larger than max value,\n        # the activated value will always be max value,\n        # so the upper bound and lower bound are both\n        # max value.\n        case2 = (preact_lb >= max_val).to(preact_ub.dtype)\n\n        # Computing intermediate values only once for Case 2\n        value = case2 * max_val\n        ub += value\n        lb += value\n\n        # Case 3:\n        # In this case, the activated output for x is always x\n        # so the bias is always zero and slope will also always\n        # be one.\n        case3 = ((preact_lb >= min_val) & (preact_ub <= max_val)).to(preact_ub.dtype)\n        uw += case3\n        lw += case3\n\n        # Case 4:\n        # Upper bound is larger than max val and lower bound is\n        # smaller than min val, in this case, we will use two\n        # line to bound, the upper bound will pass through points\n        # (max_val, max_val) and (lb_r, min_val) and the lower\n        # bound will pass through (min_val, min_val) and (ub_r, max_val).\n        # So, the slope d of the upper line is (max_val - min_val)/(max_val - lb_r)\n        # and the intercept of the upper line is max_val - d * max_val\n        # Similarly, the slope d of the lower line is (max_val - min_val)/(ub_r - min_val)\n        # and the intercept of the lower line is min_val - d * min_val.\n\n        # Computing intermediate values only once for Case 4\n        diff = max_val - min_val\n        val1 = max_val - preact_lb + epsilon\n\n        case4 = ((preact_lb < min_val) & (preact_ub > max_val)).to(preact_ub.dtype)\n        uw += case4 * diff / val1\n        lw += case4 * diff / (preact_ub - min_val + epsilon)\n        ub = case4 * (max_val - diff / val1 * max_val)\n        lb = case4 * (min_val - diff / (preact_ub - min_val + epsilon) * min_val)\n\n        # Computing intermediate values only once ( Case 5 & 6 )\n        denom = preact_ub - preact_lb + epsilon\n\n        # Case 5:\n        # Lower bound is smaller than the min val and the upper bound\n        # is larger than or equal to the min val and smaller or\n        # equal to max val. In this case, we use a single line that\n        # pass through (lb_r, min_val) and (ub_r, ub_r) as the upper\n        # bound. And for lower bound, we use a line with the same slope\n        # as the upper bound and passes through (min_val, min_val) as\n        # lower bound.\n        # So, the slope d of the upper bound is (ub_r - min_val)/(ub_r - lb_r)\n        # and the intercept of the upper bound is ub_r - d * ub_r.\n        # The slope d of the lower bound is same as upper bound and the\n        # intercept of the lower bound is min_val - d * min_val\n\n        # Computing intermediate values only once for Case 5\n        val1 = preact_ub - min_val\n        case5 = ((preact_lb < min_val) & (min_val <= preact_ub) & (preact_ub <= max_val)).to(preact_ub.dtype)\n        uw += case5 * val1 / denom\n        ub += case5 * (preact_ub - val1 / denom * preact_ub)\n\n        if self.hardtanh_options == \"same-slope\":\n            lw += case5 * val1 / denom\n            lb += case5 * (min_val - val1 / denom * min_val)\n\n        elif self.hardtanh_options == \"adaptive\":\n            cond = (uw > 0.5).to(uw)\n            lw += case5 * cond\n            lb += case5 * min_val * (1 - cond)\n\n        # Case 6:\n        # Upper bound is larger than the max val and the lower bound\n        # is larger than or equal to the min val and smaller or\n        # equal to max val. In this case, we use a single line that\n        # pass through (ub_r, max_val) and (lb_r, lb_r) as the lower\n        # bound. And for upper bound, we use a line with the same slope\n        # as lower bound which passes through (max_val, max_val) as the\n        # upper bound.\n        # So, the slope d of the lower bound is (max_val - lb_r)/(ub_r - lb_r).\n        # And the intercept of the lower bound is lb_r - d * lb_r.\n        # The slope d of the upper bound is (max_val - lb_r)/(ub_r - lb_r),\n        # and the intercept of the upper bound is max_val - d * max_val.\n\n        # Computing intermediate values only once for Case 6\n        val1 = max_val - preact_lb\n        case6 = ((min_val <= preact_lb) & (preact_lb <= max_val) & (preact_ub > max_val)).to(preact_ub.dtype)\n        lw += case6 * val1 / denom\n        lb += case6 * (preact_lb - val1 / denom * preact_lb)\n\n        if self.hardtanh_options == \"same-slope\":\n            uw += case6 * val1 / denom\n            ub += case6 * (max_val - val1 / denom * max_val)\n\n        elif self.hardtanh_options == \"adaptive\":\n            cond = (lw > 0.5).to(lw)\n            uw += case6 * cond\n            ub += (case6 * max_val) * (1 - cond)\n\n        self.uw = uw\n        self.lw = lw\n        self.ub = ub\n        self.lb = lb\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        min_val = v[1][0]\n        max_val = v[2][0]\n        assert v[1][0] == v[1][1] and v[2][0] == v[2][1]\n        return self.forward(h_L, min_val, max_val), self.forward(h_U, min_val, max_val)\n\n\nclass BoundFloor(BoundActivation):\n    def forward(self, x):\n        return torch.floor(x)\n\n    def bound_relax(self, x, init=False):\n        if init:\n            self.init_linear_relaxation(x)\n        self.lb += torch.floor(x.lower)\n        self.ub += torch.floor(x.upper)\n\n\nclass BoundMultiPiecewiseNonlinear(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n\n    def forward(self, x, weight, offset):\n        return (F.relu(x.unsqueeze(-1) - offset) * weight).sum(dim=-1)\n\n    def clip_alpha(self):\n        for v in self.alpha.values():\n            v.data = torch.clamp(v.data, 0., 1.)\n\n    def bound_backward(self, last_lA, last_uA, x, weight, offset,\n                       reduce_bias=True, start_node=None, **kwargs):\n        assert not self.is_input_perturbed(1)\n        assert not self.is_input_perturbed(2)\n\n        weight = (\n            self.inputs[1].forward_value\n            if hasattr(self.inputs[1], 'forward_value')\n            else self.inputs[1].forward()\n        )\n        offset = (\n            self.inputs[2].forward_value\n            if hasattr(self.inputs[2], 'forward_value')\n            else self.inputs[2].forward()\n        )\n\n        relu_x_lower = (x.lower.unsqueeze(-1) - offset).clamp(max=0)\n        relu_x_upper = (x.upper.unsqueeze(-1) - offset).clamp(min=0)\n        relu_x_upper = torch.max(relu_x_upper, relu_x_lower + 1e-8)\n        relu_upper_k = relu_x_upper / (relu_x_upper - relu_x_lower)\n        relu_upper_b = -relu_x_lower * relu_upper_k\n        if self.opt_stage not in ['opt', 'reuse']:\n            self.init_lower_k = relu_lower_k = (relu_upper_k > 0.5).to(relu_upper_k)\n            relu_lower_k_for_lA = relu_lower_k_for_uA = relu_lower_k.unsqueeze(0)\n        else:\n            relu_lower_k = self.alpha[start_node.name]\n            relu_lower_k_for_lA = relu_lower_k[0]\n            relu_lower_k_for_uA = relu_lower_k[1]\n        relu_lower_b = torch.zeros_like(relu_upper_b)\n        relu_lower_b = relu_lower_b.unsqueeze(0)\n        relu_upper_k = relu_upper_k.unsqueeze(0)\n        relu_upper_b = relu_upper_b.unsqueeze(0)\n\n        def _bound_oneside(last_A, pos_k, pos_b, neg_k, neg_b, weight, offset, reduce_bias):\n            if last_A is None:\n                return None, 0\n            last_A = last_A.unsqueeze(-1) * weight\n            A_pos = last_A.clamp(min=0)\n            A_neg = last_A.clamp(max=0)\n            A = A_pos * pos_k + A_neg * neg_k\n            b = -A * offset + A_pos * pos_b + A_neg * neg_b\n            A = A.sum(dim=-1)\n            if reduce_bias:\n                b = b.sum(dim=[-1, -2])\n            else:\n                b = b.sum(dim=-1)\n            return A, b\n\n        lA, lb = _bound_oneside(last_lA, relu_lower_k_for_lA, relu_lower_b,\n                                relu_upper_k, relu_upper_b,\n                                weight, offset, reduce_bias)\n        uA, ub = _bound_oneside(last_uA, relu_upper_k, relu_upper_b,\n                                relu_lower_k_for_uA, relu_lower_b,\n                                weight, offset, reduce_bias)\n\n        return [(lA, uA), (None, None), (None, None)], lb, ub\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        alpha = torch.empty(2, size_spec, *self.init_lower_k.shape,\n                            device=self.init_lower_k.device)\n        alpha.data[:2] = self.init_lower_k\n        return alpha\n\n    def get_split_mask(self, lower, upper, input_index):\n        offset = (\n            self.inputs[2].forward_value\n            if hasattr(self.inputs[2], 'forward_value')\n            else self.inputs[2].forward()\n        )\n        return ((lower.unsqueeze(-1) < offset) & (upper.unsqueeze(-1) > offset)).any(dim=-1)\n"
  },
  {
    "path": "auto_LiRPA/operators/add_sub.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom torch.nn import Module\nfrom .base import *\nfrom .constant import BoundConstant\nfrom .solver_utils import grb\n\n\nclass BoundAdd(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        options = options or {}\n        # FIXME: This is not the right way to enable patches mode.\n        # Instead we must traverse the graph and determine when patches mode needs to be used.\n\n        self.mode = options.get(\"conv_mode\", \"matrix\")\n\n    def forward(self, x, y):\n        self.x_shape = x.shape\n        self.y_shape = y.shape\n        return x + y\n\n    def bound_backward(self, last_lA, last_uA, x, y, **kwargs):\n        def _bound_oneside(last_A, w):\n            if last_A is None:\n                return None\n            return self.broadcast_backward(last_A, w)\n\n        uA_x = _bound_oneside(last_uA, x)\n        uA_y = _bound_oneside(last_uA, y)\n        lA_x = _bound_oneside(last_lA, x)\n        lA_y = _bound_oneside(last_lA, y)\n        return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0\n\n    def bound_forward(self, dim_in, x, y):\n        lb, ub = x.lb + y.lb, x.ub + y.ub\n\n        def add_w(x_w, y_w, x_b, y_b):\n            if x_w is None and y_w is None:\n                return None\n            elif x_w is not None and y_w is not None:\n                return x_w + y_w\n            elif y_w is None:\n                return x_w + torch.zeros_like(y_b)\n            else:\n                return y_w + torch.zeros_like(x_b)\n\n        lw = add_w(x.lw, y.lw, x.lb, y.lb)\n        uw = add_w(x.uw, y.uw, x.ub, y.ub)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, x, y):\n        assert (not isinstance(y, Tensor))\n        return x[0] + y[0], x[1] + y[1]\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if isinstance(v[0], Tensor) and isinstance(v[1], Tensor):\n            # constants if both inputs are tensors\n            self.solver_vars = self.forward(v[0], v[1])\n            return\n        # we have both gurobi vars as inputs\n        this_layer_shape = self.output_shape\n        gvar_array1 = np.array(v[0])\n        if isinstance(v[1], Tensor):\n            var2 = v[1].cpu().numpy()\n            # flatten to create vars and constrs first\n            gvar_array1 = gvar_array1.reshape(-1)\n            new_layer_gurobi_vars = []\n            for neuron_idx, var1 in enumerate(gvar_array1):\n                var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0,\n                                   vtype=grb.GRB.CONTINUOUS,\n                                   name=f'lay{self.name}_{neuron_idx}')\n                model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq')\n                new_layer_gurobi_vars.append(var)\n        else:\n            gvar_array2 = np.array(v[1])\n            assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:]\n\n            # flatten to create vars and constrs first\n            gvar_array1 = gvar_array1.reshape(-1)\n            gvar_array2 = gvar_array2.reshape(-1)\n            new_layer_gurobi_vars = []\n            for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)):\n                var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0,\n                                vtype=grb.GRB.CONTINUOUS,\n                                name=f'lay{self.name}_{neuron_idx}')\n                model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq')\n                new_layer_gurobi_vars.append(var)\n        # reshape to the correct list shape of solver vars\n        self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist()\n        model.update()\n\n    def build_gradient_node(self, grad_upstream):\n        if not self.inputs[0].no_jacobian:\n            grad0_node = AddGrad(self.inputs[0].output_shape if self.inputs[0].batch_dim != -1 else\n                                 torch.Size((1,) + self.inputs[0].output_shape))\n            grad0 = (grad0_node, (grad_upstream,), [])\n        else:\n            grad0 = None\n        if not self.inputs[1].no_jacobian:\n            grad1_node = AddGrad(self.inputs[1].output_shape if self.inputs[1].batch_dim != -1 else\n                                 torch.Size((1,) + self.inputs[1].output_shape))\n            grad1 = (grad1_node, (grad_upstream,), [])\n        else:\n            grad1 = None\n        return [grad0, grad1]\n\n\nclass BoundSub(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used.\n        self.mode = options.get(\"conv_mode\", \"matrix\")\n\n    def forward(self, x, y):\n        self.x_shape = x.shape\n        self.y_shape = y.shape\n        return x - y\n\n    def bound_backward(self, last_lA, last_uA, x, y, **kwargs):\n        def _bound_oneside(last_A, w, sign=-1):\n            if last_A is None:\n                return None\n            if isinstance(last_A, torch.Tensor):\n                return self.broadcast_backward(sign * last_A, w)\n            elif isinstance(last_A, Patches):\n                if sign == 1:\n                    # Patches shape requires no broadcast.\n                    return last_A\n                else:\n                    # Multiply by the sign.\n                    return last_A.create_similar(sign * last_A.patches)\n            else:\n                raise ValueError(f'Unknown last_A type {type(last_A)}')\n\n        uA_x = _bound_oneside(last_uA, x, sign=1)\n        uA_y = _bound_oneside(last_uA, y, sign=-1)\n        lA_x = _bound_oneside(last_lA, x, sign=1)\n        lA_y = _bound_oneside(last_lA, y, sign=-1)\n        return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0\n\n    def bound_forward(self, dim_in, x, y):\n        lb, ub = x.lb - y.ub, x.ub - y.lb\n\n        def add_w(x_w, y_w, x_b, y_b):\n            if x_w is None and y_w is None:\n                return None\n            elif x_w is not None and y_w is not None:\n                return x_w + y_w\n            elif y_w is None:\n                return x_w + torch.zeros_like(y_b)\n            else:\n                return y_w + torch.zeros_like(x_b)\n\n        # Some nodes such as BoundConstant does not have uw and lw.\n        lw = add_w(x.lw, -y.uw if y.uw is not None else None, x.lb, y.lb)\n        uw = add_w(x.uw, -y.lw if y.lw is not None else None, x.ub, y.ub)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, x, y):\n        return x[0] - y[1], x[1] - y[0]\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if isinstance(v[0], Tensor) and isinstance(v[1], Tensor):\n            # constants if both inputs are tensors\n            self.solver_vars = self.forward(v[0], v[1])\n            return\n        # we have both gurobi vars as inputs\n        this_layer_shape = self.output_shape\n        gvar_array1 = np.array(v[0])\n        gvar_array2 = np.array(v[1])\n        assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:]\n\n        # flatten to create vars and constrs first\n        gvar_array1 = gvar_array1.reshape(-1)\n        gvar_array2 = gvar_array2.reshape(-1)\n        new_layer_gurobi_vars = []\n        for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)):\n            var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0,\n                            vtype=grb.GRB.CONTINUOUS,\n                            name=f'lay{self.name}_{neuron_idx}')\n            model.addConstr(var == (var1 - var2), name=f'lay{self.name}_{neuron_idx}_eq')\n            new_layer_gurobi_vars.append(var)\n\n        # reshape to the correct list shape of solver vars\n        self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist()\n        model.update()\n\n    def build_gradient_node(self, grad_upstream):\n        if not self.inputs[0].no_jacobian:\n            grad_node_0 = AddGrad(self.inputs[0].output_shape if self.inputs[0].batch_dim != -1 else\n                                  torch.Size((1,) + self.inputs[0].output_shape), w=1.0)\n            grad0 = (grad_node_0, (grad_upstream,), [])\n        else:\n            grad0 = None\n        if not self.inputs[1].no_jacobian:\n            grad_node_1 = AddGrad(self.inputs[1].output_shape if self.inputs[1].batch_dim != -1 else\n                                  torch.Size((1,) + self.inputs[1].output_shape), w=-1.0)\n            grad1 = (grad_node_1, (grad_upstream,), [])\n        else:\n            grad1 = None\n        return [grad0, grad1]\n\n\nclass AddGrad(Module):\n    def __init__(self, input_shape, w=1.0):\n        super().__init__()\n        # We need the input shape to handle broadcasting.\n        self.input_shape = input_shape\n        self.w = w\n\n    def forward(self, grad_last):\n        return reduce_broadcast_dims(grad_last * self.w, self.input_shape)\n"
  },
  {
    "path": "auto_LiRPA/operators/base.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Base class and functions for implementing bound operators\"\"\"\nfrom typing import Optional, List\nimport warnings\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nimport numpy as np\n\nfrom ..perturbations import *\nfrom ..utils import *\nfrom ..patches import *\n\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_set_profiling_mode(False)\n\nepsilon = 1e-12\n\n\ndef not_implemented_op(node, func):\n    message = (\n        f'Function `{func}` of `{node}` is not supported yet.'\n        ' Please help to open an issue at https://github.com/Verified-Intelligence/auto_LiRPA'\n        ' or implement this function in auto_LiRPA/bound_ops.py'\n        ' or auto_LiRPA/operators by yourself.')\n    raise NotImplementedError(message)\n\n\nclass Interval(tuple):\n    \"\"\"Interval object for interval bound propagation.\"\"\"\n\n    # Subclassing tuple object so that all previous code can be reused.\n    def __new__(self, lb=None, ub=None, ptb=None):\n        return tuple.__new__(Interval, (lb, ub))\n\n    def __init__(self, lb, ub, ptb=None):\n        if ptb is None:\n            self.ptb = None\n            # `self.ptb == None` means that this interval\n            # is not perturbed and it shall be treated as a constant and lb = ub.\n            # To avoid mistakes, in this case the caller must make sure lb and ub are the same object.\n            assert lb is ub\n        else:\n            if not isinstance(ptb, Perturbation):\n                raise ValueError(\"ptb must be a Perturbation object or None. Got type {}\".format(type(ptb)))\n            else:\n                self.ptb = ptb\n\n    def __str__(self):\n        return \"({}, {}) with ptb={}\".format(self[0], self[1], self.ptb)\n\n    def __repr__(self):\n        return \"Interval(lb={}, ub={}, ptb={})\".format(self[0], self[1], self.ptb)\n\n    @staticmethod\n    def make_interval(lb, ub, other=None):\n        \"\"\"Checking if the other interval is tuple, keep the perturbation.\"\"\"\n        if isinstance(other, Interval):\n            return Interval(lb, ub, ptb=other.ptb)\n        else:\n            return (lb, ub)\n\n    @staticmethod\n    def get_perturbation(interval):\n        \"\"\"Given a tuple or Interval object, returns the norm and eps.\"\"\"\n        if isinstance(interval, Interval) and interval.ptb is not None:\n            if isinstance(interval.ptb, PerturbationLpNorm):\n                return interval.ptb.norm, interval.ptb.eps\n            elif isinstance(interval.ptb, PerturbationSynonym):\n                return torch.inf, 1.0\n            elif isinstance(interval.ptb, PerturbationL0Norm):\n                return 0, interval.ptb.eps, interval.ptb.ratio\n            elif isinstance(interval.ptb, PerturbationLinear):\n                return torch.inf, 0.0\n            else:\n                raise RuntimeError(\"get_perturbation() does not know how to handle {}\".format(type(interval.ptb)))\n        else:\n            # Tuple object. Assuming L infinity norm lower and upper bounds.\n            return torch.inf, np.nan\n\n\n    @staticmethod\n    def is_perturbed(interval):\n        \"\"\"Checking if a Interval or tuple object has perturbation enabled.\"\"\"\n        if isinstance(interval, Interval) and interval.ptb is None:\n            return False\n        else:\n            return True\n\n\nclass Bound(nn.Module):\n    r\"\"\"\n    Base class for supporting the bound computation of an operator. Please see examples\n    at `auto_LiRPA/operators`.\n\n    Args:\n        attr (dict): Attributes of the operator.\n\n        inputs (list): A list of input nodes.\n\n        output_index (int): The index in the output if the operator has multiple outputs. Usually output_index=0.\n\n        options (dict): Bound options.\n\n    Be sure to run `super().__init__(attr, inputs, output_index, options, device)`\n    first in the `__init__` function.\n    \"\"\"\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__()\n        attr = {} if attr is None else attr\n        inputs = [] if inputs is None else inputs\n        options = {} if options is None else options\n        self.name: Optional[str] = None\n        self.output_name = []\n        self.device = attr.get('device')\n        self.attr = attr\n        self.inputs: List['Bound'] = inputs\n        self.output_index = output_index\n        self.options = options\n        # Mark if this node is used in the bound computation (from the output node).\n        self.used = False\n        self.forward_value = None\n        self.output_shape = None\n        self.from_input = False\n        self.bounded = False\n        self.IBP_rets = None\n        self.requires_input_bounds = []\n        self.from_complex_node = None\n        # If True, when building the Jacobian graph, this node should be treated\n        # as a constant and there is no need to further propagate Jacobian.\n        self.no_jacobian = False\n        # If True, when we are computing intermediate bounds for these ops,\n        # we simply use IBP to propagate bounds from its input nodes\n        # instead of CROWN. Currently only operators with a single input can be\n        # supported.\n        self.ibp_intermediate = False\n        self.splittable = self.force_not_splittable = False\n        # Determine if this node has a perturbed output or not. The function BoundedModule._mark_perturbed_nodes() will set this property.\n        self.perturbed = False\n        self.never_perturbed = False\n        if options is not None and 'loss_fusion' in options:\n            self.loss_fusion = options['loss_fusion']\n        else:\n            self.loss_fusion = False\n        self.options = options\n        # Use `default_interval_propagate`\n        self.use_default_ibp = False\n        # If set to true, the backward bound output of this node is 0.\n        self.zero_backward_coeffs_l = False\n        self.zero_backward_coeffs_u = False\n        # If set to true, the A matrix accumulated on this node is 0.\n        self.zero_lA_mtx = False\n        self.zero_uA_mtx = False\n        self.patches_start = False\n        self.alpha_beta_update_mask = None\n        self.is_final_node = False\n        # By default, we assue this node has no batch dimension.\n        # It will be updated in BoundedModule.get_forward_value().\n        self.batch_dim = -1\n\n        # The .lower and .upper properties are written to as part of the bound propagation.\n        # Usually, in iterative refinement, each bound only depends on bounds previously\n        # computed in the same iteration. However, this changes if INVPROP is used to incorporate\n        # output constraints. Then, we also need bounds of layers *after* the currently bounded\n        # layer. Therefore, we have to cache the older bounds.\n        self._is_lower_bound_current = False\n        self._lower = None\n        self._is_upper_bound_current = False\n        self._upper = None\n        \n        # A list containing the output ACTIVATIONS node from this node.\n        # Please check backward_bound.py, forward_bound.py, batch_branch_and_bound.py for more info.\n        self.output_activations = None\n\n    def __repr__(self, attrs=None):\n        inputs = ', '.join([node.name for node in self.inputs])\n        ret = (f'{self.__class__.__name__}(name={self.name}, '\n                f'inputs=[{inputs}], perturbed={self.perturbed}')\n        if attrs is not None:\n            for k, v in attrs.items():\n                ret += f', {k}={v}'\n        ret += ')'\n        return ret\n\n    @property\n    def lower(self):\n        return self._lower\n\n    @lower.setter\n    def lower(self, value):\n        if not (value is None or isinstance(value, torch.Tensor)):\n            raise TypeError(f'lower must be a tensor or None, got {type(value)}')\n        if value is None:\n            self._is_lower_bound_current = False\n        else:\n            self._is_lower_bound_current = True\n        self._lower = value\n\n    @property\n    def upper(self):\n        return self._upper\n\n    @upper.setter\n    def upper(self, value):\n        if not (value is None or isinstance(value, torch.Tensor)):\n            raise TypeError(f'upper must be a tensor or None, got {type(value)}')\n        if value is None:\n            self._is_upper_bound_current = False\n        else:\n            self._is_upper_bound_current = True\n        self._upper = value\n\n    def move_lower_and_upper_bounds_to_cache(self):\n        if self._lower is not None:\n            self._lower = self._lower.detach().requires_grad_(False)\n            self._is_lower_bound_current = False\n        if self._upper is not None:\n            self._upper = self._upper.detach().requires_grad_(False)\n            self._is_upper_bound_current = False\n\n    def delete_lower_and_upper_bounds(self):\n        self._lower = None\n        self._upper = None\n        self._is_lower_bound_current = False\n        self._is_upper_bound_current = False\n\n    def is_lower_bound_current(self):\n        return self._is_lower_bound_current\n\n    def is_upper_bound_current(self):\n        return self._is_upper_bound_current\n\n    def are_output_constraints_activated_for_layer(\n        self: 'Bound',\n        apply_output_constraints_to: Optional[List[str]],\n    ):\n        if self.is_final_node:\n            return False\n        if apply_output_constraints_to is None:\n            return False\n        for layer_type_or_name in apply_output_constraints_to:\n            if layer_type_or_name.startswith('/'):\n                if self.name == layer_type_or_name:\n                    return True\n            else:\n                assert layer_type_or_name.startswith('Bound'), (\n                    'To apply output constraints to tighten layer bounds, pass either the layer name '\n                    '(starting with \"/\", e.g. \"/input.7\") or the layer type (starting with \"Bound\", '\n                    'e.g. \"BoundLinear\")'\n                )\n                if type(self).__name__ == layer_type_or_name:\n                    return True\n        return False\n\n    def init_gammas(self, num_constraints):\n        if not self.are_output_constraints_activated_for_layer(\n            self.options.get('optimize_bound_args', {}).get('apply_output_constraints_to', [])\n        ):\n            return\n        assert len(self.output_shape) > 0, self\n        neurons_in_this_layer = 1\n        for d in self.output_shape[1:]:\n            neurons_in_this_layer *= d\n        init_gamma_value = 0.0\n        # We need a different number of gammas depending on whether or not they are shared\n        # However, to the code outside of this class, this should be transparent.\n        # We create the correct number of gammas in gammas_underlying_tensor and if necessary\n        # expand it to simulate a larger tensor. This is just a view, no additional memory is created.\n        # By the outside, only .gammas should be used. However, we must take care to update this view\n        # whenever gammas_underlying_tensor was changed (see clip_gammas)\n        # Note that _set_gammas in optimized_bounds.py needs to refer to the gammas_underlying_tensor,\n        # because that's the leaf tensor for which we need to compute gradients.\n        if self.options.get('optimize_bound_args', {}).get('share_gammas', False):\n            self.gammas_underlying_tensor = torch.full((2, num_constraints, 1), init_gamma_value, requires_grad=True, device=self.device)\n            self.gammas = self.gammas_underlying_tensor.expand(-1, -1, neurons_in_this_layer)\n        else:\n            self.gammas_underlying_tensor = torch.full((2, num_constraints, neurons_in_this_layer), init_gamma_value, requires_grad=True, device=self.device)\n            self.gammas = self.gammas_underlying_tensor\n\n    def clip_gammas(self):\n        if not hasattr(self, \"gammas\"):\n            return\n        self.gammas_underlying_tensor.data = torch.clamp(self.gammas_underlying_tensor.data, min=0.0)\n\n        # If gammas are shared, self.gammas != self.gammas_underlying_tensor\n        # We've changed self.gammas_underlying_tensor, those changes must be propagated to self.gammas\n        neurons_in_this_layer = 1\n        for d in self.output_shape[1:]:\n            neurons_in_this_layer *= d\n        if self.options.get('optimize_bound_args', {}).get('share_gammas', False):\n            self.gammas = self.gammas_underlying_tensor.expand(-1, -1, neurons_in_this_layer)\n\n    def is_input_perturbed(self, i=0):\n        r\"\"\"Check if the i-th input is with perturbation or not.\"\"\"\n        return i < len(self.inputs) and self.inputs[i].perturbed\n\n    def clear(self):\n        \"\"\" Clear attributes when there is a new input to the network\"\"\"\n        pass\n\n    @property\n    def input_name(self):\n        return [node.name for node in self.inputs]\n\n    def forward(self, *x):\n        r\"\"\"\n        Function for standard/clean forward.\n\n        Args:\n            x: A list of input values. The length of the list is equal to the number of input nodes.\n\n        Returns:\n            output (Tensor): The standard/clean output of this node.\n        \"\"\"\n        return not_implemented_op(self, 'forward')\n\n    def interval_propagate(self, *v):\n        r\"\"\"\n        Function for interval bound propagation (IBP) computation.\n\n        There is a default function `self.default_interval_propagate(*v)` in the base class,\n        which can be used if the operator is *monotonic*. To use it, set `self.use_default_ibp = True`\n        in the `__init__` function, and the implementation of this function can be skipped.\n\n        Args:\n            v: A list of the interval bound of input nodes.\n            Generally, for each element `v[i]`, `v[i][0]` is the lower interval bound,\n            and `v[i][1]` is the upper interval bound.\n\n        Returns:\n            bound: The interval bound of this node, in a same format as v[i].\n        \"\"\"\n        if self.use_default_ibp or self.never_perturbed:\n            return self.default_interval_propagate(*v)\n        else:\n            return not_implemented_op(self, 'interval_propagate')\n\n    def default_interval_propagate(self, *v):\n        \"\"\"Default IBP using the forward function.\n\n        For unary monotonous functions or functions for altering shapes only\n        but not values.\n        \"\"\"\n        if len(v) == 0:\n            return Interval.make_interval(self.forward(), self.forward())\n        else:\n            if len(v) > 1:\n                for i in range(1, len(v)):\n                    assert not self.is_input_perturbed(i)\n            return Interval.make_interval(\n                self.forward(v[0][0], *[vv[0] for vv in v[1:]]),\n                self.forward(v[0][1], *[vv[0] for vv in v[1:]]), v[0])\n\n    def bound_forward(self, dim_in, *x):\n        r\"\"\"\n        Function for forward mode bound propagation.\n\n        Forward mode LiRPA computs a `LinearBound`\n        instance representing the linear bound for each involved node.\n        Major attributes of `LinearBound` include\n        `lw`, `uw`, `lb`, `ub`, `lower`, and `upper`.\n\n        `lw` and `uw` are coefficients of linear bounds w.r.t. model input.\n        Their shape is `(batch_size, dim_in, *standard_shape)`,\n        where `dim_in` is the total dimension of perturbed input nodes of the model,\n        and `standard_shape` is the shape of the standard/clean output.\n        `lb` and `ub` are bias terms of linear bounds, and their shape is equal\n        to the shape of standard/clean output.\n        `lower` and `upper` are concretized lower and upper bounds that will be\n        computed later in BoundedModule.\n\n        Args:\n            dim_in (int): Total dimension of perturbed input nodes of the model.\n\n            x: A list of the linear bound of input nodes. Each element in x is a `LinearBound` instance.\n\n        Returns:\n            bound (LinearBound): The linear bound of this node.\n        \"\"\"\n        return not_implemented_op(self, 'bound_forward')\n\n    def bound_dynamic_forward(self, *x, max_dim=None, offset=0):\n        raise NotImplementedError(f'bound_dynamic_forward is not implemented for {self}.')\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        r\"\"\"\n        Function for backward mode bound propagation.\n\n        Args:\n            last_lA (Tensor): `A` matrix for lower bound computation propagated to this node. It can be `None` if lower bound is not needed.\n\n            last_uA (Tensor): `A` matrix for upper bound computation propagated to this node. It can be `None` if upper bound is not needed.\n\n            x: A list of input nodes, with x[i].lower and x[i].upper that can be used as pre-activation bounds.\n\n        Returns:\n            A: A list of A matrices for the input nodes. Each element is a tuple (lA, uA).\n\n            lbias (Tensor): The bias term for lower bound computation, introduced by the linear relaxation of this node. .\n\n            ubias (Tensor): The bias term for upper bound computation, introduced by the linear relaxation of this node.\n        \"\"\"\n        return not_implemented_op(self, 'bound_backward')\n\n    def broadcast_backward(self, A, x):\n        \"\"\"\n        Adjust shape of A, adding or removing broadcast dimensions, based on the other operand x.\n\n        Typically, A has [spec, batch, ...].\n        The other operand x may have shape [batch, ...], or no batch dimension.\n        Here the \"...\" dimensions may be different.\n        We need to make sure the two match, by adding or removing dimensions in A.\n        \"\"\"\n        if isinstance(A, Tensor):\n            shape = x.output_shape\n            if x.batch_dim == -1:\n                # The other operand has no batch dimension. (e.g., constants).\n                # Add batch dimension to it.\n                if len(shape) < len(A.shape) - 1:\n                    shape = torch.Size([1] + list(shape))\n                else:\n                    # The not-from-input operand has batch dimension.\n                    # This can happen when the user explicitly unsqueezes the batch dimension on\n                    # a constant tensor when building the computation graph.\n                    warnings.warn(f\"Constant operand of node \\033[96m{self}\\033[0m has batch dimension. \"\n                                  \"Please check your model implementation. \"\n                                  \"Constant operands \\033[93mSHOULD NOT\\033[0m have batch dimension.\")\n            A = reduce_broadcast_dims(A, shape)\n        else:\n            pass\n        return A\n\n    def build_gradient_node(self, grad_upstream):\n        r\"\"\"\n        Function for building the gradient node to bound the Jacobian.\n\n        Args:\n            grad_upstream: Upstream gradient in the gradient back-propagation.\n\n        Returns:\n            A list. Each item contains the following for computing the gradient\n            of each input:\n                module_grad (torch.nn.Module): Gradient node.\n\n                grad_input (list): Inputs to the gradient node. Values do not\n                matter. We only want the shapes.\n\n                grad_extra_nodes (list): Extra nodes needed for the gradient.\n        \"\"\"\n        return not_implemented_op(self, 'build_gradient_node')\n\n    def get_bias(self, A, bias):\n        if A is None:\n            return 0\n        if not Benchmarking:\n            assert not isnan(A)\n            assert not isnan(bias)\n        if torch.isinf(bias).any():\n            warnings.warn('There is an inf value in the bias of LiRPA bounds.')\n\n        if isinstance(A, Tensor):\n            if self.batch_dim != -1:\n                bias_new = torch.einsum('sb...,b...->sb', A, bias)\n            else:\n                bias_new = torch.einsum('sb...,...->sb', A, bias)\n            if isnan(bias_new):\n                # NaN can be caused by 0 * inf, if 0 appears in `A` and inf appears in `bias`.\n                # Force the whole bias to be 0, to avoid gradient issues.\n                # FIXME maybe find a more robust solution.\n                return 0\n            else:\n                # FIXME (09/17): handle the case for pieces.unstable_idx.\n                return bias_new\n        elif isinstance(A, eyeC):\n            batch_size = A.shape[1]\n            if self.batch_dim != -1:\n                return bias.reshape(batch_size, -1).t()\n            else:\n                return bias.reshape(-1).unsqueeze(-1).repeat(1, batch_size)\n        elif type(A) == Patches:\n            # the shape of A.patches is [batch, L, out_c, in_c, K, K]\n            if self.batch_dim != -1:\n                # Input A patches has shape (spec, batch, out_h, out_w, in_c, H, W) or (unstable_size, batch, in_c, H, W).\n                patches = A.patches\n                # Here the size of bias is [batch_size, out_h, out_w, in_c, H, W]\n                bias = inplace_unfold(bias, kernel_size=A.patches.shape[-2:], stride=A.stride, padding=A.padding, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding)\n                if A.unstable_idx is not None:\n                    # Sparse bias has shape [unstable_size, batch_size, in_c, H, W]. No need to select over the out_c dimension.\n                    bias = bias[:, A.unstable_idx[1], A.unstable_idx[2]]\n                    # bias_new has shape (unstable_size, batch).\n                    bias_new = torch.einsum('bschw,sbchw->sb', bias, patches)\n                else:\n                    # Sum over the in_c, H, W dimension. Use torch.einsum() to save memory, equal to:\n                    # bias_new = (bias * patches).sum(-1,-2,-3).transpose(-2, -1)\n                    # bias_new has shape (spec, batch, out_h, out_w).\n                    bias_new = torch.einsum('bijchw,sbijchw->sbij', bias, patches)\n            else:\n                # Similar to BoundConstant. (BoundConstant does not have batch_dim).\n                # FIXME (09/16): bias size is different for BoundConstant. We should use the same size!\n                patches = A.patches\n                bias_new = torch.sum(patches, dim=(-1, -2, -3)) * bias.to(self.device)\n                # Return shape is (spec, batch, out_h, out_w) or (unstable_size, batch).\n                return bias_new\n            return bias_new\n        else:\n            return NotImplementedError()\n\n    def make_axis_non_negative(self, axis, shape='input'):\n        \"\"\"Convert negative axis to non-negative axis.\n        Args:\n            axis (int or tuple or list): The axis to be converted.\n\n            shape (str or torch.Size): The shape of the tensor. If 'input', use self.input_shape.\n                If 'output', use self.output_shape. Otherwise, it should be a torch.Size object.\n                For example, if the tensor shape is (2, 3, 4), then axis -1 will be converted to 2.\n                For the \"squeeze\" operation, the shape should be the 'input' shape.\n                While for the \"unsqueeze\" operation, the shape should be the 'output' shape.\n\n        Returns:\n            axis (int or tuple): The non-negative axis.\n        \"\"\"\n        if isinstance(axis, (tuple, list)):\n            return tuple(sorted([self.make_axis_non_negative(item, shape)\n                                 for item in axis]))\n        if shape == 'input':\n            shape = self.input_shape\n        elif shape == 'output':\n            shape = self.output_shape\n        else:\n            assert isinstance(shape, torch.Size)\n        if axis < 0:\n            return axis + len(shape)\n        else:\n            return axis\n\n    def update_requires_input_bounds(self):\n        \"\"\"Update requires_input_bounds.\n\n        This function is called once we know if the input nodesare perturbed.\n        \"\"\"\n        pass\n\n    def clamp_interim_bounds(self):\n        \"\"\"Clamp intermediate bounds.\"\"\"\n        pass\n\n    def check_constraint_available(self, node, flag=False):\n        if hasattr(node, 'cstr_interval'):\n            flag = True\n        for n in node.inputs:\n            if not n.from_input:\n                flag = flag or self.check_constraint_available(n, flag)\n        return flag\n\n    def _ibp_constraint(self, node: 'Bound', delete_bounds_after_use=False):\n        def _delete_unused_bounds(node_list):\n            \"\"\"Delete bounds from input layers after use to save memory. Used when\n            sparse_intermediate_bounds_with_ibp is true.\"\"\"\n            if delete_bounds_after_use:\n                for n in node_list:\n                    del n.cstr_interval\n                    del n.cstr_lower\n                    del n.cstr_upper\n\n        if not node.perturbed and hasattr(node, 'forward_value'):\n            node.cstr_lower, node.cstr_upper = node.cstr_interval = (\n                node.forward_value, node.forward_value)\n\n        to_be_deleted_bounds = []\n        if not hasattr(node, 'cstr_interval'):\n            for n in node.inputs:\n                if not hasattr(n, 'cstr_interval'):\n                    # Node n does not have interval bounds; we must compute it.\n                    self._ibp_constraint(\n                        n, delete_bounds_after_use=delete_bounds_after_use)\n                    to_be_deleted_bounds.append(n)\n            inp = [n_pre.cstr_interval for n_pre in node.inputs]\n            node.cstr_interval = node.interval_propagate(*inp)\n\n            node.cstr_lower, node.cstr_upper = node.cstr_interval\n            if isinstance(node.cstr_lower, torch.Size):\n                node.cstr_lower = torch.tensor(node.cstr_lower)\n                node.cstr_interval = (node.cstr_lower, node.cstr_upper)\n            if isinstance(node.cstr_upper, torch.Size):\n                node.cstr_upper = torch.tensor(node.cstr_upper)\n                node.cstr_interval = (node.cstr_lower, node.cstr_upper)\n\n        if node.is_lower_bound_current():\n            node.lower = torch.where(node.lower >= node.cstr_lower, node.lower,\n                            node.cstr_lower)\n            node.upper = torch.where(node.upper <= node.cstr_upper, node.upper,\n                            node.cstr_upper)\n            node.interval = (node.lower, node.upper)\n\n        _delete_unused_bounds(to_be_deleted_bounds)\n        return node.cstr_interval\n\n    def _check_weight_perturbation(self):\n        weight_perturbation = False\n        for n in self.inputs[1:]:\n            if hasattr(n, 'perturbation'):\n                if n.perturbation is not None:\n                    weight_perturbation = True\n        if weight_perturbation:\n            self.requires_input_bounds = list(range(len(self.inputs)))\n        else:\n            self.requires_input_bounds = []\n        return weight_perturbation\n\n    def non_deter_wrapper(self, op, *args, **kwargs):\n        \"\"\"Some operations are non-deterministic and deterministic mode will fail.\n        So we temporary disable it.\"\"\"\n        if self.options.get('deterministic', False):\n            torch.use_deterministic_algorithms(False)\n        ret = op(*args, **kwargs)\n        if self.options.get('deterministic', False):\n            torch.use_deterministic_algorithms(True)\n        return ret\n\n    def non_deter_scatter_add(self, *args, **kwargs):\n        return self.non_deter_wrapper(torch.scatter_add, *args, **kwargs)\n\n    def non_deter_index_select(self, *args, **kwargs):\n        return self.non_deter_wrapper(torch.index_select, *args, **kwargs)\n"
  },
  {
    "path": "auto_LiRPA/operators/bivariate.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Bivariate operators\"\"\"\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom typing import Dict, Optional\nfrom .base import *\nfrom .activation_base import BoundOptimizableActivation\nfrom .convex_concave import BoundSqrt\nfrom .clampmult import multiply_by_A_signs\nfrom ..utils import *\n\n\nclass MulHelper:\n    \"\"\"Handle linear relaxation for multiplication.\n\n    This helper can be used by BoundMul, BoundMatMul,\n    BoundLinear (with weight perturbation).\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    @staticmethod\n    def interpolated_relaxation(x_l: Tensor, x_u: Tensor,\n                                y_l: Tensor, y_u: Tensor,\n                                r_l: Optional[Tensor] = None,\n                                r_u: Optional[Tensor] = None,\n                                middle: bool = False,\n                               ) -> Tuple[Tensor, Tensor, Tensor,\n                                          Tensor, Tensor, Tensor]:\n        \"\"\"Interpolate two optimal linear relaxations for optimizable bounds.\"\"\"\n        if r_l is None and r_u is None:\n            if middle:\n                # This option is equivalent to optimized linear relaxation\n                # with 0.5 as the fixed parameter.\n                # It interpolates two valid linear relaxations.\n                # See Appendix C in https://openreview.net/pdf?id=BJxwPJHFwS\n                alpha_l = (y_l - y_u) * 0.5 + y_u\n                beta_l = (x_l - x_u) * 0.5 + x_u\n                gamma_l = (y_u * x_u - y_l * x_l) * 0.5 - y_u * x_u\n                alpha_u = (y_u - y_l) * 0.5 + y_l\n                beta_u = (x_l - x_u) * 0.5 + x_u\n                gamma_u = (y_l * x_u - y_u * x_l) * 0.5 - y_l * x_u\n            else:\n                alpha_l, beta_l, gamma_l = y_l, x_l, -y_l * x_l\n                alpha_u, beta_u, gamma_u = y_u, x_l, -y_u * x_l\n            return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u\n        else:\n            assert isinstance(r_l, Tensor) and isinstance(r_u, Tensor)\n            # TODO (for zhouxing/qirui): this function may benefit from JIT,\n            # because it has many element-wise operation which can be fused.\n            # Need to benchmark and see performance.\n            alpha_l = (y_l - y_u) * r_l + y_u\n            beta_l = (x_l - x_u) * r_l + x_u\n            gamma_l = (y_u * x_u - y_l * x_l) * r_l - y_u * x_u\n            alpha_u = (y_u - y_l) * r_u + y_l\n            beta_u = (x_l - x_u) * r_u + x_u\n            gamma_u = (y_l * x_u - y_u * x_l) * r_u - y_l * x_u\n            return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u\n\n    @staticmethod\n    def get_relaxation(x_l: Tensor, x_u: Tensor, y_l: Tensor, y_u: Tensor,\n                       opt_stage: Optional[str],\n                       alphas: Optional[Dict[str, Tensor]],\n                       start_name: Optional[str],\n                       middle: bool = False,\n                      ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n        if opt_stage in ['opt', 'reuse']:\n            assert x_l.ndim == y_l.ndim\n            ns = start_name\n            alphas[ns].data[:] = alphas[ns].data[:].clamp(min=0, max=1)\n            return MulHelper.interpolated_relaxation(\n                x_l, x_u, y_l, y_u, alphas[ns][:2], alphas[ns][2:4])\n        else:\n            return MulHelper.interpolated_relaxation(\n                x_l, x_u, y_l, y_u, middle=middle)\n\n    @staticmethod\n    def get_forward_relaxation(x_l, x_u, y_l, y_u, opt_stage, alpha, start_name):\n        # Broadcast\n        # FIXME perhaps use a more efficient way\n        x_l = x_l + torch.zeros_like(y_l)\n        x_u = x_u + torch.zeros_like(y_u)\n        y_l = y_l + torch.zeros_like(x_l)\n        y_u = y_u + torch.zeros_like(x_u)\n        return MulHelper.get_relaxation(x_l, x_u, y_l, y_u, opt_stage, alpha, start_name)\n\n    @staticmethod\n    def _get_gap(x, y, alpha, beta):\n        return x * y - alpha * x - beta * y\n\n\nclass BoundMul(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n        self.mul_helper = MulHelper()\n        if options is None:\n            options = {}\n        self.middle = options.get('mul', {}).get('middle', False)\n\n    def forward(self, x, y):\n        self.x_shape = x.shape\n        self.y_shape = y.shape\n        return x * y\n\n    def get_relaxation_opt(self, x_l, x_u, y_l, y_u):\n        return self.mul_helper.get_relaxation(\n            x_l, x_u, y_l, y_u, self.opt_stage, getattr(self, 'alpha', None),\n            getattr(self, '_start', None), middle=self.middle)\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        x_l = self.inputs[0].lower\n        y_l = self.inputs[1].lower\n        assert x_l.ndim == y_l.ndim\n        shape = [max(x_l.shape[i], y_l.shape[i]) for i in range(x_l.ndim)]\n        alpha = torch.ones(4, size_spec, *shape, device=x_l.device)\n        return alpha\n\n    def _is_softmax(self):\n        \"\"\"This multiplication comes from softmax.\n\n        It is the division converted to BoundMul + BoundReciprocal.\n        \"\"\"\n        return (\n            self.from_complex_node == 'BoundSoftmax'\n            and type(self.inputs[0]).__name__ == 'BoundExp'\n            and type(self.inputs[1]).__name__ == 'BoundReciprocal'\n            and type(self.inputs[1].inputs[0]).__name__ == 'BoundReduceSum'\n            and type(self.inputs[1].inputs[0].inputs[0]).__name__ == 'BoundExp')\n\n    def bound_relax(self, x, y, init=False, dim_opt=None):\n        if init:\n            pass\n        (alpha_l, beta_l, gamma_l,\n         alpha_u, beta_u, gamma_u) = self.get_relaxation_opt(\n            x.lower, x.upper, y.lower, y.upper)\n\n        # Check NaN which can happen in softmax if Exp's bounds are too loose\n        if self._is_softmax():\n            assert alpha_l.shape[:-1] == beta_l.shape[:-1]\n            assert alpha_l.shape[-1] == 1 or alpha_l.shape[-1] == beta_l.shape[-1]\n            assert beta_l.shape == gamma_l.shape\n            mask = (alpha_l.isnan().expand(beta_l.shape)\n                    | alpha_l.isinf().expand(beta_l.shape)\n                    | beta_l.isnan() | beta_l.isinf()\n                    | gamma_l.isnan() | gamma_l.isinf())\n            if mask.any():\n                alpha_l = alpha_l.clone()\n                alpha_l[mask.any(dim=-1)] = 0\n                beta_l = beta_l.clone()\n                beta_l[mask] = 0\n                gamma_l = gamma_l.clone()\n                gamma_l[mask] = 0\n\n            assert alpha_u.shape[:-1] == beta_u.shape[:-1]\n            assert alpha_u.shape[-1] == 1 or alpha_u.shape[-1] == beta_u.shape[-1]\n            assert beta_u.shape == gamma_u.shape\n            mask = (alpha_u.isnan().expand(beta_u.shape)\n                    | alpha_u.isinf().expand(beta_u.shape)\n                    | beta_u.isnan() | beta_u.isinf()\n                    | gamma_u.isnan() | gamma_u.isinf())\n            if mask.any():\n                alpha_u = alpha_u.clone()\n                alpha_u[mask.any(dim=-1)] = 0\n                beta_u = beta_u.clone()\n                beta_u[mask] = 0\n                gamma_u = gamma_u.clone()\n                gamma_u[mask] = 1.\n\n        self.lw = [alpha_l, beta_l]\n        self.lb = gamma_l\n        self.uw = [alpha_u, beta_u]\n        self.ub = gamma_u\n\n    @staticmethod\n    def _multiply_by_const(x, const):\n        if isinstance(x, torch.Tensor):\n            return x * const\n        elif isinstance(x, Patches):\n            # Multiplies patches by a const. Assuming const is a tensor, and it must be in nchw format.\n            assert isinstance(const, torch.Tensor) and const.ndim == 4\n            if (const.size(0) == 1 or const.size(0) == x.patches.size(1)) and const.size(1) == x.patches.size(-3) and const.size(2) == const.size(3) == 1:\n                # The case that we can do channel-wise broadcasting multiplication\n                # Shape of const: (batch, in_c, 1, 1)\n                # Shape of patches when unstable_idx is None: (spec, batch, in_c, patch_h, patch_w)\n                # Shape of patches when unstable_idx is not None: (out_c, batch, out_h, out_w, in_c, patch_h, patch_w)\n                const_reshaped = const\n            else:\n                assert x.unstable_idx is None and (x.padding == 0 or x.padding == [0,0,0,0]) and x.stride == 1 and x.patches.size(-1) == x.patches.size(-2) == 1\n                # The assumed dimension is (out_c, N, out_h, out_w, in_c, 1, 1) with padding =1 and stride = 0.\n                # In this special case we can directly multiply.\n                # After reshape it is (1, N, H, W, C, 1, 1)\n                const_reshaped = const.permute(0, 2, 3, 1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n            return x.create_similar(x.patches * const_reshaped)\n        else:\n            raise ValueError(f'Unsupported x type {type(x)}')\n\n    def bound_backward_constant(self, last_lA, last_uA, x, y, op=None,\n                                reduce_bias=True, **kwargs):\n        assert reduce_bias\n        op = BoundMul._multiply_by_const if op is None else op\n        # Handle the case of multiplication by a constant.\n        factor = None\n        if x.perturbed:\n            factor = y.forward_value\n        if y.perturbed:\n            factor = x.forward_value\n        # No need to compute A matrix if it is Constant.\n        lAx = (None if not x.perturbed or last_lA is None\n               else self.broadcast_backward(op(last_lA, factor), x))\n        uAx = (None if not x.perturbed or last_uA is None\n               else self.broadcast_backward(op(last_uA, factor), x))\n        lAy = (None if not y.perturbed or last_lA is None\n               else self.broadcast_backward(op(last_lA, factor), y))\n        uAy = (None if not y.perturbed or last_uA is None\n               else self.broadcast_backward(op(last_uA, factor), y))\n        return [(lAx, uAx), (lAy, uAy)], 0., 0.\n\n    def bound_backward(self, last_lA, last_uA, x, y, start_node=None, **kwargs):\n        if start_node is not None:\n            self._start = start_node.name\n        if self.is_linear_op:\n            ret = self.bound_backward_constant(last_lA, last_uA, x, y, **kwargs)\n        else:\n            ret = self.bound_backward_both_perturbed(\n                last_lA, last_uA, x, y, **kwargs)\n        return ret\n\n    def bound_backward_both_perturbed(self, last_lA, last_uA, x, y,\n                                      reduce_bias=True, **kwargs):\n        self.bound_relax(x, y)\n\n        def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos,\n                           alpha_neg, beta_neg, gamma_neg, opt=False):\n            if last_A is None:\n                return None, None, 0\n\n            if type(last_A) == Patches:\n                assert reduce_bias\n                assert last_A.identity == 0\n                # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W].\n                # Here out_c is the spec dimension.\n                # for patches mode, we need to unfold the alpha_pos/neg and beta_pos/neg\n                alpha_pos = maybe_unfold_patches(alpha_pos, last_A)\n                alpha_neg = maybe_unfold_patches(alpha_neg, last_A)\n                beta_pos = maybe_unfold_patches(beta_pos, last_A)\n                beta_neg = maybe_unfold_patches(beta_neg, last_A)\n                gamma_pos = maybe_unfold_patches(gamma_pos, last_A)\n                gamma_neg = maybe_unfold_patches(gamma_neg, last_A)\n                A_x, bias = multiply_by_A_signs(\n                    last_A, alpha_pos, alpha_neg, gamma_pos, gamma_neg)\n                A_y, _ = multiply_by_A_signs(\n                    last_A, beta_pos, beta_neg, None, None)\n            elif type(last_A) == Tensor:\n                last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0)\n                A_x, _ = multiply_by_A_signs(last_A, alpha_pos, alpha_neg, None, None)\n                A_y, _ = multiply_by_A_signs(last_A, beta_pos, beta_neg, None, None)\n                A_x = self.broadcast_backward(A_x, x)\n                A_y = self.broadcast_backward(A_y, y)\n                if reduce_bias:\n                    if opt:\n                        bias = (torch.einsum('sb...,sb...->sb',\n                                             last_A_pos, gamma_pos)\n                                + torch.einsum('sb...,sb...->sb',\n                                               last_A_neg, gamma_neg))\n                    else:\n                        bias = (self.get_bias(last_A_pos, gamma_pos.squeeze(0)) +\n                            self.get_bias(last_A_neg, gamma_neg.squeeze(0)))\n                else:\n                    assert not opt\n                    bias = last_A_pos * gamma_pos + last_A_neg * gamma_neg\n                    assert len(x.output_shape) == bias.ndim - 1\n                    assert len(y.output_shape) == bias.ndim - 1\n                    bias_x = bias_y = bias\n                    for i in range(2, bias.ndim):\n                        if bias_x.shape[i] != x.output_shape[i - 1]:\n                            assert x.output_shape[i - 1] == 1\n                            bias_x = bias_x.sum(i, keepdim=True)\n                    for i in range(2, bias.ndim):\n                        if bias_y.shape[i] != y.output_shape[i - 1]:\n                            assert y.output_shape[i - 1] == 1\n                            bias_y = bias_y.sum(i, keepdim=True)\n                    bias = (bias_x, bias_y)\n            else:\n                raise NotImplementedError(last_A)\n            return A_x, A_y, bias\n\n        alpha_l, beta_l, gamma_l = self.lw[0], self.lw[1], self.lb\n        alpha_u, beta_u, gamma_u = self.uw[0], self.uw[1], self.ub\n\n        if self.opt_stage in ['opt', 'reuse']:\n            lA_x, lA_y, lbias = _bound_oneside(\n                last_lA, alpha_l[0], beta_l[0], gamma_l[0],\n                alpha_u[0], beta_u[0], gamma_u[0], opt=True)\n            uA_x, uA_y, ubias = _bound_oneside(\n                last_uA, alpha_u[1], beta_u[1], gamma_u[1],\n                alpha_l[1], beta_l[1], gamma_l[1], opt=True)\n        else:\n            alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0)\n            beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0)\n            gamma_l, gamma_u = gamma_l.unsqueeze(0), gamma_u.unsqueeze(0)\n            lA_x, lA_y, lbias = _bound_oneside(\n                last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u)\n            uA_x, uA_y, ubias = _bound_oneside(\n                last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l)\n\n        return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias\n\n    def bound_forward(self, dim_in, x, y):\n        if self.is_linear_op:\n            if not self.inputs[0].perturbed:\n                return self.bound_forward_constant(x, y, self.inputs[0].batch_dim != -1)\n            elif not self.inputs[1].perturbed:\n                return self.bound_forward_constant(y, x, self.inputs[1].batch_dim != -1)\n            else:\n                assert False, \"When is_linear_op is True, at least one input should be constant.\"\n        return self.bound_forward_both_perturbed(dim_in, x, y)\n    \n    def bound_forward_constant(self, x, y, batched_constant):\n        # x is constant\n        const = x.lb\n        const_pos, const_neg = const.clamp(min=0), const.clamp(max=0)\n        lb = const_pos * y.lb + const_neg * y.ub\n        ub = const_pos * y.ub + const_neg * y.lb\n        if batched_constant:\n            # If x is batched, its first dimension will be the batch dimension\n            # We need to unsqueeze an extra dimension to align the batch dimension\n            # x and y both have shape (B, a_1, a_2, ..., a_n)\n            # lw/uw has shape (B, dim_in, a_1, a_2, ..., a_n)\n            const_pos = const_pos.unsqueeze(1)\n            const_neg = const_neg.unsqueeze(1)\n        lw = const_pos * y.lw + const_neg * y.uw\n        uw = const_pos * y.uw + const_neg * y.lw\n        return LinearBound(lw, lb, uw, ub)\n\n    def bound_forward_both_perturbed(self, dim_in, x, y):\n        x_lw, x_lb, x_uw, x_ub = x.lw, x.lb, x.uw, x.ub\n        y_lw, y_lb, y_uw, y_ub = y.lw, y.lb, y.uw, y.ub\n\n        (alpha_l, beta_l, gamma_l,\n         alpha_u, beta_u, gamma_u) = MulHelper.get_forward_relaxation(\n             x.lower, x.upper, y.lower, y.upper, self.opt_stage, getattr(self, 'alpha', None), self._start)\n\n        if x_lw is None: x_lw = 0\n        if y_lw is None: y_lw = 0\n        if x_uw is None: x_uw = 0\n        if y_uw is None: y_uw = 0\n\n        lw = alpha_l.unsqueeze(1).clamp(min=0) * x_lw + alpha_l.unsqueeze(1).clamp(max=0) * x_uw\n        lw = lw + beta_l.unsqueeze(1).clamp(min=0) * y_lw + beta_l.unsqueeze(1).clamp(max=0) * y_uw\n        lb = (alpha_l.clamp(min=0) * x_lb + alpha_l.clamp(max=0) * x_ub +\n             beta_l.clamp(min=0) * y_lb + beta_l.clamp(max=0) * y_ub + gamma_l)\n        uw = alpha_u.unsqueeze(1).clamp(max=0) * x_lw + alpha_u.unsqueeze(1).clamp(min=0) * x_uw\n        uw = uw + beta_u.unsqueeze(1).clamp(max=0) * y_lw + beta_u.unsqueeze(1).clamp(min=0) * y_uw\n        ub = (alpha_u.clamp(max=0) * x_lb + alpha_u.clamp(min=0) * x_ub +\n             beta_u.clamp(max=0) * y_lb + beta_u.clamp(min=0) * y_ub + gamma_u)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    @staticmethod\n    def interval_propagate_constant(x, y, op=lambda x, const: x * const):\n        # x is constant\n        const = x[0]\n        inp_lb = y[0]\n        inp_ub = y[1]\n        pos_mask = (const > 0).to(dtype=inp_lb.dtype)\n        neg_mask = 1. - pos_mask\n        lb = op(inp_lb, const * pos_mask) + op(inp_ub, const * neg_mask)\n        ub = op(inp_ub, const * pos_mask) + op(inp_lb, const * neg_mask)\n        return lb, ub\n\n    def interval_propagate(self, x, y):\n        if self.is_linear_op:\n            if not self.inputs[0].perturbed:\n                return self.interval_propagate_constant(x, y)\n            elif not self.inputs[1].perturbed:\n                return self.interval_propagate_constant(y, x)\n            else:\n                assert False, \"When is_linear_op is True, at least one input should be constant.\"\n        else:\n            lower, upper = self.interval_propagate_both_perturbed(x, y)\n            if self._is_softmax():\n                lower = lower.clamp(min=0)\n                upper = upper.clamp(max=1)\n            return lower, upper\n\n    @staticmethod\n    def interval_propagate_both_perturbed(*v):\n        x, y = v[0], v[1]\n        if x is y:\n            # A shortcut for x * x.\n            h_L, h_U = v[0]\n            r0 = h_L * h_L\n            r1 = h_U * h_U\n            # When h_L < 0, h_U > 0, lower bound is 0.\n            # When h_L < 0, h_U < 0, lower bound is h_U * h_U.\n            # When h_L > 0, h_U > 0, lower bound is h_L * h_L.\n            l = F.relu(h_L) - F.relu(-h_U)\n            return l * l, torch.max(r0, r1)\n\n        r0, r1, r2, r3 = x[0] * y[0], x[0] * y[1], x[1] * y[0], x[1] * y[1]\n        lower = torch.min(torch.min(r0, r1), torch.min(r2, r3))\n        upper = torch.max(torch.max(r0, r1), torch.max(r2, r3))\n\n        return lower, upper\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if isinstance(v[0], Tensor):\n            self.solver_vars = self.forward(*v)\n            return\n        gvar_array = np.array(v[0])\n        gvar_array = gvar_array * v[1].cpu().numpy()\n        self.solver_vars = gvar_array.tolist()\n\n    def update_requires_input_bounds(self):\n        self.is_linear_op = False\n        for inp in self.inputs:\n            if not inp.perturbed:\n                # If any of the two inputs are constant, we do not need input bounds.\n                self.is_linear_op = True\n        if self.is_linear_op:\n            # One input is constant; no bounds required.\n            self.requires_input_bounds = []\n            self.splittable = False\n        else:\n            # Both inputs are perturbed. Need relaxation.\n            self.requires_input_bounds = [0, 1]\n            if not self.force_not_splittable:\n                self.splittable = True\n        \n    def build_gradient_node(self, grad_upstream):\n        grad_node_0 = MulGrad(self.inputs[0].output_shape if self.inputs[0].batch_dim != -1 else\n                              torch.Size((1,) + self.inputs[0].output_shape))\n        grad_node_1 = MulGrad(self.inputs[1].output_shape if self.inputs[1].batch_dim != -1 else\n                              torch.Size((1,) + self.inputs[1].output_shape))\n        return [(grad_node_0, (grad_upstream, self.inputs[1].forward_value), [self.inputs[1]]),\n                (grad_node_1, (grad_upstream, self.inputs[0].forward_value), [self.inputs[0]])]\n\n\nclass MulGrad(Module):\n    def __init__(self, input_shape):\n        super().__init__()\n        # We need the input shape to handle broadcasting\n        self.input_shape = input_shape\n\n    def forward(self, grad_last, y):\n        # z = x * y\n        # ∂z/∂x = y\n        if y.ndim > 0:\n            # If y is not a constant scalar, its second dimension is for spec\n            y = y.unsqueeze(1)\n        return reduce_broadcast_dims(grad_last * y, self.input_shape)\n\n\nclass BoundDiv(Bound):\n\n    def forward(self, x, y):\n        # FIXME (05/11/2022): ad-hoc implementation for layer normalization\n        if isinstance(self.inputs[1], BoundSqrt):\n            input = self.inputs[0].inputs[0]\n            x = input.forward_value\n            n = input.forward_value.shape[-1]\n\n            dev = x * (1. - 1. / n) - (x.sum(dim=-1, keepdim=True) - x) / n\n            dev_sqr = dev ** 2\n            s = (dev_sqr.sum(dim=-1, keepdim=True) - dev_sqr) / dev_sqr.clamp(min=epsilon)\n            sqrt = torch.sqrt(1. / n * (s + 1))\n            return torch.sign(dev) * (1. / sqrt)\n\n        return x / y\n"
  },
  {
    "path": "auto_LiRPA/operators/clampmult.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Element multiplication with the A matrix based on its sign.\"\"\"\nimport torch\nfrom typing import Optional, Tuple\nfrom torch import Tensor\nfrom ..patches import Patches\n\n\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_set_profiling_mode(False)\n\n\nclass ClampedMultiplication(torch.autograd.Function):\n    @staticmethod\n    @torch.no_grad()\n    @torch.jit.script\n    def clamp_mutiply_forward(A: Tensor, d_pos: Tensor, d_neg: Tensor,\n            b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool,\n            reduce_bias: bool = False, same_slope: bool = False\n        ) -> Tuple[Tensor, Tensor]:\n        \"\"\"Forward operations; actually the same as the reference implementation.\"\"\"\n        A_pos = A.clamp(min=0)\n        A_neg = A.clamp(max=0)    \n        if same_slope:\n            # \"same-slope\" option is enabled; lower and upper bounds use the same A.\n            A_new = d_pos * A          \n        else:  \n            A_new = d_pos * A_pos + d_neg * A_neg\n        \n        bias_pos = bias_neg = torch.zeros(\n                (), dtype=A_new.dtype, device=A_new.device)\n        if b_pos is not None:\n            if not reduce_bias:\n                bias_pos = A_pos * b_pos\n            else:\n                if patches_mode:\n                    bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos)\n                else:\n                    bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos)\n        if b_neg is not None:\n            if not reduce_bias:\n                bias_neg = A_neg * b_neg\n            else:\n                if patches_mode:\n                    bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg)\n                else:\n                    bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg)\n        return A_new, bias_pos + bias_neg\n\n    @staticmethod\n    @torch.no_grad()\n    @torch.jit.script\n    def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor,\n            b_pos: Optional[Tensor], b_neg: Optional[Tensor],\n            grad_output_A: Tensor, grad_output_bias: Optional[Tensor], same_slope: bool = False\n        ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor],\n                   None, None, None]:\n        \"\"\"Improved backward operation. This should be better than the backward\n        function generated by Pytorch.\"\"\"\n        if grad_output_bias is not None:\n            extension_dim = len(A.shape) - len(grad_output_bias.shape)\n            grad_output_bias = grad_output_bias.view(\n                grad_output_bias.shape + (1, ) * extension_dim)\n        A_pos_mask = (A >= 0).to(dtype=grad_output_A.dtype)\n        A_neg_mask = 1. - A_pos_mask\n        A_pos_grad_output_A = A_pos_mask * grad_output_A\n        A_neg_grad_output_A = A_neg_mask * grad_output_A\n        # Although d_pos is d_neg, we still need to get gd_pos and gd_neg separately.\n        gd_pos = A * A_pos_grad_output_A\n        gd_neg = A * A_neg_grad_output_A\n        if b_pos is not None and b_neg is not None and grad_output_bias is not None:\n            A_pos_grad_output_bias = A_pos_mask * grad_output_bias\n            A_neg_grad_output_bias = A_neg_mask * grad_output_bias\n            gb_neg = A * A_neg_grad_output_bias\n            gb_pos = A * A_pos_grad_output_bias\n            if same_slope:\n                gA = (d_pos * grad_output_A\n                      + b_pos * A_pos_grad_output_bias\n                      + b_neg * A_neg_grad_output_bias)\n            else:\n                gA = (d_pos * A_pos_grad_output_A\n                    + d_neg * A_neg_grad_output_A\n                    + b_pos * A_pos_grad_output_bias\n                    + b_neg * A_neg_grad_output_bias)\n        elif b_neg is not None and grad_output_bias is not None:\n            A_neg_grad_output_bias = A_neg_mask * grad_output_bias\n            gb_neg = A * A_neg_grad_output_bias\n            gb_pos = None\n            if same_slope:\n                gA = (d_pos * grad_output_A\n                      + b_neg * A_neg_grad_output_bias)\n            else:\n                gA = (d_pos * A_pos_grad_output_A\n                    + d_neg * A_neg_grad_output_A\n                    + b_neg * A_neg_grad_output_bias)\n        elif b_pos is not None and grad_output_bias is not None:\n            A_pos_grad_output_bias = A_pos_mask * grad_output_bias\n            gb_pos = A * A_pos_grad_output_bias\n            gb_neg = None\n            if same_slope:\n                gA = (d_pos * grad_output_A\n                      + b_pos * A_pos_grad_output_bias)\n            else:\n                gA = (d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A\n                    + b_pos * A_pos_grad_output_bias)\n        else:\n            if same_slope:\n                gA = d_pos * grad_output_A\n            else:\n                gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A\n            gb_pos = gb_neg = None\n        return gA, gd_pos, gd_neg, gb_pos, gb_neg, None, None, None\n\n    @staticmethod\n    def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode, reduce_bias=True, same_slope=False):\n        # No need to save the intermediate A_pos, A_neg as they have been fused into the computation.\n        ctx.save_for_backward(A, d_pos, d_neg, b_pos, b_neg)\n        ctx.patches_mode = patches_mode\n        ctx.reduce_bias = reduce_bias\n        ctx.same_slope = same_slope\n        return ClampedMultiplication.clamp_mutiply_forward(\n            A, d_pos, d_neg, b_pos, b_neg, patches_mode, reduce_bias, same_slope)\n\n    @staticmethod\n    def backward(ctx, grad_output_A, grad_output_bias):\n        A, d_pos, d_neg, b_pos, b_neg = ctx.saved_tensors\n        assert ctx.reduce_bias\n        return ClampedMultiplication.clamp_mutiply_backward(\n            A, d_pos, d_neg, b_pos, b_neg,\n            grad_output_A, grad_output_bias, ctx.same_slope)\n\n\ndef multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto',\n                        reduce_bias=True, same_slope=False):\n    if isinstance(A, Tensor):\n        if contiguous is True or contiguous == 'auto':\n            # For dense mode, convert d_pos and d_neg to contiguous tensor by default.\n            d_pos = d_pos.contiguous()\n            d_neg = d_neg.contiguous()\n        if d_pos.ndim == 1:\n            # Special case for LSTM, the bias term is 1-dimension. (FIXME)\n            assert d_neg.ndim == 1 and b_pos.ndim == 1 and b_neg.ndim == 1\n            new_A = A.clamp(min=0) * d_pos + A.clamp(max=0) * d_neg\n            new_bias = A.clamp(min=0) * b_pos + A.clamp(max=0) * b_neg\n            return new_A, new_bias\n        return ClampedMultiplication.apply(\n            A, d_pos, d_neg, b_pos, b_neg, False, reduce_bias, same_slope)\n    elif isinstance(A, Patches):\n        if contiguous:\n            # For patches mode, do not convert d_pos and d_neg to contiguous tensor by default.\n            d_pos = d_pos.contiguous()\n            d_neg = d_neg.contiguous()\n        assert A.identity == 0  # TODO: handle the A.identity = 1 case. Currently not used.\n        patches = A.patches\n        patches_shape = patches.shape\n        # patches shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.\n        # or (unstable_size, batch_size, in_c, H, W) when it is sparse.\n        if len(patches_shape) == 6:\n            patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:])\n            d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_pos is not None else None\n            d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_neg is not None else None\n            b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_pos is not None else None\n            b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_neg is not None else None\n        # Apply the multiplication based on signs.\n        A_prod, bias = ClampedMultiplication.apply(\n            patches, d_pos, d_neg, b_pos, b_neg, True, reduce_bias, same_slope)\n        # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.\n        # For sparse patches the return bias size is (unstable_size, batch).\n        # For regular patches the return bias size is (spec, batch, out_h, out_w).\n        if len(patches_shape) == 6:\n            A_prod = A_prod.view(*patches_shape)\n        return A.create_similar(A_prod), bias\n\n"
  },
  {
    "path": "auto_LiRPA/operators/constant.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Constant operators, including operators that are usually fixed nodes and not perturbed \"\"\"\nfrom .base import *\n\n\nclass BoundConstant(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.value = attr['value'].to(self.device)\n        self.use_default_ibp = True\n        self.no_jacobian = True\n\n    def __repr__(self):\n        if self.value.numel() == 1:\n            return f'BoundConstant(name={self.name}, value={self.value})'\n        else:\n            return super().__repr__()\n\n    def forward(self):\n        return self.value.to(self.device)\n\n    def bound_backward(self, last_lA, last_uA, **kwargs):\n        def _bound_oneside(A):\n            if A is None:\n                return 0.0\n\n            if type(A) == Tensor:\n                if A.ndim > 2:\n                    A = torch.sum(A, dim=list(range(2, A.ndim)))\n            elif type(A) == Patches:\n                assert A.padding == 0 or A.padding == (0, 0, 0, 0) or self.value == 0  # FIXME (09/19): adding padding here.\n                patches_reshape = torch.sum(A.patches, dim=(-1, -2, -3)) * self.value.to(self.device)\n                # Expected shape for bias is (spec, batch, out_h, out_w) or (unstable_size, batch)\n                return patches_reshape\n\n            return A * self.value.to(self.device)\n\n        lbias = _bound_oneside(last_lA)\n        ubias = _bound_oneside(last_uA)\n        return [], lbias, ubias\n\n    def bound_forward(self, dim_in):\n        lw = uw = torch.zeros(dim_in, device=self.device)\n        lb = ub = self.value\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.value\n\n\nclass BoundPrimConstant(Bound):\n    def forward(self):\n        return torch.tensor([], device=self.device)\n\n\nclass BoundConstantOfShape(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.value = attr['value'].to(self.device)\n        self.no_jacobian = True\n\n    def forward(self, x):\n        self.x = x\n        self.from_input = True\n        return self.value.expand(*list(x))\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        if last_lA is not None:\n            lower_sum_b = last_lA * self.value\n            while lower_sum_b.ndim > 2:\n                lower_sum_b = torch.sum(lower_sum_b, dim=-1)\n        else:\n            lower_sum_b = 0\n\n        if last_uA is not None:\n            upper_sum_b = last_uA * self.value\n            while upper_sum_b.ndim > 2:\n                upper_sum_b = torch.sum(upper_sum_b, dim=-1)\n        else:\n            upper_sum_b = 0\n\n        return [(None, None)], lower_sum_b, upper_sum_b\n\n    def bound_forward(self, dim_in, x):\n        assert (len(self.x) >= 1)\n        lb = ub = torch.ones(self.output_shape, device=self.device) * self.value\n        lw = uw = torch.zeros(self.x[0], dim_in, *self.x[1:], device=self.device)\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, *v):\n        self.x = v[0][0]\n        value = torch.ones(tuple(v[0][0]), device=self.device) * self.value\n        return value, value\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v)\n\n\nclass BoundRange(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.device = attr['device']\n\n    def forward(self, start, end, step):\n        if start.dtype == end.dtype == step.dtype == torch.int64:\n            return torch.arange(start, end, step, dtype=torch.int64, device=self.device)\n        else:\n            return torch.arange(start, end, step, device=self.device)\n\n\nclass BoundATenDiag(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.device = attr['device']\n\n    def forward(self, x, diagonal=0):\n        return torch.diag(x, diagonal=diagonal)\n\n    def interval_propagate(self, *v):\n        return Interval.make_interval(torch.diag(v[0][0], v[1][0]), torch.diag(v[0][1], v[1][0]), v[0])\n\n\nclass BoundATenDiagonal(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.device = attr['device']\n\n    def forward(self, x, offset=0, dim1=0, dim2=1):\n        return torch.diagonal(x, offset=offset, dim1=dim1, dim2=dim2)\n\n    def interval_propagate(self, *v):\n        params = (v[1][0], v[2][0], v[3][0])\n        return Interval.make_interval(torch.diagonal(v[0][0], *params), torch.diagonal(v[0][1], *params), v[0])\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        for i in range(1, 4):\n            assert isinstance(self.inputs[i], BoundConstant)\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            A = torch.zeros(*last_A.shape[:2], *self.inputs[0].output_shape[1:]).to(last_A)\n            dim1, dim2 = self.inputs[2].value, self.inputs[3].value\n            assert dim1 != 0 and dim2 != 0\n            if dim1 > 0:\n                dim1 += 1\n            if dim2 > 0:\n                dim2 += 1\n            A = torch.diagonal_scatter(\n                A, last_A,\n                offset=self.inputs[1].value, dim1=dim1, dim2=dim2)\n            return A\n\n        return ([(_bound_oneside(last_lA), _bound_oneside(last_uA))]\n                + [(None, None)] * 3), 0, 0\n"
  },
  {
    "path": "auto_LiRPA/operators/convex_concave.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Nonlinear functions that are either convex or convave within the entire domain.\"\"\"\nimport torch\nfrom torch.nn import Module\nfrom .base import *\nfrom .activation_base import BoundActivation, BoundOptimizableActivation\n\n\nclass BoundLog(BoundActivation):\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.range_l = 1e-6\n\n    def forward(self, x):\n        # NOTE adhoc implementation for loss fusion\n        if self.loss_fusion:\n            return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1)\n        return torch.log(x.clamp(min=epsilon))\n\n    def bound_relax(self, x, init=False):\n        if init:\n            self.init_linear_relaxation(x)\n        rl, ru = self.forward(x.lower), self.forward(x.upper)\n        ku = (ru - rl) / (x.upper - x.lower + epsilon)\n        self.add_linear_relaxation(mask=None, type='lower', k=ku, x0=x.lower, y0=rl)\n        m = (x.lower + x.upper) / 2\n        k = torch.reciprocal(m)\n        rm = self.forward(m)\n        self.add_linear_relaxation(mask=None, type='upper', k=k, x0=m, y0=rm)\n\n    def interval_propagate(self, *v):\n        # NOTE adhoc implementation for loss fusion\n        if self.loss_fusion:\n            par = self.inputs[0].inputs[0].inputs[0]\n            lower = torch.logsumexp(par.lower, dim=-1)\n            upper = torch.logsumexp(par.upper, dim=-1)\n            return lower, upper\n        return super().interval_propagate(*v)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        A, lbias, ubias = super().bound_backward(last_lA, last_uA, x)\n        # NOTE adhoc implementation for loss fusion\n        if self.loss_fusion:\n            assert A[0][0] is None\n            exp_module = self.inputs[0].inputs[0]\n            ubias = ubias + self.get_bias(A[0][1], exp_module.max_input.squeeze(-1))\n        return A, lbias, ubias\n\n\nclass BoundSqrt(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_prior_constraint = True\n        self.has_constraint = True\n        self.range_l = 1e-6\n\n    def forward(self, x):\n        return torch.sqrt(x)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n\n        if self.opt_stage in ['opt', 'reuse']:\n            self.alpha[self._start].data[:2] = torch.min(torch.max(\n                self.alpha[self._start].data[:2], x.lower), x.upper)\n            mid = self.alpha[self._start]\n        else:\n            mid = (x.lower + x.upper) / 2\n        k = 0.5 / self.forward(mid)\n        self.add_linear_relaxation(mask=None, type='upper', k=k, x0=mid)\n\n        sqrt_l = self.forward(x.lower)\n        sqrt_u = self.forward(x.upper)\n        k = (sqrt_u - sqrt_l) / (x.upper - x.lower).clamp(min=1e-8)\n        self.add_linear_relaxation(mask=None, type='lower', k=k, x0=x.lower)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        if self.use_prior_constraint and self.check_constraint_available(x):\n            if hasattr(x, 'cstr_interval'):\n                del x.cstr_interval\n                del x.cstr_lower\n                del x.cstr_upper\n\n            x_l, x_u = self._ibp_constraint(x, delete_bounds_after_use=True)\n            x_u = torch.max(x_u, x_l + 1e-8)\n        return super().bound_backward(last_lA, last_uA, x, **kwargs)\n\n    def clamp_interim_bounds(self):\n        self.cstr_lower = self.lower.clamp(min=0)\n        self.cstr_upper = self.upper.clamp(min=0)\n        self.cstr_interval = (self.cstr_lower, self.cstr_upper)\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        alpha = torch.empty(2, size_spec, *l.shape, device=l.device)\n        alpha.data[:2] = (l + u) / 2\n        return alpha\n\n\nclass BoundReciprocal(BoundOptimizableActivation):\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n        self.range_l = 1e-6\n\n    def forward(self, x):\n        return torch.reciprocal(x)\n\n    def interval_propagate(self, *v):\n        h_L = v[0][0].to(dtype=torch.get_default_dtype())\n        h_U = v[0][1].to(dtype=torch.get_default_dtype())\n        assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal'\n        return torch.reciprocal(h_U), torch.reciprocal(h_L)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n\n        assert x.lower.min() >= 0\n\n        ku = -1. / (x.lower * x.upper)\n        self.add_linear_relaxation(mask=None, type='upper', k=ku, x0=x.lower)\n\n        if self.opt_stage in ['opt', 'reuse']:\n            self.alpha[self._start].data[:2] = torch.min(torch.max(\n                self.alpha[self._start].data[:2], x.lower), x.upper)\n            mid = self.alpha[self._start].clamp(min=0.01)\n        else:\n            mid = (x.lower + x.upper) / 2\n\n        self.add_linear_relaxation(\n            mask=None, type='lower', k=-1./(mid**2), x0=mid)\n\n        if x.lower.min() <= 0:\n            mask = x.lower == 0\n            self.uw[..., mask] = 0\n            self.ub[..., mask] = torch.inf\n        if x.upper.isinf().any():\n            mask = x.upper.isinf()\n            self.lw[..., mask] = 0\n            self.lb[..., mask] = 0\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        As, lbias, ubias = super().bound_backward(last_lA, last_uA, x, **kwargs)\n        if isinstance(ubias, torch.Tensor) and ubias.isnan().any():\n            ubias[ubias.isnan()] = torch.inf if (last_uA != 0).any() else 0.\n        if isinstance(lbias, torch.Tensor) and lbias.isnan().any():\n            lbias[lbias.isnan()] = 0.\n        return As, lbias, ubias\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        alpha = torch.empty(2, size_spec, *l.shape, device=l.device)\n        alpha.data[:2] = (l + u) / 2\n        return alpha\n\n    def build_gradient_node(self, grad_upstream):\n        return [(ReciprocalGrad(), (grad_upstream, self.inputs[0].forward_value), [self.inputs[0]])]\n\n\nclass ReciprocalGrad(Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, g, x):\n        # partial derivative of 1/x is -1/x^2\n        return -g / torch.square(x).unsqueeze(1)\n\n\nclass BoundExp(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if options is None:\n            options = {}\n        self.options = options.get('exp', {})\n        self.max_input = 0\n\n    def forward(self, x):\n        if self.loss_fusion and self.options != 'no-max-input':\n            self.max_input = torch.max(x, dim=-1, keepdim=True)[0].detach()\n            return torch.exp(x - self.max_input)\n        return torch.exp(x)\n\n    def interval_propagate(self, *v):\n        assert len(v) == 1\n        # unary monotonous functions only\n        h_L, h_U = v[0]\n        if self.loss_fusion and self.options != 'no-max-input':\n            self.max_input = torch.max(h_U, dim=-1, keepdim=True)[0]\n            h_L, h_U = h_L - self.max_input, h_U - self.max_input\n        else:\n            self.max_input = 0\n        return torch.exp(h_L), torch.exp(h_U)\n\n    def bound_forward(self, dim_in, x):\n        m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)\n\n        exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)\n\n        kl = exp_m\n        lw = x.lw * kl.unsqueeze(1)\n        lb = kl * (x.lb - m + 1)\n\n        ku = (exp_u - exp_l) / (x.upper - x.lower + epsilon)\n        uw = x.uw * ku.unsqueeze(1)\n        ub = x.ub * ku - ku * x.lower + exp_l\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable).\n        if self.loss_fusion and last_lA is None and last_uA is not None and torch.min(\n                last_uA) >= 0 and x.from_input:\n            # Adding an extra bias term to the input. This is equivalent to adding a constant and subtract layer before exp.\n            # Note that we also need to adjust the bias term at the end.\n            if self.options == 'no-detach':\n                self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0]\n            elif self.options != 'no-max-input':\n                self.max_input = torch.max(x.upper, dim=-1, keepdim=True)[0].detach()\n            else:\n                self.max_input = 0\n            adjusted_lower = x.lower - self.max_input\n            adjusted_upper = x.upper - self.max_input\n            # relaxation for upper bound only (used in loss fusion)\n            exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper)\n            k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower).clamp(min=1e-8)\n            if k.requires_grad:\n                k = k.clamp(min=1e-8)\n            uA = last_uA * k.unsqueeze(0)\n            ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0)\n\n            if ubias.ndim > 2:\n                ubias = torch.sum(ubias, dim=tuple(range(2, ubias.ndim)))\n            # Also adjust the missing ubias term.\n            if uA.ndim > self.max_input.ndim:\n                A = torch.sum(uA, dim=tuple(range(self.max_input.ndim, uA.ndim)))\n            else:\n                A = uA\n\n            # These should hold true in loss fusion\n            assert self.batch_dim == 0\n            assert A.shape[0] == 1\n\n            batch_size = A.shape[1]\n            ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0)\n            return [(None, uA)], 0, ubias\n        else:\n            As, lbias, ubias = super().bound_backward(last_lA, last_uA, x, **kwargs)\n            lA, uA = As[0]\n            lA, lbias = self._check_nan(lA, lbias, last_lA, 0)\n            uA, ubias = self._check_nan(uA, ubias, last_uA, torch.inf)\n            return [(lA, uA)], lbias, ubias\n\n    def _check_nan(self, A, bias, last_A, const_bound):\n        \"\"\"Check for NaN caused by 0 in A and inf in lw/lb/uw/ub.\n        It can happen if the pre-activation bounds are very loose for exp.\n        \"\"\"\n        if A is None:\n            return A, bias\n        if bias.isnan().any():\n            # These assertions ensure that 0 is in A and inf is in lw/lb/uw/ub\n            assert not last_A.isnan().any()\n            assert not last_A.isinf().any()\n            assert not self.lw.isnan().any()\n            assert not self.uw.isnan().any()\n            assert not self.lb.isnan().any()\n            assert not self.ub.isnan().any()\n            A_ = A.view(-1, *A.shape[2:]).clone()\n            bias_ = bias.view(-1).clone()\n            mask = bias_.isnan()\n            A_[mask] = 0\n            assert (last_A >= 0).all()\n            bias_[mask] = const_bound if (last_A != 0).any() else 0.\n            A = A_.view(A.shape)\n            bias = bias_.view(bias.shape)\n        return A, bias\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        min_val = -1e9\n        l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val)\n        if self.opt_stage in ['opt', 'reuse']:\n            self.alpha[self._start].data[:2] = torch.min(torch.max(\n                self.alpha[self._start].data[:2], x.lower), x.upper)\n            m = torch.min(self.alpha[self._start], x.lower + 0.99)\n        else:\n            m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99)\n        exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper)\n        k = exp_m\n        self.add_linear_relaxation(mask=None, type='lower', k=k, x0=m, y0=exp_m)\n        k = (exp_u - exp_l) / (u - l).clamp(min=1e-8)\n        self.add_linear_relaxation(mask=None, type='upper', k=k, x0=l, y0=exp_l)\n\n    def _init_opt_parameters_impl(self, size_spec, **kwargs):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        alpha = torch.empty(2, size_spec, *l.shape, device=l.device)\n        alpha.data[:2] = (l + u) / 2\n        return alpha\n\n    def build_gradient_node(self, grad_upstream):\n        if self.loss_fusion:\n            raise NotImplementedError('Gradient computation for exp with loss fusion is not supported.')\n        return [(ExpGrad(), (grad_upstream, self.inputs[0].forward_value), [self.inputs[0]])]\n\n\nclass ExpGrad(Module):\n    def __init__(self):\n        super().__init__()\n    \n    def forward(self, g, preact):\n        # exp'(x) = exp(x)\n        return g * torch.exp(preact).unsqueeze(1)\n"
  },
  {
    "path": "auto_LiRPA/operators/convolution.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Convolution and padding operators\"\"\"\nfrom torch.autograd import Function\nfrom torch.nn import Module\nfrom .base import *\nimport numpy as np\nfrom .solver_utils import grb\nfrom ..patches import unify_shape, compute_patches_stride_padding, is_shape_used, create_valid_mask\n\nEPS = 1e-2\n\nclass BoundConv(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n\n        if len(attr['kernel_shape']) == 1:\n            # for 1d conv\n            assert (attr['pads'][0] == attr['pads'][1])\n            self.padding = [attr['pads'][0]]\n            self.F_conv = F.conv1d\n            self.conv_dim = 1\n        else:\n            # for 2d conv\n            assert (attr['pads'][0] == attr['pads'][2])\n            assert (attr['pads'][1] == attr['pads'][3])\n            self.padding = [attr['pads'][0], attr['pads'][1]]\n            self.F_conv = F.conv2d\n            self.conv_dim = 2\n\n        self.stride = attr['strides']\n        self.dilation = attr['dilations']\n        self.groups = attr['group']\n        if len(inputs) == 3:\n            self.has_bias = True\n        else:\n            self.has_bias = False\n        self.patches_start = True\n        if options is None:\n            options = {}\n        self.mode = options.get(\"conv_mode\", \"matrix\")\n        # denote whether this Conv is followed by a ReLU\n        # if self.relu_followed is False, we need to manually pad the conv patches.\n        # If self.relu_followed is True, the patches are padded in the ReLU layer\n        # and the manual padding is not needed.\n        self.relu_followed = False\n\n    def forward(self, *x):\n        # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias\n        bias = x[2] if self.has_bias else None\n\n        output = self.F_conv(x[0], x[1], bias, self.stride, self.padding, self.dilation, self.groups)\n\n        return output\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        if self.is_input_perturbed(1):\n            raise NotImplementedError(\n                'Weight perturbation for convolution layers has not been implmented.')\n\n        lA_y = uA_y = lA_bias = uA_bias = None\n        weight = x[1].lower\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, 0\n            if type(last_A) is OneHotC:\n                # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead.\n                last_A = onehotc_to_dense(last_A, dtype=weight.dtype)\n\n            if type(last_A) == Tensor:\n                shape = last_A.size()\n                # when (W−F+2P)%S != 0, construct the output_padding\n                if self.conv_dim == 2:\n                    output_padding0 = (\n                        int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 *\n                        self.padding[0] - 1 - (int(weight.size()[2] - 1) * self.dilation[0]))\n                    output_padding1 = (\n                        int(self.input_shape[3]) - (int(self.output_shape[3]) - 1) * self.stride[1] + 2 *\n                        self.padding[1] - 1 - (int(weight.size()[3] - 1) * self.dilation[1]))\n                    next_A = F.conv_transpose2d(\n                        last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None,\n                        stride=self.stride, padding=self.padding, dilation=self.dilation,\n                        groups=self.groups, output_padding=(output_padding0, output_padding1))\n                else:\n                    # for 1d conv, we use conv_transpose1d()\n                    output_padding = (\n                            int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 *\n                            self.padding[0] - 1 - (int(weight.size()[2] - 1) * self.dilation[0]))\n                    next_A = F.conv_transpose1d(\n                        last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None,\n                        stride=self.stride, padding=self.padding, dilation=self.dilation,\n                        groups=self.groups, output_padding=output_padding)\n\n                next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])\n                if self.has_bias:\n                    # sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)\n                    sum_bias = torch.einsum('sbc...,c->sb', last_A, x[2].lower)\n                else:\n                    sum_bias = 0\n                return next_A, sum_bias\n            elif type(last_A) == Patches:\n                # Here we build and propagate a Patch object with (patches, stride, padding)\n                assert self.conv_dim == 2, 'Patches mode not supports conv1d so far.'\n                assert type(last_A) == Patches\n                if last_A.identity == 0:\n                    # FIXME (09/20): Don't call it relu_followed. Instead, make this a property of A, called \"padded\" and propagate this property.\n                    if not self.relu_followed:\n                        patches = last_A.create_padding(self.output_shape)\n                    else:\n                        patches = last_A.patches\n\n                    if self.has_bias:\n                        # bias is x[2] (lower and upper are the same), and has shape (c,).\n                        # Patches either has [out_c, batch, out_h, out_w, c, h, w] or [unstable_size, batch, c, h, w].\n                        sum_bias = torch.einsum('sb...chw,c->sb...', patches, x[2].lower)\n                        # sum_bias has shape (out_c, batch, out_h, out_w) or (unstable_size, batch).\n                    else:\n                        sum_bias = 0\n\n                    flattened_patches = patches.reshape(\n                        -1, patches.size(-3), patches.size(-2), patches.size(-1))\n                    pieces = F.conv_transpose2d(\n                        flattened_patches, insert_zeros(weight, last_A.inserted_zeros)\n                        , stride=self.stride)\n                    # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w).\n                    pieces = pieces.view(\n                        *patches.shape[:-3], pieces.size(-3), pieces.size(-2),\n                        pieces.size(-1))\n\n                elif last_A.identity == 1:\n                    # New patches have size [out_c, batch, out_h, out_w, c, h, w] if it is not sparse.\n                    # New patches have size [unstable_size, batch, c, h, w] if it is sparse.\n                    if last_A.unstable_idx is not None:\n                        pieces = weight.view(\n                            weight.size(0), 1, weight.size(1), weight.size(2),\n                            weight.size(3))\n                        # Select based on the output channel (out_h and out_w are irrelevant here).\n                        pieces = pieces[last_A.unstable_idx[0]]\n                        # Expand the batch dimnension.\n                        pieces = pieces.expand(-1, last_A.shape[1], -1, -1, -1)\n                        # Do the same for the bias.\n                        if self.has_bias:\n                            sum_bias = x[2].lower[last_A.unstable_idx[0]].unsqueeze(-1)\n                            # bias has shape (unstable_size, batch).\n                            sum_bias = sum_bias.expand(-1, last_A.shape[1])\n                        else:\n                            sum_bias = 0\n                    else:\n                        assert weight.size(0) == last_A.shape[0]\n                        pieces = weight.view(\n                            weight.size(0), 1, 1, 1, weight.size(1), weight.size(2),\n                            weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1)\n                        # The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w).\n                        # Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version\n                        if self.has_bias:\n                            sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])\n                        else:\n                            sum_bias = 0\n                else:\n                    raise NotImplementedError()\n                padding = last_A.padding if last_A is not None else (0, 0, 0, 0)  # (left, right, top, bottom)\n                stride = last_A.stride if last_A is not None else (1, 1)\n                inserted_zeros = last_A.inserted_zeros if last_A is not None else 0\n                output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0)\n\n                padding, stride, output_padding = compute_patches_stride_padding(\n                    self.input_shape, padding, stride, self.padding, self.stride,\n                    inserted_zeros, output_padding)\n\n                if (inserted_zeros == 0 and not is_shape_used(output_padding)\n                    and pieces.shape[-1] > self.input_shape[-1]):  # the patches is too large and from now on, we will use matrix mode instead of patches mode.\n                    # This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix.\n                    # After multiplication, the desired output is (batch_size, out_channel*output_x*output_y).\n                    # A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w)\n                    A_matrix = patches_to_matrix(\n                        pieces, self.input_shape[1:], stride, padding,\n                        last_A.output_shape, last_A.unstable_idx)\n                    # print(f'Converting patches to matrix: old shape {pieces.shape}, size {pieces.numel()}; new shape {A_matrix.shape}, size {A_matrix.numel()}')\n                    if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None:\n                        sum_bias = sum_bias.transpose(0, 1)\n                        sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1)\n                    A_matrix = A_matrix.transpose(0,1)  # Spec dimension at the front.\n                    return A_matrix, sum_bias\n                new_patches = last_A.create_similar(\n                        pieces, stride=stride, padding=padding, output_padding=output_padding,\n                        identity=0, input_shape=self.input_shape)\n                # if last_A is last_lA:\n                #     print(f'Conv : start_node {kwargs[\"start_node\"].name} layer {self.name} {new_patches}')\n                return new_patches, sum_bias\n            else:\n                raise NotImplementedError()\n\n        lA_x, lbias = _bound_oneside(last_lA)\n        uA_x, ubias = _bound_oneside(last_uA)\n        return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if self.is_input_perturbed(1):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        assert self.dilation == (1, 1) or self.dilation == [1, 1]\n        # e.g., last layer input gurobi vars (3,32,32)\n        gvars_array = np.array(v[0])\n        # pre_layer_shape (1,3,32,32)\n        pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape\n        # this layer shape (1,8,16,16)\n        this_layer_shape = self.output_shape\n        out_lbs, out_ubs = None, None\n        if self.is_lower_bound_current():\n            # self.lower shape (1,8,16,16)\n            out_lbs = self.lower.detach().cpu().numpy()\n            out_ubs = self.upper.detach().cpu().numpy()\n\n        # current layer weight (8,3,4,4)\n        this_layer_weight = v[1].detach().cpu().numpy()\n        # current layer bias (8,)\n        this_layer_bias = None\n        if self.has_bias:\n            this_layer_bias = v[2].detach().cpu().numpy()\n        weight_shape2, weight_shape3 = this_layer_weight.shape[2], this_layer_weight.shape[3]\n        padding0, padding1 = self.padding[0], self.padding[1]\n        stride0, stride1 = self.stride[0], self.stride[1]\n\n        new_layer_gurobi_vars = []\n        new_layer_gurobi_constrs = []\n\n        # precompute row and column index mappings\n\n        # compute row mapping: from current row to input rows\n        # vectorization of following code:\n        # for out_row_idx in range(this_layer_shape[2]):\n        #     ker_row_min, ker_row_max = 0, weight_shape2\n        #     in_row_idx_min = -padding0 + stride0 * out_row_idx\n        #     in_row_idx_max = in_row_idx_min + weight_shape2 - 1\n        #     if in_row_idx_min < 0:\n        #         ker_row_min = -in_row_idx_min\n        #     if in_row_idx_max >= pre_layer_shape[2]:\n        #         ker_row_max = ker_row_max - in_row_idx_max + pre_layer_shape[2] - 1\n        #     in_row_idx_min, in_row_idx_max = max(in_row_idx_min, 0), min(in_row_idx_max,\n        #                                                                  pre_layer_shape[2] - 1)\n        in_row_idx_mins = np.arange(this_layer_shape[2]) * stride0 - padding0\n        in_row_idx_maxs = in_row_idx_mins + weight_shape2 - 1\n        ker_row_mins = np.zeros(this_layer_shape[2], dtype=int)\n        ker_row_maxs = np.ones(this_layer_shape[2], dtype=int) * weight_shape2\n        ker_row_mins[in_row_idx_mins < 0] = -in_row_idx_mins[in_row_idx_mins < 0]\n        ker_row_maxs[in_row_idx_maxs >= pre_layer_shape[2]] = \\\n            ker_row_maxs[in_row_idx_maxs >= pre_layer_shape[2]] - in_row_idx_maxs[in_row_idx_maxs >= pre_layer_shape[2]]\\\n            + pre_layer_shape[2] - 1\n        in_row_idx_mins = np.maximum(in_row_idx_mins, 0)\n        in_row_idx_maxs = np.minimum(in_row_idx_maxs, pre_layer_shape[2] - 1)\n\n        # compute column mapping: from current column to input columns\n        # vectorization of following code:\n        # for out_col_idx in range(this_layer_shape[3]):\n        #     ker_col_min, ker_col_max = 0, weight_shape3\n        #     in_col_idx_min = -padding1 + stride1 * out_col_idx\n        #     in_col_idx_max = in_col_idx_min + weight_shape3 - 1\n        #     if in_col_idx_min < 0:\n        #         ker_col_min = -in_col_idx_min\n        #     if in_col_idx_max >= pre_layer_shape[3]:\n        #         ker_col_max = ker_col_max - in_col_idx_max + pre_layer_shape[3] - 1\n        #     in_col_idx_min, in_col_idx_max = max(in_col_idx_min, 0), min(in_col_idx_max,\n        #                                                                  pre_layer_shape[3] - 1)\n        in_col_idx_mins = np.arange(this_layer_shape[3]) * stride1 - padding1\n        in_col_idx_maxs = in_col_idx_mins + weight_shape3 - 1\n        ker_col_mins = np.zeros(this_layer_shape[3], dtype=int)\n        ker_col_maxs = np.ones(this_layer_shape[3], dtype=int) * weight_shape3\n        ker_col_mins[in_col_idx_mins < 0] = -in_col_idx_mins[in_col_idx_mins < 0]\n        ker_col_maxs[in_col_idx_maxs >= pre_layer_shape[3]] = \\\n            ker_col_maxs[in_col_idx_maxs >= pre_layer_shape[3]] - in_col_idx_maxs[in_col_idx_maxs >= pre_layer_shape[3]]\\\n            + pre_layer_shape[3] - 1\n        in_col_idx_mins = np.maximum(in_col_idx_mins, 0)\n        in_col_idx_maxs = np.minimum(in_col_idx_maxs, pre_layer_shape[3] - 1)\n\n        neuron_idx = 0\n        for out_chan_idx in range(this_layer_shape[1]):\n            out_chan_vars = []\n            for out_row_idx in range(this_layer_shape[2]):\n                out_row_vars = []\n\n                # get row index range from precomputed arrays\n                ker_row_min, ker_row_max = ker_row_mins[out_row_idx], ker_row_maxs[out_row_idx]\n                in_row_idx_min, in_row_idx_max = in_row_idx_mins[out_row_idx], in_row_idx_maxs[out_row_idx]\n\n                for out_col_idx in range(this_layer_shape[3]):\n\n                    # get col index range from precomputed arrays\n                    ker_col_min, ker_col_max = ker_col_mins[out_col_idx], ker_col_maxs[out_col_idx]\n                    in_col_idx_min, in_col_idx_max = in_col_idx_mins[out_col_idx], in_col_idx_maxs[out_col_idx]\n\n                    # init linear expression\n                    lin_expr = this_layer_bias[out_chan_idx] if self.has_bias else 0\n\n                    # init linear constraint LHS implied by the conv operation\n                    for in_chan_idx in range(this_layer_weight.shape[1]):\n\n                        coeffs = this_layer_weight[out_chan_idx, in_chan_idx, ker_row_min:ker_row_max, ker_col_min:ker_col_max].reshape(-1)\n                        gvars = gvars_array[in_chan_idx, in_row_idx_min:in_row_idx_max+1, in_col_idx_min:in_col_idx_max+1].reshape(-1)\n                        if solver_pkg == 'gurobi':\n                            lin_expr += grb.LinExpr(coeffs, gvars)\n                        else:\n                            for i in range(len(coeffs)):\n                                try:\n                                    lin_expr += coeffs[i] * gvars[i]\n                                except TypeError:\n                                    lin_expr += coeffs[i] * gvars[i].var\n\n                    # init potential lb and ub, which helps solver to finish faster\n                    out_lb = out_lbs[0, out_chan_idx, out_row_idx, out_col_idx] if out_lbs is not None else -float('inf')\n                    out_ub = out_ubs[0, out_chan_idx, out_row_idx, out_col_idx] if out_ubs is not None else float('inf')\n                    if out_ub - out_lb < EPS:\n                        # If the inferred lb and ub are too close, it could lead to floating point disagreement\n                        # between solver's inferred lb and ub constraints and the computed ones from ab-crown.\n                        # Such disagreement can lead to \"infeasible\" result from the solver for feasible problem.\n                        # To avoid so, we relax the box constraints.\n                        # This should not affect the solver's result correctness,\n                        # since the tighter lb and ub can be inferred by the solver.\n                        out_lb, out_ub = (out_lb + out_ub - EPS) / 2., (out_lb + out_ub + EPS) / 2.\n\n                    # add the output var and constraint\n                    var = model.addVar(lb=out_lb, ub=out_ub,\n                                            obj=0, vtype=grb.GRB.CONTINUOUS,\n                                            name=f'lay{self.name}_{neuron_idx}')\n                    model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')\n                    neuron_idx += 1\n\n                    out_row_vars.append(var)\n                out_chan_vars.append(out_row_vars)\n            new_layer_gurobi_vars.append(out_chan_vars)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()\n\n    def interval_propagate(self, *v, C=None):\n        if self.is_input_perturbed(1):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        norm = Interval.get_perturbation(v[0])\n        norm = norm[0]\n\n        h_L, h_U = v[0]\n        weight = v[1][0]\n        bias = v[2][0] if self.has_bias else None\n\n        if norm == torch.inf:\n            mid = (h_U + h_L) / 2.0\n            diff = (h_U - h_L) / 2.0\n            weight_abs = weight.abs()\n            deviation = self.F_conv(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups)\n        elif norm > 0:\n            norm, eps = Interval.get_perturbation(v[0])\n            # L2 norm, h_U and h_L are the same.\n            mid = h_U\n            # TODO: padding\n            assert not isinstance(eps, torch.Tensor) or eps.numel() == 1\n            deviation = torch.mul(weight, weight).sum((1, 2, 3)).sqrt() * eps\n            deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n        else: # Here we calculate the L0 norm IBP bound using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020]\n            norm, eps, ratio = Interval.get_perturbation(v[0])\n            mid = h_U\n            k = int(eps)\n            weight_sum = torch.sum(weight.abs(), 1)\n            deviation = torch.sum(torch.topk(weight_sum.view(weight_sum.shape[0], -1), k)[0], dim=1) * ratio\n\n            if self.has_bias:\n                center = self.F_conv(mid, weight, v[2][0], self.stride, self.padding, self.dilation, self.groups)\n            else:\n                center = self.F_conv(mid, weight, None, self.stride, self.padding, self.dilation, self.groups)\n\n            ss = center.shape\n            deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3])\n\n        center = self.F_conv(mid, weight, bias, self.stride, self.padding, self.dilation, self.groups)\n\n        upper = center + deviation\n        lower = center - deviation\n        return lower, upper\n\n    def bound_dynamic_forward(self, *x, max_dim=None, offset=0):\n        if self.is_input_perturbed(1) or self.is_input_perturbed(2):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n        weight = x[1].lb\n        bias = x[2].lb if self.has_bias else None\n        x = x[0]\n        w = x.lw\n        b = x.lb\n        shape = w.shape\n        shape_wconv = [shape[0] * shape[1]] + list(shape[2:])\n        def conv2d(input, weight, bias, stride, padding, dilation, groups):\n            \"\"\" There may be some CUDA error (illegal memory access) when\n            the batch size is too large. Thus split the input into several\n            batches when needed. \"\"\"\n            max_batch_size = 50\n            if input.device != torch.device('cpu') and input.shape[0] > max_batch_size:\n                ret = []\n                for i in range((input.shape[0] + max_batch_size - 1) // max_batch_size):\n                    ret.append(self.F_conv(\n                        input[i*max_batch_size:(i+1)*max_batch_size],\n                        weight, bias, stride, padding, dilation, groups))\n                return torch.cat(ret, dim=0)\n            else:\n                return self.F_conv(input, weight, bias, stride, padding, dilation, groups)\n        w_new = conv2d(\n            w.reshape(shape_wconv), weight, None, self.stride, self.padding,\n            self.dilation, self.groups)\n        w_new = w_new.reshape(shape[0], -1, *w_new.shape[1:])\n        b_new = conv2d(\n            b, weight, bias, self.stride, self.padding, self.dilation, self.groups)\n        return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)\n\n    def bound_forward(self, dim_in, *x):\n        if self.is_input_perturbed(1) or self.is_input_perturbed(2):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        weight = x[1].lb\n        bias = x[2].lb if self.has_bias else None\n        x = x[0]\n\n        mid_w = (x.lw + x.uw) / 2\n        mid_b = (x.lb + x.ub) / 2\n        diff_w = (x.uw - x.lw) / 2\n        diff_b = (x.ub - x.lb) / 2\n        weight_abs = weight.abs()\n        shape = mid_w.shape\n        shape_wconv = [shape[0] * shape[1]] + list(shape[2:])\n        deviation_w = self.F_conv(\n            diff_w.reshape(shape_wconv), weight_abs, None,\n            self.stride, self.padding, self.dilation, self.groups)\n        deviation_b = self.F_conv(\n            diff_b, weight_abs, None,\n            self.stride, self.padding, self.dilation, self.groups)\n        center_w = self.F_conv(\n            mid_w.reshape(shape_wconv), weight, None,\n            self.stride, self.padding, self.dilation, self.groups)\n        center_b = self.F_conv(\n            mid_b, weight, bias,\n            self.stride, self.padding, self.dilation, self.groups)\n        deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:])\n        center_w = center_w.reshape(shape[0], -1, *center_w.shape[1:])\n\n        return LinearBound(\n            lw = center_w - deviation_w,\n            lb = center_b - deviation_b,\n            uw = center_w + deviation_w,\n            ub = center_b + deviation_b)\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = Conv2dGrad(\n            self, self.inputs[1].param, self.stride, self.padding,\n            self.dilation, self.groups)\n        return [(node_grad, (grad_upstream,), [])]\n\n    def update_requires_input_bounds(self):\n        self._check_weight_perturbation()\n\n\nclass BoundConvTranspose(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        assert (attr['pads'][0] == attr['pads'][2])\n        assert (attr['pads'][1] == attr['pads'][3])\n\n        self.stride = attr['strides']\n        self.padding = [attr['pads'][0], attr['pads'][1]]\n        self.dilation = attr['dilations']\n        self.groups = attr['group']\n        self.output_padding = [attr.get('output_padding', [0, 0])[0], attr.get('output_padding', [0, 0])[1]]\n        assert len(attr['kernel_shape']) == 2  # 2d transposed convolution.\n        if len(inputs) == 3:\n            self.has_bias = True\n        else:\n            self.has_bias = False\n        self.mode = options.get(\"conv_mode\", \"matrix\")\n        assert self.output_padding == [0, 0]\n        assert self.dilation == [1, 1]\n        assert self.stride[0] == self.stride[1]\n        assert self.groups == 1\n\n        self.F_convtranspose = F.conv_transpose2d\n\n    def forward(self, *x):\n        # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias\n        bias = x[2] if self.has_bias else None\n        output = F.conv_transpose2d(x[0], x[1], bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)\n        return output\n\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        if self.is_input_perturbed(1):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        lA_y = uA_y = lA_bias = uA_bias = None\n        weight = x[1].lower\n        assert weight.size(-1) == weight.size(-2)\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, 0\n            if type(last_A) is OneHotC:\n                # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead.\n                last_A = onehotc_to_dense(last_A, dtype=weight.dtype)\n\n            if type(last_A) == Tensor:\n                shape = last_A.size()\n                next_A = F.conv2d(last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None,\n                                            stride=self.stride, padding=self.padding, dilation=self.dilation,\n                                            groups=self.groups)\n                next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])\n                if self.has_bias:\n                    sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)\n                else:\n                    sum_bias = 0\n                return next_A, sum_bias\n            elif type(last_A) == Patches:\n                # Here we build and propagate a Patch object with (patches, stride, padding)\n                assert type(last_A) == Patches\n                if last_A.identity == 0:\n                    patches = last_A.patches\n\n                    # FIXME: so far, assume there will be a relu layer in its input.\n\n                    if self.has_bias:\n                        # bias is x[2] (lower and upper are the same), and has shape (c,).\n                        # Patches either has [out_c, batch, out_h, out_w, c, h, w] or [unstable_size, batch, c, h, w].\n                        sum_bias = torch.einsum('sb...chw,c->sb...', patches, x[2].lower)\n                        # sum_bias has shape (out_c, batch, out_h, out_w) or (unstable_size, batch).\n                    else:\n                        sum_bias = 0\n\n                    flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1))\n                    # Merge patches with this layer's weights. Weight must be flipped here; and if stride != 1, we must insert zeros in the input image.\n                    # For conv_transpose2d, the weight matrix is in the (in, out, k, k) shape.\n                    # pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=self.stride)\n                    # pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=last_A.inserted_zeros + 1)\n                    # Use padding in conv_transposed2d directly.\n                    pieces = F.conv_transpose2d(\n                            # Transpose because the weight has in_channel before out_channel.\n                            flattened_patches, insert_zeros(weight.transpose(0,1).flip(-1,-2), last_A.inserted_zeros))\n                    # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w).\n                    pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1))\n\n                elif last_A.identity == 1:\n                    # New patches have size [out_c, batch, out_h, out_w, c, h, w] if it is not sparse.\n                    # New patches have size [unstable_size, batch, c, h, w] if it is sparse.\n                    if last_A.unstable_idx is not None:\n                        raise NotImplementedError()\n                    else:\n                        assert weight.size(0) == last_A.shape[0]\n                        pieces = weight.view(weight.size(0), 1, 1, 1, weight.size(1), weight.size(2), weight.size(3)).expand(-1, *last_A.shape[1:4], -1, -1, -1)\n                        # The bias (x[2].lower) has shape (out_c,) need to make it (out_c, batch, out_h, out_w).\n                        # Here we should transpose sum_bias to set the batch dim to 1, aiming to keep it consistent with the matrix version\n                        sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])\n                else:\n                    raise NotImplementedError()\n                patches_padding = last_A.padding if last_A is not None else (0, 0, 0, 0)  # (left, right, top, bottom)\n                output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0)  # (left, right, top, bottom)\n                inserted_zeros = last_A.inserted_zeros\n                assert self.stride[0] == self.stride[1]\n\n                # Unify the shape to 4-tuple.\n                output_padding = unify_shape(output_padding)\n                patches_padding = unify_shape(patches_padding)\n                this_stride = unify_shape(self.stride)\n                this_padding = unify_shape(self.padding)\n\n                # Compute new padding. Due to the shape flip during merging, we need to check the string/size on the dimension 3 - j.\n                # TODO: testing for asymmetric shapes.\n                padding = tuple(p * (inserted_zeros + 1) + (weight.size(3 - j//2) - 1) for j, p in enumerate(patches_padding))\n\n                # Compute new output padding\n                output_padding = tuple(p * (inserted_zeros + 1) + this_padding[j] for j, p in enumerate(output_padding))\n                # When we run insert_zeros, it's missing the right most column and the bottom row.\n                # padding = (padding[0], padding[1] + inserted_zeros, padding[2], padding[3] + inserted_zeros)\n\n                # If no transposed conv so far, inserted_zero is 0.\n                # We a transposed conv is encountered, stride is multiplied on it.\n                inserted_zeros = (inserted_zeros + 1) * this_stride[0] - 1\n\n                # FIXME: disabled patches_to_matrix because not all parameters are supported.\n                if inserted_zeros == 0 and not is_shape_used(output_padding) and pieces.shape[-1] > self.input_shape[-1]:  # the patches is too large and from now on, we will use matrix mode instead of patches mode.\n                    # This is our desired matrix: the input will be flattend to (batch_size, input_channel*input_x * input_y) and multiplies on this matrix.\n                    # After multiplication, the desired output is (batch_size, out_channel*output_x*output_y).\n                    # A_matrix has size (batch, out_c*out_h*out_w, in_c*in_h*in_w)\n                    assert inserted_zeros == 0\n                    A_matrix = patches_to_matrix(pieces, self.input_shape[1:], last_A.stride, padding, last_A.output_shape, last_A.unstable_idx)\n                    if isinstance(sum_bias, Tensor) and last_A.unstable_idx is None:\n                        sum_bias = sum_bias.transpose(0, 1)\n                        sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1)\n                    A_matrix = A_matrix.transpose(0,1)  # Spec dimension at the front.\n                    return A_matrix, sum_bias\n                new_patches = last_A.create_similar(\n                        pieces, padding=padding, inserted_zeros=inserted_zeros, output_padding=output_padding,\n                        input_shape=self.input_shape)\n                return new_patches, sum_bias\n            else:\n                raise NotImplementedError()\n\n        lA_x, lbias = _bound_oneside(last_lA)\n        uA_x, ubias = _bound_oneside(last_uA)\n        return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias\n\n    def interval_propagate(self, *v, C=None):\n        if self.is_input_perturbed(1):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        norm = Interval.get_perturbation(v[0])\n        norm = norm[0]\n\n        h_L, h_U = v[0]\n        weight = v[1][0]\n        bias = v[2][0] if self.has_bias else None\n\n        if norm == torch.inf:\n            mid = (h_U + h_L) / 2.0\n            diff = (h_U - h_L) / 2.0\n            weight_abs = weight.abs()\n            deviation = F.conv_transpose2d(diff, weight_abs, None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)\n        elif norm > 0:\n            raise NotImplementedError()\n            norm, eps = Interval.get_perturbation(v[0])\n            # L2 norm, h_U and h_L are the same.\n            mid = h_U\n            # TODO: padding\n            deviation = torch.mul(weight, weight).sum((1, 2, 3)).sqrt() * eps\n            deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n        else: # Here we calculate the L0 norm IBP bound using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020]\n            raise NotImplementedError()\n\n        center = F.conv_transpose2d(mid, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding)\n\n        upper = center + deviation\n        lower = center - deviation\n        return lower, upper\n\n    def bound_forward(self, dim_in, *x):\n        if self.is_input_perturbed(1) or self.is_input_perturbed(2):\n            raise NotImplementedError(\"Weight perturbation for convolution layers has not been implmented.\")\n\n        weight = x[1].lb\n        bias = x[2].lb if self.has_bias else None\n        x = x[0]\n\n        mid_w = (x.lw + x.uw) / 2\n        mid_b = (x.lb + x.ub) / 2\n        diff_w = (x.uw - x.lw) / 2\n        diff_b = (x.ub - x.lb) / 2\n        weight_abs = weight.abs()\n        shape = mid_w.shape\n        shape_wconv = [shape[0] * shape[1]] + list(shape[2:])\n        deviation_w = self.F_convtranspose(\n            diff_w.reshape(shape_wconv), weight_abs, None, output_padding=self.output_padding,\n            stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)\n        deviation_b = self.F_convtranspose(\n            diff_b, weight_abs, None, output_padding=self.output_padding,\n            stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)\n        center_w = self.F_convtranspose(\n            mid_w.reshape(shape_wconv), weight, output_padding=self.output_padding,\n            stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)\n        center_b = self.F_convtranspose(\n            mid_b, weight, bias, output_padding=self.output_padding,\n            stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)\n        deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:])\n        center_w = center_w.reshape(shape[0], -1, *center_w.shape[1:])\n\n        return LinearBound(\n            lw = center_w - deviation_w,\n            lb = center_b - deviation_b,\n            uw = center_w + deviation_w,\n            ub = center_b + deviation_b)\n\n\nclass BoundPad(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if hasattr(attr, 'pads'):\n            self.padding = attr['pads'][2:4] + attr['pads'][6:8]\n        else:\n            self.padding = [0, 0, 0, 0]\n        self.value = attr.get('value', 0.0)\n        assert self.padding == [0, 0, 0, 0]\n\n    def forward(self, x, pad, value=0.0):\n        # TODO: padding for 3-D or more dimensional inputs.\n        assert x.ndim == 4\n        # x[1] should be [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right]\n        assert pad[0] == pad[1] == pad[4] == pad[5] == 0\n        pad = [int(pad[3]), int(pad[7]), int(pad[2]), int(pad[6])]\n        final = F.pad(x, pad, value=value)\n        self.padding, self.value = pad, value\n        return final\n\n    def interval_propagate(self, *v):\n        l, u = zip(*v)\n        return Interval.make_interval(self.forward(*l), self.forward(*u), v[0])\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        # TODO: padding for 3-D or more dimensional inputs.\n        left, right, top, bottom = self.padding\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            assert type(last_A) is Patches or last_A.ndim == 5\n            if type(last_A) is Patches:\n                if isinstance(last_A.padding, tuple):\n                    new_padding = (last_A.padding[0] + left, last_A.padding[1] + right, last_A.padding[2] + top, last_A.padding[3] + bottom)\n                else:\n                    new_padding = (last_A.padding + left, last_A.padding + right, last_A.padding + top, last_A.padding + bottom)\n                return last_A.create_similar(padding=new_padding)\n            else:\n                shape = last_A.size()\n                return last_A[:, :, :, top:(shape[3] - bottom), left:(shape[4] - right)]\n        last_lA = _bound_oneside(last_lA)\n        last_uA = _bound_oneside(last_uA)\n        return [(last_lA, last_uA), (None, None), (None, None)], 0, 0\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        # e.g., last layer input gurobi vars (3,32,32)\n        gvars_array = np.array(v[0])\n        # pre_layer_shape (1,3,32,32)\n        pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape\n        # this layer shape (1,3,35,35)\n        this_layer_shape = self.output_shape\n        # v1 = tensor([0, 0, 1, 1, 0, 0, 2, 2])\n        # [0,0,pad_top,pad_left,0,0,pad_bottom,pad_right]\n        # => [left, right, top, bottom]\n        padding = [int(v[1][3]), int(v[1][7]), int(v[1][2]), int(v[1][6])]\n        left, right, top, bottom = padding\n        assert pre_layer_shape[2] + padding[0] + padding[1] == this_layer_shape[2]\n        assert pre_layer_shape[3] + padding[2] + padding[3] == this_layer_shape[3]\n\n        new_layer_gurobi_vars = []\n        neuron_idx = 0\n        for out_chan_idx in range(this_layer_shape[1]):\n            out_chan_vars = []\n            for out_row_idx in range(this_layer_shape[2]):\n                out_row_vars = []\n                row_pad = out_row_idx < left or out_row_idx >= this_layer_shape[2] - right\n                for out_col_idx in range(this_layer_shape[3]):\n                    col_pad = out_col_idx < top or out_col_idx >= this_layer_shape[3] - bottom\n                    if row_pad or col_pad:\n                        v = model.addVar(lb=0, ub=0,\n                                    obj=0, vtype=grb.GRB.CONTINUOUS,\n                                    name=f'pad{self.name}_{neuron_idx}')\n                    else:\n                        v = gvars_array[out_chan_idx, out_row_idx - left, out_col_idx - top]\n                    # print(out_chan_idx, out_row_idx, out_col_idx, row_pad, col_pad, v.LB, v.UB)\n                    neuron_idx += 1\n\n                    out_row_vars.append(v)\n                out_chan_vars.append(out_row_vars)\n            new_layer_gurobi_vars.append(out_chan_vars)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()\n\n\nclass Conv2dGrad(Module):\n    def __init__(self, fw_module, weight, stride, padding, dilation, groups):\n        super().__init__()\n        self.weight = weight\n        self.dilation = dilation\n        self.groups = groups\n        self.fw_module = fw_module\n\n        assert isinstance(stride, list) and stride[0] == stride[1]\n        assert isinstance(padding, list) and padding[0] == padding[1]\n        assert isinstance(dilation, list) and dilation[0] == dilation[1]\n        self.stride = stride[0]\n        self.padding = padding[0]\n        self.dilation = dilation[0]\n\n    def forward(self, grad_last):\n        output_padding0 = (\n            int(self.fw_module.input_shape[2])\n            - (int(self.fw_module.output_shape[2]) - 1) * self.stride\n            + 2 * self.padding - 1 - (int(self.weight.size()[2] - 1) * self.dilation))\n        output_padding1 = (\n            int(self.fw_module.input_shape[3])\n            - (int(self.fw_module.output_shape[3]) - 1) * self.stride\n            + 2 * self.padding - 1 - (int(self.weight.size()[3] - 1) * self.dilation))\n\n        return Conv2dGradOp.apply(\n            grad_last, self.weight, self.stride, self.padding, self.dilation,\n            self.groups, output_padding0, output_padding1)\n\n\nclass Conv2dGradOp(Function):\n    @staticmethod\n    def symbolic(g, x, w, stride, padding, dilation, groups,\n                 output_padding0, output_padding1):\n        return g.op(\n            'grad::Conv2d', x, w, stride_i=stride, padding_i=padding,\n            dilation_i=dilation, groups_i=groups,\n            output_padding0_i=output_padding0,\n            output_padding1_i=output_padding1).setType(x.type())\n\n    @staticmethod\n    def forward(\n            ctx, grad_last, w, stride, padding, dilation, groups, output_padding0,\n            output_padding1):\n        grad_shape = grad_last.shape\n        grad = F.conv_transpose2d(\n            grad_last.view(grad_shape[0], *grad_shape[1:]), w, None,\n            stride=stride, padding=padding, dilation=dilation,\n            groups=groups, output_padding=(output_padding0, output_padding1))\n\n        grad = grad.view((grad_shape[0], *grad.shape[1:]))\n        return grad\n\n\nclass BoundConv2dGrad(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.stride = attr['stride']\n        self.padding = attr['padding']\n        self.dilation = attr['dilation']\n        self.groups = attr['groups']\n        self.output_padding = [\n            attr.get('output_padding0', 0),\n            attr.get('output_padding1', 0)\n        ]\n        self.has_bias = len(inputs) == 3\n        self.mode = options.get('conv_mode', 'matrix')\n        self.patches_start = True\n\n    def forward(self, *x):\n        # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias\n        return F.conv_transpose2d(\n            x[0], x[1], None,\n            stride=self.stride, padding=self.padding, dilation=self.dilation,\n            groups=self.groups, output_padding=self.output_padding)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        assert not self.is_input_perturbed(1)\n\n        lA_y = uA_y = lA_bias = uA_bias = None\n        weight = x[1].lower\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, 0\n\n            if isinstance(last_A, torch.Tensor):\n                shape = last_A.size()\n                next_A = F.conv2d(\n                    last_A.reshape(shape[0] * shape[1], *shape[2:]),\n                    weight, None, stride=self.stride, padding=self.padding,\n                    dilation=self.dilation, groups=self.groups)\n                next_A = next_A.view(\n                    shape[0], shape[1], *next_A.shape[1:])\n                if self.has_bias:\n                    sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2)\n                else:\n                    sum_bias = 0\n                return next_A, sum_bias\n            elif isinstance(last_A, Patches):\n                # Here we build and propagate a Patch object with\n                # (patches, stride, padding)\n                assert self.stride == 1, 'The patches mode only supports stride = 1'\n                if last_A.identity == 1:\n                    # create a identity patch\n                    # [out_dim, batch, out_c, out_h, out_w, in_dim, in_c, in_h, in_w]\n                    patch_shape = last_A.shape\n                    if last_A.unstable_idx is not None:\n                        # FIXME Somehow the usage of unstable_idx seems to have\n                        # been changed, and the previous code is no longer working.\n                        raise NotImplementedError(\n                            'Sparse patches for '\n                            'BoundConv2dGrad is not supported yet.')\n                        output_shape = last_A.output_shape\n                        patches = torch.eye(\n                            patch_shape[0]).to(weight)\n                        patches = patches.view([\n                            patch_shape[0], 1, 1, 1, 1, patch_shape[0], 1, 1])\n                        # [out_dim, bsz, out_c, out_h, out_w, out_dim, in_c, in_h, in_w]\n                        patches = patches.expand([\n                            patch_shape[0], patch_shape[1], patch_shape[2],\n                            output_shape[2], output_shape[3],\n                            patch_shape[0], 1, 1])\n                        patches = patches.transpose(0, 1)\n                        patches = patches[\n                            :,torch.tensor(list(range(patch_shape[0]))),\n                            last_A.unstable_idx[0], last_A.unstable_idx[1],\n                            last_A.unstable_idx[2]]\n                        patches = patches.transpose(0, 1)\n                    else:\n                        # out_dim * out_c\n                        patches = torch.eye(patch_shape[0]).to(weight)\n                        patches = patches.view([\n                            patch_shape[0], 1, 1, 1, patch_shape[0], 1, 1])\n                        patches = patches.expand(patch_shape)\n                else:\n                    patches = last_A.patches\n\n                if self.has_bias:\n                    # bias is x[2] (lower and upper are the same), and has\n                    # shape (c,).\n                    # Patches either has\n                    # [out_dim, batch, out_c, out_h, out_w, out_dim, c, h, w]\n                    # or [unstable_size, batch, out_dim, c, h, w].\n                    # sum_bias has shape (out_dim, batch, out_c, out_h, out_w)\n                    # or (unstable_size, batch).\n                    sum_bias = torch.einsum(\n                        'sb...ochw,c->sb...', patches, x[2].lower)\n                else:\n                    sum_bias = 0\n\n                flattened_patches = patches.reshape(\n                    -1, patches.size(-3), patches.size(-2), patches.size(-1))\n                # Pad to the full size\n                pieces = F.conv2d(\n                    flattened_patches, weight, stride=self.stride,\n                    padding=weight.shape[2]-1)\n                # New patch size:\n                # (out_c, batch, out_h, out_w, c, h, w)\n                # or (unstable_size, batch, c, h, w).\n                pieces = pieces.view(\n                    *patches.shape[:-3], pieces.size(-3), pieces.size(-2),\n                    pieces.size(-1))\n\n                # (left, right, top, bottom)\n                padding = last_A.padding if last_A is not None else (0, 0, 0, 0)\n                stride = last_A.stride if last_A is not None else 1\n\n                if isinstance(padding, int):\n                    padding = padding + weight.shape[2] - 1\n                else:\n                    padding = tuple(p + weight.shape[2] - 1 for p in padding)\n\n                return Patches(\n                    pieces, stride, padding, pieces.shape,\n                    unstable_idx=last_A.unstable_idx,\n                    output_shape=last_A.output_shape), sum_bias\n            else:\n                raise NotImplementedError()\n\n        lA_x, lbias = _bound_oneside(last_lA)\n        uA_x, ubias = _bound_oneside(last_uA)\n        return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias\n\n    def interval_propagate(self, *v, C=None):\n        assert not self.is_input_perturbed(1)\n\n        norm = Interval.get_perturbation(v[0])[0]\n        h_L, h_U = v[0]\n\n        weight = v[1][0]\n        bias = v[2][0] if self.has_bias else None\n\n        if norm == torch.inf:\n            mid = (h_U + h_L) / 2.0\n            diff = (h_U - h_L) / 2.0\n            weight_abs = weight.abs()\n            deviation = F.conv_transpose2d(\n                diff, weight_abs, None, stride=self.stride,\n                padding=self.padding, dilation=self.dilation,\n                groups=self.groups, output_padding=self.output_padding)\n        else:\n            raise NotImplementedError\n        center = F.conv_transpose2d(\n            mid, weight, bias, stride=self.stride, padding=self.padding,\n            dilation=self.dilation, groups=self.groups,\n            output_padding=self.output_padding)\n        upper = center + deviation\n        lower = center - deviation\n        return lower, upper\n"
  },
  {
    "path": "auto_LiRPA/operators/cut_ops.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Cut operators\"\"\"\nfrom .base import *\nfrom .clampmult import multiply_by_A_signs\n\n\nclass CutModule():\n    # store under BoundedModule\n    def __init__(self, relu_nodes=[], general_beta=None, x_coeffs=None,\n                 active_cuts=None, cut_bias=None):\n        # all dict, storing cut parameters for each start node\n        # {start node name: (2 (lA, uA), spec (out_c, out_h, out_w), batch, num_cuts)}\n        self.general_beta = general_beta\n        # {start node name: (# active cut constraints,)}\n        self.active_cuts = active_cuts\n\n        # all dict with tensor, storing coeffs for each relu layer, no grad\n        # coeffs: {relu layername: (num_cuts, flattened_nodes)}\n        self.relu_coeffs, self.arelu_coeffs, self.pre_coeffs = {}, {}, {}\n        for m in relu_nodes:\n            self.relu_coeffs[m.name] = self.arelu_coeffs[m.name] = self.pre_coeffs[m.name] = None\n\n        # single tensor, always the same, no grad\n        # bias: (num_cuts,)\n        self.cut_bias = cut_bias\n        # x_coeffs: (num_cuts, flattened input dims)\n        self.x_coeffs = x_coeffs\n\n    def use_patches(self, start_node):\n        # check if we are using patches mode for the start node\n        A = start_node.lA if start_node.lA is not None else start_node.uA\n        return type(A) is Patches\n\n    def select_active_general_beta(self, start_node, unstable_idx=None):\n        # if one constraint have nodes deeper than start node, we do not count its effect for now\n        # self.general_beta[start_node.name]: (2(0 lower, 1 upper), spec (out_c, out_h, out_w/# fc nodes), batch, num_constrs)\n        # self.active_cuts[start_node.name]: a long() tensor with constraint index that\n        # should be index on current layer with current start node\n        if self.general_beta[start_node.name].ndim == 4:\n            general_beta = self.general_beta[start_node.name][:, :, :, self.active_cuts[start_node.name]]\n        elif self.general_beta[start_node.name].ndim == 6:\n            general_beta = self.general_beta[start_node.name][:, :, :, :, :, self.active_cuts[start_node.name]]\n        else:\n            print(\"general beta shape not supported!\")\n            exit()\n        if unstable_idx is not None:\n            if self.use_patches(start_node):\n                general_beta = general_beta[:, unstable_idx[0], unstable_idx[1], unstable_idx[2], :, :]\n            else:\n                # matrix mode\n                if general_beta.ndim == 6:\n                    # conv layers general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)\n                    _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape\n                    general_beta = general_beta.view(2, -1, batch, num_constrs)\n                else:\n                    # dense layers general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)\n                    pass\n                general_beta = general_beta[:, unstable_idx]\n        else:\n            # unstable_idx is None\n            if general_beta.ndim == 6:\n                # flatten spec layer shape\n                _, out_c, out_h, out_w, batch, num_constrs = general_beta.shape\n                general_beta = general_beta.view(2, -1, batch, num_constrs)\n        return general_beta\n\n    def general_beta_coeffs_mm(self, unstable_spec_beta, coeffs, A, current_layer_shape):\n        if type(A) is Patches:\n            # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA\n            # coeffs: (num_constrs, current_c, current_h, current_w)\n            # coeffs_unfolded: (num_constrs, out_h, out_w, in_c, H, W)\n            # current_layer_shape = x.lower.size()[1:]\n            coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \\\n                                             kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride)\n            # unstable_coeffs_unfolded: (num_constrs, unstable, in_c, H, W)\n            # A.unstable_idx is the unstable idx for spec layer\n            unstable_coeffs_unfolded = coeffs_unfolded[:, A.unstable_idx[1], A.unstable_idx[2], :, :, :]\n            # A.unstable_idx: unstable index on out_c, out_h and out_w\n            # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)\n            # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs)\n            # unstable_spec_beta = general_beta[:, A.unstable_idx[0],\\\n            #             A.unstable_idx[1], A.unstable_idx[2], :, :]\n            # beta_mm_coeffs_unfolded: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)\n            beta_mm_coeffs = torch.einsum('sihj,jiabc->sihabc', unstable_spec_beta, unstable_coeffs_unfolded)\n        else:\n            # unstable_spec_beta: (2(0 lower, 1 upper), unstable, batch, num_constrs)\n            # coeffs: (num_constrs, current flattened layer nodes)\n            # beta_mm_coeffs: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)\n            beta_mm_coeffs = torch.einsum('sihj,ja->siha', unstable_spec_beta, coeffs)\n            assert beta_mm_coeffs[0].numel() == A.numel(), f\"the shape of beta is not initialized correctly! {beta_mm_coeffs[0].shape} v.s. {A.shape}\"\n        return beta_mm_coeffs.reshape(2, *A.shape)\n\n    def general_beta_coeffs_addmm_to_A(self, lA, uA, general_beta, coeffs, current_layer_shape):\n        A = lA if lA is not None else uA\n        # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)\n        # coeffs: (num_constrs, current_c, current_h, current_w)\n        # beta_mm_coeffs[0] shape is the same as A\n        # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)\n        # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)\n        beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, coeffs, A, current_layer_shape)\n        assert beta_mm_coeffs[0].shape == A.shape\n        if type(A) is Patches:\n            # lA, uA are patches, we have to unfold beta and coeffs to match lA and uA\n            # lA_patches: (unstable, batch, in_c, H, W)\n            if lA is not None:\n                lA = Patches(lA.patches - beta_mm_coeffs[0], A.stride, A.padding, \\\n                             A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape)\n            if uA is not None:\n                uA = Patches(uA.patches + beta_mm_coeffs[1], A.stride, A.padding, \\\n                             A.patches.shape, unstable_idx=A.unstable_idx, output_shape=A.output_shape)\n        else:\n            # dense layers\n            if lA is not None:\n                lA = lA - beta_mm_coeffs[0]\n            if uA is not None:\n                uA = uA + beta_mm_coeffs[1]\n        return lA, uA\n\n    def patch_trick(self, start_node, layer_name, A, current_layer_shape):\n        ######## A problem with patches mode for cut constraint start ##########\n        # There are cases that the node that is in the constraint but not selected by the patches for the output node\n        # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum()\n        # we should force these beta to be 0 to disable the effect of these constraints\n        # this only apply if current layer uses patches mode; if the target layer is patches but current layer not, we should not use it!\n        assert type(A) is Patches, \"this trick fix only works for patches mode\"\n        # unstable_spec_beta stores the current propagation, self.general_beta[start_node.name] selected with active_cuts, spec unstable\n        coeffs = 0\n        if layer_name != \"input\":\n            if self.relu_coeffs[layer_name] is not None:\n                coeffs = coeffs + self.relu_coeffs[layer_name]\n            if self.arelu_coeffs[layer_name] is not None:\n                coeffs = coeffs + self.arelu_coeffs[layer_name]\n            if self.pre_coeffs[layer_name] is not None:\n                coeffs = coeffs + self.pre_coeffs[layer_name]\n        else:\n            if self.x_coeffs is not None:\n                coeffs = coeffs + self.x_coeffs\n        coeffs_unfolded = inplace_unfold(coeffs.view(-1, *current_layer_shape), \\\n                                         kernel_size=A.patches.shape[-2:], padding=A.padding, stride=A.stride)\n        num_constrs, out_h, out_w, in_c, H, W = coeffs_unfolded.shape\n        # make sure the small patch selected include all the nonzero coeffs\n        ####### NOTE: This check could be costly #######\n        patch_mask_on_beta = (coeffs_unfolded.reshape(num_constrs, out_h, out_w, -1).abs().sum(-1) == \\\n                              coeffs.reshape(num_constrs, -1).abs().sum(-1).reshape(num_constrs, 1, 1))\n        # patch_mask_on_beta: (out_h, out_w, num_constrs)\n        patch_mask_on_beta = patch_mask_on_beta.permute(1, 2, 0)\n        # 2(lower, upper), out_c, out_h, out_w, batch, num_constrs\n        patch_mask_on_beta = patch_mask_on_beta.reshape(1, 1, out_h, out_w, 1, num_constrs)\n        self.general_beta[start_node.name].data = self.general_beta[start_node.name].data * patch_mask_on_beta\n\n    def relu_cut(self, start_node, layer_name, last_lA, last_uA, current_layer_shape, unstable_idx=None, batch_mask=None):\n        # propagate relu neuron in cut constraints through relu layer\n        # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately\n        relu_coeffs = self.relu_coeffs[layer_name]\n        active_cuts = self.active_cuts[start_node.name]\n        # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes\n        if relu_coeffs is None or active_cuts.size(0) == 0:\n            # do nothing\n            return last_lA, last_uA\n        assert start_node.name in self.general_beta\n        # select current relu layer general beta\n        general_beta = self.select_active_general_beta(start_node, unstable_idx)\n        relu_coeffs = relu_coeffs[active_cuts]\n        if batch_mask is not None:\n            general_beta = general_beta[:, :, batch_mask]\n        last_lA, last_uA = self.general_beta_coeffs_addmm_to_A(last_lA, last_uA, general_beta,\n                                                               relu_coeffs, current_layer_shape)\n        return last_lA, last_uA\n\n    def pre_cut(self, start_node, layer_name, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None):\n        # propagate prerelu neuron in cut constraints through relu layer\n        # start_node.name in self.general_beta means there are intermediate betas that can optimize this start node separately\n        pre_coeffs = self.pre_coeffs[layer_name]\n        active_cuts = self.active_cuts[start_node.name]\n        # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes\n        if pre_coeffs is None or active_cuts.size(0) == 0:\n            # do nothing\n            return lA, uA\n        general_beta = self.select_active_general_beta(start_node, unstable_idx)\n        pre_coeffs = pre_coeffs[active_cuts]\n        if batch_mask is not None:\n            general_beta = general_beta[:, :, batch_mask]\n        lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, pre_coeffs, current_layer_shape)\n        return lA, uA\n\n\n    @staticmethod\n    @torch.jit.script\n    def jit_arelu_lA(last_lA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d, I_z1, I_z0):\n        nu_hat_pos = last_lA.clamp(max=0.).abs()\n        # gamma = (-lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)\n        pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[0]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)\n        pi = torch.min(pi, nu_hat_pos)#, torch.min(gamma, nu_hat_pos)\n        pi = pi.clamp(min=0.)#, gamma.clamp(min=0.)\n        pi = nu_hat_pos * I_z1 + pi * (~I_z1 * ~I_z0)\n        new_upper_d = pi / (nu_hat_pos + 1e-10)\n        # need to customize the upper bound slope and lbias for (1) unstable relus and\n        # (2) relus that are used with upper boundary relaxation\n        # original upper bound slope is u/(u-l) also equal to pi/(pi+gamma) if no beta_mm_coeffs[0]\n        # now the upper bound slope should be pi/(pi+gamma) updated with beta_mm_coeffs[0]\n        unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(last_lA < 0)\n        # conv layer:\n        # upper_d: 1, batch, current_c, current_w, current_h\n        # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h\n        # dense layer:\n        # upper_d: 1, batch, current flattened nodes\n        # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes\n        # we may need a new mask to filter out the unstable nodes that are not in the current layer\n        new_upper_d = (new_upper_d * unstable_upper_bound_index.to(lower.dtype) +\n                      upper_d * (1. - unstable_upper_bound_index.to(lower.dtype)))\n        return nu_hat_pos, pi, new_upper_d, unstable_upper_bound_index\n\n    @staticmethod\n    @torch.jit.script\n    def jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, lbias, pi, I_z1, I_z0):\n        # if no unstable, following bias should always be 0\n        if unstable_or_cut_index.sum() > 0:\n            # update lbias with new form, only contribued by unstable relus\n            uC = -upper.unsqueeze(0) * nu_hat_pos\n            lC = -lower.unsqueeze(0) * nu_hat_pos\n            # lbias: (spec unstable, batch, current flattened nodes) same as lA\n            lbias = (pi * lower.unsqueeze(0))\n\n            # previous implementation\n            # uC_mask = (beta_mm_coeffs[0] <= uC).to(lbias)\n            # lC_mask = (beta_mm_coeffs[0] >= lC).to(lbias)\n\n            # complete implementation\n            uC_mask = ((beta_mm_coeffs[0] <= uC) | I_z0).to(lbias)\n            lC_mask = ((beta_mm_coeffs[0] >= lC) | I_z1).to(lbias)\n            default_mask = ((1-uC_mask) * (1-lC_mask)).to(lbias)\n            lbias = - beta_mm_coeffs[0].to(lbias) * lC_mask + lbias * default_mask\n\n            # lbias[beta_mm_coeffs[0] <= uC] = 0.\n            # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias)\n\n            # final lbias: (spec unstable, batch)\n            lbias = (lbias * unstable_or_cut_index.unsqueeze(0).to(lower.dtype)).view(lbias.shape[0], lbias.shape[1], -1).sum(-1)\n        return lbias\n\n    @staticmethod\n    @torch.jit.script\n    def jit_arelu_uA(last_uA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, upper_d, I_z1, I_z0):\n        nu_hat_pos = (-last_uA).clamp(max=0.).abs()\n        # gamma = (- lower.unsqueeze(0) * nu_hat_pos - beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)\n        pi = (upper.unsqueeze(0) * nu_hat_pos + beta_mm_coeffs[1]) / (upper.unsqueeze(0) - lower.unsqueeze(0) + 1e-10)\n        pi = pi.clamp(min=0.)\n        pi = torch.min(pi, nu_hat_pos)\n        pi = pi * I_z1 + nu_hat_pos * (~I_z1 * ~I_z0)\n        new_upper_d = pi / (nu_hat_pos + 1e-10)\n\n        # assert ((gamma + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, \"pi+gamma should always be the same as nu_hat_pos\"\n\n        # unstable_or_cut_index = self.I.logical_or(self.arelu_coeffs.abs().sum(0).view(self.I.shape) != 0)\n        unstable_upper_bound_index = unstable_or_cut_index.unsqueeze(0).logical_and(-last_uA < 0)\n        new_upper_d = new_upper_d * unstable_upper_bound_index.to(lower.dtype) + \\\n                      upper_d * (1. - unstable_upper_bound_index.to(lower.dtype))\n        return nu_hat_pos, pi, new_upper_d, unstable_upper_bound_index\n\n    @staticmethod\n    @torch.jit.script\n    def jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, lower, upper, ubias, pi, I_z1, I_z0):\n        if unstable_or_cut_index.sum() > 0:\n            uC = -upper.unsqueeze(0) * nu_hat_pos\n            lC = -lower.unsqueeze(0) * nu_hat_pos\n            ubias = -(pi * lower.unsqueeze(0))\n\n            # uC_mask = (beta_mm_coeffs[1] <= uC).to(ubias)\n            # lC_mask = (beta_mm_coeffs[1] >= lC).to(ubias)\n            uC_mask = ((beta_mm_coeffs[1] <= uC) | I_z0).to(ubias)\n            lC_mask = ((beta_mm_coeffs[1] >= lC) | I_z1).to(ubias)\n\n            default_mask = ((1-uC_mask) * (1-lC_mask)).to(ubias)\n            ubias = beta_mm_coeffs[1].to(ubias) * lC_mask + ubias * default_mask\n\n            # ubias[beta_mm_coeffs[1] <= uC] = 0.\n            # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias)\n\n            ubias = (ubias * unstable_or_cut_index.unsqueeze(0).to(lower.dtype)).view(ubias.shape[0], ubias.shape[1], -1).sum(-1)\n        return ubias\n\n\n    def arelu_cut(self, start_node, layer_name, last_lA, last_uA, lower_d, upper_d,\n                  lower_b, upper_b, lb_lower_d, ub_lower_d, relu_indicators, x, patch_size,\n                  current_layer_shape, unstable_idx=None, batch_mask=None):\n        \"\"\"\n        We want to calculate the pi and gamma for the lower bound of the next layer.\n        To make the GCP CROWN complete, we have to consider the case when z is a constant.\n        Now discuss the case when z = 0, z = 1 (constant), and 0 < z < 1 (variable).\n            lbias is h(beta) in the paper.\n            upper_d is the upper bound slope of the current layer.\n        1. z = 0 -> pi = 0, gamma = nu_hat_pos, tao = 0, mu = (alpha) * nu_hat_neg\n            lbias = 0.\n            upper_d = pi / (pi + gamma) = 0.\n        2. z = 1 -> pi = nu_hat_pos, gamma = 0, tao = alpha * nu_hat_neg, mu = 0\n            lbias = - beta_mm_coeffs[0].\n            upper_d = pi / (pi + gamma) = 1.\n        3. 0 < z < 1. We do the regular calculation using the closed form solution.\n            lbias = pi * lower, if -upper * nu_hat_pos <= beta_mm_coeffs[0] <= -lower * nu_hat_pos\n            lbias = 0, if beta_mm_coeffs[0] <= -upper * nu_hat_pos\n            lbias = -beta_mm_coeffs[0], if beta_mm_coeffs[0] >= -lower * nu_hat_pos\n            upper_d = pi / (nu_hat_pos).\n            where\n                pi = (upper * nu_hat_pos + beta_mm_coeffs[0]) / (upper - lower),\n                pi = min(pi, nu_hat_pos),\n                pi = max(pi, 0),\n                gamma = (-lower * nu_hat_pos - beta_mm_coeffs[0]) / (upper - lower).\n                gamma = min(gamma, nu_hat_pos),\n                gamma = max(gamma, 0).\n        Thus, we have the following implementation.\n        if z = 0:\n            pi = 0.\n        if z = 1:\n            pi = nu_hat_pos.\n        Otherwise:\n            if -upper * nu_hat_pos <= beta_mm_coeffs[0] <= -lower * nu_hat_pos:\n                pi = (upper * nu_hat_pos + beta_mm_coeffs[0]) / (upper - lower),\n                pi = min(pi, nu_hat_pos),\n                pi = max(pi, 0),\n                lbias = pi * lower,\n                upper_d = pi / (nu_hat_pos).\n            if beta_mm_coeffs[0] <= -upper * nu_hat_pos:\n                lbias = 0.\n            if beta_mm_coeffs[0] >= -lower * nu_hat_pos:\n                lbias = -beta_mm_coeffs[0].\n        \"\"\"\n        # propagate integer var of relu neuron (arelu) in cut constraints through relu layer\n        # I[0]. unstable neuron mask.\n        # I[1]. previous unstable now split on z = 1.\n        # I[2]. previous unstable now split on z = 0.\n        unstable_neurons_mask, z_split_to_1_mask, z_split_to_0_mask = relu_indicators\n        arelu_coeffs = self.arelu_coeffs[layer_name]\n        active_cuts = self.active_cuts[start_node.name]\n        # active_cuts.size(0) == 0 means all constraints containing this layer have deep layer nodes\n        if arelu_coeffs is None or active_cuts.size(0) == 0:\n            # do regular propagation without cut\n            uA, ubias = _bound_oneside(last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)\n            lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, lower_b, upper_b, start_node, patch_size)\n            return lA, uA, lbias, ubias\n\n        # general_beta: (2(0 lower, 1 upper), spec (out_c, out_h, out_w), batch, num_constrs)\n        general_beta = self.select_active_general_beta(start_node, unstable_idx)\n        # arelu_coeffs: (num_constrs, flattened current layer nodes)\n        arelu_coeffs = arelu_coeffs[active_cuts]\n        if batch_mask is not None:\n            general_beta = general_beta[:, :, batch_mask]\n        A = last_lA if last_lA is not None else last_uA\n        # beta_mm_coeffs[0] shape is the same as A\n        # patches mode: (2(0 lower, 1 upper), unstable, batch, in_c, H, W)\n        # not patches: (2(0 lower, 1 upper), unstable, batch, current flattened layer nodes)\n        beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, arelu_coeffs, A, current_layer_shape)\n        # unstable_this_layer = torch.logical_and(x.lower < 0, x.upper > 0).unsqueeze(0)\n        # relu_indicator is the unstable index in this relu layer: (batch, *layer shape)\n        # if there is node in cut constraint that is stable, also need to count its effect\n        # self.arelu_coeffs: (num_constrs, flattened current layer)\n        # self.arelu_coeffs do not have a batch dimension - only one cut can be applied to all batch elements.\n        # We will handle the neurons which are unstable or those have cut constraints below, thus creating the mask.\n        unstable_or_cut_index = unstable_neurons_mask.logical_or(arelu_coeffs.abs().sum(0).view(unstable_neurons_mask[0:1].shape) != 0)\n        # Shape of unstable_or_cut_index is (batch, num_neurons). It is a binary mask.\n\n        if type(A) is Patches:\n            # patches mode, conv layer only\n            # x.lower (always regular shape): batch, current_c, current_h, current_w\n            # x_lower_unfold: unstable, batch, in_C, H, W (same as patches last_lA)\n            x_lower_unfold = _maybe_unfold(x.lower.unsqueeze(0), A)\n            x_upper_unfold = _maybe_unfold(x.upper.unsqueeze(0), A)\n            # first minus upper and lower and then unfold to patch size will save memory\n            x_upper_minus_lower_unfold = _maybe_unfold((x.upper - x.lower).unsqueeze(0), A)\n            ####### be careful with the unstable_this_layer and unstable_idx #######\n            # unstable_this_layer is the unstable index in this layer\n            # unstable_idx is the unstable index in spec layer\n            # unstable_this_layer: spec unstable, batch, in_C, H, W (same as patches last_lA)\n            # unstable_this_layer = torch.logical_and(x_lower_unfold < 0, x_upper_unfold > 0)\n            # unstable_this_layer = _maybe_unfold(self.I.unsqueeze(0), A)\n            unstable_or_cut_index = _maybe_unfold(unstable_or_cut_index.unsqueeze(0), A)\n            if last_lA is not None:\n                assert beta_mm_coeffs[0].shape == last_lA.shape, f\"{beta_mm_coeffs[0].shape} != {last_lA.shape}\"\n                # last_lA.patches, nu_hat_pos, gamma, pi: (unstable, batch, in_c, H, W)\n                nu_hat_pos = last_lA.patches.clamp(max=0.).abs()\n                # gamma = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10))\n                pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[0]) / (x_upper_minus_lower_unfold.clamp(min=1e-10))\n                pi = torch.min(pi, nu_hat_pos).clamp(min=0.)\n                pi = nu_hat_pos * z_split_to_1_mask + pi * (~z_split_to_1_mask * ~z_split_to_0_mask)\n                new_upper_d = pi / (nu_hat_pos + 1e-10)\n\n                # assert ((gamma + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, \"pi+gamma should always be the same as nu_hat_pos\"\n\n                # unstable_upper_bound_index: spec unstable, batch, in_C, H, W (same as patches last_lA)\n                unstable_upper_bound_index = unstable_or_cut_index.logical_and(last_lA.patches < 0)\n                # upper_d: (spec unstable, 1, in_C, H, W) (unfolded shape, same as patches last_lA)\n                new_upper_d = new_upper_d * unstable_upper_bound_index.to(x_lower_unfold.dtype) + \\\n                              upper_d * (1. - unstable_upper_bound_index.to(x_lower_unfold.dtype))\n\n                if last_uA is None: uA, ubias = None, 0.\n                # lbias: unstable, batch\n                # lA: unstable, batch, in_C, H, W (same as patches last_lA)\n                lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size)\n\n                # if general_beta[0].sum()!=0: import pdb; pdb.set_trace()\n                # there is any unstable relus in this layer\n                if unstable_or_cut_index.sum() > 0:\n                    uC = -x_upper_unfold * nu_hat_pos\n                    lC = -x_lower_unfold * nu_hat_pos\n                    lbias = (pi * x_lower_unfold)\n                    # lbias[beta_mm_coeffs[0] <= uC] = 0.\n                    # lbias[beta_mm_coeffs[0] >= lC] = -beta_mm_coeffs[0][beta_mm_coeffs[0] >= lC].to(lbias)\n                    lbias[(beta_mm_coeffs[0] <= uC)| z_split_to_0_mask] = 0.\n                    lbias[(beta_mm_coeffs[0] >= lC)| z_split_to_1_mask] = -beta_mm_coeffs[0][(beta_mm_coeffs[0] >= lC)| z_split_to_1_mask].to(lbias)\n                    # lbias: unstable, batch, in_C, H, W (same as patches last_lA) => lbias: (unstable, batch)\n                    lbias = (lbias * unstable_or_cut_index.to(x_lower_unfold.dtype)).view(lbias.shape[0], lbias.shape[1], -1).sum(-1)\n\n            if last_uA is not None:\n                # get the upper bound\n                nu_hat_pos = (-last_uA.patches).clamp(max=0.).abs()\n                # gamma = (-x_lower_unfold * nu_hat_pos - beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10)\n                pi = (x_upper_unfold * nu_hat_pos + beta_mm_coeffs[1]) / (x_upper_minus_lower_unfold + 1e-10)\n                pi = torch.min(pi, nu_hat_pos).clamp(min=0.)\n                pi = nu_hat_pos * z_split_to_1_mask + pi * (~z_split_to_1_mask * ~z_split_to_0_mask)\n                new_upper_d = pi / (nu_hat_pos + 1e-10)\n\n                # assert ((gamma + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, \"pi+gamma should always be the same as nu_hat_pos\"\n\n                unstable_upper_bound_index = unstable_or_cut_index.logical_and((-last_uA.patches) < 0)\n                new_upper_d = new_upper_d * unstable_upper_bound_index.to(x_lower_unfold.dtype) + \\\n                              upper_d * (1. - unstable_upper_bound_index.to(x_lower_unfold.dtype))\n\n                uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)\n                if last_lA is None: lA, lbias = None, 0.\n\n                if unstable_or_cut_index.sum() > 0:\n                    uC = -x_upper_unfold * nu_hat_pos\n                    lC = -x_lower_unfold * nu_hat_pos\n                    ubias = -(pi * x_lower_unfold)\n                    # ubias[beta_mm_coeffs[1] <= uC] = 0.\n                    # ubias[beta_mm_coeffs[1] >= lC] = beta_mm_coeffs[1][beta_mm_coeffs[1] >= lC].to(ubias)\n                    ubias[(beta_mm_coeffs[1] <= uC) | z_split_to_0_mask] = 0.\n                    ubias[(beta_mm_coeffs[1] >= lC) | z_split_to_1_mask] = beta_mm_coeffs[1][(beta_mm_coeffs[1] >= lC) | z_split_to_1_mask].to(ubias)\n                    # ubias: unstable, batch, in_C, H, W (same as patches last_uA) => ubias: (unstable, batch)\n                    ubias = (ubias * unstable_or_cut_index.to(x_lower_unfold.dtype)).view(ubias.shape[0], ubias.shape[1], -1).sum(-1)\n        else:\n            # dense\n            if last_lA is not None:\n                # #####################\n                # # C is nu_hat_pos\n                # # last_lA: (spec unstable, batch, current flattened nodes (current_c*current_h*current_w))\n                # nu_hat_pos = last_lA.clamp(max=0.).abs()\n                # # pi, gamma: spec_unstable, batch, current layer shape (same as last_lA)\n\n                # # need to customize the upper bound slope and lbias for (1) unstable relus and\n                # # (2) relus that are used with upper boundary relaxation\n                # # original upper bound slope is u/(u-l) also equal to pi/(pi+gamma) if no beta_mm_coeffs[0]\n                # # now the upper bound slope should be pi/(p+gamma) updated with beta_mm_coeffs[0]\n\n                # # conv layer:\n                # # upper_d: 1, batch, current_c, current_w, current_h\n                # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current_c, current_w, current_h\n                # # dense layer:\n                # # upper_d: 1, batch, current flattened nodes\n                # # unstable_upper_bound_index, new_upper_d: spec unstable, batch, current flattened nodes\n\n                nu_hat_pos, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_lA(last_lA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d, z_split_to_1_mask, z_split_to_0_mask)\n\n                if last_uA is None: uA, ubias = None, 0.\n                lA, lbias = _bound_oneside(last_lA, lb_lower_d if lower_d is None else lower_d, new_upper_d, lower_b, upper_b, start_node, patch_size)\n                lbias = self.jit_arelu_lbias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, lbias, pi, z_split_to_1_mask, z_split_to_0_mask)\n\n            if last_uA is not None:\n                # # C is nu_hat_pos\n                nu_hat_pos, pi, new_upper_d, unstable_upper_bound_index = self.jit_arelu_uA(last_uA, x.lower, x.upper, beta_mm_coeffs, unstable_or_cut_index, upper_d, z_split_to_1_mask, z_split_to_0_mask)\n\n                # one can test uA by optimize -obj which should have the same obj value\n                uA, ubias = _bound_oneside(last_uA, new_upper_d, ub_lower_d if lower_d is None else lower_d, upper_b, lower_b, start_node, patch_size)\n                if last_lA is None: lA, lbias = None, 0.\n                ubias = self.jit_arelu_ubias(unstable_or_cut_index, nu_hat_pos, beta_mm_coeffs, x.lower, x.upper, ubias, pi, z_split_to_1_mask, z_split_to_0_mask)\n\n        return lA, uA, lbias, ubias\n\n    def input_cut(self, start_node, lA, uA, current_layer_shape, unstable_idx=None, batch_mask=None):\n        # propagate input neuron in cut constraints through relu layer\n        active_cuts = self.active_cuts[start_node.name]\n        if self.x_coeffs is None or active_cuts.size(0) == 0:\n            return lA, uA\n\n        if type(lA) is Patches:\n            A = lA if lA is not None else uA\n            self.patch_trick(start_node, \"input\", A, current_layer_shape)\n\n        general_beta = self.select_active_general_beta(start_node, unstable_idx)\n        x_coeffs = self.x_coeffs[active_cuts]\n        if batch_mask is not None:\n            general_beta = general_beta[:, :, batch_mask]\n        # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)\n        # x_coeffs: (num_constrs, flattened input dims)\n        # beta_bias: (2(0 lower, 1 upper), batch, spec)\n        lA, uA = self.general_beta_coeffs_addmm_to_A(lA, uA, general_beta, x_coeffs, current_layer_shape)\n        return lA, uA\n\n    def bias_cut(self, start_node, lb, ub, unstable_idx=None, batch_mask=None):\n        active_cuts = self.active_cuts[start_node.name]\n        if self.cut_bias is None or active_cuts.size(0) == 0:\n            return lb, ub\n        bias_coeffs = self.cut_bias[active_cuts]\n        general_beta = self.select_active_general_beta(start_node, unstable_idx)\n        if batch_mask is not None:\n            general_beta = general_beta[:, :, batch_mask]\n        # add bias for the bias term of general cut\n        # general_beta: (2(0 lower, 1 upper), spec, batch, num_constrs)\n        # bias_coeffs: (num_constrs,)\n        # beta_bias: (2(0 lower, 1 upper), batch, spec)\n        beta_bias = torch.einsum('sihj,j->shi', general_beta.to(lb.dtype), bias_coeffs.to(lb.dtype))\n        lb = lb + beta_bias[0] if lb is not None else None\n        ub = ub - beta_bias[1] if ub is not None else None\n        return lb, ub\n\n\n# Choose upper or lower bounds based on the sign of last_A\n# this is a copy from activation.py\ndef _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg, start_node, patch_size):\n    if last_A is None:\n        return None, 0\n    if type(last_A) == Tensor:\n        A, bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg, contiguous=True)\n        return A, bias\n    elif type(last_A) == Patches:\n        # if last_A is not an identity matrix\n        assert last_A.identity == 0\n        if last_A.identity == 0:\n            # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.\n            # or (unstable_size, batch_size, in_c, H, W) when it is sparse.\n            patches = last_A.patches\n            patches_shape = patches.shape\n            if len(patches_shape) == 6:\n                patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:])\n                if d_pos is not None:\n                    d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:])\n                if d_neg is not None:\n                    d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:])\n                if b_pos is not None:\n                    b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:])\n                if b_neg is not None:\n                    b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:])\n            A_prod, bias = multiply_by_A_signs(patches, d_pos, d_neg, b_pos, b_neg)\n            # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.\n            # For sparse patches the return bias size is (unstable_size, batch).\n            # For regular patches the return bias size is (spec, batch, out_h, out_w).\n            if len(patches_shape) == 6:\n                A_prod = A_prod.view(*patches_shape)\n            # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.\n            if start_node is not None:\n                if last_A.unstable_idx is not None:\n                    # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).\n                    patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]\n                else:\n                    # Regular patches.\n                    patch_size[start_node.name] = A_prod.size()\n            return Patches(A_prod, last_A.stride, last_A.padding, A_prod.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape), bias\n\n\n# In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return.\n# this is a copy from activation.py\ndef _maybe_unfold(d_tensor, last_A):\n    # d_tensor (out_c, current_c, current_h, current_w): out_c shared the same alpha for spec layer\n    if d_tensor is None:\n        return None\n    # if mode == \"matrix\" or d_tensor is None or last_A is None:\n    if type(last_A) is not Patches or d_tensor is None or last_A is None:\n        return d_tensor\n    # Input are slopes with shape (spec, batch, input_c, input_h, input_w)\n    # Here spec is the same as out_c.\n    # assert d_tensor.ndim == 5\n    origin_d_shape = d_tensor.shape\n    if d_tensor.ndim == 6:\n        d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:])\n    d_shape = d_tensor.size()\n    # Reshape to 4-D tensor to unfold.\n    d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:])\n    # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).\n    d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding)\n    # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c.\n    d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:])\n    if last_A.unstable_idx is not None:\n        if d_unfolded_r.size(0) == 1:\n            if len(last_A.unstable_idx) == 3:\n                # Broadcast the spec shape, so only need to select the reset dimensions.\n                # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).\n                d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)\n                d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]\n            elif len(last_A.unstable_idx) == 4:\n                # [spec, batch, output_h, output_w, input_c, H, W]\n                # to [output_h, output_w, batch, in_c, H, W]\n                d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)\n                d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]]\n            else:\n                raise NotImplementedError()\n            # output shape: (unstable_size, batch, in_c, H, W).\n        else:\n            d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n        # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).\n    # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W).\n    if d_unfolded_r.ndim != last_A.patches.ndim:\n        d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4)\n    return d_unfolded_r\n"
  },
  {
    "path": "auto_LiRPA/operators/dropout.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .base import *\n\nclass BoundDropout(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if 'ratio' in attr:\n            self.ratio = attr['ratio']\n            self.dynamic = False\n        else:\n            self.ratio = None\n            self.dynamic = True\n        self.clear()\n\n    def clear(self):\n        self.mask = None\n\n    def forward(self, *inputs):\n        x = inputs[0]\n        if not self.training:\n            return x\n        if self.dynamic:\n            # Inputs: data, ratio (optional), training_mode (optional)\n            # We assume ratio must exist in the inputs.\n            # We ignore training_mode, but will use self.training which can be\n            # changed after BoundedModule is built.\n            assert (inputs[1].dtype == torch.float32 or\n                    inputs[1].dtype == torch.float64)\n            self.ratio = inputs[1]\n        if self.ratio >= 1:\n            raise ValueError('Ratio in dropout should be less than 1')\n        self.mask = torch.rand(x.shape, device=self.ratio.device) > self.ratio\n        return x * self.mask / (1 - self.ratio)\n\n    def _check_forward(self):\n        \"\"\" If in the training mode, a forward pass should have been called.\"\"\"\n        if self.training and self.mask is None:\n            raise RuntimeError('For a model with dropout in the training mode, '\\\n                'a clean forward pass must be called before bound computation')\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        empty_A = [(None, None)] * (len(args) -1)\n        if not self.training:\n            return [(last_lA, last_uA), *empty_A], 0, 0\n        self._check_forward()\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            return last_A * self.mask / (1 - self.ratio)\n        lA = _bound_oneside(last_lA)\n        uA = _bound_oneside(last_uA)\n        return [(lA, uA), *empty_A], 0, 0\n\n    def bound_forward(self, dim_in, x, *args):\n        if not self.training:\n            return x\n        self._check_forward()\n        lw = x.lw * self.mask.unsqueeze(1) / (1 - self.ratio)\n        lb = x.lb * self.mask / (1 - self.ratio)\n        uw = x.uw * self.mask.unsqueeze(1) / (1 - self.ratio)\n        ub = x.ub * self.mask / (1 - self.ratio)\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, *v):\n        if not self.training:\n            return v[0]\n        self._check_forward()\n        h_L, h_U = v[0]\n        lower = h_L * self.mask / (1 - self.ratio)\n        upper = h_U * self.mask / (1 - self.ratio)\n        return lower, upper\n"
  },
  {
    "path": "auto_LiRPA/operators/dtype.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .base import *\nfrom ..utils import Patches\n\nclass BoundCast(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.to = attr['to']\n        # See values of enum DataType in TensorProto.\n        # Unsupported: str, uint16, uint32, uint64.\n        self.data_types = [\n            None,  torch.float, torch.uint8, torch.int8,\n            None,  torch.int16, torch.int32, torch.int64,\n            None,  torch.bool, torch.float16, torch.float64,\n            None,  None, torch.complex64, torch.complex128\n        ]\n        self.type = self.data_types[self.to]\n        assert self.type is not None, \"Unsupported type conversion.\"\n        self.use_default_ibp = True\n\n    def forward(self, x):\n        self.type_in = x.dtype\n        return x.to(self.type)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        if type(last_lA) == Tensor or type(last_uA) == Tensor:\n            lA = last_lA.to(self.type_in) if last_lA is not None else None\n            uA = last_uA.to(self.type_in) if last_uA is not None else None\n        else:\n            if last_lA is not None:\n                lA = Patches(last_lA.patches.to(self.type_in), last_lA.stride, last_lA.padding, last_lA.shape, last_lA.identity, last_lA.unstable_idx, last_lA.output_shape)\n            if last_uA is not None:\n                uA = Patches(last_uA.patches.to(self.type_in), last_uA.stride, last_uA.padding, last_uA.shape, last_uA.identity, last_uA.unstable_idx, last_uA.output_shape)\n        return [(lA, uA)], 0, 0\n\n    def bound_forward(self, dim_in, x):\n        return LinearBound(\n            x.lw.to(self.type), x.lb.to(self.type),\n            x.uw.to(self.type), x.ub.to(self.type))\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v[0])\n"
  },
  {
    "path": "auto_LiRPA/operators/gelu.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .s_shaped import BoundTanh\nfrom .base import logger\n\n\n# FIXME resolve duplicate code with BoundTanh\nclass BoundGelu(BoundTanh):\n    sqrt_2 = math.sqrt(2)\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options, precompute=False)\n        self.ibp_intermediate = False\n        self.act_func = F.gelu\n        def d_act_func(x):\n            return (0.5 * (1 + torch.erf(x / self.sqrt_2))\n                    + x * torch.exp(-0.5 * x ** 2) / math.sqrt(2 * torch.pi))\n        self.d_act_func = d_act_func\n        def d2_act_func(x):\n            return (2 * torch.exp(-0.5 * x ** 2) / math.sqrt(2 * torch.pi)\n                    - x ** 2 * torch.exp(-0.5 * x ** 2) / math.sqrt(2 * torch.pi))\n        self.d2_act_func = d2_act_func\n        self.precompute_relaxation(self.act_func, self.d_act_func)\n\n    def _init_masks(self, x):\n        lower = x.lower\n        upper = x.upper\n        self.mask_left_pos = torch.logical_and(lower >= -self.sqrt_2, upper <= 0)\n        self.mask_left_neg = upper <= -self.sqrt_2\n        self.mask_left = torch.logical_xor(upper <= 0,\n                torch.logical_or(self.mask_left_pos, self.mask_left_neg))\n\n        self.mask_right_pos = lower >= self.sqrt_2\n        self.mask_right_neg = torch.logical_and(upper <= self.sqrt_2, lower >= 0)\n        self.mask_right = torch.logical_xor(lower >= 0,\n                torch.logical_or(self.mask_right_pos, self.mask_right_neg))\n\n        self.mask_2 = torch.logical_and(torch.logical_and(upper > 0, upper <= self.sqrt_2),\n                    torch.logical_and(lower < 0, lower >= -self.sqrt_2))\n        self.mask_left_3 = torch.logical_and(lower < -self.sqrt_2, torch.logical_and(\n            upper > 0, upper <= self.sqrt_2))\n        self.mask_right_3 = torch.logical_and(upper > self.sqrt_2, torch.logical_and(\n            lower < 0, lower >= -self.sqrt_2))\n        self.mask_4 = torch.logical_and(lower < -self.sqrt_2, upper > self.sqrt_2)\n        self.mask_both = torch.logical_or(self.mask_2, torch.logical_or(self.mask_4,\n                    torch.logical_or(self.mask_left_3, self.mask_right_3)))\n\n    @torch.no_grad()\n    def precompute_relaxation(self, func, dfunc, x_limit=1000):\n        \"\"\"\n        This function precomputes the tangent lines that will be used as\n        lower/upper bounds for S-shapes functions.\n        \"\"\"\n        self.x_limit = x_limit\n        self.step_pre = 0.01\n        self.num_points_pre = int(self.x_limit / self.step_pre)\n        max_iter = 100\n\n        logger.debug('Precomputing relaxation for GeLU (pre-activation limit: %f)',\n                     x_limit)\n\n        def check_lower(upper, d):\n            \"\"\"Given two points upper, d (d <= upper), check if the slope at d\n            will be less than f(upper) at upper.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a lower bound.\n            return k * (upper - d) + func(d) <= func(upper)\n\n        def check_upper(lower, d):\n            \"\"\"Given two points lower, d (d >= lower), check if the slope at d\n            will be greater than f(lower) at lower.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a upper bound.\n            return k * (lower - d) + func(d) >= func(lower)\n\n        # Given an upper bound point (>=0), find a line that is guaranteed to\n        # be a lower bound of this function.\n        upper = self.step_pre * torch.arange(\n            0, self.num_points_pre + 5, device=self.device) + self.sqrt_2\n        r = torch.ones_like(upper)\n        # Initial guess, the tangent line is at -1.\n        l = -torch.ones_like(upper)\n        while True:\n            # Check if the tangent line at the guessed point is an lower bound at f(upper).\n            checked = check_lower(upper, l).int()\n            # If the initial guess is not smaller enough, then double it (-2, -4, etc).\n            l = checked * l + (1 - checked) * (l * 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to\n        # be an lower bound at f(upper).\n        # We want to further tighten this bound by moving it closer to 0.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_lower(upper, m).int()\n            l = checked * m + (1 - checked) * l\n            r = checked * r + (1 - checked) * m\n        # At upper, a line with slope l is guaranteed to lower bound the function.\n        self.d_lower_right = l.clone()\n\n        # Do the same again:\n        # Given an lower bound point (<=0), find a line that is guaranteed to\n        # be an upper bound of this function.\n        lower = (\n            -self.step_pre * torch.arange(\n                0, self.num_points_pre + 5, device=self.device\n            ) + self.sqrt_2).clamp(min=0.01)\n        l = torch.zeros_like(upper) + self.sqrt_2\n        r = torch.zeros_like(upper) + x_limit\n        while True:\n            checked = check_upper(lower, r).int()\n            r = checked * r + (1 - checked) * (r * 2)\n            if checked.sum() == l.numel():\n                break\n        for _ in range(max_iter):\n            m = (l + r) / 2\n            checked = check_upper(lower, m).int()\n            l = (1 - checked) * m + checked * l\n            r = (1 - checked) * r + checked * m\n        self.d_upper_right = r.clone()\n\n        upper = -self.step_pre * torch.arange(\n            0, self.num_points_pre + 5, device=self.device) - self.sqrt_2\n        r = torch.zeros_like(upper) - 0.7517916\n        # Initial guess, the tangent line is at -1.\n        l = torch.zeros_like(upper) - self.sqrt_2\n        while True:\n            checked = check_lower(upper, r).int()\n            r = checked * r + (1 - checked) * (r * 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to be\n        # an lower bound at f(upper).\n        # We want to further tighten this bound by moving it closer to 0.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_lower(upper, m).int()\n            l = (1 - checked) * m + checked * l\n            r = (1 - checked) * r + checked * m\n        # At upper, a line with slope l is guaranteed to lower bound the function.\n        self.d_lower_left = r.clone()\n\n        # Do the same again:\n        # Given an lower bound point (<=0), find a line that is guaranteed to\n        # be an upper bound of this function.\n        lower = (\n            self.step_pre * torch.arange(\n                0, self.num_points_pre + 5, device=self.device\n            ) - self.sqrt_2).clamp(max=0)\n        l = torch.zeros_like(upper) - x_limit\n        r = torch.zeros_like(upper) - self.sqrt_2\n        while True:\n            checked = check_upper(lower, l).int()\n            l = checked * l + (1 - checked) * (l * 2)\n            if checked.sum() == l.numel():\n                break\n        for _ in range(max_iter):\n            m = (l + r) / 2\n            checked = check_upper(lower, m).int()\n            l = (1 - checked) * m + checked * l\n            r = (1 - checked) * r + checked * m\n        self.d_upper_left = r.clone()\n\n        logger.debug('Done')\n\n    def opt_init(self):\n        super().opt_init()\n        self.tp_right_lower_init = {}\n        self.tp_right_upper_init = {}\n        self.tp_left_lower_init = {}\n        self.tp_left_upper_init = {}\n        self.tp_both_lower_init = {}\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        shape = [size_spec] + list(l.shape)\n        alpha = torch.empty(14, *shape, device=l.device)\n        alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape)\n        alpha.data[4:6] = self.tp_right_lower_init[name_start].expand(2, *shape)\n        alpha.data[6:8] = self.tp_right_upper_init[name_start].expand(2, *shape)\n        alpha.data[8:10] = self.tp_left_lower_init[name_start].expand(2, *shape)\n        alpha.data[10:12] = self.tp_left_upper_init[name_start].expand(2, *shape)\n        alpha.data[12:14] = self.tp_both_lower_init[name_start].expand(2, *shape)\n        return alpha\n\n    def forward(self, x):\n        return F.gelu(x)\n\n    def bound_relax_impl(self, x, func, dfunc):\n        lower, upper = x.lower, x.upper\n        y_l, y_u = func(lower), func(upper)\n        # k_direct is the slope of the line directly connect\n        # (lower, func(lower)), (upper, func(upper)).\n        k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8)\n\n        # Fixed bounds that cannot be optimized. self.mask_neg are the masks\n        # for neurons with upper bound <= 0.\n        # Upper bound for the case of input lower bound <= 0, is always the direct line.\n        self.add_linear_relaxation(\n            mask=torch.logical_or(\n                torch.logical_or(self.mask_left_pos, self.mask_right_neg),\n                self.mask_both\n            ), type='upper', k=k_direct, x0=lower, y0=y_l)\n        # Lower bound for the case of input upper bound >= 0, is always the direct line.\n        self.add_linear_relaxation(\n            mask=torch.logical_or(self.mask_left_neg,\n                    self.mask_right_pos), type='lower', k=k_direct, x0=lower, y0=y_l)\n\n        # Indices of neurons with input upper bound >= sqrt(2),\n        # whose optimal slope to lower bound on the right side was pre-computed.\n        d_lower_right = self.retrieve_from_precompute(\n            self.d_lower_right, upper - self.sqrt_2, lower)\n\n        # Indices of neurons with input lower bound <= -sqrt(2),\n        # whose optimal slope to lower bound on the left side was pre-computed.\n        d_lower_left = self.retrieve_from_precompute(\n            self.d_lower_left, -lower - self.sqrt_2, upper)\n\n        # Indices of neurons with input lower bound <= sqrt(2),\n        # whose optimal slope to upper bound on the right side was pre-computed.\n        d_upper_right = self.retrieve_from_precompute(\n            self.d_upper_right, -lower + self.sqrt_2, upper)\n\n        # Indices of neurons with input lower bound <= sqrt(2),\n        # whose optimal slope to upper bound on the right side was pre-computed.\n        d_upper_left = self.retrieve_from_precompute(\n            self.d_upper_left, -lower - self.sqrt_2, upper)\n\n        if self.opt_stage in ['opt', 'reuse']:\n            if not hasattr(self, 'alpha'):\n                # Raise an error if alpha is not created.\n                self._no_bound_parameters()\n            ns = self._start\n\n            # Clipping is done here rather than after `opt.step()` call\n            # because it depends on pre-activation bounds\n            self.alpha[ns].data[0:2] = torch.max(\n                torch.min(self.alpha[ns][0:2], upper), lower)\n            self.alpha[ns].data[2:4] = torch.max(\n                torch.min(self.alpha[ns][2:4], upper), lower)\n            self.alpha[ns].data[4:6] = torch.max(\n                torch.min(self.alpha[ns][4:6], d_lower_right), lower)\n            self.alpha[ns].data[6:8] = torch.max(\n                self.alpha[ns][6:8], d_upper_right)\n            self.alpha[ns].data[8:10] = torch.min(\n                torch.max(self.alpha[ns][8:10], d_lower_left), upper)\n            self.alpha[ns].data[10:12] = torch.min(\n                self.alpha[ns][10:12], d_upper_left)\n            self.alpha[ns].data[12:14] = torch.min(\n                torch.max(self.alpha[ns][12:14], d_lower_left), d_lower_right)\n\n            # shape [2, out_c, n, c, h, w].\n            tp_pos = self.alpha[ns][0:2]  # For upper bound relaxation\n            tp_neg = self.alpha[ns][2:4]  # For lower bound relaxation\n            tp_right_lower = self.alpha[ns][4:6]\n            tp_right_upper = self.alpha[ns][6:8]\n            tp_left_lower = self.alpha[ns][8:10]\n            tp_left_upper = self.alpha[ns][10:12]\n            tp_both_lower = self.alpha[ns][12:14]\n\n            # No need to use tangent line, when the tangent point is at the left\n            # side of the preactivation lower bound. Simply connect the two sides.\n            mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(lower))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_or(self.mask_right_3,\n                    torch.logical_xor(self.mask_right, mask_direct)), type='lower',\n                k=dfunc(tp_right_lower), x0=tp_right_lower)\n            mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(upper))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_or(self.mask_left_3,\n                    torch.logical_xor(self.mask_left, mask_direct)), type='lower',\n                k=dfunc(tp_left_lower), x0=tp_left_lower)\n\n            mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(upper))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_right, mask_direct), type='upper',\n                k=dfunc(tp_right_upper), x0=tp_right_upper)\n            mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(lower))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_left, mask_direct), type='upper',\n                k=dfunc(tp_left_upper), x0=tp_left_upper)\n\n            self.add_linear_relaxation(\n                mask=self.mask_4, type='lower', k=dfunc(tp_both_lower), x0=tp_both_lower)\n            self.add_linear_relaxation(\n                mask=torch.logical_or(torch.logical_or(self.mask_left_pos, self.mask_right_neg),\n                    self.mask_2), type='lower', k=dfunc(tp_neg), x0=tp_neg)\n            self.add_linear_relaxation(\n                mask=torch.logical_or(self.mask_right_pos,\n                    self.mask_left_neg), type='upper', k=dfunc(tp_pos), x0=tp_pos)\n        else:\n            if self.opt_stage == 'init':\n                # Initialize optimizable slope.\n                tp_right_lower_init = d_lower_right.detach()\n                tp_right_upper_init = d_upper_right.detach()\n                tp_left_lower_init = d_lower_left.detach()\n                tp_left_upper_init = d_upper_left.detach()\n                tp_both_lower_init = d_lower_right.detach()\n\n                ns = self._start\n                self.tp_right_lower_init[ns] = tp_right_lower_init\n                self.tp_right_upper_init[ns] = tp_right_upper_init\n                self.tp_left_lower_init[ns] = tp_left_lower_init\n                self.tp_left_upper_init[ns] = tp_left_upper_init\n                self.tp_both_lower_init[ns] = tp_both_lower_init\n\n            # Not optimized (vanilla CROWN bound).\n            # Use the middle point slope as the lower/upper bound. Not optimized.\n            m = (lower + upper) / 2\n            y_m = func(m)\n            k = dfunc(m)\n            # Lower bound is the middle point slope for the case input upper bound <= 0.\n            # Note that the upper bound in this case is the direct line between\n            # (lower, func(lower)) and (upper, func(upper)).\n            self.add_linear_relaxation(\n                mask=torch.logical_or(\n                    torch.logical_or(self.mask_left_pos, self.mask_right_neg),\n                    self.mask_2\n                ), type='lower', k=k, x0=m, y0=y_m)\n            # Upper bound is the middle point slope for the case input lower bound >= 0.\n            # Note that the lower bound in this case is the direct line between\n            # (lower, func(lower)) and (upper, func(upper)).\n            self.add_linear_relaxation(mask=torch.logical_or(self.mask_right_pos,\n                    self.mask_left_neg), type='upper', k=k, x0=m, y0=y_m)\n\n            # Now handle the case where input lower bound <=0 and upper bound >= 0.\n            # A tangent line starting at d_lower is guaranteed to be a lower bound\n            # given the input upper bound.\n            mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(lower))\n            self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)\n            # Otherwise we do not use the direct line, we use the d_lower slope.\n            self.add_linear_relaxation(\n                mask=torch.logical_or(torch.logical_or(self.mask_right_3, self.mask_4),\n                    torch.logical_xor(self.mask_right, mask_direct)), type='lower',\n                k=dfunc(d_lower_right), x0=d_lower_right)\n            mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(upper))\n            self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_or(self.mask_left_3,\n                    torch.logical_xor(self.mask_left, mask_direct)), type='lower',\n                k=dfunc(d_lower_left), x0=d_lower_left)\n\n            mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(upper))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_right, mask_direct), type='upper',\n                k=dfunc(d_upper_right), x0=d_upper_right)\n            mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(lower))\n            self.add_linear_relaxation(\n                mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_left, mask_direct), type='upper',\n                k=dfunc(d_upper_left), x0=d_upper_left)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        self.bound_relax_impl(x, self.act_func, self.d_act_func)\n\n    def interval_propagate(self, *v):\n        pl, pu = self.forward(v[0][0]), self.forward(v[0][1])\n        pl, pu = torch.min(pl, pu), torch.max(pl, pu)\n        min_global = self.forward(torch.tensor(-0.7517916))\n        pl, pu = torch.min(min_global, torch.min(pl, pu)), torch.max(pl, pu)\n        return pl, pu\n\n\nclass GELUOp(torch.autograd.Function):\n    sqrt_2 = math.sqrt(2)\n    sqrt_2pi = math.sqrt(2 * math.pi)\n\n    @staticmethod\n    def symbolic(g, x):\n        return g.op('custom::Gelu', x)\n\n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return torch.nn.functional.gelu(x)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        grad = (0.5 * (1 + torch.erf(x / GELUOp.sqrt_2))\n                + x * torch.exp(-0.5 * x ** 2) / GELUOp.sqrt_2pi)\n        return grad_input * grad\n\n\nclass GELU(nn.Module):\n    def forward(self, x):\n        return GELUOp.apply(x)\n"
  },
  {
    "path": "auto_LiRPA/operators/indexing.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .base import *\nfrom ..patches import Patches, patches_to_matrix\nfrom torch.nn import Module\n\n\nclass BoundGather(Bound):\n    def __init__(self, attr, x, output_index, options):\n        super().__init__(attr, x, output_index, options)\n        self.axis = attr['axis'] if 'axis' in attr else 0\n\n    def forward(self, x, indices):\n        self.indices = indices\n        if self.axis == -1:\n            self.axis = len(x.shape) - 1\n        # BoundShape.shape() will return values on cpu only\n        x = x.to(self.indices.device)\n        if indices.ndim == 0:\n            if indices == -1:\n                self.indices = x.shape[self.axis] + indices\n            return torch.index_select(x, dim=self.axis, index=self.indices).squeeze(self.axis)\n        elif indices.ndim == 1:\n            if self.axis == 0:\n                assert not self.perturbed\n            # `index_select` requires `indices` to be a 1-D tensor\n            return torch.index_select(x, dim=self.axis, index=indices)\n\n        raise ValueError('Unsupported shapes in Gather: '\n                         f'data {x.shape}, indices {indices.shape}, '\n                         f'axis {self.axis}')\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        assert self.from_input\n\n        def _expand_A_with_zeros(A, axis, idx, max_axis_size):\n            # Need to recreate A with three parts: before the gathered element, gathered element, and after gathered element.\n            tensors = []\n            if idx < 0:\n                idx = max_axis_size + idx\n            if idx > 0:\n                shape_pre = list(A.shape)\n                shape_pre[axis] *= idx\n                # Create the same shape as A, except for the dimension to be gathered.\n                tensors.append(torch.zeros(shape_pre, device=A.device))\n            # The gathered element itself, in the middle.\n            tensors.append(A)\n            if max_axis_size - idx - 1 > 0:\n                shape_next = list(A.shape)\n                shape_next[axis] *= max_axis_size - idx - 1\n                # Create the rest part of A.\n                tensors.append(torch.zeros(shape_next, device=A.device))\n            # Concatenate all three parts together.\n            return torch.cat(tensors, dim=axis)\n\n        def _bound_oneside(A):\n            if A is None:\n                return None\n\n            if isinstance(A, torch.Tensor):\n                if self.indices.ndim == 0:\n                    A = A.unsqueeze(self.axis + 1)\n                    idx = int(self.indices)\n                    return _expand_A_with_zeros(A, self.axis + 1, idx, self.input_shape[self.axis])\n                else:\n                    shape = list(A.shape)\n                    final_A = torch.zeros(*shape[:self.axis + 1], self.input_shape[self.axis], *shape[self.axis + 2:], device=A.device)\n                    idx = self.indices.view([*[1]*(self.axis+1), -1, *[1]*len(shape[self.axis + 2:])])\n                    idx = idx.repeat([*A.shape[:self.axis+1], 1, *A.shape[self.axis+2:]])\n                    final_A.scatter_add_(dim=self.axis+1, index=idx, src=A)\n                    return final_A\n            elif isinstance(A, Patches):\n                if self.indices.ndim == 0:\n                    idx = int(self.indices)\n                    assert len(self.input_shape) == 4 and self.axis == 1, \"Gather is only supported on the channel dimension for Patches mode.\"\n                    # For gather in the channel dimension, we only need to deal with the in_c dimension (-3) in patches.\n                    patches = A.patches\n                    # -3 is the in_c dimension.\n                    new_patches = _expand_A_with_zeros(patches, axis=-3, idx=idx, max_axis_size=self.input_shape[self.axis])\n                    return A.create_similar(new_patches)\n                else:\n                    raise NotImplementedError\n            else:\n                raise ValueError(f'Unknown last_A type {type(A)}')\n\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, x, indices):\n        assert self.indices.numel() == 1 and self.indices.ndim <= 1 and (self.indices >= 0).all()\n        if isinstance(x, torch.Size):\n            lw = uw = torch.zeros(dim_in, device=self.device)\n            lb = ub = torch.index_select(\n                torch.tensor(x, device=self.device),\n                dim=self.axis, index=self.indices).squeeze(self.axis)\n        else:\n            axis = self.axis + 1\n            lw = torch.index_select(x.lw, dim=self.axis + 1, index=self.indices)\n            uw = torch.index_select(x.uw, dim=self.axis + 1, index=self.indices)\n            lb = torch.index_select(x.lb, dim=self.axis, index=self.indices)\n            ub = torch.index_select(x.ub, dim=self.axis, index=self.indices)\n            if self.indices.ndim == 0:\n                lw = lw.squeeze(axis)\n                uw = uw.squeeze(axis)\n                lb = lb.squeeze(self.axis)\n                ub = ub.squeeze(self.axis)\n        return LinearBound(lw, lb, uw, ub)\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(1)\n        return self.forward(v[0][0], v[1][0]), self.forward(v[0][1], v[1][0])\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v[0], v[1])\n\n    def build_gradient_node(self, grad_upstream):\n        return [(GatherGrad(self.axis, self.indices, self.input_shape), (grad_upstream,), []), None]\n\n\nclass GatherGrad(Module):\n    def __init__(self, axis, indices, input_shape):\n        super().__init__()\n        self.axis = axis\n        self.indices = indices\n        self.input_shape = input_shape\n    \n    def forward(self, grad_last):\n        # TODO: It's better to use scatter_add_ instead of cat.\n        # This is a workaround for the fact that scatter_add_ does not support negative indices.\n\n        # Scalar indices case (ndim == 0)\n        if self.indices.ndim == 0:\n            grad_unsq = grad_last.unsqueeze(self.axis)\n            \n            # Get the scalar index and adjust if negative.\n            idx = int(self.indices)\n            if idx < 0:\n                idx = self.input_shape[self.axis] + idx\n            \n            # Build the gradient by concatenating three parts along self.axis:\n            tensors = []\n            # 1. Zeros block before the gathered element (if idx > 0)\n            if idx > 0:\n                shape_pre = list(grad_unsq.shape)\n                shape_pre[self.axis] = idx  # pre-block has size idx along self.axis\n                zeros_pre = torch.zeros(shape_pre, dtype=grad_last.dtype, device=grad_last.device)\n                tensors.append(zeros_pre)\n            \n            # 2. The gathered gradient slice (already in grad_unsq)\n            tensors.append(grad_unsq)\n            \n            # 3. Zeros block after the gathered element\n            num_after = self.input_shape[self.axis] - idx - 1\n            if num_after > 0:\n                shape_post = list(grad_unsq.shape)\n                shape_post[self.axis] = num_after\n                zeros_post = torch.zeros(shape_post, dtype=grad_last.dtype, device=grad_last.device)\n                tensors.append(zeros_post)\n            \n            # Concatenate all parts along self.axis to form the full gradient tensor.\n            grad_input = torch.cat(tensors, dim=self.axis)\n            return grad_input\n\n        # 1-D indices case (ndim == 1)\n        elif self.indices.ndim == 1:\n            grad_slices = []\n            # Iterate over each position in the original input along self.axis.\n            for i in range(self.input_shape[self.axis]):\n                # matching: tensor of indices (in grad_last) where the gathered index equals i.\n                matching = (self.indices == i).nonzero(as_tuple=False).squeeze(-1)\n                \n                if matching.numel() == 0:\n                    # No matching index: create a zeros slice with the same shape as one slice of grad_last.\n                    slice_shape = list(grad_last.shape)\n                    slice_shape[self.axis] = 1  # single slice along self.axis\n                    grad_slice = torch.zeros(slice_shape, dtype=grad_last.dtype, device=grad_last.device)\n                else:\n                    # There are one or more matching positions.\n                    # For each matching index j, extract the corresponding slice from grad_last.\n                    slice_list = []\n                    for j in matching.tolist():\n                        # Build slicing object：select all elements, but at self.axis take index j.\n                        slicer = [slice(None)] * grad_last.dim()\n                        slicer[self.axis] = j\n                        # Extract the slice and add back the missing dimension.\n                        slice_j = grad_last[tuple(slicer)].unsqueeze(self.axis)\n                        slice_list.append(slice_j)\n                    # Concatenate all slices along self.axis; if there are duplicates, sum them.\n                    cat_slices = torch.cat(slice_list, dim=self.axis)\n                    # Sum along self.axis to accumulate contributions from duplicate indices.\n                    grad_slice = cat_slices.sum(dim=self.axis, keepdim=True)\n                # Append the slice corresponding to position i.\n                grad_slices.append(grad_slice)\n            \n            # Concatenate all slices in order along self.axis to form the final gradient tensor.\n            grad_input = torch.cat(grad_slices, dim=self.axis)\n            return grad_input\n\n        else:\n            raise ValueError(\"Unsupported indices dimensions in gradient for Gather\")\n\n\nclass BoundGatherElements(Bound):\n    def __init__(self, attr, input, output_index, options):\n        super().__init__(attr, input, output_index, options)\n        self.axis = attr['axis']\n\n    def forward(self, x, index):\n        self.index = index\n        return torch.gather(x, dim=self.axis, index=index)\n\n    def bound_backward(self, last_lA, last_uA, x, index, **kwargs):\n        assert self.from_input\n\n        dim = self._get_dim()\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            A = torch.zeros(\n                last_A.shape[0], last_A.shape[1], *x.output_shape[1:], device=last_A.device)\n            A.scatter_(\n                dim=dim + 1,\n                index=self.index.unsqueeze(0).repeat(A.shape[0], *([1] * (A.ndim - 1))),\n                src=last_A)\n            return A\n\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(1)\n        return self.forward(v[0][0], v[1][0]), \\\n               self.forward(v[0][1], v[1][1])\n\n    def bound_forward(self, dim_in, x, index):\n        assert self.axis != 0\n        dim = self._get_dim()\n        return LinearBound(\n            torch.gather(x.lw, dim=dim + 1, index=self.index.unsqueeze(1).repeat(1, dim_in, 1)),\n            torch.gather(x.lb, dim=dim, index=self.index),\n            torch.gather(x.uw, dim=dim + 1, index=self.index.unsqueeze(1).repeat(1, dim_in, 1)),\n            torch.gather(x.ub, dim=dim, index=self.index))\n\n    def _get_dim(self):\n        dim = self.axis\n        if dim < 0:\n            dim = len(self.output_shape) + dim\n        return dim\n"
  },
  {
    "path": "auto_LiRPA/operators/jacobian.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nfrom torch.nn import Module\nfrom .base import Bound\nfrom ..utils import prod\n\n\nclass JacobianOP(torch.autograd.Function):\n    @staticmethod\n    def symbolic(g, output, input):\n        return g.op('grad::jacobian', output, input).setType(output.type())\n\n    @staticmethod\n    def forward(ctx, output, input):\n        output_ = output.flatten(1)\n        return torch.zeros(\n            output.shape[0], output_.shape[-1], *input.shape[1:],\n            device=output.device)\n\n\nclass BoundJacobianOP(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n\n    def forward(self, output, input):\n        return JacobianOP.apply(output, input)\n\n\nclass BoundJacobianInit(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.never_perturbed = True\n\n    def forward(self, x):\n        dim = prod(x.shape[1:])\n        eye = torch.eye(dim, device=x.device, requires_grad=x.requires_grad)\n        eye = eye.unsqueeze(0).expand(\n            x.shape[0], -1, -1\n        ).view(x.shape[0], dim, *x.shape[1:])\n        return eye\n\n\nclass GradNorm(Module):\n    def __init__(self, norm=1):\n        super().__init__()\n        self.norm = norm\n\n    def forward(self, grad):\n        grad = grad.view(grad.size(0), -1)\n        if self.norm == 1:\n            # torch.norm is not supported in auto_LiRPA yet\n            # use simpler operators for now\n            return grad.abs().sum(dim=-1, keepdim=True)\n        elif self.norm == 2:\n            return (grad * grad).sum(dim=-1)\n        else:\n            raise NotImplementedError(self.norm)\n"
  },
  {
    "path": "auto_LiRPA/operators/leaf.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Leaf nodes (indepedent nodes in the auto_LiRPA paper).\n\nIncluding input, parameter, buffer, etc.\"\"\"\n\nfrom itertools import chain\nfrom .base import *\n\n\nclass BoundInput(Bound):\n    def __init__(self, ori_name, value, perturbation=None, input_index=None, options=None, attr=None):\n        super().__init__(options=options, attr=attr)\n        self.ori_name = ori_name\n        self.value = value\n        self.perturbation = perturbation\n        self.from_input = True\n        self.input_index = input_index\n        self.no_jacobian = True\n\n    def __setattr__(self, key, value):\n        super().__setattr__(key, value)\n        # Update perturbed property based on the perturbation set.\n        if key == \"perturbation\":\n            if self.perturbation is not None:\n                self.perturbed = True\n            else:\n                self.perturbed = False\n\n    def forward(self):\n        return self.value\n\n    def bound_forward(self, dim_in):\n        assert 0\n\n    def bound_backward(self, last_lA, last_uA, **kwargs):\n        raise ValueError('{} is a BoundInput node and should not be visited here'.format(\n            self.name))\n\n    def interval_propagate(self, *v):\n        raise ValueError('{} is a BoundInput node and should not be visited here'.format(\n            self.name))\n\nclass BoundParams(BoundInput):\n    def __init__(self, ori_name, value, perturbation=None, options=None, attr=None):\n        super().__init__(ori_name, None, perturbation, attr=attr)\n        self.register_parameter('param', value)\n        if options is None:\n            options = {}\n        self.auto_requires_grad = options.get(\"param\", {}).get(\"auto_requires_grad\", True)\n        self.from_input = False\n\n    def register_parameter(self, name, param):\n        \"\"\"Override register_parameter() hook to register only needed parameters.\"\"\"\n        if name == 'param':\n            return super().register_parameter(name, param)\n        else:\n            # Just register it as a normal property of class.\n            object.__setattr__(self, name, param)\n\n    def init(self, initializing=False):\n        self.initializing = initializing\n\n    def forward(self):\n        param = self.param\n        if self.auto_requires_grad:\n            param = param.requires_grad_(self.training)\n        return param\n\nclass BoundBuffers(BoundInput):\n    def __init__(self, ori_name, value, perturbation=None, options=None, attr=None):\n        super().__init__(ori_name, None, perturbation, attr=attr)\n        self.register_buffer('buffer', value.clone().detach())\n        # BoundBuffers are like constants and they are by default not from inputs.\n        # The \"has_batchdim\" was a hack that will forcibly set BoundBuffer to be\n        # from inputs, to workaround buffers with a batch size dimension. This is\n        # not needed in most cases now.\n        if 'buffers' in options and 'has_batchdim' in options['buffers']:\n            warnings.warn('The \"has_batchdim\" option for BoundBuffers is deprecated.'\n                          ' It may be removed from the next release.')\n        self.from_input = options.get('buffers', {}).get('has_batchdim', False)\n\n    def forward(self):\n        return self.buffer\n"
  },
  {
    "path": "auto_LiRPA/operators/linear.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Linear (possibly with weight perturbation) or Dot product layers \"\"\"\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom typing import Tuple, List\nfrom .activation_base import BoundOptimizableActivation\nfrom .base import *\nfrom .bivariate import BoundMul, MulHelper\nfrom .leaf import BoundParams, BoundBuffers\nfrom ..patches import Patches, inplace_unfold\nfrom .solver_utils import grb\nfrom .clampmult import multiply_by_A_signs\n\nEPS = 1e-2\n\nclass BoundLinear(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        # Gemm:\n        # A = A if transA == 0 else A.T\n        # B = B if transB == 0 else B.T\n        # C = C if C is not None else np.array(0)\n        # Y = alpha * np.dot(A, B) + beta * C\n        # return Y\n\n        super().__init__(attr, inputs, output_index, options)\n\n        # Defaults in ONNX\n        self.transA = 0\n        self.transB = 0\n        self.alpha_linear = 1.0\n        self.beta_linear = 1.0\n        if attr is not None:\n            self.transA = attr['transA'] if 'transA' in attr else self.transA\n            self.transB = attr['transB'] if 'transB' in attr else self.transB\n            self.alpha_linear = attr['alpha'] if 'alpha' in attr else self.alpha_linear\n            self.beta_linear = attr['beta'] if 'beta' in attr else self.beta_linear\n\n        options = options or {}\n        self.opt_matmul = options.get('matmul')\n        self.splittable = False\n\n        self.mul_helper = MulHelper()\n        self.use_seperate_weights_for_lower_and_upper_bounds = False\n        self.batched_weight_and_bias = False\n        self.share_alphas = options.get('matmul', {}).get('share_alphas', False)\n        self.mul_middle = options.get('mul', {}).get('middle', False)\n        # For MatMul, it's possible that only the second input is perturbed.\n        # In this case, we swap the roles of x and weight.\n        self.swap_x_and_weight = False\n\n    def _preprocess(self, a, b, c=None):\n        \"\"\"Handle tranpose and linear coefficients.\"\"\"\n        if self.transA and isinstance(a, Tensor):\n            a = a.transpose(-2,-1)\n        if self.alpha_linear != 1.0:\n            a = self.alpha_linear * a\n        if not self.transB and isinstance(b, Tensor):\n            # our code assumes B is transposed (common case), so we transpose B\n            # only when it is not transposed in gemm.\n            b = b.transpose(-2, -1)\n        if c is not None:\n            if self.beta_linear != 1.0:\n                c = self.beta_linear * c\n        return a, b, c\n\n    def init_opt_parameters(self, start_nodes):\n        shared_alpha_dims = []\n        if self.share_alphas:\n            # TODO Temporarily an adhoc check for alpha sharing.\n            count_matmul = len([item for item in self._all_optimizable_activations\n                                if isinstance(item, BoundLinear)])\n            if count_matmul >= 6:\n                shared_alpha_dims = [1, 2, 3]\n            elif count_matmul >= 4:\n                shared_alpha_dims = [1, 2]\n\n        input_lb = [xi.lower for xi in self.inputs]\n        input_ub = [xi.upper for xi in self.inputs]\n        input_lb = self._preprocess(*input_lb)\n        input_ub = self._preprocess(*input_ub)\n        x_l, x_u, y_l, y_u = self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1])\n        assert x_l.ndim == y_l.ndim\n        shape = [1 if i in shared_alpha_dims\n                 else max(x_l.shape[i], y_l.shape[i]) for i in range(x_l.ndim)]\n        for start_node in start_nodes:\n            ns, size_s = start_node[:2]\n            # start_node[3] == False means that this start node is not the final node\n            # if not start_node[3]:\n            #     # NOTE Experimental code. Please check how it will impact the results.\n            #     size_s = 1\n            if isinstance(size_s, torch.Size):\n                # TODO do not give torch.Size\n                size_s = prod(size_s)\n            elif isinstance(size_s, (list, tuple)):\n                size_s = size_s[0]\n            self.alpha[ns] = torch.ones(4, size_s, *shape, device=x_l.device)\n\n    def forward(self, x, w, b=None):\n        x, w, b = self._preprocess(x, w, b)\n        self.input_shape = self.x_shape = x.shape\n        self.y_shape = w.t().shape\n        res = x.matmul(w.t())\n        if b is not None:\n            res += b\n        return res\n\n    def onehot_mult(self, weight, bias, C, batch_size):\n        \"\"\"Multiply weight matrix with a diagonal matrix with selected rows.\"\"\"\n\n        if C is None:\n            return None, 0.0\n\n        new_weight = None\n        new_bias = 0.0\n\n        if C.index.ndim == 2:\n            # Shape is [spec, batch]\n            index = C.index.transpose(0, 1)\n            coeffs = C.coeffs.transpose(0, 1)\n        else:\n            index = C.index\n            coeffs = C.coeffs\n\n        if C.index.ndim == 1:\n            # Every element in the batch shares the same rows.\n            if weight is not None:\n                new_weight = self.non_deter_index_select(\n                    weight, dim=0, index=index\n                ).unsqueeze(1).expand(\n                    [-1, batch_size] + [-1] * (weight.ndim - 1))\n            if bias is not None:\n                new_bias = self.non_deter_index_select(\n                    bias, dim=0, index=index\n                ).unsqueeze(1).expand(-1, batch_size)\n        elif C.index.ndim == 2:\n            # Every element in the batch has different rows, but the number of\n            # rows are the same. This essentially needs a batched index_select function.\n            if weight is not None:\n                new_weight = batched_index_select(\n                    weight.unsqueeze(0), dim=1, index=index)\n            if bias is not None:\n                new_bias = batched_index_select(\n                    bias.unsqueeze(0), dim=1, index=index)\n        if C.coeffs is not None:\n            if weight is not None:\n                new_weight = new_weight * coeffs.unsqueeze(-1)\n            if bias is not None:\n                new_bias = new_bias * coeffs\n        if C.index.ndim == 2:\n            # Eventually, the shape of A is [spec, batch, *node] so need a transpose.\n            new_weight = new_weight.transpose(0, 1)\n            new_bias = new_bias.transpose(0, 1)\n        return new_weight, new_bias\n\n    def bound_backward(self, last_lA, last_uA, *x, start_node=None,\n                       reduce_bias=True, **kwargs):\n        assert len(x) == 2 or len(x) == 3\n        if start_node is not None:\n            self._start = start_node.name\n        has_bias = len(x) == 3\n        # x[0]: input node, x[1]: weight, x[2]: bias\n        input_lb = [xi.lower for xi in x]\n        input_ub = [xi.upper for xi in x]\n        if self.swap_x_and_weight:\n            input_lb = [input_lb[1].transpose(-1, -2) if input_lb[1] is not None else None,\n                        input_lb[0].transpose(-1, -2) if input_lb[0] is not None else None,\n                        input_lb[2:]]\n            input_ub = [input_ub[1].transpose(-1, -2) if input_ub[1] is not None else None,\n                        input_ub[0].transpose(-1, -2) if input_ub[0] is not None else None,\n                        input_ub[2:]]\n            if last_lA is not None:\n                if isinstance(last_lA, torch.Tensor):\n                    last_lA = last_lA.transpose(-1, -2)\n                elif isinstance(last_lA, eyeC):\n                    last_lA = last_lA._replace(shape=last_lA.shape[:-2] + (last_lA.shape[-1], last_lA.shape[-2]))\n                else:\n                    raise NotImplementedError(\n                        f\"last_lA's type {type(last_lA)} is not supported for transpose in the case of swapping x and weight.\")\n            if last_uA is not None:\n                if isinstance(last_uA, torch.Tensor):\n                    last_uA = last_uA.transpose(-1, -2)\n                elif isinstance(last_uA, eyeC):\n                    last_uA = last_uA._replace(shape=last_uA.shape[:-2] + (last_uA.shape[-1], last_uA.shape[-2]))\n                else:\n                    raise NotImplementedError(\n                        f\"last_uA's type {type(last_uA)} is not supported for transpose in the case of swapping x and weight.\")\n\n        # transpose and scale each term if necessary.\n        input_lb = self._preprocess(*input_lb)\n        input_ub = self._preprocess(*input_ub)\n        lA_y = uA_y = lA_bias = uA_bias = None\n        lbias = ubias = 0\n        batch_size = last_lA.shape[1] if last_lA is not None else last_uA.shape[1]\n        weight = input_lb[1]\n        bias = input_lb[2] if has_bias else None\n\n        def _bound_oneside(last_A, weight_override=None):\n            # For most applications, weight_override should be left as None\n            # This will cause used_weight to be set to weight, which is the weight\n            # assigned to input_lb[1]. The only reason provide an override weight\n            # is if this layer has different weights for it's lower and upper bounds.\n            # That is currently only the case for the implementation of output\n            # constraints, where lower and upper bounds use distinct gammas.\n            if weight_override is None:\n                used_weight = weight\n            else:\n                used_weight = weight_override\n\n            if last_A is None:\n                return None, 0\n            if isinstance(last_A, torch.Tensor):\n                # Matrix mode.\n                # Just multiply this layer's weight into bound matrices, and produce biases.\n                if self.batched_weight_and_bias:\n                    # last_A is the A at the current layer (self)\n                    # next_A is the A for the layer consumed by the current (self) one\n                    # \"next_A\" makes sense because we're backpropagating. However, the below shapes\n                    # will refer to \"prev_layer\", which also is the layer that is consumed by\n                    # the current (self) one. That's because they should match the documentation in\n                    # output_constraints.py, which is written from a \"forward facing\" point of view.\n\n                    # We have: last_A.shape = (unstable_neurons, batch_size, this_layer_neurons)\n                    # We want: next_A.shape = (unstable_neurons, batch_size, prev_layer_neurons)\n\n                    # We also have\n                    # used_weight.shape = (batch_size, this_layer_neurons, prev_layer_neurons)\n\n                    mod_last_A = last_A.unsqueeze(2)\n                    mod_used_weight = used_weight.unsqueeze(0)\n                    # mod_last_A.shape = (unstable_neurons, batch_size, 1, this_layer_neurons)\n                    # mod_used_weight.shape = (1, batch_size, this_layer_neurons, prev_layer_neurons)\n\n                    mod_next_A = mod_last_A.to(mod_used_weight).matmul(mod_used_weight)\n                    # mod_next_A.shape = (unstable_neurons, batch_size, 1, prev_layer_neurons)\n\n                    next_A = mod_next_A.squeeze(2)\n                    # next_A.shape = (unstable_neurons, batch_size, prev_layer_neurons)\n\n                    if has_bias:\n                        # bias.shape = (batch_size, this_layer_neurons)\n\n                        mod_bias = bias.unsqueeze(0).unsqueeze(3)\n                        # mod_bias.shape = (1, batch_size, this_layer_neurons, 1)\n                        # mod_last_A.shape = (unstable_neurons, batch_size, 1, this_layer_neurons)\n\n                        mod_sum_bias = mod_last_A.to(mod_bias).matmul(mod_bias)\n                        # mod_sum_bias.shape = (unstable_neurons, batch_size, 1, 1)\n\n                        sum_bias = mod_sum_bias.squeeze(3).squeeze(2)\n                        # sum_bias.shape = (unstable_neurons, batch_size)\n                else:\n                    next_A = last_A.to(used_weight).matmul(used_weight)\n                    sum_bias = (last_A.to(bias).matmul(bias)\n                        if has_bias else 0.0)\n            else:\n                assert isinstance(last_A, Patches)\n                assert not self.batched_weight_and_bias\n                # Patches mode. After propagating through this layer, it will become a matrix.\n                # Reshape the weight matrix as a conv image.\n                # Weight was in (linear_output_shape, linear_input_shape)\n                # Reshape it to (linear_input_shape, c, h, w)\n                reshaped_weight = used_weight.transpose(0, 1).view(\n                    -1, *last_A.input_shape[1:])\n                # After unfolding the shape is\n                # (linear_input_shape, output_h, output_w, in_c, patch_h, patch_w)\n                unfolded_weight = inplace_unfold(\n                    reshaped_weight,\n                    kernel_size=last_A.patches.shape[-2:],\n                    stride=last_A.stride, padding=last_A.padding,\n                    inserted_zeros=last_A.inserted_zeros,\n                    output_padding=last_A.output_padding)\n                if has_bias:\n                    # Do the same for the bias.\n                    reshaped_bias = bias.view(*last_A.input_shape[1:]).unsqueeze(0)\n                    # After unfolding the bias shape is (1, output_h, output_w, in_c, patch_h, patch_w)\n                    unfolded_bias = inplace_unfold(\n                        reshaped_bias, kernel_size=last_A.patches.shape[-2:],\n                        stride=last_A.stride, padding=last_A.padding,\n                        inserted_zeros=last_A.inserted_zeros,\n                        output_padding=last_A.output_padding)\n                if last_A.unstable_idx is not None:\n                    # In this case, the last_A shape is (num_unstable, batch, out_c, patch_h, patch_w)\n                    # Reshape our weight to (output_h, output_w, 1, in_c, patch_h, patch_w, linear_input_shape), 1 is the inserted batch dim.\n                    unfolded_weight_r = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(2)\n                    # for sparse patches the shape is (unstable_size, batch, in_c, patch_h, patch_w). Batch size is 1 so no need to select here.\n                    # We select in the (output_h, out_w) dimension.\n                    selected_weight = unfolded_weight_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                    next_A = torch.einsum('sbchw,sbchwi->sbi', last_A.patches, selected_weight)\n                    if has_bias:\n                        # Reshape our bias to (output_h, output_w, 1, in_c, patch_h, patch_w). We already have the batch dim.\n                        unfolded_bias_r = unfolded_bias.permute(1, 2, 0, 3, 4, 5)\n                        selected_bias = unfolded_bias_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                        sum_bias = torch.einsum('sbchw,sbchw->sb', last_A.patches, selected_bias)\n                else:\n                    # Reshape our weight to (1, 1, output_h, output_w, in_c, patch_h, patch_w, linear_input_shape), 1 is the spec and batch.\n                    selected_weight = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(0).unsqueeze(0)\n                    next_A_r = torch.einsum('sbpqchw,sbpqchwi->spqbi', last_A.patches, selected_weight)\n                    # We return a matrix with flattened spec dimension (corresponding to out_c * out_h * out_w).\n                    next_A = next_A_r.reshape(-1, next_A_r.size(-2), next_A_r.size(-1))\n                    if has_bias:\n                        # Reshape our bias to (1, 1, output_h, output_w, in_c, patch_h, patch_w)\n                        selected_bias = unfolded_bias.unsqueeze(0)\n                        sum_bias_r = torch.einsum('sbpqchw,sbpqchw->spqb', last_A.patches, selected_bias)\n                        sum_bias = sum_bias_r.reshape(-1, sum_bias_r.size(-1))\n            return next_A, sum_bias if has_bias else 0.0\n\n        # Case #1: No weight/bias perturbation, only perturbation on input.\n        if ((not self.is_input_perturbed(0) or not self.is_input_perturbed(1)) and \n            (not has_bias or not self.is_input_perturbed(2))):\n            # If last_lA and last_uA are indentity matrices.\n            # FIXME (12/28): we should check last_lA and last_uA separately.\n            # Same applies to the weight perturbed, bias perturbed settings.\n\n            def multiply_with_weight(weight, set_l: bool, set_u: bool):\n                lA_x = uA_x = None\n                lbias = ubias = 0.\n                if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC):\n                    # Use this layer's W as the next bound matrices.\n                    # Shape of inputs: (B, s_k, s_{k-1}, ..., s_1, m, n) @ (s_l, s_{l-1}, ..., s_1, n, p)\n                    #               or (B, s_k, s_{k-1}, ..., s_1, m, n) @ (B, s_k, s_{k-1}, ..., s_1, n, p)\n                    # Shape of output: (B, s_k, ..., s_1, m, p)\n                    # last_lA: (specs, B, s_k, ..., s_1, m, p)\n                    # weight: (s_l, ..., s_1, p, n) where l <= k, or (B, s_k, ..., s_1, p, n)\n\n                    if len(last_lA.shape) == 3:\n                        # input x is a vector\n                        m = 1\n                        p = last_lA.shape[-1]\n                    else:\n                        # general input shape\n                        m, p = last_lA.shape[-2:]\n                    n = weight.size(-1)\n\n                    assert last_lA.shape == last_uA.shape\n                    # shape of \"broadcast dimensions\" \\prod_{i=1...k} s_i\n                    shape_broadcast = last_lA.shape[2:-2]\n                    prod_broadcast = prod(shape_broadcast)\n                    ndim_broadcast = len(shape_broadcast)\n\n                    assert weight.ndim - 3 <= ndim_broadcast, \"Broadcasting on input 'x' is not supported.\"\n                    weight_has_batch = weight.ndim - 3 == ndim_broadcast\n\n                    # A_identity: (s_k, ...s_1, m, 1, s_k, ..., s_1, m, 1) where two 1s are for the two \"matmul dimensions\"\n                    A_identity = torch.eye(\n                        prod_broadcast * m, device=weight.device, dtype=weight.dtype\n                    ).view(*shape_broadcast, m, 1, *shape_broadcast, m, 1)\n                    # Assert specs = {product of shape of output} = \\prod s_i * m * p\n                    assert last_lA.shape[0] == prod_broadcast * m * p\n\n                    if not weight_has_batch:\n                        # Pad the \"broadcast dimensions\" of weight according to shape of input\n                        # (s_l, ..., s_1, p, n) -> (1, ..., 1, s_l, ..., s_1, p, n) where there are (k-l) 1s\n                        w_padding = weight.reshape(*[1] * (ndim_broadcast + 2 - len(weight.shape)), *weight.shape)\n                        # Duplicate the \"broadcast dimensions\" to match both sides of A_identity\n                        # (*broadcast_dims, p, n) -> (*broadcast_dims, p, *broadcast_dims, n)\n                        w_eye_mask = torch.eye(prod_broadcast, device=weight.device, dtype=weight.dtype).reshape(*shape_broadcast, 1, *shape_broadcast, 1)\n                        w = w_eye_mask * w_padding.reshape(*w_padding.shape[:-1], *[1] * (len(w_padding.shape) - 2), w_padding.size(-1))\n                        # Add two slots for the \"m\" dimension in A_identity\n                        # (*broadcast_dims, p, *broadcast_dims, n) -> (*broadcast_dims, 1, p, *broadcast_dims, 1, n)\n                        w = w.view(*w.shape[:ndim_broadcast], 1, p, *w.shape[:ndim_broadcast], 1, n)\n                        w = w * A_identity  # (*broadcast_dims, m, p, *broadcast_dims, m, n)\n                        # expand the batch_size dim\n                        # (*broadcast_dims, m, p, *broadcast_dims, m, n) -> (Prod(broadcast_dims)*m*p, B, *broadcast_dims, m, n)\n                        tmp_A_x = w.reshape(last_lA.shape[0], 1, *last_lA.shape[2:-1], weight.size(-1)).expand(last_lA.shape[0], *last_lA.shape[1:-1], weight.size(-1))\n                    else:\n                        # There's no need to pad the weight tensor if it has a batch dimension.\n                        # Duplicate the \"broadcast dimensions\" to match both sides of A_identity\n                        # (B, *broadcast_dims, p, n) -> (B, *broadcast_dims, p, *broadcast_dims, n)\n                        w_eye_mask = torch.eye(prod_broadcast, device=weight.device, dtype=weight.dtype).reshape(*shape_broadcast, 1, *shape_broadcast, 1)\n                        w = w_eye_mask * weight.reshape(*weight.shape[:-1], *[1] * (len(weight.shape) - 3), weight.size(-1))\n                        # Add two slots for the \"m\" dimension in A_identity\n                        # (B, *broadcast_dims, p, *broadcast_dims, n) -> (B, *broadcast_dims, 1, p, *broadcast_dims, 1, n)\n                        w = w.view(w.shape[0], *w.shape[1:ndim_broadcast+1], 1, p, *w.shape[1:ndim_broadcast+1], 1, n)\n                        w = w * A_identity  # (B, *broadcast_dims, m, p, *broadcast_dims, m, n)\n                        # (B, *broadcast_dims, m, p, *broadcast_dims, m, n) -> (Prod(broadcast_dims)*m*p, B, *broadcast_dims, m, n)\n                        tmp_A_x = w.reshape(w.shape[0], last_lA.shape[0], *last_lA.shape[2:-1], weight.size(-1)).transpose(0, 1)                            \n                    if set_l:\n                        lA_x = tmp_A_x\n                    if set_u:\n                        uA_x = tmp_A_x\n\n                    if has_bias:\n                        tmp_bias = bias.unsqueeze(1).repeat(1, batch_size)\n                        if set_l:\n                            lbias = tmp_bias\n                        if set_u:\n                            ubias = tmp_bias\n                elif isinstance(last_lA, OneHotC) or isinstance(last_uA, OneHotC):\n                    # We need to select several rows from the weight matrix\n                    # (its shape is output_size * input_size).\n                    if set_l:\n                        lA_x, lbias = self.onehot_mult(weight, bias, last_lA, batch_size)\n                    if last_lA is last_uA and set_l and set_u:\n                        uA_x = lA_x\n                        ubias = lbias\n                    elif set_u:\n                        uA_x, ubias = self.onehot_mult(weight, bias, last_uA, batch_size)\n                else:\n                    if set_l:\n                        lA_x, lbias = _bound_oneside(last_lA, weight_override=weight)\n                    if set_u:\n                        uA_x, ubias = _bound_oneside(last_uA, weight_override=weight)\n                return lA_x, uA_x, lbias, ubias\n\n            if self.use_seperate_weights_for_lower_and_upper_bounds:\n                lA_x, _, lbias, _ = multiply_with_weight(input_lb[1], set_l=True, set_u=False)\n                _, uA_x, _, ubias = multiply_with_weight(input_ub[1], set_l=False, set_u=True)\n            else:\n                lA_x, uA_x, lbias, ubias = multiply_with_weight(weight, set_l=True, set_u=True)\n\n        # Case #2: weight is perturbed. bias may or may not be perturbed.\n        elif self.is_input_perturbed(1):\n            assert not self.use_seperate_weights_for_lower_and_upper_bounds\n            # Obtain relaxations for matrix multiplication.\n            [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias = self.bound_backward_with_weight(\n                last_lA, last_uA, input_lb, input_ub, x[0], x[1],\n                reduce_bias=reduce_bias, **kwargs)\n            if has_bias:\n                assert reduce_bias\n                if x[2].perturbation is not None:\n                    # Bias is also perturbed. Since bias is directly added to the\n                    # output, in backward mode it is treated as an input with\n                    # last_lA and last_uA as associated bounds matrices.\n                    # It's okay if last_lA or last_uA is eyeC, as it will be\n                    # handled in the perturbation object.\n                    lA_bias = last_lA\n                    uA_bias = last_uA\n                else:\n                    # Bias not perturbed, so directly adding the bias of this\n                    # layer to the final bound bias term.\n                    if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC):\n                        # Bias will be directly added to output.\n                        lbias += input_lb[2].unsqueeze(1).repeat(1, batch_size)\n                        ubias += input_lb[2].unsqueeze(1).repeat(1, batch_size)\n                    else:\n                        if last_lA is not None:\n                            lbias += last_lA.matmul(input_lb[2])\n                        if last_uA is not None:\n                            ubias += last_uA.matmul(input_lb[2])\n            # If not has_bias, no need to compute lA_bias and uA_bias\n        # Case 3: Only bias is perturbed, weight is not perturbed.\n        elif not self.is_input_perturbed(1) and has_bias and self.is_input_perturbed(2):\n            assert not self.use_seperate_weights_for_lower_and_upper_bounds\n            assert reduce_bias\n            if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC):\n                # Use this layer's W as the next bound matrices. Duplicate the\n                # batch dimension. Other dimensions are kept 1.\n                lA_x = uA_x = input_lb[1].unsqueeze(1).repeat(\n                    [1, batch_size] + [1] * (input_lb[1].ndim - 1))\n            else:\n                lA_x = last_lA.matmul(input_lb[1])\n                uA_x = last_uA.matmul(input_lb[1])\n            # It's okay if last_lA or last_uA is eyeC, as it will be handled in the perturbation object.\n            lA_bias = last_lA\n            uA_bias = last_uA\n        else:\n            assert not self.use_seperate_weights_for_lower_and_upper_bounds\n\n        if self.swap_x_and_weight:\n            return [(None, None),\n                    (lA_x.transpose(-1, -2) if lA_x is not None else None,\n                     uA_x.transpose(-1, -2) if uA_x is not None else None),\n                    (lA_bias, uA_bias)], lbias, ubias\n        return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias\n\n    def _reshape(self, x_l, x_u, y_l, y_u):\n        x_shape, y_shape = self.input_shape, self.y_shape\n\n        # (x_1, x_2, ..., x_{n-1}, -1, x_n) # FIXME\n        x_l = x_l.unsqueeze(-2)\n        x_u = x_u.unsqueeze(-2)\n\n        # FIXME merge these two cases\n        if len(x_shape) == len(y_shape):\n            # (x_1, x_2, ..., -1, y_n, y_{n-1})\n            y_l = y_l.unsqueeze(-3)\n            y_u = y_u.unsqueeze(-3)\n        elif len(y_shape) == 2:\n            # (x_1, x_2, ..., -1, y_2, y_1)\n            y_l = y_l.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3)\n            y_u = y_u.reshape(*([1] * (len(x_shape) - 2)), *y_shape).unsqueeze(-3)\n        else:\n            raise ValueError(f'Unsupported shapes: x_shape {x_shape}, y_shape {y_shape}')\n\n        return x_l, x_u, y_l, y_u\n\n    @staticmethod\n    # @torch.jit.script\n    def propagate_A_xy(last_A: Tensor, alpha_pos: Tensor, alpha_neg: Tensor,\n                       beta_pos: Tensor, beta_neg: Tensor,\n                       dim_y: List[int]) -> Tuple[Tensor, Tensor]:\n        # last_uA has size (batch, spec, output)\n        last_A_pos = last_A.clamp(min=0).unsqueeze(-1)\n        last_A_neg = last_A.clamp(max=0).unsqueeze(-1)\n        # alpha_u has size (batch, spec, output, input)\n        # uA_x has size (batch, spec, input).\n        A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) +\n                alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1)\n        # beta_u has size (batch, spec, output, input)\n        # uA_y is for weight matrix, with parameter size (output, input)\n        # uA_y has size (batch, spec, output, input). This is an element-wise multiplication.\n        # TODO (for zhouxing/qirui): generalize multiply_by_A_signs() to calculate A_x,\n        # so last_A_pos and last_A_neg are not needed. This saves memory.\n        A_y, _ = multiply_by_A_signs(last_A.unsqueeze(-1), beta_pos, beta_neg, None, None)\n        if len(dim_y) != 0:\n            A_y = torch.sum(A_y, dim=dim_y)\n        return A_x, A_y\n\n    def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub,\n                                   x, y, reduce_bias=True, **kwargs):\n        # FIXME This is nonlinear. Move to `bivariate.py`.\n\n        # Note: x and y are not tranposed or scaled, and we should avoid using them directly.\n        # Use input_lb and input_ub instead.\n        (alpha_l, beta_l, gamma_l,\n         alpha_u, beta_u, gamma_u) = self.mul_helper.get_relaxation(\n            *self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1]),\n            self.opt_stage, getattr(self, 'alpha', None),\n            getattr(self, '_start', None), middle=self.mul_middle)\n        x_shape = input_lb[0].size()\n        if reduce_bias:\n            gamma_l = torch.sum(gamma_l, dim=-1)\n            gamma_u = torch.sum(gamma_u, dim=-1)\n\n        if len(x.output_shape) != 2 and len(x.output_shape) == len(y.output_shape):\n            dim_y = [-3]\n        elif len(y.output_shape) == 2:\n            dim_y = list(range(2, 2 + len(x_shape) - 2))\n        else:\n            raise NotImplementedError\n\n        def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg, gamma_neg):\n            if last_A is None:\n                return None, None, 0\n            if isinstance(last_A, eyeC):  # FIXME (12/28): Handle the OneHotC case.\n                #FIXME previous implementation is incorrect\n                #      expanding eyeC for now\n                last_A = (torch.eye(last_A.shape[0], device=last_A.device)\n                    .view(last_A.shape[0], 1, *last_A.shape[2:]).expand(last_A.shape))\n\n            A_x, A_y = BoundLinear.propagate_A_xy(\n                last_A, alpha_pos, alpha_neg, beta_pos, beta_neg, dim_y)\n\n            if reduce_bias:\n                # last_uA has size (batch, spec, output)\n                # gamma_u has size (batch, output, 1)\n                # ubias has size (batch, spec, 1)\n                if self.opt_stage in ['opt', 'reuse']:\n                    bias = (torch.einsum('sb...,sb...->sb',\n                                        last_A.clamp(min=0), gamma_pos)\n                            + torch.einsum('sb...,sb...->sb',\n                                        last_A.clamp(max=0), gamma_neg))\n                else:\n                    bias = (\n                        self.get_bias(last_A.clamp(min=0), gamma_pos)\n                        + self.get_bias(last_A.clamp(max=0), gamma_neg)\n                    )\n            else:\n                assert self.batch_dim == 0\n                assert self.opt_stage not in ['opt', 'reuse']\n                assert dim_y == [-3]\n                bias = (last_A.unsqueeze(-1).clamp(min=0) * gamma_pos\n                        + last_A.unsqueeze(-1).clamp(max=0) * gamma_neg)\n                bias_x = bias.sum(dim=-2)\n                bias_y = bias.sum(dim=-3)\n                bias = (bias_x, bias_y)\n            return A_x, A_y, bias\n\n        if self.opt_stage in ['opt', 'reuse']:\n            lA_x, lA_y, lbias = _bound_oneside(\n                last_lA, alpha_l[0], beta_l[0], gamma_l[0],\n                alpha_u[0], beta_u[0], gamma_u[0])\n            uA_x, uA_y, ubias = _bound_oneside(\n                last_uA, alpha_u[1], beta_u[1], gamma_u[1],\n                alpha_l[1], beta_l[1], gamma_l[1])\n        else:\n            lA_x, lA_y, lbias = _bound_oneside(\n                last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u)\n            uA_x, uA_y, ubias = _bound_oneside(\n                last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l)\n\n        return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias\n\n    @staticmethod\n    def _propagate_Linf(x, w):\n        h_L, h_U = x\n        mid = (h_L + h_U) / 2\n        diff = (h_U - h_L) / 2\n        w_abs = w.abs()\n        if mid.ndim == 2 and w.ndim == 3:\n            center = torch.bmm(mid.unsqueeze(1), w.transpose(-1, -2)).squeeze(1)\n            deviation = torch.bmm(diff.unsqueeze(1), w_abs.transpose(-1, -2)).squeeze(1)\n        else:\n            center = mid.matmul(w.transpose(-1, -2))\n            deviation = diff.matmul(w_abs.transpose(-1, -2))\n        return center, deviation\n\n    def interval_propagate(self, *v, C=None, w=None):\n        has_bias = self is not None and len(v) == 3\n        if self is not None:\n            # This will convert an Interval object to tuple.\n            # We need to add perturbation property later.\n            v_lb, v_ub = zip(*v)\n            v_lb = self._preprocess(*v_lb)\n            v_ub = self._preprocess(*v_ub)\n            # After preprocess the lower and upper bounds, we make them Intervals again.\n            v = [Interval.make_interval(bounds[0], bounds[1], bounds[2])\n                 for bounds in zip(v_lb, v_ub, v)]\n        if w is None and self is None:\n            # Use C as the weight, no bias.\n            w, lb, ub = C, torch.tensor(0., device=C.device), torch.tensor(0., device=C.device)\n        else:\n            if w is None:\n                # No specified weight, use this layer's weight.\n                if self.is_input_perturbed(1):  # input index 1 is weight.\n                    # w is a perturbed tensor. Use IBP with weight perturbation.\n                    # C matrix merging not supported.\n                    assert C is None\n                    res = self.interval_propagate_with_weight(*v)\n                    l, u = res\n                    if has_bias:\n                        return l + v[2][0], u + v[2][1]\n                    else:\n                        return l, u\n                else:\n                    # Use weight\n                    w = v[1][0]\n            if has_bias:\n                lb, ub = v[2]\n            else:\n                lb = ub = 0.0\n\n            if C is not None:\n                w = C.matmul(w)\n                lb = C.matmul(lb) if not isinstance(lb, float) else lb\n                ub = C.matmul(ub) if not isinstance(ub, float) else ub\n\n        # interval_propagate() of the Linear layer may encounter input with different norms.\n        norm, eps = Interval.get_perturbation(v[0])[:2]\n        if norm == torch.inf:\n            interval = BoundLinear._propagate_Linf(v[0], w)\n            center, deviation = interval\n        elif norm > 0:\n            # General Lp norm.\n            norm, eps = Interval.get_perturbation(v[0])\n            mid = v[0][0]\n            dual_norm = np.float64(1.0) / (1 - 1.0 / norm)\n            if w.ndim == 3:\n                # Extra batch dimension.\n                # mid has dimension [batch, input], w has dimension [batch, output, input].\n                center = w.matmul(mid.unsqueeze(-1)).squeeze(-1)\n            else:\n                # mid has dimension [batch, input], w has dimension [output, input].\n                center = mid.matmul(w.t())\n            deviation = w.norm(dual_norm, dim=-1) * eps\n        else:\n            # here we calculate the L0 norm IBP bound of Linear layers,\n            # using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020]\n            norm, eps, ratio = Interval.get_perturbation(v[0])\n            mid = v[0][0]\n            weight_abs = w.abs()\n            if w.ndim == 3:\n                # Extra batch dimension.\n                # mid has dimension [batch, input], w has dimension [batch, output, input].\n                center = w.matmul(mid.unsqueeze(-1)).squeeze(-1)\n            else:\n                # mid has dimension [batch, input], w has dimension [output, input].\n                center = mid.matmul(w.t())\n            # L0 norm perturbation\n            k = int(eps)\n            deviation = torch.sum(torch.topk(weight_abs, k)[0], dim=1) * ratio\n\n        lower, upper = center - deviation + lb, center + deviation + ub\n\n        return (lower, upper)\n\n    def interval_propagate_with_weight(self, *v):\n        input_norm, input_eps = Interval.get_perturbation(v[0])\n        weight_norm, weight_eps = Interval.get_perturbation(v[1])\n\n        if input_norm == torch.inf and weight_norm == torch.inf:\n            # A memory-efficient implementation without expanding all the elementary multiplications\n            if self.opt_matmul == 'economic':\n                x_l, x_u = v[0][0], v[0][1]\n                y_l, y_u = v[1][0].transpose(-1, -2), v[1][1].transpose(-1, -2)\n\n                dx, dy = F.relu(x_u - x_l), F.relu(y_u - y_l)\n                base = x_l.matmul(y_l)\n\n                mask_xp, mask_xn = (x_l > 0).to(x_l.dtype), (x_u < 0).to(x_u.dtype)\n                mask_xpn = 1 - mask_xp - mask_xn\n                mask_yp, mask_yn = (y_l > 0).to(y_l.dtype), (y_u < 0).to(y_u.dtype)\n                mask_ypn = 1 - mask_yp - mask_yn\n\n                lower, upper = base.clone(), base.clone()\n\n                lower += dx.matmul(y_l.clamp(max=0)) - (dx * mask_xn).matmul(y_l * mask_ypn)\n                upper += dx.matmul(y_l.clamp(min=0)) + (dx * mask_xp).matmul(y_l * mask_ypn)\n\n                lower += x_l.clamp(max=0).matmul(dy) - (x_l * mask_xpn).matmul(dy * mask_yn)\n                upper += x_l.clamp(min=0).matmul(dy) + (x_l * mask_xpn).matmul(dy * mask_yp)\n\n                lower += (dx * mask_xn).matmul(dy * mask_yn)\n                upper += (dx * (mask_xpn + mask_xp)).matmul(dy * (mask_ypn + mask_yp))\n            else:\n                # Both input data and weight are Linf perturbed (with upper and lower bounds).\n                # We need a x_l, x_u for each row of weight matrix.\n                x_l, x_u = v[0][0].unsqueeze(-2), v[0][1].unsqueeze(-2)\n                y_l, y_u = v[1][0].unsqueeze(-3), v[1][1].unsqueeze(-3)\n                # Reuse the multiplication bounds and sum over results.\n                lower, upper = BoundMul.interval_propagate_both_perturbed(*[(x_l, x_u), (y_l, y_u)])\n                lower, upper = torch.sum(lower, -1), torch.sum(upper, -1)\n\n            return lower, upper\n        elif input_norm == torch.inf and weight_norm == 2:\n            # This eps is actually the epsilon per row, as only one row is involved for each output element.\n            eps = weight_eps\n            # Input data and weight are Linf perturbed (with upper and lower bounds).\n            h_L, h_U = v[0]\n            # First, handle non-perturbed weight with Linf perturbed data.\n            center, deviation = BoundLinear._propagate_Linf(v[0], v[1][0])\n            # Compute the maximal L2 norm of data. Size is [batch, 1].\n            max_l2 = torch.max(h_L.abs(), h_U.abs()).norm(2, dim=-1).unsqueeze(-1)\n            # Add the L2 eps to bounds.\n            lb, ub = center - deviation - max_l2 * eps, center + deviation + max_l2 * eps\n            return lb, ub\n        else:\n            raise NotImplementedError(\n                \"Unsupported perturbation combination: data={}, weight={}\".format(input_norm, weight_norm))\n\n    @staticmethod\n    @torch.jit.script\n    def bound_forward_mul(x_lw: Tensor, x_lb: Tensor, x_uw: Tensor, x_ub: Tensor,\n                          w: Tensor, weight_has_batch: bool = False, swap_x_and_weight: bool = False):\n        w_pos = w.clamp(min=0)\n        w_neg = w.clamp(max=0)\n        if swap_x_and_weight:\n            lw = matmul_maybe_batched(w_pos, x_lw, weight_has_batch) + matmul_maybe_batched(w_neg, x_uw, weight_has_batch)\n            uw = matmul_maybe_batched(w_pos, x_uw, weight_has_batch) + matmul_maybe_batched(w_neg, x_lw, weight_has_batch)\n            lb = matmul_maybe_batched(w_pos, x_lb, weight_has_batch) + matmul_maybe_batched(w_neg, x_ub, weight_has_batch)\n            ub = matmul_maybe_batched(w_pos, x_ub, weight_has_batch) + matmul_maybe_batched(w_neg, x_lb, weight_has_batch)\n        else:\n            lw = matmul_maybe_batched(x_lw, w_pos, weight_has_batch) + matmul_maybe_batched(x_uw, w_neg, weight_has_batch)\n            uw = matmul_maybe_batched(x_uw, w_pos, weight_has_batch) + matmul_maybe_batched(x_lw, w_neg, weight_has_batch)\n            lb = matmul_maybe_batched(x_lb, w_pos, weight_has_batch) + matmul_maybe_batched(x_ub, w_neg, weight_has_batch)\n            ub = matmul_maybe_batched(x_ub, w_pos, weight_has_batch) + matmul_maybe_batched(x_lb, w_neg, weight_has_batch)\n        return lw, lb, uw, ub\n\n    # w: an optional argument which can be utilized by BoundMatMul\n    def bound_dynamic_forward(self, x, w=None, b=None, C=None, max_dim=None, offset=0):\n        assert not self.transA and self.alpha_linear == 1.0 and self.transB and self.beta_linear == 1.0\n        assert not self.is_input_perturbed(1)\n        assert not self.is_input_perturbed(2)\n\n        weight = w.lb\n        bias = b.lb if b is not None else None\n        if C is not None:\n            weight = C.to(weight).matmul(weight).transpose(-1, -2)\n            if bias is not None:\n                bias = C.to(bias).matmul(bias)\n            lb = x.lb.unsqueeze(1)\n        else:\n            weight = weight.transpose(-1, -2)\n            lb = x.lb\n\n        w_new = x.lw.matmul(weight)\n        b_new = lb.matmul(weight)\n        if C is not None:\n            b_new = b_new.squeeze(1)\n        if bias is not None:\n            b_new += bias\n\n        return LinearBound(w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)\n\n    # w: an optional argument which can be utilized by BoundMatMul\n    def bound_forward(self, dim_in, x, w=None, b=None, C=None, weight_has_batch=False):\n        has_bias = b is not None\n        #FIXME _preprocess can only be applied to tensors so far but not linear bounds.\n        x, w, b = self._preprocess(x, w, b)\n\n        # Shape of x: (B, s_k, s_{k-1}, ..., s_1, m, n)\n        # Shape of w: (s_l, s_{l-1}, ..., s_1, p, n) or (B, s_k, s_{k-1}, ..., s_1, p, n) if weight_has_batch\n        # Forward pass: (B, s_k, s_{k-1}, ..., s_1, m, n) @ (s_l, s_{l-1}, ..., s_1, p, n)^T\n        # Here, the transpose of w means transposing the last two dimensions of w.\n\n        # Case #1: No weight/bias perturbation, only perturbation on input.\n        if ((not self.is_input_perturbed(0) or not self.is_input_perturbed(1)) and\n            (not has_bias or not self.is_input_perturbed(2))):\n            if isinstance(w, LinearBound):\n                w = w.lower\n            if isinstance(b, LinearBound):\n                b = b.lower\n            if C is not None:\n                w = C.to(w).matmul(w).transpose(-1, -2)\n                if b is not None:\n                    b = C.to(b).matmul(b)\n                x_lb, x_ub = x.lb.unsqueeze(1), x.ub.unsqueeze(1)\n            else:\n                w = w.transpose(-1, -2)\n                x_lb, x_ub = x.lb, x.ub\n            lw, lb, uw, ub = BoundLinear.bound_forward_mul(\n                x.lw, x_lb, x.uw, x_ub, w, weight_has_batch,\n                swap_x_and_weight=self.is_input_perturbed(1))\n\n            if C is not None:\n                lb, ub = lb.squeeze(1), ub.squeeze(1)\n\n            if b is not None:\n                lb += b\n                ub += b\n        # Case #2: weight is perturbed. bias may or may not be perturbed.\n        elif self.is_input_perturbed(1):\n            if C is not None:\n                raise NotImplementedError\n            res = self.bound_forward_with_weight(dim_in, x, w)\n            if has_bias:\n                raise NotImplementedError\n            lw, lb, uw, ub = res.lw, res.lb, res.uw, res.ub\n        # Case 3: Only bias is perturbed, weight is not perturbed.\n        elif not self.is_input_perturbed(1) and has_bias and self.is_input_perturbed(2):\n            raise NotImplementedError\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def bound_forward_with_weight(self, dim_in, x, y):\n        # x has shape (B, s_k, s_{k-1}, ..., s_1, m, n)\n        # y has shape (B, s_k, s_{k-1}, ..., s_1, p, n)\n        # We need to reshape x and y to (B, s_k, s_{k-1}, ..., s_1, m, 1, n)\n        #                           and (B, s_k, s_{k-1}, ..., s_1, 1, p, n)\n        # respectively.\n        # Then we can use the bound_forward_mul function to compute the bounds\n        # for element-wise multiplication and sum over the last dimension.\n        # The result will have shape (B, s_k, s_{k-1}, ..., s_1, m, p)\n        x_unsqueeze = LinearBound(\n            x.lw.unsqueeze(-2), x.lb.unsqueeze(-2),\n            x.uw.unsqueeze(-2), x.ub.unsqueeze(-2),\n            x.lower.unsqueeze(-2), x.upper.unsqueeze(-2),\n        )\n        y_unsqueeze = LinearBound(\n            y.lw.unsqueeze(-3), y.lb.unsqueeze(-3),\n            y.uw.unsqueeze(-3), y.ub.unsqueeze(-3),\n            y.lower.unsqueeze(-3), y.upper.unsqueeze(-3),\n        )\n        res_mul = BoundMul.bound_forward_both_perturbed(self, dim_in, x_unsqueeze, y_unsqueeze)\n        return LinearBound(\n            res_mul.lw.sum(dim=-1) if res_mul.lw is not None else None,\n            res_mul.lb.sum(dim=-1),\n            res_mul.uw.sum(dim=-1) if res_mul.uw is not None else None,\n            res_mul.ub.sum(dim=-1)\n        )\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        has_bias = self is not None and len(v) == 3\n        # Aggregate a batch of bounds by taking minimum/maximum over the batch dimension.\n        out_lbs = self.lower.min(dim=0).values.detach().cpu().numpy() if self.lower is not None else None\n        out_ubs = self.upper.max(dim=0).values.detach().cpu().numpy() if self.upper is not None else None\n\n        # current layer weight (out_width, in_width)\n        this_layer_weight = v[1]\n        if self.transB == 0:\n            this_layer_weight = this_layer_weight.transpose(1, 0)\n        #### make sure if this is correct for per-label operations\n        if C is not None:\n            # merge specification C into last layer weights\n            # only last layer has C not None\n            this_layer_weight = C.squeeze(0).mm(this_layer_weight)\n        this_layer_weight = this_layer_weight.detach().cpu().numpy()\n        this_layer_shape = this_layer_weight.shape\n\n        this_layer_bias = None\n        if has_bias:\n            # current layer bias (out_width,)\n            this_layer_bias = v[2]\n            if C is not None:\n                this_layer_bias = C.squeeze(0).mm(this_layer_bias.unsqueeze(-1)).view(-1)\n            this_layer_bias = this_layer_bias.detach().cpu().numpy()\n\n        new_layer_gurobi_vars = []\n\n        for neuron_idx in range(this_layer_shape[0]):\n            out_lb = out_lbs[neuron_idx] if out_lbs is not None else -float('inf')\n            out_ub = out_ubs[neuron_idx] if out_ubs is not None else float('inf')\n            if out_lbs is not None and out_ubs is not None:\n                \"\"\"\n                    If the inferred lb and ub are too close, it could lead to floating point disagreement\n                    between solver's inferred lb and ub constraints and the computed ones from ab-crown.\n                    Such disagreement can lead to \"infeasible\" result from the solver for feasible problem.\n                    Also, prevent lb to be larger than ub due to the floating point issue.\n                    To avoid so, we relax the box constraints.\n                    This should not affect the solver's result correctness,\n                    since the tighter lb and ub can be inferred by the solver.\n                \"\"\"\n                if out_lb != float('-inf') and out_ub != float('inf'):\n                    diff = out_ub - out_lb\n                    avg = (out_ub + out_lb) / 2.0\n                    condition = (diff < EPS)\n                    out_lb = np.where(condition, avg - EPS / 2.0, out_lb)\n                    out_ub = np.where(condition, avg + EPS / 2.0, out_ub)\n            lin_expr = 0\n            if has_bias:\n                lin_expr = this_layer_bias[neuron_idx].item()\n            coeffs = this_layer_weight[neuron_idx, :]\n\n            if solver_pkg == 'gurobi':\n                lin_expr += grb.LinExpr(coeffs, v[0])\n            else:\n                # FIXME (01/12/22): This is slow, must be fixed using addRow() or similar.\n                for i in range(len(coeffs)):\n                    try:\n                        lin_expr += coeffs[i] * v[0][i]\n                    except TypeError:\n                        lin_expr += coeffs[i] * v[0][i].var\n\n            var = model.addVar(lb=out_lb, ub=out_ub, obj=0,\n                                    vtype=grb.GRB.CONTINUOUS,\n                                    name=f'lay{self.name}_{neuron_idx}')\n            model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')\n            new_layer_gurobi_vars.append(var)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()\n\n    def build_gradient_node(self, grad_upstream):\n        if not self.is_input_perturbed(1):\n            if isinstance(self.inputs[1], BoundParams):\n                w = self.inputs[1].param\n            elif isinstance(self.inputs[1], BoundBuffers):\n                w = self.inputs[1].buffer\n            else:\n                w = self.inputs[1].value\n            if not self.transB:\n                w = w.t()\n            node_grad = LinearGrad(w.detach())\n            return [(node_grad, (grad_upstream,), [])]\n        else:\n            raise NotImplementedError(\n                \"Gradient computation for weight perturbation is not supported yet.\")\n\n    def update_requires_input_bounds(self):\n        self._check_weight_perturbation()\n\n\nclass BoundMatMul(BoundLinear):\n    # Reuse most functions from BoundLinear.\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.transA = 0\n        self.transB = 0\n        self.splittable = True\n\n    def forward(self, x, y):\n        self.x_shape = x.shape\n        self.y_shape = y.shape\n        return x.matmul(y)\n\n    def interval_propagate(self, *v, C=None):\n        lower, upper = super().interval_propagate(*v, C=C)\n        return lower, upper\n\n    def bound_backward(self, last_lA, last_uA, *x, start_node=None, **kwargs):\n        assert len(x) == 2\n        # Determine if two inputs should be swapped\n        self.swap_x_and_weight = not self.is_input_perturbed(0) and self.is_input_perturbed(1)\n        idx_weight = 0 if self.swap_x_and_weight else 1\n        if start_node is not None:\n            self._start = start_node.name\n        results = list(super().bound_backward(last_lA, last_uA, *x, **kwargs))\n        # Transpose weight-related tensors\n        def transpose_weight(A_weight):\n            return A_weight.transpose(-1, -2) if A_weight is not None else None\n        results[0][idx_weight] = (transpose_weight(results[0][idx_weight][0]),\n                                  transpose_weight(results[0][idx_weight][1]))\n        if isinstance(results[1], tuple):\n            lbias = (results[1][0], results[1][1].transpose(-1, -2))\n        else:\n            lbias = results[1]\n        if isinstance(results[2], tuple):\n            ubias = (results[2][0], results[2][1].transpose(-1, -2))\n        else:\n            ubias = results[2]\n        # Reduce the broadcast dimensions\n        lA_x = self.broadcast_backward(results[0][0][0], x[0])\n        uA_x = self.broadcast_backward(results[0][0][1], x[0])\n        lA_y = self.broadcast_backward(results[0][1][0], x[1])\n        uA_y = self.broadcast_backward(results[0][1][1], x[1])\n        return [(lA_x, uA_x), (lA_y, uA_y), results[0][2]], lbias, ubias\n\n    def bound_forward(self, dim_in, x, y):\n        def _bound_forward(x, y, weight_index=1):\n            # We assume that x is perturbed and y is not perturbed (weight).\n            weight_has_batch = (self.inputs[weight_index].batch_dim != -1)\n            return super(BoundMatMul, self).bound_forward(dim_in, x, LinearBound(\n                y.lw.transpose(-1, -2) if y.lw is not None else None,\n                y.lb.transpose(-1, -2) if y.lb is not None else None,\n                y.uw.transpose(-1, -2) if y.uw is not None else None,\n                y.ub.transpose(-1, -2) if y.ub is not None else None,\n                y.lower.transpose(-1, -2) if y.lower is not None else None,\n                y.upper.transpose(-1, -2) if y.upper is not None else None\n            ), weight_has_batch=weight_has_batch)\n        \n        # Check if we need to swap x and y\n        if not self.is_input_perturbed(0) and self.is_input_perturbed(1):\n            return _bound_forward(y, x, weight_index=0)\n        else:\n            return _bound_forward(x, y, weight_index=1)\n\n    def update_requires_input_bounds(self):\n        # If any multiplier is a constant, we do not need input bounds.\n        self.is_linear_op = not self.inputs[1].perturbed or not self.inputs[0].perturbed\n        if self.is_linear_op:\n            # One input is constant; no bounds required.\n            self.requires_input_bounds = []\n            self.splittable = False\n        else:\n            # Both inputs are perturbed. Need relaxation.\n            self.requires_input_bounds = [0, 1]\n            if not self.force_not_splittable:\n                self.splittable = True\n\n\nclass BoundNeg(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.ibp_intermediate = True\n\n    def forward(self, x):\n        return -x\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        if type(last_lA) == Tensor or type(last_uA) == Tensor:\n            return [(-last_lA if last_lA is not None else None,\n                 -last_uA if last_uA is not None else None)], 0, 0\n        elif type(last_lA) == Patches or type(last_uA) == Patches:\n            if last_lA is not None:\n                lA = Patches(-last_lA.patches, last_lA.stride, last_lA.padding,\n                             last_lA.shape, unstable_idx=last_lA.unstable_idx,\n                             output_shape=last_lA.output_shape)\n            else:\n                lA = None\n\n            if last_uA is not None:\n                uA = Patches(-last_uA.patches, last_uA.stride, last_uA.padding,\n                             last_uA.shape, unstable_idx=last_uA.unstable_idx,\n                             output_shape=last_uA.output_shape)\n            else:\n                uA = None\n            return [(lA, uA)], 0, 0\n        else:\n            raise NotImplementedError\n\n    def bound_forward(self, dim_in, x):\n        return LinearBound(-x.uw, -x.ub, -x.lw, -x.lb)\n\n    def interval_propagate(self, *v):\n        return -v[0][1], -v[0][0]\n\n    def build_gradient_node(self, grad_upstream):\n        return [(NegGrad(), (grad_upstream,), [])]\n\n\nclass NegGrad(Module):\n    def forward(self, grad_last):\n        return -grad_last\n\n\nclass BoundCumSum(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n\n    def forward(self, x, axis):\n        self.axis = axis\n        return torch.cumsum(x, axis)\n\nclass BoundIdentity(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n\n    def forward(self, x):\n        return x\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        return [(last_lA, last_uA)], 0, 0\n\n    def bound_forward(self, dim_in, x):\n        return x\n\n\nclass LinearGrad(Module):\n    def __init__(self, weight):\n        super().__init__()\n        self.weight = weight\n\n    def forward(self, grad_last):\n        weight = self.weight.to(grad_last).t()\n        return F.linear(grad_last, weight)\n\n\nclass MatMulGrad(Module):\n    def forward(self, grad_last, x):\n        return grad_last.matmul(x.transpose(-1, -2))\n"
  },
  {
    "path": "auto_LiRPA/operators/logical.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Logical operators\"\"\"\nfrom .base import *\n\n\nclass BoundWhere(Bound):\n    def forward(self, condition, x, y):\n        return torch.where(condition.to(torch.bool), x, y)\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(0)\n        condition = v[0][0]\n        return tuple([torch.where(condition, v[1][j], v[2][j]) for j in range(2)])\n\n    def bound_backward(self, last_lA, last_uA, condition, x, y, **kwargs):\n        assert torch.allclose(condition.lower.float(), condition.upper.float())\n        assert self.from_input\n        mask = condition.lower.float()\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, None\n            assert last_A.ndim > 1\n            A_x = self.broadcast_backward(mask.unsqueeze(0) * last_A, x)\n            A_y = self.broadcast_backward((1 - mask).unsqueeze(0) * last_A, y)\n            return A_x, A_y\n\n        lA_x, lA_y = _bound_oneside(last_lA)\n        uA_x, uA_y = _bound_oneside(last_uA)\n\n        return [(None, None), (lA_x, uA_x), (lA_y, uA_y)], 0, 0\n\nclass BoundNot(Bound):\n    def forward(self, x):\n        return x.logical_not()\n\n\nclass BoundEqual(Bound):\n    def forward(self, x, y):\n        return x == y\n"
  },
  {
    "path": "auto_LiRPA/operators/minmax.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nfrom .base import *\nfrom .clampmult import multiply_by_A_signs\nfrom .activation_base import BoundOptimizableActivation\n\n\nclass BoundMinMax(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.options = options\n        self.requires_input_bounds = [0, 1]\n        self.op = None\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l = self.inputs[0].lower\n        # Alpha dimension is (2, output_shape, batch, *shape).\n        shape = [2, size_spec] + list(l.shape)\n        return torch.ones(shape, device=l.device)\n\n    def clip_alpha(self):\n        # See https://www.overleaf.com/read/jzgrcmqtqpcx#9dbf97 for the math behind this code.\n        lb_x = self._cached_lb_x\n        ub_x = self._cached_ub_x\n        lb_y = self._cached_lb_y\n        ub_y = self._cached_ub_y\n\n        for v in self.alpha.values():\n            eps = torch.tensor(1e-6).to(lb_x.dtype)\n            if self.op == 'max':\n                # Case 1: l_x >= u_y\n                case1 = (lb_x >= ub_y).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_lb = torch.zeros_like(case1)\n                alpha_u_ub = torch.zeros_like(case1)\n                alpha_l_lb = torch.zeros_like(case1)\n                alpha_l_ub = torch.zeros_like(case1)\n\n                # Case 2: l_x < u_y && u_x > u_y\n                case2 = ((lb_x < ub_y) * (ub_x > ub_y)).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_lb += case2 * (ub_x - ub_y) / (ub_x - torch.maximum(lb_x, lb_y))\n                alpha_u_ub += case2\n                alpha_l_ub += case2\n\n                # Case 3: l_x < u_y && u_x == u_y\n                case3 = ((lb_x < ub_y) * (ub_x == ub_y)).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_ub += case3\n                alpha_l_ub += case3\n\n                alpha_u_lb = torch.clamp(alpha_u_lb, min=eps)\n                alpha_u_ub = torch.clamp(alpha_u_ub, min=eps)\n            elif self.op == 'min':\n                # Case 1: l_y >= u_x\n                case1 = (lb_y >= ub_x).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_lb = torch.zeros_like(case1)\n                alpha_u_ub = torch.zeros_like(case1)\n                alpha_l_lb = torch.zeros_like(case1)\n                alpha_l_ub = torch.zeros_like(case1)\n\n                # Case 2: l_y < u_x && l_y > l_x\n                case2 = ((lb_y < ub_x) * (lb_y > lb_x)).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_ub += case2\n                alpha_l_lb += case2 * (lb_y - lb_x) / (torch.minimum(ub_x, ub_y) - lb_x)\n                alpha_l_ub += case2\n\n                # Case 3: l_y < u_x && l_y == l_x\n                case3 = ((lb_y < ub_x) * (lb_y == lb_x)).requires_grad_(False).to(lb_x.dtype)\n                alpha_u_ub += case3\n                alpha_l_ub += case3\n\n                alpha_l_lb = torch.clamp(alpha_l_lb, min=eps)\n                alpha_l_ub = torch.clamp(alpha_l_ub, min=eps)\n\n            v.data[0] = torch.clamp(v.data[0], alpha_u_lb, alpha_u_ub)\n            v.data[1] = torch.clamp(v.data[1], alpha_l_lb, alpha_l_ub)\n\n    def forward(self, x, y):\n        if self.op == 'max':\n            return torch.max(x, y)\n        elif self.op == 'min':\n            return torch.min(x, y)\n        else:\n            raise NotImplementedError\n\n    def _backward_relaxation(self, x, y, start_node=None):\n        # See https://www.overleaf.com/read/jzgrcmqtqpcx#9dbf97 for the math behind this code.\n\n        lb_x = x.lower\n        ub_x = x.upper\n        lb_y = y.lower\n        ub_y = y.upper\n\n        if self.opt_stage in ['opt', 'reuse']:\n            selected_alpha = self.alpha[start_node.name]\n            alpha_u = selected_alpha[0]\n            alpha_l = selected_alpha[1]\n        else:\n            alpha_u = alpha_l = 1\n\n        ub_x = ub_x.unsqueeze(0)\n        ub_y = ub_y.unsqueeze(0)\n        lb_x = lb_x.unsqueeze(0)\n        lb_y = lb_y.unsqueeze(0)\n\n        if self.op == 'max':\n            swapped_inputs = ub_x < ub_y\n        elif self.op == 'min':\n            swapped_inputs = lb_y < lb_x\n        else:\n            raise NotImplementedError\n        lb_x, lb_y = torch.where(swapped_inputs, lb_y, lb_x), torch.where(swapped_inputs, lb_x, lb_y)\n        ub_x, ub_y = torch.where(swapped_inputs, ub_y, ub_x), torch.where(swapped_inputs, ub_x, ub_y)\n\n        self._cached_lb_x = lb_x.detach()\n        self._cached_ub_x = ub_x.detach()\n        self._cached_lb_y = lb_y.detach()\n        self._cached_ub_y = ub_y.detach()\n\n        epsilon = 1e-6\n        ub_x = torch.max(ub_x, lb_x + epsilon)\n        ub_y = torch.max(ub_y, lb_y + epsilon)\n        # Ideally, if x or y are constant, this layer should be replaced by a ReLU\n        # max{x, c} = max{x − c, 0} + c\n        # min{x, c} = −max{−x, −c} = −(max{−x + c, 0} − c) = −max{−x + c, 0} + c\n        if torch.any(lb_x + 1e-4 >= ub_x) or torch.any(lb_y + 1e-4 >= ub_y):\n            print(\"Warning: MinMax layer (often used for clamping) received at \"\n                  \"least one input with lower bound almost equal to the upper \"\n                  \"bound. This can happen e.g. if x or y are constants. Consider \"\n                  \"replacing this layer with a ReLU for higher efficieny.\")\n        assert torch.all(ub_x != lb_x) and torch.all(ub_y != lb_y), (\n            'Lower/upper bounds are too close and epsilon was rounded away. '\n            'To fix this, increase epsilon.'\n        )\n\n        if isinstance(alpha_u, torch.Tensor):\n            assert alpha_u.shape[1:] == ub_x.shape[1:]\n            shape = alpha_u.shape\n        else:\n            shape = ub_x.shape\n        upper_dx = torch.zeros(shape, device=ub_x.device)\n        upper_dy = torch.zeros(shape, device=ub_x.device)\n        lower_dx = torch.zeros(shape, device=ub_x.device)\n        lower_dy = torch.zeros(shape, device=ub_x.device)\n        upper_b = torch.zeros(shape, device=ub_x.device)\n        lower_b = torch.zeros(shape, device=ub_x.device)\n        if self.op == 'max':\n            # Case 1: l_x >= u_y\n            case1 = (lb_x >= ub_y).requires_grad_(False).to(lb_x.dtype)\n            upper_dx += case1\n            lower_dx += case1\n\n            # Case 2: l_x < u_y && u_x > u_y\n            case2 = ((lb_x < ub_y) * (ub_x > ub_y)).requires_grad_(False).to(lb_x.dtype)\n            upper_dx = upper_dx + case2 * (ub_y - ub_x) / (alpha_u * (lb_x - ub_x))\n            upper_dy = upper_dy + case2 * (alpha_u - 1) * (ub_y - ub_x) / (alpha_u * (ub_y - lb_y))\n            upper_b = upper_b + case2 * (ub_x - (ub_x * (ub_y - ub_x)) / (alpha_u * (lb_x - ub_x))\n                                - ((alpha_u - 1) * (ub_y - ub_x) * lb_y) / (alpha_u * (ub_y - lb_y)))\n            lower_dx = lower_dx + case2 * (1 - alpha_l)\n            lower_dy = lower_dy + case2 * alpha_l\n\n            # Case 3: l_x < u_y && u_x == u_y\n            case3 = ((lb_x < ub_y) * (ub_x == ub_y)).requires_grad_(False).to(lb_x.dtype)\n            upper_dx = upper_dx + case3 * alpha_u * (ub_x - torch.maximum(lb_x, lb_y)) / (ub_x - lb_x)\n            upper_dy = upper_dy + case3 * alpha_u * (ub_x - torch.maximum(lb_x, lb_y)) / (ub_y - lb_y)\n            upper_b = upper_b + case3 * (ub_x -\n                        (alpha_u * (ub_x - torch.maximum(lb_x, lb_y)) * lb_x) / (ub_x - lb_x) -\n                        (alpha_u * (ub_x - torch.maximum(lb_x, lb_y)) * ub_y) / (ub_y - lb_y))\n            lower_dx = lower_dx + case3 * (1 - alpha_l)\n            lower_dy = lower_dy + case3 * alpha_l\n        elif self.op == 'min':\n            # Case 1: l_y >= u_x\n            case1 = (lb_y >= ub_x).requires_grad_(False).to(lb_x.dtype)\n            upper_dx = case1.clone()\n            lower_dx = case1.clone()\n            upper_dy = torch.zeros_like(case1)\n            lower_dy = torch.zeros_like(case1)\n            upper_b = torch.zeros_like(case1)\n            lower_b = torch.zeros_like(case1)\n\n            # Case 2: l_y < u_x && l_y > l_x\n            case2 = ((lb_y < ub_x) * (lb_y > lb_x)).requires_grad_(False).to(lb_x.dtype)\n            upper_dx = upper_dx + case2 * (1 - alpha_u)\n            upper_dy = upper_dy + case2 * alpha_u\n            lower_dx = lower_dx + case2 * (lb_x - lb_y) / (alpha_l * (lb_x - ub_x))\n            lower_dy = lower_dy + case2 * (alpha_l - 1) * (lb_x - lb_y) / (alpha_l * (ub_y - lb_y))\n            lower_b = lower_b + case2 * (lb_y - (ub_x * (lb_x - lb_y)) / (alpha_l * (lb_x - ub_x))\n                                - ((alpha_l - 1) * (lb_x - lb_y) * lb_y) / (alpha_l * (ub_y - lb_y)))\n\n            # Case 3: l_y < u_x && l_y == l_x\n            case3 = ((lb_y < ub_x) * (lb_y == lb_x)).requires_grad_(False).to(lb_x.dtype)\n            upper_dx = upper_dx + case3 * (1 - alpha_u)\n            upper_dy = upper_dy + case3 * alpha_u\n            lower_dx = lower_dx + case3 * alpha_l * (torch.minimum(ub_x, ub_y) - lb_x) / (ub_x - lb_x)\n            lower_dy = lower_dy + case3 * alpha_l * (torch.minimum(ub_x, ub_y) - lb_x) / (ub_y - lb_y)\n            lower_b = lower_b + case3 * (lb_x -\n                        (alpha_l * (torch.minimum(ub_x, ub_y) - lb_x) * lb_x) / (ub_x - lb_x) -\n                        (alpha_l * (torch.minimum(ub_x, ub_y) - lb_x) * ub_y) / (ub_y - lb_y))\n        else:\n            raise NotImplementedError\n\n        lower_dx, lower_dy = torch.where(swapped_inputs, lower_dy, lower_dx), torch.where(swapped_inputs, lower_dx, lower_dy)\n        upper_dx, upper_dy = torch.where(swapped_inputs, upper_dy, upper_dx), torch.where(swapped_inputs, upper_dx, upper_dy)\n\n        return upper_dx, upper_dy, upper_b, lower_dx, lower_dy, lower_b\n\n    def bound_backward(self, last_lA, last_uA, x=None, y=None, start_shape=None,\n                       start_node=None, **kwargs):\n        # Get element-wise CROWN linear relaxations.\n        upper_dx, upper_dy, upper_b, lower_dx, lower_dy, lower_b = \\\n            self._backward_relaxation(x, y, start_node)\n\n        # Choose upper or lower bounds based on the sign of last_A\n        def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):\n            if last_A is None:\n                return None, 0\n            # Obtain the new linear relaxation coefficients based on the signs in last_A.\n            _A, _bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg)\n            if isinstance(last_A, Patches):\n                # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters.\n                A_prod = _A.patches\n                if start_node is not None:\n                    # Regular patches.\n                    self.patch_size[start_node.name] = A_prod.size()\n            return _A, _bias\n\n        # In patches mode we might need an unfold.\n        # lower_dx, lower_dy, upper_dx, upper_dy, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None\n        # In _backward_relaxation, the lb_x etc. potentially got swapped. This may cause the memory to become\n        # non-contiguous. This is not a problem if the spec_size is 1, e.g. if alphas are shared.\n        upper_dx = upper_dx.contiguous()\n        upper_dy = upper_dy.contiguous()\n        lower_dx = lower_dx.contiguous()\n        lower_dy = lower_dy.contiguous()\n        upper_b = upper_b.contiguous()\n        lower_b = lower_b.contiguous()\n\n\n        upper_dx = maybe_unfold_patches(upper_dx, last_lA if last_lA is not None else last_uA)\n        upper_dy = maybe_unfold_patches(upper_dy, last_lA if last_lA is not None else last_uA)\n        lower_dx = maybe_unfold_patches(lower_dx, last_lA if last_lA is not None else last_uA)\n        lower_dy = maybe_unfold_patches(lower_dy, last_lA if last_lA is not None else last_uA)\n        upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA)\n        lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA)\n\n        uAx, ubias = _bound_oneside(last_uA, upper_dx, lower_dx, upper_b, lower_b)\n        uAy, ubias2 = _bound_oneside(last_uA, upper_dy, lower_dy, upper_b, lower_b)\n        if isinstance(ubias, torch.Tensor):\n            assert isinstance(ubias2, torch.Tensor)\n            assert torch.all(ubias == ubias2)\n        else:\n            assert ubias == ubias2 == 0\n        lAx, lbias = _bound_oneside(last_lA, lower_dx, upper_dx, lower_b, upper_b)\n        lAy, lbias2 = _bound_oneside(last_lA, lower_dy, upper_dy, lower_b, upper_b)\n        if isinstance(lbias, torch.Tensor):\n            assert isinstance(lbias2, torch.Tensor)\n            assert torch.all(lbias == lbias2)\n        else:\n            assert lbias == lbias2 == 0\n\n        return [(lAx, uAx), (lAy, uAy)], lbias, ubias\n\n    def interval_propagate(self, *v):\n        h_Lx, h_Ux = v[0][0], v[0][1]\n        h_Ly, h_Uy = v[1][0], v[1][1]\n        return self.forward(h_Lx, h_Ly), self.forward(h_Ux, h_Uy)\n\n\nclass BoundMax(BoundMinMax):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.op = 'max'\n\n\nclass BoundMin(BoundMinMax):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.op = 'min'\n"
  },
  {
    "path": "auto_LiRPA/operators/normalization.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Normalization operators\"\"\"\nimport copy\n\nimport torch\nimport torch.nn as nn\n\nfrom .base import *\nfrom .constant import BoundConstant\nfrom .leaf import BoundParams\nfrom .solver_utils import grb\n\n\nclass BoundBatchNormalization(Bound):\n    def __init__(self, attr, inputs, output_index, options, training):\n        super().__init__(attr, inputs, output_index, options)\n        self.eps = attr['epsilon']\n        self.momentum = round(1 - attr['momentum'], 5)  # take care!\n        self.options = options.get(\"bn\", {})\n        # modes:\n        #   - forward: use mean and variance estimated from clean forward pass\n        #   - ibp: use mean and variance estimated from ibp\n        self.bn_mode = self.options.get(\"mode\", \"forward\")\n        self.use_mean = self.options.get(\"mean\", True)\n        self.use_var = self.options.get(\"var\", True)\n        self.use_affine = self.options.get(\"affine\", True)\n        self.training = training\n        self.patches_start = True\n        self.mode = options.get(\"conv_mode\", \"matrix\")\n        if not self.use_mean or not self.use_var:\n            logger.info(f'Batch normalization node {self.name}: use_mean {self.use_mean}, use_var {self.use_var}')\n\n    def _check_unused_mean_or_var(self):\n        # Check if either mean or var is opted out\n        if not self.use_mean:\n            self.current_mean = torch.zeros_like(self.current_mean)\n        if not self.use_var:\n            self.current_var = torch.ones_like(self.current_var)\n\n    def forward(self, x, w, b, m, v):\n        if len(x.shape) == 2:\n            self.patches_start = False\n        if self.training:\n            dim = [0] + list(range(2, x.ndim))\n            self.current_mean = x.mean(dim)\n            self.current_var = x.var(dim, unbiased=False)\n        else:\n            self.current_mean = m.data\n            self.current_var = v.data\n        self._check_unused_mean_or_var()\n        if not self.use_affine:\n            w = torch.ones_like(w)\n            b = torch.zeros_like(b)\n        result = F.batch_norm(x, m, v, w, b, self.training, self.momentum, self.eps)\n        if not self.use_mean or not self.use_var:\n            # If mean or variance is disabled, recompute the output from self.current_mean\n            # and self.current_var instead of using standard F.batch_norm.\n            w = w / torch.sqrt(self.current_var + self.eps)\n            b = b - self.current_mean * w\n            shape = (1, -1) + (1,) * (x.ndim - 2)\n            result = w.view(*shape) * x + b.view(*shape)\n        return result\n\n    def bound_forward(self, dim_in, *x):\n        inp = x[0]\n        assert (x[1].lower == x[1].upper).all(), \"unsupported forward bound with perturbed mean\"\n        assert (x[2].lower == x[2].upper).all(), \"unsupported forward bound with perturbed var\"\n        weight, bias = x[1].lower, x[2].lower\n        if not self.training:\n            assert (x[3].lower == x[3].upper).all(), \"unsupported forward bound with perturbed mean\"\n            assert (x[4].lower == x[4].upper).all(), \"unsupported forward bound with perturbed var\"\n            self.current_mean = x[3].lower\n            self.current_var = x[4].lower\n        self._check_unused_mean_or_var()\n        if not self.use_affine:\n            weight = torch.ones_like(weight)\n            bias = torch.zeros_like(bias)\n\n        tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight\n        tmp_weight = weight / torch.sqrt(self.current_var + self.eps)\n\n        tmp_weight = tmp_weight.view(*((1, 1, -1) + (1,) * (inp.lw.ndim - 3)))\n        new_lw = torch.clamp(tmp_weight, min=0.) * inp.lw + torch.clamp(tmp_weight, max=0.) * inp.uw\n        new_uw = torch.clamp(tmp_weight, min=0.) * inp.uw + torch.clamp(tmp_weight, max=0.) * inp.lw\n\n        tmp_weight = tmp_weight.view(*((1, -1) + (1,) * (inp.lb.ndim - 2)))\n        tmp_bias = tmp_bias.view(*((1, -1) + (1,) * (inp.lb.ndim - 2)))\n        new_lb = torch.clamp(tmp_weight, min=0.) * inp.lb + torch.clamp(tmp_weight, max=0.) * inp.ub + tmp_bias\n        new_ub = torch.clamp(tmp_weight, min=0.) * inp.ub + torch.clamp(tmp_weight, max=0.) * inp.lb + tmp_bias\n\n        return LinearBound(\n            lw = new_lw,\n            lb = new_lb,\n            uw = new_uw,\n            ub = new_ub)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        assert not self.is_input_perturbed(1) and not self.is_input_perturbed(2), \\\n            'Weight perturbation is not supported for BoundBatchNormalization'\n\n        def get_param(p):\n            if isinstance(p, BoundConstant):\n                # When affine is disabled in BN\n                return p.value\n            elif isinstance(p, BoundParams):\n                return p.param\n            else:\n                raise TypeError(p)\n\n        # x[0]: input, x[1]: weight, x[2]: bias, x[3]: running_mean, x[4]: running_var\n        weight = get_param(x[1])\n        bias = get_param(x[2])\n        if not self.training:\n            self.current_mean = x[3].value\n            self.current_var = x[4].value\n        self._check_unused_mean_or_var()\n        if not self.use_affine:\n            weight = torch.ones_like(weight)\n            bias = torch.zeros_like(bias)\n\n        tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight\n        tmp_weight = weight / torch.sqrt(self.current_var + self.eps)\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, 0\n            if type(last_A) == Tensor:\n                next_A = last_A * tmp_weight.view(*((1, 1, -1) + (1,) * (last_A.ndim - 3)))\n                if last_A.ndim > 3:\n                    sum_bias = (last_A.sum(tuple(range(3, last_A.ndim))) * tmp_bias).sum(2)\n                else:\n                    sum_bias = (last_A * tmp_bias).sum(2)\n            elif type(last_A) == Patches:\n                # TODO Only 4-dim BN supported in the Patches mode\n                if last_A.identity == 0:\n                    # FIXME (09/17): Need to check if it has already been padding.\n                    # Patch has dimension (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w)\n                    patches = last_A.patches\n\n                    # tmp_weight has shape (c,), it will be applied on the (c,) dimension.\n                    patches = patches * tmp_weight.view(*([1] * (patches.ndim - 3)), -1, 1, 1)  # Match with sparse or non-sparse patches.\n                    next_A = last_A.create_similar(patches)\n\n                    # bias to size (c,), need expansion before unfold.\n                    bias = tmp_bias.view(-1,1,1).expand(self.input_shape[1:]).unsqueeze(0)\n                    # Unfolded bias has shape (1, out_h, out_w, in_c, H, W).\n                    bias_unfolded = inplace_unfold(bias, kernel_size=last_A.patches.shape[-2:], padding=last_A.padding, stride=last_A.stride,\n                            inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding)\n                    if last_A.unstable_idx is not None:\n                        # Sparse bias has shape (unstable_size, batch, in_c, H, W).\n                        bias_unfolded = bias_unfolded[:, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                        sum_bias = torch.einsum('bschw,sbchw->sb', bias_unfolded, last_A.patches)\n                        # Output sum_bias has shape (unstable_size, batch).\n                    else:\n                        # Patch has dimension (out_c, batch, out_h, out_w, c, h, w).\n                        sum_bias = torch.einsum('bijchw,sbijchw->sbij', bias_unfolded, last_A.patches)\n                        # Output sum_bias has shape (out_c, batch, out_h, out_w).\n                else:\n                    # we should create a real identity Patch\n                    num_channel = tmp_weight.numel()\n                    # desired Shape is (c, batch, out_w, out_h, c, 1, 1) or (unstable_size, batch, c, 1, 1).\n                    patches = (torch.eye(num_channel, device=tmp_weight.device) * tmp_weight.view(-1)).view(num_channel, 1, 1, 1, num_channel, 1, 1)\n                    # Expand out_h, out_w dimensions but not for batch dimension.\n                    patches = patches.expand(-1, -1, last_A.output_shape[2], last_A.output_shape[3], -1, 1, 1)\n                    if last_A.unstable_idx is not None:\n                        # Select based on unstable indices.\n                        patches = patches[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                    # Expand the batch dimension.\n                    patches = patches.expand(-1, last_A.shape[1], *([-1] * (patches.ndim - 2)))\n                    next_A = last_A.create_similar(patches, stride=1, padding=0, identity=0)\n                    if last_A.unstable_idx is not None:\n                        # Need to expand the bias and choose the selected ones.\n                        bias = tmp_bias.view(-1,1,1,1).expand(-1, 1, last_A.output_shape[2], last_A.output_shape[3])\n                        bias = bias[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                        # Expand the batch dimension, and final output shape is (unstable_size, batch).\n                        sum_bias = bias.expand(-1, last_A.shape[1])\n                    else:\n                        # Output sum_bias has shape (out_c, batch, out_h, out_w).\n                        sum_bias = tmp_bias.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4])\n            else:\n                raise NotImplementedError()\n            return next_A, sum_bias\n\n        lA, lbias = _bound_oneside(last_lA)\n        uA, ubias = _bound_oneside(last_uA)\n\n        return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], lbias, ubias\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(1) and not self.is_input_perturbed(2), \\\n            'Weight perturbation is not supported for BoundBatchNormalization'\n\n        h_L, h_U = v[0]\n        weight, bias = v[1][0], v[2][0]\n\n        mid = (h_U + h_L) / 2.0\n        diff = (h_U - h_L) / 2.0\n\n        # Use `mid` in IBP to compute mean and variance for BN.\n        # In this case, `forward` should not have been called.\n        if self.bn_mode == 'ibp' and not hasattr(self, 'forward_value'):\n            m, v, w, b = tuple(self.inputs[i].forward() for i in range(1, 5))\n            self.forward(mid, m, v, w, b)\n\n        if not self.training:\n            assert not (self.is_input_perturbed(3) or self.is_input_perturbed(4))\n            self.current_mean = v[3][0]\n            self.current_var = v[4][0]\n        self._check_unused_mean_or_var()\n        if not self.use_affine:\n            weight = torch.ones_like(weight)\n            bias = torch.zeros_like(bias)\n\n        tmp_weight = weight / torch.sqrt(self.current_var + self.eps)\n        tmp_weight_abs = tmp_weight.abs()\n        tmp_bias = bias - self.current_mean * tmp_weight\n        shape = (1, -1) + (1,) * (mid.ndim - 2)\n\n        # interval_propagate() of the Linear layer may encounter input with different norms.\n        norm, eps = Interval.get_perturbation(v[0])[:2]\n        if norm == torch.inf:\n            center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape)\n            deviation = tmp_weight_abs.view(*shape) * diff\n        elif norm > 0:\n            mid = v[0][0]\n            center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape)\n            if norm == 2:\n                ptb = copy.deepcopy(v[0].ptb)\n                ptb.eps = eps * tmp_weight_abs.max()\n                return Interval(center, center, ptb=ptb)\n            else:\n                # General Lp norm.\n                center = tmp_weight.view(*shape) * mid\n                deviation = tmp_weight_abs.view(*shape) * eps  # use a Linf ball to replace Lp norm\n        else:\n            raise NotImplementedError\n\n        lower, upper = center - deviation, center + deviation\n\n        return lower, upper\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        # e.g., last layer input gurobi vars (3,32,32)\n        gvars_array = np.array(v[0])\n        # pre_layer_shape (1,3,32,32)\n        pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape\n        # this layer shape (1,8,16,16)\n        this_layer_shape = self.output_shape\n\n        weight, bias = v[1], v[2]\n\n        self.current_mean = v[3]\n        self.current_var = v[4]\n        self._check_unused_mean_or_var()\n        if not self.use_affine:\n            weight = torch.ones_like(weight)\n            bias = torch.zeros_like(bias)\n\n        tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight\n        tmp_weight = weight / torch.sqrt(self.current_var + self.eps)\n\n        new_layer_gurobi_vars = []\n        neuron_idx = 0\n        for out_chan_idx in range(this_layer_shape[1]):\n            out_chan_vars = []\n            for out_row_idx in range(this_layer_shape[2]):\n                out_row_vars = []\n                for out_col_idx in range(this_layer_shape[3]):\n                    # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1))\n                    lin_expr = tmp_bias[out_chan_idx].item() + tmp_weight[out_chan_idx].item() * gvars_array[out_chan_idx, out_row_idx, out_col_idx]\n                    var = model.addVar(lb=-float('inf'), ub=float('inf'),\n                                            obj=0, vtype=grb.GRB.CONTINUOUS,\n                                            name=f'lay{self.name}_{neuron_idx}')\n                    model.addConstr(lin_expr == var, name=f'lay{self.name}_{neuron_idx}_eq')\n                    neuron_idx += 1\n\n                    out_row_vars.append(var)\n                out_chan_vars.append(out_row_vars)\n            new_layer_gurobi_vars.append(out_chan_vars)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()\n\n    def update_requires_input_bounds(self):\n        self._check_weight_perturbation()\n\n\nclass LayerNormImpl(nn.Module):\n    def __init__(self, axis, epsilon):\n        super().__init__()\n        self.axis = axis\n        self.epsilon = epsilon\n\n    def forward(self, x, scale, bias):\n        mean = x.mean(self.axis, keepdim=True)\n        d = x - mean\n        dd = d**2\n        var = dd.mean(self.axis, keepdim=True)\n        var_eps = var + self.epsilon\n        std_dev = torch.sqrt(var_eps)\n        inv_std_dev = torch.reciprocal(std_dev)\n        normalized = d * inv_std_dev\n        normalized_scaled = normalized * scale + bias\n        return normalized_scaled\n\n\nclass BoundLayerNormalization(Bound):\n    def __init__(self, attr, inputs, output_index, options):\n        super().__init__(attr, inputs, output_index, options)\n        self.complex = True\n        self.model = LayerNormImpl(self.attr['axis'], self.attr['epsilon'])\n\n    def forward(self, x, scale, bias):\n        self.input = (x, scale, bias)\n        return self.model(x, scale, bias)\n"
  },
  {
    "path": "auto_LiRPA/operators/pooling.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Pooling operators.\"\"\"\nfrom collections import OrderedDict\nfrom .base import *\nfrom .activation_base import BoundOptimizableActivation\nimport numpy as np\nfrom .solver_utils import grb\n\n\nclass BoundMaxPool(BoundOptimizableActivation):\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])\n        assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])\n\n        self.requires_input_bounds = [0]\n        self.kernel_size = attr['kernel_shape']\n        self.stride = attr['strides']\n        self.padding = [attr['pads'][0], attr['pads'][1]]\n        self.ceil_mode = False\n        self.use_default_ibp = True\n        self.alpha = {}\n        self.init = {}\n\n    def forward(self, x):\n        output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,\n                                 return_indices=True, ceil_mode=self.ceil_mode)\n        return output\n\n    def project_simplex(self, patches):\n        sorted = torch.flatten(patches, -2)\n        sorted, _ = torch.sort(sorted, -1, descending=True)\n        rho_sum = torch.cumsum(sorted, -1)\n        rho_value = 1 - rho_sum\n        rho_value = (sorted + rho_value/torch.tensor(\n            range(1, sorted.size(-1)+1), dtype=torch.float,\n            device=sorted.device)) > 0\n        _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1)\n        rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1)\n        lbd = 1/(rho_index+1)* (1-rho_sum)\n\n        return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0)\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        if name_start == '_forward':\n            warnings.warn(\"MaxPool's optimization is not supported for forward mode\")\n            return None\n        ref = self.inputs[0].lower # a reference variable for getting the shape\n        alpha = torch.empty(\n            [1, size_spec, self.input_shape[0], self.input_shape[1],\n            self.output_shape[-2], self.output_shape[-1],\n            self.kernel_size[0], self.kernel_size[1]],\n            dtype=torch.float, device=ref.device, requires_grad=True)\n        self.init[name_start] = False\n        return alpha\n\n    @staticmethod\n    @torch.jit.script\n    def jit_mutiply(Apos, Aneg, pos, neg):\n        return pos.contiguous() * Apos + neg.contiguous() * Aneg\n\n    def bound_backward(self, last_lA, last_uA, x, start_node=None,\n                       unstable_idx=None, **kwargs):\n        # self.padding is a tuple of two elements: (height dimension padding, width dimension padding).\n        paddings = tuple((self.padding[0], self.padding[0], self.padding[1], self.padding[1]))\n\n        if self.stride[0] != self.kernel_size[0]:\n            raise ValueError(\"self.stride ({}) != self.kernel_size ({})\".format(self.stride, self.kernel_size))\n\n        shape = self.input_shape\n        batch_size = x.lower.shape[0]\n        shape = list(shape[:-2]) + [a + 2*b for a, b in zip(self.input_shape[-2:], self.padding)]\n        shape[0] = batch_size\n        # Lower and upper D matrices. They have size (batch_size, input_c, x, y) which will be multiplied on enlarges the A matrices via F.interpolate.\n        upper_d = torch.zeros(shape, device=x.device)\n        lower_d = None\n\n        # Size of upper_b and lower_b: (batch_size, output_c, h, w).\n        upper_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device)\n        lower_b = torch.zeros(batch_size, *self.output_shape[1:], device=x.device)\n\n        # Find the maxpool neuron whose input bounds satisfy l_i > max_j u_j for all j != i. In this case, the maxpool neuron is linear, and we can set upper_d = lower_d = 1.\n        # We first find which indices has the largest lower bound.\n        max_lower, max_lower_index = F.max_pool2d(\n            x.lower, self.kernel_size, self.stride, self.padding,\n            return_indices=True, ceil_mode=self.ceil_mode)\n        # Set the upper bound of the i-th input to -inf so it will not be selected as the max.\n\n        if paddings == (0,0,0,0):\n            delete_upper = torch.scatter(\n                torch.flatten(x.upper, -2), -1,\n                torch.flatten(max_lower_index, -2), -torch.inf).view(upper_d.shape)\n        else:\n            delete_upper = torch.scatter(\n                torch.flatten(F.pad(x.upper, paddings), -2), -1,\n                torch.flatten(max_lower_index, -2),\n                -torch.inf).view(upper_d.shape)\n        # Find the the max upper bound over the remaining ones.\n        max_upper, _ = F.max_pool2d(\n            delete_upper, self.kernel_size, self.stride, 0,\n            return_indices=True, ceil_mode=self.ceil_mode)\n\n        # The upper bound slope for maxpool is either 1 on input satisfies l_i > max_j u_j (linear), or 0 everywhere. Upper bound is not optimized.\n        values = torch.zeros_like(max_lower)\n        values[max_lower >= max_upper] = 1.0\n        upper_d = torch.scatter(\n            torch.flatten(upper_d, -2), -1,\n            torch.flatten(max_lower_index, -2),\n            torch.flatten(values, -2)).view(upper_d.shape)\n\n        if self.opt_stage == 'opt':\n            if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1:\n                if isinstance(unstable_idx, tuple):\n                    raise NotImplementedError('Please use --conv_mode matrix')\n                elif unstable_idx.ndim == 1:\n                    # Only unstable neurons of the start_node neurons are used.\n                    alpha = self.non_deter_index_select(\n                        self.alpha[start_node.name], index=unstable_idx, dim=1)\n                elif unstable_idx.ndim == 2:\n                    # Each element in the batch selects different neurons.\n                    alpha = batched_index_select(\n                        self.alpha[start_node.name], index=unstable_idx, dim=1)\n                else:\n                    raise ValueError\n            else:\n                alpha = self.alpha[start_node.name]\n\n            if not self.init[start_node.name]:\n                lower_d = torch.zeros((shape), device=x.device)\n                # [batch, C, H, W]\n                lower_d = torch.scatter(\n                    torch.flatten(lower_d, -2), -1,\n                    torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape)\n                # shape [batch, C*k*k, L]\n                lower_d_unfold = F.unfold(\n                    lower_d, self.kernel_size, 1, stride=self.stride)\n\n                # [batch, C, k, k, out_H, out_W]\n                alpha_data = lower_d_unfold.view(\n                    lower_d.shape[0], lower_d.shape[1], self.kernel_size[0],\n                    self.kernel_size[1], self.output_shape[-2], self.output_shape[-1])\n\n                # [batch, C, out_H, out_W, k, k]\n                alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach())\n                self.init[start_node.name] = True\n                # In optimization mode, we use the same lower_d once builded.\n                if self.padding[0] > 0 or self.padding[1] > 0:\n                    lower_d = lower_d[...,self.padding[0]:-self.padding[0],\n                                      self.padding[1]:-self.padding[1]]\n            # The lower bound coefficients must be positive and projected to an unit simplex.\n            alpha.data = self.project_simplex(alpha.data).clone().detach()  # TODO: don't do this, never re-assign the .data property. Use copy_ instead.\n            # permute the last 6 dimensions of alpha to [batch, C, k, k, out_H, out_W], which prepares for the unfold operation.\n            alpha = alpha.permute((0,1,2,3,6,7,4,5))\n            alpha_shape = alpha.shape\n            alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2],\n                                   -1, alpha_shape[-2]*alpha_shape[-1]))\n            lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1,\n                             self.padding, self.stride)\n            lower_d = lower_d.view(alpha_shape[0], alpha_shape[1],\n                                   alpha_shape[2], *lower_d.shape[1:])\n            lower_d = lower_d.squeeze(0)\n        else:\n            lower_d = torch.zeros((shape), device=x.device)\n            # Not optimizable bounds. We simply set \\hat{z} >= z_i where i is the input element with largest lower bound.\n            lower_d = torch.scatter(torch.flatten(lower_d, -2), -1,\n                                    torch.flatten(max_lower_index, -2),\n                                    1.0).view(upper_d.shape)\n            if self.padding[0] > 0 or self.padding[1] > 0:\n                lower_d = lower_d[...,self.padding[0]:-self.padding[0],\n                                  self.padding[1]:-self.padding[1]]\n\n        # For the upper bound, we set the bias term to concrete upper bounds for maxpool neurons that are not linear.\n        max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride,\n                                     self.padding, return_indices=True,\n                                     ceil_mode=self.ceil_mode)\n        upper_b[max_upper > max_lower] = max_upper_[max_upper > max_lower]\n\n        def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):\n            if last_A is None:\n                return None, 0\n\n            bias = 0\n\n            if isinstance(last_A, torch.Tensor):\n                pos_A = last_A.clamp(min=0)\n                neg_A = last_A.clamp(max=0)\n\n                if b_pos is not None:\n                    # This is matrix mode, and padding is considered in the previous layers\n                    bias = bias + self.get_bias(pos_A, b_pos)\n                if b_neg is not None:\n                    bias = bias + self.get_bias(neg_A, b_neg)\n\n                # Here we should comfirm that the maxpool patches are not overlapped.\n                shape = last_A.size()\n\n                padding = [self.padding[0], self.padding[0], self.padding[1], self.padding[1]]\n                d_pos = F.pad(d_pos, padding)\n                d_neg = F.pad(d_neg, padding)\n\n                pos_A = F.interpolate(\n                    pos_A.view(shape[0] * shape[1], *shape[2:]),\n                    scale_factor=self.kernel_size)\n                if d_pos.shape[-2] > pos_A.shape[-2] or d_pos.shape[-1] > pos_A.shape[-1]:\n                    if not (d_pos.shape[-2] > pos_A.shape[-2] and d_pos.shape[-1] > pos_A.shape[-1]):\n                        raise NotImplementedError(\n                            \"Asymmetric padding of maxpool not implemented.\")\n                    pos_A = F.pad(pos_A, (0, d_pos.shape[-2] - pos_A.shape[-2],\n                                          0, d_pos.shape[-1] - pos_A.shape[-1]))\n                else:\n                    d_pos = F.pad(d_pos, (0, pos_A.shape[-2] - d_pos.shape[-2],\n                                          0, pos_A.shape[-1] - d_pos.shape[-1]))\n                pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:])\n\n                neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]),\n                                      scale_factor=self.kernel_size)\n                if d_neg.shape[-2] > neg_A.shape[-2] or d_neg.shape[-1] > neg_A.shape[-1]:\n                    if not (d_neg.shape[-2] > neg_A.shape[-2] and d_neg.shape[-1] > neg_A.shape[-1]):\n                        raise NotImplementedError(\"Asymmetric padding of maxpool not implemented.\")\n                    neg_A = F.pad(neg_A, (0, d_neg.shape[-2] - neg_A.shape[-2],\n                                          0, d_neg.shape[-1] - neg_A.shape[-1]))\n                else:\n                    d_neg = F.pad(d_neg, (0, neg_A.shape[-2] - d_neg.shape[-2],\n                                          0, neg_A.shape[-1] - d_neg.shape[-1]))\n                neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:])\n\n                next_A = self.jit_mutiply(pos_A, neg_A, d_pos, d_neg)\n                if self.padding[0] > 0 or self.padding[1] > 0:\n                    next_A = next_A[...,self.padding[0]:-self.padding[0],\n                                    self.padding[1]:-self.padding[1]]\n            elif isinstance(last_A, Patches):\n                # The last_A.patches was not padded, so we need to pad them here.\n                # If this Conv layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again.\n                one_d = torch.ones(tuple(1 for i in self.output_shape[1:]),\n                                   device=last_A.patches.device, dtype=last_A.patches.dtype).expand(self.output_shape[1:])\n                # Add batch dimension.\n                one_d = one_d.unsqueeze(0)\n                # After unfolding, the shape is (1, out_h, out_w, in_c, h, w)\n                one_d_unfolded = inplace_unfold(\n                    one_d, kernel_size=last_A.patches.shape[-2:],\n                    stride=last_A.stride, padding=last_A.padding,\n                    inserted_zeros=last_A.inserted_zeros,\n                    output_padding=last_A.output_padding)\n                if last_A.unstable_idx is not None:\n                    # Move out_h, out_w dimension to the front for easier selection.\n                    one_d_unfolded_r = one_d_unfolded.permute(1, 2, 0, 3, 4, 5)\n                    # for sparse patches the shape is (unstable_size, batch, in_c, h, w). Batch size is 1 so no need to select here.\n                    one_d_unfolded_r = one_d_unfolded_r[\n                        last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                else:\n                    # Append the spec dimension.\n                    one_d_unfolded_r = one_d_unfolded.unsqueeze(0)\n                patches = last_A.patches * one_d_unfolded_r\n\n                if b_pos is not None:\n                    patch_pos = Patches(\n                        patches.clamp(min=0), last_A.stride, last_A.padding,\n                        last_A.shape, unstable_idx=last_A.unstable_idx,\n                        output_shape=last_A.output_shape)\n                    bias = bias + self.get_bias(patch_pos, b_pos)\n                if b_neg is not None:\n                    patch_neg = Patches(\n                        patches.clamp(max=0), last_A.stride, last_A.padding,\n                        last_A.shape, unstable_idx=last_A.unstable_idx,\n                        output_shape=last_A.output_shape)\n                    bias = bias + self.get_bias(patch_neg, b_neg)\n\n                # bias = bias.transpose(0,1)\n                shape = last_A.shape\n                pos_A = last_A.patches.clamp(min=0)\n                neg_A = last_A.patches.clamp(max=0)\n\n                def upsample(last_patches, last_A):\n                    if last_A.unstable_idx is None:\n                        patches = F.interpolate(\n                            last_patches.view(shape[0] * shape[1] * shape[2], *shape[3:]),\n                            scale_factor=[1,]+self.kernel_size)\n                        patches = patches.view(shape[0], shape[1], shape[2], *patches.shape[1:])\n                    else:\n                        patches = F.interpolate(\n                            last_patches, scale_factor=[1,] + self.kernel_size)\n                    return Patches(\n                        patches, stride=last_A.stride, padding=last_A.padding,\n                        shape=patches.shape, unstable_idx=last_A.unstable_idx,\n                        output_shape=last_A.output_shape)\n\n                pos_A = upsample(pos_A, last_A)\n                neg_A = upsample(neg_A, last_A)\n\n                padding, stride, output_padding = compute_patches_stride_padding(\n                    self.input_shape, last_A.padding, last_A.stride, self.padding,\n                    self.stride, last_A.inserted_zeros, last_A.output_padding)\n\n                pos_A.padding, pos_A.stride, pos_A.output_padding = padding, stride, output_padding\n                neg_A.padding, neg_A.stride, neg_A.output_padding = padding, stride, output_padding\n\n                # unsqueeze for the spec dimension\n                d_pos = maybe_unfold_patches(d_pos.unsqueeze(0), pos_A)\n                d_neg = maybe_unfold_patches(d_neg.unsqueeze(0), neg_A)\n\n                next_A_patches = self.jit_mutiply(\n                    pos_A.patches, neg_A.patches, d_pos, d_neg)\n\n                if start_node is not None:\n                    self.patch_size[start_node.name] = next_A_patches.size()\n\n                next_A = Patches(\n                    next_A_patches, stride, padding, next_A_patches.shape,\n                    unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape,\n                    inserted_zeros=last_A.inserted_zeros, output_padding=output_padding)\n\n            return next_A, bias\n\n        if self.padding[0] > 0:\n            upper_d = upper_d[...,self.padding[0]:-self.padding[0],\n                              self.padding[0]:-self.padding[0]]\n\n        uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b)\n        lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b)\n\n        return [(lA, uA)], lbias, ubias\n\n    def bound_forward(self, dim_in, x):\n        lower_d, lower_b, upper_d, upper_b = self.bound_relax(x, init=False)\n\n        def _bound_oneside(w_pos, b_pos, w_neg, b_neg, d, b):\n            d_pos, d_neg = d.clamp(min=0), d.clamp(max=0)\n            w_new = d_pos.unsqueeze(1) * w_pos + d_neg.unsqueeze(1) * w_neg\n            b_new = d_pos * b_pos + d_neg * b_neg\n            if isinstance(self.kernel_size, list) and len(self.kernel_size) == 2:\n                tot_kernel_size = prod(self.kernel_size)\n            elif isinstance(self.kernel_size, int):\n                tot_kernel_size = self.kernel_size ** 2\n            else:\n                raise ValueError(f'Unsupported kernel size {self.kernel_size}')\n            w_pooled = (F.avg_pool2d(w_new.view(-1, *w_new.shape[2:]),\n                self.kernel_size, self.stride, self.padding,\n                ceil_mode=self.ceil_mode) * tot_kernel_size)\n            w_pooled = w_pooled.reshape(w_new.shape[0], -1, *w_pooled.shape[1:])\n            b_pooled = F.avg_pool2d(b_new, self.kernel_size, self.stride, self.padding,\n                ceil_mode=self.ceil_mode) * tot_kernel_size + b\n            return w_pooled, b_pooled\n\n        lw, lb = _bound_oneside(x.lw, x.lb, x.uw, x.ub, lower_d, lower_b)\n        uw, ub = _bound_oneside(x.uw, x.ub, x.lw, x.lb, upper_d, upper_b)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n\n        # Only used by forward mode\n        paddings = tuple(self.padding + self.padding)\n        self.upper, self.lower = x.upper, x.lower\n\n        # A_shape = last_lA.shape if last_lA is not None else last_uA.shape\n        # batch_size, input_c, x, y\n        upper_d = torch.zeros_like(x.lower)\n        lower_d = torch.zeros_like(x.lower)\n\n        upper_d = F.pad(upper_d, paddings)\n        lower_d = F.pad(lower_d, paddings)\n\n        # batch_size, output_c, x, y\n        upper_b = torch.zeros((list(self.output_shape))).to(x.lower)\n        lower_b = torch.zeros((list(self.output_shape))).to(x.lower)\n\n        # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1\n        max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode)\n        delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -torch.inf).view(upper_d.shape)\n        max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode)\n\n        values = torch.zeros_like(max_lower)\n        values[max_lower >= max_upper] = 1.0\n        upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape)\n\n        if self.opt_stage == 'opt':\n            raise NotImplementedError\n        else:\n            lower_d = torch.scatter(torch.flatten(lower_d, -2), -1,\n                                    torch.flatten(max_lower_index, -2),\n                                    1.0).view(upper_d.shape)\n            if self.padding[0] > 0:\n                lower_d = lower_d[...,self.padding[0]:-self.padding[0],\n                                  self.padding[0]:-self.padding[0]]\n\n        values[:] = 0.0\n        max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride,\n                                     self.padding, return_indices=True,\n                                     ceil_mode=self.ceil_mode)\n        values[max_upper > max_lower] = max_upper_[max_upper > max_lower]\n        upper_b = values\n\n        if self.padding[0] > 0:\n            upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]]\n\n        return lower_d, lower_b, upper_d, upper_b\n\n    def dump_alpha(self, device=None, dtype=None, non_blocking=False):\n        ret = {'alpha': self._transfer_alpha(self.alpha, device=device, dtype=dtype, non_blocking=non_blocking, require_grad=False)}\n        ret['init'] = self.init\n        return ret\n\n    def restore_alpha(self, alpha, device=None, dtype=None, non_blocking=False):\n        self.alpha = self._transfer_alpha(alpha['alpha'], device=device, dtype=dtype, non_blocking=non_blocking, require_grad=True)\n        self.init = alpha['init']\n\n    def drop_unused_alpha(self, keep_nodes):\n        for spec_name in list(self.alpha.keys()):\n            if spec_name not in keep_nodes:\n                del self.alpha[spec_name]\n                del self.init[spec_name]\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        # e.g., last layer input gurobi vars (3,32,32)\n        gvars_array = np.array(v[0])\n        # pre_layer_shape (1,32,27,27)\n        pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape\n        # this layer shape (1,32,6,6)\n        this_layer_shape = self.output_shape\n        assert this_layer_shape[2] ==  ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0])\n\n        new_layer_gurobi_vars = []\n        neuron_idx = 0\n        pre_ubs = self.forward(self.inputs[0].upper).detach().cpu().numpy()\n\n        for out_chan_idx in range(this_layer_shape[1]):\n            out_chan_vars = []\n            for out_row_idx in range(this_layer_shape[2]):\n                out_row_vars = []\n                for out_col_idx in range(this_layer_shape[3]):\n                    a_sum = 0.0\n                    v = model.addVar(lb=-float('inf'), ub=float('inf'),\n                                            obj=0, vtype=grb.GRB.CONTINUOUS,\n                                            name=f'lay{self.name}_{neuron_idx}')\n                    for ker_row_idx in range(self.kernel_size[0]):\n                        in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx\n                        if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])):\n                            # This is padding -> value of 0\n                            continue\n                        for ker_col_idx in range(self.kernel_size[1]):\n                            in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx\n                            if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]):\n                                # This is padding -> value of 0\n                                continue\n                            var = gvars_array[out_chan_idx][in_row_idx][in_col_idx]\n                            a = model.addVar(vtype=grb.GRB.BINARY)\n                            a_sum += a\n                            model.addConstr(v >= var)\n                            model.addConstr(v <= var + (1 - a) * pre_ubs[\n                                0, out_chan_idx, out_row_idx, out_col_idx])\n                    model.addConstr(a_sum == 1, name=f'lay{self.name}_{neuron_idx}_eq')\n                    out_row_vars.append(v)\n                out_chan_vars.append(out_row_vars)\n            new_layer_gurobi_vars.append(out_chan_vars)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()\n\n\n\nclass BoundGlobalAveragePool(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n\n    def forward(self, x):\n        output = nn.AdaptiveAvgPool2d((1, 1)).forward(x)  # adaptiveAveragePool with output size (1, 1)\n        return output\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        H, W = self.input_shape[-2], self.input_shape[-1]\n\n        lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None\n        uA = (last_uA.expand(list(last_uA.shape[:-2]) + [H, W]) / (H * W)) if last_uA is not None else None\n\n        return [(lA, uA)], 0, 0\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0]\n        h_L = F.adaptive_avg_pool2d(h_L, (1, 1))\n        h_U = F.adaptive_avg_pool2d(h_U, (1, 1))\n        return h_L, h_U\n\n\nclass BoundAveragePool(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        # assumptions: ceil_mode=False, count_include_pad=True\n        super().__init__(attr, inputs, output_index, options)\n\n        assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2])\n        assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3])\n\n        self.kernel_size = attr['kernel_shape']\n        assert len(self.kernel_size) == 2\n        self.stride = attr['strides']\n        assert len(self.stride) == 2\n        # FIXME (22/07/02): padding is inconsistently handled. Should use 4-tuple.\n\n        if 'pads' not in attr:\n            self.padding = [0, 0]\n        else:\n            self.padding = [attr['pads'][0], attr['pads'][1]]\n        self.ceil_mode = False\n        self.count_include_pad = True\n        self.use_default_ibp = True\n        self.relu_followed = False\n\n    def forward(self, x):\n        return F.avg_pool2d(x, self.kernel_size, self.stride,\n                            self.padding, self.ceil_mode, self.count_include_pad)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None, 0\n            equal_kernel_stride = (self.kernel_size[0] == self.stride[0]\n                                   and self.kernel_size[1] == self.stride[1])\n            if isinstance(last_A, torch.Tensor):\n                shape = last_A.size()\n                if equal_kernel_stride:\n                    # propagate A to the next layer, with batch concatenated together\n                    next_A = F.interpolate(\n                        last_A.reshape(shape[0] * shape[1], *shape[2:]),\n                        scale_factor=self.kernel_size\n                    ) / (prod(self.kernel_size))\n                    next_A = F.pad(\n                        next_A, (0, self.input_shape[-2] - next_A.shape[-2],\n                                 0, self.input_shape[-1] - next_A.shape[-1]))\n                    next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])\n                else:\n                    # Treat pooling as a general convolution\n                    weight = torch.zeros(\n                        self.input_shape[1], self.output_shape[1], *self.kernel_size,\n                        dtype=last_A.dtype, device=last_A.device)\n                    assert self.input_shape[1] == self.output_shape[1]\n                    weight = torch.eye(self.input_shape[1], dtype=last_A.dtype, device=last_A.device)\n                    weight = weight / prod(self.kernel_size)\n                    weight = weight.view(self.output_shape[1], self.input_shape[1], 1, 1)\n                    weight = weight.expand(self.output_shape[1], self.input_shape[1], *self.kernel_size)\n                    output_padding0 = (\n                        int(self.input_shape[2])\n                        - (int(self.output_shape[2]) - 1) * self.stride[0]\n                        + 2 * self.padding[0] - 1 - (int(weight.size()[2] - 1)))\n                    output_padding1 = (\n                        int(self.input_shape[3])\n                        - (int(self.output_shape[3]) - 1) * self.stride[1]\n                        + 2 * self.padding[1] - 1 - (int(weight.size()[3] - 1)))\n                    next_A = F.conv_transpose2d(\n                        last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None,\n                        stride=self.stride, padding=self.padding,\n                        output_padding=(output_padding0, output_padding1))\n                    next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:])\n            elif isinstance(last_A, Patches):\n                patches = last_A.patches\n                shape = patches.size()\n                # When the number of inserted zeros can cancel out the stride, we use a shortcut that can reduce computation.\n                simplify_patch = (equal_kernel_stride\n                                  and last_A.inserted_zeros + 1 == self.kernel_size[0]\n                                  and self.kernel_size[0] == self.kernel_size[1])\n                padding, stride, output_padding = compute_patches_stride_padding(\n                    self.input_shape, last_A.padding, last_A.stride,\n                    self.padding, self.stride,\n                    inserted_zeros=last_A.inserted_zeros,\n                    output_padding=last_A.output_padding,\n                    simplify=not simplify_patch)\n                inserted_zeros = last_A.inserted_zeros\n                if equal_kernel_stride and last_A.inserted_zeros == 0:\n                    # No inserted zeros, can be handled using interpolate.\n                    if last_A.unstable_idx is None:\n                        # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.interpolate(\n                            patches.reshape(shape[0] * shape[1],\n                                         shape[2] * shape[3], *shape[4:]),\n                            scale_factor=[1,] + self.kernel_size)\n                        # The dimension of patch-H and patch_W has changed.\n                        up_sampled_patches = up_sampled_patches.reshape(\n                            *shape[:-2], up_sampled_patches.size(-2),\n                            up_sampled_patches.size(-1))\n                    else:\n                        # shape is: [spec, batch, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.interpolate(\n                            patches, scale_factor=[1,] + self.kernel_size)\n                    # Divided by the averaging factor.\n                    up_sampled_patches = up_sampled_patches / prod(self.kernel_size)\n                elif simplify_patch:\n                    padding = tuple(p // s - o for p, s, o in zip(padding, stride, output_padding))\n                    output_padding = (0, 0, 0, 0)\n                    stride = 1  # Stride and inserted zero canceled out. No need to insert zeros and add output_padding.\n                    inserted_zeros = 0\n                    value = 1. / prod(self.kernel_size)\n                    # In the case where the stride and adding_zeros cancel out, we do not need to insert zeros.\n                    weight = torch.full(\n                        size=(self.input_shape[1], 1, *self.kernel_size),\n                        fill_value=value, dtype=patches.dtype,\n                        device=patches.device)\n                    if last_A.unstable_idx is None:\n                        # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.conv_transpose2d(\n                            patches.reshape(\n                                shape[0] * shape[1] * shape[2] * shape[3],\n                                *shape[4:]\n                            ), weight, stride=1, groups=self.input_shape[1])\n                    else:\n                        # shape is: [spec, batch, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.conv_transpose2d(\n                            patches.reshape(shape[0] * shape[1], *shape[2:]),\n                            weight, stride=1, groups=self.input_shape[1])\n                    up_sampled_patches = up_sampled_patches.view(\n                        *shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1))\n                else:\n                    # With inserted zeros, must be handled by treating pooling as general convolution.\n                    value = 1. / prod(self.kernel_size)\n                    weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size),\n                                        fill_value=value, dtype=patches.dtype,\n                                        device=patches.device)\n                    if not self.relu_followed:\n                        patches = last_A.create_padding(self.output_shape)\n                    weight = insert_zeros(weight, last_A.inserted_zeros)\n                    if last_A.unstable_idx is None:\n                        # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.conv_transpose2d(\n                            patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]),\n                            weight, stride=self.stride,\n                            groups=self.input_shape[1])\n                    else:\n                        # shape is: [spec, batch, in_c, patch_H, patch_W]\n                        up_sampled_patches = F.conv_transpose2d(\n                            patches.reshape(shape[0] * shape[1], *shape[2:]),\n                            weight, stride=self.stride,\n                            groups=self.input_shape[1])\n                    up_sampled_patches = up_sampled_patches.view(\n                        *shape[:-2], up_sampled_patches.size(-2),\n                        up_sampled_patches.size(-1))\n                next_A = last_A.create_similar(\n                    up_sampled_patches, stride=stride, padding=padding,\n                    output_padding=output_padding,\n                    inserted_zeros=inserted_zeros)\n            else:\n                raise ValueError(f'last_A has unexpected type {type(last_A)}')\n            return next_A, 0.\n\n        lA, lbias = _bound_oneside(last_lA)\n        uA, ubias = _bound_oneside(last_uA)\n        return [(lA, uA)], lbias, ubias\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        # e.g., last layer input gurobi vars (3,32,32)\n        gvars_array = np.array(v[0])\n        # pre_layer_shape (1,32,27,27)\n        pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape\n        # this layer shape (1,32,6,6)\n        this_layer_shape = self.output_shape\n        assert this_layer_shape[2] ==  (\n            (2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1)\n        ) // self.stride[0])\n\n        value = 1.0/(self.kernel_size[0] * self.kernel_size[1])\n        new_layer_gurobi_vars = []\n        neuron_idx = 0\n        for out_chan_idx in range(this_layer_shape[1]):\n            out_chan_vars = []\n            for out_row_idx in range(this_layer_shape[2]):\n                out_row_vars = []\n                for out_col_idx in range(this_layer_shape[3]):\n                    # print(self.bias.shape, out_chan_idx, out_lbs.size(1))\n                    lin_expr = 0.0\n                    for ker_row_idx in range(self.kernel_size[0]):\n                        in_row_idx = -self.padding[0] + self.stride[0] * out_row_idx + ker_row_idx\n                        if (in_row_idx < 0) or (in_row_idx == len(gvars_array[out_chan_idx][ker_row_idx])):\n                            # This is padding -> value of 0\n                            continue\n                        for ker_col_idx in range(self.kernel_size[1]):\n                            in_col_idx = -self.padding[1] + self.stride[1] * out_col_idx + ker_col_idx\n                            if (in_col_idx < 0) or (in_col_idx == pre_layer_shape[3]):\n                                # This is padding -> value of 0\n                                continue\n                            coeff = value\n                            lin_expr += coeff * gvars_array[out_chan_idx][in_row_idx][in_col_idx]\n                    v = model.addVar(lb=-float('inf'), ub=float('inf'),\n                                            obj=0, vtype=grb.GRB.CONTINUOUS,\n                                            name=f'lay{self.name}_{neuron_idx}')\n                    model.addConstr(lin_expr == v, name=f'lay{self.name}_{neuron_idx}_eq')\n                    neuron_idx += 1\n\n                    out_row_vars.append(v)\n                out_chan_vars.append(out_row_vars)\n            new_layer_gurobi_vars.append(out_chan_vars)\n\n        self.solver_vars = new_layer_gurobi_vars\n        model.update()"
  },
  {
    "path": "auto_LiRPA/operators/reduce.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Reduce operators\"\"\"\nfrom .base import *\nfrom torch.nn import Module\n\n\nclass BoundReduce(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.axis = attr.get('axes', None)\n        self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True\n        self.use_default_ibp = True\n\n    def _parse_input_and_axis(self, *x):\n        if len(x) > 1:\n            assert not self.is_input_perturbed(1)\n            self.axis = tuple(item.item() for item in tuple(x[1]))\n        self.axis = self.make_axis_non_negative(self.axis)\n        return x[0]\n\n    def _return_bound_backward(self, lA, uA):\n        return [(lA, uA)] + [(None, None)] * (len(self.inputs) - 1), 0, 0\n\n\nclass BoundReduceMax(BoundReduce):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        \"\"\"Assume that the indexes with the maximum values are not perturbed.\n        This generally doesn't hold true, but can still be used for the input shift\n        in Softmax of Transformers.\"\"\"\n        self.fixed_max_index = options.get('fixed_reducemax_index', False)\n\n    def _parse_input_and_axis(self, *x):\n        x = super()._parse_input_and_axis(*x)\n        # for torch.max, `dim` must be an int\n        if isinstance(self.axis, tuple):\n            assert len(self.axis) == 1\n            self.axis = self.axis[0]\n        return x\n\n    def forward(self, *x):\n        x = self._parse_input_and_axis(*x)\n        res = torch.max(x, dim=self.axis, keepdim=self.keepdim)\n        self.indices = res.indices\n        return res.values\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        if self.fixed_max_index:\n            def _bound_oneside(last_A):\n                if last_A is None:\n                    return None\n                indices = self.indices.unsqueeze(0)\n                if not self.keepdim:\n                    assert (self.from_input)\n                    last_A = last_A.unsqueeze(self.axis + 1)\n                    indices = indices.unsqueeze(self.axis + 1)\n                shape = list(last_A.shape)\n                shape[self.axis + 1] *= self.input_shape[self.axis]\n                A = torch.zeros(shape, device=last_A.device)\n                indices = indices.expand(*last_A.shape)\n                A.scatter_(dim=self.axis + 1, index=indices, src=last_A)\n                return A\n\n            return self._return_bound_backward(_bound_oneside(last_lA),\n                                               _bound_oneside(last_uA))\n        else:\n            raise NotImplementedError(\n                '`bound_backward` for BoundReduceMax with perturbed maximum'\n                'indexes is not implemented.')\n        \n    def build_gradient_node(self, grad_upstream):\n        if self.fixed_max_index:\n            node_grad = ReduceMaxGrad(self.axis, self.keepdim, self.input_shape, self.indices)\n            return [(node_grad, (grad_upstream,), [])]\n        else:\n            raise NotImplementedError(\n                '`build_gradient_node` for BoundReduceMax with perturbed maximum'\n                'indexes is not implemented.')\n\n\nclass ReduceMaxGrad(Module):\n    def __init__(self, axis, keepdim, input_shape, indices):\n        super().__init__()\n        self.axis = axis\n        self.keepdim = keepdim\n        self.input_shape = input_shape\n        self.indices = indices.unsqueeze(0)\n\n    def forward(self, grad_last):\n        # Only keep the gradient at the maximum index\n        # The gradient at other indices is 0\n        # If keepdim is False, add a singleton dimension at the specified axis\n        if not self.keepdim:\n            grad_last = grad_last.unsqueeze(self.axis + 1)\n            indices = self.indices.unsqueeze(self.axis + 1)\n        else:\n            indices = self.indices\n            assert grad_last.shape[self.axis + 1] == 1\n        # Calculate the target dimension size at axis + 1\n        new_dim = self.input_shape[self.axis]\n        # Create the output tensor shape\n        new_shape = list(grad_last.shape)\n        new_shape[self.axis + 1] = new_dim\n\n        ########################################################################\n        # TODO: The following lines are equivalent to:\n        #\n        # grad = torch.zeros(new_shape, device=grad_last.device)\n        # indices = indices.expand(*grad_last.shape)\n        # grad.scatter_(dim=self.axis + 1, index=indices, src=grad_last)\n        #\n        # But auto_LiRPA does not support scatter_ yet.\n        # So we use a workaround to avoid using scatter_.\n        ########################################################################\n\n        # Expand indices to match the target shape,\n        # filling axis + 1 with new_dim\n        indices_expanded = indices.expand(\n            *grad_last.shape[:self.axis + 1],\n            new_dim,\n            *grad_last.shape[self.axis + 2:]\n            ).to(grad_last.device)\n        # Create a coordinate tensor for comparison along axis + 1\n        coord_shape = [1] * grad_last.dim()\n        coord_shape[self.axis + 1] = new_dim\n        coord = torch.arange(new_dim, device=grad_last.device).view(*coord_shape)\n        # Create a binary mask where 1 indicates the desired position for each gradient\n        mask = (coord == indices_expanded).type_as(grad_last)\n        # Expand grad_last to match the target shape for element-wise multiplication\n        grad_last_expanded = grad_last.expand(\n            *grad_last.shape[:self.axis + 1],\n            new_dim,\n            *grad_last.shape[self.axis + 2:])\n        # Use the mask to retain values only at the correct positions\n        grad = mask * grad_last_expanded\n        return grad\n\n\nclass BoundReduceMin(BoundReduceMax):\n    def forward(self, *x):\n        x = self._parse_input_and_axis(*x)\n        res = torch.min(x, dim=self.axis, keepdim=self.keepdim)\n        self.indices = res.indices\n        return res.values\n\n\nclass BoundReduceMean(BoundReduce):\n    def forward(self, *x):\n        x = self._parse_input_and_axis(*x)\n        return torch.mean(x, dim=self.axis, keepdim=self.keepdim)\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            if not self.keepdim:\n                assert (self.from_input)\n                for axis in self.axis:\n                    if axis > 0:\n                        last_A = last_A.unsqueeze(axis + 1)\n            shape = list(last_A.shape)\n            shape[2:] = self.input_shape[1:]\n            # We perform expansion as in BoundReduceSum. \n            # and divide the product of the sizes of the reduced dimensions.\n            last_A = last_A.expand(*shape) / np.prod(np.take(self.input_shape, self.axis))\n            return last_A\n\n        return self._return_bound_backward(_bound_oneside(last_lA),\n                                           _bound_oneside(last_uA))\n\n    def bound_forward(self, dim_in, x, *args):\n        assert self.keepdim\n        assert len(self.axis) == 1\n        axis = self.make_axis_non_negative(self.axis[0])\n        assert (axis > 0)\n        size = self.input_shape[axis]\n        lw = x.lw.sum(dim=axis + 1, keepdim=True) / size\n        lb = x.lb.sum(dim=axis, keepdim=True) / size\n        uw = x.uw.sum(dim=axis + 1, keepdim=True) / size\n        ub = x.ub.sum(dim=axis, keepdim=True) / size\n        return LinearBound(lw, lb, uw, ub)\n\n\nclass BoundReduceSum(BoundReduce):\n    def forward(self, *x):\n        x = self._parse_input_and_axis(*x)\n        if self.axis is not None:\n            return torch.sum(x, dim=self.axis, keepdim=self.keepdim)\n        else:\n            return torch.sum(x)\n\n    def bound_backward(self, last_lA, last_uA, x, *args, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            if not self.keepdim:\n                assert (self.from_input)\n                for axis in self.axis:\n                    if axis > 0:\n                        last_A = last_A.unsqueeze(axis + 1)\n            # last_A.shape = [num_spec, batch_size, ..., dim_size_1 (1), ...]\n            shape = list(last_A.shape)\n            # self.input_shape = [batch_size_original, ..., dim_size_1_before_reduction, ...]\n            # we expand last_A with keeping its batch_size instead of that from self.input_shape.\n            shape[2:] = self.input_shape[1:]\n            # For reduced dims, their dim_size will be expanded from 1 to the original size.\n            # For non-reduced dims, their dim_size will be unchanged.\n            last_A = last_A.expand(*shape)\n            return last_A\n\n        return self._return_bound_backward(_bound_oneside(last_lA),\n                                           _bound_oneside(last_uA))\n\n    def bound_forward(self, dim_in, x, *args):\n        # Handle possibly multiple axes\n        axes = [self.make_axis_non_negative(ax) for ax in self.axis]\n        # Ensure all axes are greater than 0 (not batch dimension)\n        assert all(ax > 0 for ax in axes)\n        # For lw/uw, need to shift by 1 due to an extra leading dimension (num_spec)\n        lw = x.lw.sum(dim=[ax + 1 for ax in axes], keepdim=self.keepdim)\n        lb = x.lb.sum(dim=axes, keepdim=self.keepdim)\n        uw = x.uw.sum(dim=[ax + 1 for ax in axes], keepdim=self.keepdim)\n        ub = x.ub.sum(dim=axes, keepdim=self.keepdim)\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = ReduceSumGrad(self.axis, self.keepdim, self.input_shape)\n        return [(node_grad, (grad_upstream,), [])]\n        \n\nclass ReduceSumGrad(Module):\n    def __init__(self, axis, keepdim, input_shape):\n        super().__init__()\n        self.axis = axis\n        self.keepdim = keepdim\n        self.input_shape = input_shape\n    \n    def forward(self, grad_last):\n        grad_new = grad_last.clone()\n        if not self.keepdim:\n            for axis in self.axis:\n                if axis > 0:\n                    grad_new = grad_new.unsqueeze(axis + 1)\n        # For ReduceSum, ∂y/∂x = 1, so we just need to expand the gradient\n        # along each axis that is reduced.\n        shape = list(grad_new.shape)\n        shape[2:] = self.input_shape[1:]\n        grad_new = grad_new.expand(*shape)\n        return grad_new\n"
  },
  {
    "path": "auto_LiRPA/operators/relu.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"BoundRelu.\"\"\"\nfrom typing import Optional, Tuple\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom torch.autograd import Function\nfrom collections import OrderedDict\nfrom .base import *\nfrom .clampmult import multiply_by_A_signs\nfrom .activation_base import BoundActivation, BoundOptimizableActivation\nfrom .solver_utils import grb\nfrom ..utils import unravel_index, prod\n\n\nclass BoundTwoPieceLinear(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if options is None:\n            options = {}\n        self.options = options\n        self.ibp_intermediate = True\n        self.splittable = True\n        self.relu_options = options.get('activation_bound_option', 'adaptive')\n        self.use_sparse_spec_alpha = options.get('sparse_spec_alpha', False)\n        self.use_sparse_features_alpha = options.get('sparse_features_alpha', False)\n        self.alpha_lookup_idx = self.alpha_indices = None\n        self.beta = self.masked_beta = self.sparse_betas = None\n        self.split_beta_used = False\n        self.history_beta_used = False\n        self.flattened_nodes = None\n        self.patch_size = {}\n        self.cut_used = False\n        self.cut_module = None\n        self.gcp_unstable_relu_indicators = None\n\n    def init_opt_parameters(self, start_nodes):\n        ref = self.inputs[0].lower # a reference variable for getting the shape\n        batch_size = ref.size(0)\n        self.alpha = OrderedDict()\n        self.alpha_lookup_idx = OrderedDict()  # For alpha with sparse spec dimention.\n        self.alpha_indices = None  # indices of non-zero alphas.\n        verbosity = self.options.get('verbosity', 0)\n\n        # Alpha can be sparse in both spec dimension, and the C*H*W dimension.\n        # We first deal with the sparse-feature alpha, which is sparse in the\n        # C*H*W dimesnion of this layer.\n        minimum_sparsity = self.options.get('minimum_sparsity', 0.9)\n        if (self.use_sparse_features_alpha\n                and self.inputs[0].is_lower_bound_current()\n                and self.inputs[0].is_upper_bound_current()):\n            # Pre-activation bounds available, we will store the alpha for unstable neurons only.\n            # Since each element in a batch can have different unstable neurons,\n            # for simplicity we find a super-set using any(dim=0).\n            # This can be non-ideal if the x in a batch are very different.\n            self.get_unstable_idx()\n            total_neuron_size = self.inputs[0].lower.numel() // batch_size\n            if self.alpha_indices[0].size(0) <= minimum_sparsity * total_neuron_size:\n                # Shape is the number of unstable neurons in this layer.\n                alpha_shape = [self.alpha_indices[0].size(0)]\n                # Skip the batch, spec dimension, and find the lower slopes for all unstable neurons.\n                if len(self.alpha_indices) == 1:\n                    # This layer is after a linear layer.\n                    alpha_init = self.init_d[:, :, self.alpha_indices[0]]\n                elif len(self.alpha_indices) == 3:\n                    # This layer is after a conv2d layer.\n                    alpha_init = self.init_d[\n                        :, :, self.alpha_indices[0], self.alpha_indices[1],\n                        self.alpha_indices[2]]\n                elif len(self.alpha_indices) == 2:\n                    # This layer is after a conv1d layer.\n                    alpha_init = self.init_d[\n                                 :, :, self.alpha_indices[0], self.alpha_indices[1]]\n                else:\n                    raise ValueError\n                if verbosity > 0:\n                    print(f'layer {self.name} using sparse-features alpha with shape {alpha_shape}; unstable size '\n                          f'{self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({list(ref.shape)})')\n            else:\n                alpha_shape = self.shape  # Full alpha.\n                alpha_init = self.init_d\n                if verbosity > 0:\n                    print(f'layer {self.name} using full alpha with shape {alpha_shape}; unstable size '\n                          f'{self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({list(ref.shape)})')\n                self.alpha_indices = None  # Use full alpha.\n        else:\n            alpha_shape = self.shape  # Full alpha.\n            alpha_init = self.init_d\n        # Now we start to create alphas for all start nodes.\n        # When sparse-spec feature is enabled, alpha is created for only\n        # unstable neurons in start node.\n        for start_node in start_nodes:\n            ns, output_shape, unstable_idx = start_node[:3]\n            if isinstance(output_shape, (list, tuple)):\n                if len(output_shape) > 1:\n                    size_s = prod(output_shape)  # Conv layers.\n                else:\n                    size_s = output_shape[0]\n            else:\n                size_s = output_shape\n            # unstable_idx may be a tensor (dense layer or conv layer\n            # with shared alpha), or tuple of 3-d tensors (conv layer with\n            # non-sharing alpha).\n            sparsity = float('inf') if unstable_idx is None else unstable_idx.size(0) if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].size(0)\n            if sparsity <= minimum_sparsity * size_s and self.use_sparse_spec_alpha:\n                # For fully connected layer, or conv layer with shared alpha per channel.\n                # shape is (2, sparse_spec, batch, this_layer_shape)\n                # We create sparse specification dimension, where the spec dimension of alpha only includes slopes for unstable neurons in start_node.\n                self.alpha[ns] = torch.empty([self.alpha_size, sparsity + 1, batch_size, *alpha_shape],\n                                             dtype=torch.float, device=ref.device, requires_grad=True)\n                self.alpha[ns].data.copy_(alpha_init.data)  # This will broadcast to (2, sparse_spec) dimensions.\n                if verbosity > 0:\n                    print(f'layer {self.name} start_node {ns} using sparse-spec alpha {list(self.alpha[ns].size())}'\n                          f' with unstable size {sparsity} total_size {size_s} output_shape {output_shape}')\n                # unstable_idx is a list of used neurons (or channels for BoundConv) for the start_node.\n                assert unstable_idx.ndim == 1 if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].ndim == 1\n                # We only need to the alpha for the unstable neurons in start_node.\n                indices = torch.arange(1, sparsity + 1, device=alpha_init.device, dtype=torch.long)\n                if isinstance(output_shape, int) or len(output_shape) == 1:\n                    # Fully connected layers, or conv layer in patches mode with partially shared alpha (pixels in the same channel use the same alpha).\n                    self.alpha_lookup_idx[ns] = torch.zeros(size_s, dtype=torch.long, device=alpha_init.device)\n                    # This lookup table maps the unstable_idx to the actual alpha location in self.alpha[ns].\n                    # Note that self.alpha[ns][:,0] is reserved for any unstable neurons that are not found in the lookup table. This usually should not\n                    # happen, unless reference bounds are not properly set.\n                    self.alpha_lookup_idx[ns].data[unstable_idx] = indices\n                else:\n                    # conv layer in matrix mode, or in patches mode but with non-shared alpha. The lookup table is 3-d.\n                    assert len(output_shape) == 3\n                    self.alpha_lookup_idx[ns] = torch.zeros(output_shape, dtype=torch.long, device=alpha_init.device)\n                    if isinstance(unstable_idx, torch.Tensor):\n                        # Convert the unstable index from flattend 1-d to 3-d. (matrix mode).\n                        unstable_idx_3d = unravel_index(unstable_idx, output_shape)\n                    else:\n                        # Patches mode with non-shared alpha, unstable_idx is already 3d.\n                        unstable_idx_3d = unstable_idx\n                    # Build look-up table.\n                    self.alpha_lookup_idx[ns].data[unstable_idx_3d[0], unstable_idx_3d[1], unstable_idx_3d[2]] = indices\n            else:\n                # alpha shape is (2, spec, batch, this_layer_shape). \"this_layer_shape\" may still be sparse.\n                self.alpha[ns] = torch.empty([self.alpha_size, size_s, batch_size, *alpha_shape],\n                                             dtype=torch.float, device=ref.device, requires_grad=True)\n                self.alpha[ns].data.copy_(alpha_init.data)  # This will broadcast to (2, spec) dimensions\n                if verbosity > 0:\n                    print(f'layer {self.name} start_node {ns} using full alpha {list(self.alpha[ns].size())} with unstable '\n                          f'size {sparsity if unstable_idx is not None else None} total_size {size_s} output_shape {output_shape}')\n                # alpha_lookup_idx can be used for checking if sparse alpha is used or not.\n                self.alpha_lookup_idx[ns] = None\n\n    def select_alpha_by_idx(self, last_lA, last_uA, unstable_idx, start_node):\n        # Each alpha has shape (2, output_shape, batch_size, *relu_node_shape].\n        # If slope is shared, output_shape will be 1.\n        # The *relu_node_shape might be sparse (sparse-feature alpha), where the non-zero values are indicated by self.alpha_indices.\n        # The out_shape might be sparse (sparse-spec alpha), where the non-zero values are indexed by self.alpha_lookup_idx.\n        if unstable_idx is not None:\n            # print(f'relu layer {self.name}, start_node {start_node}, unstable_idx {type(unstable_idx)} alpha idx {self.alpha_lookup_idx[start_node.name].size()}')\n            if self.alpha_lookup_idx is not None:\n                alpha_lookup_idx = self.alpha_lookup_idx[start_node.name]\n            else:\n                alpha_lookup_idx = None\n            if isinstance(unstable_idx, tuple):\n                # Start node is a conv node.\n                selected_alpha = self.alpha[start_node.name]\n                if isinstance(last_lA, Tensor) or isinstance(last_uA, Tensor):\n                    # Start node is a conv node but we received tensors as A matrices.\n                    # Patches mode converted to matrix, or matrix mode used. Need to select accross the spec dimension.\n                    # For this node, since it is in matrix mode, the spec dimension is out_c * out_h * out_w\n                    # Shape is [2, spec, batch, *this_layer_shape]\n                    if alpha_lookup_idx is None:\n                        if self.options['optimize_bound_args'].get('use_shared_alpha', False):\n                            # alpha is shared, and its spec dimension is always 1. In this case we do not need to select.\n                            # selected_alpha will have shape [2, 1, batch, *this_layer_shape]\n                            pass\n                        else:\n                            # alpha is not shared, so it has shape [2, spec, batch, *this_layer_shape]\n                            # Reshape the spec dimension to c*h*w so we can select used alphas based on unstable index.\n                            # Shape becomes [2, out_c, out_h, out_w, batch, *this_layer_shape]\n                            selected_alpha = selected_alpha.view(selected_alpha.size(0), *start_node.output_shape[1:], *selected_alpha.shape[2:])\n                            selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]]\n                    else:\n                        assert alpha_lookup_idx.ndim == 3\n                        # We only stored some alphas, and A is also sparse, so the unstable_idx must be first translated to real indices.\n                        # alpha shape is (2, sparse_spec_shape, batch_size, *relu_node_shape) where relu_node_shape can also be sparse.\n                        # We use sparse-spec alphas. Need to convert these unstable_idx[0], unstable_idx[1], unstable_idx[0] using lookup table.\n                        _unstable_idx = alpha_lookup_idx[unstable_idx[0], unstable_idx[1], unstable_idx[2]]\n                        selected_alpha = self.non_deter_index_select(selected_alpha, index=_unstable_idx, dim=1)\n                else:\n                    # Patches mode. Alpha must be selected after unfolding, so cannot be done here.\n                    # Selection is deferred to maybe_unfold() using alpha_lookup_idx.\n                    # For partially shared alpha, its shape is (2, out_c, batch_size, *relu_node_shape).\n                    # For full alpha, its shape is (2, out_c*out_h*out_w, batch_size, *relu_node_shape).\n                    # Both the spec dimension and relu_node_shape dimensions can be sparse.\n                    pass\n            elif unstable_idx.ndim == 1:\n                # Start node is a FC node.\n                # Only unstable neurons of the start_node neurons are used.\n                assert alpha_lookup_idx is None or alpha_lookup_idx.ndim == 1\n                if self.options['optimize_bound_args'].get('use_shared_alpha', False):\n                    # Shared alpha is used, all output specs use the same alpha. No selection is needed.\n                    # The spec dim is 1 and will be broadcast.\n                    selected_alpha = self.alpha[start_node.name]\n                else:\n                    _unstable_idx = alpha_lookup_idx[unstable_idx] if alpha_lookup_idx is not None else unstable_idx\n                    selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=_unstable_idx, dim=1)\n            elif unstable_idx.ndim == 2:\n                assert alpha_lookup_idx is None, \"sparse spec alpha has not been implemented yet.\"\n                # Each element in the batch selects different neurons.\n                selected_alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1)\n            else:\n                raise ValueError\n        else:\n            # Spec dimension is dense. Alpha must not be created sparsely.\n            assert self.alpha_lookup_idx is None or self.alpha_lookup_idx[start_node.name] is None\n            selected_alpha = self.alpha[start_node.name]\n            alpha_lookup_idx = None\n        return selected_alpha, alpha_lookup_idx\n\n    def reconstruct_full_alpha(self, sparse_alpha, full_alpha_shape, alpha_indices):\n        full_alpha = torch.zeros(full_alpha_shape, dtype=sparse_alpha.dtype, device=sparse_alpha.device)\n        if len(alpha_indices) == 1:\n            # Relu after a dense layer.\n            full_alpha[:, :, alpha_indices[0]] = sparse_alpha\n        elif len(alpha_indices) == 3:\n            # Relu after a conv2d layer.\n            full_alpha[:, :, alpha_indices[0], alpha_indices[1], alpha_indices[2]] = sparse_alpha\n        elif len(alpha_indices) == 2:\n            # Relu after a conv1d layer.\n            full_alpha[:, :, alpha_indices[0], alpha_indices[1]] = sparse_alpha\n        else:\n            raise ValueError\n        return full_alpha\n\n    def bound_backward(self, last_lA, last_uA, x=None, start_node=None,\n                       unstable_idx=None, reduce_bias=True, **kwargs):\n        \"\"\"\n        start_node: the name of the layer where the backward bound propagation starts.\n                    Can be the output layer or an intermediate layer.\n        unstable_idx: indices for the unstable neurons, whose bounds need to be computed.\n                      Either be a tuple (for patches) or a 1-D tensor.\n        \"\"\"\n        lower = x.lower\n        upper = x.upper\n        # Get element-wise CROWN linear relaxations.\n        (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d,\n            lb_upper_d, ub_upper_d, lb_upper_b, ub_upper_b, alpha_lookup_idx) = \\\n            self._backward_relaxation(last_lA, last_uA, x, start_node, unstable_idx)\n        # save for calculate babsr score\n        self.d = upper_d\n        self.lA = last_lA\n        # Save for initialization bounds.\n        self.init_d = lower_d\n\n        # Choose upper or lower bounds based on the sign of last_A\n        def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg):\n            if last_A is None:\n                return None, 0\n            # Obtain the new linear relaxation coefficients based on the signs in last_A.\n            same_slope = True if self.relu_options == \"same-slope\" else False\n            _A, _bias = multiply_by_A_signs(\n                last_A, d_pos, d_neg, b_pos, b_neg, reduce_bias=reduce_bias, same_slope=same_slope)\n            if isinstance(last_A, Patches):\n                # Save the patch size, which will be used in init_alpha() to determine the number of optimizable parameters.\n                A_prod = _A.patches\n                if start_node is not None:\n                    if last_A.unstable_idx is not None:\n                        # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w).\n                        self.patch_size[start_node.name] = [\n                            last_A.output_shape[1], A_prod.size(1),\n                            last_A.output_shape[2], last_A.output_shape[3],\n                            A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)]\n                    else:\n                        # Regular patches.\n                        self.patch_size[start_node.name] = A_prod.size()\n            return _A, _bias\n\n        ######## A problem with patches mode for cut constraint start ##########\n        # There are cases that  the node that is in the constraint but not selected by the patches for the output node\n        # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum()\n        # we should force these beta to be 0 to disable the effect of these constraints\n        A = last_lA if last_lA is not None else last_uA\n        current_layer_shape = lower.size()[1:]\n        if self.cut_used and type(A) is Patches:\n            self.cut_module.patch_trick(start_node, self.name, A, current_layer_shape)\n        ######## A problem with patches mode for cut constraint end ##########\n\n        if self.cut_used:\n            if self.leaky_alpha > 0:\n                raise NotImplementedError\n            # propagate postrelu node in cut constraints\n            last_lA, last_uA = self.cut_module.relu_cut(\n                start_node, self.name, last_lA, last_uA, current_layer_shape,\n                unstable_idx, batch_mask=self.inputs[0].alpha_beta_update_mask)\n\n        # In patches mode we might need an unfold.\n        # lower_d, upper_d, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None\n        upper_d = maybe_unfold_patches(upper_d, last_lA if last_lA is not None else last_uA)\n        lower_d = maybe_unfold_patches(lower_d, last_lA if last_lA is not None else last_uA)\n        upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA)\n        lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA)  # for ReLU it is always None; keeping it here for completeness.\n        # ub_lower_d and lb_lower_d might have sparse spec dimension, so they may need alpha_lookup_idx to convert to actual spec dim.\n        ub_lower_d = maybe_unfold_patches(ub_lower_d, last_uA, alpha_lookup_idx=alpha_lookup_idx)\n        ub_upper_d = maybe_unfold_patches(ub_upper_d, last_uA, alpha_lookup_idx=alpha_lookup_idx)\n        # optimizable slope lb_lower_d: spec (only channels in spec layer), batch, current_c, current_w, current_h\n        # patches mode lb_lower_d after unfold: unstable, batch, in_C, H, W\n        lb_lower_d = maybe_unfold_patches(lb_lower_d, last_lA, alpha_lookup_idx=alpha_lookup_idx)\n        lb_upper_d = maybe_unfold_patches(lb_upper_d, last_lA, alpha_lookup_idx=alpha_lookup_idx)\n        # ub_upper_b and lb_upper_b can also be optimizable variables, just like ub/lb_upper/lower_d.\n        # This is only possible when alpha is optimized in the \"same-slope\" setting, where we move the linear upper bound together with the lower bound.\n        ub_upper_b = maybe_unfold_patches(ub_upper_b, last_lA, alpha_lookup_idx=alpha_lookup_idx)\n        lb_upper_b = maybe_unfold_patches(lb_upper_b, last_lA, alpha_lookup_idx=alpha_lookup_idx)\n\n        if self.cut_used:\n            assert reduce_bias\n            # Here, we create a tuple includes 3 masks:\n            # unstable_indicators. unstable neuron mask.\n            # positive_indicators. previous unstable now split on z = 1.\n            # negative_indicators. previous unstable now split on z = 0.\n            unstable_indicators = (lower < 0) * (upper > 0)\n            positive_indicators = ~(lower < 0) & self.gcp_unstable_relu_indicators\n            negative_indicators = ~(upper > 0) & self.gcp_unstable_relu_indicators\n            relu_indicators = (unstable_indicators, positive_indicators, negative_indicators)\n            # propagate integer var of relu neuron (arelu) in cut constraints through relu layer\n            lA, uA, lbias, ubias = self.cut_module.arelu_cut(\n                start_node, self.name, last_lA, last_uA, lower_d, upper_d,\n                lower_b, upper_b, lb_lower_d, ub_lower_d, relu_indicators, x, self.patch_size,\n                current_layer_shape, unstable_idx,\n                batch_mask=self.inputs[0].alpha_beta_update_mask)\n        else:\n            uA, ubias = _bound_oneside(\n                last_uA, ub_upper_d if upper_d is None else upper_d,\n                ub_lower_d if lower_d is None else lower_d,\n                ub_upper_b if ub_upper_b is not None else upper_b, lower_b)\n            lA, lbias = _bound_oneside(\n                last_lA, lb_lower_d if lower_d is None else lower_d,\n                lb_upper_d if upper_d is None else upper_d,\n                lower_b, lb_upper_b if lb_upper_b is not None else upper_b)\n\n        if self.cut_used:\n            # propagate prerelu node in cut constraints\n            lA, uA = self.cut_module.pre_cut(\n                start_node, self.name, lA, uA, current_layer_shape, unstable_idx,\n                batch_mask=self.inputs[0].alpha_beta_update_mask)\n        self.masked_beta_lower = self.masked_beta_upper = None\n\n        return [(lA, uA)], lbias, ubias\n\n    def _transfer_alpha_lookup_idx(self, alpha_lookup_idx, device=None, dtype=None, non_blocking=False):\n        if alpha_lookup_idx is None:\n            return None\n        alpha_lookup_idx = {spec_name: transfer(idx, device=device, dtype=dtype, non_blocking=non_blocking) if idx is not None else None\n                            for spec_name, idx in alpha_lookup_idx.items()}\n        return alpha_lookup_idx\n\n    def _transfer_alpha_indices(self, alpha_indices, device=None, dtype=None, non_blocking=False):\n        if alpha_indices is None:\n            return None\n        alpha_indices = [transfer(indices, device=device, dtype=dtype, non_blocking=non_blocking) for indices in alpha_indices]\n        return alpha_indices\n\n    def dump_alpha(self, device=None, dtype=None, non_blocking=False):\n        ret = {'alpha': self._transfer_alpha(self.alpha, device=device, dtype=dtype, non_blocking=non_blocking, require_grad=False)}\n        if self.use_sparse_spec_alpha:\n            ret['alpha_lookup_idx'] = self._transfer_alpha_lookup_idx(self.alpha_lookup_idx, device=device, dtype=None, non_blocking=non_blocking)\n        if self.use_sparse_features_alpha:\n            ret['alpha_indices'] = self._transfer_alpha_indices(self.alpha_indices, device=device, dtype=None, non_blocking=non_blocking)\n        return ret\n\n    def restore_alpha(self, alpha, device=None, dtype=None, non_blocking=False):\n        self.alpha = self._transfer_alpha(alpha['alpha'], device=device, dtype=dtype, non_blocking=non_blocking, require_grad=True)\n        if self.use_sparse_spec_alpha:\n            self.alpha_lookup_idx = self._transfer_alpha_lookup_idx(alpha['alpha_lookup_idx'], device=device, dtype=None, non_blocking=non_blocking)\n        if self.use_sparse_features_alpha:\n            self.alpha_indices = self._transfer_alpha_indices(alpha['alpha_indices'], device=device, dtype=None, non_blocking=non_blocking)\n\n    def drop_unused_alpha(self, keep_nodes):\n        for spec_name in list(self.alpha.keys()):\n            # If the spec_name is not in keep_nodes, we delete it.\n            if spec_name not in keep_nodes:\n                del self.alpha[spec_name]\n                # if use_sparse_spec_alpha is True, we also delete the alpha_lookup_idx if needed.\n                if self.use_sparse_spec_alpha:\n                    del self.alpha_lookup_idx[spec_name]\n\n        # if there is no alpha left and use_sparse_features_alpha is True,\n        # we also delete the alpha_indices.\n        if not self.alpha and self.use_sparse_features_alpha:\n            self.alpha_indices = None\n\n\nclass BoundRelu(BoundTwoPieceLinear):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        if attr is None:\n            attr = {}\n        self.leaky_alpha = attr.get('alpha', 0)\n        self.alpha_size = 2\n        # Alpha dimension is (2, output_shape, batch, *shape) for ReLU.\n\n    def get_unstable_idx(self):\n        self.alpha_indices = torch.logical_and(\n            self.inputs[0].lower < 0, self.inputs[0].upper > 0).any(dim=0).nonzero(as_tuple=True)\n\n    def clip_alpha(self):\n        for v in self.alpha.values():\n            v.data = torch.clamp(v.data, self.leaky_alpha, 1.)\n\n    def forward(self, x):\n        self.shape = x.shape[1:]\n        if self.flattened_nodes is None:\n            self.flattened_nodes = x[0].reshape(-1).shape[0]\n        if self.leaky_alpha > 0:\n            return F.leaky_relu(x, negative_slope=self.leaky_alpha)\n        else:\n            return F.relu(x)\n\n    def _relu_lower_bound_init(self, upper_k):\n        \"\"\"Return the initial lower bound without relaxation.\"\"\"\n        if self.relu_options == \"same-slope\":\n            # the same slope for upper and lower\n            lower_k = upper_k\n        elif self.relu_options == \"zero-lb\":\n            # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN\n            lower_k = torch.zeros_like(upper_k)\n            lower_k = (upper_k >= 1.0).to(upper_k)\n            if self.leaky_alpha > 0:\n                lower_k += (upper_k < 1.0).to(upper_k) * self.leaky_alpha\n        elif self.relu_options == \"one-lb\":\n            # Always use slope 1 as lower bound\n            lower_k = ((upper_k > self.leaky_alpha).to(upper_k)\n                       + (upper_k <= self.leaky_alpha).to(upper_k)\n                          * self.leaky_alpha)\n        else:\n            # adaptive\n            if self.leaky_alpha == 0:\n                lower_k = (upper_k > 0.5).to(upper_k)\n            else:\n                # FIXME this may not be optimal for leaky relu\n                lower_k = ((upper_k > 0.5).to(upper_k)\n                           + (upper_k <= 0.5).to(upper_k) * self.leaky_alpha)\n        return lower_k\n\n    def _relu_upper_opt_same_slope(self, lb_lower_d, ub_lower_d, upper_d, lower, upper):\n        \"\"\"\n        When \"same-slope\" option is enabled in CROWN-Optimized method, lower_d is get directly\n        from the optimizable paramters, so we force upper_d to be same as lower_d.\n\n        We want the same-slope upper bound to be as tight as possible, so it should pass one of the\n        vertices of the triangular convex hull of ReLU.\n\n        upper_d is the slopes of the upper bounds compputed with normal triangle relaxation.\n        For a single element:\n        - lb_lower_d > upper_d => The same-slope upper bound should pass through the left endpoint of relu;\n        - lb_lower_d < upper_d => The same-slope upper bound should pass through the right endpoint of relu.\n        \"\"\"\n        lower_y = F.relu(lower)\n        upper_y = F.relu(upper)\n\n        if lb_lower_d is None:\n            lb_upper_d = lb_upper_b = None\n        else:\n            lb_upper_d = lb_lower_d\n            b_left = lower_y - lb_upper_d * lower\n            b_right = upper_y - lb_upper_d * upper\n            use_left_end = (lb_lower_d >= upper_d)\n            lb_upper_b = use_left_end * b_left + ~use_left_end * b_right\n\n        if ub_lower_d is None:\n            ub_upper_d = ub_upper_b = None\n        else:\n            ub_upper_d = ub_lower_d\n            b_left = lower_y - ub_upper_d * lower\n            b_right = upper_y - ub_upper_d * upper\n            use_left_end = (ub_lower_d >= upper_d)\n            ub_upper_b = use_left_end * b_left + ~use_left_end * b_right\n\n        return lb_upper_d, lb_upper_b, ub_upper_d, ub_upper_b\n\n\n    def _forward_relaxation(self, x):\n        self._init_masks(x)\n        self.mask_pos = self.mask_pos.to(x.lower)\n        self.mask_both = self.mask_both.to(x.lower)\n\n        upper_k, upper_b = self._relu_upper_bound(\n            x.lower, x.upper, self.leaky_alpha)\n        self.uw = self.mask_pos + self.mask_both * upper_k\n        self.ub = self.mask_both * upper_b\n\n        if self.opt_stage in ['opt', 'reuse']:\n            # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape].\n            # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape]\n            # and we do not need its first two dimensions.\n            lower_k = self.alpha['_forward'][0, 0]\n        else:\n            lower_k = self._relu_lower_bound_init(upper_k)\n\n        # NOTE #FIXME Saved for initialization bounds for optimization.\n        # In the backward mode, same-slope bounds are used.\n        # But here it is using adaptive bounds which seem to be better\n        # for nn4sys benchmark with loose input bounds. Need confirmation\n        # for other cases.\n        self.lower_d = lower_k.detach() # saved for initializing optimized bounds\n\n        self.lw = self.mask_both * lower_k + self.mask_pos\n\n    def bound_dynamic_forward(self, x, max_dim=None, offset=0):\n        if self.leaky_alpha > 0:\n            raise NotImplementedError\n\n        if not hasattr(self, 'upper_k'):\n            # x.lower and x.upper remain same all the time,\n            # so the following only need to do once\n            self.upper_k, self.upper_b = self._relu_upper_bound(\n                x.lower, x.upper, self.leaky_alpha)\n            self.upper_b /= 2\n\n            self.device = x.lw.device\n            self.batch_size = x.lower.shape[0]\n            self.unstable = torch.logical_and(x.lower < 0, x.upper > 0).view(self.batch_size, -1).to(torch.int)\n            self.tot_dim = x.tot_dim + int(self.unstable.sum(dim=-1).max())\n\n            self.b_new = self.upper_k * x.lb + self.upper_b\n\n        b_new = self.b_new\n        batch_size = self.batch_size\n        device = self.device\n        unstable = self.unstable\n        if x.lw.shape[1]:\n            # Compute only when x.lw is not empty\n            w_new = self.upper_k.unsqueeze(1) * x.lw\n        else:\n            w_new = torch.empty_like(x.lw)\n\n        if offset + w_new.shape[1] < x.tot_dim:\n            return LinearBound(\n                w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=self.tot_dim)\n\n        # Create new variables for unstable ReLU\n        index = torch.cumsum(unstable, dim=-1).to(torch.int64)\n        index = (index - (offset + w_new.shape[1] - x.tot_dim)).clamp(min=0)\n        num_new_dim = int(index.max())\n        num_new_dim_actual = min(num_new_dim, max_dim - w_new.shape[1])\n        index = index.clamp(max=num_new_dim_actual+1)\n        w_unstable = torch.zeros(batch_size, num_new_dim_actual + 2, unstable.size(-1), device=device)\n        x_L_unstable = -torch.ones(batch_size, num_new_dim_actual, device=device)\n        x_U_unstable = torch.ones(batch_size, num_new_dim_actual, device=device)\n        w_unstable.scatter_(dim=1, index=index.unsqueeze(1), src=self.upper_b.view(batch_size, 1, -1), reduce='add')\n        w_unstable = w_unstable[:, 1:-1].view(batch_size, num_new_dim_actual, *w_new.shape[2:])\n\n        w_new = torch.cat([w_new, w_unstable], dim=1)\n        x_L_new = torch.cat([x.x_L, x_L_unstable], dim=-1)\n        x_U_new = torch.cat([x.x_U, x_U_unstable], dim=-1)\n\n        return LinearBound(\n            w_new, b_new, w_new, b_new, x_L=x_L_new, x_U=x_U_new, tot_dim=self.tot_dim)\n\n    def bound_forward(self, dim_in, x):\n        self._forward_relaxation(x)\n        lb = self.lw * x.lb\n        ub = self.uw * x.ub + self.ub\n        lw = (self.lw.unsqueeze(1) * x.lw) if x.lw is not None else None\n        uw = (self.uw.unsqueeze(1) * x.uw) if x.uw is not None else None\n        if not lw.requires_grad:\n            del self.mask_both, self.mask_pos\n            del self.lw, self.uw, self.ub\n        return LinearBound(lw, lb, uw, ub)\n\n    @staticmethod\n    @torch.jit.script\n    def _relu_upper_bound(lb, ub, leaky_alpha: float):\n        \"\"\"Upper bound slope and intercept according to CROWN relaxation.\"\"\"\n        lb_r = lb.clamp(max=0)\n        ub_r = ub.clamp(min=0)\n        ub_r = torch.max(ub_r, lb_r + 1e-8)\n        if leaky_alpha > 0:\n            upper_d = (ub_r - leaky_alpha * lb_r) / (ub_r - lb_r)\n            upper_b = - lb_r * upper_d + leaky_alpha * lb_r\n        else:\n            upper_d = ub_r / (ub_r - lb_r)\n            upper_b = - lb_r * upper_d\n        return upper_d, upper_b\n\n    @staticmethod\n    def _relu_mask_alpha(lower, upper, lb_lower_d : Optional[Tensor],\n                         ub_lower_d : Optional[Tensor], leaky_alpha : float = 0,\n                        ) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]:\n        lower_mask = (lower >= 0).requires_grad_(False).to(lower.dtype)\n        upper_mask = (upper <= 0).requires_grad_(False)\n        if leaky_alpha > 0:\n            zero_coeffs = False\n        else:\n            zero_coeffs = upper_mask.all()\n        no_mask = (1. - lower_mask) * (1. - upper_mask.to(upper.dtype))\n        if lb_lower_d is not None:\n            lb_lower_d = (\n                torch.clamp(lb_lower_d, min=leaky_alpha, max=1.) * no_mask\n                + lower_mask)\n            if leaky_alpha > 0:\n                lb_lower_d += upper_mask * leaky_alpha\n        if ub_lower_d is not None:\n            ub_lower_d = (\n                torch.clamp(ub_lower_d, min=leaky_alpha, max=1.) * no_mask\n                + lower_mask)\n            if leaky_alpha > 0:\n                ub_lower_d += upper_mask * leaky_alpha\n        return lb_lower_d, ub_lower_d, zero_coeffs\n\n    def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx):\n        # Usage of output constraints requires access to bounds of the previous iteration\n        # (see _clear_and_set_new)\n        if x is not None:\n            lower = x.lower\n            upper = x.upper\n        else:\n            lower = self.lower\n            upper = self.upper\n\n        # Upper bound slope and intercept according to CROWN relaxation.\n        upper_d, upper_b = self._relu_upper_bound(lower, upper, self.leaky_alpha)\n\n        flag_expand = False\n\n        ub_lower_d = lb_lower_d = None\n        ub_upper_d = lb_upper_d = None\n        ub_upper_b = lb_upper_b = None\n\n        lower_b = None  # ReLU does not have lower bound intercept (=0).\n        alpha_lookup_idx = None  # For sparse-spec alpha.\n        if self.opt_stage in ['opt', 'reuse']:\n            # Alpha-CROWN.\n            lower_d = None\n            selected_alpha, alpha_lookup_idx = self.select_alpha_by_idx(\n                last_lA, last_uA, unstable_idx, start_node)\n            # The first dimension is lower/upper intermediate bound.\n            if last_lA is not None:\n                lb_lower_d = selected_alpha[0]\n            if last_uA is not None:\n                ub_lower_d = selected_alpha[1]\n\n            if self.alpha_indices is not None:\n                # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only.\n                # Recover to full alpha first.\n                sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape\n                full_alpha_shape = sparse_alpha_shape[:-1] + self.shape\n                if lb_lower_d is not None:\n                    lb_lower_d = self.reconstruct_full_alpha(\n                        lb_lower_d, full_alpha_shape, self.alpha_indices)\n                if ub_lower_d is not None:\n                    ub_lower_d = self.reconstruct_full_alpha(\n                        ub_lower_d, full_alpha_shape, self.alpha_indices)\n\n            lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d, ub_lower_d, leaky_alpha=self.leaky_alpha)\n            self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = zero_coeffs\n            flag_expand = True  # we already have the spec dimension.\n\n            if self.relu_options == \"same-slope\":\n                # same-slope with optimized lower_d\n                # We force upper_d to be the same as lower_d, and compute the corresponding upper_b\n                lb_upper_d, lb_upper_b, ub_upper_d, ub_upper_b = self._relu_upper_opt_same_slope(lb_lower_d, ub_lower_d, upper_d, lower, upper)\n\n        else:\n            # FIXME: the shape can be incorrect if unstable_idx is not None.\n            # This will cause problem if some ReLU layers are optimized, some are not.\n            lower_d = self._relu_lower_bound_init(upper_d)\n\n        # Upper bound always needs an extra specification dimension, since they only depend on lb and ub.\n        upper_d = upper_d.unsqueeze(0)\n        upper_b = upper_b.unsqueeze(0)\n        if not flag_expand:\n            # FIXME: The following lines seem unused since\n            # flag_expand must be true when self.optstage in ['opt, 'reuse']\n            if self.opt_stage in ['opt', 'reuse']:\n                # We have different slopes for lower and upper bounds propagation.\n                lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None\n                ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None\n\n                if self.relu_options == \"same-slope\":\n                    upper_d = None\n                    lb_upper_d = lb_upper_d.unsqueeze(0) if last_lA is not None else None\n                    lb_upper_b = lb_upper_b.unsqueeze(0) if last_lA is not None else None\n                    ub_upper_d = ub_upper_d.unsqueeze(0) if last_uA is not None else None\n                    ub_upper_b = ub_upper_b.unsqueeze(0) if last_uA is not None else None\n            else:\n                lower_d = lower_d.unsqueeze(0)\n\n        if self.opt_stage in ['opt', 'reuse'] and self.relu_options == \"same-slope\":\n            # Remove upper_d and upper_b to avoid confusion later\n            upper_d = None\n            upper_b = None\n\n        return (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d,\n                lb_upper_d, ub_upper_d, lb_upper_b, ub_upper_b, alpha_lookup_idx)\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        return self.forward(h_L), self.forward(h_U)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if self.leaky_alpha > 0:\n            raise NotImplementedError\n\n        # e.g., last layer input gurobi vars (8,16,16)\n        gvars_array = np.array(v[0])\n        this_layer_shape = gvars_array.shape\n        assert gvars_array.shape == self.output_shape[1:]\n\n        pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1)\n        pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1)\n\n        new_layer_gurobi_vars = []\n        relu_integer_vars = []\n        new_relu_layer_constrs = []\n        # predefined zero variable shared in the whole solver model\n        zero_var = model.getVarByName(\"zero\")\n\n        for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)):\n            pre_ub = pre_ubs[neuron_idx]\n            pre_lb = pre_lbs[neuron_idx]\n\n            if pre_lb >= 0:\n                # ReLU is always passing\n                var = pre_var\n            elif pre_ub <= 0:\n                var = zero_var\n            else:\n                ub = pre_ub\n\n                var = model.addVar(ub=ub, lb=0,\n                                   obj=0,\n                                   vtype=grb.GRB.CONTINUOUS,\n                                   name=f'ReLU{self.name}_{neuron_idx}')\n\n                if model_type == \"mip\" or model_type == \"lp_integer\":\n                    # binary indicator\n                    if model_type == \"mip\":\n                        a = model.addVar(vtype=grb.GRB.BINARY, name=f'aReLU{self.name}_{neuron_idx}')\n                    elif model_type == \"lp_integer\":\n                        a = model.addVar(ub=1, lb=0, vtype=grb.GRB.CONTINUOUS, name=f'aReLU{self.name}_{neuron_idx}')\n                    relu_integer_vars.append(a)\n\n                    new_relu_layer_constrs.append(\n                        model.addConstr(pre_var - pre_lb * (1 - a) >= var,\n                                        name=f'ReLU{self.name}_{neuron_idx}_a_0'))\n                    new_relu_layer_constrs.append(\n                        model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1'))\n                    new_relu_layer_constrs.append(\n                        model.addConstr(pre_ub * a >= var, name=f'ReLU{self.name}_{neuron_idx}_a_2'))\n\n                elif model_type == \"lp\":\n                    new_relu_layer_constrs.append(\n                        model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_0'))\n                    new_relu_layer_constrs.append(model.addConstr(\n                        pre_ub * pre_var - (pre_ub - pre_lb) * var >= pre_ub * pre_lb,\n                        name=f'ReLU{self.name}_{neuron_idx}_a_1'))\n\n                else:\n                    print(f\"gurobi model type {model_type} not supported!\")\n\n            new_layer_gurobi_vars.append(var)\n\n        new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist()\n        if model_type in [\"mip\", \"lp_integer\"]:\n            self.integer_vars = relu_integer_vars\n        self.solver_vars = new_layer_gurobi_vars\n        self.solver_constrs = new_relu_layer_constrs\n        model.update()\n\n    def build_gradient_node(self, grad_upstream):\n        if self.leaky_alpha > 0:\n            raise NotImplementedError\n        node_grad = ReLUGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        # An extra node is needed to consider the state of ReLU activation\n        grad_extra_nodes = [self.inputs[0]]\n        return [(node_grad, grad_input, grad_extra_nodes)]\n\n    def get_split_mask(self, lower, upper, input_index):\n        assert input_index == 0\n        return torch.logical_and(lower < 0, upper > 0)\n\n    # Return unstable mask to determine which neuron should use constraints_solving concretization\n    def get_unstable_mask(self, lower, upper):\n        \"\"\"Return a mask to indicate if each neuron is unstable.\n\n        0: Stable (linear) neuron; 1: unstable (nonlinear) neuron.\n        \"\"\"\n        return torch.logical_and(lower < 0, upper > 0)\n\n    # Return heuristic to select which neuron should use constraints_solving concretization\n    def compute_bound_improvement_heuristics(self, lower, upper):\n        \"\"\"Return a heuristic score for each lower-upper bound pair.\n        It indicates the possible bound improvement for each neuron.\n        We will then choose if a neuron's bound needs further tightened based on the heuristic \n        \"\"\"\n        # This heuristic is actually BaBSR-interception-only.\n        return (-lower * upper).clamp(min=0) / (upper - lower + 1e-8).abs()\n\nclass BoundLeakyRelu(BoundRelu):\n    pass\n\n\nclass BoundSign(BoundActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.splittable = True\n\n    def forward(self, x):\n        return torch.sign(x)\n\n    def bound_relax(self, x, init=False):\n        if init:\n            self.init_linear_relaxation(x)\n        mask_0 = torch.logical_and(x.lower == 0, x.upper == 0)\n        mask_pos_0 = torch.logical_and(x.lower == 0, x.upper > 0)\n        mask_neg_0 = torch.logical_and(x.lower < 0, x.upper == 0)\n        mask_pos = x.lower > 0\n        mask_neg = x.upper < 0\n        mask_both = torch.logical_not(torch.logical_or(torch.logical_or(\n            mask_0, torch.logical_or(mask_pos, mask_pos_0)),\n            torch.logical_or(mask_neg, mask_neg_0)))\n        self.add_linear_relaxation(mask=mask_0, type='lower',\n            k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=0)\n        self.add_linear_relaxation(mask=mask_0, type='upper',\n            k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=0)\n\n        self.add_linear_relaxation(mask=mask_pos_0, type='lower',\n            k=1/x.upper.clamp(min=1e-8), x0=torch.zeros_like(x.upper), y0=0)\n        self.add_linear_relaxation(mask=torch.logical_or(mask_pos_0, mask_pos), type='upper',\n            k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1)\n\n        self.add_linear_relaxation(mask=torch.logical_or(mask_neg_0, mask_neg), type='lower',\n            k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1)\n        self.add_linear_relaxation(mask=mask_neg_0, type='upper',\n            k=-1/x.lower.clamp(max=-1e-8), x0=torch.zeros_like(x.upper), y0=0)\n\n        self.add_linear_relaxation(mask=mask_pos, type='lower', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1)\n        self.add_linear_relaxation(mask=mask_neg, type='upper', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1)\n        self.add_linear_relaxation(mask=mask_both, type='lower', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1)\n        self.add_linear_relaxation(mask=mask_both, type='upper', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1)\n\n\nclass SignMergeFunction_loose(torch.autograd.Function):\n    # Modified SignMerge operator.\n    # Change its backward function so that the \"gradient\" can be used for pgd attack\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        output = torch.sign(torch.sign(input) + 1e-1)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        eps = 5     # should be carefully chosen\n        input, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        grad_input[abs(input) >= eps] = 0\n        grad_input /= eps\n        return grad_input\n\nclass SignMergeFunction_tight(torch.autograd.Function):\n    # Modified SignMerge operator.\n    # Change its backward function so that the \"gradient\" can be used for pgd attack\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        output = torch.sign(torch.sign(input) + 1e-1)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        eps = 0.1     # should be carefully chosen\n        input, = ctx.saved_tensors\n        grad_input = grad_output.clone()\n        grad_input[abs(input) >= eps] = 0\n        grad_input /= eps\n        return grad_input\n\n\nclass BoundSignMerge(BoundTwoPieceLinear):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.alpha_size = 4\n        self.loose_function = SignMergeFunction_loose\n        self.tight_function = SignMergeFunction_tight\n        self.signmergefunction = self.tight_function    # default\n\n    def get_unstable_idx(self):\n        self.alpha_indices = torch.logical_and(\n            self.inputs[0].lower < 0, self.inputs[0].upper >= 0).any(dim=0).nonzero(as_tuple=True)\n\n    def forward(self, x):\n        self.shape = x.shape[1:]\n        return self.signmergefunction.apply(x)\n\n    def _mask_alpha(self, lower, upper, lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d):\n        lower_mask = (lower >= 0.).requires_grad_(False).to(lower.dtype)\n        upper_mask = (upper < 0.).requires_grad_(False).to(upper.dtype)\n        no_mask = 1. - (lower_mask + upper_mask)\n        if lb_lower_d is not None:\n            lb_lower_d = torch.min(lb_lower_d, 2/upper.clamp(min=1e-8))\n            lb_lower_d = torch.clamp(lb_lower_d, min=0) * no_mask\n            lb_upper_d = torch.min(lb_upper_d, -2/lower.clamp(max=-1e-8))\n            lb_upper_d = torch.clamp(lb_upper_d, min=0) * no_mask\n        if ub_lower_d is not None:\n            ub_lower_d = torch.min(ub_lower_d, 2/upper.clamp(min=1e-8))\n            ub_lower_d = torch.clamp(ub_lower_d, min=0) * no_mask\n            ub_upper_d = torch.min(ub_upper_d, -2/lower.clamp(max=-1e-8))\n            ub_upper_d = torch.clamp(ub_upper_d, min=0) * no_mask\n        return lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d\n\n    def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx):\n        if x is not None:\n            lower, upper = x.lower, x.upper\n        else:\n            lower, upper = self.lower, self.upper\n\n        flag_expand = False\n        ub_lower_d = lb_lower_d = lb_upper_d = ub_upper_d = None\n        alpha_lookup_idx = None  # For sparse-spec alpha.\n        if self.opt_stage in ['opt', 'reuse']:\n            # Alpha-CROWN.\n            upper_d = lower_d = None\n            selected_alpha, alpha_lookup_idx = self.select_alpha_by_idx(\n                last_lA, last_uA, unstable_idx, start_node)\n            # The first dimension is lower/upper intermediate bound.\n            if last_lA is not None:\n                lb_lower_d = selected_alpha[0]\n                lb_upper_d = selected_alpha[2]\n            if last_uA is not None:\n                ub_lower_d = selected_alpha[1]\n                ub_upper_d = selected_alpha[3]\n\n            if self.alpha_indices is not None:\n                # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only.\n                # Recover to full alpha first.\n                sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape\n                full_alpha_shape = sparse_alpha_shape[:-1] + self.shape\n                if lb_lower_d is not None:\n                    lb_lower_d = self.reconstruct_full_alpha(\n                        lb_lower_d, full_alpha_shape, self.alpha_indices)\n                    lb_upper_d = self.reconstruct_full_alpha(\n                        lb_upper_d, full_alpha_shape, self.alpha_indices)\n                if ub_lower_d is not None:\n                    ub_lower_d = self.reconstruct_full_alpha(\n                        ub_lower_d, full_alpha_shape, self.alpha_indices)\n                    ub_upper_d = self.reconstruct_full_alpha(\n                        ub_upper_d, full_alpha_shape, self.alpha_indices)\n\n            lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d = self._mask_alpha(lower, upper,\n                lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d)\n            flag_expand = True  # we already have the spec dimension.\n        else:\n            lower_d = torch.zeros_like(upper, requires_grad=True)\n            upper_d = torch.zeros_like(upper, requires_grad=True)\n\n        mask_pos = (x.lower >= 0.).requires_grad_(False).to(x.lower.dtype)\n        mask_neg = (x.upper < 0.).requires_grad_(False).to(x.upper.dtype)\n        lower_b = (-1 * (1 - mask_pos) + mask_pos).unsqueeze(0)\n        upper_b = (-1 * mask_neg + (1 - mask_neg)).unsqueeze(0)\n\n        # Upper bound always needs an extra specification dimension, since they only depend on lb and ub.\n        if not flag_expand:\n            if self.opt_stage in ['opt', 'reuse']:\n                # We have different slopes for lower and upper bounds propagation.\n                lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None\n                ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None\n                lb_upper_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None\n                ub_upper_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None\n            else:\n                lower_d = lower_d.unsqueeze(0)\n                upper_d = upper_d.unsqueeze(0)\n        return (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d,\n            lb_upper_d, ub_upper_d, None, None, alpha_lookup_idx)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n\n        # e.g., last layer input gurobi vars (8,16,16)\n        gvars_array = np.array(v[0])\n        this_layer_shape = gvars_array.shape\n        assert gvars_array.shape == self.output_shape[1:]\n\n        pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1)\n        pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1)\n\n        new_layer_gurobi_vars = []\n        integer_vars = []\n        layer_constrs = []\n        # predefined zero variable shared in the whole solver model\n        one_var = model.getVarByName(\"one\")\n        neg_one_var = model.getVarByName(\"neg_one\")\n\n        for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)):\n            pre_ub = pre_ubs[neuron_idx]\n            pre_lb = pre_lbs[neuron_idx]\n\n            if pre_lb >= 0:\n                var = one_var\n            elif pre_ub < 0:\n                var = neg_one_var\n            else:\n                ub = pre_ub\n\n                var = model.addVar(ub=ub, lb=pre_lb,\n                                   obj=0,\n                                   vtype=grb.GRB.CONTINUOUS,\n                                   name=f'Sign{self.name}_{neuron_idx}')\n\n                a = model.addVar(vtype=grb.GRB.BINARY, name=f'aSign{self.name}_{neuron_idx}')\n                integer_vars.append(a)\n\n                layer_constrs.append(\n                    model.addConstr(pre_lb * a <= pre_var, name=f'Sign{self.name}_{neuron_idx}_a_0'))\n                layer_constrs.append(\n                    model.addConstr(pre_ub * (1 - a) >= pre_var, name=f'Sign{self.name}_{neuron_idx}_a_1'))\n                layer_constrs.append(\n                    model.addConstr(var == 1 - 2*a, name=f'Sign{self.name}_{neuron_idx}_a_2'))\n\n            new_layer_gurobi_vars.append(var)\n\n        new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist()\n        if model_type in [\"mip\", \"lp_integer\"]:\n            self.integer_vars = integer_vars\n        self.solver_vars = new_layer_gurobi_vars\n        self.solver_constrs = layer_constrs\n        model.update()\n\n\ndef relu_grad(preact):\n    return (preact > 0).float()\n\n\nclass ReLUGradOp(Function):\n    \"\"\" Local gradient of ReLU.\n\n    Not including multiplication with gradients from other layers.\n    \"\"\"\n    @staticmethod\n    def symbolic(_, g, g_relu, g_relu_rev, preact):\n        return _.op('grad::Relu', g, g_relu, g_relu_rev, preact).setType(g.type())\n\n    @staticmethod\n    def forward(ctx, g, g_relu, g_relu_rev, preact):\n        return g * relu_grad(preact)\n\n\nclass ReLUGrad(Module):\n    def forward(self, g, preact):\n        g_relu = F.relu(g)\n        g_relu_rev = -F.relu(-g)\n        return ReLUGradOp.apply(g, g_relu, g_relu_rev, preact)\n\n# FIXME reuse the function from auto_LiRPA.patches\ndef _maybe_unfold(d_tensor, last_A):\n    if d_tensor is None:\n        return None\n\n    #[batch, out_dim, in_c, in_H, in_W]\n    d_shape = d_tensor.size()\n\n    # Reshape to 4-D tensor to unfold.\n    #[batch, out_dim*in_c, in_H, in_W]\n    d_tensor = d_tensor.view(d_shape[0], -1, *d_shape[-2:])\n    # unfold the slope matrix as patches.\n    # Patch shape is [batch, out_h, out_w, out_dim*in_c, H, W).\n    d_unfolded = inplace_unfold(\n        d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride,\n        padding=last_A.padding)\n    # Reshape to [batch, out_H, out_W, out_dim, in_C, H, W]\n    d_unfolded_r = d_unfolded.view(\n        *d_unfolded.shape[:3], d_shape[1], *d_unfolded.shape[-2:])\n    if last_A.unstable_idx is not None:\n        if len(last_A.unstable_idx) == 4:\n            # [batch, out_H, out_W, out_dim, in_C, H, W]\n            # to [out_H, out_W, batch, out_dim, in_C, H, W]\n            d_unfolded_r = d_unfolded_r.permute(1, 2, 0, 3, 4, 5, 6)\n            d_unfolded_r = d_unfolded_r[\n                last_A.unstable_idx[2], last_A.unstable_idx[3]]\n        else:\n            raise NotImplementedError\n    # For sparse patches, the shape after unfold is\n    # (unstable_size, batch_size, in_c, H, W).\n    # For regular patches, the shape after unfold is\n    # (spec, batch, out_h, out_w, in_c, H, W).\n    return d_unfolded_r\n\n\nclass BoundReluGrad(BoundActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.requires_input_bounds = [3]\n        self.recurjac = options.get('recurjac', False)\n\n    @staticmethod\n    def relu_grad(preact):\n        return (preact > 0).float()\n\n    def forward(self, g, g_relu, g_relu_rev, preact):\n        if g.ndim == preact.ndim + 1:\n            preact = preact.unsqueeze(1)\n        return g * relu_grad(preact)\n\n    def interval_propagate(self, *v):\n        g_lower, g_upper = v[0]\n        preact_lower, preact_upper = v[3]\n        relu_grad_lower = relu_grad(preact_lower)\n        relu_grad_upper = relu_grad(preact_upper)\n        if g_lower.ndim == relu_grad_lower.ndim + 1:\n            relu_grad_lower = relu_grad_lower.unsqueeze(1)\n            relu_grad_upper = relu_grad_upper.unsqueeze(1)\n        lower = torch.min(g_lower * relu_grad_lower, g_lower * relu_grad_upper)\n        upper = torch.max(g_upper * relu_grad_lower, g_upper * relu_grad_upper)\n        return lower, upper\n\n    def bound_backward(self, last_lA, last_uA, g, g_relu, g_relu_rev, preact,\n                       **kwargs):\n        mask_active = (preact.lower > 0).float()\n        mask_inactive = (preact.upper < 0).float()\n        mask_unstable = 1 - mask_active - mask_inactive\n\n        if self.recurjac and self.inputs[0].perturbed:\n            upper_grad = preact.upper >= 0\n            lower_interval = self.inputs[0].lower * upper_grad\n            upper_interval = self.inputs[0].upper * upper_grad\n        else:\n            lower_interval = upper_interval = None\n\n        def _bound_oneside(last_A, pos_interval=None, neg_interval=None):\n            if last_A is None:\n                return None, None, None, 0\n\n            if isinstance(last_A, torch.Tensor):\n                if self.recurjac and self.inputs[0].perturbed:\n                    mask_unstable_grad = (\n                        (self.inputs[0].lower < 0) * (self.inputs[0].upper > 0))\n                    last_A_unstable = last_A * mask_unstable_grad\n                    bias = (\n                        last_A_unstable.clamp(min=0) * pos_interval\n                        + last_A_unstable.clamp(max=0) * neg_interval)\n                    bias = bias.reshape(\n                        bias.shape[0], bias.shape[1], -1).sum(dim=-1)\n                    last_A = last_A * torch.logical_not(mask_unstable_grad)\n                else:\n                    bias = 0\n                A = last_A * mask_active\n                A_pos = last_A.clamp(min=0) * mask_unstable\n                A_neg = last_A.clamp(max=0) * mask_unstable\n                return A, A_pos, A_neg, bias\n            elif isinstance(last_A, Patches):\n                last_A_patches = last_A.patches\n\n                if self.recurjac and self.inputs[0].perturbed:\n                    mask_unstable_grad = (\n                        (self.inputs[0].lower < 0) * (self.inputs[0].upper > 0))\n                    mask_unstable_grad_unfold = _maybe_unfold(\n                        mask_unstable_grad, last_A)\n                    last_A_unstable = (\n                        last_A.to_matrix(mask_unstable_grad.shape)\n                        * mask_unstable_grad)\n                    bias = (\n                        last_A_unstable.clamp(min=0) * pos_interval\n                        + last_A_unstable.clamp(max=0) * neg_interval)\n                    # FIXME Clean up patches. This implementation does not seem\n                    # to support general shapes.\n                    assert bias.ndim == 5\n                    bias = bias.sum(dim=[-1, -2, -3]).view(-1, 1)\n                    last_A_patches = (\n                        last_A_patches\n                        * torch.logical_not(mask_unstable_grad_unfold))\n                else:\n                    bias = 0\n\n                # need to unfold mask_active and mask_unstable\n                # [batch, 1, in_c, in_H, in_W]\n                mask_active_unfold = _maybe_unfold(mask_active, last_A)\n                mask_unstable_unfold = _maybe_unfold(mask_unstable, last_A)\n                # [spec, batch, 1, in_c, in_H, in_W]\n\n                mask_active_unfold = mask_active_unfold.expand(last_A.shape)\n                mask_unstable_unfold = mask_unstable_unfold.expand(last_A.shape)\n\n                A = Patches(\n                    last_A_patches * mask_active_unfold,\n                    last_A.stride, last_A.padding, last_A.shape,\n                    last_A.identity, last_A.unstable_idx, last_A.output_shape)\n\n                A_pos_patches = last_A_patches.clamp(min=0) * mask_unstable_unfold\n                A_neg_patches = last_A_patches.clamp(max=0) * mask_unstable_unfold\n\n                A_pos = Patches(\n                    A_pos_patches, last_A.stride, last_A.padding, last_A.shape,\n                    last_A.identity, last_A.unstable_idx, last_A.output_shape)\n                A_neg = Patches(\n                    A_neg_patches, last_A.stride, last_A.padding, last_A.shape,\n                    last_A.identity, last_A.unstable_idx, last_A.output_shape)\n\n                return A, A_pos, A_neg, bias\n\n        lA, lA_pos, lA_neg, lbias = _bound_oneside(\n            last_lA, pos_interval=lower_interval, neg_interval=upper_interval)\n        uA, uA_pos, uA_neg, ubias = _bound_oneside(\n            last_uA, pos_interval=upper_interval, neg_interval=lower_interval)\n\n        return (\n            [(lA, uA), (lA_neg, uA_pos), (lA_pos, uA_neg), (None, None)],\n            lbias, ubias)\n"
  },
  {
    "path": "auto_LiRPA/operators/reshape.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom torch.nn import Module\nfrom .base import *\nfrom ..patches import Patches, patches_to_matrix\nfrom .linear import BoundLinear\nfrom .constant import BoundConstant\n\n\nclass BoundReshape(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        # It can be set to `view`, so that `view` instead of `reshape` will be used.\n        self.option = options.get('reshape', 'reshape')\n\n    def forward(self, x, shape):\n        shape = list(shape)\n        for i in range(len(shape)):\n            if shape[i] == -1:\n                shape[i] = prod(x.shape) // int(prod(shape[:i]) * prod(shape[(i + 1):]))\n        self.shape = shape\n        if self.option == 'view':\n            return x.contiguous().view(shape)\n        else:\n            return x.reshape(shape)\n\n    def bound_backward(self, last_lA, last_uA, x, shape, **kwargs):\n        def _bound_oneside(A):\n            if A is None:\n                return None\n            if type(A) == Patches:\n                # output shape should be [batch, in_c, in_H, in_W] since it's followed by Conv2d\n                assert len(self.output_shape) == 4\n                if type(self.inputs[0]) == BoundLinear:\n                    # Save the shape and it will be converted to matrix in Linear layer.\n                    return A.create_similar(input_shape=self.output_shape)\n                if A.unstable_idx is None:\n                    patches = A.patches\n                    # non-sparse: [batch, out_dim, out_c, out_H, out_W, out_dim, in_c, H, W]\n                    # [out_dim*out_c, batch, out_H, out_W, out_dim*in_c, H, W]\n                    # expected next_A shape [batch, spec, in_c, in_H , in_W].\n                    next_A = patches_to_matrix(\n                        pieces=patches, input_shape=self.output_shape,\n                        stride=A.stride, padding=A.padding)\n                else:\n                    # sparse: [spec, batch, in_c, patch_H, patch_W] (specs depends on the number of unstable neurons).\n                    patches = A.patches\n                    # expected next_A shape [batch, spec, input_c, in_H, in_W].\n                    next_A = patches_to_matrix(\n                        pieces=patches, input_shape=self.output_shape,\n                        stride=A.stride, padding=A.padding, \n                        output_shape=A.output_shape, unstable_idx=A.unstable_idx)\n                # Reshape it to [spec, batch, *input_shape]  (input_shape is the shape before Reshape operation).\n                return next_A.transpose(0, 1).reshape(-1, A.shape[1], *self.input_shape[1:])\n            else:\n                return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:])\n        #FIXME check reshape or view\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, x, shape):\n        batch_size = x.lw.shape[0]\n        lw = x.lw.reshape(batch_size, dim_in, *self.shape[1:])\n        uw = x.uw.reshape(batch_size, dim_in, *self.shape[1:])\n        lb = x.lb.reshape(batch_size, *self.shape[1:])\n        ub = x.ub.reshape(batch_size, *self.shape[1:])\n        return LinearBound(lw, lb, uw, ub)\n\n    def bound_dynamic_forward(self, x, shape, max_dim=None, offset=0):\n        w = x.lw.reshape(x.lw.shape[0], x.lw.shape[1], *self.shape[1:])\n        b = x.lb.reshape(x.lb.shape[0], *self.shape[1:])\n        return LinearBound(w, b, w, b, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)\n\n    def interval_propagate(self, *v):\n        return Interval.make_interval(\n            self.forward(v[0][0], v[1][0]),\n            self.forward(v[0][1], v[1][0]), v[0])\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if isinstance(v[0], Tensor):\n            self.solver_vars = self.forward(*v)\n            return\n        gvar_array = np.array(v[0])\n        gvar_array = gvar_array.reshape(v[1].detach().cpu().numpy())[0]\n        self.solver_vars = gvar_array.tolist()\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = ReshapeGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        return [(node_grad, grad_input, [])]\n\n\nclass BoundUnsqueeze(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n        if 'axes' in attr:\n            self.axes = attr['axes']\n            assert len(self.axes) == 1\n            self.axes = self.axes[0]\n        else:\n            self.axes = None\n\n    def forward(self, *x):\n        data = x[0]\n        if self.axes is not None:\n            axes = self.axes\n        else:\n            axes = x[1].item()\n            self.axes = axes\n        return data.unsqueeze(axes)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        if self.axes is not None:\n            axes = self.make_axis_non_negative(self.axes, 'output')\n        else:\n            axes = self.make_axis_non_negative(x[1].value.item(), 'output')\n        if axes == 0:\n            raise ValueError(\"Unsqueezing with axes == 0 is not allowed\")\n        else:\n            def squeeze_A(last_A):\n                if type(last_A) == Patches:\n                    return Patches(\n                        last_A.patches.squeeze(axes - 5),\n                        last_A.stride, last_A.padding, last_A.shape,\n                        last_A.identity, last_A.unstable_idx, last_A.output_shape)\n                elif last_A is not None:\n                    return last_A.squeeze(axes + 1)\n                else:\n                    return None\n            lA = squeeze_A(last_lA)\n            uA = squeeze_A(last_uA)\n            return [(lA, uA), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        axes = self.make_axis_non_negative(\n            self.axes if self.axes is not None else x[1].lb.item(), 'output')\n        x = x[0]\n        if len(self.input_shape) == 0:\n            lw, lb = x.lw.unsqueeze(1), x.lb.unsqueeze(0)\n            uw, ub = x.uw.unsqueeze(1), x.ub.unsqueeze(0)\n        else:\n            lw, lb = x.lw.unsqueeze(axes + 1), x.lb.unsqueeze(axes)\n            uw, ub = x.uw.unsqueeze(axes + 1), x.ub.unsqueeze(axes)\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v[0])\n\n    def build_gradient_node(self, grad_upstream):\n        axes = self.make_axis_non_negative(self.axes, 'output')\n        if axes == 0:\n            raise ValueError(\"Unsqueezing with axes == 0 is not allowed\")\n        node_grad = UnsqueezeGrad(axes)\n        return [(node_grad, (grad_upstream,), [])]\n\n\nclass UnsqueezeGrad(Module):\n    def __init__(self, axes):\n        super().__init__()\n        self.axes = axes\n\n    def forward(self, grad_last):\n        return grad_last.squeeze(self.axes + 1)\n\n\nclass BoundExpand(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n\n    def forward(self, x, y):\n        y = y.clone()\n        assert y.ndim == 1\n        n, m = x.ndim, y.shape[0]\n        assert n <= m\n        for i in range(n):\n            if y[m - n + i] == 1:\n                y[m - n + i] = x.shape[i]\n            else:\n                assert x.shape[i] == 1 or x.shape[i] == y[m - n + i]\n        return x.expand(*list(y))\n    \n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        assert not self.is_input_perturbed(1)\n        # Although torch.expand supports prepending dimensions,\n        # bound computatiion doesn't since we must always keep\n        # the batch dimension at the beginning\n        assert (\n            len(x[0].output_shape) == len(self.output_shape)\n        ), \"BoundExpand with changed ndim is not supported by bound computation\"\n        n = len(self.output_shape)\n\n        def _bound_oneside(A):\n            if A is None:\n                return None\n            dims_to_sum = [i + 1 for i in range(1, n)\n                           if x[0].output_shape[i] == 1 and A.shape[i + 1] > 1]\n            return A.sum(dim=dims_to_sum, keepdim=True) if dims_to_sum else A\n        \n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        # It doesn't support general Expand operator.\n        # This is just for the Expand operator converted from torch.repeat, and here\n        # it should just be an identical operator.\n        shape = x[1].lb\n        if not (len(x[0].lb.shape) == len(shape) and (shape == 1).all()):\n            raise NotImplementedError(\"General onnx::Expand is not supported\")\n        return x[0]      \n\n    def build_gradient_node(self, grad_upstream):\n        shape = self.inputs[1].forward_value\n        if not (len(self.inputs[0].output_shape) == len(shape) and (shape == 1).all()):\n            raise NotImplementedError(\"General onnx::Expand is not supported\")\n        return [(ExpandGrad(shape), (grad_upstream,), []), None]\n\n\nclass ExpandGrad(Module):\n    # It doesn't support general Expand operator.\n    # This is just for the Expand operator converted from torch.repeat, and here\n    # it should just be an identical operator.\n    def __init__(self, shape):\n        super().__init__()\n        self.shape = shape\n\n    def forward(self, grad_last):\n        return grad_last\n\n\nclass BoundSqueeze(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n        if 'axes' in attr:\n            self.axes = attr['axes']\n            assert len(self.axes) == 1\n            self.axes = self.axes[0]\n        else:\n            self.axes = None\n\n    def forward(self, *x):\n        data = x[0]\n        if self.axes is not None:\n            axes = self.axes\n        else:\n            axes = x[1].item()\n        return data.squeeze(axes)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        if self.axes is not None:\n            axes = self.axes\n        else:\n            axes = self.make_axis_non_negative(x[1].value.item(), 'input')\n        if axes == 0:\n            raise ValueError(\"Squeezing with axes == 0 is not allowed\")\n        return [(last_lA.unsqueeze(axes + 1) if last_lA is not None else None,\n                 last_uA.unsqueeze(axes + 1) if last_uA is not None else None),\n                (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        if self.axes is not None:\n            axes = self.axes\n        else:\n            axes = self.make_axis_non_negative(x[1].lb.item(), 'input')\n        x = x[0]\n        return LinearBound(\n            x.lw.squeeze(axes + 1),\n            x.lb.squeeze(axes),\n            x.uw.squeeze(axes + 1),\n            x.ub.squeeze(axes)\n        )\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v[0])\n\n\nclass BoundFlatten(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n        self.axis = attr['axis']\n\n    def forward(self, x):\n        return torch.flatten(x, self.axis)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        def _bound_oneside(A):\n            if A is None:\n                return None\n            return A.reshape(A.shape[0], A.shape[1], *self.input_shape[1:])\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_dynamic_forward(self, x, max_dim=None, offset=0):\n        w = torch.flatten(x.lw, self.axis + 1)\n        b = torch.flatten(x.lb, self.axis)\n        return LinearBound(w, b, w, b, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim)\n\n    def bound_forward(self, dim_in, x):\n        self.axis = self.make_axis_non_negative(self.axis)\n        assert self.axis > 0\n        return LinearBound(\n            torch.flatten(x.lw, self.axis + 1),\n            torch.flatten(x.lb, self.axis),\n            torch.flatten(x.uw, self.axis + 1),\n            torch.flatten(x.ub, self.axis),\n        )\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        # e.g., v[0] input shape (16, 8, 8) => output shape (1024,)\n        self.solver_vars = np.array(v[0]).reshape(-1).tolist()\n        model.update()\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = ReshapeGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        return [(node_grad, grad_input, [])]\n\n\nclass BoundATenUnflatten(BoundReshape):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n    \n    def forward(self, x, dim, sizes):\n        self.dim = dim.item()\n        self.sizes = sizes.tolist()\n        fval = torch.unflatten(x, self.dim, self.sizes)\n        self.shape = fval.shape\n        return fval\n    \n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        A, lbias, ubias = super().bound_backward(last_lA, last_uA, x[0], shape=None, kwargs=kwargs)\n        # One more input for Unflatten\n        A.append((None, None))\n        return A, lbias, ubias\n\n    def bound_forward(self, dim_in, *x):\n        return super().bound_forward(dim_in=dim_in, x=x[0], shape=None)\n    \n    def bound_dynamic_forward(self, *x, max_dim=None, offset=0):\n        return super().bound_dynamic_forward(x=x[0], shape=None, max_dim=max_dim, offset=offset)\n\n    def interval_propagate(self, x, dim, sizes):\n        return Interval.make_interval(\n            self.forward(x[0], dim[0], sizes[0]),\n            self.forward(x[1], dim[0], sizes[0]), x)\n    \n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        shape = torch.tensor(v[0].shape[0], *self.shape[1:])\n        return super().build_solver((v[0], shape), model=model, C=C, model_type=model_type, solver_pkg=solver_pkg)\n\n\nclass ReshapeGrad(Module):\n    def forward(self, grad_last, inp):\n        if grad_last.numel() == inp.numel():\n            return grad_last.reshape(grad_last.shape[0], *inp.shape[1:])\n        else:\n            return grad_last.reshape(*grad_last.shape[:2], *inp.shape[1:])\n\n\nclass BoundTranspose(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.perm = attr['perm']\n        self.perm_inv_inc_one = [-1] * (len(self.perm) + 1)\n        self.perm_inv_inc_one[0] = 0\n        for i in range(len(self.perm)):\n            self.perm_inv_inc_one[self.perm[i] + 1] = i + 1\n        self.use_default_ibp = True\n        self.ibp_intermediate = True\n\n    def forward(self, x):\n        return x.permute(*self.perm)\n\n    def bound_backward(self, last_lA, last_uA, x, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            return last_A.permute(self.perm_inv_inc_one)\n\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0\n\n    def bound_forward(self, dim_in, x):\n        if self.input_shape[0] != 1:\n            perm = [0] + [(p + 1) for p in self.perm]\n        else:\n            assert (self.perm[0] == 0)\n            perm = [0, 1] + [(p + 1) for p in self.perm[1:]]\n        lw, lb = x.lw.permute(*perm), x.lb.permute(self.perm)\n        uw, ub = x.uw.permute(*perm), x.ub.permute(self.perm)\n\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(*v)\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = TransposeGrad(self.perm_inv_inc_one)\n        grad_input = (grad_upstream,)\n        return [(node_grad, grad_input, [])]\n\n\nclass TransposeGrad(Module):\n    def __init__(self, perm_inv):\n        super().__init__()\n        self.perm_inv = perm_inv\n\n    def forward(self, grad_last):\n        return grad_last.permute(*self.perm_inv)\n"
  },
  {
    "path": "auto_LiRPA/operators/resize.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Resize operator \"\"\"\nimport torch\n\nfrom .base import *\nimport numpy as np\nfrom .solver_utils import grb\nfrom ..patches import unify_shape, create_valid_mask, is_shape_used\n\n\nclass BoundResize(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        # only support nearest mode for now\n        assert attr[\"mode\"] == \"nearest\"\n        self.mode = attr[\"mode\"]\n        self.scale_factor = None\n\n    def forward(self, x, size=None, scale_factor=None):\n        # currently, forwarding size is not supported.\n        assert isinstance(size, torch.Tensor) and len(size.tolist()) == 0\n        # currently, only support enlarge tensor size by an integer factor.\n        assert len(scale_factor.tolist()) == 4 and np.array([tmp.is_integer() and tmp > 0 for tmp in scale_factor.tolist()]).all()\n        assert (scale_factor[0:2].to(torch.long) == 1).all(), 'only support resize on the H and W dim'\n        self.scale_factor = tuple([int(tmp) for tmp in scale_factor][2:])\n        if x.ndim == 4:\n            final = F.interpolate(\n                x, None, self.scale_factor, mode=self.mode)\n        else:\n            raise NotImplementedError(\n                \"Interpolation in 3D or interpolation with parameter size has not been implmented.\")\n        return final\n\n    def interval_propagate(self, *v):\n        l, u = zip(*v)\n        return Interval.make_interval(self.forward(*l), self.forward(*u), v[0])\n\n    def bound_forward(self, dim_in, *inp):\n        x = inp[0]\n        lw, lb, uw, ub = x.lw, x.lb, x.uw, x.ub\n        new_lw, new_lb, new_uw, new_ub = \\\n            torch.nn.functional.upsample(lw, scale_factor=([1] * (lw.ndim - 4)) + list(self.scale_factor), mode=self.mode), \\\n            torch.nn.functional.upsample(lb, scale_factor=([1] * (lb.ndim - 4)) + list(self.scale_factor), mode=self.mode), \\\n            torch.nn.functional.upsample(uw, scale_factor=([1] * (uw.ndim - 4)) + list(self.scale_factor), mode=self.mode), \\\n            torch.nn.functional.upsample(ub, scale_factor=([1] * (ub.ndim - 4)) + list(self.scale_factor), mode=self.mode)\n        return LinearBound(\n            lw = new_lw,\n            lb = new_lb,\n            uw = new_uw,\n            ub = new_ub)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            assert type(last_A) is Patches or last_A.ndim == 5\n            # in case the kernel size cannot be divided by scale_factor, we round up the shape\n            split_shape = tuple((torch.tensor(\n                last_A.shape)[-2:] / torch.tensor(self.scale_factor)).ceil().to(torch.long).tolist())\n            new_shape = last_A.shape[:-2] + split_shape\n            if not type(last_A) is Patches:\n                # classical mode is simple to handle by\n                # sum the grid elements by using avg_pool2d with divisor_override=1\n                return torch.nn.functional.avg_pool2d(\n                    last_A.reshape(-1, *last_A.shape[-2:]), kernel_size=self.scale_factor, stride=self.scale_factor,\n                    divisor_override=1).reshape(new_shape)\n            else:\n                # for patches mode\n                assert type(last_A) is Patches\n                assert self.scale_factor[0] == self.scale_factor[1]\n                if self.scale_factor[0] == 1:\n                    # identity upsampling\n                    return last_A\n                if isinstance(last_A.padding, int) and last_A.padding % self.scale_factor[0] == 0 and last_A.stride % self.scale_factor[0] == 0 and last_A.inserted_zeros == 0:\n                    # an easy case where patch sliding windows coincides with the nearest sampling scaling windows\n                    # in this case, we divide each patch to size of scale_factor sub-matrices,\n                    # and sum up each sub-matrices respectively\n                    # print(last_A.shape)\n                    padding = last_A.shape[-1] % self.scale_factor[-1]\n                    new_patches = torch.nn.functional.pad(last_A.patches, (0, padding, 0, padding))\n                    new_patches = torch.nn.functional.avg_pool2d(\n                        new_patches.reshape(-1, *new_patches.shape[-2:]), kernel_size=self.scale_factor,\n                        stride=self.scale_factor, divisor_override=1).reshape(new_shape)\n                    return last_A.create_similar(patches=new_patches,\n                                                 stride=last_A.stride//self.scale_factor[0],\n                                                 padding=last_A.padding//self.scale_factor[0],\n                                                 )\n                else:\n                    \"\"\"\n                        The following part is created and mainly maintained by Linyi\n                        Time complexity = O(A.numel * scale_factor + outH * kerH + outW * kerW + A.numel * kerH * kerW)\n                        With Python loop complexity = O(outH + outW + kerH * kerW * scale_factor^2)\n                    \"\"\"\n                    # preparation: unify shape\n                    if last_A.padding:\n                        padding = unify_shape(last_A.padding)\n                    else:\n                        padding = (0,0,0,0)\n                    # padding = (left, right, top, bottom)\n                    if last_A.output_padding:\n                        output_padding = unify_shape(last_A.output_padding)\n                    else:\n                        output_padding = (0,0,0,0)\n                    # output_padding = (left, right, top, bottom)\n\n                    \"\"\"\n                        Step 0: filter out valid entries that maps to real cells of input\n                        Like with inserted zeros = 2, [x 0 0 x 0 0 x]. Only \"x\" cells are kept\n                        Borrowed from one_d generation from Conv patches\n                    \"\"\"\n                    one_d_unfolded_r = create_valid_mask(self.output_shape,\n                                                         last_A.patches.device,\n                                                         last_A.patches.dtype,\n                                                         last_A.patches.shape[-2:],\n                                                         last_A.stride,\n                                                         last_A.inserted_zeros,\n                                                         last_A.padding,\n                                                         last_A.output_padding,\n                                                         last_A.unstable_idx)\n                    patches = last_A.patches * one_d_unfolded_r\n\n                    \"\"\"\n                        Step 1: compute the coordinate mapping from patch coordinates to input coordinates\n                        Time complexity: O(outH + outW)\n                        note: last_A shape is [outC, batch, outH, outW, inC, kerH, kerW]\n                        We create H_idx_map and W_idx_map of shape [outH] and [outW] respectively,\n                        recording the start idx of row/column for patches at position [.,.,.,.,.,i,j]\n                        in H_idx_map[i] and W_idx_map[j]\n                    \"\"\"\n                    ker_size_h, ker_size_w = last_A.shape[-2], last_A.shape[-1]\n                    if last_A.unstable_idx is None:\n                        # we can get the real output H and W from shape[2] and shape [3]\n                        out_h, out_w = last_A.shape[2], last_A.shape[3]\n                    else:\n                        # it seems to be stored in output_shape\n                        out_h, out_w = last_A.output_shape[-2], last_A.output_shape[-1]\n                    h_idx_map = torch.arange(0, out_h) * last_A.stride - padding[-2] + output_padding[-2] * last_A.stride\n                    h_idx_map = h_idx_map.to(last_A.device)\n                    w_idx_map = torch.arange(0, out_w) * last_A.stride - padding[-4] + output_padding[-4] * last_A.stride\n                    w_idx_map = w_idx_map.to(last_A.device)\n\n                    r\"\"\"\n                        Step 2: compute the compressed patches\n                        Time complexity: O(outH * kerH + outW * kerW + A.numel * kerH * kerW)\n                        Upsampling needs to sum up A cells in scale_factor * scale_factor sub-blocks\n                        Example: when scale factor is 2\n                        [ a b c d\n                          e f g h    ---\\    [ a+b+e+f c+d+g+h\n                          i j k l    ---/      i+j+m+n k+l+o+p]\n                          m n o p]\n                        In patches mode, we need to sum up cells in each patch accordingly.\n                        The summing mechanism could change at different locations.\n                        For each spatial dimension, we create a binary sum_mask tensor [outH, ker_size_h, new_ker_size_h]\n                            to select the cells to sum up\n                        Example:\n                        For [a b c d] -> [a+b c+d], with 3x3 patch covering [0..2] and [2..4].\n                        The first patch needs to sum to [a+b c]; the second patch needs to sum to [b c+d]\n                        So we have sum_mask\n                        [ for patch 1: [[1, 1, 0],    (first entry sums up index 0 and 1)\n                                        [0, 0, 1]]^T, (second entry sums up index 2)\n                          for patch 2: [[1, 0, 0],    (first entry sums up index 0)\n                                        [0, 1, 1]]^T  (second entry sums up index 1 and 2)\n                        ]\n                        With the mask, we can now compute the new patches with einsum:\n                            [outC, batch, outH, outW, inC, kerH, kerW] * [outH, kerH, new_kerH] -> [outC, batch, outH, outW, inC, new_kerH, kerW]\n                    \"\"\"\n                    tot_scale_fac = ((last_A.inserted_zeros + 1) * self.scale_factor[0], (last_A.inserted_zeros + 1) * self.scale_factor[1])\n                    new_ker_size_h, new_ker_size_w = \\\n                        (tot_scale_fac[0] + ker_size_h - 2) // tot_scale_fac[0] + 1, \\\n                        (tot_scale_fac[1] + ker_size_w - 2) // tot_scale_fac[1] + 1\n\n                    min_h_idx, max_h_idx = h_idx_map[0], h_idx_map[-1] + ker_size_h\n                    shrank_h_idx = (torch.arange(min_h_idx, max_h_idx) + last_A.inserted_zeros).div(tot_scale_fac[0], rounding_mode='floor')\n                    if last_A.unstable_idx is None:\n                        # with nonsparse index, create full-sized sum musk for rows\n                        ker_h_indexer = torch.arange(0, ker_size_h).to(last_A.device)\n                        sum_mask_h = torch.zeros(last_A.shape[2], ker_size_h, new_ker_size_h).to(last_A.device)\n                        for i in range(last_A.shape[2]):\n                            sum_mask_h[i, ker_h_indexer, \\\n                                shrank_h_idx[h_idx_map[i] - min_h_idx: h_idx_map[i] - min_h_idx + ker_size_h] - shrank_h_idx[h_idx_map[i] - min_h_idx]] = 1\n                            # set zero to those in padding area\n                            padding_place_mask = (ker_h_indexer + h_idx_map[i] < 0)\n                            sum_mask_h[i, padding_place_mask] = 0\n                    else:\n                        # with sparse index, create sparse sum musk\n                        sum_mask_h = torch.zeros(last_A.shape[0], ker_size_h, new_ker_size_h).to(last_A.device)\n\n                        row_nos = last_A.unstable_idx[1]\n                        unstable_loc_indexer = torch.arange(0, row_nos.shape[0]).to(last_A.device)\n\n                        for k in range(ker_size_h):\n                            place_in_new_ker = shrank_h_idx[h_idx_map[row_nos] - min_h_idx + k] - shrank_h_idx[h_idx_map[row_nos] - min_h_idx]\n                            sum_mask_h[unstable_loc_indexer, k, place_in_new_ker] = 1\n                            # set zero to those in padding area\n                            padding_place_mask = (h_idx_map[row_nos] + k < 0)\n                            sum_mask_h[padding_place_mask, k] = 0\n\n                    min_w_idx, max_w_idx = w_idx_map[0], w_idx_map[-1] + ker_size_w\n                    shrank_w_idx = (torch.arange(min_w_idx, max_w_idx) + last_A.inserted_zeros).div(tot_scale_fac[1], rounding_mode='floor')\n                    if last_A.unstable_idx is None:\n                        # with nonsparse index, create full-sized sum musk for columns\n                        ker_w_indexer = torch.arange(0, ker_size_w).to(last_A.device)\n                        sum_mask_w = torch.zeros(last_A.shape[3], ker_size_w, new_ker_size_w).to(last_A.device)\n                        for i in range(last_A.shape[3]):\n                            sum_mask_w[i, ker_w_indexer, \\\n                                shrank_w_idx[w_idx_map[i] - min_w_idx: w_idx_map[i] - min_w_idx + ker_size_w] - shrank_w_idx[w_idx_map[i] - min_w_idx]] = 1\n                            # set zero to those in padding area\n                            padding_place_mask = (ker_w_indexer + w_idx_map[i] < 0)\n                            sum_mask_w[i, padding_place_mask] = 0\n                    else:\n                        # with sparse index, create sparse sum musk\n                        sum_mask_w = torch.zeros(last_A.shape[0], ker_size_w, new_ker_size_w).to(last_A.device)\n\n                        col_nos = last_A.unstable_idx[2]\n                        unstable_loc_indexer = torch.arange(0, col_nos.shape[0]).to(last_A.device)\n\n                        for k in range(ker_size_w):\n                            place_in_new_ker = shrank_w_idx[w_idx_map[col_nos] - min_w_idx + k] - shrank_w_idx[w_idx_map[col_nos] - min_w_idx]\n                            sum_mask_w[unstable_loc_indexer, k, place_in_new_ker] = 1\n                            # set zero to those in padding area\n                            padding_place_mask = (w_idx_map[col_nos] + k < 0)\n                            sum_mask_w[padding_place_mask, k] = 0\n\n                    if last_A.unstable_idx is None:\n                        # nonsparse aggregation\n                        new_patches = torch.einsum(\"ObhwIij,hix,wjy->ObhwIxy\", patches, sum_mask_h, sum_mask_w)\n                    else:\n                        # sparse aggregation\n                        new_patches = torch.einsum(\"NbIij,Nix,Njy->NbIxy\", patches, sum_mask_h, sum_mask_w)\n\n                    \"\"\"\n                        Step 3: broadcasting the new_patches by repeating elements,\n                            since later we would need to apply insert_zeros\n                        For example, scale_factor = 3, repeat patch [a,b] to [a,a,a,b,b,b]\n                        Time complexity: O(A.numel * scale_factor)\n                    \"\"\"\n                    ext_new_ker_size_h, ext_new_ker_size_w = \\\n                        new_ker_size_h * tot_scale_fac[0], new_ker_size_w * tot_scale_fac[1]\n                    ext_new_patches = torch.zeros(list(new_patches.shape[:-2]) +\n                                                  [ext_new_ker_size_h, ext_new_ker_size_w], device=new_patches.device)\n                    for i in range(ext_new_ker_size_h):\n                        for j in range(ext_new_ker_size_w):\n                            ext_new_patches[..., i, j] = new_patches[..., i // tot_scale_fac[0], j // tot_scale_fac[1]]\n\n                    \"\"\"\n                        Step 4: compute new padding, stride, shape, insert_zeros, and output_padding\n                    \"\"\"\n                    # stride should be the same after upsampling, stride is an integer\n                    # new_stride = last_A.stride\n                    # padding can change much, the beginning should extend by (scale - 1) entries,\n                    # the ending should extend by (ext_new_ker_size - ker_size) entries\n                    # padding = (left, right, top, bottom)\n                    new_padding = (padding[0] + (self.scale_factor[1] - 1) * (last_A.inserted_zeros + 1),\n                                   padding[1] + ext_new_ker_size_w - ker_size_w,\n                                   padding[2] + (self.scale_factor[0] - 1) * (last_A.inserted_zeros + 1),\n                                   padding[3] + ext_new_ker_size_h - ker_size_h)\n                    if new_padding[0] == new_padding[1] and new_padding[1] == new_padding[2] and new_padding[2] == new_padding[3]:\n                        # simplify to an int\n                        new_padding = new_padding[0]\n                    # only support uniform scaling on H and W now, i.e., self.scale_factor[0] == self.scale_factor[1]\n                    inserted_zeros = tot_scale_fac[0] - 1\n                    # output padding seems not to change\n                    # new_output_padding = last_A.output_padding\n\n                    \"\"\"\n                        Package and create\n                    \"\"\"\n                    # sparse tensor doesn't support einsum which is necessary for subsequent computes, so deprecated\n                    # if inserted_zeros >= 3:\n                    #     # mask unused cells\n                    #     input_shape = list(self.output_shape)\n                    #     input_shape[-2], input_shape[-1] = input_shape[-2] // self.scale_factor[-2], \\\n                    #         input_shape[-1] // self.scale_factor[-1]\n                    #     one_unfolded = create_valid_mask(input_shape, ext_new_patches.device,\n                    #                                       ext_new_patches.dtype, ext_new_patches.shape[-2:],\n                    #                                       last_A.stride, inserted_zeros, new_padding,\n                    #                                       last_A.output_padding,\n                    #                                       last_A.unstable_idx if last_A.unstable_idx else None)\n                    #     ext_new_patches = (ext_new_patches * one_unfolded).to_sparse()\n\n                    # print the shape change after upsampling, if needed\n                    # print(f'After upsampling, '\n                    #       f'{last_A.patches.shape} (pad={padding}, iz={last_A.inserted_zeros}, s={last_A.stride}) -> '\n                    #       f'{ext_new_patches.shape} (pad={new_padding}, iz={inserted_zeros}, s={last_A.stride})')\n                    ret_patches_A = last_A.create_similar(patches=ext_new_patches,\n                                                          padding=new_padding,\n                                                          inserted_zeros=inserted_zeros)\n                    if self.input_shape[-2] < ret_patches_A.shape[-2] and self.input_shape[-1] < ret_patches_A.shape[-2] \\\n                            and not is_shape_used(ret_patches_A.output_padding):\n                        # using matrix mode could be more memory efficient\n                        ret_matrix_A = ret_patches_A.to_matrix(self.input_shape)\n                        # print(f'After upsampling, to_matrix: {ret_matrix_A.shape}')\n                        ret_matrix_A = ret_matrix_A.transpose(0, 1)\n                        return ret_matrix_A\n                    else:\n                        return ret_patches_A\n\n        last_lA = _bound_oneside(last_lA)\n        last_uA = _bound_oneside(last_uA)\n        return [(last_lA, last_uA), (None, None), (None, None)], 0, 0\n\n"
  },
  {
    "path": "auto_LiRPA/operators/rnn.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"RNN.\"\"\"\nfrom .base import *\n\n\nclass BoundRNN(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.complex = True\n        self.output_index = output_index\n        raise NotImplementedError(\n            'torch.nn.RNN is not supported at this time.'\n            'Please implement your RNN with torch.nn.RNNCell and a manual for-loop.'\n            'See an example of LSTM:'\n            'https://github.com/Verified-Intelligence/auto_LiRPA/blob/10a9b30/examples/sequence/lstm.py#L9')\n\n    def forward(self, x, weight_input, weight_recurrent, bias, sequence_length, initial_h):\n        assert (torch.sum(torch.abs(initial_h)) == 0)\n\n        self.input_size = x.shape[-1]\n        self.hidden_size = weight_input.shape[-2]\n\n        class BoundRNNImpl(nn.Module):\n            def __init__(self, input_size, hidden_size,\n                         weight_input, weight_recurrent, bias, output_index):\n                super().__init__()\n\n                self.input_size = input_size\n                self.hidden_size = hidden_size\n\n                self.cell = torch.nn.RNNCell(\n                    input_size=input_size,\n                    hidden_size=hidden_size\n                )\n\n                self.cell.weight_ih.data.copy_(weight_input.squeeze(0).data)\n                self.cell.weight_hh.data.copy_(weight_recurrent.squeeze(0).data)\n                self.cell.bias_ih.data.copy_((bias.squeeze(0))[:hidden_size].data)\n                self.cell.bias_hh.data.copy_((bias.squeeze(0))[hidden_size:].data)\n\n                self.output_index = output_index\n\n            def forward(self, x, hidden):\n                length = x.shape[0]\n                outputs = []\n                for i in range(length):\n                    hidden = self.cell(x[i, :], hidden)\n                    outputs.append(hidden.unsqueeze(0))\n                outputs = torch.cat(outputs, dim=0)\n\n                if self.output_index == 0:\n                    return outputs\n                else:\n                    return hidden\n\n        self.model = BoundRNNImpl(\n            self.input_size, self.hidden_size,\n            weight_input, weight_recurrent, bias,\n            self.output_index)\n        self.input = (x, initial_h)\n\n        return self.model(*self.input)"
  },
  {
    "path": "auto_LiRPA/operators/s_shaped.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"S-shaped base class, activation functions, and relevant ops.\"\"\"\nimport torch\nfrom torch.nn import Module\nfrom torch.autograd import Function\nfrom .base import *\nfrom .activation_base import BoundOptimizableActivation\n\n\nclass BoundSShaped(BoundOptimizableActivation):\n    \"\"\"\n    Base class for computing output bounds of globally and partially s-shaped nonlinear functions\n    (e.g., sigmoid, tanh, sin, cos) over given input intervals.\n    \"\"\"\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None, activation=(None, None, None), precompute=False):\n        super().__init__(attr, inputs, output_index, options)\n        if options is None:\n            options = {}\n        self.splittable = True\n        self.inverse_s_shape = False\n        self.ibp_intermediate = True\n\n        self.activation = activation\n        self.activation_name = activation[0]\n\n        self.act_func = activation[1]\n        self.d_act_func = activation[2]\n\n        self.step_pre = 0.01\n        if precompute:\n            self.precompute_relaxation(self.act_func, self.d_act_func)\n            self.precompute_dfunc_values(self.act_func, self.d_act_func)\n        # TODO make them configurable when implementing a general nonlinear activation.\n        # Neurons whose gap between pre-activation bounds is smaller than this\n        # threshold will be masked and don't need branching.\n        self.split_min_gap = 1e-2  # 1e-4\n        # Neurons whose pre-activation bounds don't overlap with this range\n        # are considered as stable (with values either 0 or 1) and don't need\n        # branching.\n        self.split_range = (self.range_l, self.range_u)\n        # The initialization will be adjusted if the pre-activation bounds are too loose.\n        self.loose_threshold = options.get(self.activation_name, {}).get(\n            'loose_threshold', None)\n        self.convex_concave = None\n        self.activation_bound_option = options.get('activation_bound_option', 'adaptive')\n\n        self.inflections = [0.]\n        self.extremes = []\n        self.sigmoid_like_mask = None\n\n        # FIXME: Smoothness enhancement for s-shaped functions should be enabled by default.\n        # This enhancement makes the linear bounds change smoothly between different cases.\n        # We provide this option only to reproduce results from previous papers.\n        self.disable_smoothness_enhancement = options.get(\n            's_shaped_disable_smoothness_enhancement', False)\n\n    def opt_init(self):\n        super().opt_init()\n        self.tp_both_lower_init = {}\n        self.tp_both_upper_init = {}\n\n    def branch_input_domain(self, lb, ub):\n        # For functions that are only partially s-shaped, such as sin and cos, the non-s-shaped intervals are identified\n        # and masked here. sigmoid_like_mask marks the strictly s-shaped intervals, and branch_mask marks the non-s-\n        # shaped ones. For globally s-shaped functions like tanh and sigmoid, sigmoid_like_mask stores all 1s and\n        # branch_mask stores all 0s.\n        self.sigmoid_like_mask = torch.ones_like(lb, dtype=torch.bool)\n        self.branch_mask = torch.zeros_like(lb, dtype=torch.bool)\n\n    def _init_opt_parameters_impl(self, size_spec, name_start, num_params=10):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        shape = l.shape\n        # Alpha dimension is (num_params, output_shape, batch, *shape) for the s-shaped activation function.\n        alpha = torch.empty(num_params, size_spec, *shape, device=l.device)\n        alpha.data[:4] = (l + u) / 2\n        alpha.data[4:6] = self.tp_both_lower_init[name_start]\n        alpha.data[6:8] = self.tp_both_upper_init[name_start]\n        if num_params > 8:\n            alpha.data[8:] = 0\n        return alpha\n\n    @torch.no_grad()\n    def precompute_relaxation(self, func, dfunc, x_limit=500):\n        \"\"\"\n        This function precomputes the tangent lines that will be used as\n        lower/upper bounds for S-shaped functions centered at 0 along the x-axis.\n        \"\"\"\n        self.x_limit = x_limit\n        self.num_points_pre = int(self.x_limit / self.step_pre)\n        max_iter = 100\n\n        logger.debug('Precomputing relaxation for %s (pre-activation limit: %f)',\n                     self.__class__.__name__, x_limit)\n\n        def check_lower(upper, d):\n            \"\"\"Given two points upper, d (d <= upper),\n            check if the slope at d will be less than f(upper) at upper.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a lower bound.\n            return k * (upper - d) + func(d) <= func(upper)\n\n        def check_upper(lower, d):\n            \"\"\"Given two points lower, d (d >= lower),\n            check if the slope at d will be greater than f(lower) at lower.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a upper bound.\n            return k * (lower - d) + func(d) >= func(lower)\n\n        # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function.\n        upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device)\n        r = torch.zeros_like(upper)\n        # Initial guess, the tangent line is at -1.\n        l = -torch.ones_like(upper)\n        while True:\n            # Check if the tangent line at the guessed point is an lower bound at f(upper).\n            checked = check_lower(upper, l).int()\n            # If the initial guess is not smaller enough, then double it (-2, -4, etc).\n            l = checked * l + (1 - checked) * (l * 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper).\n        # We want to further tighten this bound by moving it closer to 0.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_lower(upper, m).int()\n            l = checked * m + (1 - checked) * l\n            r = checked * r + (1 - checked) * m\n        # At upper, a line with slope l is guaranteed to lower bound the function.\n        self.d_lower = l.clone()\n\n        # Do the same again:\n        # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function.\n        lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device)\n        l = torch.zeros_like(upper)\n        r = torch.ones_like(upper)\n        while True:\n            checked = check_upper(lower, r).int()\n            r = checked * r + (1 - checked) * (r * 2)\n            if checked.sum() == l.numel():\n                break\n        for _ in range(max_iter):\n            m = (l + r) / 2\n            checked = check_upper(lower, m).int()\n            l = (1 - checked) * m + checked * l\n            r = (1 - checked) * r + checked * m\n        self.d_upper = r.clone()\n\n        logger.debug('Done')\n\n    def precompute_dfunc_values(self, func, dfunc, x_limit=500):\n        \"\"\"\n        This function precomputes a list of values for dfunc.\n        \"\"\"\n        upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device)\n        self.dfunc_values = dfunc(upper)\n\n    def forward(self, x):\n        return self.act_func(x)\n\n    def retrieve_from_precompute(self, precomputed_d, input_bound, default_d):\n        \"\"\"\n        precomputed_d: The precomputed tangent points.\n        input_bound: The input bound of the function.\n        default_d: If input bound goes out of precompute range, we will use default_d.\n        All of the inputs should share the same shape.\n        \"\"\"\n\n        # divide input bound into number of steps to the inflection point (at x=0)\n        index = torch.max(\n            torch.zeros(input_bound.numel(), dtype=torch.long, device=input_bound.device),\n            (input_bound / self.step_pre).to(torch.long).reshape(-1)\n        ) + 1\n        # If precompute range is smaller than input, tangent points will be taken from default.\n        # The default value should be a guaranteed bound\n        if index.max() >= precomputed_d.numel():\n            warnings.warn(f'Pre-activation bounds are too loose for {self}')\n            return torch.where(\n                (index < precomputed_d.numel()).view(input_bound.shape),\n                torch.index_select(\n                    precomputed_d, 0, index.clamp(max=precomputed_d.numel() - 1)\n                ).view(input_bound.shape),\n                default_d,\n            ).view(input_bound.shape)\n        else:\n            return torch.index_select(precomputed_d, 0, index).view(input_bound.shape)\n\n    def generate_d_lower_upper(self, lower, upper):\n        # Indices of neurons with input upper bound >=0, whose optimal slope to\n        # lower bound the function was pre-computed.\n        # Note that for neurons with also input lower bound >=0,\n        # they will be masked later.\n        d_lower = self.retrieve_from_precompute(self.d_lower, upper, lower)\n\n        # Indices of neurons with lower bound <=0, whose optimal slope to upper\n        # bound the function was pre-computed.\n        d_upper = self.retrieve_from_precompute(self.d_upper, -lower, upper)\n        return d_lower, d_upper\n\n    def retrieve_d_from_k(self, k, func):\n        d_indices = torch.searchsorted(torch.flip(self.dfunc_values, [0]), k, right=False)\n        d_indices = self.num_points_pre - d_indices + 4\n        d_left = d_indices * self.step_pre\n        d_right = d_left + self.step_pre\n        y_left = func(d_left)\n        y_right = func(d_right)\n        k_left = self.dfunc_values[d_indices]\n        k_right = self.dfunc_values[torch.clamp(d_indices+1, max=self.dfunc_values.shape[0]-1)]\n        # We choose the intersection of two tangent lines\n        d_return = (k_left * d_left - k_right * d_right - y_left + y_right) / (k_left - k_right).clamp(min=1e-8)\n        mask_almost_the_same = abs(k_left - k_right) < 1e-5\n        d_return[mask_almost_the_same] = d_left[mask_almost_the_same]\n        y_d = k_left * (d_return - d_left) + y_left\n        return d_return, y_d\n\n    def bound_relax_impl_same_slope(self, x, func, dfunc):\n        lower, upper = x.lower, x.upper\n        y_l, y_u = func(lower), func(upper)\n        # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)).\n        k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8)\n        mask_almost_the_same = abs(upper - lower) < 1e-4\n        k_direct[mask_almost_the_same] = dfunc(lower)[mask_almost_the_same]\n\n        mask_direct_lower = k_direct <= dfunc(lower)\n        mask_direct_upper = k_direct <= dfunc(upper)\n\n        # We now find the tangent line with the same slope of k_direct\n        # In the case of \"mask_direct_lower(or upper)\", there should be only one possible tangent point\n        # at which we obtain the same slope within the interval [lower, upper]\n        d, y_d = self.retrieve_d_from_k(k_direct, func)\n        d[lower + upper < 0] *= -1  # This is the case \"direct upper\"\n        y_d[lower + upper < 0] = 2 * func(torch.tensor(0)) - y_d[lower + upper < 0]\n        d_clamped = torch.clamp(d, min=lower, max=upper)\n        y_d[d_clamped != d] = func(d_clamped[d_clamped != d])\n        self.add_linear_relaxation(\n            mask=mask_direct_lower, type='lower', k=k_direct, x0=lower, y0=y_l\n        )\n        self.add_linear_relaxation(\n            mask=mask_direct_lower, type='upper', k=k_direct, x0=d_clamped, y0=y_d\n        )\n        self.add_linear_relaxation(\n            mask=mask_direct_upper, type='upper', k=k_direct, x0=upper, y0=y_u\n        )\n        self.add_linear_relaxation(\n            mask=mask_direct_upper, type='lower', k=k_direct, x0=d_clamped, y0=y_d\n        )\n        # Now we turn to the case where no direct line can be used\n        d_lower, d_upper = self.generate_d_lower_upper(lower, upper)\n        mask_both = torch.logical_not(mask_direct_upper + mask_direct_lower)\n        # To make sure upper and lower bounds have the same slope,\n        # we need the two tangents to be symmetrical\n        d_same_slope = torch.max(torch.abs(d_lower), torch.abs(d_upper))\n        k = dfunc(d_same_slope)\n        y_d_same_slope = func(d_same_slope)\n        y_d_same_slope_opposite = 2*func(torch.tensor(0)) - y_d_same_slope\n        self.add_linear_relaxation(\n            mask=mask_both, type='upper', k=k, x0=d_same_slope, y0=y_d_same_slope\n        )\n        self.add_linear_relaxation(\n            mask=mask_both, type='lower', k=k, x0=-d_same_slope, y0=y_d_same_slope_opposite\n        )\n\n    def bound_relax_impl(self, x, func, dfunc):\n        lower, upper = x.lower, x.upper\n        y_l, y_u = func(lower), func(upper)\n        # k_direct is the slope of the line directly connecting the two endpoints of the function inside the interval:\n        # (lower, func(lower)) and (upper, func(upper)).\n        k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8)\n\n        # Fixed bounds that cannot be optimized.\n        # self.mask_neg are the masks for neurons with upper bound <= 0, i.e., the whole input interval lies below 0.\n        # self.mask_pos are the masks for neurons with lower bound >= 0, i.e., the whole input interval lies above 0.\n        # For negative intervals, we can derive the linear upper bound by connecting the two endpoints,\n        # i.e., starting from (lower, func(lower)) and setting the slope to k_direct.\n        self.add_linear_relaxation(\n            mask=self.mask_neg, type='upper', k=k_direct, x0=lower, y0=y_l)\n        # For positive intervals, we connect the two endpoints to find the linear lower bound instead.\n        self.add_linear_relaxation(\n            mask=self.mask_pos, type='lower', k=k_direct, x0=lower, y0=y_l)\n\n        # Store the x-coordinates of the points of tangencies.\n        # d_lower is the closest value to upper such that the tangent line at (d_lower, func(d_lower)) still lower-\n        # bounds the function in interval (lower, upper).\n        # d_upper is the closest value to lower such that the tangent line at (d_lower, func(d_lower)) still upper-\n        # bounds the function in interval (lower, upper).\n        # d_lower and d_upper can be regarded as the default points of tangencies to draw linear bounds through.\n        d_lower, d_upper = self.generate_d_lower_upper(lower, upper)\n\n        # self.mask_both is the masks for neurons where lower < 0 < upper, i.e., the input interval contains 0.\n        # mask_direct_lower is the masks for neurons whose input interval contains zero and whose linear lower bound can\n        # be derived by connecting the two endpoints.\n        # mask_direct_upper is the masks for neurons whose input interval contains zero and whose linear upper bound can\n        # be derived by connecting the two endpoints.\n        if self.convex_concave is None:\n            mask_direct_lower = k_direct < dfunc(lower)\n            mask_direct_upper = k_direct < dfunc(upper)\n        else:\n            mask_direct_lower = torch.where(\n                self.convex_concave,\n                k_direct < dfunc(lower), k_direct > dfunc(upper))\n            mask_direct_upper = torch.where(\n                self.convex_concave,\n                k_direct < dfunc(upper), k_direct > dfunc(lower))\n        mask_direct_lower = torch.logical_and(mask_direct_lower, self.mask_both)\n        mask_direct_upper = torch.logical_and(mask_direct_upper, self.mask_both)\n\n        if self.opt_stage in ['opt', 'reuse']:\n            if not hasattr(self, 'alpha'):\n                # Raise an error if alpha is not created.\n                self._no_bound_parameters()\n            ns = self._start\n\n            # Clamping is done here rather than after `opt.step()` call\n            # because it depends on pre-activation bounds\n            self.alpha[ns].data[0:2] = torch.max(\n                torch.min(self.alpha[ns][0:2], upper), lower)\n            self.alpha[ns].data[2:4] = torch.max(\n                torch.min(self.alpha[ns][2:4], upper), lower)\n            if self.convex_concave is None:\n                self.alpha[ns].data[4:6] = torch.min(\n                    self.alpha[ns][4:6], d_lower)\n                self.alpha[ns].data[6:8] = torch.max(\n                    self.alpha[ns][6:8], d_upper)\n            else:\n                self.alpha[ns].data[4:6, :] = torch.where(\n                    self.convex_concave,\n                    torch.max(lower, torch.min(self.alpha[ns][4:6, :], d_lower)),\n                    torch.min(upper, torch.max(self.alpha[ns][4:6, :], d_lower))\n                )\n                self.alpha[ns].data[6:8, :] = torch.where(\n                    self.convex_concave,\n                    torch.min(upper, torch.max(self.alpha[ns][6:8, :], d_upper)),\n                    torch.max(lower, torch.min(self.alpha[ns][6:8, :], d_upper))\n                )\n\n            # shape [2, out_c, n, c, h, w].\n            tp_pos = self.alpha[ns][0:2]  # For upper bound relaxation\n            tp_neg = self.alpha[ns][2:4]  # For lower bound relaxation\n            tp_both_lower = self.alpha[ns][4:6]\n            tp_both_upper = self.alpha[ns][6:8]\n\n            # No need to use tangent line, when the tangent point is at the left\n            # side of the preactivation lower bound. Simply connect the two sides.\n            self.add_linear_relaxation(\n                mask=mask_direct_lower, type='lower', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_both, mask_direct_lower), type='lower',\n                k=dfunc(tp_both_lower), x0=tp_both_lower, y0=func(tp_both_lower))\n\n            self.add_linear_relaxation(\n                mask=mask_direct_upper, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_both, mask_direct_upper), type='upper',\n                k=dfunc(tp_both_upper), x0=tp_both_upper, y0=func(tp_both_upper))\n\n            self.add_linear_relaxation(\n                mask=self.mask_neg, type='lower', k=dfunc(tp_neg),\n                x0=tp_neg, y0=func(tp_neg))\n            self.add_linear_relaxation(\n                mask=self.mask_pos, type='upper', k=dfunc(tp_pos),\n                x0=tp_pos, y0=func(tp_pos))\n        else:\n            if self.opt_stage == 'init':\n                # Initialize optimizable slope.\n                tp_both_lower_init = d_lower.detach()\n                tp_both_upper_init = d_upper.detach()\n\n                if self.loose_threshold is not None:\n                    # We will modify d_lower and d_upper inplace.\n                    # So make a copy for these two.\n                    tp_both_lower_init = tp_both_lower_init.clone()\n                    tp_both_upper_init = tp_both_upper_init.clone()\n                    # A different initialization if the pre-activation bounds\n                    # are too loose\n                    loose = torch.logical_or(lower < -self.loose_threshold,\n                                            upper > self.loose_threshold)\n                    d_lower[loose] = lower[loose]\n                    d_upper[loose] = upper[loose]\n\n                ns = self._start\n                self.tp_both_lower_init[ns] = tp_both_lower_init\n                self.tp_both_upper_init[ns] = tp_both_upper_init\n\n            # Not optimized (vanilla CROWN bound).\n            # Use the middle point slope as the lower/upper bound. Not optimized.\n            m = (lower + upper) / 2\n            y_m = func(m)\n            k_m = dfunc(m)\n            # Lower bound is the middle point slope for the case input upper bound <= 0.\n            # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)).\n            self.add_linear_relaxation(mask=self.mask_neg, type='lower', k=k_m, x0=m, y0=y_m)\n            # Upper bound is the middle point slope for the case input lower bound >= 0.\n            # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)).\n            self.add_linear_relaxation(mask=self.mask_pos, type='upper', k=k_m, x0=m, y0=y_m)\n            # Now handle the case where input lower bound <=0 and upper bound >= 0.\n            # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound.\n            k = dfunc(d_lower)\n            # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f.\n            # This is only valid when the slope at the input lower bound has a slope greater than the direct line.\n            self.add_linear_relaxation(mask=mask_direct_lower, type='lower', k=k_direct, x0=lower, y0=y_l)\n            # Otherwise (i.e., when the input interval cross zero and mask_direct_lower is not true),\n            # we do not use the direct line, we use the d_lower slope.\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_both, mask_direct_lower),\n                type='lower', k=k, x0=d_lower, y0=func(d_lower))\n            # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0.\n            k = dfunc(d_upper)\n            self.add_linear_relaxation(\n                mask=mask_direct_upper, type='upper', k=k_direct, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_xor(self.mask_both, mask_direct_upper),\n                type='upper', k=k, x0=d_upper, y0=func(d_upper))\n\n            if self.disable_smoothness_enhancement:\n                return\n            # Partially modify the linear bound computation for intervals that contains 0 so that the linear bound\n            # changes smoothly w.r.t to the input bounds. For example, when we fix the input lower bound and drag the\n            # input upper bound, we do not expect the linear bound to change abruptly at any point.\n            # Therefore, under certain conditions, we do not use the above heuristics. Instead, we draw a tangent line\n            # through the middle point (m, func(m)) where m = (lower + upper) / 2 and use it as a linear bound.\n            if self.inverse_s_shape:\n                # When the function has an inverse s-shape (such as pow3), we switch to drawing a tangent line through\n                # the middle point as the lower bound when the default point of tangency is on the left of the middle\n                # point. Otherwise, the lower bound will be too loose on the side of the input upper bound. The change\n                # will make the bound on the other side a little bit looser as a tradeoff for overall tightness.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(self.mask_both, d_lower < m),\n                    type='lower', k=k_m, x0=m, y0=y_m)\n                # We make a similar change to the linear upper bound when the default point of tangency is on\n                # the right of the middle point.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(self.mask_both, d_upper >= m),\n                    type='upper', k=k_m, x0=m, y0=y_m)\n            elif self.sigmoid_like_mask is not None:\n                # self.sigmoid_like_mask is originally defined for periodic functions like sin and cos. It marks\n                # intervals on the s-shaped or flipped-s-shaped parts of the function. Whether the part is flipped-s-\n                # shaped is determined by comparing func(lower) and func(upper). Currently, some overall s-shaped\n                # function, such as tanh and sigmoid, also has this mask. In the future, we will make it default for\n                # both completely and partially s-shaped functions to reduce branching in the code.\n                y_l = func(lower)\n                y_u = func(upper)\n                # If the input interval is on the s-shaped part of the function, we switch to drawing a tangent line\n                # through the middle point as the lower bound when the default point of tangency is on the right of the\n                # middle point.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(torch.logical_and(self.sigmoid_like_mask, y_l < y_u), d_lower >= m),\n                    type='lower', k=k_m, x0=m, y0=y_m)\n                # We switch to drawing a tangent line through the middle point as the upper bound when the default point\n                # of tangency is on the left of the middle point.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(torch.logical_and(self.sigmoid_like_mask, y_l < y_u), d_upper < m),\n                    type='upper', k=k_m, x0=m, y0=y_m)\n                # If the input interval is on the flipped-s-shaped part of the function, we flip the condition as well\n                # as whether we change the lower or upper bound.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(torch.logical_and(self.sigmoid_like_mask, y_l >= y_u), d_lower < m),\n                    type='lower', k=k_m, x0=m, y0=y_m)\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(torch.logical_and(self.sigmoid_like_mask, y_l >= y_u), d_upper >= m),\n                    type='upper', k=k_m, x0=m, y0=y_m)\n            else:\n                # Handle simple cases where the function has the most common s shape. Now it serves as a safeguard\n                # against any child operator class whose self.sigmoid_like_mask is uninitialized. Here self.mask_both is\n                # equivalent to self.sigmoid_like_mask & (y_l < y_u) in the case above.\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(self.mask_both, d_lower >= m),\n                    type='lower', k=k_m, x0=m, y0=y_m)\n                self.add_linear_relaxation(\n                    mask=torch.logical_and(self.mask_both, d_upper < m),\n                    type='upper', k=k_m, x0=m, y0=y_m)\n\n    def bound_relax_branch(self, lb, ub):\n        # For functions that are only partially s-shaped, such as sin and cos, the non-s-shaped intervals are re-bounded\n        # here. This method returns the linear bound coefficients (lower_slope, lower_bias, upper_slope, upper_bias) of\n        # the non-s-shaped intervals. For globally s-shaped functions like tanh and sigmoid, the method returns 0s.\n        return 0., 0., 0., 0.\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        lb = x.lower\n        ub = x.upper\n        self.branch_input_domain(lb, ub)\n        if self.activation_bound_option == 'same-slope':\n            self.bound_relax_impl_same_slope(x, self.act_func, self.d_act_func)\n        else:\n            self.bound_relax_impl(x, self.act_func, self.d_act_func)\n        lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_branch(lb, ub)\n        self.lw = self.lw * self.sigmoid_like_mask + self.branch_mask * lower_slope\n        self.lb = self.lb * self.sigmoid_like_mask + self.branch_mask * lower_bias\n        self.uw = self.uw * self.sigmoid_like_mask + self.branch_mask * upper_slope\n        self.ub = self.ub * self.sigmoid_like_mask + self.branch_mask * upper_bias\n\n    def get_split_mask(self, lower, upper, input_index):\n        assert input_index == 0\n        return torch.logical_and(\n            upper - lower >= self.split_min_gap,\n            torch.logical_or(upper >= self.split_range[0],\n                             lower <= self.split_range[1])\n        )\n\nclass BoundPow(BoundSShaped):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        self.exponent = 2\n        super().__init__(attr, inputs, output_index, options)\n        self.ibp_intermediate = False\n        self.has_constraint = True\n\n        def act_func(x):\n            return torch.pow(x, self.exponent)\n        self.act_func = act_func\n        def d_act_func(x):\n            return self.exponent * torch.pow(x, self.exponent - 1)\n        self.d_act_func = d_act_func\n        def d2_act_func(x):\n            return self.exponent * (self.exponent - 1) * torch.pow(x, self.exponent - 2)\n        self.d2_act_func = d2_act_func\n\n    def generate_d_lower_upper(self, lower, upper):\n        if self.exponent % 2:\n            # Indices of neurons with input upper bound >=0,\n            # whose optimal slope to lower bound the function was pre-computed.\n            # Note that for neurons with also input lower bound >=0, they will be masked later.\n            d_upper = self.retrieve_from_precompute(self.d_upper, upper, lower)\n\n            # Indices of neurons with lower bound <=0,\n            # whose optimal slope to upper bound the function was pre-computed.\n            d_lower = self.retrieve_from_precompute(self.d_lower, -lower, upper)\n            return d_lower, d_upper\n        else:\n            return torch.zeros_like(upper), torch.zeros_like(upper)\n\n    def branch_input_domain(self, lb, ub):\n        lower = lb\n        upper = ub\n        num_inflection = torch.zeros_like(lower)\n        inflection_mat = lower\n        for inflection in self.inflections:\n            num_inflection += torch.logical_and(\n                lower <= inflection, upper >= inflection)\n            inflection_mat = torch.where(\n                torch.logical_and(lower <= inflection, upper >= inflection),\n                torch.tensor(inflection, device=lb.device), inflection_mat)\n        inflection_mask = num_inflection <= 1.\n\n        extreme_mask = torch.ones_like(lower)\n        for extreme in self.extremes:\n            extreme_mask *= torch.logical_or(lower >= extreme, upper <= extreme)\n\n        self.sigmoid_like_mask = torch.logical_and(inflection_mask, extreme_mask)\n        self.branch_mask = torch.logical_xor(torch.ones_like(lower), self.sigmoid_like_mask)\n        self.inflection_mat = torch.where(self.sigmoid_like_mask, inflection_mat, lower)\n\n        self.mask_neg = torch.logical_and((self.d2_act_func(lower) >= 0),\n            torch.logical_and((self.d2_act_func(upper) >= 0),\n            self.sigmoid_like_mask))\n        self.mask_pos = torch.logical_and((self.d2_act_func(lower) < 0),\n            torch.logical_and((self.d2_act_func(upper) < 0),\n            self.sigmoid_like_mask))\n        self.mask_both = torch.logical_xor(self.sigmoid_like_mask,\n            torch.logical_or(self.mask_neg, self.mask_pos))\n        self.convex_concave = self.d2_act_func(lower) >= 0\n\n    @torch.no_grad()\n    def precompute_relaxation(self, func, dfunc, x_limit = 500):\n        \"\"\"\n        This function precomputes the tangent lines that will be used as\n        lower/upper bounds for S-shapes functions.\n        \"\"\"\n        self.x_limit = x_limit\n        self.num_points_pre = int(self.x_limit / self.step_pre)\n\n        max_iter = 100\n\n        def check_lower(upper, d):\n            \"\"\"Given two points upper, d (d <= upper), check if the slope at d\n            will be less than f(upper) at upper.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a lower bound.\n            return k * (upper - d) + func(d) <= func(upper)\n\n        def check_upper(lower, d):\n            \"\"\"Given two points lower, d (d >= lower), check if the slope at d\n            will be greater than f(lower) at lower.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a upper bound.\n            return k * (lower - d) + func(d) >= func(lower)\n\n        # Given an upper bound point (>=0), find a line that is guaranteed to\n        # be a lower bound of this function.\n        upper = self.step_pre * torch.arange(\n            0, self.num_points_pre + 5, device=self.device)\n        r = torch.zeros_like(upper)\n        # Initial guess, the tangent line is at -1.\n        l = -torch.ones_like(upper)\n        while True:\n            # Check if the tangent line at the guessed point is an lower bound at f(upper).\n            checked = check_upper(upper, l).int()\n            # If the initial guess is not smaller enough, then double it (-2, -4, etc).\n            l = checked * l + (1 - checked) * (l * 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to\n        # be an lower bound at f(upper).\n        # We want to further tighten this bound by moving it closer to 0.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_upper(upper, m).int()\n            l = checked * m + (1 - checked) * l\n            r = checked * r + (1 - checked) * m\n        # At upper, a line with slope l is guaranteed to lower bound the function.\n        self.d_upper = l.clone()\n\n        # Do the same again:\n        # Given an lower bound point (<=0), find a line that is guaranteed to\n        # be an upper bound of this function.\n        lower = -self.step_pre * torch.arange(\n            0, self.num_points_pre + 5, device=self.device)\n        l = torch.zeros_like(upper)\n        r = torch.ones_like(upper)\n        while True:\n            checked = check_lower(lower, r).int()\n            r = checked * r + (1 - checked) * (r * 2)\n            if checked.sum() == l.numel():\n                break\n        for _ in range(max_iter):\n            m = (l + r) / 2\n            checked = check_lower(lower, m).int()\n            l = (1 - checked) * m + checked * l\n            r = (1 - checked) * r + checked * m\n        self.d_lower = r.clone()\n\n    def forward(self, x, y):\n        return torch.pow(x, y)\n\n    def bound_backward(self, last_lA, last_uA, x, y, start_node=None,\n                       start_shape=None, **kwargs):\n        assert not self.is_input_perturbed(1)\n        self._start = start_node.name if start_node is not None else None\n        y = y.value\n        if y == int(y):\n            x.upper = torch.max(x.upper, x.lower + 1e-8)\n            self.exponent = int(y)\n            assert self.exponent >= 2\n            if self.exponent % 2:\n                self.precompute_relaxation(self.act_func, self.d_act_func)\n\n            As, lbias, ubias = super().bound_backward(\n                last_lA, last_uA, x, start_node, start_shape, **kwargs)\n            return [As[0], (None, None)], lbias, ubias\n        else:\n            raise NotImplementedError('Exponent is not supported yet')\n\n    def bound_forward(self, dim_in, x, y):\n        assert y.lower == y.upper == int(y.lower)\n        y = y.lower\n        x.upper = torch.max(x.upper, x.lower + 1e-8)\n        self.exponent = int(y)\n\n        assert self.exponent >= 2\n        if self.exponent % 2:\n            self.precompute_relaxation(self.act_func, self.d_act_func)\n        return super().bound_forward(dim_in, x)\n\n    def bound_relax_branch(self, lb, ub):\n        if self.opt_stage in ['opt', 'reuse']:\n            if not hasattr(self, 'alpha'):\n                # Raise an error if alpha is not created.\n                self._no_bound_parameters()\n            ns = self._start\n\n            self.alpha[ns].data[8:10] = torch.max(\n                torch.min(self.alpha[ns][8:10], ub), lb)\n            lb_point = self.alpha[ns][8:10]\n            lower_slope = self.d_act_func(lb_point)\n            lower_bias = self.act_func(lb_point) - lower_slope * lb_point\n        else:\n            lower_slope = 0\n            lower_bias = 0\n\n        upper_slope = (self.act_func(ub) - self.act_func(lb)) / (ub - lb).clamp(min=1e-8)\n        upper_bias = self.act_func(ub) - ub * upper_slope\n        return lower_slope, lower_bias, upper_slope, upper_bias\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        # For powers with odd exponents, such as x^3, the overall shape is inverse S-like.\n        self.inverse_s_shape = self.exponent % 2 == 1\n        if self.exponent % 2:\n            self.inflections = [0.]\n        else:\n            self.extremes = [0.]\n        super().bound_relax(x, init, dim_opt)\n\n    def interval_propagate(self, *v):\n        assert not self.is_input_perturbed(1)\n        exp = v[1][0]\n        assert exp == int(exp)\n        exp = int(exp)\n        pl, pu = torch.pow(v[0][0], exp), torch.pow(v[0][1], exp)\n        if exp % 2 == 1:\n            return pl, pu\n        else:\n            pl, pu = torch.min(pl, pu), torch.max(pl, pu)\n            mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).to(pl.dtype)\n            return pl * mask, pu\n\n    def clamp_interim_bounds(self):\n        if self.exponent % 2 == 0:\n            self.cstr_lower = self.lower.clamp(min=0)\n            self.cstr_upper = self.upper.clamp(min=0)\n            self.cstr_interval = (self.cstr_lower, self.cstr_upper)\n\n\ndef dtanh(x):\n    return 1 - torch.tanh(x).pow(2)\n\ndef dsigmoid(x):\n    return torch.sigmoid(x) * (1 - torch.sigmoid(x))\n\ndef darctan(x):\n    return (x.square() + 1.).reciprocal()\n\ndef d2tanh(x):\n    return -2 * torch.tanh(x) * (1 - torch.tanh(x).pow(2))\n\ndef d2sigmoid(x):\n    return dsigmoid(x) * (1 - 2 * torch.sigmoid(x))\n\n\nclass BoundTanh(BoundSShaped):\n    \"\"\"\n    BoundTanh is based on the S-shaped BoundSShaped. In the meantime, it works as the\n    base class for other globally S-shaped functions such as Sigmoid and Atan.\n    \"\"\"\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None,\n                 activation=('tanh', torch.tanh, dtanh), precompute=True):\n        super().__init__(attr, inputs, output_index, options, activation, precompute)\n\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        return super()._init_opt_parameters_impl(size_spec, name_start, num_params=8)\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = TanhGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        grad_extra_nodes = [self.inputs[0]]\n        return [(node_grad, grad_input, grad_extra_nodes)]\n\n\nclass TanhGradOp(Function):\n    @staticmethod\n    def symbolic(_, preact):\n        return _.op('grad::Tanh', preact).setType(preact.type())\n    \n    @staticmethod\n    def forward(ctx, preact):\n        return 1 - torch.tanh(preact)**2\n\n\nclass TanhGrad(Module):\n    def forward(self, g, preact):\n        return g * TanhGradOp.apply(preact).unsqueeze(1)\n\n\nclass BoundTanhGrad(BoundOptimizableActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None,\n                 activation=('tanh', dtanh, d2tanh), precompute=True):\n        super().__init__(attr, inputs, output_index, options)\n        self.requires_input_bounds = [0]\n        # The inflection point is where d2f/dx2 = 0.\n        self.inflection_point = 0.6585026\n        self.func = activation[1]\n        self.dfunc = activation[2]\n        if precompute:\n            self.precompute_relaxation()\n\n    def forward(self, x):\n        return self.func(x)\n\n    def interval_propagate(self, *v):\n        lower, upper = v[0]\n        f_lower = self.func(lower)\n        f_upper = self.func(upper)\n        next_lower = torch.min(f_lower, f_upper)\n        next_upper = torch.max(f_lower, f_upper)\n        mask_both = torch.logical_and(lower < 0, upper > 0)\n        next_upper[mask_both] = self.func(torch.tensor(0))\n        return next_lower, next_upper\n    \n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        return self.bound_relax_impl(x)\n    \n    def precompute_relaxation(self, x_limit=500):\n        \"\"\"\n        This function precomputes the tangent lines that will be used as\n        the lower/upper bounds for bell-shaped functions.\n        Three tensors are precomputed:\n        - self.precompute_x: The x values of the upper preactivation bound.\n        - self.d_lower: The tangent points of the lower bound.\n        - self.d_upper: The tangent points of the upper bound.\n        \"\"\"\n\n        self.x_limit = x_limit\n        self.step_pre = 0.01\n        self.num_points_pre = int(self.x_limit / self.step_pre)\n\n        max_iter = 100\n        func, dfunc = self.func, self.dfunc\n\n        logger.debug('Precomputing relaxation for %s (pre-activation limit: %f)',\n                     self.__class__.__name__, x_limit)\n\n        def check_lower(upper, d):\n            \"\"\"Given two points upper, d (d <= upper),\n            check if the slope at d will be less than f(upper) at upper.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a lower bound.\n            return k * (upper - d) + func(d) <= func(upper)\n\n        def check_upper(lower, d):\n            \"\"\"Given two points lower, d (d <= lower),\n            check if the slope at d will be greater than f(lower) at lower.\"\"\"\n            k = dfunc(d)\n            # Return True if the slope is a upper bound.\n            return k * (lower - d) + func(d) >= func(lower)\n\n        self.precompute_x = torch.arange(-self.x_limit, self.x_limit + self.step_pre, self.step_pre, device=self.device)\n        self.d_lower = torch.zeros_like(self.precompute_x)\n        self.d_upper = torch.zeros_like(self.precompute_x)\n\n        # upper point that needs lower precomputed tangent line\n        mask_need_d_lower = self.precompute_x >= -self.inflection_point\n        upper = self.precompute_x[mask_need_d_lower] \n        # 1. Initial guess, the tangent is at -2*inflection_point (should be between (-inf, -inflection_point))\n        r = -self.inflection_point * torch.ones_like(upper)\n        l = -2 * self.inflection_point * torch.ones_like(upper)\n        while True:\n            # Check if the tangent line at the guessed point is an lower bound at f(upper).\n            checked = check_lower(upper, l).int()\n            # If the initial guess is not smaller enough, then double it (-2, -4, etc).\n            l = checked * l + (1 - checked) * (l * 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper).\n        # We want to further tighten this bound by moving it closer to upper.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_lower(upper, m).int()\n            l = checked * m + (1 - checked) * l\n            r = checked * r + (1 - checked) * m\n        # At upper, a line with slope l is guaranteed to lower bound the function.\n        self.d_lower[mask_need_d_lower] = l.clone()\n\n        # upper point that needs upper precomputed tangent line\n        mask_need_upper_d = self.precompute_x >= self.inflection_point\n        upper = self.precompute_x[mask_need_upper_d]\n        # 1. Initial guess, the tangent is at inflection_point/2 (should be between (0, inflection_point))\n        r = self.inflection_point * torch.ones_like(upper)\n        l = self.inflection_point / 2 * torch.ones_like(upper)\n        while True:\n            # Check if the tangent line at the guessed point is an upper bound at f(upper).\n            checked = check_upper(upper, l).int()\n            # If the initial guess is not smaller enough, then reduce it.\n            l = checked * l + (1 - checked) * (l / 2)\n            if checked.sum() == l.numel():\n                break\n        # Now we have starting point at l, its tangent line is guaranteed to be an upper bound at f(upper).\n        # We want to further tighten this bound by moving it closer to upper.\n        for _ in range(max_iter):\n            # Binary search.\n            m = (l + r) / 2\n            checked = check_upper(upper, m).int()\n            l = checked * m + (1 - checked) * l\n            r = checked * r + (1 - checked) * m\n        # At upper, a line with slope l is guaranteed to upper bound the function.\n        self.d_upper[mask_need_upper_d] = l.clone()\n\n    def retrieve_from_precompute(self, x, flip=False):\n        if not flip:\n            if x.max() > self.x_limit:\n                warnings.warn(f'Pre-activation bounds are too loose for {self}')\n            # Take the left endpoint of the interval\n            x_indices = torch.searchsorted(self.precompute_x, x, right=True) - 1\n            return self.d_lower[x_indices], self.d_upper[x_indices]\n        else:\n            if x.min() < -self.x_limit:\n                warnings.warn(f'Pre-activation bounds are too loose for {self}')\n            # Take the right endpoint of the interval\n            x_indices = torch.searchsorted(self.precompute_x, -x, right=False)\n            return -self.d_lower[x_indices], -self.d_upper[x_indices]\n            \n\n    def bound_relax_impl(self, x):\n        lower, upper = x.lower, x.upper\n        func, dfunc = self.func, self.dfunc\n        y_l, y_u = func(lower), func(upper)\n        # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)).\n        k_direct = (y_u - y_l) / (upper - lower).clamp(min=1e-8)\n\n        # The tangent line at the midpoint can be a good approximation\n        midpoint = (lower + upper) / 2\n        k_midpoint = dfunc(midpoint)\n        y_midpoint = func(midpoint)\n\n        # If -inflection_point <= lower < upper <= inflection_point,\n        # we call it \"completely concave\" region.\n        mask_completely_concave = torch.logical_and(\n            lower >= -self.inflection_point,\n            upper <= self.inflection_point\n        )\n        self.add_linear_relaxation(\n            mask=mask_completely_concave, type='lower', k=k_direct, x0=lower, y0=y_l)\n        self.add_linear_relaxation(\n            mask=mask_completely_concave, type='upper', k=k_midpoint, x0=midpoint, y0=y_midpoint)\n        \n        # From now on, we assume at least one of the bounds is outside the completely concave region.\n        # Without loss of generality, we assume upper > inflection_point (indicated by mask_right).\n        mask_right = lower + upper >= 0\n\n        dl, du = self.retrieve_from_precompute(upper, flip=False)\n        dl_, du_ = self.retrieve_from_precompute(lower, flip=True)\n\n        # Case 1: Similar to a convex function\n        mask_case1 = torch.logical_or(\n            torch.logical_and(mask_right, lower >= self.inflection_point),\n            torch.logical_and(torch.logical_not(mask_right), upper <= -self.inflection_point)\n        )\n        self.add_linear_relaxation(\n            mask=mask_case1, type='upper', k=k_direct, x0=lower, y0=y_l)\n        self.add_linear_relaxation(\n            mask=mask_case1, type='lower', k=k_midpoint, x0=midpoint, y0=y_midpoint)\n        \n        # Case 2: Similar to a S-shaped function\n        mask_case2_right = torch.logical_and(mask_right, torch.logical_and(\n            upper > self.inflection_point, lower < self.inflection_point))\n        # The upper tangent point is lineraly interpolated between 0 and du,\n        # given lower ranging between -upper and du.\n        d_mask_case2_right_upper = du * (lower + upper) / (du + upper)\n        k_mask_case2_right_upper = dfunc(d_mask_case2_right_upper)\n        y_mask_case2_right_upper = func(d_mask_case2_right_upper)\n        self.add_linear_relaxation(\n            mask=mask_case2_right, type='upper',\n            k=k_mask_case2_right_upper, x0=d_mask_case2_right_upper, y0=y_mask_case2_right_upper)\n        # The lower tangent point is found based on lower.\n        d_mask_case2_right_lower = (dl_ + upper) / 2\n        k_mask_case2_right_lower = dfunc(d_mask_case2_right_lower)\n        y_mask_case2_right_lower = func(d_mask_case2_right_lower)\n        self.add_linear_relaxation(\n            mask=torch.logical_and(mask_case2_right, dl_ < upper), type='lower',\n            k=k_mask_case2_right_lower, x0=d_mask_case2_right_lower, y0=y_mask_case2_right_lower)\n        self.add_linear_relaxation(\n            mask=torch.logical_and(mask_case2_right, dl_ >= upper), type='lower',\n            k=k_direct, x0=lower, y0=y_l)\n\n        mask_case2_left = torch.logical_and(torch.logical_not(mask_right), torch.logical_and(\n            lower < -self.inflection_point, upper > -self.inflection_point))\n        # The upper tangent point is lineraly interpolated between du_ and 0,\n        # given upper ranging between du_ and -lower.\n        d_mask_case2_left_upper = du_ * (upper + lower) / (du_ + lower)\n        k_mask_case2_left_upper = dfunc(d_mask_case2_left_upper)\n        y_mask_case2_left_upper = func(d_mask_case2_left_upper)\n        self.add_linear_relaxation(\n            mask=mask_case2_left, type='upper',\n            k=k_mask_case2_left_upper, x0=d_mask_case2_left_upper, y0=y_mask_case2_left_upper)\n        # The lower tangent point is found based on upper.\n        d_mask_case2_left_lower = (dl + lower) / 2\n        k_mask_case2_left_lower = dfunc(d_mask_case2_left_lower)\n        y_mask_case2_left_lower = func(d_mask_case2_left_lower)\n        self.add_linear_relaxation(\n            mask=torch.logical_and(mask_case2_left, dl > lower), type='lower',\n            k=k_mask_case2_left_lower, x0=d_mask_case2_left_lower, y0=y_mask_case2_left_lower)\n        self.add_linear_relaxation(\n            mask=torch.logical_and(mask_case2_left, dl <= lower), type='lower',\n            k=k_direct, x0=upper, y0=y_u)\n        \n        # If the lower and upper bounds are too close, we just use IBP bounds to avoid numerical issues.\n        mask_very_close = upper - lower < 1e-6\n        if mask_very_close.any():\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_neg), type='lower', k=0, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_neg), type='upper', k=0, x0=upper, y0=y_u)\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_pos), type='lower', k=0, x0=upper, y0=y_u)\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_pos), type='upper', k=0, x0=lower, y0=y_l)\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_both), type='lower', k=0, x0=lower, y0=torch.min(y_l, y_u))\n            self.add_linear_relaxation(\n                mask=torch.logical_and(mask_very_close, self.mask_both), type='upper', k=0, x0=upper, y0=torch.full_like(y_l, func(torch.tensor(0))))\n\n\nclass BoundSigmoid(BoundTanh):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options,\n                         activation=('sigmoid', torch.sigmoid, dsigmoid))\n    \n    def build_gradient_node(self, grad_upstream):\n        node_grad = SigmoidGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        grad_extra_nodes = [self.inputs[0]]\n        return [(node_grad, grad_input, grad_extra_nodes)]\n\n\nclass SigmoidGradOp(Function):\n    @staticmethod\n    def symbolic(_, preact):\n        return _.op('grad::Sigmoid', preact).setType(preact.type())\n    \n    @staticmethod\n    def forward(ctx, preact):\n        sigmoid_x = torch.sigmoid(preact)\n        return sigmoid_x * (1 - sigmoid_x)\n\n\nclass SigmoidGrad(Module):\n    def forward(self, g, preact):\n        return g * SigmoidGradOp.apply(preact).unsqueeze(1)\n\n\nclass BoundSigmoidGrad(BoundTanhGrad):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None,\n                 activation=('sigmoid', dsigmoid, d2sigmoid), precompute=True):\n        super().__init__(attr, inputs, output_index, options, activation, precompute=False)\n        self.inflection_point = 1.3169614\n        if precompute:\n            self.precompute_relaxation()\n\n\nclass BoundAtan(BoundTanh):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options,\n                         activation=('arctan', torch.arctan, darctan))\n        self.split_range = (-torch.inf, torch.inf)\n\n    def build_gradient_node(self, grad_upstream):\n        node_grad = AtanGrad()\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        grad_extra_nodes = [self.inputs[0]]\n        return [(node_grad, grad_input, grad_extra_nodes)]\n\n\nclass AtanGrad(Module):\n    def forward(self, g, preact):\n        # arctan'(x) = 1 / (1 + x^2)\n        return g / (1 + preact.square()).unsqueeze(1)\n\n\nclass BoundTan(BoundAtan):\n    \"\"\"\n    The implementation of BoundTan is based on the S-shaped BoundAtan. We use the bounds from its\n    inverse function and directly convert the bounds of the inverse function to bounds of the original\n    function. This trick allows us to quickly implement bounds on inverse functions.\n    \"\"\"\n\n    def forward(self, x):\n        return torch.tan(x)\n\n    def _check_bounds(self, lower, upper):\n        # Lower and upper bounds must be within the same [-½π, ½π] region.\n        lower_periods = torch.floor((lower + 0.5 * torch.pi) / torch.pi)\n        upper_periods = torch.floor((upper + 0.5 * torch.pi) / torch.pi)\n        if not torch.allclose(lower_periods, upper_periods):\n            print('Tan preactivation lower bounds:\\n', lower)\n            print('Tan preactivation upper bounds:\\n', upper)\n            raise ValueError(\"BoundTan received pre-activation bounds that produce infinity. \"\n                    \"The preactivation bounds are too loose. Try to reduce perturbation region.\")\n        # Return the period number for each neuron.\n        # Period is 0 => bounds are within [-½π, ½π],\n        # Period is 1 => bounds are within [-½π + π, ½π + π]\n        # Period is -1 => bounds are within [-½π - π, ½π - π]\n        return lower_periods\n\n    def _init_masks(self, x):\n        # The masks now must consider the periodicity.\n        lower = torch.remainder(x.lower + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi\n        upper = torch.remainder(x.upper + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi\n        self.mask_pos = lower >= 0\n        self.mask_neg = upper <= 0\n        self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg))\n\n    def interval_propagate(self, *v):\n        # We need to check if the input lower and upper bounds are within the same period.\n        # Otherwise the bounds become infinity.\n        concrete_lower, concrete_upper = v[0][0], v[0][1]\n        self._check_bounds(concrete_lower, concrete_upper)\n        return super().interval_propagate(*v)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        if init:\n            self.init_linear_relaxation(x, dim_opt)\n        periods = self._check_bounds(x.lower, x.upper)\n        periods = torch.pi * periods\n        # Create a fake x with inversed lower and upper.\n        inverse_x = lambda: None\n        inverse_x.lower = torch.tan(x.lower)\n        inverse_x.upper = torch.tan(x.upper)\n        super().bound_relax(inverse_x, init=init, dim_opt=dim_opt)\n        # Lower slope, lower bias, upper slope and upper bias are saved to\n        # self.lw, self.lb, self.uw, self.ub. We need to reverse them.\n        # E.g., y = self.lw * x + self.lb, now becomes x = 1./self.lw * y - self.lb / self.lw\n        # Additionally, we need to add the missing ½π periods.\n        new_upper_slope = 1. / self.lw\n        new_upper_bias = - self.lb / self.lw - periods / self.lw\n        new_lower_slope = 1. / self.uw\n        new_lower_bias = - self.ub / self.uw - periods / self.uw\n\n        # NaN can happen if lw=0 or uw=0 when the pre-activation bounds are too close\n        # Replace the bounds with interval bounds.\n        if (self.lw == 0).any():\n            mask = self.lw == 0\n            new_upper_slope[mask] = 0\n            new_upper_bias[mask] = inverse_x.upper[mask]\n        if (self.uw == 0).any():\n            mask = self.uw == 0\n            new_lower_slope[mask] = 0\n            new_lower_bias[mask] = inverse_x.lower[mask]\n\n        self.lw = new_lower_slope\n        self.lb = new_lower_bias\n        self.uw = new_upper_slope\n        self.ub = new_upper_bias\n"
  },
  {
    "path": "auto_LiRPA/operators/shape.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Shape operators \"\"\"\nfrom .base import *\n\n\nclass BoundShape(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.never_perturbed = True\n\n    @staticmethod\n    def shape(x):\n        return x.shape if isinstance(x, Tensor) else torch.tensor(x).shape\n\n    def forward(self, x):\n        self.from_input = False\n        return BoundShape.shape(x)\n\n    def bound_forward(self, dim_in, x):\n        return self.forward_value\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        if not isinstance(v[0], Tensor):\n            # e.g., v[0] input shape (8, 7, 7) => output its shape (1, 8, 7, 7)\n            gvars_array = np.array(v[0])\n            self.solver_vars = torch.tensor(np.expand_dims(gvars_array, axis=0).shape).long()\n        else:\n            self.solver_vars = torch.tensor(self.forward(v[0])).long()\n"
  },
  {
    "path": "auto_LiRPA/operators/slice_concat.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Shape operators \"\"\"\nfrom torch.nn import Module\nfrom torch.autograd import Function\nfrom .base import *\nfrom ..patches import Patches\nfrom .constant import BoundConstant\n\n\nclass BoundConcat(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.axis = attr['axis']\n        self.IBP_rets = None\n        self.ibp_intermediate = True\n\n    def forward(self, *x):  # x is a list of tensors\n        x = [(item if isinstance(item, Tensor) else torch.tensor(item)) for item in x]\n        self.input_size = [item.shape[self.axis] for item in x]\n        self.axis = self.make_axis_non_negative(self.axis)\n        return torch.cat(x, dim=int(self.axis))\n\n    def interval_propagate(self, *v):\n        norms = []\n        eps = []\n        # Collect perturbation information for all inputs.\n        for i, _v in enumerate(v):\n            if self.is_input_perturbed(i):\n                n, e = Interval.get_perturbation(_v)\n                norms.append(n)\n                eps.append(e)\n            else:\n                norms.append(None)\n                eps.append(0.0)\n        eps = np.array(eps)\n        # Supporting two cases: all inputs are Linf norm, or all inputs are L2 norm perturbed.\n        # Some inputs can be constants without perturbations.\n        all_inf = all(map(lambda x: x is None or x == torch.inf, norms))\n        all_2 = all(map(lambda x: x is None or x == 2, norms))\n\n        h_L = [_v[0] for _v in v]\n        h_U = [_v[1] for _v in v]\n        if all_inf:\n            # Simply returns a tuple. Every subtensor has its own lower and upper bounds.\n            return self.forward(*h_L), self.forward(*h_U)\n        elif all_2:\n            # Sum the L2 norm over all subtensors, and use that value as the new L2 norm.\n            # This will be an over-approximation of the original perturbation (we can prove it).\n            max_eps = np.sqrt(np.sum(eps * eps))\n            # For L2 norm perturbed inputs, lb=ub and for constants lb=ub. Just propagate one object.\n            r = self.forward(*h_L)\n            ptb = PerturbationLpNorm(norm=2, eps=max_eps)\n            return Interval(r, r, ptb=ptb)\n        else:\n            raise RuntimeError(f\"BoundConcat does not support inputs with norm {norms}\")\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        self.axis = self.make_axis_non_negative(self.axis, 'output')\n        assert self.axis > 0\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            if isinstance(last_A, torch.Tensor):\n                ret = list(torch.split(last_A, self.input_size, dim=self.axis + 1))\n                # Skip unused input nodes to reduce the cost of computing unused intermediate bounds\n                for i in range(len(ret)):\n                    if (ret[i] == 0).all():\n                        ret[i] = None\n                return ret\n            elif isinstance(last_A, Patches):\n                assert len(self.input_shape) == 4 and self.axis == 1, \"Split channel dimension is supported; others are unimplemented.\"\n                # Patches shape can be [out_c, batch, out_h, out_w, in_c, patch_h, patch_w]\n                # Or [spec, batch, in_c, patch_h, patch_w]  (sparse)\n                new_patches = torch.split(last_A.patches, self.input_size, dim=-3)  # split the in_c dimension is easy.\n                return [last_A.create_similar(p) for p in new_patches]\n            else:\n                raise RuntimeError(f'Unsupported type for last_A: {type(last_A)}')\n\n        uA = _bound_oneside(last_uA)\n        lA = _bound_oneside(last_lA)\n\n        if uA is None:\n            return [(lA[i] if lA is not None else None, None)\n                    for i in range(len(lA))], 0, 0\n        if lA is None:\n            return [(None, uA[i] if uA is not None else None)\n                    for i in range(len(uA))], 0, 0\n\n        # To avoid issues in other parts of the code, we prune unused\n        # lA and uA only when they are both unused.\n        for i in range(len(lA)):\n            if lA[i] is None and uA[i] is not None:\n                lA[i] = torch.zeros_like(uA[i])\n            elif lA[i] is not None and uA[i] is None:\n                uA[i] = torch.zeros_like(lA[i])\n\n        return [(lA[i], uA[i]) for i in range(len(lA))], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        self.axis = self.make_axis_non_negative(self.axis)\n        assert (self.axis == 0 and not self.from_input or self.from_input)\n        # Concatenate each input's bounds along the axis.\n        # If x[i].lw and x[i].uw is None, it means the input is a constant,\n        # so we concatenate a tensor of zeros with the corresponding shape.\n        lw = torch.cat([item.lw if item.lw is not None else\n                        torch.zeros(item.lb.shape[0], dim_in, *item.lb.shape[1:], device=item.lb.device)\n                        for item in x], dim=self.axis + 1)\n        lb = torch.cat([item.lb for item in x], dim=self.axis)\n        uw = torch.cat([item.uw if item.uw is not None else\n                        torch.zeros(item.ub.shape[0], dim_in, *item.ub.shape[1:], device=item.ub.device)\n                        for item in x], dim=self.axis + 1)\n        ub = torch.cat([item.ub for item in x], dim=self.axis)\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(*v)\n\n    def build_gradient_node(self, grad_upstream):\n        ret = []\n        for i in range(len(self.inputs)):\n            node_grad = ConcatGrad(self.axis, i)\n            grad_input = (grad_upstream, ) + tuple(inp.forward_value for inp in self.inputs)\n            ret.append((node_grad, grad_input, []))\n        return ret\n\n\nBoundConcatFromSequence = BoundConcat\n\n\nclass BoundSlice(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.start = attr[\"starts\"][0] if \"starts\" in attr else None\n        self.end = attr[\"ends\"][0] if \"ends\" in attr else None\n        self.axes = attr[\"axes\"][0] if \"axes\" in attr else None\n        self.use_default_ibp = False\n        self.ibp_intermediate = True\n\n    def __repr__(self):\n        attrs = {}\n        if (len(self.inputs) == 5\n            and all(isinstance(item, BoundConstant) and item.value.numel() == 1\n                    for item in self.inputs[1:])):\n            attrs['start'] = self.inputs[1].value.item()\n            attrs['end'] = self.inputs[2].value.item()\n            attrs['axes'] = self.inputs[3].value.item()\n            attrs['step'] = self.inputs[4].value.item()\n        return super().__repr__(attrs)\n\n    def _fixup_params(self, shape, start, end, axes, steps):\n        if start < 0:\n            start += shape[axes]\n        if end < 0:\n            if end == -9223372036854775807:  # -inf in ONNX\n                end = 0  # only possible when step == -1\n            else:\n                end += shape[axes]\n        if steps == -1:\n            start, end = end, start + 1  # TODO: more test more negative step size.\n        end = min(end, shape[axes])\n        return start, end\n\n    # Older Pytorch version only passes steps as input.\n    def forward(self, x, start=None, end=None, axes=None, steps=1):\n        start = self.start if start is None else start\n        end = self.end if end is None else end\n        axes = self.axes if axes is None else axes\n        assert (steps == 1 or steps == -1) and axes == int(axes) and start == int(start) and end == int(end)\n        shape = x.shape if isinstance(x, Tensor) else [len(x)]\n        start, end = self._fixup_params(shape, start, end, axes, steps)\n        final = torch.narrow(x, dim=int(axes), start=int(start), length=int(end - start))\n        if steps == -1:\n            final = torch.flip(final, dims=tuple(axes))\n        return final\n\n    def interval_propagate(self, *v):\n        lb = tuple(map(lambda x:x[0],v))\n        ub = tuple(map(lambda x:x[1],v))\n        return Interval.make_interval(self.forward(*lb), self.forward(*ub))\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(*v)\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        def _bound_oneside(A, start, end, axes, steps):\n            if A is None:\n                return None\n            if isinstance(A, torch.Tensor):\n                # Reuse the batch and spec dimension of A, and replace other shapes with input.\n                A_shape = A.shape[:2] + self.input_shape[1:]\n                new_A = torch.zeros(size=A_shape, device=A.device,\n                                    requires_grad=A.requires_grad)\n                # Fill part of the new_A based on start, end, axes and steps.\n                # Skip the spec dimension at the front (axes + 1).\n                dim = axes if axes < 0 else axes + 1\n                indices = torch.arange(start, end, device=A.device)\n                new_A = torch.index_copy(new_A, dim=dim, index=indices, source=A)\n            elif isinstance(A, Patches):\n                assert A.unstable_idx is None\n                assert len(self.input_shape) == 4 and axes == 1, \"Slice is only supported on channel dimension.\"\n                patches = A.patches\n                # patches shape is [out_c, batch, out_h, out_w, in_c, patch_h, patch_w].\n                new_patches_shape = patches.shape[:4] + (self.input_shape[1], ) + patches.shape[-2:]\n                new_patches = torch.zeros(\n                    size=new_patches_shape, device=patches.device,\n                    requires_grad=patches.requires_grad)\n                indices = torch.arange(start, end, device=patches.device)\n                new_patches = torch.index_copy(new_patches, dim=-3, index=indices, source=patches)\n                # Only the in_c dimension is changed.\n                new_A = A.create_similar(new_patches)\n            else:\n                raise ValueError(f'Unsupport A type {type(A)}')\n            return new_A\n\n        start, end, axes = x[1].value.item(), x[2].value.item(), x[3].value.item()\n        steps = x[4].value.item() if len(x) == 5 else 1  # If step is not specified, it is 1.\n        # Other step size untested, do not enable for now.\n        assert steps == 1 and axes == int(axes) and start == int(start) and end == int(end)\n        start, end = self._fixup_params(self.input_shape, start, end, axes, steps)\n        # Find the original shape of A.\n        lA = _bound_oneside(last_lA, start, end, axes, steps)\n        uA = _bound_oneside(last_uA, start, end, axes, steps)\n        return [(lA, uA), (None, None), (None, None), (None, None), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *inputs):\n        assert len(inputs) == 5 or len(inputs) == 4\n        start = inputs[1].lb.item()\n        end = inputs[2].lb.item()\n        axis = self.make_axis_non_negative(inputs[3].lb.item())\n        assert axis > 0, \"Slicing along the batch dimension is not supported yet\"\n        steps = inputs[4].lb.item() if len(inputs) == 5 else 1  # If step is not specified, it is 1.\n        assert steps in [1, -1]\n        x = inputs[0]\n        shape = x.lb.shape\n        start, end = self._fixup_params(shape, start, end, axis, steps)\n        lw = torch.narrow(x.lw, dim=axis+1, start=start, length=end - start)\n        uw = torch.narrow(x.uw, dim=axis+1, start=start, length=end - start)\n        lb = torch.narrow(x.lb, dim=axis, start=start, length=end - start)\n        ub = torch.narrow(x.ub, dim=axis, start=start, length=end - start)\n        if steps == -1:\n            lw = torch.flip(lw, dims=tuple(axis+1))\n            uw = torch.flip(uw, dims=tuple(axis+1))\n            lb = torch.flip(lb, dims=tuple(axis))\n            ub = torch.flip(ub, dims=tuple(axis))\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_gradient_node(self, grad_upstream):\n        assert len(self.inputs) == 5\n        start = self.inputs[1].value.item()\n        end = self.inputs[2].value.item()\n        axes = self.inputs[3].value.item()\n        steps = self.inputs[4].value.item()\n        assert steps == 1\n        node_grad = SliceGrad(start, end, axes, steps)\n        grad_input = (grad_upstream, self.inputs[0].forward_value)\n        return [(node_grad, grad_input, [])]\n\n\nclass BoundSplit(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.axis = attr['axis']\n        self.use_default_ibp = True\n        if 'split' in attr:\n            self.split = attr['split']\n        else:\n            self.split = None\n\n    def forward(self, *x):\n        data = x[0]\n        split = self.split if self.split is not None else x[1].tolist()\n        if self.axis == -1:\n            self.axis = len(data.shape) - 1\n        return torch.split(data, split, dim=self.axis)[self.output_index]\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        assert self.axis > 0\n        split = self.split if self.split is not None else x[1].value.tolist()\n        pre = sum(split[:self.output_index])\n        suc = sum(split[(self.output_index + 1):])\n\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            A = []\n            if pre > 0:\n                A.append(torch.zeros(\n                    *last_A.shape[:(self.axis + 1)], pre, *last_A.shape[(self.axis + 2):],\n                    device=last_A.device))\n            A.append(last_A)\n            if suc > 0:\n                A.append(torch.zeros(\n                    *last_A.shape[:(self.axis + 1)], suc, *last_A.shape[(self.axis + 2):],\n                    device=last_A.device))\n            return torch.cat(A, dim=self.axis + 1)\n\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        assert self.axis > 0 and self.from_input\n        split = self.split if self.split is not None else x[1].lb.tolist()\n        x = x[0]\n        lw = torch.split(x.lw, split, dim=self.axis + 1)[self.output_index]\n        uw = torch.split(x.uw, split, dim=self.axis + 1)[self.output_index]\n        lb = torch.split(x.lb, split, dim=self.axis)[self.output_index]\n        ub = torch.split(x.ub, split, dim=self.axis)[self.output_index]\n        return LinearBound(lw, lb, uw, ub)\n\n    def build_solver(self, *v, model, C=None, model_type=\"mip\", solver_pkg=\"gurobi\"):\n        self.solver_vars = self.forward(v[0])\n\n\ndef slice_grad(x, input_shape, start, end, axes, steps):\n    assert steps == 1\n    assert axes > 0\n    out = torch.zeros(*x.shape[:2], *input_shape[1:]).to(x)\n    end = min(end, input_shape[axes])\n    index = torch.arange(start, end, device=x.device)\n    # Make index.ndim == x.ndim\n    index = index.view(\n        *((1,) * (axes + 1)),\n        end - start,\n        *((1,) * (x.ndim - axes - 2)))\n    # Make index.shape == x.shape\n    index = index.repeat(\n        *x.shape[:axes + 1],\n        1,\n        *x.shape[axes + 2:]\n    )\n    out.scatter_(axes + 1, index, x)\n    return out\n\n\nclass SliceGradOp(Function):\n    \"\"\" Local gradient of BoundSlice.\n\n    Not including multiplication with gradients from other layers.\n    \"\"\"\n    @staticmethod\n    def symbolic(_, grad_last, input, start=None, end=None, axes=None, steps=1):\n        return _.op(\n            'grad::Slice', grad_last, input,\n            start_i=start, end_i=end, axes_i=axes, steps_i=steps\n        ).setType(grad_last.type())\n\n    @staticmethod\n    def forward(ctx, grad_last, input, start, end, axes, steps):\n        return slice_grad(grad_last, input.shape, start, end, axes, steps)\n\n\nclass SliceGrad(Module):\n    def __init__(self, start, end, axes, steps):\n        super().__init__()\n        self.start = start\n        self.end = end\n        self.axes = axes\n        self.steps = steps\n\n    def forward(self, grad_last, input):\n        return SliceGradOp.apply(\n            grad_last, input,\n            self.start, self.end, self.axes, self.steps)\n\n\nclass BoundSliceGrad(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.start = attr['start']\n        self.end = attr['end']\n        self.axes = attr['axes']\n        self.steps = attr['steps']\n        self.use_default_ibp = True\n\n    def forward(self, grad_last, input):\n        return slice_grad(grad_last, input.shape,\n                          self.start, self.end, self.axes, self.steps)\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            assert self.axes > 0\n            last_A_ = last_A.reshape(-1, *self.inputs[1].output_shape[self.axes:])\n            last_A_ = last_A_[:, self.start:self.end]\n            last_A = last_A_.reshape(\n                *last_A.shape[:self.axes+2], -1,\n                *self.inputs[1].output_shape[self.axes+1:])\n            return last_A\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)),\n                (None, None)], 0, 0\n\n\ndef concat_grad(x, axis, input_index, *inputs):\n    cur = 0\n    for i in range(input_index):\n        cur += inputs[i].shape[axis]\n    x_ = x.reshape(-1, *x.shape[axis + 1:])\n    ret = x_[:, cur:cur+inputs[input_index].shape[axis]]\n    ret = ret.reshape(*x.shape[:axis + 1], *ret.shape[1:])\n    return ret\n\n\nclass ConcatGradOp(Function):\n    @staticmethod\n    def symbolic(_, grad_last, axis, input_index, *inputs):\n        return _.op('grad::Concat', grad_last, *inputs,\n                    axis_i=axis, input_index_i=input_index).setType(grad_last.type())\n\n    @staticmethod\n    def forward(ctx, grad_last, axis, input_index, *inputs):\n        return concat_grad(grad_last, axis, input_index, *inputs)\n\n\nclass ConcatGrad(Module):\n    def __init__(self, axis, input_index):\n        super().__init__()\n        self.input_index = input_index\n        self.axis = axis\n\n    def forward(self, grad_last, *input):\n        return ConcatGradOp.apply(grad_last, self.axis, self.input_index, *input)\n\n\nclass BoundConcatGrad(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.axis = attr['axis']\n        self.input_index = attr['input_index']\n        self.use_default_ibp = True\n\n    def forward(self, grad_last, *inputs):\n        return concat_grad(grad_last, self.axis, self.input_index, *inputs)\n\n    def bound_backward(self, last_lA, last_uA, *args, **kwargs):\n        def _bound_oneside(last_A):\n            if last_A is None:\n                return None\n            assert self.axis > 0\n            start = sum([self.inputs[i + 1].output_shape[self.axis]\n                         for i in range(self.input_index)])\n            end = start + self.output_shape[self.axis+1]\n            shape_behind = self.inputs[0].output_shape[self.axis+1:]\n            A = torch.zeros(*last_A.shape[:self.axis+2], *shape_behind, device=last_A.device)\n            A = A.view(-1, *shape_behind)\n            A[:, start:end] = last_lA.reshape(-1, *last_A.shape[self.axis+2:])\n            A = A.view(*last_A.shape[:self.axis+2], *shape_behind)\n            return A\n\n        return ([(_bound_oneside(last_lA), _bound_oneside(last_uA))]\n                + [(None, None)] * (len(self.inputs) - 1)), 0, 0"
  },
  {
    "path": "auto_LiRPA/operators/softmax.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\" Softmax \"\"\"\nfrom .base import *\n\nclass BoundSoftmaxImpl(nn.Module):\n    def __init__(self, axis):\n        super().__init__()\n        self.axis = axis\n        assert self.axis == int(self.axis)\n\n    def forward(self, x):\n        max_x = torch.max(x, dim=self.axis).values\n        x = torch.exp(x - max_x.unsqueeze(self.axis))\n        s = torch.sum(x, dim=self.axis, keepdim=True)\n        return x / s\n\n# The `option != 'complex'` case is not used in the auto_LiRPA main paper.\nclass BoundSoftmax(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.axis = attr['axis']\n        self.option = options.get('softmax', 'complex')\n        if self.option == 'complex':\n            self.complex = True\n        else:\n            self.max_input = 30\n\n    def forward(self, x):\n        assert self.axis == int(self.axis)\n        if self.option == 'complex':\n            self.input = (x,)\n            self.model = BoundSoftmaxImpl(self.axis)\n            self.model.device = self.device\n            return self.model(x)\n        else:\n            return F.softmax(x, dim=self.axis)\n\n    def interval_propagate(self, *v):\n        assert self.option != 'complex'\n        assert self.perturbed\n        h_L, h_U = v[0]\n        shift = h_U.max(dim=self.axis, keepdim=True).values\n        exp_L, exp_U = torch.exp(h_L - shift), torch.exp(h_U - shift)\n        lower = exp_L / (torch.sum(exp_U, dim=self.axis, keepdim=True) - exp_U + exp_L + epsilon)\n        upper = exp_U / (torch.sum(exp_L, dim=self.axis, keepdim=True) - exp_L + exp_U + epsilon)\n        return lower, upper\n"
  },
  {
    "path": "auto_LiRPA/operators/solver_utils.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nclass DummyGurobipyClass:\n    \"\"\"A dummy class with error message when gurobi is not installed.\"\"\"\n    def __getattr__(self, attr):\n        def _f(*args, **kwargs):\n            raise RuntimeError(f\"method {attr} not available because gurobipy module was not built.\")\n        return _f\n\ntry:\n    import gurobipy as grb\nexcept ModuleNotFoundError:\n    grb = DummyGurobipyClass()"
  },
  {
    "path": "auto_LiRPA/operators/tile.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"BoundTile\"\"\"\nfrom torch.nn import Module\nfrom .base import *\n\nclass BoundTile(Bound):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.use_default_ibp = True\n    \n    def forward(self, x, repeats):\n        return x.repeat(repeats.tolist())\n\n    def bound_backward(self, last_lA, last_uA, *x, **kwargs):\n        assert not self.is_input_perturbed(1)\n        repeats = x[1].value\n\n        def _bound_oneside(A):\n            if A is None:\n                return None\n            # block_shape: (specs, d1/r1, r1, d2/r2, r2, ..., dn/rn, rn)\n            # Reshaping A to block_shape and sum along the \"r\" dimensions\n            # is equivalent to summing up all block fragments of A.\n            block_shape = [A.shape[0]]\n            axes_to_sum = []\n            for i in range(len(repeats)):\n                block_shape.append(A.size(i + 1) // repeats[i].item())\n                block_shape.append(repeats[i].item())\n                axes_to_sum.append(2 * i + 2)\n            reshaped_A = A.reshape(*block_shape)\n            next_A = reshaped_A.sum(dim=axes_to_sum)\n            return next_A\n\n        return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0\n\n    def bound_forward(self, dim_in, *x):\n        assert (x[1].lb == x[1].ub).all(), \"repeats should be constant.\"\n        repeats = x[1].lb.tolist()\n        assert repeats[0] == 1, \"shouldn't repeat on the batch dimension.\"\n        # lb and ub have the same shape as x, so we repeat then with \"repeats\"\n        lb = x[0].lb.repeat(repeats)\n        ub = x[0].ub.repeat(repeats)\n        # lw and uw have shape (batch_size, input_dim, *shape_of_the_current_layer)\n        # so we need to repeat them with \"repeats\" as well, but we need to\n        # insert 1 at the second position to keep the input dimension unchanged.\n        repeats.insert(1, 1)\n        lw = x[0].lw.repeat(repeats)\n        uw = x[0].uw.repeat(repeats)\n        return LinearBound(lw, lb, uw, ub)\n"
  },
  {
    "path": "auto_LiRPA/operators/trigonometric.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom types import SimpleNamespace\n\nimport torch\nfrom torch.autograd import Function\n\nfrom .activation_base import BoundActivation\nfrom .s_shaped import BoundSShaped\n\n\nclass BoundSin(BoundSShaped):\n    # Lookup tables shared by all BoundSin classes.\n    xl_lower_tb = None\n    xl_upper_tb = None\n    xu_lower_tb = None\n    xu_upper_tb = None\n    func, d_func = torch.sin, torch.cos\n    n_table_entries = 1001\n\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.ibp_intermediate = True\n        self.act_func = torch.sin\n        self.d_act_func = torch.cos\n\n        # Bound limits used by IBP.\n        self.ibp_max_point = torch.pi / 2\n        self.ibp_min_point = torch.pi * 3 / 2\n\n        self.all_table_x = torch.linspace(\n            0, 2 * torch.pi, BoundSin.n_table_entries, device=self.device)\n        self.precompute_relaxation(self.act_func, self.d_act_func, x_limit = torch.pi / 2)\n        if BoundSin.xl_lower_tb is None:\n            # Generate look-up tables.\n            BoundSin.xl_lower_tb = BoundSin.get_lower_left_bound(self.all_table_x)\n            BoundSin.xl_upper_tb = BoundSin.get_upper_left_bound(self.all_table_x)\n            BoundSin.xu_lower_tb = BoundSin.get_lower_right_bound(self.all_table_x)\n            BoundSin.xu_upper_tb = BoundSin.get_upper_right_bound(self.all_table_x)\n\n    def d2_act_func(self, x):\n        return -torch.sin(x)\n\n    def _init_opt_parameters_impl(self, size_spec, name_start):\n        \"\"\"Implementation of init_opt_parameters for each start_node.\"\"\"\n        l, u = self.inputs[0].lower, self.inputs[0].upper\n        shape = [size_spec] + list(l.shape)\n        alpha = torch.empty(12, *shape, device=l.device)\n        alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape)\n        alpha.data[4:6] = self.tp_both_lower_init[name_start].expand(2, *shape)\n        alpha.data[6:8] = self.tp_both_upper_init[name_start].expand(2, *shape)\n        alpha.data[8:10] = self.tp_lower_init[name_start].expand(2, *shape)\n        alpha.data[10:12] = self.tp_upper_init[name_start].expand(2, *shape)\n        return alpha\n\n    def opt_init(self):\n        super().opt_init()\n        self.tp_both_lower_init = {}\n        self.tp_both_upper_init = {}\n        self.tp_lower_init = {}\n        self.tp_upper_init = {}\n\n    def branch_input_domain(self, lb, ub):\n        # Map all input lower and upper bounds to the [0, 2*pi] interval.\n        lb_clamped = lb - torch.floor(lb / (2 * torch.pi)) * (2 * torch.pi)\n        ub_clamped = ub - torch.floor(ub / (2 * torch.pi)) * (2 * torch.pi)\n\n        # Mask the mapped lower and upper bounds according to whether they are in [0, 0.5*pi), [0.5*pi, pi),\n        # [pi, 1.5*pi), or [1.5*pi, 2*pi).\n        mask_lb_1 = torch.logical_and(lb_clamped >= 0, lb_clamped < torch.pi / 2)\n        mask_lb_2 = torch.logical_and(lb_clamped >= torch.pi / 2, lb_clamped < torch.pi)\n        mask_lb_3 = torch.logical_and(lb_clamped >= torch.pi, lb_clamped < 3 * torch.pi / 2)\n        mask_lb_4 = torch.logical_and(lb_clamped >= 3 * torch.pi / 2, lb_clamped < 2 * torch.pi)\n\n        mask_ub_1 = torch.logical_and(ub_clamped >= 0, ub_clamped < torch.pi / 2)\n        mask_ub_2 = torch.logical_and(ub_clamped >= torch.pi / 2, ub_clamped < torch.pi)\n        mask_ub_3 = torch.logical_and(ub_clamped >= torch.pi, ub_clamped < 3 * torch.pi / 2)\n        mask_ub_4 = torch.logical_and(ub_clamped >= 3 * torch.pi / 2, ub_clamped < 2 * torch.pi)\n\n        self.sigmoid_like_mask = torch.logical_and(\n            ub - lb <= torch.pi,\n            torch.logical_or(\n                torch.logical_and(\n                    torch.logical_or(mask_lb_2, mask_lb_3),\n                    torch.logical_or(mask_ub_2, mask_ub_3)\n                ),\n                torch.logical_and(\n                    torch.logical_or(mask_lb_1, mask_lb_4),\n                    torch.logical_or(mask_ub_1, mask_ub_4)\n                )\n            )\n        )\n        self.branch_mask = torch.logical_not(self.sigmoid_like_mask)\n\n        self.mask_neg = torch.logical_and(torch.logical_or(mask_lb_3, mask_lb_4),\n                                          torch.logical_and(torch.logical_or(mask_ub_3, mask_ub_4),\n                                                            self.sigmoid_like_mask))\n\n        self.mask_pos = torch.logical_and(torch.logical_or(mask_lb_1, mask_lb_2),\n                                          torch.logical_and(torch.logical_or(mask_ub_1, mask_ub_2),\n                                                            self.sigmoid_like_mask))\n\n        self.mask_both = torch.logical_xor(self.sigmoid_like_mask,\n                                           torch.logical_or(self.mask_neg, self.mask_pos))\n\n        self.convex_concave = self.d2_act_func(lb) >= 0\n\n    def generate_d_lower_upper(self, lower, upper):\n        # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed.\n        # Note that for neurons with also input lower bound >=0, they will be masked later.\n        k_tensor = torch.floor(upper / (2 * torch.pi))\n        upper_clamped = upper - k_tensor * (2 * torch.pi)\n        case1_mask = torch.logical_and(upper_clamped >= 0, upper_clamped <= torch.pi / 2)\n        upper_clamped_new = upper_clamped.clamp(min=0, max=torch.pi / 2)\n        index = torch.max(\n            torch.zeros(upper.numel(), dtype=torch.long, device=upper.device),\n            (upper_clamped_new / self.step_pre).to(torch.long).reshape(-1)\n        ) + 1\n        # Lookup the lower bound slope from the pre-computed table.\n        d_lower = (torch.index_select(self.d_lower, 0, index).view(lower.shape)\n                   + k_tensor * 2 * torch.pi) * case1_mask\n\n        case2_mask = torch.logical_and(upper_clamped >= torch.pi, upper_clamped <= 3 * torch.pi / 2)\n        upper_clamped_new = upper_clamped.clamp(min=torch.pi, max=3 * torch.pi / 2)\n        index = torch.max(\n            torch.zeros(upper.numel(), dtype=torch.long, device=upper.device),\n            ((torch.pi - upper_clamped_new) / -self.step_pre).to(torch.long).reshape(-1)\n        ) + 1\n        # Lookup the lower bound slope from the pre-computed table.\n        d_upper = (torch.pi - torch.index_select(self.d_upper, 0, index).view(lower.shape)\n                   + k_tensor * 2 * torch.pi) * case2_mask\n\n        # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed.\n        k_tensor = torch.floor(lower / (2 * torch.pi))\n        lower_clamped = lower - k_tensor * (2 * torch.pi)\n        case3_mask = torch.logical_and(lower_clamped >= 3 * torch.pi / 2, lower_clamped <= 2 * torch.pi)\n        lower_clamped_new = lower_clamped.clamp(min=(3 * torch.pi / 2), max=2 * torch.pi)\n        index = torch.max(\n            torch.zeros(lower.numel(), dtype=torch.long, device=lower.device),\n            ((lower_clamped_new - 2 * torch.pi) / -self.step_pre).to(torch.long).reshape(-1)\n        ) + 1\n        d_upper += (torch.index_select(self.d_upper, 0, index).view(upper.shape)\n                    + (k_tensor + 1) * 2 * torch.pi) * case3_mask\n\n        case4_mask = torch.logical_and(lower_clamped >= torch.pi / 2, lower_clamped <= torch.pi)\n        lower_clamped_new = lower_clamped.clamp(min=(torch.pi / 2), max=3 * torch.pi)\n        index = torch.max(\n            torch.zeros(lower.numel(), dtype=torch.long, device=lower.device),\n            ((torch.pi - lower_clamped_new) / self.step_pre).to(torch.long).reshape(-1)\n        ) + 1\n        d_lower += (torch.pi - torch.index_select(self.d_lower, 0, index).view(upper.shape)\n                    + k_tensor * 2 * torch.pi) * case4_mask\n        return d_lower, d_upper\n\n    @staticmethod\n    def arcsin(c):\n        \"\"\"Arcsin with gradient fixes.\n\n        arcsin(-1) and arcsin(1) have pathological gradients and should be avoided.\n        \"\"\"\n        if c.min() > -1 and c.max() < 1:\n            return torch.arcsin(c)\n        c_ = c.clone()\n        mask_neg = c == -1\n        mask_pos = c == 1\n        c_[mask_neg] = 0\n        c_[mask_pos] = 0\n        ret = torch.arcsin(c_)\n        ret[mask_neg] = -torch.pi / 2\n        ret[mask_pos] = torch.pi / 2\n        return ret\n\n    @staticmethod\n    def get_intersection(start, end, c, theta=0.):\n        \"\"\"Get the number of intersections between y = sin(x + theta) and y = c between start and end.\"\"\"\n        # Use arcsine to find the first 2 intersections.\n        crossing1 = BoundSin.arcsin(c) - theta\n        crossing2 = torch.pi - crossing1 - 2 * theta  # Problematic at exact 1/2 pi, but ok in our case (happens only when lb=ub).\n        return BoundSin.n_crossing(start, end, crossing1) + BoundSin.n_crossing(start, end, crossing2)\n\n    @staticmethod\n    def n_crossing(start, end, s):\n        \"\"\"Check how many times we will encounter value s + k*2*pi within start and end for any integer k.\"\"\"\n        cycles = torch.floor((end - start) / (2 * torch.pi))  # Number of 2pi cycles.\n        # Move s and end to the same 2 * pi cycle as start.\n        dist = torch.floor((s - start) / (2 * torch.pi))\n        real_s = s - dist * 2 * torch.pi\n        real_end = end - cycles * 2 * torch.pi\n        return (real_s >= start).to(s) * (real_s <= real_end).to(s) + cycles\n\n    @staticmethod\n    def check_bound(tangent_point, x):\n        \"\"\"Check whether the tangent line at tangent_point is a valid lower/upper bound for x.\"\"\"\n        # evaluate the value of the tangent line at x and see it is >= 0 or <=0.\n        d = BoundSin.d_func(tangent_point)\n        val = d * (x - tangent_point) + BoundSin.func(tangent_point)\n        # We want a positive margin when finding a lower line, but as close to 0 as possible.\n        # We want a negative margin when finding a upper line, but as close to 0 as possible.\n        margin = BoundSin.func(x) - val\n        return margin\n\n    @staticmethod\n    @torch.no_grad()\n    def get_lower_left_bound(xl, steps=20):\n        \"\"\"Get a global lower bound given lower bound on x. Return slope and intercept.\"\"\"\n        dtype = xl.dtype\n        # Constrain xl into the -0.5 pi to 1.5 pi region.\n        cycles = torch.floor((xl + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi)\n        xl = xl - cycles\n        use_tangent_line = (xl >= torch.pi).to(dtype)\n        # Case 1: xl > pi, Lower tangent line is the only possible lower bound.\n        # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region.\n        left = torch.pi * torch.ones_like(xl)\n        # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1).\n        right = (1.5 * torch.pi) * torch.ones_like(xl)\n        last_right = right.clone()\n        for _ in range(steps):\n            mid = (left + right) / 2.\n            margin = BoundSin.check_bound(mid, xl)\n            pos_mask = (margin > 0).to(dtype)  # We want to margin > 0 but at small as possible.\n            neg_mask = 1.0 - pos_mask\n            right = mid * pos_mask + right * neg_mask  # We have positive margin, reduce right hand side.\n            last_right = mid * pos_mask + last_right * neg_mask  # Always sound, since the margin is positive.\n            left = mid * neg_mask + left * pos_mask\n        d = xl * use_tangent_line + last_right * (1. - use_tangent_line)\n        # Return slope and bias.\n        return [d, cycles]\n\n    @staticmethod\n    @torch.no_grad()\n    def get_upper_left_bound(xl, steps=20):\n        \"\"\"Get a global upper bound given lower bound on x. Return slope and intercept.\"\"\"\n        dtype = xl.dtype\n        # Constrain xl into the -0.5 pi to 1.5 pi region.\n        cycles = torch.floor((xl - 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi)\n        xl = xl - cycles\n        use_tangent_line = (xl >= 2.0 * torch.pi).to(dtype)\n        # Case 1: xl > pi, Lower tangent line is the only possible lower bound.\n        # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region.\n        left = (2.0 * torch.pi) * torch.ones_like(xl)\n        # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1).\n        right = (2.5 * torch.pi) * torch.ones_like(xl)\n        last_right = right.clone()\n        for _ in range(steps):\n            mid = (left + right) / 2.\n            margin = BoundSin.check_bound(mid, xl)\n            pos_mask = (margin > 0).to(dtype)  # We want to margin < 0 but at small as possible.\n            neg_mask = 1.0 - pos_mask\n            right = mid * neg_mask + right * pos_mask  # We have positive margin, reduce right hand side.\n            last_right = mid * neg_mask + last_right * pos_mask  # Always sound, since the margin is positive.\n            left = mid * pos_mask + left * neg_mask\n        d = xl * use_tangent_line + last_right * (1. - use_tangent_line)\n        # Return slope and bias.\n        return [d, cycles]\n\n    @staticmethod\n    @torch.no_grad()\n    def get_lower_right_bound(xu, steps=20):\n        \"\"\"Get a global lower bound given upper bound on x. Return slope and intercept.\"\"\"\n        # Constrain xu into the -0.5 pi to 1.5 pi region.\n        cycles = torch.floor((xu + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi)\n        xu = xu - cycles\n        d, _ = BoundSin.get_lower_left_bound(torch.pi - xu, steps)\n        return [3 * torch.pi - d, cycles - 2 * torch.pi]\n\n    @staticmethod\n    @torch.no_grad()\n    def get_upper_right_bound(xu, steps=20):\n        \"\"\"Get a global upper bound given upper bound on x. Return slope and intercept.\"\"\"\n        # Constrain xu into the 0.5 pi to 2.5 pi region.\n        cycles = torch.floor((xu - 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi)\n        xu = xu - cycles\n        d, _ = BoundSin.get_upper_left_bound(3 * torch.pi - xu, steps)\n        return [5 * torch.pi - d, cycles - 2 * torch.pi]\n\n    def get_bound_tb(self, lb, ub):\n        \"\"\"Find lower or upper bounds from lookup table.\"\"\"\n        lower, upper = lb, ub\n        step = 2 * torch.pi / (BoundSin.n_table_entries - 1)\n        # Move to 0 to 2 pi region.\n        lb_cycles = torch.floor(lb / (2 * torch.pi)) * (2 * torch.pi)\n        lb = torch.clamp(lb - lb_cycles, min=0, max=2 * torch.pi)\n        ub_cycles = torch.floor(ub / (2 * torch.pi)) * (2 * torch.pi)\n        ub = torch.clamp(ub - ub_cycles, min=0, max=2 * torch.pi)\n        # Find the indice within the lookup table from 0 - 2pi.\n        indices_lb = lb.div(step).long()\n        indices_ub = ub.div(step).long()\n        tangent_left_lower = BoundSin.xl_lower_tb[0][indices_lb]\n        tangent_left_upper = BoundSin.xl_upper_tb[0][indices_lb]\n        tangent_right_lower = BoundSin.xu_lower_tb[0][indices_ub]\n        tangent_right_upper = BoundSin.xu_upper_tb[0][indices_ub]\n        if self.opt_stage in ['opt', 'reuse']:\n            if not hasattr(self, 'alpha'):\n                # Raise an error if alpha is not created.\n                self._no_bound_parameters()\n            ns = self._start\n\n            self.alpha[ns].data[8:10, :] = torch.min(\n                torch.max(self.alpha[ns][8:10, :], tangent_left_lower), tangent_right_lower)\n            self.alpha[ns].data[10:12, :] = torch.min(\n                torch.max(self.alpha[ns][10:12, :], tangent_left_upper), tangent_right_upper)\n            tangent_lower = self.alpha[ns][8:10, :]\n            tangent_upper = self.alpha[ns][10:12, :]\n        else:\n            # add cycles to optimizable tangent region\n            unfolded_left_lower = (tangent_left_lower +\n                BoundSin.xl_lower_tb[1][indices_lb] + lb_cycles)\n            left_lower_ends = 1.5*torch.pi + BoundSin.xl_lower_tb[1][indices_lb] + lb_cycles\n            unfolded_right_lower = (tangent_right_lower +\n                BoundSin.xu_lower_tb[1][indices_ub] + ub_cycles)\n            right_lower_ends = 1.5*torch.pi + BoundSin.xu_lower_tb[1][indices_ub] + ub_cycles\n            mid = (lower + upper) / 2\n\n            leftmost_mask = torch.logical_and(mid < unfolded_left_lower,\n                unfolded_left_lower <= upper)\n            left_range_mask = torch.logical_and(mid >= unfolded_left_lower,\n                mid < left_lower_ends)\n            inbetween_mask = torch.logical_and(mid >= left_lower_ends,\n                mid < right_lower_ends)\n            rightmost_mask = torch.logical_and(mid >= unfolded_right_lower,\n                unfolded_right_lower >= lower)\n            right_range_mask = torch.logical_and(~left_range_mask, torch.logical_and(mid >= right_lower_ends,\n                mid < unfolded_right_lower))\n\n            tangent_lower = (leftmost_mask * tangent_left_lower +\n                left_range_mask * (mid - BoundSin.xl_lower_tb[1][indices_lb] - lb_cycles) +\n                inbetween_mask * 1.5*torch.pi + rightmost_mask * tangent_right_lower +\n                right_range_mask * (mid - BoundSin.xu_lower_tb[1][indices_ub] - ub_cycles))\n\n            unfolded_left_upper = (tangent_left_upper +\n                BoundSin.xl_upper_tb[1][indices_lb] + lb_cycles)\n            left_upper_ends = 2.5*torch.pi + BoundSin.xl_upper_tb[1][indices_lb] + lb_cycles\n            unfolded_right_upper = (tangent_right_upper +\n                BoundSin.xu_upper_tb[1][indices_ub] + ub_cycles)\n            right_upper_ends = 2.5*torch.pi + BoundSin.xu_upper_tb[1][indices_ub] + ub_cycles\n            mid = (lower + upper) / 2\n\n            leftmost_mask = torch.logical_and(mid < unfolded_left_upper,\n                unfolded_left_upper <= upper)\n            left_range_mask = torch.logical_and(mid >= unfolded_left_upper,\n                mid < left_upper_ends)\n            inbetween_mask = torch.logical_and(mid >= left_upper_ends,\n                mid < right_upper_ends)\n            rightmost_mask = torch.logical_and(mid >= unfolded_right_upper,\n                unfolded_right_upper >= lower)\n            right_range_mask = torch.logical_and(~left_range_mask, torch.logical_and(mid >= right_upper_ends,\n                mid < unfolded_right_upper))\n\n            tangent_upper = (leftmost_mask * tangent_left_upper +\n                left_range_mask * (mid - BoundSin.xl_upper_tb[1][indices_lb] - lb_cycles) +\n                inbetween_mask * 2.5*torch.pi + rightmost_mask * tangent_right_upper +\n                right_range_mask * (mid - BoundSin.xu_upper_tb[1][indices_ub] - ub_cycles))\n\n            if self.opt_stage == 'init':\n                ns = self._start\n                self.tp_lower_init[ns] = tangent_lower.detach()\n                self.tp_upper_init[ns] = tangent_upper.detach()\n\n        d_lower = BoundSin.d_func(tangent_lower)\n        b_lower = BoundSin.func(tangent_lower) - d_lower * (tangent_lower +\n                    torch.where(tangent_lower <= 1.5*torch.pi,\n                        BoundSin.xl_lower_tb[1][indices_lb] + lb_cycles,\n                        BoundSin.xu_lower_tb[1][indices_ub] + ub_cycles))\n        d_upper = BoundSin.d_func(tangent_upper)\n        b_upper = BoundSin.func(tangent_upper) - d_upper * (tangent_upper +\n                    torch.where(tangent_upper <= 2.5*torch.pi,\n                        BoundSin.xl_upper_tb[1][indices_lb] + lb_cycles,\n                        BoundSin.xu_upper_tb[1][indices_ub] + ub_cycles))\n        return d_lower, b_lower, d_upper, b_upper\n\n    def forward(self, x):\n        return torch.sin(x)\n\n    def interval_propagate(self, *v):\n        # Check if a point is in [l, u], considering the 2pi period\n        def check_crossing(ll, uu, point):\n            return ((((uu - point) / (2 * torch.pi)).floor()\n                     - ((ll - point) / (2 * torch.pi)).floor()) > 0).to(h_Ls.dtype)\n        h_L, h_U = v[0][0], v[0][1]\n        h_Ls, h_Us = self.forward(h_L), self.forward(h_U)\n        # If crossing pi/2, then max is fixed 1.0\n        max_mask = check_crossing(h_L, h_U, self.ibp_max_point)\n        # If crossing pi*3/2, then min is fixed -1.0\n        min_mask = check_crossing(h_L, h_U, self.ibp_min_point)\n        ub = torch.max(h_Ls, h_Us)\n        ub = max_mask + (1 - max_mask) * ub\n        lb = torch.min(h_Ls, h_Us)\n        lb = - min_mask + (1 - min_mask) * lb\n        return lb, ub\n\n    def bound_relax_branch(self, lb, ub):\n        dtype = lb.dtype\n\n        ub = torch.max(ub, lb + 1e-8)\n\n        # Case 1: Connect the two points as a line\n        sub = self.func(ub)\n        slb = self.func(lb)\n        mid = (sub + slb) / 2.\n        smid = self.func((ub + lb) / 2)\n        gap = smid - mid\n\n        case1_line_slope = (sub - slb) / (ub - lb).clamp(min=1e-10)\n        case1_line_bias = slb - case1_line_slope * lb\n        # Check if there are crossings between the line and the sin function.\n        grad_crossings = self.get_intersection(lb, ub, case1_line_slope, theta=0.5 * torch.pi)\n        # If there is no crossing, then we can connect the two points together as a lower/upper bound.\n        use_line = grad_crossings == 1\n        # Connected line is the upper bound.\n        upper_use_line = torch.logical_and(gap < 0, use_line)\n        # Connected line is the lower bound.\n        lower_use_line = torch.logical_and(gap >= 0, use_line)\n\n        # Case 2: we will try the global lower/upper bounds at lb and ub.\n        # For the points and lb and ub, we can construct both lower and upper bounds.\n        (case_2_lower_slope, case_2_lower_bias,\n            case_2_upper_slope, case_2_upper_bias) = self.get_bound_tb(lb, ub)\n\n        # Finally, choose between case 1 and case 2.\n        lower_use_line = lower_use_line.to(dtype)\n        not_lower_use_line = 1. - lower_use_line\n        upper_use_line = upper_use_line.to(dtype)\n        not_upper_use_line = 1. - upper_use_line\n        lower_slope = lower_use_line * case1_line_slope + not_lower_use_line * case_2_lower_slope\n        lower_bias = lower_use_line * case1_line_bias + not_lower_use_line * case_2_lower_bias\n        upper_slope = upper_use_line * case1_line_slope + not_upper_use_line * case_2_upper_slope\n        upper_bias = upper_use_line * case1_line_bias + not_upper_use_line * case_2_upper_bias\n        return lower_slope, lower_bias, upper_slope, upper_bias\n\n\nclass BoundCos(BoundSin):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.ibp_max_point = 0.0\n        self.ibp_min_point = torch.pi\n\n    def forward(self, x):\n        return torch.cos(x)\n\n    def bound_relax(self, x, init=False, dim_opt=None):\n        # Shift the input by half_pi, and shifting the linear bounds back.\n        half_pi = 0.5 * torch.pi\n        x_shifted = SimpleNamespace()\n        x_shifted.lower = x.lower + half_pi\n        x_shifted.upper = x.upper + half_pi\n        super().bound_relax(x_shifted, init=init, dim_opt=dim_opt)\n        self.lb = self.lb + self.lw * half_pi\n        self.ub = self.ub + self.uw * half_pi\n\n\nclass BoundSec(BoundActivation):\n    def __init__(self, attr=None, inputs=None, output_index=0, options=None):\n        super().__init__(attr, inputs, output_index, options)\n        self.ibp_intermediate = True\n\n    def forward(self, x):\n        return 1. / torch.cos(x)\n\n    def bound_relax(self, x, init=False):\n        assert x.lower.min() > -torch.pi / 2\n        assert x.upper.max() < torch.pi / 2\n\n        x_L = x.lower\n        x_U = x.upper\n        y_L = self.forward(x_L)\n        y_U = self.forward(x_U)\n        mask_close = x_U - x_L < 1e-8\n        upper_k = torch.where(\n            mask_close,\n            y_L * torch.tan(x_L),\n            (y_U - y_L) / (x_U - x_L).clamp(min=1e-8)\n        )\n        self.uw = upper_k\n        self.ub = -upper_k * x_L + y_L\n\n        mid = (x_L + x_U) / 2\n        y_mid = self.forward(mid)\n        lower_k = y_mid * torch.tan(mid)\n        self.lw = lower_k\n        self.lb = -lower_k * mid + y_mid\n\n    def interval_propagate(self, *v):\n        h_L, h_U = v[0][0], v[0][1]\n        assert h_L.min() > -torch.pi / 2\n        assert h_U.max() < torch.pi / 2\n        y_L = self.forward(h_L)\n        y_U = self.forward(h_U)\n        lower = (h_U < 0) * (y_U - 1) + (h_L > 0) * (y_L - 1) + 1\n        upper = torch.max(y_L, y_U)\n        return lower, upper\n\n\nclass SinGradOp(Function):\n    @staticmethod\n    def symbolic(_, x):\n        return _.op('grad::Sin', x)\n\n    @staticmethod\n    def forward(ctx, input):\n        return torch.cos(input)\n\n\nclass CosGradOp(Function):\n    @staticmethod\n    def symbolic(_, x):\n        return _.op('grad::Cos', x)\n\n    @staticmethod\n    def forward(ctx, input):\n        return -torch.sin(input)\n\n\nclass TanhGradOp(Function):\n    @staticmethod\n    def symbolic(_, x):\n        return _.op('grad::Tanh', x)\n\n    @staticmethod\n    def forward(ctx, input):\n        return 1 - torch.tanh(input)**2\n"
  },
  {
    "path": "auto_LiRPA/opt_pruner.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Pruning during the optimization.\"\"\"\n\nimport time\n\nimport torch\n\n\nclass OptPruner:\n\n    def __init__(self, x, threshold, multi_spec_keep_func, loss_reduction_func,\n                 decision_thresh, fix_interm_bounds,\n                 epsilon_over_decision_thresh):\n        self.x = x\n        self.threshold = threshold\n        self.multi_spec_keep_func = multi_spec_keep_func\n        self.loss_reduction_func = loss_reduction_func\n        self.decision_thresh = decision_thresh\n        self.fix_interm_bounds = fix_interm_bounds\n        self.epsilon_over_decision_thresh = epsilon_over_decision_thresh\n\n        # For computing the positive domain ratio\n        self.original_size = x[0].shape[0]\n        self.pruning_in_iteration = False\n        self.preserve_mask = None\n        self.preserve_mask_next = None\n        self.time = 0\n\n        # For holding full-sized alphas\n        self.cached_alphas = {}\n\n    def prune(self, x, C, ret_l, ret_u, ret, full_l, full_ret_l, full_ret_u,\n              full_ret, interm_bounds, aux_reference_bounds, reference_bounds,\n              stop_criterion_func, bound_lower):\n        # positive domains may already be filtered out, so we use all domains -\n        # negative domains to compute\n        # FIXME Only using ret_l but not ret_u.\n        if self.decision_thresh is not None and ret_l is not None:\n            if (isinstance(self.decision_thresh, torch.Tensor)\n                    and self.decision_thresh.numel() > 1\n                    and self.preserve_mask is not None):\n                if self.decision_thresh.shape[-1] == 1:\n                    # single spec with pruned domains\n                    negative_domain = (\n                        ret_l.view(-1)\n                        <= self.decision_thresh[self.preserve_mask].view(-1)\n                    ).sum()\n                else:\n                    # multiple spec with pruned domains\n                    negative_domain = self.multi_spec_keep_func(\n                        ret_l <= self.decision_thresh[self.preserve_mask]).sum()\n            else:\n                if ret_l.shape[-1] == 1:\n                    # single spec\n                    negative_domain = (\n                        ret_l.view(-1) <= self.decision_thresh.view(-1)).sum()\n                else:\n                    # multiple spec\n                    negative_domain = self.multi_spec_keep_func(\n                        ret_l <= self.decision_thresh).sum()\n            positive_domain_num = self.original_size - negative_domain\n        else:\n            positive_domain_num = -1\n        positive_domain_ratio = float(\n            positive_domain_num) / float(self.original_size)\n        # threshold is 10% by default\n        self.next_iter_pruning_in_iteration = (\n            self.decision_thresh is not None\n            and positive_domain_ratio > self.threshold)\n\n        if self.pruning_in_iteration:\n            stime = time.time()\n            self.get_preserve_mask(ret_l)\n            # prune C\n            if C is not None and C.shape[0] == x[0].shape[0]:\n                C = C[self.now_preserve_mask]  # means C is also batch specific\n            # prune x\n            x, pre_prune_size = self._prune_x(x)\n            # prune bounds\n            ret_prune = self._prune_bounds_by_mask(\n                ret_l, ret_u, ret,\n                interm_bounds, aux_reference_bounds, reference_bounds, pre_prune_size)\n            full_l, full_ret_l, full_ret_u, full_ret = ret_prune\n            self.time += time.time() - stime\n\n        stop_criterion = stop_criterion_func(\n            full_ret_l) if bound_lower else stop_criterion_func(-full_ret_u)\n        if (type(stop_criterion) != bool and stop_criterion.numel() > 1\n                and self.pruning_in_iteration):\n            stop_criterion = stop_criterion[self.preserve_mask]\n\n        return (x, C, full_l, full_ret_l, full_ret_u,\n                full_ret, stop_criterion)\n\n    def prune_idx(self, idx_mask, idx, x):\n        if self.pruning_in_iteration:\n            # local sparse index of preserved samples where\n            # idx == true\n            local_idx = idx_mask[self.preserve_mask].nonzero().view(-1)\n            # idx is global sparse index of preserved samples where\n            # idx == true\n            new_idx = torch.zeros_like(\n                idx_mask, dtype=torch.bool, device=x[0].device)\n            new_idx[self.preserve_mask] = idx_mask[self.preserve_mask]\n            idx = new_idx.nonzero().view(-1)\n            reference_idx = local_idx\n        else:\n            reference_idx = idx\n        return reference_idx, idx\n\n    def next_iter(self):\n        if self.pruning_in_iteration:\n            self.preserve_mask = self.preserve_mask_next\n        if (not self.pruning_in_iteration\n                and self.next_iter_pruning_in_iteration):\n            # init preserve_mask etc\n            self.preserve_mask = torch.arange(\n                0, self.x[0].shape[0], device=self.x[0].device, dtype=torch.long)\n            self.pruning_in_iteration = True\n\n    def update_best(self, full_ret_l, full_ret_u, best_ret):\n        if self.pruning_in_iteration:\n            # overwrite pruned cells in best_ret by threshold + eps\n            fin_l, fin_u = best_ret\n            if fin_l is not None:\n                new_fin_l = full_ret_l\n                new_fin_l[self.preserve_mask] = fin_l[self.preserve_mask]\n                fin_l = new_fin_l\n            if fin_u is not None:\n                new_fin_u = full_ret_u\n                new_fin_u[self.preserve_mask] = fin_u[self.preserve_mask]\n                fin_u = new_fin_u\n            best_ret = (fin_l, fin_u)\n        return best_ret\n\n    def update_ratio(self, full_l, full_ret_l):\n        if self.decision_thresh is not None and full_l.numel() > 0:\n            stime = time.time()\n            with torch.no_grad():\n                if isinstance(self.decision_thresh, torch.Tensor):\n                    if self.decision_thresh.shape[-1] == 1:\n                        neg_domain_num = torch.sum(\n                            full_ret_l.view(-1) <= self.decision_thresh.view(-1)\n                        ).item()\n                    else:\n                        neg_domain_num = torch.sum(self.multi_spec_keep_func(\n                            full_ret_l <= self.decision_thresh)).item()\n                else:\n                    if full_l.shape[-1] == 1:\n                        neg_domain_num = torch.sum(\n                            full_ret_l.view(-1) <= self.decision_thresh).item()\n                    else:\n                        neg_domain_num = torch.sum(self.multi_spec_keep_func(\n                            full_ret_l <= self.decision_thresh)).item()\n                now_pruning_ratio = (\n                    1.0 - float(neg_domain_num) / float(full_l.shape[0]))\n                print('pruning_in_iteration open status:',\n                      self.pruning_in_iteration)\n                print('ratio of positive domain =',\n                    full_l.shape[0] - neg_domain_num,\n                    '/', full_l.numel(), '=', now_pruning_ratio)\n            self.time += time.time() - stime\n            print('pruning-in-iteration extra time:', self.time)\n\n    @torch.no_grad()\n    def _prune_x(self, x):\n        \"\"\"\n        Prune x by given now_preserve_mask.\n        \"\"\"\n        x = list(x)\n        pre_prune_size = x[0].shape[0]\n        x[0].data = x[0][self.now_preserve_mask].data\n        if hasattr(x[0], 'ptb'):\n            if x[0].ptb.x_L is not None:\n                x[0].ptb.x_L = x[0].ptb.x_L[self.now_preserve_mask]\n            if x[0].ptb.x_U is not None:\n                x[0].ptb.x_U = x[0].ptb.x_U[self.now_preserve_mask]\n        x = tuple(x)\n\n        return x, pre_prune_size\n\n    def _prune_dict_of_lists(self, dict_of_lists, pre_prune_size):\n        if dict_of_lists is not None:\n            for k, v in dict_of_lists.items():\n                v_l, v_r = v[0], v[1]\n                if v_l.shape[0] == pre_prune_size:\n                    # the first dim is batch size and matches the preserve mask\n                    v_l = v_l[self.now_preserve_mask]\n                if v_r.shape[0] == pre_prune_size:\n                    # the first dim is batch size and matches the preserve mask\n                    v_r = v_r[self.now_preserve_mask]\n                dict_of_lists[k] = [v_l, v_r]\n\n    @torch.no_grad()\n    def _prune_bounds_by_mask(self, ret_l, ret_u, ret, interm_bounds,\n                              aux_reference_bounds, reference_bounds, pre_prune_size):\n        \"\"\"\n        Prune bounds by given now_preserve_mask.\n        \"\"\"\n        full_ret_l, full_l = self._recover_bounds_to_full_batch(ret_l)\n        full_ret_u, full_u = self._recover_bounds_to_full_batch(ret_u)\n\n        full_ret = (full_ret_l, full_ret_u) + ret[2:]\n\n        if self.fix_interm_bounds:\n            interval_to_prune = interm_bounds\n        else:\n            interval_to_prune = None\n\n        self._prune_dict_of_lists(interval_to_prune, pre_prune_size)\n        self._prune_dict_of_lists(aux_reference_bounds, pre_prune_size)\n        self._prune_dict_of_lists(reference_bounds, pre_prune_size)\n\n        # update the global mask here for possible next iteration\n        self.preserve_mask_next = self.preserve_mask[self.now_preserve_mask]\n\n        return full_l, full_ret_l, full_ret_u, full_ret\n\n    @torch.no_grad()\n    def get_preserve_mask(self, ret_l):\n        \"\"\"\n        Get preserve mask by decision_thresh to filter out the satisfied bounds.\n        \"\"\"\n        if (isinstance(self.decision_thresh, torch.Tensor)\n                and self.decision_thresh.numel() > 1):\n            if self.decision_thresh.shape[-1] == 1:\n                self.now_preserve_mask = (\n                    ret_l <= self.decision_thresh[self.preserve_mask]\n                ).view(-1).nonzero().view(-1)\n            else:\n                self.now_preserve_mask = self.multi_spec_keep_func(\n                    ret_l <= self.decision_thresh[self.preserve_mask]\n                ).nonzero().view(-1)\n        else:\n            if self.decision_thresh.shape[-1] == 1:\n                self.now_preserve_mask = (\n                    ret_l <= self.decision_thresh).view(-1).nonzero().view(-1)\n            else:\n                self.now_preserve_mask = self.multi_spec_keep_func(\n                    ret_l <= self.decision_thresh).nonzero().view(-1)\n\n    def _recover_bounds_to_full_batch(self, ret):\n        \"\"\"\n        Recover lower and upper bounds to full batch size so that later we can\n        directly update using the full batch size of l and u.\n        \"\"\"\n        if ret is not None:\n            if (isinstance(self.decision_thresh, torch.Tensor)\n                    and self.decision_thresh.numel() > 1):\n                full_ret = (\n                    self.decision_thresh.clone().to(ret.device).type(ret.dtype)\n                    + self.epsilon_over_decision_thresh)\n            else:\n                num_decision_thresh = self.decision_thresh\n                if isinstance(num_decision_thresh, torch.Tensor):\n                    num_decision_thresh = num_decision_thresh.item()\n                full_ret = torch.full(\n                    (self.original_size,) + tuple(ret.shape[1:]),\n                    fill_value=(num_decision_thresh\n                                + self.epsilon_over_decision_thresh),\n                    device=ret.device, dtype=ret.dtype)\n            full_ret[self.preserve_mask] = ret\n            if full_ret.shape[1] > 1:\n                full_reduced_ret = self.loss_reduction_func(full_ret)\n            else:\n                full_reduced_ret = full_ret\n        else:\n            full_ret = full_reduced_ret = None\n\n        return full_ret, full_reduced_ret\n\n    def cache_full_sized_alpha(self, optimizable_activations: list):\n        \"\"\"\n        When preserve mask is in use, cache the full-sized alphas in self.cached_alphas,\n        and rewrite the alphas in nodes according to the preserve mask.\n        The full-sized alphas will be recovered back to nodes after compute_bounds,\n        via the function named recover_full_sized_alphas()\n        :param optimizable_activations: list of nodes that may have slope alphas as optimizable variables\n        :return: None\n        \"\"\"\n        if self.pruning_in_iteration:\n            for act in optimizable_activations:\n                if act.name in self.cached_alphas:\n                    self.cached_alphas[act.name].clear()\n                self.cached_alphas[act.name] = {}\n                if act.alpha is not None:\n                    for start_node in act.alpha:\n                        # cached alphas and alphas stored in nodes should share the same memory space\n                        self.cached_alphas[act.name][start_node] = act.alpha[start_node]\n                        act.alpha[start_node] = act.alpha[start_node][:, :, self.preserve_mask]\n\n    def recover_full_sized_alpha(self, optimizable_activations: list):\n        \"\"\"\n        After bound computation, recover the full-sized alphas back to nodes.\n        :param optimizable_activations: ist of nodes that may have slope alphas as optimizable variables\n        :return: None\n        \"\"\"\n        if self.pruning_in_iteration:\n            for act in optimizable_activations:\n                for start_node in self.cached_alphas[act.name]:\n                    act.alpha[start_node] = self.cached_alphas[act.name][start_node]\n\n    def clean_full_sized_alpha_cache(self):\n        for act_node in self.cached_alphas:\n            self.cached_alphas[act_node].clear()\n        self.cached_alphas.clear()\n"
  },
  {
    "path": "auto_LiRPA/optimize_graph.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\"\"\"Optimize the graph to merge nodes and remove unnecessary ones.\n\nInitial and experimental code only.\n\"\"\"\n\nfrom auto_LiRPA.bound_ops import *\nfrom auto_LiRPA.utils import logger\nimport torch\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef _optimize_graph(self: 'BoundedModule'):\n    \"\"\"Optimize the graph to remove some unnecessary nodes.\"\"\"\n    merge_identical_act(self)\n    convert_sqr(self)\n    div_to_mul(self)\n    merge_sec(self)\n    minmax_to_relu(self)\n    optimize_relu_relation(self)\n\n    if self.bound_opts['optimize_graph']['optimizer'] is not None:\n        # Use the custom graph optimizer\n        self.bound_opts['optimize_graph']['optimizer'](self)\n\n    for node in list(self.nodes()):\n        if (not node.output_name\n                and node.name != self.final_name\n                and node.name not in self.root_names):\n            self.delete_node(node)\n\n\ndef _copy_node_properties(new, ref):\n    new.output_shape = ref.output_shape\n    new.device = ref.device\n    new.attr['device'] = ref.attr['device']\n    new.batch_dim = ref.batch_dim\n    new.from_complex_node = ref.from_complex_node\n\n\ndef merge_sec(model: 'BoundedModule'):\n    nodes = list(model.nodes())\n    for node in nodes:\n        if type(node) == BoundReciprocal and type(node.inputs[0]) == BoundCos:\n            node_new = BoundSec(inputs=[node.inputs[0].inputs[0]])\n            node_new.name = f'{node.inputs[0].name}/sec'\n            _copy_node_properties(node_new, node)\n            if node_new.name in model._modules:\n                node_existing = model._modules[node_new.name]\n                assert isinstance(node_existing, BoundSec)\n                assert node_existing.inputs[0] == node.inputs[0].inputs[0]\n                model.replace_node(node, node_existing)\n            else:\n                model.add_nodes([node_new])\n                model.replace_node(node, node_new)\n\n\ndef div_to_mul(model: 'BoundedModule'):\n    nodes = list(model.nodes())\n    for node in nodes:\n        if type(node) == BoundDiv:\n            logger.debug('Replacing BoundDiv node: %s', node)\n            node_reciprocal = BoundReciprocal(inputs=[node.inputs[1]])\n            node_reciprocal.name = f'{node.name}/reciprocal'\n            # Properties of the reciprocal node only depend on inputs[1], i.e.\n            # the node of denominator. They can be different from those of\n            # the original BoundDiv node, due to possible broadcasting and\n            # perturbed/unperturbed switching in multiplication.\n            _copy_node_properties(node_reciprocal, node.inputs[1])\n            model.add_nodes([node_reciprocal])\n            node_mul = BoundMul(inputs=[node.inputs[0], node_reciprocal],\n                                options=model.bound_opts)\n            node_mul.name = f'{node.name}/mul'\n            _copy_node_properties(node_mul, node)\n            model.add_nodes([node_mul])\n            model.replace_node(node, node_mul)\n\n\ndef convert_sqr(model: 'BoundedModule'):\n    \"\"\"Replace BoundMul or Bound Pow with BoundSqr if applicable.\n\n    1. If the two inputs nodes of a BoundMul node are the same, use BoundSqr.\n    2. Pow(x, 2) can be replaced with BoundSqr.\n    \"\"\"\n    nodes = list(model.nodes())\n    for node in nodes:\n        replace = False\n        if type(node) == BoundMul and node.inputs[0] == node.inputs[1]:\n            replace = True\n        elif type(node) == BoundPow:\n            if ((isinstance(node.inputs[1], BoundBuffers) and node.inputs[1].buffer == 2) or\n                (isinstance(node.inputs[1], BoundConstant) and node.inputs[1].value == 2)):\n                replace = True\n        if replace:\n            node_new = BoundSqr(inputs=[node.inputs[0]])\n            node_new.name = f'{node.name}/sqr'\n            _copy_node_properties(node_new, node)\n            model.add_nodes([node_new])\n            logger.debug('Replaceing %s with %s', node, node_new)\n            model.replace_node(node, node_new)\n\n\ndef merge_identical_act(model: 'BoundedModule'):\n    \"\"\"Merge identical BoundActivation\"\"\"\n    nodes = list(model.nodes())\n    merged = [False] * len(nodes)\n    for i in range(len(nodes)):\n        if (not merged[i]\n                and isinstance(nodes[i], BoundActivation)\n                and len(nodes[i].inputs) == 1):\n            for j in range(i + 1, len(nodes)):\n                if (not merged[j]\n                        and type(nodes[j]) == type(nodes[i])\n                        and len(nodes[i].inputs) == 1):\n                    if nodes[i].inputs[0] == nodes[j].inputs[0]:\n                        logger.debug('Merging node %s to %s', nodes[j], nodes[i])\n                        model.replace_node(nodes[j], nodes[i])\n                        merged[j] = True\n\n\ndef minmax_to_relu(model: 'BoundedModule'):\n    \"\"\"Replace BoundMinMax with BoundRelu if one of its inputs is constant\"\"\"\n    nodes = list(model.nodes())\n    for node in nodes:\n        if type(node) == BoundMax:\n            for i, input_node in enumerate(node.inputs):\n                if not input_node.perturbed:\n                    logger.debug('Replacing BoundMax node %s', node)\n                    # max(x, c) = ReLU(x - c) + c\n                    node_sub = BoundSub(inputs=[node.inputs[1-i], input_node],\n                                        options=model.bound_opts)\n                    node_sub.name = f'{node.name}/sub'\n                    _copy_node_properties(node_sub, node)\n                    node_relu = BoundRelu(inputs=[node_sub],\n                                          options=model.bound_opts)\n                    node_relu.name = f'{node.name}/relu'\n                    _copy_node_properties(node_relu, node)\n                    node_add = BoundAdd(inputs=[node_relu, input_node],\n                                        options=model.bound_opts)\n                    node_add.name = f'{node.name}/add'\n                    _copy_node_properties(node_add, node)\n                    model.add_nodes([node_sub, node_relu, node_add])\n                    model.replace_node(node, node_add)\n                    break\n        elif type(node) == BoundMin:\n            for i, input_node in enumerate(node.inputs):\n                if not input_node.perturbed:\n                    logger.debug('Replacing BoundMin node %s', node)\n                    # min(x, c) = -ReLU(c - x) + c\n                    node_sub_1 = BoundSub(inputs=[input_node, node.inputs[1-i]],\n                                          options=model.bound_opts)\n                    node_sub_1.name = f'{node.name}/sub/1'\n                    _copy_node_properties(node_sub_1, node)\n                    node_relu = BoundRelu(inputs=[node_sub_1],\n                                          options=model.bound_opts)\n                    node_relu.name = f'{node.name}/relu'\n                    _copy_node_properties(node_relu, node)\n                    node_sub_2 = BoundSub(inputs=[input_node, node_relu],\n                                          options=model.bound_opts)\n                    node_sub_2.name = f'{node.name}/sub/2'\n                    _copy_node_properties(node_sub_2, node)\n                    model.add_nodes([node_sub_1, node_relu, node_sub_2])\n                    model.replace_node(node, node_sub_2)\n                    break\n\ndef _pair_row(Ws, bs, Wm, j, atol=1e-8):\n    \"\"\"\n    Checks the relation ReLU(x) - ReLU(-x) = x. Return\n    the index at the merge weight if the relation exists,\n    otherwise return None.\n    \"\"\"\n    # Check whether this fits the pattern in docstring.\n    if not (torch.allclose(Ws[j+1], -Ws[j], atol=atol)\n            and abs(float(bs[j] + bs[j+1])) < atol):\n        return None\n\n    # Make merge weight 4D so Gemm and Conv share same indexing\n    if Wm.dim() == 2:                 # Gemm path\n        Wm4 = Wm.unsqueeze(-1).unsqueeze(-1)    \n    else:                             # Conv path \n        Wm4 = Wm\n\n    # Find corresponding columns of the merge weight\n    # We check 1) The two nonzero element are in the same row\n    #          2) The two entries are +1 and -1\n    # If the check pass, we return the row index, otherwise it \n    # is not a valid pattern match and we return None.\n    rows = torch.nonzero(Wm4[:, [j, j+1], 0, 0], as_tuple=False)\n    if rows.size(0) != 2 or rows[0, 0] != rows[1, 0]:\n        return None\n    r = int(rows[0, 0])\n\n    ok = (abs(float(Wm4[r, j, 0, 0] - 1)) < atol and\n          abs(float(Wm4[r, j+1, 0, 0] + 1)) < atol and\n          torch.count_nonzero(Wm4[r]) == 2)\n    return r if ok else None\n                \ndef optimize_relu_relation(model: 'BoundedModule'):\n    \"\"\"\n    This graph optimization detects the optimizable path with\n    the identity\n        ReLU(ReLU(x + b) - ReLu(-x - b)) = ReLU(x + b)\n    for both linear layer and convolution layer. Replace the \n    sequence of nodes with pattern\n        Gemm -> ReLU -> Gemm -> ReLU or\n        Conv -> ReLU -> Conv -> ReLU\n    to one single Gemm -> ReLU or Conv -> ReLU.\n    \"\"\"\n    nodes = list(model.nodes())\n    i = 0\n    while i + 3 < len(nodes):\n        A, B, C, D = nodes[i:i+4]\n        \n        # In Conv layers, we detect whether the optimization can be done\n        # for pairs of channels. If so, the optimization eliminates one\n        # Conv layer and recover the original results with the identity \n        # in docstring.\n        if (isinstance(A, BoundConv) and isinstance(B, BoundRelu) and\n            isinstance(C, BoundConv) and isinstance(D, BoundRelu) and tuple(C.attr['kernel_shape'])==(1,1)):\n            \n            # Here use forward() to extract weights to handle BoundParam/BoundConstant, or any other node\n            # that could represent weights a unified interface.\n            Ws = C.inputs[1].forward()\n            Wc = A.inputs[1].forward()\n            \n            # We only care about 2D conv\n            if Ws.ndim != 4 or Wc.ndim != 4:\n                i += 1\n                continue\n            \n            bs = C.inputs[2].forward() if C.has_bias else torch.zeros_like(Ws[:, 0, 0, 0])\n            bc = A.inputs[2].forward() if A.has_bias else torch.zeros_like(Wc[:, 0, 0, 0])\n            \n            # Detect whether and where the identity presents in the weight matrix.\n            pairs, skip = {}, set()\n            for j in range(0, Wc.size(0) - 1):\n                r = _pair_row(Wc, bc, Ws, j)\n                if r is not None:\n                    pairs[j] = r\n                    skip.add(j + 1)\n            \n            if pairs:\n                Cout, Cin, kH, kW = Ws.size(0), Wc.size(1), *Wc.shape[2:]\n                W_new = torch.empty((Cout, Cin, kH, kW), dtype=Wc.dtype, device=Wc.device)\n                b_new = torch.empty((Cout,), dtype=bc.dtype, device=bc.device)\n\n                \n                # Build fused weight and bias\n                dst = 0\n                for src in range(Wc.size(0)):\n                    if src in skip:\n                        continue\n                    b_new[dst] = bs[pairs[src]] + bc[src] if src in pairs else bc[src]\n                    W_new[dst] = Wc[src]\n                    dst += 1\n                \n                # Modify the graph using the newly built weights and bias\n                weight_node = BoundParams('fused_weight', torch.nn.Parameter(W_new))\n                bias_node = BoundParams('fused_bias', torch.nn.Parameter(b_new))\n                weight_node.name = f'{A.name}/optimized/weight' \n                bias_node.name = f'{A.name}/optimized/bias'\n                \n                fused = BoundConv(\n                    attr=A.attr.copy(),\n                    inputs=[A.inputs[0], weight_node, bias_node],\n                    output_index=A.output_index,\n                    options=model.bound_opts\n                )\n                fused.name = f'{A.name}/optimized'\n                _copy_node_properties(fused, A)\n                relu = BoundRelu(inputs=[fused], options=model.bound_opts)\n                relu.name = f'{A.name}/optimized/relu'\n                _copy_node_properties(relu, D)\n                \n                model.add_nodes([weight_node, bias_node, fused, relu])\n                model.replace_node(D, relu)\n                model.replace_node(A, fused)\n                model.delete_node(B)\n                model.delete_node(C) \n                \n                # Skip the full sequence once the pattern is detected\n                i += 4\n                continue\n        \n        # In Linear layer, we detect whether the optimization can be \n        # done for pair of rows. The code structure is similar the \n        # one at Conv branch. \n        elif (isinstance(A, BoundLinear) and isinstance(B, BoundRelu) and\n            isinstance(C, BoundLinear) and isinstance(D, BoundRelu)):\n            \n            Ws = A.inputs[1].forward()\n            Wm = C.inputs[1].forward()\n            bs = A.inputs[2].forward() if len(A.inputs) == 3 else torch.zeros_like(Ws[:, 0])\n            bm = C.inputs[2].forward() if len(C.inputs) == 3 else torch.zeros_like(Wm[:, 0])\n            \n            pairs, skip = {}, set()\n            for j in range(0, Ws.size(0) - 1):\n                r = _pair_row(Ws, bs, Wm, j)\n                if r is not None:\n                    pairs[j] = r\n                    skip.add(j + 1)\n                 \n            if pairs:\n                n_out = Wm.shape[0]\n                W_new = torch.empty((n_out, Ws.shape[1]), dtype=Ws.dtype, device=A.attr['device'])\n                b_new = torch.empty((n_out,), dtype=bs.dtype, device=A.attr['device'])\n\n                dst = 0\n                for src in range(Ws.size(0)):\n                    if src in skip:\n                        continue\n                    b_new[dst] = bm[pairs[src]] + bs[src] if src in pairs else bs[src]\n                    W_new[dst] = Ws[src]\n                    dst += 1\n                \n                weight_node = BoundParams('fused_weight', torch.nn.Parameter(W_new), attr=dict(device=A.attr['device']))\n                bias_node = BoundParams('fused_bias', torch.nn.Parameter(b_new), attr=dict(device=A.attr['device']))\n                weight_node.name = f'{A.name}/optimized/weight'\n                bias_node.name = f'{A.name}/optimized/bias'\n                \n                fused = BoundLinear(\n                    attr=A.attr.copy(),\n                    inputs=[A.inputs[0], weight_node, bias_node],\n                    output_index=A.output_index,\n                    options=model.bound_opts\n                )\n                fused.name = f'{A.name}/optimized'\n                _copy_node_properties(fused, A)\n                relu = BoundRelu(inputs=[fused], options=model.bound_opts)\n                relu.name = f'{A.name}/optimized/relu'\n                _copy_node_properties(relu, D)\n                \n                model.add_nodes([weight_node, bias_node, fused, relu])\n                model.replace_node(D, relu)\n                model.delete_node(A)\n                model.delete_node(B)\n                model.delete_node(C)\n                \n                i += 4\n                continue\n        i += 1\n"
  },
  {
    "path": "auto_LiRPA/optimized_bounds.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport time\nimport os\nfrom collections import OrderedDict\nfrom contextlib import ExitStack\n\nimport torch\nfrom torch import optim, Tensor\nfrom .beta_crown import print_optimized_beta\nfrom .cuda_utils import double2float\nfrom .utils import reduction_sum, multi_spec_keep_func_all, clone_sub_A_dict\nfrom .opt_pruner import OptPruner\nfrom .perturbations import PerturbationLpNorm\n\nfrom typing import TYPE_CHECKING, Union, Tuple, Optional, Dict\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndefault_optimize_bound_args = {\n    'enable_alpha_crown': True,  # Enable optimization of alpha.\n    'enable_beta_crown': False,  # Enable beta split constraint.\n\n    'apply_output_constraints_to': [],  # Enable optimization w.r.t. output constraints.\n    'tighten_input_bounds': False,  # Don't tighten input bounds\n    # If output constraints are activated, use only bounds computed with them.\n    'best_of_oc_and_no_oc': False,\n    'directly_optimize': [],  # No layer should be directly optimized\n    'oc_lr': 0.1,  # learning rate for dualized output constraints\n    'share_gammas': False,\n\n    'iteration': 20,  # Number of alpha/beta optimization iterations.\n    # Share some alpha variables to save memory at the cost of slightly\n    # looser bounds.\n    'use_shared_alpha': False,\n    # Optimizer used for alpha and beta optimization.\n    'optimizer': 'adam',\n    # Save best results of alpha/beta/bounds during optimization.\n    'keep_best': True,\n    # Only optimize bounds of last layer during alpha/beta CROWN.\n    'fix_interm_bounds': True,\n    # Learning rate for the optimizable parameter alpha in alpha-CROWN.\n    'lr_alpha': 0.5,\n    # Learning rate for the optimizable parameter beta in beta-CROWN.\n    'lr_beta': 0.05,\n    'lr_cut_beta': 5e-3,  # Learning rate for optimizing cut betas.\n    # Initial alpha variables by calling CROWN once.\n    'init_alpha': True,\n    'lr_coeffs': 0.01,  # Learning rate for coeffs for refinement\n    # Layers to be refined, separated by commas.\n    # -1 means preactivation before last activation.\n    'intermediate_refinement_layers': [-1],\n    # When batch size is not 1, this reduction function is applied to\n    # reduce the bounds into a scalar.\n    'loss_reduction_func': reduction_sum,\n    # Criteria function of early stop.\n    'stop_criterion_func': lambda x: False,\n    # Learning rate decay factor during bounds optimization.\n    'lr_decay': 0.98,\n    # Number of iterations that we will start considering early stop\n    # if tracking no improvement.\n    'early_stop_patience': 10,\n    # Start to save optimized best bounds\n    # when current_iteration > int(iteration*start_save_best)\n    'start_save_best': 0.5,\n    # Use double fp (float64) at the last iteration in alpha/beta CROWN.\n    'use_float64_in_last_iteration': False,\n    # Prune verified domain within iteration.\n    'pruning_in_iteration': False,\n    # Percentage of the minimum domains that can apply pruning.\n    'pruning_in_iteration_threshold': 0.2,\n    # For specification that will output multiple bounds for one\n    # property, we use this function to prune them.\n    'multi_spec_keep_func': multi_spec_keep_func_all,\n    # Use the newly fixed loss function. By default, it is set to False\n    # for compatibility with existing use cases.\n    # Try to ensure that the parameters always match with the optimized bounds.\n    'deterministic': False,\n    'max_time': 1e9,\n}\n\n\ndef opt_reuse(self: 'BoundedModule'):\n    for node in self.get_enabled_opt_act():\n        node.opt_reuse()\n\n\ndef opt_no_reuse(self: 'BoundedModule'):\n    for node in self.get_enabled_opt_act():\n        node.opt_no_reuse()\n\n\ndef _set_alpha(optimizable_activations, parameters, alphas, lr):\n    \"\"\"Set best_alphas, alphas and parameters list.\"\"\"\n    for node in optimizable_activations:\n        alphas.extend(list(node.alpha.values()))\n        node.opt_start()\n    # Alpha has shape (2, output_shape, batch_dim, node_shape)\n    parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2})\n    # best_alpha is a dictionary of dictionary. Each key is the alpha variable\n    # for one activation layer, and each value is a dictionary contains all\n    # activation layers after that layer as keys.\n    best_alphas = OrderedDict()\n    for m in optimizable_activations:\n        best_alphas[m.name] = {}\n        for alpha_m in m.alpha:\n            best_alphas[m.name][alpha_m] = m.alpha[alpha_m].detach().clone()\n            # We will directly replace the dictionary for each activation layer after\n            # optimization, so the saved alpha might not have require_grad=True.\n            m.alpha[alpha_m].requires_grad_()\n\n    return best_alphas\n\n\ndef _set_gammas(nodes, parameters):\n    \"\"\"\n    Adds gammas to parameters list\n    \"\"\"\n    gammas = []\n    gamma_lr = 0.1\n    for node in nodes:\n        if hasattr(node, 'gammas'):\n            gammas.append(node.gammas_underlying_tensor)\n            # The learning rate is the same for all layers\n            gamma_lr = node.options['optimize_bound_args']['oc_lr']\n    parameters.append({'params': gammas, 'lr': gamma_lr})\n\n\ndef _save_ret_first_time(bounds, best_ret):\n    \"\"\"Save results at the first iteration to best_ret.\"\"\"\n    if bounds is not None:\n        best_ret.append(bounds.detach().clone())\n    else:\n        best_ret.append(None)\n\n\ndef _to_float64(self: 'BoundedModule', C, x, aux_reference_bounds, interm_bounds):\n    \"\"\"\n    Transfer variables to float64 only in the last iteration to help alleviate\n    floating point error.\n    \"\"\"\n    self.to(torch.float64)\n    C = C.to(torch.float64)\n    x = self._to(x, torch.float64)\n    # best_intermediate_bounds is linked to aux_reference_bounds!\n    # we only need call .to() for one of them\n    self._to(aux_reference_bounds, torch.float64, inplace=True)\n    interm_bounds = self._to(\n        interm_bounds, torch.float64)\n\n    return C, x, interm_bounds\n\n\ndef _to_default_dtype(self: 'BoundedModule', x, total_loss, full_ret, ret,\n                      best_intermediate_bounds, return_A):\n    \"\"\"\n    Switch back to default precision from float64 typically to adapt to\n    afterwards operations.\n    \"\"\"\n    total_loss = total_loss.to(torch.get_default_dtype())\n    self.to(torch.get_default_dtype())\n    x[0].to(torch.get_default_dtype())\n    full_ret = list(full_ret)\n    if isinstance(ret[0], torch.Tensor):\n        # round down lower bound\n        full_ret[0] = double2float(full_ret[0], 'down')\n    if isinstance(ret[1], torch.Tensor):\n        # round up upper bound\n        full_ret[1] = double2float(full_ret[1], 'up')\n    for _k, _v in best_intermediate_bounds.items():\n        _v[0] = double2float(_v[0], 'down')\n        _v[1] = double2float(_v[1], 'up')\n        best_intermediate_bounds[_k] = _v\n    if return_A:\n        full_ret[2] = self._to(full_ret[2], torch.get_default_dtype())\n\n    return total_loss, x, full_ret\n\n\ndef _get_idx_mask(idx: int, full_ret_bound: Tensor, best_ret_bound: Tensor, loss_reduction_func\n                  ) -> Tuple[Tensor, Optional[Tensor]]:\n    \"\"\"\n    Get index for improved elements.\n    :param idx:                 0 := updating the lower bound, 1 := updating the upper bound\n    :param full_ret_bound:      Lower/upper bound results for this iteration\n    :param best_ret_bound:      The best lower/upper bound results seen thus far\n    :param loss_reduction_func: Loss reduction function that reduces the losses to just the batch\n                                dimension.\n    :return:\n            idx_mask:           A mask on the batch dimension where the mask is true if a\n                                sub-problem has seen loss improvement.\n            improved_idx:       A Tensor of the indices in the batch dimension that have seen loss\n                                improvement.\n    \"\"\"\n    assert idx in (0, 1), 'idx must be 0 (lower bound) or 1 (upper bound)'\n    reduced_full = loss_reduction_func(full_ret_bound)\n    reduced_best = loss_reduction_func(best_ret_bound)\n    idx_mask = (reduced_full > reduced_best) if idx == 0 else (reduced_full < reduced_best)\n    idx_mask = idx_mask.view(-1)\n\n    improved_idx = idx_mask.nonzero(as_tuple=True)[0] if idx_mask.any() else None\n    return idx_mask, improved_idx\n\n\ndef _update_best_ret(\n    full_ret: Dict[str, Dict[str, Dict[str, Union[Tensor, 'Patches', Tuple]]]],\n    best_ret: Dict[str, Dict[str, Dict[str, Union[Tensor, 'Patches', Tuple]]]],\n    loss_reduction_func,\n    idx: int,\n    deterministic: bool = False,\n    best_out_in_A_dict: Optional[Dict[str, Union[Tensor, 'Patches', Tuple]]] = None,\n    out_in_keys: Optional[Tuple[str, str]] = None\n):\n    \"\"\"\n    Update best_ret_bound and best_ret by comparing with new results.\n    :param full_ret:                The full return from the 'compute_bounds' method.\n    :param best_ret:                The best return during optimization in the same format as\n                                    'full_ret'\n    :param loss_reduction_func:     Loss reduction function that reduces the losses to just the\n                                    batch dimension.\n    :param idx:                     0 := updating the lower bound, 1 := updating the upper bound\n    :param deterministic:           If true, problems that have seen loss improvement will have\n                                    their bounds directly saved as the new best bound. Otherwise,\n                                    the current bounds will be compared to the current best bounds\n                                    and the comparison result is saved as the new best bound. In\n                                    other words, deterministic is true if an improvement in the\n                                    loss function is a sufficient condition for bound improvement.\n    :param best_out_in_A_dict:      If given, this is the A_dict entry corresponding to the output\n    :param out_in_keys:             If given, this is a tuple whose first element is the first index\n                                    into the A_dict and whose second element is the second index\n                                    into the A_dict. In particular, the first element should be the\n                                    name of the output layer of the network, and the second\n                                    element should be the name of the input layer. If these indices\n                                    are not given correctly, an indexing error will be thrown. If\n                                    given, it is assumed that we should use these keys to update\n                                    lA/uA/lbias/ubias depending on if the bounds have improved.\n                                    Therefore, we must assert that 'full_ret' and 'best_ret' contain\n                                    an A_dict.\n    :return:\n            best_ret:\n            best_out_in_A_dict:     An updated A_dict entry corresponding to the output/input layer\n            need_update:            Set to True in this method if at least one sub-problem has seen\n                                    bound improvement.\n            idx_mask:               A mask on the batch dimension where the mask is true if a\n                                    sub-problem has seen loss improvement.\n            improved_idx:           A Tensor of the indices in the batch dimension that have seen\n                                    loss improvement.\n    \"\"\"\n    assert idx in (0, 1), 'idx must be 0 (lower bound) or 1 (upper bound)'\n\n    idx_mask, improved_idx = _get_idx_mask(idx, full_ret[idx], best_ret[idx], loss_reduction_func)\n    if improved_idx is None:\n        return best_ret, best_out_in_A_dict, False, idx_mask, None\n\n    compare_fn = torch.max if idx == 0 else torch.min\n    # Update detailed return tensors (if present)\n    if full_ret[idx] is not None:\n        if deterministic:\n            best_ret[idx][improved_idx] = full_ret[idx][improved_idx]\n            if out_in_keys is not None:\n                _update_A_dict(\n                    best_out_in_A_dict,\n                    full_ret[2][out_in_keys[0]][out_in_keys[1]],\n                    improved_idx\n                )\n        else:\n            if out_in_keys is not None:\n                # Since we must also update the A_dict, we don't want to use the original\n                # 'compare' method as we need to know which specific problems have\n                # seen improvement.\n                cmp_op = (lambda x, y: (x > y)) if idx==0 else (lambda x, y: (x < y))\n                c_mask = cmp_op(full_ret[idx][improved_idx], best_ret[idx][improved_idx])\n                best_ret[idx][improved_idx] = torch.where(\n                    c_mask, full_ret[idx][improved_idx], best_ret[idx][improved_idx])\n                # Also update the lA/uA/lbias/ubias matrices/vectors from the output layer to\n                # the input layer if the bounds have improved and if the output and input layer\n                # keys were specified\n                _update_A_dict(\n                    best_out_in_A_dict,\n                    full_ret[2][out_in_keys[0]][out_in_keys[1]],\n                    improved_idx, c_mask\n                )\n            else:\n                # Simple tensor-wise comparison (no A_dict)\n                best_ret[idx][improved_idx] = compare_fn(\n                    full_ret[idx][improved_idx],\n                    best_ret[idx][improved_idx])\n\n    return best_ret, best_out_in_A_dict, True, idx_mask, improved_idx\n\n\ndef _update_A_dict(best_A, full_A, improved_idx, c_mask: Optional[Tensor] = None):\n    \"\"\"\n    Update best_A dict by full_A for entries at improved_idx.\n    :param best_A:         The A_dict entry to be updated.\n    :param full_A:         The A_dict entry containing the new values.\n    :param improved_idx:   The indices in the batch dimension that have seen bound improvement.\n    :param c_mask:         A mask on the batch dimension where the mask is true if a\n                            sub-problem has seen bound improvement. If None, then the entire\n                            slice at improved_idx will be replaced.\n    \"\"\"\n    for key, val in full_A.items():\n        if val is None:\n            # An entry for lA/uA/lbias/ubias may be None depending on if we are\n            # lower or upper bounding the network\n            continue\n        target = best_A[key][improved_idx]\n        source = val[improved_idx]\n        if c_mask is not None:\n            c_mask_expanded = c_mask.view(\n                *c_mask.shape,\n                *([1] * (val.dim() - c_mask.dim()))\n            ).expand_as(val[improved_idx])\n            # Selectively update entries based on c_mask\n            best_A[key][improved_idx] = torch.where(c_mask_expanded, source, target)\n        else:\n            # Replace the entire slice if no mask is provided\n            best_A[key][improved_idx] = source\n\n\ndef _update_optimizable_activations(\n        optimizable_activations, interm_bounds,\n        fix_interm_bounds, best_intermediate_bounds,\n        reference_idx, idx, alpha, best_alphas, deterministic):\n    \"\"\"\n    Update bounds and alpha of optimizable_activations.\n    \"\"\"\n    for node in optimizable_activations:\n        # Update best intermediate layer bounds only when they are optimized.\n        # If they are already fixed in interm_bounds, then do\n        # nothing.\n        if node.name not in best_intermediate_bounds:\n            continue\n        if (interm_bounds is None\n                or node.inputs[0].name not in interm_bounds\n                or not fix_interm_bounds):\n            if deterministic:\n                best_intermediate_bounds[node.name][0][idx] = node.inputs[0].lower[reference_idx]\n                best_intermediate_bounds[node.name][1][idx] = node.inputs[0].upper[reference_idx]\n            else:\n                best_intermediate_bounds[node.name][0][idx] = torch.max(\n                    best_intermediate_bounds[node.name][0][idx],\n                    node.inputs[0].lower[reference_idx])\n                best_intermediate_bounds[node.name][1][idx] = torch.min(\n                    best_intermediate_bounds[node.name][1][idx],\n                    node.inputs[0].upper[reference_idx])\n        if alpha:\n            # Each alpha has shape (2, output_shape, batch, *shape) for act.\n            # For other activation function this can be different.\n            for alpha_m in node.alpha:\n                best_alphas[node.name][alpha_m][:, :,\n                    idx] = node.alpha[alpha_m][:, :, idx]\n\n\ndef update_best_beta(self: 'BoundedModule', enable_opt_interm_bounds, betas,\n                     best_betas, idx):\n    \"\"\"\n    Update best beta by given idx.\n    \"\"\"\n    if enable_opt_interm_bounds and betas:\n        for node in self.splittable_activations:\n            for node_input in node.inputs:\n                for key in node_input.sparse_betas.keys():\n                    best_betas[node_input.name][key] = (\n                        node_input.sparse_betas[key].val.detach().clone())\n        if self.cut_used:\n            for gbidx, general_betas in enumerate(self.cut_beta_params):\n                # FIXME need to check if 'cut' is a node name\n                best_betas['cut'][gbidx] = general_betas.detach().clone()\n    else:\n        for node in self.nodes_with_beta:\n            best_betas[node.name][idx] = node.sparse_betas[0].val[idx]\n        if self.cut_used:\n            regular_beta_length = len(betas) - len(self.cut_beta_params)\n            for cut_beta_idx in range(len(self.cut_beta_params)):\n                # general cut beta crown general_betas\n                best_betas['cut'][cut_beta_idx][:, :, idx,\n                    :] = betas[regular_beta_length + cut_beta_idx][:, :, idx, :]\n\n\ndef _get_optimized_bounds(\n        self: 'BoundedModule', x=None, aux=None, C=None, IBP=False,\n        forward=False, method='backward', bound_side='lower',\n        reuse_ibp=False, return_A=False, average_A=False, final_node_name=None,\n        interm_bounds=None, reference_bounds=None,\n        aux_reference_bounds=None, needed_A_dict=None, cutter=None,\n        decision_thresh=None, epsilon_over_decision_thresh=1e-4):\n    \"\"\"\n    Optimize CROWN lower/upper bounds by alpha and/or beta.\n    \"\"\"\n\n    opts = self.bound_opts['optimize_bound_args']\n    iteration = opts['iteration']\n    max_time = opts['max_time']\n    beta = opts['enable_beta_crown']\n    alpha = opts['enable_alpha_crown']\n    apply_output_constraints_to = opts['apply_output_constraints_to']\n    opt_choice = opts['optimizer']\n    keep_best = opts['keep_best']\n    fix_interm_bounds = opts['fix_interm_bounds']\n    loss_reduction_func = opts['loss_reduction_func']\n    stop_criterion_func = opts['stop_criterion_func']\n    use_float64_in_last_iteration = opts['use_float64_in_last_iteration']\n    early_stop_patience = opts['early_stop_patience']\n    start_save_best = opts['start_save_best']\n    multi_spec_keep_func = opts['multi_spec_keep_func']\n    deterministic = opts['deterministic']\n    enable_opt_interm_bounds = self.bound_opts.get(\n        'enable_opt_interm_bounds', False)\n    sparse_intermediate_bounds = self.bound_opts.get(\n        'sparse_intermediate_bounds', False)\n    verbosity = self.bound_opts['verbosity']\n\n    if bound_side not in ['lower', 'upper']:\n        raise ValueError(bound_side)\n    bound_lower = bound_side == 'lower'\n    bound_upper = bound_side == 'upper'\n\n    assert alpha or beta, (\n        'nothing to optimize, use compute bound instead!')\n\n    if C is not None:\n        self.final_shape = C.size()[:2]\n        self.bound_opts.update({'final_shape': self.final_shape})\n    if opts['init_alpha']:\n        # TODO: this should set up aux_reference_bounds.\n        self.init_alpha(x, share_alphas=opts['use_shared_alpha'],\n                        method=method, c=C, final_node_name=final_node_name)\n\n    optimizable_activations = self.get_enabled_opt_act()\n\n    alphas, parameters = [], []\n    dense_coeffs_mask = []\n    if alpha:\n        best_alphas = _set_alpha(\n            optimizable_activations, parameters, alphas, opts['lr_alpha'])\n    else:\n        best_alphas = None\n    if beta:\n        ret_set_beta = self.set_beta(\n            enable_opt_interm_bounds, parameters,\n            opts['lr_beta'], opts['lr_cut_beta'], cutter, dense_coeffs_mask)\n        betas, best_betas, coeffs, dense_coeffs_mask = ret_set_beta[:4]\n    if apply_output_constraints_to is not None and len(apply_output_constraints_to) > 0:\n        _set_gammas(self.nodes(), parameters)\n\n    start = time.time()\n\n    if isinstance(decision_thresh, torch.Tensor):\n        if decision_thresh.dim() == 1:\n            # add the spec dim to be aligned with compute_bounds return\n            decision_thresh = decision_thresh.unsqueeze(-1)\n\n    if opts['pruning_in_iteration']:\n        if return_A:\n            raise NotImplementedError(\n                'Pruning in iteration optimization does not support '\n                'return A yet. '\n                'Please fix or discard this optimization by setting '\n                '--disable_pruning_in_iteration '\n                'or bab: pruning_in_iteration: false')\n        pruner = OptPruner(\n            x, threshold=opts['pruning_in_iteration_threshold'],\n            multi_spec_keep_func=multi_spec_keep_func,\n            loss_reduction_func=loss_reduction_func,\n            decision_thresh=decision_thresh,\n            epsilon_over_decision_thresh=epsilon_over_decision_thresh,\n            fix_interm_bounds=fix_interm_bounds)\n    else:\n        pruner = None\n\n    if opt_choice == 'adam-autolr':\n        opt = AdamElementLR(parameters)\n    elif opt_choice == 'adam':\n        opt = optim.Adam(parameters)\n    elif opt_choice == 'sgd':\n        opt = optim.SGD(parameters, momentum=0.9)\n    else:\n        raise NotImplementedError(opt_choice)\n\n    # Create a weight vector to scale learning rate.\n    loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device)\n    scheduler = optim.lr_scheduler.ExponentialLR(opt, opts['lr_decay'])\n\n    # best_intermediate_bounds is linked to aux_reference_bounds!\n    best_intermediate_bounds = {}\n    if (sparse_intermediate_bounds and aux_reference_bounds is None\n            and reference_bounds is not None):\n        aux_reference_bounds = {}\n        for name, (lb, ub) in reference_bounds.items():\n            aux_reference_bounds[name] = [\n                lb.detach().clone(), ub.detach().clone()]\n    if aux_reference_bounds is None:\n        aux_reference_bounds = {}\n\n    if len(apply_output_constraints_to) > 0:\n        # INVPROP requires that all layers have cached bounds. This may not be the case\n        # unless we explicitly compute them.\n        self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = []\n        with torch.no_grad():\n            self.compute_bounds(\n                x=x, C=C, method='backward', bound_lower=bound_lower,\n                bound_upper=bound_upper, final_node_name=final_node_name,\n                interm_bounds=interm_bounds)\n        self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = (\n            apply_output_constraints_to\n        )\n\n    if (return_A and self.output_name[0] in needed_A_dict.keys()\n            and self.input_name[0] in needed_A_dict[self.output_name[0]]):\n        # If the A dict will be returned, and we expect to retrieve the hyperplanes relating the\n        # output layer to the input layer, then we store these keys and pass them to the\n        # '_update_best_ret' method so that these entries may be updated during the optimization\n        # process. Only these output/input layer entries will be updated, and if other entries need\n        # to be updated, '_update_best_ret' is not the correct method to update them.\n        out_in_keys = (self.output_name[0], self.input_name[0])\n    else:\n        out_in_keys = None\n\n    need_grad = True\n    patience = 0\n    ret_0 = None\n    for i in range(iteration):\n        if i == 0:\n            # If we are at the first iteration, we need to\n            # set the constraints_optimized to None\n            self.constraints_optimized = None\n\n        if cutter:\n            # cuts may be optimized by cutter\n            self.cut_module = cutter.cut_module\n\n        if self.constraints_optimized is not None:\n            for root in self.roots():\n                if ( hasattr(root, 'perturbation')\n                    and root.perturbation is not None\n                    # Currently constraints solving is designed for LpNorm.\n                    and isinstance(root.perturbation, PerturbationLpNorm) ):\n\n                    # Reset the constraints for this root.\n                    # TODO: Currently, the `reset` function simply overwrites,\n                    #       should support more sophisticated reset logic.\n                    root.perturbation.reset_constraints(\n                        self.constraints_optimized, decision_thresh)\n\n\n        intermediate_constr = None\n\n        if not fix_interm_bounds:\n            # If we still optimize all intermediate neurons, we can use\n            # interm_bounds as reference bounds.\n            if reference_bounds is None:\n                reference_bounds = {}\n            if interm_bounds is not None:\n                reference_bounds.update(interm_bounds)\n            interm_bounds = {}\n\n        if i == iteration - 1:\n            # No grad update needed for the last iteration\n            need_grad = False\n            if (self.device == 'cuda'\n                    and torch.get_default_dtype() == torch.float32\n                    and use_float64_in_last_iteration):\n                C, x, interm_bounds = self._to_float64(\n                    C, x, aux_reference_bounds, interm_bounds)\n\n        if pruner:\n            # we will use last update preserve mask in caller functions to recover\n            # lA, l, u, etc to full batch size\n            self.last_update_preserve_mask = pruner.preserve_mask\n            pruner.cache_full_sized_alpha(optimizable_activations)\n\n        # If input bounds are tightened with output constraints, they depend on the\n        # relaxations of all other layers. The current iteration will recompute them.\n        # This involves concretizing them, so they will depend on themselves.\n        # To avoid a loop of gradients, remove gradients here.\n        tighten_input_bounds = (\n                self.bound_opts['optimize_bound_args']['tighten_input_bounds']\n        )\n        if tighten_input_bounds:\n            for root in self.roots():\n                if hasattr(root, 'perturbation') and root.perturbation is not None:\n                    root.perturbation.x_L = root.perturbation.x_L.detach()\n                    root.perturbation.x_U = root.perturbation.x_U.detach()\n\n        with torch.no_grad() if not need_grad else ExitStack():\n            # ret is lb, ub or lb, ub, A_dict (if return_A is set to true)\n            ret = self.compute_bounds(\n                x, aux, C, method=method, IBP=IBP, forward=forward,\n                bound_lower=bound_lower, bound_upper=bound_upper,\n                reuse_ibp=reuse_ibp, return_A=return_A,\n                final_node_name=final_node_name, average_A=average_A,\n                # When intermediate bounds are recomputed, we must set it\n                # to None\n                interm_bounds=interm_bounds if fix_interm_bounds else None,\n                # This is the currently tightest interval, which will be used to\n                # pass split constraints when intermediate betas are used.\n                reference_bounds=reference_bounds,\n                # This is the interval used for checking for unstable neurons.\n                aux_reference_bounds=aux_reference_bounds if sparse_intermediate_bounds else None,\n                # These are intermediate layer beta variables and their\n                # corresponding A matrices and biases.\n                intermediate_constr=intermediate_constr,\n                needed_A_dict=needed_A_dict,\n                update_mask=pruner.preserve_mask if pruner else None,\n                cache_bounds=len(apply_output_constraints_to) > 0,\n            )\n        # If output constraints are used, it's possible that no inputs satisfy them.\n        # If one of the layer that uses output constraints realizes this, it sets\n        # self.infeasible_bounds = True for this element in the batch.\n        if self.infeasible_bounds is not None and torch.any(self.infeasible_bounds):\n            if ret[0] is not None:\n                ret = (\n                    torch.where(\n                        self.infeasible_bounds.unsqueeze(1),\n                        torch.full_like(ret[0], float('inf')),\n                        ret[0],\n                    ),\n                    ret[1],\n                )\n            if ret[1] is not None:\n                ret = (\n                    ret[0],\n                    torch.where(\n                        self.infeasible_bounds.unsqueeze(1),\n                        torch.full_like(ret[1], float('-inf')),\n                        ret[1],\n                    ),\n                )\n        ret_l, ret_u = ret[0], ret[1]\n\n        if pruner:\n            pruner.recover_full_sized_alpha(optimizable_activations)\n\n        if (self.cut_used and i % cutter.log_interval == 0\n                and len(self.cut_beta_params) > 0):\n            # betas[-1]: (2(0 lower, 1 upper), spec, batch, num_constrs)\n            if ret_l is not None:\n                print(i, 'lb beta sum:',\n                      f'{self.cut_beta_params[-1][0].sum() / ret_l.size(0)},',\n                      f'worst {ret_l.min()}')\n            if ret_u is not None:\n                print(i, 'lb beta sum:',\n                      f'{self.cut_beta_params[-1][1].sum() / ret_u.size(0)},',\n                      f'worst {ret_u.min()}')\n\n        if i == 0:\n            # save results at the first iteration\n            best_ret = [ret.detach().clone() if ret is not None else None for ret in ret[:2]]\n            ret_0 = ret[0].detach().clone() if bound_lower else ret[1].detach().clone()\n\n            for node in optimizable_activations:\n                if node.inputs[0].lower is None and node.inputs[0].upper is None:\n                    continue\n                new_intermediate = [node.inputs[0].lower.detach().clone(),\n                                    node.inputs[0].upper.detach().clone()]\n                best_intermediate_bounds[node.name] = new_intermediate\n                if sparse_intermediate_bounds:\n                    # Always using the best bounds so far as the reference\n                    # bounds.\n                    aux_reference_bounds[node.inputs[0].name] = new_intermediate\n\n            if out_in_keys is not None:\n                best_out_in_A_dict = clone_sub_A_dict(ret[2], out_in_keys)\n            else:\n                best_out_in_A_dict = None\n\n        l = ret_l\n        # Reduction over the spec dimension.\n        if ret_l is not None and ret_l.shape[1] != 1:\n            l = loss_reduction_func(ret_l)\n        u = ret_u\n        if ret_u is not None and ret_u.shape[1] != 1:\n            u = loss_reduction_func(ret_u)\n\n        # full_l, full_ret_l and full_u, full_ret_u is used for update the best\n        full_ret_l, full_ret_u = ret_l, ret_u\n        full_l = l\n        full_ret = ret\n\n        if pruner:\n            (x, C, full_l, full_ret_l, full_ret_u,\n             full_ret, stop_criterion) = pruner.prune(\n                x, C, ret_l, ret_u, ret, full_l, full_ret_l, full_ret_u,\n                full_ret, interm_bounds, aux_reference_bounds, reference_bounds,\n                stop_criterion_func, bound_lower)\n        else:\n            stop_criterion = (stop_criterion_func(full_ret_l) if bound_lower\n                              else stop_criterion_func(-full_ret_u))\n\n        loss_ = l if bound_lower else -u\n        total_loss = -1 * loss_\n        directly_optimize_layers = self.bound_opts['optimize_bound_args']['directly_optimize']\n        for directly_optimize_layer_name in directly_optimize_layers:\n            total_loss += (\n                self[directly_optimize_layer_name].upper.sum()\n                - self[directly_optimize_layer_name].lower.sum()\n            )\n\n        if type(stop_criterion) == bool:\n            loss = total_loss.sum() * (not stop_criterion)\n        else:\n            assert total_loss.shape == stop_criterion.shape\n            loss = (total_loss * stop_criterion.logical_not()).sum()\n\n        stop_criterion_final = isinstance(\n            stop_criterion, torch.Tensor) and stop_criterion.all()\n\n        if i == iteration - 1:\n            best_ret = list(best_ret)\n            if best_ret[0] is not None:\n                best_ret[0] = best_ret[0].to(torch.get_default_dtype())\n            if best_ret[1] is not None:\n                best_ret[1] = best_ret[1].to(torch.get_default_dtype())\n\n        if (i == iteration - 1 and self.device == 'cuda'\n                and torch.get_default_dtype() == torch.float32\n                and use_float64_in_last_iteration):\n            total_loss, x, full_ret = self._to_default_dtype(\n                x, total_loss, full_ret, ret, best_intermediate_bounds, return_A)\n\n        with torch.no_grad():\n            # for lb and ub, we update them in every iteration since updating them is cheap\n            need_update = False\n            improved_idx = None\n            if keep_best:\n                if best_ret[0] is not None:\n                    (\n                        best_ret, best_out_in_A_dict,\n                        need_update, idx_mask, improved_idx,\n                    ) = _update_best_ret(\n                        full_ret, best_ret,\n                        loss_reduction_func,\n                        idx=0,\n                        deterministic=deterministic,\n                        best_out_in_A_dict=best_out_in_A_dict,\n                        out_in_keys=out_in_keys,\n                    )\n                if best_ret[1] is not None:\n                    (\n                        best_ret, best_out_in_A_dict,\n                        need_update, idx_mask, improved_idx,\n                    ) = _update_best_ret(\n                        full_ret, best_ret,\n                        loss_reduction_func,\n                        idx=1,\n                        deterministic=deterministic,\n                        best_out_in_A_dict=best_out_in_A_dict,\n                        out_in_keys=out_in_keys,\n                    )\n            else:\n                # Not saving the best, just keep the last iteration.\n                if full_ret[0] is not None:\n                    best_ret[0] = full_ret[0]\n                if full_ret[1] is not None:\n                    best_ret[1] = full_ret[1]\n\n            if return_A:\n                best_ret = [best_ret[0], best_ret[1], full_ret[2]]\n                if out_in_keys is not None:\n                    # Update A_dict entry for output/input layer\n                    # This entry corresponds to the best bounds.\n                    # Other A_dict entries may not, as they are copied from the last iteration.\n                    best_ret[2][out_in_keys[0]][out_in_keys[1]] = best_out_in_A_dict\n\n            patience = 0 if need_update else patience + 1\n            time_spent = time.time() - start\n\n            # Save variables if this is the best iteration.\n            # To save computational cost, we only check keep_best at the first\n            # (in case divergence) and second half iterations\n            # or before early stop by either stop_criterion or\n            # early_stop_patience reached\n            if (\n                i < 1\n                or i > int(iteration * start_save_best)\n                or deterministic\n                or stop_criterion_final\n                or patience == early_stop_patience\n                or time_spent > max_time\n            ):\n                # compare with the first iteration results and get improved indexes\n                if bound_lower:\n                    if deterministic:\n                        idx_mask, idx = improved_idx, None\n                    else:\n                        idx_mask, idx = _get_idx_mask(0, full_ret_l, ret_0, loss_reduction_func)\n                    ret_0[idx] = full_ret_l[idx]\n                else:\n                    if deterministic:\n                        idx_mask, idx = improved_idx, None\n                    else:\n                        idx_mask, idx = _get_idx_mask(1, full_ret_u, ret_0, loss_reduction_func)\n                    ret_0[idx] = full_ret_u[idx]\n\n                if idx is not None:\n                    # for update propose, we condition the idx to update only\n                    # on domains preserved\n                    if pruner:\n                        reference_idx, idx = pruner.prune_idx(idx_mask, idx, x)\n                    else:\n                        reference_idx = idx\n\n                    _update_optimizable_activations(\n                        optimizable_activations, interm_bounds,\n                        fix_interm_bounds, best_intermediate_bounds,\n                        reference_idx, idx, alpha, best_alphas, deterministic)\n\n                    if beta:\n                        self.update_best_beta(enable_opt_interm_bounds, betas,\n                                              best_betas, idx)\n\n        if os.environ.get('AUTOLIRPA_DEBUG_OPT', False):\n            print(f'****** iter [{i}]',\n                  f'loss: {loss.item()}, lr: {opt.param_groups[0][\"lr\"]}',\n                  (' pruning_in_iteration open status: '\n                     f'{pruner.pruning_in_iteration}') if pruner else '')\n\n        if stop_criterion_final:\n            print(f'\\nall verified at {i}th iter')\n            break\n\n        if patience > early_stop_patience:\n            print(f'Early stop at {i}th iter due to {early_stop_patience}'\n                  ' iterations no improvement!')\n            break\n\n        if time_spent > max_time:\n            print(f'Early stop at {i}th iter due to exceeding the time limit '\n                  f'for the optimization (time spent: {time_spent})')\n            break\n\n        if i != iteration - 1 and not loss.requires_grad:\n            assert i == 0, (i, iteration)\n            print('[WARNING] No optimizable parameters found. Will skip optimiziation. '\n                  'This happens e.g. if all optimizable layers are freezed or the '\n                  'network has no optimizable layers.')\n            break\n\n        opt.zero_grad(set_to_none=True)\n\n        if verbosity > 2:\n            current_lr = [param_group['lr'] for param_group in opt.param_groups]\n            print(f'*** iter [{i}]\\n', f'loss: {loss.item()}',\n                  total_loss.squeeze().detach().cpu().numpy(), 'lr: ',\n                  current_lr)\n            if beta:\n                print_optimized_beta(optimizable_activations)\n            if beta and i == 0 and verbosity > 2:\n                breakpoint()\n\n        if i != iteration - 1:\n            # we do not need to update parameters in the last step since the\n            # best result already obtained\n            loss.backward()\n\n            # All intermediate variables are not needed at this point.\n            self._clear_and_set_new(\n                None,\n                cache_bounds=len(apply_output_constraints_to) > 0,\n            )\n            if opt_choice == 'adam-autolr':\n                opt.step(lr_scale=[loss_weight, loss_weight])\n            else:\n                opt.step()\n\n        if beta:\n            for b in betas:\n                b.data = (b >= 0) * b.data\n            for dmi in range(len(dense_coeffs_mask)):\n                # apply dense mask to the dense split coeffs matrix\n                coeffs[dmi].data = (\n                    dense_coeffs_mask[dmi].float() * coeffs[dmi].data)\n\n\n        if alpha:\n            for m in optimizable_activations:\n                m.clip_alpha()\n        if apply_output_constraints_to is not None and len(apply_output_constraints_to) > 0:\n            for m in self.nodes():\n                m.clip_gammas()\n\n        scheduler.step()\n\n        if pruner:\n            pruner.next_iter()\n\n    if pruner:\n        best_ret = pruner.update_best(full_ret_l, full_ret_u, best_ret)\n\n    if verbosity > 3:\n        breakpoint()\n\n    if keep_best:\n        # Set all variables to their saved best values.\n        with torch.no_grad():\n            for idx, node in enumerate(optimizable_activations):\n                if node.name not in best_intermediate_bounds:\n                    continue\n                if alpha:\n                    # Assigns a new dictionary.\n                    node.alpha = best_alphas[node.name]\n                # Update best intermediate layer bounds only when they are\n                # optimized. If they are already fixed in\n                # interm_bounds, then do nothing.\n                best_intermediate = best_intermediate_bounds[node.name]\n                node.inputs[0].lower.data = best_intermediate[0].data\n                node.inputs[0].upper.data = best_intermediate[1].data\n            if beta:\n                for node in self.nodes_with_beta:\n                    assert getattr(node, 'sparse_betas', None) is not None\n                    if enable_opt_interm_bounds:\n                        for key in node.sparse_betas.keys():\n                            node.sparse_betas[key].val.copy_(\n                                best_betas[node.name][key])\n                    else:\n                        node.sparse_betas[0].val.copy_(best_betas[node.name])\n            if self.cut_used:\n                for ii in range(len(self.cut_beta_params)):\n                    self.cut_beta_params[ii].data = best_betas['cut'][ii].data\n\n    if interm_bounds is not None and not fix_interm_bounds:\n        for l in self._modules.values():\n            if (l.name in interm_bounds.keys()\n                    and l.is_lower_bound_current()):\n                l.lower = torch.max(l.lower, interm_bounds[l.name][0])\n                l.upper = torch.min(l.upper, interm_bounds[l.name][1])\n                infeasible_neurons = l.lower > l.upper\n                if infeasible_neurons.any():\n                    print(f'Infeasibility detected in layer {l.name}.',\n                          infeasible_neurons.sum().item(),\n                          infeasible_neurons.nonzero()[:, 0])\n\n    if verbosity > 0:\n        if best_ret[0] is not None:\n            # FIXME: unify the handling of l and u.\n            print('best_l after optimization:', best_ret[0].sum().item())\n            if beta:\n                print('beta sum per layer:', [p.sum().item() for p in betas])\n        print('alpha/beta optimization time:', time.time() - start)\n\n    for node in optimizable_activations:\n        node.opt_end()\n\n    if pruner:\n        pruner.update_ratio(full_l, full_ret_l)\n        pruner.clean_full_sized_alpha_cache()\n\n    if os.environ.get('AUTOLIRPA_DEBUG_OPT', False):\n        print()\n\n    return best_ret\n\n\ndef init_alpha(self: 'BoundedModule', x, share_alphas=False, method='backward',\n               c=None, bound_lower=True, bound_upper=True, final_node_name=None,\n               interm_bounds=None, reference_alphas=None,\n               skip_bound_compute=False):\n    self(*x) # Do a forward pass to set perturbed nodes\n    final = (self.final_node() if final_node_name is None\n             else self[final_node_name])\n    self._set_used_nodes(final)\n\n    optimizable_activations = self.get_enabled_opt_act()\n    for node in optimizable_activations:\n        # TODO(7/6/2023) In the future, we may need to enable alpha sharing\n        # automatically by consider the size of all the optimizable nodes in the\n        # graph. For now, only an adhoc check in MatMul is added.\n        node._all_optimizable_activations = optimizable_activations\n\n        # initialize the parameters\n        node.opt_init()\n\n    apply_output_constraints_to = (\n        self.bound_opts['optimize_bound_args']['apply_output_constraints_to']\n    )\n    if (not skip_bound_compute or interm_bounds is None or\n            reference_alphas is None or not all(\n                [act.name in reference_alphas\n                 for act in optimizable_activations])):\n        skipped = False\n        # if new interval is None, then CROWN interval is not present\n        # in this case, we still need to redo a CROWN pass to initialize\n        # lower/upper\n        with torch.no_grad():\n            # We temporarilly deactivate output constraints\n            self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = []\n            l, u = self.compute_bounds(\n                x=x, C=c, method=method, bound_lower=bound_lower,\n                bound_upper=bound_upper, final_node_name=final_node_name,\n                interm_bounds=interm_bounds)\n            self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = (\n                apply_output_constraints_to\n            )\n            if len(apply_output_constraints_to) > 0:\n                # Some layers, such as the BoundTanh layer, do some of their initialization\n                # in the forward pass. We need to call the forward pass again to ensure\n                # that they are initialized for the output constraints, too.\n                l, u = self.compute_bounds(\n                    x=x, C=c, method=method, bound_lower=bound_lower,\n                    bound_upper=bound_upper, final_node_name=final_node_name,\n                    interm_bounds=interm_bounds, cache_bounds=True)\n    else:\n        # we skip, but we still would like to figure out the \"used\",\n        # \"perturbed\", \"backward_from\" of each note in the graph\n        skipped = True\n        # this set the \"perturbed\" property\n        self.set_input(*x, interm_bounds=interm_bounds)\n        self.backward_from = {node: [final] for node in self._modules}\n        l = u = None\n\n    final_node_name = final_node_name or self.final_name\n\n    init_intermediate_bounds = {}\n    for node in optimizable_activations:\n        start_nodes = []\n        if method in ['forward', 'forward+backward']:\n            start_nodes.append(('_forward', 1, None, False))\n        if method in ['backward', 'forward+backward']:\n            start_nodes += self.get_alpha_crown_start_nodes(\n                node,\n                c=c,\n                share_alphas=share_alphas,\n                final_node_name=final_node_name,\n            )\n        if not start_nodes:\n            continue\n        if skipped:\n            node.restore_alpha(reference_alphas[node.name], device=x[0].device, dtype=x[0].dtype)\n\n        else:\n            node.init_opt_parameters(start_nodes)\n        if node in self.splittable_activations:\n            for i in node.requires_input_bounds:\n                input_node = node.inputs[i]\n                if (not input_node.perturbed\n                        or node.inputs[i].lower is None\n                        and node.inputs[i].upper is None):\n                    continue\n                init_intermediate_bounds[node.inputs[i].name] = (\n                    [node.inputs[i].lower.detach(),\n                    node.inputs[i].upper.detach()])\n    if (\n        apply_output_constraints_to is not None\n        and len(apply_output_constraints_to) > 0\n        and hasattr(self, 'constraints')\n    ):\n        # self.constraints.shape = (batch_size, num_constraints, num_output_neurons)\n        # For abCROWN we know that:\n        # If the output constraints are a conjunction, the shape is (1, num_constraints, *)\n        # If the output constraints are a disjunction, the shape is (num_constraints, 1, *)\n        # Checking which entry is 1 allows to discern both cases.\n        # If auto_LiRPA is used directly, we could have batches of inputs with more than one\n        # constraint. This is currently not supported.\n        if self.constraints.size(0) == 1:\n            num_gammas = self.constraints.size(1)\n        elif self.constraints.size(1) == 1:\n            num_gammas = self.constraints.size(0)\n        else:\n            raise NotImplementedError(\n                'To use output constraints, either have a batch size of 1 or use only one '\n                'output constraint'\n            )\n        for node in self.nodes():\n            node.init_gammas(num_gammas)\n\n    if self.bound_opts['verbosity'] >= 1:\n        print('Optimizable variables initialized.')\n    if skip_bound_compute:\n        return init_intermediate_bounds\n    else:\n        return l, u, init_intermediate_bounds\n"
  },
  {
    "path": "auto_LiRPA/output_constraints.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\n\nfrom .utils import *\nfrom .bound_ops import *\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef invprop_enabled(self: 'BoundedModule'):\n    return self.bound_opts['optimize_bound_args']['apply_output_constraints_to']\n\n\ndef invprop_init_infeasible_bounds(self: 'BoundedModule', bound_node, C):\n    # Infeasible bounds can result from unsatisfiable output constraints.\n    # We track them to set the corresponding lower bounds to inf and upper bounds to\n    # -inf.\n    if self.infeasible_bounds is None:\n        device = bound_node.attr['device']\n        if isinstance(C, Patches):\n            self.infeasible_bounds = torch.full((C.shape[1],), False, device=device)\n        else:\n            assert isinstance(C, (torch.Tensor, eyeC, OneHotC)), type(C)\n            self.infeasible_bounds = torch.full((C.shape[0],), False, device=device)\n\n\ndef invprop_check_infeasible_bounds(self: 'BoundedModule', lb, ub):\n    if torch.any(self.infeasible_bounds):\n        if lb is not None:\n            assert lb.size(0) == self.infeasible_bounds.size(0)\n            lb = torch.where(self.infeasible_bounds.unsqueeze(1),\n                             torch.tensor(float('inf'), device=lb.device), lb)\n        if ub is not None:\n            assert ub.size(0) == self.infeasible_bounds.size(0)\n            ub = torch.where(self.infeasible_bounds.unsqueeze(1),\n                             torch.tensor(float('-inf'), device=ub.device), ub)\n    return lb, ub\n\n\ndef backward_general_invprop(\n    self: 'BoundedModule',\n    initial_As, initial_lb, initial_ub,\n    bound_node,\n    C,\n    start_backpropagation_at_node = None,\n    bound_lower=True,\n    bound_upper=True,\n    average_A=False,\n    need_A_only=False,\n    unstable_idx=None,\n    update_mask=None,\n):\n    use_beta_crown = self.bound_opts['optimize_bound_args']['enable_beta_crown']\n    # Sometimes, not using output constraints can give better results.\n    # When this flag is set, the bounds are computed both with and without\n    # output constraints, and the best of the two is returned.\n    best_of_oc_and_no_oc = (\n        self.bound_opts['optimize_bound_args']['best_of_oc_and_no_oc']\n    )\n\n    assert not use_beta_crown\n    assert not self.cut_used\n    assert initial_As is None\n    assert initial_lb is None\n    assert initial_ub is None\n    if best_of_oc_and_no_oc:\n        # Important: If input bounds are tightened, then this call must be done\n        # *before* the use of output constraints.\n        # At the end of backward_general, the bounds are concretized. For the input\n        # bounds, those concrete bounds are used to overwrite the bounds in the\n        # input perturbations, so they'll then be used by all other layers during\n        # their concretization. These input bounds *must* have their gradients\n        # w.r.t. the relaxations set up. The call to backward_general without\n        # output constraints will overwrite these bounds with values that do not\n        # have gradients. So it must come first.\n        with torch.no_grad():\n            o_res = self.backward_general(\n                bound_node=bound_node,\n                C=C,\n                start_backpropagation_at_node=start_backpropagation_at_node,\n                bound_lower=bound_lower,\n                bound_upper=bound_upper,\n                average_A=average_A,\n                need_A_only=need_A_only,\n                unstable_idx=unstable_idx,\n                update_mask=update_mask,\n                apply_output_constraints_to=[],\n            )\n    res = self.backward_general_with_output_constraint(\n        bound_node=bound_node,\n        C=C,\n        start_backporpagation_at_node=start_backpropagation_at_node,\n        bound_lower=bound_lower,\n        bound_upper=bound_upper,\n        average_A=average_A,\n        need_A_only=need_A_only,\n        unstable_idx=unstable_idx,\n        update_mask=update_mask,\n    )\n    if best_of_oc_and_no_oc:\n        # We use the best of both results. This would convert Infs to NaNs\n        # (because inf - inf = nan), so those entries get masked.\n        res0_inf_mask = torch.isinf(res[0])\n        r0 = res[0] - res[0].detach() + torch.max(res[0].detach(), o_res[0].detach())\n        r0 = torch.where(res0_inf_mask, res[0], r0)\n        res1_inf_mask = torch.isinf(res[1])\n        r1 = res[1] - res[1].detach() + torch.min(res[1].detach(), o_res[1].detach())\n        r1 = torch.where(res1_inf_mask, res[1], r1)\n        if self.return_A:\n            if res[2] != {}:\n                raise NotImplementedError(\n                    \"Merging of A not implemented yet. If set, try disabling --best_of_oc_and_no_oc\"\n                )\n            res = (r0, r1, {})\n        else:\n            res = (r0, r1)\n    batch_size = res[0].size(0)\n    infeasible_bounds = torch.any(res[0].reshape((batch_size, -1)) > res[1].reshape((batch_size, -1)), dim=1)\n    if torch.any(infeasible_bounds):\n        self.infeasible_bounds = torch.logical_or(self.infeasible_bounds, infeasible_bounds)\n    return res\n\n\ndef backward_general_with_output_constraint(\n    self: 'BoundedModule',\n    bound_node,\n    C,\n    start_backporpagation_at_node = None,\n    bound_lower=True,\n    bound_upper=True,\n    average_A=False,\n    need_A_only=False,\n    unstable_idx=None,\n    update_mask=None,\n):\n    assert start_backporpagation_at_node is None\n    assert not isinstance(C, str)\n\n    neurons_in_layer = 1\n    for d in bound_node.output_shape[1:]:\n        neurons_in_layer *= d\n\n    # backward_general uses C to compute batch_size, output_dim and output_shape, just like below.\n    # When output constraints are applied, it will perform a different backpropagation,\n    # but those variables need to be computed regardless. So we need to retain the original C\n    # and pass it on to backward_general. If initial_As is set (which it is, if this code here\n    # is executed), it will not use C for anything else.\n    orig_C = C\n\n    C, batch_size, output_dim, output_shape = self._preprocess_C(C, bound_node)\n    device = bound_node.device\n    if device is None and hasattr(C, 'device'):\n        device = C.device\n    # self.constraints.shape == (batch_size, num_constraints, output_neurons)\n    batch_size = self.constraints.size(0)\n    num_constraints = self.constraints.size(1)\n\n    # 1) Linear: Hx + d\n    # Result is a tensor, <= 0 for all entries if output constraint is satisfied\n    H = self.constraints.transpose(1,2)  # (batch_size, output_neurons, num_constraints)\n    d = -self.thresholds  # (batch)\n    assert H.ndim == 3\n    assert H.size(0) == batch_size\n    assert H.size(2) == num_constraints\n    assert d.ndim == 1\n    if batch_size > 1:\n        assert num_constraints == 1\n        assert d.size(0) == batch_size\n    else:\n        assert d.size(0) == num_constraints\n\n    if hasattr(bound_node, 'gammas'):\n        gammas = bound_node.gammas\n    else:\n        if hasattr(bound_node, 'opt_stage'):\n            assert bound_node.opt_stage not in ['opt', 'reuse']\n        if batch_size == 1:\n            gammas = torch.zeros((2, num_constraints, neurons_in_layer), device=device)\n        else:\n            gammas = torch.zeros((2, batch_size, neurons_in_layer), device=device)\n\n    # H.shape = (batch_size, output_neurons, num_constraints==1)\n    # We need used_weight.shape = (batch_size, this_layer_neurons, prev_layer_neurons)\n    # This is satisfied by H, because it will be transposed before being accessed and\n    # output_neurons == prev_layer_neurons\n    linear_Hxd_layer_weight_value = nn.Parameter(H.to(gammas))\n    linear_Hxd_layer_weight = BoundParams(\n        ori_name=\"/linear_Hxd_layer_weight\",\n        value=None,\n        perturbation=None,\n    )\n    linear_Hxd_layer_weight.name = \"linear_Hxd_layer_weight\"\n    linear_Hxd_layer_weight.lower = linear_Hxd_layer_weight_value\n    linear_Hxd_layer_weight.upper = linear_Hxd_layer_weight_value\n\n    if batch_size == 1:\n        linear_Hxd_layer_bias_value = nn.Parameter(d.float().to(device))\n    else:\n        linear_Hxd_layer_bias_value = nn.Parameter(d.float().to(device).unsqueeze(1))\n    linear_Hxd_layer_bias = BoundParams(\n        ori_name=\"/linear_Hxd_layer_bias\",\n        value=None,\n        perturbation=None,\n    )\n    linear_Hxd_layer_bias.name = \"linear_Hxd_layer_bias\"\n    linear_Hxd_layer_bias.lower = linear_Hxd_layer_bias_value\n    linear_Hxd_layer_bias.upper = linear_Hxd_layer_bias_value\n\n    linear_Hxd_layer = BoundLinear(\n        attr=None,\n        inputs=[\n            self.final_node(),\n            linear_Hxd_layer_weight,\n            linear_Hxd_layer_bias,\n        ],\n        output_index=0,\n        options=self.bound_opts,\n    )\n    linear_Hxd_layer.name = \"/linear_Hxd_layer\"\n    linear_Hxd_layer.device = device\n    linear_Hxd_layer.perturbed = True\n    linear_Hxd_layer.output_shape = torch.Size([1, num_constraints])\n    linear_Hxd_layer.batch_dim = bound_node.batch_dim\n    linear_Hxd_layer.batched_weight_and_bias = (batch_size > 1)\n\n    # 2) Gamma\n    # A seperate gamma per output constraint. All gammas are always positive.\n    # Depending on the configuration, gammas are shared across neurons in the\n    # optimized layer.\n    gamma_layer_weight = BoundParams(\n        ori_name=\"/gamma_layer_weight\",\n        value=None,\n        perturbation=None,\n    )\n    gamma_layer_weight.name = \"gamma_layer_weight\"\n    assert gammas.ndim == 3\n    assert gammas.size(0) == 2\n    if batch_size == 1:\n        # gammas.shape = (2, num_constraints, this_layer_neurons)\n        assert gammas.ndim == 3\n        assert gammas.size(0) == 2\n        assert gammas.size(1) == num_constraints\n        this_layer_neurons = gammas.size(2)\n\n        # In linear.py, these weights will be used to compute next_A based on last_A:\n        # last_A.shape = (unstable_neurons, batch_size==1, this_layer_neurons)\n        # next_A.shape = (unstable_neurons, batch_size==1, prev_layer_neurons)\n        # prev_layer_neurons == num_constraints\n        # So we set the weights as\n        # (num_constraints, this_layer_neurons)\n        # This will be transposed and accessed by linear.py as\n        # (this_layer_neurons, num_constraints)\n        # Note that the shape will be further modified in linear.py\n        gamma_layer_weight.lower = gammas[0].unsqueeze(0)\n        gamma_layer_weight.upper = -gammas[1].unsqueeze(0)\n    else:\n        # ABCrown optimized the computation by transposing the query.\n        # Instead of one batch entry with N constraints, we have N batch entries\n        # with one contraint each. We do not support multiple batch entries\n        # each with multiple constraints.\n        # gammas.shape = (2, batch_size, this_layer_neurons)\n        # Here, we can only check that the batch size is correct.\n        assert gammas.size(1) == batch_size\n        assert num_constraints == 1\n\n        this_layer_neurons = gammas.size(2)\n\n        # In linear.py, these weights will be used to compute next_A based on last_A:\n        # last_A.shape = (unstable_neurons, batch_size, this_layer_neurons)\n        # next_A.shape = (unstable_neurons, batch_size, prev_layer_neurons==1)\n        # prev_layer_neurons == 1 because it's num_constraints\n        # So we set the weights as\n        # (batch_size, 1, this_layer_neurons)\n        # This will be transposed and accessed by linear.py as\n        # (batch_size, this_layer_neurons, 1)\n        # Note that the shape will be further modified in linear.py\n        gamma_layer_weight.lower = gammas[0].unsqueeze(1)\n        gamma_layer_weight.upper = -gammas[1].unsqueeze(1)\n    gamma_layer = BoundLinear(\n        attr=None,\n        inputs=[linear_Hxd_layer, gamma_layer_weight],\n        output_index=0,\n        options=self.bound_opts,\n    )\n    gamma_layer.name = \"/gamma_layer\"\n    gamma_layer.device = device\n    gamma_layer.perturbed = True\n    gamma_layer.input_shape = linear_Hxd_layer.output_shape\n    gamma_layer.output_shape = torch.Size([1, this_layer_neurons])\n    gamma_layer.batch_dim = bound_node.batch_dim\n    gamma_layer.use_seperate_weights_for_lower_and_upper_bounds = True\n    gamma_layer.batched_weight_and_bias = (batch_size > 1)\n\n    # 3) Reshape\n    # To the same shape as the layer that's optimized.\n    reshape_layer_output_shape = BoundBuffers(\n        ori_name=\"/reshape_layer_output_shape\",\n        value = torch.tensor(bound_node.output_shape[1:]),\n        perturbation=None,\n        options=self.bound_opts,\n    )\n    reshape_layer_output_shape.name = \"reshape_layer_output_shape\"\n    reshape_layer = BoundReshape(\n        attr=None,\n        inputs = [gamma_layer, reshape_layer_output_shape],\n        output_index=0,\n        options=self.bound_opts,\n    )\n    reshape_layer.name = \"/reshape_layer\"\n    reshape_layer.device = device\n    reshape_layer.perturbed = True\n    reshape_layer.input_shape = gamma_layer.output_shape\n    reshape_layer.output_shape = bound_node.output_shape\n    reshape_layer.batch_dim = bound_node.batch_dim\n\n    # The residual connection that connects the optimized layer and the reshape\n    # layer from above is not explicitly coded, it's handled implicitly:\n    # Here, we propagate backwards through 5->4->3->2->1->regular output layer and let\n    # CROWN handle the propagation from there on backwards to the input layer.\n    # The other half of the residual connection is implemented by explicitly setting\n    # the .lA and .uA values of the optimized layer to C.\n    # This is done via initial_As, initial_lb, initial_ub.\n\n    if isinstance(C, (OneHotC, eyeC)):\n        batch_size = C.shape[1]\n        assert C.shape[0] <= C.shape[2]\n        assert len(C.shape) == 3\n        # This is expensive, but Reshape doesn't support OneHotC objects\n        if isinstance(C, OneHotC):\n            C = torch.eye(C.shape[2], device=C.device)[C.index].unsqueeze(1).expand(-1, batch_size, -1)\n        else:\n            C = torch.eye(C.shape[2], device=C.device).unsqueeze(1).expand(-1, batch_size, -1)\n\n    start_shape = None\n    lA = C if bound_lower else None\n    uA = C if bound_upper else None\n\n    # 3) Reshape\n    A, lower_b, upper_b = reshape_layer.bound_backward(\n        lA, uA, *reshape_layer.inputs,\n        start_node=bound_node, unstable_idx=unstable_idx,\n        start_shape=start_shape)\n    assert lower_b == 0\n    assert upper_b == 0\n    lA = A[0][0]\n    uA = A[0][1]\n\n    # 2) Gamma\n    A, lower_b, upper_b = gamma_layer.bound_backward(\n        lA, uA, *gamma_layer.inputs,\n        start_node=bound_node, unstable_idx=unstable_idx,\n        start_shape=start_shape)\n    assert lower_b == 0\n    assert upper_b == 0\n    lA = A[0][0]\n    uA = A[0][1]\n\n    # 1) Hx + d\n    A, lower_b, upper_b = linear_Hxd_layer.bound_backward(\n        lA, uA, *linear_Hxd_layer.inputs,\n        start_node=bound_node, unstable_idx=unstable_idx,\n        start_shape=start_shape)\n    # lower_b and upper_b are no longer 0, because d wasn't 0.\n    lA = A[0][0]\n    uA = A[0][1]\n\n    # This encodes the residual connection.\n    initial_As = {\n        self.final_node().name: (lA, uA),\n        bound_node.name: (C, C),\n    }\n\n    assert lower_b.ndim == 2\n    assert upper_b.ndim == 2\n\n    return self.backward_general(\n        bound_node = bound_node,\n        start_backpropagation_at_node = self.final_node(),\n        C = orig_C,  #  only used for batch_size, output_dim, output_shape computation\n        bound_lower = bound_lower,\n        bound_upper = bound_upper,\n        average_A = average_A,\n        need_A_only = need_A_only,\n        unstable_idx = unstable_idx,\n        update_mask = update_mask,\n        apply_output_constraints_to = [],  # no nested application\n        initial_As = initial_As,\n        initial_lb = lower_b,\n        initial_ub = upper_b,\n    )\n"
  },
  {
    "path": "auto_LiRPA/parse_graph.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nfrom torch.onnx.utils import _optimize_graph\nfrom collections import OrderedDict\nfrom collections import namedtuple\nfrom packaging import version\nimport re\nimport os\nimport traceback\nfrom .bounded_tensor import BoundedTensor, BoundedParameter\nfrom .utils import logger, unpack_inputs\n\nNode = namedtuple('Node', (\n    'name', 'ori_name', 'inputs', 'attr', 'op', 'param', 'input_index',\n    'bound_node', 'output_index', 'perturbation'), defaults=(None,) * 10)\n\ndef get_node_name(node):\n    return node.debugName()\n\ndef get_node_attribute(node, attribute_name):\n    if hasattr(torch.onnx.symbolic_helper, '_node_get'):\n        # Pytorch >= 1.13.\n        return torch.onnx.symbolic_helper._node_get(node, attribute_name)\n    else:\n        # Pytorch <= 1.12. This will call _node_getitem in torch.onnx.utils.\n        return node[attribute_name]\n\ndef parse_graph(graph, inputs, params):\n    input_all = []\n    input_used = []\n    scope = {}\n    for n in graph.inputs():\n        input_all.append(n.debugName())\n    for n in graph.nodes():\n        n_inputs = [get_node_name(i) for i in n.inputs()]\n        for inp in n.inputs():\n            input_used.append(inp.debugName())\n        for out in n.outputs():\n            scope[get_node_name(out)] = n.scopeName()\n    for node in graph.inputs():\n        name = get_node_name(node)\n        scope[name] = ''\n    for n in graph.outputs():\n        name = get_node_name(n)\n        if name in input_all:\n            # This output node directly comes from an input node with an Op\n            input_used.append(n.debugName())\n\n    def name_with_scope(node):\n        name = get_node_name(node)\n        name = '/'.join([scope[name], name])\n        if '.' in name:\n            # \".\" should not be used as it could issues in state_dict loading\n            # where PyTorch would treat it as having submodules\n            name = name.replace('.', '-')\n        return name\n\n    nodesOP = []\n    for n in graph.nodes():\n        attrs = {k: get_node_attribute(n, k) for k in n.attributeNames()}\n        n_inputs = [name_with_scope(i) for i in n.inputs()]\n        for i, out in enumerate(list(n.outputs())):\n            nodesOP.append(Node(**{\n                'name': name_with_scope(out),\n                'op': n.kind(),\n                'inputs': n_inputs,\n                'attr': attrs,\n                'output_index': i,\n            }))\n\n    # filter out input nodes in `graph.inputs()` that are actually used\n    nodesIn = []\n    used_by_index = []\n    for i, n in enumerate(graph.inputs()):\n        name = get_node_name(n)\n        used = name in input_used\n        used_by_index.append(used)\n        if used:\n            nodesIn.append(n)\n\n    # filter out input nodes in `inputs` that are actually used\n    inputs_unpacked = unpack_inputs(inputs)\n    assert len(list(graph.inputs())) == len(inputs_unpacked) + len(params)\n    inputs = [inputs_unpacked[i] for i in range(len(inputs_unpacked)) if used_by_index[i]]\n    # index of the used inputs among all the inputs\n    input_index = [i for i in range(len(inputs_unpacked)) if used_by_index[i]]\n    # Add a name to all inputs\n    inputs = list(zip([\"input_{}\".format(input_index[i]) for i in range(len(inputs))], inputs))\n    # filter out params that are actually used\n    params = [params[i] for i in range(len(params)) if used_by_index[i + len(inputs_unpacked)]]\n    inputs_and_params = inputs + params\n    assert len(nodesIn) == len(inputs_and_params)\n\n    # output nodes of the module\n    nodesOut = []\n    for n in graph.outputs():\n        # we only record names\n        nodesOut.append(name_with_scope(n))\n\n    for i, n in enumerate(nodesIn):\n        if (isinstance(inputs_and_params[i][1], BoundedTensor) or\n                isinstance(inputs_and_params[i][1], BoundedParameter)):\n            perturbation = inputs_and_params[i][1].ptb\n        else:\n            perturbation = None\n        if i > 0 and n.type().sizes() != list(inputs_and_params[i][1].size()):\n            raise RuntimeError(\"Input tensor shapes do not much: {} != {}\".format(\n                n.type().sizes(), list(inputs_and_params[i][1].size())))\n        name = name_with_scope(n)\n        nodesIn[i] = Node(**{\n            'name': name,\n            'ori_name': inputs_and_params[i][0],\n            'op': 'Parameter',\n            'inputs': [],\n            'attr': str(n.type()),\n            'param': inputs_and_params[i][1] if i >= len(inputs) else None,\n            # index among all the inputs including unused ones\n            'input_index': input_index[i] if i < len(inputs) else None,\n            # Input nodes may have perturbation, if they are wrapped in BoundedTensor or BoundedParameters\n            'perturbation': perturbation,\n        })\n\n    return nodesOP, nodesIn, nodesOut\n\ndef _get_jit_params(module, param_exclude, param_include):\n    state_dict = torch.jit._unique_state_dict(module, keep_vars=True)\n\n    if param_exclude is not None:\n        param_exclude = re.compile(param_exclude)\n    if param_include is not None:\n        param_include = re.compile(param_include)\n\n    new_state_dict = OrderedDict()\n    for k, v in state_dict.items():\n        if param_exclude is not None and param_exclude.match(k) is not None:\n            print(f'\\nremove input element {k} from nodesIn\\n')\n            continue\n        if param_include is not None and param_include.match(k) is None:\n            continue\n        new_state_dict[k] = v\n\n    params = zip(new_state_dict.keys(), new_state_dict.values())\n\n    return params\n\ndef get_output_template(out):\n    \"\"\"Construct a template for the module output with `None` representing places\n    to be filled with tensor results\"\"\"\n    if isinstance(out, torch.Tensor):\n        return None\n    elif isinstance(out, list):\n        return list([get_output_template(o) for o in out])\n    elif isinstance(out, tuple):\n        return tuple([get_output_template(o) for o in out])\n    elif isinstance(out, dict):\n        template = {}\n        for key in out:\n            template[key] = get_output_template(out[key])\n        return template\n    else:\n        raise NotImplementedError\n\ndef parse_source(node):\n    kind = node.kind()\n    if hasattr(node, 'sourceRange'):\n        source_range_str = node.sourceRange()\n        # divide source_range_str by '\\n' and drop any lines containing 'torch.nn'\n        source_range_str = '\\n'.join([line for line in source_range_str.split('\\n') if 'torch/nn' not in line])\n        match = re.match(r'([^ ]+\\.py)\\((\\d+)\\)', source_range_str)\n        if match:\n            # match.group(1) is the file name\n            # match.group(2) is the line number\n            return f\"{kind}_{os.path.basename(match.group(1)).split('.')[0]}_{match.group(2)}\"\n    return kind\n\ndef update_debug_names(trace_graph):\n    visited = []\n    for n in trace_graph.nodes():\n        for input in n.inputs():\n            if input.debugName() not in visited:\n                input.setDebugName(f\"{input.debugName()}_{parse_source(n)}\")\n                visited.append(input.debugName())\n        for output in n.outputs():\n            if output.debugName() not in visited:\n                output.setDebugName(f\"{output.debugName()}_{parse_source(n)}\")\n                visited.append(output.debugName())\n\ndef parse_module(module, inputs, param_exclude=\".*AuxLogits.*\", param_include=None):\n    params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include)\n    try:\n        trace, out = torch.jit._get_trace_graph(module, inputs)\n    except:\n        print(traceback.format_exc())\n        raise RuntimeError(\n            'Failed to get the trace. '\n            'Please check that the model and inputs are compatible with torch.jit.')\n\n    if version.parse(torch.__version__) < version.parse(\"2.0.0\"):\n        from torch.onnx.symbolic_helper import _set_opset_version\n        _set_opset_version(12)\n    if version.parse(torch.__version__) >= version.parse(\"2.1.0\"):\n        # This is needed for BoundConcatGrad to work with torch 2.1.0 and later\n        if version.parse(torch.__version__) < version.parse(\"2.9.0\"):\n            from torch.onnx._globals import GLOBALS\n        else:\n            from torch.onnx._internal.torchscript_exporter._globals import GLOBALS\n        GLOBALS.autograd_inlining = False\n\n    logger.debug(\"Graph before ONNX convertion:\")\n    logger.debug(trace)\n\n    # Assuming that the first node in the graph is the primary input node.\n    # It must have a batch dimension.\n    primary_input = get_node_name(next(iter(trace.inputs())))\n    trace_graph = _optimize_graph(\n        trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,\n        params_dict={},\n        input_names=[primary_input],\n        dynamic_axes={primary_input: {0: 'batch'}})\n    logger.debug('trace_graph: %s', trace_graph)\n\n    if os.environ.get('AUTOLIRPA_DEBUG_NAMES', 0):\n        update_debug_names(trace_graph)\n\n    logger.debug(\"ONNX graph:\")\n    logger.debug(trace_graph)\n\n    if not isinstance(inputs, tuple):\n        inputs = (inputs, )\n\n    nodesOP, nodesIn, nodesOut = parse_graph(trace_graph, tuple(inputs), tuple(params))\n\n    for i in range(len(nodesOP)):\n        param_in = OrderedDict()\n        for inp in nodesOP[i].inputs:\n            for n in nodesIn:\n                if inp == n.name:\n                    param_in.update({inp:n.param})\n        nodesOP[i] = nodesOP[i]._replace(param=param_in)\n\n    template = get_output_template(out)\n\n    return nodesOP, nodesIn, nodesOut, template\n"
  },
  {
    "path": "auto_LiRPA/patches.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\n\ndef insert_zeros(image, s):\n    \"\"\"\n    Insert s columns and rows 0 between every pixel in the image. For example:\n    image = [[1, 2, 3],\n             [4, 5, 6],\n             [7, 8, 9]]\n    s = 2\n    output = [[1, 0, 0, 2, 0, 0, 3],\n              [0, 0, 0, 0, 0, 0, 0],\n              [0, 0, 0, 0, 0, 0, 0],\n              [4, 0, 0, 5, 0, 0, 6],\n              [0, 0, 0, 0, 0, 0, 0],\n              [0, 0, 0, 0, 0, 0, 0],\n              [7, 0, 0, 8, 0, 0, 9]]\n    \"\"\"\n    if s <= 0:\n        return image\n    matrix = torch.zeros(size=(image.size(0), image.size(1), image.size(2) * (s+1) - s, image.size(3) * (s+1) - s), dtype=image.dtype, device=image.device)\n    matrix_stride = matrix.stride()\n    selected_matrix = torch.as_strided(matrix, [\n          # Shape of the output matrix.\n          matrix.size(0),  # Batch size.\n          matrix.size(1),  # Channel.\n          image.size(2),  # H (without zeros)\n          image.size(3),  # W (without zeros)\n          ], [\n          # Stride of the output matrix.\n          matrix_stride[0],  # Batch size dimension, keep using the old stride.\n          matrix_stride[1],  # Channel dimension.\n          matrix_stride[2] * (s + 1),  # Move s+1 rows.\n          s+1,  # Move s+1 pixels.\n    ])  # Move a pixel (on the width direction).\n    selected_matrix[:] = image\n    return matrix\n\n\ndef remove_zeros(image, s, remove_zero_start_idx=(0,0)):\n    if s <= 0:\n        return image\n    matrix_stride = image.stride()\n    storage_offset = image.storage_offset()\n    return torch.as_strided(image, [\n        # Shape of the output matrix.\n        *image.shape[:-2],\n        (image.size(-2) - remove_zero_start_idx[-2] + (s + 1) - 1) // (s + 1),  # H (without zeros)\n        (image.size(-1) - remove_zero_start_idx[-1] + (s + 1) - 1) // (s + 1),  # W (without zeros)\n        ], [\n        # Stride of the output matrix.\n        *matrix_stride[:-2],\n        matrix_stride[-2] * (s + 1),  # Move s+1 rows.\n        matrix_stride[-1] * (s + 1),  # Move s+1 pixels.\n        ],\n        storage_offset + matrix_stride[-2] * remove_zero_start_idx[-2] + matrix_stride[-1] * remove_zero_start_idx[-1]\n    )\n\n\ndef unify_shape(shape):\n    \"\"\"\n    Convert shapes to 4-tuple: (left, right, top, bottom).\n    \"\"\"\n    if shape is not None:\n        if isinstance(shape, int):\n            # Same on all four directions.\n            shape = (shape, shape, shape, shape)\n        if len(shape) == 2:\n            # (height direction, width direction).\n            shape = (shape[1], shape[1], shape[0], shape[0])\n        assert len(shape) == 4\n    # Returned: (left, right, top, bottom).\n    return shape\n\n\ndef simplify_shape(shape):\n    \"\"\"\n    Convert shapes to 2-tuple or a single number.\n    Used to avoid extra padding operation because the padding\n    operation in F.conv2d is not general enough.\n    \"\"\"\n    if len(shape) == 4:\n        # 4-tuple: (left, right, top, bottom).\n        if shape[0] == shape[1] and shape[2] == shape[3]:\n            shape = (shape[2], shape[0])\n    if len(shape) == 2:\n        # 2-tuple: (height direction, width direction).\n        if shape[0] == shape[1]:\n            shape = shape[0]\n    return shape\n\n\ndef is_shape_used(shape, expected=0):\n    if isinstance(shape, int):\n        return shape != expected\n    else:\n        return sum(shape) != expected\n\n\nclass Patches:\n    \"\"\"\n    A special class which denotes a convoluntional operator as a group of patches\n    the shape of Patches.patches is [batch_size, num_of_patches, out_channel, in_channel, M, M]\n    M is the size of a single patch\n    Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N)\n    num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2\n    Here we only consider kernels with the same H and W\n    \"\"\"\n    def __init__(\n            self, patches=None, stride=1, padding=0, shape=None, identity=0,\n            unstable_idx=None, output_shape=None, inserted_zeros=0, output_padding=0, input_shape=None):\n        # Shape: [batch_size, num_of_patches, out_channel, in_channel, M, M]\n        # M is the size of a single patch\n        # Assume that we have a conv2D layer with w.weight(out_channel, in_channel, M, M), stride and padding applied on an image (N * N)\n        # num_of_patches = ((N + padding * 2 - M)//stride + 1) ** 2\n        # Here we only consider kernels with the same H and W\n        self.patches = patches\n        self.stride = stride\n        self.padding = padding\n        self.shape = shape\n        self.identity = identity\n        self.unstable_idx = unstable_idx\n        self.output_shape = output_shape\n        self.input_shape = input_shape\n        self.inserted_zeros = inserted_zeros\n        self.output_padding = output_padding\n        self.simplify()\n\n    def __add__(self, other):\n        if isinstance(other, Patches):\n            # Insert images with zero to make stride the same, if necessary.\n            assert self.stride == other.stride\n            if self.unstable_idx is not None or other.unstable_idx is not None:\n                if self.unstable_idx is not other.unstable_idx:  # Same tuple object.\n                    raise ValueError('Please set bound option \"sparse_conv_intermediate_bounds\" to False to run this model.')\n                assert self.output_shape == other.output_shape\n            A1 = self.patches\n            A2 = other.patches\n            # change paddings to merge the two patches\n            sp = torch.tensor(unify_shape(self.padding))\n            op = torch.tensor(unify_shape(other.padding))\n            if (sp - op).abs().sum().item() > 0:\n                if (sp - op >= 0).all():\n                    A2 = F.pad(A2, (sp - op).tolist())\n                    pass\n                elif (sp - op <= 0).all():\n                    A1 = F.pad(A1, (op - sp).tolist())\n                else:\n                    raise ValueError(\"Unsupported padding size\")\n            ret = A1 + A2\n            return Patches(ret, other.stride, torch.max(sp, op).tolist(),\n                           ret.shape, unstable_idx=self.unstable_idx, output_shape=self.output_shape,\n                           inserted_zeros=self.inserted_zeros, output_padding=self.output_padding)\n        else:\n            assert self.inserted_zeros == 0\n            assert not is_shape_used(self.output_padding)\n            # Patches has shape (out_c, batch, out_h, out_w, in_c, h, w).\n            input_shape = other.shape[3:]\n            matrix = other\n            pieces = self.patches\n            if pieces.ndim == 9:\n                pieces = pieces.transpose(0, 1)\n                pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5]*pieces.shape[6], pieces.shape[7], pieces.shape[8]).transpose(0,1)\n            if pieces.ndim == 8:\n                pieces = pieces.transpose(0, 1)\n                pieces = pieces.view(pieces.shape[0], -1, pieces.shape[3], pieces.shape[4], pieces.shape[5], pieces.shape[6], pieces.shape[7]).transpose(0,1)\n            A1_matrix = patches_to_matrix(\n                pieces, input_shape, self.stride, self.padding,\n                output_shape=self.output_shape, unstable_idx=self.unstable_idx)\n            return A1_matrix.transpose(0, 1) + matrix\n\n    def __str__(self):\n        return (\n                f\"Patches(stride={self.stride}, padding={self.padding}, \"\n                f\"output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}, \"\n                f\"kernel_shape={list(self.patches.shape)}, input_shape={self.input_shape}, \"\n                f\"output_shape={self.output_shape}, unstable_idx={type(self.unstable_idx)})\"\n        )\n\n    @property\n    def device(self):\n        if self.patches is not None:\n            return self.patches.device\n        if self.unstable_idx is not None:\n            if isinstance(self.unstable_idx, tuple):\n                return self.unstable_idx[0].device\n            else:\n                return self.unstable_idx.device\n        raise RuntimeError(\"Patches object is unintialized and cannot determine its device.\")\n\n    def create_similar(self, patches=None, stride=None, padding=None, identity=None,\n                       unstable_idx=None, output_shape=None, inserted_zeros=None, output_padding=None,\n                       input_shape=None):\n        \"\"\"\n        Create a new Patches object with new patches weights, and keep other properties the same.\n        \"\"\"\n        new_patches = self.patches.clone() if patches is None else patches\n        new_identity = self.identity if identity is None else identity\n        if new_identity and (new_patches is not None):\n            raise ValueError(\"Identity Patches should have .patches property set to 0.\")\n        return Patches(\n            new_patches,\n            stride=self.stride if stride is None else stride,\n            padding=self.padding if padding is None else padding,\n            shape=new_patches.shape,\n            identity=new_identity,\n            unstable_idx=self.unstable_idx if unstable_idx is None else unstable_idx,\n            output_shape=self.output_shape if output_shape is None else output_shape,\n            inserted_zeros=self.inserted_zeros if inserted_zeros is None else inserted_zeros,\n            output_padding=self.output_padding if output_padding is None else output_padding,\n            input_shape=self.input_shape if input_shape is None else input_shape,\n        )\n    \n    def clone(self):\n        return self.create_similar()\n    \n    def detach(self):\n        new_obj = Patches(\n            patches=self.patches.detach() if self.patches is not None else None,\n            stride=self.stride,\n            padding=self.padding,\n            shape=self.shape,\n            identity=self.identity,\n            unstable_idx=(\n                tuple(idx.detach() for idx in self.unstable_idx)\n                if isinstance(self.unstable_idx, tuple)\n                else self.unstable_idx.detach()\n            ) if self.unstable_idx is not None else None,\n            output_shape=self.output_shape,\n            inserted_zeros=self.inserted_zeros,\n            output_padding=self.output_padding,\n            input_shape=self.input_shape,\n        )\n        return new_obj\n\n    def to_matrix(self, input_shape):\n        assert not is_shape_used(self.output_padding)\n        return patches_to_matrix(\n            self.patches, input_shape, self.stride, self.padding,\n            self.output_shape, self.unstable_idx, self.inserted_zeros\n        )\n\n    def simplify(self):\n        \"\"\"Merge stride and inserted_zeros; if they are the same they can cancel out.\"\"\"\n        stride = [self.stride, self.stride] if isinstance(self.stride, int) else self.stride\n        if (self.inserted_zeros > 0 and self.inserted_zeros + 1 == stride[0] and\n                stride[0] == stride[1] and (self.patches.size(-1) % stride[1]) == 0 and (self.patches.size(-2) % stride[0]) == 0):\n            # print(f'before simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}')\n            full_stride = [stride[1], stride[1], stride[0], stride[0]]\n            # output_padding = tuple(p // s for p, s in zip(output_padding, full_stride))\n            padding = unify_shape(self.padding)\n            # since inserted_zero will not put zeros to both end, like [x 0 0 x 0 0 x] instead of [x 0 0 x 0 0 x 0 0]\n            # when computing the simplified padding, we should view (inserted_zeros-1) padding entries from one end side\n            # as part of the inserted_zero matrices (i.e., \"consumed\")\n            consumed_padding = (padding[0], padding[1] - (stride[1] - 1), padding[2], padding[3] - (stride[0] - 1))\n            tentative_padding = tuple(p // s - o for p, s, o in zip(consumed_padding, full_stride, unify_shape(self.output_padding)))\n            # negative padding is inconvenient\n            if all([p >= 0 for p in tentative_padding]):\n                remove_zero_start_idx = (padding[2] % stride[0], padding[0] % stride[1])\n                self.padding = tentative_padding\n                self.patches = remove_zeros(self.patches, self.inserted_zeros, remove_zero_start_idx=remove_zero_start_idx)\n                self.stride = 1\n                self.inserted_zeros = 0\n                self.output_padding = 0\n                # print(f'after simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}')\n\n    def matmul(self, input, patch_abs=False, input_shape=None):\n        \"\"\"\n        Broadcast multiplication for patches and a matrix.\n\n        Input shape: (batch_size, in_c, in_h, in_w).\n        If the dim of in_c, in_h, in_w = 1, the the input will be expand by given input_shape to support broadcast\n\n        Output shape: [batch_size, unstable_size] when unstable_idx is not None,\n                      [batch_size, out_c, out_h, out_w] when unstable_idx is None,\n        \"\"\"\n\n        patches = self.patches\n        if patch_abs:\n            patches = patches.abs()\n\n        if input_shape is not None:\n            # For cases that input only has fewer dimensions like (1, in_c, 1, 1)\n            input = input.expand(input_shape)\n            # Expand to (batch_size, in_c, in_h, in_w)\n\n        # unfold the input as [batch_size, out_h, out_w, in_c, H, W]\n        unfold_input = inplace_unfold(\n            input, kernel_size=patches.shape[-2:],\n            padding=self.padding, stride=self.stride,\n            inserted_zeros=self.inserted_zeros, output_padding=self.output_padding)\n        if self.unstable_idx is not None:\n            # We need to add a out_c dimension and select from it.\n            unfold_input = unfold_input.unsqueeze(0).expand(self.output_shape[1], -1, -1, -1, -1, -1, -1)\n            # Shape: [unstable_size, batch_size, in_c, H, W].\n            # Here unfold_input will match this shape.\n            unfold_input = unfold_input[self.unstable_idx[0], :, self.unstable_idx[1], self.unstable_idx[2]]\n            # shape: [batch_size, unstable_size].\n            return torch.einsum('sbchw,sbchw->bs', unfold_input, patches)\n        else:\n            # shape: [batch_size, out_c, out_h, out_w].\n            return torch.einsum('bijchw,sbijchw->bsij', unfold_input, patches)\n\n    def create_padding(self, output_shape):\n        # patches was not padded, so we need to pad them here.\n        # If this layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again.\n        one_d_unfolded_r = create_valid_mask(\n            output_shape, self.patches.device,\n            self.patches.dtype,\n            self.patches.shape[-2:],\n            self.stride,\n            self.inserted_zeros,\n            self.padding,\n            self.output_padding,\n            self.unstable_idx if self.unstable_idx else None)\n        patches = self.patches * one_d_unfolded_r\n        return patches\n\n\ndef compute_patches_stride_padding(input_shape, patches_padding, patches_stride, op_padding, op_stride, inserted_zeros=0, output_padding=0, simplify=True):\n    \"\"\"\n    Compute stride and padding after a conv layer with patches mode.\n    \"\"\"\n    for p in (patches_padding, patches_stride, op_padding, op_stride):\n        assert isinstance(p, int) or (isinstance(p, (list, tuple)) and (len(p) == 2 or len(p) == 4))\n    # If p is int, then same padding on all 4 sides.\n    # If p is 2-tuple, then it is padding p[0] on both sides of H, p[1] on both sides of W\n    # If p is 4-tuple, then it is padding p[2], p[3] on top and bottom sides of H, p[0] and p[1] on left and right sides of W\n\n    # If any of the inputs are not tuple/list, we convert them to tuple.\n    full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [\n            (p, p) if isinstance(p, int) else p for p in [patches_padding, op_padding, patches_stride, op_stride]]\n    full_patch_padding, full_op_padding, full_patch_stride, full_op_stride = [\n            (p[1], p[1], p[0], p[0]) if len(p) == 2 else p for p in [full_patch_padding, full_op_padding, full_patch_stride, full_op_stride]]\n    # Compute the new padding and stride after this layer.\n    new_padding = tuple(pp * os + op * (inserted_zeros + 1) for pp, op, os in zip(full_patch_padding, full_op_padding, full_op_stride))\n    new_stride = tuple(ps * os for ps, os in zip(full_patch_stride, full_op_stride))\n\n    output_padding = unify_shape(output_padding)\n    new_output_padding = (output_padding[0],  # Left\n          output_padding[1] + inserted_zeros * input_shape[3] % full_op_stride[2],  # Right\n          output_padding[2],  # Top\n          output_padding[3] + inserted_zeros * input_shape[2] % full_op_stride[0])  # Bottom\n\n    # Merge into a single number if all numbers are identical.\n    if simplify:\n        if new_padding.count(new_padding[0]) == len(new_padding):\n            new_padding = new_padding[0]\n        if new_stride.count(new_stride[0]) == len(new_stride):\n            new_stride = new_stride[0]\n\n    return new_padding, new_stride, new_output_padding\n\n\ndef patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None,\n                      unstable_idx=None, inserted_zeros=0):\n    \"\"\"Converting a Patches piece into a full dense matrix.\"\"\"\n\n    # torch.as_strided may cause unpredictable error under deterministic mode,\n    # so we temporarily disable it.\n    deterministic = torch.are_deterministic_algorithms_enabled()\n    torch.use_deterministic_algorithms(False)\n\n    if type(padding) == int:\n        padding = (padding, padding, padding, padding)\n\n    if pieces.ndim == 9:\n        # Squeeze two additional dimensions for output and input respectively\n        assert pieces.shape[1] == 1 and pieces.shape[5] == 1\n        pieces = pieces.reshape(\n            pieces.shape[0], *pieces.shape[2:5],\n            *pieces.shape[6:]\n        )\n\n    if unstable_idx is None:\n        assert pieces.ndim == 7\n        # Non-sparse pieces, with shape (out_c, batch, out_h, out_w, c, h, w).\n        output_channel, batch_size, output_x, output_y = pieces.shape[:4]\n    else:\n        batch_size = pieces.shape[1]\n        output_channel, output_x, output_y = output_shape[1:]\n    input_channel, kernel_x, kernel_y = pieces.shape[-3:]\n    input_x, input_y = input_shape[-2:]\n\n    if inserted_zeros > 0:\n        input_x, input_y = (input_x - 1) * (inserted_zeros + 1) + 1, (input_y - 1) * (inserted_zeros + 1) + 1\n\n    if unstable_idx is None:\n        # Fix all patches in a full A matrix.\n        A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype)\n        # Save its orignal stride.\n        orig_stride = A_matrix.stride()\n        # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.\n        # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.\n        matrix_strided = torch.as_strided(A_matrix, [batch_size, output_channel, output_x, output_y, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], orig_stride[2], orig_stride[3], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[4], input_y + padding[0] + padding[1], 1])\n        # Now we need to fill the conv kernel parameters into the last three dimensions of matrix_strided.\n        first_indices = torch.arange(output_x * output_y, device=pieces.device)\n        second_indices = torch.div(first_indices, output_y, rounding_mode=\"trunc\")\n        third_indices = torch.fmod(first_indices, output_y)\n        # pieces have shape (out_c, batch, out_h, out_w, c, h, w).\n        pieces = pieces.transpose(0, 1)   # pieces has the out_c dimension at the front, need to move it to the second.\n        matrix_strided[:,:,second_indices,third_indices,second_indices,third_indices,:,:,:] = pieces.reshape(*pieces.shape[:2], -1, *pieces.shape[4:])\n        A_matrix = A_matrix.view(batch_size, output_channel * output_x * output_y, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])\n    else:\n        # Fill only a selection of patches.\n        # Create only a partial A matrix.\n        unstable_size = unstable_idx[0].numel()\n        A_matrix = torch.zeros(batch_size, unstable_size, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype)\n        # Save its orignal stride.\n        orig_stride = A_matrix.stride()\n        # This is the main trick - we create a *view* of the original matrix, and it contains all sliding windows for the convolution.\n        # Since we only created a view (in fact, only metadata of the matrix changed), it should be very efficient.\n        matrix_strided = torch.as_strided(A_matrix, [batch_size, unstable_size, output_x, output_y, input_channel, kernel_x, kernel_y], [orig_stride[0], orig_stride[1], (input_x + padding[2] + padding[3]) * stride, stride, orig_stride[2], input_y + padding[0] + padding[1], 1])\n        # pieces have shape (unstable_size, batch, c, h, w).\n        first_indices = torch.arange(unstable_size, device=pieces.device)\n        matrix_strided[:,first_indices,unstable_idx[1],unstable_idx[2],:,:,:] = pieces.transpose(0, 1).to(matrix_strided)\n        A_matrix = A_matrix.view(batch_size, unstable_size, input_channel, input_x + padding[2] + padding[3], input_y + padding[0] + padding[1])\n\n    A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]]\n\n    if inserted_zeros > 0:\n        A_matrix = A_matrix[:,:,:, ::(inserted_zeros+1), ::(inserted_zeros+1)]\n\n    # Re-enable deterministic if needed.\n    torch.use_deterministic_algorithms(deterministic)\n\n    return A_matrix\n\n\ndef check_patch_biases(lb, ub, lower_b, upper_b):\n    # When we use patches mode, it's possible that we need to add two bias\n    # one is from the Tensor mode and one is from the patches mode\n    # And we need to detect this case and reshape the bias\n    if lower_b.ndim < lb.ndim:\n        lb = lb.transpose(0,1).reshape(lb.size(1), lb.size(0), -1)\n        lb = lb.expand(lb.size(0), lb.size(1), lower_b.size(0)//lb.size(1))\n        lb = lb.reshape(lb.size(0), -1).t()\n        ub = ub.transpose(0,1).reshape(ub.size(1), ub.size(0), -1)\n        ub = ub.expand(ub.size(0), ub.size(1), upper_b.size(0)//ub.size(1))\n        ub = ub.reshape(ub.size(0), -1).t()\n    elif lower_b.ndim > lb.ndim:\n        lower_b = lower_b.transpose(0, 1).reshape(lower_b.size(1), -1).t()\n        upper_b = upper_b.transpose(0, 1).reshape(upper_b.size(1), -1).t()\n    return lb, ub, lower_b, upper_b\n\n\ndef inplace_unfold(image, kernel_size, stride=1, padding=0, inserted_zeros=0, output_padding=0):\n    # Image has size (batch_size, channel, height, width).\n    assert image.ndim == 4\n    if isinstance(kernel_size, int):\n        kernel_size = (kernel_size, kernel_size)\n    if isinstance(padding, int):\n        padding = (padding, padding, padding, padding)  # (left, right, top, bottom).\n    if len(padding) == 2:  # (height direction, width direction).\n        padding = (padding[1], padding[1], padding[0], padding[0])\n    if isinstance(output_padding, int):\n        output_padding = (output_padding, output_padding, output_padding, output_padding)  # (left, right, top, bottom).\n    if len(output_padding) == 2:  # (height direction, width direction).\n        output_padding = (output_padding[1], output_padding[1], output_padding[0], output_padding[0])\n    if isinstance(stride, int):\n        stride = (stride, stride)  # (height direction, width direction).\n    assert len(kernel_size) == 2 and len(padding) == 4 and len(stride) == 2\n    # Make sure the image is large enough for the kernel.\n    assert image.size(2) + padding[2] + padding[3] >= kernel_size[0] and image.size(3) + padding[0] + padding[1] >= kernel_size[1]\n    if inserted_zeros > 0:\n        # We first need to insert zeros in the image before unfolding.\n        image = insert_zeros(image, inserted_zeros)\n        # padding = (padding[0], padding[1] + 1, padding[2], padding[3] + 1)\n    # Compute the number of patches.\n    # Formulation: https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold\n    patches_h = int((image.size(2) + padding[2] + padding[3] - (kernel_size[0] - 1) - 1) / stride[0] + 1)\n    patches_w = int((image.size(3) + padding[0] + padding[1] - (kernel_size[1] - 1) - 1) / stride[1] + 1)\n    # Pad image.\n    if sum(padding) != 0:\n        image = torch.nn.functional.pad(image, padding)\n    # Save its orignal stride.\n    image_stride = image.stride()\n    matrix_strided = torch.as_strided(image, [\n        # Shape of the output matrix.\n        image.size(0),  # Batch size.\n        patches_h,  # indices for each patch.\n        patches_w,\n        image.size(1),  # Channel.\n        kernel_size[0],   # indices for each pixel on a patch.\n        kernel_size[1]], [\n        # Stride of the output matrix.\n        image_stride[0],  # Batch size dimension, keep using the old stride.\n        image_stride[2] * stride[0],  # Move patch in the height dimension.\n        image_stride[3] * stride[1],  # Move patch in the width dimension.\n        image_stride[1],  # Move to the next channel.\n        image_stride[2],  # Move to the next row.\n        image_stride[3]])  # Move a pixel (on the width direction).\n    # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width)\n    if sum(output_padding) > 0:\n      output_padding = tuple(p if p > 0 else None for p in output_padding)\n      matrix_strided = matrix_strided[:, output_padding[2]:-output_padding[3] if output_padding[3] is not None else None,\n                                      output_padding[0]:-output_padding[1] if output_padding[1] is not None else None, :, :, :]\n    return matrix_strided\n\n\ndef maybe_unfold_patches(d_tensor, last_A, alpha_lookup_idx=None):\n    \"\"\"\n    Utility function to handle patch mode bound propagation in activation functions.\n    In patches mode, we need to unfold lower and upper slopes (as input \"d_tensor\").\n    In matrix mode we simply return.\n    \"\"\"\n    if d_tensor is None or last_A is None or isinstance(last_A, Tensor):\n        return d_tensor\n\n    # Shape for d_tensor:\n    #   sparse: [spec, batch, in_c, in_h, in_w]\n    #   non-sparse (partially shared): [out_c, batch, in_c, in_h, in_w]\n    #   non-sparse (not shared): [out_c*out_h*out_w, batch, in_c, in_h, in_w]\n    #   shared (independent of output spec): [1, batch, in_c, in_h, in_w]\n    # The in_h, in_w dimensions must be unfolded as patches.\n    origin_d_shape = d_tensor.shape\n    if d_tensor.ndim == 6:\n        # Merge the (out_h, out_w) dimensions.\n        d_tensor = d_tensor.view(*origin_d_shape[:2], -1, *origin_d_shape[-2:])\n    d_shape = d_tensor.size()\n    # Reshape to 4-D tensor to unfold.\n    d_tensor = d_tensor.view(-1, *d_tensor.shape[-3:])\n    # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W).\n    d_unfolded = inplace_unfold(\n        d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride,\n        padding=last_A.padding, inserted_zeros=last_A.inserted_zeros,\n        output_padding=last_A.output_padding)\n    # Reshape to the original shape of d, e.g., for non-sparse it is (out_c, batch, out_h, out_w, in_c, H, W).\n    d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:])\n    if last_A.unstable_idx is not None:\n        # Here we have d for all output neurons, but we only need to select unstable ones.\n        if d_unfolded_r.size(0) == 1 and alpha_lookup_idx is None:\n            # Shared alpha, spasre alpha should not be used.\n            # Note: only d_unfolded_r.size(0) == 1 cannot judge that it is a shared alpha,\n            #   since the activation may have no unstable neuron at all so\n            #   the first dim = 1 + # unstable neuron still equals to 1\n            if len(last_A.unstable_idx) == 3:\n                # Broadcast the spec shape, so only need to select the rest dimensions.\n                # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W).\n                d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)\n                d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]]\n            elif len(last_A.unstable_idx) == 4:\n                # [spec, batch, output_h, output_w, input_c, H, W]\n                # to [output_h, output_w, batch, in_c, H, W]\n                d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5)\n                d_unfolded_r = d_unfolded_r[last_A.unstable_idx[2], last_A.unstable_idx[3]]\n            else:\n                raise NotImplementedError()\n            # output shape: (unstable_size, batch, in_c, H, W).\n        else:\n            # The spec dimension may be sparse and contains unstable neurons for the spec layer only.\n            if alpha_lookup_idx is None:\n                # alpha is spec-dense. Possible because the number of unstable neurons may decrease.\n                if last_A.output_shape[1] == d_unfolded_r.size(0):\n                    # Non spec-sparse, partially shared alpha among output channel dimension.\n                    # Shape after unfolding is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w).\n                    d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                else:\n                    # Non spec-sparse, non-shared alpha.\n                    # Shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w).\n                    # Reshaped to (out_c, out_h, out_w, batch, out_h, out_w, in_c, patch_h, patch_w).\n                    d_unfolded_r = d_unfolded_r.view(last_A.shape[0], last_A.shape[2], last_A.shape[3], -1, *d_unfolded_r.shape[2:])\n                    # Select on all out_c, out_h, out_w dimensions.\n                    d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], last_A.unstable_idx[1],\n                            last_A.unstable_idx[2], :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n            elif alpha_lookup_idx.ndim == 1:\n                # sparse alpha: [spec, batch, in_c, in_h, in_w]\n                # Partially shared alpha on the spec dimension - all output neurons on the same channel use the same alpha.\n                # If alpha_lookup_idx is not None, we need to convert the sparse indices using alpha_lookup_idx.\n                _unstable_idx = alpha_lookup_idx[last_A.unstable_idx[0]]\n                # The selection is only used on the channel dimension.\n                d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n            elif alpha_lookup_idx is not None and alpha_lookup_idx.ndim == 3:\n                # sparse alpha: [spec, batch, in_c, in_h, in_w]\n                # We created alpha as full output shape; alpha not shared among channel dimension.\n                # Shape of alpha is (out_c*out_h*out_w, batch, in_c, in_h, in_w), note that the first 3 dimensions\n                # is merged into one to allow simpler selection.\n                _unstable_idx = alpha_lookup_idx[\n                    last_A.unstable_idx[0],\n                    last_A.unstable_idx[1],\n                    last_A.unstable_idx[2]]\n                # d_unfolded_r shape from (out_c, batch, out_h, out_w, in_c, in_h, in_w)\n                # to (out_c * out_h * out_w(sparse), batch, in_c, in_h, in_w)\n                # Note that the dimensions out_h, out_w come from unfolding, not specs in alpha, so they will be selected\n                # directly without translating using the lookup table.\n                d_unfolded_r = d_unfolded_r[_unstable_idx, :, last_A.unstable_idx[1], last_A.unstable_idx[2]]\n                # after selection we return (unstable_size, batch_size, in_c, H, W)\n                return d_unfolded_r\n            else:\n                raise ValueError\n    else:\n        # A is not sparse. Alpha shouldn't be sparse as well.\n        assert alpha_lookup_idx is None\n        if last_A.patches.size(0) != d_unfolded_r.size(0) and d_unfolded_r.size(0) != 1:\n            # Non-shared alpha, shape after unfolding is (out_c*out_h*out_w, batch, out_h, out_w, in_c, patch_h, patch_w).\n            # Reshaped to (out_c, out_h*out_w, batch, out_h*out_w, in_c, patch_h, patch_w).\n            d_unfolded_r = d_unfolded_r.reshape(last_A.shape[0], last_A.shape[2] * last_A.shape[3], -1,\n                    d_unfolded_r.shape[2] * d_unfolded_r.shape[3], *d_unfolded_r.shape[4:])\n            # Select the \"diagonal\" elements in the out_h*out_w dimension.\n            # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h*out_w)\n            d_unfolded_r = d_unfolded_r.diagonal(offset=0, dim1=1, dim2=3)\n            # New shape is (out_c, batch, in_c, patch_h, patch_w, out_h, out_w)\n            d_unfolded_r = d_unfolded_r.view(*d_unfolded_r.shape[:-1], last_A.shape[2], last_A.shape[3])\n            # New shape is (out_c, batch, out_h, out_w, in_c, patch_h, patch_w)\n            d_unfolded_r = d_unfolded_r.permute(0, 1, 5, 6, 2, 3, 4)\n\n\n    # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W).\n    # For regular patches, the shape after unfold is (out_c, batch, out_h, out_w, in_c, H, W).\n    if d_unfolded_r.ndim != last_A.patches.ndim:\n        # For the situation of d independent of output neuron (e.g., vanilla crown bound), which does not have\n        # the out_h, out_w dimension and out_c = 1 (sepc). We added 1s for the out_h, out_w dimensions.\n        d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4)\n    return d_unfolded_r\n\ndef create_valid_mask(output_shape, device, dtype, kernel_size, stride, inserted_zeros, padding, output_padding,\n                      unstable_idx=None):\n    \"\"\"\n        Create a 0-1 mask of patch pieces shape (except batch dim),\n        where 1 indicates the cells corresponding to valid image pixels\n        Can be used to mask out unused A cells\n    :return: tensor of batch pieces shape, containing the binary mask\n    \"\"\"\n    one_d = torch.ones(\n        tuple(1 for i in output_shape[1:]),\n        device=device, dtype=dtype\n    ).expand(output_shape[1:])\n    # Add batch dimension.\n    one_d = one_d.unsqueeze(0)\n    # After unfolding, the shape is (1, out_h, out_w, in_c, h, w)\n    one_d_unfolded = inplace_unfold(\n        one_d, kernel_size=kernel_size,\n        stride=stride, padding=padding,\n        inserted_zeros=inserted_zeros,\n        output_padding=output_padding)\n    if unstable_idx is not None:\n        # Move out_h, out_w dimension to the front for easier selection.\n        ans = one_d_unfolded.permute(1, 2, 0, 3, 4, 5)\n        # for sparse patches the shape is (unstable_size, batch, in_c, h, w).\n        # Batch size is 1 so no need to select here.\n        ans = ans[unstable_idx[1], unstable_idx[2]]\n    else:\n        # Append the spec dimension.\n        ans = one_d_unfolded.unsqueeze(0)\n    return ans\n"
  },
  {
    "path": "auto_LiRPA/perturbations.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport json\nimport math\nimport os\nimport numpy as np\nimport torch\nfrom .utils import logger, eyeC\nfrom .patches import Patches, patches_to_matrix\nfrom .linear_bound import LinearBound\n\nfrom .concretize_func import constraints_solving, sort_out_constr_batches, construct_constraints\n\nclass Perturbation:\n    r\"\"\"\n    Base class for a perturbation specification. Please see examples\n    at `auto_LiRPA/perturbations.py`.\n\n    Examples:\n\n    * `PerturbationLpNorm`: Lp-norm (p>=1) perturbation.\n\n    * `PerturbationL0Norm`: L0-norm perturbation.\n\n    * `PerturbationSynonym`: Synonym substitution perturbation for NLP.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def set_eps(self, eps):\n        self.eps = eps\n\n    def concretize(self, x, A, sign=-1, aux=None):\n        r\"\"\"\n        Concretize bounds according to the perturbation specification.\n\n        Args:\n            x (Tensor): Input before perturbation.\n\n            A (Tensor) : A matrix from LiRPA computation.\n\n            sign (-1 or +1): If -1, concretize for lower bound; if +1, concretize for upper bound.\n\n            aux (object, optional): Auxilary information for concretization.\n\n        Returns:\n            bound (Tensor): concretized bound with the shape equal to the clean output.\n        \"\"\"\n        raise NotImplementedError\n\n    def init(self, x, aux=None, forward=False):\n        r\"\"\"\n        Initialize bounds before LiRPA computation.\n\n        Args:\n            x (Tensor): Input before perturbation.\n\n            aux (object, optional): Auxilary information.\n\n            forward (bool): It indicates whether forward mode LiRPA is involved.\n\n        Returns:\n            bound (LinearBound): Initialized bounds.\n\n            center (Tensor): Center of perturbation. It can simply be `x`, or some other value.\n\n            aux (object, optional): Auxilary information. Bound initialization may modify or add auxilary information.\n        \"\"\"\n\n        raise NotImplementedError\n\n\nclass PerturbationL0Norm(Perturbation):\n    \"\"\"Perturbation constrained by the L_0 norm.\n\n    Assuming input data is in the range of 0-1.\n    \"\"\"\n\n    def __init__(self, eps, x_L=None, x_U=None, ratio=1.0):\n        self.eps = eps\n        self.x_U = x_U\n        self.x_L = x_L\n        self.ratio = ratio\n\n    def concretize(self, x, A, sign=-1, aux=None):\n        if A is None:\n            return None\n\n        eps = math.ceil(self.eps)\n        x = x.reshape(x.shape[0], -1, 1)\n        center = A.matmul(x)\n\n        x = x.reshape(x.shape[0], 1, -1)\n\n        original = A * x.expand(x.shape[0], A.shape[-2], x.shape[2])\n        neg_mask = A < 0\n        pos_mask = A >= 0\n\n        if sign == 1:\n            A_diff = torch.zeros_like(A)\n            A_diff[pos_mask] = A[pos_mask] - original[pos_mask]# changes that one weight can contribute to the value\n            A_diff[neg_mask] = - original[neg_mask]\n        else:\n            A_diff = torch.zeros_like(A)\n            A_diff[pos_mask] = original[pos_mask]\n            A_diff[neg_mask] = original[neg_mask] - A[neg_mask]\n\n        # FIXME: this assumes the input pixel range is between 0 and 1!\n        A_diff, _= torch.sort(A_diff, dim = 2, descending=True)\n\n        bound = center + sign * A_diff[:, :, :eps].sum(dim = 2).unsqueeze(2) * self.ratio\n\n        return bound.squeeze(2)\n\n    def init(self, x, aux=None, forward=False):\n        # For other norms, we pass in the BoundedTensor objects directly.\n        x_L = x\n        x_U = x\n        if not forward:\n            return LinearBound(None, None, None, None, x_L, x_U), x, None\n        batch_size = x.shape[0]\n        dim = x.reshape(batch_size, -1).shape[-1]\n        eye = torch.eye(dim).to(x.device).unsqueeze(0).repeat(batch_size, 1, 1)\n        lw = eye.reshape(batch_size, dim, *x.shape[1:])\n        lb = torch.zeros_like(x).to(x.device)\n        uw, ub = lw.clone(), lb.clone()\n        return LinearBound(lw, lb, uw, ub, x_L, x_U), x, None\n\n    def __repr__(self):\n        return 'PerturbationLpNorm(norm=0, eps={})'.format(self.eps)\n\n\nclass PerturbationLpNorm(Perturbation):\n    \"\"\"Perturbation constrained by the L_p norm.\"\"\"\n    def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None, eps_min=0, \n                 constraints=None, rearrange_constraints=False, no_return_inf=False, timer=None):\n        r\"\"\"\n        Initialize a p-norm perturbation instance.\n        There are two ways to initialize it:\n            -- x_L, x_U: (Higher priority)\n            -- eps     : (Lower priority)\n        If use eps to initialize it, the centroid x (or x0 as in the member attribute) will be\n            passed into `init` and `concretize` function.  \n        For the shape notations such as 'B' or 'X', please check the shape declaration \n            at the beginning of concretize_func.py\n\n        Args:\n            eps (Tensor): The epsilon tensor, it represents the pertubation added to a BoundedTensor.\n            norm (int or torch.inf): The p in p-norm perturbation.\n            x_L (Tensor): Lower bound of input box, shape (B, *input_shape[1:]).\n            x_U (Tensor): Upper bound of input box, shape (B, *input_shape[1:]).\n            eps_min ()\n            constraints (Tuple[Tensor, Tensor] or None): \n                A tuple `(A, b)` representing per-batch linear constraints.\n                - `A`: shape (B, N_constr, X)\n                - `b`: shape (B, N_constr)\n            rearrange_constraints (bool): \n                Whether to rearrange constraints for better solver performance. Default: False.\n            no_return_inf (bool): \n                If True, infeasible batches will be excluded from `active_indices`.\n                Otherwise, infeasible batches are still marked active. Default: False.\n                Please check `constraints_solving` and `sort_out_constr_batches` for more details.\n            timer (Timer):\n                A timer recording the concretization time.\n        \"\"\"        \n        self.eps = eps\n        self.x0 = None\n\n        # For p = inf, pre-compute x0 and eps would accerlerate the concretize function.\n        if norm == np.inf and x_L is not None and x_U is not None:\n            self.eps = (x_U - x_L) / 2\n            self.x0 = (x_U + x_L) / 2\n        \n        # x0_act and eps_act stands for x0 and eps matrix for batches with active constraints\n        self.x0_act = None          # shape (batchsize, *X_shape)\n        self.eps_act = None         # shape (batchsize, *X_shape)\n        # x0_sparse_act and eps_sparse_act are the active sparse x0 and eps matrix when sparse perturbation is enabled.\n        # Check init_sparse_linf to see how sparse x0, eps, x0_act, eps_act are created.\n        self.x0_sparse_act = None   # shape (batchsize, *X_sparse_shape)\n        self.eps_sparse_act = None  # shape (batchsize, *X_sparse_shape)\n\n        self.eps_min = eps_min\n        self.norm = norm\n        self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) / (1 - 1.0 / self.norm))\n        self.x_L = x_L\n        self.x_U = x_U\n        self.sparse = False\n\n        self.timer = timer\n        self.aux_lb = None\n        self.aux_ub = None\n\n        self.rearrange_constraints = rearrange_constraints\n\n        # constraints is a tuple containing both the coefficient matrix and bias term\n        # of the constraints. The constraints would appear in the form of:\n        #                           A_c * x + b_c <= 0\n        # Coefficient matrix will be reshaped into (batchsize, # of constraints,\n        # input_dim). Bias term will be reshaped into (batchsize, # of constraints)\n        # also see in `constraints_solving` in constraints_solver.py\n\n        # Pre-process the constraints.\n        self.constraints, self.sorted_out_batches = sort_out_constr_batches(x_L, x_U, constraints, \n                                                                            rearrange_constraints=rearrange_constraints,\n                                                                            no_return_inf=no_return_inf)\n        # The indices of hidden neurons to apply constraints.\n        self.objective_indices = None   # shape: (batchsize, num_of_neurons)\n        if self.constraints is None or self.constraints[0].numel() == 0:\n            self._constraints_enable = False\n        else:\n            self._constraints_enable = True\n        self.no_return_inf = no_return_inf\n\n        self._use_grad = False\n\n    def get_input_bounds(self, x, A):\n        if self.sparse:\n            if self.x_L_sparse.shape[-1] == A.shape[-1]:\n                x_L, x_U = self.x_L_sparse, self.x_U_sparse\n                act_x0, act_eps = self.x0_sparse_act, self.eps_sparse_act\n            else:\n                # In backward mode, A is not sparse.\n                x_L, x_U = self.x_L, self.x_U\n                act_x0, act_eps = self.x0_act, self.eps_act\n        else:\n            x_L = x - self.eps if self.x_L is None else self.x_L\n            x_U = x + self.eps if self.x_U is None else self.x_U\n            act_x0, act_eps = self.x0_act, self.eps_act\n        return x_L, x_U, act_x0, act_eps\n\n    def get_constraints(self, A):\n        if self.constraints is None:\n            return None\n        if self.sparse and self.x_L_sparse.shape[-1] == A.shape[-1]:\n            return self.constraints_sparse\n        else:\n            return self.constraints\n\n    def concretize_matrix(self, x, A, sign, constraints=None):\n        # If A is an identity matrix, we will handle specially.\n        if not isinstance(A, eyeC):\n            # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size.\n            A = A.reshape(A.shape[0], A.shape[1], -1)\n\n        if self.norm == np.inf:\n            x_L, x_U, act_x0, act_eps = self.get_input_bounds(x, A)\n            if constraints is None:\n                constraints = self.get_constraints(A)\n            # The original code for matrix concretize has been merged into `constraints_solving`.\n            # Pick out auxiliary bound based on the sign.\n            aux_bounds = self.aux_lb if sign == -1.0 else self.aux_ub\n            results = constraints_solving(x_L, x_U, A, constraints, sign,\n                                        sorted_out_batches=self.sorted_out_batches, objective_indices=self.objective_indices, \n                                        constraints_enable=self._constraints_enable, no_return_inf=self.no_return_inf,\n                                        timer=self.timer, \n                                        aux_bounds=aux_bounds, act_x0=act_x0, act_eps=act_eps,\n                                        use_grad=self._use_grad)\n            \n            if self.no_return_inf:\n                # return: (bound, infeasible_bounds)\n                bound = results[0]\n                infeasible_bounds = results[1]\n                self.add_infeasible_batches(infeasible_bounds)\n            else:\n                # return: bound\n                bound = results\n        else:\n            x = x.reshape(x.shape[0], -1, 1)\n            if not isinstance(A, eyeC):\n                # Find the upper and lower bounds via dual norm.\n                deviation = A.norm(self.dual_norm, -1) * self.eps\n                bound = A.matmul(x) + sign * deviation.unsqueeze(-1)\n            else:\n                # A is an identity matrix. Its norm is all 1.\n                bound = x + sign * self.eps\n        bound = bound.squeeze(-1)\n        return bound\n\n    def concretize_patches(self, x, A, sign):\n        if self.norm == np.inf:\n            x_L, x_U, _, _,  = self.get_input_bounds(x, A)\n\n            # Here we should not reshape\n            # Find the uppwer and lower bound similarly to IBP.\n            center = (x_U + x_L) / 2.0\n            diff = (x_U - x_L) / 2.0\n\n            if not A.identity == 1:\n                bound = A.matmul(center)\n                bound_diff = A.matmul(diff, patch_abs=True)\n                if sign == 1:\n                    bound += bound_diff\n                elif sign == -1:\n                    bound -= bound_diff\n                else:\n                    raise ValueError(\"Unsupported Sign\")\n            else:\n                # A is an identity matrix. No need to do this matmul.\n                bound = center + sign * diff\n            return bound\n        else:  # Lp norm\n            input_shape = x.shape\n            if not A.identity:\n                # Find the upper and lower bounds via dual norm.\n                # matrix has shape\n                # (batch_size, out_c * out_h * out_w, input_c, input_h, input_w)\n                # or (batch_size, unstable_size, input_c, input_h, input_w)\n                matrix = patches_to_matrix(\n                    A.patches, input_shape, A.stride, A.padding, A.output_shape,\n                    A.unstable_idx)\n                # Note that we should avoid reshape the matrix.\n                # Due to padding, matrix cannot be reshaped without copying.\n                deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps\n                # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size).\n                bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation\n                if A.unstable_idx is None:\n                    # Reshape to (batch, out_c, out_h, out_w).\n                    bound = bound.view(matrix.size(0), A.patches.size(0),\n                                       A.patches.size(2), A.patches.size(3))\n            else:\n                # A is an identity matrix. Its norm is all 1.\n                bound = x + sign * self.eps\n            return bound\n\n    def concretize(self, x, A, sign=-1, constraints=None, aux=None):\n        \"\"\"Given an variable x and its bound matrix A, compute worst case bound according to Lp norm.\"\"\"\n        if A is None:\n            return None\n        if isinstance(A, eyeC) or isinstance(A, torch.Tensor):\n            ret = self.concretize_matrix(x, A, sign, constraints)\n        elif isinstance(A, Patches):\n            ret = self.concretize_patches(x, A, sign)\n        else:\n            raise NotImplementedError()\n        if ret.ndim > 2:\n            ret = ret.reshape(A.shape[1], -1)\n        return ret\n\n    def init_sparse_linf(self, x, x_L, x_U):\n        \"\"\" Sparse Linf perturbation where only a few dimensions are actually perturbed\"\"\"\n        self.sparse = True\n        batch_size = x_L.shape[0]\n        perturbed = (x_U > x_L).int()\n        logger.debug(f'Perturbed: {perturbed.sum()}')\n        lb = ub = x_L * (1 - perturbed) # x_L=x_U holds when perturbed=0\n        perturbed = perturbed.view(batch_size, -1)\n        index = torch.cumsum(perturbed, dim=-1)\n        dim = max(perturbed.view(batch_size, -1).sum(dim=-1).max(), 1)\n        self.x_L_sparse = torch.zeros(batch_size, dim + 1).to(x_L)\n        self.x_L_sparse.scatter_(dim=-1, index=index, src=(x_L - lb).view(batch_size, -1), reduce='add')\n        self.x_U_sparse = torch.zeros(batch_size, dim + 1).to(x_U)\n        self.x_U_sparse.scatter_(dim=-1, index=index, src=(x_U - ub).view(batch_size, -1), reduce='add')\n        self.x_L_sparse, self.x_U_sparse = self.x_L_sparse[:, 1:], self.x_U_sparse[:, 1:]\n        \n        # --- create x0 and eps for Lp Norm\n        self.x0_sparse = (self.x_L_sparse + self.x_U_sparse) / 2\n        self.eps_sparse = (self.x_U_sparse - self.x_L_sparse) / 2\n        if self.sorted_out_batches is not None:\n            active_indices = self.sorted_out_batches[\"active_indices\"]\n            self.x0_sparse_act = self.x0_sparse[active_indices].unsqueeze(-1)\n            self.eps_sparse_act = self.eps_sparse[active_indices].unsqueeze(-1)\n\n        lw = torch.zeros(batch_size, dim + 1, perturbed.shape[-1], device=x.device)\n        perturbed = perturbed.to(torch.get_default_dtype())\n        lw.scatter_(dim=1, index=index.unsqueeze(1), src=perturbed.unsqueeze(1))\n        lw = uw = lw[:, 1:, :].view(batch_size, dim, *x.shape[1:])\n        print(f'Using Linf sparse perturbation. Perturbed dimensions: {dim}.')\n        print(f'Avg perturbation: {(self.x_U_sparse - self.x_L_sparse).mean()}')\n\n        # When sparse linf is enabled, the input x perturbation would change its shape\n        # Hence, the shape of constraints_A should change accordingly.\n        # But for the final layer, we still need the dense linf, and use the original (dense) constraints\n        if self.constraints is not None:\n            # constraints_A: (batchsize, n_constraints, x_dim)\n            constraints_A, constraints_b = self.constraints\n            # reversed_lw: (batchsize, x_dim, sparse_dim)\n            reversed_lw = lw.reshape((batch_size, dim, -1)).transpose(1, 2)\n            lb_act = lb\n            # When pre-processing the constraints, we only kept the active ones.\n            # Hence, reversed_lw and lb_act should also be re-collected.\n            active_indices = self.sorted_out_batches[\"active_indices\"]\n            reversed_lw = reversed_lw[active_indices]\n            lb_act = lb_act[active_indices]\n            # reversed lw will sort out the sparse dimensions out of all x dimension\n            new_constraints_A = constraints_A.bmm(reversed_lw)\n            # Besides original constraint_b, should also include the a*x terms where x is not perturbed\n            # new_constraints_b = constraints_b + torch.einsum(\"bcx, bx -> bc\", constraints_A, lb_act.flatten(1))\n            new_constraints_b = constraints_b\n            self.constraints_sparse = (new_constraints_A, new_constraints_b)\n        return LinearBound(\n            lw, lb, uw, ub, x_L, x_U), x, None\n\n    def init(self, x, aux=None, forward=False):\n        self.sparse = False\n        if self.norm == np.inf:\n            x_L = x - self.eps if self.x_L is None else self.x_L\n            x_U = x + self.eps if self.x_U is None else self.x_U\n        else:\n            if int(os.environ.get('AUTOLIRPA_L2_DEBUG', 0)) == 1:\n                # FIXME Experimental code. Need to change the IBP code also.\n                x_L = x - self.eps if self.x_L is None else self.x_L\n                x_U = x + self.eps if self.x_U is None else self.x_U\n            else:\n                # FIXME This causes confusing lower bound and upper bound\n                # For other norms, we pass in the BoundedTensor objects directly.\n                x_L = x_U = x\n\n        if self.x_L is not None and self.x_U is not None:\n            self.x0 = (self.x_L + self.x_U) / 2\n        else:\n            self.x0 = x.data\n        if self.sorted_out_batches is not None and self.sorted_out_batches.get(\"active_indices\") is not None:\n            active_indices = self.sorted_out_batches[\"active_indices\"]\n            self.x0_act = self.x0[active_indices].flatten(1).unsqueeze(-1)\n            self.eps_act = self.eps[active_indices].flatten(1).unsqueeze(-1)\n\n        if not forward:\n            return LinearBound(\n                None, None, None, None, x_L, x_U), x, None\n        if (self.norm == np.inf and x_L.numel() > 1\n                and (x_L == x_U).sum() > 0.5 * x_L.numel()):\n            return self.init_sparse_linf(x, x_L, x_U)\n\n        batch_size = x.shape[0]\n        dim = x.reshape(batch_size, -1).shape[-1]\n        lb = ub = torch.zeros_like(x)\n        eye = torch.eye(dim).to(x).expand(batch_size, dim, dim)\n        lw = uw = eye.reshape(batch_size, dim, *x.shape[1:])\n        return LinearBound(\n            lw, lb, uw, ub, x_L, x_U), x, None\n\n    def add_infeasible_batches(self, infeasible_batches):\n        r\"\"\"\n        Synchronize the `infeasible_batches` tensor between the global graph and the local perturbation node.\n\n        If the computation graph includes multiple perturbed inputs, the BoundedModule (entire network) maintains a global\n        `infeasible_batches` tensor, while each perturbed input (root) keeps its own local copy.\n\n        - Before concretization: copy the global tensor to the local one.\n        - After concretization: propagate updates from the local tensor back to the global tensor.\n\n        Args:\n            infeasible_batches: A boolean vector with shape (batchsize, ). A True value indicates that a batch is infeasible\n                                given its constraints.\n        \"\"\"\n        if self.constraints is not None and infeasible_batches is not None and infeasible_batches.any():\n            if self.sorted_out_batches[\"infeasible_batches\"] is None:\n                self.sorted_out_batches[\"infeasible_batches\"] = infeasible_batches\n            else:\n                infeasible_batches = infeasible_batches | self.sorted_out_batches[\"infeasible_batches\"]\n                self.sorted_out_batches[\"infeasible_batches\"] = infeasible_batches\n            \n            active_indices = self.sorted_out_batches[\"active_indices\"]\n            B_act = active_indices.numel()\n            active_feasible_mask = (~infeasible_batches)[active_indices]\n            if active_feasible_mask.sum() < B_act:\n                self.sorted_out_batches[\"active_indices\"] = active_indices[active_feasible_mask]\n                self.x0_act = self.x0_act[active_feasible_mask]\n                self.eps_act = self.eps_act[active_feasible_mask]\n                constraints_A, constraints_b = self.constraints\n                constraints_A = constraints_A[active_feasible_mask]\n                constraints_b = constraints_b[active_feasible_mask]\n                self.constraints = (constraints_A, constraints_b)\n\n    def add_objective_indices(self, objective_indices):\n        if self.constraints is not None:\n            self.objective_indices = objective_indices\n\n    @property\n    def constraints_enable(self):\n        '''\n        Enable / Disable the constrained concretize mode, regardless whether constraints is None or not. \n        '''\n        return self._constraints_enable\n    \n    @constraints_enable.setter\n    def constraints_enable(self, enable: bool):\n        self._constraints_enable = enable\n\n    @constraints_enable.deleter\n    def constraints_enable(self):\n        del self._constraints_enable  \n\n    @property\n    def use_grad(self):\n        '''\n        Enable / Disable the constrained concretize with gradient. \n        '''\n        return self._use_grad\n    \n    @use_grad.setter\n    def use_grad(self, use_grad: bool):\n        self._use_grad = use_grad\n\n    @use_grad.deleter\n    def use_grad(self):\n        del self._use_grad  \n\n    def add_aux_bounds(self, aux_lb, aux_ub):\n        self.aux_lb = aux_lb\n        self.aux_ub = aux_ub\n\n    def clear_aux_bounds(self):\n        self.aux_lb = None\n        self.aux_ub = None\n\n    def reset_constraints(self, constraints, decision_thresh):\n        r\"\"\"\n        Reset the constraints of this perturbation. Also will call `sort_out_constr_batches` to preprocess the constraints.\n        Be sure not to reset with the same constraints input repeatedly.\n        \"\"\"\n        # We have to enable the gradient computation for the constraints\n        # when using constraints_solving within alpha crown.\n        self.use_grad = True\n        constraints = construct_constraints(constraints[0], constraints[1], decision_thresh, self.x_L.shape[0], self.x_L.flatten(1).shape[1])\n        self.constraints, self.sorted_out_batches = sort_out_constr_batches(self.x_L, self.x_U, constraints, \n                                                                            rearrange_constraints=self.rearrange_constraints,\n                                                                            no_return_inf=self.no_return_inf)\n\n    def __repr__(self):\n        if self.norm == np.inf:\n            if self.x_L is None and self.x_U is None:\n                return f'PerturbationLpNorm(norm=inf, eps={self.eps})'\n            else:\n                return f'PerturbationLpNorm(norm=inf, eps={self.eps}, x_L={self.x_L}, x_U={self.x_U})'\n        else:\n            return f'PerturbationLpNorm(norm={self.norm}, eps={self.eps})'\n\n\nclass PerturbationLinear(Perturbation):\n    \"\"\"\n    Perturbation defined by a Linear transformation.\n    args:\n        lower_A: Lower bound matrix of shape (B, output_dim, input_dim)\n        upper_A: Upper bound matrix of shape (B, output_dim, input_dim)\n        lower_b: Lower bound bias of shape (B, output_dim)\n        upper_b: Upper bound bias of shape (B, output_dim)\n        input_lb: Input lower bound of shape (B, input_dim)\n        input_ub: Input upper bound of shape (B, input_dim)\n        x_L: Output lower bound of shape (B, output_dim)\n        x_U: Output upper bound of shape (B, output_dim)\n\n        x_L and x_U can be None, in which case they will be computed from the other parameters.    \n    \"\"\"\n    def __init__(self, lower_A, upper_A, lower_b, upper_b, input_lb, input_ub, x_L=None, x_U=None):\n        super(PerturbationLinear, self).__init__()\n        self.lower_A = lower_A\n        self.upper_A = upper_A\n        self.lower_b = lower_b.unsqueeze(-1) if lower_b is not None else None\n        self.upper_b = upper_b.unsqueeze(-1) if upper_b is not None else None\n        self.input_lb = input_lb.unsqueeze(-1) if input_lb is not None else None\n        self.input_ub = input_ub.unsqueeze(-1) if input_ub is not None else None\n        if x_L is None or x_U is None:\n            mid = (self.input_lb + self.input_ub) / 2\n            diff = (self.input_ub - self.input_lb) / 2\n            self.x_U = (self.upper_A @ mid + torch.abs(self.upper_A) @ diff + self.upper_b).squeeze(-1)\n            self.x_L = (self.lower_A @ mid - torch.abs(self.lower_A) @ diff + self.lower_b).squeeze(-1)\n        else:\n            self.x_L = x_L\n            self.x_U = x_U\n\n    def concretize(self, x, A, sign=-1, aux=None):\n        if A is None:\n            return None\n        else:\n            A_pos = torch.clamp(A, min=0)\n            A_neg = torch.clamp(A, max=0)\n\n            center = (self.input_lb + self.input_ub) / 2\n            diff = (self.input_ub - self.input_lb) / 2\n\n            if sign == 1:\n                composite_A = A_pos @ self.upper_A + A_neg @ self.lower_A\n                composite_b = A_pos @ self.upper_b + A_neg @ self.lower_b\n                bound = composite_A @ center + torch.abs(composite_A) @ diff + composite_b\n            else:\n                composite_A = A_pos @ self.lower_A + A_neg @ self.upper_A\n                composite_b = A_pos @ self.lower_b + A_neg @ self.upper_b\n                bound = composite_A @ center - torch.abs(composite_A) @ diff + composite_b\n            return bound.squeeze(-1)\n\n    def init(self, x, aux=None, forward=False):\n        if not forward:\n            return LinearBound(None, None, None, None, self.x_L, self.x_U), x, None\n        else:\n            raise NotImplementedError(\"Linear perturbation does not support forward mode.\")\n\n\nclass PerturbationSynonym(Perturbation):\n    def __init__(self, budget, eps=1.0, use_simple=False):\n        super(PerturbationSynonym, self).__init__()\n        self._load_synonyms()\n        self.budget = budget\n        self.eps = eps\n        self.use_simple = use_simple\n        self.model = None\n        self.train = False\n\n    def __repr__(self):\n        return (f'perturbation(Synonym-based word substitution '\n                f'budget={self.budget}, eps={self.eps})')\n\n    def _load_synonyms(self, path='data/synonyms.json'):\n        with open(path) as file:\n            self.synonym = json.loads(file.read())\n        logger.info('Synonym list loaded for {} words'.format(len(self.synonym)))\n\n    def set_train(self, train):\n        self.train = train\n\n    def concretize(self, x, A, sign, aux):\n        assert(self.model is not None)\n\n        x_rep, mask, can_be_replaced = aux\n        batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2]\n        dim_out = A.shape[1]\n        max_num_cand = x_rep.shape[2]\n\n        mask_rep = torch.tensor(can_be_replaced, dtype=torch.get_default_dtype(), device=A.device)\n\n        num_pos = int(np.max(np.sum(can_be_replaced, axis=-1)))\n        update_A = A.shape[-1] > num_pos * dim_word\n        if update_A:\n            bias = torch.bmm(A, (x * (1 - mask_rep).unsqueeze(-1)).reshape(batch_size, -1, 1)).squeeze(-1)\n        else:\n            bias = 0.\n        A = A.reshape(batch_size, dim_out, -1, dim_word)\n\n        A_new, x_new, x_rep_new, mask_new = [], [], [], []\n        zeros_A = torch.zeros(dim_out, dim_word, device=A.device)\n        zeros_w = torch.zeros(dim_word, device=A.device)\n        zeros_rep = torch.zeros(max_num_cand, dim_word, device=A.device)\n        zeros_mask = torch.zeros(max_num_cand, device=A.device)\n        for t in range(batch_size):\n            cnt = 0\n            for i in range(0, length):\n                if can_be_replaced[t][i]:\n                    if update_A:\n                        A_new.append(A[t, :, i, :])\n                    x_new.append(x[t][i])\n                    x_rep_new.append(x_rep[t][i])\n                    mask_new.append(mask[t][i])\n                    cnt += 1\n            if update_A:\n                A_new += [zeros_A] * (num_pos - cnt)\n            x_new += [zeros_w] * (num_pos - cnt)\n            x_rep_new += [zeros_rep] * (num_pos - cnt)\n            mask_new += [zeros_mask] * (num_pos - cnt)\n        if update_A:\n            A = torch.cat(A_new).reshape(batch_size, num_pos, dim_out, dim_word).transpose(1, 2)\n        x = torch.cat(x_new).reshape(batch_size, num_pos, dim_word)\n        x_rep = torch.cat(x_rep_new).reshape(batch_size, num_pos, max_num_cand, dim_word)\n        mask = torch.cat(mask_new).reshape(batch_size, num_pos, max_num_cand)\n        length = num_pos\n\n        A = A.reshape(batch_size, A.shape[1], length, -1).transpose(1, 2)\n        x = x.reshape(batch_size, length, -1, 1)\n\n        if sign == 1:\n            cmp, init = torch.max, -1e30\n        else:\n            cmp, init = torch.min, 1e30\n\n        init_tensor = torch.ones(batch_size, dim_out).to(x.device) * init\n        dp = [[init_tensor] * (self.budget + 1) for i in range(0, length + 1)]\n        dp[0][0] = torch.zeros(batch_size, dim_out).to(x.device)\n\n        A = A.reshape(batch_size * length, A.shape[2], A.shape[3])\n        Ax = torch.bmm(\n            A,\n            x.reshape(batch_size * length, x.shape[2], x.shape[3])\n        ).reshape(batch_size, length, A.shape[1])\n\n        Ax_rep = torch.bmm(\n            A,\n            x_rep.reshape(batch_size * length, max_num_cand, x.shape[2]).transpose(-1, -2)\n        ).reshape(batch_size, length, A.shape[1], max_num_cand)\n        Ax_rep = Ax_rep * mask.unsqueeze(2) + init * (1 - mask).unsqueeze(2)\n        Ax_rep_bound = cmp(Ax_rep, dim=-1).values\n\n        if self.use_simple and self.train:\n            return torch.sum(cmp(Ax, Ax_rep_bound), dim=1) + bias\n\n        for i in range(1, length + 1):\n            dp[i][0] = dp[i - 1][0] + Ax[:, i - 1]\n            for j in range(1, self.budget + 1):\n                dp[i][j] = cmp(\n                    dp[i - 1][j] + Ax[:, i - 1],\n                    dp[i - 1][j - 1] + Ax_rep_bound[:, i - 1]\n                )\n        dp = torch.cat(dp[length], dim=0).reshape(self.budget + 1, batch_size, dim_out)\n\n        return cmp(dp, dim=0).values + bias\n\n    def init(self, x, aux=None, forward=False):\n        tokens, batch = aux\n        self.tokens = tokens # DEBUG\n        assert(len(x.shape) == 3)\n        batch_size, length, dim_word = x.shape[0], x.shape[1], x.shape[2]\n\n        max_pos = 1\n        can_be_replaced = np.zeros((batch_size, length), dtype=bool)\n\n        self._build_substitution(batch)\n\n        for t in range(batch_size):\n            cnt = 0\n            candidates = batch[t]['candidates']\n            # for transformers\n            if tokens[t][0] == '[CLS]':\n                candidates = [[]] + candidates + [[]]\n            for i in range(len(tokens[t])):\n                if tokens[t][i] == '[UNK]' or \\\n                        len(candidates[i]) == 0 or tokens[t][i] != candidates[i][0]:\n                    continue\n                for w in candidates[i][1:]:\n                    if w in self.model.vocab:\n                        can_be_replaced[t][i] = True\n                        cnt += 1\n                        break\n            max_pos = max(max_pos, cnt)\n\n        dim = max_pos * dim_word\n        if forward:\n            eye = torch.eye(dim_word).to(x.device)\n            lw = torch.zeros(batch_size, dim, length, dim_word).to(x.device)\n            lb = torch.zeros_like(x).to(x.device)\n        word_embeddings = self.model.word_embeddings.weight\n        vocab = self.model.vocab\n        x_rep = [[[] for i in range(length)] for t in range(batch_size)]\n        max_num_cand = 1\n        for t in range(batch_size):\n            candidates = batch[t]['candidates']\n            # for transformers\n            if tokens[t][0] == '[CLS]':\n                candidates = [[]] + candidates + [[]]\n            cnt = 0\n            for i in range(length):\n                if can_be_replaced[t][i]:\n                    word_embed = word_embeddings[vocab[tokens[t][i]]]\n                    # positional embedding and token type embedding\n                    other_embed = x[t, i] - word_embed\n                    if forward:\n                        lw[t, (cnt * dim_word):((cnt + 1) * dim_word), i, :] = eye\n                        lb[t, i, :] = torch.zeros_like(word_embed)\n                    for w in candidates[i][1:]:\n                        if w in self.model.vocab:\n                            x_rep[t][i].append(\n                                word_embeddings[self.model.vocab[w]] + other_embed)\n                    max_num_cand = max(max_num_cand, len(x_rep[t][i]))\n                    cnt += 1\n                else:\n                    if forward:\n                        lb[t, i, :] = x[t, i, :]\n        if forward:\n            uw, ub = lw, lb\n        else:\n            lw = lb = uw = ub = None\n        zeros = torch.zeros(dim_word, device=x.device)\n\n        x_rep_, mask = [], []\n        for t in range(batch_size):\n            for i in range(length):\n                x_rep_ += x_rep[t][i] + [zeros] * (max_num_cand - len(x_rep[t][i]))\n                mask += [1] * len(x_rep[t][i]) + [0] * (max_num_cand - len(x_rep[t][i]))\n        x_rep_ = torch.cat(x_rep_).reshape(batch_size, length, max_num_cand, dim_word)\n        mask = torch.tensor(mask, dtype=torch.get_default_dtype(), device=x.device)\\\n            .reshape(batch_size, length, max_num_cand)\n        x_rep_ = x_rep_ * self.eps + x.unsqueeze(2) * (1 - self.eps)\n\n        inf = 1e20\n        lower = torch.min(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * inf, dim=2).values\n        upper = torch.max(mask.unsqueeze(-1) * x_rep_ + (1 - mask).unsqueeze(-1) * (-inf), dim=2).values\n        lower = torch.min(lower, x)\n        upper = torch.max(upper, x)\n\n        return LinearBound(lw, lb, uw, ub, lower, upper), x, (x_rep_, mask, can_be_replaced)\n\n    def _build_substitution(self, batch):\n        for example in batch:\n            if not 'candidates' in example or example['candidates'] is None:\n                candidates = []\n                tokens = example['sentence'].strip().lower().split(' ')\n                for i in range(len(tokens)):\n                    _cand = []\n                    if tokens[i] in self.synonym:\n                        for w in self.synonym[tokens[i]]:\n                            if w in self.model.vocab:\n                                _cand.append(w)\n                    if len(_cand) > 0:\n                        _cand = [tokens[i]] + _cand\n                    candidates.append(_cand)\n                example['candidates'] = candidates\n\n"
  },
  {
    "path": "auto_LiRPA/solver_module.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nfrom .bound_ops import *\n\nfrom typing import TYPE_CHECKING\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef build_solver_module(self: 'BoundedModule', x=None, C=None, interm_bounds=None,\n                        final_node_name=None, model_type=\"mip\", solver_pkg=\"gurobi\", set_input=True):\n    r\"\"\"build lp/mip solvers in general graph.\n\n    Args:\n        x: inputs, a list of BoundedTensor. If set to None, we reuse exisint bounds that\n        were previously computed in compute_bounds().\n        C (Tensor): The specification matrix that can map the output of the model with an\n        additional linear layer. This is usually used for maping the logits output of the\n        model to classification margins.\n        interm_bounds: if specified, will replace existing intermediate layer bounds.\n        Otherwise we reuse exising intermediate bounds.\n\n        final_node_name (String): the name for the target layer to optimize\n\n        solver_pkg (String): the backbone of the solver, default gurobi, also support scipy\n\n    Returns:\n        output vars (list): a list of final nodes to optimize\n    \"\"\"\n    # self.root_names: list of root node name\n    # self.final_name: list of output node name\n    # self.final_node: output module\n    # <module>.input: a list of input modules of this layer module\n    # <module>.solver_vars: a list of gurobi vars of every layer module\n    #       list with conv shape if conv layers, otherwise flattened\n    # if last layer we need to be careful with:\n    #       C: specification matrix\n    #       <module>.is_input_perturbed(1)\n    if x is not None:\n        assert interm_bounds is not None\n        # Set the model to use new intermediate layer bounds, ignore the original ones.\n        self.set_input(x, interm_bounds=interm_bounds)\n\n    roots = [self[name] for name in self.root_names]\n\n    # create interval ranges for input and other weight parameters\n    for i in range(len(roots)):\n        # if isinstance(root[i], BoundInput) and not isinstance(root[i], BoundParams):\n        if type(roots[i]) is BoundInput:\n            # create input vars for gurobi self.model\n            if set_input:\n                inp_gurobi_vars = self._build_solver_input(roots[i])\n        else:\n            value = roots[i].forward()\n            # regular weights\n            roots[i].solver_vars = value\n\n    final = self.final_node() if final_node_name is None else self[final_node_name]\n\n    # backward propagate every layer including last layer\n    self._build_solver_general(node=final, C=C, model_type=model_type, solver_pkg=solver_pkg)\n\n    # a list of output solver vars\n    return final.solver_vars\n\n\ndef _build_solver_general(self: 'BoundedModule', node: Bound, C=None, model_type=\"mip\",\n                          solver_pkg=\"gurobi\"):\n    if not hasattr(node, 'solver_vars'):\n        if not node.perturbed:\n            # if not perturbed, just forward\n            node.solver_vars = self.get_forward_value(node)\n            return node.solver_vars\n        for n in node.inputs:\n            self._build_solver_general(n, C=C, model_type=model_type)\n        inp = [n_pre.solver_vars for n_pre in node.inputs]\n        if C is not None and isinstance(node, BoundLinear) and\\\n                not node.is_input_perturbed(1) and self.final_name == node.name:\n            # when node is the last layer\n            # merge the last BoundLinear node with the specification,\n            # available when weights of this layer are not perturbed\n            solver_vars = node.build_solver(*inp, model=self.solver_model, C=C,\n                model_type=model_type, solver_pkg=solver_pkg)\n        else:\n            solver_vars = node.build_solver(*inp, model=self.solver_model, C=None,\n                    model_type=model_type, solver_pkg=solver_pkg)\n        # just return output node gurobi vars\n        return solver_vars\n\ndef _reset_solver_vars(self: 'BoundedModule', node: Bound, iteration=True):\n    if hasattr(node, 'solver_vars'):\n        del node.solver_vars\n    if iteration:\n        if hasattr(node, 'inputs'):\n            for n in node.inputs:\n                self._reset_solver_vars(n)\n                \ndef _reset_solver_model(self: 'BoundedModule'):\n    self.solver_model.remove(self.solver_model.getVars())\n    self.solver_model.remove(self.solver_model.getConstrs())\n    self.solver_model.update()\n\ndef _build_solver_input(self: 'BoundedModule', node):\n    ## Do the input layer, which is a special case\n    assert isinstance(node, BoundInput)\n    assert node.perturbation is not None\n\n    if self.solver_model is None:\n        self.solver_model = grb.Model()\n    # zero var will be shared within the solver model\n    zero_var = self.solver_model.addVar(lb=0, ub=0, obj=0, vtype=grb.GRB.CONTINUOUS, name='zero')\n    one_var = self.solver_model.addVar(lb=1, ub=1, obj=0, vtype=grb.GRB.CONTINUOUS, name='one')\n    neg_one_var = self.solver_model.addVar(lb=-1, ub=-1, obj=0, vtype=grb.GRB.CONTINUOUS, name='neg_one')\n\n    x_L = node.value - node.perturbation.eps if node.perturbation.x_L is None else node.perturbation.x_L\n    x_U = node.value + node.perturbation.eps if node.perturbation.x_U is None else node.perturbation.x_U\n    x_L = x_L.min(dim=0).values\n    x_U = x_U.max(dim=0).values\n\n    input_shape = x_L.shape\n    name_array = [f'inp_{idx}' for idx in range(prod(input_shape))]\n    inp_gurobi_vars_dict = self.solver_model.addVars(*input_shape, lb=x_L, ub=x_U,\n                                                      obj=0, vtype=grb.GRB.CONTINUOUS, name=name_array)\n\n    inp_gurobi_vars = np.empty(input_shape, dtype=object)\n    for idx in inp_gurobi_vars_dict:\n        inp_gurobi_vars[idx] = inp_gurobi_vars_dict[idx]\n    inp_gurobi_vars = inp_gurobi_vars.tolist()\n    \n    # Flatten the input solver_vars. \n    def flatten(x):\n        if isinstance(x, list):\n            result = []\n            for item in x:\n                result.extend(flatten(item))\n            return result\n        else:\n            return [x]\n\n    # Add extra constraints for the inputs if the perturbation norm is not L_inf.\n    if node.perturbation.norm != float(\"inf\"):\n        if isinstance(inp_gurobi_vars, (list, tuple)):\n            flat_inp_gurobi_vars = flatten(inp_gurobi_vars)\n        else:\n            flat_inp_gurobi_vars = inp_gurobi_vars\n        if hasattr(node.value[0], \"flatten\"):\n            flat_node_value = node.value.flatten().tolist()\n        else:\n            flat_node_value = node.value\n        assert len(flat_inp_gurobi_vars) == len(flat_node_value), \"The input doesn't match the variables\"\n\n        if node.perturbation.norm == 2:\n            # For L2 norm, we directly add a quadratic constraint for cplex compatibility.\n            # TODO: Compare efficiency with the second method below. If the second method is faster,\n            # we should use it for L2 norm by default (when cplex is not used).\n            print(f'setup L2 constraint for input with radius {node.perturbation.eps}.')\n            quad_expr = grb.QuadExpr()\n            for var, val in zip(flat_inp_gurobi_vars, flat_node_value):\n                quad_expr.add((var - val) * (var - val))\n\n            self.solver_model.addQConstr(\n                quad_expr <= node.perturbation.eps ** 2,\n                name=\"l2_perturbation\"\n            )\n        else:\n            print(f'setup Lp constraint for input with radius {node.perturbation.eps}.')\n            n = len(flat_inp_gurobi_vars)\n            # Create variables to set up the lp constraint.\n            # We set input = x0 + delta where delta is under the Lp norm constraint.\n            senses = ['='] * n\n            delta_vars = self.solver_model.addVars(\n                n,\n                lb=-grb.GRB.INFINITY,\n                ub=grb.GRB.INFINITY,\n                name=\"delta\"\n            )\n            diff = -np.array(flat_node_value)\n            vars_list = list(delta_vars.values()) + flat_inp_gurobi_vars\n            self.solver_model.update()\n            A = np.hstack([np.eye(n), -np.eye(n)])\n            # Add constraints input = x0 + delta as delta - input = -x0.\n            # Here x0 is \"flat_node_value\" and input is \"flat_inp_gurobi_vars\".\n            self.solver_model.addMConstr(A, vars_list, senses, diff)\n            # Set up the lp constraint here: \\| delta \\|_p <= eps.\n            lp_norm_var = self.solver_model.addVar(\n                lb=0, \n                vtype=grb.GRB.CONTINUOUS,\n                name=\"lp_norm\"\n            )\n            self.solver_model.addGenConstrNorm(\n                lp_norm_var,\n                delta_vars,\n                node.perturbation.norm,\n                name=\"lp_norm_constr\"\n            )\n            self.solver_model.addConstr(\n                lp_norm_var <= node.perturbation.eps,\n                name=\"lp_perturbation_radius\"\n            )\n    \n    node.solver_vars = inp_gurobi_vars\n    # Save the gurobi input variables so that we can later extract primal values in input space easily.\n    self.input_vars = inp_gurobi_vars\n    self.solver_model.update()\n    return inp_gurobi_vars\n\n"
  },
  {
    "path": "auto_LiRPA/tools.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nfrom graphviz import Digraph\nimport shutil\nimport re\n\nfrom typing import TYPE_CHECKING, List\nif TYPE_CHECKING:\n    from .bound_general import BoundedModule\n\n\ndef visualize(self: 'BoundedModule', output_path, print_bounds=False):\n    r\"\"\"A visualization tool for BoundedModule.\n    If dot engine is available in the system enviornment, it renders the graph and output {output_path}.png.\n    Otherwise, it output a {output_path}.dot.\n\n    Args:\n        output_path (str): The path to save the graph (without file extension).\n        print_bounds (bool): Whether to display the mean width of the bounds for each node.\n    \"\"\"\n\n    nodes = list(self.nodes())\n    # Create a directed graph\n    dot = Digraph(format='png', engine='dot')\n    # Add nodes with optional attributes\n    for node in nodes:\n        # we name the Graphviz nodes with the sanitized node name,\n        # while keeping the original name in the label which is displayed in the graph.\n        export_node_name = sanitize_graphviz_name(node.name)\n        label = f\"\"\"<\n            <TABLE BORDER=\"0\" CELLBORDER=\"0\" CELLPADDING=\"4\">\n                <TR><TD><FONT FACE=\"Arial\" COLOR=\"black\">{node.name}</FONT></TD></TR>\n                <TR><TD><FONT FACE=\"Courier\" COLOR=\"blue\">{node.__class__.__name__}</FONT></TD></TR>\n                <TR><TD><FONT FACE=\"Courier\" COLOR=\"black\">{\n                    tuple(node.output_shape) if node.output_shape is not None else None}</FONT></TD></TR>\n            </TABLE>\n        >\"\"\"\n        if print_bounds:\n            # Display the mean width of the bounds)\n            # (Both the empirical bound from forward value and the computed bound if available)\n            label = f\"\"\"<\n                <TABLE BORDER=\"0\" CELLBORDER=\"0\" CELLPADDING=\"4\">\n                    <TR><TD><FONT FACE=\"Arial\" COLOR=\"black\">{node.name}</FONT></TD></TR>\n                    <TR><TD><FONT FACE=\"Courier\" COLOR=\"blue\">{node.__class__.__name__}</FONT></TD></TR>\n                    <TR><TD><FONT FACE=\"Courier\" COLOR=\"black\">{\n                        tuple(node.output_shape) if node.output_shape is not None else None}</FONT></TD></TR>\n                    <TR><TD><FONT FACE=\"Courier\" COLOR=\"black\">{\n                        (node.forward_value.max(dim=0)[0] - node.forward_value.min(dim=0)[0]).to(dtype=torch.float).mean().item() if (\n                            node.perturbed and\n                            hasattr(node, \"forward_value\") and\n                            isinstance(node.forward_value, torch.Tensor)) else None}</FONT></TD></TR>\n                    <TR><TD><FONT FACE=\"Courier\" COLOR=\"black\">{\n                        (node.upper - node.lower).to(dtype=torch.float).mean().item() if (\n                            node.perturbed and\n                            hasattr(node, \"lower\") and hasattr(node, \"upper\") and\n                            node.lower is not None and node.upper is not None) else None}</FONT></TD></TR>\n                </TABLE>\n            >\"\"\"\n        # perturbed nodes are highlighted in grey\n        if getattr(node, \"perturbed\", False):\n            style_attrs = {'style': 'filled', 'fillcolor': 'lightgrey'}\n        else:\n            style_attrs = {}\n        if node.__class__.__name__ in [\"BoundParams\", \"boundConstant\", \"BoundBuffers\"]:\n            dot.node(export_node_name, label=label, fontsize=\"8\", width=\"0.5\", height=\"0.2\", shape=\"ellipse\", **style_attrs)\n        elif node.__class__.__name__ == \"BoundInput\":\n            dot.node(export_node_name, label=label, shape=\"diamond\", **style_attrs)\n        else:\n            dot.node(export_node_name, label=label, shape=\"square\", **style_attrs)\n        for inp in node.inputs:\n            dot.edge(sanitize_graphviz_name(inp.name), export_node_name)\n    # Render graph\n    if shutil.which(\"dot\") is None:\n        print(\"Cannot render the graphviz file (dot not found).\")\n        print(f\"Graph saved to {output_path}.dot\")\n        dot.save(output_path + \".dot\")\n    else:\n        dot.render(output_path, cleanup=True)\n        print(f\"Graph saved to {output_path}.png\")\n\ndef sanitize_graphviz_name(name):\n    \"\"\"\n    Convert problematic characters (like `:`, `::`) in a Graphviz node name to a safe alternative character `_`.\n    \"\"\"\n    unsafe_chars = r'[:;,\\[\\]{}()<>|#*@&=+`~^?\"\\\\\\s]'\n    safe_name = re.sub(unsafe_chars, \"_\", name)\n    \n    return safe_name\n"
  },
  {
    "path": "auto_LiRPA/utils.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport logging\nimport time\nimport torch\nimport torch.nn as nn\nimport os\nimport sys\nimport appdirs\nfrom collections import defaultdict, namedtuple\nfrom functools import reduce\nimport operator\nimport warnings\nfrom typing import Tuple\nfrom .patches import Patches\n\n\nlogging.basicConfig(\n    format='%(levelname)-8s %(asctime)-12s [%(filename)s:%(lineno)d] %(message)s',\n    datefmt='%H:%M:%S',\n    stream=sys.stdout,\n    level=logging.INFO\n)\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.DEBUG if os.environ.get('AUTOLIRPA_DEBUG', 0) else logging.INFO)\n\nwarnings.simplefilter(\"once\")\n\n# Special identity matrix. Avoid extra computation of identity matrix multiplication in various places.\neyeC = namedtuple('eyeC', 'shape device')\nOneHotC = namedtuple('OneHotC', 'shape device index coeffs')\nBatchedCrownC = namedtuple('BatchedCrownC', 'type')\n\ndef onehotc_to_dense(one_hot_c: OneHotC, dtype: torch.dtype) -> torch.Tensor:\n    shape = one_hot_c.shape  # [spec, batch, C, H, W]\n    dim = int(prod(shape[2:]))\n    dense = torch.zeros(\n        size=(shape[0], shape[1], dim), device=one_hot_c.device, dtype=dtype)\n    # one_hot_c.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A.\n    # one_hot_c.coeffs is the value of the non-zero element.\n    dense = torch.scatter(\n        dense, dim=2, index=one_hot_c.index.unsqueeze(-1),\n        src=one_hot_c.coeffs.unsqueeze(-1))\n    dense = dense.view(shape[0], shape[1], *shape[2:])\n    return dense\n\n# Benchmarking mode disable some expensive assertions.\nBenchmarking = True\n\nreduction_sum = lambda x: x.sum(dim=tuple(range(1, x.dim())), keepdim=True)\nreduction_mean = lambda x: x.mean(dim=tuple(range(1, x.dim())), keepdim=True)\nreduction_max = lambda x: x.amax(dim=tuple(range(1, x.dim())), keepdim=True)\nreduction_min = lambda x: x.amin(dim=tuple(range(1, x.dim())), keepdim=True)\n\nMIN_HALF_FP = 5e-8  # 2**-24, which is the smallest value that float16 can be represented\n\n\ndef reduction_str2func(reduction_func):\n    if type(reduction_func) == str:\n        if reduction_func == 'min':\n            return reduction_min\n        elif reduction_func == 'max':\n            return reduction_max\n        elif reduction_func == 'sum':\n            return reduction_sum\n        elif reduction_func == 'mean':\n            return reduction_mean\n        else:\n            raise NotImplementedError(f'Unknown reduction_func {reduction_func}')\n    else:\n        return reduction_func\n\ndef stop_criterion_placeholder(threshold=0):\n    return lambda x: RuntimeError(\"BUG: bound optimization stop criterion not specified.\")\n\ndef stop_criterion_min(threshold=0):\n    return lambda x: (x.min(1, keepdim=True).values > threshold)\n\ndef stop_criterion_all(threshold=0):\n    # The dimension of x should be (batch, spec). The spec dimension\n    # This was used in the incomplete verifier, where the spec dimension can\n    # present statements in an OR clause.\n    return lambda x: (x > threshold).all(dim=1, keepdim=True)\n\ndef stop_criterion_max(threshold=0):\n    return lambda x: (x.max(1, keepdim=True).values > threshold)\n\ndef stop_criterion_batch(threshold=0):\n    # may unexpected broadcast, pay attention to the shape of threshold\n    # x shape: batch, number_bounds; threshold shape: batch, number_bounds\n    return lambda x: (x > threshold)\n\ndef stop_criterion_batch_any(threshold=0):\n    \"\"\"If any spec >= rhs, then this sample can be stopped;\n       if all samples can be stopped, stop = True, o.w., False.\n    \"\"\"\n    # may unexpected broadcast, pay attention to the shape of threshold\n    # x shape: batch, number_bounds; threshold shape: batch, number_bounds\n    return lambda x: (x > threshold).any(dim=1, keepdim=True)\n\ndef stop_criterion_general(or_spec_size, threshold=0):\n    \"\"\"\n    If any spec in a group >= rhs, then this group can be stopped;\n    if all groups can be stopped, stop = True, o.w., False.\n    Args:\n        or_clause_indices: [num_clause]. the indices of the belonging OR clauses for AND clauses.\n        num_or: the number of OR clauses.\n        threshold: [batch, num_clause]. The threshold for each spec. sum(or_clause_indices) == num_clauses.\n    \"\"\"\n    def stop_criterion_per_or(x):\n        # get the indices of OR clauses assigned to their corresponding atom clauses, [num_clause]\n        num_or = or_spec_size.shape[0]\n        or_clause_indices = torch.repeat_interleave(\n            torch.arange(num_or, device=or_spec_size.device), or_spec_size\n        ).view(1, -1).expand(x.shape)\n        # get the result for each spec. [batch, num_clause]\n        result_per_spec = (x > threshold) \n        # get the number of verified ANDs for each OR clause. [batch, num_or]\n        num_verified_and_per_or = torch.scatter_reduce(result_per_spec[:, :num_or], 1, or_clause_indices, result_per_spec, 'sum', include_self=False)\n        # result of any spec in a OR (group of ANDs) is True (sum >= 1) -> result of the OR is True.\n        return num_verified_and_per_or >= 1\n    # if all OR clauses are True, then return True. [batch, 1]\n    return lambda x: stop_criterion_per_or(x).all(dim=1, keepdim=True)\n\ndef stop_criterion_batch_topk(threshold=0, k=1314):\n    # x shape: batch, number_bounds; threshold shape: batch, number_bounds\n    return lambda x: (torch.kthvalue(x, k, dim=-1, keepdim=True).values > threshold).any(dim=1)\n\ndef multi_spec_keep_func_all(x):\n    return torch.all(x, dim=-1)\n\n\nuser_data_dir = appdirs.user_data_dir('auto_LiRPA')\nif not os.path.exists(user_data_dir):\n    try:\n        os.makedirs(user_data_dir)\n    except:\n        logger.error('Failed to create directory {}'.format(user_data_dir))\n\n\nclass MultiAverageMeter(object):\n    \"\"\"Computes and stores the average and current value for multiple metrics\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.sum_meter = defaultdict(float)\n        self.lasts = defaultdict(float)\n        self.counts_meter = defaultdict(int)\n        self.batch_size = 1\n\n    def set_batch_size(self, batch_size):\n        self.batch_size = batch_size\n\n    def update(self, key, val, n=None):\n        if val is None:\n            return\n        if n is None:\n            n = self.batch_size\n        if isinstance(val, torch.Tensor):\n            val = val.item()\n        self.lasts[key] = val\n        self.sum_meter[key] += val * n\n        self.counts_meter[key] += n\n\n    def last(self, key):\n        return self.lasts[key]\n\n    def avg(self, key):\n        if self.counts_meter[key] == 0:\n            return 0.0\n        else:\n            return self.sum_meter[key] / self.counts_meter[key]\n\n    def __repr__(self):\n        s = \"\"\n        for k in self.sum_meter:\n            s += \"{}={:.4f} \".format(k, self.avg(k))\n        return s.strip()\n\n\nclass MultiTimer(object):\n    \"\"\"Count the time for each part of training.\"\"\"\n    def __init__(self):\n        self.reset()\n    def reset(self):\n        self.timer_starts = defaultdict(float)\n        self.timer_total = defaultdict(float)\n    def start(self, key):\n        if self.timer_starts[key] != 0:\n            raise RuntimeError(\"start() is called more than once\")\n        self.timer_starts[key] = time.time()\n    def stop(self, key):\n        if key not in self.timer_starts:\n            raise RuntimeError(\"Key does not exist; please call start() before stop()\")\n        self.timer_total[key] += time.time() - self.timer_starts[key]\n        self.timer_starts[key] = 0\n    def total(self, key):\n        return self.timer_total[key]\n    def __repr__(self):\n        s = \"\"\n        for k in self.timer_total:\n            s += \"{}_time={:.3f} \".format(k, self.timer_total[k])\n        return s.strip()\n\n\nclass Flatten(nn.Flatten):\n    \"\"\"Legacy Flatten class.\n\n    It was previously created when nn.Flatten was not supported. Simply use\n    nn.Flatten in the future.\"\"\"\n    pass\n\n\nclass Unflatten(nn.Module):\n    def __init__(self, wh):\n        super().__init__()\n        self.wh = wh # width and height of the feature maps\n    def forward(self, x):\n        return x.view(x.size(0), -1, self.wh, self.wh)\n\n\nclass Max(nn.Module):\n\n    def __init__(self):\n        super(Max, self).__init__()\n\n    def forward(self, x, y):\n        return torch.max(x, y)\n\n\nclass Min(nn.Module):\n\n    def __init__(self):\n        super(Min, self).__init__()\n\n    def forward(self, x, y):\n        return torch.min(x, y)\n\n\ndef scale_gradients(optimizer, gradient_accumulation_steps, grad_clip=None):\n    parameters = []\n    for param_group in optimizer.param_groups:\n        for param in param_group['params']:\n            parameters.append(param)\n            if param.grad is not None:\n                param.grad.data /= gradient_accumulation_steps\n    if grad_clip is not None:\n        return torch.nn.utils.clip_grad_norm_(parameters, grad_clip)\n\n\n# unpack tuple, dict, list into one single list\n# TODO: not sure if the order matches graph.inputs()\ndef unpack_inputs(inputs, device=None):\n    if isinstance(inputs, dict):\n        inputs = list(inputs.values())\n    if isinstance(inputs, tuple) or isinstance(inputs, list):\n        res = []\n        for item in inputs:\n            res += unpack_inputs(item, device=device)\n        return res\n    else:\n        if device is not None:\n            inputs = inputs.to(device)\n        return [inputs]\n\n\ndef isnan(x):\n    if isinstance(x, Patches):\n        return False\n    return torch.isnan(x).any()\n\n\ndef prod(x):\n    return reduce(operator.mul, x, 1)\n\n\ndef batched_index_select(input, dim, index):\n    # Assuming the input has a batch dimension.\n    # index has dimensin [spec, batch].\n    if input.ndim == 4:\n        # Alphas for fully connected layers, shape [2, spec, batch, neurons]\n        index = index.unsqueeze(-1).unsqueeze(0).expand(input.size(0), -1, -1, input.size(3))\n    elif input.ndim == 6:\n        # Alphas for fully connected layers, shape [2, spec, batch, c, h, w].\n        index = index.view(1, index.size(0), index.size(1), *([1] * (input.ndim - 3))).expand(input.size(0), -1, -1, *input.shape[3:])\n    elif input.ndim == 3:\n        # Weights.\n        input = input.expand(index.size(0), -1, -1)\n        index = index.unsqueeze(-1).expand(-1, -1, input.size(2))\n    elif input.ndim == 2:\n        # Bias.\n        input = input.expand(index.size(0), -1)\n    else:\n        raise ValueError\n    return torch.gather(input, dim, index)\n\n\ndef get_spec_matrix(X, y, num_classes):\n    with torch.no_grad():\n        c = (torch.eye(num_classes).type_as(X)[y].unsqueeze(1)\n            - torch.eye(num_classes).type_as(X).unsqueeze(0))\n        I = (~(y.unsqueeze(1) == torch.arange(num_classes).type_as(y).unsqueeze(0)))\n        c = (c[I].view(X.size(0), num_classes - 1, num_classes))\n    return c\n\n\ndef unravel_index(\n    indices: torch.LongTensor,\n    shape: Tuple[int, ...],\n) -> torch.LongTensor:\n    r\"\"\"Converts flat indices into unraveled coordinates in a target shape.\n\n    Args:\n        indices: A tensor of (flat) indices, (*, N).\n        shape: The targeted shape, (D,).\n\n    Returns:\n        The unraveled coordinates, a list with tensors in shape (N, D).\n\n    Code borrowed from:\n        https://github.com/pytorch/pytorch/issues/35674\n    \"\"\"\n\n    coord = []\n\n    for dim in reversed(shape):\n        coord.append(indices % dim)\n        indices = torch.div(indices, dim, rounding_mode='trunc')\n\n    return list(reversed(coord))\n\n\nclass AutoBatchSize:\n    def __init__(self, init_batch_size, device, vram_ratio=0.9, enable=True):\n        self.batch_size = init_batch_size\n        self.max_actual_batch_size = 0\n        self.device = device\n        self.vram_ratio = vram_ratio\n        self.enable = enable\n\n    def record_actual_batch_size(self, actual_batch_size):\n        \"\"\"Record the actual batch size used.\n\n        It may be smaller than self.batch_size, especially for the early batches.\n        \"\"\"\n        self.max_actual_batch_size = max(self.max_actual_batch_size, actual_batch_size)\n\n    def update(self):\n        \"\"\"Check if the batch size can be enlarged.\"\"\"\n        if not self.enable:\n            return None\n        # Only try to update the batch size if the current batch size has\n        # been actually used, as indicated by `max_actual_batch_size`\n        if self.device == 'cpu' or self.max_actual_batch_size < self.batch_size:\n            return None\n        total_vram = torch.cuda.get_device_properties(self.device).total_memory\n        current_vram = torch.cuda.memory_reserved(self.device)\n        if current_vram * 2 >= total_vram * self.vram_ratio:\n            return None\n        new_batch_size = self.batch_size * 2\n        self.batch_size = new_batch_size\n        logger.debug('Automatically updated batch size to %d', new_batch_size)\n        return {\n            'current_vram': current_vram,\n            'total_vram': total_vram,\n        }\n\n\ndef sync_params(model_ori: torch.nn.Module,\n                model: 'BoundedModule',\n                loss_fusion: bool = False):\n    \"\"\"Sync the parameters from a BoundedModule to the original model.\"\"\"\n    state_dict_loss = model.state_dict()\n    state_dict = model_ori.state_dict()\n    for name in state_dict_loss:\n        v = state_dict_loss[name]\n        if name.endswith('.param'):\n            name = name[:-6]\n        elif name.endswith('.buffer'):\n            name = name[:-7]\n        else:\n            raise NameError(name)\n        name_ori = model[name].ori_name\n        if loss_fusion:\n            assert name_ori.startswith('model.')\n            name_ori = name_ori[6:]\n        assert name_ori in state_dict\n        state_dict[name_ori] = v\n    model_ori.load_state_dict(state_dict)\n    return state_dict\n\n\ndef reduce_broadcast_dims(A, target_shape, left_extra_dims=1):\n    \"\"\"\n    When backward propagating tensors that are automatically broadcasted,\n    we need to reduce the broadcasted dimensions to match the input shape.\n    This can be useful for backward bound propagation and backward gradient\n    computation.\n\n    Args:\n        A: The input tensor.\n        target_shape: The target shape to reduce to.\n        left_extra_dims: The number of dimensions that A should have but the target\n            shape doesn't have. These dimensions are usually added to the left of the\n            target shape and don't need to be reduced (e.g. spec).\n\n    Example:\n        x1 has shape [a1, a2, a3, a4], x2 has shape [a2, 1, a4], y = x1 * x2.\n        Two types of broadcasting here:\n            1. Adding additional dimensions to x2 to match the dimension of x1.\n            2. Broadcasting along existing dimensions length 1.\n        In backward computation from y to x2, we need to reduce (sum) the A matrix\n        to match the shape of x2. The first dimension of A is usually for spec, so\n        the shape usually aligns from the second dimension.\n    \"\"\"\n    # Step 1: Dimension doesn't exist in target shape but exists in A.\n    # cnt_sum is the number of dimensions that are broadcast.\n    # (The additional dimensions in A that are not in target shape).\n    cnt_sum = (A.ndim - left_extra_dims) - len(target_shape)\n    # The broadcast dimensions must be the first dimensions in A\n    # (except the extra dimensions and batch dimension).\n    dims = list(range(left_extra_dims + 1, cnt_sum + left_extra_dims + 1))\n    if dims:\n        A = torch.sum(A, dim=dims, keepdim=False)\n    # Step 2: Dimension exists in target shape, broadcast from 1.\n    # FIXME (05/11/2022): the following condition is not always correct.\n    # We should not rely on checking dimension is \"1\" or not.\n    dims = [i + left_extra_dims for i in range(left_extra_dims, len(target_shape))\n            if target_shape[i] == 1 and A.shape[i + left_extra_dims] != 1]\n    if dims:\n        A = torch.sum(A, dim=dims, keepdim=True)\n    # Check the final shape - it should be compatible.\n    assert A.shape[2:] == target_shape[1:]  # skip the spec and batch dimension.\n    return A\n\n\n@torch.jit.script\ndef matmul_maybe_batched(a: torch.Tensor, b: torch.Tensor, both_batched: bool):\n    # Basically just matmul, but we need to handle the batch dimension.\n    if both_batched:\n        return torch.einsum(\"b...ij,b...jk->b...ik\", a, b)\n    else:\n        return a.matmul(b)\n\ndef transfer(tensor, device=None, dtype=None, non_blocking=False):\n    \"\"\"Transfer a tensor to a specific device or dtype.\"\"\"\n    if device:\n        tensor = tensor.to(device, non_blocking=non_blocking)\n    if dtype:\n        tensor = tensor.to(dtype)\n\n    return tensor\n\n\ndef clone_sub_A_dict(A_dict, out_in_keys: Tuple):\n    \"\"\"\n    Deep copy the A_dict structure for specific out_in_keys.\n    Args:\n        A_dict: The A_dict to be copied.\n        out_in_keys: The (out_key, in_key) pairs to be copied.\n    Returns:\n        A new A_dict with all tensors cloned.\n    \"\"\"\n    # Structure: A_dict[out_key][in_key][key]\n    # key in [lA, uA, lbias, ubias, unstable_idx]\n    # lA, uA are tensors or Patches\n    # (there're also types like eyeC, OneHotC, not supported here)\n    # lbias, ubias are tensors\n    # unstable_idx is tensor or tuple of tensors\n\n    out_key, in_key = out_in_keys\n    src_subdict = A_dict[out_key][in_key]\n    cloned_subdict = {}\n\n    for key, val in src_subdict.items():\n        if val is None:\n            cloned_subdict[key] = None\n            continue\n\n        if isinstance(val, (torch.Tensor, Patches)):\n            cloned_subdict[key] = val.detach().clone()\n        elif isinstance(val, tuple):\n            cloned_subdict[key] = tuple(v.detach().clone() for v in val)\n        else:\n            raise NotImplementedError(f'Unsupported A type {type(val)} for copying.')\n    return cloned_subdict\n\n\ndef clone_full_A_dict(A_dict):\n    \"\"\"\n    Deep copy the A_dict structure.\n    Args:\n        A_dict: The A_dict to be copied.\n    Returns:\n        A new A_dict with all tensors cloned.\n    \"\"\"\n    new_A_dict = {}\n    for out_key, in_dict in A_dict.items():\n        new_A_dict[out_key] = {}\n        for in_key in in_dict:\n            new_A_dict[out_key][in_key] = clone_sub_A_dict(A_dict, (out_key, in_key))\n    return new_A_dict"
  },
  {
    "path": "auto_LiRPA/wrapper.py",
    "content": "#########################################################################\n##   This file is part of the auto_LiRPA library, a core part of the   ##\n##   α,β-CROWN (alpha-beta-CROWN) neural network verifier developed    ##\n##   by the α,β-CROWN Team                                             ##\n##                                                                     ##\n##   Copyright (C) 2020-2025 The α,β-CROWN Team                        ##\n##   Team leaders:                                                     ##\n##          Faculty:   Huan Zhang <huan@huan-zhang.com> (UIUC)         ##\n##          Student:   Xiangru Zhong <xiangru4@illinois.edu> (UIUC)    ##\n##                                                                     ##\n##   See CONTRIBUTORS for all current and past developers in the team. ##\n##                                                                     ##\n##     This program is licensed under the BSD 3-Clause License,        ##\n##        contained in the LICENCE file in this directory.             ##\n##                                                                     ##\n#########################################################################\nimport torch\nimport torch.nn as nn\n\nclass CrossEntropyWrapper(nn.Module):\n    def __init__(self, model):\n        super(CrossEntropyWrapper, self).__init__()\n        self.model = model\n\n    def forward(self, x, labels):\n        y = self.model(x)\n        logits = y - torch.gather(y, dim=-1, index=labels.unsqueeze(-1))\n        return torch.exp(logits).sum(dim=-1, keepdim=True)\n\nclass CrossEntropyWrapperMultiInput(nn.Module):\n    def __init__(self, model):\n        super(CrossEntropyWrapperMultiInput, self).__init__()\n        self.model = model\n\n    def forward(self, labels, *x):\n        y = self.model(*x)\n        logits = y - torch.gather(y, dim=-1, index=labels.unsqueeze(-1))\n        return torch.exp(logits).sum(dim=-1, keepdim=True)"
  },
  {
    "path": "doc/.gitignore",
    "content": "_build\nsections\n*.md\n!src/*.md\n!README.md"
  },
  {
    "path": "doc/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "doc/README.md",
    "content": "# Documentation\n\nThis directory contains source files for building our documentation.\nPlease view the compiled documentation on our [documentation page](https://auto-lirpa.readthedocs.io/en/latest/?badge=latest), as some links may not work here on GitHub.\n\n## Dependencies\n\nInstall additional libraries for building documentations:\n\n```bash\npip install -r requirements.txt\n```\n\n## Build\n\nBuild documentations in HTML:\n\n```\nmake html\n```\n\nThe documentation will be generated at `_build/html`.\n"
  },
  {
    "path": "doc/api.rst",
    "content": "API Usage\n======================================\n\n.. autoclass:: auto_LiRPA.BoundedModule\n\n   .. autofunction:: auto_LiRPA.BoundedModule.forward\n   .. autofunction:: auto_LiRPA.BoundedModule.compute_bounds\n   .. autofunction:: auto_LiRPA.BoundedModule.save_intermediate\n\n.. autoclass:: auto_LiRPA.bound_ops.Bound\n\n   .. autofunction:: auto_LiRPA.bound_ops.Bound.forward\n   .. autofunction:: auto_LiRPA.bound_ops.Bound.interval_propagate\n   .. autofunction:: auto_LiRPA.bound_ops.Bound.bound_forward\n   .. autofunction:: auto_LiRPA.bound_ops.Bound.bound_backward\n\n.. autoclass:: auto_LiRPA.perturbations.Perturbation\n\n   .. autofunction:: auto_LiRPA.perturbations.Perturbation.concretize\n   .. autofunction:: auto_LiRPA.perturbations.Perturbation.init\n\nIndices and tables\n-------------------\n\n* :ref:`genindex`\n* :ref:`search`\n\n..\n   * :ref:`modindex`"
  },
  {
    "path": "doc/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport subprocess\nimport inspect\nimport sys\nfrom pygit2 import Repository\nsys.path.insert(0, '..')\nimport auto_LiRPA\n\nsubprocess.run(['python', 'process.py'])\n\n# -- Project information -----------------------------------------------------\n\nproject = 'auto_LiRPA'\nauthor = '<a href=\"https://github.com/Verified-Intelligence/auto_LiRPA#developers-and-copyright\">auto-LiRPA authors</a>'\ncopyright = f'2020-2025, {author}'\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx.ext.linkcode',\n    'm2r2',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'src', 'Thumbs.db', '.DS_Store']\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'alabaster'\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\nrepo = Repository('../')\nbranch = repo.head.shorthand\n\n# Resolve function for the linkcode extension.\ndef linkcode_resolve(domain, info):\n    def find_source():\n        obj = auto_LiRPA\n        parts = info['fullname'].split('.')\n        if info['module'].endswith(f'.{parts[0]}'):\n            module = info['module'][:-len(parts[0])-1]\n        else:\n            module = info['module']\n        obj = sys.modules[module]\n        for part in parts:\n            obj = getattr(obj, part)\n        fn = inspect.getsourcefile(obj)\n        source, lineno = inspect.getsourcelines(obj)\n        return fn, lineno, lineno + len(source) - 1\n\n    fn, lineno_start, lineno_end = find_source()\n    filename = f'{fn}#L{lineno_start}-L{lineno_end}'\n\n    return f\"https://github.com/Verified-Intelligence/auto_LiRPA/blob/{branch}/doc/{filename}\"\n"
  },
  {
    "path": "doc/index.rst",
    "content": ".. auto_LiRPA documentation master file, created by\n   sphinx-quickstart on Wed Jul 14 21:56:10 2021.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nDocumentation for `auto_LiRPA <https://github.com/Verified-Intelligence/auto_LiRPA>`_\n===========================================================================\n\n.. toctree::\n   :hidden:\n\n   installation\n   quick-start\n   examples\n   api\n   custom_op\n   paper\n\n.. raw:: html\n\n   <p align=\"center\">\n   <a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_lirpa_2.png\" width=\"45%\" height=\"45%\" float=\"left\"></a>\n   <a href=\"http://PaperCode.cc/AutoLiRPA-Video\"><img src=\"http://www.huan-zhang.com/images/upload/lirpa/auto_lirpa_1.png\" width=\"45%\" height=\"45%\" float=\"right\"></a>\n   </p>\n\n.. mdinclude:: sections/introduction.md\n\nUsage\n-----\n\n* :doc:`Installation <sections/installation>`\n* :doc:`Quick Start <sections/quick-start>`\n* :doc:`More Working Examples <examples>`\n* :doc:`API Usage <api>`\n* :doc:`Custom Operators <custom_op>`\n* :doc:`Reproducing our NeurIPS 2020 paper <paper>`\n"
  },
  {
    "path": "doc/process.py",
    "content": "\"\"\" Process source files before running Sphinx\"\"\"\nimport re\nimport os\nimport shutil\nfrom pygit2 import Repository\n\nrepo = 'https://github.com/Verified-Intelligence/auto_LiRPA'\nbranch = Repository('.').head.shorthand\nrepo_file_path = os.path.join(repo, 'tree', branch)\n\n# Parse README.md into sections which can be reused\nheading = ''\ncopied = {}\nprint('Parsing markdown sections from README:')\nwith open('../README.md') as file:\n    for line in file.readlines():\n        if line.startswith('##'):\n            heading = line[2:].strip()\n        else:\n            if not heading in copied:\n                copied[heading] = ''\n            copied[heading] += line\nif not os.path.exists('sections'):\n    os.makedirs('sections')\nfor key in copied:\n    if key == '':\n        continue\n    filename = re.sub(r\"[?+\\'\\\"]\", '', key.lower())\n    filename = re.sub(r\" \", '-', filename) + '.md'\n    print(filename)\n    with open(os.path.join('sections', filename), 'w') as file:\n        file.write(f'## {key}\\n')\n        file.write(copied[key])\nprint()\n\n# Load source files and fix links to GitHub\nfor folder in ['src', 'sections']:\n    for filename in os.listdir(folder):\n        print(f'Processing {folder}/{filename}')\n        with open(os.path.join(folder, filename)) as file:\n            source = file.read()\n        source_new = ''\n        ptr = 0\n        for m in re.finditer('(\\[.*\\])(\\(.*\\))', source):\n            assert m.start() >= ptr\n            source_new += source[ptr:m.start()]\n            ptr = m.start()\n            source_new += m.group(1)\n            ptr += len(m.group(1))\n            link_raw = m.group(2)\n            while len(link_raw) >= 2 and link_raw[-2] == ')':\n                link_raw = link_raw[:-1]\n            link = link_raw[1:-1]\n            if link.startswith('https://') or link.startswith('http://') or '.html#' in link:\n                link_new = link\n            else:\n                if folder == 'sections':\n                    link_new = os.path.join(repo_file_path, link)\n                else:\n                    link_new = os.path.join(repo_file_path, 'docs/src', link)\n                print(f'Fix link {link} -> {link_new}')\n            source_new += f'({link_new})'\n            ptr += len(link_raw)\n        source_new += source[ptr:]\n        with open(filename, 'w') as file:\n            file.write(source_new)\n        print()\n"
  },
  {
    "path": "examples/.gitignore",
    "content": "auto_LiRPA\n"
  },
  {
    "path": "examples/__init__.py",
    "content": ""
  },
  {
    "path": "examples/language/.gitignore",
    "content": "model*\n!modeling*\nlog*\nres_test.pkl\nckpt_*\ndata_language.tar.gz\ndata/\n"
  },
  {
    "path": "examples/language/Transformer/Transformer.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights   rved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import absolute_import, division, print_function\n\nimport os\nimport torch\nimport torch.nn as nn\n\nfrom Transformer.modeling import BertForSequenceClassification\nfrom pytorch_pretrained_bert.modeling import BertConfig\nfrom Transformer.utils import convert_examples_to_features\nfrom language_utils import build_vocab\nfrom auto_LiRPA.utils import logger\n\n\nclass Transformer(nn.Module):\n    def __init__(self, args, data_train):\n        super().__init__()\n        self.args = args\n        self.max_seq_length = args.max_sent_length\n        self.drop_unk = args.drop_unk\n        self.num_labels = args.num_classes\n        self.label_list = range(args.num_classes)\n        self.device = args.device\n        self.lr = args.lr\n\n        self.dir = args.dir\n        self.vocab = build_vocab(data_train, args.min_word_freq)\n        if not os.path.exists(self.dir):\n            os.makedirs(self.dir)\n        self.checkpoint = 0\n        config = BertConfig(len(self.vocab))\n        config.num_hidden_layers = args.num_layers\n        config.embedding_size = args.embedding_size\n        config.hidden_size = args.hidden_size\n        config.intermediate_size = args.intermediate_size\n        config.hidden_act = args.hidden_act\n        config.num_attention_heads = args.num_attention_heads\n        config.layer_norm = args.layer_norm\n        config.hidden_dropout_prob = args.dropout\n        self.model = BertForSequenceClassification(\n            config, self.num_labels, vocab=self.vocab).to(self.device)\n        logger.info(\"Model initialized\")\n        if args.load:\n            checkpoint = torch.load(args.load, map_location=torch.device(self.device))\n            epoch = checkpoint['epoch']\n            self.model.embeddings.load_state_dict(checkpoint['state_dict_embeddings'])\n            self.model.model_from_embeddings.load_state_dict(checkpoint['state_dict_model_from_embeddings'])\n            logger.info('Checkpoint loaded: {}'.format(args.load))\n\n        self.model_from_embeddings = self.model.model_from_embeddings\n        self.word_embeddings = self.model.embeddings.word_embeddings\n        self.model_from_embeddings.device = self.device\n\n    def save(self, epoch):\n        self.model.model_from_embeddings = self.model_from_embeddings\n        path = os.path.join(self.dir, \"ckpt_{}\".format(epoch))\n        torch.save({\n            'state_dict_embeddings': self.model.embeddings.state_dict(),\n            'state_dict_model_from_embeddings': self.model.model_from_embeddings.state_dict(),\n            'epoch': epoch\n        }, path)\n        logger.info(\"Model saved to {}\".format(path))\n\n    def build_optimizer(self):\n        # update the original model with the converted model\n        self.model.model_from_embeddings = self.model_from_embeddings\n        param_group = [\n            {\"params\": [p[1] for p in self.model.named_parameters()], \"weight_decay\": 0.},\n        ]\n        return torch.optim.Adam(param_group, lr=self.lr)\n\n    def train(self):\n        self.model.train()\n        self.model_from_embeddings.train()\n\n    def eval(self):\n        self.model.eval()\n        self.model_from_embeddings.eval()\n\n    def get_input(self, batch):\n        features = convert_examples_to_features(\n            batch, self.label_list, self.max_seq_length, self.vocab, drop_unk=self.drop_unk)\n\n        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(self.device)\n        input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long).to(self.device)\n        segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long).to(self.device)\n        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long).to(self.device)\n        tokens = [f.tokens for f in features]\n\n        embeddings, extended_attention_mask = \\\n            self.model(input_ids, segment_ids, input_mask, embed_only=True)\n\n        return embeddings, extended_attention_mask, tokens, label_ids\n\n    def forward(self, batch):\n        embeddings, extended_attention_mask, tokens, label_ids = self.get_input(batch)\n        logits = self.model_from_embeddings(embeddings, extended_attention_mask)\n        preds = torch.argmax(logits, dim=1)\n        return preds"
  },
  {
    "path": "examples/language/Transformer/__init__.py",
    "content": ""
  },
  {
    "path": "examples/language/Transformer/modeling.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch BERT model.\"\"\"\n\nfrom __future__ import absolute_import, division, print_function, unicode_literals\n\nimport torch\nfrom torch import nn\n\nfrom pytorch_pretrained_bert.modeling import BertIntermediate, BertSelfAttention, BertPreTrainedModel\n\nclass BertLayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        super(BertLayerNorm, self).__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.bias = nn.Parameter(torch.zeros(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        u = x.mean(-1, keepdim=True)\n        s = (x - u).pow(2).mean(-1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.variance_epsilon)\n        return self.weight * x + self.bias\n\nclass BertLayerNormNoVar(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        super(BertLayerNormNoVar, self).__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.bias = nn.Parameter(torch.zeros(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        u = x.mean(-1, keepdim=True)\n        x = x - u\n        return self.weight * x + self.bias\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\n    \"\"\"\n    def __init__(self, config, glove=None, vocab=None):\n        super(BertEmbeddings, self).__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0)\n        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)\n        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)\n\n        self.config = config\n\n    def forward(self, input_ids, token_type_ids=None):\n        seq_length = input_ids.size(1)\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)\n        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)\n        if token_type_ids is None:\n            token_type_ids = torch.zeros_like(input_ids)\n\n        words_embeddings = self.word_embeddings(input_ids)\n        position_embeddings = self.position_embeddings(position_ids)\n        token_type_embeddings = self.token_type_embeddings(token_type_ids)\n\n        # position/token_type embedding disabled\n        # embeddings = words_embeddings + position_embeddings + token_type_embeddings\n\n        embeddings = words_embeddings\n        return embeddings\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super(BertSelfOutput, self).__init__()\n        self.config = config\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if hasattr(config, \"layer_norm\") and config.layer_norm == \"no_var\":\n            self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12)\n        else:\n            self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        if hidden_states.shape[-1] == input_tensor.shape[-1]:\n            hidden_states = hidden_states + input_tensor\n        if hasattr(self.config, \"layer_norm\") and self.config.layer_norm == \"no\":\n            pass\n        else:\n            hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, input_size):\n        super(BertAttention, self).__init__()\n        self.self = BertSelfAttention(config)\n        self.output = BertSelfOutput(config)\n\n    def forward(self, input_tensor, attention_mask):\n        self_output = self.self(input_tensor, attention_mask)\n        attention_output = self.output(self_output, input_tensor)\n\n        return attention_output\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super(BertOutput, self).__init__()\n        self.config = config\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        if hasattr(config, \"layer_norm\") and config.layer_norm == \"no_var\":\n            self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12)\n        else:\n            self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = hidden_states + input_tensor\n        if hasattr(self.config, \"layer_norm\") and self.config.layer_norm == \"no\":\n            pass\n        else:\n            hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_id):\n        super(BertLayer, self).__init__()\n        self.input_size = config.hidden_size\n        self.attention = BertAttention(config, self.input_size)\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def forward(self, hidden_states, attention_mask):\n        attention_output = self.attention(hidden_states, attention_mask)\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n\n        return layer_output\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super(BertEncoder, self).__init__()\n        self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])\n\n    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):\n        all_encoder_layers = []\n        for layer_module in self.layer:\n            hidden_states = layer_module(hidden_states, attention_mask)\n            if output_all_encoded_layers:\n                all_encoder_layers.append(hidden_states)\n        if not output_all_encoded_layers:\n            all_encoder_layers.append(hidden_states)\n        return all_encoder_layers\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super(BertPooler, self).__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\nclass BertModelFromEmbeddings(BertPreTrainedModel):\n    def __init__(self, config):\n        super(BertModelFromEmbeddings, self).__init__(config)\n        self.encoder = BertEncoder(config)\n        self.pooler = BertPooler(config)\n        self.apply(self.init_bert_weights)\n\n    def forward(self, embeddings, extended_attention_mask):\n        encoded_layers  = self.encoder(embeddings, extended_attention_mask)\n        sequence_output = encoded_layers[-1]\n        pooled_output = self.pooler(sequence_output)\n        return pooled_output\n\nclass BertForSequenceClassificationFromEmbeddings(BertPreTrainedModel):\n    def __init__(self, config, num_labels=2):\n        super(BertForSequenceClassificationFromEmbeddings, self).__init__(config)\n        self.num_labels = num_labels\n        self.bert = BertModelFromEmbeddings(config)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.classifier = nn.Linear(config.hidden_size, num_labels)\n        self.linear_in = nn.Linear(config.embedding_size, config.hidden_size)\n\n        self.layer_norm = config.layer_norm\n        if hasattr(config, \"layer_norm\") and config.layer_norm == \"no_var\":\n            self.LayerNorm = BertLayerNormNoVar(config.embedding_size, eps=1e-12)\n        else:\n            self.LayerNorm = BertLayerNorm(config.embedding_size, eps=1e-12)\n\n        self.apply(self.init_bert_weights)\n\n    def forward(self, embeddings, extended_attention_mask):\n        embeddings = self.linear_in(embeddings)\n\n        if self.layer_norm == \"no\":\n            pass\n        else:\n            embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n\n        pooled_output = self.bert(embeddings, extended_attention_mask)\n        pooled_output = self.dropout(pooled_output)\n        logits = self.classifier(pooled_output)\n        return logits\n\nclass BertForSequenceClassification(BertPreTrainedModel):\n    def __init__(self, config, num_labels=2, glove=None, vocab=None):\n        super(BertForSequenceClassification, self).__init__(config)\n        self.model_from_embeddings = BertForSequenceClassificationFromEmbeddings(\n            config, num_labels\n        )\n        self.num_labels = num_labels\n        self.embeddings = BertEmbeddings(config, glove=glove, vocab=vocab)\n        self.apply(self.init_bert_weights)\n\n    def forward(self, input_ids, token_type_ids=None, attention_mask=None, embed_only=False):\n        if attention_mask is None:\n            attention_mask = torch.ones_like(input_ids)\n        if token_type_ids is None:\n            token_type_ids = torch.zeros_like(input_ids)\n        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)\n        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        embeddings = self.embeddings(input_ids, token_type_ids)\n        if embed_only:\n            return embeddings, extended_attention_mask\n        logits = self.model_from_embeddings(embeddings, extended_attention_mask)\n        return logits\n"
  },
  {
    "path": "examples/language/Transformer/utils.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights   rved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom language_utils import tokenize, token_to_id\n\nclass InputExample(object):\n    def __init__(self, guid, text_a, text_b=None, label=None):\n        self.guid = guid\n        self.text_a = text_a\n        self.text_b = text_b\n        self.label = label\n\nclass InputFeatures(object):\n    def __init__(self, input_ids, input_mask, segment_ids, label_id, tokens):\n        self.input_ids = input_ids\n        self.input_mask = input_mask\n        self.segment_ids = segment_ids\n        self.label_id = label_id\n        self.tokens = tokens\n\ndef convert_examples_to_features(examples, label_list, max_seq_length,\n                                vocab, drop_unk=False):\n                                #tokenizer):\n    \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n    features = []\n    all_tokens = tokenize(examples, vocab, max_seq_length - 2, drop_unk=drop_unk)\n    for i in range(len(all_tokens)):\n        all_tokens[i] = [\"[CLS]\"] + all_tokens[i] + [\"[SEP]\"]\n    all_ids = token_to_id(all_tokens, vocab)\n\n    max_seq_length = min(max_seq_length, max([len(tokens) for tokens in all_tokens]))\n    for (ex_index, example) in enumerate(examples):\n        tokens = all_tokens[ex_index]\n        segment_ids = [0] * len(tokens)\n        input_ids = all_ids[ex_index]\n        input_mask = [1] * len(input_ids)\n        padding = [0] * (max_seq_length - len(input_ids))\n        input_ids += padding\n        input_mask += padding\n        segment_ids += padding\n\n        assert len(input_ids) == max_seq_length\n        assert len(input_mask) == max_seq_length\n        assert len(segment_ids) == max_seq_length\n\n        features.append(InputFeatures(\n            input_ids=input_ids,\n            input_mask=input_mask,\n            segment_ids=segment_ids,\n            label_id=example[\"label\"],\n            tokens=tokens))\n\n    return features\n"
  },
  {
    "path": "examples/language/data_utils.py",
    "content": "import random\nimport json\nfrom auto_LiRPA.utils import logger\n\ndef load_data_sst():\n    data = []\n    for split in ['train_all_nodes', 'train', 'dev', 'test']:\n        with open('data/sst/{}.json'.format(split)) as file:\n            data.append(json.loads(file.read()))\n    return data\n\ndef load_data(dataset):    \n    if dataset == \"sst\":\n        return load_data_sst()\n    else:\n        raise NotImplementedError('Unknown dataset {}'.format(dataset))\n\ndef clean_data(data):\n    return [example for example in data if example['candidates'] is not None]\n\ndef get_batches(data, batch_size):\n    batches = []\n    random.shuffle(data)\n    for i in range((len(data) + batch_size - 1) // batch_size):\n        batches.append(data[i * batch_size : (i + 1) * batch_size])\n    return batches\n"
  },
  {
    "path": "examples/language/language_utils.py",
    "content": "from auto_LiRPA.utils import logger\nimport numpy as np\n\ndef build_vocab(data_train, min_word_freq, dump=False, include=[]):\n    vocab = {\n        '[PAD]': 0,\n        '[UNK]': 1,\n        '[CLS]': 2,\n        '[SEP]': 3,\n        '[MASK]': 4\n    }\n    cnt = {}\n    for example in data_train:\n        for token in example['sentence'].strip().lower().split():\n            if token in cnt:\n                cnt[token] += 1\n            else:\n                cnt[token] = 1\n    for w in cnt:\n        if cnt[w] >= min_word_freq or w in include:\n            vocab[w] = len(vocab)\n    logger.info('Vocabulary size: {}'.format(len(vocab)))\n\n    if dump:\n        with open('tmp/vocab.txt', 'w') as file:\n            for w in vocab.keys():\n                file.write('{}\\n'.format(w))\n\n    return vocab\n\ndef tokenize(batch, vocab, max_seq_length, drop_unk=False):\n    res = []\n    for example in batch:\n        t = example['sentence'].strip().lower().split(' ')\n        if drop_unk:\n            tokens = [w for w in t if w in vocab][:max_seq_length]\n        else:\n            tokens = []\n            for token in t[:max_seq_length]:\n                if token in vocab:\n                    tokens.append(token)\n                else:\n                    tokens.append('[UNK]')\n        res.append(tokens)    \n    return res\n\ndef token_to_id(tokens, vocab):\n    ids = []\n    for t in tokens:\n        ids.append([vocab[w] for w in t])\n    return ids"
  },
  {
    "path": "examples/language/lstm.py",
    "content": "import os\nimport shutil\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom auto_LiRPA.utils import logger\nfrom language_utils import build_vocab\n\nclass LSTMFromEmbeddings(nn.Module):\n    def __init__(self, args, vocab_size):\n        super(LSTMFromEmbeddings, self).__init__()\n\n        self.embedding_size = args.embedding_size\n        self.hidden_size = args.hidden_size\n        self.num_classes = args.num_classes\n        self.device = args.device\n\n        self.cell_f = nn.LSTMCell(self.embedding_size, self.hidden_size)\n        self.cell_b = nn.LSTMCell(self.embedding_size, self.hidden_size)\n        self.linear = nn.Linear(self.hidden_size * 2, self.num_classes)\n        if args.dropout is not None:\n            self.dropout = nn.Dropout(p=args.dropout)\n            logger.info('LSTM dropout: {}'.format(args.dropout))\n        else:\n            self.dropout = None\n\n    def forward(self, embeddings, mask):\n        if self.dropout is not None:\n            embeddings = self.dropout(embeddings)\n        embeddings = embeddings * mask.unsqueeze(-1)\n        batch_size = embeddings.shape[0]\n        length = embeddings.shape[1]\n        h_f = torch.zeros(batch_size, self.hidden_size).to(embeddings.device)\n        c_f = h_f.clone()\n        h_b, c_b = h_f.clone(), c_f.clone()\n        h_f_sum, h_b_sum = h_f.clone(), h_b.clone()\n        for i in range(length):\n            h_f, c_f = self.cell_f(embeddings[:, i], (h_f, c_f))\n            h_b, c_b = self.cell_b(embeddings[:, length - i - 1], (h_b, c_b))\n            h_f_sum = h_f_sum + h_f\n            h_b_sum = h_b_sum + h_b\n        states = torch.cat([h_f_sum / float(length), h_b_sum / float(length)], dim=-1)\n        logits = self.linear(states)\n        return logits\n\nclass LSTM(nn.Module):\n    def __init__(self, args, data_train):\n        super(LSTM, self).__init__()\n        self.args = args\n        self.embedding_size = args.embedding_size\n        self.max_seq_length = args.max_sent_length\n        self.min_word_freq = args.min_word_freq\n        self.device = args.device\n        self.lr = args.lr\n\n        self.dir = args.dir\n        if not os.path.exists(self.dir):\n            os.makedirs(self.dir)\n        self.vocab = self.vocab_actual = build_vocab(data_train, args.min_word_freq)\n        self.checkpoint = 0\n\n        if args.load:\n            ckpt = torch.load(args.load, map_location=torch.device(self.device))\n            self.embedding = torch.nn.Embedding(len(self.vocab), self.embedding_size)\n            self.model_from_embeddings = LSTMFromEmbeddings(args, len(self.vocab))\n            self.model = self.embedding, LSTMFromEmbeddings(args, len(self.vocab))\n            self.embedding.load_state_dict(ckpt['state_dict_embedding'])\n            self.model_from_embeddings.load_state_dict(ckpt['state_dict_model_from_embeddings'])\n            self.checkpoint = ckpt['epoch']\n        else:\n            self.embedding = torch.nn.Embedding(len(self.vocab), self.embedding_size)\n            self.model_from_embeddings = LSTMFromEmbeddings(args, len(self.vocab))\n            self.model = self.embedding, LSTMFromEmbeddings(args, len(self.vocab))\n            logger.info(\"Model initialized\")\n        self.embedding = self.embedding.to(self.device)\n        self.model_from_embeddings = self.model_from_embeddings.to(self.device)\n        self.word_embeddings = self.embedding\n\n    def save(self, epoch):\n        path = os.path.join(self.dir, 'ckpt_{}'.format(epoch))\n        torch.save({\n            'state_dict_embedding': self.embedding.state_dict(),\n            'state_dict_model_from_embeddings': self.model_from_embeddings.state_dict(),\n            'epoch': epoch\n        }, path)\n        logger.info('LSTM saved: {}'.format(path))\n\n    def build_optimizer(self):\n        self.model = (self.model[0], self.model_from_embeddings)\n        param_group = []\n        for m in self.model:\n            for p in m.named_parameters():\n                param_group.append(p)\n        param_group = [{\"params\": [p[1] for p in param_group], \"weight_decay\": 0.}]\n        return torch.optim.Adam(param_group, lr=self.lr)\n\n    def get_input(self, batch):\n        mask, tokens = [], []\n        for example in batch:\n            _tokens = []\n            for token in example[\"sentence\"].strip().lower().split(' ')[:self.max_seq_length]:\n                if token in self.vocab:\n                    _tokens.append(token)\n                else:\n                    _tokens.append(\"[UNK]\")\n            tokens.append(_tokens)\n        max_seq_length = max([len(t) for t in tokens])\n        token_ids = []\n        for t in tokens:\n            ids = [self.vocab[w] for w in t]\n            mask.append(torch.cat([\n                torch.ones(1, len(ids)),\n                torch.zeros(1, self.max_seq_length - len(ids))\n            ], dim=-1).to(self.device))\n            ids += [self.vocab[\"[PAD]\"]] * (self.max_seq_length - len(ids))\n            token_ids.append(ids)\n        embeddings = self.embedding(torch.tensor(token_ids, dtype=torch.long).to(self.device))\n        mask = torch.cat(mask, dim=0)\n        label_ids = torch.tensor([example[\"label\"] for example in batch]).to(self.device)\n        return embeddings, mask, tokens, label_ids\n\n    def train(self):\n        self.model_from_embeddings.train()\n\n    def eval(self):\n        self.model_from_embeddings.eval()\n"
  },
  {
    "path": "examples/language/oracle.py",
    "content": "import torch\nfrom auto_LiRPA.utils import logger\nfrom auto_LiRPA import PerturbationSynonym\nfrom data_utils import get_batches\n\ndef oracle(args, model, ptb, data, type):\n    logger.info('Running oracle for {}'.format(type))\n    model.eval()\n    assert(isinstance(ptb, PerturbationSynonym))\n    cnt_cor = 0\n    word_embeddings = model.word_embeddings.weight\n    vocab = model.vocab    \n    for t, example in enumerate(data):\n        embeddings, mask, tokens, label_ids = model.get_input([example])\n        candidates = example['candidates']\n        if tokens[0][0] == '[CLS]':\n            candidates = [[]] + candidates + [[]]   \n        embeddings_all = []\n        def dfs(tokens, embeddings, budget, index):\n            if index == len(tokens):\n                embeddings_all.append(embeddings.cpu())\n                return\n            dfs(tokens, embeddings, budget, index + 1)\n            if budget > 0 and tokens[index] != '[UNK]' and len(candidates[index]) > 0\\\n                    and tokens[index] == candidates[index][0]:\n                for w in candidates[index][1:]:\n                    if w in vocab:\n                        _embeddings = torch.cat([\n                            embeddings[:index],\n                            word_embeddings[vocab[w]].unsqueeze(0),\n                            embeddings[index + 1:]\n                        ], dim=0)\n                        dfs(tokens, _embeddings, budget - 1, index + 1)\n        dfs(tokens[0], embeddings[0], ptb.budget, 0)\n        cor = True\n        for embeddings in get_batches(embeddings_all, args.oracle_batch_size):\n            embeddings_tensor = torch.cat(embeddings).cuda().reshape(len(embeddings), *embeddings[0].shape)\n            logits = model.model_from_embeddings(embeddings_tensor, mask)        \n            for pred in list(torch.argmax(logits, dim=1)):\n                if pred != example['label']:\n                    cor = False\n            if not cor: break\n        cnt_cor += cor\n\n        if (t + 1) % args.log_interval == 0:\n            logger.info('{} {}/{}: oracle robust acc {:.3f}'.format(type, t + 1, len(data), cnt_cor * 1. / (t + 1)))\n    logger.info('{}: oracle robust acc {:.3f}'.format(type, cnt_cor * 1. / (t + 1)))\n    "
  },
  {
    "path": "examples/language/preprocess/pre_compute_lm_scores.py",
    "content": "# Ref: https://worksheets.codalab.org/rest/bundles/0x3f614472f4a14393b3d85d5568114591/contents/blob/precompute_lm_scores.py\n\n\"\"\"Precompute language model scores.\"\"\"\nimport argparse\nimport json\nimport os\nimport sys\nimport torch\nfrom tqdm import tqdm\n\nfrom data_utils import load_data\n\nsys.path.insert(0, 'tmp/windweller-l2w/adaptive_softmax')\nimport query as lmquery\n\nOPTS = None\n\ndef parse_args():\n  parser = argparse.ArgumentParser('Insert a description of this script.')\n  parser.add_argument('--data', type=str, default='sst')\n  parser.add_argument('--out', default='tmp')\n  parser.add_argument('--window-radius', '-w', type=int, default=6)\n  parser.add_argument('--neighbor-file', type=str, default='tmp/synonyms.json')\n  return parser.parse_args()\n\ndef main():\n  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n  query_handler = lmquery.load_model(device)\n  with open(OPTS.neighbor_file) as f:\n    neighbors = json.load(f)\n\n  data_train_warmup, data_train, data_dev, data_test = load_data(OPTS.data)\n  split = [('train', data_train), ('dev', data_dev), ('test', data_test)]\n\n  for s in split:\n    data = s[1]\n    out_file = os.path.join(OPTS.out, '{}_lm_scores.txt'.format(s[0]))\n\n    with open(out_file, 'w') as f:\n      for sent_idx, example in enumerate(tqdm(data)):\n        sentence = example[\"sentence\"]\n        print('%d\\t%s' % (sent_idx, sentence), file=f)\n        words = sentence.lower().strip().split(' ')\n        for i, w in enumerate(words):\n          if w in neighbors:\n            options = [w] + neighbors[w]\n            start = max(0, i - OPTS.window_radius)\n            end = min(len(words), i + 1 + OPTS.window_radius)\n            # Remove OOV words from prefix and suffix\n            prefix = [x for x in words[start:i] if x in query_handler.word_to_idx]\n            suffix = [x for x in words[i+1:end] if x in query_handler.word_to_idx]\n            queries = []\n            in_vocab_options = []\n            for opt in options:\n              if opt in query_handler.word_to_idx:\n                queries.append(prefix + [opt] + suffix)\n                in_vocab_options.append(opt)\n              else:\n                print('%d\\t%d\\t%s\\t%s' % (sent_idx, i, opt, float('-inf')), file=f)\n            if queries:\n              log_probs = query_handler.query(queries, batch_size=16)\n              for x, lp in zip(in_vocab_options, log_probs):\n                print('%d\\t%d\\t%s\\t%s' % (sent_idx, i, x, lp), file=f)\n        f.flush()\n\nif __name__ == '__main__':\n  OPTS = parse_args()\n  main()"
  },
  {
    "path": "examples/language/preprocess/preprocess_sst.py",
    "content": "import random, json\n\ndef load_data_sst():\n    # training data\n    path = \"train-nodes.tsv\"\n    data_train_all_nodes = []  \n    with open(path) as file:\n        for line in file.readlines()[1:]:\n            data_train_all_nodes.append({\n                \"sentence\": line.split(\"\\t\")[0],\n                \"label\": int(line.split(\"\\t\")[1])\n            })   \n     \n    # train/dev/test data\n    for subset in [\"train\", \"dev\", \"test\"]:\n        path = \"{}.txt\".format(subset)\n        data = []  \n        with open(path) as file:\n            for line in file.readlines():\n                segs = line[:-1].split(\" \")\n                tokens, word_labels = [], []\n                label = int(segs[0][1])\n                if label < 2: \n                    label = 0\n                elif label >= 3: \n                    label = 1\n                else: \n                    continue\n                for i in range(len(segs) - 1):\n                    if segs[i][0] == \"(\" and segs[i][1] in [\"0\", \"1\", \"2\", \"3\", \"4\"]\\\n                            and segs[i + 1][0] != \"(\":\n                        tokens.append(segs[i + 1][:segs[i + 1].find(\")\")])\n                        word_labels.append(int(segs[i][1]))\n                data.append({\n                    \"label\": label,\n                    \"sentence\": \" \".join(tokens),\n                    \"word_labels\": word_labels\n                })\n        if subset == \"train\":\n            data_train = data\n        elif subset == \"dev\":\n            data_dev = data\n        else:\n            data_test = data\n\n    return data_train_all_nodes, data_train, data_dev, data_test\n\ndef read_scores(split):\n    res = {}\n    with open('{}_lm_scores.txt'.format(split)) as file:\n        line = file.readline().strip().split('\\t')\n        while True:\n            if len(line) < 2: break\n            sentence = line[-1]\n            tokens = sentence.lower().split(' ')\n            candidates = [[] for i in range(len(tokens))]\n            while True:\n                line = file.readline().strip().split('\\t')\n                if len(line) != 4: break\n                pos, word, score = int(line[1]), line[2], float(line[3])\n                if score == float('-inf'):\n                    continue\n                if len(candidates[pos]) == 0:\n                    if word != tokens[pos]:\n                        continue\n                elif score < candidates[pos][0][1] - 5.0:\n                    continue\n                candidates[pos].append((word, score))\n            res[sentence] = [[w[0] for w in cand] for cand in candidates]\n    return res\n\ndata_train_all_nodes, data_train, data_dev, data_test = load_data_sst()\ncandidates_dev = read_scores('dev')\ncandidates_test = read_scores('test')\nfor example in data_dev:\n    example['candidates'] = candidates_dev[example['sentence']]\nfor example in data_test:\n    example['candidates'] = candidates_test[example['sentence']]\nwith open('train_all_nodes.json', 'w') as file:\n    file.write(json.dumps(data_train_all_nodes))\nwith open('train.json', 'w') as file:\n    file.write(json.dumps(data_train))\nwith open('dev.json', 'w') as file:\n    file.write(json.dumps(data_dev))\nwith open('test.json', 'w') as file:\n    file.write(json.dumps(data_test))\n"
  },
  {
    "path": "examples/language/train.py",
    "content": "import argparse\nimport random\nimport pickle\nimport os\nimport pdb\nimport time\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import CrossEntropyLoss\nfrom torch.utils.tensorboard import SummaryWriter\nfrom auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, scale_gradients\nfrom auto_LiRPA.eps_scheduler import *\nfrom Transformer.Transformer import Transformer\nfrom lstm import LSTM\nfrom data_utils import load_data, clean_data, get_batches\nfrom oracle import oracle\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--train', action='store_true')\nparser.add_argument('--robust', action='store_true')\nparser.add_argument('--oracle', action='store_true')\nparser.add_argument('--dir', type=str, default='model')\nparser.add_argument('--checkpoint', type=int, default=None)\nparser.add_argument('--data', type=str, default='sst', choices=['sst'])\nparser.add_argument('--seed', type=int, default=0)\nparser.add_argument('--device', type=str, default='cuda', choices=['cuda', 'cpu'])\nparser.add_argument('--load', type=str, default=None)\nparser.add_argument('--legacy_loading', action='store_true', help='use a deprecated way of loading checkpoints for previously saved models')\nparser.add_argument('--auto_test', action='store_true')\n\nparser.add_argument('--eps', type=float, default=1.0)\nparser.add_argument('--budget', type=int, default=6)\nparser.add_argument('--method', type=str, default=None,\n                    choices=['IBP', 'IBP+backward', 'IBP+backward_train', 'forward', 'forward+backward'])\n\nparser.add_argument('--model', type=str, default='transformer',\n                    choices=['transformer', 'lstm'])\nparser.add_argument('--num_epochs', type=int, default=25)\nparser.add_argument('--num_epochs_all_nodes', type=int, default=20)\nparser.add_argument('--eps_start', type=int, default=1)\nparser.add_argument('--eps_length', type=int, default=10)\nparser.add_argument('--log_interval', type=int, default=100)\nparser.add_argument('--min_word_freq', type=int, default=2)\nparser.add_argument('--batch_size', type=int, default=32)\nparser.add_argument('--oracle_batch_size', type=int, default=1024)\nparser.add_argument('--gradient_accumulation_steps', type=int, default=1)\nparser.add_argument('--max_sent_length', type=int, default=32)\nparser.add_argument('--vocab_size', type=int, default=50000)\nparser.add_argument('--lr', type=float, default=1e-4)\nparser.add_argument('--lr_decay', type=float, default=1)\nparser.add_argument('--grad_clip', type=float, default=10.0)\nparser.add_argument('--num_classes', type=int, default=2)\nparser.add_argument('--num_layers', type=int, default=1)\nparser.add_argument('--num_attention_heads', type=int, default=4)\nparser.add_argument('--hidden_size', type=int, default=64)\nparser.add_argument('--embedding_size', type=int, default=64)\nparser.add_argument('--intermediate_size', type=int, default=128)\nparser.add_argument('--drop_unk', action='store_true')\nparser.add_argument('--hidden_act', type=str, default='relu')\nparser.add_argument('--layer_norm', type=str, default='no_var',\n                    choices=['standard', 'no', 'no_var'])\nparser.add_argument('--loss_fusion', action='store_true')\nparser.add_argument('--dropout', type=float, default=0.1)\nparser.add_argument('--bound_opts_relu', type=str, default='zero-lb')\n\nargs = parser.parse_args()\n\nwriter = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10)\nfile_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log'))\nfile_handler.setFormatter(logging.Formatter('%(levelname)-8s %(asctime)-12s %(message)s'))\nlogger.addHandler(file_handler)\n\ndata_train_all_nodes, data_train, data_dev, data_test = load_data(args.data)\nif args.robust:\n    data_dev, data_test = clean_data(data_dev), clean_data(data_test)\nif args.auto_test:\n    random.seed(args.seed)\n    random.shuffle(data_test)\n    data_test = data_test[:10]\n    assert args.batch_size >= 10\n    # Use double precision and deterministic algorithm for automatic testing.\n    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'\n    torch.use_deterministic_algorithms(True)\n    torch.set_default_dtype(torch.float64)\n\nlogger.info('Dataset sizes: {}/{}/{}/{}'.format(\n    len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test)))\n\nrandom.seed(args.seed)\nnp.random.seed(args.seed)\ntorch.manual_seed(args.seed)\ntorch.cuda.manual_seed_all(args.seed)\n\ndummy_embeddings = torch.zeros(1, args.max_sent_length, args.embedding_size, device=args.device)\ndummy_labels = torch.zeros(1, dtype=torch.long, device=args.device)\n\nif args.model == 'transformer':\n    dummy_mask = torch.zeros(1, 1, 1, args.max_sent_length, device=args.device)\n    model = Transformer(args, data_train)\nelif args.model == 'lstm':\n    dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device)\n    model = LSTM(args, data_train)\n\ndev_batches = get_batches(data_dev, args.batch_size)\ntest_batches = get_batches(data_test, args.batch_size)\n\nptb = PerturbationSynonym(budget=args.budget)\ndummy_embeddings = BoundedTensor(dummy_embeddings, ptb)\nmodel_ori = model.model_from_embeddings\nbound_opts = { 'activation_bound_option': args.bound_opts_relu, 'exp': 'no-max-input', 'fixed_reducemax_index': True }\nif isinstance(model_ori, BoundedModule):\n    model_bound = model_ori\nelse:\n    model_bound = BoundedModule(\n        model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device)\nmodel.model_from_embeddings = model_bound\nif args.loss_fusion:\n    bound_opts['loss_fusion'] = True\n    model_loss = BoundedModule(\n        CrossEntropyWrapperMultiInput(model_ori),\n        (torch.zeros(1, dtype=torch.long), dummy_embeddings, dummy_mask),\n        bound_opts=bound_opts, device=args.device)\n\nptb.model = model\noptimizer = model.build_optimizer()\nif args.lr_decay < 1:\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=args.lr_decay)\nelse:\n    lr_scheduler = None\nif args.robust:\n    eps_scheduler = LinearScheduler(args.eps, 'start={},length={}'.format(args.eps_start, args.eps_length))\n    for i in range(model.checkpoint):\n        eps_scheduler.step_epoch(verbose=False)\nelse:\n    eps_scheduler = None\nlogger.info('Model converted to support bounds')\n\ndef step(model, ptb, batch, eps=1.0, train=False):\n    model_bound = model.model_from_embeddings\n    if train:\n        model.train()\n        model_bound.train()\n        grad = torch.enable_grad()\n        if args.loss_fusion:\n            model_loss.train()\n    else:\n        model.eval()\n        model_bound.eval()\n        grad = torch.no_grad()\n    if args.auto_test:\n        grad = torch.enable_grad()\n\n    with grad:\n        ptb.set_eps(eps)\n        ptb.set_train(train)\n        embeddings_unbounded, mask, tokens, labels = model.get_input(batch)\n        aux = (tokens, batch)\n        if args.robust and eps > 1e-9:\n            embeddings = BoundedTensor(embeddings_unbounded, ptb)\n        else:\n            embeddings = embeddings_unbounded.detach().requires_grad_(True)\n\n        robust = args.robust and eps > 1e-6\n\n        if train and robust and args.loss_fusion:\n            # loss_fusion loss\n            if args.method == 'IBP+backward_train':\n                lb, ub = model_loss.compute_bounds(\n                    x=(labels, embeddings, mask), aux=aux,\n                    C=None, method='IBP+backward', bound_lower=False)\n            else:\n                raise NotImplementedError\n            loss_robust = torch.log(ub).mean()\n            loss = acc = acc_robust = -1 # unknown\n        else:\n            # regular loss\n            logits = model_bound(embeddings, mask)\n            loss = CrossEntropyLoss()(logits, labels)\n            acc = (torch.argmax(logits, dim=1) == labels).float().mean()\n\n            if robust:\n                num_class = args.num_classes\n                c = torch.eye(num_class).type_as(embeddings)[labels].unsqueeze(1) - \\\n                    torch.eye(num_class).type_as(embeddings).unsqueeze(0)\n                I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))\n                c = (c[I].view(embeddings.size(0), num_class - 1, num_class))\n                if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']:\n                    lb, ub = model_bound.compute_bounds(aux=aux, C=c, method=args.method, bound_upper=False)\n                elif args.method == 'IBP+backward_train':\n                    # CROWN-IBP\n                    if 1 - eps > 1e-4:\n                        lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False)\n                        ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True)\n                        lb = eps * ilb + (1 - eps) * lb\n                    else:\n                        lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP')\n                else:\n                    raise NotImplementedError\n                lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n                fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n                loss_robust = robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n                acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float())\n            else:\n                acc_robust, loss_robust = acc, loss\n\n    if train or args.auto_test:\n        loss_robust.backward()\n        grad_embed = torch.autograd.grad(\n            embeddings_unbounded, model.word_embeddings.weight,\n            grad_outputs=embeddings.grad)[0]\n        if model.word_embeddings.weight.grad is None:\n            model.word_embeddings.weight.grad = grad_embed\n        else:\n            model.word_embeddings.weight.grad += grad_embed\n\n    if args.auto_test:\n        print('Saving results for automated tests.')\n        print(f'acc={acc}, loss={loss}, robust_acc={acc_robust}, robust_loss={loss_robust}')\n        print('gradients:')\n        print(grad_embed)\n        with open('res_test.pkl', 'wb') as file:\n            pickle.dump((\n                float(acc), float(loss), float(acc_robust), float(loss_robust),\n                grad_embed.detach().numpy()), file)\n\n    return acc, loss, acc_robust, loss_robust\n\ndef train(epoch, batches, type):\n    meter = MultiAverageMeter()\n    assert(optimizer is not None)\n    train = type == 'train'\n    if args.robust:\n        eps_scheduler.set_epoch_length(len(batches))\n        if train:\n            eps_scheduler.train()\n            eps_scheduler.step_epoch()\n        else:\n            eps_scheduler.eval()\n    for i, batch in enumerate(batches):\n        if args.robust:\n            eps_scheduler.step_batch()\n            eps = eps_scheduler.get_eps()\n        else:\n            eps = 0\n        acc, loss, acc_robust, loss_robust = step(\n            model, ptb, batch, eps=eps, train=train)\n        meter.update('acc', acc, len(batch))\n        meter.update('loss', loss, len(batch))\n        meter.update('acc_rob', acc_robust, len(batch))\n        meter.update('loss_rob', loss_robust, len(batch))\n        if train:\n            if (i + 1) % args.gradient_accumulation_steps == 0 or (i + 1) == len(batches):\n                scale_gradients(optimizer, i % args.gradient_accumulation_steps + 1, args.grad_clip)\n                optimizer.step()\n                optimizer.zero_grad()\n            if lr_scheduler is not None:\n                lr_scheduler.step()\n            writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'), i + 1)\n            writer.add_scalar('loss_robust_train_{}'.format(epoch), meter.avg('loss_rob'), i + 1)\n            writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'), i + 1)\n            writer.add_scalar('acc_robust_train_{}'.format(epoch), meter.avg('acc_rob'), i + 1)\n        if (i + 1) % args.log_interval == 0 or (i + 1) == len(batches):\n            logger.info('Epoch {}, {} step {}/{}: eps {:.5f}, {}'.format(\n                epoch, type, i + 1, len(batches), eps, meter))\n            if lr_scheduler is not None:\n                logger.info('lr {}'.format(lr_scheduler.get_lr()))\n    writer.add_scalar('loss/{}'.format(type), meter.avg('loss'), epoch)\n    writer.add_scalar('loss_robust/{}'.format(type), meter.avg('loss_rob'), epoch)\n    writer.add_scalar('acc/{}'.format(type), meter.avg('acc'), epoch)\n    writer.add_scalar('acc_robust/{}'.format(type), meter.avg('acc_rob'), epoch)\n\n    if train:\n        if args.loss_fusion:\n            state_dict_loss = model_loss.state_dict()\n            state_dict = {}\n            for name in state_dict_loss:\n                assert(name.startswith('model.'))\n                state_dict[name[6:]] = state_dict_loss[name]\n            model_ori.load_state_dict(state_dict)\n            model_bound = BoundedModule(\n                model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device)\n            model.model_from_embeddings = model_bound\n        model.save(epoch)\n\n    return meter.avg('acc_rob')\n\ndef main():\n    if args.train:\n        for t in range(model.checkpoint, args.num_epochs):\n            if t + 1 <= args.num_epochs_all_nodes:\n                train(t + 1, get_batches(data_train_all_nodes, args.batch_size), 'train')\n            else:\n                train(t + 1, get_batches(data_train, args.batch_size), 'train')\n            train(t + 1, dev_batches, 'dev')\n            train(t + 1, test_batches, 'test')\n    elif args.oracle:\n        oracle(args, model, ptb, data_test, 'test')\n    else:\n        if args.robust:\n            for i in range(args.num_epochs):\n                eps_scheduler.step_epoch(verbose=False)\n            res = []\n            for i in range(1, args.budget + 1):\n                logger.info('budget {}'.format(i))\n                ptb.budget = i\n                acc_rob = train(None, test_batches, 'test')\n                res.append(acc_rob)\n            logger.info('Verification results:')\n            for i in range(len(res)):\n                logger.info('budget {} acc_rob {:.3f}'.format(i + 1, res[i]))\n            logger.info(res)\n        else:\n            train(None, test_batches, 'test')\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/sequence/.gitignore",
    "content": "model/\ndata/"
  },
  {
    "path": "examples/sequence/__init__.py",
    "content": ""
  },
  {
    "path": "examples/sequence/data_utils.py",
    "content": "import random\nfrom torchvision import transforms\nfrom torchvision.datasets.mnist import MNIST as mnist\n\ndef load_data():\n    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n    data_train = mnist(\"data\", train=True, download=True, transform=transform)\n    data_test = mnist(\"data\", train=False, download=True, transform=transform)\n    data_train = [data_train[i] for i in range(len(data_train))]\n    data_test = [data_test[i] for i in range(len(data_test))]\n    return data_train, data_test\n\ndef get_batches(data, batch_size):\n    batches = []\n    random.shuffle(data)\n    for i in range((len(data) + batch_size - 1) // batch_size):\n        batches.append(data[i * batch_size : (i + 1) * batch_size])\n    return batches"
  },
  {
    "path": "examples/sequence/lstm.py",
    "content": "import os\nimport shutil\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA.utils import logger\n\nclass LSTMCore(nn.Module):\n    def __init__(self, args):\n        super(LSTMCore, self).__init__()\n\n        self.input_size = args.input_size // args.num_slices\n        self.hidden_size = args.hidden_size\n        self.num_classes = args.num_classes\n        self.device = args.device\n\n        self.cell_f = nn.LSTMCell(self.input_size, self.hidden_size)\n        self.linear = nn.Linear(self.hidden_size, self.num_classes)\n\n    def forward(self, X):\n        batch_size, length = X.shape[0], X.shape[1]\n        h_f = torch.zeros(batch_size, self.hidden_size).to(X.device)\n        c_f = h_f.clone()\n        h_f_sum = h_f.clone()\n        for i in range(length):\n            h_f, c_f = self.cell_f(X[:, i], (h_f, c_f))\n            h_f_sum = h_f_sum + h_f\n        states = h_f_sum / float(length)\n        logits = self.linear(states)\n        return logits\n\nclass LSTM(nn.Module):\n    def __init__(self, args):\n        super(LSTM, self).__init__()\n        self.args = args\n        self.device = args.device\n        self.lr = args.lr\n        self.num_slices = args.num_slices\n\n        self.dir = args.dir\n        if not os.path.exists(self.dir):\n            os.makedirs(self.dir)\n        self.checkpoint = 0\n        self.model = LSTMCore(args)\n        if args.load:\n            self.model.load_state_dict(args.load)\n            logger.info(f\"Model loaded: {args.load}\")\n        else:\n            logger.info(\"Model initialized\")\n        self.model = self.model.to(self.device)\n        self.core = self.model\n\n    def save(self, epoch):\n        output_dir = os.path.join(self.dir, \"ckpt-%d\" % epoch)\n        if os.path.exists(output_dir):\n            shutil.rmtree(output_dir)\n        os.mkdir(output_dir)\n        path = os.path.join(output_dir, \"model\")\n        torch.save(self.core.state_dict(), path)\n        with open(os.path.join(self.dir, \"checkpoint\"), \"w\") as file:\n            file.write(str(epoch))\n        logger.info(\"LSTM saved: %s\" % output_dir)\n\n    def build_optimizer(self):\n        param_group = []\n        for p in self.core.named_parameters():\n            param_group.append(p)\n        param_group = [{\"params\": [p[1] for p in param_group], \"weight_decay\": 0.}]\n        return torch.optim.Adam(param_group, lr=self.lr)\n\n    def get_input(self, batch):\n        X = torch.cat([example[0].reshape(1, self.num_slices, -1) for example in batch])\n        y = torch.tensor([example[1] for example in batch], dtype=torch.long)\n        return X.to(self.device), y.to(self.device)\n\n    def train(self):\n        self.core.train()\n\n    def eval(self):\n        self.core.eval()"
  },
  {
    "path": "examples/sequence/train.py",
    "content": "import argparse\nimport random\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom lstm import LSTM\nfrom data_utils import load_data, get_batches\nfrom auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--seed\", type=int, default=0)\nparser.add_argument(\"--load\", type=str, default=None)\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cuda\", \"cpu\"])\nparser.add_argument(\"--norm\", type=int, default=np.inf)\nparser.add_argument(\"--eps\", type=float, default=0.1)\nparser.add_argument(\"--num_epochs\", type=int, default=20)\nparser.add_argument(\"--batch_size\", type=int, default=512)\nparser.add_argument(\"--num_slices\", type=int, default=8)\nparser.add_argument(\"--hidden_size\", type=int, default=256)\nparser.add_argument(\"--num_classes\", type=int, default=10)\nparser.add_argument(\"--input_size\", type=int, default=784)\nparser.add_argument(\"--lr\", type=float, default=1e-2)\nparser.add_argument(\"--dir\", type=str, default=\"model\", help=\"directory to load or save the model\")\nparser.add_argument(\"--num_epochs_warmup\", type=int, default=10, help=\"number of epochs for the warmup stage when eps is linearly increased from 0 to the full value\")\nparser.add_argument(\"--log_interval\", type=int, default=10, help=\"interval of printing the log during training\")\nargs = parser.parse_args()\n\n\n## Train or test one batch.\ndef step(model, ptb, batch, eps=args.eps, train=False):\n    # We increase the perturbation each batch.\n    ptb.set_eps(eps)\n    # We create a BoundedTensor object with current batch of data.\n    X, y = model.get_input(batch)\n    X = BoundedTensor(X, ptb)\n    logits = model.core(X)\n\n    # Form the linear speicifications, which are margins of ground truth class and other classes.\n    num_class = args.num_classes\n    c = get_spec_matrix(X, y, num_class)\n\n    # Compute CROWN-IBP (IBP+backward) bounds for training. We only need the lower bound.\n    # Here we can omit the x=(X,) argument because we have just used X for forward propagation.\n    lb, ub = model.core.compute_bounds(C=c, method='CROWN-IBP', bound_upper=False)\n\n    # Compute robust cross entropy loss.\n    lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n    fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n    loss = nn.CrossEntropyLoss()(-lb_padded, fake_labels)\n\n    # Report accuracy and robust accuracy.\n    acc = (torch.argmax(logits, dim=-1) == y).float().mean()\n    acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float())\n\n    if train:\n        loss.backward()\n\n    return acc.detach(), acc_robust.detach(), loss.detach()\n\n\n## Train one epoch.\ndef train(epoch):\n    meter = MultiAverageMeter()\n    model.train()\n    # Load data for a epoch.\n    train_batches = get_batches(data_train, args.batch_size)\n\n    eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches))\n\n    for i, batch in enumerate(train_batches):\n        # We increase eps linearly every batch.\n        eps = args.eps * min(eps_inc_per_step * ((epoch - 1) * len(train_batches) + i + 1), 1.0)\n        # Call the main training loop.\n        acc, acc_robust, loss = step(model, ptb, batch, eps=eps, train=True)\n        # Optimize the loss.\n        torch.nn.utils.clip_grad_norm_(model.core.parameters(), 5.0)\n        optimizer.step()\n        optimizer.zero_grad()\n        meter.set_batch_size(len(batch))\n        meter.update('acc', acc)\n        meter.update('acc_rob', acc_robust)\n        meter.update('loss', loss)\n        if (i + 1) % args.log_interval == 0:\n            logger.info(\"Epoch {}, training step {}/{}: {}, eps {:.3f}\".format(\n                epoch, i + 1, len(train_batches), meter, eps))\n    model.save(epoch)\n\n\n## Test accuracy and robust accuracy.\ndef test(epoch, batches):\n    meter = MultiAverageMeter()\n    model.eval()\n    for batch in batches:\n        acc, acc_robust, loss = step(model, ptb, batch)\n        meter.set_batch_size(len(batch))\n        meter.update('acc', acc)\n        meter.update('acc_rob', acc_robust)\n        meter.update('loss', loss)\n    logger.info(\"Epoch {} test: {}\".format(epoch, meter))\n\n# Load MNIST dataset\nlogger.info(\"Loading data...\")\ndata_train, data_test = load_data()\nlogger.info(\"Dataset sizes: {}/{}\".format(len(data_train), len(data_test)))\ntest_batches = get_batches(data_test, args.batch_size)\n\n# Set all random seeds.\nrandom.seed(args.seed)\nnp.random.seed(args.seed)\ntorch.manual_seed(args.seed)\ntorch.cuda.manual_seed_all(args.seed)\n\n# Create a LSTM sequence classifier.\nlogger.info(\"Creating LSTM model...\")\nmodel = LSTM(args).to(args.device)\nX, y = model.get_input(test_batches[0])\n# Create the perturbation object once here, and we can reuse it.\nptb = PerturbationLpNorm(norm=args.norm, eps=args.eps)\n# Convert the LSTM to BoundedModule\nX = BoundedTensor(X, ptb)\nmodel.core = BoundedModule(model.core, (X,), device=args.device)\noptimizer = model.build_optimizer()\n\n# Main training loop.\nfor t in range(model.checkpoint, args.num_epochs):\n    train(t + 1)\n    test(t + 1, test_batches)\n\n# If the loaded model has already reached the last epoch, test it directly.\nif model.checkpoint == args.num_epochs:\n    test(args.num_epochs, test_batches)\n\n"
  },
  {
    "path": "examples/simple/invprop.py",
    "content": "\"\"\"\nA toy example for bounding neural network outputs under input perturbations using INVPROP\n\nSee https://arxiv.org/abs/2302.01404\n\"\"\"\nimport torch\nfrom collections import defaultdict\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\n\nclass simple_model(torch.nn.Module):\n    \"\"\"\n    A very simple 2-layer neural network for demonstration.\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n        # Weights of linear layers.\n        self.w1 = torch.tensor([[1., -1.], [2., -1.]])\n        self.w2 = torch.tensor([[1., -1.]])\n\n    def forward(self, x):\n        # Linear layer.\n        z1 = x.matmul(self.w1.t())\n        # Relu layer.\n        hz1 = torch.nn.functional.relu(z1)\n        # Linear layer.\n        z2 = hz1.matmul(self.w2.t())\n        return z2\n\n\nmodel = simple_model()\n\n# Input x.\nx = torch.tensor([[1., 1.]])\n# Lowe and upper bounds of x.\nlower = torch.tensor([[-1., -2.]])\nupper = torch.tensor([[2., 1.]])\n\n# Compute bounds using LiRPA using the given lower and upper bounds.\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper)\nbounded_x = BoundedTensor(x, ptb)\n\n# INVPROP configuration\n# apply_output_constraints_to: list of layer names or types to which the output\n#     constraints should be applied. Here, they will be applied to all layers of type\n#     'BoundMatMul' and 'BoundInput'. To only apply them to specific layers, use their\n#     names, e.g. ['/0', '/z1']. The currently recommended way to get those names is\n#     either to first construct an instance of BoundedModule with arbitrary bound_opts,\n#     print it to stdout and inspect their names manually, or to access the layer names\n#     as lirpa_model.final_node().inputs[0].inputs[0].name\n# tighten_input_bounds: whether to tighten the input bounds. This will modify the\n#     perturbation of the input. If set, apply_output_constraints_to should contain\n#     'BoundInput' or the corresponding layer name. Otherwise, this will have no effect.\n#     Similiar, adding 'BoundInput' to apply_output_constraints_to will have no effect\n#     unless tighten_input_bounds is set.\n# best_of_oc_and_no_oc: Using output constraints may sometimes lead to worse results,\n#     because the optimization might find bad local minima. If this is set to True,\n#     every optimization step will be run twice, once with and once without output\n#     constraints, and the better result will be chosen.\n# directly_optimize: Usually, only linear layers preceeding non-linear layers are\n#     optimized using output constraints. If you want to optimize a specific layer that\n#     would usually be skipped, add it's name to this list. This is most likely to be\n#     used when preimages should be computed as they might use linear combinations of\n#     the inputs. This requires the use of sequential linear layers. For detailed\n#     examples, see https://github.com/kothasuhas/verify-input\n# oc_lr: Learning rate for the optimization of output constraints.\n# share_gammas: Whether neurons in each layer should share the same gamma\n\nlirpa_model = BoundedModule(model, torch.empty_like(x), bound_opts={\n    'optimize_bound_args': {\n        'apply_output_constraints_to': ['BoundMatMul', 'BoundInput'],\n        'tighten_input_bounds': True,\n        'best_of_oc_and_no_oc': False,\n        'directly_optimize': [],\n        'oc_lr': 0.1,\n        'share_gammas': False,\n        'iteration': 1000,\n    }\n})\n# To dynamically set the apply_output_constraints_to option, set it to `[]` in the\n# above code, and then use the following:\n# lirpa_model.set_bound_opts({\n#   'optimize_bound_args': {\n#     'apply_output_constraints_to': [\n#       lirpa_model.final_node().inputs[0].inputs[0].inputs[0].name,\n#       lirpa_model.final_node().inputs[0].inputs[0].name,\n#     ]\n#   }\n# })\n\n# The scalar output must be <= -1\n# Constraints have the shape [1, num_constraints, num_output_neurons]\n# They are treated as conjunctions, i.e., all constraints must be satisfied.\nlirpa_model.constraints = torch.ones(1,1,1)\n# Thresholds have the shape [num_constraints]\nlirpa_model.thresholds = torch.tensor([-1.])\n\nprint(f\"Original perturbation: x0: [{ptb.x_L[0][0]}, {ptb.x_U[0][0]}], x1: [{ptb.x_L[0][1]}, {ptb.x_U[0][1]}]\")\nlb, ub = lirpa_model.compute_bounds(x=(bounded_x,), method='alpha-CROWN')\ntightened_ptb = lirpa_model['/0'].perturbation\nprint(f\"Tightened perturbation: x0: [{tightened_ptb.x_L[0][0]}, {tightened_ptb.x_U[0][0]}], x1: [{tightened_ptb.x_L[0][1]}, {tightened_ptb.x_U[0][1]}]\")\n\n# For the bounds without output constraints, refer to toy.py\nprint(f'alpha-CROWN bounds without output constraints: lower=-3, upper=2')\nprint(f'alpha-CROWN bounds with output constraints: lower={lb.item()}, upper={ub.item()}')"
  },
  {
    "path": "examples/simple/lp_full.py",
    "content": "\"\"\"\nA simple example for bounding neural network outputs using LP/MIP solvers.\n\nAuto_LiRPA supports constructing LP/MIP optimization formulations (using\nGurobi).  This example uses LP to solve all intermediate layer bounds and\nfinal layer bounds, reflecting the setting in the paper \"A Convex\nRelaxation Barrier to Tight Robustness Verification of Neural Networks\".\nThis is sometimes referred to as the LP-Full setting. This is in general,\nvery slow; alpha-CROWN is generally recommended to compute intermediate\nlayer bound rather than LP.\n\nExample usage: python lp_full.py --index 0 --norm 2.0 --perturbation 1.0\n\nHere `--index` is the dataset index (MNIST in this example), `--norm` is\nthe Lp perturbation norm used and `--perturbation` is the magnitude of\nthe perturbation added to model input.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.operators import BoundLinear, BoundConv\nimport gurobipy as grb\nimport time\nimport numpy as np\nimport argparse\n\n# Help function for generating output matrix. This function used for \n# generating matrix C to calculate the margin between true class and \n# the other classes.\ndef build_C(label, classes):\n    \"\"\"\n    label: shape (B,). Each label[b] in [0..classes-1].\n    Return:\n        C: shape (B, classes-1, classes).\n        For each sample b, each row is a \"negative class\" among [0..classes-1]\\{label[b]}.\n        Puts +1 at column=label[b], -1 at each negative class column.\n    \"\"\"\n    device = label.device\n    batch_size = label.size(0)\n    \n    # 1) Initialize\n    C = torch.zeros((batch_size, classes-1, classes), device=device)\n    \n    # 2) All class indices\n    # shape: (1, K) -> (B, K)\n    all_cls = torch.arange(classes, device=device).unsqueeze(0).expand(batch_size, -1)\n    \n    # 3) Negative classes only, shape (B, K-1)\n    # mask out the ground-truth\n    mask = all_cls != label.unsqueeze(1)\n    neg_cls = all_cls[mask].view(batch_size, -1)\n    \n    # 4) Scatter +1 at each sample’s ground-truth label\n    #    shape needed: (B, K-1, 1)\n    pos_idx = label.unsqueeze(1).expand(-1, classes-1).unsqueeze(-1)\n    C.scatter_(dim=2, index=pos_idx, value=1.0)\n    \n    # 5) Scatter -1 at each row’s negative label\n    #    We have (B, K-1) negative labels. For row j in each sample b, neg_cls[b, j] is that row’s negative label\n    row_idx = torch.arange(classes-1, device=device).unsqueeze(0).expand(batch_size, -1)\n    # shape: (B, K-1)\n    \n    # We can do advanced indexing:\n    C[torch.arange(batch_size).unsqueeze(1), row_idx, neg_cls] = -1.0\n    \n    return C\n    \nparser = argparse.ArgumentParser()\nparser.add_argument('--index', default=0, type=int, help='Index of data example (from MNIST dataset).')\nparser.add_argument('--norm', default='inf', type=str, help='Input perturbation norm.')\nparser.add_argument('--perturbation', default=0.05, type=float, help='Input perturbation magnitude.')\nparser.add_argument('--lr', default=0.5, type=float, help='Learning rate for alpha_crown.')\nparser.add_argument('--iteration', default=30, type=int, help='Iterations for alpha_crown.')\nargs = parser.parse_args()\n\n## Step 1: Define computational graph by implementing forward()\n# You can create your own model here.\nmodel = nn.Sequential(\n\t\tnn.Flatten(),\n\t\tnn.Linear(784, 100),\n\t\tnn.ReLU(),\n\t\tnn.Linear(100, 100),\n\t\tnn.ReLU(),\n\t\tnn.Linear(100, 10)\n\t)\n# Optionally, load the pretrained weights.\ncheckpoint = torch.load('./models/spectral_NOR_MLP_B.pth', weights_only=True)\nmodel.load_state_dict(checkpoint)\n\n## Step 2: Prepare dataset.\ntest_data = torchvision.datasets.MNIST(\n    './data', train=False, download=True,\n    transform=torchvision.transforms.ToTensor())\n\nn_classes = 10\nimage = test_data.data[args.index].to(torch.float32).unsqueeze(0).unsqueeze(0) / 255.0\ntrue_label = torch.tensor([test_data.targets[args.index]])\n\n## Step 3: Define perturbation.\neps = args.perturbation\nnorm = float(args.norm)\n# The upper bound and lower bound of mnist dataset is [0,1],\n# replace the bounds if using other dataset.\nif norm == float('inf'):\n    x_U = None\n    x_L = None\nelse:\n    x_U = torch.ones_like(image)\n    x_L = torch.zeros_like(image)\nptb = PerturbationLpNorm(norm = norm, eps = eps, x_U = x_U, x_L = x_L)\nprint(f'Verification of MNIST data index {args.index} with L{args.norm} perturbation of {args.perturbation}\\n')\n# Here we only use one image as input.\nimage = BoundedTensor(image, ptb)\nprint('Running LP-Full with LPs for all intermediate layers...')\nstart_time = time.time()\n\n## Step 4: Compute the bounds of different methods.\n# For CROWN/alpha-CROWN, we use the compute_bounds() method.\n# For LP and MIP, we use the build_solver_module() method.\ninterm_bounds = {}\nlirpa_model = BoundedModule(model, image, device=image.device)\n# Store the output shape for each layer first\nfor node in lirpa_model.nodes():\n    # For each intermediate layers, we first set their bound to be infinity as placeholder.\n    if hasattr(node, 'output_shape'):\n        interm_lb = torch.full(node.output_shape, -float('inf'))\n        interm_ub = torch.full(node.output_shape, float('inf'))\n        interm_bounds[node.name] = [interm_lb, interm_ub]\n\n# C is the specification matrix (groundtruth - target class).\nC = build_C(true_label, classes=n_classes)\n# Here we assume that the last node is the model output, and we start from intermdiate layers first.\n# Technically, here we need a topological sort of all model nodes if the computation graph is general.\nfor node in lirpa_model.nodes():\n    # For simplicity, we assume the model contains linear, conv, and ReLU layers.\n    # We need to calculate the preactivation bounds before each ReLU layer, which are the bounds for linear of conv layers.\n    if isinstance(node, (BoundLinear, BoundConv)):\n        interm_lb = torch.full(node.output_shape, -float('inf'))\n        interm_ub = torch.full(node.output_shape, float('inf'))\n        if node.is_final_node:\n            print(f'Solving LPs for final layer bounds...')\n            # Last node, all intermediate layer bounds have been obtained.\n            # For last node, we need to use the specification matrix C to calculate the bounds on groundtruth - target labels.\n            solver_vars = lirpa_model.build_solver_module(model_type='lp', x=(image,), final_node_name=node.name, interm_bounds=interm_bounds, C=C)\n            lirpa_model.solver_model.setParam('OutputFlag', 0)\n            final_lb = torch.empty(n_classes-1)\n            final_ub = torch.empty(n_classes-1)\n            for i in range(n_classes-1):\n                print(f'Solving class {i}...')\n                # Now you can define objectives based on the variables on the output layer.\n                # And then solve them using gurobi. Here we just output the lower and upper\n                # bounds for each output neuron.\n                # Solve upper bound.\n                lirpa_model.solver_model.setObjective(solver_vars[i], grb.GRB.MAXIMIZE)\n                lirpa_model.solver_model.optimize()\n                # If the solver does not terminate, you will get a NaN.\n                if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                    final_ub[i] = lirpa_model.solver_model.objVal\n                # Solve lower bound.\n                lirpa_model.solver_model.setObjective(solver_vars[i], grb.GRB.MINIMIZE)\n                lirpa_model.solver_model.optimize()\n                if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                    final_lb[i] = lirpa_model.solver_model.objVal\n        else:\n            print(f'Solving LPs for layer {node.name} intermediate layer bounds...')\n            # Solve intermediate layer bounds, one by one.\n            solver_vars = lirpa_model.build_solver_module(model_type='lp', x=(image,), final_node_name=node.name, interm_bounds=interm_bounds)\n            lirpa_model.solver_model.setParam('OutputFlag', 0)\n            # For linear layer, the solver_vars shape is: (neurons).\n            if isinstance(node, BoundLinear):\n                for i, var in enumerate(solver_vars):\n                    lirpa_model.solver_model.setObjective(var, grb.GRB.MAXIMIZE)\n                    lirpa_model.solver_model.optimize()\n                    if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                        interm_ub[0][i] = lirpa_model.solver_model.objVal\n                    # Solve lower bound.\n                    lirpa_model.solver_model.setObjective(var, grb.GRB.MINIMIZE)\n                    lirpa_model.solver_model.optimize()\n                    if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                        interm_lb[0][i] = lirpa_model.solver_model.objVal\n            # For convolutional layer, the solver_vars shape is (channel, out_w, out_h).\n            elif isinstance(node, BoundConv):\n                for i,channel in enumerate(solver_vars):\n                    for j, row in enumerate(channel):\n                        for k, var in enumerate(row):\n                            lirpa_model.solver_model.setObjective(var, grb.GRB.MAXIMIZE)\n                            lirpa_model.solver_model.optimize()\n                            if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                                interm_ub[0][i][j][k] = lirpa_model.solver_model.objVal\n                            # Solve lower bound.\n                            lirpa_model.solver_model.setObjective(var, grb.GRB.MINIMIZE)\n                            lirpa_model.solver_model.optimize()\n                            if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                                interm_lb[0][i][j][k] = lirpa_model.solver_model.objVal\n            interm_bounds[node.name] = [interm_lb, interm_ub]\n        print(f'Finished solving layer {node.name} with {len(solver_vars)} neurons')\nend_time = time.time()\nlp_time = end_time - start_time\nprint(f'LP-Full time: {lp_time}\\n')\n\nlirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)\nlirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': args.iteration, 'lr_alpha': args.lr}})\nstart_time = time.time()\nprint(f'Running alpha-CROWN with {args.iteration} iterations and learning rate of {args.lr}...')\ncrown_lb, crown_ub = lirpa_model.compute_bounds(x=(image, ), C=C, method='alpha-CROWN')\nend_time = time.time()\nalpha_crown_time = end_time - start_time\nprint(f'alpha-CROWN time: {alpha_crown_time}')\n\n# Step 5: output the final results of each method.\nprint(f'\\nResults for dataset index: {args.index}')\nprint(f'LP-Full bounds:')\nfor i in range(n_classes-1):\n    if i == true_label.item():\n        label = i + 1\n    else:\n        label = i\n    print('{l:8.3f} <= f_{k} - f_{j} <= {u:8.3f}'.format(\n        k=true_label.item(), j=label, l=final_lb[i].item(), u=final_ub[i].item()))\n\n# Alpha-CROWN should achieve similar results as LP full but without running any LPs.\nprint(f'\\nalpha-CROWN bounds:')\nfor i in range(n_classes-1):\n    if i == true_label.item():\n        label = i + 1\n    else:\n        label = i\n    print('{l:8.3f} <= f_{k} - f_{j} <= {u:8.3f}'.format(\n        k=true_label.item(), j=label, l=crown_lb[0][i].item(), u=crown_ub[0][i].item()))\nprint(f'alpha-CROWN bounds and LP-full bounds should be close for Linf norm; '\n      'adjust the number of iterations and learning rate when necessary.\\n')\n"
  },
  {
    "path": "examples/simple/mip_lp_solver.py",
    "content": "\"\"\"\nA simple example for bounding neural network outputs using LP/MIP solvers.\n\nAuto_LiRPA supports constructing LP/MIP optimization formulations (using\nGurobi).  This example serves as a skeleton for using the build_solver_module()\nmethod to obtain LP/MIP formulations of neural networks.\n\nNote that alpha-CROWN is used to calculate intermediate layer bounds for\nconstructing the convex relaxation of ReLU neurons. So we are actually using\n\"alpha-CROWN+MIP\" or \"alpha-CROWN+LP\" here. Calculating intermediate layer\nbounds using LP/MIP is often impractical due to the high cost.\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nimport gurobipy as grb\n\n## Step 1: Define computational graph by implementing forward()\n# You can create your own model here.\nclass mnist_model(nn.Module):\n    def __init__(\n            self, input_size=28*28, hidden_size=128,\n            hidden_size_2=64, output_size=10):\n        super(mnist_model, self).__init__()\n        self.fc1 = nn.Linear(input_size, hidden_size)\n        self.fc2 = nn.Linear(hidden_size, hidden_size_2)\n        self.fc3 = nn.Linear(hidden_size_2, output_size)\n        self.relu = nn.ReLU()\n        \n    def forward(self, x):\n        x = x.view(-1, 784)\n        out = self.fc1(x)\n        out = self.relu(out)\n        out = self.fc2(out)\n        out = self.relu(out)\n        out = self.fc3(out)\n        return out\n\nmodel = mnist_model()\n# Optionally, load the pretrained weights.\ncheckpoint = torch.load('../vision/pretrained/mnist_fc_3layer.pth')\nmodel.load_state_dict(checkpoint)\n\n## Step 2: Prepare dataset.\ntest_data = torchvision.datasets.MNIST(\n    './data', train=False, download=True,\n    transform=torchvision.transforms.ToTensor())\n# For illustration we only use 1 image from dataset.\nN = 1\nn_classes = 10\nimage = test_data.data[:N].view(N, 1, 28, 28)\ntrue_label = test_data.targets[:N]\nimage = image.to(torch.float32) / 255.0\n\n## Step 3: Define perturbation.\neps = 0.03\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, eps = eps)\n# Here we only use one image as input\nimage = BoundedTensor(image[0], ptb)\n\n## Step 4: Compute the bounds of different methods.\n# For CROWN/alpha-CROWN, we use the compute_bounds() method.\n# For LP and MIP, we use the build_solver_module() method.\nresult = {}\n# Note that here 'lp' or 'mip' are essentially 'alpha-CROWN+lp' and 'alpha-CROWN+mip'.\n# We use alpha-CROWN to calculate all the intermediate layer bounds for LP/MIP, because\n# using MIP/LP for all intermediate neurons will be very slow.\nfor method in ['alpha-CROWN','lp','mip']:\n    # To get clean results and avoid interference among methods, we create a\n    # new BoundedModule object.  However, in your production code please pay\n    # attention that BoundedModule() has high construction overhead.\n    lirpa_model = BoundedModule(model, torch.empty_like(image[0]), device=image.device)\n    # Call alpha-CROWN first, which gives all intermediate layer bounds.\n    lb, ub = lirpa_model.compute_bounds(x=(image,), method='alpha-CROWN')\n\n    if method != 'alpha-CROWN':\n        lb = torch.full_like(lb, float('nan'))\n        ub = torch.full_like(ub, float('nan'))\n        # Obtain the optimizer (Gurobi) variables for the output layer.\n        # Auto_LiRPA will construct the LP/MIP formulation based on computation graph.\n        # Note that pre-activation bounds are required for using this function.\n        # Preactivation bounds have been computed using alpha-CROWN above.\n        solver_vars = lirpa_model.build_solver_module(model_type=method)\n        # Set some parameters for Gurobi optimizer.\n        lirpa_model.solver_model.setParam('OutputFlag', 0)\n        for i in range(n_classes):\n            print(f'Solving class {i} with method {method}')\n            # Now you can define objectives based on the variables on the output layer.\n            # And then solve them using gurobi. Here we just output the lower and upper\n            # bounds for each output neuron.\n            # Solve upper bound.\n            lirpa_model.solver_model.setObjective(solver_vars[i], grb.GRB.MAXIMIZE)\n            lirpa_model.solver_model.optimize()\n            # If the solver does not terminate, you will get a NaN.\n            if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                ub[0][i] = lirpa_model.solver_model.objVal\n            # Solve lower bound.\n            lirpa_model.solver_model.setObjective(solver_vars[i], grb.GRB.MINIMIZE)\n            lirpa_model.solver_model.optimize()\n            if lirpa_model.solver_model.status == grb.GRB.Status.OPTIMAL:\n                lb[0][i] = lirpa_model.solver_model.objVal\n    result[method] = (lb, ub)\n\n# Step 5: output the final results of each method.\nfor method in result.keys():    \n    print(f'Bounding method: {method}')\n    lb, ub = result[method]\n    for i in range(n_classes):\n        print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}'.format(\n            j=i, l=lb[0][i].item(), u=ub[0][i].item()))\n"
  },
  {
    "path": "examples/simple/toy.py",
    "content": "\"\"\"\nA toy example for bounding neural network outputs under input perturbations.\n\"\"\"\nimport torch\nfrom collections import defaultdict\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\n\nclass simple_model(torch.nn.Module):\n    \"\"\"\n    A very simple 2-layer neural network for demonstration.\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n        # Weights of linear layers.\n        self.w1 = torch.tensor([[1., -1.], [2., -1.]])\n        self.w2 = torch.tensor([[1., -1.]])\n\n    def forward(self, x):\n        # Linear layer.\n        z1 = x.matmul(self.w1.t())\n        # Relu layer.\n        hz1 = torch.nn.functional.relu(z1)\n        # Linear layer.\n        z2 = hz1.matmul(self.w2.t())\n        return z2\n\n\nmodel = simple_model()\n\n# Input x.\nx = torch.tensor([[1., 1.]])\n# Lowe and upper bounds of x.\nlower = torch.tensor([[-1., -2.]])\nupper = torch.tensor([[2., 1.]])\n\n# Wrap model with auto_LiRPA for bound computation.\n# The second parameter is for constructing the trace of the computational graph,\n# and its content is not important.\nlirpa_model = BoundedModule(model, torch.empty_like(x))\npred = lirpa_model(x)\nprint(f'Model prediction: {pred.item()}')\n\n# Compute bounds using LiRPA using the given lower and upper bounds.\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper)\nbounded_x = BoundedTensor(x, ptb)\n\n# Compute bounds.\nlb, ub = lirpa_model.compute_bounds(x=(bounded_x,), method='IBP')\nprint(f'IBP bounds: lower={lb.item()}, upper={ub.item()}')\nlb, ub = lirpa_model.compute_bounds(x=(bounded_x,), method='CROWN')\nprint(f'CROWN bounds: lower={lb.item()}, upper={ub.item()}')\n\n# Getting the linear bound coefficients (A matrix).\nrequired_A = defaultdict(set)\nrequired_A[lirpa_model.output_name[0]].add(lirpa_model.input_name[0])\nlb, ub, A = lirpa_model.compute_bounds(x=(bounded_x,), method='CROWN', return_A=True, needed_A_dict=required_A)\nprint('CROWN linear (symbolic) bounds: lA x + lbias <= f(x) <= uA x + ubias, where')\nprint(A[lirpa_model.output_name[0]][lirpa_model.input_name[0]])\n\n# Opimized bounds, which is tighter.\nlb, ub, A = lirpa_model.compute_bounds(x=(bounded_x,), method='alpha-CROWN', return_A=True, needed_A_dict=required_A)\nprint(f'alpha-CROWN bounds: lower={lb.item()}, upper={ub.item()}')\nprint('alpha-CROWN linear (symbolic) bounds: lA x + lbias <= f(x) <= uA x + ubias, where')\nprint(A[lirpa_model.output_name[0]][lirpa_model.input_name[0]])\n"
  },
  {
    "path": "examples/vision/.gitignore",
    "content": "exp\nexp_inv\n__pycache__\nmodel_*\n!model_gurobi.py\nsaved_models\nconfig\n"
  },
  {
    "path": "examples/vision/bound_option.py",
    "content": "\"\"\"\nA simple example for bounding neural network outputs with different bound options on ReLU activation functions.\n\n\"\"\"\nimport os\nfrom collections import defaultdict\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\n\n## Step 1: Define computational graph by implementing forward()\n# This simple model comes from https://github.com/locuslab/convex_adversarial\ndef mnist_model():\n    model = nn.Sequential(\n        nn.Conv2d(1, 16, 4, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(16, 32, 4, stride=2, padding=1),\n        nn.ReLU(),\n        Flatten(),\n        nn.Linear(32*7*7,100),\n        nn.ReLU(),\n        nn.Linear(100, 10)\n    )\n    return model\n\nmodel = mnist_model()\n# Optionally, load the pretrained weights.\ncheckpoint = torch.load(\n    os.path.join(os.path.dirname(__file__), 'pretrained/mnist_a_adv.pth'),\n    map_location=torch.device('cpu'))\nmodel.load_state_dict(checkpoint)\n\n## Step 2: Prepare dataset as usual\ntest_data = torchvision.datasets.MNIST(\n    './data', train=False, download=True,\n    transform=torchvision.transforms.ToTensor())\n# For illustration we only use one image from dataset\nN = 1\nn_classes = 10\nimage = test_data.data[:N].view(N,1,28,28)\ntrue_label = test_data.targets[:N]\n# Convert to float\nimage = image.to(torch.float32) / 255.0\nif torch.cuda.is_available():\n    image = image.cuda()\n    model = model.cuda()\n\n## Step 3: wrap model with auto_LiRPA\n# Use default bound_option\nlirpa_model_default = BoundedModule(model, torch.empty_like(image), device=image.device)\n# Use same-slope option for ReLU functions\nlirpa_model_sameslope = BoundedModule(model, torch.empty_like(image), device=image.device, \n                                      bound_opts={'activation_bound_option': 'same-slope'})\nprint('Running on', image.device)\n\n## Step 4: Compute bounds using LiRPA given a perturbation\neps = 0.3\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, eps = eps)\nimage = BoundedTensor(image, ptb)\n# Get model prediction as usual\npred = lirpa_model_default(image)\nlabel = torch.argmax(pred, dim=1).cpu().detach().numpy()\n\nprint()\nprint('Demonstration 1.1: Bound computation and comparisons of different options.')\n## Step 5: Compute bounds for final output\nprint('Bounding method:', 'backward (CROWN)')\nprint('Bounding option:', 'Default (adaptive)')\nlb, ub = lirpa_model_default.compute_bounds(x=(image,), method='backward')\nfor i in range(N):\n    print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')\n    for j in range(n_classes):\n        indicator = '(ground-truth)' if j == true_label[i] else ''\n        print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(\n            j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\nprint()\n\nprint('Bounding option:', 'same-slope')\nlb, ub = lirpa_model_sameslope.compute_bounds(x=(image,), method='backward')\nfor i in range(N):\n    print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')\n    for j in range(n_classes):\n        indicator = '(ground-truth)' if j == true_label[i] else ''\n        print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(\n            j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\nprint()\n\nprint('Demonbstration 1.2: same-slope option is also available with CROWN-Optimized')\nprint('Bounding method:', 'CROWN-Optimized (alpha-CROWN)')\nprint('Bounding option:', 'Default (adaptive)')\nlb, ub = lirpa_model_default.compute_bounds(x=(image,), method='CROWN-Optimized')\nfor i in range(N):\n    print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')\n    for j in range(n_classes):\n        indicator = '(ground-truth)' if j == true_label[i] else ''\n        print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(\n            j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\nprint()\n\nprint('Bounding option:', 'same-slope')\nlb, ub = lirpa_model_sameslope.compute_bounds(x=(image,), method='CROWN-Optimized')\nfor i in range(N):\n    print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')\n    for j in range(n_classes):\n        indicator = '(ground-truth)' if j == true_label[i] else ''\n        print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(\n            j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\nprint()\n\n\nprint('Demonstration 2: Obtaining linear coefficients of the lower and upper bounds.')\nprint('With same-slope option, two linear coefficients should be the same.')\n# There are many bound coefficients during CROWN bound calculation; here we are interested in the linear bounds\n# of the output layer, with respect to the input layer (the image).\nrequired_A = defaultdict(set)\nrequired_A[lirpa_model_sameslope.output_name[0]].add(lirpa_model_sameslope.input_name[0])\n\nprint(\"Bounding method:\", 'backward')\nprint(\"Bounding option:\", 'same-slope')\nlb, ub, A_dict = lirpa_model_sameslope.compute_bounds(x=(image,), method='backward', return_A=True, needed_A_dict=required_A)\nlower_A, lower_bias = A_dict[lirpa_model_sameslope.output_name[0]][lirpa_model_sameslope.input_name[0]]['lA'], A_dict[lirpa_model_sameslope.output_name[0]][lirpa_model_sameslope.input_name[0]]['lbias']\nupper_A, upper_bias = A_dict[lirpa_model_sameslope.output_name[0]][lirpa_model_sameslope.input_name[0]]['uA'], A_dict[lirpa_model_sameslope.output_name[0]][lirpa_model_sameslope.input_name[0]]['ubias']\nprint(f'lower bound linear coefficients size (batch, output_dim, *input_dims): {list(lower_A.size())}')\nprint(f'lower bound bias term size (batch, output_dim): {list(lower_bias.size())}')\nprint(f'upper bound linear coefficients size (batch, output_dim, *input_dims): {list(upper_A.size())}')\nprint(f'upper bound bias term size (batch, output_dim): {list(upper_bias.size())}')\nprint()\nprint(f'lower bound linear coefficients should be the same as upper bound linear coefficients: {(lower_A - upper_A).abs().max() < 1e-5}')\nprint()\n"
  },
  {
    "path": "examples/vision/cifar_training.py",
    "content": "import argparse\nimport multiprocessing\nimport random\nimport time\nimport logging\nimport os\n\nimport torch.optim as optim\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nfrom torch.nn import CrossEntropyLoss\n\nimport models\nfrom auto_LiRPA import BoundedModule, BoundedTensor, BoundDataParallel, CrossEntropyWrapper\nfrom auto_LiRPA.bound_ops import BoundExp\nfrom auto_LiRPA.eps_scheduler import LinearScheduler, SmoothedScheduler, AdaptiveScheduler, FixedScheduler\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix, sync_params\n\ndef get_exp_module(bounded_module):\n    for _, node in bounded_module.named_modules():\n        # Find the Exp neuron in computational graph\n        if isinstance(node, BoundExp):\n            return node\n    return None\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--verify\", action=\"store_true\", help='verification mode, do not train')\nparser.add_argument(\"--no_loss_fusion\", action=\"store_true\", help='without loss fusion, slower training mode')\nparser.add_argument(\"--load\", type=str, default=\"\", help='Load pretrained model')\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cpu\", \"cuda\"], help='use cpu or cuda')\nparser.add_argument(\"--data\", type=str, default=\"CIFAR\", choices=[\"MNIST\", \"CIFAR\"], help='dataset')\nparser.add_argument(\"--seed\", type=int, default=100, help='random seed')\nparser.add_argument(\"--eps\", type=float, default=8.8/255, help='Target training epsilon')\nparser.add_argument(\"--norm\", type=float, default='inf', help='p norm for epsilon perturbation')\nparser.add_argument(\"--bound_type\", type=str, default=\"CROWN-IBP\",\n                    choices=[\"IBP\", \"CROWN-IBP\", \"CROWN\"], help='method of bound analysis')\nparser.add_argument(\"--model\", type=str, default=\"cnn_7layer_bn\",\n                    help='model name (Densenet_cifar_32, resnet18, ResNeXt_cifar, MobileNet_cifar, wide_resnet_cifar_bn_wo_pooling)')\nparser.add_argument(\"--num_epochs\", type=int, default=2000, help='number of total epochs')\nparser.add_argument(\"--batch_size\", type=int, default=256, help='batch size')\nparser.add_argument(\"--lr\", type=float, default=5e-4, help='learning rate')\nparser.add_argument(\"--lr_decay_rate\", type=float, default=0.1, help='learning rate decay rate')\nparser.add_argument(\"--lr_decay_milestones\", nargs='+', type=int, default=[1400, 1700], help='learning rate dacay milestones')\nparser.add_argument(\"--scheduler_name\", type=str, default=\"SmoothedScheduler\",\n                    choices=[\"LinearScheduler\", \"SmoothedScheduler\"], help='epsilon scheduler')\nparser.add_argument(\"--scheduler_opts\", type=str, default=\"start=101,length=801,mid=0.4\", help='options for epsilon scheduler')\nparser.add_argument(\"--bound_opts\", type=str, default=None, choices=[\"same-slope\", \"zero-lb\", \"one-lb\"],\n                    help='bound options for relu')\nparser.add_argument('--clip_grad_norm', type=float, default=8.0)\n\nargs = parser.parse_args()\nexp_name = args.model + '_b' + str(args.batch_size) + '_' + str(args.bound_type) + '_epoch' + str(args.num_epochs) + '_' + args.scheduler_opts + '_' + str(args.eps)[:6]\nos.makedirs('saved_models/', exist_ok=True)\nlog_file = f'saved_models/{exp_name}{\"_test\" if args.verify else \"\"}.log'\nfile_handler = logging.FileHandler(log_file)\nlogger.addHandler(file_handler)\n\ndef Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None):\n    num_class = 10\n    meter = MultiAverageMeter()\n    if train:\n        model.train()\n        eps_scheduler.train()\n        eps_scheduler.step_epoch()\n        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))\n    else:\n        model.eval()\n        eps_scheduler.eval()\n\n    exp_module = get_exp_module(model)\n\n    def get_bound_loss(x=None, c=None):\n        if loss_fusion:\n            bound_lower, bound_upper = False, True\n        else:\n            bound_lower, bound_upper = True, False\n\n        if bound_type == 'IBP':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True)\n        elif bound_type == 'CROWN':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=False, C=c, method='backward',\n                                          bound_lower=bound_lower, bound_upper=bound_upper)\n        elif bound_type == 'CROWN-IBP':\n            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound\n            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)\n            factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()\n            ilb, iub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None, final_node_name=final_node_name, no_replicas=True)\n            if factor < 1e-50:\n                lb, ub = ilb, iub\n            else:\n                clb, cub = model(method_opt=\"compute_bounds\", IBP=False, C=c, method='backward',\n                             bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, no_replicas=True)\n                if loss_fusion:\n                    ub = cub * factor + iub * (1 - factor)\n                else:\n                    lb = clb * factor + ilb * (1 - factor)\n\n        if loss_fusion:\n            if isinstance(model, BoundDataParallel):\n                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')\n            else:\n                max_input = exp_module.max_input\n            return None, torch.mean(torch.log(ub) + max_input)\n        else:\n            # Pad zero at the beginning for each example, and use fake label '0' for all examples\n            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n            return lb, robust_ce\n\n    for i, (data, labels) in enumerate(loader):\n        start = time.time()\n        eps_scheduler.step_batch()\n        eps = eps_scheduler.get_eps()\n        # For small eps just use natural training, no need to compute LiRPA bounds\n        batch_method = method\n        if eps < 1e-50:\n            batch_method = \"natural\"\n        if train:\n            opt.zero_grad()\n        # bound input for Linf norm used only\n        if norm == np.inf:\n            data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_ub = torch.min(data + (eps / loader.std).view(1,-1,1,1), data_max)\n            data_lb = torch.max(data - (eps / loader.std).view(1,-1,1,1), data_min)\n        else:\n            data_ub = data_lb = data\n\n        if list(model.parameters())[0].is_cuda:\n            data, labels = data.cuda(), labels.cuda()\n            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()\n\n        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)\n        x = BoundedTensor(data, ptb)\n        if loss_fusion:\n            if batch_method == 'natural' or not train:\n                output = model(x, labels)  # , disable_multi_gpu=True\n                regular_ce = torch.mean(torch.log(output))\n            else:\n                model(x, labels)\n                regular_ce = torch.tensor(0., device=data.device)\n            meter.update('CE', regular_ce.item(), x.size(0))\n            x = (x, labels)\n            c = None\n        else:\n            # Generate speicification matrix (when loss fusion is not used).\n            c = get_spec_matrix(data, labels, num_class)\n            x = (x,) if final_node_name is None else (x, labels)\n            output = model(x, final_node_name=final_node_name)\n            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up\n            meter.update('CE', regular_ce.item(), x[0].size(0))\n            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))\n\n        if batch_method == 'robust':\n            lb, robust_ce = get_bound_loss(x=x, c=c)\n            loss = robust_ce\n        elif batch_method == 'natural':\n            loss = regular_ce\n\n        if train:\n            loss.backward()\n\n            if args.clip_grad_norm:\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)\n                meter.update('grad_norm', grad_norm)\n\n            if isinstance(eps_scheduler, AdaptiveScheduler):\n                eps_scheduler.update_loss(loss.item() - regular_ce.item())\n            opt.step()\n        meter.update('Loss', loss.item(), data.size(0))\n\n        if batch_method != 'natural':\n            meter.update('Robust_CE', robust_ce.item(), data.size(0))\n            if not loss_fusion:\n                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.\n                # If any margin is < 0 this example is counted as an error\n                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))\n        meter.update('Time', time.time() - start)\n\n        if (i + 1) % 50 == 0 and train:\n            logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n\n    logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n    return meter\n\n\ndef main(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py\n    if args.data == 'MNIST':\n        model_ori = models.Models[args.model](in_ch=1, in_dim=28)\n    else:\n        model_ori = models.Models[args.model](in_ch=3, in_dim=32)\n    epoch = 0\n    if args.load:\n        checkpoint = torch.load(args.load)\n        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']\n        opt_state = None\n        try:\n            opt_state = checkpoint['optimizer']\n        except KeyError:\n            print('no opt_state found')\n        for k, v in state_dict.items():\n            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0\n        model_ori.load_state_dict(state_dict)\n        logger.info('Checkpoint loaded: {}'.format(args.load))\n\n    ## Step 2: Prepare dataset as usual\n    if args.data == 'MNIST':\n        dummy_input = torch.randn(2, 1, 28, 28)\n        train_data = datasets.MNIST(\"./data\", train=True, download=True, transform=transforms.ToTensor())\n        test_data = datasets.MNIST(\"./data\", train=False, download=True, transform=transforms.ToTensor())\n    elif args.data == 'CIFAR':\n        dummy_input = torch.randn(2, 3, 32, 32)\n        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])\n        train_data = datasets.CIFAR10(\"./data\", train=True, download=True,\n                transform=transforms.Compose([\n                    transforms.RandomHorizontalFlip(),\n                    transforms.RandomCrop(32, 4, padding_mode='edge'),\n                    transforms.ToTensor(),\n                    normalize]))\n        test_data = datasets.CIFAR10(\"./data\", train=False, download=True,\n                transform=transforms.Compose([transforms.ToTensor(), normalize]))\n\n    train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size//2, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    if args.data == 'MNIST':\n        train_data.mean = test_data.mean = torch.tensor([0.0])\n        train_data.std = test_data.std = torch.tensor([1.0])\n    elif args.data == 'CIFAR':\n        train_data.mean = test_data.mean = torch.tensor([0.4914, 0.4822, 0.4465])\n        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])\n\n    ## Step 3: wrap model with auto_LiRPA\n    # The second parameter dummy_input is for constructing the trace of the computational graph.\n    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n    final_name1 = model.final_name\n    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),\n                               bound_opts={'activation_bound_option': args.bound_opts, 'loss_fusion': True}, device=args.device)\n    # after CrossEntropyWrapper, the final name will change because of one additional input node in CrossEntropyWrapper\n    final_name2 = model_loss._modules[final_name1].output_name[0]\n    assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])\n    if args.no_loss_fusion:\n        model_loss = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n        final_name2 = None\n    model_loss = BoundDataParallel(model_loss)\n\n    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler\n    opt = optim.Adam(model_loss.parameters(), lr=args.lr)\n    norm = float(args.norm)\n    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=args.lr_decay_rate)\n    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)\n    logger.info(str(model_ori))\n\n    # skip epochs\n    if epoch > 0:\n        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)\n        eps_scheduler.set_epoch_length(epoch_length)\n        eps_scheduler.train()\n        for i in range(epoch):\n            lr_scheduler.step()\n            eps_scheduler.step_epoch(verbose=True)\n            for j in range(epoch_length):\n                eps_scheduler.step_batch()\n        logger.info('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))\n\n    if args.load:\n        if opt_state:\n            opt.load_state_dict(opt_state)\n            logger.info('resume opt_state')\n\n    ## Step 5: start training\n    if args.verify:\n        eps_scheduler = FixedScheduler(args.eps)\n        with torch.no_grad():\n            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=None)\n    else:\n        timer = 0.0\n        best_err = 1e10\n        for t in range(epoch + 1, args.num_epochs+1):\n            logger.info(\"Epoch {}, learning rate {}\".format(t, lr_scheduler.get_last_lr()))\n            start_time = time.time()\n            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=not args.no_loss_fusion)\n            lr_scheduler.step()\n            epoch_time = time.time() - start_time\n            timer += epoch_time\n            logger.info('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))\n\n            logger.info(\"Evaluating...\")\n            torch.cuda.empty_cache()\n\n            state_dict = sync_params(model_ori, model_loss, loss_fusion=True)\n\n            with torch.no_grad():\n                if t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):\n                    m = Train(model_loss, t, test_data, FixedScheduler(8./255), norm, False, None, 'IBP', loss_fusion=False,\n                              final_node_name=final_name2)\n                else:\n                    m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=final_name2)\n\n            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}\n            if t < int(eps_scheduler.params['start']):\n                torch.save(save_dict, 'saved_models/natural_' + exp_name)\n            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):\n                current_err = m.avg('Verified_Err')\n                if current_err < best_err:\n                    best_err = current_err\n                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_err)[:6])\n                else:\n                    torch.save(save_dict, 'saved_models/' + exp_name)\n            else:\n                torch.save(save_dict, 'saved_models/' + exp_name)\n            torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    logger.info(args)\n    main(args)\n"
  },
  {
    "path": "examples/vision/custom_op.py",
    "content": "\"\"\" A example for custom operators.\n\nIn this example, we create a custom operator called \"PlusConstant\", which can\nbe written as \"f(x) = x + c\" for some constant \"c\" (an attribute of the operator).\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor, register_custom_op\nfrom auto_LiRPA.operators import Bound\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\n\n\"\"\" Step 1: Define a `torch.autograd.Function` class to declare and implement the\ncomputation of the operator. \"\"\"\nclass PlusConstantOp(torch.autograd.Function):\n    @staticmethod\n    def symbolic(g, x, const):\n        \"\"\" In this function, define the arguments and attributes of the operator.\n        \"custom::PlusConstant\" is the name of the new operator, \"x\" is an argument\n        of the operator, \"const_i\" is an attribute which stands for \"c\" in the operator.\n        There can be multiple arguments and attributes. For attribute naming,\n        use a suffix such as \"_i\" to specify the data type, where \"_i\" stands for\n        integer, \"_t\" stands for tensor, \"_f\" stands for float, etc. \"\"\"\n        return g.op('custom::PlusConstant', x, const_i=const)\n\n    @staticmethod\n    def forward(ctx, x, const):\n        \"\"\" In this function, implement the computation for the operator, i.e.,\n        f(x) = x + c in this case. \"\"\"\n        return x + const\n\n\"\"\" Step 2: Define a `torch.nn.Module` class to declare a module using the defined\ncustom operator. \"\"\"\nclass PlusConstant(nn.Module):\n    def __init__(self, const=1):\n        super().__init__()\n        self.const = const\n\n    def forward(self, x):\n        \"\"\" Use `PlusConstantOp.apply` to call the defined custom operator. \"\"\"\n        return PlusConstantOp.apply(x, self.const)\n\n\"\"\" Step 3: Implement a Bound class to support bound computation for the new operator. \"\"\"\nclass BoundPlusConstant(Bound):\n    def __init__(self, attr, inputs, output_index, options):\n        \"\"\" `const` is an attribute and can be obtained from the dict `attr` \"\"\"\n        super().__init__(attr, inputs, output_index, options)\n        self.const = attr['const']\n\n    def forward(self, x):\n        return x + self.const\n\n    def bound_backward(self, last_lA, last_uA, x, *args, **kwargs):\n        \"\"\" Backward mode bound propagation \"\"\"\n        print('Calling bound_backward for custom::PlusConstant')\n        def _bound_oneside(last_A):\n            # If last_lA or last_uA is None, it means lower or upper bound\n            # is not required, so we simply return None.\n            if last_A is None:\n                return None, 0\n            # The function f(x) = x + c is a linear function with coefficient 1.\n            # Then A · f(x) = A · (x + c) = A · x + A · c.\n            # Thus the new A matrix is the same as the last A matrix:\n            A = last_A\n            # For bias, compute A · c and reduce the dimensions by sum:\n            bias = last_A.sum(dim=list(range(2, last_A.ndim))) * self.const\n            return A, bias\n        lA, lbias = _bound_oneside(last_lA)\n        uA, ubias = _bound_oneside(last_lA)\n        return [(lA, uA)], lbias, ubias\n\n    def interval_propagate(self, *v):\n        \"\"\" IBP computation \"\"\"\n        print('Calling interval_propagate for custom::PlusConstant')\n        # Interval bound of the input\n        h_L, h_U = v[0]\n        # Since this function is monotonic, we can get the lower bound and upper bound\n        # by applying the function on h_L and h_U respectively.\n        lower = h_L + self.const\n        upper = h_U + self.const\n        return lower, upper\n\n\"\"\" Step 4: Register the custom operator \"\"\"\nregister_custom_op(\"custom::PlusConstant\", BoundPlusConstant)\n\n# Use the `PlusConstant` module in model definition\nmodel = nn.Sequential(\n    Flatten(),\n    nn.Linear(28 * 28, 256),\n    PlusConstant(const=1),\n    nn.Linear(256, 10),\n)\nprint(\"Model:\", model)\n\ntest_data = torchvision.datasets.MNIST(\"./data\", train=False, download=True, transform=torchvision.transforms.ToTensor())\nN = 1\nn_classes = 10\nimage = test_data.data[:N].view(N,1,28,28)\ntrue_label = test_data.targets[:N]\nimage = image.to(torch.float32) / 255.0\nif torch.cuda.is_available():\n    image = image.cuda()\n    model = model.cuda()\n\nlirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)\n\neps = 0.3\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, eps = eps)\nimage = BoundedTensor(image, ptb)\npred = lirpa_model(image)\nlabel = torch.argmax(pred, dim=1).cpu().detach().numpy()\n\nfor method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)']:\n    print(\"Bounding method:\", method)\n    lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0])\n    for i in range(N):\n        print(\"Image {} top-1 prediction {} ground-truth {}\".format(i, label[i], true_label[i]))\n        for j in range(n_classes):\n            indicator = '(ground-truth)' if j == true_label[i] else ''\n            print(\"f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}\".format(\n                j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\n    print()\n\n"
  },
  {
    "path": "examples/vision/data/.gitignore",
    "content": "MNIST\ncifar*"
  },
  {
    "path": "examples/vision/data/ImageNet64/imagenet_data_loader.py",
    "content": "import os\n\nimport numpy as np\nfrom PIL import Image\n\n\nclass DatasetDownsampledImageNet():\n    def __init__(self):\n        # self.data_path = data_path\n        os.mkdir('train')\n        os.mkdir('test')\n        for i in range(1000):\n            os.mkdir('train/' + str(i))\n            os.mkdir('test/' + str(i))\n            print(i)\n        self.load_data('raw_data/Imagenet64_train_npz', count=0, fname='train/')\n        self.load_data('raw_data/Imagenet64_val_npz', count=1e8, fname='test/')\n\n    def load_data(self, data_path, img_size=64, count=0., fname=''):\n        files = os.listdir(data_path)\n        img_size2 = img_size * img_size\n\n        # count = 0  # 1e8  # test data start with 1\n        for file in files:\n            f = np.load(data_path + '/' + file)\n            x = np.array(f['data'])\n            y = np.array(f['labels']) - 1\n            x = np.dstack((x[:, :img_size2], x[:, img_size2:2 * img_size2], x[:, 2 * img_size2:]))\n            x = x.reshape((x.shape[0], img_size, img_size, 3))\n\n            for i, img in enumerate(x):\n                img = Image.fromarray(img.reshape(img_size, img_size, 3))\n                name = str(int(count)).zfill(9)\n                label = str(y[i])\n                print(count, fname + label + '/' + name + '_label_' + label.zfill(4) + '.png')\n                # img.show()\n                img.save(fname + label + '/' + name + '_label_' + label.zfill(4) + '.png')\n\n                count += 1\n\n\nif __name__ == \"__main__\":\n    DatasetDownsampledImageNet()\n"
  },
  {
    "path": "examples/vision/data/tinyImageNet/.gitignore",
    "content": "tiny-imagenet-200*\n"
  },
  {
    "path": "examples/vision/data/tinyImageNet/tinyimagenet_download.sh",
    "content": "#!/bin/bash\n\n# download and unzip dataset\nwget http://cs231n.stanford.edu/tiny-imagenet-200.zip\nunzip tiny-imagenet-200.zip\n\ncurrent=\"$(pwd)/tiny-imagenet-200\"\n\n# training data\necho \"preparing training data...\"\ncd $current/train\nfor DIR in $(ls); do\n   cd $DIR\n   rm *.txt\n   mv images/* .\n   rm -r images\n   cd ..\ndone\n\n# validation data\necho \"preparing validation data...\"\ncd $current/val\nannotate_file=\"val_annotations.txt\"\nlength=$(cat $annotate_file | wc -l)\nfor i in $(seq 1 $length); do\n    # fetch i th line\n    line=$(sed -n ${i}p $annotate_file)\n    # get file name and directory name\n    file=$(echo $line | cut -f1 -d\" \" )\n    directory=$(echo $line | cut -f2 -d\" \")\n    mkdir -p $directory\n    mv images/$file $directory\ndone\nrm -r images\necho \"done\"\n"
  },
  {
    "path": "examples/vision/datasets.py",
    "content": "import multiprocessing\nimport torch\nfrom torch.utils import data\nfrom functools import partial\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\n\n# compute image statistics (by Andreas https://discuss.pytorch.org/t/computing-the-mean-and-std-of-dataset/34949/4)\ndef get_stats(loader):\n    mean = 0.0\n    for images, _ in loader:\n        batch_samples = images.size(0) \n        reshaped_img = images.view(batch_samples, images.size(1), -1)\n        mean += reshaped_img.mean(2).sum(0)\n    w = images.size(2)\n    h = images.size(3)\n    mean = mean / len(loader.dataset)\n\n    var = 0.0\n    for images, _ in loader:\n        batch_samples = images.size(0)\n        images = images.view(batch_samples, images.size(1), -1)\n        var += ((images - mean.unsqueeze(1))**2).sum([0,2])\n    std = torch.sqrt(var / (len(loader.dataset)*w*h))\n    return mean, std\n\n# load MNIST of Fashion-MNIST\ndef mnist_loaders(dataset, batch_size, shuffle_train = True, shuffle_test = False, ratio=None, test_batch_size=None):\n    # Use the AWS mirror and avoid the yann.lecun.com mirror.\n    dataset.mirrors = [\n        'https://ossci-datasets.s3.amazonaws.com/mnist/',\n    ]\n\n    mnist_train = dataset(\"./data\", train=True, download=True, transform=transforms.ToTensor())\n    mnist_test = dataset(\"./data\", train=False, download=True, transform=transforms.ToTensor())\n\n    if ratio is not None:\n        # only sample in training data\n        num_of_each_class_train = int(len(mnist_train) // 10 * ratio)\n        # num_of_each_class_test = int(len(mnist_test)//10*ratio)\n\n        class_idx_train = [(mnist_train.targets == _).nonzero().numpy().squeeze() for _ in range(10)]\n        # class_idx_test = [(mnist_test.targets==_).nonzero().numpy().squeeze() for _ in range(10)]\n\n        for i in range(len(class_idx_train)):\n            class_idx_train[i] = class_idx_train[i][:num_of_each_class_train]\n            # class_idx_test[i] = class_idx_test[i][:num_of_each_class_test]\n\n        mnist_train = data.Subset(mnist_train, [y for z in class_idx_train for y in z])\n\n    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))\n    if test_batch_size:\n        batch_size = test_batch_size\n    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))\n    std = [1.0]\n    train_loader.std = std\n    test_loader.std = std\n    return train_loader, test_loader\n\n\ndef cifar_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): \n    if normalize_input:\n        std = [0.2023, 0.1994, 0.2010]\n        normalize = transforms.Normalize(mean = [0.4914, 0.4822, 0.4465],\n                                          std = std)\n    else:\n        std = [1.0, 1.0, 1.0]\n        normalize = transforms.Normalize(mean=[0, 0, 0],\n                                         std=std)\n    if train_random_transform:\n        if normalize_input:\n            train = datasets.CIFAR10('./data', train=True, download=True, \n                transform=transforms.Compose([\n                    transforms.RandomHorizontalFlip(),\n                    transforms.RandomCrop(32, 4),\n                    transforms.ToTensor(),\n                    normalize,\n                ]))\n        else:\n            train = datasets.CIFAR10('./data', train=True, download=True, \n                transform=transforms.Compose([\n                    transforms.RandomHorizontalFlip(),\n                    transforms.RandomCrop(32, 4),\n                    transforms.ToTensor(),\n                ]))\n    else:\n        train = datasets.CIFAR10('./data', train=True, download=True, \n            transform=transforms.Compose([transforms.ToTensor(),normalize]))\n    test = datasets.CIFAR10('./data', train=False, \n        transform=transforms.Compose([transforms.ToTensor(), normalize]))\n    \n    if num_examples:\n        indices = list(range(num_examples))\n        train = data.Subset(train, indices)\n        test = data.Subset(test, indices)\n\n    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,\n        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))\n    if test_batch_size:\n        batch_size = test_batch_size\n    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),\n        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))\n    train_loader.std = std\n    test_loader.std = std\n    return train_loader, test_loader\n\ndef svhn_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): \n    if normalize_input:\n        mean = [0.43768206, 0.44376972, 0.47280434] \n        std = [0.19803014, 0.20101564, 0.19703615]\n        normalize = transforms.Normalize(mean = mean,\n                                          std = std)\n    else:\n        std = [1.0, 1.0, 1.0]\n        normalize = transforms.Normalize(mean=[0, 0, 0],\n                                         std=std)\n    if train_random_transform:\n        if normalize_input:\n            train = datasets.SVHN('./data', split='train', download=True, \n                transform=transforms.Compose([\n                    transforms.RandomCrop(32, 4),\n                    transforms.ToTensor(),\n                    normalize,\n                ]))\n        else:\n            train = datasets.SVHN('./data', split='train', download=True, \n                transform=transforms.Compose([\n                    transforms.RandomCrop(32, 4),\n                    transforms.ToTensor(),\n                ]))\n    else:\n        train = datasets.SVHN('./data', split='train', download=True, \n            transform=transforms.Compose([transforms.ToTensor(),normalize]))\n    test = datasets.SVHN('./data', split='test', download=True,\n        transform=transforms.Compose([transforms.ToTensor(), normalize]))\n    \n    if num_examples:\n        indices = list(range(num_examples))\n        train = data.Subset(train, indices)\n        test = data.Subset(test, indices)\n\n    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,\n        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))\n    if test_batch_size:\n        batch_size = test_batch_size\n    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),\n        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))\n    train_loader.std = std\n    test_loader.std = std\n    mean, std = get_stats(train_loader)\n    print('dataset mean = ', mean.numpy(), 'std = ', std.numpy())\n    return train_loader, test_loader\n\ndef load_data(data, batch_size):\n    if data == 'MNIST':\n        dummy_input = torch.randn(1, 1, 28, 28)\n        train_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())\n        test_data = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())\n    elif data == 'CIFAR':\n        dummy_input = torch.randn(1, 3, 32, 32)\n        normalize = transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010])\n        train_data = datasets.CIFAR10('./data', train=True, download=True,\n                transform=transforms.Compose([\n                    transforms.RandomHorizontalFlip(),\n                    transforms.RandomCrop(32, 4, padding_mode='edge'),\n                    transforms.ToTensor(),\n                    normalize]))\n        test_data = datasets.CIFAR10('./data', train=False, download=True, \n                transform=transforms.Compose([transforms.ToTensor(), normalize]))\n\n    train_data = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    test_data = torch.utils.data.DataLoader(test_data, batch_size=batch_size, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    if data == 'MNIST':\n        train_data.mean = test_data.mean = torch.tensor([0.0])\n        train_data.std = test_data.std = torch.tensor([1.0])\n    elif data == 'CIFAR':\n        train_data.mean = test_data.mean = torch.tensor([0.4914, 0.4822, 0.4465])\n        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])\n\n    return dummy_input, train_data, test_data\n\n# when new loaders is added, they must be registered here\nloaders = {\n        \"MNIST\": partial(mnist_loaders, datasets.MNIST),\n        \"FashionMNIST\": partial(mnist_loaders, datasets.FashionMNIST),\n        \"CIFAR\": cifar_loaders,\n        \"svhn\": svhn_loaders,\n        }\n\n"
  },
  {
    "path": "examples/vision/efficient_convolution.py",
    "content": "\"\"\"\nDemonstration of efficient convolutional network implementation in auto_LiRPA.\n\nauto_LiRPA library supports an efficient algorithm for computing bounds for\nconvolutional networks. The \"patches\" mode implementation makes full backward\nbounds (CROWN) for convolutional layers significantly faster by using more\nefficient GPU operators.  The convolution mode can be set by the \"conv_mode\"\nkey in the bound_opts parameter when constructing your BoundeModule object and\nthe new \"patches\" mode is enabled by default.  In this example we show the\ndifferences between \"patches\" mode and the old \"matrix\" mode in memory\nconsumption, on a relatively large ResNet network.\n\n\"\"\"\n\nimport sys\nimport torch\nimport random\nimport numpy as np\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nimport models\n\ndevice = 'cpu'\nif torch.cuda.is_available():\n    device = 'cuda'\nconv_mode = sys.argv[1] if len(sys.argv) > 1 else 'patches' # conv_mode can be set as 'matrix' or 'patches'\n\nseed = 1234\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\nrandom.seed(seed)\nnp.random.seed(seed)\n\n## Step 1: Define the model\n# model_ori = models.model_resnet(width=1, mult=4)\n# model_ori = models.ResNet18(in_planes=2)\n# model_ori = models.vnncomp_resnet2b()\nmodel_ori = models.vnncomp_resnet4b()\nmodel_ori = model_ori.to(device=device)\n\n## Step 2: Prepare dataset as usual.\n# test_data = torchvision.datasets.MNIST(\"./data\", train=False, download=True, transform=torchvision.transforms.ToTensor())\nnormalize = torchvision.transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010])\ntest_data = torchvision.datasets.CIFAR10(\"./data\", train=False, download=True, \n                transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))\n# For illustration we only use 1 image from dataset\nN = 1\nn_classes = 10\nimage = torch.Tensor(test_data.data[:N]).reshape(N,3,32,32)\n# Convert to float between 0. and 1.\nimage = image.to(torch.float32) / 255.0\nif device == 'cuda':\n    image = image.cuda()\n\n## Step 3: wrap model with auto_LiRPA.\n# The second parameter is for constructing the trace of the computational graph, and its content is not important.\n# The new \"patches\" conv_mode provides an more efficient implementation for convolutional neural networks.\nmodel = BoundedModule(model_ori, image, bound_opts={\"conv_mode\": conv_mode}, device=device) \n\n## Step 4: Compute bounds using LiRPA given a perturbation\neps = 0.1\nnorm = 2\nptb = PerturbationLpNorm(norm = norm, eps = eps)\nimage = BoundedTensor(image, ptb)\n# Get model prediction as usual\npred = model(image)\n\n# Compute bounds\nif device == 'cuda':\n    torch.cuda.empty_cache()\nprint('Using {} mode to compute convolution.'.format(conv_mode))\nlb, ub = model.compute_bounds(IBP=False, C=None, method='backward')\n\n## Step 5: Final output\n# pred = pred.detach().cpu().numpy()\nlb = lb.detach().cpu().numpy()\nub = ub.detach().cpu().numpy()\nfor i in range(N):\n    # print(\"Image {} top-1 prediction {}\".format(i, label[i]))\n    for j in range(n_classes):\n        print(\"f_{j}(x_0): {l:8.5f} <= f_{j}(x_0+delta) <= {u:8.5f}\".format(j=j, l=lb[i][j], u=ub[i][j]))\n    print()\n\n# Print the GPU memory usage\nprint('Memory usage in \"{}\" mode:'.format(conv_mode))\nif device == 'cuda':\n    print(torch.cuda.memory_summary())\n"
  },
  {
    "path": "examples/vision/imagenet_training.py",
    "content": "import random\nimport time\nimport argparse\nimport multiprocessing\nimport logging\nimport torch.optim as optim\nfrom torch.nn import CrossEntropyLoss\nfrom auto_LiRPA import BoundedModule, BoundedTensor, BoundDataParallel, CrossEntropyWrapper\nfrom auto_LiRPA.bound_ops import BoundExp\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix, sync_params\nimport models\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nfrom auto_LiRPA.eps_scheduler import *\n\ndef get_exp_module(bounded_module):\n    for _, node in bounded_module.named_modules():\n        # Find the Exp neuron in computational graph\n        if isinstance(node, BoundExp):\n            return node\n    return None\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--verify\", action=\"store_true\", help='verification mode, do not train')\nparser.add_argument(\"--load\", type=str, default=\"\", help='Load pretrained model')\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cpu\", \"cuda\"], help='use cpu or cuda')\nparser.add_argument(\"--data_dir\", type=str, default=\"data/ImageNet64\",\n                    help='dir of dataset')\nparser.add_argument(\"--seed\", type=int, default=100, help='random seed')\nparser.add_argument(\"--eps\", type=float, default=1. / 255, help='Target training epsilon')\nparser.add_argument(\"--norm\", type=float, default='inf', help='p norm for epsilon perturbation')\nparser.add_argument(\"--bound_type\", type=str, default=\"CROWN-IBP\",\n                    choices=[\"IBP\", \"CROWN-IBP\", \"CROWN\"], help='method of bound analysis')\nparser.add_argument(\"--model\", type=str, default=\"wide_resnet_imagenet64_1000class\",\n                    help='model name (mlp_3layer, cnn_4layer, cnn_6layer, cnn_7layer, resnet)')\nparser.add_argument(\"--num_epochs\", type=int, default=240, help='number of total epochs')\nparser.add_argument(\"--batch_size\", type=int, default=125, help='batch size')\nparser.add_argument(\"--lr\", type=float, default=1e-3, help='learning rate')\nparser.add_argument(\"--lr_decay_milestones\", nargs='+', type=int, default=[200, 220], help='learning rate dacay milestones')\nparser.add_argument(\"--scheduler_name\", type=str, default=\"SmoothedScheduler\",\n                    choices=[\"LinearScheduler\", \"AdaptiveScheduler\", \"SmoothedScheduler\"], help='epsilon scheduler')\nparser.add_argument(\"--scheduler_opts\", type=str, default=\"start=100,length=80\", help='options for epsilon scheduler')\nparser.add_argument(\"--bound_opts\", type=str, default=None, choices=[\"same-slope\", \"zero-lb\", \"one-lb\"],\n                    help='bound options')\nparser.add_argument('--clip_grad_norm', type=float, default=8.0)\nparser.add_argument('--in_planes', type=int, default=16)\nparser.add_argument('--widen_factor', type=int, default=10)\n\nargs = parser.parse_args()\n\nexp_name = args.model + '_b' + str(args.batch_size) + '_' + str(args.bound_type) + '_epoch' + str(\n    args.num_epochs) + '_' + args.scheduler_opts + '_' + str(args.eps)[:6]\nlog_file = f'saved_models/{exp_name}{\"_test\" if args.verify else \"\"}.log'\nfile_handler = logging.FileHandler(log_file)\nlogger.addHandler(file_handler)\n\ndef Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True,\n          final_node_name=None):\n    num_class = 1000\n    meter = MultiAverageMeter()\n    if train:\n        model.train()\n        eps_scheduler.train()\n        eps_scheduler.step_epoch()\n        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))\n    else:\n        model.eval()\n        eps_scheduler.eval()\n\n    exp_module = get_exp_module(model)\n\n    def get_bound_loss(x=None, c=None):\n        if loss_fusion:\n            bound_lower, bound_upper = False, True\n        else:\n            bound_lower, bound_upper = True, False\n\n        if bound_type == 'IBP':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None,\n                           final_node_name=final_node_name, no_replicas=True)\n        elif bound_type == 'CROWN':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=False, C=c, method='backward',\n                           bound_lower=bound_lower, bound_upper=bound_upper)\n        elif bound_type == 'CROWN-IBP':\n            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound\n            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)\n            factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()\n            ilb, iub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None,\n                             final_node_name=final_node_name, no_replicas=True)\n            if factor < 1e-50:\n                lb, ub = ilb, iub\n            else:\n                clb, cub = model(method_opt=\"compute_bounds\", IBP=False, C=c, method='backward',\n                                 bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name,\n                                 no_replicas=True)\n                if loss_fusion:\n                    ub = cub * factor + iub * (1 - factor)\n                else:\n                    lb = clb * factor + ilb * (1 - factor)\n\n        if loss_fusion:\n            if isinstance(model, BoundDataParallel):\n                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')\n            else:\n                max_input = exp_module.max_input\n            return None, torch.mean(torch.log(ub) + max_input)\n        else:\n            # Pad zero at the beginning for each example, and use fake label '0' for all examples\n            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n            return lb, robust_ce\n\n    for i, (data, labels) in enumerate(loader):\n        start = time.time()\n        eps_scheduler.step_batch()\n        eps = eps_scheduler.get_eps()\n        # For small eps just use natural training, no need to compute LiRPA bounds\n        batch_method = method\n        if eps < 1e-50:\n            batch_method = \"natural\"\n        if train:\n            opt.zero_grad()\n        # bound input for Linf norm used only\n        if norm == np.inf:\n            data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max)\n            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min)\n        else:\n            data_ub = data_lb = data\n\n        if list(model.parameters())[0].is_cuda:\n            data, labels = data.cuda(), labels.cuda()\n            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()\n\n        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)\n        x = BoundedTensor(data, ptb)\n        if loss_fusion:\n            if batch_method == 'natural' or not train:\n                output = model(x, labels)\n                regular_ce = torch.mean(torch.log(output))\n            else:\n                model(x, labels)\n                regular_ce = torch.tensor(0., device=data.device)\n            meter.update('CE', regular_ce.item(), x.size(0))\n            x = (x, labels)\n            c = None\n        else:\n            c = get_spec_matrix(data, labels, num_class)\n            x = (x, labels)\n            output = model(x, final_node_name=final_node_name)\n            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up\n            meter.update('CE', regular_ce.item(), x[0].size(0))\n            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))\n\n        if batch_method == 'robust':\n            # print(data.sum())\n            lb, robust_ce = get_bound_loss(x=x, c=c)\n            loss = robust_ce\n        elif batch_method == 'natural':\n            loss = regular_ce\n\n        if train:\n            loss.backward()\n\n            if args.clip_grad_norm:\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)\n                meter.update('grad_norm', grad_norm)\n\n            if isinstance(eps_scheduler, AdaptiveScheduler):\n                eps_scheduler.update_loss(loss.item() - regular_ce.item())\n            opt.step()\n        meter.update('Loss', loss.item(), data.size(0))\n\n        if batch_method != 'natural':\n            meter.update('Robust_CE', robust_ce.item(), data.size(0))\n            if not loss_fusion:\n                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.\n                # If any margin is < 0 this example is counted as an error\n                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))\n        meter.update('Time', time.time() - start)\n\n        if (i + 1) % 500 == 0 and train:\n            logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n\n    logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n    return meter\n\n\ndef main(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py\n    model_ori = models.Models[args.model](in_planes=args.in_planes, widen_factor=args.widen_factor)\n    epoch = 0\n    if args.load:\n        checkpoint = torch.load(args.load)\n        epoch, state_dict, opt_state = checkpoint['epoch'], checkpoint['state_dict'], checkpoint.get('optimizer')\n        for k, v in state_dict.items():\n            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0\n        model_ori.load_state_dict(state_dict)\n        logger.info('Checkpoint loaded: {}'.format(args.load))\n\n    ## Step 2: Prepare dataset as usual\n    dummy_input = torch.randn(2, 3, 56, 56)\n    normalize = transforms.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2153, 0.2111, 0.2121])\n    train_data = datasets.ImageFolder(args.data_dir + '/train',\n                                      transform=transforms.Compose([\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomCrop(56, padding_mode='edge'),\n                                          transforms.ToTensor(),\n                                          normalize,\n                                      ]))\n    test_data = datasets.ImageFolder(args.data_dir + '/test',\n                                     transform=transforms.Compose([\n                                         # transforms.RandomResizedCrop(64, scale=(0.875, 0.875), ratio=(1., 1.)),\n                                         transforms.CenterCrop(56),\n                                         transforms.ToTensor(),\n                                         normalize]))\n\n    train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True,\n                                             num_workers=min(multiprocessing.cpu_count(), 4))\n    test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size // 4, pin_memory=True,\n                                            num_workers=min(multiprocessing.cpu_count(), 4))\n    train_data.mean = test_data.mean = torch.tensor([0.4815, 0.4578, 0.4082])\n    train_data.std = test_data.std = torch.tensor([0.2153, 0.2111, 0.2121])\n\n    ## Step 3: wrap model with auto_LiRPA\n    # The second parameter dummy_input is for constructing the trace of the computational graph.\n    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),\n                               bound_opts= { 'activation_bound_option': args.bound_opts, 'loss_fusion': True }, device=args.device)\n    model_loss = BoundDataParallel(model_loss)\n\n    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler\n    opt = optim.Adam(model_loss.parameters(), lr=args.lr)\n    norm = float(args.norm)\n    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)\n    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)\n    logger.info(str(model_ori))\n\n    if args.load:\n        if opt_state:\n            opt.load_state_dict(opt_state)\n            logger.info('resume opt_state')\n\n    # skip epochs\n    if epoch > 0:\n        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)\n        eps_scheduler.set_epoch_length(epoch_length)\n        eps_scheduler.train()\n        for i in range(epoch):\n            lr_scheduler.step()\n            eps_scheduler.step_epoch(verbose=True)\n            for j in range(epoch_length):\n                eps_scheduler.step_batch()\n        logger.info('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))\n\n    ## Step 5: start training\n    if args.verify:\n        eps_scheduler = FixedScheduler(args.eps)\n        with torch.no_grad():\n            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=None)\n    else:\n        timer = 0.0\n        best_err = 1e10\n        for t in range(epoch + 1, args.num_epochs + 1):\n            logger.info(\"Epoch {}, learning rate {}\".format(t, lr_scheduler.get_last_lr()))\n            start_time = time.time()\n            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)\n            lr_scheduler.step()\n            epoch_time = time.time() - start_time\n            timer += epoch_time\n            logger.info('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))\n\n            logger.info(\"Evaluating...\")\n            torch.cuda.empty_cache()\n\n            state_dict = sync_params(model_ori, model_loss, loss_fusion=True)\n\n            with torch.no_grad():\n                if int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']) > t >= int(\n                        eps_scheduler.params['start']):\n                    m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type, loss_fusion=True)\n                else:\n                    model_ori.load_state_dict(state_dict)\n                    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n                    model = BoundDataParallel(model)\n                    m = Train(model, t, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False)\n                    del model\n\n            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}\n            if t < int(eps_scheduler.params['start']):\n                torch.save(save_dict, 'saved_models/natural_' + exp_name)\n            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):\n                current_err = m.avg('Verified_Err')\n                if current_err < best_err:\n                    best_err = current_err\n                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_err)[:6])\n            else:\n                torch.save(save_dict, 'saved_models/' + exp_name)\n            torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    logger.info(args)\n    main(args)\n"
  },
  {
    "path": "examples/vision/jacobian.py",
    "content": "\"\"\"Examples of computing Jacobian bounds.\n\nWe show examples of:\n- Computing Jacobian bounds\n- Computing Linf local Lipschitz constants\n- Computing JVP bounds\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\nfrom auto_LiRPA.jacobian import JacobianOP, GradNorm\n\n\ndef build_model(in_ch=3, in_dim=32):\n    model = nn.Sequential(\n        Flatten(),\n        nn.Linear(in_ch*in_dim**2, 100),\n        nn.ReLU(),\n        nn.Linear(100, 200),\n        nn.ReLU(),\n        nn.Linear(200, 10),\n    )\n    return model\n\n\ndef example_jacobian(model_ori, x0, bound_opts, device):\n    \"\"\"Example: computing Jacobian bounds.\"\"\"\n\n    class JacobianWrapper(nn.Module):\n        def __init__(self, model):\n            super().__init__()\n            self.model = model\n\n        def forward(self, x):\n            y = self.model(x)\n            return JacobianOP.apply(y, x)\n\n    model = BoundedModule(JacobianWrapper(model_ori), x0, bound_opts=bound_opts, device=device)\n\n    def func(x0):\n        return model_ori(x0.requires_grad_(True))\n    ret_ori = torch.autograd.functional.jacobian(func, x0).squeeze(2)\n    ret_new = model(x0)\n    assert torch.allclose(ret_ori, ret_new)\n\n    ret = []\n    for eps in [0, 1./255, 4./255]:\n        x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps))\n        lower, upper = model.compute_jacobian_bounds(x)\n        print(f'Gap between upper and lower Jacobian bound for eps={eps:.5f}',\n            (upper - lower).max())\n        if eps == 0:\n            assert torch.allclose(\n                ret_new.view(-1),\n                lower.sum(dim=0, keepdim=True).view(-1))\n            assert torch.allclose(\n                ret_new.view(-1),\n                upper.sum(dim=0, keepdim=True).view(-1))\n        ret.append((lower.detach(), upper.detach()))\n\n    return ret\n\n\ndef example_local_lipschitz(model_ori, x0, bound_opts, device):\n    \"\"\"Example: computing Linf local Lipschitz constant.\"\"\"\n\n    class LocalLipschitzWrapper(nn.Module):\n        def __init__(self, model):\n            super().__init__()\n            self.model = model\n            self.grad_norm = GradNorm(norm=1)\n\n        def forward(self, x, mask):\n            y = self.model(x)\n            y_selected = y.matmul(mask)\n            jacobian = JacobianOP.apply(y_selected, x)\n            lipschitz = self.grad_norm(jacobian)\n            return lipschitz\n\n    mask = torch.zeros(10, 1, device=device)\n    mask[1, 0] = 1\n    model = BoundedModule(LocalLipschitzWrapper(model_ori), (BoundedTensor(x0), mask),\n                          bound_opts=bound_opts, device=device)\n\n    y = model_ori(x0.requires_grad_(True))\n    ret_ori = torch.autograd.grad(y[:, 1].sum(), x0)[0].abs().flatten(1).sum(dim=-1).view(-1)\n    ret_new = model(x0, mask).view(-1)\n    assert torch.allclose(ret_ori, ret_new)\n\n    ret = []\n    for eps in [0, 1./255, 4./255]:\n        x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps))\n        lip = []\n        for i in range(mask.shape[0]):\n            mask.zero_()\n            mask[i, 0] = 1\n            ub = model.compute_jacobian_bounds((x, mask), bound_lower=False)[1]\n            lip.append(ub)\n        lip = torch.concat(lip).max()\n        print(f'Linf local Lipschitz constant for eps={eps:.5f}: {lip.item()}')\n        ret.append(lip.detach())\n\n    return ret\n\n\ndef example_jvp(model_ori, x0, bound_opts, device):\n    \"\"\"Example: computing Jacobian-Vector Product.\"\"\"\n\n    class JVPWrapper(nn.Module):\n        def __init__(self, model):\n            super().__init__()\n            self.model = model\n            self.grad_norm = GradNorm(norm=1)\n\n        def forward(self, x, v):\n            y = self.model(x)\n            jacobian = JacobianOP.apply(y, x).flatten(2)\n            jvp = (jacobian * v.flatten(1).unsqueeze(1)).sum(dim=-1)\n            return jvp\n\n    vector = torch.rand_like(x0)\n    model = BoundedModule(JVPWrapper(model_ori), (BoundedTensor(x0), vector),\n                          bound_opts=bound_opts, device=device)\n\n    def func(x0):\n        return model_ori(x0.requires_grad_(True))\n    ret_ori = torch.autograd.functional.jvp(func, x0, vector)[-1].view(-1)\n    ret_new = model(x0, vector)\n    assert torch.allclose(ret_ori, ret_new)\n\n    ret = []\n    for eps in [0, 1./255, 4./255]:\n        x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps))\n        lb, ub = model.compute_jacobian_bounds((x, vector))\n        print(f'JVP lower bound for eps={eps:.5f}: {lb}')\n        print(f'JVP upper bound for eps={eps:.5f}: {ub}')\n        ret.append((lb, ub))\n\n    return ret\n\n\ndef compute_jacobians(model_ori, x0, bound_opts=None, device='cpu'):\n    results = [[] for _ in range(3)]\n\n    model_ori = model_ori.to(device)\n    x0 = x0.to(device)\n    print('Model:', model_ori)\n\n    results[0] = example_jacobian(model_ori, x0, bound_opts, device)\n    results[1] = example_local_lipschitz(model_ori, x0, bound_opts, device)\n    results[2] = example_jvp(model_ori, x0, bound_opts, device)\n\n    return results\n\n\nif __name__ == '__main__':\n    torch.manual_seed(0)\n\n    # Create a small model and load pre-trained parameters.\n    model_ori = build_model(in_dim=8)\n    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n    x0 = torch.randn(1, 3, 8, 8, device=device)\n\n    compute_jacobians(model_ori, x0, device=device)\n"
  },
  {
    "path": "examples/vision/models/__init__.py",
    "content": "from models.resnet import model_resnet\nfrom models.feedforward import *\nfrom models.resnext import *\nfrom models.resnext_imagenet64 import *\nfrom models.densenet import *\nfrom models.mobilenet import *\nfrom models.densenet_no_bn import *\nfrom models.densenet_imagenet import *\nfrom models.wide_resnet_imagenet64 import *\nfrom models.wide_resnet_cifar import *\nfrom models.resnet18 import *\nfrom models.vnncomp_resnet import resnet2b as vnncomp_resnet2b, resnet4b as vnncomp_resnet4b\n\n\nModels = {\n    'mlp_2layer': mlp_2layer,\n    'mlp_3layer': mlp_3layer,\n    'mlp_3layer_weight_perturb': mlp_3layer_weight_perturb,\n    'mlp_5layer': mlp_5layer,\n    'cnn_4layer': cnn_4layer,\n    'cnn_6layer': cnn_6layer,\n    'cnn_7layer': cnn_7layer,\n    'cnn_7layer_bn': cnn_7layer_bn,\n    'cnn_7layer_bn_imagenet': cnn_7layer_bn_imagenet,\n    'resnet': model_resnet,\n    'resnet18': ResNet18,\n    'ResNeXt_cifar': ResNeXt_cifar,\n    'ResNeXt_imagenet64': ResNeXt_imagenet64,\n    'Densenet_cifar_32': Densenet_cifar_32,\n    'Densenet_cifar_wobn': Densenet_cifar_wobn,\n    'Densenet_imagenet': Densenet_imagenet,\n    'MobileNet_cifar': MobileNetV2,\n    'wide_resnet_cifar': wide_resnet_cifar,\n    'wide_resnet_cifar_bn': wide_resnet_cifar_bn,\n    'wide_resnet_cifar_bn_wo_pooling': wide_resnet_cifar_bn_wo_pooling,\n    'wide_resnet_cifar_bn_wo_pooling_dropout': wide_resnet_cifar_bn_wo_pooling_dropout,\n    'wide_resnet_imagenet64': wide_resnet_imagenet64,\n    'wide_resnet_imagenet64_1000class': wide_resnet_imagenet64_1000class,\n    'vnncomp_resnet2b': vnncomp_resnet2b,\n    'vnncomp_resnet4b': vnncomp_resnet4b,\n}\n"
  },
  {
    "path": "examples/vision/models/densenet.py",
    "content": "'''DenseNet in PyTorch.\nhttps://github.com/kuangliu/pytorch-cifar\n'''\n\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Bottleneck(nn.Module):\n    def __init__(self, in_planes, growth_rate):\n        super(Bottleneck, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=True)\n        self.bn2 = nn.BatchNorm2d(4*growth_rate)\n        self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=True)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = self.conv2(F.relu(self.bn2(out)))\n        # out = self.conv1(F.relu(x))\n        # out = self.conv2(F.relu(out))\n        out = torch.cat([out,x], 1)\n        return out\n\n\nclass Transition(nn.Module):\n    def __init__(self, in_planes, out_planes):\n        super(Transition, self).__init__()\n        self.bn = nn.BatchNorm2d(in_planes)\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=True)\n\n    def forward(self, x):\n        out = self.conv(F.relu(self.bn(x)))\n        out = F.avg_pool2d(out, 2)\n        return out\n\n\nclass DenseNet(nn.Module):\n    def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):\n        super(DenseNet, self).__init__()\n        self.growth_rate = growth_rate\n\n        num_planes = 2*growth_rate\n        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=True)\n\n        self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])\n        num_planes += nblocks[0]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans1 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])\n        num_planes += nblocks[1]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans2 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])\n        num_planes += nblocks[2]*growth_rate\n        # out_planes = int(math.floor(num_planes*reduction))\n        # self.trans3 = Transition(num_planes, out_planes)\n        # num_planes = out_planes\n\n        # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])\n        # num_planes += nblocks[3]*growth_rate\n\n        self.bn = nn.BatchNorm2d(num_planes)\n        self.linear1 = nn.Linear(14336, 512)\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _make_dense_layers(self, block, in_planes, nblock):\n        layers = []\n        for i in range(nblock):\n            layers.append(block(in_planes, self.growth_rate))\n            in_planes += self.growth_rate\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.trans1(self.dense1(out))\n        out = self.trans2(self.dense2(out))\n        out = self.dense3(out)\n        out = F.relu(self.bn(out))\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n\n        return out\n\ndef Densenet_cifar_32(in_ch=3, in_dim=32):\n    return DenseNet(Bottleneck, [2,4,4], growth_rate=32)\n\nif __name__ == \"__main__\":\n    from thop import profile\n\n    net = Densenet_cifar_32()\n    x = torch.randn(1,3,32,32)\n    y = net(x)\n    print(net)\n    macs, params = profile(net, (torch.randn(1, 3, 32, 32),))\n    print(macs / 1000000, params / 1000000)  # 6830M, 7M\n    print(y)\n"
  },
  {
    "path": "examples/vision/models/densenet_imagenet.py",
    "content": "'''DenseNet in PyTorch.\nhttps://github.com/kuangliu/pytorch-cifar\n'''\n\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Bottleneck(nn.Module):\n    def __init__(self, in_planes, growth_rate):\n        super(Bottleneck, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=True)\n        self.bn2 = nn.BatchNorm2d(4*growth_rate)\n        self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=True)\n\n    def forward(self, x):\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = self.conv2(F.relu(self.bn2(out)))\n        # out = self.conv1(F.relu(x))\n        # out = self.conv2(F.relu(out))\n        out = torch.cat([out,x], 1)\n        return out\n\n\nclass Transition(nn.Module):\n    def __init__(self, in_planes, out_planes):\n        super(Transition, self).__init__()\n        self.bn = nn.BatchNorm2d(in_planes)\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=True)\n\n    def forward(self, x):\n        out = self.conv(F.relu(self.bn(x)))\n        out = F.avg_pool2d(out, 2)\n        return out\n\n\nclass DenseNet(nn.Module):\n    def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=200):\n        super(DenseNet, self).__init__()\n        self.growth_rate = growth_rate\n\n        num_planes = 2*growth_rate\n        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=True)\n\n        self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])\n        num_planes += nblocks[0]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans1 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])\n        num_planes += nblocks[1]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans2 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])\n        num_planes += nblocks[2]*growth_rate\n        # out_planes = int(math.floor(num_planes*reduction))\n        # self.trans3 = Transition(num_planes, out_planes)\n        # num_planes = out_planes\n\n        # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])\n        # num_planes += nblocks[3]*growth_rate\n\n        self.bn = nn.BatchNorm2d(num_planes)\n        self.linear1 = nn.Linear(43904, 512)\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _make_dense_layers(self, block, in_planes, nblock):\n        layers = []\n        for i in range(nblock):\n            layers.append(block(in_planes, self.growth_rate))\n            in_planes += self.growth_rate\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.trans1(self.dense1(out))\n        out = self.trans2(self.dense2(out))\n        out = self.dense3(out)\n        out = F.relu(self.bn(out))\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n\n        return out\n\n\ndef Densenet_imagenet(in_ch=3, in_dim=56):\n    return DenseNet(Bottleneck, [2,4,4], growth_rate=32)\n\nif __name__ == \"__main__\":\n    from thop import profile\n\n    net = Densenet_imagenet()\n    x = torch.randn(1,3,56,56)\n    y = net(x)\n    print(net)\n    macs, params = profile(net, (torch.randn(1, 3, 56, 56),))\n    print(macs / 1000000, params / 1000000)  # 564M, 11M\n    print(y.shape)\n"
  },
  {
    "path": "examples/vision/models/densenet_no_bn.py",
    "content": "'''DenseNet in PyTorch.\nhttps://github.com/kuangliu/pytorch-cifar\n'''\n\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Bottleneck(nn.Module):\n    def __init__(self, in_planes, growth_rate):\n        super(Bottleneck, self).__init__()\n        # self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=True)\n        # self.bn2 = nn.BatchNorm2d(4*growth_rate)\n        self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=True)\n\n    def forward(self, x):\n        # out = self.conv1(F.relu(self.bn1(x)))\n        # out = self.conv2(F.relu(self.bn2(out)))\n        out = self.conv1(F.relu(x))\n        out = self.conv2(F.relu(out))\n        out = torch.cat([out,x], 1)\n        return out\n\n\nclass Transition(nn.Module):\n    def __init__(self, in_planes, out_planes):\n        super(Transition, self).__init__()\n        # self.bn = nn.BatchNorm2d(in_planes)\n        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=True)\n\n    def forward(self, x):\n        out = self.conv(F.relu(x))\n        out = F.avg_pool2d(out, 2)\n        return out\n\n\nclass DenseNet(nn.Module):\n    def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):\n        super(DenseNet, self).__init__()\n        self.growth_rate = growth_rate\n\n        num_planes = 2*growth_rate\n        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=True)\n\n        self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])\n        num_planes += nblocks[0]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans1 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])\n        num_planes += nblocks[1]*growth_rate\n        out_planes = int(math.floor(num_planes*reduction))\n        self.trans2 = Transition(num_planes, out_planes)\n        num_planes = out_planes\n\n        self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])\n        num_planes += nblocks[2]*growth_rate\n        # out_planes = int(math.floor(num_planes*reduction))\n        # self.trans3 = Transition(num_planes, out_planes)\n        # num_planes = out_planes\n\n        # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])\n        # num_planes += nblocks[3]*growth_rate\n\n        # self.bn = nn.BatchNorm2d(num_planes)\n        self.linear1 = nn.Linear(9216, 512)\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _make_dense_layers(self, block, in_planes, nblock):\n        layers = []\n        for i in range(nblock):\n            layers.append(block(in_planes, self.growth_rate))\n            in_planes += self.growth_rate\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.trans1(self.dense1(out))\n        out = self.trans2(self.dense2(out))\n        out = self.dense3(out)\n        out = F.relu(out)\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n\n        return out\n\ndef Densenet_cifar_wobn(in_ch=3, in_dim=56):\n    return DenseNet(Bottleneck, [2,4,6], growth_rate=16)\n\n\nif __name__ == \"__main__\":\n    net = Densenet_cifar_wobn()\n    x = torch.randn(1,3,32,32)\n    y = net(x)\n    print(net)\n    print(y)\n"
  },
  {
    "path": "examples/vision/models/feedforward.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom auto_LiRPA import PerturbationLpNorm, BoundedParameter\n\n\n# CNN, relatively large 4-layer\n# parameter in_ch: input image channel, 1 for MNIST and 3 for CIFAR\n# parameter in_dim: input dimension, 28 for MNIST and 32 for CIFAR\n# parameter width: width multiplier\nclass cnn_4layer(nn.Module):\n    def __init__(self, in_ch, in_dim, width=2, linear_size=256):\n        super(cnn_4layer, self).__init__()\n        self.conv1 = nn.Conv2d(in_ch, 4 * width, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(4 * width, 8 * width, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(8 * width * (in_dim // 4) * (in_dim // 4), linear_size)\n        self.fc2 = nn.Linear(linear_size, 10)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        x = torch.flatten(x, 1)\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n\n        return x\n\n\nclass mlp_2layer(nn.Module):\n    def __init__(self, in_ch, in_dim, width=1):\n        super(mlp_2layer, self).__init__()\n        self.fc1 = nn.Linear(in_ch * in_dim * in_dim, 256 * width)\n        self.fc2 = nn.Linear(256 * width, 10)\n\n    def forward(self, x):\n        x = torch.flatten(x, 1)\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        return x\n\n\nclass mlp_3layer(nn.Module):\n    def __init__(self, in_ch, in_dim, width=1):\n        super(mlp_3layer, self).__init__()\n        self.fc1 = nn.Linear(in_ch * in_dim * in_dim, 256 * width)\n        self.fc2 = nn.Linear(256 * width, 128 * width)\n        self.fc3 = nn.Linear(128 * width, 10)\n\n    def forward(self, x):\n        x = torch.flatten(x, 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nclass mlp_3layer_weight_perturb(nn.Module):\n    def __init__(self, in_ch=1, in_dim=28, width=1, pert_weight=True, pert_bias=False, norm=2):\n        super(mlp_3layer_weight_perturb, self).__init__()\n        self.fc1 = nn.Linear(in_ch * in_dim * in_dim, 64 * width)\n        self.fc2 = nn.Linear(64 * width, 64 * width)\n        self.fc3 = nn.Linear(64 * width, 10)\n\n        eps = 0.01\n        self.ptb = PerturbationLpNorm(norm=norm, eps=eps)\n\n        if pert_weight:\n            self.fc1.weight = BoundedParameter(self.fc1.weight.data, self.ptb)\n            self.fc2.weight = BoundedParameter(self.fc2.weight.data, self.ptb)\n            self.fc3.weight = BoundedParameter(self.fc3.weight.data, self.ptb)\n\n        if pert_bias:\n            self.fc1.bias = BoundedParameter(self.fc1.bias.data, self.ptb)\n            self.fc2.bias = BoundedParameter(self.fc2.bias.data, self.ptb)\n            self.fc3.bias = BoundedParameter(self.fc3.bias.data, self.ptb)\n\n    def forward(self, x):\n        x = x.view(-1, 784)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nclass mlp_5layer(nn.Module):\n    def __init__(self, in_ch, in_dim, width=1):\n        super(mlp_5layer, self).__init__()\n        self.fc1 = nn.Linear(in_ch * in_dim * in_dim, 256 * width)\n        self.fc2 = nn.Linear(256 * width, 256 * width)\n        self.fc3 = nn.Linear(256 * width, 256 * width)\n        self.fc4 = nn.Linear(256 * width, 128 * width)\n        self.fc5 = nn.Linear(128 * width, 10)\n\n    def forward(self, x):\n        x = torch.flatten(x, 1)\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = F.relu(self.fc3(x))\n        x = F.relu(self.fc4(x))\n        x = self.fc5(x)\n        return x\n\n\n# Model can also be defined as a nn.Sequential\ndef cnn_7layer(in_ch=3, in_dim=32, width=64, linear_size=512):\n    model = nn.Sequential(\n        nn.Conv2d(in_ch, width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(width, width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(width, 2 * width, 3, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Flatten(),\n        nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size),\n        nn.ReLU(),\n        nn.Linear(linear_size,10)\n    )\n    return model\n\ndef cnn_7layer_bn(in_ch=3, in_dim=32, width=64, linear_size=512):\n    model = nn.Sequential(\n        nn.Conv2d(in_ch, width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(width),\n        nn.ReLU(),\n        nn.Conv2d(width, width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(width),\n        nn.ReLU(),\n        nn.Conv2d(width, 2 * width, 3, stride=2, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Flatten(),\n        nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size),\n        nn.ReLU(),\n        nn.Linear(linear_size,10)\n    )\n    return model\n\ndef cnn_7layer_bn_imagenet(in_ch=3, in_dim=32, width=64, linear_size=512):\n    model = nn.Sequential(\n        nn.Conv2d(in_ch, width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(width),\n        nn.ReLU(),\n        nn.Conv2d(width, width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(width),\n        nn.ReLU(),\n        nn.Conv2d(width, 2 * width, 3, stride=2, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=2, padding=1),\n        nn.BatchNorm2d(2 * width),\n        nn.ReLU(),\n        nn.Flatten(),\n        nn.Linear(25088, linear_size),\n        nn.ReLU(),\n        nn.Linear(linear_size,200)\n    )\n    return model\n\ndef cnn_6layer(in_ch, in_dim, width=32, linear_size=256):\n    model = nn.Sequential(\n        nn.Conv2d(in_ch, width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(width, width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(width, 2 * width, 3, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1),\n        nn.ReLU(),\n        nn.Flatten(),\n        nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size),\n        nn.ReLU(),\n        nn.Linear(linear_size,10)\n    )\n    return model\n"
  },
  {
    "path": "examples/vision/models/mobilenet.py",
    "content": "'''MobileNetV2 in PyTorch.\n\nSee the paper \"Inverted Residuals and Linear Bottlenecks:\nMobile Networks for Classification, Detection and Segmentation\" for more details.\n'''\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Block(nn.Module):\n    '''expand + depthwise + pointwise'''\n    def __init__(self, in_planes, out_planes, expansion, stride):\n        super(Block, self).__init__()\n        self.stride = stride\n\n        planes = expansion * in_planes\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)\n        # self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)\n        # self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)\n        # self.bn3 = nn.BatchNorm2d(out_planes)\n\n        self.shortcut = nn.Sequential()\n        if stride == 1 and in_planes != out_planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),\n                # nn.BatchNorm2d(out_planes),\n            )\n\n    def forward(self, x):\n        out = F.relu((self.conv1(x)))\n        out = F.relu((self.conv2(out)))\n        out = self.conv3(out)\n        out = out + self.shortcut(x) if self.stride==1 else out\n        return out\n\n\nclass MobileNetV2(nn.Module):\n    # (expansion, out_planes, num_blocks, stride)\n    cfg = [(1,  16, 1, 1),\n           (6,  24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10\n           (6,  32, 3, 2),\n           (6,  64, 4, 2),\n           (6,  96, 3, 1),\n           (6, 160, 3, 2),\n           (6, 320, 1, 1)]\n\n    def __init__(self, num_classes=10):\n        super(MobileNetV2, self).__init__()\n        # NOTE: change conv1 stride 2 -> 1 for CIFAR10\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)\n        # self.bn1 = nn.BatchNorm2d(32)\n        self.layers = self._make_layers(in_planes=32)\n        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)\n        # self.bn2 = nn.BatchNorm2d(1280)\n        self.linear = nn.Linear(1280, num_classes)\n\n    def _make_layers(self, in_planes):\n        layers = []\n        for expansion, out_planes, num_blocks, stride in self.cfg:\n            strides = [stride] + [1]*(num_blocks-1)\n            for stride in strides:\n                layers.append(Block(in_planes, out_planes, expansion, stride))\n                in_planes = out_planes\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu((self.conv1(x)))\n        out = self.layers(out)\n        out = F.relu((self.conv2(out)))\n        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10\n        out = F.avg_pool2d(out, 4)\n        out = torch.flatten(out, 1)\n        out = self.linear(out)\n        return out\n\n\nif __name__ == \"__main__\":\n    net = MobileNetV2()\n    x = torch.randn(2,3,32,32)\n    y = net(x)\n    print(y.size())\n"
  },
  {
    "path": "examples/vision/models/resnet.py",
    "content": "'''\nResNet used in https://arxiv.org/pdf/1805.12514.pdf\nhttps://github.com/locuslab/convex_adversarial/blob/0d11e671ad9318745a2439afce513c82dc6bf5ce/examples/problems.py\n'''\nimport torch\nimport torch.nn as nn\nimport math\n\n\nclass Dense(nn.Module):\n    def __init__(self, *Ws):\n        super(Dense, self).__init__()\n        self.Ws = nn.ModuleList(list(Ws))\n        if len(Ws) > 0 and hasattr(Ws[0], 'out_features'):\n            self.out_features = Ws[0].out_features\n\n    def forward(self, *xs):\n        xs = xs[-len(self.Ws):]\n        out = sum(W(x) for x, W in zip(xs, self.Ws) if W is not None)\n        return out\n\n\nclass DenseSequential(nn.Sequential):\n    def forward(self, x):\n        xs = [x]\n        for module in self._modules.values():\n            if 'Dense' in type(module).__name__:\n                xs.append(module(*xs))\n            else:\n                xs.append(module(xs[-1]))\n        return xs[-1]\n\n\ndef model_resnet(in_ch=3, in_dim=32, width=1, mult=16, N=1):\n    def block(in_filters, out_filters, k, downsample):\n        if not downsample:\n            k_first = 3\n            skip_stride = 1\n            k_skip = 1\n        else:\n            k_first = 4\n            skip_stride = 2\n            k_skip = 2\n        return [\n            Dense(nn.Conv2d(in_filters, out_filters, k_first, stride=skip_stride, padding=1)),\n            nn.ReLU(),\n            Dense(nn.Conv2d(in_filters, out_filters, k_skip, stride=skip_stride, padding=0),\n                  None,\n                  nn.Conv2d(out_filters, out_filters, k, stride=1, padding=1)),\n            nn.ReLU()\n        ]\n\n    conv1 = [nn.Conv2d(in_ch, mult, 3, stride=1, padding=3 if in_dim == 28 else 1), nn.ReLU()]\n    conv2 = block(mult, mult * width, 3, False)\n    for _ in range(N):\n        conv2.extend(block(mult * width, mult * width, 3, False))\n    conv3 = block(mult * width, mult * 2 * width, 3, True)\n    for _ in range(N - 1):\n        conv3.extend(block(mult * 2 * width, mult * 2 * width, 3, False))\n    conv4 = block(mult * 2 * width, mult * 4 * width, 3, True)\n    for _ in range(N - 1):\n        conv4.extend(block(mult * 4 * width, mult * 4 * width, 3, False))\n    layers = (\n            conv1 +\n            conv2 +\n            conv3 +\n            conv4 +\n            [nn.Flatten(),\n             nn.Linear(mult * 4 * width * 8 * 8, 1000),\n             nn.ReLU(),\n             nn.Linear(1000, 10)]\n    )\n    model = DenseSequential(\n        *layers\n    )\n\n    for m in model.modules():\n        if isinstance(m, nn.Conv2d):\n            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            m.weight.data.normal_(0, math.sqrt(2. / n))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    return model\n\n\nif __name__ == \"__main__\":\n    model = model_resnet(in_ch=1, in_dim=28)\n    dummy = torch.randn(8, 1, 28, 28)\n    print(model)\n    print(model(dummy).shape)\n"
  },
  {
    "path": "examples/vision/models/resnet18.py",
    "content": "'''ResNet in PyTorch.\n\nFor Pre-activation ResNet, see 'preact_resnet.py'.\n\nReference:\n[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n    Deep Residual Learning for Image Recognition. arXiv:1512.03385\n'''\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_planes, planes, stride=1):\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(\n            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*planes,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion*planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.bn2(self.conv2(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, in_planes, planes, stride=1):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                               stride=stride, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, self.expansion *\n                               planes, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(self.expansion*planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*planes,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(self.expansion*planes)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = self.bn3(self.conv3(out))\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, num_blocks, num_classes=10, in_planes=64):\n        super(ResNet, self).__init__()\n        self.in_planes = in_planes\n\n        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3,\n                               stride=1, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.layer1 = self._make_layer(block, in_planes, num_blocks[0], stride=1)\n        self.layer2 = self._make_layer(block, in_planes * 2, num_blocks[1], stride=2)\n        self.layer3 = self._make_layer(block, in_planes * 4, num_blocks[2], stride=2)\n        self.layer4 = self._make_layer(block, in_planes * 8, num_blocks[3], stride=2)\n        self.linear = nn.Linear(in_planes * 8 * block.expansion, num_classes)\n\n    def _make_layer(self, block, planes, num_blocks, stride):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = self.layer4(out)\n        out = F.avg_pool2d(out, 4)\n        out = torch.flatten(out, 1)\n        out = self.linear(out)\n        return out\n\ndef ResNet18(in_planes=64):\n    return ResNet(BasicBlock, [2, 2, 2, 2], in_planes=in_planes)\n\nif __name__ == \"__main__\":\n    from thop import profile\n    net = ResNet18(in_planes=64)\n    x = torch.randn(1,3,32,32)\n    y = net(x)\n    print(net)\n    macs, params = profile(net, (torch.randn(1, 3, 32, 32),))\n    print(macs / 1000000, params / 1000000)  # 556M, 11M\n    print(y)\n"
  },
  {
    "path": "examples/vision/models/resnext.py",
    "content": "'''ResNeXt in PyTorch.\nSee the paper \"Aggregated Residual Transformations for Deep Neural Networks\" for more details.\nhttps://github.com/kuangliu/pytorch-cifar\n'''\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Block(nn.Module):\n    '''Grouped convolution block.'''\n    expansion = 2\n\n    def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):\n        super(Block, self).__init__()\n        group_width = cardinality * bottleneck_width\n        self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=True)\n        self.bn1 = nn.BatchNorm2d(group_width)\n        self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=True)\n        # self.bn2 = nn.BatchNorm2d(group_width)\n        self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=True)\n        # self.bn3 = nn.BatchNorm2d(self.expansion*group_width)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*group_width:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=True),\n                # nn.BatchNorm2d(self.expansion*group_width)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        # out = F.relu(self.bn2(self.conv2(out)))\n        # out = self.bn3(self.conv3(out))\n        # out = F.relu(self.conv1(x))\n        out = F.relu(self.conv2(out))\n        out = self.conv3(out)\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNeXt(nn.Module):\n    def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10):\n        super(ResNeXt, self).__init__()\n        self.cardinality = cardinality\n        self.bottleneck_width = bottleneck_width\n        self.in_planes = 16\n\n        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, bias=True, padding=1)\n        # self.bn1 = nn.BatchNorm2d(16)\n        self.layer1 = self._make_layer(num_blocks[0], 1)\n        self.layer2 = self._make_layer(num_blocks[1], 2)\n        self.layer3 = self._make_layer(num_blocks[2], 2)\n        # self.layer4 = self._make_layer(num_blocks[3], 2)\n        self.linear1 = nn.Linear(cardinality*bottleneck_width*512, 512)\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _make_layer(self, num_blocks, stride):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))\n            self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width\n        # Increase bottleneck_width by 2 after each stage.\n        self.bottleneck_width *= 2\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n        return out\n\n\ndef ResNeXt29_2x64d():\n    return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64)\n\ndef ResNeXt29_4x64d():\n    return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64)\n\ndef ResNeXt29_8x64d():\n    return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64)\n\ndef ResNeXt29_32x4d():\n    return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4)\n\ndef ResNeXt_cifar(in_ch=3, in_dim=32):\n    return ResNeXt(num_blocks=[1,1,1], cardinality=2, bottleneck_width=32)\n\nif __name__ == \"__main__\":\n    from thop import profile\n    net = ResNeXt_cifar()\n    x = torch.randn(1,3,32,32)\n    y = net(x)\n    print(net)\n    macs, params = profile(net, (torch.randn(1, 3, 32, 32),))\n    print(macs / 1000000, params / 1000000)  # 6830M, 7M\n    print(y)"
  },
  {
    "path": "examples/vision/models/resnext_imagenet64.py",
    "content": "'''ResNeXt in PyTorch.\nSee the paper \"Aggregated Residual Transformations for Deep Neural Networks\" for more details.\nhttps://github.com/kuangliu/pytorch-cifar\n'''\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Block(nn.Module):\n    '''Grouped convolution block.'''\n    expansion = 2\n\n    def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):\n        super(Block, self).__init__()\n        group_width = cardinality * bottleneck_width\n        self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=True)\n        self.bn1 = nn.BatchNorm2d(group_width)\n        self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=True)\n        # self.bn2 = nn.BatchNorm2d(group_width)\n        self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=True)\n        # self.bn3 = nn.BatchNorm2d(self.expansion*group_width)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*group_width:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=True),\n                # nn.BatchNorm2d(self.expansion*group_width)\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(self.conv1(x)))\n        # out = F.relu(self.bn2(self.conv2(out)))\n        # out = self.bn3(self.conv3(out))\n        # out = F.relu(self.conv1(x))\n        out = F.relu(self.conv2(out))\n        out = self.conv3(out)\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\nclass ResNeXt(nn.Module):\n    def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=200):\n        super(ResNeXt, self).__init__()\n        self.cardinality = cardinality\n        self.bottleneck_width = bottleneck_width\n        self.in_planes = 16\n\n        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, bias=True, padding=1)\n        # self.bn1 = nn.BatchNorm2d(16)\n        self.layer1 = self._make_layer(num_blocks[0], 1)\n        self.layer2 = self._make_layer(num_blocks[1], 2)\n        self.layer3 = self._make_layer(num_blocks[2], 2)\n        # self.layer4 = self._make_layer(num_blocks[3], 2)\n        self.linear1 = nn.Linear(cardinality*bottleneck_width*1568, 512)\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _make_layer(self, num_blocks, stride):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))\n            self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width\n        # Increase bottleneck_width by 2 after each stage.\n        self.bottleneck_width *= 2\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = F.relu(self.conv1(x))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n        return out\n\ndef ResNeXt_imagenet64():\n    return ResNeXt(num_blocks=[2,2,2], cardinality=2, bottleneck_width=8)\n\nif __name__ == \"__main__\":\n    from thop import profile\n    net = ResNeXt_imagenet64()\n    x = torch.randn(1,3,56,56)\n    y = net(x)\n    print(net)\n    macs, params = profile(net, (torch.randn(1, 3, 56, 56),))\n    print(macs / 1000000, params / 1000000)  # 64M, 13M\n    print(y.shape)"
  },
  {
    "path": "examples/vision/models/vnncomp_resnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_planes, planes, stride=1, bn=True, kernel=3):\n        super(BasicBlock, self).__init__()\n        self.bn = bn\n        if kernel == 3:\n            self.conv1 = nn.Conv2d(\n                in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=(not self.bn))\n            if self.bn:\n                self.bn1 = nn.BatchNorm2d(planes)\n            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,\n                                   stride=1, padding=1, bias=(not self.bn))\n        elif kernel == 2:\n            self.conv1 = nn.Conv2d(\n                in_planes, planes, kernel_size=2, stride=stride, padding=1, bias=(not self.bn))\n            if self.bn:\n                self.bn1 = nn.BatchNorm2d(planes)\n            self.conv2 = nn.Conv2d(planes, planes, kernel_size=2,\n                                   stride=1, padding=0, bias=(not self.bn))\n        elif kernel == 1:\n            self.conv1 = nn.Conv2d(\n                in_planes, planes, kernel_size=1, stride=stride, padding=0, bias=(not self.bn))\n            if self.bn:\n                self.bn1 = nn.BatchNorm2d(planes)\n            self.conv2 = nn.Conv2d(planes, planes, kernel_size=1,\n                                   stride=1, padding=0, bias=(not self.bn))\n        else:\n            exit(\"kernel not supported!\")\n\n        if self.bn:\n            self.bn2 = nn.BatchNorm2d(planes)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != self.expansion*planes:\n            if self.bn:\n                self.shortcut = nn.Sequential(\n                    nn.Conv2d(in_planes, self.expansion*planes,\n                              kernel_size=1, stride=stride, bias=(not self.bn)),\n                    nn.BatchNorm2d(self.expansion*planes)\n                )\n            else:\n                self.shortcut = nn.Sequential(\n                    nn.Conv2d(in_planes, self.expansion*planes,\n                              kernel_size=1, stride=stride, bias=(not self.bn)),\n                )\n\n    def forward(self, x):\n        if self.bn:\n            out = F.relu(self.bn1(self.conv1(x)))\n            out = self.bn2(self.conv2(out))\n        else:\n            out = F.relu(self.conv1(x))\n            out = self.conv2(out)\n        out += self.shortcut(x)\n        out = F.relu(out)\n        return out\n\n\n\nclass ResNet5(nn.Module):\n    def __init__(self, block, num_blocks=2, num_classes=10, in_planes=64, bn=True, last_layer=\"avg\"):\n        super(ResNet5, self).__init__()\n        self.in_planes = in_planes\n        self.bn = bn\n        self.last_layer = last_layer\n        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3,\n                               stride=2, padding=1, bias=not self.bn)\n        if self.bn: self.bn1 = nn.BatchNorm2d(in_planes)\n        self.layer1 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)\n        if self.last_layer == \"avg\":\n            self.avg2d = nn.AvgPool2d(4)\n            self.linear = nn.Linear(in_planes * 8 * block.expansion, num_classes)\n        elif self.last_layer == \"dense\":\n            self.linear1 = nn.Linear(in_planes * 8 * block.expansion * 16, 100)\n            self.linear2 = nn.Linear(100, num_classes)\n        else:\n            exit(\"last_layer type not supported!\")\n\n    def _make_layer(self, block, planes, num_blocks, stride, bn, kernel):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride, bn, kernel))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        if self.bn:\n            out = F.relu(self.bn1(self.conv1(x)))\n        else:\n            out = F.relu(self.conv1(x))\n        out = self.layer1(out)\n        if self.last_layer == \"avg\":\n            out = self.avg2d(out)\n            out = torch.flatten(out, 1)\n            out = self.linear(out)\n        elif self.last_layer == \"dense\":\n            out = torch.flatten(out, 1)\n            out = F.relu(self.linear1(out))\n            out = self.linear2(out)\n        return out\n\n\nclass ResNet9(nn.Module):\n    def __init__(self, block, num_blocks=2, num_classes=10, in_planes=64, bn=True, last_layer=\"avg\"):\n        super(ResNet9, self).__init__()\n        self.in_planes = in_planes\n        self.bn = bn\n        self.last_layer = last_layer\n        self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3,\n                               stride=2, padding=1, bias=not self.bn)\n        if self.bn: self.bn1 = nn.BatchNorm2d(in_planes)\n        self.layer1 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)\n        self.layer2 = self._make_layer(block, in_planes*2, num_blocks, stride=2, bn=bn, kernel=3)\n        if self.last_layer == \"avg\":\n            self.avg2d = nn.AvgPool2d(4)\n            self.linear = nn.Linear(in_planes * 2 * block.expansion, num_classes)\n        elif self.last_layer == \"dense\":\n            self.linear1 = nn.Linear(in_planes * 2 * block.expansion * 16, 100)\n            self.linear2 = nn.Linear(100, num_classes)\n        else:\n            exit(\"last_layer type not supported!\")\n\n    def _make_layer(self, block, planes, num_blocks, stride, bn, kernel):\n        strides = [stride] + [1]*(num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride, bn, kernel))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        if self.bn:\n            out = F.relu(self.bn1(self.conv1(x)))\n        else:\n            out = F.relu(self.conv1(x))\n        out = self.layer1(out)\n        out = self.layer2(out)\n        if self.last_layer == \"avg\":\n            out = self.avg2d(out)\n            out = torch.flatten(out, 1)\n            out = self.linear(out)\n        elif self.last_layer == \"dense\":\n            out = torch.flatten(out, 1)\n            out = F.relu(self.linear1(out))\n            out = self.linear2(out)\n        return out\n\n\ndef resnet2b():\n    return ResNet5(BasicBlock, num_blocks=2, in_planes=8, bn=False, last_layer=\"dense\")\n\ndef resnet4b():\n    return ResNet9(BasicBlock, num_blocks=2, in_planes=16, bn=False, last_layer=\"dense\")\n\n\nif __name__ == '__main__':\n    print('ResNet-2B:\\n', resnet2b())\n    print('ResNet-4B:\\n', resnet4b())\n"
  },
  {
    "path": "examples/vision/models/wide_resnet_cifar.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\nimport sys\nimport numpy as np\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)\n\ndef conv_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        init.xavier_uniform_(m.weight, gain=np.sqrt(2))\n        init.constant_(m.bias, 0)\n    elif classname.find('BatchNorm') != -1:\n        init.constant_(m.weight, 1)\n        init.constant_(m.bias, 0)\n\nclass wide_basic(nn.Module):\n    def __init__(self, in_planes, planes, dropout_rate, stride=1, use_bn=False):\n        super(wide_basic, self).__init__()\n        self.use_bn = use_bn\n        self.dropout_rate = dropout_rate\n        if use_bn:\n            self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)\n        if dropout_rate:\n            self.dropout = nn.Dropout(p=dropout_rate)\n        # self.bn2 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),\n            )\n\n    def forward(self, x):\n        # out = self.dropout(self.conv1(F.relu(self.bn1(x))))\n        if self.use_bn:\n            out = self.conv1(F.relu(self.bn1(x)))\n        else:\n            out = self.conv1(F.relu(x))\n        if self.dropout_rate:\n            out = self.dropout(out)\n        # out = self.conv2(F.relu(self.bn2(out)))\n        out = self.conv2(F.relu(out))\n\n        out += self.shortcut(x)\n\n        return out\n\nclass Wide_ResNet(nn.Module):\n    def __init__(self, depth, widen_factor, dropout_rate, num_classes, use_bn=False, use_pooling=True):\n        super(Wide_ResNet, self).__init__()\n        self.in_planes = 16\n        self.use_bn = use_bn\n        self.use_pooling = use_pooling\n        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'\n        n = (depth-4)/6\n        k = widen_factor\n\n        print('| Wide-Resnet %dx%d' %(depth, k))\n        nStages = [self.in_planes, self.in_planes*2*k, self.in_planes*4*k, self.in_planes*8*k]\n\n        self.conv1 = conv3x3(3,nStages[0])\n        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)\n        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)\n        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)\n        # self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.1)\n        if self.use_pooling:\n            self.linear1 = nn.Linear(nStages[3], 512)\n        else:\n            self.linear1 = nn.Linear(nStages[3]*64, 512)\n\n        self.linear2 = nn.Linear(512, num_classes)\n\n\n    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):\n        strides = [stride] + [1]*(int(num_blocks)-1)\n        layers = []\n\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, dropout_rate, stride, self.use_bn))\n            self.in_planes = planes\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = F.relu(out)\n        if self.use_pooling:\n            out = F.avg_pool2d(out, 8)\n        out = torch.flatten(out, 1)\n        out = F.relu(self.linear1(out))\n        out = self.linear2(out)\n\n        return out\n\ndef wide_resnet_cifar(in_ch=3, in_dim=32):\n    return Wide_ResNet(16, 4, 0.3, 10)\n\ndef wide_resnet_cifar_bn(in_ch=3, in_dim=32):\n    return Wide_ResNet(10, 4, None, 10, use_bn=True)\n\ndef wide_resnet_cifar_bn_wo_pooling(in_ch=3, in_dim=32): # 1113M, 21M\n    return Wide_ResNet(10, 4, None, 10, use_bn=True, use_pooling=False)\n\ndef wide_resnet_cifar_bn_wo_pooling_dropout(in_ch=3, in_dim=32): # 1113M, 21M\n    return Wide_ResNet(10, 4, 0.3, 10, use_bn=True, use_pooling=False)\n\nif __name__ == '__main__':\n    from thop import profile\n    net = wide_resnet_cifar_bn_wo_pooling_dropout()\n    print(net)\n    y = net(torch.randn(1,3,32,32))\n    macs, params = profile(net, (torch.randn(1, 3, 32, 32),))\n    print(macs/1000000, params/1000000)  # 1096M, 5M\n    print(y.size())"
  },
  {
    "path": "examples/vision/models/wide_resnet_imagenet64.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\n\nimport sys\nimport numpy as np\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)\n\ndef conv_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        init.xavier_uniform_(m.weight, gain=np.sqrt(2))\n        init.constant_(m.bias, 0)\n    elif classname.find('BatchNorm') != -1:\n        init.constant_(m.weight, 1)\n        init.constant_(m.bias, 0)\n\nclass wide_basic(nn.Module):\n    def __init__(self, in_planes, planes, dropout_rate, stride=1):\n        super(wide_basic, self).__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)\n        # self.dropout = nn.Dropout(p=dropout_rate)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)\n\n        self.shortcut = nn.Sequential()\n        if stride != 1 or in_planes != planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),\n            )\n\n    def forward(self, x):\n        # out = self.dropout(self.conv1(F.relu(self.bn1(x))))\n        out = self.conv1(F.relu(self.bn1(x)))\n        out = self.conv2(F.relu(self.bn2(out)))\n        out += self.shortcut(x)\n\n        return out\n\nclass Wide_ResNet(nn.Module):\n    def __init__(self, depth, widen_factor, dropout_rate, num_classes,\n            in_planes=16, in_dim=56):\n        super(Wide_ResNet, self).__init__()\n        self.in_planes = in_planes\n\n        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'\n        n = (depth-4)/6\n        k = widen_factor\n\n        print('| Wide-Resnet %dx%d' %(depth, k))\n        nStages = [in_planes, in_planes*k, in_planes*2*k, in_planes*4*k]\n\n        self.conv1 = conv3x3(3,nStages[0])\n        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)\n        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)\n        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)\n        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.1)\n        self.linear = nn.Linear(nStages[3] * (in_dim//4//7)**2, num_classes)\n\n    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):\n        strides = [stride] + [1]*(int(num_blocks)-1)\n        layers = []\n\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, dropout_rate, stride))\n            self.in_planes = planes\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = F.relu(self.bn1(out))\n        out = F.avg_pool2d(out, 7)\n        out = torch.flatten(out, 1)\n        out = self.linear(out)\n\n        return out\n\ndef wide_resnet_imagenet64(in_ch=3, in_dim=56, in_planes=16, widen_factor=10):\n    return Wide_ResNet(10, widen_factor, 0.3, 200, in_dim=in_dim, in_planes=in_planes)\n\ndef wide_resnet_imagenet64_1000class(in_ch=3, in_dim=56, in_planes=16, widen_factor=10):\n    return Wide_ResNet(10, widen_factor, 0.3, 1000, in_dim=in_dim, in_planes=in_planes)\n\nif __name__ == '__main__':\n    from thop import profile\n    net = wide_resnet_imagenet64_1000class()\n    print(net)\n    y = net(torch.randn(1,3,56,56))\n    macs, params = profile(net, (torch.randn(1, 3, 56, 56),))\n    print(macs, params)  # 5229M, 8M\n    print(y.size())"
  },
  {
    "path": "examples/vision/save_intermediate_bound.py",
    "content": "\"\"\"\nA simple example for saving intermediate bounds.\n\"\"\"\nimport os\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\n\ndef mnist_model():\n    model = nn.Sequential(\n        nn.Conv2d(1, 16, 4, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(16, 32, 4, stride=2, padding=1),\n        nn.ReLU(),\n        Flatten(),\n        nn.Linear(32*7*7,100),\n        nn.ReLU(),\n        nn.Linear(100, 10)\n    )\n    return model\n\nmodel = mnist_model()\n# Optionally, load the pretrained weights.\ncheckpoint = torch.load(\n    os.path.join(os.path.dirname(__file__), 'pretrained/mnist_a_adv.pth'),\n    map_location=torch.device('cpu'))\nmodel.load_state_dict(checkpoint)\n\ntest_data = torchvision.datasets.MNIST(\n    './data', train=False, download=True,\n    transform=torchvision.transforms.ToTensor())\n# For illustration we only use 2 image from dataset\nN = 2\nn_classes = 10\nimage = test_data.data[:N].view(N,1,28,28)\ntrue_label = test_data.targets[:N]\n# Convert to float\nimage = image.to(torch.float32) / 255.0\nif torch.cuda.is_available():\n    image = image.cuda()\n    model = model.cuda()\n\nlirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)\nprint('Running on', image.device)\n\neps = 0.3\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, eps = eps)\nimage = BoundedTensor(image, ptb)\n\nlirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1, }})\nlb, ub = lirpa_model.compute_bounds(x=(image,), method='CROWN-Optimized')\n# Intermediate layer bounds are returned as a dictionary, and if an argument is given,\n# a pytorch checkpoint will also be saved to disk.\nsave_dict = lirpa_model.save_intermediate('./mnist_a_adv_bounds.pt')\n# To avoid saving the file and get just the bounds, call without any arguments:\n# save_dict = lirpa_model.save_intermediate()\n"
  },
  {
    "path": "examples/vision/simple_training.py",
    "content": "\"\"\"\nA simple script to train certified defense using the auto_LiRPA library.\n\nWe compute output bounds under input perturbations using auto_LiRPA, and use\nthem to form a \"robust loss\" for certified defense.  Several different bound\noptions are supported, such as IBP, CROWN, and CROWN-IBP. This is a basic\nexample on MNIST and CIFAR-10 datasets with Lp (p>=0) norm perturbation. For\nfaster training, please see our examples with loss fusion such as\ncifar_training.py and tinyimagenet_training.py\n\"\"\"\n\nimport time\nimport random\nimport multiprocessing\nimport argparse\nimport torch.optim as optim\nfrom torch.nn import CrossEntropyLoss\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import MultiAverageMeter\nfrom auto_LiRPA.eps_scheduler import LinearScheduler, AdaptiveScheduler, SmoothedScheduler, FixedScheduler\nimport models\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--verify\", action=\"store_true\", help='verification mode, do not train')\nparser.add_argument(\"--load\", type=str, default=\"\", help='Load pretrained model')\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cpu\", \"cuda\"], help='use cpu or cuda')\nparser.add_argument(\"--data\", type=str, default=\"MNIST\", choices=[\"MNIST\", \"CIFAR\"], help='dataset')\nparser.add_argument(\"--seed\", type=int, default=100, help='random seed')\nparser.add_argument(\"--eps\", type=float, default=0.3, help='Target training epsilon')\nparser.add_argument(\"--norm\", type=float, default='inf', help='p norm for epsilon perturbation')\nparser.add_argument(\"--bound_type\", type=str, default=\"CROWN-IBP\",\n                    choices=[\"IBP\", \"CROWN-IBP\", \"CROWN\", \"CROWN-FAST\"], help='method of bound analysis')\nparser.add_argument(\"--model\", type=str, default=\"resnet\", help='model name (mlp_3layer, cnn_4layer, cnn_6layer, cnn_7layer, resnet)')\nparser.add_argument(\"--num_epochs\", type=int, default=100, help='number of total epochs')\nparser.add_argument(\"--batch_size\", type=int, default=256, help='batch size')\nparser.add_argument(\"--lr\", type=float, default=5e-4, help='learning rate')\nparser.add_argument(\"--scheduler_name\", type=str, default=\"SmoothedScheduler\",\n                    choices=[\"LinearScheduler\", \"AdaptiveScheduler\", \"SmoothedScheduler\", \"FixedScheduler\"], help='epsilon scheduler')\nparser.add_argument(\"--scheduler_opts\", type=str, default=\"start=3,length=60\", help='options for epsilon scheduler')\nparser.add_argument(\"--bound_opts\", type=str, default=None, choices=[\"same-slope\", \"zero-lb\", \"one-lb\"],\n                    help='bound options')\nparser.add_argument(\"--conv_mode\", type=str, choices=[\"matrix\", \"patches\"], default=\"patches\")\nparser.add_argument(\"--save_model\", type=str, default='')\n\nargs = parser.parse_args()\n\n\ndef Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust'):\n    num_class = 10\n    meter = MultiAverageMeter()\n    if train:\n        model.train()\n        eps_scheduler.train()\n        eps_scheduler.step_epoch()\n        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))\n    else:\n        model.eval()\n        eps_scheduler.eval()\n\n    for i, (data, labels) in enumerate(loader):\n        start = time.time()\n        eps_scheduler.step_batch()\n        eps = eps_scheduler.get_eps()\n        # For small eps just use natural training, no need to compute LiRPA bounds\n        batch_method = method\n        if eps < 1e-20:\n            batch_method = \"natural\"\n        if train:\n            opt.zero_grad()\n        # generate specifications\n        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(0)\n        # remove specifications to self\n        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))\n        c = (c[I].view(data.size(0), num_class - 1, num_class))\n        # bound input for Linf norm used only\n        if norm == np.inf:\n            data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_ub = torch.min(data + (eps / loader.std).view(1,-1,1,1), data_max)\n            data_lb = torch.max(data - (eps / loader.std).view(1,-1,1,1), data_min)\n        else:\n            data_ub = data_lb = data\n\n        if list(model.parameters())[0].is_cuda:\n            data, labels, c = data.cuda(), labels.cuda(), c.cuda()\n            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()\n\n        # Specify Lp norm perturbation.\n        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.\n        if norm > 0:\n            ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)\n        elif norm == 0:\n            ptb = PerturbationL0Norm(eps = eps_scheduler.get_max_eps(), ratio = eps_scheduler.get_eps()/eps_scheduler.get_max_eps())\n        x = BoundedTensor(data, ptb)\n\n        output = model(x)\n        regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up\n        meter.update('CE', regular_ce.item(), x.size(0))\n        meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).cpu().detach().numpy() / x.size(0), x.size(0))\n\n        if batch_method == \"robust\":\n            if bound_type == \"IBP\":\n                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)\n            elif bound_type == \"CROWN\":\n                lb, ub = model.compute_bounds(IBP=False, C=c, method=\"backward\", bound_upper=False)\n            elif bound_type == \"CROWN-IBP\":\n                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method=\"backward\")  # pure IBP bound\n                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)\n                factor = (eps_scheduler.get_max_eps() - eps) / eps_scheduler.get_max_eps()\n                ilb, iub = model.compute_bounds(IBP=True, C=c, method=None)\n                if factor < 1e-5:\n                    lb = ilb\n                else:\n                    clb, cub = model.compute_bounds(IBP=False, C=c, method=\"backward\", bound_upper=False)\n                    lb = clb * factor + ilb * (1 - factor)\n            elif bound_type == \"CROWN-FAST\":\n                # Similar to CROWN-IBP but no mix between IBP and CROWN bounds.\n                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)\n                lb, ub = model.compute_bounds(IBP=False, C=c, method=\"backward\", bound_upper=False)\n\n\n            # Pad zero at the beginning for each example, and use fake label \"0\" for all examples\n            lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n        if batch_method == \"robust\":\n            loss = robust_ce\n        elif batch_method == \"natural\":\n            loss = regular_ce\n        if train:\n            loss.backward()\n            eps_scheduler.update_loss(loss.item() - regular_ce.item())\n            opt.step()\n        meter.update('Loss', loss.item(), data.size(0))\n        if batch_method != \"natural\":\n            meter.update('Robust_CE', robust_ce.item(), data.size(0))\n            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.\n            # If any margin is < 0 this example is counted as an error\n            meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))\n        meter.update('Time', time.time() - start)\n        if i % 50 == 0 and train:\n            print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))\n    print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))\n\ndef main(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py\n    if args.data == 'MNIST':\n        model_ori = models.Models[args.model](in_ch=1, in_dim=28)\n    else:\n        model_ori = models.Models[args.model](in_ch=3, in_dim=32)\n    if args.load:\n        state_dict = torch.load(args.load)['state_dict']\n        model_ori.load_state_dict(state_dict)\n\n    ## Step 2: Prepare dataset as usual\n    if args.data == 'MNIST':\n        dummy_input = torch.randn(2, 1, 28, 28)\n        train_data = datasets.MNIST(\"./data\", train=True, download=True, transform=transforms.ToTensor())\n        test_data = datasets.MNIST(\"./data\", train=False, download=True, transform=transforms.ToTensor())\n    elif args.data == 'CIFAR':\n        dummy_input = torch.randn(2, 3, 32, 32)\n        normalize = transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010])\n        train_data = datasets.CIFAR10(\"./data\", train=True, download=True,\n                transform=transforms.Compose([\n                    transforms.RandomHorizontalFlip(),\n                    transforms.RandomCrop(32, 4),\n                    transforms.ToTensor(),\n                    normalize]))\n        test_data = datasets.CIFAR10(\"./data\", train=False, download=True, \n                transform=transforms.Compose([transforms.ToTensor(), normalize]))\n\n    train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),4))\n    if args.data == 'MNIST':\n        train_data.mean = test_data.mean = torch.tensor([0.0])\n        train_data.std = test_data.std = torch.tensor([1.0])\n    elif args.data == 'CIFAR':\n        train_data.mean = test_data.mean = torch.tensor([0.4914, 0.4822, 0.4465])\n        train_data.std = test_data.std = torch.tensor([0.2023, 0.1994, 0.2010])\n\n    ## Step 3: wrap model with auto_LiRPA\n    # The second parameter dummy_input is for constructing the trace of the computational graph.\n    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts, 'conv_mode': args.conv_mode}, device=args.device)\n\n    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler\n    opt = optim.Adam(model.parameters(), lr=args.lr)\n    norm = float(args.norm)\n    lr_scheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.5)\n    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)\n    print(\"Model structure: \\n\", str(model_ori))\n\n    ## Step 5: start training\n    if args.verify:\n        eps_scheduler = FixedScheduler(args.eps)\n        with torch.no_grad():\n            Train(model, 1, test_data, eps_scheduler, norm, False, None, args.bound_type)\n    else:\n        timer = 0.0\n        for t in range(1, args.num_epochs+1):\n            if eps_scheduler.reached_max_eps():\n                # Only decay learning rate after reaching the maximum eps\n                lr_scheduler.step()\n            print(\"Epoch {}, learning rate {}\".format(t, lr_scheduler.get_lr()))\n            start_time = time.time()\n            Train(model, t, train_data, eps_scheduler, norm, True, opt, args.bound_type)\n            epoch_time = time.time() - start_time\n            timer += epoch_time\n            print('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))\n            print(\"Evaluating...\")\n            with torch.no_grad():\n                Train(model, t, test_data, eps_scheduler, norm, False, None, args.bound_type)\n            torch.save({'state_dict': model_ori.state_dict(), 'epoch': t}, args.save_model if args.save_model != \"\" else args.model)\n\n\nif __name__ == \"__main__\":\n    main(args)\n"
  },
  {
    "path": "examples/vision/simple_verification.py",
    "content": "\"\"\"\nA simple example for bounding neural network outputs under input perturbations.\n\nThis example serves as a skeleton for robustness verification of neural networks.\n\"\"\"\nimport os\nfrom collections import defaultdict\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\n\n## Step 1: Define computational graph by implementing forward()\n# This simple model comes from https://github.com/locuslab/convex_adversarial\ndef mnist_model():\n    model = nn.Sequential(\n        nn.Conv2d(1, 16, 4, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(16, 32, 4, stride=2, padding=1),\n        nn.ReLU(),\n        Flatten(),\n        nn.Linear(32*7*7,100),\n        nn.ReLU(),\n        nn.Linear(100, 10)\n    )\n    return model\n\nmodel = mnist_model()\n# Optionally, load the pretrained weights.\ncheckpoint = torch.load(\n    os.path.join(os.path.dirname(__file__), 'pretrained/mnist_a_adv.pth'),\n    map_location=torch.device('cpu'))\nmodel.load_state_dict(checkpoint)\n\n## Step 2: Prepare dataset as usual\ntest_data = torchvision.datasets.MNIST(\n    './data', train=False, download=True,\n    transform=torchvision.transforms.ToTensor())\n# For illustration we only use 2 image from dataset\nN = 2\nn_classes = 10\nimage = test_data.data[:N].view(N,1,28,28)\ntrue_label = test_data.targets[:N]\n# Convert to float\nimage = image.to(torch.float32) / 255.0\nif torch.cuda.is_available():\n    image = image.cuda()\n    model = model.cuda()\n\n## Step 3: wrap model with auto_LiRPA\n# The second parameter is for constructing the trace of the computational graph,\n# and its content is not important.\nlirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)\nprint('Running on', image.device)\n# Visualize the lirpa_model\n# Visualization file is saved as \"bounded_mnist_model.png\" or \"bounded_mnist_model.dot\"\nlirpa_model.visualize(\"bounded_mnist_model\")\nprint()\n\n## Step 4: Compute bounds using LiRPA given a perturbation\neps = 0.3\nnorm = float(\"inf\")\nptb = PerturbationLpNorm(norm = norm, eps = eps)\nimage = BoundedTensor(image, ptb)\n# Get model prediction as usual\npred = lirpa_model(image)\nlabel = torch.argmax(pred, dim=1).cpu().detach().numpy()\nprint('Demonstration 1: Bound computation and comparisons of different methods.\\n')\n\n## Step 5: Compute bounds for final output\nfor method in [\n        'IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)',\n        'CROWN-Optimized (alpha-CROWN)']:\n    print('Bounding method:', method)\n    if 'Optimized' in method:\n        # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.\n        lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})\n    lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0])\n    for i in range(N):\n        print(f'Image {i} top-1 prediction {label[i]} ground-truth {true_label[i]}')\n        for j in range(n_classes):\n            indicator = '(ground-truth)' if j == true_label[i] else ''\n            print('f_{j}(x_0): {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f} {ind}'.format(\n                j=j, l=lb[i][j].item(), u=ub[i][j].item(), ind=indicator))\n    print()\n\n\nprint('Demonstration 2: Obtaining linear coefficients of the lower and upper bounds.\\n')\n# There are many bound coefficients during CROWN bound calculation; here we are interested in the linear bounds\n# of the output layer, with respect to the input layer (the image).\nrequired_A = defaultdict(set)\nrequired_A[lirpa_model.output_name[0]].add(lirpa_model.input_name[0])\n\n# Helper functions to concretize the linear bounds\ndef concretize_bound(A, bias, xL, xU, upper: bool):\n    \"\"\"\n    Concretize linear bound.\n    If upper is True: use A_pos * xU + A_neg * xL + bias\n    If upper is False: use A_pos * xL + A_neg * xU + bias\n    \"\"\"\n    A_pos = torch.clamp(A, min=0.0)\n    A_neg = torch.clamp(A, max=0.0)\n    if upper:\n        return (\n            torch.einsum(\"boijk,boijk->bo\", A_pos, xU)\n            + torch.einsum(\"boijk,boijk->bo\", A_neg, xL)\n            + bias\n        )\n    else:\n        return (\n            torch.einsum(\"boijk,boijk->bo\", A_pos, xL)\n            + torch.einsum(\"boijk,boijk->bo\", A_neg, xU)\n            + bias\n        )\n\n# Prepare input bounds\nx_L = (image - eps).unsqueeze(1)\nx_U = (image + eps).unsqueeze(1)\n\nfor method in [\n        'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN',\n        'CROWN-Optimized (alpha-CROWN)']:\n    print(\"Bounding method:\", method)\n    if 'Optimized' in method:\n        # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.\n        lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 30, 'lr_alpha': 0.1}})\n    lb, ub, A_dict = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], return_A=True, needed_A_dict=required_A)\n    lower_A, lower_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['lbias']\n    upper_A, upper_bias = A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['uA'], A_dict[lirpa_model.output_name[0]][lirpa_model.input_name[0]]['ubias']\n    print(f'lower bound linear coefficients size (batch, output_dim, *input_dims): {list(lower_A.size())}')\n    print(f'lower bound linear coefficients norm (smaller is better): {lower_A.norm()}')\n    print(f'lower bound bias term size (batch, output_dim): {list(lower_bias.size())}')\n    print(f'lower bound bias term sum (larger is better): {lower_bias.sum()}')\n    print(f'upper bound linear coefficients size (batch, output_dim, *input_dims): {list(upper_A.size())}')\n    print(f'upper bound linear coefficients norm (smaller is better): {upper_A.norm()}')\n    print(f'upper bound bias term size (batch, output_dim): {list(upper_bias.size())}')\n    print(f'upper bound bias term sum (smaller is better): {upper_bias.sum()}')\n    print(f'These linear lower and upper bounds are valid everywhere within the perturbation radii.\\n')\n\n    # Validate the concretization of the linear bounds\n    concretized_lb = concretize_bound(lower_A, lower_bias, x_L, x_U, upper=False)\n    concretized_ub = concretize_bound(upper_A, upper_bias, x_L, x_U, upper=True)\n    assert torch.allclose(\n        concretized_lb, lb, rtol=1e-4, atol=1e-5), \"Lower bound mismatch! Error: {}\".format((concretized_lb - lb).abs().max())\n    assert torch.allclose(\n        concretized_ub, ub, rtol=1e-4, atol=1e-5), \"Upper bound mismatch! Error: {}\".format((concretized_ub - ub).abs().max())\n\n\n## An example for computing margin bounds.\n# In compute_bounds() function you can pass in a specification matrix C, which is a final linear matrix applied to the last layer NN output.\n# For example, if you are interested in the margin between the groundtruth class and another class, you can use C to specify the margin.\n# This generally yields tighter bounds.\n# Here we compute the margin between groundtruth class and groundtruth class + 1.\n# If you have more than 1 specifications per batch element, you can expand the second dimension of C (it is 1 here for demonstration).\nlirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device)\nC = torch.zeros(size=(N, 1, n_classes), device=image.device)\ngroundtruth = true_label.to(device=image.device).unsqueeze(1).unsqueeze(1)\ntarget_label = (groundtruth + 1) % n_classes\nC.scatter_(dim=2, index=groundtruth, value=1.0)\nC.scatter_(dim=2, index=target_label, value=-1.0)\nprint('Demonstration 3: Computing bounds with a specification matrix.\\n')\nprint('Specification matrix:\\n', C)\n\nfor method in ['IBP', 'IBP+backward (CROWN-IBP)', 'backward (CROWN)', 'CROWN-Optimized (alpha-CROWN)']:\n    print('Bounding method:', method)\n    if 'Optimized' in method:\n        # For optimized bound, you can change the number of iterations, learning rate, etc here. Also you can increase verbosity to see per-iteration loss values.\n        lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1, }})\n    lb, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0], C=C)\n    for i in range(N):\n        print('Image {} top-1 prediction {} ground-truth {}'.format(i, label[i], true_label[i]))\n        print('margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}'.format(\n            j=true_label[i], target=(true_label[i] + 1) % n_classes, l=lb[i][0].item(), u=ub[i][0].item()))\n    print()\n"
  },
  {
    "path": "examples/vision/tinyimagenet_training.py",
    "content": "import os\nimport random\nimport time\nimport argparse\nimport multiprocessing\nimport logging\nimport torch.optim as optim\nfrom torch.nn import CrossEntropyLoss\nfrom auto_LiRPA import BoundedModule, BoundedTensor, BoundDataParallel, CrossEntropyWrapper\nfrom auto_LiRPA.bound_ops import BoundExp\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix, sync_params\nimport models\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nfrom auto_LiRPA.eps_scheduler import *\n\ndef get_exp_module(bounded_module):\n    for _, node in bounded_module.named_modules():\n        # Find the Exp neuron in computational graph\n        if isinstance(node, BoundExp):\n            return node\n    return None\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--verify\", action=\"store_true\", help='verification mode, do not train')\nparser.add_argument(\"--load\", type=str, default=\"\", help='Load pretrained model')\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cpu\", \"cuda\"], help='use cpu or cuda')\nparser.add_argument(\"--data_dir\", type=str, default=\"data/tinyImageNet/tiny-imagenet-200\",\n                    help='dir of dataset')\nparser.add_argument(\"--seed\", type=int, default=100, help='random seed')\nparser.add_argument(\"--eps\", type=float, default=1. / 255, help='Target training epsilon')\nparser.add_argument(\"--norm\", type=float, default='inf', help='p norm for epsilon perturbation')\nparser.add_argument(\"--bound_type\", type=str, default=\"CROWN-IBP\",\n                    choices=[\"IBP\", \"CROWN-IBP\", \"CROWN\"], help='method of bound analysis')\nparser.add_argument(\"--model\", type=str, default=\"wide_resnet_imagenet64\",\n                    help='model name (cnn_7layer_bn_imagenet, ResNeXt_imagenet64, ResNeXt_imagenet64)')\nparser.add_argument(\"--num_epochs\", type=int, default=600, help='number of total epochs')\nparser.add_argument(\"--batch_size\", type=int, default=128, help='batch size')\nparser.add_argument(\"--lr\", type=float, default=5e-4, help='learning rate')\nparser.add_argument(\"--lr_decay_milestones\", nargs='+', type=int, default=[600, 700], help='learning rate dacay milestones')\nparser.add_argument(\"--scheduler_name\", type=str, default=\"SmoothedScheduler\",\n                    choices=[\"LinearScheduler\", \"AdaptiveScheduler\", \"SmoothedScheduler\"], help='epsilon scheduler')\nparser.add_argument(\"--scheduler_opts\", type=str, default=\"start=100,length=400,mid=0.4\", help='options for epsilon scheduler')\nparser.add_argument(\"--bound_opts\", type=str, default=None, choices=[\"same-slope\", \"zero-lb\", \"one-lb\"],\n                    help='bound options')\nparser.add_argument('--clip_grad_norm', type=float, default=8.0)\nparser.add_argument('--in_planes', type=int, default=16)\nparser.add_argument('--widen_factor', type=int, default=10)\n\nargs = parser.parse_args()\n\nexp_name = args.model + '_b' + str(args.batch_size) + '_' + str(args.bound_type) + '_epoch' + str(\n    args.num_epochs) + '_' + args.scheduler_opts + '_ImageNet_' + str(args.eps)[:6]\nos.makedirs('saved_models/', exist_ok=True)\nlog_file = f'saved_models/{exp_name}{\"_test\" if args.verify else \"\"}.log'\nfile_handler = logging.FileHandler(log_file)\nlogger.addHandler(file_handler)\n\ndef Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True,\n          final_node_name=None):\n    num_class = 200\n    meter = MultiAverageMeter()\n    if train:\n        model.train()\n        eps_scheduler.train()\n        eps_scheduler.step_epoch()\n        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))\n    else:\n        model.eval()\n        eps_scheduler.eval()\n\n    exp_module = get_exp_module(model)\n\n    def get_bound_loss(x=None, c=None):\n        if loss_fusion:\n            bound_lower, bound_upper = False, True\n        else:\n            bound_lower, bound_upper = True, False\n\n        if bound_type == 'IBP':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None,\n                           final_node_name=final_node_name, no_replicas=True)\n        elif bound_type == 'CROWN':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, IBP=False, C=c, method='backward',\n                           bound_lower=bound_lower, bound_upper=bound_upper)\n        elif bound_type == 'CROWN-IBP':\n            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound\n            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)\n            factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()\n            ilb, iub = model(method_opt=\"compute_bounds\", x=x, IBP=True, C=c, method=None,\n                             final_node_name=final_node_name, no_replicas=True)\n            if factor < 1e-50:\n                lb, ub = ilb, iub\n            else:\n                clb, cub = model(method_opt=\"compute_bounds\", IBP=False, C=c, method='backward',\n                                 bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name,\n                                 no_replicas=True)\n                if loss_fusion:\n                    ub = cub * factor + iub * (1 - factor)\n                else:\n                    lb = clb * factor + ilb * (1 - factor)\n\n        if loss_fusion:\n            if isinstance(model, BoundDataParallel):\n                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')\n            else:\n                max_input = exp_module.max_input\n            return None, torch.mean(torch.log(ub) + max_input)\n        else:\n            # Pad zero at the beginning for each example, and use fake label '0' for all examples\n            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n            return lb, robust_ce\n\n    for i, (data, labels) in enumerate(loader):\n        start = time.time()\n        eps_scheduler.step_batch()\n        eps = eps_scheduler.get_eps()\n        # For small eps just use natural training, no need to compute LiRPA bounds\n        batch_method = method\n        if eps < 1e-50:\n            batch_method = \"natural\"\n        if train:\n            opt.zero_grad()\n        # bound input for Linf norm used only\n        if norm == np.inf:\n            data_max = torch.reshape((1. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_min = torch.reshape((0. - loader.mean) / loader.std, (1, -1, 1, 1))\n            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1), data_max)\n            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1), data_min)\n        else:\n            data_ub = data_lb = data\n\n        if list(model.parameters())[0].is_cuda:\n            data, labels = data.cuda(), labels.cuda()\n            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()\n\n        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)\n        x = BoundedTensor(data, ptb)\n        if loss_fusion:\n            if batch_method == 'natural' or not train:\n                output = model(x, labels)\n                regular_ce = torch.mean(torch.log(output))\n            else:\n                model(x, labels)\n                regular_ce = torch.tensor(0., device=data.device)\n            meter.update('CE', regular_ce.item(), x.size(0))\n            x = (x, labels)\n            c = None\n        else:\n            c = get_spec_matrix(data, labels, num_class)\n            x = (x, labels)\n            output = model(x, final_node_name=final_node_name)\n            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up\n            meter.update('CE', regular_ce.item(), x[0].size(0))\n            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))\n\n        if batch_method == 'robust':\n            # print(data.sum())\n            lb, robust_ce = get_bound_loss(x=x, c=c)\n            loss = robust_ce\n        elif batch_method == 'natural':\n            loss = regular_ce\n\n        if train:\n            loss.backward()\n\n            if args.clip_grad_norm:\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)\n                meter.update('grad_norm', grad_norm)\n\n            if isinstance(eps_scheduler, AdaptiveScheduler):\n                eps_scheduler.update_loss(loss.item() - regular_ce.item())\n            opt.step()\n        meter.update('Loss', loss.item(), data.size(0))\n\n        if batch_method != 'natural':\n            meter.update('Robust_CE', robust_ce.item(), data.size(0))\n            if not loss_fusion:\n                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.\n                # If any margin is < 0 this example is counted as an error\n                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))\n        meter.update('Time', time.time() - start)\n\n        if (i + 1) % 250 == 0 and train:\n            logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n\n    logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n    return meter\n\n\ndef main(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    ## Step 1: Initial original model as usual, see model details in models/example_feedforward.py and models/example_resnet.py\n    model_ori = models.Models[args.model](in_planes=args.in_planes, widen_factor=args.widen_factor)\n    epoch = 0\n    if args.load:\n        checkpoint = torch.load(args.load)\n        epoch, state_dict, opt_state = checkpoint['epoch'], checkpoint['state_dict'], checkpoint.get('optimizer')\n        for k, v in state_dict.items():\n            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0\n        model_ori.load_state_dict(state_dict)\n        logger.info('Checkpoint loaded: {}'.format(args.load))\n\n    ## Step 2: Prepare dataset as usual\n    dummy_input = torch.randn(2, 3, 56, 56)\n    normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])\n    train_data = datasets.ImageFolder(args.data_dir + '/train',\n                                      transform=transforms.Compose([\n                                          transforms.RandomHorizontalFlip(),\n                                          transforms.RandomCrop(56, padding_mode='edge'),\n                                          transforms.ToTensor(),\n                                          normalize,\n                                      ]))\n    test_data = datasets.ImageFolder(args.data_dir + '/val',\n                                     transform=transforms.Compose([\n                                         # transforms.RandomResizedCrop(64, scale=(0.875, 0.875), ratio=(1., 1.)),\n                                         transforms.CenterCrop(56),\n                                         transforms.ToTensor(),\n                                         normalize]))\n\n    train_data = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True,\n                                             num_workers=min(multiprocessing.cpu_count(), 4))\n    test_data = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size // 5, pin_memory=True,\n                                            num_workers=min(multiprocessing.cpu_count(), 4))\n    train_data.mean = test_data.mean = torch.tensor([0.4802, 0.4481, 0.3975])\n    train_data.std = test_data.std = torch.tensor([0.2302, 0.2265, 0.2262])\n\n    ## Step 3: wrap model with auto_LiRPA\n    # The second parameter dummy_input is for constructing the trace of the computational graph.\n    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),\n                               bound_opts= { 'activation_bound_option': args.bound_opts, 'loss_fusion': True }, device=args.device)\n    model_loss = BoundDataParallel(model_loss)\n\n    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler\n    opt = optim.Adam(model_loss.parameters(), lr=args.lr)\n    norm = float(args.norm)\n    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)\n    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)\n    logger.info(str(model_ori))\n\n    if args.load:\n        if opt_state:\n            opt.load_state_dict(opt_state)\n            logger.info('resume opt_state')\n\n    # skip epochs\n    if epoch > 0:\n        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)\n        eps_scheduler.set_epoch_length(epoch_length)\n        eps_scheduler.train()\n        for i in range(epoch):\n            lr_scheduler.step()\n            eps_scheduler.step_epoch(verbose=True)\n            for j in range(epoch_length):\n                eps_scheduler.step_batch()\n        logger.info('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))\n\n    ## Step 5: start training\n    if args.verify:\n        eps_scheduler = FixedScheduler(args.eps)\n        with torch.no_grad():\n            Train(model, 1, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False, final_node_name=None)\n    else:\n        timer = 0.0\n        best_err = 1e10\n        for t in range(epoch + 1, args.num_epochs + 1):\n            logger.info(\"Epoch {}, learning rate {}\".format(t, lr_scheduler.get_last_lr()))\n            start_time = time.time()\n            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)\n            lr_scheduler.step()\n            epoch_time = time.time() - start_time\n            timer += epoch_time\n            logger.info('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))\n\n            logger.info(\"Evaluating...\")\n            torch.cuda.empty_cache()\n\n            state_dict = sync_params(model_ori, model_loss, loss_fusion=True)\n\n            with torch.no_grad():\n                if int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']) > t >= int(\n                        eps_scheduler.params['start']):\n                    m = Train(model_loss, t, test_data, eps_scheduler, norm, False, None, args.bound_type, loss_fusion=True)\n                else:\n                    model_ori.load_state_dict(state_dict)\n                    model = BoundedModule(model_ori, dummy_input, bound_opts={'activation_bound_option':args.bound_opts}, device=args.device)\n                    model = BoundDataParallel(model)\n                    m = Train(model, t, test_data, eps_scheduler, norm, False, None, 'IBP', loss_fusion=False)\n                    del model\n\n            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}\n            if t < int(eps_scheduler.params['start']):\n                torch.save(save_dict, 'saved_models/natural_' + exp_name)\n            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):\n                current_err = m.avg('Verified_Err')\n                if current_err < best_err:\n                    best_err = current_err\n                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_err)[:6])\n            else:\n                torch.save(save_dict, 'saved_models/' + exp_name)\n            torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    logger.info(args)\n    main(args)\n"
  },
  {
    "path": "examples/vision/verify_two_node.py",
    "content": "\"\"\"\nExample for multi-node perturbation. An input image is splited to two parts\nwhere each part is perturbed respectively constained by L-inf norm. It is\nexpected to output the same results as running `simple_verification.py` where\nthe whole image is perturbed constained by L-inf norm.\n\"\"\"\n\nimport os\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\n\n## Step 1: Define computational graph by implementing forward()\nclass cnn_MNIST(nn.Module):\n    def __init__(self):\n        super(cnn_MNIST, self).__init__()\n        self.conv1 = nn.Conv2d(1, 8, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(784, 256)\n        self.fc2 = nn.Linear(256, 10)\n\n    def forward(self, x, y):\n        x = torch.cat([x, y], dim=2) # concat the two parts of input\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        x = x.view(-1, 784)\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        return x\n\nmodel = cnn_MNIST()\n# Load the pretrained weights\ncheckpoint = torch.load(os.path.join(os.path.dirname(__file__),\"pretrained/mnist_cnn_small.pth\"),\n                        map_location=torch.device('cpu'))\nmodel.load_state_dict(checkpoint)\n\n## Step 2: Prepare dataset as usual\ntest_data = torchvision.datasets.MNIST(\n    \"./data\", train=False, download=True, transform=torchvision.transforms.ToTensor())\n# For illustration we only use 2 image from dataset\nN = 2\nn_classes = 10\nimage = test_data.data[:N].view(N,1,28,28)\n# Convert to float\nimage = image.to(torch.float32) / 255.0\nif torch.cuda.is_available():\n    image = image.cuda()\n    model = model.cuda()\n\n## Step 3: wrap model with auto_LiRPA\n# The second parameter is for constructing the trace of the computational graph,\n# and its content is not important.\nimage_1, image_2 = torch.split(torch.empty_like(image), [14, 14], dim=2)\nmodel = BoundedModule(\n    model, (image_1, image_2), device=image.device,\n    bound_opts={'conv_mode': 'matrix'} # Patches mode is not supported currently\n)\n\n## Step 4: Compute bounds using LiRPA given a perturbation\neps = 0.3\nnorm = np.inf\nptb = PerturbationLpNorm(norm=norm, eps=eps)\nimage_1, image_2 = torch.split(image, [14, 14], dim=2)\nimage_1 = BoundedTensor(image_1, ptb)\nimage_2 = BoundedTensor(image_2, ptb)\n# Get model prediction as usual\npred = model(image_1, image_2)\nlabel = torch.argmax(pred, dim=1).cpu().numpy()\n# Compute bounds\nlb, ub = model.compute_bounds()\n\n## Step 5: Final output\npred = pred.detach().cpu().numpy()\nlb = lb.detach().cpu().numpy()\nub = ub.detach().cpu().numpy()\nfor i in range(N):\n    print(\"Image {} top-1 prediction {}\".format(i, label[i]))\n    for j in range(n_classes):\n        print(\"f_{j}(x_0) = {fx0:8.3f},   {l:8.3f} <= f_{j}(x_0+delta) <= {u:8.3f}\".format(\n            j=j, fx0=pred[i][j], l=lb[i][j], u=ub[i][j]))\n    print()\n"
  },
  {
    "path": "examples/vision/weight_perturbation_training.py",
    "content": "\"\"\"\nA simple example for certified robustness against model weight perturbations.\n\nSince our framework works on general computational graphs, where both model\nweights and model inputs are inputs of the computational graph, our\nperturbation analysis can naturally be applied to the model weights, allowing\nanalysis for certified model robustness under weight perturbations. This file\nprovides a simple example of certified defense for model weight perturbations.\n\nSee our paper https://arxiv.org/abs/2002.12920 for more details.\n\"\"\"\nimport random\nimport time\nimport os\nimport argparse\nimport logging\nimport torch.optim as optim\nfrom torch.nn import CrossEntropyLoss\nfrom auto_LiRPA import BoundedModule, CrossEntropyWrapper, BoundDataParallel, BoundedParameter\nfrom auto_LiRPA.bound_ops import BoundExp\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix\nfrom datasets import mnist_loaders\nimport torchvision.datasets as datasets\nimport models\nfrom auto_LiRPA.eps_scheduler import LinearScheduler, AdaptiveScheduler, SmoothedScheduler, FixedScheduler\n\ndef get_exp_module(bounded_module):\n    for _, node in bounded_module.named_modules():\n        # Find the Exp neuron in computational graph\n        if isinstance(node, BoundExp):\n            return node\n    return None\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--verify\", action=\"store_true\", help='verification mode, do not train')\nparser.add_argument(\"--load\", type=str, default=\"\", help='Load pretrained model')\nparser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cpu\", \"cuda\"], help='use cpu or cuda')\nparser.add_argument(\"--data\", type=str, default=\"MNIST\", choices=[\"MNIST\", \"FashionMNIST\"], help='dataset')\nparser.add_argument(\"--ratio\", type=float, default=None, help='percent of training used, None means whole training data')\nparser.add_argument(\"--seed\", type=int, default=100, help='random seed')\nparser.add_argument(\"--eps\", type=float, default=0.1, help='Target training epsilon for weight perturbations')\nparser.add_argument(\"--norm\", type=float, default='inf', help='p norm for epsilon perturbation')\nparser.add_argument(\"--bound_type\", type=str, default=\"CROWN-IBP\",\n                    choices=[\"IBP\", \"CROWN-IBP\", \"CROWN\"], help='method of bound analysis')\nparser.add_argument(\"--opt\", type=str, default='ADAM', choices=[\"ADAM\", \"SGD\"], help='optimizer')\nparser.add_argument(\"--num_epochs\", type=int, default=150, help='number of total epochs')\nparser.add_argument(\"--batch_size\", type=int, default=256, help='batch size')\nparser.add_argument(\"--lr\", type=float, default=0.001, help='learning rate')\nparser.add_argument(\"--lr_decay_milestones\", nargs='+', type=int, default=[120, 140], help='learning rate dacay milestones')\nparser.add_argument(\"--scheduler_name\", type=str, default=\"LinearScheduler\",\n                    choices=[\"LinearScheduler\", \"AdaptiveScheduler\", \"SmoothedScheduler\"], help='epsilon scheduler')\nparser.add_argument(\"--scheduler_opts\", type=str, default=\"start=10,length=100\", help='options for epsilon scheduler')\nparser.add_argument(\"--bound_opts\", type=str, default=None, choices=[\"same-slope\", \"zero-lb\", \"one-lb\"],\n                    help='bound options')\nparser.add_argument('--clip_grad_norm', type=float, default=8.0)\nparser.add_argument('--truncate_data', type=int, help='Truncate the training/test batches in unit test')\nparser.add_argument('--multigpu', action='store_true', help='MultiGPU training')\n\nnum_class = 10\nargs = parser.parse_args()\nexp_name = 'mlp_MNIST'+'_b'+str(args.batch_size)+'_'+str(args.bound_type)+'_epoch'+str(args.num_epochs)+'_'+args.scheduler_opts+'_'+str(args.eps)[:6]\nlog_file = f'{exp_name}{\"_test\" if args.verify else \"\"}.log'\nfile_handler = logging.FileHandler(log_file)\nlogger.addHandler(file_handler) \n\n## Training one epoch.\ndef Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None):\n    meter = MultiAverageMeter()\n    if train:\n        model.train()\n        eps_scheduler.train()\n        eps_scheduler.step_epoch(verbose=False)\n        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))\n    else:\n        model.eval()\n        eps_scheduler.eval()\n    \n    # Used for loss-fusion. Get the exp operation in computational graph.\n    exp_module = get_exp_module(model)\n\n    def get_bound_loss(x=None, c=None):\n        if loss_fusion:\n            # When loss fusion is used, we need the upper bound for the final loss function.\n            bound_lower, bound_upper = False, True\n        else:\n            # When loss fusion is not used, we need the lower bound for the logit layer.\n            bound_lower, bound_upper = True, False\n\n        if bound_type == 'IBP':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, C=c, method=\"IBP\", final_node_name=final_node_name, no_replicas=True)\n        elif bound_type == 'CROWN':\n            lb, ub = model(method_opt=\"compute_bounds\", x=x, C=c, method=\"backward\",\n                                          bound_lower=bound_lower, bound_upper=bound_upper)\n        elif bound_type == 'CROWN-IBP':\n            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)\n            # factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()\n            ilb, iub = model(method_opt=\"compute_bounds\", x=x, C=c, method=\"IBP\", final_node_name=final_node_name, no_replicas=True)\n            lb, ub = model(method_opt=\"compute_bounds\", C=c, method=\"CROWN-IBP\",\n                         bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, average_A=True, no_replicas=True)\n        if loss_fusion:\n            # When loss fusion is enabled, we need to get the common factor before softmax.\n            if isinstance(model, BoundDataParallel):\n                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')\n            else:\n                max_input = exp_module.max_input\n            return None, torch.mean(torch.log(ub) + max_input)\n        else:\n            # Pad zero at the beginning for each example, and use fake label '0' for all examples\n            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)\n            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)\n            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)\n            return lb, robust_ce\n\n    for i, (data, labels) in enumerate(loader):\n        # For unit test. We only use a small number of batches\n        if args.truncate_data:\n            if i >= args.truncate_data:\n                break\n\n        start = time.time()\n        eps_scheduler.step_batch()\n        eps = eps_scheduler.get_eps()\n        # For small eps just use natural training, no need to compute LiRPA bounds\n        batch_method = method\n        if eps < 1e-50:\n            batch_method = \"natural\"\n        if train:\n            opt.zero_grad()\n\n        if list(model.parameters())[0].is_cuda:\n            data, labels = data.cuda(), labels.cuda()\n\n        model.ptb.eps = eps\n        x = data\n        if loss_fusion:\n            if batch_method == 'natural' or not train:\n                output = model(x, labels)  # , disable_multi_gpu=True\n                regular_ce = torch.mean(torch.log(output))\n            else:\n                model(x, labels)\n                regular_ce = torch.tensor(0., device=data.device)\n            meter.update('CE', regular_ce.item(), x.size(0))\n            x = (x, labels)\n            c = None\n        else:\n            # Generate speicification matrix (when loss fusion is not used).\n            c = get_spec_matrix(data, labels, num_class)\n            x = (x, labels)\n            output = model(x, final_node_name=final_node_name)\n            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up\n            meter.update('CE', regular_ce.item(), x[0].size(0))\n            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))\n\n        if batch_method == 'robust':\n            lb, robust_ce = get_bound_loss(x=x, c=c)\n            loss = robust_ce\n        elif batch_method == 'natural':\n            loss = regular_ce\n\n        if train:\n            loss.backward()\n\n            if args.clip_grad_norm:\n                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)\n                meter.update('grad_norm', grad_norm)\n\n            if isinstance(eps_scheduler, AdaptiveScheduler):\n                eps_scheduler.update_loss(loss.item() - regular_ce.item())\n            opt.step()\n        meter.update('Loss', loss.item(), data.size(0))\n\n        if batch_method != 'natural':\n            meter.update('Robust_CE', robust_ce.item(), data.size(0))\n            if not loss_fusion:\n                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.\n                # If any margin is < 0 this example is counted as an error\n                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))\n        meter.update('Time', time.time() - start)\n\n        if (i + 1) % 50 == 0 and train:\n            logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n\n    logger.info('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))\n    return meter\n\n\ndef main(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n\n    ## Load the model with BoundedParameter for weight perturbation.\n    model_ori = models.Models['mlp_3layer_weight_perturb']()\n\n    epoch = 0\n    ## Load a checkpoint, if requested.\n    if args.load:\n        checkpoint = torch.load(args.load)\n        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']\n        opt_state = None\n        try:\n            opt_state = checkpoint['optimizer']\n        except KeyError:\n            print('no opt_state found')\n        for k, v in state_dict.items():\n            assert torch.isnan(v).any().cpu().numpy() == 0 and torch.isinf(v).any().cpu().numpy() == 0\n        model_ori.load_state_dict(state_dict)\n        logger.info('Checkpoint loaded: {}'.format(args.load))\n\n    ## Step 2: Prepare dataset as usual\n    dummy_input = torch.randn(2, 1, 28, 28)\n    train_data,  test_data = mnist_loaders(datasets.MNIST, batch_size=args.batch_size, ratio=args.ratio)\n    train_data.mean = test_data.mean = torch.tensor([0.0])\n    train_data.std = test_data.std = torch.tensor([1.0])\n\n    ## Step 3: wrap model with auto_LiRPA\n    # The second parameter dummy_input is for constructing the trace of the computational graph.\n    model = BoundedModule(model_ori, dummy_input, device=args.device, bound_opts={\n        'activation_bound_option':args.bound_opts, 'sparse_intermediate_bounds': False,\n        'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False})\n    final_name1 = model.final_name\n    model_loss = BoundedModule(CrossEntropyWrapper(model_ori), (dummy_input, torch.zeros(1, dtype=torch.long)),\n            device=args.device, bound_opts= {'activation_bound_option': args.bound_opts, 'loss_fusion': True,\n                                             'sparse_intermediate_bounds': False,\n                                             'sparse_conv_intermediate_bounds': False,\n                                             'sparse_intermediate_bounds_with_ibp': False})\n\n    # after CrossEntropyWrapper, the final name will change because of one more input node in CrossEntropyWrapper\n    final_name2 = model_loss._modules[final_name1].output_name[0]\n    assert type(model._modules[final_name1]) == type(model_loss._modules[final_name2])\n    \n    if args.multigpu:\n        model_loss = BoundDataParallel(model_loss)\n        \n    model_loss.ptb = model.ptb = model_ori.ptb  # Perturbation on the parameters\n\n    ## Step 4 prepare optimizer, epsilon scheduler and learning rate scheduler\n    if args.opt == 'ADAM':\n        opt = optim.Adam(model_loss.parameters(), lr=args.lr, weight_decay=0.01)\n    elif args.opt == 'SGD':\n        opt = optim.SGD(model_loss.parameters(), lr=args.lr, weight_decay=0.01)\n\n    norm = float(args.norm)\n    lr_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=args.lr_decay_milestones, gamma=0.1)\n    eps_scheduler = eval(args.scheduler_name)(args.eps, args.scheduler_opts)\n    logger.info(str(model_ori))\n\n    # Skip epochs if we continue training from a checkpoint.\n    if epoch > 0:\n        epoch_length = int((len(train_data.dataset) + train_data.batch_size - 1) / train_data.batch_size)\n        eps_scheduler.set_epoch_length(epoch_length)\n        eps_scheduler.train()\n        for i in range(epoch):\n            lr_scheduler.step()\n            eps_scheduler.step_epoch(verbose=True)\n            for j in range(epoch_length):\n                eps_scheduler.step_batch()\n        logger.info('resume from eps={:.12f}'.format(eps_scheduler.get_eps()))\n\n    if args.load:\n        if opt_state:\n            opt.load_state_dict(opt_state)\n            logger.info('resume opt_state')\n\n    ## Step 5: start training.\n    if args.verify:\n        eps_scheduler = FixedScheduler(args.eps)\n        with torch.no_grad():\n            Train(model_loss, 1, test_data, eps_scheduler, norm, False, None, args.bound_type, loss_fusion=False, final_node_name=final_name2)\n    else:\n        timer = 0.0\n        best_loss = 1e10\n        # Main training loop\n        for t in range(epoch + 1, args.num_epochs+1):\n            logger.info(\"Epoch {}, learning rate {}\".format(t, lr_scheduler.get_last_lr()))\n            start_time = time.time()\n\n            # Training one epoch\n            Train(model_loss, t, train_data, eps_scheduler, norm, True, opt, args.bound_type, loss_fusion=True)\n            lr_scheduler.step()\n            epoch_time = time.time() - start_time\n            timer += epoch_time\n            logger.info('Epoch time: {:.4f}, Total time: {:.4f}'.format(epoch_time, timer))\n\n            logger.info(\"Evaluating...\")\n            torch.cuda.empty_cache()\n\n            state_dict = model_loss.state_dict()\n\n            # Test one epoch.\n            with torch.no_grad():\n                m = Train(model, t, test_data, eps_scheduler, norm, False, None, args.bound_type,\n              loss_fusion=False, final_node_name=final_name1)\n\n            # Save checkpoints.\n            save_dict = {'state_dict': state_dict, 'epoch': t, 'optimizer': opt.state_dict()}\n            if not os.path.exists('saved_models'):\n                os.mkdir('saved_models')\n            if t < int(eps_scheduler.params['start']):\n                torch.save(save_dict, 'saved_models/natural_' + exp_name)\n            elif t > int(eps_scheduler.params['start']) + int(eps_scheduler.params['length']):\n                current_loss = m.avg('Loss')\n                if current_loss < best_loss:\n                    best_loss = current_loss\n                    torch.save(save_dict, 'saved_models/' + exp_name + '_best_' + str(best_loss)[:6])\n                else:\n                    torch.save(save_dict, 'saved_models/' + exp_name)\n            else:\n                torch.save(save_dict, 'saved_models/' + exp_name)\n            torch.cuda.empty_cache()\n\n\nif __name__ == \"__main__\":\n    main(args)\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\nfrom pathlib import Path\n\n# Check PyTorch version\npytorch_version_l = '2.0.0'\npytorch_version_u = '2.9.0' # excluded\ntorchvision_version_l = '0.12.0'\ntorchvision_version_u = '0.24.0' # excluded\nmsg_install_pytorch = (f'It is recommended to manually install PyTorch '\n                    f'(>={pytorch_version_l},<{pytorch_version_u}) suitable '\n                    'for your system ahead: https://pytorch.org/get-started.\\n')\ntry:\n    import torch\n    if torch.__version__ < pytorch_version_l:\n        print(f'PyTorch version {torch.__version__} is too low. '\n                        + msg_install_pytorch)\n    if torch.__version__ >= pytorch_version_u:\n        print(f'PyTorch version {torch.__version__} is too high. '\n                        + msg_install_pytorch)\nexcept ModuleNotFoundError:\n    print(f'PyTorch is not installed. {msg_install_pytorch}')\n\nwith open('auto_LiRPA/__init__.py') as file:\n    for line in file.readlines():\n        if '__version__' in line:\n            version = eval(line.strip().split()[-1])\n\nthis_directory = Path(__file__).parent\nlong_description = (this_directory / 'README.md').read_text()\n\nprint(f'Installing auto_LiRPA {version}')\nsetup(\n    name='auto_LiRPA',\n    version=version,\n    description='A library for Automatic Linear Relaxation based Perturbation Analysis (LiRPA) on general computational graphs, with a focus on adversarial robustness verification and certification of deep neural networks.',\n    long_description=long_description,\n    long_description_content_type='text/markdown',\n    url='https://github.com/Verified-Intelligence/auto_LiRPA',\n    author='α,β-CROWN Team',\n    author_email='huan@huan-zhang.com, xiangru4@illinois.edu',\n    packages=find_packages(),\n    install_requires=[\n        f'torch>={pytorch_version_l},<{pytorch_version_u}',\n        f'torchvision>={torchvision_version_l},<{torchvision_version_u}',\n        'numpy>=1.20',\n        'packaging>=20.0',\n        'pytest==8.1.1',\n        'pylint>=2.15',\n        'pytest-order>=1.0.0',\n        'pytest-mock>=3.14',\n        'appdirs>=1.4',\n        'pyyaml>=5.0',\n        'ninja>=1.10',\n        'tqdm>=4.64',\n        'graphviz>=0.20.3'\n    ],\n    platforms=['any'],\n    license='BSD',\n)\n"
  },
  {
    "path": "tests/.gitignore",
    "content": ".cache\n"
  },
  {
    "path": "tests/data/.gitignore",
    "content": "cifar-10-python.tar.gz\ncifar-10-batches-py\nMNIST"
  },
  {
    "path": "tests/test_1d_activation.py",
    "content": "\"\"\"Test one dimensional activation functions (e.g., ReLU, tanh, exp, sin, etc)\"\"\"\nimport functools\n\nimport pytest\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import logger\nfrom auto_LiRPA.operators.s_shaped import TanhGradOp, SigmoidGradOp\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n# Wrap the computation with a nn.Module\nclass test_model(nn.Module):\n    def __init__(self, act_func):\n        super().__init__()\n        self.act_func = act_func\n\n    def forward(self, x):\n        return self.act_func(x)\n\ndef pow_2(x):\n    return torch.pow(x, 2)\n\ndef pow_3(x):\n    return torch.pow(x, 3)\n\nclass GELUOp(torch.autograd.Function):\n    @staticmethod\n    def symbolic(g, x):\n        return g.op('custom::Gelu', x)\n\n    @staticmethod\n    def forward(ctx, x):\n        return torch.nn.functional.gelu(x)\n\ndef GELU(x):\n    return GELUOp.apply(x)\n\ndef gen_hardtanh(min_val, max_val):\n   return functools.partial(torch.nn.functional.hardtanh, min_val=min_val, max_val=max_val)\n\n# The original tanhgrad and sigmoidgrad also take in the gradient from the following layer\n# and multiply it. Here we only implement the part that computes the local gradient.\ndef tanhgrad(x):\n    return TanhGradOp.apply(x)\n\ndef sigmoidgrad(x):\n    return SigmoidGradOp.apply(x)\n\n\nclass Test1DActivation(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def create_test(self, act_func, low, high, ntests=1000, nsamples=1000,\n                    method='IBP', activation_bound_option='adaptive', input_lb=None, input_ub=None):\n        print(f'Testing activation {act_func} (method {method}, activation_bound_option {activation_bound_option})')\n\n        model = test_model(act_func)\n        image = torch.zeros(1, ntests)\n        bounded_model = BoundedModule(\n            model, image, bound_opts={\n                'optimize_bound_args': {'iteration': 2},\n                'activation_bound_option': activation_bound_option\n            }, device=self.default_device)\n\n        if input_lb is None or input_ub is None:\n            # Generate randomly bounded inputs.\n            p = torch.rand(1, ntests) * (high - low) + low\n            q = torch.rand(1, ntests) * (high - low) + low\n            input_lb = torch.min(p, q)\n            input_ub = torch.max(p, q)\n        else:\n            low, high = torch.min(input_lb), torch.max(input_ub)\n        input_center = (input_lb + input_ub) / 2.0\n        ptb = PerturbationLpNorm(norm=float(\"inf\"), eps=None, x_L=input_lb, x_U=input_ub)\n        ptb_data = BoundedTensor(input_center, ptb)\n\n        # Generate reference results.\n        table = act_func(torch.linspace(start=low, end=high, steps=nsamples+1))\n        def lookup(l, u):\n            assert torch.all(u <= high)\n            assert torch.all(l >= low)\n            shape = l.size()\n            l = l.squeeze()\n            u = u.squeeze()\n            # select all sample points between l and u.\n            low_index = torch.ceil((l - low) / (high - low) * nsamples).int()  # Make sure we do not have index 0.\n            high_index = torch.floor((u - low) / (high - low) * nsamples).int()\n            real_lb = torch.empty_like(l)\n            real_ub = torch.empty_like(u)\n            for i, (li, hi) in enumerate(zip(low_index, high_index)):\n                if li == hi + 1:\n                    # Not enough precision. l and u are too close so we cannot tell.\n                    real_lb[i] = float(\"inf\")\n                    real_ub[i] = float(\"-inf\")\n                else:\n                    selected = table[li : hi+1]\n                    real_lb[i] = torch.min(selected)\n                    real_ub[i] = torch.max(selected)\n            real_lb = real_lb.view(*shape)\n            real_ub = real_ub.view(*shape)\n            return real_lb, real_ub\n\n        # These are reference results. IBP results should be very close to these.\n        # Linear bound results can be looser than these.\n        ref_forward = model(input_center)\n        ref_output_lb, ref_output_ub = lookup(input_lb, input_ub)\n\n        # Get bounding results.\n        forward = bounded_model(ptb_data)\n        output_lb, output_ub = bounded_model.compute_bounds(\n            x=(ptb_data,), method=method)\n        bounded_model.set_bound_opts({\n            'optimize_bound_args': {'iteration': 2, 'init_alpha': True},\n        })\n\n        # Compare.\n        assert torch.allclose(forward, ref_forward)\n        for i in range(ntests):\n            show = False\n            if output_ub[0,i] < ref_output_ub[0,i] - 1e-5:\n                logger.warning(f'upper bound is wrong {ref_output_ub[0,i] - output_ub[0,i]}')\n                show = True\n            if output_lb[0,i] > ref_output_lb[0,i] + 1e-5:\n                logger.warning(f'lower bound is wrong {output_lb[0,i] - ref_output_lb[0,i]}')\n                show = True\n            if show:\n                logger.warning(f'input_lb={input_lb[0,i]:8.3f}, input_ub={input_ub[0,i]:8.3f}, lb={output_lb[0,i]:8.3f}, ref_lb={ref_output_lb[0,i]:8.3f}, ub={output_ub[0,i]:8.3f}, ref_ub={ref_output_ub[0,i]:8.3f}')\n        assert torch.all(output_ub + 1e-5 >= ref_output_ub)\n        assert torch.all(output_lb - 1e-5 <= ref_output_lb)\n\n    @pytest.mark.skip(reason=\"Known issue: https://github.com/Verified-Intelligence/Verifier_Development/issues/164\")\n    def test_tan(self):\n        # Test tan(x) in different periods.\n        for i in range(-5, 5):\n            self.create_test(\n                act_func=torch.tan,\n                low=-0.5*torch.pi + i*torch.pi + 1e-20,\n                high=0.5*torch.pi + i*torch.pi - 1e-20, method='IBP')\n            self.create_test(\n                act_func=torch.tan,\n                low=-0.5*torch.pi + i*torch.pi + 1e-20,\n                high=0.5*torch.pi + i*torch.pi - 1e-20, method='CROWN')\n\n    def test_acts(self):\n        for act_func in [torch.nn.functional.relu,\n                         torch.sin, torch.cos,\n                         torch.tanh, torch.sigmoid, torch.arctan,\n                         torch.exp, pow_2, pow_3,\n                         torch.sign, GELU, gen_hardtanh(-1,1),gen_hardtanh(-0.25,0.25),gen_hardtanh(1,10),gen_hardtanh(-5,2),\n                         tanhgrad, sigmoidgrad]:\n            low, high = -10, 10\n            if act_func == torch.reciprocal:\n                # So far only positive values are supported.\n                low = 0.01\n            self.create_test(act_func=act_func, low=low, high=high, method='IBP')\n            self.create_test(act_func=act_func, low=low, high=high, method='CROWN')\n            if act_func not in [torch.exp, torch.sign, torch.sin, torch.cos, tanhgrad, sigmoidgrad]:\n                # Use optimized bounds\n                self.create_test(act_func=act_func, low=low, high=high,\n                                 method='CROWN-Optimized')\n            if act_func in [torch.sin, torch.cos]:\n                test_samples = 10\n                for _ in range(test_samples):\n                    self.create_test(act_func=act_func, low=low, high=high, method='CROWN-Optimized')\n\n            if act_func in [torch.nn.functional.relu]:\n                self.create_test(act_func=act_func, low=low, high=high, method='Dynamic-Forward')\n            if act_func in [torch.nn.functional.relu, torch.tanh]:\n                self.create_test(act_func=act_func, low=low, high=high, method='CROWN', activation_bound_option='same-slope')\n\n        print('Testing activations with large input range')\n        for act_func in [torch.sin, torch.tanh,\n                        pow_3, GELU]:\n            low, high = -600, 600\n            self.create_test(act_func=act_func, low=low, high=high, method='CROWN')\n\n\nif __name__ == '__main__':\n    testcase = Test1DActivation()\n    testcase.test_acts()\n"
  },
  {
    "path": "tests/test_2d_activation.py",
    "content": "\"\"\"Test two dimensional activation functions (e.g., min, max, etc)\"\"\"\nimport tqdm\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom auto_LiRPA.utils import logger\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n# Wrap the computation with a nn.Module\nclass test_model(nn.Module):\n    def __init__(self, act_func):\n        super().__init__()\n        self.act_func = act_func\n\n    def forward(self, x, y):\n        return self.act_func(x, y)\n\n\ndef mul(x, y):\n    return x * y\n\n\nclass Test2DActivation(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def create_test(self, act_func, low_x, high_x, low_y, high_y,\n                    ntests=10000, nsamples=1000, method='IBP'):\n        print(f'Testing activation {act_func}')\n\n        model = test_model(act_func)\n        image = torch.zeros(2, ntests)\n        bounded_model = BoundedModule(model, (image[0], image[1]), device=self.default_device)\n\n        # Generate randomly bounded inputs.\n        p_x = torch.rand(1, ntests) * (high_x - low_x) + low_x\n        q_x = torch.rand(1, ntests) * (high_x - low_x) + low_x\n        input_lb_x = torch.min(p_x, q_x)\n        input_ub_x = torch.max(p_x, q_x)\n        input_center_x = (input_lb_x + input_ub_x) / 2.0\n        ptb_x = PerturbationLpNorm(x_L=input_lb_x, x_U=input_ub_x)\n        ptb_data_x = BoundedTensor(input_center_x, ptb_x)\n\n        p_y = torch.rand(1, ntests) * (high_y - low_y) + low_y\n        q_y = torch.rand(1, ntests) * (high_y - low_y) + low_y\n        input_lb_y = torch.min(p_y, q_y)\n        input_ub_y = torch.max(p_y, q_y)\n        input_center_y = (input_lb_y + input_ub_y) / 2.0\n        ptb_y = PerturbationLpNorm(x_L=input_lb_y, x_U=input_ub_y)\n        ptb_data_y = BoundedTensor(input_center_y, ptb_y)\n\n        # Generate reference results.\n        range_xy = torch.linspace(start=low_x, end=high_x, steps=nsamples+1)\n        table = torch.empty([range_xy.shape[0], range_xy.shape[0]])\n        for i in range(range_xy.shape[0]):\n            x = range_xy[i]\n            table_y = act_func(x, torch.linspace(start=low_y, end=high_y, steps=nsamples+1))\n            table[i] = table_y\n        def lookup(l_x, u_x, l_y, u_y):\n            assert torch.all(u_x <= high_x)\n            assert torch.all(l_x >= low_x)\n            assert torch.all(u_y <= high_y)\n            assert torch.all(l_y >= low_y)\n            shape = l_x.size()\n            l_x = l_x.squeeze()\n            u_x = u_x.squeeze()\n            l_y = l_y.squeeze()\n            u_y = u_y.squeeze()\n            # select all sample points between l and u.\n            low_index_x = torch.ceil((l_x - low_x) / (high_x - low_x) * nsamples).int()  # Make sure we do not have index 0.\n            high_index_x = torch.floor((u_x - low_x) / (high_x - low_x) * nsamples).int()\n            low_index_y = torch.ceil((l_y - low_y) / (high_y - low_y) * nsamples).int()  # Make sure we do not have index 0.\n            high_index_y = torch.floor((u_y - low_y) / (high_y - low_y) * nsamples).int()\n            real_lb = torch.empty_like(l_x)\n            real_ub = torch.empty_like(u_x)\n            for i, (li_x, hi_x) in enumerate(zip(low_index_x, high_index_x)):\n                li_y = low_index_y[i]\n                hi_y = high_index_y[i]\n                if li_x == hi_x + 1 or li_y == hi_y + 1:\n                    # Not enough precision. l and u are too close so we cannot tell.\n                    real_lb[i] = float(\"inf\")\n                    real_ub[i] = float(\"-inf\")\n                else:\n                    selected = table[li_x : hi_x+1, li_y : hi_y+1].reshape(-1)\n                    real_lb[i] = torch.min(selected)\n                    real_ub[i] = torch.max(selected)\n            real_lb = real_lb.view(*shape)\n            real_ub = real_ub.view(*shape)\n            return real_lb, real_ub\n        # These are reference results. IBP results should be very close to these. Linear bound results can be looser than these.\n        ref_forward = model(input_center_x, input_center_y)\n        ref_output_lb, ref_output_ub = lookup(input_lb_x, input_ub_x, input_lb_y, input_ub_y)\n\n        # Get bounding results.\n        forward = bounded_model(ptb_data_x, ptb_data_y)\n        output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data_x, ptb_data_y), method = method)\n\n        # Compare.\n        assert torch.allclose(forward, ref_forward)\n        for i in tqdm.tqdm(range(ntests)):\n            show = False\n            if output_ub[0,i] < ref_output_ub[0,i] - 1e-5:\n                logger.warning(f'upper bound is wrong {ref_output_ub[0,i] - output_ub[0,i]}')\n                show = True\n            if output_lb[0,i] > ref_output_lb[0,i] + 1e-5:\n                logger.warning(f'lower bound is wrong {output_lb[0,i] - ref_output_lb[0,i]}')\n                show = True\n            if show:\n                logger.warning(f'input_lb_x={input_lb_x[0,i]:8.3f}, input_ub_x={input_ub_x[0,i]:8.3f},input_lb_y={input_lb_y[0,i]:8.3f}, input_ub_y={input_ub_y[0,i]:8.3f}, lb={output_lb[0,i]:8.3f}, ref_lb={ref_output_lb[0,i]:8.3f}, ub={output_ub[0,i]:8.3f}, ref_ub={ref_output_ub[0,i]:8.3f}')\n        assert torch.all(output_ub + 1e-5 >= ref_output_ub)\n        assert torch.all(output_lb - 1e-5 <= ref_output_lb)\n\n    def test_max(self):\n        self.create_test(act_func=torch.max, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP')\n        self.create_test(act_func=torch.max, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN')\n\n    def test_min(self):\n        self.create_test(act_func=torch.min, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP')\n        self.create_test(act_func=torch.min, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN')\n\n    def test_mul(self):\n        self.create_test(act_func=mul, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP')\n        self.create_test(act_func=mul, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN')\n\nif __name__ == '__main__':\n    testcase = Test2DActivation()\n    testcase.test_max()\n    testcase.test_min()\n    testcase.test_mul()\n"
  },
  {
    "path": "tests/test_avgpool.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\ndef ff(num_conv=2, num_mlp_only=None, pooling=False, activation=\"ReLU\",\n       hidden_size=256, input_ch=1, input_dim=28, num_classes=10, pool_kernel=3, pool_stride=1, pool_padding=1):\n    activation = eval(f\"nn.{activation}()\")\n    layers = []\n    if num_conv:\n        layers.append(nn.Conv2d(input_ch, 4, 3, stride=1, padding=1))\n        layers.append(activation)\n        num_channels = 4\n        if pooling:\n            layers.append(nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding))\n        if num_conv >= 2:\n            layers.append(nn.Conv2d(4, 8, 3, stride=1, padding=1))\n            layers.append(nn.ReLU())\n            if pooling:\n                layers.append(nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding))\n            num_channels = 8\n        for _ in range(num_conv - 2):\n            layers.append(nn.Conv2d(8, 8, 3, stride=1, padding=1))\n            layers.append(nn.ReLU())\n            if pooling:\n                layers.append(nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding))\n        layers.append(nn.Flatten(1))\n\n        # Calculate output size after pooling operations\n        if pooling and num_conv > 0:\n            pooled_dim = input_dim\n            for _ in range(num_conv):\n                pooled_dim = (pooled_dim + 2 * pool_padding - pool_kernel) // pool_stride + 1\n            linear_input_size = num_channels * (pooled_dim ** 2)\n        else:\n            linear_input_size = num_channels * (input_dim ** 2)\n\n        layers.append(nn.Linear(linear_input_size, hidden_size))\n        layers.append(nn.ReLU())\n        layers.append(nn.Linear(hidden_size, num_classes))\n    else:\n        layers.append(nn.Flatten(1))\n        cur = input_ch * (input_dim ** 2)\n        for _ in range(num_mlp_only - 1):\n            layers.append(nn.Linear(cur, hidden_size))\n            layers.append(activation)\n            cur = hidden_size\n        layers.append(nn.Linear(hidden_size, num_classes))\n    return nn.Sequential(*layers)\n\n\ndef synthetic_net(input_ch, input_dim, **kwargs):\n    return ff(input_ch=input_ch, input_dim=input_dim, num_classes=2, **kwargs)\n\n\ndef synthetic_4c2f_pool(input_ch, input_dim, **kwargs):\n    return synthetic_net(input_ch, input_dim, num_conv=4, pooling=True, **kwargs)\n\n\nclass TestAvgPool(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1234, ref_name='avgpool_test_data',\n            generate=generate, device=device, dtype=dtype)\n\n    def test(self):\n        test_configs = [\n            {'input_ch': 1, 'input_dim': 5, 'hidden_size': 8, 'pool_kernel': 3, 'pool_stride': 1, 'pool_padding': 1},\n            {'input_ch': 1, 'input_dim': 32, 'hidden_size': 16, 'pool_kernel': 2, 'pool_stride': 2, 'pool_padding': 0}\n        ]\n\n        self.result = []\n\n        for config in test_configs:\n            print(f\"Testing config: {config}\")\n\n            model_ori = synthetic_4c2f_pool(**config)\n            model_ori = model_ori.eval().to(self.default_device).to(self.default_dtype)\n\n            x = torch.randn(8, config['input_ch'], config['input_dim'], config['input_dim'])\n\n            ptb = PerturbationLpNorm(norm=np.inf, eps=100)\n            x_bounded = BoundedTensor(x, ptb)\n\n            print(f\"  Testing with default conv_mode (patches)\")\n            model = BoundedModule(model_ori, x, device=self.default_device)\n\n            lb_patches, ub_patches = model.compute_bounds(x=(x_bounded,), method='backward')\n            print(f\"    Patches mode - LB: {lb_patches}\")\n            print(f\"    Patches mode - UB: {ub_patches}\")\n\n            self.result += [lb_patches, ub_patches]\n\n            print(f\"  Testing with conv_mode='matrix'\")\n            model_matrix = BoundedModule(model_ori, x, bound_opts={'conv_mode': 'matrix'})\n\n            lb_matrix, ub_matrix = model_matrix.compute_bounds(x=(x_bounded,), method='backward')\n            print(f\"    Matrix mode - LB: {lb_matrix}\")\n            print(f\"    Matrix mode - UB: {ub_matrix}\")\n\n            self.result += [lb_matrix, ub_matrix]\n\n            lb_diff = torch.abs(lb_patches - lb_matrix).max().item()\n            ub_diff = torch.abs(ub_patches - ub_matrix).max().item()\n            print(f\"    Max difference in LB between patches and matrix: {lb_diff}\")\n            print(f\"    Max difference in UB between patches and matrix: {ub_diff}\")\n\n            assert torch.allclose(lb_patches, lb_matrix, atol=1e-6), f\"Lower bounds not equivalent between patches and matrix modes\"\n            assert torch.allclose(ub_patches, ub_matrix, atol=1e-6), f\"Upper bounds not equivalent between patches and matrix modes\"\n            print(f\"    Matrix and patches modes produce equivalent results\")\n            print()\n\n        self.check()\n\n\nif __name__ == '__main__':\n    testcase = TestAvgPool(generate=False)\n    testcase.test()"
  },
  {
    "path": "tests/test_bound_ops.py",
    "content": "\"\"\"Test classes for bound operators\"\"\"\nimport torch\nfrom auto_LiRPA.bound_ops import *\nfrom auto_LiRPA.linear_bound import LinearBound\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass Dummy:\n    \"\"\"Dummy node for testing\"\"\"\n    def __init__(self, lower, upper=None, perturbed=False):\n        self.lower = lower\n        self.upper = upper if upper is not None else lower\n        self.perturbed = perturbed\n        self.output_shape = lower.shape\n\n\nclass TestBoundOp(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name='bound_ops_data',\n            generate=generate, device=device, dtype=dtype)\n\n    def test(self):\n        device = self.default_device\n        dtype = self.default_dtype\n        batch_size = 5\n        dim_final = 7\n        dim_output = 9\n        dim_input = 11\n\n        # multiplication of [batch_size, dim_input] and [dim_output, dim_input]^T\n        weight = torch.randn(dim_output, dim_input, device=device)\n        bias = torch.randn(dim_output, device=device)\n        data_in = torch.randn(batch_size, dim_input, device=device)\n        data_in_delta = torch.randn(batch_size, dim_input, device=device)\n        dummy_in = Dummy(\n            data_in - torch.abs(data_in_delta),\n            data_in + torch.abs(data_in_delta), True)\n        dummy_weight = Dummy(weight)\n        dummy_bias = Dummy(bias)\n\n        op = BoundLinear(\n            attr={'transB': 1},\n            inputs=[dummy_in, dummy_weight, dummy_bias],\n            output_index=0, options={})\n        op.batch_dim = 0\n\n        # test `forward`\n        data_out = op(data_in, weight, bias)\n        self.assertEqual(data_out, data_in.matmul(weight.t()) + bias)\n\n        # test `bound_backward`\n        # The `transpose` here to make the randomization consistent with the previous reference.\n        # It can be removed once a new reference is generated.\n        last_lA = torch.randn(batch_size, dim_final, dim_output, device=device).transpose(0, 1)\n        last_uA = torch.randn(batch_size, dim_final, dim_output, device=device).transpose(0, 1)\n        A, lbias, ubias = op.bound_backward(last_lA, last_uA, *op.inputs)\n        self.assertEqual(A[0][0], last_lA.matmul(weight))\n        self.assertEqual(A[0][1], last_uA.matmul(weight))\n        self.assertEqual(lbias, last_lA.matmul(bias))\n        self.assertEqual(ubias, last_uA.matmul(bias))\n\n        # test `bound_forward`\n        # note that the upper bound may be actually smaller than the lower bound\n        # in these dummy linear bounds\n        bound_in = LinearBound(\n            lw=torch.randn(batch_size, dim_final, dim_input, device=device),\n            lb=torch.randn(batch_size, dim_input, device=device),\n            uw=torch.randn(batch_size, dim_final, dim_input, device=device),\n            ub=torch.randn(batch_size, dim_input, device=device),\n            lower=None, upper=None)\n        bound_weight = LinearBound(None, None, None, None, dummy_weight.lower, dummy_weight.upper)\n        bound_bias = LinearBound(None, None, None, None, dummy_bias.lower, dummy_bias.upper)\n        bound_out = op.bound_forward(dim_final, bound_in, bound_weight, bound_bias)\n        self.assertEqual(\n            bound_out.lw, bound_in.lw.matmul(weight.t().clamp(min=0))\n            + bound_in.uw.matmul(weight.t().clamp(max=0)))\n        self.assertEqual(\n            bound_out.uw, bound_in.uw.matmul(weight.t().clamp(min=0))\n            + bound_in.lw.matmul(weight.t().clamp(max=0)))\n        self.assertEqual(\n            bound_out.lb, bound_in.lb.matmul(weight.t().clamp(min=0))\n            + bound_in.ub.matmul(weight.t().clamp(max=0)) + bias)\n        self.assertEqual(\n            bound_out.ub, bound_in.ub.matmul(weight.t().clamp(min=0))\n            + bound_in.lb.matmul(weight.t().clamp(max=0)) + bias)\n\n        # test `interval_propagate`\n        bound_in = (\n            torch.randn(*data_in.shape, device=device),\n            torch.randn(*data_in.shape, device=device))\n        bound_weight = (bound_weight.lower, bound_weight.upper)\n        bound_bias = (bound_bias.lower, bound_bias.upper)\n        bound_out = op.interval_propagate(bound_in, bound_weight, bound_bias)\n        self.assertEqual(bound_out[0],\n                         bound_in[0].matmul(weight.t().clamp(min=0))\n                         + bound_in[1].matmul(weight.t().clamp(max=0)) + bias)\n        self.assertEqual(bound_out[1],\n                         bound_in[1].matmul(weight.t().clamp(min=0))\n                         + bound_in[0].matmul(weight.t().clamp(max=0)) + bias)\n\n        # test weight perturbation\n        # `bound_backward`\n        ptb_weight = torch.randn(weight.shape)\n        op.inputs[1].upper += ptb_weight\n        op.inputs[1].perturbed = True\n        op.inputs[2].perturbation = None # no perturbation on bias\n        A, lbias, ubias = op.bound_backward(last_lA, last_uA, *op.inputs)\n        # `interval_propagate`\n        bound_weight = (op.inputs[1].lower, op.inputs[1].upper)\n        bound_out = op.interval_propagate(bound_in, bound_weight, bound_bias)\n\n        self.result = (A, lbias, ubias, bound_out)\n\n        if self.generate:\n            self.save()\n            self.reference = self.result\n\n        A_ref, lbias_ref, ubias_ref, bound_out_ref = self.reference\n        for i in range(3):\n            for j in range(2):\n                if A_ref[i][j] is not None:\n                    ref = A_ref[i][j].to(device=device, dtype=dtype)\n                    self.assertEqual(A[i][j], ref)\n                    \n        lbias_ref = lbias_ref.to(device=device, dtype=dtype)\n        ubias_ref = ubias_ref.to(device=device, dtype=dtype)\n        bound_out_ref = (\n            bound_out_ref[0].to(device=device, dtype=dtype),\n            bound_out_ref[1].to(device=device, dtype=dtype)\n        )\n        self.assertEqual(lbias, lbias_ref)\n        self.assertEqual(ubias, ubias_ref)\n        self.assertEqual(bound_out[0], bound_out_ref[0])\n        self.assertEqual(bound_out[1], bound_out_ref[1])\n\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestBoundOp(generate=False)\n    testcase.setUp()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_branching_heuristics.py",
    "content": "import sys\nimport torch\nfrom types import SimpleNamespace\n\nsys.path.insert(0, '../complete_verifier')\n\nfrom heuristics.base import RandomNeuronBranching\nfrom testcase import DEFAULT_DEVICE, DEFAULT_DTYPE, set_default_dtype_device\n\ndef test_branching_heuristics():\n    device = DEFAULT_DEVICE\n    dtype = DEFAULT_DTYPE\n    set_default_dtype_device(dtype, device)\n    import random\n    import numpy as np\n    seed = 123\n    torch.manual_seed(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n\n    net = SimpleNamespace()\n    branching_heuristic = RandomNeuronBranching(net)\n\n    for _ in range(10000):\n        batch_size = random.randint(1, 5)\n        # Number of layers, and we will split the total_layers into this\n        # many of layers.\n        n_layers = random.randint(1, 5)\n        total_len = random.randint(n_layers, 100)\n        net.split_nodes = []\n        net.split_activations = {}\n        for i in range(n_layers):\n            layer = SimpleNamespace()\n            layer.name = i\n            activation = SimpleNamespace()\n            activation.name = f'{i}_activation'\n            net.split_nodes.append(layer)\n            net.split_activations[layer.name] = [(activation, 0)]\n        # Total number of neurons in all layers.\n        topk = random.randint(1, total_len)\n        # Generate random and unique scores.\n        # scores = torch.argsort(torch.rand(batch_size, total_len)) + 1\n        scores = torch.rand(batch_size, total_len) + 1e-8\n        # Generate random mask. Mask = 1 means this neuron can be split.\n        masks = (torch.rand(batch_size, total_len) > 0.75).float()\n        # Generate random split locations.\n        split_position = torch.randint(\n            low=0, high=total_len, size=(n_layers - 1,)).sort().values\n        print(f'testing batch={batch_size}, n_layers={n_layers}, '\n              f'total_len={total_len}, topk={topk}, split={split_position}')\n        segment_lengths = (torch.cat(\n            [split_position, torch.full(size=(1,),\n                                        fill_value=total_len,\n                                        device=split_position.device)])\n                           - torch.cat([torch.zeros((1,), device=split_position.device),\n                                        split_position]))\n        segment_lengths = segment_lengths.int().tolist()\n        # Cap to the minimum number of valid neurons in each batch.\n        min_k = int(masks.sum(dim=1).min().item())\n        # Find the topk scores and indices across all layers.\n        topk_scores, topk_indices = (scores * masks).topk(k=min(min_k, topk))\n        # Map the indices to groundtruth layer number.\n        topk_layers = torch.searchsorted(\n            split_position, topk_indices, right=True)\n        # Map the indices to groundtruth neuron number.\n        topk_neurons = topk_indices - torch.cat(\n            [torch.zeros(1, device=split_position.device, dtype=torch.int64),\n             split_position]\n        ).view(1, -1).repeat(batch_size, 1).gather(\n            dim=1, index=topk_layers)\n        # Split into a list of scores for testing.\n        all_layer_scores = scores.split(segment_lengths, dim=1)\n        all_layer_masks = masks.split(segment_lengths, dim=1)\n        all_layer_scores = {i: item for i, item in enumerate(all_layer_scores)}\n        all_layer_masks = {i: item for i, item in enumerate(all_layer_masks)}\n        branching_heuristic.update_batch_size_and_device(all_layer_scores)\n        (calculated_layers, calculated_neurons,\n         calculated_scores) = branching_heuristic.find_topk_scores(\n            all_layer_scores, all_layer_masks, k=topk, return_scores=True)\n        torch.testing.assert_close(calculated_layers, topk_layers)\n        torch.testing.assert_close(calculated_neurons, topk_neurons)\n        torch.testing.assert_close(calculated_scores, topk_scores)\n\n\nif __name__ == \"__main__\":\n    test_branching_heuristics()\n"
  },
  {
    "path": "tests/test_clip_domains.py",
    "content": "\"\"\"\nTests clip_domains\n\nTo run tests: py.test             test_clip_domains.py\n          or: python -m pytest    test_clip_domains.py\nVerbose (-v): py.test -v          test_clip_domains.py\n          or: python -m pytest -v test_clip_domains.py\n\"\"\"\nimport torch\nfrom torch import Tensor\nfrom random import randint\nfrom typing import Union, Tuple\n\nimport sys\nsys.path.append('../complete_verifier')\n\n# importing clip_domains from CROWN\nfrom input_split.clip import clip_domains\nfrom testcase import DEFAULT_DEVICE, DEFAULT_DTYPE, set_default_dtype_device\n\nbatches = 2 # Do not use large batch sizes when running on CI\ndevice = DEFAULT_DEVICE  # CI is not equipped with CUDA\ndtype = DEFAULT_DTYPE\n\nset_default_dtype_device(dtype, device)\n\natol = 1e-4  # my references are defined at this level of tolerance\n\ndef setup_module(module):\n    \"\"\"\n    Displays global information about the test run\n    @param module:\n    @return:\n    \"\"\"\n    print()\n    print(\"setup_module      module:%s\" % module.__name__)\n    print(f\"Using device: {device}\")\n    print(f\"Using dtype: {dtype}\")\n    print(f\"Using atol: {atol}\")\n    print(f\"Using number of batches (batch copies): {batches}\")\n    print()\n\ndef setup_function(function):\n    \"\"\"\n    Adds spacing between tests\n    @param function:\n    @return:\n    \"\"\"\n    print(f\"\\nRunning test case: {function.__name__}\")\n\ndef _tensor(x):\n    return torch.tensor(x, device=device, dtype=dtype)\n\ndef test_case_one_one():\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[4 / 5, -7 / 20], [3 / 10, -3 / 7]])\n    x_L_base = _tensor([-3, -2])\n    x_U_base = _tensor([3, 2])\n    c_bar_base = _tensor([[1 / 10], [3 / 10]])\n    target_base = _tensor([[0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = _tensor([-3., -1.4]).unsqueeze(0).expand(batches, -1)\n    ref_x_U = _tensor([0.75, 2.0000]).unsqueeze(0).expand(batches, -1)\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\ndef test_case_one_two():\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[3 / 10, -3 / 7]])\n    x_L_base = _tensor([-3, -2])\n    x_U_base = _tensor([3, 2])\n    c_bar_base = _tensor([[3 / 10]])\n    target_base = _tensor([[0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = _tensor([-3., -1.4]).unsqueeze(0).expand(batches, -1)\n    ref_x_U = _tensor([1.8571, 2.0000]).unsqueeze(0).expand(batches, -1)\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\ndef test_case_one_three():\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[3 / 10, -3 / 7], [3 / 10, -3 / 7]])\n    x_L_base = _tensor([-3, -2])\n    x_U_base = _tensor([3, 2])\n    c_bar_base = _tensor([[3 / 10], [3 / 10]])\n    target_base = _tensor([[0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = _tensor([-3., -1.4]).unsqueeze(0).expand(batches, -1)\n    ref_x_U = _tensor([1.8571, 2.0000]).unsqueeze(0).expand(batches, -1)\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\n\ndef test_case_one_four():\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[4 / 5, -7 / 20, 0.1], [3 / 10, -3 / 7, 0.1]])\n    x_L_base = _tensor([-3, -2, -1])\n    x_U_base = _tensor([3, 2, 1])\n    c_bar_base = _tensor([[1 / 10], [3 / 10]])\n    target_base = _tensor([[0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\ndef test_case_two_one():\n    \"\"\"\n    Visualize this test case at\n    https://www.desmos.com/3d/fz6e11ovm3\n    @return:\n    \"\"\"\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[5/5, 1/5], [2/5, 1/5], [10/35, 1/5]])\n    x_L_base = _tensor([0, 0])\n    x_U_base = _tensor([1, 1])\n    c_bar_base = _tensor([[-1/5], [-1/5], [-1/5]])\n    target_base = _tensor([[0], [0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = _tensor([0., 0.]).unsqueeze(0).expand(batches, -1)\n    ref_x_U = _tensor([0.2000, 1.0000]).unsqueeze(0).expand(batches, -1)\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\n\ndef test_case_two_two():\n    \"\"\"\n    Visualize this test case at\n    https://www.desmos.com/3d/ruty3i54wu\n    @return:\n    \"\"\"\n    print()\n    # Define the base 2D tensors\n    A_bar_base = -1. * _tensor([[5 / 5, 1 / 5], [2 / 5, 1 / 5], [10 / 35, 1 / 5]])\n    x_L_base = _tensor([0, 0])\n    x_U_base = _tensor([1, 1])\n    c_bar_base = -1. * _tensor([[-1 / 5], [-1 / 5], [-1 / 5]])\n    target_base = _tensor([[0], [0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = x_L.clone()\n    ref_x_U = x_U.clone()\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\ndef test_case_two_three():\n    \"\"\"\n    Visualize this test case at\n    https://www.desmos.com/3d/vogsjthmav\n    @return:\n    \"\"\"\n    print()\n    # Define the base 2D tensors\n    A_bar_base = _tensor([[-5 / 5, -1 / 5], [2 / 5, 1 / 5], [10 / 35, 1 / 5]])\n    x_L_base = _tensor([0, 0])\n    x_U_base = _tensor([1, 1])\n    c_bar_base = _tensor([[1 / 5], [-1 / 5], [-1 / 5]])\n    target_base = _tensor([[0], [0], [0]])\n\n    # Expand the base tensors along the batch dimension\n    lA, x_L, x_U, c_bar, thresholds, dm_lb = setup_test_matrices(A_bar_base, x_L_base, x_U_base, c_bar_base, target_base,\n                                                          batches)\n\n    # In this suite, we have a reference for x_L/U\n    ref_x_L = x_L.clone()\n    ref_x_U = torch.zeros_like(x_U)\n    ref_x_U[:] = _tensor([0.5, 1.0])\n\n    old_x_L = x_L.clone()\n    old_x_U = x_U.clone()\n    ret = clip_domains(x_L, x_U, thresholds, lA, None, dm_lb)\n    new_x_L, new_x_U = ret\n    assert (new_x_L.shape == old_x_L.shape) and (new_x_U.shape == old_x_U.shape), \"x_L(U) should have the same shape as before\"\n\n    # check the returned x_L/U matches the expected x_L/U values\n    x_L_eq = torch.allclose(new_x_L, ref_x_L, atol=atol)\n    x_U_eq = torch.allclose(new_x_U, ref_x_U, atol=atol)\n    assert x_L_eq, \"x_L is not correct\"\n    assert x_U_eq, \"x_U is not correct\"\n\n# Rest of file are helper functions\n\ndef concretize_bounds(\n        x_hat: torch.Tensor,\n        x_eps: torch.Tensor,\n        lA: torch.Tensor,\n        lbias: Union[torch.Tensor, int],\n        C: Union[torch.Tensor, None] = None,\n        lower: bool = True):\n    \"\"\"\n    Takes batches and concretizes them\n    @param x_hat: shape (batch, input_dim)                  The origin position of the input domain\n    @param x_eps: shape (batch, input_dim)                  The epsilon disturbance from the origin of the input domain\n    @param lA: shape (batch, spec_dim/lA rows, input_dim)   The lA matrix calculated by CROWN; When C is None, we refer\n                                                            to the second dimension as spec_dim. When C is given, this\n                                                            is denoted as lA rows\n    @param lbias: shape (batch, spec_dim)                   The bias vector calculated by CROWN\n    @param lower:                                           Whether the lower or upper bound should be concretized\n    @param C: shape (batch, spec_dim, lA rows)              When not None, is transposed and distributed to lA and lbias\n                                                            to produce the specification of interest\n    @return:                                                The lower/upper bound of the batches\n    \"\"\"\n    lA = lA.view(lA.shape[0], lA.shape[1], -1)\n    batches, spec_dim, input_dim = lA.shape\n    if isinstance(lbias, int):\n        lbias = _tensor([lbias]).expand(batches, spec_dim)\n    lbias = lbias.unsqueeze(-1)  # change lbiases to be column vectors\n    if C is not None:\n        # Let C act like the new last linear layer of the network and distribute it to lA and lbias\n        # Update shapes\n        C = C.reshape(batches, spec_dim, -1)\n        C = C.transpose(1, 2)\n        lA = C.bmm(lA)\n        lbias = C.bmm(lbias)\n        batches, spec_dim, input_dim = lA.shape\n    # lA shape: (batch, spec_dim, # inputs)\n    # dom_lb shape: (batch, spec_dim)\n    # thresholds shape: (batch, spec_dim)\n    # lbias shape: (batch, spec_dim, 1)\n\n    sign = -1 if lower else 1\n    x_hat = x_hat.unsqueeze(-1)\n    x_eps = x_eps.unsqueeze(-1)\n\n    ret = lA.bmm(x_hat) + sign * lA.abs().bmm(x_eps) + lbias\n\n    return ret.squeeze(2)\n\ndef setup_test_matrices(\n        A_bar_base: Tensor,\n        x_L_base: Tensor,\n        x_U_base: Tensor,\n        l_bias_base: Tensor,\n        target_base: Tensor,\n        batches: int\n) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n    \"\"\"\n    Creates batch copies of base Tensors and formats them in the same format that they would be in CROWN.\n    @param A_bar_base: shape (spec_dim, input_dim)  The lA matrix of the instance\n    @param x_L_base: shape (input_dim,)             The lower bound on the input domain\n    @param x_U_base: shape (input_dim,)             The upper bound on the input domain\n    @param l_bias_base: shape (spec_dim,)           The bias vector of the instance\n    @param target_base: shape (spec_dim,)           The threshold/specification to verify\n    @param batches:                                 The number of batch copies to produce of the instance\n    @return:                                        Returns same instance in batch form\n    \"\"\"\n    # create the copies\n    lA, x_L, x_U, c_bar, thresholds = create_batch_copies(A_bar_base, x_L_base, x_U_base, l_bias_base, target_base,\n                                                          batches)\n\n    # This is how x_L, x_U, lbias will be received in CROWN\n    # x_L/U shape: (batch, # inputs)\n    # lA shape: (batch, spec_dim, # inputs)\n    # dom_lb shape: (batch, spec_dim)\n    # thresholds shape: (batch, spec_dim)\n    x_L = x_L.flatten(1)\n    x_U = x_U.flatten(1)\n    c_bar = c_bar.squeeze(-1)\n    thresholds = thresholds.squeeze(-1)\n\n    # get the global lb\n    x_hat = (x_U + x_L) / 2\n    x_eps = (x_U - x_L) / 2\n    dm_lb = concretize_bounds(x_hat, x_eps, lA, c_bar)\n\n    return lA, x_L, x_U, c_bar, thresholds, dm_lb\n\ndef create_batch_copies(\n        A_bar_base: Tensor,\n        x_L_base: Tensor,\n        x_U_base: Tensor,\n        l_bias_base: Tensor,\n        target_base: Tensor,\n        batches: int\n) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n    \"\"\"\n    Takes a problem not in batch form and turns them into batches.\n    If batches = 1, we only solve the initial problem in batch form, and if batches > 1, we are solving the same\n    problem but in multiple batches.\n    @param A_bar_base:\n    @param x_L_base:\n    @param x_U_base:\n    @param l_bias_base:\n    @param target_base:\n    @param batches:\n    @return:\n    \"\"\"\n    A_bar = A_bar_base.unsqueeze(0).repeat(batches, 1, 1)\n    x_L = x_L_base.unsqueeze(0).repeat(batches, 1)\n    x_U = x_U_base.unsqueeze(0).repeat(batches, 1)\n    l_bias = l_bias_base.unsqueeze(0).repeat(batches, 1, 1)\n    target = target_base.unsqueeze(0).repeat(batches, 1, 1)\n\n    return A_bar, x_L, x_U, l_bias, target\n\n\ndef random_setup_generator(\n        randint_range=(1, 10),\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:\n    \"\"\"\n    Creates random problem set-ups to test out if our new heuristic is compatible with various dimensions\n    @param randint_range:   A range where batches, spec_dim, and input_dim will exist in\n    @return:\n    \"\"\"\n    batches, spec_dim, input_dim = randint(*randint_range), randint(*randint_range), randint(*randint_range)\n    lA = torch.rand((batches, spec_dim, input_dim))\n    lbias = torch.rand((batches, spec_dim, 1))\n    thresholds = torch.rand((batches, spec_dim, 1))\n    parameters = {\n        \"batches\": batches,\n        \"spec_dim\": spec_dim,\n        \"input_dim\": input_dim\n    }\n    return lA, lbias, thresholds, parameters\n"
  },
  {
    "path": "tests/test_constant.py",
    "content": "\"\"\"Test BoundConstant\"\"\"\nimport torch\nimport os\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass cnn_MNIST(nn.Module):\n    def __init__(self):\n        super(cnn_MNIST, self).__init__()\n        self.conv1 = nn.Conv2d(1, 8, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(784, 256)\n        self.fc2 = nn.Linear(256, 10)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        x = x.view(-1, 784)\n        x = 2.0 * x\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        return 0.5 * x\n\nclass TestConstant(TestCase):\n    \n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name='constant_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        model = cnn_MNIST()\n        checkpoint = torch.load(\"../examples/vision/pretrained/mnist_cnn_small.pth\", map_location=self.default_device)\n        model.load_state_dict(checkpoint)\n\n        N = 2\n        n_classes = 10\n        image = torch.randn(N, 1, 28, 28)\n        image = image.to(device=self.default_device,\n                         dtype=self.default_dtype) / 255.0\n\n        model = BoundedModule(model, torch.empty_like(image), device=self.default_device)\n        eps = 0.3\n        norm = np.inf\n        ptb = PerturbationLpNorm(norm=norm, eps=eps)\n        image = BoundedTensor(image, ptb)\n        pred = model(image)\n        lb, ub = model.compute_bounds()\n\n        assert lb.shape == ub.shape == torch.Size((2, 10))\n\n        self.result = (lb, ub)\n        if self.reference:\n            self.reference = (\n                self.reference[0].to(\n                    device=self.default_device, dtype=self.default_dtype),\n                self.reference[1].to(\n                    device=self.default_device, dtype=self.default_dtype)\n            )\n\n        self.rtol = 5e-4\n        self.check()\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestConstant(generate=False)\n    testcase.setUp()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_constrained_concretize.py",
    "content": "\"\"\"Test optimized bounds in simple_verification.\"\"\"\nimport torch\nimport numpy as np\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\n\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass ConstrainedConcretizeModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.w1 = torch.tensor([[1., -1.], [2., -1.]])\n        self.w2 = torch.tensor([[1., -1.]])\n\n    def forward(self, x):\n        z1 = x.matmul(self.w1.t())\n        hz1 = torch.nn.functional.relu(z1)\n        z2 = hz1.matmul(self.w2.t())\n        return z2\n\nclass TestConstrainedConcretize(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, 1, \"test_constrained_concretize\", generate, device=device, dtype=dtype)\n\n    def test(self):\n        model = ConstrainedConcretizeModel().to(self.default_device).to(self.default_dtype)\n        # Input x.\n        x = torch.tensor([[1., 1.]], dtype=self.default_dtype, device=self.default_device)\n        # Lower and upper bounds of x.\n        lower = torch.tensor([[-1., -2.]], dtype=self.default_dtype, device=self.default_device)\n        upper = torch.tensor([[2., 1.]], dtype=self.default_dtype, device=self.default_device)\n\n        # Wrap model with auto_LiRPA for bound computation.\n        # The second parameter is for constructing the trace of the computational graph,\n        # and its content is not important.\n\n        lirpa_model = BoundedModule(model, torch.empty_like(x))\n        pred = lirpa_model(x)\n        print(f'Model prediction: {pred.item()}')\n\n        # Compute bounds using LiRPA using the given lower and upper bounds.\n        norm = float(\"inf\")\n        ptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper)\n        bounded_x = BoundedTensor(x, ptb)\n\n        # Compute bounds.\n        lb, ub = lirpa_model.compute_bounds(x=(bounded_x,), method='CROWN')\n        print(f'CROWN bounds: lower={lb.item()}, upper={ub.item()}')\n\n        # Add a new constraint of :\n        #    1*x_0 + 1*x_1 + 2 <= 0\n        constraint_a = torch.tensor([[[1.0, 1.0]]], dtype=self.default_dtype, device=self.default_device)\n        constraint_b = torch.tensor([[2.0]], dtype=self.default_dtype, device=self.default_device)\n        constraints = (constraint_a, constraint_b)\n\n        norm = float(\"inf\")\n        ptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper, constraints=constraints)\n        bounded_x = BoundedTensor(x, ptb)\n        # Compute bounds.\n        constrained_lb, constrained_ub = lirpa_model.compute_bounds(x=(bounded_x,), method='CROWN')\n        print(f'CROWN bounds (with constraints): lower={constrained_lb.item()}, upper={constrained_ub.item()}')\n\n        self.result = (lb, ub, constrained_lb, constrained_ub)\n        self.check()\n\nif __name__ == '__main__':\n    testcase = TestConstrainedConcretize(generate=True)\n    testcase.setUp()\n    testcase.test()"
  },
  {
    "path": "tests/test_conv.py",
    "content": "import torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass cnn_model(nn.Module):\n    def __init__(self, layers, padding, stride, linear):\n        super(cnn_model, self).__init__()\n        self.module_list = []\n        channel = 1\n        length = 28\n        for i in range(layers):\n            self.module_list.append(nn.Conv2d(channel, 3, 4, stride = stride, padding = padding))\n            channel = 3\n            length = (length + 2 * padding - 4)//stride + 1\n            assert length > 0\n            self.module_list.append(nn.ReLU())\n        self.module_list.append(nn.Flatten())\n        if linear:\n            self.module_list.append(nn.Linear(3 * length * length, 256))\n            self.module_list.append(nn.Linear(256, 10))\n        self.model = nn.Sequential(*self.module_list)\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\nclass TestConv(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name=None,\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        models = [1, 2, 3]\n        paddings = [1, 2]\n        strides = [1, 3]\n\n        N = 2\n        n_classes = 10\n        image = torch.randn(N, 1, 28, 28, dtype=self.default_dtype, device=self.default_device)\n        image = image / 255.0\n\n        for layer_num in models:\n            for padding in paddings:\n                for stride in strides:\n                    for linear in [True, False]:\n                        model_ori = cnn_model(layer_num, padding, stride, linear)\n                        print('Model:', model_ori)\n                        model_ori = model_ori.to(\n                            device=self.default_device, dtype=self.default_dtype)\n\n                        model = BoundedModule(model_ori, image, device=self.default_device, bound_opts={\"conv_mode\": \"patches\"})\n                        eps = 0.3\n                        ptb = PerturbationLpNorm(x_L=image-eps, x_U=image+eps)\n                        image = BoundedTensor(image, ptb)\n                        pred = model(image)\n                        lb, ub = model.compute_bounds()\n\n                        model = BoundedModule(model_ori, image, device=self.default_device, bound_opts={\"conv_mode\": \"matrix\"})\n                        pred = model(image)\n                        lb_ref, ub_ref = model.compute_bounds()\n\n                        if linear:\n                            assert lb.shape == ub.shape == torch.Size((N, n_classes))\n                        self.assertEqual(lb, lb_ref)\n                        self.assertEqual(ub, ub_ref)\n\n                        if not linear and layer_num == 1:\n                            pred = model(image)\n                            lb_forward, ub_forward = model.compute_bounds(method='forward')\n                            self.assertEqual(lb, lb_forward)\n                            self.assertEqual(ub, ub_forward)\n                            pred = model(image)\n                            lb_forward, ub_forward = model.compute_bounds(method='dynamic-forward+backward')\n                            self.assertEqual(lb, lb_forward)\n                            self.assertEqual(ub, ub_forward)\n\nif __name__ == '__main__':\n    testcase = TestConv()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_conv1d.py",
    "content": "\"\"\"Test Conv1d.\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass Model(nn.Module):\n    def __init__(self, kernel_size=2, stride=1, padding=0, in_features=1,out_features=1):\n        super(Model, self).__init__()        \n        self.n_n_conv1d_1 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True})\n        self.n_n_conv1d_2 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True})\n        self.relu_2 = nn.ReLU()\n        self.n_n_conv1d_3 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True})\n        self.relu_3 = nn.ReLU()\n        self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1})\n        L_in,dialation = in_features,1\n        L_out_1 = math.floor((L_in+2*padding-dialation*(kernel_size-1)-1)/stride+1)\n        L_out_2 = math.floor((L_out_1+2*padding-dialation*(kernel_size-1)-1)/stride+1)\n        L_out_3 = math.floor((L_out_2+2*padding-dialation*(kernel_size-1)-1)/stride+1)\n        self.n_n_linear = nn.Linear(**{'in_features':L_out_3, 'out_features':out_features,'bias':True})\n\n    def forward(self, *inputs,debug=False):\n        t_ImageInputLayer, = inputs\n        t_conv1d_1 = self.n_n_conv1d_1(t_ImageInputLayer)\n        if debug: print(\"t_ImageInputLayer\",t_ImageInputLayer.shape)\n        if debug: print(\"t_conv1d_1\",t_conv1d_1.shape)\n        t_conv1d_relu_1 = F.relu(t_conv1d_1)\n        t_conv1d_2 = self.n_n_conv1d_2(t_conv1d_relu_1)\n        if debug: print(\"t_conv1d_2\",t_conv1d_2.shape)\n        t_conv1d_relu_2 = F.relu(t_conv1d_2)\n        t_conv1d_3 = self.n_n_conv1d_3(t_conv1d_relu_2)\n        if debug: print(\"t_conv1d_3\",t_conv1d_3.shape)\n        t_conv1d_relu_3 = F.relu(t_conv1d_3)\n        t_flatten = self.n_n_activation_Flatten(t_conv1d_relu_3)\n        if debug: print(\"t_flatten\",t_flatten.shape)\n        t_linear = self.n_n_linear(t_flatten)        \n        if debug: print(\"t_linear\",t_linear.shape)\n        return t_linear\n\nclass TestConv1D(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name=None,\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        if self.default_dtype == torch.float64:\n            data_path = 'data_64/'\n        else:\n            data_path = 'data/'\n\n        N = 3\n        C = 1\n        M = 173\n        n_classes = 2\n        for kernel_size in [3,4]:\n            for padding in [0,1]:\n                for stride in [2,3]:\n                    print(kernel_size, padding, stride)\n\n                    model_ori = Model(kernel_size=kernel_size, padding=padding, stride=stride, in_features=M,out_features=n_classes)\n                    model_ori = model_ori.to(dtype=self.default_dtype, device=self.default_device)\n                    if not self.generate:\n                        data = torch.load(data_path + 'conv1d_test_data_{}-{}-{}'.format(kernel_size, padding, stride), weights_only=False)\n                        image = data['input'].to(dtype=self.default_dtype, device=self.default_device)\n                        model_ori(image)\n                        model_ori.load_state_dict(data['model'])\n                    else:\n                        image = torch.rand([N, C, M], dtype=self.default_dtype, device=self.default_device)\n                        model_ori(image)\n\n\n                    conv_mode = \"matrix\"\n\n                    model = BoundedModule(model_ori, image, device=self.default_device, bound_opts={\"conv_mode\": conv_mode})\n                    eps = 0.3\n                    norm = np.inf\n                    ptb = PerturbationLpNorm(norm=norm, eps=eps)\n                    image = BoundedTensor(image, ptb)\n                    lb, ub, A = model.compute_bounds((image,), return_A=True, needed_A_dict={model.output_name[0]:model.input_name[0]},)\n                    '''\n                    # 1. testing if lb == ub == pred when eps = 0\n                    assert (lb == ub).all() and torch.allclose(lb,pred,rtol=1e-5) and torch.allclose(ub,pred,rtol=1e-5)\n                    # 2. test if A matrix equals to gradient of the input\n                    # get output's grad with respect to the input without iterating through torch.autograd.grad:\n                    # https://stackoverflow.com/questions/64988010/getting-the-outputs-grad-with-respect-to-the-input\n                    uA = A[model.output_name[0]][model.input_name[0]]['uA']\n                    lA = A[model.output_name[0]][model.input_name[0]]['lA']\n                    assert (uA==lA).all()\n                    assert (torch.autograd.functional.jacobian(model_ori,image_clean).sum(dim=2)==uA).all()\n                    assert (torch.autograd.functional.jacobian(model_ori,image_clean).sum(dim=2)==lA).all()\n                    # double check\n                    input_grads = torch.zeros(uA.shape)\n                    for i in range(N):\n                        for j in range(n_classes):\n                            input_grads[i][j]=torch.autograd.grad(outputs=output_clean[i,j], inputs=image_clean, retain_graph=True)[0].sum(dim=0)\n                    assert (input_grads==uA).all()\n                    assert (input_grads==lA).all()\n                    '''\n                    # 3. test when eps = 0.3 (uncommented)\n                    if self.generate:\n                        torch.save(\n                            {'model': model_ori.state_dict(),\n                            'input': image,\n                            'lb': lb,\n                            'ub': ub}, data_path + '/conv1d_test_data_{}-{}-{}'.format(kernel_size, padding, stride)\n                        )\n\n                    if not self.generate:\n                        lb_ref = data['lb']\n                        ub_ref = data['ub']\n                        assert torch.allclose(lb, lb_ref, 1e-3)\n                        assert torch.allclose(ub, ub_ref, 1e-3)\n\n\nif __name__ == '__main__':\n    testcase = TestConv1D(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_distinct_patches.py",
    "content": "import torch\nimport random\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nimport sys\nsys.path.append('../examples/vision')\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\ndef reset_seed(seed=1234):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    np.random.seed(seed)\n\n\nclass cnn_4layer_b(nn.Module):\n    def __init__(self, paddingA, paddingB):\n        super().__init__()\n        self.padA = nn.ZeroPad2d(paddingA)\n        self.padB = nn.ZeroPad2d(paddingB)\n\n        self.conv1 = nn.Conv2d(3, 32, (5,5), stride=2, padding=0)\n        self.conv2 = nn.Conv2d(32, 128, (4,4), stride=2, padding=1)\n\n        self.linear = None\n        self.fc = nn.Linear(250, 10)\n\n    def forward(self, x):\n        x = self.padA(x)\n        x = self.conv1(x)\n        x = self.conv2(self.padB(F.relu(x)))\n        x = F.relu(x)\n        x = x.view(x.size(0), -1)\n        if self.linear is None:\n            self.linear = nn.Linear(x.size(1), 250)\n        x = self.linear(x)\n        return self.fc(F.relu(x))\n\n\nclass TestDistinctPatches(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):        \n        super().__init__(methodName,\n            seed=1234, ref_name='distinct_patches_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n        self.cases = [(2,1,2,1), (0,0,0,0), (1,3,3,1), (2,2,3,1)]\n\n        normalize = torchvision.transforms.Normalize(\n            mean = [0.4914, 0.4822, 0.4465],\n            std = [0.2023, 0.1994, 0.2010]\n        )\n        test_data = torchvision.datasets.CIFAR10(\n            \"./data\", train=False, download=True,\n            transform=torchvision.transforms.Compose([\n                torchvision.transforms.ToTensor(),\n                normalize\n            ])\n        )\n        imgs = torch.from_numpy(test_data.data[:1]).reshape(1,3,32,32).float() / 255.0\n        self.single_img = imgs.to(dtype=self.default_dtype, device=self.default_device)\n\n    def run_conv_mode(self, model, img, conv_mode):\n        model(img)  # dummy run to initialize shapes\n        model_lirpa = BoundedModule(\n            model, img, device=self.default_device,\n            bound_opts={\"conv_mode\": conv_mode}\n        )\n        ptb = PerturbationLpNorm(norm = np.inf, eps = 0.03)\n        img_perturbed = BoundedTensor(img, ptb)\n\n        lb, ub = model_lirpa.compute_bounds(\n            x=(img_perturbed,), IBP=False, C=None, method='backward'\n        )\n        return lb, ub\n\n    def test(self):\n        self.result = []\n        for paddingA in self.cases:\n            for paddingB in self.cases:\n                print(\"Testing\", paddingA, paddingB)\n                reset_seed()\n                model_ori = cnn_4layer_b(paddingA, paddingB).to(\n                    device=self.default_device, dtype=self.default_dtype\n                )\n\n                lb_patch, ub_patch = self.run_conv_mode(\n                    model_ori, self.single_img, conv_mode='patches'\n                )\n                self.result.append((lb_patch, ub_patch))\n\n                if self.generate:\n                    # We only compare with matrix mode when generating reference results\n                    lb_matrix, ub_matrix = self.run_conv_mode(\n                        model_ori, self.single_img, conv_mode='matrix'\n                    )\n                    # Check equality\n                    assert torch.allclose(lb_patch, lb_matrix), \"Lower bounds differ!\"\n                    assert torch.allclose(ub_patch, ub_matrix), \"Upper bounds differ!\"\n        \n        self.check()\n\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestDistinctPatches(generate=False)\n    testcase.test()"
  },
  {
    "path": "tests/test_examples.py",
    "content": "\"\"\"Test all the examples before release.\n\nThis script is expected be manually run and is not used in automatic tests.\"\"\"\n\nimport pytest\nimport subprocess\nimport os\nimport sys\nimport shlex\n\n\npytest_skip = pytest.mark.skip(\n    reason=\"It should be tested on a GPU server and excluded from CI\")\n\nif not 'CACHE_DIR' in os.environ:\n    cache_dir = os.path.join(os.getcwd(), '.cache')\nelse:\n    cache_dir = os.environ['CACHE_DIR']\nif not os.path.exists(cache_dir):\n    os.makedirs(cache_dir)\n\ndef download_data_language():\n    url = \"http://download.huan-zhang.com/datasets/language/data_language.tar.gz\"\n    if not os.path.exists('../examples/language/data/sst'):\n        subprocess.run(shlex.split(f\"wget {url}\"), cwd=\"../examples/language\")\n        subprocess.run(shlex.split(f\"tar xvf data_language.tar.gz\"),\n            cwd=\"../examples/language\")\n\n@pytest_skip\ndef test_transformer():\n    cmd = f\"\"\"python train.py --dir {cache_dir} --robust\n        --method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2\n        --eps_start 2 --eps_length 1 --eps 0.1\"\"\"\n    print(cmd, file=sys.stderr)\n    download_data_language()\n    subprocess.run(shlex.split(cmd), cwd='../examples/language')\n\n@pytest_skip\ndef test_lstm():\n    cmd = f\"\"\"python train.py --dir {cache_dir}\n        --model lstm --lr 1e-3 --dropout 0.5 --robust\n        --method IBP+backward_train --train --num_epochs 2 --num_epochs_all_nodes 2\n        --eps_start 2 --eps_length 1 --eps 0.1\n        --hidden_size 2 --embedding_size 2 --intermediate_size 2 --max_sent_length 4\"\"\"\n    print(cmd, file=sys.stderr)\n    download_data_language()\n    subprocess.run(shlex.split(cmd), cwd='../examples/language')\n\n@pytest_skip\ndef test_lstm_seq():\n    cmd = f\"\"\"python train.py --dir {cache_dir}\n        --hidden_size 2 --num_epochs 2 --num_slices 4\"\"\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/sequence')\n\n@pytest_skip\ndef test_simple_verification():\n    cmd = \"python simple_verification.py\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_custom_op():\n    cmd = \"python custom_op.py\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_efficient_convolution():\n    cmd = \"python efficient_convolution.py\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_two_node():\n    cmd = \"python verify_two_node.py\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_simple_training():\n    cmd = \"\"\"python simple_training.py\n        --num_epochs 5 --scheduler_opts start=2,length=2\"\"\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_cifar_training():\n    cmd = \"\"\"python cifar_training.py\n        --batch_size 64 --model ResNeXt_cifar\n        --num_epochs 5 --scheduler_opts start=2,length=2\"\"\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_weight_perturbation():\n    cmd = \"\"\"python weight_perturbation_training.py\n        --norm 2 --bound_type CROWN-IBP\n        --num_epochs 3 --scheduler_opts start=2,length=1 --eps 0.01\"\"\"\n    print(cmd, file=sys.stderr)\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_tinyimagenet():\n    cmd = f\"\"\"python tinyimagenet_training.py\n        --batch_size 32 --model wide_resnet_imagenet64\n        --num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255}\n        --in_planes 2 --widen_factor 2\"\"\"\n    print(cmd, file=sys.stderr)\n    if not os.path.exists('../examples/vision/data/tinyImageNet/tiny-imagenet-200'):\n        subprocess.run(shlex.split(\"bash tinyimagenet_download.sh\"),\n        cwd=\"../examples/vision/data/tinyImageNet\")\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n@pytest_skip\ndef test_imagenet():\n    cmd = f\"\"\"python imagenet_training.py\n        --batch_size 32 --model wide_resnet_imagenet64_1000class\n        --num_epochs 3 --scheduler_opts start=2,length=1 --eps {0.1/255}\n        --in_planes 2 --widen_factor 2\"\"\"\n    print(cmd)\n    if (not os.path.exists('../examples/vision/data/ImageNet64/train') or\n            not os.path.exists('../examples/vision/data/ImageNet64/test')):\n        print('Error: ImageNet64 dataset is not ready.')\n        return -1\n    subprocess.run(shlex.split(cmd), cwd='../examples/vision')\n\n\ndef test_release():\n    \"\"\"Run all tests that don't require a GPU server.\"\"\"\n    test_simple_verification()\n    test_custom_op()\n    test_efficient_convolution()\n    test_two_node()\n\nif __name__ == '__main__':\n    test_release()\n"
  },
  {
    "path": "tests/test_examples_ci.py",
    "content": "import subprocess\nimport traceback\n\nimport test_examples\n\noriginal_subprocess_run = subprocess.run\n\n\ndef custom_run(*args, **kwargs):\n    kwargs.setdefault('check', True)\n    return original_subprocess_run(*args, **kwargs)\n\n\nsubprocess.run = custom_run\n\n\ndef run_tests():\n    # get all func start with test in test_examples other than 'test_release'\n    # and 'test_cifar_training'(cannot run on GPU with memory lower than 32GB)\n    test_functions = [\n        getattr(test_examples, func) for func in dir(test_examples)\n        if callable(getattr(test_examples, func)) and func.startswith('test')\n        and func not in ['test_release']\n    ]\n\n    try:\n        for test_func in test_functions:\n            test_func()\n            print(f\"{test_func.__name__} executed successfully.\")\n    except Exception as e:\n        print(f\"Exception in {test_func.__name__}: {e}\")\n        traceback.print_exc()  # Print detailed exception information\n\n        print(\"Examples Test Result:\")\n        print(\"\\nFailed tests:\")\n        print(test_func.__name__)\n        raise\n\n    print(\"Examples Test Result:\")\n    print(\"\\nAll tests passed successfully.\")\n\n\nif __name__ == '__main__':\n    run_tests()\n"
  },
  {
    "path": "tests/test_general_nonlinear.py",
    "content": "import sys\nimport pytest\nimport torch.nn as nn\n\nsys.path.insert(0, '../complete_verifier')\n\nimport arguments\nfrom beta_CROWN_solver import LiRPANet\nfrom bab import general_bab\n\nfrom auto_LiRPA import BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass Sin(nn.Module):\n    def forward(self, x):\n        return torch.sin(x)\n\n\ndef cifar_model_wide():\n    # cifar wide\n    model = nn.Sequential(\n        nn.Conv2d(3, 16, 4, stride=2, padding=1),\n        Sin(),\n        nn.Conv2d(16, 32, 4, stride=2, padding=1),\n        Sin(),\n        nn.Flatten(),\n        nn.Linear(32 * 8 * 8, 100),\n        Sin(),\n        nn.Linear(100, 10)\n    )\n    return model\n\n\ndef bab(model_ori, data, target, norm, eps, data_max=None, data_min=None, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n    data = data.to(device=device, dtype=dtype)\n    eps = eps.to(device=device, dtype=dtype)\n    if norm == np.inf:\n        if data_max is None:\n            data_ub = data + eps\n            data_lb = data - eps\n        else:\n            data_max = data_max.to(device=device, dtype=dtype)\n            data_min = data_min.to(device=device, dtype=dtype)\n            data_ub = torch.min(data + eps, data_max)\n            data_lb = torch.max(data - eps, data_min)\n    else:\n        data_ub = data_lb = data\n    pred = torch.argmax(model_ori(data), dim=1)\n\n    c = torch.zeros((1, 1, 10), device=device, dtype=dtype) # we only support c with shape of (1, 1, n)\n    c[0, 0, pred] = 1\n    c[0, 0, target] = -1\n    rhs = torch.tensor(arguments.Config[\"bab\"][\"decision_thresh\"], dtype=dtype, device=device).view(c.shape[:2])\n\n    arguments.Config.parse_config(args={})\n\n    arguments.Config['general']['device'] = 'cpu'\n    arguments.Config[\"solver\"][\"batch_size\"] = 200\n    arguments.Config[\"bab\"][\"decision_thresh\"] = np.float64(10)  # naive float obj has no max() function, np.inf will lead infeasible domain\n    arguments.Config[\"solver\"][\"beta-crown\"][\"iteration\"] = 20\n    arguments.Config[\"bab\"][\"timeout\"] = 60 #300\n\n    arguments.Config[\"solver\"][\"alpha-crown\"][\"lr_alpha\"] = 0.1\n    arguments.Config[\"solver\"][\"beta-crown\"][\"lr_beta\"] = 0.1\n    arguments.Config[\"bab\"][\"branching\"][\"method\"] = 'nonlinear'\n    arguments.Config[\"bab\"][\"branching\"][\"candidates\"] = 2\n    arguments.Config[\"general\"][\"enable_incomplete_verification\"] = False\n    arguments.Config[\"data\"][\"dataset\"] = 'cifar'\n\n    # LiRPA wrapper\n    model = LiRPANet(model_ori, device=device, in_size=(1, 3, 32, 32))\n\n    ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)\n    x = BoundedTensor(data, ptb)\n    forward = model_ori(x)\n\n    min_lb = general_bab(model, x, c, rhs)[0]\n\n    if isinstance(min_lb, torch.Tensor):\n        min_lb = min_lb.item()\n\n    min_lb += arguments.Config[\"bab\"][\"decision_thresh\"]\n    print(min_lb)\n\n    assert min_lb < torch.min(forward)\n\n# This test takes long time so it is set as the last test case.\n@pytest.mark.skip(reason=\"The test is failing now after removing index clamping.\")\n# @pytest.mark.order(-1)\ndef test(device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n    model_ori = cifar_model_wide()\n    data = torch.load('data/beta_crown_test_data')\n    model_ori.load_state_dict(data['state_dict'])\n    model_ori = model_ori.to(device=device, dtype=dtype)\n    x = data['x']\n    pidx = data['pidx']\n    eps_temp = data['eps_temp']\n    data_max = data['data_max']\n    data_min = data['data_min']\n\n    bab(model_ori, x, pidx, float('inf'), eps_temp, data_max=data_max, data_min=data_min, device=device, dtype=dtype)\n\n\nif __name__ == \"__main__\":\n    test()\n"
  },
  {
    "path": "tests/test_general_shape.py",
    "content": "\"\"\" Test inputs of general shapes (especially for matmul)\"\"\"\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm\nfrom auto_LiRPA.operators import BoundMatMul\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nBATCH_SIZE = 2\n\nclass GeneralShapeModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.weight_1 = nn.Parameter(torch.randn(3, 4))\n        self.weight_2 = nn.Parameter(torch.randn(4, 3))\n        self.weight_3 = nn.Parameter(torch.randn(3, 4))\n        self.weight_4 = nn.Parameter(torch.randn(4, 4, 3))\n        self.weight_5 = nn.Parameter(torch.randn(6, 3, 4))\n        self.weight_6 = nn.Parameter(torch.randn(3, 5))\n        self.relu = nn.ReLU()\n        \n    def forward(self, x, w):\n        # Basic MatMul (B, 3) @ (3, 4) -> (B, 4)\n        y1 = x.matmul(self.weight_1)\n\n        # BoundUnsqueeze and BoundTile\n        y2 = self.relu(y1)\n        y2 = y2.unsqueeze(1).repeat(1, 5, 1)   # (B, 5, 4)\n        y2 = y2.matmul(self.weight_2)   # (B, 5, 4) @ (4, 3) -> (B, 5, 3)\n\n        # More dimensions on x\n        y3 = self.relu(y2)\n        y3 = y3.unsqueeze(1).repeat(1, 4, 1, 1)     # (B, 4, 5, 3)\n        y3 = y3.matmul(self.weight_3)   # (B, 4, 5, 3) @ (3, 4) -> (B, 4, 5, 4)\n\n        # More dimensions on weight\n        y4 = self.relu(y3)\n        y4 = y4.matmul(self.weight_4)   # (B, 4, 5, 4) @ (4, 4, 3) -> (B, 4, 5, 3)\n\n        # Automatically broadcast x\n        y5 = self.relu(y4)\n        y5 = y5.unsqueeze(2)   # (B, 4, 1, 5, 3)\n        y5 = y5.matmul(self.weight_5)   # (B, 4, 1, 5, 3) @ (6, 3, 4) -> (B, 4, 6, 5, 4)\n\n        # Multiply with a weight with batch dimension\n        y6 = self.relu(y5)\n        y6 = y6.matmul(w)   # (B, 4, 6, 5, 4) @ (B, 4, 6, 4, 3) -> (B, 4, 6, 5, 3)\n\n        # Swap x and weight\n        y7 = self.relu(y6)\n        y7 = self.weight_6.matmul(y7)   # (3, 5) @ (B, 4, 6, 5, 3) -> (B, 4, 6, 3, 3)\n\n        return y7\n\nclass TestGeneralShape(TestCase):\n    def __init__(self, methodName='runTest', seed=1, generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed, 'test_general_shape_data', generate, device=device, dtype=dtype)\n        self.rtol = 1e-4\n\n    def test(self):\n        model = GeneralShapeModel().to(device=self.default_device, dtype=self.default_dtype)\n        input = torch.randn(\n            (BATCH_SIZE, 3), device=self.default_device, dtype=self.default_dtype)\n        eps = 100\n        ptb = PerturbationLpNorm(norm=np.inf, eps=eps)\n        x = BoundedTensor(input, ptb)\n        # w is an unperturbed input, but still have batch dimension\n        w = torch.randn((BATCH_SIZE, 4, 6, 4, 3),\n                        device=self.default_device, dtype=self.default_dtype)\n        lirpa_model = BoundedModule(model, (x, w), device=self.default_device)\n\n        lb, ub = lirpa_model.compute_bounds((x, w), method=\"backward\")\n\n        # # Test by sampling\n        # sample_ptb = torch.rand(BATCH_SIZE, *input.shape[1:]) * 2 * eps - eps\n        # sample_inputs = input[0] + sample_ptb\n        # sample_output = model(sample_inputs, w)\n        # assert (sample_output <= ub).all()\n        # assert (sample_output >= lb).all()\n\n        self.result = []\n        for node in lirpa_model.nodes():\n            if type(node) == BoundMatMul:\n                self.result.append((node.lower, node.upper))\n        self.result.append((lb, ub))\n\n        self.check()\n\nif __name__ == '__main__':\n    testcase = TestGeneralShape(generate=False)\n    testcase.setUp()\n    testcase.test()"
  },
  {
    "path": "tests/test_identity.py",
    "content": "\"\"\"Test a model with an nn.Identity layer only\"\"\"\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE \n\nclass TestIdentity(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def test(self):\n        model = nn.Sequential(nn.Identity())\n        x = torch.randn(2, 10, device=self.default_device,\n                        dtype=self.default_dtype)\n        y = model(x)\n        eps = 0.1\n        ptb = PerturbationLpNorm(norm=np.inf, eps=eps)\n        x = BoundedTensor(x, ptb)\n        model = BoundedModule(model, x, device=self.default_device)\n        y_l, y_u = model.compute_bounds()\n        self.assertEqual(torch.Tensor(x), y)\n        self.assertEqual(y_l, x - eps)\n        self.assertEqual(y_u, x + eps)\n\n\nif __name__ == '__main__':\n    testcase = TestIdentity()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_invprop.py",
    "content": "\"\"\"Test INVPROP.\"\"\"\nimport sys\nsys.path.append('../complete_verifier')\nfrom complete_verifier.load_model import unzip_and_optimize_onnx\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass SimpleExampleModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # Weights of linear layers.\n        self.w1 = torch.tensor([[1., -1.], [2., -1.]])\n        self.w2 = torch.tensor([[1., -1.]])\n\n    def forward(self, x):\n        # Linear layer.\n        z1 = x.matmul(self.w1.t())\n        # Relu layer.\n        hz1 = torch.nn.functional.relu(z1)\n        # Linear layer.\n        z2 = hz1.matmul(self.w2.t())\n        return z2\n\nclass TestInvpropSimpleExample(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name=None,\n                         generate=generate, device=device, dtype=dtype)\n\n    def test(self):\n        np.random.seed(123)\n\n        model_ori = SimpleExampleModel().to(\n            device=self.default_device, dtype=self.default_dtype)\n\n        apply_output_constraints_to = ['BoundMatMul', 'BoundInput']\n        x = torch.tensor([[1., 1.]], device=self.default_device,\n                         dtype=self.default_dtype)\n        model = BoundedModule(model_ori, torch.empty_like(x), bound_opts={\n            'optimize_bound_args': {\n                'apply_output_constraints_to': apply_output_constraints_to,\n                'tighten_input_bounds': True,\n                'best_of_oc_and_no_oc': False,\n                'directly_optimize': [],\n                'oc_lr': 0.1,\n                'share_gammas': False,\n                'iteration': 1000,\n            }\n        },\n            device=self.default_device\n        )\n        model.constraints = torch.ones(\n            1, 1, 1, device=self.default_device, dtype=self.default_dtype)\n        model.thresholds = torch.tensor(\n            [-1.], device=self.default_device, dtype=self.default_dtype)\n\n        norm = float(\"inf\")\n        lower = torch.tensor(\n            [[-1., -2.]], device=self.default_device, dtype=self.default_dtype)\n        upper = torch.tensor(\n            [[2., 1.]], device=self.default_device, dtype=self.default_dtype)\n        ptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper)\n        bounded_x = BoundedTensor(x, ptb)\n\n        lb, ub = model.compute_bounds(x=(bounded_x,), method='alpha-CROWN')\n\n        if '/0' in model._modules:\n            tightened_ptb = model['/0'].perturbation\n        else:\n            tightened_ptb = model['/x'].perturbation\n\n        if self.default_dtype == torch.float64:\n            data_path = 'data_64/'\n        else:\n            data_path = 'data/'\n\n        if self.generate:\n            torch.save({\n                'lb': lb,\n                'ub': ub,\n                'x_L': tightened_ptb.x_L,\n                'x_U': tightened_ptb.x_U\n            }, data_path + 'invprop/simple_reference')\n        else:\n            data = torch.load(data_path + 'invprop/simple_reference')\n            lb_ref = data['lb']\n            ub_ref = data['ub']\n            x_L_ref = data['x_L']\n            x_U_ref = data['x_U']\n\n            assert torch.allclose(lb, lb_ref, 1e-4)\n            assert torch.allclose(ub, ub_ref, 1e-4)\n            assert torch.allclose(tightened_ptb.x_L, x_L_ref, 1e-4)\n            assert torch.allclose(tightened_ptb.x_U, x_U_ref, 1e-4)\n\nclass TestInvpropOODExample(TestCase):\n    # Based on https://github.com/kothasuhas/verify-input/tree/main/examples/ood\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed=1, ref_name=None,\n                         generate=generate, device=device, dtype=dtype)\n\n    def test(self):\n        np.random.seed(123)\n\n        import onnx2pytorch\n\n        model_ori = onnx2pytorch.ConvertModel(unzip_and_optimize_onnx('data/invprop/ood.onnx')).eval()\n        model_ori = model_ori.to(\n            device=self.default_device, dtype=self.default_dtype)\n\n        x = torch.tensor([[1., 1.]], device=self.default_device,\n                         dtype=self.default_dtype)\n        model = BoundedModule(model_ori, torch.empty_like(x), bound_opts={\n            'optimize_bound_args': {\n                'apply_output_constraints_to': ['BoundInput', \"/input\", \"/input-3\", \"/21\"],\n                'tighten_input_bounds': True,\n                'best_of_oc_and_no_oc': True,\n                'directly_optimize': ['/input'],\n                'oc_lr': 0.01,\n                'iteration': 1000,\n                'share_gammas': False,\n                'lr_decay': 0.99,\n                'early_stop_patience': 1000,\n                'init_alpha': False,\n                'lr_alpha': 0.4,\n                'start_save_best': -1,\n            }\n        },\n            device=self.default_device\n        )\n        model.constraints = torch.tensor(\n            [[[-1., 0., 1.]], [[0., -1., 1.]]], device=self.default_device, dtype=self.default_dtype)\n        model.thresholds = torch.tensor(\n            [0., 0.], device=self.default_device, dtype=self.default_dtype)\n\n        norm = float(\"inf\")\n        lower = torch.tensor(\n            [[-2., -2.], [-2., -2.]], device=self.default_device, dtype=self.default_dtype)\n        upper = torch.tensor(\n            [[0., 0.], [0., 0.]], device=self.default_device, dtype=self.default_dtype)\n        ptb = PerturbationLpNorm(norm = norm, x_L=lower, x_U=upper)\n        x_expand = BoundedTensor(torch.tensor(\n            [[-1., -1.], [-1., -1.]], device=self.default_device, dtype=self.default_dtype), ptb)\n        c = torch.tensor([[[-1.,  0.,  1.]], [[0., -1.,  1.]]],\n                         device=self.default_device, dtype=self.default_dtype)\n\n        # Init manually, to set bound_upper=False\n        model.init_alpha(\n                (x_expand,), share_alphas=False, c=c, bound_upper=False)\n\n        model.compute_bounds(x=(x_expand,), C=c, method='CROWN-Optimized')\n\n        if self.default_dtype == torch.float64:\n            data_path = 'data_64/'\n        else:\n            data_path = 'data/'\n\n        if self.generate:\n            torch.save({\n                'lower': model['/input'].lower,\n                'upper': model['/input'].upper,\n            }, data_path + 'invprop/ood_reference')\n        else:\n            data = torch.load(data_path + 'invprop/ood_reference')\n            lower_ref = data['lower']\n            upper_ref = data['upper']\n\n            lower_diff = model['/input'].lower[0] - lower_ref[0]\n            assert torch.allclose(model['/input'].lower[0], lower_ref[0], atol=1e-3), (lower_diff, lower_diff.abs().max())\n            assert torch.all(torch.isposinf(lower_ref[1]))\n            assert torch.all(torch.isposinf(model['/input'].lower[1]))\n            upper_diff = model['/input'].upper[0] - upper_ref[0]\n            assert torch.allclose(model['/input'].upper[0], upper_ref[0], atol=1e-3), (upper_diff, upper_diff.abs().max())\n            assert torch.all(torch.isneginf(upper_ref[1]))\n            assert torch.all(torch.isneginf(model['/input'].upper[1]))\n\n\nif __name__ == '__main__':\n    testcase = TestInvpropSimpleExample(generate=False)\n    testcase.test()\n    testcase = TestInvpropOODExample(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_jacobian.py",
    "content": "# pylint: disable=wrong-import-position\n\"\"\"Test Jacobian bounds.\"\"\"\nimport sys\nimport torch\nimport torch.nn as nn\n\nsys.path.append('../examples/vision')\nfrom jacobian import compute_jacobians\nfrom auto_LiRPA import BoundedModule\nfrom auto_LiRPA.utils import Flatten\nfrom auto_LiRPA.jacobian import JacobianOP\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass TestJacobian(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(\n            methodName, seed=1, ref_name='jacobian_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        in_dim, linear_size = 8, 100\n        model = nn.Sequential(\n            Flatten(),\n            nn.Linear(3*in_dim**2, linear_size),\n            nn.ReLU(),\n            nn.Linear(linear_size, linear_size),\n            nn.Tanh(),\n            nn.Linear(linear_size, linear_size),\n            nn.Sigmoid(),\n            nn.Linear(linear_size, 10),\n        )\n        model = model.to(device=self.default_device, dtype=self.default_dtype)\n        x0 = torch.randn(1, 3, in_dim, in_dim,\n                         device=self.default_device, dtype=self.default_dtype)\n        self.result = compute_jacobians(model, x0)\n        self.check()\n\n    def test_concat_jacobian(self):\n        '''\n        Test JacobianOP with Concat operation. This needs some special handling\n        in auto_LiRPA to make it work properly. (See parse_graph.py for details.)\n        '''\n        class ConcatModule(nn.Module):\n            def forward(self, x):\n                return JacobianOP.apply(torch.cat([x, x], dim=1), x)\n        concatmodel = ConcatModule().to(device=self.default_device, dtype=self.default_dtype)\n        x0 = torch.randn(1, 5, device=self.default_device, dtype=self.default_dtype)\n        BoundedModule(concatmodel, x0)\n        print('Concat JacobianOP test passed.')\n\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestJacobian(generate=False)\n    testcase.setUp()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_language_models.py",
    "content": "\"\"\"Test classes for Transformer and LSTM on language tasks\"\"\"\nimport os\nimport argparse\nimport pickle\nimport torch\nfrom auto_LiRPA.utils import logger\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--gen_ref', action='store_true', help='generate reference results')\nparser.add_argument('--train', action='store_true', help='pre-train the models')\nparser.add_argument('--keep_results', action='store_true', help='keep intermediate results.')\nparser.add_argument('--load_results', action='store_true', help='load intermediate results without reruning.')\nargs, unknown = parser.parse_known_args()\n\ndef prepare_data():\n    os.system('cd ../examples/language;\\\n        wget http://download.huan-zhang.com/datasets/language/data_language.tar.gz;\\\n        tar xvf data_language.tar.gz')\n\ncmd_transformer_train = 'cd ../examples/language; \\\n    DIR=model_transformer_test; \\\n    python train.py --hidden_size=16 --embedding_size=16 --intermediate_size=16 --max_sent_length=16 \\\n    --dir=$DIR --robust --method=IBP+backward_train \\\n    --num_epochs=2 --num_epochs_all_nodes=1 --eps_start=2 --train'\ncmd_transformer_test = 'cd ../examples/language; \\\n    python train.py --hidden_size=16 --embedding_size=16 --intermediate_size=16 --max_sent_length=16 \\\n    --robust --method=IBP+backward --budget=1 --auto_test --eps=0.2 --load=../../tests/data/ckpt_transformer \\\n    --device=cpu'\ncmd_lstm_train = 'cd ../examples/language; \\\n    DIR=model_lstm_test; \\\n    python train.py  --hidden_size=16 --embedding_size=16 --max_sent_length=16 \\\n    --dir=$DIR --model=lstm --lr=1e-3 --robust --method=IBP+backward_train --dropout=0.5 \\\n    --num_epochs=2 --num_epochs_all_nodes=1 --eps_start=2 --train'\ncmd_lstm_test = 'cd ../examples/language; \\\n    python train.py --model=lstm --hidden_size=16 --embedding_size=16 --max_sent_length=16 \\\n    --robust --method=IBP+backward --budget=1 --auto_test --eps=0.2 --load=../../tests/data/ckpt_lstm \\\n    --device=cpu'\nres_path = '../examples/language/res_test.pkl'\n\n\"\"\"Pre-train a simple Transformer and LSTM respectively\"\"\"\ndef train():\n    if os.path.exists(\"../examples/language/model_transformer_test\"):\n        os.system(\"rm -rf ../examples/language/model_transformer_test\")\n    if os.path.exists(\"../examples/language/model_lstm_test\"):\n        os.system(\"rm -rf ../examples/language/model_lstm_test\")\n    logger.info(\"\\nTraining a Transformer\")\n    print(cmd_transformer_train)\n    print()\n    os.system(cmd_transformer_train)\n    os.system(\"cp ../examples/language/model_transformer_test/ckpt_2 data/ckpt_transformer\")\n    logger.info(\"\\nTraining an LSTM\")\n    print(cmd_lstm_train)\n    print()\n    os.system(cmd_lstm_train)\n    os.system(\"cp ../examples/language/model_lstm_test/ckpt_2 data/ckpt_lstm\")\n\ndef read_res():\n    with open(res_path, 'rb') as file:\n        return pickle.load(file)\n\ndef evaluate():\n    if args.load_results:\n        print(\"loading intermediate results...\")\n        with open(\"./tmp_language_results.pkl\", \"rb\") as file:\n            return pickle.load(file)\n    logger.info('\\nEvaluating the trained LSTM')\n    print(cmd_lstm_test)\n    print()\n    os.system(cmd_lstm_test)\n    res_lstm = read_res()\n    logger.info('\\nEvaluating the trained Transformer')\n    print(cmd_transformer_test)\n    print()\n    os.system(cmd_transformer_test)\n    res_transformer = read_res()\n    os.system(\"rm {}\".format(res_path))\n    if args.keep_results:\n        with open(\"./tmp_language_results.pkl\", \"wb\") as file:\n            pickle.dump((res_transformer, res_lstm), file)\n        print(\"intermediate results saved.\")\n    return res_transformer, res_lstm\n\ndef gen_ref():\n    if args.train:\n        train()\n    res_transformer, res_lstm = evaluate()\n    with open('data/language_test_data', 'wb') as file:\n        pickle.dump((res_transformer, res_lstm), file)\n    logger.info('Reference results saved')\n\ndef check():\n    with open('data/language_test_data', 'rb') as file:\n        res_transformer_ref, res_lstm_ref = pickle.load(file)\n    res_transformer, res_lstm = evaluate()\n    for res, res_ref in zip([res_transformer, res_lstm], [res_transformer_ref, res_lstm_ref]):\n        for a, b in zip(res, res_ref):\n            ta, tb = torch.tensor(a), torch.tensor(b)\n            diff = torch.max(torch.abs(ta - tb))\n            assert diff < 1e-5, diff\n            assert (torch.tensor(a) - torch.tensor(b)).pow(2).sum() < 1e-9\n\ndef test():\n    if not os.path.exists('../examples/language/data'):\n        prepare_data()\n    if args.gen_ref:\n        gen_ref()\n    else:\n        check()\n    logger.info(\"test_Language done\")\n\nif __name__ == '__main__':\n    test()\n"
  },
  {
    "path": "tests/test_linear_cnn_model.py",
    "content": "\"\"\"Test bounds on a 1 layer CNN network.\"\"\"\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom test_linear_model import TestLinearModel\nfrom testcase import DEFAULT_DEVICE, DEFAULT_DTYPE\n\ninput_dim = 8\nout_channel = 2\nN = 10\n\nclass LinearCNNModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = nn.Conv2d(1, out_channel, 3, stride=2, padding=1)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = x.view(-1, input_dim //2 * input_dim // 2 * out_channel)\n        return x\n\nclass TestLinearCNNModel(TestLinearModel):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n        self.original_model = LinearCNNModel().to(device=device, dtype=dtype)\n\n    def compute_and_compare_bounds(self, eps, norm, IBP, method):\n        input_data = torch.randn((N, 1, input_dim, input_dim))\n        model = BoundedModule(self.original_model, torch.empty_like(input_data), device=self.default_device)\n        ptb = PerturbationLpNorm(norm=norm, eps=eps)\n        ptb_data = BoundedTensor(input_data, ptb)\n        pred = model(ptb_data)\n        label = torch.argmax(pred, dim=1).cpu().detach().numpy()\n        # Compute bounds.\n        lb, ub = model.compute_bounds(IBP=IBP, method=method)\n        # Compute reference.\n        conv_weight, conv_bias = list(model.parameters())\n        conv_bias = conv_bias.view(1, out_channel, 1, 1)\n        matrix_eye = torch.eye(input_dim * input_dim).view(input_dim * input_dim, 1, input_dim, input_dim)\n        # Obtain equivalent weight and bias for convolution.\n        weight = self.original_model.conv(matrix_eye) - conv_bias # Output is (batch, channel, weight, height).\n        weight = weight.view(input_dim * input_dim, -1) # Dimension is (flattened_input, flattened_output).\n        bias = conv_bias.repeat(1, 1, input_dim //2, input_dim //2).view(-1)\n        flattend_data = input_data.view(N, -1)\n        # Compute dual norm.\n        if norm == 1:\n            q = np.inf\n        elif norm == np.inf:\n            q = 1.0\n        else:\n            q = 1.0 / (1.0 - (1.0 / norm))\n        # Manually compute bounds.\n        norm = weight.t().norm(p=q, dim=1)\n        expected_pred = flattend_data.matmul(weight) + bias\n        expected_ub = eps * norm + expected_pred\n        expected_lb = -eps * norm + expected_pred\n        # Check equivalence.\n        if method == 'backward' or method == 'forward':\n            self.rtol = 1e-4\n            self.assertEqual(expected_pred, pred)\n            self.assertEqual(expected_ub, ub)\n            self.assertEqual(expected_lb, lb)\n"
  },
  {
    "path": "tests/test_linear_model.py",
    "content": "\"\"\"Test bounds on a 1 layer linear network.\"\"\"\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nn_classes = 3\nN = 10\n\nclass LinearModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc = nn.Linear(256, n_classes)\n\n    def forward(self, x):\n        x = self.fc(x)\n        return x\n\nclass TestLinearModel(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed=0, device=device, dtype=dtype)\n        self.original_model = LinearModel().to(device=device, dtype=dtype)\n\n    def compute_and_compare_bounds(self, eps, norm, IBP, method):\n        input_data = torch.randn(\n            (N, 256), device=self.default_device, dtype=self.default_dtype)\n        model = BoundedModule(self.original_model, torch.empty_like(input_data), device=self.default_device)\n        ptb = PerturbationLpNorm(norm=norm, eps=eps)\n        ptb_data = BoundedTensor(input_data, ptb)\n        pred = model(ptb_data)\n        label = torch.argmax(pred, dim=1).cpu().detach().numpy()\n        # Compute bounds.\n        lb, ub = model.compute_bounds(IBP=IBP, method=method)\n        # Compute dual norm.\n        if norm == 1:\n            q = np.inf\n        elif norm == np.inf:\n            q = 1.0\n        else:\n            q = 1.0 / (1.0 - (1.0 / norm))\n        # Compute reference manually.\n        weight, bias = list(model.parameters())\n        norm = weight.norm(p=q, dim=1)\n        expected_pred = input_data.matmul(weight.t()) + bias\n        expected_ub = eps * norm + expected_pred\n        expected_lb = -eps * norm + expected_pred\n\n        # Check equivalence.\n        self.rtol = 1e-4\n        self.assertEqual(expected_pred, pred)\n        self.assertEqual(expected_ub, ub)\n        self.assertEqual(expected_lb, lb)\n\n    def test_Linf_forward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=0.3, norm=np.inf, IBP=False, method='forward')\n\n    def test_Linf_backward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=0.3, norm=np.inf, IBP=False, method='backward')\n\n    def test_Linf_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=0.3, norm=np.inf, IBP=True, method=None)\n\n    def test_Linf_backward_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=0.3, norm=np.inf, IBP=True, method='backward')\n\n    def test_L2_forward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=1.0, norm=2, IBP=False, method='forward')\n\n    def test_L2_backward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=1.0, norm=2, IBP=False, method='backward')\n\n    def test_L2_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=1.0, norm=2, IBP=True, method=None)\n\n    def test_L2_backward_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=1.0, norm=2, IBP=True, method='backward')\n\n    def test_L1_forward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=3.0, norm=1, IBP=False, method='forward')\n\n    def test_L1_backward(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=3.0, norm=1, IBP=False, method='backward')\n\n    def test_L1_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=3.0, norm=1, IBP=True, method=None)\n\n    def test_L1_backward_IBP(self):\n        with np.errstate(divide='ignore'):\n            self.compute_and_compare_bounds(eps=3.0, norm=1, IBP=True, method='backward')\n\n"
  },
  {
    "path": "tests/test_maxpool.py",
    "content": "\"\"\"Test max pooling.\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass Model(nn.Module):\n    def __init__(self, kernel_size=4, stride=4, padding=0, conv_padding=0):\n        super(Model, self).__init__()\n        self.n_n_conv2d = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 1, 'padding': conv_padding, 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True})\n        self.n_n_maxpool = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]})\n        self.n_n_conv2d_2 = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 1, 'padding': [conv_padding, conv_padding], 'kernel_size': (2, 2), 'stride': [1, 1], 'in_channels': 1, 'bias': True})\n        self.n_n_maxpool_2 = nn.MaxPool2d(**{'kernel_size': [kernel_size, kernel_size], 'ceil_mode': False, 'stride': [stride, stride], 'padding': [padding, padding]})\n        self.n_n_flatten_Flatten = nn.Flatten(**{'start_dim': 1})\n\n        self.n_n_dense = None\n\n        self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1})\n\n    def forward(self, *inputs):\n        t_ImageInputLayer, = inputs\n        t_conv2d = self.n_n_conv2d(t_ImageInputLayer)\n        t_conv2d_relu = F.relu(t_conv2d)\n        t_maxpool = self.n_n_maxpool(t_conv2d_relu)[:, :, :, :]\n        t_conv2d_max = self.n_n_conv2d_2(t_maxpool)\n        t_conv2d_max = F.relu(t_conv2d_max)\n        # t_maxpool_2 = self.n_n_maxpool_2(t_conv2d_max)\n        t_flatten_Transpose = t_conv2d_max.permute(*[0, 2, 3, 1])\n        t_flatten_Flatten = self.n_n_flatten_Flatten(t_flatten_Transpose)\n        t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Flatten, 2)\n        t_flatten_Unsqueeze = torch.unsqueeze(t_flatten_Unsqueeze, 3)\n\n        if self.n_n_dense is None:\n            self.n_n_dense = nn.Conv2d(**{'groups': 1, 'dilation': [1, 1], 'out_channels': 2, 'padding': [0, 0], 'kernel_size': (1, 1), 'stride': [1, 1], 'in_channels': t_flatten_Unsqueeze.shape[1], 'bias': True})\n        t_dense = self.n_n_dense(t_flatten_Unsqueeze)\n        t_activation_Flatten = self.n_n_activation_Flatten(t_dense)\n\n        return t_activation_Flatten\n\nclass TestMaxPool(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name=None,\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        if self.default_dtype == torch.float64:\n            data_path = 'data_64/'\n        else:\n            data_path = 'data/'\n\n        N = 2\n\n        for kernel_size in [3,4]:\n            for padding in [0,1]:\n                for conv_padding in [0,1]:\n                    print(kernel_size, padding, kernel_size, conv_padding)\n\n                    model_ori = Model(kernel_size=kernel_size, padding=padding, stride=kernel_size, conv_padding=conv_padding).to(\n                        device=self.default_device, dtype=self.default_dtype)\n                    if not self.generate:\n                        data = torch.load(data_path + 'maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding), weights_only=False)\n                        image = data['input']\n                        model_ori(image)\n                        model_ori.load_state_dict(data['model'])\n                    else:\n                        image = torch.rand([N, 1, 28, 28])\n                        model_ori(image)\n\n                    if self.generate:\n                        conv_mode = \"matrix\"\n                    else:\n                        conv_mode = \"patches\"\n\n                    model = BoundedModule(model_ori, image, device=self.default_device, bound_opts={\"conv_mode\": conv_mode})\n                    eps = 0.3\n                    norm = np.inf\n                    ptb = PerturbationLpNorm(norm=norm, eps=eps)\n                    image = BoundedTensor(image, ptb)\n\n                    lb, ub = model.compute_bounds((image,))\n\n                    if self.generate:\n                        torch.save(\n                            {'model': model_ori.state_dict(),\n                            'input': image,\n                            'lb': lb,\n                            'ub': ub}, data_path + 'maxpool_test_data_{}-{}-{}-{}'.format(kernel_size, padding, kernel_size, conv_padding)\n                        )\n\n                    if not self.generate:\n                        lb_ref = data['lb']\n                        ub_ref = data['ub']\n\n                        assert torch.allclose(lb, lb_ref, 1e-4)\n                        assert torch.allclose(ub, ub_ref, 1e-4)\n\n\nif __name__ == '__main__':\n    testcase = TestMaxPool(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_min_max.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE \n\nclass Test_Model(nn.Module):\n    def __init__(self):\n        super(Test_Model, self).__init__()\n\n        self.seq1 = nn.Sequential(\n            nn.Conv2d(1, 16, 4, stride=2, padding=1),\n            nn.ReLU(),\n            nn.Conv2d(16, 32, 4, stride=2, padding=1)\n        )\n\n        self.seq2 = nn.Sequential(\n            nn.Conv2d(1, 16, 4, stride=2, padding=1),\n            nn.ReLU(),\n            nn.Conv2d(16, 32, 4, stride=2, padding=1)\n        )\n\n        self.seq3 = nn.Sequential(\n            nn.Conv2d(32, 8, 2, stride=2, padding=1),\n            nn.ReLU(),\n            Flatten(),\n            nn.Linear(8*4*4,100),\n            nn.ReLU(),\n            nn.Linear(100, 10)\n        )\n\n    def forward(self, x):\n        return self.seq3(torch.max(self.seq1(x), self.seq2(x)))\n\nclass TestMinMax(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1, ref_name='min_max_test_data', generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        self.result = []\n        for conv_mode in ['patches', 'matrix']:\n            for use_shared_alpha in [True, False]:\n                model = Test_Model().to(device=self.default_device, dtype=self.default_dtype)\n                checkpoint = torch.load(\n                    os.path.join(os.path.dirname(__file__), '../examples/vision/pretrained/test_min_max.pth'),\n                    map_location=self.default_device)\n                model.load_state_dict(checkpoint)\n\n                test_data = torchvision.datasets.MNIST(\n                    './data', train=False, download=True,\n                    transform=torchvision.transforms.ToTensor())\n\n                N = 2\n                image = test_data.data[:N].view(N,1,28,28)\n                image = image.to(device=self.default_device,\n                                 dtype=self.default_dtype) / 255.0\n\n                lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device, bound_opts={\"conv_mode\": conv_mode})\n\n                eps = 0.3\n                ptb = PerturbationLpNorm(eps = eps)\n                image = BoundedTensor(image, ptb)\n\n                lirpa_model.set_bound_opts({\n                    'optimize_bound_args': {\n                        'iteration': 5,\n                        'lr_alpha': 0.1,\n                        'use_shared_alpha': use_shared_alpha,\n                    }\n                })\n                lb, ub = lirpa_model.compute_bounds(x=(image,), method='CROWN-Optimized')\n                print(lb, ub)\n\n                self.result.append((lb, ub))\n\n        self.setUp()\n        self.rtol = 1e-4\n        self.check()\n\nif __name__ == \"__main__\":\n    testcase = TestMinMax(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_perturbation.py",
    "content": "\"\"\" Test different Perturbation classes\"\"\"\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm, PerturbationLinear\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nBATCH = 2\nIN_DIM = 3\nOUT_DIM = 4\n\n\nclass ToyModel(nn.Module):\n    \"\"\"Small model with two MatMuls and ReLU.\"\"\"\n    def __init__(self):\n        super().__init__()\n        self.fc1 = nn.Linear(OUT_DIM, 8)\n        self.fc2 = nn.Linear(8, OUT_DIM)\n        self.relu = nn.ReLU()\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        return x\n\n\nclass TestPerturbation(TestCase):\n    \"\"\"\n    Tests for:\n    - PerturbationLinear\n    - PerturbationLpNorm\n    \"\"\"\n    def __init__(self, methodName='runTest', seed=1, generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed, 'test_perturbation_data', generate, device=device, dtype=dtype)\n\n    def test(self):\n        device = self.default_device\n        dtype = self.default_dtype\n\n        model = ToyModel().to(device=device, dtype=dtype)\n\n        # Prepare base input interval\n        input_lb = torch.rand(BATCH, IN_DIM, device=device, dtype=dtype)\n        input_ub = input_lb + torch.rand_like(input_lb)    # ensure ub > lb\n\n        self.result = []\n\n        # =================================================================\n        # Test PerturbationLinear\n        # =================================================================\n        # Build A matrices\n        lower_A = torch.randn(BATCH, OUT_DIM, IN_DIM, device=device, dtype=dtype)\n        upper_A = lower_A + torch.rand_like(lower_A)\n        # biases\n        lower_b = torch.randn(BATCH, OUT_DIM, device=device, dtype=dtype)\n        upper_b = lower_b + torch.rand_like(lower_b)\n\n        # Manual concretization\n        mid = ((input_lb + input_ub) / 2.0).unsqueeze(-1)   # (B, IN_DIM, 1)\n        diff = ((input_ub - input_lb) / 2.0).unsqueeze(-1)   # (B, IN_DIM, 1)\n\n        manual_L = (lower_A @ mid - torch.abs(lower_A) @ diff).squeeze(-1) + lower_b\n        manual_U = (upper_A @ mid + torch.abs(upper_A) @ diff).squeeze(-1) + upper_b\n        assert (manual_L < manual_U).all(), \"Invalid manual bounds construction.\"\n\n        ptb_linear = PerturbationLinear(\n            lower_A=lower_A, upper_A=upper_A, lower_b=lower_b, upper_b=upper_b,\n            input_lb=input_lb, input_ub=input_ub,\n            x_L=manual_L, x_U=manual_U\n        )\n        bounded_x = BoundedTensor((manual_L + manual_U) / 2, ptb_linear)\n        lirpa_model = BoundedModule(model, bounded_x)\n        lb_linear, ub_linear = lirpa_model.compute_bounds(bounded_x, method='backward')\n        assert (lb_linear <= ub_linear).all(), \"Invalid bounds from PerturbationLinear.\"\n        self.result.append((lb_linear, ub_linear))\n\n\n        # =================================================================\n        # Test PerturbationLpNorm\n        # =================================================================\n        # We directly use manual concretization here for testing\n        ptb_linf = PerturbationLpNorm(x_L=manual_L, x_U=manual_U)\n        bounded_x = BoundedTensor((manual_L + manual_U) / 2, ptb_linf)\n        lirpa_model = BoundedModule(model, bounded_x)\n        lb_linf, ub_linf = lirpa_model.compute_bounds(bounded_x, method='backward')\n        assert (lb_linf <= ub_linf).all(), \"Invalid bounds from PerturbationLpNorm.\"\n        self.result.append((lb_linf, ub_linf))\n\n        # Notice that with the same x_L and x_U, PerturbationLinear should give\n        # tighter bounds than PerturbationLpNorm. This is because\n        # PerturbationLinear uses additional information (A matrices and biases).\n        assert (lb_linear >= lb_linf).all() and (ub_linear <= ub_linf).all(\n        ), \"PerturbationLinear should give tighter bounds than PerturbationLpNorm.\"\n\n        self.check()\n\n\nif __name__ == '__main__':\n    testcase = TestPerturbation(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_rectangle_patches.py",
    "content": "import sys\nimport torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nsys.path.append('../examples/vision')\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass cnn_4layer_resnet(nn.Module):\n    def __init__(self):\n        super(cnn_4layer_resnet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.bn = nn.BatchNorm2d(3)\n        self.shortcut = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(168, 10)\n\n    def forward(self, x):\n        x_ = x\n        x = F.relu(self.conv1(self.bn(x)))\n        x += self.shortcut(x_)\n        x = F.relu(self.conv2(x))\n        x = x.view(x.size(0), -1)\n        print(x.size())\n        x = self.fc1(x)\n\n        return x\n\nclass TestResnetPatches(TestCase): \n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, \n            seed=1234, ref_name='rectangle_patches_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        model_oris = [\n            cnn_4layer_resnet(),\n        ]\n        self.result = []\n        if not self.generate:\n            self.reference = torch.load(\n                self.ref_path, map_location=self.default_device)\n\n        for model_ori in model_oris:\n            conv_mode = 'patches' # conv_mode can be set as 'matrix' or 'patches'        \n                \n            normalize = torchvision.transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010])\n            test_data = torchvision.datasets.CIFAR10(\"./data\", train=False, download=True, \n                            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))\n            N = 1\n            n_classes = 10\n\n            image = torch.Tensor(test_data.data[:N]).reshape(N,3,32,32)\n            image = image[:, :, :28, :]\n            image = image.to(device=self.default_device,\n                             dtype=self.default_dtype) / 255.0\n\n            model_ori = model_ori.to(\n                device=self.default_device, dtype=self.default_dtype)\n            model = BoundedModule(model_ori, image, bound_opts={\n                                  \"conv_mode\": conv_mode}, device=self.default_device)\n\n            ptb = PerturbationLpNorm(norm = np.inf, eps = 0.03)\n            image = BoundedTensor(image, ptb)\n            pred = model(image)\n            lb, ub = model.compute_bounds(IBP=False, C=None, method='backward')\n            self.result += [lb, ub]\n\n        self.check()\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestResnetPatches(generate=False)\n    testcase.test()"
  },
  {
    "path": "tests/test_resnet_patches.py",
    "content": "import sys\nimport torch\nimport numpy as np\nimport torchvision\nimport models\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\nsys.path.append('../examples/vision')\n\n\n\nclass TestResnetPatches(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName,\n            seed=1234, ref_name='resnet_patches_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def test(self):\n        model_oris = [\n            models.model_resnet(width=1, mult=2),\n            models.ResNet18(in_planes=2)\n        ]\n        self.result = []\n\n        for model_ori in model_oris:\n            conv_mode = 'patches' # conv_mode can be set as 'matrix' or 'patches'\n\n            normalize = torchvision.transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010])\n            test_data = torchvision.datasets.CIFAR10(\"./data\", train=False, download=True,\n                            transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize]))\n            N = 1\n            n_classes = 10\n\n            image = torch.Tensor(test_data.data[:N]).reshape(N,3,32,32)\n            image = image.to(device=self.default_device,\n                             dtype=self.default_dtype) / 255.0\n\n            model_ori = model_ori.to(\n                device=self.default_device, dtype=self.default_dtype)\n            model = BoundedModule(model_ori, image, bound_opts={\"conv_mode\": conv_mode}, device=self.default_device)\n\n            ptb = PerturbationLpNorm(norm = np.inf, eps = 0.03)\n            image = BoundedTensor(image, ptb)\n            pred = model(image)\n            lb, ub = model.compute_bounds(IBP=False, C=None, method='backward')\n            self.result += [lb, ub]\n\n        self.check()\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestResnetPatches(generate=False)\n    testcase.test()"
  },
  {
    "path": "tests/test_s_shaped.py",
    "content": "# pylint: disable=wrong-import-position\n\"\"\"Test S-shaped activation functions.\"\"\"\nimport torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass test_model(nn.Module):\n    def __init__(self, act_func):\n        super().__init__()\n        self.act_func = act_func\n\n    def forward(self, x):\n        return self.act_func(x)\n\ndef sigmoid(x):\n    return torch.sigmoid(x)\n\ndef sin(x):\n    return torch.sin(x)\n\n\ndef verify_bounds(model, input_lb, input_ub, lb, ub):\n    \"\"\"\n    Empirically verify that the model's output bounds are correct given input bounds.\n\n    Args:\n        model: The neural network model.\n        input_lb: Lower bound of the input.\n        input_ub: Upper bound of the input.\n        lb: Computed lower bound of the output.\n        ub: Computed upper bound of the output.\n    \"\"\"\n    n_samples = 100000\n    atol = 1e-5\n    inputs = torch.rand(n_samples, *input_lb.shape[1:]) * (input_ub - input_lb) + input_lb\n    outputs = model(inputs)\n    empirical_lb = outputs.min(dim=0).values\n    empirical_ub = outputs.max(dim=0).values\n    if not (empirical_lb - lb >= -atol).all():\n        max_violation = (lb - empirical_lb).max().item()\n        raise AssertionError(f\"Lower bound violated. Max violation: {max_violation}\")\n    if not (empirical_ub - ub <= atol).all():\n        max_violation = (empirical_ub - ub).max().item()\n        raise AssertionError(f\"Upper bound violated. Max violation: {max_violation}\")\n    print(\"Bounds verified successfully.\")\n\n\nclass TestSShaped(TestCase):\n    def __init__(self, methodName='runTest', generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(\n            methodName, seed=1, ref_name='s_shape_test_data',\n            generate=generate,\n            device=device, dtype=dtype)\n\n    def _run_bound_test(self, model, input_lb, input_ub, methods):\n        \"\"\"Helper to compute and verify bounds.\"\"\"\n        model = model.to(device=self.default_device, dtype=self.default_dtype)\n        lirpa_model = BoundedModule(model, torch.empty_like(input_lb), device=self.default_device)\n        ptb = PerturbationLpNorm(x_L=input_lb, x_U=input_ub)\n        ptb_data = BoundedTensor(input_lb, ptb)\n\n        for method in methods:\n            lb, ub = lirpa_model.compute_bounds(x=(ptb_data,), method=method)\n            verify_bounds(model, input_lb, input_ub, lb, ub)\n            self.result.append((lb, ub))\n\n    def test(self):\n        self.result = []\n        methods = ['CROWN', 'CROWN-OPTIMIZED']\n\n        # ----- Test BoundSin -----\n        model_sin = test_model(sin)\n        start, end = -10, 10\n        n_intervals = end - start - 1\n\n        # Inputs as multiples of pi\n        input_lb = torch.linspace(start, end - 1, n_intervals) * torch.pi\n        input_ub = torch.linspace(start + 1, end, n_intervals) * torch.pi\n        input_lb, input_ub = input_lb.unsqueeze(0), input_ub.unsqueeze(0)\n\n        self._run_bound_test(model_sin, input_lb, input_ub, methods)\n\n        # Inputs as multiples of pi / 2\n        self._run_bound_test(model_sin, input_lb / 2, input_ub / 2, methods)\n\n        # ----- Test BoundSigmoid -----\n        model_sigmoid = test_model(sigmoid)\n        input_lb = torch.tensor([[-2., -0.1]], device=self.default_device, dtype=self.default_dtype)\n        input_ub = torch.tensor([[0.1, 2.]], device=self.default_device, dtype=self.default_dtype)\n\n        self._run_bound_test(model_sigmoid, input_lb, input_ub, methods)\n\n        # Check reference results\n        self.check()\n\n\nif __name__ == '__main__':\n    # Change to generate=True when generating reference results\n    testcase = TestSShaped(generate=False)\n    testcase.setUp()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_save_intermediate.py",
    "content": "import torch\nimport torch.nn as nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import _to, TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass test_model(nn.Module):\n    def __init__(self):\n        super(test_model, self).__init__()\n        self.model = nn.Sequential(\n            nn.Flatten(),\n            nn.Linear(3 * 32 * 32, 1000),\n            nn.Sigmoid(),\n            nn.Linear(1000, 500),\n            nn.Linear(500, 200),\n            nn.Linear(200, 100),\n            nn.ReLU(),\n            nn.Linear(100, 10)\n        )\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\nclass TestSave(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def test(self, gen_ref=False):\n        image = torch.randn(1, 3, 32, 32)\n        image = image.to(device=self.default_device,\n                         dtype=self.default_dtype) / 255.0\n        model = test_model().to(device=self.default_device, dtype=self.default_dtype)\n\n        bounded_model = BoundedModule(\n            model, image, bound_opts={\n                'optimize_bound_args': {'iteration': 2},\n            }, device=self.default_device)\n\n        ptb = PerturbationLpNorm(eps=3/255)\n        x = BoundedTensor(image, ptb)\n        bounded_model.compute_bounds(x=(x,), method='CROWN-Optimized')\n        if self.default_dtype == torch.float32:\n            data_path = 'data/'\n        elif self.default_dtype == torch.float64:\n            data_path = 'data_64/'\n        data_path += 'test_save_data'\n\n        save_dict = bounded_model.save_intermediate(\n            save_path=data_path if gen_ref else None)\n\n        if gen_ref:\n            torch.save(save_dict, data_path)\n            return\n\n        ref_dict = torch.load(data_path)\n        ref_dict = _to(\n            ref_dict, device=self.default_device, dtype=self.default_dtype)\n\n\n        for node in ref_dict.keys():\n            assert torch.allclose(ref_dict[node][0], save_dict[node][0], atol=1e-5)\n            assert torch.allclose(ref_dict[node][1], save_dict[node][1], atol=1e-5)\n\n\nif __name__ == '__main__':\n    testcase = TestSave()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_simple_verification.py",
    "content": "\"\"\"Test optimized bounds in simple_verification.\"\"\"\nimport torch\nimport torch.nn as nn\nimport torchvision\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import PerturbationLpNorm\nfrom auto_LiRPA.utils import Flatten\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n# This simple model comes from https://github.com/locuslab/convex_adversarial\ndef mnist_model():\n    model = nn.Sequential(\n        nn.Conv2d(1, 16, 4, stride=2, padding=1),\n        nn.ReLU(),\n        nn.Conv2d(16, 32, 4, stride=2, padding=1),\n        nn.ReLU(),\n        Flatten(),\n        nn.Linear(32*7*7,100),\n        nn.ReLU(),\n        nn.Linear(100, 10)\n    )\n    return model\n\nclass TestSimpleVerification(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def test(self):\n      model = mnist_model()\n      checkpoint = torch.load(\n        '../examples/vision/pretrained/mnist_a_adv.pth',\n        map_location=torch.device('cpu'))\n      model.load_state_dict(checkpoint)\n      model = model.to(device=self.default_device, dtype=self.default_dtype)\n      test_data = torchvision.datasets.MNIST(\n        './data', train=False, download=True, transform=torchvision.transforms.ToTensor())\n      N = 2\n      image = test_data.data[:N].view(N,1,28,28)\n      image = image.to(device=self.default_device,\n                       dtype=self.default_dtype) / 255.0\n\n      lirpa_model = BoundedModule(model, torch.empty_like(image), device=self.default_device)\n      ptb = PerturbationLpNorm(0.3)\n      image = BoundedTensor(image, ptb)\n\n      method = 'CROWN-Optimized (alpha-CROWN)'\n      lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}})\n      _, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0])\n      self.assertEqual(ub[0][7], torch.tensor(12.5080))\n\nif __name__ == '__main__':\n    testcase = TestSimpleVerification()\n    testcase.test()\n"
  },
  {
    "path": "tests/test_state_dict_name.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom auto_LiRPA import BoundedModule\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass FeatureExtraction(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(1, 8, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(8, 16, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(784, 256)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        x = x.view(-1, 784)\n        x = F.relu(self.fc1(x))\n        return x\n\n\nclass cnn_MNIST(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.features = BoundedModule(FeatureExtraction(), torch.empty((1, 1, 28, 28)))\n        self.fc = nn.Linear(256, 10)\n\n    def forward(self, x):\n        x = self.features(x)\n        return self.fc(x)\n\n\nclass TestStateDictName(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n\n    def test(self):\n        model = cnn_MNIST().to(device=self.default_device, dtype=self.default_dtype)\n        state_dict = model.state_dict()\n        dummy = torch.randn((1, 1, 28, 28))\n        ret1 = model(dummy)\n\n        # create second model and load state_dict to test load_state_dict() whether works proper\n        model = cnn_MNIST().to(device=self.default_device, dtype=self.default_dtype)\n        model.load_state_dict(state_dict, strict=True)\n        ret2 = model(dummy)\n        self.assertEqual(ret1, ret2)\n\n\nif __name__ == '__main__':\n    # Change to generate=True when genearting reference results\n    testcase = TestStateDictName(generate=False)\n    testcase.test()\n"
  },
  {
    "path": "tests/test_tensor_storage.py",
    "content": "\nimport random\nimport torch\nfrom complete_verifier.tensor_storage import StackTensorStorage, QueueTensorStorage\n\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\n\nclass TestTensorStorage(TestCase):\n    def __init__(self, methodName='runTest', device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, device=device, dtype=dtype)\n        \n    def test_content(self, seed=123):\n        self.set_seed(seed)\n        storage_classes_and_pop_behavior = [\n            (\n                StackTensorStorage,\n                lambda tensor_list, num_pop: (tensor_list[-num_pop:], tensor_list[:-num_pop])\n            ),\n            (\n                QueueTensorStorage,\n                lambda tensor_list, num_pop: (tensor_list[:num_pop], tensor_list[num_pop:])\n            )\n        ]\n\n        for storage_class, pop_behavior in storage_classes_and_pop_behavior:\n            for concat_dim in [0, 1, 2]:\n                # The call to `.size()` has side effects for `QueueTensorStorage`, because it will\n                # cause a call to `.tensor()` which may change the internal storage.\n                for check_size in [True, False]:\n                    stored_tensors = []\n                    shape = [2,3,4]\n                    def make_random_tensor():\n                        random_size = random.randint(1, 100)\n                        tensors = []\n                        for _ in range(random_size):\n                            random_tensor = torch.randn(\n                                shape[:concat_dim] + shape[concat_dim+1:], device=self.default_device, dtype=self.default_dtype).unsqueeze(concat_dim)\n                            tensors.append(random_tensor)\n                        return torch.cat(tensors, dim=concat_dim), tensors\n                    s = storage_class(full_shape=shape, initial_size=16, switching_size=65536, concat_dim=concat_dim)\n                    for _ in range(1000):\n                        random_tensor, tensors = make_random_tensor()\n                        s.append(random_tensor)\n                        stored_tensors.extend(tensors)\n                        if check_size:\n                            assert s.size(concat_dim) == len(stored_tensors)\n\n                        num_pop = random.randint(1, 100)\n                        popped_tensors, stored_tensors = pop_behavior(stored_tensors, num_pop)\n                        popped_tensor = s.pop(num_pop)\n                        assert torch.allclose(popped_tensor, torch.cat(popped_tensors, dim=concat_dim))\n                        if check_size:\n                            assert s.size(concat_dim) == len(stored_tensors)\n\n    def test_tensor_call(self, seed=123):\n        # The call to `.tensor()` has side effects for `QueueTensorStorage`, because it will\n        # cause a call to `.size()` which may change the internal storage.\n        self.set_seed(seed)\n        pop_behavior = lambda tensor_list, num_pop: (tensor_list[:num_pop], tensor_list[num_pop:])\n\n        for concat_dim in [0, 1, 2]:\n            stored_tensors = []\n            shape = [2,3,4]\n            def make_random_tensor():\n                random_size = random.randint(1, 100)\n                tensors = []\n                for _ in range(random_size):\n                    random_tensor = torch.randn(shape[:concat_dim] + shape[concat_dim+1:], dtype=self.default_dtype).unsqueeze(concat_dim)\n                    tensors.append(random_tensor)\n                return torch.cat(tensors, dim=concat_dim), tensors\n            s = QueueTensorStorage(full_shape=shape, initial_size=16, switching_size=16, concat_dim=concat_dim)\n            for _ in range(1000):\n                random_tensor, tensors = make_random_tensor()\n                s.append(random_tensor)\n                stored_tensors.extend(tensors)\n\n                num_pop = random.randint(1, 10)\n                _, stored_tensors = pop_behavior(stored_tensors, num_pop)\n                _ = s.pop(num_pop)\n                if s._usage_start + s.num_used > s._storage.size(concat_dim):\n                    storage_content = s.tensor()\n                    assert torch.allclose(storage_content, torch.cat(stored_tensors, dim=concat_dim))\n\n\n    def test_size_queue(self):\n        for concat_dim in [0, 1, 2]:\n            shape = [1,1,1]\n            shape[concat_dim] = -1 # does no matter.\n            zero_shape = shape.copy()\n            zero_shape[concat_dim] = 0\n            def make_tensor(x): return torch.arange(\n                1, x+1, device=self.default_device, dtype=self.default_dtype).view(*shape)\n            s = QueueTensorStorage(full_shape=shape, initial_size=16, switching_size=65536, concat_dim=concat_dim)\n            s.append(make_tensor(1))\n            assert s.sum() == 1, s.tensor()\n            s.append(make_tensor(3))\n            assert s.sum() == 1 + 6, s.tensor()\n            s.append(make_tensor(5))\n            assert s.sum() == 1 + 6 + 15, s.tensor()\n            t = s.pop(5)\n            assert torch.allclose(t.squeeze(), torch.tensor(\n                [1, 1, 2, 3, 1], device=self.default_device, dtype=self.default_dtype))\n            t = s.pop(0)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(-1)\n            assert t.shape == torch.Size(zero_shape)\n            s.append(make_tensor(100))\n            expected_sum = 1 + sum(range(1,4)) + sum(range(1,6)) - (1 + 1 + 2 + 3 + 1) + sum(range(1,101))\n            assert s.sum() == expected_sum, (s.sum(), expected_sum)\n            t = s.pop(5)\n            assert torch.allclose(t.squeeze(), torch.tensor(\n                [2, 3, 4, 5, 1], device=self.default_device, dtype=self.default_dtype)), print(t)\n            assert s.size(concat_dim) == 99, print(s.size())\n            assert s._storage.size(concat_dim) == 104, print(s._storage.size())\n            s.append(make_tensor(10))\n            assert s.size(concat_dim) == 109, print(s.size())\n            assert s._storage.size(concat_dim) == 208, print(s._storage.size())\n            s.append(make_tensor(32768))\n            assert s.size(concat_dim) == 32877, print(s.size())\n            assert s._storage.size(concat_dim) == 32877, print(s._storage.size())\n            s.pop(1)\n            s.append(make_tensor(2))\n            assert s.size(concat_dim) == 32878, print(s.size())\n            assert s._storage.size(concat_dim) == 32877*2, print(s._storage.size())\n            s.append(make_tensor(32800))\n            s.append(make_tensor(100))\n            assert s._storage.size(concat_dim) == 32877*2+100*32, print(s._storage.size())\n            s.pop(100000)\n            assert s._storage.size(concat_dim) == 32877*2+100*32, print(s._storage.size())\n            assert s.size(concat_dim) == 0, print(s.size())\n            t = s.pop(1)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(0)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(-1)\n            assert t.shape == torch.Size(zero_shape)\n\n    def test_size_stack(self):\n        for concat_dim in [0, 1, 2]:\n            shape = [1,1,1]\n            shape[concat_dim] = -1 # does no matter.\n            zero_shape = shape.copy()\n            zero_shape[concat_dim] = 0\n            make_tensor = lambda x: torch.arange(1,x+1, dtype=self.default_dtype).view(*shape)\n            s = StackTensorStorage(full_shape=shape, initial_size=16, switching_size=65536, concat_dim=concat_dim)\n            s.append(make_tensor(1))\n            assert s.sum() == 1, print(s)\n            s.append(make_tensor(3))\n            assert s.sum() == 1 + 6, print(s)\n            s.append(make_tensor(5))\n            assert s.sum() == 1 + 6 + 15, print(s)\n            t = s.pop(5)\n            assert torch.allclose(t.squeeze(), torch.tensor(\n                [1, 2, 3, 4, 5], device=self.default_device, dtype=self.default_dtype)), print(t)\n            t = s.pop(0)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(-1)\n            assert t.shape == torch.Size(zero_shape)\n            s.append(make_tensor(100))\n            assert s.sum() == 1 + 6 + 50*101\n            t = s.pop(5)\n            assert torch.allclose(t.squeeze(), torch.tensor(\n                [96, 97, 98, 99, 100], device=self.default_device, dtype=self.default_dtype)), print(t)\n            assert s.size(concat_dim) == 99, print(s.size())\n            assert s._storage.size(concat_dim) == 104, print(s._storage.size())\n            s.append(make_tensor(10))\n            assert s.size(concat_dim) == 109, print(s.size())\n            assert s._storage.size(concat_dim) == 208, print(s._storage.size())\n            s.append(make_tensor(32768))\n            assert s.size(concat_dim) == 32877, print(s.size())\n            assert s._storage.size(concat_dim) == 32877, print(s._storage.size())\n            s.pop(1)\n            s.append(make_tensor(2))\n            assert s.size(concat_dim) == 32878, print(s.size())\n            assert s._storage.size(concat_dim) == 32877*2, print(s._storage.size())\n            s.append(make_tensor(32800))\n            s.append(make_tensor(100))\n            assert s._storage.size(concat_dim) == 32877*2+100*32, print(s._storage.size())\n            s.pop(100000)\n            assert s._storage.size(concat_dim) == 32877*2+100*32, print(s._storage.size())\n            assert s.size(concat_dim) == 0, print(s.size())\n            t = s.pop(1)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(0)\n            assert t.shape == torch.Size(zero_shape)\n            t = s.pop(-1)\n            assert t.shape == torch.Size(zero_shape)\n\nif __name__ == \"__main__\":\n    testcase = TestTensorStorage()\n    testcase.test_tensor_call()\n    testcase.test_size_stack()\n    testcase.test_size_queue()\n    testcase.test_content()\n"
  },
  {
    "path": "tests/test_upsample.py",
    "content": "from collections import defaultdict\n\nfrom torch import nn\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\n\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass Model(nn.Module):\n\n    def __init__(self,\n                 input_dim=5, image_size=4,\n                 scale_factor=2, conv_kernel_size=3, stride=1, padding=1,\n                 conv_in_channels=16, conv_out_channels=4):\n        super(Model, self).__init__()\n        self.conv_in_channels = conv_in_channels\n        self.input_dim = input_dim\n        self.image_size = image_size\n\n        self.fc1 = nn.Linear(input_dim, conv_in_channels * image_size * image_size)\n        self.upsample = nn.Upsample(scale_factor=(scale_factor, scale_factor), mode='nearest')\n        # H = W = 4 * scale_factor now\n        self.conv1 = nn.Conv2d(in_channels=conv_in_channels, out_channels=conv_out_channels,\n                               kernel_size=(conv_kernel_size, conv_kernel_size), stride=(stride, stride), padding=padding)\n        # H = W = (4 * scale + 2 * pad - ker + s) // s\n        size_after_conv = (4 * scale_factor + 2 * padding - conv_kernel_size + stride) // stride\n        assert size_after_conv > 0, \"0 size after convolution, please use more padding, more scale_factor,\" \\\n                                    \"smaller kernel, or smaller stride\"\n        self.relu = nn.ReLU()\n        self.flatten = nn.Flatten()\n        self.fc2 = nn.Linear(size_after_conv * size_after_conv * conv_out_channels, 1)\n        # self.sigmoid = nn.Sigmoid()\n\n    def forward(self, input_z):\n        f1 = self.fc1(input_z)\n        d1 = f1.reshape(-1, self.conv_in_channels, self.image_size, self.image_size)\n        d2 = self.upsample(d1)\n        d3 = self.conv1(d2)\n        d4 = self.relu(d3)\n        f2 = self.flatten(d4)\n        f3 = self.fc2(f2)\n        # out = self.sigmoid(f3)\n        return f3\n\nclass ModelReducedCGAN(nn.Module):\n    def __init__(self):\n        \"\"\"\n            The network has the same architecture with merged bn CGAN upsampling one except reduced channel nums\n        \"\"\"\n        super(ModelReducedCGAN, self).__init__()\n        self.fc1 = nn.Linear(5, 32)\n        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')\n        self.conv1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1)\n        self.relu1 = nn.ReLU()\n        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')\n        self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3, stride=1, padding=1)\n        self.relu2 = nn.ReLU()\n        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')\n        self.conv3 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1)\n        self.relu3 = nn.ReLU()\n        self.conv4 = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=3, stride=1, padding=1)\n        self.conv5 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1)\n        self.relu4 = nn.ReLU()\n        self.conv6 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1)\n        self.relu5 = nn.ReLU()\n        self.conv7 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1)\n        self.relu6 = nn.ReLU()\n        self.conv8 = nn.Conv2d(in_channels=4, out_channels=4, kernel_size=3, stride=2, padding=1)\n        self.relu7 = nn.ReLU()\n        self.fc2 = nn.Linear(4 * 2 * 2, 1)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, input_z):\n        f1 = self.fc1(input_z)\n        f2 = f1.reshape(-1, 2, 4, 4)\n        f3 = self.up1(f2)\n        f4 = self.conv1(f3)\n        f5 = self.relu1(f4)\n        f6 = self.up2(f5)\n        f7 = self.conv2(f6)\n        f8 = self.relu2(f7)\n        f9 = self.up3(f8)\n        f10 = self.conv3(f9)\n        f11 = self.relu3(f10)\n        f12 = self.conv4(f11)\n        f13 = self.conv5(f12)\n        f14 = self.relu4(f13)\n        f15 = self.conv6(f14)\n        f16 = self.relu5(f15)\n        f17 = self.conv7(f16)\n        f18 = self.relu6(f17)\n        f19 = self.conv8(f18)\n        f20 = self.relu7(f19)\n        f21 = f20.reshape(f20.shape[0], -1)\n        f22 = self.fc2(f21)\n        # f23 = self.sigmoid(f22)\n        return f22\n\n\n\ndef recursive_allclose(a, b: dict, verbose=False, prefix=''):\n    \"\"\"\n        Recursively check whether every corresponding tensors in two dicts are close\n    :param a: dict a\n    :param b: dict b\n    :param prefix: reserved for path tracking in recursive calling for error printing\n    :return: bool: all_close or not\n    \"\"\"\n    tot_tensor = 0\n    tot_dict = 0\n    for k in a:\n        if isinstance(a[k], torch.Tensor):\n            if k == 'unstable_idx': continue\n            if verbose:\n                print(f'recursive_allclose(): Checking {prefix}{k}')\n            assert k in b and isinstance(b[k], torch.Tensor) or isinstance(b[k], Patches), f'recursive_allclose(): Tensor not found in path {prefix}{k}'\n            if isinstance(b[k], torch.Tensor):\n                assert torch.allclose(a[k].reshape(-1), b[k].reshape(-1), 1e-4, 1e-5), f'recursive_allclose(): Inconsistency found in path {prefix}{k}'\n            tot_tensor += 1\n        elif isinstance(a[k], dict):\n            assert k in b and isinstance(b[k], dict), f'recursive_allclose(): dict not found in path {prefix}{k}'\n            recursive_allclose(a[k], b[k], verbose, prefix + k)\n            tot_dict += 1\n    tot_b_tensor = sum([1 if isinstance(v, torch.Tensor) or isinstance(v, Patches) and k != 'unstable_idx' else 0 for k, v in b.items()])\n    tot_b_dict = sum([1 if isinstance(v, dict) else 0 for v in b.values()])\n    assert tot_tensor == tot_b_tensor, f'recursive_allclose(): Extra tensors found in path {prefix}'\n    assert tot_dict == tot_b_dict, f'recursive_allclose(): Extra recursive paths found in path {prefix}'\n    return True\n\n\nclass TestUpSample(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed=1, ref_name=None, generate=generate,\n                         device=device, dtype=dtype)\n        # self.device = device\n\n    def test(self, seed=123):\n        for kernel_size in [3,5]:\n            for scaling_factor in [2,3,4]:\n                for stride in [1,2]:\n                    for padding in [1]:\n                        self.test_instance(kernel_size, scaling_factor, stride, padding, seed=seed)\n\n    def test_instance(self, kernel_size=3, scaling_factor=2, stride=1, padding=1, seed=123):\n        self.set_seed(seed)\n\n        print(f'kernel_size = {kernel_size}, scaling_factor = {scaling_factor}, stride = {stride}, padding = {padding}')\n        random_input = torch.randn(\n            (1, 5), device=self.default_device, dtype=self.default_dtype) * 1000.\n        eps = 0.3\n\n        model_ori = Model(scale_factor=scaling_factor,\n                          conv_kernel_size=kernel_size,\n                          stride=stride,\n                          padding=padding).to(device=self.default_device, dtype=self.default_dtype)\n\n        ptb = PerturbationLpNorm(norm=np.inf, eps=eps)\n        z1_clean = random_input.detach().clone().requires_grad_(requires_grad=True)\n\n        z1 = BoundedTensor(random_input, ptb)\n        model_mat = BoundedModule(model_ori, (random_input,), device=self.default_device, bound_opts={\"conv_mode\": \"matrix\"})\n        pred_of_mat = model_mat(z1)\n        lb_m, ub_m, A_m = model_mat.compute_bounds(return_A=True, needed_A_dict={model_mat.output_name[0]: model_mat.input_name[0]}, )\n\n        model_pat = BoundedModule(model_ori, (random_input,), device=self.default_device,\n                                  bound_opts={\"conv_mode\": \"patches\"})\n        pred_of_patch = model_pat(z1)\n        lb_p, ub_p, A_p = model_pat.compute_bounds(return_A=True, needed_A_dict={\n            model_pat.output_name[0]: model_pat.input_name[0]}, )\n\n        assert torch.allclose(pred_of_mat, pred_of_patch, 1e-5)\n        assert torch.allclose(lb_m, lb_p, 1e-5)\n        assert torch.allclose(ub_m, ub_p, 1e-5)\n        assert recursive_allclose(A_m, A_p, verbose=True)\n\nclass TestReducedCGAN(TestCase):\n\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed=1, ref_name=None, generate=generate,\n                         device=device, dtype=dtype)\n        # self.device = device\n\n    def test(self, seed=456):\n        self.set_seed(seed)\n        input = torch.tensor([[0.583, -0.97, -0.97, 0.598, 0.737]])\n        eps = 0.1\n\n        model_ori = ModelReducedCGAN().to(\n            device=self.default_device, dtype=self.default_dtype)\n\n        ptb = PerturbationLpNorm(norm=np.inf, eps=eps)\n        z1_clean = input.detach().clone().requires_grad_(requires_grad=True)\n\n        z1 = BoundedTensor(input, ptb)\n        model_mat = BoundedModule(model_ori, (input,), device=self.default_device,\n                                  bound_opts={\"conv_mode\": \"matrix\"})\n        pred_of_mat = model_mat(z1)\n\n        needed_A_dict = defaultdict(set)\n        for node in model_mat.nodes():\n            needed_A_dict[node.name] = set()\n\n        lb_m, ub_m, A_m = model_mat.compute_bounds((z1,), return_A=True, needed_A_dict=needed_A_dict, method='crown')\n\n        model_pat = BoundedModule(model_ori, (input,), device=self.default_device,\n                                  bound_opts={\"conv_mode\": \"patches\", \"sparse_features_alpha\": False})\n        pred_of_patch = model_pat(z1)\n        lb_p, ub_p, A_p = model_pat.compute_bounds((z1,), return_A=True, needed_A_dict=needed_A_dict, method='crown')\n\n        # print(pred_of_mat, pred_of_patch)\n        assert torch.allclose(pred_of_mat, pred_of_patch, 1e-5)\n        assert torch.allclose(lb_m, lb_p, 1e-5)\n        assert torch.allclose(ub_m, ub_p, 1e-5)\n        assert recursive_allclose(A_m, A_p, verbose=True)\n\nif __name__ == '__main__':\n    # should use device = 'cpu' for GitHub CI\n    testcase = TestUpSample(generate=False)\n    testcase.test(seed=123)\n\n    # \"\"\"\n    #     following test is much stronger, but runs within 30s only on GPUs\n    #     so commented it out for CI testing now\n    #     required GPU memory: 1.5 GiB\n    # \"\"\"\n    testhardcase = TestReducedCGAN(generate=False)\n    testhardcase.test(seed=456)\n\n\n"
  },
  {
    "path": "tests/test_vision_models.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom auto_LiRPA import BoundedModule, BoundedTensor\nfrom auto_LiRPA.perturbations import *\nfrom testcase import _to, TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass cnn_4layer_test(nn.Module):\n    def __init__(self):\n        super(cnn_4layer_test, self).__init__()\n        self.conv1 = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.bn = nn.BatchNorm2d(3)\n        self.shortcut = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(3, 3, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(192, 10)\n\n    def forward(self, x):\n        x_ = x\n        x = F.relu(self.conv1(self.bn(x)))\n        x += self.shortcut(x_)\n        x = F.relu(self.conv2(x))\n        x = x.view(x.size(0), -1)\n        x = self.fc1(x)\n\n        return x\n\nclass TestVisionModels(TestCase):\n    def __init__(self, methodName='runTest', ref_name='vision_test_data', model=cnn_4layer_test(), generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, seed=1234, ref_name=ref_name,\n                         generate=generate, device=device, dtype=dtype)\n        self.result = {}\n        self.model = model.to(device=self.default_device,\n                              dtype=self.default_dtype)\n\n    def setUp(self):\n        super().setUp()\n        if self.reference:\n            self.reference = _to(self.reference, self.default_device)\n            self.reference = _to(self.reference, self.default_device)\n        if self.generate:\n            # state_dict from an existing reference is needed \n            self.reference = torch.load(self.ref_path)\n\n    def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name):\n        lb, ub = model(method_opt=\"compute_bounds\", x=(x,), IBP=IBP, method=method)\n        self.result[lb_name] = lb\n        self.result[ub_name] = ub\n\n        if method != 'CROWN-Optimized':\n        # test gradient backward propagation\n        # only when method is not \"CROWN-Optimized\" (in that case, lb and ub don't have gradient)\n            loss = (ub - lb).abs().sum()\n            loss.backward()\n            grad = x.grad\n            self.result[lb_name[:-2] + 'grad'] = grad.clone()\n\n        if not self.generate:\n            if method != 'CROWN-Optimized':\n                assert torch.allclose(lb, self.reference[lb_name], 1e-4, atol=2e-7), (lb - self.reference[lb_name]).abs().max()\n                assert torch.allclose(ub, self.reference[ub_name], 1e-4, atol=2e-7), (ub - self.reference[ub_name]).abs().max()\n                assert ((lb - self.reference[lb_name]).pow(2).sum() < 1.3e-9), (lb - self.reference[lb_name]).pow(2).sum()\n                assert ((ub - self.reference[ub_name]).pow(2).sum() < 1.3e-9), (ub - self.reference[ub_name]).pow(2).sum()\n                if \"same-slope\" not in lb_name:\n                    assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6),  (grad - self.reference[lb_name[:-2] + 'grad']).abs().max()\n                    assert (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum() < 1.e-6, (grad - self.reference[lb_name[:-2] + 'grad']).pow(2).sum()\n            else:\n                assert torch.allclose(lb, self.reference[lb_name], 1e-4, atol=5e-6), (lb - self.reference[lb_name]).abs().max()\n                assert torch.allclose(ub, self.reference[ub_name], 1e-4, atol=5e-6), (ub - self.reference[ub_name]).abs().max()\n                assert ((lb - self.reference[lb_name]).pow(2).sum() < 1.3e-9), (lb - self.reference[lb_name]).pow(2).sum()\n                assert ((ub - self.reference[ub_name]).pow(2).sum() < 1.3e-9), (ub - self.reference[ub_name]).pow(2).sum()\n\n\n    def test_bounds(self, bound_opts=None, optimize = True):\n        if bound_opts is None:\n            bound_opts = {'activation_bound_option': 'same-slope'}\n        np.random.seed(123)  # FIXME inconsistent seeds\n        model_ori = self.model.eval()\n        model_ori.load_state_dict(self.reference['model'])\n        dummy_input = self.reference['data'].to(dtype=self.default_dtype, device=self.default_device)\n        inputs = (dummy_input,)\n\n        model = BoundedModule(model_ori, inputs, device=self.default_device)\n        model.set_bound_opts({'optimize_bound_args': {'lr_alpha': 0.1}})\n        forward_ret = model(dummy_input)\n        model_ori.eval()\n\n        assert torch.allclose(model_ori(dummy_input), model(dummy_input), 1e-4, 1e-6)\n\n        model_same_slope = BoundedModule(model_ori, inputs, device=self.default_device, bound_opts=bound_opts)\n        model_same_slope.set_bound_opts({'optimize_bound_args': {'lr_alpha': 0.1}})\n\n        # Linf\n        ptb = PerturbationLpNorm(norm=np.inf, eps=0.01)\n        x = BoundedTensor(dummy_input, ptb)\n        x.requires_grad_()\n\n        self.verify_bounds(model, x, IBP=True, method=None, forward_ret=forward_ret, lb_name='l_inf_IBP_lb',\n                    ub_name='l_inf_IBP_ub')  # IBP\n        self.verify_bounds(model, x, IBP=True, method='backward', forward_ret=forward_ret, lb_name='l_inf_CROWN-IBP_lb',\n                    ub_name='l_inf_CROWN-IBP_ub')  # CROWN-IBP\n        self.verify_bounds(model, x, IBP=False, method='backward', forward_ret=forward_ret, lb_name='l_inf_CROWN_lb',\n                    ub_name='l_inf_CROWN_ub')  # CROWN\n        self.verify_bounds(model_same_slope, x, IBP=False, method='backward', forward_ret=forward_ret, lb_name='l_inf_CROWN-same-slope_lb',\n                    ub_name='l_inf_CROWN-same-slope_ub') # CROWN-same-slope\n        if optimize:\n            self.verify_bounds(model, x, IBP=False, method='CROWN-Optimized', forward_ret=forward_ret, lb_name='l_inf_CROWN-Optimized_lb',\n                        ub_name='l_inf_CROWN-Optimized_ub') # CROWN-Optimized\n            self.verify_bounds(model_same_slope, x, IBP=False, method='CROWN-Optimized', forward_ret=forward_ret, lb_name='l_inf_CROWN-Optimized-same-slope_lb',\n                        ub_name='l_inf_CROWN-Optimized-same-slope_ub')  # Crown-Optimized-same-slope\n\n\n        # L2\n        ptb = PerturbationLpNorm(norm=2, eps=0.01)\n        x = BoundedTensor(dummy_input, ptb)\n        x.requires_grad_()\n\n        self.verify_bounds(model, x, IBP=True, method=None, forward_ret=forward_ret, lb_name='l_2_IBP_lb',\n                    ub_name='l_2_IBP_ub')  # IBP\n        self.verify_bounds(model, x, IBP=True, method='backward', forward_ret=forward_ret, lb_name='l_2_CROWN-IBP_lb',\n                    ub_name='l_2_CROWN-IBP_ub')  # CROWN-IBP\n        self.verify_bounds(model, x, IBP=False, method='backward', forward_ret=forward_ret, lb_name='l_2_CROWN_lb',\n                    ub_name='l_2_CROWN_ub')  # CROWN\n        self.verify_bounds(model_same_slope, x, IBP=False, method='backward', forward_ret=forward_ret, lb_name='l_2_CROWN-same-slope_lb',\n                    ub_name='l_2_CROWN-same-slope_ub') # CROWN-same-slope\n        if optimize:\n            self.verify_bounds(model, x, IBP=False, method='CROWN-Optimized', forward_ret=forward_ret, lb_name='l_2_CROWN-Optimized_lb',\n                        ub_name='l_2_CROWN-Optimized_ub') # CROWN-Optimized\n            self.verify_bounds(model_same_slope, x, IBP=False, method='CROWN-Optimized', forward_ret=forward_ret, lb_name='l_2_CROWN-Optimized-same-slope_lb',\n                        ub_name='l_2_CROWN-Optimized-same-slope_ub')  # Crown-Optimized-same-slope\n\n        if self.generate:\n            self.result['data'] = self.reference['data']\n            self.result['model'] = self.reference['model']\n            self.save()\n\n\nif __name__ ==\"__main__\":\n    t = TestVisionModels(generate=False)\n    # t = TestVisionModels()\n    t.setUp()\n    t.test_bounds()\n"
  },
  {
    "path": "tests/test_vision_models_hardtanh.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\nfrom auto_LiRPA.perturbations import *\nfrom test_vision_models import TestVisionModels\nfrom testcase import DEFAULT_DEVICE, DEFAULT_DTYPE\n\nclass cnn_4layer_test_hardtanh(nn.Module):\n    def __init__(self, in_ch, in_dim, width=2, linear_size=256):\n        super(cnn_4layer_test_hardtanh, self).__init__()\n        self.conv1 = nn.Conv2d(in_ch, 4 * width, 4, stride=2, padding=1)\n        self.conv2 = nn.Conv2d(4 * width, 8 * width, 4, stride=2, padding=1)\n        self.fc1 = nn.Linear(8 * width * (in_dim // 4) * (in_dim // 4), linear_size)\n        self.fc2 = nn.Linear(linear_size, 10)\n\n    def forward(self, x):\n        x = F.hardtanh(self.conv1(x))\n        x = F.hardtanh(self.conv2(x))\n        x = torch.flatten(x, 1)\n        x = F.hardtanh(self.fc1(x))\n        x = self.fc2(x)\n\n        return x\n\nclass TestCustomVisionModel(TestVisionModels):\n    def __init__(self, methodName='runTest', model=cnn_4layer_test_hardtanh(in_ch=1, in_dim=28), generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(methodName, 'vision_clip_test_data', model, generate, device=device, dtype=dtype)\n\n    def test_bounds(self, bound_opts=None, optimize=False):\n        if bound_opts is None:\n            bound_opts = {'hardtanh': 'same-slope'}\n        super().test_bounds(bound_opts=bound_opts, optimize=optimize)\n\nif __name__ == \"__main__\":\n    t = TestCustomVisionModel()\n    t.setUp()\n    t.test_bounds()\n"
  },
  {
    "path": "tests/test_weight_perturbation.py",
    "content": "import copy\nimport subprocess\nimport numpy as np\nfrom testcase import TestCase, DEFAULT_DEVICE, DEFAULT_DTYPE \nimport sys\nsys.path.append('../examples/vision')\nimport models\nfrom auto_LiRPA import BoundedModule\nfrom auto_LiRPA.perturbations import *\n\n\nclass TestWeightPerturbation(TestCase):\n    def __init__(self, methodName='runTest', generate=False, device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n        super().__init__(\n            methodName, seed=1234,\n            ref_name='weight_perturbation_test_data', generate=generate,\n            device=device, dtype=dtype)\n        self.result = {}\n\n    def test_training(self):\n        # python weight_perturbation_training.py --device cpu --scheduler_opts start=1,length=100 --num_epochs 1  --truncate_data 5\n        ret = subprocess.run(\n            ['python', 'weight_perturbation_training.py',\n            '--device', 'cpu',\n            '--scheduler_opts', 'start=1,length=100',\n            '--num_epochs',  '1',\n            '--truncate_data', '5'],\n            cwd='../examples/vision', capture_output=True)\n        self.assertEqual(ret.returncode, 0, ret.stderr)\n        res_test = ret.stdout.decode().split('\\n')[-2].split(' ')\n        assert abs(float(res_test[-3].split('=')[1]) - 2.246) < 0.01\n\n    def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name):\n        lb, ub = model(method_opt=\"compute_bounds\", x=(x,), IBP=IBP, method=method)\n        self.result[lb_name] = lb.detach().data.clone()\n        self.result[ub_name] = ub.detach().data.clone()\n\n        # test gradient backward propagation\n        loss = (ub - lb).abs().sum()\n        loss.backward()\n        # gradient w.r.t input only\n        grad = x.grad\n        self.result[lb_name+'_grad'] = grad.detach().data.clone()\n\n        if not self.generate:\n            assert torch.allclose(self.reference[lb_name], self.result[lb_name], 1e-4, 1e-6)\n            assert torch.allclose(self.reference[ub_name], self.result[ub_name], 1e-4, 1e-6)\n            assert ((self.reference[lb_name] - self.result[lb_name]).pow(2).sum() < 1e-8)\n            assert ((self.reference[ub_name] - self.result[ub_name]).pow(2).sum() < 1e-8)\n            assert torch.allclose(self.reference[lb_name+'_grad'],\n                                  self.result[lb_name + '_grad'], 1e-4, 1e-6)\n            assert ((self.reference[lb_name + '_grad']\n                     - self.result[lb_name + '_grad']).pow(2).sum() < 1e-8)\n\n    def test_perturbation(self):\n        np.random.seed(123) # FIXME This seed is inconsistent with other seeds (1234)\n\n        model_ori = models.Models['mlp_3layer_weight_perturb'](pert_weight=True, pert_bias=True).eval()\n        self.result['model'] = model_ori.state_dict()\n        self.result['data'] = torch.randn(8, 1, 28, 28)\n        model_ori.load_state_dict(self.result['model'])\n        state_dict = copy.deepcopy(model_ori.state_dict())\n        dummy_input = self.result['data'].requires_grad_()\n        inputs = (dummy_input,)\n\n        model = BoundedModule(model_ori, inputs, bound_opts={\n            'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False}, device=self.default_device)\n        forward_ret = model(dummy_input)\n        model_ori.eval()\n\n        assert torch.isclose(model_ori(dummy_input), model_ori(dummy_input), 1e-8).all()\n\n        def verify_model(pert_weight=True, pert_bias=True, norm=np.inf, lb_name='', ub_name=''):\n            model_ori_ = models.Models['mlp_3layer_weight_perturb'](pert_weight=pert_weight, pert_bias=pert_bias, norm=norm).eval()\n            model_ori_.load_state_dict(state_dict)\n            model_ = BoundedModule(model_ori_, inputs, bound_opts={\n                'sparse_intermediate_bounds': False, 'sparse_conv_intermediate_bounds': False, 'sparse_intermediate_bounds_with_ibp': False})\n            model_.ptb = model_ori.ptb\n\n            self.verify_bounds(model_, dummy_input, IBP=True, method='backward', forward_ret=forward_ret,\n                        lb_name=lb_name + '_CROWN-IBP', ub_name=ub_name + '_CROWN-IBP')  # CROWN-IBP\n            self.verify_bounds(model_, dummy_input, IBP=False, method='backward', forward_ret=forward_ret,\n                        lb_name=lb_name + '_CROWN', ub_name=ub_name + '_CROWN')  # CROWN\n\n        # Linf\n        verify_model(pert_weight=True, pert_bias=True, norm=np.inf, lb_name='l_inf_weights_bias_lb', ub_name='l_inf_weights_bias_ub')\n        verify_model(pert_weight=True, pert_bias=False, norm=np.inf, lb_name='l_inf_weights_lb', ub_name='l_inf_weights_ub')\n        verify_model(pert_weight=False, pert_bias=True, norm=np.inf, lb_name='l_inf_bias_lb', ub_name='l_inf_bias_ub')\n\n        # L2\n        verify_model(pert_weight=True, pert_bias=True, norm=2, lb_name='l_2_weights_bias_lb', ub_name='l_2_weights_bias_ub')\n        verify_model(pert_weight=True, pert_bias=False, norm=2, lb_name='l_2_weights_lb', ub_name='l_2_weights_ub')\n        verify_model(pert_weight=False, pert_bias=True, norm=2, lb_name='l_2_bias_lb', ub_name='l_2_bias_ub')\n\n        if self.generate:\n            self.save()\n\nif __name__ == '__main__':\n    testcase = TestWeightPerturbation(generate=False)\n    testcase.setUp()\n    testcase.reference = testcase._to(testcase.reference, testcase.default_device)\n    testcase.reference = testcase._to(testcase.reference, testcase.default_dtype)\n    testcase.test_perturbation()\n    testcase.test_training()\n"
  },
  {
    "path": "tests/testcase.py",
    "content": "import unittest\nimport random\nimport torch\nimport numpy as np\n\nDEFAULT_DEVICE = 'cpu'\nDEFAULT_DTYPE = torch.float32\n\nclass TestCase(unittest.TestCase):\n    \"\"\"Superclass for unit test cases in auto_LiRPA.\"\"\"\n\n    def __init__(self, methodName='runTest', seed=1, ref_name=None, generate=False,\n                 device=DEFAULT_DEVICE, dtype=DEFAULT_DTYPE):\n\n        super().__init__(methodName)\n\n        self.addTypeEqualityFunc(np.ndarray, '_assert_array_equal')\n        self.addTypeEqualityFunc(torch.Tensor, '_assert_tensor_equal')\n        self.rtol = 1e-5\n        self.atol = 1e-6\n        self.default_dtype = dtype\n        self.default_device = device\n        set_default_dtype_device(dtype, device)\n        self.set_seed(seed)\n        data_path = 'data_64/' if dtype == torch.float64 else 'data/'\n        self.ref_path = data_path + ref_name if ref_name else None\n        self.generate = generate\n        self.setUp()\n\n    def set_seed(self, seed):\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        random.seed(seed)\n        np.random.seed(seed)\n\n    def setUp(self):\n        \"\"\"Load the reference result if it exists.\"\"\"\n        if self.generate:\n            self.reference = None\n        else:\n            self.reference = torch.load(self.ref_path, weights_only=False) if self.ref_path else None\n                        \n    def save(self):\n        \"\"\"Save result for future comparison.\"\"\"\n        print('Saving result to', self.ref_path)\n        torch.save(self.result, self.ref_path)\n\n    def check(self):\n        \"\"\"Save or check the results.\n\n        This function can be called at the end of each test.\n        If `self.generate == True`, save results for future comparison;\n        otherwise, compare the current results `self.result` with the loaded\n        reference `self.reference`. Results are expected to be a list or tuple\n        of `torch.Tensor` instances.\n        \"\"\"\n        if self.generate:\n            self.save()\n        else:\n            self.result = _to(\n                self.result, device=self.default_device, dtype=self.default_dtype)\n            self.reference = _to(\n                self.reference, device=self.default_device, dtype=self.default_dtype)\n            self._assert_equal(self.result, self.reference)\n\n    def _assert_equal(self, a, b):\n        assert type(a) == type(b)\n        if isinstance(a, (list, tuple)):\n            for a_, b_ in zip(a, b):\n                self._assert_equal(a_, b_)\n        else:\n            self.assertEqual(a, b)\n\n    def _assert_array_equal(self, a, b, msg=None):\n        if not a.shape == b.shape:\n            if msg is None:\n                msg = f\"Shapes are not equal: {a.shape} {b.shape}\"\n            raise self.failureException(msg)\n        if not np.allclose(a, b, rtol=self.rtol, atol=self.atol):\n            if msg is None:\n                msg = f\"Arrays are not equal:\\n{a}\\n{b}, max diff: {np.max(np.abs(a - b))}\"\n            raise self.failureException(msg)\n\n    def _assert_tensor_equal(self, a, b, msg=None):\n        if not a.shape == b.shape:\n            if msg is None:\n                msg = f\"Shapes are not equal: {a.shape} {b.shape}\"\n            raise self.failureException(msg)\n        if not torch.allclose(a, b, rtol=self.rtol, atol=self.atol):\n            if msg is None:\n                msg = f\"Tensors are not equal:\\n{a}\\n{b}, max diff: {torch.max(torch.abs(a - b))}\"\n            raise self.failureException(msg)\n\n\ndef _to(obj, device=None, dtype=None, inplace=False):\n    \"\"\" Move all tensors in the object to a specified dest\n    (device or dtype). The inplace=True option is available for dict.\"\"\"\n    if obj is None:\n        return obj\n    elif isinstance(obj, torch.Tensor):\n        return obj.to(device=device if device is not None else obj.device,\n                      dtype=dtype if dtype is not None else obj.dtype)\n    elif isinstance(obj, tuple):\n        return tuple([_to(item, device=device, dtype=dtype) for item in obj])\n    elif isinstance(obj, list):\n        return [_to(item, device=device, dtype=dtype) for item in obj]\n    elif isinstance(obj, dict):\n        if inplace:\n            for k, v in obj.items():\n                obj[k] = _to(v, device=device, dtype=dtype, inplace=True)\n            return obj\n        else:\n            return {k: _to(v, device=device, dtype=dtype) for k, v in obj.items()}\n    else:\n        raise NotImplementedError(f\"Unsupported type: {type(obj)}\")\n\n\ndef set_default_dtype_device(dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE):\n    \"\"\"Utility function to set default dtype and device.\"\"\"\n    torch.set_default_dtype(dtype)\n    torch.set_default_device(torch.device(device))\n\n\n__all__ = ['TestCase', 'DEFAULT_DEVICE',\n           'DEFAULT_DTYPE', '_to', 'set_default_dtype_device']\n"
  }
]