Full Code of openmedlab/XrayPULSE for AI

main 530a0c013e6d cached
49 files
276.2 KB
66.8k tokens
482 symbols
1 requests
Download .txt
Showing preview only (294K chars total). Download the full file or copy to clipboard to get everything.
Repository: openmedlab/XrayPULSE
Branch: main
Commit: 530a0c013e6d
Files: 49
Total size: 276.2 KB

Directory structure:
gitextract_y6ened2f/

├── README.md
├── demo.py
├── demo_configs/
│   └── xraypulse_demo.yaml
├── env.yml
├── prompts/
│   └── alignment.txt
├── run_demo.sh
└── xraypulse/
    ├── __init__.py
    ├── common/
    │   ├── __init__.py
    │   ├── config.py
    │   ├── dist_utils.py
    │   ├── gradcam.py
    │   ├── logger.py
    │   ├── optims.py
    │   ├── registry.py
    │   └── utils.py
    ├── configs/
    │   ├── datasets/
    │   │   ├── mimic/
    │   │   │   └── defaults.yaml
    │   │   └── openi/
    │   │       └── defaults.yaml
    │   ├── default.yaml
    │   └── models/
    │       └── xraypulse.yaml
    ├── conversation/
    │   ├── __init__.py
    │   └── conversation.py
    ├── datasets/
    │   ├── __init__.py
    │   ├── builders/
    │   │   ├── __init__.py
    │   │   ├── base_dataset_builder.py
    │   │   └── image_text_pair_builder.py
    │   ├── data_utils.py
    │   └── datasets/
    │       ├── __init__.py
    │       ├── base_dataset.py
    │       ├── caption_datasets.py
    │       ├── dataloader_utils.py
    │       ├── mimic_dataset.py
    │       └── openi_dataset.py
    ├── models/
    │   ├── Qformer.py
    │   ├── __init__.py
    │   ├── base_model.py
    │   ├── blip2.py
    │   ├── blip2_outputs.py
    │   ├── eva_vit.py
    │   ├── pos_embed.py
    │   └── xray_pulse.py
    ├── processors/
    │   ├── __init__.py
    │   ├── base_processor.py
    │   ├── blip_processors.py
    │   └── randaugment.py
    ├── runners/
    │   ├── __init__.py
    │   └── runner_base.py
    └── tasks/
        ├── __init__.py
        ├── base_task.py
        └── image_text_pretrain.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# XrayPULSE

<!--
**Here are some ideas to get you started:**
🙋‍♀️ A short introduction - what is your organization all about?
🌈 Contribution guidelines - how can the community get involved?
👩‍💻 Useful resources - where can the community find your docs? Is there anything else the community should know?
🍿 Fun facts - what does your team eat for breakfast?
🧙 Remember, you can do mighty things with the power of [Markdown](https://docs.github.com/github/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)
-->

<!-- Insert the project banner here -->

<div align="center">
    <a href="https://"><img width="1000px" height="auto" src="./banner.png"></a>
</div>


---

<!-- Select some of the point info, feel free to delete -->


## Key Features

This repository provides the official implementation of XrayPULSE: 

Key feature bulletin points here

- An attempt to extend [PULSE]() to a biomedical multimodal conversational assistant. 
- XrayPULSE is fintuned on Xray-Report paired datasets in Chinese


## Details

Our model is based on PULSE. We utilize [MedCLIP](https://github.com/RyanWangZf/MedCLIP)  as our medical visual encoder and Q-former ([BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2)) following a simple linear transformation as the adapter to inject the image to PULSE. For aligning the frozen visual encoder and the LLM by the adapter, we generate Chinese-version Xray-Report paired data from free-text radiology reports of two datasets ([MIMIC-CXR](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) and [OpenI](https://openi.nlm.nih.gov/faq#collection)) with the help of chatGPT.  To facilitate research in biomedical multimodal learning, we will release the data to the public.

<!-- Insert a pipeline of your algorithm here if got one -->

<div align="center">
    <a href="https://"><img width="1000px" height="auto" src="./framework.png"></a>
</div>



## Get Started

**Installation**

```bash
git clone https://github.com/openmedlab/XrayPULSE.git
cd XrayPULSE
```

**Environment**

```bash
conda env create -f env.yml
conda activate xraypulse
```

**Prepare the pretrained weights**

You can find the pretrained model weights.

- [PULSE\_Model](https://huggingface.co/OpenMEDLab/PULSE-7bv5) 
- [Pretrained_XrayPULSE_Checkpoint](https://drive.google.com/file/d/1VsO61-3DFuK4ysGPvoD4_JZaRFKvAJR_/view?usp=drive_link)

The weights of PULSE would be in a single folder in a structure similar to the following:

```
pulse_weights
├── config.json
├── generation_config.json
├── tokenizer.json
├── tokenizer_config.json
├── special_tokens_map.json 
├── pytorch_model.bin.index.json
├── pytorch_model-00001-of-00002.bin
├── pytorch_model-00002-of-00002.bin 
```

Then, set the path of pulse_weights to "bloom_model" in the model config file "xraypulse/configs/models/xraypulse.yaml"

And add the path of the pretrained checkpoint  in "demo_configs/xraypulse_demo.yaml".

**Run Demo**

```bash
bash run_demo.sh
```



## 🙏 Acknowledgement
This project is built upon the gaint sholders of [XrayGPT](https://github.com/mbzuai-oryx/XrayGPT). Great thanks to it!

We used medical aware image encoder from [MedCLIP](https://github.com/RyanWangZf/MedCLIP).

The model architecture of XrayGPT follows [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2).


## 🛡️ License

This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.


================================================
FILE: demo.py
================================================
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import gradio as gr

from xraypulse.common.config import Config
from xraypulse.common.dist_utils import get_rank
from xraypulse.common.registry import registry
from xraypulse.conversation.conversation import Chat, CONV_ZH

# imports modules for registration
from xraypulse.datasets.builders import *
from xraypulse.models import *
from xraypulse.processors import *
from xraypulse.runners import *
from xraypulse.tasks import *


def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


# ========================================
#             Model Initialization
# ========================================

print('Initializing Chat')
args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
print(model_config)
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
print(model_cls)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.openi.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

# ========================================
#             Gradio Setting
# ========================================

def gradio_reset(chat_state, img_list):
    if chat_state is not None:
        chat_state.messages = []
    if img_list is not None:
        img_list = []
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='请先上传图片', interactive=False),gr.update(value="上传图片并开始咨询", interactive=True), chat_state, img_list

def upload_img(gr_img, text_input, chat_state):
    if gr_img is None:
        return None, None, gr.update(interactive=True), chat_state, None
    chat_state = CONV_ZH.copy()
    img_list = []
    llm_message = chat.upload_img(gr_img, chat_state, img_list)
    return gr.update(interactive=False), gr.update(interactive=True, placeholder='输入问题'), gr.update(value="开始对话", interactive=False), chat_state, img_list

def gradio_ask(user_message, chatbot, chat_state):
    if len(user_message) == 0:
        return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
    chat.ask(user_message, chat_state)
    chatbot = chatbot + [[user_message, None]]
    return '', chatbot, chat_state


def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
    llm_message = chat.answer(conv=chat_state,
                              img_list=img_list,
                              num_beams=num_beams,
                              temperature=temperature,
                              max_new_tokens=300,
                              max_length=2000)[0]
    chatbot[-1][1] = llm_message
    return chatbot, chat_state, img_list

title = """<h1 align="center"> XrayPULSE </h1>"""
description = """<h3>上传X光影像,开始诊断咨询</h3>"""
disclaimer = """ 
            <h1 >使用说明:</h1>
            <ul> 
                <li>XrayPULSE为PULSE在医疗多模态领域的扩展应用之一,可以用于对X光影像进行医学诊断分析,辅助医生,并为患者提供诊断支持。</li>
                <li>XrayPULSE尝试通过分析X光影像提供准确和有用的结果。然而,我们对所提供结果的有效性、可靠性或完整性不作任何明确的保证或陈述。我们需要不断改善和完善服务,为医疗专业人员提供最好的协助</li>
            </ul>
            <hr> 
            <h3 align="center">OpenMedLab</h3>

            """

def set_example_xray(example: list) -> dict:
    return gr.Image.update(value=example[0])


def set_example_text_input(example_text: str) -> dict:
    return gr.Textbox.update(value=example_text[0])

#TODO show examples below

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        with gr.Column(scale=0.5):
            image = gr.Image(type="pil")
            upload_button = gr.Button(value="上传影像并开始咨询", interactive=True, variant="primary")
            clear = gr.Button("重制")
            
            num_beams = gr.Slider(
                minimum=1,
                maximum=10,
                value=1,
                step=1,
                interactive=True,
                label="beam search numbers",
            )
            
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Temperature",
            )

        with gr.Column():
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(label='XrayPULSE')
            text_input = gr.Textbox(label='用户', placeholder='请上传X光影像', interactive=False)


    with gr.Row():
        example_xrays = gr.Dataset(components=[image], label="X光影像范例",
                                    samples=[
                                        [os.path.join(os.path.dirname(__file__), "images/image1.png")],
                                        [os.path.join(os.path.dirname(__file__), "images/image2.png")],
                                        [os.path.join(os.path.dirname(__file__), "images/image3.png")],
                                        [os.path.join(os.path.dirname(__file__), "images/image4.png")],
                                        [os.path.join(os.path.dirname(__file__), "images/image5.png")],
                                        [os.path.join(os.path.dirname(__file__), "images/image6.png")],
                                    ])
        

    with gr.Row():
        example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],
                                    label="咨询问题范例",
                                    samples=[
                                        ["详细描述所给的胸部X光影像。"],
                                        ["请观察这张胸部X光影像,并阐述你的发现和总结。"],
                                        ["你能否对所给的胸部X光影像进行详细的描述?"],
                                        ["尽可能详细地描述所给的胸部X光影像。"],
                                        ["这张胸部X光影像中的关键症状是什么?"],
                                        ["你能在这张胸部X光影像中,指出存在的任何异常或需要注意的地方吗"],
                                        ["这张胸部X光影像中,有哪些肺部和心脏的具体特征可见?"],
                                        ["在这张胸部X光影像中,最显著的特征是什么,它是如何反映出病人的健康状况?"],
                                        ["根据从这张胸部X光影像中观察到的发现,给出影像的总体印象是正常还是异常?"],
                                    ],)
    
    example_xrays.click(fn=set_example_xray, inputs=example_xrays, outputs=example_xrays.components)

    upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
    
    click_response = example_texts.click(set_example_text_input, inputs=example_texts, outputs=text_input).then(
        gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state], queue=False)
    click_response.then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list], queue=False
    )
    
    submit_response = text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state], queue=False)
    submit_response.then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list], queue=False
    )
    clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
    
    gr.Markdown(disclaimer)
demo.launch(share=True, enable_queue=True)


================================================
FILE: demo_configs/xraypulse_demo.yaml
================================================
model:
  arch: xray_pulse
  model_type: pulse
  freeze_vit: True
  freeze_qformer: True
  max_txt_len: 160
  end_sym: "</s>"
  low_resource: True
  prompt_path: "prompts/alignment.txt"
  prompt_template: 'Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01</s> User: {} </s> Helper: '
  ckpt: './XrayPULSE_ckpt.pth'

datasets:
  openi:
    vis_processor:
      train:
        name: "blip2_image_eval"
        image_size: 224
    text_processor:
      train:
        name: "blip_caption"

run:
  task: image_text_pretrain


================================================
FILE: env.yml
================================================
name: xraypulse
channels:
  - pytorch
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - blas=1.0=mkl
  - brotlipy=0.7.0=py39h27cfd23_1003
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.01.10=h06a4308_0
  - certifi=2022.12.7=py39h06a4308_0
  - cffi=1.15.1=py39h5eee18b_3
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cryptography=39.0.1=py39h9ce1e76_0
  - cudatoolkit=11.3.1=h2bc3f7f_2
  - ffmpeg=4.3=hf484d3e_0
  - flit-core=3.8.0=py39h06a4308_0
  - freetype=2.12.1=h4a9f257_0
  - giflib=5.2.1=h5eee18b_3
  - gmp=6.2.1=h295c915_3
  - gnutls=3.6.15=he1e5248_0
  - intel-openmp=2021.4.0=h06a4308_3561
  - jpeg=9e=h5eee18b_1
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - lerc=3.0=h295c915_0
  - libdeflate=1.17=h5eee18b_0
  - libffi=3.4.2=h6a678d5_6
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.2=h7f8727e_0
  - libpng=1.6.39=h5eee18b_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.19.0=h5eee18b_0
  - libtiff=4.5.0=h6a678d5_2
  - libunistring=0.9.10=h27cfd23_0
  - libwebp=1.2.4=h11a3e52_1
  - libwebp-base=1.2.4=h5eee18b_1
  - lz4-c=1.9.4=h6a678d5_0
  - mkl=2021.4.0=h06a4308_640
  - mkl-service=2.4.0=py39h7f8727e_0
  - mkl_fft=1.3.1=py39hd3c417c_0
  - mkl_random=1.2.2=py39h51133e4_0
  - ncurses=6.4=h6a678d5_0
  - nettle=3.7.3=hbbd107a_1
  - numpy=1.23.5=py39h14f4228_0
  - numpy-base=1.23.5=py39h31eccc5_0
  - openh264=2.1.1=h4ff587b_0
  - openssl=1.1.1t=h7f8727e_0
  - pillow=9.4.0=py39h6a678d5_0
  - pip=23.0.1=py39h06a4308_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pyopenssl=23.0.0=py39h06a4308_0
  - pysocks=1.7.1=py39h06a4308_0
  - python=3.9.16=h7a1cb2a_2
  - pytorch-mutex=1.0=cuda
  - readline=8.2=h5eee18b_0
  - requests=2.28.1=py39h06a4308_1
  - setuptools=66.0.0=py39h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - torchaudio=0.12.1=py39_cu113
  - torchvision=0.13.1=py39_cu113
  - typing_extensions=4.4.0=py39h06a4308_0
  - urllib3=1.26.15=py39h06a4308_0
  - wheel=0.38.4=py39h06a4308_0
  - xz=5.2.10=h5eee18b_1
  - zlib=1.2.13=h5eee18b_0
  - zstd=1.5.5=hc292b87_0
  - pip:
      - accelerate==0.15.0
      - aiofiles==23.1.0
      - aiohttp==3.8.4
      - aiosignal==1.3.1
      - albumentations==1.3.0
      - altair==4.2.2
      - antlr4-python3-runtime==4.9.3
      - anyio==3.6.2
      - appdirs==1.4.4
      - argon2-cffi==21.3.0
      - argon2-cffi-bindings==21.2.0
      - arrow==1.2.3
      - asttokens==2.2.1
      - async-timeout==4.0.2
      - attrs==22.2.0
      - backcall==0.2.0
      - beautifulsoup4==4.12.2
      - bitsandbytes==0.37.0
      - bleach==6.0.0
      - blis==0.7.9
      - braceexpand==0.1.7
      - cachetools==5.3.0
      - catalogue==2.0.8
      - cchardet==2.1.7
      - chardet==3.0.4
      - click==8.1.3
      - cmake==3.26.3
      - comm==0.1.3
      - confection==0.0.4
      - contourpy==1.0.7
      - cycler==0.11.0
      - cymem==2.0.7
      - dataclasses==0.6
      - datasets==2.12.0
      - debugpy==1.6.7
      - decorator==5.1.1
      - decord==0.6.0
      - defusedxml==0.7.1
      - dill==0.3.6
      - docker-pycreds==0.4.0
      - entrypoints==0.4
      - et-xmlfile==1.1.0
      - evaluate==0.4.0
      - executing==1.2.0
      - exifread-nocycle==3.0.1
      - fairscale==0.4.13
      - fastapi==0.95.1
      - fastchat==0.1
      - fastjsonschema==2.16.3
      - ffmpy==0.3.0
      - filelock==3.9.0
      - fire==0.5.0
      - fonttools==4.38.0
      - fqdn==1.5.1
      - frozenlist==1.3.3
      - fschat==0.2.3
      - fsspec==2022.11.0
      - gensim==4.3.1
      - gitdb==4.0.10
      - gitpython==3.1.31
      - googletrans==3.0.0
      - gradio==3.23.0
      - gradio-client==0.0.8
      - h11==0.9.0
      - h2==3.2.0
      - hiq-python==1.1.12
      - hpack==3.0.0
      - hstspreload==2023.1.1
      - httpcore==0.9.1
      - httpx==0.13.3
      - huggingface-hub==0.13.4
      - hyperframe==5.2.0
      - idna==2.10
      - imageio==2.27.0
      - img2dataset==1.25.4
      - importlib-metadata==6.5.0
      - importlib-resources==5.12.0
      - iopath==0.1.10
      - ipykernel==6.22.0
      - ipython==8.12.0
      - ipython-genutils==0.2.0
      - isoduration==20.11.0
      - jedi==0.18.2
      - jinja2==3.1.2
      - joblib==1.2.0
      - jsonpointer==2.3
      - jsonschema==4.17.3
      - jupyter-client==8.2.0
      - jupyter-core==5.3.0
      - jupyter-events==0.6.3
      - jupyter-server==2.5.0
      - jupyter-server-terminals==0.4.4
      - jupyterlab-pygments==0.2.2
      - kiwisolver==1.4.4
      - langcodes==3.3.0
      - lazy-loader==0.2
      - linkify-it-py==2.0.0
      - lit==16.0.1
      - llvmlite==0.39.1
      - markdown-it-py==2.2.0
      - markdown2==2.4.8
      - markupsafe==2.1.2
      - matplotlib==3.7.0
      - matplotlib-inline==0.1.6
      - mdit-py-plugins==0.3.3
      - mdurl==0.1.2
      - medclip==0.0.3
      - mistune==2.0.5
      - mpmath==1.3.0
      - multidict==6.0.4
      - multiprocess==0.70.14
      - murmurhash==1.0.9
      - nbclassic==0.5.5
      - nbclient==0.7.3
      - nbconvert==7.3.1
      - nbformat==5.8.0
      - nest-asyncio==1.5.6
      - networkx==3.1
      - nltk==3.8.1
      - notebook==6.5.4
      - notebook-shim==0.2.2
      - numba==0.56.4
      - nvidia-cublas-cu11==11.10.3.66
      - nvidia-cuda-cupti-cu11==11.7.101
      - nvidia-cuda-nvrtc-cu11==11.7.99
      - nvidia-cuda-runtime-cu11==11.7.99
      - nvidia-cudnn-cu11==8.5.0.96
      - nvidia-cufft-cu11==10.9.0.58
      - nvidia-curand-cu11==10.2.10.91
      - nvidia-cusolver-cu11==11.4.0.1
      - nvidia-cusparse-cu11==11.7.4.91
      - nvidia-nccl-cu11==2.14.3
      - nvidia-nvtx-cu11==11.7.91
      - omegaconf==2.3.0
      - openai==0.27.0
      - opencv-python==4.7.0.72
      - opencv-python-headless==4.7.0.72
      - openpyxl==3.1.2
      - orjson==3.8.10
      - packaging==23.0
      - pandas==1.5.3
      - pandocfilters==1.5.0
      - parso==0.8.3
      - pathtools==0.1.2
      - pathy==0.10.1
      - peft==0.2.0
      - pexpect==4.8.0
      - pickleshare==0.7.5
      - platformdirs==3.2.0
      - portalocker==2.7.0
      - preshed==3.0.8
      - prometheus-client==0.16.0
      - promise==2.3
      - prompt-toolkit==3.0.38
      - protobuf==3.20.3
      - psutil==5.9.4
      - ptyprocess==0.7.0
      - pure-eval==0.2.2
      - py-itree==0.0.19
      - pyarrow==12.0.1
      - pydantic==1.10.7
      - pydub==0.25.1
      - pygments==2.15.1
      - pyllama==0.0.9
      - pynndescent==0.5.9
      - pyparsing==3.0.9
      - pyrsistent==0.19.3
      - python-dateutil==2.8.2
      - python-json-logger==2.0.7
      - python-multipart==0.0.6
      - pytz==2023.3
      - pywavelets==1.4.1
      - pyyaml==6.0
      - pyzmq==25.0.2
      - qudida==0.0.4
      - regex==2022.10.31
      - responses==0.18.0
      - rfc3339-validator==0.1.4
      - rfc3986==1.5.0
      - rfc3986-validator==0.1.1
      - rich==13.3.4
      - scikit-image==0.20.0
      - scikit-learn==1.2.2
      - scipy==1.9.1
      - semantic-version==2.10.0
      - send2trash==1.8.0
      - sentence-transformers==2.2.2
      - sentencepiece==0.1.97
      - sentry-sdk==1.19.1
      - setproctitle==1.3.2
      - shortuuid==1.0.11
      - smart-open==6.3.0
      - smmap==5.0.0
      - sniffio==1.3.0
      - soupsieve==2.4.1
      - spacy==3.5.1
      - spacy-legacy==3.0.12
      - spacy-loggers==1.0.4
      - srsly==2.4.6
      - stack-data==0.6.2
      - starlette==0.26.1
      - svgwrite==1.4.3
      - sympy==1.11.1
      - tenacity==8.2.2
      - termcolor==2.2.0
      - terminado==0.17.1
      - textaugment==1.3.4
      - textblob==0.17.1
      - thinc==8.1.9
      - threadpoolctl==3.1.0
      - tifffile==2023.4.12
      - timm==0.6.13
      - tinycss2==1.2.1
      - tokenizers==0.13.2
      - toolz==0.12.0
      - torch==2.0.0
      - tornado==6.3
      - tqdm==4.64.1
      - traitlets==5.9.0
      - transformers==4.29.0
      - triton==2.0.0
      - typer==0.7.0
      - tzdata==2023.3
      - uc-micro-py==1.0.1
      - umap-learn==0.5.3
      - uri-template==1.2.0
      - uvicorn==0.21.1
      - wandb==0.12.21
      - wasabi==1.1.1
      - wavedrom==2.0.3.post3
      - wcwidth==0.2.6
      - webcolors==1.13
      - webdataset==0.2.48
      - webencodings==0.5.1
      - websocket-client==1.5.1
      - websockets==11.0.2
      - wget==3.2
      - xxhash==3.2.0
      - yarl==1.8.2
      - zipp==3.14.0


