Repository: Ucas-HaoranWei/GOT-OCR2.0 Branch: main Commit: 179ed086ad6b Files: 41 Total size: 224.5 KB Directory structure: gitextract_is3bdbjd/ ├── GOT-OCR-2.0-master/ │ ├── GOT/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ └── conversation_dataset_qwen.py │ │ ├── demo/ │ │ │ ├── process_results.py │ │ │ ├── run_ocr_2.0.py │ │ │ └── run_ocr_2.0_crop.py │ │ ├── eval/ │ │ │ ├── eval_GOT_ocr.py │ │ │ ├── evaluate_GOT.py │ │ │ ├── multi_hardware_eval_GOT.py │ │ │ └── pyevaltools/ │ │ │ ├── __init__.py │ │ │ ├── eval_ocr.py │ │ │ ├── eval_ocr_format.py │ │ │ ├── eval_ocr_scene.py │ │ │ └── merge_results.py │ │ ├── model/ │ │ │ ├── GOT_ocr_2_0.py │ │ │ ├── __init__.py │ │ │ ├── plug/ │ │ │ │ └── blip_process.py │ │ │ └── vision_encoder/ │ │ │ ├── __init__.py │ │ │ └── vary_b.py │ │ ├── train/ │ │ │ ├── train.py │ │ │ ├── train_GOT.py │ │ │ ├── train_flash_attn.py │ │ │ ├── train_lora.py │ │ │ ├── train_lora_flash_attn.py │ │ │ ├── trainer.py │ │ │ ├── trainer_llm_llrd.py │ │ │ ├── trainer_vit_fixlr.py │ │ │ └── trainer_vit_llrd.py │ │ └── utils/ │ │ ├── arguments.py │ │ ├── constants.py │ │ ├── conversation.py │ │ └── utils.py │ ├── pyproject.toml │ ├── pyvenv.cfg │ ├── render_tools/ │ │ ├── content-mmd-to-html.html │ │ └── tikz.html │ ├── results/ │ │ └── demo.html │ └── zero_config/ │ ├── zero2.json │ └── zero3.json └── README.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: GOT-OCR-2.0-master/GOT/__init__.py ================================================ ================================================ FILE: GOT-OCR-2.0-master/GOT/data/__init__.py ================================================ import torch import transformers from dataclasses import dataclass, field from GOT.utils.constants import * @dataclass class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances): # print(instances) # exit() input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) images = [torch.stack(instance['image']) for instance in instances] # if 'flattened_patches' in instances[0]['image_high'][0].keys(): # images_high = [torch.stack([instance['image_high'][0]['flattened_patches']]) for instance in instances] # else: images_high = [torch.stack(instance['image_high']) for instance in instances] images = list(zip(images, images_high)) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), images=images, ) return batch def make_supervised_data_module(interleave, with_box, tokenizer, data_args): if data_args.conversation_version == 'mpt': from GOT.data.conversation_dataset_qwen import ConversationDataset dataset_cls = ConversationDataset train_dataset = dataset_cls( tokenizer=tokenizer, datasets=data_args.datasets, multimodal_cfg=dict( sep_image_conv_front=data_args.sep_image_conv_front, image_token_len=data_args.image_token_len, image_aspect_ratio=data_args.image_aspect_ratio, use_im_start_end=data_args.use_im_start_end, image_processor=data_args.image_processor, image_processor_high = data_args.image_processor_high, box_limit=data_args.box_limit, ) ) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) ================================================ FILE: GOT-OCR-2.0-master/GOT/data/base_dataset.py ================================================ import io import os import copy import json import logging import torch import transformers import boto3 from typing import List, Optional, Tuple, Union, Dict, Sequence from torch.utils.data import Dataset from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from GOT.utils.constants import * class BaseDataset(Dataset): def __init__( self, datasets: str, tokenizer: transformers.PreTrainedTokenizer, multimodal_cfg: dict ): super(BaseDataset, self).__init__() self.tokenizer = tokenizer self.multimodal_cfg = multimodal_cfg logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image") def image_processor(self, image): # processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit) processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn) image_high = image.copy() # Vary old codes # # TODO the 'keep', 'padding' only used for the first processor # if self.multimodal_cfg['image_aspect_ratio'] == 'keep': # max_hw, min_hw = max(image.size), min(image.size) # aspect_ratio = max_hw / min_hw # max_len, min_len = 448, 224 # shortest_edge = int(min(max_len / aspect_ratio, min_len)) # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] # elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': # def expand2square(pil_img, background_color): # width, height = pil_img.size # if width == height: # return pil_img # elif width > height: # result = Image.new(pil_img.mode, (width, width), background_color) # result.paste(pil_img) # for simpler box processing # return result # else: # result = Image.new(pil_img.mode, (height, height), background_color) # result.paste(pil_img) # for simpler box processing # return result # image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) # image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0] # else: # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] image_high = processor_high(image_high) return image_high def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: pass ================================================ FILE: GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py ================================================ import io import os import copy import json import logging import torch import random from typing import List, Optional, Tuple, Union, Dict, Sequence from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from GOT.data.base_dataset import BaseDataset from GOT.utils.constants import * from GOT.utils import conversation as conversation_lib import boto3 import smart_open from megfile import smart_glob from natsort import natsorted class ConversationDataset(BaseDataset): """Conversation format dataset stage2 fine-tuning.""" def __init__(self, datasets, tokenizer, multimodal_cfg): super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg) # v0 version format conversation conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"] logging.warning("Formatting inputs into conversation type: mpt-fixed") logging.warning("Loading data...") list_data_dict = [] list_image_path = [] # TODO add your data [data1, data2, data3, .....] got_data_dict = { "pdf-ocr": ["data1", "data2"], 'scene-ocr': ["data3", "data4"] # ...... } for name_all in datasets.split("+"): for name in got_data_dict[name_all]: dataset = CONVERSATION_DATA[name] data_path = dataset['annotations'] data = json.load(open(data_path, "r")) list_data_dict.extend(data) image_path = dataset['images'] list_image_path.extend([image_path] * len(data)) logging.warning(f"Data from {data_path} provide {len(data)} conversations.") assert len(list_data_dict) == len(list_image_path) logging.warning(f"{len(list_data_dict)} conversations in total.") a_new_list = list(zip(list_data_dict, list_image_path)) random.shuffle(a_new_list) list_data_dict_new, list_image_path_new = zip(*a_new_list) self.list_data_dict = list_data_dict_new self.list_image_path = list_image_path_new self.im_patch_token = 151859 self.im_start_token = 151857 self.im_end_token = 151858 def multimodal_processor(self, sources, flag_num_patches): for source in sources: if self.multimodal_cfg['sep_image_conv_front']: assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] for sentence in source: replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']*flag_num_patches replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN # sentence["value"] = str(sentence["value"]).replace('\qquad', '\quad') sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def _tokenize_fn(self, strings): """Tokenize a list of strings.""" tokenized_list = [ self.tokenizer( text, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(self, target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker.lower() == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def token_processor(self, sources, image_name): conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = self.tokenizer( conversations, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True, ).input_ids # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids) # round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer)) # instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) instruction_len = len(self.tokenizer(parts[0]).input_ids) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < self.tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) print(image_name) return dict( input_ids=input_ids, labels=targets, ) def __getitem__(self, i) -> Dict[str, torch.Tensor]: # data = self.list_data_dict[i] data = copy.deepcopy(self.list_data_dict[i]) if isinstance(data, dict): image_list = [] image_high_list = [] flag_num_patches = 1 if 'image' in data: image_path = self.list_image_path[i] image_file = data['image'] # multi-crop or multi page, only support .png files if ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path): if image_file[0] == '/': patch_dir = image_path[:-1] + image_file patches = smart_glob(patch_dir + '*.png') else: patch_dir = image_path + image_file patches = smart_glob(patch_dir + '*.png') # print(patches) if not patches: print(f'cannot glob the dir {patch_dir}.') return self.__getitem__(0) # sort multi images by name patches = natsorted(patches) flag_num_patches = len(patches) for patch in patches: try: image = Image.open(patch).convert('RGB') except: print(f'cannot identify image file {patch}.') return self.__getitem__(0) try: img = self.image_processor(image) image_list.append(img) image_high_list.append(img) except: print(f'image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!') return self.__getitem__(0) else: flag_num_patches = 1 try: image = Image.open(image_path + image_file).convert('RGB') except: print(f'cannot identify image file {image_file}.') return self.__getitem__(0) try: image = self.image_processor(image) except: print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!') return self.__getitem__(0) conversations = self.multimodal_processor([data["conversations"]], flag_num_patches) # print(conversations) # exit() else: conversations = [data] # align with fastchat & llava here, put the conversation into a list for tokenization image_name = image_path + image_file data_dict = self.token_processor(conversations, image_name) data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) if isinstance(data, dict) and 'image' in data: if image_list and image_high_list: data_dict['image'] = image_list data_dict['image_high'] = image_high_list else: data_dict['image'] = [image] data_dict['image_high'] = [image] else: # crop_size = self.multimodal_cfg['image_processor'].crop_size # data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] # Vary for two image, GOT does not use the data_dict['image] data_dict['image'] = [torch.zeros(3, 1024, 1024)] data_dict['image_high'] = [torch.zeros(3, 1024, 1024)] return data_dict ================================================ FILE: GOT-OCR-2.0-master/GOT/demo/process_results.py ================================================ import string punctuation_dict = { ",": ",", "。": ".", } # import os def svg_to_html(svg_content, output_filename): html_content = f""" SVG Embedded in HTML {svg_content} """ with open(output_filename, 'w') as file: file.write(html_content) ================================================ FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py ================================================ import argparse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from GOT.utils.conversation import conv_templates, SeparatorStyle from GOT.utils.utils import disable_torch_init from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria from GOT.model import * from GOT.utils.utils import KeywordsStoppingCriteria from PIL import Image import os import requests from PIL import Image from io import BytesIO from GOT.model.plug.blip_process import BlipImageEvalProcessor from transformers import TextStreamer import re from GOT.demo.process_results import punctuation_dict, svg_to_html import string DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' translation_table = str.maketrans(punctuation_dict) def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval() model.to(device='cuda', dtype=torch.bfloat16) # TODO vary old codes, NEED del image_processor = BlipImageEvalProcessor(image_size=1024) image_processor_high = BlipImageEvalProcessor(image_size=1024) use_im_start_end = True image_token_len = 256 image = load_image(args.image_file) w, h = image.size # print(image.size) if args.type == 'format': qs = 'OCR with format: ' else: qs = 'OCR: ' if args.box: bbox = eval(args.box) if len(bbox) == 2: bbox[0] = int(bbox[0]/w*1000) bbox[1] = int(bbox[1]/h*1000) if len(bbox) == 4: bbox[0] = int(bbox[0]/w*1000) bbox[1] = int(bbox[1]/h*1000) bbox[2] = int(bbox[2]/w*1000) bbox[3] = int(bbox[3]/h*1000) if args.type == 'format': qs = str(bbox) + ' ' + 'OCR with format: ' else: qs = str(bbox) + ' ' + 'OCR: ' if args.color: if args.type == 'format': qs = '[' + args.color + ']' + ' ' + 'OCR with format: ' else: qs = '[' + args.color + ']' + ' ' + 'OCR: ' if use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv_mode = "mpt" args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() print(prompt) inputs = tokenizer([prompt]) # vary old codes, no use image_1 = image.copy() image_tensor = image_processor(image) image_tensor_1 = image_processor_high(image_1) input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = model.generate( input_ids, images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())], do_sample=False, num_beams = 1, no_repeat_ngram_size = 20, streamer=streamer, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) if args.render: print('==============rendering===============') outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() if '**kern' in outputs: import verovio from cairosvg import svg2png import cv2 import numpy as np tk = verovio.toolkit() tk.loadData(outputs) tk.setOptions({"pageWidth": 2100, "footer": 'none', 'barLineWidth': 0.5, 'beamMaxSlope': 15, 'staffLineWidth': 0.2, 'spacingStaff': 6}) tk.getPageCount() svg = tk.renderToSVG() svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"") svg_to_html(svg, "./results/demo.html") if args.type == 'format' and '**kern' not in outputs: if '\\begin{tikzpicture}' not in outputs: html_path = "./render_tools/" + "/content-mmd-to-html.html" html_path_2 = "./results/demo.html" right_num = outputs.count('\\right') left_num = outputs.count('\left') if right_num != left_num: outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') outputs = outputs.replace('"', '``').replace('$', '') outputs_list = outputs.split('\n') gt= '' for out in outputs_list: gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' gt = gt[:-2] with open(html_path, 'r') as web_f: lines = web_f.read() lines = lines.split("const text =") new_web = lines[0] + 'const text =' + gt + lines[1] else: html_path = "./render_tools/" + "/tikz.html" html_path_2 = "./results/demo.html" outputs = outputs.translate(translation_table) outputs_list = outputs.split('\n') gt= '' for out in outputs_list: if out: if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out: while out[-1] == ' ': out = out[:-1] if out is None: break if out: if out[-1] != ';': gt += out[:-1] + ';\n' else: gt += out + '\n' else: gt += out + '\n' with open(html_path, 'r') as web_f: lines = web_f.read() lines = lines.split("const text =") new_web = lines[0] + gt + lines[1] with open(html_path_2, 'w') as web_f_new: web_f_new.write(new_web) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--type", type=str, required=True) parser.add_argument("--box", type=str, default= '') parser.add_argument("--color", type=str, default= '') parser.add_argument("--render", action='store_true') args = parser.parse_args() eval_model(args) ================================================ FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py ================================================ import argparse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from GOT.utils.conversation import conv_templates, SeparatorStyle from GOT.utils.utils import disable_torch_init from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria from GOT.model import * from GOT.utils.utils import KeywordsStoppingCriteria from PIL import Image import os import requests from PIL import Image from io import BytesIO from GOT.model.plug.blip_process import BlipImageEvalProcessor from transformers import TextStreamer from natsort import natsorted import glob DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) # print(target_ratios) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval() model.to(device='cuda', dtype=torch.bfloat16) # vary old codes, no use image_processor = BlipImageEvalProcessor(image_size=1024) image_processor_high = BlipImageEvalProcessor(image_size=1024) use_im_start_end = True image_token_len = 256 image_list = [] if args.multi_page: qs = 'OCR with format across multi pages: ' # only for png files patches = glob.glob(args.image_file + '/*png') patches = natsorted(patches) sub_images = [] for sub_image in patches: sub_images.append(load_image(sub_image)) ll = len(patches) else: qs = 'OCR with format upon the patch reference: ' img = load_image(args.image_file) sub_images = dynamic_preprocess(img) ll = len(sub_images) for p in sub_images: image = p image_1 = image.copy() # no use, vary old codes image_tensor = image_processor(image) image_tensor_1 = image_processor_high(image_1) image_list.append(image_tensor_1) image_list = torch.stack(image_list) print('====new images batch size======: ',image_list.shape) # qs = args.query if use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv_mode = "mpt" args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = model.generate( input_ids, images=[(image_list.half().cuda(), image_list.half().cuda())], do_sample=False, num_beams = 1, # no_repeat_ngram_size = 20, streamer=streamer, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) if args.render: print('==============rendering===============') outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() html_path = "./render_tools/" + "/content-mmd-to-html.html" html_path_2 = "./results/demo.html" right_num = outputs.count('\\right') left_num = outputs.count('\left') if right_num != left_num: outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.') outputs = outputs.replace('"', '``').replace('$', '') outputs_list = outputs.split('\n') gt= '' for out in outputs_list: gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' gt = gt[:-2] with open(html_path, 'r') as web_f: lines = web_f.read() lines = lines.split("const text =") new_web = lines[0] + 'const text =' + gt + lines[1] with open(html_path_2, 'w') as web_f_new: web_f_new.write(new_web) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--image-file", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--multi-page", action='store_true') parser.add_argument("--render", action='store_true') args = parser.parse_args() eval_model(args) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py ================================================ import argparse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from tqdm import tqdm from PIL import Image import json import os import requests from PIL import Image from io import BytesIO import math import argparse from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from GOT.utils.conversation import conv_templates, SeparatorStyle from GOT.utils.utils import disable_torch_init from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria from GOT.model import * from GOT.utils.utils import KeywordsStoppingCriteria from PIL import Image import os import requests from PIL import Image from io import BytesIO from GOT.model.plug.blip_process import BlipImageEvalProcessor from transformers import TextStreamer from GOT.model.plug.transforms import train_transform, test_transform import re from GOT.demo.process_results import punctuation_dict, svg_to_html DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' import string translation_table = str.maketrans(punctuation_dict) def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) # print(target_ratios) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # print(target_aspect_ratio) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # print(blocks) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): chunks = split_list(lst, n) return chunks[k] output_list = [] def eval_model(args): # Model disable_torch_init() model_name = os.path.expanduser(args.model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval() # vary old codes, no use image_processor = BlipImageEvalProcessor(image_size=1024) # image_processor_high = BlipImageEvalProcessor(image_size=1280) image_processor_high = BlipImageEvalProcessor(image_size=1024) use_im_start_end = True # image_token_len = 400 image_token_len = 256 gts_path = args.gtfile_path gts = json.load(open(gts_path)) # gts = gts[0] print("Generate Results......") if "OCR" in args.datatype: gts = get_chunk(gts, args.num_chunks, args.chunk_idx) for ann in tqdm(gts): output_json = {} if "OCR" in args.datatype: qs = ann["conversations"][0]["value"] else: qs = ann["question"] # ans = ann["answers"][0] qs2 = qs image_file = ann["image"] if 'Text' in args.datatype: image_file = image_file + '.jpg' if "VQAv2" in args.datatype: image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg' if "Cap" in args.datatype: image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg' image_file_path = os.path.join(args.image_path, image_file) # print(image_file_path) # exit() # qs = args.query # if mm_use_im_start_end: multi_crop = False if multi_crop: image_list = [] # qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + '\n' + 'OCR with format upon the patch reference: ' img = load_image(image_file_path) sub_images = dynamic_preprocess(img) ll = len(sub_images) for p in sub_images: image = p image_1 = image.copy() # vary old code, NO USE image_tensor = image_processor_high(image_1) # image_tensor_1 = image_processor_high.preprocess(image_1, return_tensors='pt')['pixel_values'][0] image_tensor_1 = image_processor_high(image_1) image_list.append(image_tensor_1) # print(image_tensor_1.shape) image_list = torch.stack(image_list) else: ll = 1 image = load_image(image_file_path) image_1 = image.copy() # image_1 = image_1.resize((1024, 1024)) # vary old code, NO USE image_tensor = image_processor_high(image_1) image_tensor_1 = image_processor_high(image_1) # image_tensor_1 = torch.zeros(3, 1024, 1024) qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + 'OCR with format: ' conv_mode = "mpt" if args.conv_mode is not None and conv_mode != args.conv_mode: print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) else: args.conv_mode = conv_mode conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) if multi_crop: with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = model.generate( input_ids, images=[(image_list.half().cuda(), image_list.half().cuda())], do_sample=False, num_beams = 1, # temperature=0.2, # no_repeat_ngram_size = 20, # streamer=streamer, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) else: with torch.autocast("cuda", dtype=torch.bfloat16): output_ids = model.generate( input_ids, images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())], do_sample=False, num_beams = 1, # temperature=0.2, no_repeat_ngram_size = 20, # encoder_repetition_penalty = 1.2, # penalty_alpha=0.2, # top_k=3, max_new_tokens=4096, stopping_criteria=[stopping_criteria] ) outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() if outputs.endswith(stop_str): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() # outputs = outputs.strip()[:-1] if "Cap" in args.datatype: # output_json['image'] = ann["image"] output_json['image_id'] = ann["id"] output_json["caption"] = outputs else: # output_json['questionId'] = qs_id # output_json['question_id'] = qs_id output_json['image'] = ann["image"] output_json['question'] = qs output_json['label'] = ann["conversations"][1]["value"] output_json['answer'] = outputs output_list.append(output_json) filename = args.out_path + "/results_" + str(args.chunk_idx) + ".json" with open(filename, 'w', encoding="utf-8") as file_obj: json.dump(output_list, file_obj, ensure_ascii=False, indent=1) # print(outputs) # print("Evaluate Results... ") # doc_text_eval(gts_path, filename, args.datatype) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--gtfile_path", type=str, required=True) parser.add_argument("--image_path", type=str, required=True) parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--datatype", type=str, required=True) # Text or Doc parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) # parser.add_argument("--query", type=str, required=True) parser.add_argument("--conv-mode", type=str, default=None) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() print(args) eval_model(args) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/evaluate_GOT.py ================================================ import os import argparse parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--gtfile_path", type=str, required=True) parser.add_argument("--image_path", type=str, required=True) parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--datatype", type=str, required=True) # Text\Doc\VQAv2\Cap # parser.add_argument("--eval", type=str, required=True) args = parser.parse_args() os.system("python3 -m GOT.eval.multi_hardware_eval_GOT" + " " + "--model-name" + " " + args.model_name + " " + "--gtfile_path" + " " + args.gtfile_path + " " + "--image_path" + " " + args.image_path + " " + "--out_path" + " " + args.out_path + " " + "--num-chunks" + " " + str(args.num_chunks) + " " + "--temperature" + " " + str(args.temperature) + " " + "--datatype" + " " + args.datatype ) print("Evaluating.....") os.system("python3 -m GOT.eval.pyevaltools.merge_results" + " " + "--out_path" + " " + args.out_path) # if args.datatype == "OCR": a_type = 'plain' # 'palin'; 'format'; 'scene' if a_type == 'plain': os.system("python3 -m GOT.eval.pyevaltools.eval_ocr" + " " + "--out_path" + " " + args.out_path + " " + "--gt_path" + " " + args.gtfile_path + " " + "--datatype" + " " + args.datatype ) if a_type == 'format': os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_format" + " " + "--out_path" + " " + args.out_path + " " + "--gt_path" + " " + args.gtfile_path + " " + "--datatype" + " " + args.datatype ) if a_type == 'scene': os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_scene" + " " + "--out_path" + " " + args.out_path + " " + "--gt_path" + " " + args.gtfile_path + " " + "--datatype" + " " + args.datatype ) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py ================================================ import os import argparse from multiprocessing import Pool # from GOT.eval.merge_results import merge_outputs # from GOT.eval.doctextVQA import doc_text_eval def run_eval(chunk_id, model_name, gtfile_path, image_path, out_path, num_chunks, datatype, temperature): os.system("CUDA_VISIBLE_DEVICES=" + str(chunk_id) + " " + "python3 -m GOT.eval.eval_GOT_ocr" + " " + "--model-name" + " " + model_name + " " + "--gtfile_path" + " " + gtfile_path + " " + "--image_path" + " " + image_path + " " + "--out_path" + " " + out_path + " " + "--num-chunks" + " " + str(num_chunks) + " " + "--chunk-idx" + " " + str(chunk_id) + " " + "--temperature" + " " + str(temperature) + " " + "--datatype" + " " + datatype ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--gtfile_path", type=str, required=True) parser.add_argument("--image_path", type=str, required=True) parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--datatype", type=str, required=True) # Text or Doc # parser.add_argument("--eval", type=str, required=True) args = parser.parse_args() num_chunks = args.num_chunks if os.path.exists(args.out_path) == False: os.makedirs(args.out_path) with Pool(num_chunks) as p: for i in range(num_chunks): chunk_id = i p.apply_async(run_eval, (chunk_id, args.model_name, args.gtfile_path, args.image_path, args.out_path, num_chunks, args.datatype, args.temperature)) p.close() p.join() ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/__init__.py ================================================ author='aagrawal' ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py ================================================ import json # from doctextVQAeval import VQAEval import argparse # import fitz as pymupdf import nltk from nltk.metrics import precision, recall, f_measure import numpy as np import jieba # import megfile as mf import pickle import pandas as pd import re # from loguru import logger # nltk.download('wordnet') from nltk.translate import meteor_score # from marker_scoring import score_text # from utils import contain_chinese_string parser = argparse.ArgumentParser() parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--datatype", type=str, required=True) args = parser.parse_args() def preprocess(text, predict_root_): if 'InternVL' in predict_root_: text = text.split("All words in the image:\n")[1] text = text.split("[UNUSED_TOKEN_145]")[0] return text def contain_chinese_string(text): # 使用正则表达式匹配中文字符 chinese_pattern = re.compile(r'[\u4e00-\u9fa5]') return bool(chinese_pattern.search(text)) inline_reg = re.compile(r"\\\((.*?)(?= 1: # try: metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis) # except LookupError: # metrics["meteor"] = np.nan reference = set(reference) hypothesis = set(hypothesis) metrics["f_measure"] = f_measure(reference, hypothesis) if heavy_mode >= 1: metrics["precision"] = precision(reference, hypothesis) metrics["recall"] = recall(reference, hypothesis) if heavy_mode == 2: # 速度太慢 metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt)) return metrics def doc_formated_text_eval(gt_root_, predict_root_, datatype): predicts = json.load(open(predict_root_, encoding='utf-8')) # print(predicts) gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label') pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer') text_results = [] math_results = [] table_results = [] for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split): # try: # text, math, table text_gts, text_pres = gt0, pre0 math_gts, math_pres = gt1, pre1 table_gts, table_pres = gt2, pre2 # for text_gt, text_pre in zip(text_gts, text_pres): ans = nougat_per_metrics(predict_root_, text_gts, text_pres) # if len(ans) == 0: # continue if ans: text_results.append(ans) # for math_gt, math_pre in zip(math_gts, math_pres): ans = nougat_per_metrics(predict_root_, math_gts, math_pres) # if len(ans) == 0: # continue if ans: math_results.append(ans) # for table_gt, table_pre in zip(table_gts, table_pres): ans = nougat_per_metrics(predict_root_, table_gts, table_pres) # if len(ans) == 0: # continue if ans: table_results.append(ans) mean_dict = {} # print((result)) # print(len(result)) mean_dict["eval question num"] = len(text_results) mean_dict['text'] = {} mean_dict['math'] = {} mean_dict['table'] = {} for k, v in text_results[0].items(): mean_dict['text'][k] = 0 mean_dict['math'][k] = 0 mean_dict['table'][k] = 0 for each in text_results: for k, v in each.items(): mean_dict['text'][k] += v for each in math_results: for k, v in each.items(): mean_dict['math'][k] += v for each in table_results: for k, v in each.items(): mean_dict['table'][k] += v for k, v in mean_dict['text'].items(): mean_dict['text'][k] /= len(text_results) for k, v in mean_dict['math'].items(): mean_dict['math'][k] /= len(math_results) for k, v in mean_dict['table'].items(): mean_dict['table'][k] /= len(table_results) print(json.dumps(mean_dict, indent=4)) def doc_text_eval(gt_root_, predict_root_, datatype): predicts = json.load(open(predict_root_, encoding='utf-8')) # print(predicts) result = [] for ann in predicts: try: ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"]) if len(ans) == 0: continue result.append(ans) except: assert False, print("ERROR!!! Check yout output!!!") mean_dict = {} # print((result)) # print(len(result)) mean_dict["eval question num"] = len(result) for k, v in result[0].items(): mean_dict[k] = 0 for each in result: for k, v in each.items(): mean_dict[k] += v for k, v in mean_dict.items(): if k == "eval question num": continue mean_dict[k] /= len(result) print(json.dumps(mean_dict, indent=4)) # doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc") # doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype) doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py ================================================ import json # from doctextVQAeval import VQAEval import argparse # import fitz as pymupdf import nltk from nltk.metrics import precision, recall, f_measure import numpy as np import jieba # import megfile as mf import pickle import pandas as pd import re # from loguru import logger # nltk.download('wordnet') from nltk.translate import meteor_score # from marker_scoring import score_text # from utils import contain_chinese_string parser = argparse.ArgumentParser() parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--datatype", type=str, required=True) args = parser.parse_args() def preprocess(text, predict_root_): if 'InternVL' in predict_root_: text = text.split("All words in the image:\n")[1] text = text.split("[UNUSED_TOKEN_145]")[0] return text def contain_chinese_string(text): # 使用正则表达式匹配中文字符 chinese_pattern = re.compile(r'[\u4e00-\u9fa5]') return bool(chinese_pattern.search(text)) inline_reg = re.compile(r"\\\((.*?)(?= 1: # try: metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis) # except LookupError: # metrics["meteor"] = np.nan reference = set(reference) hypothesis = set(hypothesis) metrics["f_measure"] = f_measure(reference, hypothesis) if heavy_mode >= 1: metrics["precision"] = precision(reference, hypothesis) metrics["recall"] = recall(reference, hypothesis) if heavy_mode == 2: # 速度太慢 metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt)) return metrics def doc_formated_text_eval(gt_root_, predict_root_, datatype): predicts = json.load(open(predict_root_, encoding='utf-8')) # print(predicts) gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label') pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer') text_results = [] math_results = [] table_results = [] for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split): # try: # text, math, table text_gts, text_pres = gt0, pre0 math_gts, math_pres = gt1, pre1 table_gts, table_pres = gt2, pre2 # for text_gt, text_pre in zip(text_gts, text_pres): ans = nougat_per_metrics(predict_root_, text_gts, text_pres) # if len(ans) == 0: # continue if ans: text_results.append(ans) # for math_gt, math_pre in zip(math_gts, math_pres): ans = nougat_per_metrics(predict_root_, math_gts, math_pres) # if len(ans) == 0: # continue if ans: math_results.append(ans) # for table_gt, table_pre in zip(table_gts, table_pres): ans = nougat_per_metrics(predict_root_, table_gts, table_pres) # if len(ans) == 0: # continue if ans: table_results.append(ans) mean_dict = {} # print((result)) # print(len(result)) mean_dict["eval question num"] = len(text_results) mean_dict['text'] = {} mean_dict['math'] = {} mean_dict['table'] = {} for k, v in text_results[0].items(): mean_dict['text'][k] = 0 mean_dict['math'][k] = 0 mean_dict['table'][k] = 0 for each in text_results: for k, v in each.items(): mean_dict['text'][k] += v for each in math_results: for k, v in each.items(): mean_dict['math'][k] += v for each in table_results: for k, v in each.items(): mean_dict['table'][k] += v for k, v in mean_dict['text'].items(): mean_dict['text'][k] /= len(text_results) for k, v in mean_dict['math'].items(): mean_dict['math'][k] /= len(math_results) for k, v in mean_dict['table'].items(): mean_dict['table'][k] /= len(table_results) print(json.dumps(mean_dict, indent=4)) def doc_text_eval(gt_root_, predict_root_, datatype): predicts = json.load(open(predict_root_, encoding='utf-8')) # print(predicts) result = [] for ann in predicts: try: ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"]) if len(ans) == 0: continue result.append(ans) except: assert False, print("ERROR!!! Check yout output!!!") mean_dict = {} # print((result)) # print(len(result)) mean_dict["eval question num"] = len(result) for k, v in result[0].items(): mean_dict[k] = 0 for each in result: for k, v in each.items(): mean_dict[k] += v for k, v in mean_dict.items(): if k == "eval question num": continue mean_dict[k] /= len(result) print(json.dumps(mean_dict, indent=4)) # doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc") doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype) # doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py ================================================ import json import argparse import nltk from nltk.metrics import precision, recall, f_measure import numpy as np import jieba # import megfile as mf import pickle import pandas as pd import re from nltk.translate import meteor_score parser = argparse.ArgumentParser() parser.add_argument("--out_path", type=str, required=True) parser.add_argument("--gt_path", type=str, required=True) parser.add_argument("--datatype", type=str, required=True) args = parser.parse_args() def preprocess(text, predict_root_): if 'InternVL' in predict_root_: text = text.split("All words in the image:\n")[1] text = text.split("[UNUSED_TOKEN_145]")[0] return text def contain_chinese_string(text): chinese_pattern = re.compile(r'[\u4e00-\u9fa5]') return bool(chinese_pattern.search(text)) def nougat_per_metrics(predict_root_, pred, gt, minlen=1): metrics = {} if len(pred) < minlen or len(gt) < minlen: return metrics reference = list(gt) hypothesis = list(pred) metrics["bleu"] = nltk.translate.bleu([reference], hypothesis) metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis) reference = set(reference) hypothesis = set(hypothesis) metrics["f_measure"] = f_measure(reference, hypothesis) metrics["precision"] = precision(reference, hypothesis) metrics["recall"] = recall(reference, hypothesis) metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt)) return metrics def doc_text_eval(gt_root_, predict_root_, datatype): predicts = json.load(open(predict_root_, encoding='utf-8')) result = [] for ann in predicts: try: ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"]) if len(ans) == 0: continue result.append(ans) except: assert False, print("ERROR!!! Check yout output!!!") mean_dict = {} mean_dict["eval question num"] = len(result) for k, v in result[0].items(): mean_dict[k] = 0 for each in result: for k, v in each.items(): mean_dict[k] += v for k, v in mean_dict.items(): if k == "eval question num": continue mean_dict[k] /= len(result) print(json.dumps(mean_dict, indent=4)) doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype) ================================================ FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py ================================================ import os import json import argparse def merge_outputs(out_path): files = os.listdir(out_path) # print(files) alist = [] for file in files: alist += json.load(open(os.path.join(out_path, file), encoding='utf-8')) # print(len(alist)) filename = out_path + "/results_final" + ".json" with open(filename, 'w', encoding="utf-8") as file_obj: json.dump(alist, file_obj, ensure_ascii=False, indent=1) parser = argparse.ArgumentParser() parser.add_argument("--out_path", type=str, required=True) args = parser.parse_args() merge_outputs(args.out_path) ================================================ FILE: GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py ================================================ from transformers import AutoConfig, AutoModelForCausalLM, \ Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \ CLIPVisionModel, CLIPImageProcessor from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from typing import List, Optional, Tuple, Union from transformers.cache_utils import Cache, DynamicCache import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from GOT.utils.constants import * from GOT.model.vision_encoder.vary_b import build_vary_vit_b from GOT.model.plug.blip_process import BlipImageEvalProcessor class GOTConfig(Qwen2Config): model_type = "GOT" class GOTQwenModel(Qwen2Model): config_class = GOTConfig def __init__(self, config: Qwen2Config): super(GOTQwenModel, self).__init__(config) self.vision_tower_high = build_vary_vit_b() self.mm_projector_vary = nn.Linear(1024, 1024) def initialize_vision_modules( self, vision_tower, pretrained_stage1_model=None, freeze_vision_tower=False, use_im_start_end=False, vision_select_layer=-1, dtype=torch.float16, device="cuda" ): # Vary old codes, not use in GOT image_processor = BlipImageEvalProcessor(image_size=1024) # 1024*1024 image_processor_high = BlipImageEvalProcessor(image_size=1024) self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device) self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device) image_token_len = 256 self.config.vision_tower = vision_tower self.config.image_token_len = image_token_len # self.config.use_im_start_end = use_im_start_end self.config.use_im_start_end = True self.config.vision_select_layer = vision_select_layer self.config.freeze_vision_tower = freeze_vision_tower return dict( image_processor=image_processor, image_processor_high=image_processor_high, image_token_len=image_token_len, ) # def get_input_embeddings(self, x): # return self.wte(x) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: # HACK: replace back original embeddings for LLaVA pretraining orig_embeds_params = getattr(self, 'orig_embeds_params', None) if orig_embeds_params is not None: with torch.no_grad(): self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) vision_tower_high = getattr(self, 'vision_tower_high', None) if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: # if True: # assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images") # print(im) use_im_start_end = getattr(self.config, "use_im_start_end", -1) vision_select_layer = getattr(self.config, "vision_select_layer", -1) im_patch_token = getattr(self.config, "im_patch_token", -1) im_start_token = getattr(self.config, "im_start_token", -1) im_end_token = getattr(self.config, "im_end_token", -1) freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) im_patch_token = 151859 im_start_token = 151857 im_end_token = 151858 image_features = [] for image in images: P, C, H, W = image[1].shape # with torch.set_grad_enabled(True): # # print(image[1].shape) # cnn_feature = vision_tower_high(image[1]) # cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256 1024 # # image_features.append(cnn_feature) # image_features_2.append(cnn_feature) if P == 1: with torch.set_grad_enabled(False): # print(image[1].shape) cnn_feature = vision_tower_high(image[1]) cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024 # image_features.append(cnn_feature) # image_features_2.append(cnn_feature) image_feature = self.mm_projector_vary(cnn_feature) image_features.append(image_feature) else: image_patches = torch.unbind(image[1]) image_patches_features = [] for image_patch in image_patches: image_p = torch.stack([image_patch]) with torch.set_grad_enabled(False): cnn_feature_p = vision_tower_high(image_p) cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1) image_feature_p = self.mm_projector_vary(cnn_feature_p) image_patches_features.append(image_feature_p) image_feature = torch.cat(image_patches_features, dim=1) # print(P) # print(image_feature.shape) # exit() image_features.append(image_feature) dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) dummy_image_features = dummy_image_features_2 use_im_start_end = True new_input_embeds = [] for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): if (cur_input_ids == im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() new_input_embeds.append(cur_input_embeds) continue if use_im_start_end: if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): raise ValueError("The number of image start tokens and image end tokens should be the same.") image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) num_patches = per_cur_image_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: raise ValueError("The image end token should follow the image start token.") cur_input_embeds = torch.cat( ( cur_input_embeds[:image_start_token_pos+1], per_cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:] ), dim=0 ) new_input_embeds.append(cur_input_embeds) else: raise NotImplementedError inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(GOTQwenModel, self).forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) class GOTQwenForCausalLM(Qwen2ForCausalLM): config_class = GOTConfig # supports_gradient_checkpointing = True def __init__(self, config): super(Qwen2ForCausalLM, self).__init__(config) self.model = GOTQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_model(self): return self.model # def _set_gradient_checkpointing(self, module, value=False): # if isinstance(module, GOTQwenModel): # module.gradient_checkpointing = value # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # print(input_ids) # print(len(images)) # print(inputs_embeds) outputs = self.model( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, images=images, return_dict=return_dict ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() # logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "images": kwargs.get("images", None), } ) return model_inputs def initialize_vision_tokenizer( self, tokenizer, freeze_lm_model=False, pretrained_stage1_model=None, device="cuda" ): config = self.get_model().config # add image patch token # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] config.im_patch_token = 151859 config.use_im_start_end = True # add image start token and end token if config.use_im_start_end: # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) config.im_start_token, config.im_end_token = 151857, 151858 AutoConfig.register("GOT", GOTConfig) AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM) ================================================ FILE: GOT-OCR-2.0-master/GOT/model/__init__.py ================================================ from .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig ================================================ FILE: GOT-OCR-2.0-master/GOT/model/plug/blip_process.py ================================================ """ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import cv2 import numpy as np import torch # from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from PIL import Image class BaseProcessor: def __init__(self): self.transform = lambda x: x return def __call__(self, item): return self.transform(item) # @classmethod # def from_config(cls, cfg=None): # return cls() # def build(self, **kwargs): # cfg = OmegaConf.create(kwargs) # return self.from_config(cfg) class BlipImageBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) # mean = (0.0, 0.0, 0.0) # std = (1.0, 1.0, 1.0) self.normalize = transforms.Normalize(mean, std) ## aug functions def identity_func(img): return img def autocontrast_func(img, cutoff=0): """ same output as PIL.ImageOps.autocontrast """ n_bins = 256 def tune_channel(ch): n = ch.size cut = cutoff * n // 100 if cut == 0: high, low = ch.max(), ch.min() else: hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) low = np.argwhere(np.cumsum(hist) > cut) low = 0 if low.shape[0] == 0 else low[0] high = np.argwhere(np.cumsum(hist[::-1]) > cut) high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] if high <= low: table = np.arange(n_bins) else: scale = (n_bins - 1) / (high - low) offset = -low * scale table = np.arange(n_bins) * scale + offset table[table < 0] = 0 table[table > n_bins - 1] = n_bins - 1 table = table.clip(0, 255).astype(np.uint8) return table[ch] channels = [tune_channel(ch) for ch in cv2.split(img)] out = cv2.merge(channels) return out def equalize_func(img): """ same output as PIL.ImageOps.equalize PIL's implementation is different from cv2.equalize """ n_bins = 256 def tune_channel(ch): hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) non_zero_hist = hist[hist != 0].reshape(-1) step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) if step == 0: return ch n = np.empty_like(hist) n[0] = step // 2 n[1:] = hist[:-1] table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) return table[ch] channels = [tune_channel(ch) for ch in cv2.split(img)] out = cv2.merge(channels) return out def rotate_func(img, degree, fill=(0, 0, 0)): """ like PIL, rotate by degree, not radians """ H, W = img.shape[0], img.shape[1] center = W / 2, H / 2 M = cv2.getRotationMatrix2D(center, degree, 1) out = cv2.warpAffine(img, M, (W, H), borderValue=fill) return out def solarize_func(img, thresh=128): """ same output as PIL.ImageOps.posterize """ table = np.array([el if el < thresh else 255 - el for el in range(256)]) table = table.clip(0, 255).astype(np.uint8) out = table[img] return out def color_func(img, factor): """ same output as PIL.ImageEnhance.Color """ ## implementation according to PIL definition, quite slow # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] # out = blend(degenerate, img, factor) # M = ( # np.eye(3) * factor # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) # )[np.newaxis, np.newaxis, :] M = np.float32( [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] ) * factor + np.float32([[0.114], [0.587], [0.299]]) out = np.matmul(img, M).clip(0, 255).astype(np.uint8) return out def contrast_func(img, factor): """ same output as PIL.ImageEnhance.Contrast """ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) table = ( np.array([(el - mean) * factor + mean for el in range(256)]) .clip(0, 255) .astype(np.uint8) ) out = table[img] return out def brightness_func(img, factor): """ same output as PIL.ImageEnhance.Contrast """ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) out = table[img] return out def sharpness_func(img, factor): """ The differences the this result and PIL are all on the 4 boundaries, the center areas are same """ kernel = np.ones((3, 3), dtype=np.float32) kernel[1][1] = 5 kernel /= 13 degenerate = cv2.filter2D(img, -1, kernel) if factor == 0.0: out = degenerate elif factor == 1.0: out = img else: out = img.astype(np.float32) degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) out = out.astype(np.uint8) return out def shear_x_func(img, factor, fill=(0, 0, 0)): H, W = img.shape[0], img.shape[1] M = np.float32([[1, factor, 0], [0, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def translate_x_func(img, offset, fill=(0, 0, 0)): """ same output as PIL.Image.transform """ H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, -offset], [0, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def translate_y_func(img, offset, fill=(0, 0, 0)): """ same output as PIL.Image.transform """ H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, 0], [0, 1, -offset]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def posterize_func(img, bits): """ same output as PIL.ImageOps.posterize """ out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) return out def shear_y_func(img, factor, fill=(0, 0, 0)): H, W = img.shape[0], img.shape[1] M = np.float32([[1, 0, 0], [factor, 1, 0]]) out = cv2.warpAffine( img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR ).astype(np.uint8) return out def cutout_func(img, pad_size, replace=(0, 0, 0)): replace = np.array(replace, dtype=np.uint8) H, W = img.shape[0], img.shape[1] rh, rw = np.random.random(2) pad_size = pad_size // 2 ch, cw = int(rh * H), int(rw * W) x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) out = img.copy() out[x1:x2, y1:y2, :] = replace return out ### level to args def enhance_level_to_args(MAX_LEVEL): def level_to_args(level): return ((level / MAX_LEVEL) * 1.8 + 0.1,) return level_to_args def shear_level_to_args(MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * 0.3 if np.random.random() > 0.5: level = -level return (level, replace_value) return level_to_args def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * float(translate_const) if np.random.random() > 0.5: level = -level return (level, replace_value) return level_to_args def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): def level_to_args(level): level = int((level / MAX_LEVEL) * cutout_const) return (level, replace_value) return level_to_args def solarize_level_to_args(MAX_LEVEL): def level_to_args(level): level = int((level / MAX_LEVEL) * 256) return (level,) return level_to_args def none_level_to_args(level): return () def posterize_level_to_args(MAX_LEVEL): def level_to_args(level): level = int((level / MAX_LEVEL) * 4) return (level,) return level_to_args def rotate_level_to_args(MAX_LEVEL, replace_value): def level_to_args(level): level = (level / MAX_LEVEL) * 30 if np.random.random() < 0.5: level = -level return (level, replace_value) return level_to_args func_dict = { "Identity": identity_func, "AutoContrast": autocontrast_func, "Equalize": equalize_func, "Rotate": rotate_func, "Solarize": solarize_func, "Color": color_func, "Contrast": contrast_func, "Brightness": brightness_func, "Sharpness": sharpness_func, "ShearX": shear_x_func, "TranslateX": translate_x_func, "TranslateY": translate_y_func, "Posterize": posterize_func, "ShearY": shear_y_func, } translate_const = 10 MAX_LEVEL = 10 replace_value = (128, 128, 128) arg_dict = { "Identity": none_level_to_args, "AutoContrast": none_level_to_args, "Equalize": none_level_to_args, "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), "Solarize": solarize_level_to_args(MAX_LEVEL), "Color": enhance_level_to_args(MAX_LEVEL), "Contrast": enhance_level_to_args(MAX_LEVEL), "Brightness": enhance_level_to_args(MAX_LEVEL), "Sharpness": enhance_level_to_args(MAX_LEVEL), "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), "Posterize": posterize_level_to_args(MAX_LEVEL), "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), } class RandomAugment(object): def __init__(self, N=2, M=10, isPIL=False, augs=[]): self.N = N self.M = M self.isPIL = isPIL if augs: self.augs = augs else: self.augs = list(arg_dict.keys()) def get_random_ops(self): sampled_ops = np.random.choice(self.augs, self.N) return [(op, 0.5, self.M) for op in sampled_ops] def __call__(self, img): if self.isPIL: img = np.array(img) ops = self.get_random_ops() for name, prob, level in ops: if np.random.random() > prob: continue args = arg_dict[name](level) img = func_dict[name](img, *args) return img class VideoRandomAugment(object): def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): self.N = N self.M = M self.p = p self.tensor_in_tensor_out = tensor_in_tensor_out if augs: self.augs = augs else: self.augs = list(arg_dict.keys()) def get_random_ops(self): sampled_ops = np.random.choice(self.augs, self.N, replace=False) return [(op, self.M) for op in sampled_ops] def __call__(self, frames): assert ( frames.shape[-1] == 3 ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." if self.tensor_in_tensor_out: frames = frames.numpy().astype(np.uint8) num_frames = frames.shape[0] ops = num_frames * [self.get_random_ops()] apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] frames = torch.stack( list(map(self._aug, frames, ops, apply_or_not)), dim=0 ).float() return frames def _aug(self, img, ops, apply_or_not): for i, (name, level) in enumerate(ops): if not apply_or_not[i]: continue args = arg_dict[name](level) img = func_dict[name](img, *args) return torch.from_numpy(img) # if __name__ == "__main__": # a = RandomAugment() # img = np.random.randn(32, 32, 3) # a(img) class BlipImageTrainProcessor(BlipImageBaseProcessor): def __init__( self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 ): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.RandomResizedCrop( image_size, scale=(min_scale, max_scale), interpolation=InterpolationMode.BICUBIC, ), # transforms.RandomHorizontalFlip(), RandomAugment( 2, 5, isPIL=True, augs=[ "Identity", # "AutoContrast", "Brightness", "Sharpness", "Equalize", # "ShearX", # "ShearY", # "TranslateX", # "TranslateY", # "Rotate", ], ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) class BlipImageEvalProcessor(BlipImageBaseProcessor): def __init__(self, image_size=384, mean=None, std=None): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) # if __name__ == "__main__": # a = BlipImageTrainProcessor(image_size=1024) # # img = np.random.randn(1024, 1024, 3) # # x = torch.zeros(1024, 1024, 3) # x = Image.open("/data/codes/GOT-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB") # print(x.size) # y = a(x) # print(y.size()) ================================================ FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/__init__.py ================================================ ================================================ FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Type from functools import partial import torch import torch.nn as nn from typing import Type # from GOT.model.vision_encoder.vitg_qwen import Resampler import math class Projector(nn.Module): def __init__( self, width: 256, n_queries: int = 256, output_dim: int = 4096, **kwargs ): super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.attn_pool = Resampler( grid_size=int(math.sqrt(n_queries)), embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, norm_layer=norm_layer, ) self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) def forward(self, x: torch.Tensor): x = self.attn_pool(x) x = self.ln_post(x) x = x @ self.proj return x class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViT(nn.Module): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__() self.img_size = img_size self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.pos_embed: Optional[nn.Parameter] = None if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) ) self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size if i not in global_attn_indexes else 0, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.neck = nn.Sequential( nn.Conv2d( embed_dim, out_chans, kernel_size=1, bias=False, ), LayerNorm2d(out_chans), nn.Conv2d( out_chans, out_chans, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(out_chans), ) self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = self.neck(x.permute(0, 3, 1, 2)) x = self.net_2(x) x = self.net_3(x) return x class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.norm2 = norm_layer(dim) self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) self.window_size = window_size def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class Attention(nn.Module): """Multi-head Attention block with relative position embeddings.""" def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads. qkv_bias (bool): If True, add a learnable bias to query, key, value. rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. input_size (tuple(int, int) or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.use_rel_pos = use_rel_pos if self.use_rel_pos: assert ( input_size is not None ), "Input size must be provided if using relative positional encoding." # initialize relative positional embeddings self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1) x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) x = self.proj(x) return x def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition( windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. Args: windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x def build_vary_vit_b(checkpoint=None): return _build_vary( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, ) def _build_vary( encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, ): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size image_encoder=ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) # if checkpoint is not None: # # with open(checkpoint, "rb") as f: # state_dict = torch.load(checkpoint) # # print(state_dict.keys()) # # for key in state_dict: # # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False) # # ocr-anyting # # image_encoder.load_state_dict(state_dict, strict=True) # # tob # # model.vision_tower. # image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True) # print(checkpoint) return image_encoder if __name__ == '__main__': x = torch.zeros(2, 3, 1024, 1024) # x.permute(0, 3, 1, 2) net = build_vary_vit_b(checkpoint ='/mnt/shared-storage/tenant/hypertext/xpkong/jycode/checkpoint/pytorch_model.bin') # mlp = Projector(width=256, n_queries = 256, output_dim = 768) y = net(x) y = y.flatten(2).permute(0, 2, 1) print(y.shape) # y = mlp(y) # y = net_2(y) # y = net_3(y) # # print(y.shape) ================================================ FILE: GOT-OCR-2.0-master/GOT/train/train.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pathlib import torch import transformers # from GOT.train.trainer import GOTTrainer # from GOT.train.trainer_vit_llrd import GOTTrainer from GOT.train.trainer_vit_fixlr import GOTTrainer from GOT.model import GOTLlamaForCausalLM from GOT.model import * from GOT.data import make_supervised_data_module from GOT.utils.arguments import * from GOT.utils.constants import * from GOT.utils.utils import smart_tokenizer_and_embedding_resize from GOT.model.vision_encoder.sam import build_sam_vit_b from GOT.model.vision_encoder.swin_transformer import build_swin_transformer def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model = GOTLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) tokenizer = transformers.AutoTokenizer.from_pretrained( '/data/hypertext/xpkong/newcode/checkpoints/kly-vary-1025-cc595-pretrain/', cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", use_fast=False, ) # tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,) # # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval() # model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda') if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path: if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), tokenizer=tokenizer, model=model, ) if "llama" in model_args.model_name_or_path: tokenizer.add_special_tokens({ "eos_token": DEFAULT_EOS_TOKEN, "bos_token": DEFAULT_BOS_TOKEN, "unk_token": DEFAULT_UNK_TOKEN, }) else: tokenizer.pad_token = tokenizer.unk_token # tokenizer.pad_token = DEFAULT_UNK_TOKEN # tokenizer.pad_token = tokenizer.eos_token # tokenizer.add_special_tokens({'pad_token':'<|endoftext|>'}) dtype = torch.float32 if training_args.fp16: dtype = torch.float16 if training_args.bf16: dtype = torch.bfloat16 vision_tower_dict = model.get_model().initialize_vision_modules( vision_tower=model_args.vision_tower, pretrained_stage1_model=model_args.pretrained_stage1_model, freeze_vision_tower=model_args.freeze_vision_tower, use_im_start_end=model_args.use_im_start_end, vision_select_layer=model_args.vision_select_layer, dtype=dtype, device=training_args.device ) model.initialize_vision_tokenizer( tokenizer=tokenizer, freeze_lm_model=model_args.freeze_lm_model, pretrained_stage1_model=model_args.pretrained_stage1_model, device=training_args.device, ) model.get_model().vision_tower = transformers.CLIPVisionModel.from_pretrained( '/data/public/ucaswei/pretrain/vit-large-patch14') model.get_model().vision_tower_high = build_sam_vit_b(checkpoint='/data/hypertext/xpkong/newcode/checkpoints/kly-sam-opt-all-1023-new/pytorch_model.bin') # model.get_model().mm_projector = create_perciever() model.to(dtype=dtype, device=training_args.device) # 'image_processor_high # data_args.image_token_len = vision_tower_dict['image_token_len'] data_args.image_token_len = 256 data_args.image_processor = vision_tower_dict['image_processor'] data_args.image_processor_high = vision_tower_dict['image_processor_high'] data_args.use_im_start_end = model_args.use_im_start_end # mixed relation, to be fixed if model_args.freeze_lm_model: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True # for p in model.get_model().vision_encoder.parameters(): # p.requires_grad = True # for p in model.get_model().chatt.parameters(): # p.requires_grad = True for p in model.get_input_embeddings().parameters(): p.requires_grad = True # conv_final # for p in model.get_model().conv_final.parameters(): # p.requires_grad = True if not model_args.freeze_vision_tower: model.get_model().vision_tower.requires_grad_(True) # for i in range(20): # model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False) # model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False) # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False) # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False) # model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False) # for p in model.get_model().vision_encoder.parameters(): # p.requires_grad = True # for n, p in model.named_parameters(): # print(n, p.requires_grad) if model_args.freeze_vision_tower: model.get_model().vision_tower.requires_grad_(False) params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] # if len(params_no_grad) > 0: # if training_args.fsdp is not None and len(training_args.fsdp) > 0: # if len(params_no_grad) < 10: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) # else: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP # def patch_FSDP_use_orig_params(func): # def wrap_func(*args, **kwargs): # use_orig_params = kwargs.pop('use_orig_params', True) # return func(*args, **kwargs, use_orig_params=use_orig_params) # return wrap_func # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) data_module = make_supervised_data_module( interleave=training_args.interleave, with_box=training_args.with_box, tokenizer=tokenizer, data_args=data_args ) trainer = GOTTrainer( model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() trainer._safe_save(output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: GOT-OCR-2.0-master/GOT/train/train_GOT.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pathlib import torch # torch.set_num_threads(1) import transformers # from GOT.train.trainer import GOTTrainer # from GOT.train.trainer_vit_llrd import GOTTrainer from GOT.train.trainer_vit_fixlr import GOTTrainer from GOT.model import * from GOT.data import make_supervised_data_module from GOT.utils.arguments import * from GOT.utils.constants import * from GOT.utils.utils import smart_tokenizer_and_embedding_resize from GOT.model.vision_encoder.vary_b import build_vary_vit_b import os # os.environ['NCCL_IB_DISABLE'] = '1' os.environ['NCCL_DEBUG'] = 'INFO' os.environ['OSS_ENDPOINT'] = "http://oss.i.shaipower.com" def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,) model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, use_safetensors=True) smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token='<|endoftext|>'), tokenizer=tokenizer, model=model, ) dtype = torch.float32 if training_args.fp16: dtype = torch.float16 if training_args.bf16: dtype = torch.bfloat16 vision_tower_dict = model.get_model().initialize_vision_modules( vision_tower=model_args.vision_tower, pretrained_stage1_model=model_args.pretrained_stage1_model, freeze_vision_tower=model_args.freeze_vision_tower, use_im_start_end=model_args.use_im_start_end, vision_select_layer=model_args.vision_select_layer, dtype=dtype, device=training_args.device ) model.initialize_vision_tokenizer( tokenizer=tokenizer, freeze_lm_model=model_args.freeze_lm_model, pretrained_stage1_model=model_args.pretrained_stage1_model, device=training_args.device, ) model.to(dtype=dtype, device=training_args.device) # 'image_processor_high # data_args.image_token_len = vision_tower_dict['image_token_len'] data_args.image_token_len = 256 data_args.image_processor = vision_tower_dict['image_processor'] data_args.image_processor_high = vision_tower_dict['image_processor_high'] data_args.use_im_start_end = model_args.use_im_start_end # mixed relation, to be fixed if model_args.freeze_lm_model: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True for p in model.get_model().mm_projector_vary.parameters(): p.requires_grad = True for p in model.get_input_embeddings().parameters(): p.requires_grad = True params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] # if len(params_no_grad) > 0: # if training_args.fsdp is not None and len(training_args.fsdp) > 0: # if len(params_no_grad) < 10: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) # else: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP # def patch_FSDP_use_orig_params(func): # def wrap_func(*args, **kwargs): # use_orig_params = kwargs.pop('use_orig_params', True) # return func(*args, **kwargs, use_orig_params=use_orig_params) # return wrap_func # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) data_module = make_supervised_data_module( interleave=training_args.interleave, with_box=training_args.with_box, tokenizer=tokenizer, data_args=data_args ) trainer = GOTTrainer( model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() trainer._safe_save(output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: GOT-OCR-2.0-master/GOT/train/train_flash_attn.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() from GOT.train.train import train if __name__ == "__main__": train() ================================================ FILE: GOT-OCR-2.0-master/GOT/train/train_lora.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pathlib import torch import transformers # from GOT.train.trainer import GOTTrainer # from GOT.train.trainer_vit_llrd import GOTTrainer from GOT.train.trainer_vit_fixlr import GOTTrainer from GOT.model import GOTLlamaForCausalLM from GOT.data import make_supervised_data_module from GOT.utils.arguments import * from GOT.utils.constants import * from GOT.utils.utils import * # def find_all_linear_names(model): # cls = torch.nn.Linear # lora_module_names = set() # for name, module in model.named_modules(): # if isinstance(module, cls): # names = name.split('.') # lora_module_names.add(names[0] if len(names) == 1 else names[-1]) # if 'lm_head' in lora_module_names: # needed for 16-bit # lora_module_names.remove('lm_head') # return list(lora_module_names) def train(): parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model = GOTLlamaForCausalLM.from_pretrained( # model_args.model_name_or_path, # cache_dir=training_args.cache_dir, # ) # tokenizer = transformers.AutoTokenizer.from_pretrained( # model_args.model_name_or_path, # cache_dir=training_args.cache_dir, # model_max_length=training_args.model_max_length, # padding_side="right", # use_fast=False, # ) tokenizer = transformers.AutoTokenizer.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen-chat/", trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,) # # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval() model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda') smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), tokenizer=tokenizer, model=model, ) # if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path: # if tokenizer.pad_token is None: # smart_tokenizer_and_embedding_resize( # special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), # tokenizer=tokenizer, # model=model, # ) # if "llama" in model_args.model_name_or_path: # tokenizer.add_special_tokens({ # "eos_token": DEFAULT_EOS_TOKEN, # "bos_token": DEFAULT_BOS_TOKEN, # "unk_token": DEFAULT_UNK_TOKEN, # }) # else: # tokenizer.pad_token = tokenizer.unk_token dtype = torch.float32 if training_args.fp16: dtype = torch.float16 if training_args.bf16: dtype = torch.bfloat16 if training_args.lora_enable: from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=training_args.lora_r, lora_alpha=training_args.lora_alpha, target_modules=find_all_linear_names(model), lora_dropout=training_args.lora_dropout, bias=training_args.lora_bias, task_type="CAUSAL_LM", ) logging.warning("Adding LoRA adapters...") model = get_peft_model(model, lora_config) vision_tower_dict = model.get_model().initialize_vision_modules( vision_tower=model_args.vision_tower, pretrained_stage1_model=model_args.pretrained_stage1_model, freeze_vision_tower=model_args.freeze_vision_tower, use_im_start_end=model_args.use_im_start_end, vision_select_layer=model_args.vision_select_layer, dtype=dtype, device=training_args.device ) model.initialize_vision_tokenizer( tokenizer=tokenizer, freeze_lm_model=model_args.freeze_lm_model, pretrained_stage1_model=model_args.pretrained_stage1_model, device=training_args.device, ) model.get_model().vision_tower = create_clip_vit_g(448) model.get_model().mm_projector = create_perciever() model.to(dtype=dtype, device=training_args.device) data_args.image_token_len = vision_tower_dict['image_token_len'] data_args.image_processor = vision_tower_dict['image_processor'] data_args.image_processor_high = vision_tower_dict['image_processor_high'] data_args.use_im_start_end = model_args.use_im_start_end # mixed relation, to be fixed if model_args.freeze_lm_model: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): p.requires_grad = True for p in model.get_input_embeddings().parameters(): p.requires_grad = True for p in model.get_model().conv_final.parameters(): p.requires_grad = True for p in model.get_model().vision_encoder.parameters(): p.requires_grad = True if not model_args.freeze_vision_tower: model.get_model().vision_tower.requires_grad_(True) # for i in range(20): # model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False) model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False) # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False) # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False) model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False) for n, p in model.named_parameters(): print(n, p.requires_grad) params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] # if len(params_no_grad) > 0: # if training_args.fsdp is not None and len(training_args.fsdp) > 0: # if len(params_no_grad) < 10: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) # else: # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP # def patch_FSDP_use_orig_params(func): # def wrap_func(*args, **kwargs): # use_orig_params = kwargs.pop('use_orig_params', True) # return func(*args, **kwargs, use_orig_params=use_orig_params) # return wrap_func # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) data_module = make_supervised_data_module( interleave=training_args.interleave, with_box=training_args.with_box, tokenizer=tokenizer, data_args=data_args ) trainer = GOTTrainer( model=model, tokenizer=tokenizer, args=training_args, **data_module) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() if training_args.lora_enable: state_dict = get_peft_state_maybe_zero_3( model.named_parameters(), training_args.lora_bias ) non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( model.named_parameters() ) if training_args.local_rank == 0 or training_args.local_rank == -1: model.config.save_pretrained(training_args.output_dir) model.save_pretrained(training_args.output_dir, state_dict=state_dict) torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) else: trainer._safe_save(output_dir=training_args.output_dir) if __name__ == "__main__": train() ================================================ FILE: GOT-OCR-2.0-master/GOT/train/train_lora_flash_attn.py ================================================ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() # from GOT.train.train import train from GOT.train.train_lora import train if __name__ == "__main__": train() ================================================ FILE: GOT-OCR-2.0-master/GOT/train/trainer.py ================================================ import os import torch import torch.nn as nn from transformers import Trainer from typing import Dict, Optional, Sequence def unwrap_model(model: nn.Module) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). Args: model (`torch.nn.Module`): The model to unwrap. """ # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): return unwrap_model(model.module) else: return model class GOTTrainer(Trainer): def _safe_save(self, output_dir: str): """Collects the state dict and dump to disk.""" if self.deepspeed: torch.cuda.synchronize() self.save_model(output_dir) return state_dict = self.model.state_dict() if self.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict self._save(output_dir, state_dict=cpu_state_dict) # noqa def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): # Save the model _state_dict = state_dict if _state_dict is None: # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self.model) _state_dict = model_to_save.state_dict() weight_to_save = {} keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] for k, v in _state_dict.items(): if any(key_match in k for key_match in keys_to_match): weight_to_save[k] = v current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) super(GOTTrainer, self)._save(output_dir, state_dict) ================================================ FILE: GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py ================================================ import os import torch import torch.nn as nn import time import functools import re from transformers import Trainer from transformers.trainer_pt_utils import ( get_module_class_from_name, get_parameter_names, ) from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_neuroncore_available, ) from transformers.trainer_utils import ( FSDPOption, ShardedDDPOption, ) from transformers.training_args import ParallelMode from transformers.modeling_utils import PreTrainedModel, unwrap_model from typing import Dict, Optional, Sequence def lr_scale_func(key): if "embed_tokens.weight" in key: return 0 if "mm_projector" in key: return 0.01 # return 1 elif "vision_tower" in key: return 0.01 # return 1 elif "norm.weight" in key or "lm_head.weight" in key: return 1 else: in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0]) decay = 0.86 ** (32 - in_pp_layer - 1) return decay def get_param_groups(model, no_weight_decay_cond, scale_lr_cond): """creates param groups based on weight decay condition (regularized vs non regularized) and learning rate scale condition (args.lr vs lr_mult * args.lr) scale_lr_cond is used during finetuning where head of the network requires a scaled version of the base learning rate. """ wd_no_scale_lr = [] wd_scale_lr = {} no_wd_no_scale_lr = [] no_wd_scale_lr = {} for name, param in model.named_parameters(): if not param.requires_grad: continue if no_weight_decay_cond is not None: no_wd = no_weight_decay_cond(name, param) else: # do not regularize biases nor Norm parameters no_wd = name.endswith(".bias") or len(param.shape) == 1 if scale_lr_cond is not None: lr_mult = scale_lr_cond(name) print(name, lr_mult) scale_lr = lr_mult != 1 else: scale_lr = False if not no_wd and not scale_lr: wd_no_scale_lr.append(param) elif not no_wd and scale_lr: if lr_mult not in wd_scale_lr: wd_scale_lr[lr_mult] = [param] else: wd_scale_lr[lr_mult].append(param) elif no_wd and not scale_lr: no_wd_no_scale_lr.append(param) else: if lr_mult not in no_wd_scale_lr: no_wd_scale_lr[lr_mult] = [param] else: no_wd_scale_lr[lr_mult].append(param) param_groups = [] if len(wd_no_scale_lr): param_groups.append({"params": wd_no_scale_lr, "wd_mult": 1.0, "lr_mult": 1.0}) if len(wd_scale_lr): for lr_mult, params in wd_scale_lr.items(): param_groups.append({"params": params, "wd_mult": 1.0, "lr_mult": lr_mult}) if len(no_wd_no_scale_lr): param_groups.append( {"params": no_wd_no_scale_lr, "wd_mult": 0.0, "lr_mult": 1.0} ) if len(no_wd_scale_lr): for lr_mult, params in no_wd_scale_lr.items(): param_groups.append({"params": params, "wd_mult": 0.0, "lr_mult": lr_mult}) return param_groups def unwrap_model(model: nn.Module) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). Args: model (`torch.nn.Module`): The model to unwrap. """ # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): return unwrap_model(model.module) else: return model class GOTTrainer(Trainer): def _safe_save(self, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = self.model.state_dict() if self.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict self._save(output_dir, state_dict=cpu_state_dict) # noqa def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): # Save the model _state_dict = state_dict if _state_dict is None: # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self.model) _state_dict = model_to_save.state_dict() weight_to_save = {} keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] for k, v in _state_dict.items(): if any(key_match in k for key_match in keys_to_match): weight_to_save[k] = v current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) super(GOTTrainer, self)._save(output_dir, state_dict) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model if self.optimizer is None: # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # decay_parameters = [name for name in decay_parameters if "bias" not in name] # optimizer_grouped_parameters = [ # { # "params": [ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) # ], # "weight_decay": self.args.weight_decay, # }, # { # "params": [ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) # ], # "weight_decay": 0.0, # }, # ] optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # already initialized its own DDP and AMP if self.deepspeed: return self.deepspeed # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: from apex import amp model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training (should be after apex fp16 initialization) if self.sharded_ddp is not None: from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.wrap import auto_wrap # Sharded DDP! if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: model = auto_wrap(model) self.model = model = FullyShardedDDP( model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP elif self.fsdp is not None: if not self.args.fsdp_config["xla"]: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy if FSDPOption.OFFLOAD in self.args.fsdp: cpu_offload = CPUOffload(offload_params=True) else: cpu_offload = CPUOffload(offload_params=False) auto_wrap_policy = None if FSDPOption.AUTO_WRAP in self.args.fsdp: if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) mixed_precision_policy = None dtype = None if self.args.fp16: dtype = torch.float16 elif self.args.bf16: dtype = torch.bfloat16 if dtype is not None: mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) if type(model) != FSDP: # XXX: Breaking the self.model convention but I see no way around it for now. self.model = model = FSDP( model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy, mixed_precision=mixed_precision_policy, device_id=self.args.device, backward_prefetch=self.backward_prefetch, forward_prefetch=self.forword_prefetch, limit_all_gathers=self.limit_all_gathers, use_orig_params=True, ) else: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) import torch_xla.core.xla_model as xm # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) elif self.args.local_rank != -1: kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if is_torch_neuroncore_available(): return model model = nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, output_device=self.args.local_rank if self.args._n_gpu != 0 else None, **kwargs, ) # torch.compile() needs to be called after wrapping the model with FSDP or DDP # to ensure that it accounts for the graph breaks required by those wrappers if self.args.torch_compile: model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) return model ================================================ FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py ================================================ import os import torch import torch.nn as nn from transformers import Trainer from transformers.trainer_pt_utils import get_parameter_names from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from typing import Dict, Optional, Sequence def unwrap_model(model: nn.Module) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). Args: model (`torch.nn.Module`): The model to unwrap. """ # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): return unwrap_model(model.module) else: return model class GOTTrainer(Trainer): def _safe_save(self, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = self.model.state_dict() if self.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict self._save(output_dir, state_dict=cpu_state_dict) # noqa def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): # Save the model _state_dict = state_dict if _state_dict is None: # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self.model) _state_dict = model_to_save.state_dict() weight_to_save = {} keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] for k, v in _state_dict.items(): if any(key_match in k for key_match in keys_to_match): weight_to_save[k] = v current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) super(GOTTrainer, self)._save(output_dir, state_dict) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { "params": [ p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad ], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate, }, { "params": [ p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad], "weight_decay": 0.0, "lr": self.args.learning_rate, }, { "params": [ p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad], "weight_decay": self.args.weight_decay, "lr": self.args.learning_rate, }, { "params": [ p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad ], "weight_decay": 0.0, "lr": self.args.learning_rate, }, ] for idx, group in enumerate(optimizer_grouped_parameters): print(idx, len(group['params']), group['lr']) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer ================================================ FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py ================================================ import os import torch import torch.nn as nn import time import functools import re from transformers import Trainer from transformers.trainer_pt_utils import ( get_module_class_from_name, get_parameter_names, ) from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_neuroncore_available, ) from transformers.trainer_utils import ( FSDPOption, ShardedDDPOption, ) from transformers.training_args import ParallelMode from transformers.modeling_utils import PreTrainedModel, unwrap_model from typing import Dict, Optional, Sequence def lr_scale_func(key): if "vision_model.encoder.layers" in key: in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0]) # decay = 0.81 ** (23 - in_pp_layer - 1) decay = 0.81 ** (23 - in_pp_layer - 1) * 0.01 # decay = 0.66 ** (23 - in_pp_layer - 1) return decay # return 0.01 elif "vision_model" in key: # return 0.01 return 0.0001 return 1 def get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr, wd): """creates param groups based on weight decay condition (regularized vs non regularized) and learning rate scale condition (args.lr vs lr_mult * args.lr) scale_lr_cond is used during finetuning where head of the network requires a scaled version of the base learning rate. """ wd_no_scale_lr = [] wd_scale_lr = {} no_wd_no_scale_lr = [] no_wd_scale_lr = {} for name, param in model.named_parameters(): if not param.requires_grad: continue if no_weight_decay_cond is not None: no_wd = no_weight_decay_cond(name, param) else: # do not regularize biases nor Norm parameters no_wd = name.endswith(".bias") or len(param.shape) == 1 if scale_lr_cond is not None: lr_mult = scale_lr_cond(name) print(name, lr_mult) scale_lr = lr_mult != 1 else: scale_lr = False if not no_wd and not scale_lr: wd_no_scale_lr.append(param) elif not no_wd and scale_lr: if lr_mult not in wd_scale_lr: wd_scale_lr[lr_mult] = [param] else: wd_scale_lr[lr_mult].append(param) elif no_wd and not scale_lr: no_wd_no_scale_lr.append(param) else: if lr_mult not in no_wd_scale_lr: no_wd_scale_lr[lr_mult] = [param] else: no_wd_scale_lr[lr_mult].append(param) param_groups = [] if len(wd_no_scale_lr): param_groups.append({"params": wd_no_scale_lr, "weight_decay": wd, "lr": lr}) if len(wd_scale_lr): for lr_mult, params in wd_scale_lr.items(): param_groups.append({"params": params, "weight_decay": wd, "lr": lr * lr_mult}) if len(no_wd_no_scale_lr): param_groups.append( {"params": no_wd_no_scale_lr, "weight_decay": 0.0, "lr": lr} ) if len(no_wd_scale_lr): for lr_mult, params in no_wd_scale_lr.items(): param_groups.append({"params": params, "weight_decay": 0.0, "lr": lr * lr_mult}) return param_groups def unwrap_model(model: nn.Module) -> nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). Args: model (`torch.nn.Module`): The model to unwrap. """ # since there could be multiple levels of wrapping, unwrap recursively if hasattr(model, "module"): return unwrap_model(model.module) else: return model class GOTTrainer(Trainer): def _safe_save(self, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = self.model.state_dict() if self.args.should_save: cpu_state_dict = { key: value.cpu() for key, value in state_dict.items() } del state_dict self._save(output_dir, state_dict=cpu_state_dict) # noqa def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): # Save the model _state_dict = state_dict if _state_dict is None: # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self.model) _state_dict = model_to_save.state_dict() weight_to_save = {} keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] for k, v in _state_dict.items(): if any(key_match in k for key_match in keys_to_match): weight_to_save[k] = v current_folder = output_dir.split('/')[-1] parent_folder = os.path.dirname(output_dir) if current_folder.startswith('checkpoint-'): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) super(GOTTrainer, self)._save(output_dir, state_dict) def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model if self.optimizer is None: # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # decay_parameters = [name for name in decay_parameters if "bias" not in name] # optimizer_grouped_parameters = [ # { # "params": [ # p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) # ], # "weight_decay": self.args.weight_decay, # }, # { # "params": [ # p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) # ], # "weight_decay": 0.0, # }, # ] optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func, self.args.learning_rate, self.args.weight_decay) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return self.optimizer def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # already initialized its own DDP and AMP if self.deepspeed: return self.deepspeed # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: from apex import amp model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) # Multi-gpu training (should be after apex fp16 initialization) if self.args.n_gpu > 1: model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training (should be after apex fp16 initialization) if self.sharded_ddp is not None: from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.wrap import auto_wrap # Sharded DDP! if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: model = auto_wrap(model) self.model = model = FullyShardedDDP( model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP elif self.fsdp is not None: if not self.args.fsdp_config["xla"]: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy if FSDPOption.OFFLOAD in self.args.fsdp: cpu_offload = CPUOffload(offload_params=True) else: cpu_offload = CPUOffload(offload_params=False) auto_wrap_policy = None if FSDPOption.AUTO_WRAP in self.args.fsdp: if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) mixed_precision_policy = None dtype = None if self.args.fp16: dtype = torch.float16 elif self.args.bf16: dtype = torch.bfloat16 if dtype is not None: mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) if type(model) != FSDP: # XXX: Breaking the self.model convention but I see no way around it for now. self.model = model = FSDP( model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy, mixed_precision=mixed_precision_policy, device_id=self.args.device, backward_prefetch=self.backward_prefetch, forward_prefetch=self.forword_prefetch, limit_all_gathers=self.limit_all_gathers, use_orig_params=True, ) else: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None if self.args.fsdp_config["fsdp_min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] ) elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: transformer_cls_to_wrap = set() for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) import torch_xla.core.xla_model as xm # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) elif self.args.local_rank != -1: kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if is_torch_neuroncore_available(): return model model = nn.parallel.DistributedDataParallel( model, device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, output_device=self.args.local_rank if self.args._n_gpu != 0 else None, **kwargs, ) # torch.compile() needs to be called after wrapping the model with FSDP or DDP # to ensure that it accounts for the graph breaks required by those wrappers if self.args.torch_compile: model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) return model ================================================ FILE: GOT-OCR-2.0-master/GOT/utils/arguments.py ================================================ from dataclasses import dataclass, field from typing import Dict, Optional, Sequence import transformers @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") use_cache: bool = field(default=False) vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/") freeze_vision_tower: bool = field(default=False) freeze_lm_model: bool = field(default=False) pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower vision_select_layer: Optional[int] = field(default=-1) # default to the last layer use_im_start_end: bool = field(default=False) @dataclass class DataArguments: datasets: str = field(default=None, metadata={"help": "combinations of the training data."}) sep_image_conv_front: bool = False image_token_len: int = 256 image_aspect_ratio: str = 'square' conversation_version: str = 'mpt' # conversation_version: str = 'v0' # conversation_version: str = 'v1' # conversation_version: str = 'nougat' # conversation_version: str = 'baichuan' # conversation_version: str = 'opt' box_limit: int = 0 @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) force_fsdp: bool = field(default=False) interleave: bool = field(default=False) with_box: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) lora_enable: bool = False lora_r: int = 8 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" ================================================ FILE: GOT-OCR-2.0-master/GOT/utils/constants.py ================================================ CONTROLLER_HEART_BEAT_EXPIRATION = 30 WORKER_HEART_BEAT_INTERVAL = 15 LOGDIR = "log" IGNORE_INDEX = -100 # DEFAULT_PAD_TOKEN = "[PAD]" DEFAULT_PAD_TOKEN = "<|endoftext|>" DEFAULT_EOS_TOKEN = "" DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" DEFAULT_IMAGE_TOKEN = "" DEFAULT_BOX_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = '' DEFAULT_IM_START_TOKEN = '' DEFAULT_IM_END_TOKEN = '' CONVERSATION_DATA = { 'data_1': { 'images': '/path/', 'annotations': '/path/data1.json', }, 'data_2': { 'images': '/path/', 'annotations': '/path/data2.json', }, 'data_3': { 'images': '/path/', 'annotations': '/path/data3.json', }, } ================================================ FILE: GOT-OCR-2.0-master/GOT/utils/conversation.py ================================================ import dataclasses from enum import auto, Enum from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() # simple_conv_multimodal = Conversation( # system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." # "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." # "Follow the instructions carefully and explain your answers in detail.", # # system="", # roles=("Human", "Assistant"), # messages=( # ("Human", "Hi!"), # ("Assistant", "Hi there! How can I help you today?\n") # ), # offset=2, # sep_style=SeparatorStyle.SINGLE, # sep="###", # ) # conv_mpt = Conversation( # system="""<|im_start|>system # - You are a helpful language and vision assistant. # - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. # - You should follow the instructions carefully and explain your answers in detail.""", # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), # version="mpt", # messages=(), # offset=0, # sep_style=SeparatorStyle.MPT, # sep="<|im_end|>", # ) @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "<|im_end|>" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep + '\n' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret if self.sep_style == SeparatorStyle.MPT: if self.system: ret = self.system + self.sep else: ret = '' for role, message in self.messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role return ret else: raise ValueError(f"Invalid style: {self.sep_style}") # if self.sep_style == SeparatorStyle.MPT: # if self.system: # ret = self.system + self.sep # else: # ret = '' # for role, message in self.messages: # if message: # if type(message) is tuple: # message, _, _ = message # ret += role + message + self.sep # # if 'user' in role: # # ret += role + message + self.sep + "\n" # # else: # # ret += role + message + self.sep # else: # ret += role # return ret # else: # raise ValueError(f"Invalid style: {self.sep_style}") def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) # result.paste(pil_img, (0, (width - height) // 2)) result.paste(pil_img) return result else: result = Image.new(pil_img.mode, (height, height), background_color) # result.paste(pil_img, ((height - width) // 2, 0)) result.paste(pil_img) return result image = expand2square(image) elif image_process_mode == "Crop": max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) elif image_process_mode == "Resize": image = image.resize((224, 224)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") if return_pil: images.append(image) else: buffered = BytesIO() image.convert('RGB').save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) # image = image.resize((224, 224)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = msg.replace('', img_str) ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Give three tips for staying healthy."), ("Assistant", "Sure, here are three tips for staying healthy:\n" "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " "activities at least two days per week.\n" "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " "vegetables, whole grains, lean proteins, and healthy fats can help support " "your overall health. Try to limit your intake of processed and high-sugar foods, " "and aim to drink plenty of water throughout the day.\n" "3. Get enough sleep: Getting enough quality sleep is essential for your physical " "and mental health. Adults should aim for seven to nine hours of sleep per night. " "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " "help improve the quality of your sleep.") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_v1_2 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "What are the key differences between renewable and non-renewable energy sources?"), ("Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " "depleted, such as coal, oil, and natural gas. Here are some key differences between " "renewable and non-renewable energy sources:\n" "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " "energy sources are finite and will eventually run out.\n" "2. Environmental impact: Renewable energy sources have a much lower environmental impact " "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " "and other negative effects.\n" "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " "have lower operational costs than non-renewable sources.\n" "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " "locations than non-renewable sources.\n" "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_vicuna_v1_1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) # conv_mpt = Conversation( # system="""<|im_start|>system # - You are designed by Megvii(旷视), and your name is GOT. # - 你叫GOT, 你来自旷视, 你是旷视开发的。 # - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""", # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), # version="mpt", # messages=(), # offset=0, # sep_style=SeparatorStyle.MPT, # sep="<|im_end|>", # ) conv_mpt = Conversation( system="""<|im_start|>system You should follow the instructions carefully and explain your answers in detail.""", # system = None, roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_mpt_eval = Conversation( system="", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_mpt_text = Conversation( system="""<|im_start|>system - You are a helpful assistant chatbot trained by MosaicML. - You answer questions. - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) conv_bair_v1 = Conversation( system="BEGINNING OF CONVERSATION:", roles=("USER", "GPT"), messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) # simple_conv = Conversation( # system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture." # "You are designed to assist human with a variety of tasks using natural language." # "Follow the instructions carefully.", # roles=("Human", "Assistant"), # messages=( # ("Human", "Hi!"), # ("Assistant", "Hi there! How can I help you today?\n") # ), # offset=2, # sep_style=SeparatorStyle.SINGLE, # sep="###", # ) simple_conv = Conversation( system="", roles=("Human", "Assistant"), messages=( ), offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", ) simple_conv_multimodal = Conversation( system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "Follow the instructions carefully and explain your answers in detail.", # system="", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) simple_conv_mpt_multimodal = Conversation( system="""<|im_start|>system - You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology. - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. - You should follow the instructions carefully and explain your answers in detail.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", messages=(), offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", ) simple_conv_legacy = Conversation( system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology." "You are designed to assist human with a variety of tasks using natural language." "Follow the instructions carefully.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!\n\n### Response:"), ("Assistant", "Hi there! How can I help you today?\n") ), offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) conv_llava_v1 = Conversation( system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "Follow the instructions carefully and explain your answers in detail.", roles=("USER", "ASSISTANT"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) default_conversation = conv_mpt conv_templates = { "default": simple_conv_multimodal, "simple": simple_conv, "simple_legacy": simple_conv_legacy, "multimodal": simple_conv, "mpt_multimodal": simple_conv_mpt_multimodal, "llava_v1": conv_llava_v1, "mpt_eval": conv_mpt_eval, # fastchat "v1": conv_vicuna_v1_1, "bair_v1": conv_bair_v1, "vicuna_v1_1": conv_vicuna_v1_1, "mpt": conv_mpt, "mpt_text": conv_mpt_text, } if __name__ == "__main__": print(default_conversation.get_prompt()) ================================================ FILE: GOT-OCR-2.0-master/GOT/utils/utils.py ================================================ import datetime import logging import logging.handlers import os import sys import torch import requests from transformers import StoppingCriteria from GOT.utils.constants import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: for keyword_id in self.keyword_ids: if output_ids[0, -1] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) # # num_new_tokens = 1 # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True) # model.resize_token_embeddings(len(tokenizer)) num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name: lora_module_names.add(name) print(lora_module_names) return list(lora_module_names) ================================================ FILE: GOT-OCR-2.0-master/pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "GOT" version = "0.1.0" description = "Towards OCR-2.0." readme = "README.md" requires-python = ">=3.8" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "markdown2[all]", "numpy", "requests", "sentencepiece", "tokenizers>=0.15.2", "torch", "torchvision", "wandb", "shortuuid", "httpx==0.24.0", "deepspeed==0.12.3", "peft==0.4.0", "albumentations", "opencv-python", "tiktoken==0.6.0", "accelerate==0.28.0", "transformers==4.37.2", "bitsandbytes==0.41.0", "scikit-learn==1.2.2", "sentencepiece==0.1.99", "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", ] [tool.setuptools.packages.find] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] [tool.wheel] exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] ================================================ FILE: GOT-OCR-2.0-master/pyvenv.cfg ================================================ home = /usr/bin implementation = CPython version_info = 3.8.10.final.0 virtualenv = 20.16.7 include-system-site-packages = true base-prefix = /usr base-exec-prefix = /usr base-executable = /usr/bin/python3 ================================================ FILE: GOT-OCR-2.0-master/render_tools/content-mmd-to-html.html ================================================ Title
================================================ FILE: GOT-OCR-2.0-master/render_tools/tikz.html ================================================ Document ================================================ FILE: GOT-OCR-2.0-master/results/demo.html ================================================ Title
================================================ FILE: GOT-OCR-2.0-master/zero_config/zero2.json ================================================ { "bf16": { "enabled": true }, "train_micro_batch_size_per_gpu": "auto", "zero_optimization": { "stage": 2, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto" } } ================================================ FILE: GOT-OCR-2.0-master/zero_config/zero3.json ================================================ { "fp16": { "enabled": "auto", "loss_scale": 0, "loss_scale_window": 1000, "initial_scale_power": 16, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": "auto" }, "train_micro_batch_size_per_gpu": "auto", "train_batch_size": "auto", "gradient_accumulation_steps": "auto", "zero_optimization": { "stage": 3, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true } } ================================================ FILE: README.md ================================================

