Repository: Kwai-Kolors/MPS Branch: main Commit: be9027f3e909 Files: 8 Total size: 31.7 KB Directory structure: gitextract_7pb8rbkq/ ├── LICENSE ├── README.md ├── eval_overall_mps_on_hpdv2.py ├── eval_overall_mps_on_imagereward.py ├── requirements.txt └── trainer/ └── models/ ├── base_model.py ├── clip_model.py └── cross_modeling.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Learning Multi-dimensional Human Preference for Text-to-Image Generation (CVPR 2024) This repository contains the code and model for the paper [Learning Multi-dimensional Human Preference for Text-to-Image Generation](https://openaccess.thecvf.com/content/CVPR2024/papers/Zhang_Learning_Multi-Dimensional_Human_Preference_for_Text-to-Image_Generation_CVPR_2024_paper.pdf). ## Installation Create a virual env and download torch: ```bash conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia ``` Install the requirements: ```bash pip install -r requirements.txt pip install -e . ``` ## Inference with MPS We display here an example for running inference with MPS: ```python # import from transformers import AutoProcessor, AutoModel from PIL import Image import torch # load model device = "cuda" processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) model_ckpt_path = "outputs/MPS_overall_checkpoint.pth" model = torch.load(model_ckpt_path) model.eval().to(device) def infer_example(images, prompt, condition, clip_model, clip_processor, tokenizer, device): def _process_image(image): if isinstance(image, dict): image = image["bytes"] if isinstance(image, bytes): image = Image.open(BytesIO(image)) if isinstance(image, str): image = Image.open( image ) image = image.convert("RGB") pixel_values = clip_processor(image, return_tensors="pt")["pixel_values"] return pixel_values def _tokenize(caption): input_ids = tokenizer( caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return input_ids image_inputs = torch.concatenate([_process_image(images[0]).to(device), _process_image(images[1]).to(device)]) text_inputs = _tokenize(prompt).to(device) condition_inputs = _tokenize(condition).to(device) with torch.no_grad(): text_features, image_0_features, image_1_features = clip_model(text_inputs, image_inputs, condition_inputs) image_0_features = image_0_features / image_0_features.norm(dim=-1, keepdim=True) image_1_features = image_1_features / image_1_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_0_scores = clip_model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_0_features)) image_1_scores = clip_model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_1_features)) scores = torch.stack([image_0_scores, image_1_scores], dim=-1) probs = torch.softmax(scores, dim=-1)[0] return probs.cpu().tolist() img_0, img_1 = "image1.jpg", "image2.jpg" # infer the best image for the caption prompt = "the caption of image" # condition for overall condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." print(infer_example([img_0, img_1], prompt, condition, model, image_processor, tokenizer, device)) ``` ## Download the MPS checkpoint
ID Training Data MPS Model
Overall Aesthetics Alignment Detail
 1     ✓         -         -     -  Model Link
 2     ✓         ✓         ✓     ✓         -
Due to the internal model approval process within the company, we only release MPS trained on overall preference, while MPS trained on multi human preferences will be open-sourced once it passes the approval process; however, there is a risk of delays and the possibility of force majeure events. (Move the checkpoint file to `outputs/MPS_overall_checkpoint.pth`) ## Evaluation Test MPS on ImageReward benchmark: Please download the file, `datasets/test.json` to `imagereward/test.json` from [ImageReward](https://github.com/kekewind/ImageReward) and the related images from [ImageRewardDB](https://huggingface.co/datasets/THUDM/ImageRewardDB) as well. ```bash python eval_overall_mhp_on_imagereward.py ``` Test MPS on hpd_v2 benchmark: Please download the annotation file, `test.json` to `hpdv2/test.json` and the related images(test dataset) from [HPDv2](https://huggingface.co/datasets/ymhao/HPDv2/tree/main). ```bash python eval_overall_mhp_on_hpdv2.py ``` ## Results on different datasets | ID | Preference Model | ImageReward | HPD v2 | MHP (Overall) | |:-:|:-:|:-:|:-:|:-:| | 1 | CLIP score | 54.3 | 71.2 | 63.7 | | 2 | Aesthetic Score | 57.4 | 72.6 | 62.9 | | 3 | ImageReward | 65.1 | 70.6 | 67.5 | | 4 | HPS | 61.2 | 73.1 | 65.5 | | 5 | PickScore | 62.9 | 79.8 | 69.5 | | 6 | HPS v2 | 65.7 | 83.3 | 65.5 | | 7 | **MPS (Ours)** | **67.5** | **83.5** | **74.2** | ## Citation If you find this work useful, please cite: ```bibtex @inproceedings{MPS, title={Learning Multi-dimensional Human Preference for Text-to-Image Generation}, author={Zhang, Sixian and Wang, Bohan and Wu, Junqiang and Li, Yan and Gao, Tingting and Zhang, Di and Wang, Zhongyuan}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={8018--8027}, year={2024} } ``` ## Acknowledgments We thank the authors of [ImageReward](https://github.com/kekewind/ImageReward), [HPS](https://github.com/tgxs002/align_sd), [HPS v2](https://github.com/tgxs002/HPSv2), and [PickScore](https://github.com/yuvalkirstain/PickScore) for their codes and papers, which greatly contributed to our work. ================================================ FILE: eval_overall_mps_on_hpdv2.py ================================================ import numpy as np import torch from PIL import Image from io import BytesIO from tqdm.auto import tqdm from fire import Fire from transformers import CLIPFeatureExtractor, CLIPImageProcessor from dataclasses import dataclass from transformers import CLIPModel as HFCLIPModel from torch import nn, einsum from trainer.models.base_model import BaseModelConfig from transformers import CLIPConfig from transformers import AutoProcessor, AutoModel, AutoTokenizer from typing import Any, Optional, Tuple, Union import torch import cv2 import os from trainer.models.cross_modeling import Cross_model import matplotlib.pyplot as plt import torch.nn.functional as F import gc import json @torch.no_grad() def infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device, condition=None): def _process_image(image): if isinstance(image, dict): image = image["bytes"] if isinstance(image, bytes): image = Image.open(BytesIO(image)) if isinstance(image, str): image = Image.open( image ) image = image.convert("RGB") pixel_values = clip_processor(image, return_tensors="pt")["pixel_values"] return pixel_values def _tokenize(caption): input_ids = tokenizer( caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return input_ids image_input = _process_image(image).to(device) text_input = _tokenize(prompt).to(device) if condition is None: condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." condition_batch = _tokenize(condition).repeat(text_input.shape[0],1).to(device) with torch.no_grad(): text_f, text_features = clip_model.model.get_text_features(text_input) image_f = clip_model.model.get_image_features(image_input.half()) condition_f, _ = clip_model.model.get_text_features(condition_batch) sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] sim_text_condition = sim_text_condition / sim_text_condition.max() mask = torch.where(sim_text_condition > 0.3, 0, float('-inf')) mask = mask.repeat(1,image_f.shape[1],1) image_features = clip_model.cross_model(image_f, text_f,mask.half())[:,0,:] image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_score = clip_model.logit_scale.exp() * text_features @ image_features.T return image_score[0] def infer_example(images, prompt, clip_model, clip_processor, tokenizer, device): scores = [] for image in images: score = infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device) scores.append(score) scores = torch.stack(scores, dim=-1) probs = torch.softmax(scores, dim=-1)[0] return probs.cpu().tolist() def acc(score_sample, predict_sample): tol_cnt = 0. true_cnt = 0. for idx in range(len(score_sample)): item_base = score_sample[idx]["rank"] item = predict_sample[idx]["rewards"] for i in range(len(item_base)): for j in range(i+1, len(item_base)): if item_base[i] > item_base[j]: if item[i] >= item[j]: tol_cnt += 1 elif item[i] < item[j]: tol_cnt += 1 true_cnt += 1 elif item_base[i] < item_base[j]: if item[i] > item[j]: tol_cnt += 1 true_cnt += 1 elif item[i] <= item[j]: tol_cnt += 1 return true_cnt / tol_cnt def inversion_score(predict_sample, score_sample): n = len(score_sample) cnt = 0 for i in range(n-1): for j in range(i+1, n): if score_sample[i] > score_sample[j] and predict_sample[i] > predict_sample[j]: cnt += 1 elif score_sample[i] < score_sample[j] and predict_sample[i] < predict_sample[j]: cnt += 1 return 1 - cnt / (n * (n - 1) / 2) def main(): processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" device = "cuda" image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) model_ckpt_path = "outputs/MPS_overall_checkpoint.pth" model = torch.load(model_ckpt_path) model.eval().to(device) score_sample = [] with open("hpdv2/test.json", "r") as f: score_sample = json.load(f) predict_sample = [] score = 0. with torch.no_grad(): for i in range(len(score_sample)): item = score_sample[i] rewards = infer_example(item["image_path"], item["prompt"], model, image_processor, tokenizer, device) score += inversion_score(rewards, item['rank']) test_acc = score / len(score_sample) print(f"HPDv2 Test Acc: {100 * test_acc:.2f}%") if __name__ == '__main__': Fire(main) ================================================ FILE: eval_overall_mps_on_imagereward.py ================================================ import numpy as np # from transformers import AutoProcessor #, AutoModel import torch from PIL import Image from io import BytesIO from tqdm.auto import tqdm from fire import Fire from transformers import CLIPFeatureExtractor, CLIPImageProcessor from dataclasses import dataclass from transformers import CLIPModel as HFCLIPModel from torch import nn, einsum from trainer.models.base_model import BaseModelConfig from transformers import CLIPConfig from transformers import AutoProcessor, AutoModel, AutoTokenizer from typing import Any, Optional, Tuple, Union import torch import cv2 import os from trainer.models.cross_modeling import Cross_model import matplotlib.pyplot as plt import torch.nn.functional as F import gc import json @torch.no_grad() def infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device, condition=None): def _process_image(image): if isinstance(image, dict): image = image["bytes"] if isinstance(image, bytes): image = Image.open(BytesIO(image)) if isinstance(image, str): image = Image.open( image ) image = image.convert("RGB") pixel_values = clip_processor(image, return_tensors="pt")["pixel_values"] return pixel_values def _tokenize(caption): input_ids = tokenizer( caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return input_ids image_input = _process_image(image).to(device) text_input = _tokenize(prompt).to(device) if condition is None: condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." condition_batch = _tokenize(condition).repeat(text_input.shape[0],1).to(device) with torch.no_grad(): text_f, text_features = clip_model.model.get_text_features(text_input) image_f = clip_model.model.get_image_features(image_input.half()) condition_f, _ = clip_model.model.get_text_features(condition_batch) sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] sim_text_condition = sim_text_condition / sim_text_condition.max() mask = torch.where(sim_text_condition > 0.3, 0, float('-inf')) mask = mask.repeat(1,image_f.shape[1],1) image_features = clip_model.cross_model(image_f, text_f,mask.half())[:,0,:] image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) image_score = clip_model.logit_scale.exp() * text_features @ image_features.T return image_score[0] def infer_example(images, prompt, clip_model, clip_processor, tokenizer, device): scores = [] for image in images: score = infer_one_sample(image, prompt, clip_model, clip_processor, tokenizer, device) scores.append(score) scores = torch.stack(scores, dim=-1) probs = torch.softmax(scores, dim=-1)[0] return probs.cpu().tolist() def acc(score_sample, predict_sample): tol_cnt = 0. true_cnt = 0. for idx in range(len(score_sample)): item_base = score_sample[idx]["ranking"] item = predict_sample[idx]["rewards"] for i in range(len(item_base)): for j in range(i+1, len(item_base)): if item_base[i] > item_base[j]: if item[i] >= item[j]: tol_cnt += 1 elif item[i] < item[j]: tol_cnt += 1 true_cnt += 1 elif item_base[i] < item_base[j]: if item[i] > item[j]: tol_cnt += 1 true_cnt += 1 elif item[i] <= item[j]: tol_cnt += 1 return true_cnt / tol_cnt def main(): processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" device = "cuda" image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path) tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) model_ckpt_path = "outputs/MPS_overall_checkpoint.pth" model = torch.load(model_ckpt_path) model.eval().to(device) score_sample = [] with open("imagereward/test.json", "r") as f: # change the path to the ImageReward test dataset score_sample = json.load(f) predict_sample = [] with torch.no_grad(): for item in score_sample: rewards = infer_example(item["generations"], item["prompt"], model, image_processor, tokenizer, device) predict_item = { "id": item["id"], "prompt": item["prompt"], "rewards": rewards } predict_sample.append(predict_item) test_acc = acc(score_sample, predict_sample) print(f"ImageReward Test Acc: {100 * test_acc:.2f}%") if __name__ == '__main__': Fire(main) ================================================ FILE: requirements.txt ================================================ accelerate @ git+https://github.com/huggingface/accelerate.git@d1aa558119859c4b205a324afabaecabd9ef375e datasets==2.10.1 deepspeed==0.8.3 fire==0.4.0 hydra-core==1.3.2 rich==13.3.2 submitit==1.4.5 transformers==4.27.3 wandb==0.12.21 ================================================ FILE: trainer/models/base_model.py ================================================ from dataclasses import dataclass @dataclass class BaseModelConfig: pass ================================================ FILE: trainer/models/clip_model.py ================================================ from dataclasses import dataclass from transformers import CLIPModel as HFCLIPModel from transformers import AutoTokenizer from torch import nn, einsum from trainer.models.base_model import BaseModelConfig from transformers import CLIPConfig from typing import Any, Optional, Tuple, Union import torch from trainer.models.cross_modeling import Cross_model import gc class XCLIPModel(HFCLIPModel): def __init__(self, config: CLIPConfig): super().__init__(config) def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # pooled_output = text_outputs[1] # text_features = self.text_projection(pooled_output) last_hidden_state = text_outputs[0] text_features = self.text_projection(last_hidden_state) pooled_output = text_outputs[1] text_features_EOS = self.text_projection(pooled_output) # del last_hidden_state, text_outputs # gc.collect() return text_features, text_features_EOS def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # pooled_output = vision_outputs[1] # pooled_output # image_features = self.visual_projection(pooled_output) last_hidden_state = vision_outputs[0] image_features = self.visual_projection(last_hidden_state) return image_features @dataclass class ClipModelConfig(BaseModelConfig): _target_: str = "trainer.models.clip_model.CLIPModel" pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32" class CLIPModel(nn.Module): def __init__(self, ckpt): super().__init__() self.model = XCLIPModel.from_pretrained(ckpt) self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16) def get_text_features(self, *args, **kwargs): return self.model.get_text_features(*args, **kwargs) def get_image_features(self, *args, **kwargs): return self.model.get_image_features(*args, **kwargs) def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None): outputs = () text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024 outputs += text_EOS, image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024 condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024 sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] sim_text_condition = sim_text_condition / sim_text_condition.max() mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77 mask = mask.repeat(1,image_f.shape[1],1) # B*257*77 bc = int(image_f.shape[0]/2) sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half()) sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half()) outputs += sim0[:,0,:], outputs += sim1[:,0,:], return outputs @property def logit_scale(self): return self.model.logit_scale def save(self, path): self.model.save_pretrained(path) ================================================ FILE: trainer/models/cross_modeling.py ================================================ import torch from torch import einsum, nn import torch.nn.functional as F from einops import rearrange, repeat # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d # normalization # they use layernorm without bias, something that pytorch does not offer class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.register_buffer("bias", torch.zeros(dim)) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.weight, self.bias) # residual class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x # rotary positional embedding # https://arxiv.org/abs/2104.09864 class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) freqs = einsum("i , j -> i j", seq, self.inv_freq) return torch.cat((freqs, freqs), dim=-1) def rotate_half(x): x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward # https://arxiv.org/abs/2002.05202 class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x # parallel attention and feedforward with residual # discovered by Wang et al + EleutherAI from GPT-J fame class ParallelTransformerBlock(nn.Module): def __init__(self, dim, dim_head=64, heads=8, ff_mult=4): super().__init__() self.norm = LayerNorm(dim) attn_inner_dim = dim_head * heads ff_inner_dim = dim * ff_mult self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) self.heads = heads self.scale = dim_head**-0.5 self.rotary_emb = RotaryEmbedding(dim_head) self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) self.ff_out = nn.Sequential( SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False) ) self.register_buffer("pos_emb", None, persistent=False) def get_rotary_embedding(self, n, device): if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: return self.pos_emb[:n] pos_emb = self.rotary_emb(n, device=device) self.register_buffer("pos_emb", pos_emb, persistent=False) return pos_emb def forward(self, x, attn_mask=None): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ n, device, h = x.shape[1], x.device, self.heads # pre layernorm x = self.norm(x) # attention queries, keys, values, and feedforward inner q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) # split heads # they use multi-query single-key-value attention, yet another Noam Shazeer paper # they found no performance loss past a certain scale, and more efficient decoding obviously # https://arxiv.org/abs/1911.02150 q = rearrange(q, "b n (h d) -> b h n d", h=h) # rotary embeddings positions = self.get_rotary_embedding(n, device) q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) # scale q = q * self.scale # similarity sim = einsum("b h i d, b j d -> b h i j", q, k) # extra attention mask - for masking out attention from text CLS token to padding if exists(attn_mask): attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j') sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) # attention sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum("b h i j, b j d -> b h i d", attn, v) # merge heads out = rearrange(out, "b h n d -> b n (h d)") return self.attn_out(out) + self.ff_out(ff) # cross attention - using multi-query + one-headed key / values as in PaLM w/ optional parallel feedforward class CrossAttention(nn.Module): def __init__( self, dim, *, context_dim=None, dim_head=64, heads=12, parallel_ff=False, ff_mult=4, norm_context=False ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 inner_dim = heads * dim_head context_dim = default(context_dim, dim) self.norm = LayerNorm(dim) self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) # whether to have parallel feedforward ff_inner_dim = ff_mult * dim self.ff = nn.Sequential( nn.Linear(dim, ff_inner_dim * 2, bias=False), SwiGLU(), nn.Linear(ff_inner_dim, dim, bias=False) ) if parallel_ff else None def forward(self, x, context, mask): """ einstein notation b - batch h - heads n, i, j - sequence length (base sequence length, source, target) d - feature dimension """ # pre-layernorm, for queries and context x = self.norm(x) context = self.context_norm(context) # get queries q = self.to_q(x) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) # scale q = q * self.scale # get key / values k, v = self.to_kv(context).chunk(2, dim=-1) # query / key similarity sim = einsum('b h i d, b j d -> b h i j', q, k) # attention mask = mask.unsqueeze(1).repeat(1,self.heads,1,1) sim = sim + mask # context mask sim = sim - sim.amax(dim=-1, keepdim=True) attn = sim.softmax(dim=-1) # aggregate out = einsum('b h i j, b j d -> b h i d', attn, v) # merge and combine heads out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) # add parallel feedforward (for multimodal layers) if exists(self.ff): out = out + self.ff(x) return out class Cross_model(nn.Module): def __init__( self, dim=512, layer_num=4, dim_head=64, heads=8, ff_mult=4 ): super().__init__() self.layers = nn.ModuleList([]) for ind in range(layer_num): self.layers.append(nn.ModuleList([ Residual(CrossAttention(dim=dim, dim_head=dim_head, heads=heads, parallel_ff=True, ff_mult=ff_mult)), Residual(ParallelTransformerBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult)) ])) def forward( self, query_tokens, context_tokens, mask ): for cross_attn, self_attn_ff in self.layers: query_tokens = cross_attn(query_tokens, context_tokens,mask) query_tokens = self_attn_ff(query_tokens) return query_tokens