================================================
FILE: prompts/alignment.txt
================================================
<Img><图片></Img> 详细描述所给的胸部X光影像。
<Img><图片></Img> 请观察这张胸部X光影像,并阐述你的发现和总结。
<Img><图片></Img> 你能否对所给的胸部X光影像进行详细的描述?
<Img><图片></Img> 尽可能详细地描述所给的胸部X光影像。
<Img><图片></Img> 这张胸部X光影像中的关键症状是什么?
<Img><图片></Img> 你能在这张胸部X光影像中,指出存在的任何异常或需要注意的地方吗?
<Img><图片></Img> 这张胸部X光影像中,有哪些肺部和心脏的具体特征可见?
<Img><图片></Img> 在这张胸部X光影像中,最显著的特征是什么,它是如何反映出病人的健康状况?
<Img><图片></Img> 这张胸部X光影像提供了哪些观察发现和总体印象?
<Img><图片></Img> 这张胸部X光影像中,心脏的大小和形状如何?
<Img><图片></Img> 根据从这张胸部X光影像中观察到的发现,给出影像的总体印象是正常还是异常?
<Img><图片></Img> 在这张胸部X光影像中,有无感染或炎症的迹象?如果有,可能的原因是什么?
<Img><图片></Img> 根据这张胸部X光影像中的发现,请你给出总体印象。
<Img><图片></Img> 在这张胸部X光影像中,有没有患者淋巴结肿大或异常的可见迹象
<Img><图片></Img> 这张胸部X光影像中观察到的异常有没有可能引发的并发症或风险?或者说,这张X光影像所展示的患者是正常的吗

================================================
FILE: run_demo.sh
================================================
CUDA_VISIBLE_DEVICES=0 python -u demo.py --cfg-path demo_configs/xraypulse_demo.yaml  --gpu-id 0

================================================
FILE: xraypulse/__init__.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import sys

from omegaconf import OmegaConf

from xraypulse.common.registry import registry

from xraypulse.datasets.builders import *
from xraypulse.models import *
from xraypulse.processors import *
from xraypulse.tasks import *


root_dir = os.path.dirname(os.path.abspath(__file__))
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))

registry.register_path("library_root", root_dir)
repo_root = os.path.join(root_dir, "..")
registry.register_path("repo_root", repo_root)
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
registry.register_path("cache_root", cache_root)

registry.register("MAX_INT", sys.maxsize)
registry.register("SPLIT_NAMES", ["train", "val", "test"])


================================================
FILE: xraypulse/common/__init__.py
================================================


================================================
FILE: xraypulse/common/config.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import json
from typing import Dict

from omegaconf import OmegaConf
from xraypulse.common.registry import registry


class Config:
    def __init__(self, args):
        self.config = {}

        self.args = args

        # Register the config and configuration for setup
        registry.register("configuration", self)

        user_config = self._build_opt_list(self.args.options)

        config = OmegaConf.load(self.args.cfg_path)

        runner_config = self.build_runner_config(config)
        model_config = self.build_model_config(config, **user_config)
        dataset_config = self.build_dataset_config(config)

        # Validate the user-provided runner configuration
        # model and dataset configuration are supposed to be validated by the respective classes
        # [TODO] validate the model/dataset configuration
        # self._validate_runner_config(runner_config)

        # Override the default configuration with user options.
        self.config = OmegaConf.merge(
            runner_config, model_config, dataset_config, user_config
        )

    def _validate_runner_config(self, runner_config):
        """
        This method validates the configuration, such that
            1) all the user specified options are valid;
            2) no type mismatches between the user specified options and the config.
        """
        runner_config_validator = create_runner_config_validator()
        runner_config_validator.validate(runner_config)

    def _build_opt_list(self, opts):
        opts_dot_list = self._convert_to_dot_list(opts)
        return OmegaConf.from_dotlist(opts_dot_list)

    @staticmethod
    def build_model_config(config, **kwargs):
        model = config.get("model", None)
        assert model is not None, "Missing model configuration file."

        model_cls = registry.get_model_class(model.arch)
        assert model_cls is not None, f"Model '{model.arch}' has not been registered."

        model_type = kwargs.get("model.model_type", None)
        if not model_type:
            model_type = model.get("model_type", None)
        # else use the model type selected by user.

        assert model_type is not None, "Missing model_type."

        model_config_path = model_cls.default_config_path(model_type=model_type)

        model_config = OmegaConf.create()
        # hierarchy override, customized config > default config
        model_config = OmegaConf.merge(
            model_config,
            OmegaConf.load(model_config_path),
            {"model": config["model"]},
        )

        return model_config

    @staticmethod
    def build_runner_config(config):
        return {"run": config.run}

    @staticmethod
    def build_dataset_config(config):
        datasets = config.get("datasets", None)
        if datasets is None:
            raise KeyError(
                "Expecting 'datasets' as the root key for dataset configuration."
            )

        dataset_config = OmegaConf.create()

        for dataset_name in datasets:
            builder_cls = registry.get_builder_class(dataset_name)

            dataset_config_type = datasets[dataset_name].get("type", "default")
            dataset_config_path = builder_cls.default_config_path(
                type=dataset_config_type
            )

            # hierarchy override, customized config > default config
            dataset_config = OmegaConf.merge(
                dataset_config,
                OmegaConf.load(dataset_config_path),
                {"datasets": {dataset_name: config["datasets"][dataset_name]}},
            )

        return dataset_config

    def _convert_to_dot_list(self, opts):
        if opts is None:
            opts = []

        if len(opts) == 0:
            return opts

        has_equal = opts[0].find("=") != -1

        if has_equal:
            return opts

        return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]

    def get_config(self):
        return self.config

    @property
    def run_cfg(self):
        return self.config.run

    @property
    def datasets_cfg(self):
        return self.config.datasets

    @property
    def model_cfg(self):
        return self.config.model

    def pretty_print(self):
        logging.info("\n=====  Running Parameters    =====")
        logging.info(self._convert_node_to_json(self.config.run))

        logging.info("\n======  Dataset Attributes  ======")
        datasets = self.config.datasets

        for dataset in datasets:
            if dataset in self.config.datasets:
                logging.info(f"\n======== {dataset} =======")
                dataset_config = self.config.datasets[dataset]
                logging.info(self._convert_node_to_json(dataset_config))
            else:
                logging.warning(f"No dataset named '{dataset}' in config. Skipping")

        logging.info(f"\n======  Model Attributes  ======")
        logging.info(self._convert_node_to_json(self.config.model))

    def _convert_node_to_json(self, node):
        container = OmegaConf.to_container(node, resolve=True)
        return json.dumps(container, indent=4, sort_keys=True)

    def to_dict(self):
        return OmegaConf.to_container(self.config)


def node_to_dict(node):
    return OmegaConf.to_container(node)