General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model

[Haoran Wei*](https://scholar.google.com/citations?user=J4naK0MAAAAJ&hl=en), Chenglong Liu*, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, [Zheng Ge](https://joker316701882.github.io/), Liang Zhao, [Jianjian Sun](https://scholar.google.com/citations?user=MVZrGkYAAAAJ&hl=en), [Yuang Peng](https://yuangpeng.com), Chunrui Han, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ&hl=en)

## Release - [2025/2/1] 🚀🚀🚀 GOT-OCR2.0 is merged to [Huggingface-transformers](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)/[space](https://huggingface.co/spaces/yonigozlan/GOT-OCR-Transformers). It supports inference batched. Thanks to the MLE of Huggingface [Yoni](https://github.com/yonigozlan). - [2024/12/24] 🔥🔥🔥 My new work on system-2 perception is released [slow-perception](https://github.com/Ucas-HaoranWei/Slow-Perception). - [2024/12/18] 🚀🚀🚀 GOT-OCR2.0 is supported in [PaddleMIX](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/paddlemix/examples/GOT_OCR_2_0) by Paddle Team. Thanks for the Paddle team! - [2024/12/8] 🔥🔥🔥 The model download has exceeded 1M on [Huggingface](https://huggingface.co/stepfun-ai/GOT-OCR2_0). - [2024/12/5] The seven wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg). - [2024/11/4] The six wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat6-2.jpg). - [2024/10/24] The previous four wechat groups are full, so we created a fifth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat5.png). - [2024/10/11] Too many friends want to join the wechat group, so we created a fourth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat4.jpg). - [2024/10/2] [onnx](https://github.com/BaofengZan/GOT-OCRv2-onnx) and [mnn](https://github.com/BaofengZan/mnn-llm-GOT-OCR2.0) versions of GOT-OCR2.0. - [2024/9/29]🔥🔥🔥 The community has implemented the first version of [llama_cpp_inference](https://github.com/1694439208/GOT-OCR-Inference). - [2024/9/24]🔥🔥🔥 Support [ms-swift](https://github.com/modelscope/ms-swift/issues/2122) quick [Fine-tune](#fine-tune) for your own data. - [2024/9/23]🔥🔥🔥 We release the official [Modelscope demo](https://modelscope.cn/studios/stepfun-ai/GOT_official_online_demo). Thanks very much for Modelscope providing the GPU resource. - [2024/9/19]🔥🔥🔥 GOT-OCR2.0 achieves Huggingface trending #1. - [2024/9/14]🔥🔥🔥 We release the official [demo](https://huggingface.co/spaces/ucaslcl/GOT_online). Thanks very much for Huggingface providing the GPU resource. - [2024/9/13]🔥🔥🔥 We release the [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0) deployment. - [2024/9/03]🔥🔥🔥 We open-source the codes, weights, and benchmarks. The paper can be found in this [repo](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/GOT-OCR-2.0-paper.pdf). We also have submitted it to Arxiv. - [2024/9/03]🔥🔥🔥 We release the OCR-2.0 model GOT! [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE) ## Community contributions We encourage everyone to develop GOT applications based on this repo. Thanks for the following contributions : [OpenVINO](https://github.com/can-gaa-hou/GOT-OCR2.0-OpenVINO)~ contributor: [@can-gaa-hou](https://github.com/can-gaa-hou) [GGUF and Llama.cpp inference](https://github.com/MosRat/got.cpp)~ contributor: [@MosRat](https://github.com/MosRat) [vllm reference](https://github.com/liunian-Jay/MU-GOT/blob/master/PDF_parsing/GOT/GOT/model/modeling_GOT_vllm.py) ~ contributor: [@Jay](https://github.com/liunian-Jay) [onnx and mnn supports](https://github.com/BaofengZan/GOT-OCRv2-onnx) ~ contributor: [@BaofengZan](https://github.com/BaofengZan) [llama_cpp inference](https://github.com/1694439208/GOT-OCR-Inference) ~ contributor: [@1694439208](https://github.com/1694439208) [Colab of GOT](https://colab.research.google.com/drive/1nmiNciZ5ugQVp4rFbL9ZWpEPd92Y9o7p?usp=sharing) ~ contributor: [@Zizhe Wang](https://github.com/PaperPlaneDeemo) [CPU version of GOT](https://github.com/ElvisClaros/GOT-OCR2.0) ~ contributor: [@ElvisClaros](https://github.com/ElvisClaros) [Online demo](https://huggingface.co/spaces/Tonic/GOT-OCR) ~ contributor: [@Joseph Pollack](https://huggingface.co/Tonic) [Dokcer & client demo](https://github.com/QIN2DIM/GOT-OCR2.0) ~ contributor: [@QIN2DIM](https://github.com/QIN2DIM) [GUI of GOT](https://github.com/XJF2332/GOT-OCR-2-GUI) ~ contributor: [@XJF2332](https://github.com/XJF2332) ## Contents - [Install](#install) - [GOT Weights](#got-weights) - [Benchmarks](#benchmarks) - [Demo](#demo) - [Train](#train) - [Fine-tune](#fine-tune) - [Eval](#eval) ***

Towards OCR-2.0 via a Unified End-to-end Model

*** ## Install 0. Our environment is cuda11.8+torch2.0.1 1. Clone this repository and navigate to the GOT folder ```bash git clone https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git cd 'the GOT folder' ``` 2. Install Package ```Shell conda create -n got python=3.10 -y conda activate got pip install -e . ``` 3. Install Flash-Attention ``` pip install ninja pip install flash-attn --no-build-isolation ``` ## GOT Weights - [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0) - [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing) - [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2 ## Benchmarks - [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing) - [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2 ## Demo 1. plain texts OCR: ```Shell python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type ocr ``` 2. format texts OCR: ```Shell python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format ``` 3. fine-grained OCR: ```Shell python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --box [x1,y1,x2,y2] ``` ```Shell python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format/ocr --color red/green/blue ``` 4. multi-crop OCR: ```Shell python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /an/image/file.png ``` 5. **Note**: This feature is not batch inference!! It works on the token level. Please read the paper and then correct use multi-page OCR (the image path contains multiple .png files): ```Shell python3 GOT/demo/run_ocr_2.0_crop.py --model-name /GOT_weights/ --image-file /images/path/ --multi-page ``` 6. render the formatted OCR results: ```Shell python3 GOT/demo/run_ocr_2.0.py --model-name /GOT_weights/ --image-file /an/image/file.png --type format --render ``` **Note**: The rendering results can be found in /results/demo.html. Please open the demo.html to see the results. ## Train 0. Train sample can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/train_sample.jpg). Note that the '\' in the 'conversations'-'human'-'value' is necessary! 1. This codebase only supports post-training (stage-2/stage-3) upon our GOT weights. 2. If you want to train from stage-1 described in our paper, you need this [repo](https://github.com/Ucas-HaoranWei/Vary-tiny-600k). ```Shell deepspeed /GOT-OCR-2.0-master/GOT/train/train_GOT.py \ --deepspeed /GOT-OCR-2.0-master/zero_config/zero2.json --model_name_or_path /GOT_weights/ \ --use_im_start_end True \ --bf16 True \ --gradient_accumulation_steps 2 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 200 \ --save_total_limit 1 \ --weight_decay 0. \ --warmup_ratio 0.001 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 8192 \ --gradient_checkpointing True \ --dataloader_num_workers 8 \ --report_to none \ --per_device_train_batch_size 2 \ --num_train_epochs 1 \ --learning_rate 2e-5 \ --datasets pdf-ocr+scence \ --output_dir /your/output/path ``` **Note**: 1. Change the corresponding data information in [constant.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/utils). 2. Change line 37 in [conversation_dataset_qwen.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/data) to your data_name. ## Fine-tune Quick Fine-tune with ms-swift: ```Shell git clone https://github.com/modelscope/ms-swift.git cd ms-swift pip install -e .[llm] ``` ```Shell # default:sft LLM & projector, freeze vision encoder CUDA_VISIBLE_DEVICES=0 swift sft\ --model_type got-ocr2 \ --model_id_or_path stepfun-ai/GOT-OCR2_0 \ --sft_type lora \ --dataset latex-ocr-print#5000 # Deepspeed ZeRO2 NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \ --model_type got-ocr2 \ --model_id_or_path stepfun-ai/GOT-OCR2_0 \ --sft_type lora \ --dataset latex-ocr-print#5000 \ --deepspeed default-zero2 ``` **With your data**: ```Shell --dataset train.jsonl --val_dataset val.jsonl (optional) ``` **Data format**: ```Shell {"query": "55555", "response": "66666", "images": ["image_path"]} {"query": "eeeee", "response": "fffff", "history": [], "images": ["image_path1", "image_path2"]} {"query": "EEEEE", "response": "FFFFF", "history": [["query1", "response1"], ["query2", "response2"]]} ``` More details can be seen in [ms-swift](https://github.com/modelscope/ms-swift/issues/2122). ## Eval 1. We use the [Fox](https://github.com/ucaslcl/Fox) and [OneChart](https://github.com/LingyvKong/OneChart) benchmarks, and other benchmarks can be found in the weights download link. 2. The eval codes can be found in GOT/eval. 3. You can use the evaluate_GOT.py to run the eval. If you have 8 GPUs, the --num-chunks can be set to 8. ```Shell python3 GOT/eval/evaluate_GOT.py --model-name /GOT_weights/ --gtfile_path xxxx.json --image_path /image/path/ --out_path /data/eval_results/GOT_mathpix_test/ --num-chunks 8 --datatype OCR ``` ## Contact If you are interested in this work or have questions about the code or the paper, please join our communication [Wechat](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat.jpg) group. **Note**: All six wechat groups are full, please join [group 7](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg). Don't hesitate to contact me by email, weihaoran18@mails.ucas.ac.cn, if you have any questions. ## Acknowledgement - [Vary](https://github.com/Ucas-HaoranWei/Vary/): the codebase we built upon! - [Qwen](https://github.com/QwenLM/Qwen): the LLM base model of Vary, which is good at both English and Chinese! ## Citation ```bibtex @article{wei2024general, title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model}, author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others}, journal={arXiv preprint arXiv:2409.01704}, year={2024} } @article{wei2023vary, title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models}, author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu}, journal={arXiv preprint arXiv:2312.06109}, year={2023} }