Repository: openmedlab/XrayPULSE Branch: main Commit: 530a0c013e6d Files: 49 Total size: 276.2 KB Directory structure: gitextract_y6ened2f/ ├── README.md ├── demo.py ├── demo_configs/ │ └── xraypulse_demo.yaml ├── env.yml ├── prompts/ │ └── alignment.txt ├── run_demo.sh └── xraypulse/ ├── __init__.py ├── common/ │ ├── __init__.py │ ├── config.py │ ├── dist_utils.py │ ├── gradcam.py │ ├── logger.py │ ├── optims.py │ ├── registry.py │ └── utils.py ├── configs/ │ ├── datasets/ │ │ ├── mimic/ │ │ │ └── defaults.yaml │ │ └── openi/ │ │ └── defaults.yaml │ ├── default.yaml │ └── models/ │ └── xraypulse.yaml ├── conversation/ │ ├── __init__.py │ └── conversation.py ├── datasets/ │ ├── __init__.py │ ├── builders/ │ │ ├── __init__.py │ │ ├── base_dataset_builder.py │ │ └── image_text_pair_builder.py │ ├── data_utils.py │ └── datasets/ │ ├── __init__.py │ ├── base_dataset.py │ ├── caption_datasets.py │ ├── dataloader_utils.py │ ├── mimic_dataset.py │ └── openi_dataset.py ├── models/ │ ├── Qformer.py │ ├── __init__.py │ ├── base_model.py │ ├── blip2.py │ ├── blip2_outputs.py │ ├── eva_vit.py │ ├── pos_embed.py │ └── xray_pulse.py ├── processors/ │ ├── __init__.py │ ├── base_processor.py │ ├── blip_processors.py │ └── randaugment.py ├── runners/ │ ├── __init__.py │ └── runner_base.py └── tasks/ ├── __init__.py ├── base_task.py └── image_text_pretrain.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # XrayPULSE
--- ## Key Features This repository provides the official implementation of XrayPULSE: Key feature bulletin points here - An attempt to extend [PULSE]() to a biomedical multimodal conversational assistant. - XrayPULSE is fintuned on Xray-Report paired datasets in Chinese ## Details Our model is based on PULSE. We utilize [MedCLIP](https://github.com/RyanWangZf/MedCLIP) as our medical visual encoder and Q-former ([BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2)) following a simple linear transformation as the adapter to inject the image to PULSE. For aligning the frozen visual encoder and the LLM by the adapter, we generate Chinese-version Xray-Report paired data from free-text radiology reports of two datasets ([MIMIC-CXR](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) and [OpenI](https://openi.nlm.nih.gov/faq#collection)) with the help of chatGPT. To facilitate research in biomedical multimodal learning, we will release the data to the public.
## Get Started **Installation** ```bash git clone https://github.com/openmedlab/XrayPULSE.git cd XrayPULSE ``` **Environment** ```bash conda env create -f env.yml conda activate xraypulse ``` **Prepare the pretrained weights** You can find the pretrained model weights. - [PULSE\_Model](https://huggingface.co/OpenMEDLab/PULSE-7bv5) - [Pretrained_XrayPULSE_Checkpoint](https://drive.google.com/file/d/1VsO61-3DFuK4ysGPvoD4_JZaRFKvAJR_/view?usp=drive_link) The weights of PULSE would be in a single folder in a structure similar to the following: ``` pulse_weights ├── config.json ├── generation_config.json ├── tokenizer.json ├── tokenizer_config.json ├── special_tokens_map.json ├── pytorch_model.bin.index.json ├── pytorch_model-00001-of-00002.bin ├── pytorch_model-00002-of-00002.bin ``` Then, set the path of pulse_weights to "bloom_model" in the model config file "xraypulse/configs/models/xraypulse.yaml" And add the path of the pretrained checkpoint in "demo_configs/xraypulse_demo.yaml". **Run Demo** ```bash bash run_demo.sh ``` ## 🙏 Acknowledgement This project is built upon the gaint sholders of [XrayGPT](https://github.com/mbzuai-oryx/XrayGPT). Great thanks to it! We used medical aware image encoder from [MedCLIP](https://github.com/RyanWangZf/MedCLIP). The model architecture of XrayGPT follows [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2). ## 🛡️ License This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. ================================================ FILE: demo.py ================================================ import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn import gradio as gr from xraypulse.common.config import Config from xraypulse.common.dist_utils import get_rank from xraypulse.common.registry import registry from xraypulse.conversation.conversation import Chat, CONV_ZH # imports modules for registration from xraypulse.datasets.builders import * from xraypulse.models import * from xraypulse.processors import * from xraypulse.runners import * from xraypulse.tasks import * def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--cfg-path", required=True, help="path to configuration file.") parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args def setup_seeds(config): seed = config.run_cfg.seed + get_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.benchmark = False cudnn.deterministic = True # ======================================== # Model Initialization # ======================================== print('Initializing Chat') args = parse_args() cfg = Config(args) model_config = cfg.model_cfg print(model_config) model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) print(model_cls) model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) vis_processor_cfg = cfg.datasets_cfg.openi.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) print('Initialization Finished') # ======================================== # Gradio Setting # ======================================== def gradio_reset(chat_state, img_list): if chat_state is not None: chat_state.messages = [] if img_list is not None: img_list = [] return None, gr.update(value=None, interactive=True), gr.update(placeholder='请先上传图片', interactive=False),gr.update(value="上传图片并开始咨询", interactive=True), chat_state, img_list def upload_img(gr_img, text_input, chat_state): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None chat_state = CONV_ZH.copy() img_list = [] llm_message = chat.upload_img(gr_img, chat_state, img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='输入问题'), gr.update(value="开始对话", interactive=False), chat_state, img_list def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat.ask(user_message, chat_state) chatbot = chatbot + [[user_message, None]] return '', chatbot, chat_state def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): llm_message = chat.answer(conv=chat_state, img_list=img_list, num_beams=num_beams, temperature=temperature, max_new_tokens=300, max_length=2000)[0] chatbot[-1][1] = llm_message return chatbot, chat_state, img_list title = """

XrayPULSE

""" description = """

上传X光影像,开始诊断咨询

""" disclaimer = """

使用说明:


OpenMedLab