class ConfigValidator:
    """
    This is a preliminary implementation to centralize and validate the configuration.
    May be altered in the future.

    A helper class to validate configurations from yaml file.

    This serves the following purposes:
        1. Ensure all the options in the yaml are defined, raise error if not.
        2. when type mismatches are found, the validator will raise an error.
        3. a central place to store and display helpful messages for supported configurations.

    """

    class _Argument:
        def __init__(self, name, choices=None, type=None, help=None):
            self.name = name
            self.val = None
            self.choices = choices
            self.type = type
            self.help = help

        def __str__(self):
            s = f"{self.name}={self.val}"
            if self.type is not None:
                s += f", ({self.type})"
            if self.choices is not None:
                s += f", choices: {self.choices}"
            if self.help is not None:
                s += f", ({self.help})"
            return s

    def __init__(self, description):
        self.description = description

        self.arguments = dict()

        self.parsed_args = None

    def __getitem__(self, key):
        assert self.parsed_args is not None, "No arguments parsed yet."

        return self.parsed_args[key]

    def __str__(self) -> str:
        return self.format_help()

    def add_argument(self, *args, **kwargs):
        """
        Assume the first argument is the name of the argument.
        """
        self.arguments[args[0]] = self._Argument(*args, **kwargs)

    def validate(self, config=None):
        """
        Convert yaml config (dict-like) to list, required by argparse.
        """
        for k, v in config.items():
            assert (
                k in self.arguments
            ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""

            if self.arguments[k].type is not None:
                try:
                    self.arguments[k].val = self.arguments[k].type(v)
                except ValueError:
                    raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")

            if self.arguments[k].choices is not None:
                assert (
                    v in self.arguments[k].choices
                ), f"""{k} must be one of {self.arguments[k].choices}."""

        return config

    def format_arguments(self):
        return str([f"{k}" for k in sorted(self.arguments.keys())])

    def format_help(self):
        # description + key-value pair string for each argument
        help_msg = str(self.description)
        return help_msg + ", available arguments: " + self.format_arguments()

    def print_help(self):
        # display help message
        print(self.format_help())


def create_runner_config_validator():
    validator = ConfigValidator(description="Runner configurations")

    validator.add_argument(
        "runner",
        type=str,
        choices=["runner_base", "runner_iter"],
        help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
            runner runs based on iters. Default: runner_base""",
    )
    # add argumetns for training dataset ratios
    validator.add_argument(
        "train_dataset_ratios",
        type=Dict[str, float],
        help="""Ratios of training dataset. This is used in iteration-based runner.
        Do not support for epoch-based runner because how to define an epoch becomes tricky.
        Default: None""",
    )
    validator.add_argument(
        "max_iters",
        type=float,
        help="Maximum number of iterations to run.",
    )
    validator.add_argument(
        "max_epoch",
        type=int,
        help="Maximum number of epochs to run.",
    )
    # add arguments for iters_per_inner_epoch
    validator.add_argument(
        "iters_per_inner_epoch",
        type=float,
        help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
    )
    lr_scheds_choices = registry.list_lr_schedulers()
    validator.add_argument(
        "lr_sched",
        type=str,
        choices=lr_scheds_choices,
        help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
    )
    task_choices = registry.list_tasks()
    validator.add_argument(
        "task",
        type=str,
        choices=task_choices,
        help="Task to use, from {}".format(task_choices),
    )
    # add arguments for init_lr
    validator.add_argument(
        "init_lr",
        type=float,
        help="Initial learning rate. This will be the learning rate after warmup and before decay.",
    )
    # add arguments for min_lr
    validator.add_argument(
        "min_lr",
        type=float,
        help="Minimum learning rate (after decay).",
    )
    # add arguments for warmup_lr
    validator.add_argument(
        "warmup_lr",
        type=float,
        help="Starting learning rate for warmup.",
    )
    # add arguments for learning rate decay rate
    validator.add_argument(
        "lr_decay_rate",
        type=float,
        help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
    )
    # add arguments for weight decay
    validator.add_argument(
        "weight_decay",
        type=float,
        help="Weight decay rate.",
    )
    # add arguments for training batch size
    validator.add_argument(
        "batch_size_train",
        type=int,
        help="Training batch size.",
    )
    # add arguments for evaluation batch size
    validator.add_argument(
        "batch_size_eval",
        type=int,
        help="Evaluation batch size, including validation and testing.",
    )
    # add arguments for number of workers for data loading
    validator.add_argument(
        "num_workers",
        help="Number of workers for data loading.",
    )
    # add arguments for warm up steps
    validator.add_argument(
        "warmup_steps",
        type=int,
        help="Number of warmup steps. Required if a warmup schedule is used.",
    )
    # add arguments for random seed
    validator.add_argument(
        "seed",
        type=int,
        help="Random seed.",
    )
    # add arguments for output directory
    validator.add_argument(
        "output_dir",
        type=str,
        help="Output directory to save checkpoints and logs.",
    )
    # add arguments for whether only use evaluation
    validator.add_argument(
        "evaluate",
        help="Whether to only evaluate the model. If true, training will not be performed.",
    )
    # add arguments for splits used for training, e.g. ["train", "val"]
    validator.add_argument(
        "train_splits",
        type=list,
        help="Splits to use for training.",
    )
    # add arguments for splits used for validation, e.g. ["val"]
    validator.add_argument(
        "valid_splits",
        type=list,
        help="Splits to use for validation. If not provided, will skip the validation.",
    )
    # add arguments for splits used for testing, e.g. ["test"]
    validator.add_argument(
        "test_splits",
        type=list,
        help="Splits to use for testing. If not provided, will skip the testing.",
    )
    # add arguments for accumulating gradient for iterations
    validator.add_argument(
        "accum_grad_iters",
        type=int,
        help="Number of iterations to accumulate gradient for.",
    )

    # ====== distributed training ======
    validator.add_argument(
        "device",
        type=str,
        choices=["cpu", "cuda"],
        help="Device to use. Support 'cuda' or 'cpu' as for now.",
    )
    validator.add_argument(
        "world_size",
        type=int,
        help="Number of processes participating in the job.",
    )
    validator.add_argument("dist_url", type=str)
    validator.add_argument("distributed", type=bool)
    # add arguments to opt using distributed sampler during evaluation or not
    validator.add_argument(
        "use_dist_eval_sampler",
        type=bool,
        help="Whether to use distributed sampler during evaluation or not.",
    )

    # ====== task specific ======
    # generation task specific arguments
    # add arguments for maximal length of text output
    validator.add_argument(
        "max_len",
        type=int,
        help="Maximal length of text output.",
    )
    # add arguments for minimal length of text output
    validator.add_argument(
        "min_len",
        type=int,
        help="Minimal length of text output.",
    )
    # add arguments number of beams
    validator.add_argument(
        "num_beams",
        type=int,
        help="Number of beams used for beam search.",
    )

    # vqa task specific arguments
    # add arguments for number of answer candidates
    validator.add_argument(
        "num_ans_candidates",
        type=int,
        help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
    )
    # add arguments for inference method
    validator.add_argument(
        "inference_method",
        type=str,
        choices=["genearte", "rank"],
        help="""Inference method to use for question answering. If rank, requires a answer list.""",
    )

    # ====== model specific ======
    validator.add_argument(
        "k_test",
        type=int,
        help="Number of top k most similar samples from ITC/VTC selection to be tested.",
    )

    return validator


================================================
FILE: xraypulse/common/dist_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import functools
import os

import torch
import torch.distributed as dist
import timm.models.hub as timm_hub


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def init_distributed_mode(args):
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = "nccl"
    print(
        "| distributed init (rank {}, world {}): {}".format(
            args.rank, args.world_size, args.dist_url
        ),
        flush=True,
    )
    torch.distributed.init_process_group(
        backend=args.dist_backend,
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
        timeout=datetime.timedelta(
            days=365
        ),  # allow auto-downloading and de-compressing
    )
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)


def get_dist_info():
    if torch.__version__ < "1.0":
        initialized = dist._initialized
    else:
        initialized = dist.is_initialized()
    if initialized:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:  # non-distributed training
        rank = 0
        world_size = 1
    return rank, world_size


def main_process(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        rank, _ = get_dist_info()
        if rank == 0:
            return func(*args, **kwargs)

    return wrapper


def download_cached_file(url, check_hash=True, progress=False):
    """
    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
    """

    def get_cached_file_path():
        # a hack to sync the file path across processes
        parts = torch.hub.urlparse(url)
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)

        return cached_file

    if is_main_process():
        timm_hub.download_cached_file(url, check_hash, progress)

    if is_dist_avail_and_initialized():
        dist.barrier()

    return get_cached_file_path()


================================================
FILE: xraypulse/common/gradcam.py
================================================
import numpy as np
from matplotlib import pyplot as plt
from scipy.ndimage import filters
from skimage import transform as skimage_transform


def getAttMap(img, attMap, blur=True, overlap=True):
    attMap -= attMap.min()
    if attMap.max() > 0:
        attMap /= attMap.max()
    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
    if blur:
        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
        attMap -= attMap.min()
        attMap /= attMap.max()
    cmap = plt.get_cmap("jet")
    attMapV = cmap(attMap)
    attMapV = np.delete(attMapV, 3, 2)
    if overlap:
        attMap = (
            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
        )
    return attMap


================================================
FILE: xraypulse/common/logger.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import datetime
import logging
import time
from collections import defaultdict, deque

import torch
import torch.distributed as dist

from xraypulse.common import dist_utils


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not dist_utils.is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value,
        )


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(
            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
        )

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append("{}: {}".format(name, str(meter)))
        return self.delimiter.join(loss_str)

    def global_avg(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        log_msg = [
            header,
            "[{0" + space_fmt + "}/{1}]",
            "eta: {eta}",
            "{meters}",
            "time: {time}",
            "data: {data}",
        ]
        if torch.cuda.is_available():
            log_msg.append("max mem: {memory:.0f}")
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(
            "{} Total time: {} ({:.4f} s / it)".format(
                header, total_time_str, total_time / len(iterable)
            )
        )


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def setup_logger():
    logging.basicConfig(
        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler()],
    )


================================================
FILE: xraypulse/common/optims.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import math

from xraypulse.common.registry import registry


@registry.register_lr_scheduler("linear_warmup_step_lr")
class LinearWarmupStepLRScheduler:
    def __init__(
        self,
        optimizer,
        max_epoch,
        min_lr,
        init_lr,
        decay_rate=1,
        warmup_start_lr=-1,
        warmup_steps=0,
        **kwargs
    ):
        self.optimizer = optimizer

        self.max_epoch = max_epoch
        self.min_lr = min_lr

        self.decay_rate = decay_rate

        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr

    def step(self, cur_epoch, cur_step):
        if cur_epoch == 0:
            warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        else:
            step_lr_schedule(
                epoch=cur_epoch,
                optimizer=self.optimizer,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
                decay_rate=self.decay_rate,
            )


@registry.register_lr_scheduler("linear_warmup_cosine_lr")
class LinearWarmupCosineLRScheduler:
    def __init__(
        self,
        optimizer,
        max_epoch,
        iters_per_epoch,
        min_lr,
        init_lr,
        warmup_steps=0,
        warmup_start_lr=-1,
        **kwargs
    ):
        self.optimizer = optimizer

        self.max_epoch = max_epoch
        self.iters_per_epoch = iters_per_epoch
        self.min_lr = min_lr

        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr

    def step(self, cur_epoch, cur_step):
        total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
        if total_cur_step < self.warmup_steps:
            warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        else:
            cosine_lr_schedule(
                epoch=total_cur_step,
                optimizer=self.optimizer,
                max_epoch=self.max_epoch * self.iters_per_epoch,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
            )


def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
    """Decay the learning rate"""
    lr = (init_lr - min_lr) * 0.5 * (
        1.0 + math.cos(math.pi * epoch / max_epoch)
    ) + min_lr
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
    """Warmup the learning rate"""
    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
    """Decay the learning rate"""
    lr = max(min_lr, init_lr * (decay_rate**epoch))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


================================================
FILE: xraypulse/common/registry.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""


class Registry:
    mapping = {
        "builder_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
        "runner_name_mapping": {},
        "state": {},
        "paths": {},
    }

    @classmethod
    def register_builder(cls, name):
        r"""Register a dataset builder to registry with key 'name'

        Args:
            name: Key with which the builder will be registered.

        Usage:

            from xraypulse.common.registry import registry
            from xraypulse.datasets.base_dataset_builder import BaseDatasetBuilder
        """

        def wrap(builder_cls):
            from xraypulse.datasets.builders.base_dataset_builder import BaseDatasetBuilder

            assert issubclass(
                builder_cls, BaseDatasetBuilder
            ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
                builder_cls
            )
            if name in cls.mapping["builder_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["builder_name_mapping"][name]
                    )
                )
            cls.mapping["builder_name_mapping"][name] = builder_cls
            return builder_cls

        return wrap

    @classmethod
    def register_task(cls, name):
        r"""Register a task to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from minigpt4.common.registry import registry
        """

        def wrap(task_cls):
            from xraypulse.tasks.base_task import BaseTask

            assert issubclass(
                task_cls, BaseTask
            ), "All tasks must inherit BaseTask class"
            if name in cls.mapping["task_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["task_name_mapping"][name]
                    )
                )
            cls.mapping["task_name_mapping"][name] = task_cls
            return task_cls

        return wrap

    @classmethod
    def register_model(cls, name):
        r"""Register a task to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from xraypulse.common.registry import registry
        """

        def wrap(model_cls):
            from xraypulse.models import BaseModel

            assert issubclass(
                model_cls, BaseModel
            ), "All models must inherit BaseModel class"
            if name in cls.mapping["model_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["model_name_mapping"][name]
                    )
                )
            cls.mapping["model_name_mapping"][name] = model_cls
            return model_cls

        return wrap

    @classmethod
    def register_processor(cls, name):
        r"""Register a processor to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from xraypulse.common.registry import registry
        """

        def wrap(processor_cls):
            from xraypulse.processors import BaseProcessor

            assert issubclass(
                processor_cls, BaseProcessor
            ), "All processors must inherit BaseProcessor class"
            if name in cls.mapping["processor_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["processor_name_mapping"][name]
                    )
                )
            cls.mapping["processor_name_mapping"][name] = processor_cls
            return processor_cls

        return wrap

    @classmethod
    def register_lr_scheduler(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from xraypulse.common.registry import registry
        """

        def wrap(lr_sched_cls):
            if name in cls.mapping["lr_scheduler_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["lr_scheduler_name_mapping"][name]
                    )
                )
            cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
            return lr_sched_cls

        return wrap

    @classmethod
    def register_runner(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from xraypulse.common.registry import registry
        """

        def wrap(runner_cls):
            if name in cls.mapping["runner_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["runner_name_mapping"][name]
                    )
                )
            cls.mapping["runner_name_mapping"][name] = runner_cls
            return runner_cls

        return wrap

    @classmethod
    def register_path(cls, name, path):
        r"""Register a path to registry with key 'name'

        Args:
            name: Key with which the path will be registered.

        Usage:

            from xraypulse.common.registry import registry
        """
        assert isinstance(path, str), "All path must be str."
        if name in cls.mapping["paths"]:
            raise KeyError("Name '{}' already registered.".format(name))
        cls.mapping["paths"][name] = path

    @classmethod
    def register(cls, name, obj):
        r"""Register an item to registry with key 'name'

        Args:
            name: Key with which the item will be registered.

        Usage::

            from minigpt4.common.registry import registry

            registry.register("config", {})
        """
        path = name.split(".")
        current = cls.mapping["state"]

        for part in path[:-1]:
            if part not in current:
                current[part] = {}
            current = current[part]

        current[path[-1]] = obj

    # @classmethod
    # def get_trainer_class(cls, name):
    #     return cls.mapping["trainer_name_mapping"].get(name, None)

    @classmethod
    def get_builder_class(cls, name):
        return cls.mapping["builder_name_mapping"].get(name, None)

    @classmethod
    def get_model_class(cls, name):
        return cls.mapping["model_name_mapping"].get(name, None)

    @classmethod
    def get_task_class(cls, name):
        return cls.mapping["task_name_mapping"].get(name, None)

    @classmethod
    def get_processor_class(cls, name):
        return cls.mapping["processor_name_mapping"].get(name, None)

    @classmethod
    def get_lr_scheduler_class(cls, name):
        return cls.mapping["lr_scheduler_name_mapping"].get(name, None)

    @classmethod
    def get_runner_class(cls, name):
        return cls.mapping["runner_name_mapping"].get(name, None)

    @classmethod
    def list_runners(cls):
        return sorted(cls.mapping["runner_name_mapping"].keys())

    @classmethod
    def list_models(cls):
        return sorted(cls.mapping["model_name_mapping"].keys())

    @classmethod
    def list_tasks(cls):
        return sorted(cls.mapping["task_name_mapping"].keys())

    @classmethod
    def list_processors(cls):
        return sorted(cls.mapping["processor_name_mapping"].keys())

    @classmethod
    def list_lr_schedulers(cls):
        return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())

    @classmethod
    def list_datasets(cls):
        return sorted(cls.mapping["builder_name_mapping"].keys())

    @classmethod
    def get_path(cls, name):
        return cls.mapping["paths"].get(name, None)

    @classmethod
    def get(cls, name, default=None, no_warning=False):
        r"""Get an item from registry with key 'name'

        Args:
            name (string): Key whose value needs to be retrieved.
            default: If passed and key is not in registry, default value will
                     be returned with a warning. Default: None
            no_warning (bool): If passed as True, warning when key doesn't exist
                               will not be generated. Useful for MMF's
                               internal operations. Default: False
        """
        original_name = name
        name = name.split(".")
        value = cls.mapping["state"]
        for subname in name:
            value = value.get(subname, default)
            if value is default:
                break

        if (
            "writer" in cls.mapping["state"]
            and value == default
            and no_warning is False
        ):
            cls.mapping["state"]["writer"].warning(
                "Key {} is not present in registry, returning default value "
                "of {}".format(original_name, default)
            )
        return value

    @classmethod
    def unregister(cls, name):
        r"""Remove an item from registry with key 'name'

        Args:
            name: Key which needs to be removed.
        Usage::

            from mmf.common.registry import registry

            config = registry.unregister("config")
        """
        return cls.mapping["state"].pop(name, None)


registry = Registry()


================================================
FILE: xraypulse/common/utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import io
import json
import logging
import os
import pickle
import re
import shutil
import urllib
import urllib.error
import urllib.request
from typing import Optional
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import yaml
from iopath.common.download import download
from iopath.common.file_io import file_lock, g_pathmgr
from xraypulse.common.registry import registry
from torch.utils.model_zoo import tqdm
from torchvision.datasets.utils import (
    check_integrity,
    download_file_from_google_drive,
    extract_archive,
)


def now():
    from datetime import datetime

    return datetime.now().strftime("%Y%m%d%H%M")[:-1]


def is_url(url_or_filename):
    parsed = urlparse(url_or_filename)
    return parsed.scheme in ("http", "https")


def get_cache_path(rel_path):
    return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))


def get_abs_path(rel_path):
    return os.path.join(registry.get_path("library_root"), rel_path)


def load_json(filename):
    with open(filename, "r") as f:
        return json.load(f)


# The following are adapted from torchvision and vissl
# torchvision: https://github.com/pytorch/vision
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py


def makedir(dir_path):
    """
    Create the directory if it does not exist.
    """
    is_success = False
    try:
        if not g_pathmgr.exists(dir_path):
            g_pathmgr.mkdirs(dir_path)
        is_success = True
    except BaseException:
        print(f"Error creating directory: {dir_path}")
    return is_success


def get_redirected_url(url: str):
    """
    Given a URL, returns the URL it redirects to or the
    original URL in case of no indirection
    """
    import requests

    with requests.Session() as session:
        with session.get(url, stream=True, allow_redirects=True) as response:
            if response.history:
                return response.url
            else:
                return url


def to_google_drive_download_url(view_url: str) -> str:
    """
    Utility function to transform a view URL of google drive
    to a download URL for google drive
    Example input:
        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
    Example output:
        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
    """
    splits = view_url.split("/")
    assert splits[-1] == "view"
    file_id = splits[-2]
    return f"https://drive.google.com/uc?export=download&id={file_id}"


def download_google_drive_url(url: str, output_path: str, output_file_name: str):
    """
    Download a file from google drive
    Downloading an URL from google drive requires confirmation when
    the file of the size is too big (google drive notifies that
    anti-viral checks cannot be performed on such files)
    """
    import requests

    with requests.Session() as session:

        # First get the confirmation token and append it to the URL
        with session.get(url, stream=True, allow_redirects=True) as response:
            for k, v in response.cookies.items():
                if k.startswith("download_warning"):
                    url = url + "&confirm=" + v

        # Then download the content of the file
        with session.get(url, stream=True, verify=True) as response:
            makedir(output_path)
            path = os.path.join(output_path, output_file_name)
            total_size = int(response.headers.get("Content-length", 0))
            with open(path, "wb") as file:
                from tqdm import tqdm

                with tqdm(total=total_size) as progress_bar:
                    for block in response.iter_content(
                        chunk_size=io.DEFAULT_BUFFER_SIZE
                    ):
                        file.write(block)
                        progress_bar.update(len(block))


def _get_google_drive_file_id(url: str) -> Optional[str]:
    parts = urlparse(url)

    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
        return None

    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
    if match is None:
        return None

    return match.group("id")


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
    with open(filename, "wb") as fh:
        with urllib.request.urlopen(
            urllib.request.Request(url, headers={"User-Agent": "vissl"})
        ) as response:
            with tqdm(total=response.length) as pbar:
                for chunk in iter(lambda: response.read(chunk_size), ""):
                    if not chunk:
                        break
                    pbar.update(chunk_size)
                    fh.write(chunk)


def download_url(
    url: str,
    root: str,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
) -> None:
    """Download a file from a url and place it in root.
    Args:
        url (str): URL to download file from
        root (str): Directory to place downloaded file in
        filename (str, optional): Name to save the file under.
                                  If None, use the basename of the URL.
        md5 (str, optional): MD5 checksum of the download. If None, do not check
    """
    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename)

    makedir(root)

    # check if file is already present locally
    if check_integrity(fpath, md5):
        print("Using downloaded and verified file: " + fpath)
        return

    # expand redirect chain if needed
    url = get_redirected_url(url)

    # check if file is located on Google Drive
    file_id = _get_google_drive_file_id(url)
    if file_id is not None:
        return download_file_from_google_drive(file_id, root, filename, md5)

    # download the file
    try:
        print("Downloading " + url + " to " + fpath)
        _urlretrieve(url, fpath)
    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]
        if url[:5] == "https":
            url = url.replace("https:", "http:")
            print(
                "Failed download. Trying https -> http instead."
                " Downloading " + url + " to " + fpath
            )
            _urlretrieve(url, fpath)
        else:
            raise e

    # check integrity of downloaded file
    if not check_integrity(fpath, md5):
        raise RuntimeError("File not found or corrupted.")


def download_and_extract_archive(
    url: str,
    download_root: str,
    extract_root: Optional[str] = None,
    filename: Optional[str] = None,
    md5: Optional[str] = None,
    remove_finished: bool = False,
) -> None:
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)

    download_url(url, download_root, filename, md5)

    archive = os.path.join(download_root, filename)
    print("Extracting {} to {}".format(archive, extract_root))
    extract_archive(archive, extract_root, remove_finished)


def cache_url(url: str, cache_dir: str) -> str:
    """
    This implementation downloads the remote resource and caches it locally.
    The resource will only be downloaded if not previously requested.
    """
    parsed_url = urlparse(url)
    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
    makedir(dirname)
    filename = url.split("/")[-1]
    cached = os.path.join(dirname, filename)
    with file_lock(cached):
        if not os.path.isfile(cached):
            logging.info(f"Downloading {url} to {cached} ...")
            cached = download(url, dirname, filename=filename)
    logging.info(f"URL {url} cached in {cached}")
    return cached


# TODO (prigoyal): convert this into RAII-style API
def create_file_symlink(file1, file2):
    """
    Simply create the symlinks for a given file1 to file2.
    Useful during model checkpointing to symlinks to the
    latest successful checkpoint.
    """
    try:
        if g_pathmgr.exists(file2):
            g_pathmgr.rm(file2)
        g_pathmgr.symlink(file1, file2)
    except Exception as e:
        logging.info(f"Could NOT create symlink. Error: {e}")


def save_file(data, filename, append_to_json=True, verbose=True):
    """
    Common i/o utility to handle saving data to various file formats.
    Supported:
        .pkl, .pickle, .npy, .json
    Specifically for .json, users have the option to either append (default)
    or rewrite by passing in Boolean value to append_to_json.
    """
    if verbose:
        logging.info(f"Saving data to file: {filename}")
    file_ext = os.path.splitext(filename)[1]
    if file_ext in [".pkl", ".pickle"]:
        with g_pathmgr.open(filename, "wb") as fopen:
            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
    elif file_ext == ".npy":
        with g_pathmgr.open(filename, "wb") as fopen:
            np.save(fopen, data)
    elif file_ext == ".json":
        if append_to_json:
            with g_pathmgr.open(filename, "a") as fopen:
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
                fopen.flush()
        else:
            with g_pathmgr.open(filename, "w") as fopen:
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
                fopen.flush()
    elif file_ext == ".yaml":
        with g_pathmgr.open(filename, "w") as fopen:
            dump = yaml.dump(data)
            fopen.write(dump)
            fopen.flush()
    else:
        raise Exception(f"Saving {file_ext} is not supported yet")

    if verbose:
        logging.info(f"Saved data to file: {filename}")


def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
    """
    Common i/o utility to handle loading data from various file formats.
    Supported:
        .pkl, .pickle, .npy, .json
    For the npy files, we support reading the files in mmap_mode.
    If the mmap_mode of reading is not successful, we load data without the
    mmap_mode.
    """
    if verbose:
        logging.info(f"Loading data from file: {filename}")

    file_ext = os.path.splitext(filename)[1]
    if file_ext == ".txt":
        with g_pathmgr.open(filename, "r") as fopen:
            data = fopen.readlines()
    elif file_ext in [".pkl", ".pickle"]:
        with g_pathmgr.open(filename, "rb") as fopen:
            data = pickle.load(fopen, encoding="latin1")
    elif file_ext == ".npy":
        if mmap_mode:
            try:
                with g_pathmgr.open(filename, "rb") as fopen:
                    data = np.load(
                        fopen,
                        allow_pickle=allow_pickle,
                        encoding="latin1",
                        mmap_mode=mmap_mode,
                    )
            except ValueError as e:
                logging.info(
                    f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
                )
                data = np.load(
                    filename,
                    allow_pickle=allow_pickle,
                    encoding="latin1",
                    mmap_mode=mmap_mode,
                )
                logging.info("Successfully loaded without g_pathmgr")
            except Exception:
                logging.info("Could not mmap without g_pathmgr. Trying without mmap")
                with g_pathmgr.open(filename, "rb") as fopen:
                    data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
        else:
            with g_pathmgr.open(filename, "rb") as fopen:
                data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
    elif file_ext == ".json":
        with g_pathmgr.open(filename, "r") as fopen:
            data = json.load(fopen)
    elif file_ext == ".yaml":
        with g_pathmgr.open(filename, "r") as fopen:
            data = yaml.load(fopen, Loader=yaml.FullLoader)
    elif file_ext == ".csv":
        with g_pathmgr.open(filename, "r") as fopen:
            data = pd.read_csv(fopen)
    else:
        raise Exception(f"Reading from {file_ext} is not supported yet")
    return data


def abspath(resource_path: str):
    """
    Make a path absolute, but take into account prefixes like
    "http://" or "manifold://"
    """
    regex = re.compile(r"^\w+://")
    if regex.match(resource_path) is None:
        return os.path.abspath(resource_path)
    else:
        return resource_path


def makedir(dir_path):
    """
    Create the directory if it does not exist.
    """
    is_success = False
    try:
        if not g_pathmgr.exists(dir_path):
            g_pathmgr.mkdirs(dir_path)
        is_success = True
    except BaseException:
        logging.info(f"Error creating directory: {dir_path}")
    return is_success


def is_url(input_url):
    """
    Check if an input string is a url. look for http(s):// and ignoring the case
    """
    is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
    return is_url


def cleanup_dir(dir):
    """
    Utility for deleting a directory. Useful for cleaning the storage space
    that contains various training artifacts like checkpoints, data etc.
    """
    if os.path.exists(dir):
        logging.info(f"Deleting directory: {dir}")
        shutil.rmtree(dir)
    logging.info(f"Deleted contents of directory: {dir}")


def get_file_size(filename):
    """
    Given a file, get the size of file in MB
    """
    size_in_mb = os.path.getsize(filename) / float(1024**2)
    return size_in_mb


================================================
FILE: xraypulse/configs/datasets/mimic/defaults.yaml
================================================
datasets:
  mimic:
    data_type: images
    build_info:
      storage: /mnt/petrelfs/share_data/huangzhongzhen/multimodal_pretrain/dataset/mimic


================================================
FILE: xraypulse/configs/datasets/openi/defaults.yaml
================================================
datasets:
  openi:
    data_type: images
    build_info:
      storage: /mnt/petrelfs/share_data/huangzhongzhen/multimodal_pretrain/dataset/openi


================================================
FILE: xraypulse/configs/default.yaml
================================================
env:
  # For default users
  # cache_root: "cache"
  # For internal use with persistent storage
  cache_root: "/export/home/.cache/xraypulse"


================================================
FILE: xraypulse/configs/models/xraypulse.yaml
================================================
model:
  arch: xray_pulse

  # vit encoder
  image_size: 224
  drop_path_rate: 0
  use_grad_checkpoint: False
  vit_precision: "fp16"
  freeze_vit: True
  freeze_qformer: True

  # Q-Former
  num_query_token: 32

  # Vicuna
  bloom_model: "OpenMEDLab/PULSE-7bv5"

  # generation configs
  prompt: ""

preprocess:
    vis_processor:
        train:
          name: "blip2_image_train"
          image_size: 224
        eval:
          name: "blip2_image_eval"
          image_size: 224
    text_processor:
        train:
          name: "blip_caption"
        eval:
          name: "blip_caption"


================================================
FILE: xraypulse/conversation/__init__.py
================================================


================================================
FILE: xraypulse/conversation/conversation.py
================================================
import argparse
import time
from PIL import Image

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomTokenizerFast
from transformers import StoppingCriteria, StoppingCriteriaList

import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any

from xraypulse.common.registry import registry


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()


@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    # system_img: List[Image.Image] = []
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None

    skip_next: bool = False
    conv_id: Any = None

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in self.messages:
                if message:
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        print('ret:')
        print(ret)
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            # system_img=self.system_img,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            conv_id=self.conv_id)

    def dict(self):
        return {
            "system": self.system,
            # "system_img": self.system_img,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
            "conv_id": self.conv_id,
        }


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False

CONV_ZH = Conversation(
    system="Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01</s> User: {} </s> Helper: ",
        # "Please answer the medical questions based on the patient's description. Give the following medical scan: <Img>图片</Img>."
        # "You will be able to see the medical scan once I provide it to you. Please answer the patients questions.",
    roles=("User", "Helper"),
    messages=[],
    offset=0,
    sep_style=SeparatorStyle.SINGLE,
    sep="</s>",
    sep2="###",  
)


class Chat:
    def __init__(self, model, vis_processor, device='cuda:0'):
        self.device = device
        self.model = model
        self.vis_processor = vis_processor
        stop_words_ids = [torch.tensor([835]).to(self.device),
                          torch.tensor([2277, 29937]).to(self.device)]  # '###' can be encoded in two different ways.
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def ask(self, text, conv):
        if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
                and conv.messages[-1][1][-6:] == '</Img>':  # last message is image.
            conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
        else:
            conv.append_message(conv.roles[0], text)

    def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
               repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
        conv.append_message(conv.roles[1], None)
        embs = self.get_context_emb(conv, img_list)

        current_max_len = embs.shape[1] + max_new_tokens
        if current_max_len - max_length > 0:
            print('Warning: The number of tokens in current conversation exceeds the max length. '
                  'The model will not see the contexts outside the range.')
        begin_idx = max(0, current_max_len - max_length)

        embs = embs[:, begin_idx:]

        outputs = self.model.bloom_model.generate(
            inputs_embeds=embs,
            max_new_tokens=max_new_tokens,
            stopping_criteria=self.stopping_criteria,
            num_beams=num_beams,
            do_sample=True,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            temperature=temperature,
        )
        output_token = outputs[0]
        if output_token[0] == 0:  # the model might output a unknow token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # some users find that there is a start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.model.bloom_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('</s>')[0]  # remove the stop sign '###'
        output_text = output_text.split('###')[0]
        conv.messages[-1][1] = output_text
        return output_text, output_token.cpu().numpy()

    def upload_img(self, image, conv, img_list):
        if isinstance(image, str):  # is a image path
            raw_image = Image.open(image).convert('RGB')
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, Image.Image):
            raw_image = image
            image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
        elif isinstance(image, torch.Tensor):
            if len(image.shape) == 3:
                image = image.unsqueeze(0)
            image = image.to(self.device)

        image_emb, _ = self.model.encode_img(image)
        img_list.append(image_emb)
        conv.append_message(conv.roles[0], "<Img><图片></Img>")
        msg = "Received."
        # self.conv.append_message(self.conv.roles[1], msg)
        return msg

    def get_context_emb(self, conv, img_list):
        prompt = conv.get_prompt()
        prompt_segs = prompt.split('<图片>')
        assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
        seg_tokens = [
            self.model.bloom_tokenizer(
                seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
            # only add bos to the first seg
            for i, seg in enumerate(prompt_segs)
        ]
        seg_embs = [self.model.bloom_model.transformer.word_embeddings_layernorm(self.model.bloom_model.transformer.word_embeddings(seg_t)) for seg_t in seg_tokens]
        mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
        mixed_embs = torch.cat(mixed_embs, dim=1)
        return mixed_embs




================================================
FILE: xraypulse/datasets/__init__.py
================================================


================================================
FILE: xraypulse/datasets/builders/__init__.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from xraypulse.datasets.builders.base_dataset_builder import load_dataset_config
from xraypulse.datasets.builders.image_text_pair_builder import (
    MIMICBuilder,
    OpenIBuilder,
)
from xraypulse.common.registry import registry

__all__ = [
    "MIMICBuilder",
    "OpenIBuilder",
]


def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
    """
    Example

    >>> dataset = load_dataset("coco_caption", cfg=None)
    >>> splits = dataset.keys()
    >>> print([len(dataset[split]) for split in splits])

    """
    if cfg_path is None:
        cfg = None
    else:
        cfg = load_dataset_config(cfg_path)

    try:
        builder = registry.get_builder_class(name)(cfg)
    except TypeError:
        print(
            f"Dataset {name} not found. Available datasets:\n"
            + ", ".join([str(k) for k in dataset_zoo.get_names()])
        )
        exit(1)

    if vis_path is not None:
        if data_type is None:
            # use default data type in the config
            data_type = builder.config.data_type

        assert (
            data_type in builder.config.build_info
        ), f"Invalid data_type {data_type} for {name}."

        builder.config.build_info.get(data_type).storage = vis_path

    dataset = builder.build_datasets()
    return dataset


class DatasetZoo:
    def __init__(self) -> None:
        self.dataset_zoo = {
            k: list(v.DATASET_CONFIG_DICT.keys())
            for k, v in sorted(registry.mapping["builder_name_mapping"].items())
        }

    def get_names(self):
        return list(self.dataset_zoo.keys())


dataset_zoo = DatasetZoo()


================================================
FILE: xraypulse/datasets/builders/base_dataset_builder.py
================================================
"""
 This file is from
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import os
import shutil
import warnings

from omegaconf import OmegaConf
import torch.distributed as dist
from torchvision.datasets.utils import download_url

import xraypulse.common.utils as utils
from xraypulse.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from xraypulse.common.registry import registry
from xraypulse.processors.base_processor import BaseProcessor



class BaseDatasetBuilder:
    train_dataset_cls, eval_dataset_cls = None, None

    def __init__(self, cfg=None):
        super().__init__()

        if cfg is None:
            # help to create datasets from default config.
            self.config = load_dataset_config(self.default_config_path())
        elif isinstance(cfg, str):
            self.config = load_dataset_config(cfg)
        else:
            # when called from task.build_dataset()
            self.config = cfg

        self.data_type = self.config.data_type

        self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
        self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}

    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed

        if is_main_process():
            self._download_data()

        if is_dist_avail_and_initialized():
            dist.barrier()

        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets...")
        datasets = self.build()  # dataset['train'/'val'/'test']

        return datasets

    def build_processors(self):
        vis_proc_cfg = self.config.get("vis_processor")
        txt_proc_cfg = self.config.get("text_processor")

        if vis_proc_cfg is not None:
            vis_train_cfg = vis_proc_cfg.get("train")
            vis_eval_cfg = vis_proc_cfg.get("eval")

            self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
            self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)

        if txt_proc_cfg is not None:
            txt_train_cfg = txt_proc_cfg.get("train")
            txt_eval_cfg = txt_proc_cfg.get("eval")

            self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
            self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)

    @staticmethod
    def _build_proc_from_cfg(cfg):
        return (
            registry.get_processor_class(cfg.name).from_config(cfg)
            if cfg is not None
            else None
        )

    @classmethod
    def default_config_path(cls, type="default"):
        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])

    def _download_data(self):
        self._download_ann()
        self._download_vis()

    def _download_ann(self):
        """
        Download annotation files if necessary.
        All the vision-language datasets should have annotations of unified format.

        storage_path can be:
          (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
          (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.

        Local annotation paths should be relative.
        """
        anns = self.config.build_info.annotations

        splits = anns.keys()

        cache_root = registry.get_path("cache_root")

        for split in splits:
            info = anns[split]

            urls, storage_paths = info.get("url", None), info.storage

            if isinstance(urls, str):
                urls = [urls]
            if isinstance(storage_paths, str):
                storage_paths = [storage_paths]

            assert len(urls) == len(storage_paths)

            for url_or_filename, storage_path in zip(urls, storage_paths):
                # if storage_path is relative, make it full by prefixing with cache_root.
                if not os.path.isabs(storage_path):
                    storage_path = os.path.join(cache_root, storage_path)

                dirname = os.path.dirname(storage_path)
                if not os.path.exists(dirname):
                    os.makedirs(dirname)

                if os.path.isfile(url_or_filename):
                    src, dst = url_or_filename, storage_path
                    if not os.path.exists(dst):
                        shutil.copyfile(src=src, dst=dst)
                    else:
                        logging.info("Using existing file {}.".format(dst))
                else:
                    if os.path.isdir(storage_path):
                        # if only dirname is provided, suffix with basename of URL.
                        raise ValueError(
                            "Expecting storage_path to be a file path, got directory {}".format(
                                storage_path
                            )
                        )
                    else:
                        filename = os.path.basename(storage_path)

                    download_url(url=url_or_filename, root=dirname, filename=filename)

    def _download_vis(self):

        storage_path = self.config.build_info.get(self.data_type).storage
        storage_path = utils.get_cache_path(storage_path)

        if not os.path.exists(storage_path):
            warnings.warn(
                f"""
                The specified path {storage_path} for visual inputs does not exist.
                Please provide a correct path to the visual inputs or
                refer to datasets/download_scripts/README.md for downloading instructions.
                """
            )

    def build(self):
        """
        Create by split datasets inheriting torch.utils.data.Datasets.

        # build() can be dataset-specific. Overwrite to customize.
        """
        self.build_processors()

        build_info = self.config.build_info

        ann_info = build_info.annotations
        vis_info = build_info.get(self.data_type)

        datasets = dict()
        for split in ann_info.keys():
            if split not in ["train", "val", "test"]:
                continue

            is_train = split == "train"

            # processors
            vis_processor = (
                self.vis_processors["train"]
                if is_train
                else self.vis_processors["eval"]
            )
            text_processor = (
                self.text_processors["train"]
                if is_train
                else self.text_processors["eval"]
            )

            # annotation path
            ann_paths = ann_info.get(split).storage
            if isinstance(ann_paths, str):
                ann_paths = [ann_paths]

            abs_ann_paths = []
            for ann_path in ann_paths:
                if not os.path.isabs(ann_path):
                    ann_path = utils.get_cache_path(ann_path)
                abs_ann_paths.append(ann_path)
            ann_paths = abs_ann_paths

            # visual data storage path
            vis_path = os.path.join(vis_info.storage, split)

            if not os.path.isabs(vis_path):
                # vis_path = os.path.join(utils.get_cache_path(), vis_path)
                vis_path = utils.get_cache_path(vis_path)

            if not os.path.exists(vis_path):
                warnings.warn("storage path {} does not exist.".format(vis_path))

            # create datasets
            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
            datasets[split] = dataset_cls(
                vis_processor=vis_processor,
                text_processor=text_processor,
                ann_paths=ann_paths,
                vis_root=vis_path,
            )

        return datasets


