Repository: Ucas-HaoranWei/GOT-OCR2.0
Branch: main
Commit: 179ed086ad6b
Files: 41
Total size: 224.5 KB
Directory structure:
gitextract_is3bdbjd/
├── GOT-OCR-2.0-master/
│ ├── GOT/
│ │ ├── __init__.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── base_dataset.py
│ │ │ └── conversation_dataset_qwen.py
│ │ ├── demo/
│ │ │ ├── process_results.py
│ │ │ ├── run_ocr_2.0.py
│ │ │ └── run_ocr_2.0_crop.py
│ │ ├── eval/
│ │ │ ├── eval_GOT_ocr.py
│ │ │ ├── evaluate_GOT.py
│ │ │ ├── multi_hardware_eval_GOT.py
│ │ │ └── pyevaltools/
│ │ │ ├── __init__.py
│ │ │ ├── eval_ocr.py
│ │ │ ├── eval_ocr_format.py
│ │ │ ├── eval_ocr_scene.py
│ │ │ └── merge_results.py
│ │ ├── model/
│ │ │ ├── GOT_ocr_2_0.py
│ │ │ ├── __init__.py
│ │ │ ├── plug/
│ │ │ │ └── blip_process.py
│ │ │ └── vision_encoder/
│ │ │ ├── __init__.py
│ │ │ └── vary_b.py
│ │ ├── train/
│ │ │ ├── train.py
│ │ │ ├── train_GOT.py
│ │ │ ├── train_flash_attn.py
│ │ │ ├── train_lora.py
│ │ │ ├── train_lora_flash_attn.py
│ │ │ ├── trainer.py
│ │ │ ├── trainer_llm_llrd.py
│ │ │ ├── trainer_vit_fixlr.py
│ │ │ └── trainer_vit_llrd.py
│ │ └── utils/
│ │ ├── arguments.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ └── utils.py
│ ├── pyproject.toml
│ ├── pyvenv.cfg
│ ├── render_tools/
│ │ ├── content-mmd-to-html.html
│ │ └── tikz.html
│ ├── results/
│ │ └── demo.html
│ └── zero_config/
│ ├── zero2.json
│ └── zero3.json
└── README.md
================================================
FILE CONTENTS
================================================
================================================
FILE: GOT-OCR-2.0-master/GOT/__init__.py
================================================
================================================
FILE: GOT-OCR-2.0-master/GOT/data/__init__.py
================================================
import torch
import transformers
from dataclasses import dataclass, field
from GOT.utils.constants import *
@dataclass
class DataCollatorForSupervisedDataset(object):
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances):
# print(instances)
# exit()
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
images = [torch.stack(instance['image']) for instance in instances]
# if 'flattened_patches' in instances[0]['image_high'][0].keys():
# images_high = [torch.stack([instance['image_high'][0]['flattened_patches']]) for instance in instances]
# else:
images_high = [torch.stack(instance['image_high']) for instance in instances]
images = list(zip(images, images_high))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_INDEX)
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
images=images,
)
return batch
def make_supervised_data_module(interleave, with_box, tokenizer, data_args):
if data_args.conversation_version == 'mpt':
from GOT.data.conversation_dataset_qwen import ConversationDataset
dataset_cls = ConversationDataset
train_dataset = dataset_cls(
tokenizer=tokenizer,
datasets=data_args.datasets,
multimodal_cfg=dict(
sep_image_conv_front=data_args.sep_image_conv_front,
image_token_len=data_args.image_token_len,
image_aspect_ratio=data_args.image_aspect_ratio,
use_im_start_end=data_args.use_im_start_end,
image_processor=data_args.image_processor,
image_processor_high = data_args.image_processor_high,
box_limit=data_args.box_limit,
)
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
================================================
FILE: GOT-OCR-2.0-master/GOT/data/base_dataset.py
================================================
import io
import os
import copy
import json
import logging
import torch
import transformers
import boto3
from typing import List, Optional, Tuple, Union, Dict, Sequence
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from GOT.utils.constants import *
class BaseDataset(Dataset):
def __init__(
self,
datasets: str,
tokenizer: transformers.PreTrainedTokenizer,
multimodal_cfg: dict
):
super(BaseDataset, self).__init__()
self.tokenizer = tokenizer
self.multimodal_cfg = multimodal_cfg
logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")
def image_processor(self, image):
# processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)
image_high = image.copy()
# Vary old codes
# # TODO the 'keep', 'padding' only used for the first processor
# if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
# max_hw, min_hw = max(image.size), min(image.size)
# aspect_ratio = max_hw / min_hw
# max_len, min_len = 448, 224
# shortest_edge = int(min(max_len / aspect_ratio, min_len))
# image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
# elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
# def expand2square(pil_img, background_color):
# width, height = pil_img.size
# if width == height:
# return pil_img
# elif width > height:
# result = Image.new(pil_img.mode, (width, width), background_color)
# result.paste(pil_img) # for simpler box processing
# return result
# else:
# result = Image.new(pil_img.mode, (height, height), background_color)
# result.paste(pil_img) # for simpler box processing
# return result
# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
# image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
# else:
# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_high = processor_high(image_high)
return image_high
def __len__(self):
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
pass
================================================
FILE: GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py
================================================
import io
import os
import copy
import json
import logging
import torch
import random
from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from GOT.data.base_dataset import BaseDataset
from GOT.utils.constants import *
from GOT.utils import conversation as conversation_lib
import boto3
import smart_open
from megfile import smart_glob
from natsort import natsorted
class ConversationDataset(BaseDataset):
"""Conversation format dataset stage2 fine-tuning."""
def __init__(self, datasets, tokenizer, multimodal_cfg):
super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
# v0 version format conversation
conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
logging.warning("Formatting inputs into conversation type: mpt-fixed")
logging.warning("Loading data...")
list_data_dict = []
list_image_path = []
# TODO add your data [data1, data2, data3, .....]
got_data_dict = {
"pdf-ocr": ["data1", "data2"],
'scene-ocr': ["data3", "data4"]
# ......
}
for name_all in datasets.split("+"):
for name in got_data_dict[name_all]:
dataset = CONVERSATION_DATA[name]
data_path = dataset['annotations']
data = json.load(open(data_path, "r"))
list_data_dict.extend(data)
image_path = dataset['images']
list_image_path.extend([image_path] * len(data))
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
assert len(list_data_dict) == len(list_image_path)
logging.warning(f"{len(list_data_dict)} conversations in total.")
a_new_list = list(zip(list_data_dict, list_image_path))
random.shuffle(a_new_list)
list_data_dict_new, list_image_path_new = zip(*a_new_list)
self.list_data_dict = list_data_dict_new
self.list_image_path = list_image_path_new
self.im_patch_token = 151859
self.im_start_token = 151857
self.im_end_token = 151858
def multimodal_processor(self, sources, flag_num_patches):
for source in sources:
if self.multimodal_cfg['sep_image_conv_front']:
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
for sentence in source:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']*flag_num_patches
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# sentence["value"] = str(sentence["value"]).replace('\qquad', '\quad')
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def _tokenize_fn(self, strings):
"""Tokenize a list of strings."""
tokenized_list = [
self.tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(self, target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker.lower() == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def token_processor(self, sources, image_name):
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = self.tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=self.tokenizer.model_max_length,
truncation=True,
).input_ids
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
# round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer))
# instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
instruction_len = len(self.tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < self.tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
print(image_name)
return dict(
input_ids=input_ids,
labels=targets,
)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# data = self.list_data_dict[i]
data = copy.deepcopy(self.list_data_dict[i])
if isinstance(data, dict):
image_list = []
image_high_list = []
flag_num_patches = 1
if 'image' in data:
image_path = self.list_image_path[i]
image_file = data['image']
# multi-crop or multi page, only support .png files
if ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path):
if image_file[0] == '/':
patch_dir = image_path[:-1] + image_file
patches = smart_glob(patch_dir + '*.png')
else:
patch_dir = image_path + image_file
patches = smart_glob(patch_dir + '*.png')
# print(patches)
if not patches:
print(f'cannot glob the dir {patch_dir}.')
return self.__getitem__(0)
# sort multi images by name
patches = natsorted(patches)
flag_num_patches = len(patches)
for patch in patches:
try:
image = Image.open(patch).convert('RGB')
except:
print(f'cannot identify image file {patch}.')
return self.__getitem__(0)
try:
img = self.image_processor(image)
image_list.append(img)
image_high_list.append(img)
except:
print(f'image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
else:
flag_num_patches = 1
try:
image = Image.open(image_path + image_file).convert('RGB')
except:
print(f'cannot identify image file {image_file}.')
return self.__getitem__(0)
try:
image = self.image_processor(image)
except:
print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
return self.__getitem__(0)
conversations = self.multimodal_processor([data["conversations"]], flag_num_patches)
# print(conversations)
# exit()
else:
conversations = [data]
# align with fastchat & llava here, put the conversation into a list for tokenization
image_name = image_path + image_file
data_dict = self.token_processor(conversations, image_name)
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
if isinstance(data, dict) and 'image' in data:
if image_list and image_high_list:
data_dict['image'] = image_list
data_dict['image_high'] = image_high_list
else:
data_dict['image'] = [image]
data_dict['image_high'] = [image]
else:
# crop_size = self.multimodal_cfg['image_processor'].crop_size
# data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
# Vary for two image, GOT does not use the data_dict['image]
data_dict['image'] = [torch.zeros(3, 1024, 1024)]
data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
return data_dict
================================================
FILE: GOT-OCR-2.0-master/GOT/demo/process_results.py
================================================
import string
punctuation_dict = {
",": ",",
"。": ".",
}
# import os
def svg_to_html(svg_content, output_filename):
html_content = f"""
SVG Embedded in HTML
"""
with open(output_filename, 'w') as file:
file.write(html_content)
================================================
FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import string
DEFAULT_IMAGE_TOKEN = ""
DEFAULT_IMAGE_PATCH_TOKEN = ''
DEFAULT_IM_START_TOKEN = ''
DEFAULT_IM_END_TOKEN = ''
translation_table = str.maketrans(punctuation_dict)
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
model.to(device='cuda', dtype=torch.bfloat16)
# TODO vary old codes, NEED del
image_processor = BlipImageEvalProcessor(image_size=1024)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
image = load_image(args.image_file)
w, h = image.size
# print(image.size)
if args.type == 'format':
qs = 'OCR with format: '
else:
qs = 'OCR: '
if args.box:
bbox = eval(args.box)
if len(bbox) == 2:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
if len(bbox) == 4:
bbox[0] = int(bbox[0]/w*1000)
bbox[1] = int(bbox[1]/h*1000)
bbox[2] = int(bbox[2]/w*1000)
bbox[3] = int(bbox[3]/h*1000)
if args.type == 'format':
qs = str(bbox) + ' ' + 'OCR with format: '
else:
qs = str(bbox) + ' ' + 'OCR: '
if args.color:
if args.type == 'format':
qs = '[' + args.color + ']' + ' ' + 'OCR with format: '
else:
qs = '[' + args.color + ']' + ' ' + 'OCR: '
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
inputs = tokenizer([prompt])
# vary old codes, no use
image_1 = image.copy()
image_tensor = image_processor(image)
image_tensor_1 = image_processor_high(image_1)
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=False,
num_beams = 1,
no_repeat_ngram_size = 20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
if args.render:
print('==============rendering===============')
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
if '**kern' in outputs:
import verovio
from cairosvg import svg2png
import cv2
import numpy as np
tk = verovio.toolkit()
tk.loadData(outputs)
tk.setOptions({"pageWidth": 2100, "footer": 'none',
'barLineWidth': 0.5, 'beamMaxSlope': 15,
'staffLineWidth': 0.2, 'spacingStaff': 6})
tk.getPageCount()
svg = tk.renderToSVG()
svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
svg_to_html(svg, "./results/demo.html")
if args.type == 'format' and '**kern' not in outputs:
if '\\begin{tikzpicture}' not in outputs:
html_path = "./render_tools/" + "/content-mmd-to-html.html"
html_path_2 = "./results/demo.html"
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
else:
html_path = "./render_tools/" + "/tikz.html"
html_path_2 = "./results/demo.html"
outputs = outputs.translate(translation_table)
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
if out:
if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
while out[-1] == ' ':
out = out[:-1]
if out is None:
break
if out:
if out[-1] != ';':
gt += out[:-1] + ';\n'
else:
gt += out + '\n'
else:
gt += out + '\n'
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--type", type=str, required=True)
parser.add_argument("--box", type=str, default= '')
parser.add_argument("--color", type=str, default= '')
parser.add_argument("--render", action='store_true')
args = parser.parse_args()
eval_model(args)
================================================
FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from natsort import natsorted
import glob
DEFAULT_IMAGE_TOKEN = ""
DEFAULT_IMAGE_PATCH_TOKEN = ''
DEFAULT_IM_START_TOKEN = ''
DEFAULT_IM_END_TOKEN = ''
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
model.to(device='cuda', dtype=torch.bfloat16)
# vary old codes, no use
image_processor = BlipImageEvalProcessor(image_size=1024)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
image_token_len = 256
image_list = []
if args.multi_page:
qs = 'OCR with format across multi pages: '
# only for png files
patches = glob.glob(args.image_file + '/*png')
patches = natsorted(patches)
sub_images = []
for sub_image in patches:
sub_images.append(load_image(sub_image))
ll = len(patches)
else:
qs = 'OCR with format upon the patch reference: '
img = load_image(args.image_file)
sub_images = dynamic_preprocess(img)
ll = len(sub_images)
for p in sub_images:
image = p
image_1 = image.copy()
# no use, vary old codes
image_tensor = image_processor(image)
image_tensor_1 = image_processor_high(image_1)
image_list.append(image_tensor_1)
image_list = torch.stack(image_list)
print('====new images batch size======: ',image_list.shape)
# qs = args.query
if use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
conv_mode = "mpt"
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_list.half().cuda(), image_list.half().cuda())],
do_sample=False,
num_beams = 1,
# no_repeat_ngram_size = 20,
streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
if args.render:
print('==============rendering===============')
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
html_path = "./render_tools/" + "/content-mmd-to-html.html"
html_path_2 = "./results/demo.html"
right_num = outputs.count('\\right')
left_num = outputs.count('\left')
if right_num != left_num:
outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
outputs = outputs.replace('"', '``').replace('$', '')
outputs_list = outputs.split('\n')
gt= ''
for out in outputs_list:
gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
gt = gt[:-2]
with open(html_path, 'r') as web_f:
lines = web_f.read()
lines = lines.split("const text =")
new_web = lines[0] + 'const text =' + gt + lines[1]
with open(html_path_2, 'w') as web_f_new:
web_f_new.write(new_web)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--image-file", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--multi-page", action='store_true')
parser.add_argument("--render", action='store_true')
args = parser.parse_args()
eval_model(args)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from tqdm import tqdm
from PIL import Image
import json
import os
import requests
from PIL import Image
from io import BytesIO
import math
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria
from PIL import Image
import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from GOT.model.plug.transforms import train_transform, test_transform
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
DEFAULT_IMAGE_TOKEN = ""
DEFAULT_IMAGE_PATCH_TOKEN = ''
DEFAULT_IM_START_TOKEN = ''
DEFAULT_IM_END_TOKEN = ''
import string
translation_table = str.maketrans(punctuation_dict)
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# print(blocks)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
output_list = []
def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()
# vary old codes, no use
image_processor = BlipImageEvalProcessor(image_size=1024)
# image_processor_high = BlipImageEvalProcessor(image_size=1280)
image_processor_high = BlipImageEvalProcessor(image_size=1024)
use_im_start_end = True
# image_token_len = 400
image_token_len = 256
gts_path = args.gtfile_path
gts = json.load(open(gts_path))
# gts = gts[0]
print("Generate Results......")
if "OCR" in args.datatype:
gts = get_chunk(gts, args.num_chunks, args.chunk_idx)
for ann in tqdm(gts):
output_json = {}
if "OCR" in args.datatype:
qs = ann["conversations"][0]["value"]
else:
qs = ann["question"]
# ans = ann["answers"][0]
qs2 = qs
image_file = ann["image"]
if 'Text' in args.datatype:
image_file = image_file + '.jpg'
if "VQAv2" in args.datatype:
image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'
if "Cap" in args.datatype:
image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'
image_file_path = os.path.join(args.image_path, image_file)
# print(image_file_path)
# exit()
# qs = args.query
# if mm_use_im_start_end:
multi_crop = False
if multi_crop:
image_list = []
# qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + 'OCR with format upon the patch reference: '
img = load_image(image_file_path)
sub_images = dynamic_preprocess(img)
ll = len(sub_images)
for p in sub_images:
image = p
image_1 = image.copy()
# vary old code, NO USE
image_tensor = image_processor_high(image_1)
# image_tensor_1 = image_processor_high.preprocess(image_1, return_tensors='pt')['pixel_values'][0]
image_tensor_1 = image_processor_high(image_1)
image_list.append(image_tensor_1)
# print(image_tensor_1.shape)
image_list = torch.stack(image_list)
else:
ll = 1
image = load_image(image_file_path)
image_1 = image.copy()
# image_1 = image_1.resize((1024, 1024))
# vary old code, NO USE
image_tensor = image_processor_high(image_1)
image_tensor_1 = image_processor_high(image_1)
# image_tensor_1 = torch.zeros(3, 1024, 1024)
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + 'OCR with format: '
conv_mode = "mpt"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
if multi_crop:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_list.half().cuda(), image_list.half().cuda())],
do_sample=False,
num_beams = 1,
# temperature=0.2,
# no_repeat_ngram_size = 20,
# streamer=streamer,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
output_ids = model.generate(
input_ids,
images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
do_sample=False,
num_beams = 1,
# temperature=0.2,
no_repeat_ngram_size = 20,
# encoder_repetition_penalty = 1.2,
# penalty_alpha=0.2,
# top_k=3,
max_new_tokens=4096,
stopping_criteria=[stopping_criteria]
)
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
# outputs = outputs.strip()[:-1]
if "Cap" in args.datatype:
# output_json['image'] = ann["image"]
output_json['image_id'] = ann["id"]
output_json["caption"] = outputs
else:
# output_json['questionId'] = qs_id
# output_json['question_id'] = qs_id
output_json['image'] = ann["image"]
output_json['question'] = qs
output_json['label'] = ann["conversations"][1]["value"]
output_json['answer'] = outputs
output_list.append(output_json)
filename = args.out_path + "/results_" + str(args.chunk_idx) + ".json"
with open(filename, 'w', encoding="utf-8") as file_obj:
json.dump(output_list, file_obj, ensure_ascii=False, indent=1)
# print(outputs)
# print("Evaluate Results... ")
# doc_text_eval(gts_path, filename, args.datatype)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--gtfile_path", type=str, required=True)
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True) # Text or Doc
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
# parser.add_argument("--query", type=str, required=True)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
print(args)
eval_model(args)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/evaluate_GOT.py
================================================
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--gtfile_path", type=str, required=True)
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--datatype", type=str, required=True) # Text\Doc\VQAv2\Cap
# parser.add_argument("--eval", type=str, required=True)
args = parser.parse_args()
os.system("python3 -m GOT.eval.multi_hardware_eval_GOT" + " "
+ "--model-name" + " " + args.model_name + " "
+ "--gtfile_path" + " " + args.gtfile_path + " "
+ "--image_path" + " " + args.image_path + " "
+ "--out_path" + " " + args.out_path + " "
+ "--num-chunks" + " " + str(args.num_chunks) + " "
+ "--temperature" + " " + str(args.temperature) + " "
+ "--datatype" + " " + args.datatype
)
print("Evaluating.....")
os.system("python3 -m GOT.eval.pyevaltools.merge_results" + " "
+ "--out_path" + " " + args.out_path)
# if args.datatype == "OCR":
a_type = 'plain' # 'palin'; 'format'; 'scene'
if a_type == 'plain':
os.system("python3 -m GOT.eval.pyevaltools.eval_ocr" + " "
+ "--out_path" + " " + args.out_path + " "
+ "--gt_path" + " " + args.gtfile_path + " "
+ "--datatype" + " " + args.datatype
)
if a_type == 'format':
os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_format" + " "
+ "--out_path" + " " + args.out_path + " "
+ "--gt_path" + " " + args.gtfile_path + " "
+ "--datatype" + " " + args.datatype
)
if a_type == 'scene':
os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_scene" + " "
+ "--out_path" + " " + args.out_path + " "
+ "--gt_path" + " " + args.gtfile_path + " "
+ "--datatype" + " " + args.datatype
)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py
================================================
import os
import argparse
from multiprocessing import Pool
# from GOT.eval.merge_results import merge_outputs
# from GOT.eval.doctextVQA import doc_text_eval
def run_eval(chunk_id, model_name, gtfile_path, image_path, out_path, num_chunks, datatype, temperature):
os.system("CUDA_VISIBLE_DEVICES=" + str(chunk_id) + " "
+ "python3 -m GOT.eval.eval_GOT_ocr" + " "
+ "--model-name" + " " + model_name + " "
+ "--gtfile_path" + " " + gtfile_path + " "
+ "--image_path" + " " + image_path + " "
+ "--out_path" + " " + out_path + " "
+ "--num-chunks" + " " + str(num_chunks) + " "
+ "--chunk-idx" + " " + str(chunk_id) + " "
+ "--temperature" + " " + str(temperature) + " "
+ "--datatype" + " " + datatype
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--gtfile_path", type=str, required=True)
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--datatype", type=str, required=True) # Text or Doc
# parser.add_argument("--eval", type=str, required=True)
args = parser.parse_args()
num_chunks = args.num_chunks
if os.path.exists(args.out_path) == False:
os.makedirs(args.out_path)
with Pool(num_chunks) as p:
for i in range(num_chunks):
chunk_id = i
p.apply_async(run_eval, (chunk_id, args.model_name, args.gtfile_path,
args.image_path, args.out_path, num_chunks, args.datatype, args.temperature))
p.close()
p.join()
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/__init__.py
================================================
author='aagrawal'
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py
================================================
import json
# from doctextVQAeval import VQAEval
import argparse
# import fitz as pymupdf
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
# from loguru import logger
# nltk.download('wordnet')
from nltk.translate import meteor_score
# from marker_scoring import score_text
# from utils import contain_chinese_string
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()
def preprocess(text, predict_root_):
if 'InternVL' in predict_root_:
text = text.split("All words in the image:\n")[1]
text = text.split("[UNUSED_TOKEN_145]")[0]
return text
def contain_chinese_string(text):
# 使用正则表达式匹配中文字符
chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
return bool(chinese_pattern.search(text))
inline_reg = re.compile(r"\\\((.*?)(?= 1:
# try:
metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)
# except LookupError:
# metrics["meteor"] = np.nan
reference = set(reference)
hypothesis = set(hypothesis)
metrics["f_measure"] = f_measure(reference, hypothesis)
if heavy_mode >= 1:
metrics["precision"] = precision(reference, hypothesis)
metrics["recall"] = recall(reference, hypothesis)
if heavy_mode == 2:
# 速度太慢
metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
return metrics
def doc_formated_text_eval(gt_root_, predict_root_, datatype):
predicts = json.load(open(predict_root_, encoding='utf-8'))
# print(predicts)
gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')
pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')
text_results = []
math_results = []
table_results = []
for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):
# try:
# text, math, table
text_gts, text_pres = gt0, pre0
math_gts, math_pres = gt1, pre1
table_gts, table_pres = gt2, pre2
# for text_gt, text_pre in zip(text_gts, text_pres):
ans = nougat_per_metrics(predict_root_, text_gts, text_pres)
# if len(ans) == 0:
# continue
if ans:
text_results.append(ans)
# for math_gt, math_pre in zip(math_gts, math_pres):
ans = nougat_per_metrics(predict_root_, math_gts, math_pres)
# if len(ans) == 0:
# continue
if ans:
math_results.append(ans)
# for table_gt, table_pre in zip(table_gts, table_pres):
ans = nougat_per_metrics(predict_root_, table_gts, table_pres)
# if len(ans) == 0:
# continue
if ans:
table_results.append(ans)
mean_dict = {}
# print((result))
# print(len(result))
mean_dict["eval question num"] = len(text_results)
mean_dict['text'] = {}
mean_dict['math'] = {}
mean_dict['table'] = {}
for k, v in text_results[0].items():
mean_dict['text'][k] = 0
mean_dict['math'][k] = 0
mean_dict['table'][k] = 0
for each in text_results:
for k, v in each.items():
mean_dict['text'][k] += v
for each in math_results:
for k, v in each.items():
mean_dict['math'][k] += v
for each in table_results:
for k, v in each.items():
mean_dict['table'][k] += v
for k, v in mean_dict['text'].items():
mean_dict['text'][k] /= len(text_results)
for k, v in mean_dict['math'].items():
mean_dict['math'][k] /= len(math_results)
for k, v in mean_dict['table'].items():
mean_dict['table'][k] /= len(table_results)
print(json.dumps(mean_dict, indent=4))
def doc_text_eval(gt_root_, predict_root_, datatype):
predicts = json.load(open(predict_root_, encoding='utf-8'))
# print(predicts)
result = []
for ann in predicts:
try:
ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
if len(ans) == 0:
continue
result.append(ans)
except:
assert False, print("ERROR!!! Check yout output!!!")
mean_dict = {}
# print((result))
# print(len(result))
mean_dict["eval question num"] = len(result)
for k, v in result[0].items():
mean_dict[k] = 0
for each in result:
for k, v in each.items():
mean_dict[k] += v
for k, v in mean_dict.items():
if k == "eval question num":
continue
mean_dict[k] /= len(result)
print(json.dumps(mean_dict, indent=4))
# doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc")
# doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)
doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py
================================================
import json
# from doctextVQAeval import VQAEval
import argparse
# import fitz as pymupdf
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
# from loguru import logger
# nltk.download('wordnet')
from nltk.translate import meteor_score
# from marker_scoring import score_text
# from utils import contain_chinese_string
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()
def preprocess(text, predict_root_):
if 'InternVL' in predict_root_:
text = text.split("All words in the image:\n")[1]
text = text.split("[UNUSED_TOKEN_145]")[0]
return text
def contain_chinese_string(text):
# 使用正则表达式匹配中文字符
chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
return bool(chinese_pattern.search(text))
inline_reg = re.compile(r"\\\((.*?)(?= 1:
# try:
metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)
# except LookupError:
# metrics["meteor"] = np.nan
reference = set(reference)
hypothesis = set(hypothesis)
metrics["f_measure"] = f_measure(reference, hypothesis)
if heavy_mode >= 1:
metrics["precision"] = precision(reference, hypothesis)
metrics["recall"] = recall(reference, hypothesis)
if heavy_mode == 2:
# 速度太慢
metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
return metrics
def doc_formated_text_eval(gt_root_, predict_root_, datatype):
predicts = json.load(open(predict_root_, encoding='utf-8'))
# print(predicts)
gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')
pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')
text_results = []
math_results = []
table_results = []
for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):
# try:
# text, math, table
text_gts, text_pres = gt0, pre0
math_gts, math_pres = gt1, pre1
table_gts, table_pres = gt2, pre2
# for text_gt, text_pre in zip(text_gts, text_pres):
ans = nougat_per_metrics(predict_root_, text_gts, text_pres)
# if len(ans) == 0:
# continue
if ans:
text_results.append(ans)
# for math_gt, math_pre in zip(math_gts, math_pres):
ans = nougat_per_metrics(predict_root_, math_gts, math_pres)
# if len(ans) == 0:
# continue
if ans:
math_results.append(ans)
# for table_gt, table_pre in zip(table_gts, table_pres):
ans = nougat_per_metrics(predict_root_, table_gts, table_pres)
# if len(ans) == 0:
# continue
if ans:
table_results.append(ans)
mean_dict = {}
# print((result))
# print(len(result))
mean_dict["eval question num"] = len(text_results)
mean_dict['text'] = {}
mean_dict['math'] = {}
mean_dict['table'] = {}
for k, v in text_results[0].items():
mean_dict['text'][k] = 0
mean_dict['math'][k] = 0
mean_dict['table'][k] = 0
for each in text_results:
for k, v in each.items():
mean_dict['text'][k] += v
for each in math_results:
for k, v in each.items():
mean_dict['math'][k] += v
for each in table_results:
for k, v in each.items():
mean_dict['table'][k] += v
for k, v in mean_dict['text'].items():
mean_dict['text'][k] /= len(text_results)
for k, v in mean_dict['math'].items():
mean_dict['math'][k] /= len(math_results)
for k, v in mean_dict['table'].items():
mean_dict['table'][k] /= len(table_results)
print(json.dumps(mean_dict, indent=4))
def doc_text_eval(gt_root_, predict_root_, datatype):
predicts = json.load(open(predict_root_, encoding='utf-8'))
# print(predicts)
result = []
for ann in predicts:
try:
ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
if len(ans) == 0:
continue
result.append(ans)
except:
assert False, print("ERROR!!! Check yout output!!!")
mean_dict = {}
# print((result))
# print(len(result))
mean_dict["eval question num"] = len(result)
for k, v in result[0].items():
mean_dict[k] = 0
for each in result:
for k, v in each.items():
mean_dict[k] += v
for k, v in mean_dict.items():
if k == "eval question num":
continue
mean_dict[k] /= len(result)
print(json.dumps(mean_dict, indent=4))
# doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc")
doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)
# doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py
================================================
import json
import argparse
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
from nltk.translate import meteor_score
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()
def preprocess(text, predict_root_):
if 'InternVL' in predict_root_:
text = text.split("All words in the image:\n")[1]
text = text.split("[UNUSED_TOKEN_145]")[0]
return text
def contain_chinese_string(text):
chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
return bool(chinese_pattern.search(text))
def nougat_per_metrics(predict_root_, pred, gt, minlen=1):
metrics = {}
if len(pred) < minlen or len(gt) < minlen:
return metrics
reference = list(gt)
hypothesis = list(pred)
metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)
metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)
reference = set(reference)
hypothesis = set(hypothesis)
metrics["f_measure"] = f_measure(reference, hypothesis)
metrics["precision"] = precision(reference, hypothesis)
metrics["recall"] = recall(reference, hypothesis)
metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
return metrics
def doc_text_eval(gt_root_, predict_root_, datatype):
predicts = json.load(open(predict_root_, encoding='utf-8'))
result = []
for ann in predicts:
try:
ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
if len(ans) == 0:
continue
result.append(ans)
except:
assert False, print("ERROR!!! Check yout output!!!")
mean_dict = {}
mean_dict["eval question num"] = len(result)
for k, v in result[0].items():
mean_dict[k] = 0
for each in result:
for k, v in each.items():
mean_dict[k] += v
for k, v in mean_dict.items():
if k == "eval question num":
continue
mean_dict[k] /= len(result)
print(json.dumps(mean_dict, indent=4))
doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)
================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py
================================================
import os
import json
import argparse
def merge_outputs(out_path):
files = os.listdir(out_path)
# print(files)
alist = []
for file in files:
alist += json.load(open(os.path.join(out_path, file), encoding='utf-8'))
# print(len(alist))
filename = out_path + "/results_final" + ".json"
with open(filename, 'w', encoding="utf-8") as file_obj:
json.dump(alist, file_obj, ensure_ascii=False, indent=1)
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, required=True)
args = parser.parse_args()
merge_outputs(args.out_path)
================================================
FILE: GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py
================================================
from transformers import AutoConfig, AutoModelForCausalLM, \
Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \
CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from GOT.utils.constants import *
from GOT.model.vision_encoder.vary_b import build_vary_vit_b
from GOT.model.plug.blip_process import BlipImageEvalProcessor
class GOTConfig(Qwen2Config):
model_type = "GOT"
class GOTQwenModel(Qwen2Model):
config_class = GOTConfig
def __init__(self, config: Qwen2Config):
super(GOTQwenModel, self).__init__(config)
self.vision_tower_high = build_vary_vit_b()
self.mm_projector_vary = nn.Linear(1024, 1024)
def initialize_vision_modules(
self,
vision_tower,
pretrained_stage1_model=None,
freeze_vision_tower=False,
use_im_start_end=False,
vision_select_layer=-1,
dtype=torch.float16,
device="cuda"
):
# Vary old codes, not use in GOT
image_processor = BlipImageEvalProcessor(image_size=1024)
# 1024*1024
image_processor_high = BlipImageEvalProcessor(image_size=1024)
self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
image_token_len = 256
self.config.vision_tower = vision_tower
self.config.image_token_len = image_token_len
# self.config.use_im_start_end = use_im_start_end
self.config.use_im_start_end = True
self.config.vision_select_layer = vision_select_layer
self.config.freeze_vision_tower = freeze_vision_tower
return dict(
image_processor=image_processor,
image_processor_high=image_processor_high,
image_token_len=image_token_len,
)
# def get_input_embeddings(self, x):
# return self.wte(x)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if orig_embeds_params is not None:
with torch.no_grad():
self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower_high = getattr(self, 'vision_tower_high', None)
if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
# if True:
# assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images")
# print(im)
use_im_start_end = getattr(self.config, "use_im_start_end", -1)
vision_select_layer = getattr(self.config, "vision_select_layer", -1)
im_patch_token = getattr(self.config, "im_patch_token", -1)
im_start_token = getattr(self.config, "im_start_token", -1)
im_end_token = getattr(self.config, "im_end_token", -1)
freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
im_patch_token = 151859
im_start_token = 151857
im_end_token = 151858
image_features = []
for image in images:
P, C, H, W = image[1].shape
# with torch.set_grad_enabled(True):
# # print(image[1].shape)
# cnn_feature = vision_tower_high(image[1])
# cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256 1024
# # image_features.append(cnn_feature)
# image_features_2.append(cnn_feature)
if P == 1:
with torch.set_grad_enabled(False):
# print(image[1].shape)
cnn_feature = vision_tower_high(image[1])
cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
# image_features.append(cnn_feature)
# image_features_2.append(cnn_feature)
image_feature = self.mm_projector_vary(cnn_feature)
image_features.append(image_feature)
else:
image_patches = torch.unbind(image[1])
image_patches_features = []
for image_patch in image_patches:
image_p = torch.stack([image_patch])
with torch.set_grad_enabled(False):
cnn_feature_p = vision_tower_high(image_p)
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
image_feature_p = self.mm_projector_vary(cnn_feature_p)
image_patches_features.append(image_feature_p)
image_feature = torch.cat(image_patches_features, dim=1)
# print(P)
# print(image_feature.shape)
# exit()
image_features.append(image_feature)
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
# dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
dummy_image_features = dummy_image_features_2
use_im_start_end = True
new_input_embeds = []
for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
if (cur_input_ids == im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if use_im_start_end:
if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
raise ValueError("The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
num_patches = per_cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
raise ValueError("The image end token should follow the image start token.")
cur_input_embeds = torch.cat(
(
cur_input_embeds[:image_start_token_pos+1],
per_cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]
),
dim=0
)
new_input_embeds.append(cur_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return super(GOTQwenModel, self).forward(
input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict
)
class GOTQwenForCausalLM(Qwen2ForCausalLM):
config_class = GOTConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = GOTQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, GOTQwenModel):
# module.gradient_checkpointing = value
# @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
# @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print(input_ids)
# print(len(images))
# print(inputs_embeds)
outputs = self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
images=images,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
# logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def initialize_vision_tokenizer(
self,
tokenizer,
freeze_lm_model=False,
pretrained_stage1_model=None,
device="cuda"
):
config = self.get_model().config
# add image patch token
# tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
# config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
config.im_patch_token = 151859
config.use_im_start_end = True
# add image start token and end token
if config.use_im_start_end:
# num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
# config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
config.im_start_token, config.im_end_token = 151857, 151858
AutoConfig.register("GOT", GOTConfig)
AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)
================================================
FILE: GOT-OCR-2.0-master/GOT/model/__init__.py
================================================
from .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig
================================================
FILE: GOT-OCR-2.0-master/GOT/model/plug/blip_process.py
================================================
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import cv2
import numpy as np
import torch
# from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
# @classmethod
# def from_config(cls, cfg=None):
# return cls()
# def build(self, **kwargs):
# cfg = OmegaConf.create(kwargs)
# return self.from_config(cfg)
class BlipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
# mean = (0.0, 0.0, 0.0)
# std = (1.0, 1.0, 1.0)
self.normalize = transforms.Normalize(mean, std)
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
"""
same output as PIL.ImageOps.autocontrast
"""
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
"""
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
"""
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
"""
like PIL, rotate by degree, not radians
"""
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
"""
same output as PIL.ImageOps.posterize
"""
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
"""
same output as PIL.ImageEnhance.Color
"""
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = np.float32(
[[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
) * factor + np.float32([[0.114], [0.587], [0.299]])
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = (
np.array([(el - mean) * factor + mean for el in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
out = table[img]
return out
def brightness_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
"""
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
"""
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
"""
same output as PIL.Image.transform
"""
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def posterize_func(img, bits):
"""
same output as PIL.ImageOps.posterize
"""
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(
img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level,)
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level,)
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
"Identity": identity_func,
"AutoContrast": autocontrast_func,
"Equalize": equalize_func,
"Rotate": rotate_func,
"Solarize": solarize_func,
"Color": color_func,
"Contrast": contrast_func,
"Brightness": brightness_func,
"Sharpness": sharpness_func,
"ShearX": shear_x_func,
"TranslateX": translate_x_func,
"TranslateY": translate_y_func,
"Posterize": posterize_func,
"ShearY": shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
"Identity": none_level_to_args,
"AutoContrast": none_level_to_args,
"Equalize": none_level_to_args,
"Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
"Solarize": solarize_level_to_args(MAX_LEVEL),
"Color": enhance_level_to_args(MAX_LEVEL),
"Contrast": enhance_level_to_args(MAX_LEVEL),
"Brightness": enhance_level_to_args(MAX_LEVEL),
"Sharpness": enhance_level_to_args(MAX_LEVEL),
"ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
"TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
"Posterize": posterize_level_to_args(MAX_LEVEL),
"ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
class VideoRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert (
frames.shape[-1] == 3
), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(
list(map(self._aug, frames, ops, apply_or_not)), dim=0
).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
# if __name__ == "__main__":
# a = RandomAugment()
# img = np.random.randn(32, 32, 3)
# a(img)
class BlipImageTrainProcessor(BlipImageBaseProcessor):
def __init__(
self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, max_scale),
interpolation=InterpolationMode.BICUBIC,
),
# transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
# "AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
# "ShearX",
# "ShearY",
# "TranslateX",
# "TranslateY",
# "Rotate",
],
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
# if __name__ == "__main__":
# a = BlipImageTrainProcessor(image_size=1024)
# # img = np.random.randn(1024, 1024, 3)
# # x = torch.zeros(1024, 1024, 3)
# x = Image.open("/data/codes/GOT-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
# print(x.size)
# y = a(x)
# print(y.size())
================================================
FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/__init__.py
================================================
================================================
FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
import torch
import torch.nn as nn
from typing import Type
# from GOT.model.vision_encoder.vitg_qwen import Resampler
import math
class Projector(nn.Module):
def __init__(
self,
width: 256,
n_queries: int = 256,
output_dim: int = 4096,
**kwargs
):
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
def forward(self, x: torch.Tensor):
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x = self.net_2(x)
x = self.net_3(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_vary_vit_b(checkpoint=None):
return _build_vary(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_vary(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
# if checkpoint is not None:
# # with open(checkpoint, "rb") as f:
# state_dict = torch.load(checkpoint)
# # print(state_dict.keys())
# # for key in state_dict:
# # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
# # ocr-anyting
# # image_encoder.load_state_dict(state_dict, strict=True)
# # tob
# # model.vision_tower.
# image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)
# print(checkpoint)
return image_encoder
if __name__ == '__main__':
x = torch.zeros(2, 3, 1024, 1024)
# x.permute(0, 3, 1, 2)
net = build_vary_vit_b(checkpoint ='/mnt/shared-storage/tenant/hypertext/xpkong/jycode/checkpoint/pytorch_model.bin')
# mlp = Projector(width=256, n_queries = 256, output_dim = 768)
y = net(x)
y = y.flatten(2).permute(0, 2, 1)
print(y.shape)
# y = mlp(y)
# y = net_2(y)
# y = net_3(y)
#
# print(y.shape)
================================================
FILE: GOT-OCR-2.0-master/GOT/train/train.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import torch
import transformers
# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import GOTLlamaForCausalLM
from GOT.model import *
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import smart_tokenizer_and_embedding_resize
from GOT.model.vision_encoder.sam import build_sam_vit_b
from GOT.model.vision_encoder.swin_transformer import build_swin_transformer
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = GOTLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
'/data/hypertext/xpkong/newcode/checkpoints/kly-vary-1025-cc595-pretrain/',
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)
# # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval()
# model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')
if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path:
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens({
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
})
else:
tokenizer.pad_token = tokenizer.unk_token
# tokenizer.pad_token = DEFAULT_UNK_TOKEN
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.add_special_tokens({'pad_token':'<|endoftext|>'})
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
vision_tower_dict = model.get_model().initialize_vision_modules(
vision_tower=model_args.vision_tower,
pretrained_stage1_model=model_args.pretrained_stage1_model,
freeze_vision_tower=model_args.freeze_vision_tower,
use_im_start_end=model_args.use_im_start_end,
vision_select_layer=model_args.vision_select_layer,
dtype=dtype,
device=training_args.device
)
model.initialize_vision_tokenizer(
tokenizer=tokenizer,
freeze_lm_model=model_args.freeze_lm_model,
pretrained_stage1_model=model_args.pretrained_stage1_model,
device=training_args.device,
)
model.get_model().vision_tower = transformers.CLIPVisionModel.from_pretrained(
'/data/public/ucaswei/pretrain/vit-large-patch14')
model.get_model().vision_tower_high = build_sam_vit_b(checkpoint='/data/hypertext/xpkong/newcode/checkpoints/kly-sam-opt-all-1023-new/pytorch_model.bin')
# model.get_model().mm_projector = create_perciever()
model.to(dtype=dtype, device=training_args.device)
# 'image_processor_high
# data_args.image_token_len = vision_tower_dict['image_token_len']
data_args.image_token_len = 256
data_args.image_processor = vision_tower_dict['image_processor']
data_args.image_processor_high = vision_tower_dict['image_processor_high']
data_args.use_im_start_end = model_args.use_im_start_end
# mixed relation, to be fixed
if model_args.freeze_lm_model:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
# for p in model.get_model().vision_encoder.parameters():
# p.requires_grad = True
# for p in model.get_model().chatt.parameters():
# p.requires_grad = True
for p in model.get_input_embeddings().parameters():
p.requires_grad = True
# conv_final
# for p in model.get_model().conv_final.parameters():
# p.requires_grad = True
if not model_args.freeze_vision_tower:
model.get_model().vision_tower.requires_grad_(True)
# for i in range(20):
# model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)
# model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)
# model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)
# model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)
# model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)
# for p in model.get_model().vision_encoder.parameters():
# p.requires_grad = True
# for n, p in model.named_parameters():
# print(n, p.requires_grad)
if model_args.freeze_vision_tower:
model.get_model().vision_tower.requires_grad_(False)
params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
# params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
# if len(params_no_grad) > 0:
# if training_args.fsdp is not None and len(training_args.fsdp) > 0:
# if len(params_no_grad) < 10:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
# else:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
# print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
# print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
# from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# def patch_FSDP_use_orig_params(func):
# def wrap_func(*args, **kwargs):
# use_orig_params = kwargs.pop('use_orig_params', True)
# return func(*args, **kwargs, use_orig_params=use_orig_params)
# return wrap_func
# FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
data_module = make_supervised_data_module(
interleave=training_args.interleave,
with_box=training_args.with_box,
tokenizer=tokenizer,
data_args=data_args
)
trainer = GOTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer._safe_save(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_GOT.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import torch
# torch.set_num_threads(1)
import transformers
# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import *
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import smart_tokenizer_and_embedding_resize
from GOT.model.vision_encoder.vary_b import build_vary_vit_b
import os
# os.environ['NCCL_IB_DISABLE'] = '1'
os.environ['NCCL_DEBUG'] = 'INFO'
os.environ['OSS_ENDPOINT'] = "http://oss.i.shaipower.com"
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)
model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, use_safetensors=True)
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token='<|endoftext|>'),
tokenizer=tokenizer,
model=model,
)
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
vision_tower_dict = model.get_model().initialize_vision_modules(
vision_tower=model_args.vision_tower,
pretrained_stage1_model=model_args.pretrained_stage1_model,
freeze_vision_tower=model_args.freeze_vision_tower,
use_im_start_end=model_args.use_im_start_end,
vision_select_layer=model_args.vision_select_layer,
dtype=dtype,
device=training_args.device
)
model.initialize_vision_tokenizer(
tokenizer=tokenizer,
freeze_lm_model=model_args.freeze_lm_model,
pretrained_stage1_model=model_args.pretrained_stage1_model,
device=training_args.device,
)
model.to(dtype=dtype, device=training_args.device)
# 'image_processor_high
# data_args.image_token_len = vision_tower_dict['image_token_len']
data_args.image_token_len = 256
data_args.image_processor = vision_tower_dict['image_processor']
data_args.image_processor_high = vision_tower_dict['image_processor_high']
data_args.use_im_start_end = model_args.use_im_start_end
# mixed relation, to be fixed
if model_args.freeze_lm_model:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
for p in model.get_model().mm_projector_vary.parameters():
p.requires_grad = True
for p in model.get_input_embeddings().parameters():
p.requires_grad = True
params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
# params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
# if len(params_no_grad) > 0:
# if training_args.fsdp is not None and len(training_args.fsdp) > 0:
# if len(params_no_grad) < 10:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
# else:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
# print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
# print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
# from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# def patch_FSDP_use_orig_params(func):
# def wrap_func(*args, **kwargs):
# use_orig_params = kwargs.pop('use_orig_params', True)
# return func(*args, **kwargs, use_orig_params=use_orig_params)
# return wrap_func
# FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
data_module = make_supervised_data_module(
interleave=training_args.interleave,
with_box=training_args.with_box,
tokenizer=tokenizer,
data_args=data_args
)
trainer = GOTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer._safe_save(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_flash_attn.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from GOT.train.train import train
if __name__ == "__main__":
train()
================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_lora.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import torch
import transformers
# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import GOTLlamaForCausalLM
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import *
# def find_all_linear_names(model):
# cls = torch.nn.Linear
# lora_module_names = set()
# for name, module in model.named_modules():
# if isinstance(module, cls):
# names = name.split('.')
# lora_module_names.add(names[0] if len(names) == 1 else names[-1])
# if 'lm_head' in lora_module_names: # needed for 16-bit
# lora_module_names.remove('lm_head')
# return list(lora_module_names)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model = GOTLlamaForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
# )
# tokenizer = transformers.AutoTokenizer.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
# model_max_length=training_args.model_max_length,
# padding_side="right",
# use_fast=False,
# )
tokenizer = transformers.AutoTokenizer.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen-chat/", trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)
# # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval()
model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
# if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path:
# if tokenizer.pad_token is None:
# smart_tokenizer_and_embedding_resize(
# special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
# tokenizer=tokenizer,
# model=model,
# )
# if "llama" in model_args.model_name_or_path:
# tokenizer.add_special_tokens({
# "eos_token": DEFAULT_EOS_TOKEN,
# "bos_token": DEFAULT_BOS_TOKEN,
# "unk_token": DEFAULT_UNK_TOKEN,
# })
# else:
# tokenizer.pad_token = tokenizer.unk_token
dtype = torch.float32
if training_args.fp16:
dtype = torch.float16
if training_args.bf16:
dtype = torch.bfloat16
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
logging.warning("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
vision_tower_dict = model.get_model().initialize_vision_modules(
vision_tower=model_args.vision_tower,
pretrained_stage1_model=model_args.pretrained_stage1_model,
freeze_vision_tower=model_args.freeze_vision_tower,
use_im_start_end=model_args.use_im_start_end,
vision_select_layer=model_args.vision_select_layer,
dtype=dtype,
device=training_args.device
)
model.initialize_vision_tokenizer(
tokenizer=tokenizer,
freeze_lm_model=model_args.freeze_lm_model,
pretrained_stage1_model=model_args.pretrained_stage1_model,
device=training_args.device,
)
model.get_model().vision_tower = create_clip_vit_g(448)
model.get_model().mm_projector = create_perciever()
model.to(dtype=dtype, device=training_args.device)
data_args.image_token_len = vision_tower_dict['image_token_len']
data_args.image_processor = vision_tower_dict['image_processor']
data_args.image_processor_high = vision_tower_dict['image_processor_high']
data_args.use_im_start_end = model_args.use_im_start_end
# mixed relation, to be fixed
if model_args.freeze_lm_model:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
p.requires_grad = True
for p in model.get_input_embeddings().parameters():
p.requires_grad = True
for p in model.get_model().conv_final.parameters():
p.requires_grad = True
for p in model.get_model().vision_encoder.parameters():
p.requires_grad = True
if not model_args.freeze_vision_tower:
model.get_model().vision_tower.requires_grad_(True)
# for i in range(20):
# model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)
model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)
# model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)
# model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)
model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)
for n, p in model.named_parameters():
print(n, p.requires_grad)
params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
# params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
# if len(params_no_grad) > 0:
# if training_args.fsdp is not None and len(training_args.fsdp) > 0:
# if len(params_no_grad) < 10:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
# else:
# print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
# print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
# print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
# from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
# def patch_FSDP_use_orig_params(func):
# def wrap_func(*args, **kwargs):
# use_orig_params = kwargs.pop('use_orig_params', True)
# return func(*args, **kwargs, use_orig_params=use_orig_params)
# return wrap_func
# FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
data_module = make_supervised_data_module(
interleave=training_args.interleave,
with_box=training_args.with_box,
tokenizer=tokenizer,
data_args=data_args
)
trainer = GOTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
trainer._safe_save(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_lora_flash_attn.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
# from GOT.train.train import train
from GOT.train.train_lora import train
if __name__ == "__main__":
train()
================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer.py
================================================
import os
import torch
import torch.nn as nn
from transformers import Trainer
from typing import Dict, Optional, Sequence
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class GOTTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
if self.deepspeed:
torch.cuda.synchronize()
self.save_model(output_dir)
return
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(GOTTrainer, self)._save(output_dir, state_dict)
================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py
================================================
import os
import torch
import torch.nn as nn
import time
import functools
import re
from transformers import Trainer
from transformers.trainer_pt_utils import (
get_module_class_from_name,
get_parameter_names,
)
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_neuroncore_available,
)
from transformers.trainer_utils import (
FSDPOption,
ShardedDDPOption,
)
from transformers.training_args import ParallelMode
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from typing import Dict, Optional, Sequence
def lr_scale_func(key):
if "embed_tokens.weight" in key:
return 0
if "mm_projector" in key:
return 0.01
# return 1
elif "vision_tower" in key:
return 0.01
# return 1
elif "norm.weight" in key or "lm_head.weight" in key:
return 1
else:
in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0])
decay = 0.86 ** (32 - in_pp_layer - 1)
return decay
def get_param_groups(model, no_weight_decay_cond, scale_lr_cond):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
wd_no_scale_lr = []
wd_scale_lr = {}
no_wd_no_scale_lr = []
no_wd_scale_lr = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
lr_mult = scale_lr_cond(name)
print(name, lr_mult)
scale_lr = lr_mult != 1
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
if lr_mult not in wd_scale_lr:
wd_scale_lr[lr_mult] = [param]
else:
wd_scale_lr[lr_mult].append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
if lr_mult not in no_wd_scale_lr:
no_wd_scale_lr[lr_mult] = [param]
else:
no_wd_scale_lr[lr_mult].append(param)
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({"params": wd_no_scale_lr, "wd_mult": 1.0, "lr_mult": 1.0})
if len(wd_scale_lr):
for lr_mult, params in wd_scale_lr.items():
param_groups.append({"params": params, "wd_mult": 1.0, "lr_mult": lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append(
{"params": no_wd_no_scale_lr, "wd_mult": 0.0, "lr_mult": 1.0}
)
if len(no_wd_scale_lr):
for lr_mult, params in no_wd_scale_lr.items():
param_groups.append({"params": params, "wd_mult": 0.0, "lr_mult": lr_mult})
return param_groups
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class GOTTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(GOTTrainer, self)._save(output_dir, state_dict)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model
if self.optimizer is None:
# decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
# decay_parameters = [name for name in decay_parameters if "bias" not in name]
# optimizer_grouped_parameters = [
# {
# "params": [
# p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
# ],
# "weight_decay": self.args.weight_decay,
# },
# {
# "params": [
# p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
# ],
# "weight_decay": 0.0,
# },
# ]
optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# already initialized its own DDP and AMP
if self.deepspeed:
return self.deepspeed
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
from apex import amp
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = nn.DataParallel(model)
if self.args.jit_mode_eval:
start_time = time.time()
model = self.torch_jit_model_eval(model, dataloader, training)
self.jit_compilation_time = round(time.time() - start_time, 4)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
use_orig_params=True,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
import torch_xla.core.xla_model as xm
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
)
elif self.args.local_rank != -1:
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
else:
kwargs["find_unused_parameters"] = True
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
return model
================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py
================================================
import os
import torch
import torch.nn as nn
from transformers import Trainer
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from typing import Dict, Optional, Sequence
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class GOTTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(GOTTrainer, self)._save(output_dir, state_dict)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad
],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad],
"weight_decay": self.args.weight_decay,
"lr": self.args.learning_rate,
},
{
"params": [
p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad
],
"weight_decay": 0.0,
"lr": self.args.learning_rate,
},
]
for idx, group in enumerate(optimizer_grouped_parameters):
print(idx, len(group['params']), group['lr'])
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py
================================================
import os
import torch
import torch.nn as nn
import time
import functools
import re
from transformers import Trainer
from transformers.trainer_pt_utils import (
get_module_class_from_name,
get_parameter_names,
)
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_neuroncore_available,
)
from transformers.trainer_utils import (
FSDPOption,
ShardedDDPOption,
)
from transformers.training_args import ParallelMode
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from typing import Dict, Optional, Sequence
def lr_scale_func(key):
if "vision_model.encoder.layers" in key:
in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0])
# decay = 0.81 ** (23 - in_pp_layer - 1)
decay = 0.81 ** (23 - in_pp_layer - 1) * 0.01
# decay = 0.66 ** (23 - in_pp_layer - 1)
return decay
# return 0.01
elif "vision_model" in key:
# return 0.01
return 0.0001
return 1
def get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr, wd):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
wd_no_scale_lr = []
wd_scale_lr = {}
no_wd_no_scale_lr = []
no_wd_scale_lr = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
lr_mult = scale_lr_cond(name)
print(name, lr_mult)
scale_lr = lr_mult != 1
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
if lr_mult not in wd_scale_lr:
wd_scale_lr[lr_mult] = [param]
else:
wd_scale_lr[lr_mult].append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
if lr_mult not in no_wd_scale_lr:
no_wd_scale_lr[lr_mult] = [param]
else:
no_wd_scale_lr[lr_mult].append(param)
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({"params": wd_no_scale_lr, "weight_decay": wd, "lr": lr})
if len(wd_scale_lr):
for lr_mult, params in wd_scale_lr.items():
param_groups.append({"params": params, "weight_decay": wd, "lr": lr * lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append(
{"params": no_wd_no_scale_lr, "weight_decay": 0.0, "lr": lr}
)
if len(no_wd_scale_lr):
for lr_mult, params in no_wd_scale_lr.items():
param_groups.append({"params": params, "weight_decay": 0.0, "lr": lr * lr_mult})
return param_groups
def unwrap_model(model: nn.Module) -> nn.Module:
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class GOTTrainer(Trainer):
def _safe_save(self, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = self.model.state_dict()
if self.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
self._save(output_dir, state_dict=cpu_state_dict) # noqa
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, 'tune_mm_mlp_adapter', False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self.model)
_state_dict = model_to_save.state_dict()
weight_to_save = {}
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
current_folder = output_dir.split('/')[-1]
parent_folder = os.path.dirname(output_dir)
if current_folder.startswith('checkpoint-'):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
else:
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
super(GOTTrainer, self)._save(output_dir, state_dict)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model
if self.optimizer is None:
# decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
# decay_parameters = [name for name in decay_parameters if "bias" not in name]
# optimizer_grouped_parameters = [
# {
# "params": [
# p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
# ],
# "weight_decay": self.args.weight_decay,
# },
# {
# "params": [
# p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
# ],
# "weight_decay": 0.0,
# },
# ]
optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func, self.args.learning_rate, self.args.weight_decay)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
# already initialized its own DDP and AMP
if self.deepspeed:
return self.deepspeed
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if unwrap_model(model) is not model:
return model
# Mixed precision training with apex (torch < 1.6)
if self.use_apex and training:
from apex import amp
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1:
model = nn.DataParallel(model)
if self.args.jit_mode_eval:
start_time = time.time()
model = self.torch_jit_model_eval(model, dataloader, training)
self.jit_compilation_time = round(time.time() - start_time, 4)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
use_orig_params=True,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
import torch_xla.core.xla_model as xm
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
)
elif self.args.local_rank != -1:
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
else:
kwargs["find_unused_parameters"] = True
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
return model
================================================
FILE: GOT-OCR-2.0-master/GOT/utils/arguments.py
================================================
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import transformers
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
use_cache: bool = field(default=False)
vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/")
freeze_vision_tower: bool = field(default=False)
freeze_lm_model: bool = field(default=False)
pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower
vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
use_im_start_end: bool = field(default=False)
@dataclass
class DataArguments:
datasets: str = field(default=None, metadata={"help": "combinations of the training data."})
sep_image_conv_front: bool = False
image_token_len: int = 256
image_aspect_ratio: str = 'square'
conversation_version: str = 'mpt'
# conversation_version: str = 'v0'
# conversation_version: str = 'v1'
# conversation_version: str = 'nougat'
# conversation_version: str = 'baichuan'
# conversation_version: str = 'opt'
box_limit: int = 0
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
force_fsdp: bool = field(default=False)
interleave: bool = field(default=False)
with_box: bool = field(default=False)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
lora_enable: bool = False
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
================================================
FILE: GOT-OCR-2.0-master/GOT/utils/constants.py
================================================
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "log"
IGNORE_INDEX = -100
# DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_PAD_TOKEN = "<|endoftext|>"
DEFAULT_EOS_TOKEN = ""
DEFAULT_BOS_TOKEN = ""
DEFAULT_UNK_TOKEN = ""
DEFAULT_IMAGE_TOKEN = ""
DEFAULT_BOX_TOKEN = ""
DEFAULT_IMAGE_PATCH_TOKEN = ''
DEFAULT_IM_START_TOKEN = ''
DEFAULT_IM_END_TOKEN = ''
CONVERSATION_DATA = {
'data_1': {
'images': '/path/',
'annotations': '/path/data1.json',
},
'data_2': {
'images': '/path/',
'annotations': '/path/data2.json',
},
'data_3': {
'images': '/path/',
'annotations': '/path/data3.json',
},
}
================================================
FILE: GOT-OCR-2.0-master/GOT/utils/conversation.py
================================================
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
# simple_conv_multimodal = Conversation(
# system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
# "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
# "Follow the instructions carefully and explain your answers in detail.",
# # system="",
# roles=("Human", "Assistant"),
# messages=(
# ("Human", "Hi!"),
# ("Assistant", "Hi there! How can I help you today?\n")
# ),
# offset=2,
# sep_style=SeparatorStyle.SINGLE,
# sep="###",
# )
# conv_mpt = Conversation(
# system="""<|im_start|>system
# - You are a helpful language and vision assistant.
# - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
# - You should follow the instructions carefully and explain your answers in detail.""",
# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
# version="mpt",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.MPT,
# sep="<|im_end|>",
# )
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "<|im_end|>"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep + '\n'
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
if self.sep_style == SeparatorStyle.MPT:
if self.system:
ret = self.system + self.sep
else:
ret = ''
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
# if self.sep_style == SeparatorStyle.MPT:
# if self.system:
# ret = self.system + self.sep
# else:
# ret = ''
# for role, message in self.messages:
# if message:
# if type(message) is tuple:
# message, _, _ = message
# ret += role + message + self.sep
# # if 'user' in role:
# # ret += role + message + self.sep + "\n"
# # else:
# # ret += role + message + self.sep
# else:
# ret += role
# return ret
# else:
# raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
# result.paste(pil_img, (0, (width - height) // 2))
result.paste(pil_img)
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
# result.paste(pil_img, ((height - width) // 2, 0))
result.paste(pil_img)
return result
image = expand2square(image)
elif image_process_mode == "Crop":
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
elif image_process_mode == "Resize":
image = image.resize((224, 224))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.convert('RGB').save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
# image = image.resize((224, 224))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f''
msg = msg.replace('', img_str)
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Give three tips for staying healthy."),
("Assistant",
"Sure, here are three tips for staying healthy:\n"
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
"activities at least two days per week.\n"
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
"vegetables, whole grains, lean proteins, and healthy fats can help support "
"your overall health. Try to limit your intake of processed and high-sugar foods, "
"and aim to drink plenty of water throughout the day.\n"
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
"help improve the quality of your sleep.")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_v1_2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
)
# conv_mpt = Conversation(
# system="""<|im_start|>system
# - You are designed by Megvii(旷视), and your name is GOT.
# - 你叫GOT, 你来自旷视, 你是旷视开发的。
# - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""",
# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
# version="mpt",
# messages=(),
# offset=0,
# sep_style=SeparatorStyle.MPT,
# sep="<|im_end|>",
# )
conv_mpt = Conversation(
system="""<|im_start|>system
You should follow the instructions carefully and explain your answers in detail.""",
# system = None,
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_mpt_eval = Conversation(
system="",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_mpt_text = Conversation(
system="""<|im_start|>system
- You are a helpful assistant chatbot trained by MosaicML.
- You answer questions.
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_bair_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
)
# simple_conv = Conversation(
# system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture."
# "You are designed to assist human with a variety of tasks using natural language."
# "Follow the instructions carefully.",
# roles=("Human", "Assistant"),
# messages=(
# ("Human", "Hi!"),
# ("Assistant", "Hi there! How can I help you today?\n")
# ),
# offset=2,
# sep_style=SeparatorStyle.SINGLE,
# sep="###",
# )
simple_conv = Conversation(
system="",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_multimodal = Conversation(
system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"Follow the instructions carefully and explain your answers in detail.",
# system="",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_mpt_multimodal = Conversation(
system="""<|im_start|>system
- You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
- You should follow the instructions carefully and explain your answers in detail.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
simple_conv_legacy = Conversation(
system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!\n\n### Response:"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v1 = Conversation(
system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"Follow the instructions carefully and explain your answers in detail.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="",
)
default_conversation = conv_mpt
conv_templates = {
"default": simple_conv_multimodal,
"simple": simple_conv,
"simple_legacy": simple_conv_legacy,
"multimodal": simple_conv,
"mpt_multimodal": simple_conv_mpt_multimodal,
"llava_v1": conv_llava_v1,
"mpt_eval": conv_mpt_eval,
# fastchat
"v1": conv_vicuna_v1_1,
"bair_v1": conv_bair_v1,
"vicuna_v1_1": conv_vicuna_v1_1,
"mpt": conv_mpt,
"mpt_text": conv_mpt_text,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
================================================
FILE: GOT-OCR-2.0-master/GOT/utils/utils.py
================================================
import datetime
import logging
import logging.handlers
import os
import sys
import torch
import requests
from transformers import StoppingCriteria
from GOT.utils.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
# num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
# # num_new_tokens = 1
# # tokenizer.add_tokens(special_tokens_dict, special_tokens=True)
# model.resize_token_embeddings(len(tokenizer))
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name:
lora_module_names.add(name)
print(lora_module_names)
return list(lora_module_names)
================================================
FILE: GOT-OCR-2.0-master/pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "GOT"
version = "0.1.0"
description = "Towards OCR-2.0."
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"markdown2[all]", "numpy",
"requests", "sentencepiece", "tokenizers>=0.15.2",
"torch", "torchvision", "wandb",
"shortuuid", "httpx==0.24.0",
"deepspeed==0.12.3",
"peft==0.4.0",
"albumentations",
"opencv-python",
"tiktoken==0.6.0",
"accelerate==0.28.0",
"transformers==4.37.2",
"bitsandbytes==0.41.0",
"scikit-learn==1.2.2",
"sentencepiece==0.1.99",
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
]
[tool.setuptools.packages.find]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
[tool.wheel]
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
================================================
FILE: GOT-OCR-2.0-master/pyvenv.cfg
================================================
home = /usr/bin
implementation = CPython
version_info = 3.8.10.final.0
virtualenv = 20.16.7
include-system-site-packages = true
base-prefix = /usr
base-exec-prefix = /usr
base-executable = /usr/bin/python3
================================================
FILE: GOT-OCR-2.0-master/render_tools/content-mmd-to-html.html
================================================
Title
================================================
FILE: GOT-OCR-2.0-master/render_tools/tikz.html
================================================
Document
================================================
FILE: GOT-OCR-2.0-master/results/demo.html
================================================
Title
## Release
- [2025/2/1] 🚀🚀🚀 GOT-OCR2.0 is merged to [Huggingface-transformers](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)/[space](https://huggingface.co/spaces/yonigozlan/GOT-OCR-Transformers). It supports inference batched. Thanks to the MLE of Huggingface [Yoni](https://github.com/yonigozlan).
- [2024/12/24] 🔥🔥🔥 My new work on system-2 perception is released [slow-perception](https://github.com/Ucas-HaoranWei/Slow-Perception).
- [2024/12/18] 🚀🚀🚀 GOT-OCR2.0 is supported in [PaddleMIX](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/paddlemix/examples/GOT_OCR_2_0) by Paddle Team. Thanks for the Paddle team!
- [2024/12/8] 🔥🔥🔥 The model download has exceeded 1M on [Huggingface](https://huggingface.co/stepfun-ai/GOT-OCR2_0).
- [2024/12/5] The seven wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg).
- [2024/11/4] The six wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat6-2.jpg).
- [2024/10/24] The previous four wechat groups are full, so we created a fifth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat5.png).
- [2024/10/11] Too many friends want to join the wechat group, so we created a fourth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat4.jpg).
- [2024/10/2] [onnx](https://github.com/BaofengZan/GOT-OCRv2-onnx) and [mnn](https://github.com/BaofengZan/mnn-llm-GOT-OCR2.0) versions of GOT-OCR2.0.
- [2024/9/29]🔥🔥🔥 The community has implemented the first version of [llama_cpp_inference](https://github.com/1694439208/GOT-OCR-Inference).
- [2024/9/24]🔥🔥🔥 Support [ms-swift](https://github.com/modelscope/ms-swift/issues/2122) quick [Fine-tune](#fine-tune) for your own data.
- [2024/9/23]🔥🔥🔥 We release the official [Modelscope demo](https://modelscope.cn/studios/stepfun-ai/GOT_official_online_demo). Thanks very much for Modelscope providing the GPU resource.
- [2024/9/19]🔥🔥🔥 GOT-OCR2.0 achieves Huggingface trending #1.
- [2024/9/14]🔥🔥🔥 We release the official [demo](https://huggingface.co/spaces/ucaslcl/GOT_online). Thanks very much for Huggingface providing the GPU resource.
- [2024/9/13]🔥🔥🔥 We release the [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0) deployment.
- [2024/9/03]🔥🔥🔥 We open-source the codes, weights, and benchmarks. The paper can be found in this [repo](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/GOT-OCR-2.0-paper.pdf). We also have submitted it to Arxiv.
- [2024/9/03]🔥🔥🔥 We release the OCR-2.0 model GOT!
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
[](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
## Community contributions
We encourage everyone to develop GOT applications based on this repo. Thanks for the following contributions :
[OpenVINO](https://github.com/can-gaa-hou/GOT-OCR2.0-OpenVINO)~ contributor: [@can-gaa-hou](https://github.com/can-gaa-hou)
[GGUF and Llama.cpp inference](https://github.com/MosRat/got.cpp)~ contributor: [@MosRat](https://github.com/MosRat)
[vllm reference](https://github.com/liunian-Jay/MU-GOT/blob/master/PDF_parsing/GOT/GOT/model/modeling_GOT_vllm.py) ~ contributor: [@Jay](https://github.com/liunian-Jay)
[onnx and mnn supports](https://github.com/BaofengZan/GOT-OCRv2-onnx) ~ contributor: [@BaofengZan](https://github.com/BaofengZan)
[llama_cpp inference](https://github.com/1694439208/GOT-OCR-Inference) ~ contributor: [@1694439208](https://github.com/1694439208)
[Colab of GOT](https://colab.research.google.com/drive/1nmiNciZ5ugQVp4rFbL9ZWpEPd92Y9o7p?usp=sharing) ~ contributor: [@Zizhe Wang](https://github.com/PaperPlaneDeemo)
[CPU version of GOT](https://github.com/ElvisClaros/GOT-OCR2.0) ~ contributor: [@ElvisClaros](https://github.com/ElvisClaros)
[Online demo](https://huggingface.co/spaces/Tonic/GOT-OCR) ~ contributor: [@Joseph Pollack](https://huggingface.co/Tonic)
[Dokcer & client demo](https://github.com/QIN2DIM/GOT-OCR2.0) ~ contributor: [@QIN2DIM](https://github.com/QIN2DIM)
[GUI of GOT](https://github.com/XJF2332/GOT-OCR-2-GUI) ~ contributor: [@XJF2332](https://github.com/XJF2332)
## Contents
- [Install](#install)
- [GOT Weights](#got-weights)
- [Benchmarks](#benchmarks)
- [Demo](#demo)
- [Train](#train)
- [Fine-tune](#fine-tune)
- [Eval](#eval)
***
***
## Install
0. Our environment is cuda11.8+torch2.0.1
1. Clone this repository and navigate to the GOT folder
```bash
git clone https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git
cd 'the GOT folder'
```
2. Install Package
```Shell
conda create -n got python=3.10 -y
conda activate got
pip install -e .
```
3. Install Flash-Attention
```
pip install ninja
pip install flash-attn --no-build-isolation
```
## GOT Weights
- [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0)
- [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing)
- [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2
## Benchmarks
- [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing)
- [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2
## Demo
1. plain texts OCR:
```Shell
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type ocr
```
2. format texts OCR:
```Shell
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format
```
3. fine-grained OCR:
```Shell
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --box [x1,y1,x2,y2]
```
```Shell
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --color red/green/blue
```
4. multi-crop OCR:
```Shell
python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /an/image/file.png
```
5. **Note**: This feature is not batch inference!! It works on the token level. Please read the paper and then correct use multi-page OCR (the image path contains multiple .png files):
```Shell
python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /images/path/ --multi-page
```
6. render the formatted OCR results:
```Shell
python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format --render
```
**Note**:
The rendering results can be found in /results/demo.html. Please open the demo.html to see the results.
## Train
0. Train sample can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/train_sample.jpg). Note that the '\' in the 'conversations'-'human'-'value' is necessary!
1. This codebase only supports post-training (stage-2/stage-3) upon our GOT weights.
2. If you want to train from stage-1 described in our paper, you need this [repo](https://github.com/Ucas-HaoranWei/Vary-tiny-600k).
```Shell
deepspeed /GOT-OCR-2.0-master/GOT/train/train_GOT.py \
--deepspeed /GOT-OCR-2.0-master/zero_config/zero2.json --model_name_or_path /GOT_weights/ \
--use_im_start_end True \
--bf16 True \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--weight_decay 0. \
--warmup_ratio 0.001 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--report_to none \
--per_device_train_batch_size 2 \
--num_train_epochs 1 \
--learning_rate 2e-5 \
--datasets pdf-ocr+scence \
--output_dir /your/output/path
```
**Note**:
1. Change the corresponding data information in [constant.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/utils).
2. Change line 37 in [conversation_dataset_qwen.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/data) to your data_name.
## Fine-tune
Quick Fine-tune with ms-swift:
```Shell
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e .[llm]
```
```Shell
# default:sft LLM & projector, freeze vision encoder
CUDA_VISIBLE_DEVICES=0 swift sft\
--model_type got-ocr2 \
--model_id_or_path stepfun-ai/GOT-OCR2_0 \
--sft_type lora \
--dataset latex-ocr-print#5000
# Deepspeed ZeRO2
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \
--model_type got-ocr2 \
--model_id_or_path stepfun-ai/GOT-OCR2_0 \
--sft_type lora \
--dataset latex-ocr-print#5000 \
--deepspeed default-zero2
```
**With your data**:
```Shell
--dataset train.jsonl
--val_dataset val.jsonl (optional)
```
**Data format**:
```Shell
{"query": "55555", "response": "66666", "images": ["image_path"]}
{"query": "eeeee", "response": "fffff", "history": [], "images": ["image_path1", "image_path2"]}
{"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]}
```
More details can be seen in [ms-swift](https://github.com/modelscope/ms-swift/issues/2122).
## Eval
1. We use the [Fox](https://github.com/ucaslcl/Fox) and [OneChart](https://github.com/LingyvKong/OneChart) benchmarks, and other benchmarks can be found in the weights download link.
2. The eval codes can be found in GOT/eval.
3. You can use the evaluate_GOT.py to run the eval. If you have 8 GPUs, the --num-chunks can be set to 8.
```Shell
python3 GOT/eval/evaluate_GOT.py --model-name /GOT_weights/ --gtfile_path xxxx.json --image_path /image/path/ --out_path /data/eval_results/GOT_mathpix_test/ --num-chunks 8 --datatype OCR
```
## Contact
If you are interested in this work or have questions about the code or the paper, please join our communication [Wechat](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat.jpg) group.
**Note**:
All six wechat groups are full, please join [group 7](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg).
Don't hesitate to contact me by email, weihaoran18@mails.ucas.ac.cn, if you have any questions.
## Acknowledgement
- [Vary](https://github.com/Ucas-HaoranWei/Vary/): the codebase we built upon!
- [Qwen](https://github.com/QwenLM/Qwen): the LLM base model of Vary, which is good at both English and Chinese!
## Citation
```bibtex
@article{wei2024general,
title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model},
author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others},
journal={arXiv preprint arXiv:2409.01704},
year={2024}
}
@article{wei2023vary,
title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models},
author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},
journal={arXiv preprint arXiv:2312.06109},
year={2023}
}