""" def set_example_xray(example: list) -> dict: return gr.Image.update(value=example[0]) def set_example_text_input(example_text: str) -> dict: return gr.Textbox.update(value=example_text[0]) #TODO show examples below with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): with gr.Column(scale=0.5): image = gr.Image(type="pil") upload_button = gr.Button(value="上传影像并开始咨询", interactive=True, variant="primary") clear = gr.Button("重制") num_beams = gr.Slider( minimum=1, maximum=10, value=1, step=1, interactive=True, label="beam search numbers", ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature", ) with gr.Column(): chat_state = gr.State() img_list = gr.State() chatbot = gr.Chatbot(label='XrayPULSE') text_input = gr.Textbox(label='用户', placeholder='请上传X光影像', interactive=False) with gr.Row(): example_xrays = gr.Dataset(components=[image], label="X光影像范例", samples=[ [os.path.join(os.path.dirname(__file__), "images/image1.png")], [os.path.join(os.path.dirname(__file__), "images/image2.png")], [os.path.join(os.path.dirname(__file__), "images/image3.png")], [os.path.join(os.path.dirname(__file__), "images/image4.png")], [os.path.join(os.path.dirname(__file__), "images/image5.png")], [os.path.join(os.path.dirname(__file__), "images/image6.png")], ]) with gr.Row(): example_texts = gr.Dataset(components=[gr.Textbox(visible=False)], label="咨询问题范例", samples=[ ["详细描述所给的胸部X光影像。"], ["请观察这张胸部X光影像,并阐述你的发现和总结。"], ["你能否对所给的胸部X光影像进行详细的描述?"], ["尽可能详细地描述所给的胸部X光影像。"], ["这张胸部X光影像中的关键症状是什么?"], ["你能在这张胸部X光影像中,指出存在的任何异常或需要注意的地方吗"], ["这张胸部X光影像中,有哪些肺部和心脏的具体特征可见?"], ["在这张胸部X光影像中,最显著的特征是什么,它是如何反映出病人的健康状况?"], ["根据从这张胸部X光影像中观察到的发现,给出影像的总体印象是正常还是异常?"], ],) example_xrays.click(fn=set_example_xray, inputs=example_xrays, outputs=example_xrays.components) upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list]) click_response = example_texts.click(set_example_text_input, inputs=example_texts, outputs=text_input).then( gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state], queue=False) click_response.then( gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list], queue=False ) submit_response = text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state], queue=False) submit_response.then( gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list], queue=False ) clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False) gr.Markdown(disclaimer) demo.launch(share=True, enable_queue=True) ================================================ FILE: demo_configs/xraypulse_demo.yaml ================================================ model: arch: xray_pulse model_type: pulse freeze_vit: True freeze_qformer: True max_txt_len: 160 end_sym: "" low_resource: True prompt_path: "prompts/alignment.txt" prompt_template: 'Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01 User: {} Helper: ' ckpt: './XrayPULSE_ckpt.pth' datasets: openi: vis_processor: train: name: "blip2_image_eval" image_size: 224 text_processor: train: name: "blip_caption" run: task: image_text_pretrain ================================================ FILE: env.yml ================================================ name: xraypulse channels: - pytorch - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - blas=1.0=mkl - brotlipy=0.7.0=py39h27cfd23_1003 - bzip2=1.0.8=h7b6447c_0 - ca-certificates=2023.01.10=h06a4308_0 - certifi=2022.12.7=py39h06a4308_0 - cffi=1.15.1=py39h5eee18b_3 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - cryptography=39.0.1=py39h9ce1e76_0 - cudatoolkit=11.3.1=h2bc3f7f_2 - ffmpeg=4.3=hf484d3e_0 - flit-core=3.8.0=py39h06a4308_0 - freetype=2.12.1=h4a9f257_0 - giflib=5.2.1=h5eee18b_3 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 - intel-openmp=2021.4.0=h06a4308_3561 - jpeg=9e=h5eee18b_1 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libdeflate=1.17=h5eee18b_0 - libffi=3.4.2=h6a678d5_6 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h7f8727e_2 - libidn2=2.3.2=h7f8727e_0 - libpng=1.6.39=h5eee18b_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.0=h6a678d5_2 - libunistring=0.9.10=h27cfd23_0 - libwebp=1.2.4=h11a3e52_1 - libwebp-base=1.2.4=h5eee18b_1 - lz4-c=1.9.4=h6a678d5_0 - mkl=2021.4.0=h06a4308_640 - mkl-service=2.4.0=py39h7f8727e_0 - mkl_fft=1.3.1=py39hd3c417c_0 - mkl_random=1.2.2=py39h51133e4_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - numpy=1.23.5=py39h14f4228_0 - numpy-base=1.23.5=py39h31eccc5_0 - openh264=2.1.1=h4ff587b_0 - openssl=1.1.1t=h7f8727e_0 - pillow=9.4.0=py39h6a678d5_0 - pip=23.0.1=py39h06a4308_0 - pycparser=2.21=pyhd3eb1b0_0 - pyopenssl=23.0.0=py39h06a4308_0 - pysocks=1.7.1=py39h06a4308_0 - python=3.9.16=h7a1cb2a_2 - pytorch-mutex=1.0=cuda - readline=8.2=h5eee18b_0 - requests=2.28.1=py39h06a4308_1 - setuptools=66.0.0=py39h06a4308_0 - six=1.16.0=pyhd3eb1b0_1 - sqlite=3.41.2=h5eee18b_0 - tk=8.6.12=h1ccaba5_0 - torchaudio=0.12.1=py39_cu113 - torchvision=0.13.1=py39_cu113 - typing_extensions=4.4.0=py39h06a4308_0 - urllib3=1.26.15=py39h06a4308_0 - wheel=0.38.4=py39h06a4308_0 - xz=5.2.10=h5eee18b_1 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.5=hc292b87_0 - pip: - accelerate==0.15.0 - aiofiles==23.1.0 - aiohttp==3.8.4 - aiosignal==1.3.1 - albumentations==1.3.0 - altair==4.2.2 - antlr4-python3-runtime==4.9.3 - anyio==3.6.2 - appdirs==1.4.4 - argon2-cffi==21.3.0 - argon2-cffi-bindings==21.2.0 - arrow==1.2.3 - asttokens==2.2.1 - async-timeout==4.0.2 - attrs==22.2.0 - backcall==0.2.0 - beautifulsoup4==4.12.2 - bitsandbytes==0.37.0 - bleach==6.0.0 - blis==0.7.9 - braceexpand==0.1.7 - cachetools==5.3.0 - catalogue==2.0.8 - cchardet==2.1.7 - chardet==3.0.4 - click==8.1.3 - cmake==3.26.3 - comm==0.1.3 - confection==0.0.4 - contourpy==1.0.7 - cycler==0.11.0 - cymem==2.0.7 - dataclasses==0.6 - datasets==2.12.0 - debugpy==1.6.7 - decorator==5.1.1 - decord==0.6.0 - defusedxml==0.7.1 - dill==0.3.6 - docker-pycreds==0.4.0 - entrypoints==0.4 - et-xmlfile==1.1.0 - evaluate==0.4.0 - executing==1.2.0 - exifread-nocycle==3.0.1 - fairscale==0.4.13 - fastapi==0.95.1 - fastchat==0.1 - fastjsonschema==2.16.3 - ffmpy==0.3.0 - filelock==3.9.0 - fire==0.5.0 - fonttools==4.38.0 - fqdn==1.5.1 - frozenlist==1.3.3 - fschat==0.2.3 - fsspec==2022.11.0 - gensim==4.3.1 - gitdb==4.0.10 - gitpython==3.1.31 - googletrans==3.0.0 - gradio==3.23.0 - gradio-client==0.0.8 - h11==0.9.0 - h2==3.2.0 - hiq-python==1.1.12 - hpack==3.0.0 - hstspreload==2023.1.1 - httpcore==0.9.1 - httpx==0.13.3 - huggingface-hub==0.13.4 - hyperframe==5.2.0 - idna==2.10 - imageio==2.27.0 - img2dataset==1.25.4 - importlib-metadata==6.5.0 - importlib-resources==5.12.0 - iopath==0.1.10 - ipykernel==6.22.0 - ipython==8.12.0 - ipython-genutils==0.2.0 - isoduration==20.11.0 - jedi==0.18.2 - jinja2==3.1.2 - joblib==1.2.0 - jsonpointer==2.3 - jsonschema==4.17.3 - jupyter-client==8.2.0 - jupyter-core==5.3.0 - jupyter-events==0.6.3 - jupyter-server==2.5.0 - jupyter-server-terminals==0.4.4 - jupyterlab-pygments==0.2.2 - kiwisolver==1.4.4 - langcodes==3.3.0 - lazy-loader==0.2 - linkify-it-py==2.0.0 - lit==16.0.1 - llvmlite==0.39.1 - markdown-it-py==2.2.0 - markdown2==2.4.8 - markupsafe==2.1.2 - matplotlib==3.7.0 - matplotlib-inline==0.1.6 - mdit-py-plugins==0.3.3 - mdurl==0.1.2 - medclip==0.0.3 - mistune==2.0.5 - mpmath==1.3.0 - multidict==6.0.4 - multiprocess==0.70.14 - murmurhash==1.0.9 - nbclassic==0.5.5 - nbclient==0.7.3 - nbconvert==7.3.1 - nbformat==5.8.0 - nest-asyncio==1.5.6 - networkx==3.1 - nltk==3.8.1 - notebook==6.5.4 - notebook-shim==0.2.2 - numba==0.56.4 - nvidia-cublas-cu11==11.10.3.66 - nvidia-cuda-cupti-cu11==11.7.101 - nvidia-cuda-nvrtc-cu11==11.7.99 - nvidia-cuda-runtime-cu11==11.7.99 - nvidia-cudnn-cu11==8.5.0.96 - nvidia-cufft-cu11==10.9.0.58 - nvidia-curand-cu11==10.2.10.91 - nvidia-cusolver-cu11==11.4.0.1 - nvidia-cusparse-cu11==11.7.4.91 - nvidia-nccl-cu11==2.14.3 - nvidia-nvtx-cu11==11.7.91 - omegaconf==2.3.0 - openai==0.27.0 - opencv-python==4.7.0.72 - opencv-python-headless==4.7.0.72 - openpyxl==3.1.2 - orjson==3.8.10 - packaging==23.0 - pandas==1.5.3 - pandocfilters==1.5.0 - parso==0.8.3 - pathtools==0.1.2 - pathy==0.10.1 - peft==0.2.0 - pexpect==4.8.0 - pickleshare==0.7.5 - platformdirs==3.2.0 - portalocker==2.7.0 - preshed==3.0.8 - prometheus-client==0.16.0 - promise==2.3 - prompt-toolkit==3.0.38 - protobuf==3.20.3 - psutil==5.9.4 - ptyprocess==0.7.0 - pure-eval==0.2.2 - py-itree==0.0.19 - pyarrow==12.0.1 - pydantic==1.10.7 - pydub==0.25.1 - pygments==2.15.1 - pyllama==0.0.9 - pynndescent==0.5.9 - pyparsing==3.0.9 - pyrsistent==0.19.3 - python-dateutil==2.8.2 - python-json-logger==2.0.7 - python-multipart==0.0.6 - pytz==2023.3 - pywavelets==1.4.1 - pyyaml==6.0 - pyzmq==25.0.2 - qudida==0.0.4 - regex==2022.10.31 - responses==0.18.0 - rfc3339-validator==0.1.4 - rfc3986==1.5.0 - rfc3986-validator==0.1.1 - rich==13.3.4 - scikit-image==0.20.0 - scikit-learn==1.2.2 - scipy==1.9.1 - semantic-version==2.10.0 - send2trash==1.8.0 - sentence-transformers==2.2.2 - sentencepiece==0.1.97 - sentry-sdk==1.19.1 - setproctitle==1.3.2 - shortuuid==1.0.11 - smart-open==6.3.0 - smmap==5.0.0 - sniffio==1.3.0 - soupsieve==2.4.1 - spacy==3.5.1 - spacy-legacy==3.0.12 - spacy-loggers==1.0.4 - srsly==2.4.6 - stack-data==0.6.2 - starlette==0.26.1 - svgwrite==1.4.3 - sympy==1.11.1 - tenacity==8.2.2 - termcolor==2.2.0 - terminado==0.17.1 - textaugment==1.3.4 - textblob==0.17.1 - thinc==8.1.9 - threadpoolctl==3.1.0 - tifffile==2023.4.12 - timm==0.6.13 - tinycss2==1.2.1 - tokenizers==0.13.2 - toolz==0.12.0 - torch==2.0.0 - tornado==6.3 - tqdm==4.64.1 - traitlets==5.9.0 - transformers==4.29.0 - triton==2.0.0 - typer==0.7.0 - tzdata==2023.3 - uc-micro-py==1.0.1 - umap-learn==0.5.3 - uri-template==1.2.0 - uvicorn==0.21.1 - wandb==0.12.21 - wasabi==1.1.1 - wavedrom==2.0.3.post3 - wcwidth==0.2.6 - webcolors==1.13 - webdataset==0.2.48 - webencodings==0.5.1 - websocket-client==1.5.1 - websockets==11.0.2 - wget==3.2 - xxhash==3.2.0 - yarl==1.8.2 - zipp==3.14.0 ================================================ FILE: prompts/alignment.txt ================================================ <图片> 详细描述所给的胸部X光影像。 <图片> 请观察这张胸部X光影像,并阐述你的发现和总结。 <图片> 你能否对所给的胸部X光影像进行详细的描述? <图片> 尽可能详细地描述所给的胸部X光影像。 <图片> 这张胸部X光影像中的关键症状是什么? <图片> 你能在这张胸部X光影像中,指出存在的任何异常或需要注意的地方吗? <图片> 这张胸部X光影像中,有哪些肺部和心脏的具体特征可见? <图片> 在这张胸部X光影像中,最显著的特征是什么,它是如何反映出病人的健康状况? <图片> 这张胸部X光影像提供了哪些观察发现和总体印象? <图片> 这张胸部X光影像中,心脏的大小和形状如何? <图片> 根据从这张胸部X光影像中观察到的发现,给出影像的总体印象是正常还是异常? <图片> 在这张胸部X光影像中,有无感染或炎症的迹象?如果有,可能的原因是什么? <图片> 根据这张胸部X光影像中的发现,请你给出总体印象。 <图片> 在这张胸部X光影像中,有没有患者淋巴结肿大或异常的可见迹象 <图片> 这张胸部X光影像中观察到的异常有没有可能引发的并发症或风险?或者说,这张X光影像所展示的患者是正常的吗 ================================================ FILE: run_demo.sh ================================================ CUDA_VISIBLE_DEVICES=0 python -u demo.py --cfg-path demo_configs/xraypulse_demo.yaml --gpu-id 0 ================================================ FILE: xraypulse/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import os import sys from omegaconf import OmegaConf from xraypulse.common.registry import registry from xraypulse.datasets.builders import * from xraypulse.models import * from xraypulse.processors import * from xraypulse.tasks import * root_dir = os.path.dirname(os.path.abspath(__file__)) default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) registry.register_path("library_root", root_dir) repo_root = os.path.join(root_dir, "..") registry.register_path("repo_root", repo_root) cache_root = os.path.join(repo_root, default_cfg.env.cache_root) registry.register_path("cache_root", cache_root) registry.register("MAX_INT", sys.maxsize) registry.register("SPLIT_NAMES", ["train", "val", "test"]) ================================================ FILE: xraypulse/common/__init__.py ================================================ ================================================ FILE: xraypulse/common/config.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import json from typing import Dict from omegaconf import OmegaConf from xraypulse.common.registry import registry class Config: def __init__(self, args): self.config = {} self.args = args # Register the config and configuration for setup registry.register("configuration", self) user_config = self._build_opt_list(self.args.options) config = OmegaConf.load(self.args.cfg_path) runner_config = self.build_runner_config(config) model_config = self.build_model_config(config, **user_config) dataset_config = self.build_dataset_config(config) # Validate the user-provided runner configuration # model and dataset configuration are supposed to be validated by the respective classes # [TODO] validate the model/dataset configuration # self._validate_runner_config(runner_config) # Override the default configuration with user options. self.config = OmegaConf.merge( runner_config, model_config, dataset_config, user_config ) def _validate_runner_config(self, runner_config): """ This method validates the configuration, such that 1) all the user specified options are valid; 2) no type mismatches between the user specified options and the config. """ runner_config_validator = create_runner_config_validator() runner_config_validator.validate(runner_config) def _build_opt_list(self, opts): opts_dot_list = self._convert_to_dot_list(opts) return OmegaConf.from_dotlist(opts_dot_list) @staticmethod def build_model_config(config, **kwargs): model = config.get("model", None) assert model is not None, "Missing model configuration file." model_cls = registry.get_model_class(model.arch) assert model_cls is not None, f"Model '{model.arch}' has not been registered." model_type = kwargs.get("model.model_type", None) if not model_type: model_type = model.get("model_type", None) # else use the model type selected by user. assert model_type is not None, "Missing model_type." model_config_path = model_cls.default_config_path(model_type=model_type) model_config = OmegaConf.create() # hierarchy override, customized config > default config model_config = OmegaConf.merge( model_config, OmegaConf.load(model_config_path), {"model": config["model"]}, ) return model_config @staticmethod def build_runner_config(config): return {"run": config.run} @staticmethod def build_dataset_config(config): datasets = config.get("datasets", None) if datasets is None: raise KeyError( "Expecting 'datasets' as the root key for dataset configuration." ) dataset_config = OmegaConf.create() for dataset_name in datasets: builder_cls = registry.get_builder_class(dataset_name) dataset_config_type = datasets[dataset_name].get("type", "default") dataset_config_path = builder_cls.default_config_path( type=dataset_config_type ) # hierarchy override, customized config > default config dataset_config = OmegaConf.merge( dataset_config, OmegaConf.load(dataset_config_path), {"datasets": {dataset_name: config["datasets"][dataset_name]}}, ) return dataset_config def _convert_to_dot_list(self, opts): if opts is None: opts = [] if len(opts) == 0: return opts has_equal = opts[0].find("=") != -1 if has_equal: return opts return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] def get_config(self): return self.config @property def run_cfg(self): return self.config.run @property def datasets_cfg(self): return self.config.datasets @property def model_cfg(self): return self.config.model def pretty_print(self): logging.info("\n===== Running Parameters =====") logging.info(self._convert_node_to_json(self.config.run)) logging.info("\n====== Dataset Attributes ======") datasets = self.config.datasets for dataset in datasets: if dataset in self.config.datasets: logging.info(f"\n======== {dataset} =======") dataset_config = self.config.datasets[dataset] logging.info(self._convert_node_to_json(dataset_config)) else: logging.warning(f"No dataset named '{dataset}' in config. Skipping") logging.info(f"\n====== Model Attributes ======") logging.info(self._convert_node_to_json(self.config.model)) def _convert_node_to_json(self, node): container = OmegaConf.to_container(node, resolve=True) return json.dumps(container, indent=4, sort_keys=True) def to_dict(self): return OmegaConf.to_container(self.config) def node_to_dict(node): return OmegaConf.to_container(node) class ConfigValidator: """ This is a preliminary implementation to centralize and validate the configuration. May be altered in the future. A helper class to validate configurations from yaml file. This serves the following purposes: 1. Ensure all the options in the yaml are defined, raise error if not. 2. when type mismatches are found, the validator will raise an error. 3. a central place to store and display helpful messages for supported configurations. """ class _Argument: def __init__(self, name, choices=None, type=None, help=None): self.name = name self.val = None self.choices = choices self.type = type self.help = help def __str__(self): s = f"{self.name}={self.val}" if self.type is not None: s += f", ({self.type})" if self.choices is not None: s += f", choices: {self.choices}" if self.help is not None: s += f", ({self.help})" return s def __init__(self, description): self.description = description self.arguments = dict() self.parsed_args = None def __getitem__(self, key): assert self.parsed_args is not None, "No arguments parsed yet." return self.parsed_args[key] def __str__(self) -> str: return self.format_help() def add_argument(self, *args, **kwargs): """ Assume the first argument is the name of the argument. """ self.arguments[args[0]] = self._Argument(*args, **kwargs) def validate(self, config=None): """ Convert yaml config (dict-like) to list, required by argparse. """ for k, v in config.items(): assert ( k in self.arguments ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" if self.arguments[k].type is not None: try: self.arguments[k].val = self.arguments[k].type(v) except ValueError: raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") if self.arguments[k].choices is not None: assert ( v in self.arguments[k].choices ), f"""{k} must be one of {self.arguments[k].choices}.""" return config def format_arguments(self): return str([f"{k}" for k in sorted(self.arguments.keys())]) def format_help(self): # description + key-value pair string for each argument help_msg = str(self.description) return help_msg + ", available arguments: " + self.format_arguments() def print_help(self): # display help message print(self.format_help()) def create_runner_config_validator(): validator = ConfigValidator(description="Runner configurations") validator.add_argument( "runner", type=str, choices=["runner_base", "runner_iter"], help="""Runner to use. The "runner_base" uses epoch-based training while iter-based runner runs based on iters. Default: runner_base""", ) # add argumetns for training dataset ratios validator.add_argument( "train_dataset_ratios", type=Dict[str, float], help="""Ratios of training dataset. This is used in iteration-based runner. Do not support for epoch-based runner because how to define an epoch becomes tricky. Default: None""", ) validator.add_argument( "max_iters", type=float, help="Maximum number of iterations to run.", ) validator.add_argument( "max_epoch", type=int, help="Maximum number of epochs to run.", ) # add arguments for iters_per_inner_epoch validator.add_argument( "iters_per_inner_epoch", type=float, help="Number of iterations per inner epoch. This is required when runner is runner_iter.", ) lr_scheds_choices = registry.list_lr_schedulers() validator.add_argument( "lr_sched", type=str, choices=lr_scheds_choices, help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), ) task_choices = registry.list_tasks() validator.add_argument( "task", type=str, choices=task_choices, help="Task to use, from {}".format(task_choices), ) # add arguments for init_lr validator.add_argument( "init_lr", type=float, help="Initial learning rate. This will be the learning rate after warmup and before decay.", ) # add arguments for min_lr validator.add_argument( "min_lr", type=float, help="Minimum learning rate (after decay).", ) # add arguments for warmup_lr validator.add_argument( "warmup_lr", type=float, help="Starting learning rate for warmup.", ) # add arguments for learning rate decay rate validator.add_argument( "lr_decay_rate", type=float, help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", ) # add arguments for weight decay validator.add_argument( "weight_decay", type=float, help="Weight decay rate.", ) # add arguments for training batch size validator.add_argument( "batch_size_train", type=int, help="Training batch size.", ) # add arguments for evaluation batch size validator.add_argument( "batch_size_eval", type=int, help="Evaluation batch size, including validation and testing.", ) # add arguments for number of workers for data loading validator.add_argument( "num_workers", help="Number of workers for data loading.", ) # add arguments for warm up steps validator.add_argument( "warmup_steps", type=int, help="Number of warmup steps. Required if a warmup schedule is used.", ) # add arguments for random seed validator.add_argument( "seed", type=int, help="Random seed.", ) # add arguments for output directory validator.add_argument( "output_dir", type=str, help="Output directory to save checkpoints and logs.", ) # add arguments for whether only use evaluation validator.add_argument( "evaluate", help="Whether to only evaluate the model. If true, training will not be performed.", ) # add arguments for splits used for training, e.g. ["train", "val"] validator.add_argument( "train_splits", type=list, help="Splits to use for training.", ) # add arguments for splits used for validation, e.g. ["val"] validator.add_argument( "valid_splits", type=list, help="Splits to use for validation. If not provided, will skip the validation.", ) # add arguments for splits used for testing, e.g. ["test"] validator.add_argument( "test_splits", type=list, help="Splits to use for testing. If not provided, will skip the testing.", ) # add arguments for accumulating gradient for iterations validator.add_argument( "accum_grad_iters", type=int, help="Number of iterations to accumulate gradient for.", ) # ====== distributed training ====== validator.add_argument( "device", type=str, choices=["cpu", "cuda"], help="Device to use. Support 'cuda' or 'cpu' as for now.", ) validator.add_argument( "world_size", type=int, help="Number of processes participating in the job.", ) validator.add_argument("dist_url", type=str) validator.add_argument("distributed", type=bool) # add arguments to opt using distributed sampler during evaluation or not validator.add_argument( "use_dist_eval_sampler", type=bool, help="Whether to use distributed sampler during evaluation or not.", ) # ====== task specific ====== # generation task specific arguments # add arguments for maximal length of text output validator.add_argument( "max_len", type=int, help="Maximal length of text output.", ) # add arguments for minimal length of text output validator.add_argument( "min_len", type=int, help="Minimal length of text output.", ) # add arguments number of beams validator.add_argument( "num_beams", type=int, help="Number of beams used for beam search.", ) # vqa task specific arguments # add arguments for number of answer candidates validator.add_argument( "num_ans_candidates", type=int, help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", ) # add arguments for inference method validator.add_argument( "inference_method", type=str, choices=["genearte", "rank"], help="""Inference method to use for question answering. If rank, requires a answer list.""", ) # ====== model specific ====== validator.add_argument( "k_test", type=int, help="Number of top k most similar samples from ITC/VTC selection to be tested.", ) return validator ================================================ FILE: xraypulse/common/dist_utils.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import datetime import functools import os import torch import torch.distributed as dist import timm.models.hub as timm_hub def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def init_distributed_mode(args): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) elif "SLURM_PROCID" in os.environ: args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" print( "| distributed init (rank {}, world {}): {}".format( args.rank, args.world_size, args.dist_url ), flush=True, ) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta( days=365 ), # allow auto-downloading and de-compressing ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) def get_dist_info(): if torch.__version__ < "1.0": initialized = dist._initialized else: initialized = dist.is_initialized() if initialized: rank = dist.get_rank() world_size = dist.get_world_size() else: # non-distributed training rank = 0 world_size = 1 return rank, world_size def main_process(func): @functools.wraps(func) def wrapper(*args, **kwargs): rank, _ = get_dist_info() if rank == 0: return func(*args, **kwargs) return wrapper def download_cached_file(url, check_hash=True, progress=False): """ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. """ def get_cached_file_path(): # a hack to sync the file path across processes parts = torch.hub.urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(timm_hub.get_cache_dir(), filename) return cached_file if is_main_process(): timm_hub.download_cached_file(url, check_hash, progress) if is_dist_avail_and_initialized(): dist.barrier() return get_cached_file_path() ================================================ FILE: xraypulse/common/gradcam.py ================================================ import numpy as np from matplotlib import pyplot as plt from scipy.ndimage import filters from skimage import transform as skimage_transform def getAttMap(img, attMap, blur=True, overlap=True): attMap -= attMap.min() if attMap.max() > 0: attMap /= attMap.max() attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") if blur: attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) attMap -= attMap.min() attMap /= attMap.max() cmap = plt.get_cmap("jet") attMapV = cmap(attMap) attMapV = np.delete(attMapV, 3, 2) if overlap: attMap = ( 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV ) return attMap ================================================ FILE: xraypulse/common/logger.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import datetime import logging import time from collections import defaultdict, deque import torch import torch.distributed as dist from xraypulse.common import dist_utils class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not dist_utils.is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value, ) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, attr) ) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def global_avg(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = "" start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt="{avg:.4f}") data_time = SmoothedValue(fmt="{avg:.4f}") space_fmt = ":" + str(len(str(len(iterable)))) + "d" log_msg = [ header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}", ] if torch.cuda.is_available(): log_msg.append("max mem: {memory:.0f}") log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print( log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB, ) ) else: print( log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), ) ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print( "{} Total time: {} ({:.4f} s / it)".format( header, total_time_str, total_time / len(iterable) ) ) class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def setup_logger(): logging.basicConfig( level=logging.INFO if dist_utils.is_main_process() else logging.WARN, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()], ) ================================================ FILE: xraypulse/common/optims.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import math from xraypulse.common.registry import registry @registry.register_lr_scheduler("linear_warmup_step_lr") class LinearWarmupStepLRScheduler: def __init__( self, optimizer, max_epoch, min_lr, init_lr, decay_rate=1, warmup_start_lr=-1, warmup_steps=0, **kwargs ): self.optimizer = optimizer self.max_epoch = max_epoch self.min_lr = min_lr self.decay_rate = decay_rate self.init_lr = init_lr self.warmup_steps = warmup_steps self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr def step(self, cur_epoch, cur_step): if cur_epoch == 0: warmup_lr_schedule( step=cur_step, optimizer=self.optimizer, max_step=self.warmup_steps, init_lr=self.warmup_start_lr, max_lr=self.init_lr, ) else: step_lr_schedule( epoch=cur_epoch, optimizer=self.optimizer, init_lr=self.init_lr, min_lr=self.min_lr, decay_rate=self.decay_rate, ) @registry.register_lr_scheduler("linear_warmup_cosine_lr") class LinearWarmupCosineLRScheduler: def __init__( self, optimizer, max_epoch, iters_per_epoch, min_lr, init_lr, warmup_steps=0, warmup_start_lr=-1, **kwargs ): self.optimizer = optimizer self.max_epoch = max_epoch self.iters_per_epoch = iters_per_epoch self.min_lr = min_lr self.init_lr = init_lr self.warmup_steps = warmup_steps self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr def step(self, cur_epoch, cur_step): total_cur_step = cur_epoch * self.iters_per_epoch + cur_step if total_cur_step < self.warmup_steps: warmup_lr_schedule( step=cur_step, optimizer=self.optimizer, max_step=self.warmup_steps, init_lr=self.warmup_start_lr, max_lr=self.init_lr, ) else: cosine_lr_schedule( epoch=total_cur_step, optimizer=self.optimizer, max_epoch=self.max_epoch * self.iters_per_epoch, init_lr=self.init_lr, min_lr=self.min_lr, ) def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): """Decay the learning rate""" lr = (init_lr - min_lr) * 0.5 * ( 1.0 + math.cos(math.pi * epoch / max_epoch) ) + min_lr for param_group in optimizer.param_groups: param_group["lr"] = lr def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): """Warmup the learning rate""" lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) for param_group in optimizer.param_groups: param_group["lr"] = lr def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): """Decay the learning rate""" lr = max(min_lr, init_lr * (decay_rate**epoch)) for param_group in optimizer.param_groups: param_group["lr"] = lr ================================================ FILE: xraypulse/common/registry.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ class Registry: mapping = { "builder_name_mapping": {}, "task_name_mapping": {}, "processor_name_mapping": {}, "model_name_mapping": {}, "lr_scheduler_name_mapping": {}, "runner_name_mapping": {}, "state": {}, "paths": {}, } @classmethod def register_builder(cls, name): r"""Register a dataset builder to registry with key 'name' Args: name: Key with which the builder will be registered. Usage: from xraypulse.common.registry import registry from xraypulse.datasets.base_dataset_builder import BaseDatasetBuilder """ def wrap(builder_cls): from xraypulse.datasets.builders.base_dataset_builder import BaseDatasetBuilder assert issubclass( builder_cls, BaseDatasetBuilder ), "All builders must inherit BaseDatasetBuilder class, found {}".format( builder_cls ) if name in cls.mapping["builder_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["builder_name_mapping"][name] ) ) cls.mapping["builder_name_mapping"][name] = builder_cls return builder_cls return wrap @classmethod def register_task(cls, name): r"""Register a task to registry with key 'name' Args: name: Key with which the task will be registered. Usage: from minigpt4.common.registry import registry """ def wrap(task_cls): from xraypulse.tasks.base_task import BaseTask assert issubclass( task_cls, BaseTask ), "All tasks must inherit BaseTask class" if name in cls.mapping["task_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["task_name_mapping"][name] ) ) cls.mapping["task_name_mapping"][name] = task_cls return task_cls return wrap @classmethod def register_model(cls, name): r"""Register a task to registry with key 'name' Args: name: Key with which the task will be registered. Usage: from xraypulse.common.registry import registry """ def wrap(model_cls): from xraypulse.models import BaseModel assert issubclass( model_cls, BaseModel ), "All models must inherit BaseModel class" if name in cls.mapping["model_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["model_name_mapping"][name] ) ) cls.mapping["model_name_mapping"][name] = model_cls return model_cls return wrap @classmethod def register_processor(cls, name): r"""Register a processor to registry with key 'name' Args: name: Key with which the task will be registered. Usage: from xraypulse.common.registry import registry """ def wrap(processor_cls): from xraypulse.processors import BaseProcessor assert issubclass( processor_cls, BaseProcessor ), "All processors must inherit BaseProcessor class" if name in cls.mapping["processor_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["processor_name_mapping"][name] ) ) cls.mapping["processor_name_mapping"][name] = processor_cls return processor_cls return wrap @classmethod def register_lr_scheduler(cls, name): r"""Register a model to registry with key 'name' Args: name: Key with which the task will be registered. Usage: from xraypulse.common.registry import registry """ def wrap(lr_sched_cls): if name in cls.mapping["lr_scheduler_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["lr_scheduler_name_mapping"][name] ) ) cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls return lr_sched_cls return wrap @classmethod def register_runner(cls, name): r"""Register a model to registry with key 'name' Args: name: Key with which the task will be registered. Usage: from xraypulse.common.registry import registry """ def wrap(runner_cls): if name in cls.mapping["runner_name_mapping"]: raise KeyError( "Name '{}' already registered for {}.".format( name, cls.mapping["runner_name_mapping"][name] ) ) cls.mapping["runner_name_mapping"][name] = runner_cls return runner_cls return wrap @classmethod def register_path(cls, name, path): r"""Register a path to registry with key 'name' Args: name: Key with which the path will be registered. Usage: from xraypulse.common.registry import registry """ assert isinstance(path, str), "All path must be str." if name in cls.mapping["paths"]: raise KeyError("Name '{}' already registered.".format(name)) cls.mapping["paths"][name] = path @classmethod def register(cls, name, obj): r"""Register an item to registry with key 'name' Args: name: Key with which the item will be registered. Usage:: from minigpt4.common.registry import registry registry.register("config", {}) """ path = name.split(".") current = cls.mapping["state"] for part in path[:-1]: if part not in current: current[part] = {} current = current[part] current[path[-1]] = obj # @classmethod # def get_trainer_class(cls, name): # return cls.mapping["trainer_name_mapping"].get(name, None) @classmethod def get_builder_class(cls, name): return cls.mapping["builder_name_mapping"].get(name, None) @classmethod def get_model_class(cls, name): return cls.mapping["model_name_mapping"].get(name, None) @classmethod def get_task_class(cls, name): return cls.mapping["task_name_mapping"].get(name, None) @classmethod def get_processor_class(cls, name): return cls.mapping["processor_name_mapping"].get(name, None) @classmethod def get_lr_scheduler_class(cls, name): return cls.mapping["lr_scheduler_name_mapping"].get(name, None) @classmethod def get_runner_class(cls, name): return cls.mapping["runner_name_mapping"].get(name, None) @classmethod def list_runners(cls): return sorted(cls.mapping["runner_name_mapping"].keys()) @classmethod def list_models(cls): return sorted(cls.mapping["model_name_mapping"].keys()) @classmethod def list_tasks(cls): return sorted(cls.mapping["task_name_mapping"].keys()) @classmethod def list_processors(cls): return sorted(cls.mapping["processor_name_mapping"].keys()) @classmethod def list_lr_schedulers(cls): return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) @classmethod def list_datasets(cls): return sorted(cls.mapping["builder_name_mapping"].keys()) @classmethod def get_path(cls, name): return cls.mapping["paths"].get(name, None) @classmethod def get(cls, name, default=None, no_warning=False): r"""Get an item from registry with key 'name' Args: name (string): Key whose value needs to be retrieved. default: If passed and key is not in registry, default value will be returned with a warning. Default: None no_warning (bool): If passed as True, warning when key doesn't exist will not be generated. Useful for MMF's internal operations. Default: False """ original_name = name name = name.split(".") value = cls.mapping["state"] for subname in name: value = value.get(subname, default) if value is default: break if ( "writer" in cls.mapping["state"] and value == default and no_warning is False ): cls.mapping["state"]["writer"].warning( "Key {} is not present in registry, returning default value " "of {}".format(original_name, default) ) return value @classmethod def unregister(cls, name): r"""Remove an item from registry with key 'name' Args: name: Key which needs to be removed. Usage:: from mmf.common.registry import registry config = registry.unregister("config") """ return cls.mapping["state"].pop(name, None) registry = Registry() ================================================ FILE: xraypulse/common/utils.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import io import json import logging import os import pickle import re import shutil import urllib import urllib.error import urllib.request from typing import Optional from urllib.parse import urlparse import numpy as np import pandas as pd import yaml from iopath.common.download import download from iopath.common.file_io import file_lock, g_pathmgr from xraypulse.common.registry import registry from torch.utils.model_zoo import tqdm from torchvision.datasets.utils import ( check_integrity, download_file_from_google_drive, extract_archive, ) def now(): from datetime import datetime return datetime.now().strftime("%Y%m%d%H%M")[:-1] def is_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") def get_cache_path(rel_path): return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) def get_abs_path(rel_path): return os.path.join(registry.get_path("library_root"), rel_path) def load_json(filename): with open(filename, "r") as f: return json.load(f) # The following are adapted from torchvision and vissl # torchvision: https://github.com/pytorch/vision # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py def makedir(dir_path): """ Create the directory if it does not exist. """ is_success = False try: if not g_pathmgr.exists(dir_path): g_pathmgr.mkdirs(dir_path) is_success = True except BaseException: print(f"Error creating directory: {dir_path}") return is_success def get_redirected_url(url: str): """ Given a URL, returns the URL it redirects to or the original URL in case of no indirection """ import requests with requests.Session() as session: with session.get(url, stream=True, allow_redirects=True) as response: if response.history: return response.url else: return url def to_google_drive_download_url(view_url: str) -> str: """ Utility function to transform a view URL of google drive to a download URL for google drive Example input: https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view Example output: https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp """ splits = view_url.split("/") assert splits[-1] == "view" file_id = splits[-2] return f"https://drive.google.com/uc?export=download&id={file_id}" def download_google_drive_url(url: str, output_path: str, output_file_name: str): """ Download a file from google drive Downloading an URL from google drive requires confirmation when the file of the size is too big (google drive notifies that anti-viral checks cannot be performed on such files) """ import requests with requests.Session() as session: # First get the confirmation token and append it to the URL with session.get(url, stream=True, allow_redirects=True) as response: for k, v in response.cookies.items(): if k.startswith("download_warning"): url = url + "&confirm=" + v # Then download the content of the file with session.get(url, stream=True, verify=True) as response: makedir(output_path) path = os.path.join(output_path, output_file_name) total_size = int(response.headers.get("Content-length", 0)) with open(path, "wb") as file: from tqdm import tqdm with tqdm(total=total_size) as progress_bar: for block in response.iter_content( chunk_size=io.DEFAULT_BUFFER_SIZE ): file.write(block) progress_bar.update(len(block)) def _get_google_drive_file_id(url: str) -> Optional[str]: parts = urlparse(url) if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: return None match = re.match(r"/file/d/(?P[^/]*)", parts.path) if match is None: return None return match.group("id") def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: with open(filename, "wb") as fh: with urllib.request.urlopen( urllib.request.Request(url, headers={"User-Agent": "vissl"}) ) as response: with tqdm(total=response.length) as pbar: for chunk in iter(lambda: response.read(chunk_size), ""): if not chunk: break pbar.update(chunk_size) fh.write(chunk) def download_url( url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, ) -> None: """Download a file from a url and place it in root. Args: url (str): URL to download file from root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL. md5 (str, optional): MD5 checksum of the download. If None, do not check """ root = os.path.expanduser(root) if not filename: filename = os.path.basename(url) fpath = os.path.join(root, filename) makedir(root) # check if file is already present locally if check_integrity(fpath, md5): print("Using downloaded and verified file: " + fpath) return # expand redirect chain if needed url = get_redirected_url(url) # check if file is located on Google Drive file_id = _get_google_drive_file_id(url) if file_id is not None: return download_file_from_google_drive(file_id, root, filename, md5) # download the file try: print("Downloading " + url + " to " + fpath) _urlretrieve(url, fpath) except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] if url[:5] == "https": url = url.replace("https:", "http:") print( "Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath ) _urlretrieve(url, fpath) else: raise e # check integrity of downloaded file if not check_integrity(fpath, md5): raise RuntimeError("File not found or corrupted.") def download_and_extract_archive( url: str, download_root: str, extract_root: Optional[str] = None, filename: Optional[str] = None, md5: Optional[str] = None, remove_finished: bool = False, ) -> None: download_root = os.path.expanduser(download_root) if extract_root is None: extract_root = download_root if not filename: filename = os.path.basename(url) download_url(url, download_root, filename, md5) archive = os.path.join(download_root, filename) print("Extracting {} to {}".format(archive, extract_root)) extract_archive(archive, extract_root, remove_finished) def cache_url(url: str, cache_dir: str) -> str: """ This implementation downloads the remote resource and caches it locally. The resource will only be downloaded if not previously requested. """ parsed_url = urlparse(url) dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) makedir(dirname) filename = url.split("/")[-1] cached = os.path.join(dirname, filename) with file_lock(cached): if not os.path.isfile(cached): logging.info(f"Downloading {url} to {cached} ...") cached = download(url, dirname, filename=filename) logging.info(f"URL {url} cached in {cached}") return cached # TODO (prigoyal): convert this into RAII-style API def create_file_symlink(file1, file2): """ Simply create the symlinks for a given file1 to file2. Useful during model checkpointing to symlinks to the latest successful checkpoint. """ try: if g_pathmgr.exists(file2): g_pathmgr.rm(file2) g_pathmgr.symlink(file1, file2) except Exception as e: logging.info(f"Could NOT create symlink. Error: {e}") def save_file(data, filename, append_to_json=True, verbose=True): """ Common i/o utility to handle saving data to various file formats. Supported: .pkl, .pickle, .npy, .json Specifically for .json, users have the option to either append (default) or rewrite by passing in Boolean value to append_to_json. """ if verbose: logging.info(f"Saving data to file: {filename}") file_ext = os.path.splitext(filename)[1] if file_ext in [".pkl", ".pickle"]: with g_pathmgr.open(filename, "wb") as fopen: pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) elif file_ext == ".npy": with g_pathmgr.open(filename, "wb") as fopen: np.save(fopen, data) elif file_ext == ".json": if append_to_json: with g_pathmgr.open(filename, "a") as fopen: fopen.write(json.dumps(data, sort_keys=True) + "\n") fopen.flush() else: with g_pathmgr.open(filename, "w") as fopen: fopen.write(json.dumps(data, sort_keys=True) + "\n") fopen.flush() elif file_ext == ".yaml": with g_pathmgr.open(filename, "w") as fopen: dump = yaml.dump(data) fopen.write(dump) fopen.flush() else: raise Exception(f"Saving {file_ext} is not supported yet") if verbose: logging.info(f"Saved data to file: {filename}") def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): """ Common i/o utility to handle loading data from various file formats. Supported: .pkl, .pickle, .npy, .json For the npy files, we support reading the files in mmap_mode. If the mmap_mode of reading is not successful, we load data without the mmap_mode. """ if verbose: logging.info(f"Loading data from file: {filename}") file_ext = os.path.splitext(filename)[1] if file_ext == ".txt": with g_pathmgr.open(filename, "r") as fopen: data = fopen.readlines() elif file_ext in [".pkl", ".pickle"]: with g_pathmgr.open(filename, "rb") as fopen: data = pickle.load(fopen, encoding="latin1") elif file_ext == ".npy": if mmap_mode: try: with g_pathmgr.open(filename, "rb") as fopen: data = np.load( fopen, allow_pickle=allow_pickle, encoding="latin1", mmap_mode=mmap_mode, ) except ValueError as e: logging.info( f"Could not mmap {filename}: {e}. Trying without g_pathmgr" ) data = np.load( filename, allow_pickle=allow_pickle, encoding="latin1", mmap_mode=mmap_mode, ) logging.info("Successfully loaded without g_pathmgr") except Exception: logging.info("Could not mmap without g_pathmgr. Trying without mmap") with g_pathmgr.open(filename, "rb") as fopen: data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") else: with g_pathmgr.open(filename, "rb") as fopen: data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") elif file_ext == ".json": with g_pathmgr.open(filename, "r") as fopen: data = json.load(fopen) elif file_ext == ".yaml": with g_pathmgr.open(filename, "r") as fopen: data = yaml.load(fopen, Loader=yaml.FullLoader) elif file_ext == ".csv": with g_pathmgr.open(filename, "r") as fopen: data = pd.read_csv(fopen) else: raise Exception(f"Reading from {file_ext} is not supported yet") return data def abspath(resource_path: str): """ Make a path absolute, but take into account prefixes like "http://" or "manifold://" """ regex = re.compile(r"^\w+://") if regex.match(resource_path) is None: return os.path.abspath(resource_path) else: return resource_path def makedir(dir_path): """ Create the directory if it does not exist. """ is_success = False try: if not g_pathmgr.exists(dir_path): g_pathmgr.mkdirs(dir_path) is_success = True except BaseException: logging.info(f"Error creating directory: {dir_path}") return is_success def is_url(input_url): """ Check if an input string is a url. look for http(s):// and ignoring the case """ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None return is_url def cleanup_dir(dir): """ Utility for deleting a directory. Useful for cleaning the storage space that contains various training artifacts like checkpoints, data etc. """ if os.path.exists(dir): logging.info(f"Deleting directory: {dir}") shutil.rmtree(dir) logging.info(f"Deleted contents of directory: {dir}") def get_file_size(filename): """ Given a file, get the size of file in MB """ size_in_mb = os.path.getsize(filename) / float(1024**2) return size_in_mb ================================================ FILE: xraypulse/configs/datasets/mimic/defaults.yaml ================================================ datasets: mimic: data_type: images build_info: storage: /mnt/petrelfs/share_data/huangzhongzhen/multimodal_pretrain/dataset/mimic ================================================ FILE: xraypulse/configs/datasets/openi/defaults.yaml ================================================ datasets: openi: data_type: images build_info: storage: /mnt/petrelfs/share_data/huangzhongzhen/multimodal_pretrain/dataset/openi ================================================ FILE: xraypulse/configs/default.yaml ================================================ env: # For default users # cache_root: "cache" # For internal use with persistent storage cache_root: "/export/home/.cache/xraypulse" ================================================ FILE: xraypulse/configs/models/xraypulse.yaml ================================================ model: arch: xray_pulse # vit encoder image_size: 224 drop_path_rate: 0 use_grad_checkpoint: False vit_precision: "fp16" freeze_vit: True freeze_qformer: True # Q-Former num_query_token: 32 # Vicuna bloom_model: "OpenMEDLab/PULSE-7bv5" # generation configs prompt: "" preprocess: vis_processor: train: name: "blip2_image_train" image_size: 224 eval: name: "blip2_image_eval" image_size: 224 text_processor: train: name: "blip_caption" eval: name: "blip_caption" ================================================ FILE: xraypulse/conversation/__init__.py ================================================ ================================================ FILE: xraypulse/conversation/conversation.py ================================================ import argparse import time from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BloomTokenizerFast from transformers import StoppingCriteria, StoppingCriteriaList import dataclasses from enum import auto, Enum from typing import List, Tuple, Any from xraypulse.common.registry import registry class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int # system_img: List[Image.Image] = [] sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None skip_next: bool = False conv_id: Any = None def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: ret.append([msg, None]) else: ret[-1][-1] = msg print('ret:') print(ret) return ret def copy(self): return Conversation( system=self.system, # system_img=self.system_img, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, conv_id=self.conv_id) def dict(self): return { "system": self.system, # "system_img": self.system_img, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, "conv_id": self.conv_id, } class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False CONV_ZH = Conversation( system="Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01 User: {} Helper: ", # "Please answer the medical questions based on the patient's description. Give the following medical scan: 图片." # "You will be able to see the medical scan once I provide it to you. Please answer the patients questions.", roles=("User", "Helper"), messages=[], offset=0, sep_style=SeparatorStyle.SINGLE, sep="", sep2="###", ) class Chat: def __init__(self, model, vis_processor, device='cuda:0'): self.device = device self.model = model self.vis_processor = vis_processor stop_words_ids = [torch.tensor([835]).to(self.device), torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def ask(self, text, conv): if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ and conv.messages[-1][1][-6:] == '': # last message is image. conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) else: conv.append_message(conv.roles[0], text) def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) current_max_len = embs.shape[1] + max_new_tokens if current_max_len - max_length > 0: print('Warning: The number of tokens in current conversation exceeds the max length. ' 'The model will not see the contexts outside the range.') begin_idx = max(0, current_max_len - max_length) embs = embs[:, begin_idx:] outputs = self.model.bloom_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] output_text = self.model.bloom_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('')[0] # remove the stop sign '###' output_text = output_text.split('###')[0] conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy() def upload_img(self, image, conv, img_list): if isinstance(image, str): # is a image path raw_image = Image.open(image).convert('RGB') image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) elif isinstance(image, Image.Image): raw_image = image image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) elif isinstance(image, torch.Tensor): if len(image.shape) == 3: image = image.unsqueeze(0) image = image.to(self.device) image_emb, _ = self.model.encode_img(image) img_list.append(image_emb) conv.append_message(conv.roles[0], "<图片>") msg = "Received." # self.conv.append_message(self.conv.roles[1], msg) return msg def get_context_emb(self, conv, img_list): prompt = conv.get_prompt() prompt_segs = prompt.split('<图片>') assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." seg_tokens = [ self.model.bloom_tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] seg_embs = [self.model.bloom_model.transformer.word_embeddings_layernorm(self.model.bloom_model.transformer.word_embeddings(seg_t)) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs ================================================ FILE: xraypulse/datasets/__init__.py ================================================ ================================================ FILE: xraypulse/datasets/builders/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from xraypulse.datasets.builders.base_dataset_builder import load_dataset_config from xraypulse.datasets.builders.image_text_pair_builder import ( MIMICBuilder, OpenIBuilder, ) from xraypulse.common.registry import registry __all__ = [ "MIMICBuilder", "OpenIBuilder", ] def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): """ Example >>> dataset = load_dataset("coco_caption", cfg=None) >>> splits = dataset.keys() >>> print([len(dataset[split]) for split in splits]) """ if cfg_path is None: cfg = None else: cfg = load_dataset_config(cfg_path) try: builder = registry.get_builder_class(name)(cfg) except TypeError: print( f"Dataset {name} not found. Available datasets:\n" + ", ".join([str(k) for k in dataset_zoo.get_names()]) ) exit(1) if vis_path is not None: if data_type is None: # use default data type in the config data_type = builder.config.data_type assert ( data_type in builder.config.build_info ), f"Invalid data_type {data_type} for {name}." builder.config.build_info.get(data_type).storage = vis_path dataset = builder.build_datasets() return dataset class DatasetZoo: def __init__(self) -> None: self.dataset_zoo = { k: list(v.DATASET_CONFIG_DICT.keys()) for k, v in sorted(registry.mapping["builder_name_mapping"].items()) } def get_names(self): return list(self.dataset_zoo.keys()) dataset_zoo = DatasetZoo() ================================================ FILE: xraypulse/datasets/builders/base_dataset_builder.py ================================================ """ This file is from Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import shutil import warnings from omegaconf import OmegaConf import torch.distributed as dist from torchvision.datasets.utils import download_url import xraypulse.common.utils as utils from xraypulse.common.dist_utils import is_dist_avail_and_initialized, is_main_process from xraypulse.common.registry import registry from xraypulse.processors.base_processor import BaseProcessor class BaseDatasetBuilder: train_dataset_cls, eval_dataset_cls = None, None def __init__(self, cfg=None): super().__init__() if cfg is None: # help to create datasets from default config. self.config = load_dataset_config(self.default_config_path()) elif isinstance(cfg, str): self.config = load_dataset_config(cfg) else: # when called from task.build_dataset() self.config = cfg self.data_type = self.config.data_type self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} def build_datasets(self): # download, split, etc... # only called on 1 GPU/TPU in distributed if is_main_process(): self._download_data() if is_dist_avail_and_initialized(): dist.barrier() # at this point, all the annotations and image/videos should be all downloaded to the specified locations. logging.info("Building datasets...") datasets = self.build() # dataset['train'/'val'/'test'] return datasets def build_processors(self): vis_proc_cfg = self.config.get("vis_processor") txt_proc_cfg = self.config.get("text_processor") if vis_proc_cfg is not None: vis_train_cfg = vis_proc_cfg.get("train") vis_eval_cfg = vis_proc_cfg.get("eval") self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) if txt_proc_cfg is not None: txt_train_cfg = txt_proc_cfg.get("train") txt_eval_cfg = txt_proc_cfg.get("eval") self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) @staticmethod def _build_proc_from_cfg(cfg): return ( registry.get_processor_class(cfg.name).from_config(cfg) if cfg is not None else None ) @classmethod def default_config_path(cls, type="default"): return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) def _download_data(self): self._download_ann() self._download_vis() def _download_ann(self): """ Download annotation files if necessary. All the vision-language datasets should have annotations of unified format. storage_path can be: (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. Local annotation paths should be relative. """ anns = self.config.build_info.annotations splits = anns.keys() cache_root = registry.get_path("cache_root") for split in splits: info = anns[split] urls, storage_paths = info.get("url", None), info.storage if isinstance(urls, str): urls = [urls] if isinstance(storage_paths, str): storage_paths = [storage_paths] assert len(urls) == len(storage_paths) for url_or_filename, storage_path in zip(urls, storage_paths): # if storage_path is relative, make it full by prefixing with cache_root. if not os.path.isabs(storage_path): storage_path = os.path.join(cache_root, storage_path) dirname = os.path.dirname(storage_path) if not os.path.exists(dirname): os.makedirs(dirname) if os.path.isfile(url_or_filename): src, dst = url_or_filename, storage_path if not os.path.exists(dst): shutil.copyfile(src=src, dst=dst) else: logging.info("Using existing file {}.".format(dst)) else: if os.path.isdir(storage_path): # if only dirname is provided, suffix with basename of URL. raise ValueError( "Expecting storage_path to be a file path, got directory {}".format( storage_path ) ) else: filename = os.path.basename(storage_path) download_url(url=url_or_filename, root=dirname, filename=filename) def _download_vis(self): storage_path = self.config.build_info.get(self.data_type).storage storage_path = utils.get_cache_path(storage_path) if not os.path.exists(storage_path): warnings.warn( f""" The specified path {storage_path} for visual inputs does not exist. Please provide a correct path to the visual inputs or refer to datasets/download_scripts/README.md for downloading instructions. """ ) def build(self): """ Create by split datasets inheriting torch.utils.data.Datasets. # build() can be dataset-specific. Overwrite to customize. """ self.build_processors() build_info = self.config.build_info ann_info = build_info.annotations vis_info = build_info.get(self.data_type) datasets = dict() for split in ann_info.keys(): if split not in ["train", "val", "test"]: continue is_train = split == "train" # processors vis_processor = ( self.vis_processors["train"] if is_train else self.vis_processors["eval"] ) text_processor = ( self.text_processors["train"] if is_train else self.text_processors["eval"] ) # annotation path ann_paths = ann_info.get(split).storage if isinstance(ann_paths, str): ann_paths = [ann_paths] abs_ann_paths = [] for ann_path in ann_paths: if not os.path.isabs(ann_path): ann_path = utils.get_cache_path(ann_path) abs_ann_paths.append(ann_path) ann_paths = abs_ann_paths # visual data storage path vis_path = os.path.join(vis_info.storage, split) if not os.path.isabs(vis_path): # vis_path = os.path.join(utils.get_cache_path(), vis_path) vis_path = utils.get_cache_path(vis_path) if not os.path.exists(vis_path): warnings.warn("storage path {} does not exist.".format(vis_path)) # create datasets dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls datasets[split] = dataset_cls( vis_processor=vis_processor, text_processor=text_processor, ann_paths=ann_paths, vis_root=vis_path, ) return datasets def load_dataset_config(cfg_path): cfg = OmegaConf.load(cfg_path).datasets cfg = cfg[list(cfg.keys())[0]] return cfg ================================================ FILE: xraypulse/datasets/builders/image_text_pair_builder.py ================================================ import os import logging import warnings from xraypulse.common.registry import registry from xraypulse.datasets.builders.base_dataset_builder import BaseDatasetBuilder from xraypulse.datasets.datasets.openi_dataset import OpenIDataset from xraypulse.datasets.datasets.mimic_dataset import MIMICDataset @registry.register_builder("mimic") class MIMICBuilder(BaseDatasetBuilder): train_dataset_cls = MIMICDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/mimic/defaults.yaml"} def _download_ann(self): pass def _download_vis(self): pass def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. logging.info("Building datasets...") self.build_processors() build_info = self.config.build_info storage_path = build_info.storage datasets = dict() if not os.path.exists(storage_path): warnings.warn("storage path {} does not exist.".format(storage_path)) # create datasets dataset_cls = self.train_dataset_cls datasets['train'] = dataset_cls( vis_processor=self.vis_processors["train"], text_processor=self.text_processors["train"], ann_paths=[os.path.join(storage_path, 'zh_filter_cap.json')], vis_root=os.path.join(storage_path, 'image'), ) return datasets @registry.register_builder("openi") class OpenIBuilder(BaseDatasetBuilder): train_dataset_cls = OpenIDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/openi/defaults.yaml"} def _download_ann(self): pass def _download_vis(self): pass def build(self): self.build_processors() build_info = self.config.build_info storage_path = build_info.storage datasets = dict() split = "train" # create datasets # [NOTE] return inner_datasets (wds.DataPipeline) dataset_cls = self.train_dataset_cls datasets[split] = dataset_cls( vis_processor=self.vis_processors["train"], text_processor=self.text_processors["train"], ann_paths=[os.path.join(storage_path, 'zh_filter_cap.json')], vis_root=os.path.join(storage_path, 'image'), ) return datasets ================================================ FILE: xraypulse/datasets/data_utils.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import gzip import logging import os import random as rnd import tarfile import zipfile import random from typing import List from tqdm import tqdm import decord from decord import VideoReader import webdataset as wds import numpy as np import torch from torch.utils.data.dataset import IterableDataset from xraypulse.common.registry import registry from xraypulse.datasets.datasets.base_dataset import ConcatDataset decord.bridge.set_bridge("torch") MAX_INT = registry.get("MAX_INT") class ChainDataset(wds.DataPipeline): r"""Dataset for chaining multiple :class:`DataPipeline` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: List[wds.DataPipeline]) -> None: super().__init__() self.datasets = datasets self.prob = [] self.names = [] for dataset in self.datasets: if hasattr(dataset, 'name'): self.names.append(dataset.name) else: self.names.append('Unknown') if hasattr(dataset, 'sample_ratio'): self.prob.append(dataset.sample_ratio) else: self.prob.append(1) logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.") def __iter__(self): datastreams = [iter(dataset) for dataset in self.datasets] while True: select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0] yield next(select_datastream) def apply_to_sample(f, sample): if len(sample) == 0: return {} def _apply(x): if torch.is_tensor(x): return f(x) elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] else: return x return _apply(sample) def move_to_cuda(sample): def _move_to_cuda(tensor): return tensor.cuda() return apply_to_sample(_move_to_cuda, sample) def prepare_sample(samples, cuda_enabled=True): if cuda_enabled: samples = move_to_cuda(samples) # TODO fp16 support return samples def reorg_datasets_by_split(datasets): """ Organizes datasets by split. Args: datasets: dict of torch.utils.data.Dataset objects by name. Returns: Dict of datasets by split {split_name: List[Datasets]}. """ # if len(datasets) == 1: # return datasets[list(datasets.keys())[0]] # else: reorg_datasets = dict() # reorganize by split for _, dataset in datasets.items(): for split_name, dataset_split in dataset.items(): if split_name not in reorg_datasets: reorg_datasets[split_name] = [dataset_split] else: reorg_datasets[split_name].append(dataset_split) return reorg_datasets def concat_datasets(datasets): """ Concatenates multiple datasets into a single dataset. It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support generic IterableDataset because it requires creating separate samplers. Now only supports conctenating training datasets and assuming validation and testing have only a single dataset. This is because metrics should not be computed on the concatenated datasets. Args: datasets: dict of torch.utils.data.Dataset objects by split. Returns: Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, "val" and "test" remain the same. If the input training datasets contain both map-style and DataPipeline datasets, returns a tuple, where the first element is a concatenated map-style dataset and the second element is a chained DataPipeline dataset. """ # concatenate datasets in the same split for split_name in datasets: if split_name != "train": assert ( len(datasets[split_name]) == 1 ), "Do not support multiple {} datasets.".format(split_name) datasets[split_name] = datasets[split_name][0] else: iterable_datasets, map_datasets = [], [] for dataset in datasets[split_name]: if isinstance(dataset, wds.DataPipeline): logging.info( "Dataset {} is IterableDataset, can't be concatenated.".format( dataset ) ) iterable_datasets.append(dataset) elif isinstance(dataset, IterableDataset): raise NotImplementedError( "Do not support concatenation of generic IterableDataset." ) else: map_datasets.append(dataset) # if len(iterable_datasets) > 0: # concatenate map-style datasets and iterable-style datasets separately if len(iterable_datasets) > 1: chained_datasets = ( ChainDataset(iterable_datasets) ) elif len(iterable_datasets) == 1: chained_datasets = iterable_datasets[0] else: chained_datasets = None concat_datasets = ( ConcatDataset(map_datasets) if len(map_datasets) > 0 else None ) train_datasets = concat_datasets, chained_datasets train_datasets = tuple([x for x in train_datasets if x is not None]) train_datasets = ( train_datasets[0] if len(train_datasets) == 1 else train_datasets ) datasets[split_name] = train_datasets return datasets ================================================ FILE: xraypulse/datasets/datasets/__init__.py ================================================ ================================================ FILE: xraypulse/datasets/datasets/base_dataset.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import json from typing import Iterable from torch.utils.data import Dataset, ConcatDataset from torch.utils.data.dataloader import default_collate class BaseDataset(Dataset): def __init__( self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] ): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ self.vis_root = vis_root self.annotation = [] for ann_path in ann_paths: self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) self.vis_processor = vis_processor self.text_processor = text_processor self._add_instance_ids() def __len__(self): return len(self.annotation) def collater(self, samples): return default_collate(samples) def set_processors(self, vis_processor, text_processor): self.vis_processor = vis_processor self.text_processor = text_processor def _add_instance_ids(self, key="instance_id"): for idx, ann in enumerate(self.annotation): ann[key] = str(idx) class ConcatDataset(ConcatDataset): def __init__(self, datasets: Iterable[Dataset]) -> None: super().__init__(datasets) def collater(self, samples): # TODO For now only supports datasets with same underlying collater implementations all_keys = set() for s in samples: all_keys.update(s) shared_keys = all_keys for s in samples: shared_keys = shared_keys & set(s.keys()) samples_shared_keys = [] for s in samples: samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) return self.datasets[0].collater(samples_shared_keys) ================================================ FILE: xraypulse/datasets/datasets/caption_datasets.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import os from collections import OrderedDict from xraypulse.datasets.datasets.base_dataset import BaseDataset from PIL import Image class __DisplMixin: def displ_item(self, index): sample, ann = self.__getitem__(index), self.annotation[index] return OrderedDict( { "file": ann["image"], "caption": ann["caption"], "image": sample["image"], } ) class CaptionDataset(BaseDataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ super().__init__(vis_processor, text_processor, vis_root, ann_paths) self.img_ids = {} n = 0 for ann in self.annotation: img_id = ann["image_id"] if img_id not in self.img_ids.keys(): self.img_ids[img_id] = n n += 1 def __getitem__(self, index): # TODO this assumes image input, not general enough ann = self.annotation[index] img_file = '{:0>12}.png'.format(ann["image_id"]) image_path = os.path.join(self.vis_root, img_file) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) caption = self.text_processor(ann["caption"]) return { "image": image, "text_input": caption, "image_id": self.img_ids[ann["image_id"]], } class CaptionEvalDataset(BaseDataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file split (string): val or test """ super().__init__(vis_processor, text_processor, vis_root, ann_paths) #below lines are added during test rogue score self.img_ids = {} n = 0 for ann in self.annotation: img_id = ann["image_id"] if img_id not in self.img_ids.keys(): self.img_ids[img_id] = n n += 1 def __getitem__(self, index): ann = self.annotation[index] image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) return { "image": image, "image_id": ann["image_id"], "instance_id": ann["instance_id"], } ================================================ FILE: xraypulse/datasets/datasets/dataloader_utils.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import time import random import torch from xraypulse.datasets.data_utils import move_to_cuda from torch.utils.data import DataLoader class MultiIterLoader: """ A simple wrapper for iterating over multiple iterators. Args: loaders (List[Loader]): List of Iterator loaders. ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. """ def __init__(self, loaders, ratios=None): # assert all loaders has __next__ method for loader in loaders: assert hasattr( loader, "__next__" ), "Loader {} has no __next__ method.".format(loader) if ratios is None: ratios = [1.0] * len(loaders) else: assert len(ratios) == len(loaders) ratios = [float(ratio) / sum(ratios) for ratio in ratios] self.loaders = loaders self.ratios = ratios def __next__(self): # random sample from each loader by ratio loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] return next(self.loaders[loader_idx]) class PrefetchLoader(object): """ Modified from https://github.com/ChenRocks/UNITER. overlap compute and cuda data transfer (copied and then modified from nvidia apex) """ def __init__(self, loader): self.loader = loader self.stream = torch.cuda.Stream() def __iter__(self): loader_it = iter(self.loader) self.preload(loader_it) batch = self.next(loader_it) while batch is not None: is_tuple = isinstance(batch, tuple) if is_tuple: task, batch = batch if is_tuple: yield task, batch else: yield batch batch = self.next(loader_it) def __len__(self): return len(self.loader) def preload(self, it): try: self.batch = next(it) except StopIteration: self.batch = None return # if record_stream() doesn't work, another option is to make sure # device inputs are created on the main stream. # self.next_input_gpu = torch.empty_like(self.next_input, # device='cuda') # self.next_target_gpu = torch.empty_like(self.next_target, # device='cuda') # Need to make sure the memory allocated for next_* is not still in use # by the main stream at the time we start copying to next_*: # self.stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.stream): self.batch = move_to_cuda(self.batch) # more code for the alternative if record_stream() doesn't work: # copy_ will record the use of the pinned source tensor in this # side stream. # self.next_input_gpu.copy_(self.next_input, non_blocking=True) # self.next_target_gpu.copy_(self.next_target, non_blocking=True) # self.next_input = self.next_input_gpu # self.next_target = self.next_target_gpu def next(self, it): torch.cuda.current_stream().wait_stream(self.stream) batch = self.batch if batch is not None: record_cuda_stream(batch) self.preload(it) return batch def __getattr__(self, name): method = self.loader.__getattribute__(name) return method def record_cuda_stream(batch): if isinstance(batch, torch.Tensor): batch.record_stream(torch.cuda.current_stream()) elif isinstance(batch, list) or isinstance(batch, tuple): for t in batch: record_cuda_stream(t) elif isinstance(batch, dict): for t in batch.values(): record_cuda_stream(t) else: pass class IterLoader: """ A wrapper to convert DataLoader as an infinite iterator. Modified from: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py """ def __init__(self, dataloader: DataLoader, use_distributed: bool = False): self._dataloader = dataloader self.iter_loader = iter(self._dataloader) self._use_distributed = use_distributed self._epoch = 0 @property def epoch(self) -> int: return self._epoch def __next__(self): try: data = next(self.iter_loader) except StopIteration: self._epoch += 1 if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: self._dataloader.sampler.set_epoch(self._epoch) time.sleep(2) # Prevent possible deadlock during epoch transition self.iter_loader = iter(self._dataloader) data = next(self.iter_loader) return data def __iter__(self): return self def __len__(self): return len(self._dataloader) ================================================ FILE: xraypulse/datasets/datasets/mimic_dataset.py ================================================ import os from PIL import Image import webdataset as wds from xraypulse.datasets.datasets.base_dataset import BaseDataset from xraypulse.datasets.datasets.caption_datasets import CaptionDataset class MIMICDataset(CaptionDataset): def __getitem__(self, index): # TODO this assumes image input, not general enough ann = self.annotation[index] img_file = '{}.jpg'.format(ann["image_id"]) image_path = os.path.join(self.vis_root, img_file) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) caption = ann['caption'] return { "image": image, "caption":caption, "image_id": self.img_ids[ann["image_id"]], } ================================================ FILE: xraypulse/datasets/datasets/openi_dataset.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import os from PIL import Image import webdataset as wds from xraypulse.datasets.datasets.base_dataset import BaseDataset from xraypulse.datasets.datasets.caption_datasets import CaptionDataset class OpenIDataset(CaptionDataset): def __getitem__(self, index): # TODO this assumes image input, not general enough ann = self.annotation[index] img_file = '{}.png'.format(ann["image_id"]) image_path = os.path.join(self.vis_root, img_file) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) caption = ann['caption'] return { "image": image, "caption":caption, "image_id": self.img_ids[ann["image_id"]], } ================================================ FILE: xraypulse/models/Qformer.py ================================================ """ * Copyright (c) 2023, salesforce.com, inc. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on huggingface code base * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert """ import math import os import warnings from dataclasses import dataclass from typing import Optional, Tuple, Dict, Any import torch from torch import Tensor, device, dtype, nn import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F from transformers.activations import ACT2FN from transformers.file_utils import ( ModelOutput, ) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, NextSentencePredictorOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.utils import logging from transformers.models.bert.configuration_bert import BertConfig logger = logging.get_logger(__name__) class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) ) self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[ :, past_key_values_length : seq_length + past_key_values_length ].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size" ): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr( config, "position_embedding_type", "absolute" ) if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size ) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(-1, 1) position_ids_r = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1 ) positional_embedding = positional_embedding.to( dtype=query_layer.dtype ) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", key_layer, positional_embedding ) attention_scores = ( attention_scores + relative_position_scores_query + relative_position_scores_key ) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) outputs = ( (context_layer, attention_probs) if output_attentions else (context_layer,) ) outputs = outputs + (past_key_value,) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = ( self.self.attention_head_size * self.self.num_attention_heads ) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[ 1: ] # add attentions if we output them return outputs class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.layer_num = layer_num if ( self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0 ): self.crossattention = BertAttention( config, is_cross_attention=self.config.add_cross_attention ) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.intermediate_query = BertIntermediate(config) self.output_query = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = ( past_key_value[:2] if past_key_value is not None else None ) self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: assert ( encoder_hidden_states is not None ), "encoder_hidden_states must be given for cross-attention layers" cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = ( outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output,) + outputs outputs = outputs + (present_key_value,) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [BertLayer(config, i) for i in range(config.num_hidden_layers)] ) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = ( () if output_attentions and self.config.add_cross_attention else None ) next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module( *inputs, past_key_value, output_attentions, query_length ) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool, has_query: bool = False, ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = ( seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] ) # add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] if has_query: # UniLM style attention mask causal_mask = torch.cat( [ torch.zeros( (batch_size, prefix_seq_len, seq_length), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=1, ) causal_mask = torch.cat( [ torch.ones( (batch_size, causal_mask.shape[1], prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = ( causal_mask[:, None, :, :] * attention_mask[:, None, None, :] ) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape ) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=self.dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: assert ( query_embeds is not None ), "You have to specify query_embeds when input_ids is None" # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 ) query_length = query_embeds.shape[1] if query_embeds is not None else 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, query_embeds=query_embeds, past_key_values_length=past_key_values_length, ) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device if attention_mask is None: attention_mask = torch.ones( ((batch_size, seq_length + past_key_values_length)), device=device ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if is_decoder: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_ids.shape, device, is_decoder, has_query=(query_embeds is not None), ) else: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 0 ].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [ self.invert_attention_mask(mask) for mask in encoder_attention_mask ] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) sequence_output = encoder_outputs[0] pooled_output = ( self.pooler(sequence_output) if self.pooler is not None else None ) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=True, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=True, reduction="mean", ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). Returns: Example:: >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> config = BertConfig.from_pretrained("bert-base-cased") >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if labels is not None: use_cache = False if past_key_values is not None: query_embeds = None outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) sequence_output = outputs[0] if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1] :, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) lm_loss = loss_fct( shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation( self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs ): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) query_mask = input_ids.new_ones(query_embeds.shape[:-1]) attention_mask = torch.cat([query_mask, attention_mask], dim=-1) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "query_embeds": query_embeds, "attention_mask": attention_mask, "past_key_values": past, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, } def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += ( tuple( past_state.index_select(0, beam_idx) for past_state in layer_past ), ) return reordered_past class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1] :, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ) if not return_dict: output = (prediction_scores,) + outputs[2:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else output ) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: xraypulse/models/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import torch from omegaconf import OmegaConf from xraypulse.common.registry import registry from xraypulse.models.base_model import BaseModel from xraypulse.models.blip2 import Blip2Base from xraypulse.models.xray_pulse import XrayPulse from xraypulse.processors.base_processor import BaseProcessor __all__ = [ "load_model", "BaseModel", "Blip2Base", "XrayPulse", ] def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): """ Load supported models. To list all available models and types in registry: >>> from minigpt4.models import model_zoo >>> print(model_zoo) Args: name (str): name of the model. model_type (str): type of the model. is_eval (bool): whether the model is in eval mode. Default: False. device (str): device to use. Default: "cpu". checkpoint (str): path or to checkpoint. Default: None. Note that expecting the checkpoint to have the same keys in state_dict as the model. Returns: model (torch.nn.Module): model. """ model = registry.get_model_class(name).from_pretrained(model_type=model_type) if checkpoint is not None: model.load_checkpoint(checkpoint) if is_eval: model.eval() if device == "cpu": model = model.float() return model.to(device) def load_preprocess(config): """ Load preprocessor configs and construct preprocessors. If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. Args: config (dict): preprocessor configs. Returns: vis_processors (dict): preprocessors for visual inputs. txt_processors (dict): preprocessors for text inputs. Key is "train" or "eval" for processors used in training and evaluation respectively. """ def _build_proc_from_cfg(cfg): return ( registry.get_processor_class(cfg.name).from_config(cfg) if cfg is not None else BaseProcessor() ) vis_processors = dict() txt_processors = dict() vis_proc_cfg = config.get("vis_processor") txt_proc_cfg = config.get("text_processor") if vis_proc_cfg is not None: vis_train_cfg = vis_proc_cfg.get("train") vis_eval_cfg = vis_proc_cfg.get("eval") else: vis_train_cfg = None vis_eval_cfg = None vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) if txt_proc_cfg is not None: txt_train_cfg = txt_proc_cfg.get("train") txt_eval_cfg = txt_proc_cfg.get("eval") else: txt_train_cfg = None txt_eval_cfg = None txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) return vis_processors, txt_processors def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): """ Load model and its related preprocessors. List all available models and types in registry: >>> from minigpt4.models import model_zoo >>> print(model_zoo) Args: name (str): name of the model. model_type (str): type of the model. is_eval (bool): whether the model is in eval mode. Default: False. device (str): device to use. Default: "cpu". Returns: model (torch.nn.Module): model. vis_processors (dict): preprocessors for visual inputs. txt_processors (dict): preprocessors for text inputs. """ model_cls = registry.get_model_class(name) # load model model = model_cls.from_pretrained(model_type=model_type) if is_eval: model.eval() # load preprocess cfg = OmegaConf.load(model_cls.default_config_path(model_type)) if cfg is not None: preprocess_cfg = cfg.preprocess vis_processors, txt_processors = load_preprocess(preprocess_cfg) else: vis_processors, txt_processors = None, None logging.info( f"""No default preprocess for model {name} ({model_type}). This can happen if the model is not finetuned on downstream datasets, or it is not intended for direct use without finetuning. """ ) if device == "cpu" or device == torch.device("cpu"): model = model.float() return model.to(device), vis_processors, txt_processors class ModelZoo: """ A utility class to create string representation of available model architectures and types. >>> from minigpt4.models import model_zoo >>> # list all available models >>> print(model_zoo) >>> # show total number of models >>> print(len(model_zoo)) """ def __init__(self) -> None: self.model_zoo = { k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) for k, v in registry.mapping["model_name_mapping"].items() } def __str__(self) -> str: return ( "=" * 50 + "\n" + f"{'Architectures':<30} {'Types'}\n" + "=" * 50 + "\n" + "\n".join( [ f"{name:<30} {', '.join(types)}" for name, types in self.model_zoo.items() ] ) ) def __iter__(self): return iter(self.model_zoo.items()) def __len__(self): return sum([len(v) for v in self.model_zoo.values()]) model_zoo = ModelZoo() ================================================ FILE: xraypulse/models/base_model.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import numpy as np import torch import torch.nn as nn from xraypulse.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from xraypulse.common.utils import get_abs_path, is_url from omegaconf import OmegaConf class BaseModel(nn.Module): """Base class for models.""" def __init__(self): super().__init__() @property def device(self): return list(self.parameters())[0].device def load_checkpoint(self, url_or_filename): """ Load from a finetuned checkpoint. This should expect no mismatch in the model keys and the checkpoint keys. """ if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if "model" in checkpoint.keys(): state_dict = checkpoint["model"] else: state_dict = checkpoint msg = self.load_state_dict(state_dict, strict=False) logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg @classmethod def from_pretrained(cls, model_type): """ Build a pretrained model from default configuration file, specified by model_type. Args: - model_type (str): model type, specifying architecture and checkpoints. Returns: - model (nn.Module): pretrained or finetuned model, depending on the configuration. """ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model model = cls.from_config(model_cfg) return model @classmethod def default_config_path(cls, model_type): assert ( model_type in cls.PRETRAINED_MODEL_CONFIG_DICT ), "Unknown model type {}".format(model_type) return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) def load_checkpoint_from_config(self, cfg, **kwargs): """ Load checkpoint as specified in the config file. If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. When loading the pretrained model, each task-specific architecture may define their own load_from_pretrained() method. """ load_finetuned = cfg.get("load_finetuned", True) if load_finetuned: finetune_path = cfg.get("finetuned", None) assert ( finetune_path is not None ), "Found load_finetuned is True, but finetune_path is None." self.load_checkpoint(url_or_filename=finetune_path) else: # load pre-trained weights pretrain_path = cfg.get("pretrained", None) assert "Found load_finetuned is False, but pretrain_path is None." self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) def before_evaluation(self, **kwargs): pass def show_n_params(self, return_str=True): tot = 0 for p in self.parameters(): w = 1 for x in p.shape: w *= x tot += w if return_str: if tot >= 1e6: return "{:.1f}M".format(tot / 1e6) else: return "{:.1f}K".format(tot / 1e3) else: return tot class BaseEncoder(nn.Module): """ Base class for primitive encoders, such as ViT, TimeSformer, etc. """ def __init__(self): super().__init__() def forward_features(self, samples, **kwargs): raise NotImplementedError @property def device(self): return list(self.parameters())[0].device class SharedQueueMixin: @torch.no_grad() def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): # gather keys before updating queue image_feats = concat_all_gather(image_feat) text_feats = concat_all_gather(text_feat) batch_size = image_feats.shape[0] ptr = int(self.queue_ptr) assert self.queue_size % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.image_queue[:, ptr : ptr + batch_size] = image_feats.T self.text_queue[:, ptr : ptr + batch_size] = text_feats.T if idxs is not None: idxs = concat_all_gather(idxs) self.idx_queue[:, ptr : ptr + batch_size] = idxs.T ptr = (ptr + batch_size) % self.queue_size # move pointer self.queue_ptr[0] = ptr class MomentumDistilationMixin: @torch.no_grad() def copy_params(self): for model_pair in self.model_pairs: for param, param_m in zip( model_pair[0].parameters(), model_pair[1].parameters() ): param_m.data.copy_(param.data) # initialize param_m.requires_grad = False # not update by gradient @torch.no_grad() def _momentum_update(self): for model_pair in self.model_pairs: for param, param_m in zip( model_pair[0].parameters(), model_pair[1].parameters() ): param_m.data = param_m.data * self.momentum + param.data * ( 1.0 - self.momentum ) class GatherLayer(torch.autograd.Function): """ Gather tensors from all workers with support for backward propagation: This implementation does not cut the gradients as torch.distributed.all_gather does. """ @staticmethod def forward(ctx, x): output = [ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(output, x) return tuple(output) @staticmethod def backward(ctx, *grads): all_gradients = torch.stack(grads) torch.distributed.all_reduce(all_gradients) return all_gradients[torch.distributed.get_rank()] def all_gather_with_grad(tensors): """ Performs all_gather operation on the provided tensors. Graph remains connected for backward grad computation. """ # Queue the gathered tensors world_size = torch.distributed.get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: return tensors # tensor_all = GatherLayer.apply(tensors) tensor_all = GatherLayer.apply(tensors) return torch.cat(tensor_all, dim=0) @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ # if use distributed training if not is_dist_avail_and_initialized(): return tensor tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def tile(x, dim, n_tile): init_dim = x.size(dim) repeat_idx = [1] * x.dim() repeat_idx[dim] = n_tile x = x.repeat(*(repeat_idx)) order_index = torch.LongTensor( np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) ) return torch.index_select(x, dim, order_index.to(x.device)) ================================================ FILE: xraypulse/models/blip2.py ================================================ """ Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import contextlib import logging import os import time import datetime import torch import torch.nn as nn import torch.distributed as dist import torch.nn.functional as F import xraypulse.common.dist_utils as dist_utils from xraypulse.common.dist_utils import download_cached_file from xraypulse.common.utils import is_url from xraypulse.common.logger import MetricLogger from xraypulse.models.base_model import BaseModel from xraypulse.models.Qformer import BertConfig, BertLMHeadModel from xraypulse.models.eva_vit import create_eva_vit_g from transformers import BertTokenizer class Blip2Base(BaseModel): @classmethod def init_tokenizer(cls): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer.add_special_tokens({"bos_token": "[DEC]"}) return tokenizer def maybe_autocast(self, dtype=torch.float16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext() @classmethod def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens @classmethod def init_vision_encoder( cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision ): assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" visual_encoder = create_eva_vit_g( img_size, drop_path_rate, use_grad_checkpoint, precision ) ln_vision = LayerNorm(visual_encoder.num_features) return visual_encoder, ln_vision def load_from_pretrained(self, url_or_filename): if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] msg = self.load_state_dict(state_dict, strict=False) # logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) def compute_sim_matrix(model, data_loader, **kwargs): k_test = kwargs.pop("k_test") metric_logger = MetricLogger(delimiter=" ") header = "Evaluation:" logging.info("Computing features for evaluation...") start_time = time.time() texts = data_loader.dataset.text num_text = len(texts) text_bs = 256 text_ids = [] text_embeds = [] text_atts = [] for i in range(0, num_text, text_bs): text = texts[i : min(num_text, i + text_bs)] text_input = model.tokenizer( text, padding="max_length", truncation=True, max_length=35, return_tensors="pt", ).to(model.device) text_feat = model.forward_text(text_input) text_embed = F.normalize(model.text_proj(text_feat)) text_embeds.append(text_embed) text_ids.append(text_input.input_ids) text_atts.append(text_input.attention_mask) text_embeds = torch.cat(text_embeds, dim=0) text_ids = torch.cat(text_ids, dim=0) text_atts = torch.cat(text_atts, dim=0) vit_feats = [] image_embeds = [] for samples in data_loader: image = samples["image"] image = image.to(model.device) image_feat, vit_feat = model.forward_image(image) image_embed = model.vision_proj(image_feat) image_embed = F.normalize(image_embed, dim=-1) vit_feats.append(vit_feat.cpu()) image_embeds.append(image_embed) vit_feats = torch.cat(vit_feats, dim=0) image_embeds = torch.cat(image_embeds, dim=0) sims_matrix = [] for image_embed in image_embeds: sim_q2t = image_embed @ text_embeds.t() sim_i2t, _ = sim_q2t.max(0) sims_matrix.append(sim_i2t) sims_matrix = torch.stack(sims_matrix, dim=0) score_matrix_i2t = torch.full( (len(data_loader.dataset.image), len(texts)), -100.0 ).to(model.device) num_tasks = dist_utils.get_world_size() rank = dist_utils.get_rank() step = sims_matrix.size(0) // num_tasks + 1 start = rank * step end = min(sims_matrix.size(0), start + step) for i, sims in enumerate( metric_logger.log_every(sims_matrix[start:end], 50, header) ): topk_sim, topk_idx = sims.topk(k=k_test, dim=0) image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) score = model.compute_itm( image_inputs=image_inputs, text_ids=text_ids[topk_idx], text_atts=text_atts[topk_idx], ).float() score_matrix_i2t[start + i, topk_idx] = score + topk_sim sims_matrix = sims_matrix.t() score_matrix_t2i = torch.full( (len(texts), len(data_loader.dataset.image)), -100.0 ).to(model.device) step = sims_matrix.size(0) // num_tasks + 1 start = rank * step end = min(sims_matrix.size(0), start + step) for i, sims in enumerate( metric_logger.log_every(sims_matrix[start:end], 50, header) ): topk_sim, topk_idx = sims.topk(k=k_test, dim=0) image_inputs = vit_feats[topk_idx.cpu()].to(model.device) score = model.compute_itm( image_inputs=image_inputs, text_ids=text_ids[start + i].repeat(k_test, 1), text_atts=text_atts[start + i].repeat(k_test, 1), ).float() score_matrix_t2i[start + i, topk_idx] = score + topk_sim if dist_utils.is_dist_avail_and_initialized(): dist.barrier() torch.distributed.all_reduce( score_matrix_i2t, op=torch.distributed.ReduceOp.SUM ) torch.distributed.all_reduce( score_matrix_t2i, op=torch.distributed.ReduceOp.SUM ) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logging.info("Evaluation time {}".format(total_time_str)) return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() ================================================ FILE: xraypulse/models/blip2_outputs.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from dataclasses import dataclass from typing import Optional import torch from transformers.modeling_outputs import ( ModelOutput, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) @dataclass class BlipSimilarity(ModelOutput): sim_i2t: torch.FloatTensor = None sim_t2i: torch.FloatTensor = None sim_i2t_m: Optional[torch.FloatTensor] = None sim_t2i_m: Optional[torch.FloatTensor] = None sim_i2t_targets: Optional[torch.FloatTensor] = None sim_t2i_targets: Optional[torch.FloatTensor] = None @dataclass class BlipIntermediateOutput(ModelOutput): """ Data class for intermediate outputs of BLIP models. image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. decoder_labels (torch.LongTensor): labels for the captioning loss. itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) """ # uni-modal features image_embeds: torch.FloatTensor = None text_embeds: Optional[torch.FloatTensor] = None image_embeds_m: Optional[torch.FloatTensor] = None text_embeds_m: Optional[torch.FloatTensor] = None # intermediate outputs of multimodal encoder encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None itm_logits: Optional[torch.FloatTensor] = None itm_labels: Optional[torch.LongTensor] = None # intermediate outputs of multimodal decoder decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None decoder_labels: Optional[torch.LongTensor] = None @dataclass class BlipOutput(ModelOutput): # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. sims: Optional[BlipSimilarity] = None intermediate_output: BlipIntermediateOutput = None loss: Optional[torch.FloatTensor] = None loss_itc: Optional[torch.FloatTensor] = None loss_itm: Optional[torch.FloatTensor] = None loss_lm: Optional[torch.FloatTensor] = None @dataclass class BlipOutputFeatures(ModelOutput): """ Data class of features from BlipFeatureExtractor. Args: image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional The first embedding or feature is for the [CLS] token. Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. """ image_embeds: Optional[torch.FloatTensor] = None image_embeds_proj: Optional[torch.FloatTensor] = None text_embeds: Optional[torch.FloatTensor] = None text_embeds_proj: Optional[torch.FloatTensor] = None multimodal_embeds: Optional[torch.FloatTensor] = None ================================================ FILE: xraypulse/models/eva_vit.py ================================================ # Based on EVA, BEIT, timm and DeiT code bases # https://github.com/baaivision/EVA # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/microsoft/unilm/tree/master/beit # https://github.com/facebookresearch/deit/ # https://github.com/facebookresearch/dino # --------------------------------------------------------' import math from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import drop_path, to_2tuple, trunc_normal_ from timm.models.registry import register_model from xraypulse.common.dist_utils import download_cached_file def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), **kwargs } class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None and init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): super().__init__() self.image_size = img_size self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None self.use_checkpoint = use_checkpoint dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)]) # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) # if isinstance(self.head, nn.Linear): # trunc_normal_(self.head.weight, std=.02) self.apply(self._init_weights) self.fix_init_weight() # if isinstance(self.head, nn.Linear): # self.head.weight.data.mul_(init_scale) # self.head.bias.data.mul_(init_scale) def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) return x # x = self.norm(x) # if self.fc_norm is not None: # t = x[:, 1:, :] # return self.fc_norm(t.mean(1)) # else: # return x[:, 0] def forward(self, x): x = self.forward_features(x) # x = self.head(x) return x def get_intermediate_layers(self, x): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) features = [] rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: x = blk(x, rel_pos_bias) features.append(x) return features def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'].float() embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed def convert_weights_to_fp16(model: nn.Module): """Convert applicable model parameters to fp16""" def _convert_weights_to_fp16(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.half() if l.bias is not None: l.bias.data = l.bias.data.half() # if isinstance(l, (nn.MultiheadAttention, Attention)): # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: # tensor = getattr(l, attr) # if tensor is not None: # tensor.data = tensor.data.half() model.apply(_convert_weights_to_fp16) def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): model = VisionTransformer( img_size=img_size, patch_size=14, use_mean_pooling=False, embed_dim=1408, depth=39, num_heads=1408//88, mlp_ratio=4.3637, qkv_bias=True, drop_path_rate=drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=1e-6), use_checkpoint=use_checkpoint, ) url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" cached_file = download_cached_file( url, check_hash=False, progress=True ) state_dict = torch.load(cached_file, map_location="cpu") interpolate_pos_embed(model,state_dict) incompatible_keys = model.load_state_dict(state_dict, strict=False) # print(incompatible_keys) if precision == "fp16": # model.to("cuda") convert_weights_to_fp16(model) return model ================================================ FILE: xraypulse/models/pos_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Position embedding utils # -------------------------------------------------------- import numpy as np import torch # -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float) omega = omega / embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed ================================================ FILE: xraypulse/models/xray_pulse.py ================================================ import logging import random import torch from torch.cuda.amp import autocast as autocast import torch.nn as nn import torch.nn.functional as F from xraypulse.common.registry import registry from xraypulse.models.blip2 import Blip2Base, disabled_train from transformers import BloomForCausalLM from transformers import BloomTokenizerFast from transformers import StoppingCriteria, StoppingCriteriaList #for imprgpt eval import csv #for imprgpt rogue eval from xraypulse.conversation.conversation import Conversation from enum import auto, Enum from typing import List, Tuple, Any class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() @registry.register_model("xray_pulse") class XrayPulse(Blip2Base): """ BLIP2 GPT-Bloom model. """ PRETRAINED_MODEL_CONFIG_DICT = { "pulse": "configs/models/xraypulse.yaml", } def __init__( self, vit_model="eva_clip_g", q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, vit_precision="fp16", freeze_vit=True, freeze_qformer=True, num_query_token=32, bloom_model="", prompt_path="", prompt_template="", max_txt_len=32, low_resource=False, # use 8 bit and put vit in cpu end_sym='\n', ): super().__init__() self.tokenizer = self.init_tokenizer() self.low_resource = low_resource print('Loading VIT') self.visual_encoder, self.ln_vision = self.init_vision_encoder( vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision ) if freeze_vit: for name, param in self.visual_encoder.named_parameters(): param.requires_grad = False self.visual_encoder = self.visual_encoder.eval() self.visual_encoder.train = disabled_train for name, param in self.ln_vision.named_parameters(): param.requires_grad = False self.ln_vision = self.ln_vision.eval() self.ln_vision.train = disabled_train logging.info("freeze vision encoder") print('Loading VIT Done') print('Loading Q-Former') self.Qformer, self.query_tokens = self.init_Qformer( num_query_token, self.visual_encoder.num_features ) self.Qformer.cls = None self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.load_from_pretrained(url_or_filename=q_former_model) if freeze_qformer: for name, param in self.Qformer.named_parameters(): param.requires_grad = False self.Qformer = self.Qformer.eval() self.Qformer.train = disabled_train self.query_tokens.requires_grad = False logging.info("freeze Qformer") print('Loading Q-Former Done') print('Loading Bloom') self.bloom_tokenizer = BloomTokenizerFast.from_pretrained(bloom_model, use_fast=False) self.bloom_tokenizer.pad_token = self.bloom_tokenizer.eos_token if self.low_resource: self.bloom_model = BloomForCausalLM.from_pretrained( bloom_model, torch_dtype=torch.float16, device_map="auto" ) else: self.bloom_model = BloomForCausalLM.from_pretrained( bloom_model, torch_dtype=torch.float16, ) for name, param in self.bloom_model.named_parameters(): param.requires_grad = False print('Loading Bloom Done') self.bloom_proj = nn.Linear( self.Qformer.config.hidden_size, self.bloom_model.config.hidden_size ) self.max_txt_len = max_txt_len self.end_sym = end_sym if prompt_path: with open(prompt_path, 'r') as f: raw_prompts = f.read().splitlines() filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<图片>" in raw_prompt] self.prompt_list = [prompt_template.format(p) for p in filted_prompts] print('Load {} training prompts'.format(len(self.prompt_list))) print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) else: self.prompt_list = [] print('#'*100) def vit_to_cpu(self): self.ln_vision.to("cpu") self.ln_vision.float() self.visual_encoder.to("cpu") self.visual_encoder.float() def encode_img(self, image): device = image.device if self.low_resource: self.vit_to_cpu() image = image.to("cpu") with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) inputs_bloom = self.bloom_proj(query_output.last_hidden_state) atts_bloom = torch.ones(inputs_bloom.size()[:-1], dtype=torch.long).to(image.device) return inputs_bloom, atts_bloom def prompt_wrap(self, img_embeds, atts_img, prompt): if prompt: batch_size = img_embeds.shape[0] p_before, p_after = prompt.split('<图片>') p_before_tokens = self.bloom_tokenizer( p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) p_after_tokens = self.bloom_tokenizer( p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) p_before_embeds = self.bloom_model.transformer.word_embeddings_layernorm(self.bloom_model.transformer.word_embeddings(p_before_tokens.input_ids)).expand(batch_size, -1, -1) p_after_embeds = self.bloom_model.transformer.word_embeddings_layernorm(self.bloom_model.transformer.word_embeddings(p_after_tokens.input_ids)).expand(batch_size, -1, -1) wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) return wrapped_img_embeds, wrapped_atts_img else: return img_embeds, atts_img def forward(self, samples): image = samples["image"] img_embeds, atts_img = self.encode_img(image) if hasattr(samples, 'question_split'): # VQA dataset print('VQA Batch') vqa_prompt = 'Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01 User: <图片> Helper: ' img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) elif self.prompt_list: prompt = random.choice(self.prompt_list) img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) self.bloom_tokenizer.padding_side = "right" text = [t + self.end_sym for t in samples["caption"]] to_regress_tokens = self.bloom_tokenizer( text, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens=False ).to(image.device) targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.bloom_tokenizer.pad_token_id, -100 ) empty_targets = ( torch.ones([atts_img.shape[0], atts_img.shape[1]+1], dtype=torch.long).to(image.device).fill_(-100) # plus one for bos ) targets = torch.cat([empty_targets, targets], dim=1) batch_size = img_embeds.shape[0] bos = torch.ones([batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device) * self.bloom_tokenizer.bos_token_id bos_embeds = self.bloom_model.transformer.word_embeddings_layernorm(self.bloom_model.transformer.word_embeddings(bos)) atts_bos = atts_img[:, :1] to_regress_embeds = self.bloom_model.transformer.word_embeddings_layernorm(self.bloom_model.transformer.word_embeddings(to_regress_tokens.input_ids)) inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) with self.maybe_autocast(): outputs = self.bloom_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) loss = outputs.loss return {"loss": loss} def test(self, samples, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): CONV_ZH = Conversation( system="Instructions: You are PULSE, a large language model trained by SHAIlab. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2022-02-01 User: {} Helper: ", # "Please answer the medical questions based on the patient's description. Give the following medical scan: 图片." # "You will be able to see the medical scan once I provide it to you. Please answer the patients questions.", roles=("User", "Helper"), messages=[], offset=0, sep_style=SeparatorStyle.SINGLE, sep="", sep2="###", ) stop_words_ids = [torch.tensor([835]).to(samples["image"].device), torch.tensor([2277, 29937]).to(samples["image"].device)] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) conv.append_message(conv.roles[1], None) image = samples["image"] img_embeds, atts_img = self.encode_img(image) conv.append_message(conv.roles[0], "<图片>") embs = self.get_context_emb(conv, img_embeds) current_max_len = embs.shape[1] + max_new_tokens if current_max_len - max_length > 0: print('Warning: The number of tokens in current conversation exceeds the max length. ' 'The model will not see the contexts outside the range.') begin_idx = max(0, current_max_len - max_length) embs = embs[:, begin_idx:] outputs = self.bloom_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria = stopping_criteria, num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] output_text = self.bloom_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('')[0] conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy() def get_context_emb(self, conv, img): prompt = random.choice(self.prompt_list) text = prompt.split("<图片>")[-1] if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ and conv.messages[-1][1][-6:] == '': # last message is image. conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) else: conv.append_message(conv.roles[0], prompt) prompt_segs = prompt.split('<图片>') img_list = [img] seg_tokens = [ self.bloom_tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] seg_embs = [self.bloom_model.transformer.word_embeddings_layernorm(self.bloom_model.transformer.word_embeddings(seg_t)) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs @classmethod def from_config(cls, cfg): vit_model = cfg.get("vit_model", "eva_clip_g") q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") img_size = cfg.get("image_size") num_query_token = cfg.get("num_query_token") bloom_model = cfg.get("bloom_model") drop_path_rate = cfg.get("drop_path_rate", 0) use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) freeze_qformer = cfg.get("freeze_qformer", True) low_resource = cfg.get("low_resource", False) prompt_path = cfg.get("prompt_path", "") prompt_template = cfg.get("prompt_template", "") max_txt_len = cfg.get("max_txt_len", 32) end_sym = cfg.get("end_sym", '\n') model = cls( vit_model=vit_model, q_former_model=q_former_model, img_size=img_size, drop_path_rate=drop_path_rate, use_grad_checkpoint=use_grad_checkpoint, vit_precision=vit_precision, freeze_vit=freeze_vit, freeze_qformer=freeze_qformer, num_query_token=num_query_token, bloom_model=bloom_model, prompt_path=prompt_path, prompt_template=prompt_template, max_txt_len=max_txt_len, low_resource=low_resource, end_sym=end_sym ) ckpt_path = cfg.get("ckpt", "") if ckpt_path: print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) ckpt = torch.load(ckpt_path, map_location="cpu") msg = model.load_state_dict(ckpt['model'], strict=False) return model ================================================ FILE: xraypulse/processors/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from xraypulse.processors.base_processor import BaseProcessor from xraypulse.processors.blip_processors import ( Blip2ImageTrainProcessor, Blip2ImageEvalProcessor, BlipCaptionProcessor, ) from xraypulse.common.registry import registry __all__ = [ "BaseProcessor", "Blip2ImageTrainProcessor", "Blip2ImageEvalProcessor", "BlipCaptionProcessor", ] def load_processor(name, cfg=None): """ Example >>> processor = load_processor("alpro_video_train", cfg=None) """ processor = registry.get_processor_class(name).from_config(cfg) return processor ================================================ FILE: xraypulse/processors/base_processor.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from omegaconf import OmegaConf class BaseProcessor: def __init__(self): self.transform = lambda x: x return def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): return cls() def build(self, **kwargs): cfg = OmegaConf.create(kwargs) return self.from_config(cfg) ================================================ FILE: xraypulse/processors/blip_processors.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import re from xraypulse.common.registry import registry from xraypulse.processors.base_processor import BaseProcessor from xraypulse.processors.randaugment import RandomAugment from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode class BlipImageBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) @registry.register_processor("blip_caption") class BlipCaptionProcessor(BaseProcessor): def __init__(self, prompt="", max_words=50): self.prompt = prompt self.max_words = max_words def __call__(self, caption): caption = self.prompt + self.pre_caption(caption) return caption @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() prompt = cfg.get("prompt", "") max_words = cfg.get("max_words", 50) return cls(prompt=prompt, max_words=max_words) def pre_caption(self, caption): caption = re.sub( r"([.!\"()*#:;~])", " ", caption.lower(), ) caption = re.sub( r"\s{2,}", " ", caption, ) caption = caption.rstrip("\n") caption = caption.strip(" ") # truncate caption caption_words = caption.split(" ") if len(caption_words) > self.max_words: caption = " ".join(caption_words[: self.max_words]) return caption @registry.register_processor("blip2_image_train") class Blip2ImageTrainProcessor(BlipImageBaseProcessor): def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.RandomResizedCrop( image_size, scale=(min_scale, max_scale), interpolation=InterpolationMode.BICUBIC, ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 224) mean = cfg.get("mean", None) std = cfg.get("std", None) min_scale = cfg.get("min_scale", 0.5) max_scale = cfg.get("max_scale", 1.0) return cls( image_size=image_size, mean=mean, std=std, min_scale=min_scale, max_scale=max_scale, ) @registry.register_processor("blip2_image_eval") class Blip2ImageEvalProcessor(BlipImageBaseProcessor): def __init__(self, image_size=224, mean=None, std=None): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) @classmethod def from_config(cls, cfg=None): if cfg is None: cfg = OmegaConf.create() image_size = cfg.get("image_size", 224) mean = cfg.get("mean", None) std = cfg.get("std", None) return cls(image_size=image_size, mean=mean, std=std) ================================================ FILE: xraypulse/processors/randaugment.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import cv2 import numpy as np import torch ## aug functions def identity_func(img): return img def autocontrast_func(img, cutoff=0): """ same output as PIL.ImageOps.autocontrast """ n_bins = 256 def tune_channel(ch): n = ch.size cut = cutoff * n // 100 if cut == 0: high, low = ch.max(), ch.min() else: hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) low = np.argwhere(np.cumsum(hist) > cut) low = 0 if low.shape[0] == 0 else low[0] high = np.argwhere(np.cumsum(hist[::-1]) > cut) high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] if high <= low: table = np.arange(n_bins) else: scale = (n_bins - 1) / (high - low) offset = -low * scale table = np.arange(n_bins) * scale + offset table[table < 0] = 0 table[table > n_bins - 1] = n_bins - 1 table = table.clip(0, 255).astype(np.uint8) return table[ch] channels = [tune_channel(ch) for ch in cv2.split(img)] out = cv2.merge(channels) return out def equalize_func(img): """ same output as PIL.ImageOps.equalize PIL's implementation is different from cv2.equalize """ n_bins = 256 def tune_channel(ch): hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) non_zero_hist = hist[hist != 0].reshape(-1) step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) if step == 0: return ch n = np.empty_like(hist) n[0] = step // 2 n[1:] = hist[:-1] table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) return table[ch] channels = [tune_channel(ch) for ch in cv2.split(img)] out = cv2.merge(channels) return out def rotate_func(img, degree, fill=(0, 0, 0)): """ like PIL, rotate by degree, not radians """ H, W = img.shape[0], img.shape[1] center = W / 2, H / 2 M = cv2.getRotationMatrix2D(center, degree, 1) out = cv2.warpAffine(img, M, (W, H), borderValue=fill) return out def solarize_func(img, thresh=128): """ same output as PIL.ImageOps.posterize """ table = np.array([el if el < thresh else 255 - el for el in range(256)]) table = table.clip(0, 255).astype(np.uint8) out = table[img] return out def color_func(img, factor): """ same output as PIL.ImageEnhance.Color """ ## implementation according to PIL definition, quite slow # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] # out = blend(degenerate, img, factor) # M = ( # np.eye(3) * factor # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) # )[np.newaxis, np.newaxis, :] M = np.float32( [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] ) * factor + np.float32([[0.114], [0.587], [0.299]]) out = np.matmul(img, M).clip(0, 255).astype(np.uint8) return out def contrast_func(img, factor): """ same output as PIL.ImageEnhance.Contrast """ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) table = ( np.array([(el - mean) * factor + mean for el in range(256)]) .clip(0, 255) .astype(np.uint8) ) out = table[img] return out def brightness_func(img, factor): """ same output as PIL.ImageEnhance.Contrast """ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) out = table[img] return out def sharpness_func(img, factor): """ The differences the this result and PIL are all on the 4 boundaries, the center areas are same """ kernel = np.ones((3, 3), dtype=np.float32) kernel[1][1] = 5 kernel /= 13 degenerate = cv2.filter2D(img, -1, kernel) if factor == 0.0: out = degenerate elif factor == 1.0: out = img else: out = img.astype(np.float32) degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) out = out.astype(np.uint8) return out def shear_x_func(img, factor, fill=(0, 0, 0)): H, W = img.shape[0], img.shape[1] M = np.float32([[1, factor, 0], [0, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def translate_x_func(img, offset, fill=(0, 0, 0)): """ same output as PIL.Image.transform """ H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, -offset], [0, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def translate_y_func(img, offset, fill=(0, 0, 0)): """ same output as PIL.Image.transform """ H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, 0], [0, 1, -offset]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def posterize_func(img, bits): """ same output as PIL.ImageOps.posterize """ out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) return out def shear_y_func(img, factor, fill=(0, 0, 0)): H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, 0], [factor, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def cutout_func(img, pad_size, replace=(0, 0, 0)): replace = np.array(replace, dtype=np.uint8) H, W = img.shape[0], img.shape[1] rh, rw = np.random.random(2) pad_size = pad_size // 2 ch, cw = int(rh * H), int(rw * W) x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) out = img.copy() out[x1:x2, y1:y2, :] = replace return out ### level to args def enhance_level_to_args(MAX_LEVEL): def level_to_args(level): return ((level / MAX_LEVEL) * 1.8 + 0.1,) return level_to_args def shear_level_to_args(MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * 0.3 if np.random.random() > 0.5: level = -level return (level, replace_value) return level_to_args def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * float(translate_const) if np.random.random() > 0.5: level = -level return (level, replace_value) return level_to_args def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): def level_to_args(level): level = int((level / MAX_LEVEL) * cutout_const) return (level, replace_value) return level_to_args def solarize_level_to_args(MAX_LEVEL): def level_to_args(level): level = int((level / MAX_LEVEL) * 256) return (level,) return level_to_args def none_level_to_args(level): return () def posterize_level_to_args(MAX_LEVEL): def level_to_args(level): level = int((level / MAX_LEVEL) * 4) return (level,) return level_to_args def rotate_level_to_args(MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * 30 if np.random.random() < 0.5: level = -level return (level, replace_value) return level_to_args func_dict = { "Identity": identity_func, "AutoContrast": autocontrast_func, "Equalize": equalize_func, "Rotate": rotate_func, "Solarize": solarize_func, "Color": color_func, "Contrast": contrast_func, "Brightness": brightness_func, "Sharpness": sharpness_func, "ShearX": shear_x_func, "TranslateX": translate_x_func, "TranslateY": translate_y_func, "Posterize": posterize_func, "ShearY": shear_y_func, } translate_const = 10 MAX_LEVEL = 10 replace_value = (128, 128, 128) arg_dict = { "Identity": none_level_to_args, "AutoContrast": none_level_to_args, "Equalize": none_level_to_args, "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), "Solarize": solarize_level_to_args(MAX_LEVEL), "Color": enhance_level_to_args(MAX_LEVEL), "Contrast": enhance_level_to_args(MAX_LEVEL), "Brightness": enhance_level_to_args(MAX_LEVEL), "Sharpness": enhance_level_to_args(MAX_LEVEL), "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), "Posterize": posterize_level_to_args(MAX_LEVEL), "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), } class RandomAugment(object): def __init__(self, N=2, M=10, isPIL=False, augs=[]): self.N = N self.M = M self.isPIL = isPIL if augs: self.augs = augs else: self.augs = list(arg_dict.keys()) def get_random_ops(self): sampled_ops = np.random.choice(self.augs, self.N) return [(op, 0.5, self.M) for op in sampled_ops] def __call__(self, img): if self.isPIL: img = np.array(img) ops = self.get_random_ops() for name, prob, level in ops: if np.random.random() > prob: continue args = arg_dict[name](level) img = func_dict[name](img, *args) return img class VideoRandomAugment(object): def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): self.N = N self.M = M self.p = p self.tensor_in_tensor_out = tensor_in_tensor_out if augs: self.augs = augs else: self.augs = list(arg_dict.keys()) def get_random_ops(self): sampled_ops = np.random.choice(self.augs, self.N, replace=False) return [(op, self.M) for op in sampled_ops] def __call__(self, frames): assert ( frames.shape[-1] == 3 ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." if self.tensor_in_tensor_out: frames = frames.numpy().astype(np.uint8) num_frames = frames.shape[0] ops = num_frames * [self.get_random_ops()] apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] frames = torch.stack( list(map(self._aug, frames, ops, apply_or_not)), dim=0 ).float() return frames def _aug(self, img, ops, apply_or_not): for i, (name, level) in enumerate(ops): if not apply_or_not[i]: continue args = arg_dict[name](level) img = func_dict[name](img, *args) return torch.from_numpy(img) if __name__ == "__main__": a = RandomAugment() img = np.random.randn(32, 32, 3) a(img) ================================================ FILE: xraypulse/runners/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from xraypulse.runners.runner_base import RunnerBase __all__ = ["RunnerBase"] ================================================ FILE: xraypulse/runners/runner_base.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import datetime import json import logging import os import time from pathlib import Path import torch import torch.distributed as dist import webdataset as wds from xraypulse.common.dist_utils import ( download_cached_file, get_rank, get_world_size, is_main_process, main_process, ) from xraypulse.common.registry import registry from xraypulse.common.utils import is_url from xraypulse.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset from xraypulse.datasets.datasets.dataloader_utils import ( IterLoader, MultiIterLoader, PrefetchLoader, ) from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler @registry.register_runner("runner_base") class RunnerBase: """ A runner class to train and evaluate a model given a task and datasets. The runner uses pytorch distributed data parallel by default. Future release will support other distributed frameworks. """ def __init__(self, cfg, task, model, datasets, job_id): self.config = cfg self.job_id = job_id self.task = task self.datasets = datasets self._model = model self._wrapped_model = None self._device = None self._optimizer = None self._scaler = None self._dataloaders = None self._lr_sched = None self.start_epoch = 0 # self.setup_seeds() self.setup_output_dir() @property def device(self): if self._device is None: self._device = torch.device(self.config.run_cfg.device) return self._device @property def use_distributed(self): return self.config.run_cfg.distributed @property def model(self): """ A property to get the DDP-wrapped model on the device. """ # move model to device if self._model.device != self.device: self._model = self._model.to(self.device) # distributed training wrapper if self.use_distributed: if self._wrapped_model is None: self._wrapped_model = DDP( self._model, device_ids=[self.config.run_cfg.gpu] ) else: self._wrapped_model = self._model return self._wrapped_model @property def optimizer(self): # TODO make optimizer class and configurations if self._optimizer is None: num_parameters = 0 p_wd, p_non_wd = [], [] for n, p in self.model.named_parameters(): if not p.requires_grad: continue # frozen weights print(n) if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n: p_non_wd.append(p) else: p_wd.append(p) num_parameters += p.data.nelement() logging.info("number of trainable parameters: %d" % num_parameters) optim_params = [ { "params": p_wd, "weight_decay": float(self.config.run_cfg.weight_decay), }, {"params": p_non_wd, "weight_decay": 0}, ] beta2 = self.config.run_cfg.get("beta2", 0.999) self._optimizer = torch.optim.AdamW( optim_params, lr=float(self.config.run_cfg.init_lr), weight_decay=float(self.config.run_cfg.weight_decay), betas=(0.9, beta2), ) return self._optimizer @property def scaler(self): amp = self.config.run_cfg.get("amp", False) if amp: if self._scaler is None: self._scaler = torch.cuda.amp.GradScaler() return self._scaler @property def lr_scheduler(self): """ A property to get and create learning rate scheduler by split just in need. """ if self._lr_sched is None: lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) # max_epoch = self.config.run_cfg.max_epoch max_epoch = self.max_epoch # min_lr = self.config.run_cfg.min_lr min_lr = self.min_lr # init_lr = self.config.run_cfg.init_lr init_lr = self.init_lr # optional parameters decay_rate = self.config.run_cfg.get("lr_decay_rate", None) warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) warmup_steps = self.config.run_cfg.get("warmup_steps", 0) iters_per_epoch = self.config.run_cfg.get("iters_per_epoch", None) if iters_per_epoch is None: try: iters_per_epoch = len(self.dataloaders['train']) except (AttributeError, TypeError): iters_per_epoch = 10000 self._lr_sched = lr_sched_cls( optimizer=self.optimizer, max_epoch=max_epoch, iters_per_epoch=iters_per_epoch, min_lr=min_lr, init_lr=init_lr, decay_rate=decay_rate, warmup_start_lr=warmup_start_lr, warmup_steps=warmup_steps, ) return self._lr_sched @property def dataloaders(self) -> dict: """ A property to get and create dataloaders by split just in need. If no train_dataset_ratio is provided, concatenate map-style datasets and chain wds.DataPipe datasets separately. Training set becomes a tuple (ConcatDataset, ChainDataset), both are optional but at least one of them is required. The resultant ConcatDataset and ChainDataset will be sampled evenly. If train_dataset_ratio is provided, create a MultiIterLoader to sample each dataset by ratios during training. Currently do not support multiple datasets for validation and test. Returns: dict: {split_name: (tuples of) dataloader} """ if self._dataloaders is None: # concatenate map-style datasets and chain wds.DataPipe datasets separately # training set becomes a tuple (ConcatDataset, ChainDataset), both are # optional but at least one of them is required. The resultant ConcatDataset # and ChainDataset will be sampled evenly. logging.info( "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." ) datasets = reorg_datasets_by_split(self.datasets) self.datasets = datasets # self.datasets = concat_datasets(datasets) # print dataset statistics after concatenation/chaining for split_name in self.datasets: if isinstance(self.datasets[split_name], tuple) or isinstance( self.datasets[split_name], list ): # mixed wds.DataPipeline and torch.utils.data.Dataset num_records = sum( [ len(d) if not type(d) in [wds.DataPipeline, ChainDataset] else 0 for d in self.datasets[split_name] ] ) else: if hasattr(self.datasets[split_name], "__len__"): # a single map-style dataset num_records = len(self.datasets[split_name]) else: # a single wds.DataPipeline num_records = -1 logging.info( "Only a single wds.DataPipeline dataset, no __len__ attribute." ) if num_records >= 0: logging.info( "Loaded {} records for {} split from the dataset.".format( num_records, split_name ) ) # create dataloaders split_names = sorted(self.datasets.keys()) datasets = [self.datasets[split] for split in split_names] is_trains = [split in self.train_splits for split in split_names] batch_sizes = [ self.config.run_cfg.batch_size_train if split == "train" else self.config.run_cfg.batch_size_eval for split in split_names ] collate_fns = [] for dataset in datasets: if isinstance(dataset, tuple) or isinstance(dataset, list): collate_fns.append([getattr(d, "collater", None) for d in dataset]) else: collate_fns.append(getattr(dataset, "collater", None)) dataloaders = self.create_loaders( datasets=datasets, num_workers=self.config.run_cfg.num_workers, batch_sizes=batch_sizes, is_trains=is_trains, collate_fns=collate_fns, ) self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} return self._dataloaders @property def cuda_enabled(self): return self.device.type == "cuda" @property def max_epoch(self): return int(self.config.run_cfg.max_epoch) @property def log_freq(self): log_freq = self.config.run_cfg.get("log_freq", 50) return int(log_freq) @property def init_lr(self): return float(self.config.run_cfg.init_lr) @property def min_lr(self): return float(self.config.run_cfg.min_lr) @property def accum_grad_iters(self): return int(self.config.run_cfg.get("accum_grad_iters", 1)) @property def valid_splits(self): valid_splits = self.config.run_cfg.get("valid_splits", []) if len(valid_splits) == 0: logging.info("No validation splits found.") return valid_splits @property def test_splits(self): test_splits = self.config.run_cfg.get("test_splits", []) return test_splits @property def train_splits(self): train_splits = self.config.run_cfg.get("train_splits", []) if len(train_splits) == 0: logging.info("Empty train splits.") return train_splits @property def evaluate_only(self): """ Set to True to skip training. """ return self.config.run_cfg.evaluate @property def use_dist_eval_sampler(self): return self.config.run_cfg.get("use_dist_eval_sampler", True) @property def resume_ckpt_path(self): return self.config.run_cfg.get("resume_ckpt_path", None) @property def train_loader(self): train_dataloader = self.dataloaders["train"] return train_dataloader def setup_output_dir(self): lib_root = Path(registry.get_path("library_root")) output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id result_dir = output_dir / "result" output_dir.mkdir(parents=True, exist_ok=True) result_dir.mkdir(parents=True, exist_ok=True) registry.register_path("result_dir", str(result_dir)) registry.register_path("output_dir", str(output_dir)) self.result_dir = result_dir self.output_dir = output_dir def train(self): start_time = time.time() best_agg_metric = 0 best_epoch = 0 self.log_config() # resume from checkpoint if specified if not self.evaluate_only and self.resume_ckpt_path is not None: self._load_checkpoint(self.resume_ckpt_path) for cur_epoch in range(self.start_epoch, self.max_epoch): # training phase if not self.evaluate_only: logging.info("Start training") train_stats = self.train_epoch(cur_epoch) self.log_stats(split_name="train", stats=train_stats) # evaluation phase if len(self.valid_splits) > 0: for split_name in self.valid_splits: logging.info("Evaluating on {}.".format(split_name)) val_log = self.eval_epoch( split_name=split_name, cur_epoch=cur_epoch ) if val_log is not None: if is_main_process(): assert ( "agg_metrics" in val_log ), "No agg_metrics found in validation log." agg_metrics = val_log["agg_metrics"] if agg_metrics > best_agg_metric and split_name == "val": best_epoch, best_agg_metric = cur_epoch, agg_metrics self._save_checkpoint(cur_epoch, is_best=True) val_log.update({"best_epoch": best_epoch}) self.log_stats(val_log, split_name) else: # if no validation split is provided, we just save the checkpoint at the end of each epoch. if not self.evaluate_only: self._save_checkpoint(cur_epoch, is_best=False) if self.evaluate_only: break if self.config.run_cfg.distributed: dist.barrier() # # testing phase # test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch # self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logging.info("Training time {}".format(total_time_str)) def test(self): start_time = time.time() # resume from checkpoint if specified if not self.evaluate_only and self.resume_ckpt_path is not None: self._load_checkpoint(self.resume_ckpt_path) test_stats = self.test_epoch(1) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Testing time {}".format(total_time_str)) def evaluate(self, ckpt, cur_epoch="best", skip_reload=False): test_logs = dict() if len(self.test_splits) > 0: for split_name in self.test_splits: test_logs[split_name] = self.eval_epoch( ckpt, split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload ) return test_logs def train_epoch(self, epoch): # train self.model.train() return self.task.train_epoch( epoch=epoch, model=self.model, data_loader=self.train_loader, optimizer=self.optimizer, scaler=self.scaler, lr_scheduler=self.lr_scheduler, cuda_enabled=self.cuda_enabled, log_freq=self.log_freq, accum_grad_iters=self.accum_grad_iters, ) def test_epoch(self, epoch): # train self.model.eval() return self.task.test_epoch( epoch=epoch, model=self.model, data_loader=self.train_loader, optimizer=self.optimizer, scaler=self.scaler, lr_scheduler=self.lr_scheduler, cuda_enabled=self.cuda_enabled, log_freq=self.log_freq, accum_grad_iters=self.accum_grad_iters, ) @torch.no_grad() def eval_epoch(self, ckpt, split_name, cur_epoch, skip_reload=False): """ Evaluate the model on a given split. Args: split_name (str): name of the split to evaluate on. cur_epoch (int): current epoch. skip_reload_best (bool): whether to skip reloading the best checkpoint. During training, we will reload the best checkpoint for validation. During testing, we will use provided weights and skip reloading the best checkpoint . """ data_loader = self.dataloaders.get(split_name, None) assert data_loader, "data_loader for split {} is None.".format(split_name) # TODO In validation, you need to compute loss as well as metrics # TODO consider moving to model.before_evaluation() model = self.unwrap_dist_model(self.model) if not skip_reload and cur_epoch == "best": model = self._reload_model(model, ckpt) model.eval() self.task.before_evaluation( model=model, dataset=self.datasets[split_name], ) results = self.task.evaluation(model, data_loader) if results is not None: return self.task.after_evaluation( val_result=results, split_name=split_name, epoch=cur_epoch, ) def unwrap_dist_model(self, model): if self.use_distributed: return model.module else: return model def create_loaders( self, datasets, num_workers, batch_sizes, is_trains, collate_fns, dataset_ratios=None, ): """ Create dataloaders for training and validation. """ def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): # create a single dataloader for each split if isinstance(dataset, ChainDataset) or isinstance( dataset, wds.DataPipeline ): # wds.WebdDataset instance are chained together # webdataset.DataPipeline has its own sampler and collate_fn loader = iter( DataLoader( dataset, batch_size=bsz, num_workers=num_workers, pin_memory=True, ) ) else: # map-style dataset are concatenated together # setup distributed sampler if self.use_distributed: sampler = DistributedSampler( dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank(), ) if not self.use_dist_eval_sampler: # e.g. retrieval evaluation sampler = sampler if is_train else None else: sampler = None loader = DataLoader( dataset, batch_size=bsz, num_workers=num_workers, pin_memory=True, sampler=sampler, shuffle=sampler is None and is_train, collate_fn=collate_fn, drop_last=True if is_train else False, ) loader = PrefetchLoader(loader) if is_train: loader = IterLoader(loader, use_distributed=self.use_distributed) return loader loaders = [] for dataset, bsz, is_train, collate_fn in zip( datasets, batch_sizes, is_trains, collate_fns ): if isinstance(dataset, list) or isinstance(dataset, tuple): if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: dataset_ratios = [d.sample_ratio for d in dataset] loader = MultiIterLoader( loaders=[ _create_loader(d, num_workers, bsz, is_train, collate_fn[i]) for i, d in enumerate(dataset) ], ratios=dataset_ratios, ) else: loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) loaders.append(loader) return loaders @main_process def _save_checkpoint(self, cur_epoch, is_best=False): """ Save the checkpoint at the current epoch. """ model_no_ddp = self.unwrap_dist_model(self.model) param_grad_dic = { k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() } state_dict = model_no_ddp.state_dict() for k in list(state_dict.keys()): if k in param_grad_dic.keys() and not param_grad_dic[k]: # delete parameters that do not require gradient del state_dict[k] save_obj = { "model": state_dict, "optimizer": self.optimizer.state_dict(), "config": self.config.to_dict(), "scaler": self.scaler.state_dict() if self.scaler else None, "epoch": cur_epoch, } save_to = os.path.join( self.output_dir, "checkpoint_{}.pth".format("best" if is_best else cur_epoch), ) logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) torch.save(save_obj, save_to) def _reload_best_model(self, model): """ Load the best checkpoint for evaluation. """ checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") logging.info("Loading checkpoint from {}.".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location="cpu") try: model.load_state_dict(checkpoint["model"]) except RuntimeError as e: logging.warning( """ Key mismatch when loading checkpoint. This is expected if only part of the model is saved. Trying to load the model with strict=False. """ ) model.load_state_dict(checkpoint["model"], strict=False) return model def _reload_model(self, model,ckpt): """ Load the best checkpoint for evaluation. """ logging.info("Loading checkpoint from {}.".format(ckpt)) checkpoint = torch.load(ckpt, map_location="cpu") try: model.load_state_dict(checkpoint["model"]) except RuntimeError as e: logging.warning( """ Key mismatch when loading checkpoint. This is expected if only part of the model is saved. Trying to load the model with strict=False. """ ) model.load_state_dict(checkpoint["model"], strict=False) return model def _load_checkpoint(self, url_or_filename): """ Resume from a checkpoint. """ if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location=self.device) elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location=self.device) else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] self.unwrap_dist_model(self.model).load_state_dict(state_dict,strict=False) self.optimizer.load_state_dict(checkpoint["optimizer"]) if self.scaler and "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler"]) self.start_epoch = checkpoint["epoch"] + 1 logging.info("Resume checkpoint from {}".format(url_or_filename)) @main_process def log_stats(self, stats, split_name): if isinstance(stats, dict): log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} with open(os.path.join(self.output_dir, "log.txt"), "a") as f: f.write(json.dumps(log_stats) + "\n") elif isinstance(stats, list): pass @main_process def log_config(self): with open(os.path.join(self.output_dir, "log.txt"), "a") as f: f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") ================================================ FILE: xraypulse/tasks/__init__.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from xraypulse.common.registry import registry from xraypulse.tasks.base_task import BaseTask from xraypulse.tasks.image_text_pretrain import ImageTextPretrainTask def setup_task(cfg): assert "task" in cfg.run_cfg, "Task name must be provided." task_name = cfg.run_cfg.task task = registry.get_task_class(task_name).setup_task(cfg=cfg) assert task is not None, "Task {} not properly registered.".format(task_name) return task __all__ = [ "BaseTask", "ImageTextPretrainTask", ] ================================================ FILE: xraypulse/tasks/base_task.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import torch import torch.distributed as dist from xraypulse.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized from xraypulse.common.logger import MetricLogger, SmoothedValue from xraypulse.common.registry import registry from xraypulse.datasets.data_utils import prepare_sample import csv #for imprgpt rogue eval class BaseTask: def __init__(self, **kwargs): super().__init__() self.inst_id_key = "instance_id" @classmethod def setup_task(cls, **kwargs): return cls() def build_model(self, cfg): model_config = cfg.model_cfg model_cls = registry.get_model_class(model_config.arch) return model_cls.from_config(model_config) def build_datasets(self, cfg): """ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. Download dataset and annotations automatically if not exist. Args: cfg (common.config.Config): _description_ Returns: dict: Dictionary of torch.utils.data.Dataset objects by split. """ datasets = dict() datasets_config = cfg.datasets_cfg assert len(datasets_config) > 0, "At least one dataset has to be specified." for name in datasets_config: dataset_config = datasets_config[name] builder = registry.get_builder_class(name)(dataset_config) dataset = builder.build_datasets() dataset['train'].name = name if 'sample_ratio' in dataset_config: dataset['train'].sample_ratio = dataset_config.sample_ratio datasets[name] = dataset return datasets def train_step(self, model, samples): loss = model(samples)["loss"] return loss def test_step(self, model, samples): output_text, output_token = model.test(samples) return output_text def valid_step(self, model, samples): raise NotImplementedError def before_evaluation(self, model, dataset, **kwargs): model.before_evaluation(dataset=dataset, task_type=type(self)) def after_evaluation(self, **kwargs): pass def inference_step(self): raise NotImplementedError def evaluation(self, model, data_loader, cuda_enabled=True): metric_logger = MetricLogger(delimiter=" ") header = "Evaluation" # TODO make it configurable print_freq = 10 results = [] for samples in metric_logger.log_every(data_loader, print_freq, header): samples = prepare_sample(samples, cuda_enabled=cuda_enabled) eval_output = self.valid_step(model=model, samples=samples) results.extend(eval_output) if is_dist_avail_and_initialized(): dist.barrier() return results def train_epoch( self, epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, cuda_enabled=False, log_freq=50, accum_grad_iters=1, ): return self._train_inner_loop( epoch=epoch, iters_per_epoch=lr_scheduler.iters_per_epoch, model=model, data_loader=data_loader, optimizer=optimizer, scaler=scaler, lr_scheduler=lr_scheduler, log_freq=log_freq, cuda_enabled=cuda_enabled, accum_grad_iters=accum_grad_iters, ) def test_epoch( self, epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, cuda_enabled=False, log_freq=50, accum_grad_iters=1, ): return self._test_inner_loop( epoch=epoch, iters_per_epoch=lr_scheduler.iters_per_epoch, model=model, data_loader=data_loader, optimizer=optimizer, scaler=scaler, lr_scheduler=lr_scheduler, log_freq=log_freq, cuda_enabled=cuda_enabled, accum_grad_iters=accum_grad_iters, ) def train_iters( self, epoch, start_iters, iters_per_inner_epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, cuda_enabled=False, log_freq=50, accum_grad_iters=1, ): return self._train_inner_loop( epoch=epoch, start_iters=start_iters, iters_per_epoch=iters_per_inner_epoch, model=model, data_loader=data_loader, optimizer=optimizer, scaler=scaler, lr_scheduler=lr_scheduler, log_freq=log_freq, cuda_enabled=cuda_enabled, accum_grad_iters=accum_grad_iters, ) def test_iters( self, epoch, start_iters, iters_per_inner_epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, cuda_enabled=False, log_freq=50, accum_grad_iters=1, ): return self._test_inner_loop( epoch=epoch, start_iters=start_iters, iters_per_epoch=iters_per_inner_epoch, model=model, data_loader=data_loader, optimizer=optimizer, scaler=scaler, lr_scheduler=lr_scheduler, log_freq=log_freq, cuda_enabled=cuda_enabled, accum_grad_iters=accum_grad_iters, ) def _train_inner_loop( self, epoch, iters_per_epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, start_iters=None, log_freq=50, cuda_enabled=False, accum_grad_iters=1, ): """ An inner training loop compatible with both epoch-based and iter-based training. When using epoch-based, training stops after one epoch; when using iter-based, training stops after #iters_per_epoch iterations. """ use_amp = scaler is not None if not hasattr(data_loader, "__next__"): # convert to iterator if not already data_loader = iter(data_loader) metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) # if iter-based runner, schedule lr based on inner epoch. logging.info( "Start training epoch {}, {} iters per inner epoch.".format( epoch, iters_per_epoch ) ) header = "Train: data epoch: [{}]".format(epoch) if start_iters is None: # epoch-based runner inner_epoch = epoch else: # In iter-based runner, we schedule the learning rate based on iterations. inner_epoch = start_iters // iters_per_epoch header = header + "; inner epoch [{}]".format(inner_epoch) for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header): # if using iter-based runner, we stop after iters_per_epoch iterations. if i >= iters_per_epoch: break samples = next(data_loader) samples = prepare_sample(samples, cuda_enabled=cuda_enabled) samples.update( { "epoch": inner_epoch, "num_iters_per_epoch": iters_per_epoch, "iters": i, } ) lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i) with torch.cuda.amp.autocast(enabled=use_amp): loss = self.train_step(model=model, samples=samples) # after_train_step() if use_amp: scaler.scale(loss).backward() else: loss.backward() # update gradients every accum_grad_iters iterations if (i + 1) % accum_grad_iters == 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() metric_logger.update(loss=loss.item()) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # after train_epoch() # gather the stats from all processes metric_logger.synchronize_between_processes() logging.info("Averaged stats: " + str(metric_logger.global_avg())) return { k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items() } def _test_inner_loop( self, epoch, iters_per_epoch, model, data_loader, optimizer, lr_scheduler, scaler=None, start_iters=None, log_freq=50, cuda_enabled=False, accum_grad_iters=1, ): """ An inner testing loop . """ use_amp = scaler is not None if not hasattr(data_loader, "__next__"): # convert to iterator if not already data_loader = iter(data_loader) for I in range(iters_per_epoch) : samples = next(data_loader) samples = prepare_sample(samples, cuda_enabled=cuda_enabled) with torch.cuda.amp.autocast(enabled=use_amp): output_text = self.test_step(model=model, samples=samples) output_text = output_text.replace("", "") output_text = output_text.replace("Impression:", "") output_text = output_text.replace("\n", "") output_text = output_text.replace("#", "") output_text = output_text.replace("___", "") fields=[str(samples['image_id'][0].tolist()), output_text, samples['caption'][0]] with open(r'/mnt/lustre/huangzhongzhen/pretrain_our/ImprMiniGPT/ours_medclip/minigpt4_stage3_all_v2.2(mimic_chatgpt)_radiology_finetune/20230509154/result/vanilla_minigpt.csv', 'a') as f: writer = csv.writer(f) writer.writerow(fields) print('{}/{}'.format(I,iters_per_epoch)) @staticmethod def save_result(result, result_dir, filename, remove_duplicate=""): import json result_file = os.path.join( result_dir, "%s_rank%d.json" % (filename, get_rank()) ) final_result_file = os.path.join(result_dir, "%s.json" % filename) json.dump(result, open(result_file, "w")) if is_dist_avail_and_initialized(): dist.barrier() if is_main_process(): logging.warning("rank %d starts merging results." % get_rank()) # combine results from all processes result = [] for rank in range(get_world_size()): result_file = os.path.join( result_dir, "%s_rank%d.json" % (filename, rank) ) res = json.load(open(result_file, "r")) result += res if remove_duplicate: result_new = [] id_list = [] for res in result: if res[remove_duplicate] not in id_list: id_list.append(res[remove_duplicate]) result_new.append(res) result = result_new json.dump(result, open(final_result_file, "w")) print("result file saved to %s" % final_result_file) return final_result_file ================================================ FILE: xraypulse/tasks/image_text_pretrain.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from xraypulse.common.registry import registry from xraypulse.tasks.base_task import BaseTask @registry.register_task("image_text_pretrain") class ImageTextPretrainTask(BaseTask): def __init__(self): super().__init__() def evaluation(self, model, data_loader, cuda_enabled=True): pass