def load_dataset_config(cfg_path):
    cfg = OmegaConf.load(cfg_path).datasets
    cfg = cfg[list(cfg.keys())[0]]

    return cfg


================================================
FILE: xraypulse/datasets/builders/image_text_pair_builder.py
================================================
import os
import logging
import warnings

from xraypulse.common.registry import registry
from xraypulse.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from xraypulse.datasets.datasets.openi_dataset import OpenIDataset
from xraypulse.datasets.datasets.mimic_dataset import MIMICDataset


@registry.register_builder("mimic")
class MIMICBuilder(BaseDatasetBuilder):
    train_dataset_cls = MIMICDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/mimic/defaults.yaml"}

    def _download_ann(self):
        pass

    def _download_vis(self):
        pass

    def build_datasets(self):
        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets...")
        self.build_processors()

        build_info = self.config.build_info
        storage_path = build_info.storage

        datasets = dict()

        if not os.path.exists(storage_path):
            warnings.warn("storage path {} does not exist.".format(storage_path))

        # create datasets
        dataset_cls = self.train_dataset_cls
        datasets['train'] = dataset_cls(
            vis_processor=self.vis_processors["train"],
            text_processor=self.text_processors["train"],
            ann_paths=[os.path.join(storage_path, 'zh_filter_cap.json')],
            vis_root=os.path.join(storage_path, 'image'),
        )

        return datasets


@registry.register_builder("openi")
class OpenIBuilder(BaseDatasetBuilder):
    train_dataset_cls = OpenIDataset

    DATASET_CONFIG_DICT = {"default": "configs/datasets/openi/defaults.yaml"}

    def _download_ann(self):
        pass

    def _download_vis(self):
        pass

    def build(self):
        self.build_processors()

        build_info = self.config.build_info
        storage_path = build_info.storage

        datasets = dict()
        split = "train"

        # create datasets
        # [NOTE] return inner_datasets (wds.DataPipeline)
        dataset_cls = self.train_dataset_cls
        datasets[split] = dataset_cls(
            vis_processor=self.vis_processors["train"],
            text_processor=self.text_processors["train"],
            ann_paths=[os.path.join(storage_path, 'zh_filter_cap.json')],
            vis_root=os.path.join(storage_path, 'image'),
        )

        return datasets


================================================
FILE: xraypulse/datasets/data_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import gzip
import logging
import os
import random as rnd
import tarfile
import zipfile
import random
from typing import List
from tqdm import tqdm

import decord
from decord import VideoReader
import webdataset as wds
import numpy as np
import torch
from torch.utils.data.dataset import IterableDataset

from xraypulse.common.registry import registry
from xraypulse.datasets.datasets.base_dataset import ConcatDataset


decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT")


class ChainDataset(wds.DataPipeline):
    r"""Dataset for chaining multiple :class:`DataPipeline` s.

    This class is useful to assemble different existing dataset streams. The
    chaining operation is done on-the-fly, so concatenating large-scale
    datasets with this class will be efficient.

    Args:
        datasets (iterable of IterableDataset): datasets to be chained together
    """
    def __init__(self, datasets: List[wds.DataPipeline]) -> None:
        super().__init__()
        self.datasets = datasets
        self.prob = []
        self.names = []
        for dataset in self.datasets:
            if hasattr(dataset, 'name'):
                self.names.append(dataset.name)
            else:
                self.names.append('Unknown')
            if hasattr(dataset, 'sample_ratio'):
                self.prob.append(dataset.sample_ratio)
            else:
                self.prob.append(1)
                logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")

    def __iter__(self):
        datastreams = [iter(dataset) for dataset in self.datasets]
        while True:
            select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
            yield next(select_datastream)


def apply_to_sample(f, sample):
    if len(sample) == 0:
        return {}

    def _apply(x):
        if torch.is_tensor(x):
            return f(x)
        elif isinstance(x, dict):
            return {key: _apply(value) for key, value in x.items()}
        elif isinstance(x, list):
            return [_apply(x) for x in x]
        else:
            return x

    return _apply(sample)


def move_to_cuda(sample):
    def _move_to_cuda(tensor):
        return tensor.cuda()

    return apply_to_sample(_move_to_cuda, sample)


def prepare_sample(samples, cuda_enabled=True):
    if cuda_enabled:
        samples = move_to_cuda(samples)

    # TODO fp16 support

    return samples


def reorg_datasets_by_split(datasets):
    """
    Organizes datasets by split.

    Args:
        datasets: dict of torch.utils.data.Dataset objects by name.

    Returns:
        Dict of datasets by split {split_name: List[Datasets]}.
    """
    # if len(datasets) == 1:
    #     return datasets[list(datasets.keys())[0]]
    # else:
    reorg_datasets = dict()

    # reorganize by split
    for _, dataset in datasets.items():
        for split_name, dataset_split in dataset.items():
            if split_name not in reorg_datasets:
                reorg_datasets[split_name] = [dataset_split]
            else:
                reorg_datasets[split_name].append(dataset_split)

    return reorg_datasets


