[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Learning Multi-dimensional Human Preference for Text-to-Image Generation (CVPR 2024)\nThis repository contains the code and model for the paper [Learning Multi-dimensional Human Preference for Text-to-Image Generation](https://openaccess.thecvf.com/content/CVPR2024/papers/Zhang_Learning_Multi-Dimensional_Human_Preference_for_Text-to-Image_Generation_CVPR_2024_paper.pdf). \n\n<img src=\"./framework.png\" width=\"60%\">\n\n## Installation\nCreate a virual env and download torch:\n\n```bash\nconda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n```\n\nInstall the requirements:\n```bash\npip install -r requirements.txt\npip install -e .\n```\n\n## Inference with MPS\nWe display here an example for running inference with MPS:\n```python\n# import\nfrom transformers import AutoProcessor, AutoModel\nfrom PIL import Image\nimport torch\n\n# load model\ndevice = \"cuda\"\nprocessor_name_or_path = \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\"\nimage_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)\ntokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)\n\nmodel_ckpt_path = \"outputs/MPS_overall_checkpoint.pth\"\nmodel = torch.load(model_ckpt_path)\nmodel.eval().to(device)\n\ndef infer_example(images, prompt, condition, clip_model, clip_processor, tokenizer, device):\n    def _process_image(image):\n        if isinstance(image, dict):\n            image = image[\"bytes\"]\n        if isinstance(image, bytes):\n            image = Image.open(BytesIO(image))\n        if isinstance(image, str):\n            image = Image.open( image )\n        image = image.convert(\"RGB\")\n        pixel_values = clip_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return pixel_values\n    \n    def _tokenize(caption):\n        input_ids = tokenizer(\n            caption,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        return input_ids\n    \n    image_inputs = torch.concatenate([_process_image(images[0]).to(device), _process_image(images[1]).to(device)])\n    text_inputs = _tokenize(prompt).to(device)\n    condition_inputs = _tokenize(condition).to(device)\n\n    with torch.no_grad():\n        text_features, image_0_features, image_1_features = clip_model(text_inputs, image_inputs, condition_inputs)\n        image_0_features = image_0_features / image_0_features.norm(dim=-1, keepdim=True)\n        image_1_features = image_1_features / image_1_features.norm(dim=-1, keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n        image_0_scores = clip_model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_0_features))\n        image_1_scores = clip_model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_1_features))\n        scores = torch.stack([image_0_scores, image_1_scores], dim=-1)\n        probs = torch.softmax(scores, dim=-1)[0]\n\n    return probs.cpu().tolist()\n\nimg_0, img_1 = \"image1.jpg\", \"image2.jpg\"\n# infer the best image for the caption\nprompt = \"the caption of image\" \n\n# condition for overall\ncondition = \"light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things.\" \n\nprint(infer_example([img_0, img_1], prompt, condition, model, image_processor, tokenizer, device))\n```\n\n## Download the MPS checkpoint\n<table>\n  <tr>\n    <th rowspan=\"2\" text-align=\"center\">ID</th>\n    <th colspan=\"4\" text-align=\"center\">Training Data</th>\n    <th rowspan=\"2\" text-align=\"center\">MPS Model</th>\n  </tr>\n  <tr>\n    <th text-align=\"center\">Overall</th>\n    <th text-align=\"center\">Aesthetics</th>\n    <th text-align=\"center\">Alignment</th>\n    <th text-align=\"center\">Detail</th>\n  </tr>\n  <tr text-align=\"center\">\n    <td>&nbsp;1</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&#10003;</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;-</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;-</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;-</td>\n    <td>&nbsp;<a href=\"http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing\">Model Link</a></td>\n  </tr>\n  <tr>\n    <td>&nbsp;2</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&#10003;</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&#10003;</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&#10003;</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&#10003;</td>\n    <td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;-</td>\n  </tr>\n</table>\n\nDue to the internal model approval process within the company, we only release MPS trained on overall preference, while MPS trained on multi human preferences will be open-sourced once it passes the approval process; however, there is a risk of delays and the possibility of force majeure events.\n(Move the checkpoint file to `outputs/MPS_overall_checkpoint.pth`)\n\n\n## Evaluation\nTest MPS on ImageReward benchmark:\n\nPlease download the file, `datasets/test.json` to `imagereward/test.json` from [ImageReward](https://github.com/kekewind/ImageReward) and the related images from [ImageRewardDB](https://huggingface.co/datasets/THUDM/ImageRewardDB) as well.\n```bash\n python eval_overall_mhp_on_imagereward.py\n```\nTest MPS on hpd_v2 benchmark:\n\nPlease download the annotation file, `test.json` to `hpdv2/test.json` and the related images(test dataset) from [HPDv2](https://huggingface.co/datasets/ymhao/HPDv2/tree/main).\n```bash\n python eval_overall_mhp_on_hpdv2.py\n```\n\n## Results on different datasets\n| ID  | Preference Model     | ImageReward | HPD v2 | MHP (Overall) |\n|:-:|:-:|:-:|:-:|:-:|\n| 1   | CLIP score           | 54.3        | 71.2   | 63.7          |\n| 2   | Aesthetic Score      | 57.4        | 72.6   | 62.9          |\n| 3   | ImageReward          | 65.1        | 70.6   | 67.5          |\n| 4   | HPS                  | 61.2        | 73.1   | 65.5          |\n| 5   | PickScore            | 62.9        | 79.8   | 69.5          |\n| 6   | HPS v2               | 65.7        | 83.3   | 65.5          |\n| 7   | **MPS (Ours)**           | **67.5**        | **83.5**   | **74.2**          |\n\n\n## Citation\nIf you find this work useful, please cite:\n```bibtex\n@inproceedings{MPS,\n  title={Learning Multi-dimensional Human Preference for Text-to-Image Generation},\n  author={Zhang, Sixian and Wang, Bohan and Wu, Junqiang and Li, Yan and Gao, Tingting and Zhang, Di and Wang, Zhongyuan},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={8018--8027},\n  year={2024}\n}\n```\n\n## Acknowledgments\nWe thank the authors of [ImageReward](https://github.com/kekewind/ImageReward), [HPS](https://github.com/tgxs002/align_sd), [HPS v2](https://github.com/tgxs002/HPSv2), and [PickScore](https://github.com/yuvalkirstain/PickScore) for their codes and papers, which greatly contributed to our work.\n"
  },
  {
    "path": "eval_overall_mps_on_hpdv2.py",
    "content": "import numpy as np\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\nfrom tqdm.auto import tqdm\nfrom fire import Fire\nfrom transformers import CLIPFeatureExtractor, CLIPImageProcessor\n\nfrom dataclasses import dataclass\nfrom transformers import CLIPModel as HFCLIPModel\n\nfrom torch import nn, einsum\n\nfrom trainer.models.base_model import BaseModelConfig\n\nfrom transformers import CLIPConfig\nfrom transformers import AutoProcessor, AutoModel, AutoTokenizer\nfrom typing import Any, Optional, Tuple, Union\nimport torch\nimport cv2\nimport os\n\nfrom trainer.models.cross_modeling import Cross_model\nimport matplotlib.pyplot as plt\nimport torch.nn.functional as F\n\nimport gc\nimport json\n\n\n@torch.no_grad()\n\ndef infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device, condition=None):\n    def _process_image(image):\n        if isinstance(image, dict):\n            image = image[\"bytes\"]\n        if isinstance(image, bytes):\n            image = Image.open(BytesIO(image))\n        if isinstance(image, str):\n            image = Image.open( image )\n        image = image.convert(\"RGB\")\n        pixel_values = clip_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return pixel_values\n    \n    def _tokenize(caption):\n        input_ids = tokenizer(\n            caption,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        return input_ids\n    \n    image_input = _process_image(image).to(device)\n    text_input = _tokenize(prompt).to(device)\n    if condition is None:\n        condition = \"light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things.\"\n    condition_batch = _tokenize(condition).repeat(text_input.shape[0],1).to(device)\n\n    with torch.no_grad():\n        text_f, text_features = clip_model.model.get_text_features(text_input)\n\n        image_f = clip_model.model.get_image_features(image_input.half())\n        condition_f, _ = clip_model.model.get_text_features(condition_batch)\n\n        sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)\n        sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]\n        sim_text_condition = sim_text_condition / sim_text_condition.max()\n        mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))\n        mask = mask.repeat(1,image_f.shape[1],1)\n        image_features = clip_model.cross_model(image_f, text_f,mask.half())[:,0,:]\n\n        image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n        image_score = clip_model.logit_scale.exp() * text_features @ image_features.T\n\n    return image_score[0]\n\ndef infer_example(images, prompt, clip_model, clip_processor, tokenizer, device):\n    scores = []\n    for image in images:\n        score = infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device)\n        scores.append(score)\n    scores = torch.stack(scores, dim=-1)\n    probs = torch.softmax(scores, dim=-1)[0]\n    return probs.cpu().tolist()\n\ndef acc(score_sample, predict_sample):\n    tol_cnt = 0.\n    true_cnt = 0.\n    for idx in range(len(score_sample)):\n        item_base = score_sample[idx][\"rank\"]\n        item = predict_sample[idx][\"rewards\"]\n        for i in range(len(item_base)):\n            for j in range(i+1, len(item_base)):\n                if item_base[i] > item_base[j]:\n                    if item[i] >= item[j]:\n                        tol_cnt += 1\n                    elif item[i] < item[j]:\n                        tol_cnt += 1\n                        true_cnt += 1\n                elif item_base[i] < item_base[j]:\n                    if item[i] > item[j]:\n                        tol_cnt += 1\n                        true_cnt += 1\n                    elif item[i] <= item[j]:\n                        tol_cnt += 1\n    return true_cnt / tol_cnt\n\ndef inversion_score(predict_sample, score_sample):\n    n = len(score_sample)\n    cnt = 0\n    for i in range(n-1):\n        for j in range(i+1, n):\n            if score_sample[i] > score_sample[j] and predict_sample[i] > predict_sample[j]:\n                cnt += 1\n            elif score_sample[i] < score_sample[j] and predict_sample[i] < predict_sample[j]:\n                cnt += 1\n    return 1 - cnt / (n * (n - 1) / 2)\n\ndef main():\n    processor_name_or_path = \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\"\n\n    device = \"cuda\"\n    image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)\n    tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)\n\n    model_ckpt_path = \"outputs/MPS_overall_checkpoint.pth\"\n    model = torch.load(model_ckpt_path)\n    model.eval().to(device)\n\n    score_sample = []\n    with open(\"hpdv2/test.json\", \"r\") as f:\n        score_sample = json.load(f)\n    \n    predict_sample = []\n    score = 0.\n    with torch.no_grad():\n        for i in range(len(score_sample)):\n            item = score_sample[i]\n            rewards = infer_example(item[\"image_path\"], item[\"prompt\"], model, image_processor, tokenizer, device)\n            score += inversion_score(rewards, item['rank'])\n    test_acc = score / len(score_sample)\n    print(f\"HPDv2 Test Acc: {100 * test_acc:.2f}%\")\n\n\nif __name__ == '__main__':\n    Fire(main)\n"
  },
  {
    "path": "eval_overall_mps_on_imagereward.py",
    "content": "import numpy as np\n# from transformers import AutoProcessor #, AutoModel\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\nfrom tqdm.auto import tqdm\nfrom fire import Fire\nfrom transformers import CLIPFeatureExtractor, CLIPImageProcessor\n\nfrom dataclasses import dataclass\nfrom transformers import CLIPModel as HFCLIPModel\n\nfrom torch import nn, einsum\n\nfrom trainer.models.base_model import BaseModelConfig\n\nfrom transformers import CLIPConfig\nfrom transformers import AutoProcessor, AutoModel, AutoTokenizer\nfrom typing import Any, Optional, Tuple, Union\nimport torch\nimport cv2\nimport os\n\nfrom trainer.models.cross_modeling import Cross_model\nimport matplotlib.pyplot as plt\nimport torch.nn.functional as F\n\nimport gc\nimport json\n\n\n@torch.no_grad()\n\ndef infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device, condition=None):\n    def _process_image(image):\n        if isinstance(image, dict):\n            image = image[\"bytes\"]\n        if isinstance(image, bytes):\n            image = Image.open(BytesIO(image))\n        if isinstance(image, str):\n            image = Image.open( image )\n        image = image.convert(\"RGB\")\n        pixel_values = clip_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return pixel_values\n    \n    def _tokenize(caption):\n        input_ids = tokenizer(\n            caption,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\"\n        ).input_ids\n        return input_ids\n\n    image_input = _process_image(image).to(device)\n    text_input = _tokenize(prompt).to(device)\n    if condition is None:\n        condition = \"light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things.\"\n    condition_batch = _tokenize(condition).repeat(text_input.shape[0],1).to(device)\n\n    with torch.no_grad():\n        text_f, text_features = clip_model.model.get_text_features(text_input)\n\n        image_f = clip_model.model.get_image_features(image_input.half())\n        condition_f, _ = clip_model.model.get_text_features(condition_batch)\n\n        sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)\n        sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]\n        sim_text_condition = sim_text_condition / sim_text_condition.max()\n        mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))\n        mask = mask.repeat(1,image_f.shape[1],1)\n        image_features = clip_model.cross_model(image_f, text_f,mask.half())[:,0,:]\n\n        image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n        image_score = clip_model.logit_scale.exp() * text_features @ image_features.T\n    return image_score[0]\n\ndef infer_example(images, prompt, clip_model, clip_processor, tokenizer, device):\n    scores = []\n    for image in images:\n        score = infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device)\n        scores.append(score)\n    scores = torch.stack(scores, dim=-1)\n    probs = torch.softmax(scores, dim=-1)[0]\n    return probs.cpu().tolist()\n\n\ndef acc(score_sample, predict_sample):\n    tol_cnt = 0.\n    true_cnt = 0.\n    for idx in range(len(score_sample)):\n        item_base = score_sample[idx][\"ranking\"]\n        item = predict_sample[idx][\"rewards\"]\n        for i in range(len(item_base)):\n            for j in range(i+1, len(item_base)):\n                if item_base[i] > item_base[j]:\n                    if item[i] >= item[j]:\n                        tol_cnt += 1\n                    elif item[i] < item[j]:\n                        tol_cnt += 1\n                        true_cnt += 1\n                elif item_base[i] < item_base[j]:\n                    if item[i] > item[j]:\n                        tol_cnt += 1\n                        true_cnt += 1\n                    elif item[i] <= item[j]:\n                        tol_cnt += 1\n    return true_cnt / tol_cnt\n                \n\ndef main():\n    processor_name_or_path = \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\"\n\n    device = \"cuda\"\n    image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)\n    tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)\n\n    model_ckpt_path = \"outputs/MPS_overall_checkpoint.pth\"\n    model = torch.load(model_ckpt_path)\n    model.eval().to(device)\n\n    score_sample = []\n    with open(\"imagereward/test.json\", \"r\") as f: # change the path to the ImageReward test dataset\n        score_sample = json.load(f)\n    \n    predict_sample = []\n    with torch.no_grad():\n        for item in score_sample:\n            rewards = infer_example(item[\"generations\"], item[\"prompt\"], model, image_processor, tokenizer, device)\n            predict_item = {\n                \"id\": item[\"id\"],\n                \"prompt\": item[\"prompt\"],\n                \"rewards\": rewards\n            }\n            predict_sample.append(predict_item)\n    test_acc = acc(score_sample, predict_sample)\n    print(f\"ImageReward Test Acc: {100 * test_acc:.2f}%\")\n\n\nif __name__ == '__main__':\n    Fire(main)\n"
  },
  {
    "path": "requirements.txt",
    "content": "accelerate @ git+https://github.com/huggingface/accelerate.git@d1aa558119859c4b205a324afabaecabd9ef375e\ndatasets==2.10.1\ndeepspeed==0.8.3\nfire==0.4.0\nhydra-core==1.3.2\nrich==13.3.2\nsubmitit==1.4.5\ntransformers==4.27.3\nwandb==0.12.21"
  },
  {
    "path": "trainer/models/base_model.py",
    "content": "from dataclasses import dataclass\n\n\n\n@dataclass\nclass BaseModelConfig:\n    pass\n"
  },
  {
    "path": "trainer/models/clip_model.py",
    "content": "from dataclasses import dataclass\nfrom transformers import CLIPModel as HFCLIPModel\nfrom transformers import AutoTokenizer\n\nfrom torch import nn, einsum\n\nfrom trainer.models.base_model import BaseModelConfig\n\nfrom transformers import CLIPConfig\nfrom typing import Any, Optional, Tuple, Union\nimport torch\n\nfrom trainer.models.cross_modeling import Cross_model\n\nimport gc\n\nclass XCLIPModel(HFCLIPModel):\n    def __init__(self, config: CLIPConfig):\n        super().__init__(config)\n    \n    def get_text_features(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> torch.FloatTensor:\n\n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        text_outputs = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        # pooled_output = text_outputs[1]\n        # text_features = self.text_projection(pooled_output)\n        last_hidden_state = text_outputs[0]\n        text_features = self.text_projection(last_hidden_state)\n\n        pooled_output = text_outputs[1]\n        text_features_EOS = self.text_projection(pooled_output)\n\n\n        # del last_hidden_state, text_outputs\n        # gc.collect()\n\n        return text_features, text_features_EOS\n\n    def get_image_features(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> torch.FloatTensor:\n        \n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        vision_outputs = self.vision_model(\n            pixel_values=pixel_values,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        # pooled_output = vision_outputs[1]  # pooled_output\n        # image_features = self.visual_projection(pooled_output)\n        last_hidden_state = vision_outputs[0]\n        image_features = self.visual_projection(last_hidden_state)\n\n        return image_features\n\n\n\n@dataclass\nclass ClipModelConfig(BaseModelConfig):\n    _target_: str = \"trainer.models.clip_model.CLIPModel\"\n    pretrained_model_name_or_path: str =\"openai/clip-vit-base-patch32\"\n\n\nclass CLIPModel(nn.Module):\n    def __init__(self, ckpt):\n        super().__init__()\n        self.model = XCLIPModel.from_pretrained(ckpt)\n        self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)\n    \n    def get_text_features(self, *args, **kwargs):\n        return self.model.get_text_features(*args, **kwargs)\n\n    def get_image_features(self, *args, **kwargs):\n        return self.model.get_image_features(*args, **kwargs)\n\n    def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):\n        outputs = ()\n\n        text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024\n        outputs += text_EOS,\n\n        image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024\n        condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024\n\n        sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)\n        sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]\n        sim_text_condition = sim_text_condition / sim_text_condition.max()\n        mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77\n\n        mask = mask.repeat(1,image_f.shape[1],1) # B*257*77\n        bc = int(image_f.shape[0]/2)\n\n        sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())\n        sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())\n        outputs += sim0[:,0,:],\n        outputs += sim1[:,0,:],\n\n        return outputs\n\n    @property\n    def logit_scale(self):\n        return self.model.logit_scale\n\n    def save(self, path):\n        self.model.save_pretrained(path)\n\n"
  },
  {
    "path": "trainer/models/cross_modeling.py",
    "content": "import torch\r\nfrom torch import einsum, nn\r\nimport torch.nn.functional as F\r\nfrom einops import rearrange, repeat\r\n\r\n# helper functions\r\n\r\ndef exists(val):\r\n    return val is not None\r\n\r\ndef default(val, d):\r\n    return val if exists(val) else d\r\n\r\n# normalization\r\n# they use layernorm without bias, something that pytorch does not offer\r\n\r\n\r\nclass LayerNorm(nn.Module):\r\n    def __init__(self, dim):\r\n        super().__init__()\r\n        self.weight = nn.Parameter(torch.ones(dim))\r\n        self.register_buffer(\"bias\", torch.zeros(dim))\r\n\r\n    def forward(self, x):\r\n        return F.layer_norm(x, x.shape[-1:], self.weight, self.bias)\r\n\r\n# residual\r\n\r\n\r\nclass Residual(nn.Module):\r\n    def __init__(self, fn):\r\n        super().__init__()\r\n        self.fn = fn\r\n\r\n    def forward(self, x, *args, **kwargs):\r\n        return self.fn(x, *args, **kwargs) + x\r\n\r\n\r\n# rotary positional embedding\r\n# https://arxiv.org/abs/2104.09864\r\n\r\n\r\nclass RotaryEmbedding(nn.Module):\r\n    def __init__(self, dim):\r\n        super().__init__()\r\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))\r\n        self.register_buffer(\"inv_freq\", inv_freq)\r\n\r\n    def forward(self, max_seq_len, *, device):\r\n        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)\r\n        freqs = einsum(\"i , j -> i j\", seq, self.inv_freq)\r\n        return torch.cat((freqs, freqs), dim=-1)\r\n\r\n\r\ndef rotate_half(x):\r\n    x = rearrange(x, \"... (j d) -> ... j d\", j=2)\r\n    x1, x2 = x.unbind(dim=-2)\r\n    return torch.cat((-x2, x1), dim=-1)\r\n\r\n\r\ndef apply_rotary_pos_emb(pos, t):\r\n    return (t * pos.cos()) + (rotate_half(t) * pos.sin())\r\n\r\n\r\n# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward\r\n# https://arxiv.org/abs/2002.05202\r\n\r\n\r\nclass SwiGLU(nn.Module):\r\n    def forward(self, x):\r\n        x, gate = x.chunk(2, dim=-1)\r\n        return F.silu(gate) * x\r\n\r\n\r\n# parallel attention and feedforward with residual\r\n# discovered by Wang et al + EleutherAI from GPT-J fame\r\n\r\nclass ParallelTransformerBlock(nn.Module):\r\n    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):\r\n        super().__init__()\r\n        self.norm = LayerNorm(dim)\r\n\r\n        attn_inner_dim = dim_head * heads\r\n        ff_inner_dim = dim * ff_mult\r\n        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))\r\n\r\n        self.heads = heads\r\n        self.scale = dim_head**-0.5\r\n        self.rotary_emb = RotaryEmbedding(dim_head)\r\n\r\n        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)\r\n        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)\r\n\r\n        self.ff_out = nn.Sequential(\r\n            SwiGLU(),\r\n            nn.Linear(ff_inner_dim, dim, bias=False)\r\n        )\r\n\r\n        self.register_buffer(\"pos_emb\", None, persistent=False)\r\n\r\n\r\n    def get_rotary_embedding(self, n, device):\r\n        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:\r\n            return self.pos_emb[:n]\r\n\r\n        pos_emb = self.rotary_emb(n, device=device)\r\n        self.register_buffer(\"pos_emb\", pos_emb, persistent=False)\r\n        return pos_emb\r\n\r\n    def forward(self, x, attn_mask=None):\r\n        \"\"\"\r\n        einstein notation\r\n        b - batch\r\n        h - heads\r\n        n, i, j - sequence length (base sequence length, source, target)\r\n        d - feature dimension\r\n        \"\"\"\r\n\r\n        n, device, h = x.shape[1], x.device, self.heads\r\n\r\n        # pre layernorm\r\n\r\n        x = self.norm(x)\r\n\r\n        # attention queries, keys, values, and feedforward inner\r\n\r\n        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)\r\n\r\n        # split heads\r\n        # they use multi-query single-key-value attention, yet another Noam Shazeer paper\r\n        # they found no performance loss past a certain scale, and more efficient decoding obviously\r\n        # https://arxiv.org/abs/1911.02150\r\n\r\n        q = rearrange(q, \"b n (h d) -> b h n d\", h=h)\r\n\r\n        # rotary embeddings\r\n\r\n        positions = self.get_rotary_embedding(n, device)\r\n        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))\r\n\r\n        # scale\r\n\r\n        q = q * self.scale\r\n\r\n        # similarity\r\n\r\n        sim = einsum(\"b h i d, b j d -> b h i j\", q, k)\r\n\r\n\r\n        # extra attention mask - for masking out attention from text CLS token to padding\r\n\r\n        if exists(attn_mask):\r\n            attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')\r\n            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)\r\n\r\n        # attention\r\n\r\n        sim = sim - sim.amax(dim=-1, keepdim=True).detach()\r\n        attn = sim.softmax(dim=-1)\r\n\r\n        # aggregate values\r\n\r\n        out = einsum(\"b h i j, b j d -> b h i d\", attn, v)\r\n\r\n        # merge heads\r\n\r\n        out = rearrange(out, \"b h n d -> b n (h d)\")\r\n        return self.attn_out(out) + self.ff_out(ff)\r\n\r\n# cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward\r\n\r\nclass CrossAttention(nn.Module):\r\n    def __init__(\r\n        self,\r\n        dim,\r\n        *,\r\n        context_dim=None,\r\n        dim_head=64,\r\n        heads=12,\r\n        parallel_ff=False,\r\n        ff_mult=4,\r\n        norm_context=False\r\n    ):\r\n        super().__init__()\r\n        self.heads = heads\r\n        self.scale = dim_head ** -0.5\r\n        inner_dim = heads * dim_head\r\n        context_dim = default(context_dim, dim)\r\n\r\n        self.norm = LayerNorm(dim)\r\n        self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()\r\n\r\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\r\n        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)\r\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\r\n\r\n        # whether to have parallel feedforward\r\n\r\n        ff_inner_dim = ff_mult * dim\r\n\r\n        self.ff = nn.Sequential(\r\n            nn.Linear(dim, ff_inner_dim * 2, bias=False),\r\n            SwiGLU(),\r\n            nn.Linear(ff_inner_dim, dim, bias=False)\r\n        ) if parallel_ff else None\r\n\r\n    def forward(self, x, context, mask):\r\n        \"\"\"\r\n        einstein notation\r\n        b - batch\r\n        h - heads\r\n        n, i, j - sequence length (base sequence length, source, target)\r\n        d - feature dimension\r\n        \"\"\"\r\n\r\n        # pre-layernorm, for queries and context\r\n\r\n        x = self.norm(x)\r\n        context = self.context_norm(context)\r\n\r\n        # get queries\r\n\r\n        q = self.to_q(x)\r\n        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)\r\n\r\n        # scale\r\n\r\n        q = q * self.scale\r\n\r\n        # get key / values\r\n\r\n        k, v = self.to_kv(context).chunk(2, dim=-1)\r\n\r\n        # query / key similarity\r\n\r\n        sim = einsum('b h i d, b j d -> b h i j', q, k)\r\n\r\n        # attention\r\n        mask = mask.unsqueeze(1).repeat(1,self.heads,1,1)\r\n        sim = sim + mask  # context mask\r\n        sim = sim - sim.amax(dim=-1, keepdim=True)\r\n        attn = sim.softmax(dim=-1)\r\n\r\n        # aggregate\r\n\r\n        out = einsum('b h i j, b j d -> b h i d', attn, v)\r\n\r\n        # merge and combine heads\r\n\r\n        out = rearrange(out, 'b h n d -> b n (h d)')\r\n        out = self.to_out(out)\r\n\r\n        # add parallel feedforward (for multimodal layers)\r\n\r\n        if exists(self.ff):\r\n            out = out + self.ff(x)\r\n\r\n        return out\r\n\r\n\r\nclass Cross_model(nn.Module):\r\n    def __init__(\r\n        self,\r\n        dim=512,\r\n        layer_num=4,\r\n        dim_head=64,\r\n        heads=8,\r\n        ff_mult=4\r\n    ):\r\n        super().__init__()\r\n\r\n        self.layers = nn.ModuleList([])\r\n\r\n\r\n        for ind in range(layer_num):\r\n            self.layers.append(nn.ModuleList([\r\n                Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)),\r\n                Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult))\r\n            ]))\r\n\r\n    def forward(\r\n        self,\r\n        query_tokens,\r\n        context_tokens,\r\n        mask\r\n    ):\r\n\r\n        for cross_attn, self_attn_ff in self.layers:\r\n            query_tokens = cross_attn(query_tokens, context_tokens,mask)\r\n            query_tokens = self_attn_ff(query_tokens)\r\n\r\n        return query_tokens\r\n"
  }
]