def concat_datasets(datasets):
    """
    Concatenates multiple datasets into a single dataset.

    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
    generic IterableDataset because it requires creating separate samplers.

    Now only supports conctenating training datasets and assuming validation and testing
    have only a single dataset. This is because metrics should not be computed on the concatenated
    datasets.

    Args:
        datasets: dict of torch.utils.data.Dataset objects by split.

    Returns:
        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
        "val" and "test" remain the same.

        If the input training datasets contain both map-style and DataPipeline datasets, returns
        a tuple, where the first element is a concatenated map-style dataset and the second
        element is a chained DataPipeline dataset.

    """
    # concatenate datasets in the same split
    for split_name in datasets:
        if split_name != "train":
            assert (
                len(datasets[split_name]) == 1
            ), "Do not support multiple {} datasets.".format(split_name)
            datasets[split_name] = datasets[split_name][0]
        else:
            iterable_datasets, map_datasets = [], []
            for dataset in datasets[split_name]:
                if isinstance(dataset, wds.DataPipeline):
                    logging.info(
                        "Dataset {} is IterableDataset, can't be concatenated.".format(
                            dataset
                        )
                    )
                    iterable_datasets.append(dataset)
                elif isinstance(dataset, IterableDataset):
                    raise NotImplementedError(
                        "Do not support concatenation of generic IterableDataset."
                    )
                else:
                    map_datasets.append(dataset)

            # if len(iterable_datasets) > 0:
            # concatenate map-style datasets and iterable-style datasets separately
            if len(iterable_datasets) > 1:
                chained_datasets = (
                    ChainDataset(iterable_datasets)
                )
            elif len(iterable_datasets) == 1:
                chained_datasets = iterable_datasets[0]
            else:
                chained_datasets = None

            concat_datasets = (
                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
            )

            train_datasets = concat_datasets, chained_datasets
            train_datasets = tuple([x for x in train_datasets if x is not None])
            train_datasets = (
                train_datasets[0] if len(train_datasets) == 1 else train_datasets
            )

            datasets[split_name] = train_datasets

    return datasets



================================================
FILE: xraypulse/datasets/datasets/__init__.py
================================================


================================================
FILE: xraypulse/datasets/datasets/base_dataset.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
from typing import Iterable

from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data.dataloader import default_collate


class BaseDataset(Dataset):
    def __init__(
        self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
    ):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        self.vis_root = vis_root

        self.annotation = []
        for ann_path in ann_paths:
            self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])

        self.vis_processor = vis_processor
        self.text_processor = text_processor

        self._add_instance_ids()

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)

    def set_processors(self, vis_processor, text_processor):
        self.vis_processor = vis_processor
        self.text_processor = text_processor

    def _add_instance_ids(self, key="instance_id"):
        for idx, ann in enumerate(self.annotation):
            ann[key] = str(idx)


class ConcatDataset(ConcatDataset):
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super().__init__(datasets)

    def collater(self, samples):
        # TODO For now only supports datasets with same underlying collater implementations

        all_keys = set()
        for s in samples:
            all_keys.update(s)

        shared_keys = all_keys
        for s in samples:
            shared_keys = shared_keys & set(s.keys())

        samples_shared_keys = []
        for s in samples:
            samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})

        return self.datasets[0].collater(samples_shared_keys)


================================================
FILE: xraypulse/datasets/datasets/caption_datasets.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from collections import OrderedDict

from xraypulse.datasets.datasets.base_dataset import BaseDataset
from PIL import Image


class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]

        return OrderedDict(
            {
                "file": ann["image"],
                "caption": ann["caption"],
                "image": sample["image"],
            }
        )


class CaptionDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        img_file = '{:0>12}.png'.format(ann["image_id"])
        image_path = os.path.join(self.vis_root, img_file)
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])

        return {
            "image": image,
            "text_input": caption,
            "image_id": self.img_ids[ann["image_id"]],
        }


class CaptionEvalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        #below lines are added during test rogue score
        self.img_ids = {}
        n = 0
        for ann in self.annotation:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __getitem__(self, index):

        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        return {
            "image": image,
            "image_id": ann["image_id"],
            "instance_id": ann["instance_id"],
        }


================================================
FILE: xraypulse/datasets/datasets/dataloader_utils.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import time
import random
import torch
from xraypulse.datasets.data_utils import move_to_cuda
from torch.utils.data import DataLoader


class MultiIterLoader:
    """
    A simple wrapper for iterating over multiple iterators.

    Args:
        loaders (List[Loader]): List of Iterator loaders.
        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
    """

    def __init__(self, loaders, ratios=None):
        # assert all loaders has __next__ method
        for loader in loaders:
            assert hasattr(
                loader, "__next__"
            ), "Loader {} has no __next__ method.".format(loader)

        if ratios is None:
            ratios = [1.0] * len(loaders)
        else:
            assert len(ratios) == len(loaders)
            ratios = [float(ratio) / sum(ratios) for ratio in ratios]

        self.loaders = loaders
        self.ratios = ratios

    def __next__(self):
        # random sample from each loader by ratio
        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
        return next(self.loaders[loader_idx])


class PrefetchLoader(object):
    """
    Modified from https://github.com/ChenRocks/UNITER.

    overlap compute and cuda data transfer
    (copied and then modified from nvidia apex)
    """

    def __init__(self, loader):
        self.loader = loader
        self.stream = torch.cuda.Stream()

    def __iter__(self):
        loader_it = iter(self.loader)
        self.preload(loader_it)
        batch = self.next(loader_it)
        while batch is not None:
            is_tuple = isinstance(batch, tuple)
            if is_tuple:
                task, batch = batch

            if is_tuple:
                yield task, batch
            else:
                yield batch
            batch = self.next(loader_it)

    def __len__(self):
        return len(self.loader)

    def preload(self, it):
        try:
            self.batch = next(it)
        except StopIteration:
            self.batch = None
            return
        # if record_stream() doesn't work, another option is to make sure
        # device inputs are created on the main stream.
        # self.next_input_gpu = torch.empty_like(self.next_input,
        #                                        device='cuda')
        # self.next_target_gpu = torch.empty_like(self.next_target,
        #                                         device='cuda')
        # Need to make sure the memory allocated for next_* is not still in use
        # by the main stream at the time we start copying to next_*:
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.batch = move_to_cuda(self.batch)
            # more code for the alternative if record_stream() doesn't work:
            # copy_ will record the use of the pinned source tensor in this
            # side stream.
            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
            # self.next_input = self.next_input_gpu
            # self.next_target = self.next_target_gpu

    def next(self, it):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is not None:
            record_cuda_stream(batch)
        self.preload(it)
        return batch

    def __getattr__(self, name):
        method = self.loader.__getattribute__(name)
        return method


def record_cuda_stream(batch):
    if isinstance(batch, torch.Tensor):
        batch.record_stream(torch.cuda.current_stream())
    elif isinstance(batch, list) or isinstance(batch, tuple):
        for t in batch:
            record_cuda_stream(t)
    elif isinstance(batch, dict):
        for t in batch.values():
            record_cuda_stream(t)
    else:
        pass


class IterLoader:
    """
    A wrapper to convert DataLoader as an infinite iterator.

    Modified from:
        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
    """

    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._use_distributed = use_distributed
        self._epoch = 0

    @property
    def epoch(self) -> int:
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __iter__(self):
        return self

    def __len__(self):
        return len(self._dataloader)


================================================
FILE: xraypulse/datasets/datasets/mimic_dataset.py
================================================
import os
from PIL import Image
import webdataset as wds
from xraypulse.datasets.datasets.base_dataset import BaseDataset
from xraypulse.datasets.datasets.caption_datasets import CaptionDataset

class MIMICDataset(CaptionDataset):

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        img_file = '{}.jpg'.format(ann["image_id"])
        image_path = os.path.join(self.vis_root, img_file)
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = ann['caption']

        return {
            "image": image,
            "caption":caption,
            "image_id": self.img_ids[ann["image_id"]],
        }

================================================
FILE: xraypulse/datasets/datasets/openi_dataset.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from PIL import Image
import webdataset as wds
from xraypulse.datasets.datasets.base_dataset import BaseDataset
from xraypulse.datasets.datasets.caption_datasets import CaptionDataset
    
class OpenIDataset(CaptionDataset):

    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]

        img_file = '{}.png'.format(ann["image_id"])
        image_path = os.path.join(self.vis_root, img_file)
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)
        caption = ann['caption']

        return {
            "image": image,
            "caption":caption,
            "image_id": self.img_ids[ann["image_id"]],
        }



================================================
FILE: xraypulse/models/Qformer.py
================================================
"""
 * Copyright (c) 2023, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 * By Junnan Li
 * Based on huggingface code base
 * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
"""

import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

from transformers.activations import ACT2FN
from transformers.file_utils import (
    ModelOutput,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig

logger = logging.get_logger(__name__)


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
        )
        self.position_embedding_type = getattr(
            config, "position_embedding_type", "absolute"
        )

        self.config = config

    def forward(
        self,
        input_ids=None,
        position_ids=None,
        query_embeds=None,
        past_key_values_length=0,
    ):
        if input_ids is not None:
            seq_length = input_ids.size()[1]
        else:
            seq_length = 0

        if position_ids is None:
            position_ids = self.position_ids[
                :, past_key_values_length : seq_length + past_key_values_length
            ].clone()

        if input_ids is not None:
            embeddings = self.word_embeddings(input_ids)
            if self.position_embedding_type == "absolute":
                position_embeddings = self.position_embeddings(position_ids)
                embeddings = embeddings + position_embeddings

            if query_embeds is not None:
                embeddings = torch.cat((query_embeds, embeddings), dim=1)
        else:
            embeddings = query_embeds

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config, is_cross_attention):
        super().__init__()
        self.config = config
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
            config, "embedding_size"
        ):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        if is_cross_attention:
            self.key = nn.Linear(config.encoder_width, self.all_head_size)
            self.value = nn.Linear(config.encoder_width, self.all_head_size)
        else:
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(
            config, "position_embedding_type", "absolute"
        )
        if (
            self.position_embedding_type == "relative_key"
            or self.position_embedding_type == "relative_key_query"
        ):
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1, self.attention_head_size
            )
        self.save_attention = False

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        mixed_query_layer = self.query(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if (
            self.position_embedding_type == "relative_key"
            or self.position_embedding_type == "relative_key_query"
        ):
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(
                seq_length, dtype=torch.long, device=hidden_states.device
            ).view(-1, 1)
            position_ids_r = torch.arange(
                seq_length, dtype=torch.long, device=hidden_states.device
            ).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1
            )
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype
            )  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                relative_position_scores_key = torch.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding
                )
                attention_scores = (
                    attention_scores
                    + relative_position_scores_query
                    + relative_position_scores_key
                )

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (
            (context_layer, attention_probs) if output_attentions else (context_layer,)
        )

        outputs = outputs + (past_key_value,)
        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads,
            self.self.num_attention_heads,
            self.self.attention_head_size,
            self.pruned_heads,
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = (
            self.self.attention_head_size * self.self.num_attention_heads
        )
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,) + self_outputs[
            1:
        ]  # add attentions if we output them
        return outputs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.layer_num = layer_num
        if (
            self.config.add_cross_attention
            and layer_num % self.config.cross_attention_freq == 0
        ):
            self.crossattention = BertAttention(
                config, is_cross_attention=self.config.add_cross_attention
            )
            self.has_cross_attention = True
        else:
            self.has_cross_attention = False
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

        self.intermediate_query = BertIntermediate(config)
        self.output_query = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        query_length=0,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = (
            past_key_value[:2] if past_key_value is not None else None
        )
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:-1]

        present_key_value = self_attention_outputs[-1]

        if query_length > 0:
            query_attention_output = attention_output[:, :query_length, :]

            if self.has_cross_attention:
                assert (
                    encoder_hidden_states is not None
                ), "encoder_hidden_states must be given for cross-attention layers"
                cross_attention_outputs = self.crossattention(
                    query_attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                )
                query_attention_output = cross_attention_outputs[0]
                outputs = (
                    outputs + cross_attention_outputs[1:-1]
                )  # add cross attentions if we output attention weights

            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk_query,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                query_attention_output,
            )
            if attention_output.shape[1] > query_length:
                layer_output_text = apply_chunking_to_forward(
                    self.feed_forward_chunk,
                    self.chunk_size_feed_forward,
                    self.seq_len_dim,
                    attention_output[:, query_length:, :],
                )
                layer_output = torch.cat([layer_output, layer_output_text], dim=1)
        else:
            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                attention_output,
            )
        outputs = (layer_output,) + outputs

        outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

    def feed_forward_chunk_query(self, attention_output):
        intermediate_output = self.intermediate_query(attention_output)
        layer_output = self.output_query(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList(
            [BertLayer(config, i) for i in range(config.num_hidden_layers)]
        )

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        query_length=0,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = (
            () if output_attentions and self.config.add_cross_attention else None
        )

        next_decoder_cache = () if use_cache else None

        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if getattr(self.config, "gradient_checkpointing", False) and self.training:

                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(
                            *inputs, past_key_value, output_attentions, query_length
                        )

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    query_length,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """

    def __init__(self, config, add_pooling_layer=False):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)

        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def get_extended_attention_mask(
        self,
        attention_mask: Tensor,
        input_shape: Tuple[int],
        device: device,
        is_decoder: bool,
        has_query: bool = False,
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if is_decoder:
                batch_size, seq_length = input_shape

                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = (
                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
                    <= seq_ids[None, :, None]
                )

                # add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    if has_query:  # UniLM style attention mask
                        causal_mask = torch.cat(
                            [
                                torch.zeros(
                                    (batch_size, prefix_seq_len, seq_length),
                                    device=device,
                                    dtype=causal_mask.dtype,
                                ),
                                causal_mask,
                            ],
                            axis=1,
                        )
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, causal_mask.shape[1], prefix_seq_len),
                                device=device,
                                dtype=causal_mask.dtype,
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )
                extended_attention_mask = (
                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
                )
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=self.dtype
        )  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_decoder=False,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is None:
            assert (
                query_embeds is not None
            ), "You have to specify query_embeds when input_ids is None"

        # past_key_values_length
        past_key_values_length = (
            past_key_values[0][0].shape[2] - self.config.query_length
            if past_key_values is not None
            else 0
        )

        query_length = query_embeds.shape[1] if query_embeds is not None else 0

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            query_embeds=query_embeds,
            past_key_values_length=past_key_values_length,
        )

        input_shape = embedding_output.size()[:-1]
        batch_size, seq_length = input_shape
        device = embedding_output.device

        if attention_mask is None:
            attention_mask = torch.ones(
                ((batch_size, seq_length + past_key_values_length)), device=device
            )

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if is_decoder:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask,
                input_ids.shape,
                device,
                is_decoder,
                has_query=(query_embeds is not None),
            )
        else:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask, input_shape, device, is_decoder
            )

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if encoder_hidden_states is not None:
            if type(encoder_hidden_states) == list:
                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
                    0
                ].size()
            else:
                (
                    encoder_batch_size,
                    encoder_sequence_length,
                    _,
                ) = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)

            if type(encoder_attention_mask) == list:
                encoder_extended_attention_mask = [
                    self.invert_attention_mask(mask) for mask in encoder_attention_mask
                ]
            elif encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask
                )
            else:
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask
                )
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            query_length=query_length,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = (
            self.pooler(sequence_output) if self.pooler is not None else None
        )

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


class BertLMHeadModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        past_key_values=None,
        use_cache=True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=True,
        reduction="mean",
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        Returns:
        Example::
            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
            >>> import torch
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
            >>> config = BertConfig.from_pretrained("bert-base-cased")
            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
            >>> outputs = model(**inputs)
            >>> prediction_logits = outputs.logits
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        if labels is not None:
            use_cache = False
        if past_key_values is not None:
            query_embeds = None

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        sequence_output = outputs[0]
        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]

        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores[:, :-1, :].contiguous()

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
            lm_loss = loss_fct(
                shifted_prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1),
            )
            if reduction == "none":
                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((lm_loss,) + output) if lm_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=lm_loss,
            logits=prediction_scores,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
    ):
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)
        query_mask = input_ids.new_ones(query_embeds.shape[:-1])
        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids": input_ids,
            "query_embeds": query_embeds,
            "attention_mask": attention_mask,
            "past_key_values": past,
            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
            "is_decoder": True,
        }

    def _reorder_cache(self, past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx) for past_state in layer_past
                ),
            )
        return reordered_past


class BertForMaskedLM(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
        """

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
            )

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return (
                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            )

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


================================================
FILE: xraypulse/models/__init__.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import torch
from omegaconf import OmegaConf

from xraypulse.common.registry import registry
from xraypulse.models.base_model import BaseModel
from xraypulse.models.blip2 import Blip2Base
from xraypulse.models.xray_pulse import XrayPulse
from xraypulse.processors.base_processor import BaseProcessor


__all__ = [
    "load_model",
    "BaseModel",
    "Blip2Base",
    "XrayPulse",
]


def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
    """
    Load supported models.

    To list all available models and types in registry:
    >>> from minigpt4.models import model_zoo
    >>> print(model_zoo)

    Args:
        name (str): name of the model.
        model_type (str): type of the model.
        is_eval (bool): whether the model is in eval mode. Default: False.
        device (str): device to use. Default: "cpu".
        checkpoint (str): path or to checkpoint. Default: None.
            Note that expecting the checkpoint to have the same keys in state_dict as the model.

    Returns:
        model (torch.nn.Module): model.
    """

    model = registry.get_model_class(name).from_pretrained(model_type=model_type)

    if checkpoint is not None:
        model.load_checkpoint(checkpoint)

    if is_eval:
        model.eval()

    if device == "cpu":
        model = model.float()

    return model.to(device)


def load_preprocess(config):
    """
    Load preprocessor configs and construct preprocessors.

    If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.

    Args:
        config (dict): preprocessor configs.

    Returns:
        vis_processors (dict): preprocessors for visual inputs.
        txt_processors (dict): preprocessors for text inputs.

        Key is "train" or "eval" for processors used in training and evaluation respectively.
    """

    def _build_proc_from_cfg(cfg):
        return (
            registry.get_processor_class(cfg.name).from_config(cfg)
            if cfg is not None
            else BaseProcessor()
        )

    vis_processors = dict()
    txt_processors = dict()

    vis_proc_cfg = config.get("vis_processor")
    txt_proc_cfg = config.get("text_processor")

    if vis_proc_cfg is not None:
        vis_train_cfg = vis_proc_cfg.get("train")
        vis_eval_cfg = vis_proc_cfg.get("eval")
    else:
        vis_train_cfg = None
        vis_eval_cfg = None

    vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
    vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)

    if txt_proc_cfg is not None:
        txt_train_cfg = txt_proc_cfg.get("train")
        txt_eval_cfg = txt_proc_cfg.get("eval")
    else:
        txt_train_cfg = None
        txt_eval_cfg = None

    txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
    txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)

    return vis_processors, txt_processors


def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
    """
    Load model and its related preprocessors.

    List all available models and types in registry:
    >>> from minigpt4.models import model_zoo
    >>> print(model_zoo)

    Args:
        name (str): name of the model.
        model_type (str): type of the model.
        is_eval (bool): whether the model is in eval mode. Default: False.
        device (str): device to use. Default: "cpu".

    Returns:
        model (torch.nn.Module): model.
        vis_processors (dict): preprocessors for visual inputs.
        txt_processors (dict): preprocessors for text inputs.
    """
    model_cls = registry.get_model_class(name)

    # load model
    model = model_cls.from_pretrained(model_type=model_type)

    if is_eval:
        model.eval()

    # load preprocess
    cfg = OmegaConf.load(model_cls.default_config_path(model_type))
    if cfg is not None:
        preprocess_cfg = cfg.preprocess

        vis_processors, txt_processors = load_preprocess(preprocess_cfg)
    else:
        vis_processors, txt_processors = None, None
        logging.info(
            f"""No default preprocess for model {name} ({model_type}).
                This can happen if the model is not finetuned on downstream datasets,
                or it is not intended for direct use without finetuning.
            """
        )

    if device == "cpu" or device == torch.device("cpu"):
        model = model.float()

    return model.to(device), vis_processors, txt_processors


class ModelZoo:
    """
    A utility class to create string representation of available model architectures and types.

    >>> from minigpt4.models import model_zoo
    >>> # list all available models
    >>> print(model_zoo)
    >>> # show total number of models
    >>> print(len(model_zoo))
    """

    def __init__(self) -> None:
        self.model_zoo = {
            k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
            for k, v in registry.mapping["model_name_mapping"].items()
        }

    def __str__(self) -> str:
        return (
            "=" * 50
            + "\n"
            + f"{'Architectures':<30} {'Types'}\n"
            + "=" * 50
            + "\n"
            + "\n".join(
                [
                    f"{name:<30} {', '.join(types)}"
                    for name, types in self.model_zoo.items()
                ]
            )
        )

    def __iter__(self):
        return iter(self.model_zoo.items())

    def __len__(self):
        return sum([len(v) for v in self.model_zoo.values()])


model_zoo = ModelZoo()


================================================
FILE: xraypulse/models/base_model.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import os

import numpy as np
import torch
import torch.nn as nn
from xraypulse.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
from xraypulse.common.utils import get_abs_path, is_url
from omegaconf import OmegaConf


class BaseModel(nn.Module):
    """Base class for models."""

    def __init__(self):
        super().__init__()

    @property
    def device(self):
        return list(self.parameters())[0].device

    def load_checkpoint(self, url_or_filename):
        """
        Load from a finetuned checkpoint.

        This should expect no mismatch in the model keys and the checkpoint keys.
        """

        if is_url(url_or_filename):
            cached_file = download_cached_file(
                url_or_filename, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file, map_location="cpu")
        elif os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        if "model" in checkpoint.keys():
            state_dict = checkpoint["model"]
        else:
            state_dict = checkpoint

        msg = self.load_state_dict(state_dict, strict=False)

        logging.info("Missing keys {}".format(msg.missing_keys))
        logging.info("load checkpoint from %s" % url_or_filename)

        return msg

    @classmethod
    def from_pretrained(cls, model_type):
        """
        Build a pretrained model from default configuration file, specified by model_type.

        Args:
            - model_type (str): model type, specifying architecture and checkpoints.

        Returns:
            - model (nn.Module): pretrained or finetuned model, depending on the configuration.
        """
        model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
        model = cls.from_config(model_cfg)

        return model

    @classmethod
    def default_config_path(cls, model_type):
        assert (
            model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
        ), "Unknown model type {}".format(model_type)
        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])

    def load_checkpoint_from_config(self, cfg, **kwargs):
        """
        Load checkpoint as specified in the config file.

        If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
        When loading the pretrained model, each task-specific architecture may define their
        own load_from_pretrained() method.
        """
        load_finetuned = cfg.get("load_finetuned", True)
        if load_finetuned:
            finetune_path = cfg.get("finetuned", None)
            assert (
                finetune_path is not None
            ), "Found load_finetuned is True, but finetune_path is None."
            self.load_checkpoint(url_or_filename=finetune_path)
        else:
            # load pre-trained weights
            pretrain_path = cfg.get("pretrained", None)
            assert "Found load_finetuned is False, but pretrain_path is None."
            self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)

    def before_evaluation(self, **kwargs):
        pass

    def show_n_params(self, return_str=True):
        tot = 0
        for p in self.parameters():
            w = 1
            for x in p.shape:
                w *= x
            tot += w
        if return_str:
            if tot >= 1e6:
                return "{:.1f}M".format(tot / 1e6)
            else:
                return "{:.1f}K".format(tot / 1e3)
        else:
            return tot


class BaseEncoder(nn.Module):
    """
    Base class for primitive encoders, such as ViT, TimeSformer, etc.
    """

    def __init__(self):
        super().__init__()

    def forward_features(self, samples, **kwargs):
        raise NotImplementedError

    @property
    def device(self):
        return list(self.parameters())[0].device


class SharedQueueMixin:
    @torch.no_grad()
    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
        # gather keys before updating queue
        image_feats = concat_all_gather(image_feat)
        text_feats = concat_all_gather(text_feat)

        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
        self.text_queue[:, ptr : ptr + batch_size] = text_feats.T

        if idxs is not None:
            idxs = concat_all_gather(idxs)
            self.idx_queue[:, ptr : ptr + batch_size] = idxs.T

        ptr = (ptr + batch_size) % self.queue_size  # move pointer
        self.queue_ptr[0] = ptr


class MomentumDistilationMixin:
    @torch.no_grad()
    def copy_params(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(
                model_pair[0].parameters(), model_pair[1].parameters()
            ):
                param_m.data.copy_(param.data)  # initialize
                param_m.requires_grad = False  # not update by gradient

    @torch.no_grad()
    def _momentum_update(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(
                model_pair[0].parameters(), model_pair[1].parameters()
            ):
                param_m.data = param_m.data * self.momentum + param.data * (
                    1.0 - self.momentum
                )


class GatherLayer(torch.autograd.Function):
    """
    Gather tensors from all workers with support for backward propagation:
    This implementation does not cut the gradients as torch.distributed.all_gather does.
    """

    @staticmethod
    def forward(ctx, x):
        output = [
            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        torch.distributed.all_reduce(all_gradients)
        return all_gradients[torch.distributed.get_rank()]


def all_gather_with_grad(tensors):
    """
    Performs all_gather operation on the provided tensors.
    Graph remains connected for backward grad computation.
    """
    # Queue the gathered tensors
    world_size = torch.distributed.get_world_size()
    # There is no need for reduction in the single-proc case
    if world_size == 1:
        return tensors

    # tensor_all = GatherLayer.apply(tensors)
    tensor_all = GatherLayer.apply(tensors)

    return torch.cat(tensor_all, dim=0)


@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    # if use distributed training
    if not is_dist_avail_and_initialized():
        return tensor

    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


def tile(x, dim, n_tile):
    init_dim = x.size(dim)
    repeat_idx = [1] * x.dim()
    repeat_idx[dim] = n_tile
    x = x.repeat(*(repeat_idx))
    order_index = torch.LongTensor(
        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
    )
    return torch.index_select(x, dim, order_index.to(x.device))


================================================
FILE: xraypulse/models/blip2.py
================================================
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
import datetime

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F

import xraypulse.common.dist_utils as dist_utils
from xraypulse.common.dist_utils import download_cached_file
from xraypulse.common.utils import is_url
from xraypulse.common.logger import MetricLogger
from xraypulse.models.base_model import BaseModel
from xraypulse.models.Qformer import BertConfig, BertLMHeadModel
from xraypulse.models.eva_vit import create_eva_vit_g
from transformers import BertTokenizer


class Blip2Base(BaseModel):
    @classmethod
    def init_tokenizer(cls):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer

    def maybe_autocast(self, dtype=torch.float16):
        # if on cpu, don't use autocast
        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    @classmethod
    def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
        encoder_config.encoder_width = vision_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        Qformer = BertLMHeadModel(config=encoder_config)
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens

    @classmethod
    def init_vision_encoder(
        cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
    ):
        assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
        visual_encoder = create_eva_vit_g(
            img_size, drop_path_rate, use_grad_checkpoint, precision
        )

        ln_vision = LayerNorm(visual_encoder.num_features)
        return visual_encoder, ln_vision

    def load_from_pretrained(self, url_or_filename):
        if is_url(url_or_filename):
            cached_file = download_cached_file(
                url_or_filename, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file, map_location="cpu")
        elif os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]

        msg = self.load_state_dict(state_dict, strict=False)

        # logging.info("Missing keys {}".format(msg.missing_keys))
        logging.info("load checkpoint from %s" % url_or_filename)

        return msg


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


def compute_sim_matrix(model, data_loader, **kwargs):
    k_test = kwargs.pop("k_test")

    metric_logger = MetricLogger(delimiter="  ")
    header = "Evaluation:"

    logging.info("Computing features for evaluation...")
    start_time = time.time()

    texts = data_loader.dataset.text
    num_text = len(texts)
    text_bs = 256
    text_ids = []
    text_embeds = []
    text_atts = []
    for i in range(0, num_text, text_bs):
        text = texts[i : min(num_text, i + text_bs)]
        text_input = model.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=35,
            return_tensors="pt",
        ).to(model.device)
        text_feat = model.forward_text(text_input)
        text_embed = F.normalize(model.text_proj(text_feat))
        text_embeds.append(text_embed)
        text_ids.append(text_input.input_ids)
        text_atts.append(text_input.attention_mask)

    text_embeds = torch.cat(text_embeds, dim=0)
    text_ids = torch.cat(text_ids, dim=0)
    text_atts = torch.cat(text_atts, dim=0)

    vit_feats = []
    image_embeds = []
    for samples in data_loader:
        image = samples["image"]

        image = image.to(model.device)
        image_feat, vit_feat = model.forward_image(image)
        image_embed = model.vision_proj(image_feat)
        image_embed = F.normalize(image_embed, dim=-1)

        vit_feats.append(vit_feat.cpu())
        image_embeds.append(image_embed)

    vit_feats = torch.cat(vit_feats, dim=0)
    image_embeds = torch.cat(image_embeds, dim=0)

    sims_matrix = []
    for image_embed in image_embeds:
        sim_q2t = image_embed @ text_embeds.t()
        sim_i2t, _ = sim_q2t.max(0)
        sims_matrix.append(sim_i2t)
    sims_matrix = torch.stack(sims_matrix, dim=0)

    score_matrix_i2t = torch.full(
        (len(data_loader.dataset.image), len(texts)), -100.0
    ).to(model.device)

    num_tasks = dist_utils.get_world_size()
    rank = dist_utils.get_rank()
    step = sims_matrix.size(0) // num_tasks + 1
    start = rank * step
    end = min(sims_matrix.size(0), start + step)

    for i, sims in enumerate(
        metric_logger.log_every(sims_matrix[start:end], 50, header)
    ):
        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
        image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
        score = model.compute_itm(
            image_inputs=image_inputs,
            text_ids=text_ids[topk_idx],
            text_atts=text_atts[topk_idx],
        ).float()
        score_matrix_i2t[start + i, topk_idx] = score + topk_sim

    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full(
        (len(texts), len(data_loader.dataset.image)), -100.0
    ).to(model.device)

    step = sims_matrix.size(0) // num_tasks + 1
    start = rank * step
    end = min(sims_matrix.size(0), start + step)

    for i, sims in enumerate(
        metric_logger.log_every(sims_matrix[start:end], 50, header)
    ):
        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
        image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
        score = model.compute_itm(
            image_inputs=image_inputs,
            text_ids=text_ids[start + i].repeat(k_test, 1),
            text_atts=text_atts[start + i].repeat(k_test, 1),
        ).float()
        score_matrix_t2i[start + i, topk_idx] = score + topk_sim

    if dist_utils.is_dist_avail_and_initialized():
        dist.barrier()
        torch.distributed.all_reduce(
            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
        )
        torch.distributed.all_reduce(
            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
        )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logging.info("Evaluation time {}".format(total_time_str))

    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()


================================================
FILE: xraypulse/models/blip2_outputs.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from dataclasses import dataclass
from typing import Optional

import torch
from transformers.modeling_outputs import (
    ModelOutput,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
)


@dataclass
class BlipSimilarity(ModelOutput):
    sim_i2t: torch.FloatTensor = None
    sim_t2i: torch.FloatTensor = None

    sim_i2t_m: Optional[torch.FloatTensor] = None
    sim_t2i_m: Optional[torch.FloatTensor] = None

    sim_i2t_targets: Optional[torch.FloatTensor] = None
    sim_t2i_targets: Optional[torch.FloatTensor] = None


@dataclass
class BlipIntermediateOutput(ModelOutput):
    """
    Data class for intermediate outputs of BLIP models.

    image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
    text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).

    image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
    text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).

    encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
    encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.

    decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
    decoder_labels (torch.LongTensor): labels for the captioning loss.

    itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
    itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)

    """

    # uni-modal features
    image_embeds: torch.FloatTensor = None
    text_embeds: Optional[torch.FloatTensor] = None

    image_embeds_m: Optional[torch.FloatTensor] = None
    text_embeds_m: Optional[torch.FloatTensor] = None

    # intermediate outputs of multimodal encoder
    encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
    encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None

    itm_logits: Optional[torch.FloatTensor] = None
    itm_labels: Optional[torch.LongTensor] = None

    # intermediate outputs of multimodal decoder
    decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
    decoder_labels: Optional[torch.LongTensor] = None


@dataclass
class BlipOutput(ModelOutput):
    # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
    sims: Optional[BlipSimilarity] = None

    intermediate_output: BlipIntermediateOutput = None

    loss: Optional[torch.FloatTensor] = None

    loss_itc: Optional[torch.FloatTensor] = None

    loss_itm: Optional[torch.FloatTensor] = None

    loss_lm: Optional[torch.FloatTensor] = None


@dataclass
class BlipOutputFeatures(ModelOutput):
    """
    Data class of features from BlipFeatureExtractor.

    Args:
        image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
        image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
        text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
        text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional

        The first embedding or feature is for the [CLS] token.

        Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
    """

    image_embeds: Optional[torch.FloatTensor] = None
    image_embeds_proj: Optional[torch.FloatTensor] = None

    text_embeds: Optional[torch.FloatTensor] = None
    text_embeds_proj: Optional[torch.FloatTensor] = None

    multimodal_embeds: Optional[torch.FloatTensor] = None


================================================
FILE: xraypulse/models/eva_vit.py
================================================
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model

from xraypulse.common.dist_utils import download_cached_file

def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic',
        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
        **kwargs
    }


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
    
    def extra_repr(self) -> str:
        return 'p={}'.format(self.drop_prob)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement 
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
            proj_drop=0., window_size=None, attn_head_dim=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        if window_size:
            self.window_size = window_size
            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
            # cls to token & token 2 cls & cls to cls

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(window_size[0])
            coords_w = torch.arange(window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
            relative_position_index = \
                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            relative_position_index[0, 0:] = self.num_relative_distance - 3
            relative_position_index[0:, 0] = self.num_relative_distance - 2
            relative_position_index[0, 0] = self.num_relative_distance - 1

            self.register_buffer("relative_position_index", relative_position_index)
        else:
            self.window_size = None
            self.relative_position_bias_table = None
            self.relative_position_index = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, rel_pos_bias=None):
        B, N, C = x.shape
        qkv_bias = None
        if self.q_bias is not None:
            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        if self.relative_position_bias_table is not None:
            relative_position_bias = \
                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                    self.window_size[0] * self.window_size[1] + 1,
                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)

        if rel_pos_bias is not None:
            attn = attn + rel_pos_bias
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -
Download .txt
gitextract_y6ened2f/

├── README.md
├── demo.py
├── demo_configs/
│   └── xraypulse_demo.yaml
├── env.yml
├── prompts/
│   └── alignment.txt
├── run_demo.sh
└── xraypulse/
    ├── __init__.py
    ├── common/
    │   ├── __init__.py
    │   ├── config.py
    │   ├── dist_utils.py
    │   ├── gradcam.py
    │   ├── logger.py
    │   ├── optims.py
    │   ├── registry.py
    │   └── utils.py
    ├── configs/
    │   ├── datasets/
    │   │   ├── mimic/
    │   │   │   └── defaults.yaml
    │   │   └── openi/
    │   │       └── defaults.yaml
    │   ├── default.yaml
    │   └── models/
    │       └── xraypulse.yaml
    ├── conversation/
    │   ├── __init__.py
    │   └── conversation.py
    ├── datasets/
    │   ├── __init__.py
    │   ├── builders/
    │   │   ├── __init__.py
    │   │   ├── base_dataset_builder.py
    │   │   └── image_text_pair_builder.py
    │   ├── data_utils.py
    │   └── datasets/
    │       ├── __init__.py
    │       ├── base_dataset.py
    │       ├── caption_datasets.py
    │       ├── dataloader_utils.py
    │       ├── mimic_dataset.py
    │       └── openi_dataset.py
    ├── models/
    │   ├── Qformer.py
    │   ├── __init__.py
    │   ├── base_model.py
    │   ├── blip2.py
    │   ├── blip2_outputs.py
    │   ├── eva_vit.py
    │   ├── pos_embed.py
    │   └── xray_pulse.py
    ├── processors/
    │   ├── __init__.py
    │   ├── base_processor.py
    │   ├── blip_processors.py
    │   └── randaugment.py
    ├── runners/
    │   ├── __init__.py
    │   └── runner_base.py
    └── tasks/
        ├── __init__.py
        ├── base_task.py
        └── image_text_pretrain.py
Download .txt
SYMBOL INDEX (482 symbols across 34 files)

FILE: demo.py
  function parse_args (line 23) | def parse_args():
  function setup_seeds (line 38) | def setup_seeds(config):
  function gradio_reset (line 73) | def gradio_reset(chat_state, img_list):
  function upload_img (line 80) | def upload_img(gr_img, text_input, chat_state):
  function gradio_ask (line 88) | def gradio_ask(user_message, chatbot, chat_state):
  function gradio_answer (line 96) | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
  function set_example_xray (line 119) | def set_example_xray(example: list) -> dict:
  function set_example_text_input (line 123) | def set_example_text_input(example_text: str) -> dict:

FILE: xraypulse/common/config.py
  class Config (line 16) | class Config:
    method __init__ (line 17) | def __init__(self, args):
    method _validate_runner_config (line 43) | def _validate_runner_config(self, runner_config):
    method _build_opt_list (line 52) | def _build_opt_list(self, opts):
    method build_model_config (line 57) | def build_model_config(config, **kwargs):
    method build_runner_config (line 84) | def build_runner_config(config):
    method build_dataset_config (line 88) | def build_dataset_config(config):
    method _convert_to_dot_list (line 114) | def _convert_to_dot_list(self, opts):
    method get_config (line 128) | def get_config(self):
    method run_cfg (line 132) | def run_cfg(self):
    method datasets_cfg (line 136) | def datasets_cfg(self):
    method model_cfg (line 140) | def model_cfg(self):
    method pretty_print (line 143) | def pretty_print(self):
    method _convert_node_to_json (line 161) | def _convert_node_to_json(self, node):
    method to_dict (line 165) | def to_dict(self):
  function node_to_dict (line 169) | def node_to_dict(node):
  class ConfigValidator (line 173) | class ConfigValidator:
    class _Argument (line 187) | class _Argument:
      method __init__ (line 188) | def __init__(self, name, choices=None, type=None, help=None):
      method __str__ (line 195) | def __str__(self):
    method __init__ (line 205) | def __init__(self, description):
    method __getitem__ (line 212) | def __getitem__(self, key):
    method __str__ (line 217) | def __str__(self) -> str:
    method add_argument (line 220) | def add_argument(self, *args, **kwargs):
    method validate (line 226) | def validate(self, config=None):
    method format_arguments (line 248) | def format_arguments(self):
    method format_help (line 251) | def format_help(self):
    method print_help (line 256) | def print_help(self):
  function create_runner_config_validator (line 261) | def create_runner_config_validator():

FILE: xraypulse/common/dist_utils.py
  function setup_for_distributed (line 17) | def setup_for_distributed(is_master):
  function is_dist_avail_and_initialized (line 33) | def is_dist_avail_and_initialized():
  function get_world_size (line 41) | def get_world_size():
  function get_rank (line 47) | def get_rank():
  function is_main_process (line 53) | def is_main_process():
  function init_distributed_mode (line 57) | def init_distributed_mode(args):
  function get_dist_info (line 93) | def get_dist_info():
  function main_process (line 107) | def main_process(func):
  function download_cached_file (line 117) | def download_cached_file(url, check_hash=True, progress=False):

FILE: xraypulse/common/gradcam.py
  function getAttMap (line 7) | def getAttMap(img, attMap, blur=True, overlap=True):

FILE: xraypulse/common/logger.py
  class SmoothedValue (line 19) | class SmoothedValue(object):
    method __init__ (line 24) | def __init__(self, window_size=20, fmt=None):
    method update (line 32) | def update(self, value, n=1):
    method synchronize_between_processes (line 37) | def synchronize_between_processes(self):
    method median (line 51) | def median(self):
    method avg (line 56) | def avg(self):
    method global_avg (line 61) | def global_avg(self):
    method max (line 65) | def max(self):
    method value (line 69) | def value(self):
    method __str__ (line 72) | def __str__(self):
  class MetricLogger (line 82) | class MetricLogger(object):
    method __init__ (line 83) | def __init__(self, delimiter="\t"):
    method update (line 87) | def update(self, **kwargs):
    method __getattr__ (line 94) | def __getattr__(self, attr):
    method __str__ (line 103) | def __str__(self):
    method global_avg (line 109) | def global_avg(self):
    method synchronize_between_processes (line 115) | def synchronize_between_processes(self):
    method add_meter (line 119) | def add_meter(self, name, meter):
    method log_every (line 122) | def log_every(self, iterable, print_freq, header=None):
  class AttrDict (line 184) | class AttrDict(dict):
    method __init__ (line 185) | def __init__(self, *args, **kwargs):
  function setup_logger (line 190) | def setup_logger():

FILE: xraypulse/common/optims.py
  class LinearWarmupStepLRScheduler (line 14) | class LinearWarmupStepLRScheduler:
    method __init__ (line 15) | def __init__(
    method step (line 37) | def step(self, cur_epoch, cur_step):
  class LinearWarmupCosineLRScheduler (line 57) | class LinearWarmupCosineLRScheduler:
    method __init__ (line 58) | def __init__(
    method step (line 79) | def step(self, cur_epoch, cur_step):
  function cosine_lr_schedule (line 99) | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
  function warmup_lr_schedule (line 108) | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
  function step_lr_schedule (line 115) | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):

FILE: xraypulse/common/registry.py
  class Registry (line 9) | class Registry:
    method register_builder (line 22) | def register_builder(cls, name):
    method register_task (line 54) | def register_task(cls, name):
    method register_model (line 83) | def register_model(cls, name):
    method register_processor (line 112) | def register_processor(cls, name):
    method register_lr_scheduler (line 141) | def register_lr_scheduler(cls, name):
    method register_runner (line 165) | def register_runner(cls, name):
    method register_path (line 189) | def register_path(cls, name, path):
    method register (line 205) | def register(cls, name, obj):
    method get_builder_class (line 232) | def get_builder_class(cls, name):
    method get_model_class (line 236) | def get_model_class(cls, name):
    method get_task_class (line 240) | def get_task_class(cls, name):
    method get_processor_class (line 244) | def get_processor_class(cls, name):
    method get_lr_scheduler_class (line 248) | def get_lr_scheduler_class(cls, name):
    method get_runner_class (line 252) | def get_runner_class(cls, name):
    method list_runners (line 256) | def list_runners(cls):
    method list_models (line 260) | def list_models(cls):
    method list_tasks (line 264) | def list_tasks(cls):
    method list_processors (line 268) | def list_processors(cls):
    method list_lr_schedulers (line 272) | def list_lr_schedulers(cls):
    method list_datasets (line 276) | def list_datasets(cls):
    method get_path (line 280) | def get_path(cls, name):
    method get (line 284) | def get(cls, name, default=None, no_warning=False):
    method unregister (line 315) | def unregister(cls, name):

FILE: xraypulse/common/utils.py
  function now (line 35) | def now():
  function is_url (line 41) | def is_url(url_or_filename):
  function get_cache_path (line 46) | def get_cache_path(rel_path):
  function get_abs_path (line 50) | def get_abs_path(rel_path):
  function load_json (line 54) | def load_json(filename):
  function makedir (line 64) | def makedir(dir_path):
  function get_redirected_url (line 78) | def get_redirected_url(url: str):
  function to_google_drive_download_url (line 93) | def to_google_drive_download_url(view_url: str) -> str:
  function download_google_drive_url (line 108) | def download_google_drive_url(url: str, output_path: str, output_file_na...
  function _get_google_drive_file_id (line 141) | def _get_google_drive_file_id(url: str) -> Optional[str]:
  function _urlretrieve (line 154) | def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
  function download_url (line 167) | def download_url(
  function download_and_extract_archive (line 221) | def download_and_extract_archive(
  function cache_url (line 242) | def cache_url(url: str, cache_dir: str) -> str:
  function create_file_symlink (line 261) | def create_file_symlink(file1, file2):
  function save_file (line 275) | def save_file(data, filename, append_to_json=True, verbose=True):
  function load_file (line 313) | def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
  function abspath (line 374) | def abspath(resource_path: str):
  function makedir (line 386) | def makedir(dir_path):
  function is_url (line 400) | def is_url(input_url):
  function cleanup_dir (line 408) | def cleanup_dir(dir):
  function get_file_size (line 419) | def get_file_size(filename):

FILE: xraypulse/conversation/conversation.py
  class SeparatorStyle (line 16) | class SeparatorStyle(Enum):
  class Conversation (line 23) | class Conversation:
    method get_prompt (line 37) | def get_prompt(self):
    method append_message (line 58) | def append_message(self, role, message):
    method to_gradio_chatbot (line 61) | def to_gradio_chatbot(self):
    method copy (line 72) | def copy(self):
    method dict (line 84) | def dict(self):
  class StoppingCriteriaSub (line 97) | class StoppingCriteriaSub(StoppingCriteria):
    method __init__ (line 99) | def __init__(self, stops=[], encounters=1):
    method __call__ (line 103) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  class Chat (line 123) | class Chat:
    method __init__ (line 124) | def __init__(self, model, vis_processor, device='cuda:0'):
    method ask (line 132) | def ask(self, text, conv):
    method answer (line 139) | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_...
    method upload_img (line 175) | def upload_img(self, image, conv, img_list):
    method get_context_emb (line 194) | def get_context_emb(self, conv, img_list):

FILE: xraypulse/datasets/builders/__init__.py
  function load_dataset (line 21) | def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
  class DatasetZoo (line 59) | class DatasetZoo:
    method __init__ (line 60) | def __init__(self) -> None:
    method get_names (line 66) | def get_names(self):

FILE: xraypulse/datasets/builders/base_dataset_builder.py
  class BaseDatasetBuilder (line 25) | class BaseDatasetBuilder:
    method __init__ (line 28) | def __init__(self, cfg=None):
    method build_datasets (line 45) | def build_datasets(self):
    method build_processors (line 61) | def build_processors(self):
    method _build_proc_from_cfg (line 80) | def _build_proc_from_cfg(cfg):
    method default_config_path (line 88) | def default_config_path(cls, type="default"):
    method _download_data (line 91) | def _download_data(self):
    method _download_ann (line 95) | def _download_ann(self):
    method _download_vis (line 152) | def _download_vis(self):
    method build (line 166) | def build(self):
  function load_dataset_config (line 232) | def load_dataset_config(cfg_path):

FILE: xraypulse/datasets/builders/image_text_pair_builder.py
  class MIMICBuilder (line 12) | class MIMICBuilder(BaseDatasetBuilder):
    method _download_ann (line 17) | def _download_ann(self):
    method _download_vis (line 20) | def _download_vis(self):
    method build_datasets (line 23) | def build_datasets(self):
  class OpenIBuilder (line 49) | class OpenIBuilder(BaseDatasetBuilder):
    method _download_ann (line 54) | def _download_ann(self):
    method _download_vis (line 57) | def _download_vis(self):
    method build (line 60) | def build(self):

FILE: xraypulse/datasets/data_utils.py
  class ChainDataset (line 33) | class ChainDataset(wds.DataPipeline):
    method __init__ (line 43) | def __init__(self, datasets: List[wds.DataPipeline]) -> None:
    method __iter__ (line 59) | def __iter__(self):
  function apply_to_sample (line 66) | def apply_to_sample(f, sample):
  function move_to_cuda (line 83) | def move_to_cuda(sample):
  function prepare_sample (line 90) | def prepare_sample(samples, cuda_enabled=True):
  function reorg_datasets_by_split (line 99) | def reorg_datasets_by_split(datasets):
  function concat_datasets (line 125) | def concat_datasets(datasets):

FILE: xraypulse/datasets/datasets/base_dataset.py
  class BaseDataset (line 15) | class BaseDataset(Dataset):
    method __init__ (line 16) | def __init__(
    method __len__ (line 34) | def __len__(self):
    method collater (line 37) | def collater(self, samples):
    method set_processors (line 40) | def set_processors(self, vis_processor, text_processor):
    method _add_instance_ids (line 44) | def _add_instance_ids(self, key="instance_id"):
  class ConcatDataset (line 49) | class ConcatDataset(ConcatDataset):
    method __init__ (line 50) | def __init__(self, datasets: Iterable[Dataset]) -> None:
    method collater (line 53) | def collater(self, samples):

FILE: xraypulse/datasets/datasets/caption_datasets.py
  class __DisplMixin (line 15) | class __DisplMixin:
    method displ_item (line 16) | def displ_item(self, index):
  class CaptionDataset (line 28) | class CaptionDataset(BaseDataset, __DisplMixin):
    method __init__ (line 29) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
    method __getitem__ (line 44) | def __getitem__(self, index):
  class CaptionEvalDataset (line 63) | class CaptionEvalDataset(BaseDataset, __DisplMixin):
    method __init__ (line 64) | def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
    method __getitem__ (line 81) | def __getitem__(self, index):

FILE: xraypulse/datasets/datasets/dataloader_utils.py
  class MultiIterLoader (line 15) | class MultiIterLoader:
    method __init__ (line 24) | def __init__(self, loaders, ratios=None):
    method __next__ (line 40) | def __next__(self):
  class PrefetchLoader (line 46) | class PrefetchLoader(object):
    method __init__ (line 54) | def __init__(self, loader):
    method __iter__ (line 58) | def __iter__(self):
    method __len__ (line 73) | def __len__(self):
    method preload (line 76) | def preload(self, it):
    method next (line 101) | def next(self, it):
    method __getattr__ (line 109) | def __getattr__(self, name):
  function record_cuda_stream (line 114) | def record_cuda_stream(batch):
  class IterLoader (line 127) | class IterLoader:
    method __init__ (line 135) | def __init__(self, dataloader: DataLoader, use_distributed: bool = Fal...
    method epoch (line 142) | def epoch(self) -> int:
    method __next__ (line 145) | def __next__(self):
    method __iter__ (line 158) | def __iter__(self):
    method __len__ (line 161) | def __len__(self):

FILE: xraypulse/datasets/datasets/mimic_dataset.py
  class MIMICDataset (line 7) | class MIMICDataset(CaptionDataset):
    method __getitem__ (line 9) | def __getitem__(self, index):

FILE: xraypulse/datasets/datasets/openi_dataset.py
  class OpenIDataset (line 14) | class OpenIDataset(CaptionDataset):
    method __getitem__ (line 16) | def __getitem__(self, index):

FILE: xraypulse/models/Qformer.py
  class BertEmbeddings (line 51) | class BertEmbeddings(nn.Module):
    method __init__ (line 54) | def __init__(self, config):
    method forward (line 78) | def forward(
  class BertSelfAttention (line 111) | class BertSelfAttention(nn.Module):
    method __init__ (line 112) | def __init__(self, config, is_cross_attention):
    method save_attn_gradients (line 149) | def save_attn_gradients(self, attn_gradients):
    method get_attn_gradients (line 152) | def get_attn_gradients(self):
    method save_attention_map (line 155) | def save_attention_map(self, attention_map):
    method get_attention_map (line 158) | def get_attention_map(self):
    method transpose_for_scores (line 161) | def transpose_for_scores(self, x):
    method forward (line 169) | def forward(
  class BertSelfOutput (line 278) | class BertSelfOutput(nn.Module):
    method __init__ (line 279) | def __init__(self, config):
    method forward (line 285) | def forward(self, hidden_states, input_tensor):
  class BertAttention (line 292) | class BertAttention(nn.Module):
    method __init__ (line 293) | def __init__(self, config, is_cross_attention=False):
    method prune_heads (line 299) | def prune_heads(self, heads):
    method forward (line 322) | def forward(
  class BertIntermediate (line 349) | class BertIntermediate(nn.Module):
    method __init__ (line 350) | def __init__(self, config):
    method forward (line 358) | def forward(self, hidden_states):
  class BertOutput (line 364) | class BertOutput(nn.Module):
    method __init__ (line 365) | def __init__(self, config):
    method forward (line 371) | def forward(self, hidden_states, input_tensor):
  class BertLayer (line 378) | class BertLayer(nn.Module):
    method __init__ (line 379) | def __init__(self, config, layer_num):
    method forward (line 402) | def forward(
    method feed_forward_chunk (line 476) | def feed_forward_chunk(self, attention_output):
    method feed_forward_chunk_query (line 481) | def feed_forward_chunk_query(self, attention_output):
  class BertEncoder (line 487) | class BertEncoder(nn.Module):
    method __init__ (line 488) | def __init__(self, config):
    method forward (line 495) | def forward(
  class BertPooler (line 592) | class BertPooler(nn.Module):
    method __init__ (line 593) | def __init__(self, config):
    method forward (line 598) | def forward(self, hidden_states):
  class BertPredictionHeadTransform (line 607) | class BertPredictionHeadTransform(nn.Module):
    method __init__ (line 608) | def __init__(self, config):
    method forward (line 617) | def forward(self, hidden_states):
  class BertLMPredictionHead (line 624) | class BertLMPredictionHead(nn.Module):
    method __init__ (line 625) | def __init__(self, config):
    method forward (line 638) | def forward(self, hidden_states):
  class BertOnlyMLMHead (line 644) | class BertOnlyMLMHead(nn.Module):
    method __init__ (line 645) | def __init__(self, config):
    method forward (line 649) | def forward(self, sequence_output):
  class BertPreTrainedModel (line 654) | class BertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 664) | def _init_weights(self, module):
  class BertModel (line 677) | class BertModel(BertPreTrainedModel):
    method __init__ (line 687) | def __init__(self, config, add_pooling_layer=False):
    method get_input_embeddings (line 699) | def get_input_embeddings(self):
    method set_input_embeddings (line 702) | def set_input_embeddings(self, value):
    method _prune_heads (line 705) | def _prune_heads(self, heads_to_prune):
    method get_extended_attention_mask (line 713) | def get_extended_attention_mask(
    method forward (line 804) | def forward(
  class BertLMHeadModel (line 968) | class BertLMHeadModel(BertPreTrainedModel):
    method __init__ (line 973) | def __init__(self, config):
    method get_output_embeddings (line 981) | def get_output_embeddings(self):
    method set_output_embeddings (line 984) | def set_output_embeddings(self, new_embeddings):
    method forward (line 987) | def forward(
    method prepare_inputs_for_generation (line 1097) | def prepare_inputs_for_generation(
    method _reorder_cache (line 1120) | def _reorder_cache(self, past, beam_idx):
  class BertForMaskedLM (line 1131) | class BertForMaskedLM(BertPreTrainedModel):
    method __init__ (line 1136) | def __init__(self, config):
    method get_output_embeddings (line 1144) | def get_output_embeddings(self):
    method set_output_embeddings (line 1147) | def set_output_embeddings(self, new_embeddings):
    method forward (line 1150) | def forward(

FILE: xraypulse/models/__init__.py
  function load_model (line 27) | def load_model(name, model_type, is_eval=False, device="cpu", checkpoint...
  function load_preprocess (line 61) | def load_preprocess(config):
  function load_model_and_preprocess (line 113) | def load_model_and_preprocess(name, model_type, is_eval=False, device="c...
  class ModelZoo (line 161) | class ModelZoo:
    method __init__ (line 172) | def __init__(self) -> None:
    method __str__ (line 178) | def __str__(self) -> str:
    method __iter__ (line 193) | def __iter__(self):
    method __len__ (line 196) | def __len__(self):

FILE: xraypulse/models/base_model.py
  class BaseModel (line 19) | class BaseModel(nn.Module):
    method __init__ (line 22) | def __init__(self):
    method device (line 26) | def device(self):
    method load_checkpoint (line 29) | def load_checkpoint(self, url_or_filename):
    method from_pretrained (line 59) | def from_pretrained(cls, model_type):
    method default_config_path (line 75) | def default_config_path(cls, model_type):
    method load_checkpoint_from_config (line 81) | def load_checkpoint_from_config(self, cfg, **kwargs):
    method before_evaluation (line 102) | def before_evaluation(self, **kwargs):
    method show_n_params (line 105) | def show_n_params(self, return_str=True):
  class BaseEncoder (line 121) | class BaseEncoder(nn.Module):
    method __init__ (line 126) | def __init__(self):
    method forward_features (line 129) | def forward_features(self, samples, **kwargs):
    method device (line 133) | def device(self):
  class SharedQueueMixin (line 137) | class SharedQueueMixin:
    method _dequeue_and_enqueue (line 139) | def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
  class MomentumDistilationMixin (line 161) | class MomentumDistilationMixin:
    method copy_params (line 163) | def copy_params(self):
    method _momentum_update (line 172) | def _momentum_update(self):
  class GatherLayer (line 182) | class GatherLayer(torch.autograd.Function):
    method forward (line 189) | def forward(ctx, x):
    method backward (line 197) | def backward(ctx, *grads):
  function all_gather_with_grad (line 203) | def all_gather_with_grad(tensors):
  function concat_all_gather (line 221) | def concat_all_gather(tensor):
  function tile (line 239) | def tile(x, dim, n_tile):

FILE: xraypulse/models/blip2.py
  class Blip2Base (line 28) | class Blip2Base(BaseModel):
    method init_tokenizer (line 30) | def init_tokenizer(cls):
    method maybe_autocast (line 35) | def maybe_autocast(self, dtype=torch.float16):
    method init_Qformer (line 46) | def init_Qformer(cls, num_query_token, vision_width, cross_attention_f...
    method init_vision_encoder (line 61) | def init_vision_encoder(
    method load_from_pretrained (line 72) | def load_from_pretrained(self, url_or_filename):
  function disabled_train (line 93) | def disabled_train(self, mode=True):
  class LayerNorm (line 99) | class LayerNorm(nn.LayerNorm):
    method forward (line 102) | def forward(self, x: torch.Tensor):
  function compute_sim_matrix (line 108) | def compute_sim_matrix(model, data_loader, **kwargs):

FILE: xraypulse/models/blip2_outputs.py
  class BlipSimilarity (line 20) | class BlipSimilarity(ModelOutput):
  class BlipIntermediateOutput (line 32) | class BlipIntermediateOutput(ModelOutput):
  class BlipOutput (line 73) | class BlipOutput(ModelOutput):
  class BlipOutputFeatures (line 89) | class BlipOutputFeatures(ModelOutput):

FILE: xraypulse/models/eva_vit.py
  function _cfg (line 20) | def _cfg(url='', **kwargs):
  class DropPath (line 30) | class DropPath(nn.Module):
    method __init__ (line 33) | def __init__(self, drop_prob=None):
    method forward (line 37) | def forward(self, x):
    method extra_repr (line 40) | def extra_repr(self) -> str:
  class Mlp (line 44) | class Mlp(nn.Module):
    method __init__ (line 45) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 54) | def forward(self, x):
  class Attention (line 64) | class Attention(nn.Module):
    method __init__ (line 65) | def __init__(
    method forward (line 118) | def forward(self, x, rel_pos_bias=None):
  class Block (line 151) | class Block(nn.Module):
    method __init__ (line 153) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
    method forward (line 173) | def forward(self, x, rel_pos_bias=None):
  class PatchEmbed (line 183) | class PatchEmbed(nn.Module):
    method __init__ (line 186) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 198) | def forward(self, x, **kwargs):
  class RelativePositionBias (line 207) | class RelativePositionBias(nn.Module):
    method __init__ (line 209) | def __init__(self, window_size, num_heads):
    method forward (line 238) | def forward(self):
  class VisionTransformer (line 246) | class VisionTransformer(nn.Module):
    method __init__ (line 249) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
    method fix_init_weight (line 300) | def fix_init_weight(self):
    method _init_weights (line 308) | def _init_weights(self, m):
    method get_classifier (line 317) | def get_classifier(self):
    method reset_classifier (line 320) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 324) | def forward_features(self, x):
    method forward (line 349) | def forward(self, x):
    method get_intermediate_layers (line 354) | def get_intermediate_layers(self, x):
  function interpolate_pos_embed (line 373) | def interpolate_pos_embed(model, checkpoint_model):
  function convert_weights_to_fp16 (line 397) | def convert_weights_to_fp16(model: nn.Module):
  function create_eva_vit_g (line 415) | def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=Fals...

FILE: xraypulse/models/pos_embed.py
  function get_2d_sincos_pos_embed (line 20) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
  function get_2d_sincos_pos_embed_from_grid (line 38) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  function interpolate_pos_embed (line 75) | def interpolate_pos_embed(model, checkpoint_model):

FILE: xraypulse/models/xray_pulse.py
  class StoppingCriteriaSub (line 18) | class StoppingCriteriaSub(StoppingCriteria):
    method __init__ (line 20) | def __init__(self, stops=[], encounters=1):
    method __call__ (line 24) | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTen...
  class SeparatorStyle (line 31) | class SeparatorStyle(Enum):
  class XrayPulse (line 38) | class XrayPulse(Blip2Base):
    method __init__ (line 47) | def __init__(
    method vit_to_cpu (line 144) | def vit_to_cpu(self):
    method encode_img (line 151) | def encode_img(self, image):
    method prompt_wrap (line 174) | def prompt_wrap(self, img_embeds, atts_img, prompt):
    method forward (line 190) | def forward(self, samples):
    method test (line 246) | def test(self, samples, max_new_tokens=300, num_beams=1, min_length=1,...
    method get_context_emb (line 305) | def get_context_emb(self, conv, img):
    method from_config (line 329) | def from_config(cls, cfg):

FILE: xraypulse/processors/__init__.py
  function load_processor (line 25) | def load_processor(name, cfg=None):

FILE: xraypulse/processors/base_processor.py
  class BaseProcessor (line 11) | class BaseProcessor:
    method __init__ (line 12) | def __init__(self):
    method __call__ (line 16) | def __call__(self, item):
    method from_config (line 20) | def from_config(cls, cfg=None):
    method build (line 23) | def build(self, **kwargs):

FILE: xraypulse/processors/blip_processors.py
  class BlipImageBaseProcessor (line 18) | class BlipImageBaseProcessor(BaseProcessor):
    method __init__ (line 19) | def __init__(self, mean=None, std=None):
  class BlipCaptionProcessor (line 29) | class BlipCaptionProcessor(BaseProcessor):
    method __init__ (line 30) | def __init__(self, prompt="", max_words=50):
    method __call__ (line 34) | def __call__(self, caption):
    method from_config (line 40) | def from_config(cls, cfg=None):
    method pre_caption (line 49) | def pre_caption(self, caption):
  class Blip2ImageTrainProcessor (line 72) | class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
    method __init__ (line 73) | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5,...
    method __call__ (line 88) | def __call__(self, item):
    method from_config (line 92) | def from_config(cls, cfg=None):
  class Blip2ImageEvalProcessor (line 114) | class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
    method __init__ (line 115) | def __init__(self, image_size=224, mean=None, std=None):
    method __call__ (line 128) | def __call__(self, item):
    method from_config (line 132) | def from_config(cls, cfg=None):

FILE: xraypulse/processors/randaugment.py
  function identity_func (line 15) | def identity_func(img):
  function autocontrast_func (line 19) | def autocontrast_func(img, cutoff=0):
  function equalize_func (line 52) | def equalize_func(img):
  function rotate_func (line 76) | def rotate_func(img, degree, fill=(0, 0, 0)):
  function solarize_func (line 87) | def solarize_func(img, thresh=128):
  function color_func (line 97) | def color_func(img, factor):
  function contrast_func (line 115) | def contrast_func(img, factor):
  function brightness_func (line 129) | def brightness_func(img, factor):
  function sharpness_func (line 138) | def sharpness_func(img, factor):
  function shear_x_func (line 159) | def shear_x_func(img, factor, fill=(0, 0, 0)):
  function translate_x_func (line 168) | def translate_x_func(img, offset, fill=(0, 0, 0)):
  function translate_y_func (line 180) | def translate_y_func(img, offset, fill=(0, 0, 0)):
  function posterize_func (line 192) | def posterize_func(img, bits):
  function shear_y_func (line 200) | def shear_y_func(img, factor, fill=(0, 0, 0)):
  function cutout_func (line 209) | def cutout_func(img, pad_size, replace=(0, 0, 0)):
  function enhance_level_to_args (line 223) | def enhance_level_to_args(MAX_LEVEL):
  function shear_level_to_args (line 230) | def shear_level_to_args(MAX_LEVEL, replace_value):
  function translate_level_to_args (line 240) | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
  function cutout_level_to_args (line 250) | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
  function solarize_level_to_args (line 258) | def solarize_level_to_args(MAX_LEVEL):
  function none_level_to_args (line 266) | def none_level_to_args(level):
  function posterize_level_to_args (line 270) | def posterize_level_to_args(MAX_LEVEL):
  function rotate_level_to_args (line 278) | def rotate_level_to_args(MAX_LEVEL, replace_value):
  class RandomAugment (line 326) | class RandomAugment(object):
    method __init__ (line 327) | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
    method get_random_ops (line 336) | def get_random_ops(self):
    method __call__ (line 340) | def __call__(self, img):
  class VideoRandomAugment (line 352) | class VideoRandomAugment(object):
    method __init__ (line 353) | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
    method get_random_ops (line 363) | def get_random_ops(self):
    method __call__ (line 367) | def __call__(self, frames):
    method _aug (line 386) | def _aug(self, img, ops, apply_or_not):

FILE: xraypulse/runners/runner_base.py
  class RunnerBase (line 38) | class RunnerBase:
    method __init__ (line 46) | def __init__(self, cfg, task, model, datasets, job_id):
    method device (line 68) | def device(self):
    method use_distributed (line 75) | def use_distributed(self):
    method model (line 79) | def model(self):
    method optimizer (line 99) | def optimizer(self):
    method scaler (line 132) | def scaler(self):
    method lr_scheduler (line 142) | def lr_scheduler(self):
    method dataloaders (line 182) | def dataloaders(self) -> dict:
    method cuda_enabled (line 279) | def cuda_enabled(self):
    method max_epoch (line 283) | def max_epoch(self):
    method log_freq (line 287) | def log_freq(self):
    method init_lr (line 292) | def init_lr(self):
    method min_lr (line 296) | def min_lr(self):
    method accum_grad_iters (line 300) | def accum_grad_iters(self):
    method valid_splits (line 304) | def valid_splits(self):
    method test_splits (line 313) | def test_splits(self):
    method train_splits (line 319) | def train_splits(self):
    method evaluate_only (line 328) | def evaluate_only(self):
    method use_dist_eval_sampler (line 335) | def use_dist_eval_sampler(self):
    method resume_ckpt_path (line 339) | def resume_ckpt_path(self):
    method train_loader (line 343) | def train_loader(self):
    method setup_output_dir (line 348) | def setup_output_dir(self):
    method train (line 363) | def train(self):
    method test (line 422) | def test(self):
    method evaluate (line 435) | def evaluate(self, ckpt, cur_epoch="best", skip_reload=False):
    method train_epoch (line 446) | def train_epoch(self, epoch):
    method test_epoch (line 462) | def test_epoch(self, epoch):
    method eval_epoch (line 479) | def eval_epoch(self, ckpt, split_name, cur_epoch, skip_reload=False):
    method unwrap_dist_model (line 513) | def unwrap_dist_model(self, model):
    method create_loaders (line 519) | def create_loaders(
    method _save_checkpoint (line 603) | def _save_checkpoint(self, cur_epoch, is_best=False):
    method _reload_best_model (line 630) | def _reload_best_model(self, model):
    method _reload_model (line 650) | def _reload_model(self, model,ckpt):
    method _load_checkpoint (line 669) | def _load_checkpoint(self, url_or_filename):
    method log_stats (line 694) | def log_stats(self, stats, split_name):
    method log_config (line 703) | def log_config(self):

FILE: xraypulse/tasks/__init__.py
  function setup_task (line 13) | def setup_task(cfg):

FILE: xraypulse/tasks/base_task.py
  class BaseTask (line 21) | class BaseTask:
    method __init__ (line 22) | def __init__(self, **kwargs):
    method setup_task (line 28) | def setup_task(cls, **kwargs):
    method build_model (line 31) | def build_model(self, cfg):
    method build_datasets (line 37) | def build_datasets(self, cfg):
    method train_step (line 69) | def train_step(self, model, samples):
    method test_step (line 73) | def test_step(self, model, samples):
    method valid_step (line 77) | def valid_step(self, model, samples):
    method before_evaluation (line 80) | def before_evaluation(self, model, dataset, **kwargs):
    method after_evaluation (line 83) | def after_evaluation(self, **kwargs):
    method inference_step (line 86) | def inference_step(self):
    method evaluation (line 89) | def evaluation(self, model, data_loader, cuda_enabled=True):
    method train_epoch (line 108) | def train_epoch(
    method test_epoch (line 133) | def test_epoch(
    method train_iters (line 158) | def train_iters(
    method test_iters (line 186) | def test_iters(
    method _train_inner_loop (line 214) | def _train_inner_loop(
    method _test_inner_loop (line 307) | def _test_inner_loop(
    method save_result (line 354) | def save_result(result, result_dir, filename, remove_duplicate=""):

FILE: xraypulse/tasks/image_text_pretrain.py
  class ImageTextPretrainTask (line 13) | class ImageTextPretrainTask(BaseTask):
    method __init__ (line 14) | def __init__(self):
    method evaluation (line 17) | def evaluation(self, model, data_loader, cuda_enabled=True):
Condensed preview — 49 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (299K chars).
[
  {
    "path": "README.md",
    "chars": 3466,
    "preview": "# XrayPULSE\n\n<!--\n**Here are some ideas to get you started:**\n🙋‍♀️ A short introduction - what is your organization all "
  },
  {
    "path": "demo.py",
    "chars": 8179,
    "preview": "import argparse\nimport os\nimport random\n\nimport numpy as np\nimport torch\nimport torch.backends.cudnn as cudnn\nimport gra"
  },
  {
    "path": "demo_configs/xraypulse_demo.yaml",
    "chars": 630,
    "preview": "model:\n  arch: xray_pulse\n  model_type: pulse\n  freeze_vit: True\n  freeze_qformer: True\n  max_txt_len: 160\n  end_sym: \"<"
  },
  {
    "path": "env.yml",
    "chars": 8435,
    "preview": "name: xraypulse\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n"
  },
  {
    "path": "prompts/alignment.txt",
    "chars": 661,
    "preview": "<Img><图片></Img> 详细描述所给的胸部X光影像。\n<Img><图片></Img> 请观察这张胸部X光影像,并阐述你的发现和总结。\n<Img><图片></Img> 你能否对所给的胸部X光影像进行详细的描述?\n<Img><图片></"
  },
  {
    "path": "run_demo.sh",
    "chars": 96,
    "preview": "CUDA_VISIBLE_DEVICES=0 python -u demo.py --cfg-path demo_configs/xraypulse_demo.yaml  --gpu-id 0"
  },
  {
    "path": "xraypulse/__init__.py",
    "chars": 956,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "xraypulse/common/config.py",
    "chars": 15080,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/dist_utils.py",
    "chars": 3620,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/gradcam.py",
    "chars": 815,
    "preview": "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom scipy.ndimage import filters\nfrom skimage import transform "
  },
  {
    "path": "xraypulse/common/logger.py",
    "chars": 6002,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/optims.py",
    "chars": 3517,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/registry.py",
    "chars": 9926,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/common/utils.py",
    "chars": 13808,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/configs/datasets/mimic/defaults.yaml",
    "chars": 146,
    "preview": "datasets:\n  mimic:\n    data_type: images\n    build_info:\n      storage: /mnt/petrelfs/share_data/huangzhongzhen/multimod"
  },
  {
    "path": "xraypulse/configs/datasets/openi/defaults.yaml",
    "chars": 146,
    "preview": "datasets:\n  openi:\n    data_type: images\n    build_info:\n      storage: /mnt/petrelfs/share_data/huangzhongzhen/multimod"
  },
  {
    "path": "xraypulse/configs/default.yaml",
    "chars": 142,
    "preview": "env:\n  # For default users\n  # cache_root: \"cache\"\n  # For internal use with persistent storage\n  cache_root: \"/export/h"
  },
  {
    "path": "xraypulse/configs/models/xraypulse.yaml",
    "chars": 595,
    "preview": "model:\n  arch: xray_pulse\n\n  # vit encoder\n  image_size: 224\n  drop_path_rate: 0\n  use_grad_checkpoint: False\n  vit_prec"
  },
  {
    "path": "xraypulse/conversation/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "xraypulse/conversation/conversation.py",
    "chars": 7912,
    "preview": "import argparse\nimport time\nfrom PIL import Image\n\nimport torch\nfrom transformers import AutoTokenizer, AutoModelForCaus"
  },
  {
    "path": "xraypulse/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "xraypulse/datasets/builders/__init__.py",
    "chars": 1854,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/datasets/builders/base_dataset_builder.py",
    "chars": 8109,
    "preview": "\"\"\"\n This file is from\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-C"
  },
  {
    "path": "xraypulse/datasets/builders/image_text_pair_builder.py",
    "chars": 2364,
    "preview": "import os\nimport logging\nimport warnings\n\nfrom xraypulse.common.registry import registry\nfrom xraypulse.datasets.builder"
  },
  {
    "path": "xraypulse/datasets/data_utils.py",
    "chars": 6283,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/datasets/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "xraypulse/datasets/datasets/base_dataset.py",
    "chars": 2067,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/datasets/datasets/caption_datasets.py",
    "chars": 2885,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/datasets/datasets/dataloader_utils.py",
    "chars": 5259,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/datasets/datasets/mimic_dataset.py",
    "chars": 747,
    "preview": "import os\nfrom PIL import Image\nimport webdataset as wds\nfrom xraypulse.datasets.datasets.base_dataset import BaseDatase"
  },
  {
    "path": "xraypulse/datasets/datasets/openi_dataset.py",
    "chars": 981,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/models/Qformer.py",
    "chars": 48386,
    "preview": "\"\"\"\n * Copyright (c) 2023, salesforce.com, inc.\n * All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n * For "
  },
  {
    "path": "xraypulse/models/__init__.py",
    "chars": 5762,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/models/base_model.py",
    "chars": 7867,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/models/blip2.py",
    "chars": 7724,
    "preview": "\"\"\"\n Copyright (c) 2023, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/models/blip2_outputs.py",
    "chars": 4153,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/models/eva_vit.py",
    "chars": 19530,
    "preview": "# Based on EVA, BEIT, timm and DeiT code bases\n# https://github.com/baaivision/EVA\n# https://github.com/rwightman/pytorc"
  },
  {
    "path": "xraypulse/models/pos_embed.py",
    "chars": 4054,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "xraypulse/models/xray_pulse.py",
    "chars": 15580,
    "preview": "import logging\nimport random\n\nimport torch\nfrom torch.cuda.amp import autocast as autocast\nimport torch.nn as nn\nimport "
  },
  {
    "path": "xraypulse/processors/__init__.py",
    "chars": 826,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/processors/base_processor.py",
    "chars": 610,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/processors/blip_processors.py",
    "chars": 4006,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/processors/randaugment.py",
    "chars": 11298,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/runners/__init__.py",
    "chars": 307,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/runners/runner_base.py",
    "chars": 24701,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/tasks/__init__.py",
    "chars": 739,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/tasks/base_task.py",
    "chars": 12069,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "xraypulse/tasks/image_text_pretrain.py",
    "chars": 540,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  }
]

About this extraction

This page contains the full source code of the openmedlab/XrayPULSE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 49 files (276.2 KB), approximately 66.8k tokens, and a symbol index with 482 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!