Full Code of Ucas-HaoranWei/GOT-OCR2.0 for AI

main 179ed086ad6b cached
41 files
224.5 KB
56.5k tokens
183 symbols
1 requests
Download .txt
Showing preview only (238K chars total). Download the full file or copy to clipboard to get everything.
Repository: Ucas-HaoranWei/GOT-OCR2.0
Branch: main
Commit: 179ed086ad6b
Files: 41
Total size: 224.5 KB

Directory structure:
gitextract_is3bdbjd/

├── GOT-OCR-2.0-master/
│   ├── GOT/
│   │   ├── __init__.py
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset.py
│   │   │   └── conversation_dataset_qwen.py
│   │   ├── demo/
│   │   │   ├── process_results.py
│   │   │   ├── run_ocr_2.0.py
│   │   │   └── run_ocr_2.0_crop.py
│   │   ├── eval/
│   │   │   ├── eval_GOT_ocr.py
│   │   │   ├── evaluate_GOT.py
│   │   │   ├── multi_hardware_eval_GOT.py
│   │   │   └── pyevaltools/
│   │   │       ├── __init__.py
│   │   │       ├── eval_ocr.py
│   │   │       ├── eval_ocr_format.py
│   │   │       ├── eval_ocr_scene.py
│   │   │       └── merge_results.py
│   │   ├── model/
│   │   │   ├── GOT_ocr_2_0.py
│   │   │   ├── __init__.py
│   │   │   ├── plug/
│   │   │   │   └── blip_process.py
│   │   │   └── vision_encoder/
│   │   │       ├── __init__.py
│   │   │       └── vary_b.py
│   │   ├── train/
│   │   │   ├── train.py
│   │   │   ├── train_GOT.py
│   │   │   ├── train_flash_attn.py
│   │   │   ├── train_lora.py
│   │   │   ├── train_lora_flash_attn.py
│   │   │   ├── trainer.py
│   │   │   ├── trainer_llm_llrd.py
│   │   │   ├── trainer_vit_fixlr.py
│   │   │   └── trainer_vit_llrd.py
│   │   └── utils/
│   │       ├── arguments.py
│   │       ├── constants.py
│   │       ├── conversation.py
│   │       └── utils.py
│   ├── pyproject.toml
│   ├── pyvenv.cfg
│   ├── render_tools/
│   │   ├── content-mmd-to-html.html
│   │   └── tikz.html
│   ├── results/
│   │   └── demo.html
│   └── zero_config/
│       ├── zero2.json
│       └── zero3.json
└── README.md

================================================
FILE CONTENTS
================================================

================================================
FILE: GOT-OCR-2.0-master/GOT/__init__.py
================================================


================================================
FILE: GOT-OCR-2.0-master/GOT/data/__init__.py
================================================

import torch
import transformers
from dataclasses import dataclass, field

from GOT.utils.constants import *


@dataclass
class DataCollatorForSupervisedDataset(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        # print(instances)
        # exit()
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        images = [torch.stack(instance['image']) for instance in instances]

        # if 'flattened_patches' in instances[0]['image_high'][0].keys():
        #     images_high = [torch.stack([instance['image_high'][0]['flattened_patches']]) for instance in instances]
        # else:
        images_high = [torch.stack(instance['image_high']) for instance in instances]

        images = list(zip(images, images_high))


        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
            
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_INDEX)
        
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            images=images,
        )
        return batch
    

def make_supervised_data_module(interleave, with_box, tokenizer, data_args):

    if data_args.conversation_version == 'mpt':
        from GOT.data.conversation_dataset_qwen import ConversationDataset
        dataset_cls = ConversationDataset
        
    train_dataset = dataset_cls(
        tokenizer=tokenizer,
        datasets=data_args.datasets,
        multimodal_cfg=dict(
            sep_image_conv_front=data_args.sep_image_conv_front,
            image_token_len=data_args.image_token_len,
            image_aspect_ratio=data_args.image_aspect_ratio,
            use_im_start_end=data_args.use_im_start_end,
            image_processor=data_args.image_processor,
            image_processor_high = data_args.image_processor_high,
            box_limit=data_args.box_limit,
        )
    )
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

================================================
FILE: GOT-OCR-2.0-master/GOT/data/base_dataset.py
================================================
import io
import os
import copy
import json
import logging
import torch
import transformers
import boto3
from typing import List, Optional, Tuple, Union, Dict, Sequence
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from GOT.utils.constants import *



class BaseDataset(Dataset):
    def __init__(
        self, 
        datasets: str,
        tokenizer: transformers.PreTrainedTokenizer,
        multimodal_cfg: dict
    ):
        super(BaseDataset, self).__init__()
        self.tokenizer = tokenizer
        self.multimodal_cfg = multimodal_cfg

        logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")

    def image_processor(self, image):
        # processor = self.multimodal_cfg['image_processor']  # the first processor, usually is the clip pretrained model (vit)
        processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)
        image_high = image.copy()

        #  Vary old codes
        
        # # TODO the 'keep', 'padding' only used for the first processor
        # if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
        #     max_hw, min_hw = max(image.size), min(image.size)
        #     aspect_ratio = max_hw / min_hw
        #     max_len, min_len = 448, 224
        #     shortest_edge = int(min(max_len / aspect_ratio, min_len))
        #     image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
        # elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
        #     def expand2square(pil_img, background_color):
        #         width, height = pil_img.size
        #         if width == height:
        #             return pil_img
        #         elif width > height:
        #             result = Image.new(pil_img.mode, (width, width), background_color)
        #             result.paste(pil_img) # for simpler box processing
        #             return result
        #         else:
        #             result = Image.new(pil_img.mode, (height, height), background_color)
        #             result.paste(pil_img) # for simpler box processing
        #             return result
        #     image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
        #     image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
        # else:
        #     image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

        image_high = processor_high(image_high)

        return image_high

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        pass

================================================
FILE: GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py
================================================

import io
import os
import copy
import json
import logging
import torch
import random

from typing import List, Optional, Tuple, Union, Dict, Sequence
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from GOT.data.base_dataset import BaseDataset
from GOT.utils.constants import *
from GOT.utils import conversation as conversation_lib
import boto3
import smart_open
from megfile import smart_glob
from natsort import natsorted


class ConversationDataset(BaseDataset):
    """Conversation format dataset stage2 fine-tuning."""

    def __init__(self, datasets, tokenizer, multimodal_cfg):
        super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
        # v0 version format conversation
        conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
        logging.warning("Formatting inputs into conversation type: mpt-fixed")
        logging.warning("Loading data...")

        list_data_dict = []
        list_image_path = []

        # TODO add your data  [data1, data2, data3, .....]
        got_data_dict = {
            "pdf-ocr": ["data1", "data2"],
            'scene-ocr': ["data3", "data4"]
            # ......
        }
        for name_all in datasets.split("+"):
            for name in got_data_dict[name_all]:
                dataset = CONVERSATION_DATA[name]

                data_path = dataset['annotations']
                data = json.load(open(data_path, "r"))

                list_data_dict.extend(data)

                image_path = dataset['images']

                list_image_path.extend([image_path] * len(data))

                logging.warning(f"Data from {data_path} provide {len(data)} conversations.")

        assert len(list_data_dict) == len(list_image_path)
        logging.warning(f"{len(list_data_dict)} conversations in total.")
        a_new_list = list(zip(list_data_dict, list_image_path))
        random.shuffle(a_new_list)
        list_data_dict_new, list_image_path_new = zip(*a_new_list)
        self.list_data_dict = list_data_dict_new
        self.list_image_path = list_image_path_new

        self.im_patch_token = 151859

        self.im_start_token = 151857

        self.im_end_token = 151858
    
    def multimodal_processor(self, sources, flag_num_patches):
        for source in sources:
            if self.multimodal_cfg['sep_image_conv_front']:
                assert DEFAULT_IMAGE_TOKEN in source[0]['value']
                source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
                source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']

            for sentence in source:
                replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']*flag_num_patches
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                # sentence["value"] = str(sentence["value"]).replace('\qquad', '\quad')
                sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
        return sources

    def _tokenize_fn(self, strings):
        """Tokenize a list of strings."""
        tokenized_list = [
            self.tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
            ) for text in strings
        ]
        input_ids = labels = [
            tokenized.input_ids[0] for tokenized in tokenized_list
        ]
        input_ids_lens = labels_lens = [
            tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
            for tokenized in tokenized_list
        ]
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def _mask_targets(self, target, tokenized_lens, speakers):
        # cur_idx = 0
        cur_idx = tokenized_lens[0]
        tokenized_lens = tokenized_lens[1:]
        target[:cur_idx] = IGNORE_INDEX
        for tokenized_len, speaker in zip(tokenized_lens, speakers):
            if speaker.lower() == "human":
                target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
            cur_idx += tokenized_len

    def token_processor(self, sources, image_name):
        conv = conversation_lib.default_conversation.copy()
        roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

        # Apply prompt templates
        conversations = []
        for i, source in enumerate(sources):
            if roles[source[0]["from"]] != conv.roles[0]:
                # Skip the first one if it is not from human
                source = source[1:]

            conv.messages = []
            for j, sentence in enumerate(source):
                role = roles[sentence["from"]]
                assert role == conv.roles[j % 2], f"{i}"
                conv.append_message(role, sentence["value"])
            conversations.append(conv.get_prompt())

        # Tokenize conversations


        input_ids = self.tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
        ).input_ids

        # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
        targets = input_ids.clone()
        assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

        # Mask targets
        sep = conv.sep + conv.roles[1]
        for conversation, target in zip(conversations, targets):
            total_len = int(target.ne(self.tokenizer.pad_token_id).sum())

            rounds = conversation.split(conv.sep)
            re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
            for conv_idx in range(3, len(rounds), 2):
                re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt
            cur_len = 0
            target[:cur_len] = IGNORE_INDEX
            for i, rou in enumerate(re_rounds):
                if rou == "":
                    break

                parts = rou.split(sep)
                if len(parts) != 2:
                    break
                parts[0] += sep
                round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
                # round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer))
                # instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
                instruction_len = len(self.tokenizer(parts[0]).input_ids)
                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

                cur_len += round_len
            target[cur_len:] = IGNORE_INDEX

            if cur_len < self.tokenizer.model_max_length:
                if cur_len != total_len:
                    target[:] = IGNORE_INDEX
                    print(
                        f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                        f" (ignored)"
                    )
                    print(image_name)

        return dict(
            input_ids=input_ids,
            labels=targets,
        )

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        # data = self.list_data_dict[i]
        data = copy.deepcopy(self.list_data_dict[i])

        if isinstance(data, dict):
            image_list =  []
            image_high_list = []
            flag_num_patches = 1
            if 'image' in data:
                image_path = self.list_image_path[i]
                image_file = data['image']

                # multi-crop or multi page, only support .png files
                if ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path):
                    if image_file[0] == '/':
                        patch_dir = image_path[:-1] + image_file
                        patches = smart_glob(patch_dir + '*.png')
                    else:
                        patch_dir = image_path + image_file
                        patches = smart_glob(patch_dir + '*.png')

                    # print(patches)
                    if not patches:
                        print(f'cannot glob the dir {patch_dir}.')
                        return self.__getitem__(0)

                    # sort multi images by name
                    patches = natsorted(patches)
                    flag_num_patches = len(patches)

                    for patch in patches:
                        try:
                            image = Image.open(patch).convert('RGB')
                        except:
                            print(f'cannot identify image file {patch}.')
                            return self.__getitem__(0)

                        try:
                            img = self.image_processor(image)
                            image_list.append(img)
                            image_high_list.append(img)
                        except:
                            print(f'image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!')
                            return self.__getitem__(0)

                else:
                    flag_num_patches = 1
                    try:
                        image = Image.open(image_path + image_file).convert('RGB')
                    except:
                        print(f'cannot identify image file {image_file}.')
                        return self.__getitem__(0)

                    try:
                        image = self.image_processor(image)
                    except:
                        print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
                        return self.__getitem__(0)

            conversations = self.multimodal_processor([data["conversations"]], flag_num_patches)
            # print(conversations)
            # exit()
        else:
            conversations = [data]

        # align with fastchat & llava here, put the conversation into a list for tokenization
        image_name = image_path + image_file
        data_dict = self.token_processor(conversations, image_name)
        data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
        
        if isinstance(data, dict) and 'image' in data:
            if image_list and image_high_list:
                data_dict['image'] = image_list
                data_dict['image_high'] = image_high_list
            else:
                data_dict['image'] = [image]
                data_dict['image_high'] = [image]
        else:
            # crop_size = self.multimodal_cfg['image_processor'].crop_size
            # data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
            # Vary for two image, GOT does not use the data_dict['image]
            data_dict['image'] = [torch.zeros(3, 1024, 1024)]
            data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
        return data_dict



================================================
FILE: GOT-OCR-2.0-master/GOT/demo/process_results.py
================================================
import string

punctuation_dict = {
    ",": ",",
    "。": ".",

}


# import os
 
def svg_to_html(svg_content, output_filename):

    html_content = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>SVG Embedded in HTML</title>
    </head>
    <body>
        <svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
            {svg_content}
        </svg>
    </body>
    </html>
    """

    with open(output_filename, 'w') as file:
        file.write(html_content)
 

 

================================================
FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor

from transformers import TextStreamer
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html
import string

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'

DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'


 
translation_table = str.maketrans(punctuation_dict)


def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def eval_model(args):
    # Model
    disable_torch_init()
    model_name = os.path.expanduser(args.model_name)

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()

    

    model.to(device='cuda',  dtype=torch.bfloat16)


    # TODO vary old codes, NEED del 
    image_processor = BlipImageEvalProcessor(image_size=1024)

    image_processor_high =  BlipImageEvalProcessor(image_size=1024)

    use_im_start_end = True

    image_token_len = 256

    image = load_image(args.image_file)

    w, h = image.size
    # print(image.size)
    
    if args.type == 'format':
        qs = 'OCR with format: '
    else:
        qs = 'OCR: '

    if args.box:
        bbox = eval(args.box)
        if len(bbox) == 2:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
        if len(bbox) == 4:
            bbox[0] = int(bbox[0]/w*1000)
            bbox[1] = int(bbox[1]/h*1000)
            bbox[2] = int(bbox[2]/w*1000)
            bbox[3] = int(bbox[3]/h*1000)
        if args.type == 'format':
            qs = str(bbox) + ' ' + 'OCR with format: '
        else:
            qs = str(bbox) + ' ' + 'OCR: '

    if args.color:
        if args.type == 'format':
            qs = '[' + args.color + ']' + ' ' + 'OCR with format: '
        else:
            qs = '[' + args.color + ']' + ' ' + 'OCR: '

    if use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs 
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs



    conv_mode = "mpt"
    args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    print(prompt)


    inputs = tokenizer([prompt])


    # vary old codes, no use
    image_1 = image.copy()
    image_tensor = image_processor(image)


    image_tensor_1 = image_processor_high(image_1)


    input_ids = torch.as_tensor(inputs.input_ids).cuda()

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)


    with torch.autocast("cuda", dtype=torch.bfloat16):
        output_ids = model.generate(
            input_ids,
            images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
            do_sample=False,
            num_beams = 1,
            no_repeat_ngram_size = 20,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=[stopping_criteria]
            )
        

        if args.render:
            print('==============rendering===============')

            outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            
            if outputs.endswith(stop_str):
                outputs = outputs[:-len(stop_str)]
            outputs = outputs.strip()

            if '**kern' in outputs:
                import verovio
                from cairosvg import svg2png
                import cv2
                import numpy as np
                tk = verovio.toolkit()
                tk.loadData(outputs)
                tk.setOptions({"pageWidth": 2100, "footer": 'none',
               'barLineWidth': 0.5, 'beamMaxSlope': 15,
               'staffLineWidth': 0.2, 'spacingStaff': 6})
                tk.getPageCount()
                svg = tk.renderToSVG()
                svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")

                svg_to_html(svg, "./results/demo.html")

            if args.type == 'format' and '**kern' not in outputs:

                
                if  '\\begin{tikzpicture}' not in outputs:
                    html_path = "./render_tools/" + "/content-mmd-to-html.html"
                    html_path_2 = "./results/demo.html"
                    right_num = outputs.count('\\right')
                    left_num = outputs.count('\left')

                    if right_num != left_num:
                        outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')


                    outputs = outputs.replace('"', '``').replace('$', '')

                    outputs_list = outputs.split('\n')
                    gt= ''
                    for out in outputs_list:
                        gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' 
                    
                    gt = gt[:-2]

                    with open(html_path, 'r') as web_f:
                        lines = web_f.read()
                        lines = lines.split("const text =")
                        new_web = lines[0] + 'const text ='  + gt  + lines[1]
                else:
                    html_path = "./render_tools/" + "/tikz.html"
                    html_path_2 = "./results/demo.html"
                    outputs = outputs.translate(translation_table)
                    outputs_list = outputs.split('\n')
                    gt= ''
                    for out in outputs_list:
                        if out:
                            if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
                                while out[-1] == ' ':
                                    out = out[:-1]
                                    if out is None:
                                        break
    
                                if out:
                                    if out[-1] != ';':
                                        gt += out[:-1] + ';\n'
                                    else:
                                        gt += out + '\n'
                            else:
                                gt += out + '\n'


                    with open(html_path, 'r') as web_f:
                        lines = web_f.read()
                        lines = lines.split("const text =")
                        new_web = lines[0] + gt + lines[1]

                with open(html_path_2, 'w') as web_f_new:
                    web_f_new.write(new_web)





if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--type", type=str, required=True)
    parser.add_argument("--box", type=str, default= '')
    parser.add_argument("--color", type=str, default= '')
    parser.add_argument("--render", action='store_true')
    args = parser.parse_args()

    eval_model(args)


================================================
FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor
from transformers import TextStreamer
from natsort import natsorted
import glob




DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'



def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    # print(target_ratios)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # print(target_aspect_ratio)
    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images



def eval_model(args):
    # Model
    disable_torch_init()
    model_name = os.path.expanduser(args.model_name)

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()



    model.to(device='cuda',  dtype=torch.bfloat16)


    # vary old codes, no use
    image_processor = BlipImageEvalProcessor(image_size=1024)

    image_processor_high =  BlipImageEvalProcessor(image_size=1024)

    use_im_start_end = True


    image_token_len = 256




    image_list = []

    if args.multi_page:
        qs = 'OCR with format across multi pages: '
        # only for png files
        patches = glob.glob(args.image_file + '/*png')
        patches = natsorted(patches)
        sub_images = []
        for sub_image in patches:
            sub_images.append(load_image(sub_image))

        ll = len(patches)

    else:
        qs = 'OCR with format upon the patch reference: '
        img = load_image(args.image_file)
        sub_images = dynamic_preprocess(img)
        ll = len(sub_images)

    for p in sub_images:

        image = p
        image_1 = image.copy()
        # no use, vary old codes
        image_tensor = image_processor(image)


        image_tensor_1 = image_processor_high(image_1)

        image_list.append(image_tensor_1)


    image_list = torch.stack(image_list)

    print('====new images batch size======:  ',image_list.shape)





    # qs = args.query
    if use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs 
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs


    

    conv_mode = "mpt"
    args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()


    inputs = tokenizer([prompt])

    input_ids = torch.as_tensor(inputs.input_ids).cuda()

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)


    with torch.autocast("cuda", dtype=torch.bfloat16):
        output_ids = model.generate(
            input_ids,
            images=[(image_list.half().cuda(), image_list.half().cuda())],
            do_sample=False,
            num_beams = 1,
            # no_repeat_ngram_size = 20,
            streamer=streamer,
            max_new_tokens=4096,
            stopping_criteria=[stopping_criteria]
            )
        
    if args.render:
        print('==============rendering===============')
        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()

        html_path = "./render_tools/" + "/content-mmd-to-html.html"
        html_path_2 = "./results/demo.html"
        right_num = outputs.count('\\right')
        left_num = outputs.count('\left')

        if right_num != left_num:
            outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')


        outputs = outputs.replace('"', '``').replace('$', '')

        outputs_list = outputs.split('\n')
        gt= ''
        for out in outputs_list:
            gt +=  '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n' 
        
        gt = gt[:-2]

        with open(html_path, 'r') as web_f:
            lines = web_f.read()
            lines = lines.split("const text =")
            new_web = lines[0] + 'const text ='  + gt  + lines[1]
            
        with open(html_path_2, 'w') as web_f_new:
            web_f_new.write(new_web)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--image-file", type=str, required=True)
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--multi-page", action='store_true')
    parser.add_argument("--render", action='store_true')
    args = parser.parse_args()

    eval_model(args)


================================================
FILE: GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py
================================================
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

from tqdm import tqdm
from PIL import Image
import json
import os
import requests
from PIL import Image
from io import BytesIO
import math

import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from GOT.utils.conversation import conv_templates, SeparatorStyle
from GOT.utils.utils import disable_torch_init
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
from GOT.model import *
from GOT.utils.utils import KeywordsStoppingCriteria

from PIL import Image

import os
import requests
from PIL import Image
from io import BytesIO
from GOT.model.plug.blip_process import BlipImageEvalProcessor

from transformers import TextStreamer
from GOT.model.plug.transforms import train_transform, test_transform
import re
from GOT.demo.process_results import punctuation_dict, svg_to_html

DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'


import string
 
translation_table = str.maketrans(punctuation_dict)


def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    # print(target_ratios)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # print(target_aspect_ratio)
    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # print(blocks)

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images



def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]



output_list = []

def eval_model(args):
    # Model
    disable_torch_init()
    model_name = os.path.expanduser(args.model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()


    # vary old codes, no use
    image_processor = BlipImageEvalProcessor(image_size=1024)


    # image_processor_high = BlipImageEvalProcessor(image_size=1280)
    image_processor_high = BlipImageEvalProcessor(image_size=1024)
    use_im_start_end = True



    # image_token_len = 400
    image_token_len = 256
    gts_path = args.gtfile_path
    gts = json.load(open(gts_path))

    # gts = gts[0]


    print("Generate Results......")


    if "OCR" in args.datatype:
            gts = get_chunk(gts, args.num_chunks, args.chunk_idx)


    for ann in tqdm(gts):
        output_json = {}
        
        if "OCR" in args.datatype:
            qs = ann["conversations"][0]["value"]
        else:
            qs = ann["question"]
            # ans = ann["answers"][0]
        
        qs2 = qs
        image_file = ann["image"] 
        if 'Text' in args.datatype:
            image_file = image_file + '.jpg'
        if "VQAv2" in args.datatype:
            image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'
        if "Cap" in args.datatype:
            image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'

        image_file_path = os.path.join(args.image_path, image_file)
        # print(image_file_path)
        # exit()

        # qs = args.query
        # if mm_use_im_start_end:



        multi_crop = False
        if multi_crop:
            image_list = []
            # qs =  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN  + '\n' +  'OCR with format upon the patch reference: '
            img = load_image(image_file_path)
            sub_images = dynamic_preprocess(img)
            ll = len(sub_images)
            for p in sub_images:
                image = p
                image_1 = image.copy()
                # vary old code, NO USE
                image_tensor = image_processor_high(image_1)

                # image_tensor_1 = image_processor_high.preprocess(image_1, return_tensors='pt')['pixel_values'][0]

                image_tensor_1 = image_processor_high(image_1)

                image_list.append(image_tensor_1)

                # print(image_tensor_1.shape)

            image_list = torch.stack(image_list)

        else:
            ll = 1
            image = load_image(image_file_path)
            image_1 = image.copy()
            # image_1 = image_1.resize((1024, 1024))

            # vary old code, NO USE
            image_tensor = image_processor_high(image_1)

            image_tensor_1 = image_processor_high(image_1)
            # image_tensor_1 = torch.zeros(3, 1024, 1024)


        qs =  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len*ll + DEFAULT_IM_END_TOKEN  + '\n' +  'OCR with format: '

        

        conv_mode = "mpt"

        if args.conv_mode is not None and conv_mode != args.conv_mode:
            print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
        else:
            args.conv_mode = conv_mode

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        inputs = tokenizer([prompt])

        input_ids = torch.as_tensor(inputs.input_ids).cuda()

        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

        if multi_crop:
            with torch.autocast("cuda", dtype=torch.bfloat16):
                output_ids = model.generate(
                    input_ids,
                    images=[(image_list.half().cuda(), image_list.half().cuda())],
                    do_sample=False,
                    num_beams = 1,
                    # temperature=0.2,
                    # no_repeat_ngram_size = 20,
                    # streamer=streamer,
                    max_new_tokens=4096,
                    stopping_criteria=[stopping_criteria]
                    )
        else:
            with torch.autocast("cuda", dtype=torch.bfloat16):
                output_ids = model.generate(
                    input_ids,
                    images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
                    do_sample=False,
                    num_beams = 1,
                    # temperature=0.2,
                    no_repeat_ngram_size = 20,
                    # encoder_repetition_penalty = 1.2,
                    # penalty_alpha=0.2,
                    # top_k=3,
                    max_new_tokens=4096,
                    stopping_criteria=[stopping_criteria]
                    )

        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        
        if outputs.endswith(stop_str):
            outputs = outputs[:-len(stop_str)]
        outputs = outputs.strip()
        # outputs = outputs.strip()[:-1]
        if "Cap" in args.datatype:
            # output_json['image'] = ann["image"]
            output_json['image_id'] = ann["id"]
            output_json["caption"] = outputs
        else:
            # output_json['questionId'] = qs_id
            # output_json['question_id'] = qs_id
            output_json['image'] = ann["image"]
            output_json['question'] = qs 
            output_json['label'] = ann["conversations"][1]["value"]
            output_json['answer'] = outputs
        output_list.append(output_json)

    filename = args.out_path + "/results_" + str(args.chunk_idx) + ".json"
    with open(filename, 'w', encoding="utf-8") as file_obj:
        json.dump(output_list, file_obj, ensure_ascii=False, indent=1)
        # print(outputs)
    # print("Evaluate Results... ")
    # doc_text_eval(gts_path, filename, args.datatype)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--gtfile_path", type=str, required=True)
    parser.add_argument("--image_path", type=str, required=True)
    parser.add_argument("--out_path", type=str, required=True)
    parser.add_argument("--datatype", type=str, required=True)  # Text or Doc
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    # parser.add_argument("--query", type=str, required=True)
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    args = parser.parse_args()
    print(args)
    eval_model(args)


================================================
FILE: GOT-OCR-2.0-master/GOT/eval/evaluate_GOT.py
================================================
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--gtfile_path", type=str, required=True)
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--datatype", type=str, required=True)  # Text\Doc\VQAv2\Cap
# parser.add_argument("--eval", type=str, required=True)
args = parser.parse_args()

os.system("python3 -m GOT.eval.multi_hardware_eval_GOT" + " "
          + "--model-name" + " " + args.model_name + " "
          + "--gtfile_path" + " " + args.gtfile_path + " "
          + "--image_path" + " " + args.image_path + " "
          + "--out_path" + " " + args.out_path + " "
          + "--num-chunks" + " " + str(args.num_chunks) + " "
          + "--temperature" + " " + str(args.temperature) + " "
          + "--datatype" + " " + args.datatype
          )

print("Evaluating.....")
os.system("python3 -m GOT.eval.pyevaltools.merge_results" + " "
          + "--out_path" + " " + args.out_path)


# if args.datatype == "OCR":


a_type = 'plain'  # 'palin'; 'format'; 'scene'

if a_type == 'plain':
    os.system("python3 -m GOT.eval.pyevaltools.eval_ocr" + " "
                + "--out_path" + " " + args.out_path + " "
                + "--gt_path" + " " + args.gtfile_path + " "
                + "--datatype" + " " + args.datatype
                )
if a_type == 'format':
    os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_format" + " "
            + "--out_path" + " " + args.out_path + " "
            + "--gt_path" + " " + args.gtfile_path + " "
            + "--datatype" + " " + args.datatype
            )
if a_type == 'scene':
    os.system("python3 -m GOT.eval.pyevaltools.eval_ocr_scene" + " "
        + "--out_path" + " " + args.out_path + " "
        + "--gt_path" + " " + args.gtfile_path + " "
        + "--datatype" + " " + args.datatype
        )

================================================
FILE: GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py
================================================
import os
import argparse
from multiprocessing import Pool
# from GOT.eval.merge_results import merge_outputs
# from GOT.eval.doctextVQA import doc_text_eval


def run_eval(chunk_id, model_name, gtfile_path, image_path, out_path, num_chunks, datatype, temperature):
    os.system("CUDA_VISIBLE_DEVICES=" + str(chunk_id) + " "
              + "python3 -m GOT.eval.eval_GOT_ocr" + " "
              + "--model-name" + " " + model_name + " "
              + "--gtfile_path" +  " " + gtfile_path + " "
              + "--image_path" + " " + image_path + " "
              + "--out_path" +  " " + out_path + " "
              + "--num-chunks" + " " +  str(num_chunks) + " "
              + "--chunk-idx" + " " +  str(chunk_id) + " "
              + "--temperature" + " " +  str(temperature) + " "
              + "--datatype" + " " + datatype
              )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--gtfile_path", type=str, required=True)
    parser.add_argument("--image_path", type=str, required=True)
    parser.add_argument("--out_path", type=str, required=True)
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--datatype", type=str, required=True)  # Text or Doc
    # parser.add_argument("--eval", type=str, required=True)
    args = parser.parse_args()

    num_chunks = args.num_chunks
    
    if os.path.exists(args.out_path) == False:
        os.makedirs(args.out_path)


    with Pool(num_chunks) as p:
        for i in range(num_chunks):
            chunk_id = i
            p.apply_async(run_eval, (chunk_id, args.model_name, args.gtfile_path,
                                     args.image_path, args.out_path, num_chunks, args.datatype, args.temperature))
        p.close()
        p.join()



================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/__init__.py
================================================
author='aagrawal'


================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py
================================================
import json
# from doctextVQAeval import VQAEval

import argparse
# import fitz as pymupdf
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
# from loguru import logger
# nltk.download('wordnet')
from nltk.translate import meteor_score

# from marker_scoring import score_text
# from utils import contain_chinese_string
parser = argparse.ArgumentParser()

parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()

def preprocess(text, predict_root_):
    if 'InternVL' in predict_root_:
        text = text.split("All words in the image:\n")[1]
        text = text.split("[UNUSED_TOKEN_145]")[0]
    return text

def contain_chinese_string(text):
    # 使用正则表达式匹配中文字符
    chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
    return bool(chinese_pattern.search(text))


inline_reg = re.compile(r"\\\((.*?)(?<!\\)\\\)")
display_reg = re.compile(r"\\\[(.+?)(?<!\\)\\\]")
table_reg = re.compile(r"\\begin\{tabular\}(.+?)(?:\\end\{tabular\}|$)", re.S)

def split_text(pages, a_type):
    """
    Split a list of pages into text, inline math, display math, and table blocks.

    Args:
        pages: The pages to split.
    """
    text, math, table = [], [], []
    for page in pages:
        for i, reg in enumerate([inline_reg, display_reg, table_reg]):
            matches = "\n".join(reg.findall(page[a_type]))
            if i == 2:
                table.append(matches)
            elif i == 1:
                math[-1] += matches
            else:
                math.append(matches)
        page_str = page[a_type]
        text.append(page_str.strip())
    return text, math, table

def nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: int = 2):
    """
    Args:
    - heavy_mode:
        0 is clean mode, only similar, bleu, f1
        1 is normal, do not include edit_dist
        2 is heavy, total
    """
    metrics = {}

    # pred = preprocess(pred, predict_root_)

    if len(pred) < minlen or len(gt) < minlen:
        return metrics

    # metrics["similar"] = score_text(pred, gt)
    if contain_chinese_string(gt) or contain_chinese_string(pred):
        reference = jieba.lcut(gt)
        hypothesis = jieba.lcut(pred)
    else:
        reference = gt.split()
        hypothesis = pred.split()

    metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)
    if heavy_mode >= 1:
        # try:
        metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)
        # except LookupError:
        #     metrics["meteor"] = np.nan

    reference = set(reference)
    hypothesis = set(hypothesis)
    metrics["f_measure"] = f_measure(reference, hypothesis)

    if heavy_mode >= 1:
        metrics["precision"] = precision(reference, hypothesis)
        metrics["recall"] = recall(reference, hypothesis)
    if heavy_mode == 2:
        # 速度太慢
        metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
    return metrics

def doc_formated_text_eval(gt_root_, predict_root_, datatype):

    predicts = json.load(open(predict_root_, encoding='utf-8'))
    
    # print(predicts)

    gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')
    pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')
    text_results = []
    math_results = []
    table_results = []

    for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):
        # try:
        # text, math, table
        text_gts, text_pres = gt0, pre0
        math_gts, math_pres = gt1, pre1
        table_gts, table_pres = gt2, pre2

        # for text_gt, text_pre in zip(text_gts, text_pres):
        ans = nougat_per_metrics(predict_root_, text_gts, text_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            text_results.append(ans)
        # for math_gt, math_pre in zip(math_gts, math_pres):
        ans = nougat_per_metrics(predict_root_, math_gts, math_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            math_results.append(ans)
        
        # for table_gt, table_pre in zip(table_gts, table_pres):
        ans = nougat_per_metrics(predict_root_, table_gts, table_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            table_results.append(ans)
    
    mean_dict = {}
    # print((result))
    # print(len(result))
    mean_dict["eval question num"] = len(text_results)
    mean_dict['text'] = {}
    mean_dict['math'] = {}
    mean_dict['table'] = {}

    for k, v in text_results[0].items():
        mean_dict['text'][k] = 0
        mean_dict['math'][k] = 0
        mean_dict['table'][k] = 0
    
    for each in text_results:
        for k, v in each.items():
            mean_dict['text'][k] += v
    
    for each in math_results:
        for k, v in each.items():
            mean_dict['math'][k] += v

    for each in table_results:
        for k, v in each.items():
            mean_dict['table'][k] += v

    for k, v in mean_dict['text'].items():
        mean_dict['text'][k] /= len(text_results)

    for k, v in mean_dict['math'].items():
        mean_dict['math'][k] /= len(math_results)


    for k, v in mean_dict['table'].items():
        mean_dict['table'][k] /= len(table_results)

    print(json.dumps(mean_dict, indent=4))

def doc_text_eval(gt_root_, predict_root_, datatype):

   
    predicts = json.load(open(predict_root_, encoding='utf-8'))
    
    # print(predicts)
    result = []
    for ann in predicts:
        try:
            ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
            if len(ans) == 0:
                continue
            result.append(ans)
        except:
            assert False, print("ERROR!!! Check yout output!!!")
    
    mean_dict = {}
    # print((result))
    # print(len(result))
    mean_dict["eval question num"] = len(result)
    for k, v in result[0].items():
        mean_dict[k] = 0
    
    for each in result:
        for k, v in each.items():
            mean_dict[k] += v

    for k, v in mean_dict.items():
        if k == "eval question num":
            continue
        mean_dict[k] /= len(result)
    print(json.dumps(mean_dict, indent=4))

# doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc")


# doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)

doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)

================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py
================================================
import json
# from doctextVQAeval import VQAEval

import argparse
# import fitz as pymupdf
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
# from loguru import logger
# nltk.download('wordnet')
from nltk.translate import meteor_score

# from marker_scoring import score_text
# from utils import contain_chinese_string
parser = argparse.ArgumentParser()

parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()

def preprocess(text, predict_root_):
    if 'InternVL' in predict_root_:
        text = text.split("All words in the image:\n")[1]
        text = text.split("[UNUSED_TOKEN_145]")[0]
    return text

def contain_chinese_string(text):
    # 使用正则表达式匹配中文字符
    chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
    return bool(chinese_pattern.search(text))


inline_reg = re.compile(r"\\\((.*?)(?<!\\)\\\)")
display_reg = re.compile(r"\\\[(.+?)(?<!\\)\\\]")
table_reg = re.compile(r"\\begin\{tabular\}(.+?)(?:\\end\{tabular\}|$)", re.S)

def split_text(pages, a_type):
    """
    Split a list of pages into text, inline math, display math, and table blocks.

    Args:
        pages: The pages to split.
    """
    text, math, table = [], [], []
    for page in pages:
        for i, reg in enumerate([inline_reg, display_reg, table_reg]):
            matches = "\n".join(reg.findall(page[a_type]))
            if i == 2:
                table.append(matches)
            elif i == 1:
                math[-1] += matches
            else:
                math.append(matches)
        page_str = page[a_type]
        text.append(page_str.strip())
    return text, math, table

def nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: int = 2):
    """
    Args:
    - heavy_mode:
        0 is clean mode, only similar, bleu, f1
        1 is normal, do not include edit_dist
        2 is heavy, total
    """
    metrics = {}

    # pred = preprocess(pred, predict_root_)

    if len(pred) < minlen or len(gt) < minlen:
        return metrics

    # metrics["similar"] = score_text(pred, gt)
    if contain_chinese_string(gt) or contain_chinese_string(pred):
        reference = jieba.lcut(gt)
        hypothesis = jieba.lcut(pred)
    else:
        reference = gt.split()
        hypothesis = pred.split()

    metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)
    if heavy_mode >= 1:
        # try:
        metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)
        # except LookupError:
        #     metrics["meteor"] = np.nan

    reference = set(reference)
    hypothesis = set(hypothesis)
    metrics["f_measure"] = f_measure(reference, hypothesis)

    if heavy_mode >= 1:
        metrics["precision"] = precision(reference, hypothesis)
        metrics["recall"] = recall(reference, hypothesis)
    if heavy_mode == 2:
        # 速度太慢
        metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))
    return metrics

def doc_formated_text_eval(gt_root_, predict_root_, datatype):

    predicts = json.load(open(predict_root_, encoding='utf-8'))
    
    # print(predicts)

    gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')
    pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')
    text_results = []
    math_results = []
    table_results = []

    for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):
        # try:
        # text, math, table
        text_gts, text_pres = gt0, pre0
        math_gts, math_pres = gt1, pre1
        table_gts, table_pres = gt2, pre2

        # for text_gt, text_pre in zip(text_gts, text_pres):
        ans = nougat_per_metrics(predict_root_, text_gts, text_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            text_results.append(ans)
        # for math_gt, math_pre in zip(math_gts, math_pres):
        ans = nougat_per_metrics(predict_root_, math_gts, math_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            math_results.append(ans)
        
        # for table_gt, table_pre in zip(table_gts, table_pres):
        ans = nougat_per_metrics(predict_root_, table_gts, table_pres)
        # if len(ans) == 0:
        #     continue
        if ans:
            table_results.append(ans)
    
    mean_dict = {}
    # print((result))
    # print(len(result))
    mean_dict["eval question num"] = len(text_results)
    mean_dict['text'] = {}
    mean_dict['math'] = {}
    mean_dict['table'] = {}

    for k, v in text_results[0].items():
        mean_dict['text'][k] = 0
        mean_dict['math'][k] = 0
        mean_dict['table'][k] = 0
    
    for each in text_results:
        for k, v in each.items():
            mean_dict['text'][k] += v
    
    for each in math_results:
        for k, v in each.items():
            mean_dict['math'][k] += v

    for each in table_results:
        for k, v in each.items():
            mean_dict['table'][k] += v

    for k, v in mean_dict['text'].items():
        mean_dict['text'][k] /= len(text_results)

    for k, v in mean_dict['math'].items():
        mean_dict['math'][k] /= len(math_results)


    for k, v in mean_dict['table'].items():
        mean_dict['table'][k] /= len(table_results)

    print(json.dumps(mean_dict, indent=4))

def doc_text_eval(gt_root_, predict_root_, datatype):

   
    predicts = json.load(open(predict_root_, encoding='utf-8'))
    
    # print(predicts)
    result = []
    for ann in predicts:
        try:
            ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
            if len(ans) == 0:
                continue
            result.append(ans)
        except:
            assert False, print("ERROR!!! Check yout output!!!")
    
    mean_dict = {}
    # print((result))
    # print(len(result))
    mean_dict["eval question num"] = len(result)
    for k, v in result[0].items():
        mean_dict[k] = 0
    
    for each in result:
        for k, v in each.items():
            mean_dict[k] += v

    for k, v in mean_dict.items():
        if k == "eval question num":
            continue
        mean_dict[k] /= len(result)
    print(json.dumps(mean_dict, indent=4))

# doc_text_eval("/data/data/DocVQA/val/val_v1.0.json", "/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json", "Doc")


doc_formated_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)

# doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)

================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py
================================================
import json
import argparse
import nltk
from nltk.metrics import precision, recall, f_measure
import numpy as np
import jieba
# import megfile as mf
import pickle
import pandas as pd
import re
from nltk.translate import meteor_score

parser = argparse.ArgumentParser()

parser.add_argument("--out_path", type=str, required=True)
parser.add_argument("--gt_path", type=str, required=True)
parser.add_argument("--datatype", type=str, required=True)
args = parser.parse_args()

def preprocess(text, predict_root_):
    if 'InternVL' in predict_root_:
        text = text.split("All words in the image:\n")[1]
        text = text.split("[UNUSED_TOKEN_145]")[0]
    return text

def contain_chinese_string(text):
    chinese_pattern = re.compile(r'[\u4e00-\u9fa5]')
    return bool(chinese_pattern.search(text))

def nougat_per_metrics(predict_root_, pred, gt, minlen=1):

    metrics = {}

    if len(pred) < minlen or len(gt) < minlen:
        return metrics


    reference = list(gt)
    hypothesis = list(pred)

    metrics["bleu"] = nltk.translate.bleu([reference], hypothesis)

    metrics["meteor"] = meteor_score.meteor_score([reference], hypothesis)

    reference = set(reference)
    hypothesis = set(hypothesis)
    metrics["f_measure"] = f_measure(reference, hypothesis)
    metrics["precision"] = precision(reference, hypothesis)
    metrics["recall"] = recall(reference, hypothesis)
    metrics["edit_dist"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))

    return metrics

def doc_text_eval(gt_root_, predict_root_, datatype):

   

    predicts = json.load(open(predict_root_, encoding='utf-8'))
    
    result = []
    for ann in predicts:
        try:
            ans = nougat_per_metrics(predict_root_, ann["label"], ann["answer"])
            if len(ans) == 0:
                continue
            result.append(ans)
        except:
            assert False, print("ERROR!!! Check yout output!!!")
    
    mean_dict = {}

    mean_dict["eval question num"] = len(result)
    for k, v in result[0].items():
        mean_dict[k] = 0
    
    for each in result:
        for k, v in each.items():
            mean_dict[k] += v

    for k, v in mean_dict.items():
        if k == "eval question num":
            continue
        mean_dict[k] /= len(result)
    print(json.dumps(mean_dict, indent=4))


doc_text_eval(args.gt_path, args.out_path + "/results_final.json", args.datatype)

================================================
FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py
================================================
import os
import json
import argparse

def merge_outputs(out_path):
    files = os.listdir(out_path)
    # print(files)
    alist = []
    for file in files:
        alist += json.load(open(os.path.join(out_path, file), encoding='utf-8'))
    # print(len(alist))

    filename = out_path + "/results_final" + ".json"
    with open(filename, 'w', encoding="utf-8") as file_obj:
        json.dump(alist, file_obj, ensure_ascii=False, indent=1)

parser = argparse.ArgumentParser()
parser.add_argument("--out_path", type=str, required=True)
args = parser.parse_args()

merge_outputs(args.out_path)

================================================
FILE: GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py
================================================
from transformers import AutoConfig, AutoModelForCausalLM, \
                         Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \
                         CLIPVisionModel, CLIPImageProcessor
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from GOT.utils.constants import *
from GOT.model.vision_encoder.vary_b import build_vary_vit_b
from GOT.model.plug.blip_process import BlipImageEvalProcessor

class GOTConfig(Qwen2Config):
    model_type = "GOT"


class GOTQwenModel(Qwen2Model):
    config_class = GOTConfig

    def __init__(self, config: Qwen2Config):
        super(GOTQwenModel, self).__init__(config)

        self.vision_tower_high = build_vary_vit_b()

        self.mm_projector_vary =  nn.Linear(1024, 1024)


    def initialize_vision_modules(
        self, 
        vision_tower,
        pretrained_stage1_model=None,
        freeze_vision_tower=False,
        use_im_start_end=False,
        vision_select_layer=-1,
        dtype=torch.float16,
        device="cuda"
    ):

        # Vary old codes, not use in GOT
        image_processor = BlipImageEvalProcessor(image_size=1024)
        # 1024*1024

        image_processor_high = BlipImageEvalProcessor(image_size=1024)


      
        self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)

        self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)


        image_token_len = 256

        self.config.vision_tower = vision_tower
        self.config.image_token_len = image_token_len
        # self.config.use_im_start_end = use_im_start_end
        self.config.use_im_start_end = True

        self.config.vision_select_layer = vision_select_layer
        self.config.freeze_vision_tower = freeze_vision_tower
        
        return dict(
            image_processor=image_processor,
            image_processor_high=image_processor_high,
            image_token_len=image_token_len,
        )
         
    # def get_input_embeddings(self, x):
    #     return self.wte(x)
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        # HACK: replace back original embeddings for LLaVA pretraining
        orig_embeds_params = getattr(self, 'orig_embeds_params', None)
        if orig_embeds_params is not None:
            with torch.no_grad():
                self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)


        vision_tower_high = getattr(self, 'vision_tower_high', None)


        if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
        # if True:
            # assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images")
            # print(im)
            use_im_start_end = getattr(self.config, "use_im_start_end", -1)

            vision_select_layer = getattr(self.config, "vision_select_layer", -1)
            im_patch_token = getattr(self.config, "im_patch_token", -1)
            im_start_token = getattr(self.config, "im_start_token", -1)
            im_end_token = getattr(self.config, "im_end_token", -1)
            freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)

            im_patch_token = 151859

            im_start_token = 151857

            im_end_token = 151858
            


            image_features = []
            

            for image in images:
                P, C, H, W = image[1].shape
                # with torch.set_grad_enabled(True):
                #     # print(image[1].shape)
                #     cnn_feature = vision_tower_high(image[1])
                #     cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256  1024
                #     # image_features.append(cnn_feature)
                # image_features_2.append(cnn_feature)
                if P == 1:
                    with torch.set_grad_enabled(False):
                        # print(image[1].shape)
                        cnn_feature = vision_tower_high(image[1])
                        cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
                        # image_features.append(cnn_feature)
                    # image_features_2.append(cnn_feature)
                    image_feature = self.mm_projector_vary(cnn_feature)
                    image_features.append(image_feature)

                else:
                    image_patches = torch.unbind(image[1])
                    image_patches_features = []
                    for image_patch in image_patches:
                        image_p = torch.stack([image_patch])
                        with torch.set_grad_enabled(False):
                            cnn_feature_p = vision_tower_high(image_p)
                            cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
                        image_feature_p = self.mm_projector_vary(cnn_feature_p)
                        image_patches_features.append(image_feature_p)
                    image_feature = torch.cat(image_patches_features, dim=1)
                    # print(P)
                    # print(image_feature.shape)
                    # exit()
                    image_features.append(image_feature)



            dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
            dummy_image_features = dummy_image_features_2
            use_im_start_end = True
            new_input_embeds = []
            for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
                if (cur_input_ids == im_patch_token).sum() == 0:
                    # multimodal LLM, but the current sample is not multimodal
                    cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
                    new_input_embeds.append(cur_input_embeds)
                    continue

                if use_im_start_end:
                    if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
                        raise ValueError("The number of image start tokens and image end tokens should be the same.")
                    
                    image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
                    for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
                        per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
                        num_patches = per_cur_image_features.shape[0]

                        if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
                            raise ValueError("The image end token should follow the image start token.")
                        
                        cur_input_embeds = torch.cat(
                            (
                                cur_input_embeds[:image_start_token_pos+1], 
                                per_cur_image_features, 
                                cur_input_embeds[image_start_token_pos + num_patches + 1:]
                            ), 
                            dim=0
                        )


                    new_input_embeds.append(cur_input_embeds)
                else:
                    raise NotImplementedError

            inputs_embeds = torch.stack(new_input_embeds, dim=0)

        return super(GOTQwenModel, self).forward(
            input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
            inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
            output_attentions=output_attentions, output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )



class GOTQwenForCausalLM(Qwen2ForCausalLM):
    config_class = GOTConfig
    # supports_gradient_checkpointing = True

    def __init__(self, config):
        super(Qwen2ForCausalLM, self).__init__(config)
        self.model = GOTQwenModel(config)

        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    # def _set_gradient_checkpointing(self, module, value=False):
    #     if isinstance(module, GOTQwenModel):
    #         module.gradient_checkpointing = value
    # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
    # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
        
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        # print(input_ids)
        # print(len(images))

        # print(inputs_embeds)

        outputs  = self.model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            images=images,
            return_dict=return_dict
            
        )


        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        # logits

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        # Omit tokens covered by past_key_values
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                cache_length = past_key_values.get_seq_length()
                past_length = past_key_values.seen_tokens
                max_cache_length = past_key_values.get_max_length()
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
            }
        )
        return model_inputs

    def initialize_vision_tokenizer(
        self, 
        tokenizer, 
        freeze_lm_model=False, 
        pretrained_stage1_model=None,
        device="cuda"
    ):
        config = self.get_model().config

        # add image patch token <image>
        # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        self.resize_token_embeddings(len(tokenizer))
        # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]

        config.im_patch_token = 151859

        config.use_im_start_end = True

        # add image start token <im_start> and end token <im_end>
        if config.use_im_start_end:
            # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))
            # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])

            config.im_start_token, config.im_end_token = 151857, 151858


AutoConfig.register("GOT", GOTConfig)
AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)



================================================
FILE: GOT-OCR-2.0-master/GOT/model/__init__.py
================================================

from .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig



================================================
FILE: GOT-OCR-2.0-master/GOT/model/plug/blip_process.py
================================================
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import cv2
import numpy as np

import torch

# from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image

class BaseProcessor:
    def __init__(self):
        self.transform = lambda x: x
        return

    def __call__(self, item):
        return self.transform(item)

    # @classmethod
    # def from_config(cls, cfg=None):
    #     return cls()

    # def build(self, **kwargs):
    #     cfg = OmegaConf.create(kwargs)

    #     return self.from_config(cfg)

class BlipImageBaseProcessor(BaseProcessor):
    def __init__(self, mean=None, std=None):
        if mean is None:
            mean = (0.48145466, 0.4578275, 0.40821073)
        if std is None:
            std = (0.26862954, 0.26130258, 0.27577711)
        # mean = (0.0, 0.0, 0.0)
        # std = (1.0, 1.0, 1.0)

        self.normalize = transforms.Normalize(mean, std)


## aug functions
def identity_func(img):
    return img


def autocontrast_func(img, cutoff=0):
    """
    same output as PIL.ImageOps.autocontrast
    """
    n_bins = 256

    def tune_channel(ch):
        n = ch.size
        cut = cutoff * n // 100
        if cut == 0:
            high, low = ch.max(), ch.min()
        else:
            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
            low = np.argwhere(np.cumsum(hist) > cut)
            low = 0 if low.shape[0] == 0 else low[0]
            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
        if high <= low:
            table = np.arange(n_bins)
        else:
            scale = (n_bins - 1) / (high - low)
            offset = -low * scale
            table = np.arange(n_bins) * scale + offset
            table[table < 0] = 0
            table[table > n_bins - 1] = n_bins - 1
        table = table.clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def equalize_func(img):
    """
    same output as PIL.ImageOps.equalize
    PIL's implementation is different from cv2.equalize
    """
    n_bins = 256

    def tune_channel(ch):
        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
        non_zero_hist = hist[hist != 0].reshape(-1)
        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
        if step == 0:
            return ch
        n = np.empty_like(hist)
        n[0] = step // 2
        n[1:] = hist[:-1]
        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def rotate_func(img, degree, fill=(0, 0, 0)):
    """
    like PIL, rotate by degree, not radians
    """
    H, W = img.shape[0], img.shape[1]
    center = W / 2, H / 2
    M = cv2.getRotationMatrix2D(center, degree, 1)
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
    return out


def solarize_func(img, thresh=128):
    """
    same output as PIL.ImageOps.posterize
    """
    table = np.array([el if el < thresh else 255 - el for el in range(256)])
    table = table.clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def color_func(img, factor):
    """
    same output as PIL.ImageEnhance.Color
    """
    ## implementation according to PIL definition, quite slow
    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
    #  out = blend(degenerate, img, factor)
    #  M = (
    #      np.eye(3) * factor
    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
    #  )[np.newaxis, np.newaxis, :]
    M = np.float32(
        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
    ) * factor + np.float32([[0.114], [0.587], [0.299]])
    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
    return out


def contrast_func(img, factor):
    """
    same output as PIL.ImageEnhance.Contrast
    """
    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
    table = (
        np.array([(el - mean) * factor + mean for el in range(256)])
        .clip(0, 255)
        .astype(np.uint8)
    )
    out = table[img]
    return out


def brightness_func(img, factor):
    """
    same output as PIL.ImageEnhance.Contrast
    """
    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def sharpness_func(img, factor):
    """
    The differences the this result and PIL are all on the 4 boundaries, the center
    areas are same
    """
    kernel = np.ones((3, 3), dtype=np.float32)
    kernel[1][1] = 5
    kernel /= 13
    degenerate = cv2.filter2D(img, -1, kernel)
    if factor == 0.0:
        out = degenerate
    elif factor == 1.0:
        out = img
    else:
        out = img.astype(np.float32)
        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
        out = out.astype(np.uint8)
    return out


def shear_x_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, factor, 0], [0, 1, 0]])
    out = cv2.warpAffine(
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
    ).astype(np.uint8)
    return out


def translate_x_func(img, offset, fill=(0, 0, 0)):
    """
    same output as PIL.Image.transform
    """
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, -offset], [0, 1, 0]])
    out = cv2.warpAffine(
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
    ).astype(np.uint8)
    return out


def translate_y_func(img, offset, fill=(0, 0, 0)):
    """
    same output as PIL.Image.transform
    """
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [0, 1, -offset]])
    out = cv2.warpAffine(
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
    ).astype(np.uint8)
    return out


def posterize_func(img, bits):
    """
    same output as PIL.ImageOps.posterize
    """
    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
    return out


def shear_y_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [factor, 1, 0]])
    out = cv2.warpAffine(
        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
    ).astype(np.uint8)
    return out


def cutout_func(img, pad_size, replace=(0, 0, 0)):
    replace = np.array(replace, dtype=np.uint8)
    H, W = img.shape[0], img.shape[1]
    rh, rw = np.random.random(2)
    pad_size = pad_size // 2
    ch, cw = int(rh * H), int(rw * W)
    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
    out = img.copy()
    out[x1:x2, y1:y2, :] = replace
    return out


### level to args
def enhance_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        return ((level / MAX_LEVEL) * 1.8 + 0.1,)

    return level_to_args


def shear_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 0.3
        if np.random.random() > 0.5:
            level = -level
        return (level, replace_value)

    return level_to_args


def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * float(translate_const)
        if np.random.random() > 0.5:
            level = -level
        return (level, replace_value)

    return level_to_args


def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * cutout_const)
        return (level, replace_value)

    return level_to_args


def solarize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 256)
        return (level,)

    return level_to_args


def none_level_to_args(level):
    return ()


def posterize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 4)
        return (level,)

    return level_to_args


def rotate_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 30
        if np.random.random() < 0.5:
            level = -level
        return (level, replace_value)

    return level_to_args


func_dict = {
    "Identity": identity_func,
    "AutoContrast": autocontrast_func,
    "Equalize": equalize_func,
    "Rotate": rotate_func,
    "Solarize": solarize_func,
    "Color": color_func,
    "Contrast": contrast_func,
    "Brightness": brightness_func,
    "Sharpness": sharpness_func,
    "ShearX": shear_x_func,
    "TranslateX": translate_x_func,
    "TranslateY": translate_y_func,
    "Posterize": posterize_func,
    "ShearY": shear_y_func,
}

translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
    "Identity": none_level_to_args,
    "AutoContrast": none_level_to_args,
    "Equalize": none_level_to_args,
    "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
    "Solarize": solarize_level_to_args(MAX_LEVEL),
    "Color": enhance_level_to_args(MAX_LEVEL),
    "Contrast": enhance_level_to_args(MAX_LEVEL),
    "Brightness": enhance_level_to_args(MAX_LEVEL),
    "Sharpness": enhance_level_to_args(MAX_LEVEL),
    "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
    "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
    "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
    "Posterize": posterize_level_to_args(MAX_LEVEL),
    "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
}


class RandomAugment(object):
    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
        self.N = N
        self.M = M
        self.isPIL = isPIL
        if augs:
            self.augs = augs
        else:
            self.augs = list(arg_dict.keys())

    def get_random_ops(self):
        sampled_ops = np.random.choice(self.augs, self.N)
        return [(op, 0.5, self.M) for op in sampled_ops]

    def __call__(self, img):
        if self.isPIL:
            img = np.array(img)
        ops = self.get_random_ops()
        for name, prob, level in ops:
            if np.random.random() > prob:
                continue
            args = arg_dict[name](level)
            img = func_dict[name](img, *args)
        return img


class VideoRandomAugment(object):
    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
        self.N = N
        self.M = M
        self.p = p
        self.tensor_in_tensor_out = tensor_in_tensor_out
        if augs:
            self.augs = augs
        else:
            self.augs = list(arg_dict.keys())

    def get_random_ops(self):
        sampled_ops = np.random.choice(self.augs, self.N, replace=False)
        return [(op, self.M) for op in sampled_ops]

    def __call__(self, frames):
        assert (
            frames.shape[-1] == 3
        ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."

        if self.tensor_in_tensor_out:
            frames = frames.numpy().astype(np.uint8)

        num_frames = frames.shape[0]

        ops = num_frames * [self.get_random_ops()]
        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]

        frames = torch.stack(
            list(map(self._aug, frames, ops, apply_or_not)), dim=0
        ).float()

        return frames

    def _aug(self, img, ops, apply_or_not):
        for i, (name, level) in enumerate(ops):
            if not apply_or_not[i]:
                continue
            args = arg_dict[name](level)
            img = func_dict[name](img, *args)
        return torch.from_numpy(img)


# if __name__ == "__main__":
#     a = RandomAugment()
#     img = np.random.randn(32, 32, 3)
#     a(img)






class BlipImageTrainProcessor(BlipImageBaseProcessor):
    def __init__(
        self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
    ):
        super().__init__(mean=mean, std=std)

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    image_size,
                    scale=(min_scale, max_scale),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                # transforms.RandomHorizontalFlip(),
                RandomAugment(
                    2,
                    5,
                    isPIL=True,
                    augs=[
                        "Identity",
                        # "AutoContrast",
                        "Brightness",
                        "Sharpness",
                        "Equalize",
                        # "ShearX",
                        # "ShearY",
                        # "TranslateX",
                        # "TranslateY",
                        # "Rotate",
                    ],
                ),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def __call__(self, item):
        return self.transform(item)


class BlipImageEvalProcessor(BlipImageBaseProcessor):
    def __init__(self, image_size=384, mean=None, std=None):
        super().__init__(mean=mean, std=std)

        self.transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                self.normalize,
            ]
        )

    def __call__(self, item):
        return self.transform(item)


# if __name__ == "__main__":
#     a = BlipImageTrainProcessor(image_size=1024)
#     # img = np.random.randn(1024, 1024, 3)
#     # x = torch.zeros(1024, 1024, 3)
#     x = Image.open("/data/codes/GOT-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
#     print(x.size)
#     y = a(x)

#     print(y.size())


================================================
FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/__init__.py
================================================



================================================
FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Tuple, Type

from functools import partial

import torch
import torch.nn as nn

from typing import Type

# from GOT.model.vision_encoder.vitg_qwen import Resampler
import math


class Projector(nn.Module):
    def __init__(
        self,
        width: 256,
        n_queries: int = 256,
        output_dim: int = 4096,
        **kwargs
    ):
        super().__init__()

        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.attn_pool = Resampler(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
        )
        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x


class MLPBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        mlp_dim: int,
        act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
        self.act = act()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin2(self.act(self.lin1(x)))


# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
    def __init__(
        self,
        img_size: int = 1024,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        out_chans: int = 256,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        use_abs_pos: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        global_attn_indexes: Tuple[int, ...] = (),
    ) -> None:
        """
        Args:
            img_size (int): Input image size.
            patch_size (int): Patch size.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
            depth (int): Depth of ViT.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_abs_pos (bool): If True, use absolute positional embeddings.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks.
            global_attn_indexes (list): Indexes for blocks using global attention.
        """
        super().__init__()
        self.img_size = img_size

        self.patch_embed = PatchEmbed(
            kernel_size=(patch_size, patch_size),
            stride=(patch_size, patch_size),
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        self.pos_embed: Optional[nn.Parameter] = None
        if use_abs_pos:
            # Initialize absolute positional embedding with pretrain image size.
            self.pos_embed = nn.Parameter(
                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
            )

        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                norm_layer=norm_layer,
                act_layer=act_layer,
                use_rel_pos=use_rel_pos,
                rel_pos_zero_init=rel_pos_zero_init,
                window_size=window_size if i not in global_attn_indexes else 0,
                input_size=(img_size // patch_size, img_size // patch_size),
            )
            self.blocks.append(block)

        self.neck = nn.Sequential(
            nn.Conv2d(
                embed_dim,
                out_chans,
                kernel_size=1,
                bias=False,
            ),
            LayerNorm2d(out_chans),
            nn.Conv2d(
                out_chans,
                out_chans,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            LayerNorm2d(out_chans),
        )

        
        self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
        self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        if self.pos_embed is not None:
            x = x + self.pos_embed

        for blk in self.blocks:
            x = blk(x)

        x = self.neck(x.permute(0, 3, 1, 2))
        x = self.net_2(x)
        x = self.net_3(x)


        return x


class Block(nn.Module):
    """Transformer blocks with support of window attention and residual propagation blocks"""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then
                use global attention.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        """
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
        )

        self.norm2 = norm_layer(dim)
        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)

        self.window_size = window_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.norm1(x)
        # Window partition
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)

        x = self.attn(x)
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x


class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
            rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (tuple(int, int) or None): Input resolution for calculating the relative
                positional parameter size.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert (
                input_size is not None
            ), "Input size must be provided if using relative positional encoding."
            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        if self.use_rel_pos:
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        x = self.proj(x)

        return x


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows, (Hp, Wp)


def window_unpartition(
    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
    """
    Window unpartition into original sequences and removing padding.
    Args:
        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
        window_size (int): window size.
        pad_hw (Tuple): padded height and width (Hp, Wp).
        hw (Tuple): original height and width (H, W) before padding.

    Returns:
        x: unpartitioned sequences with [B, H, W, C].
    """
    Hp, Wp = pad_hw
    H, W = hw
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

    if Hp > H or Wp > W:
        x = x[:, :H, :W, :].contiguous()
    return x


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    """
    Get relative positional embeddings according to the relative positions of
        query and key sizes.
    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]


def add_decomposed_rel_pos(
    attn: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
    Args:
        attn (Tensor): attention map.
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        attn (Tensor): attention map with added relative positional embeddings.
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    attn = (
        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)

    return attn


class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        kernel_size: Tuple[int, int] = (16, 16),
        stride: Tuple[int, int] = (16, 16),
        padding: Tuple[int, int] = (0, 0),
        in_chans: int = 3,
        embed_dim: int = 768,
    ) -> None:
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
        """
        super().__init__()

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x



def build_vary_vit_b(checkpoint=None):
    return _build_vary(
        encoder_embed_dim=768,
        encoder_depth=12,
        encoder_num_heads=12,
        encoder_global_attn_indexes=[2, 5, 8, 11],
        checkpoint=checkpoint,
    )


def _build_vary(
    encoder_embed_dim,
    encoder_depth,
    encoder_num_heads,
    encoder_global_attn_indexes,
    checkpoint=None,
):
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size
    image_encoder=ImageEncoderViT(
            depth=encoder_depth,
            embed_dim=encoder_embed_dim,
            img_size=image_size,
            mlp_ratio=4,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            num_heads=encoder_num_heads,
            patch_size=vit_patch_size,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=encoder_global_attn_indexes,
            window_size=14,
            out_chans=prompt_embed_dim,
        )
    
    # if checkpoint is not None:
    #     # with open(checkpoint, "rb") as f:
    #     state_dict = torch.load(checkpoint)
    #     # print(state_dict.keys())
    #     # for key in state_dict:
    #     # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
    #     # ocr-anyting
    #     # image_encoder.load_state_dict(state_dict, strict=True)
    #     # tob
    #     # model.vision_tower.
    #     image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)
    #     print(checkpoint)
    return image_encoder




if __name__ == '__main__':

    x = torch.zeros(2, 3, 1024, 1024)

    # x.permute(0, 3, 1, 2)

    net = build_vary_vit_b(checkpoint ='/mnt/shared-storage/tenant/hypertext/xpkong/jycode/checkpoint/pytorch_model.bin')
    
    # mlp = Projector(width=256, n_queries = 256, output_dim = 768)
    y = net(x)
    y = y.flatten(2).permute(0, 2, 1)
    print(y.shape)
    # y = mlp(y)
    
    # y = net_2(y)
    # y = net_3(y)
    # 

    # print(y.shape)

================================================
FILE: GOT-OCR-2.0-master/GOT/train/train.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import logging
import pathlib
import torch
import transformers

# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import GOTLlamaForCausalLM
from GOT.model import *
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import smart_tokenizer_and_embedding_resize
from GOT.model.vision_encoder.sam import build_sam_vit_b
from GOT.model.vision_encoder.swin_transformer import build_swin_transformer
def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = GOTLlamaForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        '/data/hypertext/xpkong/newcode/checkpoints/kly-vary-1025-cc595-pretrain/',
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )

    # tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)

    # # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval()

    # model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')


    if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path:
        if tokenizer.pad_token is None:
            smart_tokenizer_and_embedding_resize(
                special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
                tokenizer=tokenizer,
                model=model,
            )
        if "llama" in model_args.model_name_or_path:
            tokenizer.add_special_tokens({
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            })
    else:
        tokenizer.pad_token = tokenizer.unk_token

    # tokenizer.pad_token = DEFAULT_UNK_TOKEN
    # tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.add_special_tokens({'pad_token':'<|endoftext|>'})

    dtype = torch.float32
    if training_args.fp16:
        dtype = torch.float16
    if training_args.bf16:
        dtype = torch.bfloat16

    vision_tower_dict = model.get_model().initialize_vision_modules(
        vision_tower=model_args.vision_tower,
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        freeze_vision_tower=model_args.freeze_vision_tower,
        use_im_start_end=model_args.use_im_start_end,
        vision_select_layer=model_args.vision_select_layer,
        dtype=dtype,
        device=training_args.device
    )

    model.initialize_vision_tokenizer(
        tokenizer=tokenizer, 
        freeze_lm_model=model_args.freeze_lm_model, 
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        device=training_args.device,
    )
    model.get_model().vision_tower = transformers.CLIPVisionModel.from_pretrained(
        '/data/public/ucaswei/pretrain/vit-large-patch14')
    model.get_model().vision_tower_high = build_sam_vit_b(checkpoint='/data/hypertext/xpkong/newcode/checkpoints/kly-sam-opt-all-1023-new/pytorch_model.bin')
    # model.get_model().mm_projector = create_perciever()

    model.to(dtype=dtype, device=training_args.device)
    # 'image_processor_high
    # data_args.image_token_len = vision_tower_dict['image_token_len']
    data_args.image_token_len = 256
    data_args.image_processor = vision_tower_dict['image_processor']
    data_args.image_processor_high = vision_tower_dict['image_processor_high']
    data_args.use_im_start_end = model_args.use_im_start_end

    # mixed relation, to be fixed
    if model_args.freeze_lm_model:
        model.requires_grad_(False)
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = True
        # for p in model.get_model().vision_encoder.parameters():
        #     p.requires_grad = True
        # for p in model.get_model().chatt.parameters():
        #     p.requires_grad = True
        for p in model.get_input_embeddings().parameters():
            p.requires_grad = True
        # conv_final
        # for p in model.get_model().conv_final.parameters():
        #     p.requires_grad = True


        if not model_args.freeze_vision_tower:

            model.get_model().vision_tower.requires_grad_(True)
            # for i in range(20):
            #     model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)
            # model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)
            # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)
            # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)
            # model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)

            #         for p in model.get_model().vision_encoder.parameters():
            # p.requires_grad = True

            # for n, p in model.named_parameters():
            #     print(n, p.requires_grad)

    if model_args.freeze_vision_tower:
        model.get_model().vision_tower.requires_grad_(False)
    
    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
    print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")

    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
    # if len(params_no_grad) > 0:
    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:
    #         if len(params_no_grad) < 10:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
    #         else:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
    #         print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
    #         print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
    #         def patch_FSDP_use_orig_params(func):
    #             def wrap_func(*args, **kwargs):
    #                 use_orig_params = kwargs.pop('use_orig_params', True)
    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)
    #             return wrap_func

    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)


    data_module = make_supervised_data_module(
        interleave=training_args.interleave, 
        with_box=training_args.with_box, 
        tokenizer=tokenizer, 
        data_args=data_args
    )

    trainer = GOTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **data_module)

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    trainer._safe_save(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()




================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_GOT.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import logging
import pathlib
import torch
# torch.set_num_threads(1)
import transformers

# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import *
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import smart_tokenizer_and_embedding_resize
from GOT.model.vision_encoder.vary_b import build_vary_vit_b
import os

# os.environ['NCCL_IB_DISABLE'] = '1'
os.environ['NCCL_DEBUG'] = 'INFO'
os.environ['OSS_ENDPOINT'] = "http://oss.i.shaipower.com"

def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()


    tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)


    model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, use_safetensors=True)



    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=dict(pad_token='<|endoftext|>'),
        tokenizer=tokenizer,
        model=model,
        )


    dtype = torch.float32
    if training_args.fp16:
        dtype = torch.float16
    if training_args.bf16:
        dtype = torch.bfloat16

    vision_tower_dict = model.get_model().initialize_vision_modules(
        vision_tower=model_args.vision_tower,
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        freeze_vision_tower=model_args.freeze_vision_tower,
        use_im_start_end=model_args.use_im_start_end,
        vision_select_layer=model_args.vision_select_layer,
        dtype=dtype,
        device=training_args.device
    )

    model.initialize_vision_tokenizer(
        tokenizer=tokenizer, 
        freeze_lm_model=model_args.freeze_lm_model, 
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        device=training_args.device,
    )


    model.to(dtype=dtype, device=training_args.device)
    # 'image_processor_high
    # data_args.image_token_len = vision_tower_dict['image_token_len']
    data_args.image_token_len = 256
    data_args.image_processor = vision_tower_dict['image_processor']
    data_args.image_processor_high = vision_tower_dict['image_processor_high']
    data_args.use_im_start_end = model_args.use_im_start_end

    # mixed relation, to be fixed
    if model_args.freeze_lm_model:
        model.requires_grad_(False)
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = True
        for p in model.get_model().mm_projector_vary.parameters():
            p.requires_grad = True
        for p in model.get_input_embeddings().parameters():
            p.requires_grad = True


                
    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
    print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")

    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
    # if len(params_no_grad) > 0:
    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:
    #         if len(params_no_grad) < 10:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
    #         else:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
    #         print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
    #         print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
    #         def patch_FSDP_use_orig_params(func):
    #             def wrap_func(*args, **kwargs):
    #                 use_orig_params = kwargs.pop('use_orig_params', True)
    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)
    #             return wrap_func

    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)

    

    data_module = make_supervised_data_module(
        interleave=training_args.interleave, 
        with_box=training_args.with_box, 
        tokenizer=tokenizer, 
        data_args=data_args
    )

    trainer = GOTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **data_module)

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    trainer._safe_save(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()


================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_flash_attn.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

replace_llama_attn_with_flash_attn()

from GOT.train.train import train

if __name__ == "__main__":
    train()


================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_lora.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import logging
import pathlib
import torch
import transformers

# from GOT.train.trainer import GOTTrainer
# from GOT.train.trainer_vit_llrd import GOTTrainer
from GOT.train.trainer_vit_fixlr import GOTTrainer
from GOT.model import GOTLlamaForCausalLM
from GOT.data import make_supervised_data_module
from GOT.utils.arguments import *
from GOT.utils.constants import *
from GOT.utils.utils import *


# def find_all_linear_names(model):
#     cls = torch.nn.Linear
#     lora_module_names = set()
#     for name, module in model.named_modules():
#         if isinstance(module, cls):
#             names = name.split('.')
#             lora_module_names.add(names[0] if len(names) == 1 else names[-1])


#     if 'lm_head' in lora_module_names: # needed for 16-bit
#         lora_module_names.remove('lm_head')
#     return list(lora_module_names)

def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # model = GOTLlamaForCausalLM.from_pretrained(
    #     model_args.model_name_or_path,
    #     cache_dir=training_args.cache_dir,
    # )

    # tokenizer = transformers.AutoTokenizer.from_pretrained(
    #     model_args.model_name_or_path,
    #     cache_dir=training_args.cache_dir,
    #     model_max_length=training_args.model_max_length,
    #     padding_side="right",
    #     use_fast=False,
    # )

    tokenizer = transformers.AutoTokenizer.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen-chat/", trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)

    # # model = AutoModelForCausalLM.from_pretrained("/data/public/ucaswei/cache/Qwen/qwen/", device_map="cuda", trust_remote_code=True).eval()

    model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
        tokenizer=tokenizer,
        model=model,
        )

    # if data_args.conversation_version == "v0" or "models--decapoda-research--llama-7b-hf" in model_args.model_name_or_path:
    #     if tokenizer.pad_token is None:
    #         smart_tokenizer_and_embedding_resize(
    #             special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
    #             tokenizer=tokenizer,
    #             model=model,
    #         )
    #     if "llama" in model_args.model_name_or_path:
    #         tokenizer.add_special_tokens({
    #             "eos_token": DEFAULT_EOS_TOKEN,
    #             "bos_token": DEFAULT_BOS_TOKEN,
    #             "unk_token": DEFAULT_UNK_TOKEN,
    #         })
    # else:
    #     tokenizer.pad_token = tokenizer.unk_token

    dtype = torch.float32
    if training_args.fp16:
        dtype = torch.float16
    if training_args.bf16:
        dtype = torch.bfloat16

    if training_args.lora_enable:
        from peft import LoraConfig, get_peft_model
        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=find_all_linear_names(model),
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        logging.warning("Adding LoRA adapters...")
        model = get_peft_model(model, lora_config)

    vision_tower_dict = model.get_model().initialize_vision_modules(
        vision_tower=model_args.vision_tower,
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        freeze_vision_tower=model_args.freeze_vision_tower,
        use_im_start_end=model_args.use_im_start_end,
        vision_select_layer=model_args.vision_select_layer,
        dtype=dtype,
        device=training_args.device
    )

    model.initialize_vision_tokenizer(
        tokenizer=tokenizer, 
        freeze_lm_model=model_args.freeze_lm_model, 
        pretrained_stage1_model=model_args.pretrained_stage1_model,
        device=training_args.device,
    )

    model.get_model().vision_tower = create_clip_vit_g(448)
    model.get_model().mm_projector = create_perciever()
    model.to(dtype=dtype, device=training_args.device)

    data_args.image_token_len = vision_tower_dict['image_token_len']
    data_args.image_processor = vision_tower_dict['image_processor']
    data_args.image_processor_high = vision_tower_dict['image_processor_high']
    data_args.use_im_start_end = model_args.use_im_start_end

    # mixed relation, to be fixed
    if model_args.freeze_lm_model:
        model.requires_grad_(False)
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = True
        for p in model.get_input_embeddings().parameters():
            p.requires_grad = True
        for p in model.get_model().conv_final.parameters():
            p.requires_grad = True
        for p in model.get_model().vision_encoder.parameters():
            p.requires_grad = True

        if not model_args.freeze_vision_tower:
            model.get_model().vision_tower.requires_grad_(True)
            # for i in range(20):
            #     model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)
            model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)
            # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)
            # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)
            model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)

            for n, p in model.named_parameters():
                print(n, p.requires_grad)
                
    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
    print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")

    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
    # if len(params_no_grad) > 0:
    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:
    #         if len(params_no_grad) < 10:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
    #         else:
    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
    #         print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
    #         print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
    #         def patch_FSDP_use_orig_params(func):
    #             def wrap_func(*args, **kwargs):
    #                 use_orig_params = kwargs.pop('use_orig_params', True)
    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)
    #             return wrap_func

    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)

    data_module = make_supervised_data_module(
        interleave=training_args.interleave, 
        with_box=training_args.with_box, 
        tokenizer=tokenizer, 
        data_args=data_args
    )

    trainer = GOTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        **data_module)

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()

    if training_args.lora_enable:
        state_dict = get_peft_state_maybe_zero_3(
            model.named_parameters(), training_args.lora_bias
        )
        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
            model.named_parameters()
        )
        if training_args.local_rank == 0 or training_args.local_rank == -1:
            model.config.save_pretrained(training_args.output_dir)
            model.save_pretrained(training_args.output_dir, state_dict=state_dict)
            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
    else:
        trainer._safe_save(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()


================================================
FILE: GOT-OCR-2.0-master/GOT/train/train_lora_flash_attn.py
================================================
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

replace_llama_attn_with_flash_attn()

# from GOT.train.train import train
from GOT.train.train_lora import train

if __name__ == "__main__":
    train()


================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer.py
================================================
import os
import torch
import torch.nn as nn

from transformers import Trainer
from typing import Dict, Optional, Sequence


def unwrap_model(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
    """
    # since there could be multiple levels of wrapping, unwrap recursively
    if hasattr(model, "module"):
        return unwrap_model(model.module)
    else:
        return model


class GOTTrainer(Trainer):

    def _safe_save(self, output_dir: str):
        """Collects the state dict and dump to disk."""
        if self.deepspeed:
            torch.cuda.synchronize()
            self.save_model(output_dir)
            return
    
        state_dict = self.model.state_dict()
        if self.args.should_save:
            cpu_state_dict = {
                key: value.cpu()
                for key, value in state_dict.items()
            }
            del state_dict
            self._save(output_dir, state_dict=cpu_state_dict)  # noqa


    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            # Save the model
            _state_dict = state_dict
            if _state_dict is None:
                # Only save the model itself if we are using distributed training
                model_to_save = unwrap_model(self.model)
                _state_dict = model_to_save.state_dict()

            weight_to_save = {}
            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
            for k, v in _state_dict.items():
                if any(key_match in k for key_match in keys_to_match):
                    weight_to_save[k] = v

            current_folder = output_dir.split('/')[-1]
            parent_folder = os.path.dirname(output_dir)
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))

        super(GOTTrainer, self)._save(output_dir, state_dict)


================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py
================================================
import os
import torch
import torch.nn as nn
import time
import functools
import re

from transformers import Trainer
from transformers.trainer_pt_utils import (
    get_module_class_from_name,
    get_parameter_names,
)
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_neuroncore_available,
)
from transformers.trainer_utils import (
    FSDPOption,
    ShardedDDPOption,
)
from transformers.training_args import ParallelMode
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from typing import Dict, Optional, Sequence


def lr_scale_func(key):
    if "embed_tokens.weight" in key:
        return 0
    if "mm_projector" in key:
        return 0.01
        # return 1
    elif "vision_tower" in key:
        return 0.01
        # return 1
    elif "norm.weight" in key or "lm_head.weight" in key:
        return 1
    else:
        in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0])
        decay = 0.86 ** (32 - in_pp_layer - 1)
        return decay
                

def get_param_groups(model, no_weight_decay_cond, scale_lr_cond):
    """creates param groups based on weight decay condition (regularized vs non regularized)
    and learning rate scale condition (args.lr vs lr_mult * args.lr)
    scale_lr_cond is used during finetuning where head of the network requires a scaled
    version of the base learning rate.
    """
    wd_no_scale_lr = []
    wd_scale_lr = {}
    no_wd_no_scale_lr = []
    no_wd_scale_lr = {}
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if no_weight_decay_cond is not None:
            no_wd = no_weight_decay_cond(name, param)
        else:
            # do not regularize biases nor Norm parameters
            no_wd = name.endswith(".bias") or len(param.shape) == 1

        if scale_lr_cond is not None:
            lr_mult = scale_lr_cond(name)
            print(name, lr_mult)
            scale_lr = lr_mult != 1
        else:
            scale_lr = False

        if not no_wd and not scale_lr:
            wd_no_scale_lr.append(param)
        elif not no_wd and scale_lr:
            if lr_mult not in wd_scale_lr:
                wd_scale_lr[lr_mult] = [param]
            else:
                wd_scale_lr[lr_mult].append(param)
        elif no_wd and not scale_lr:
            no_wd_no_scale_lr.append(param)
        else:
            if lr_mult not in no_wd_scale_lr:
                no_wd_scale_lr[lr_mult] = [param]
            else:
                no_wd_scale_lr[lr_mult].append(param)

    param_groups = []
    if len(wd_no_scale_lr):
        param_groups.append({"params": wd_no_scale_lr, "wd_mult": 1.0, "lr_mult": 1.0})
    if len(wd_scale_lr):
        for lr_mult, params in wd_scale_lr.items():
            param_groups.append({"params": params, "wd_mult": 1.0, "lr_mult": lr_mult})
    if len(no_wd_no_scale_lr):
        param_groups.append(
            {"params": no_wd_no_scale_lr, "wd_mult": 0.0, "lr_mult": 1.0}
        )
    if len(no_wd_scale_lr):
        for lr_mult, params in no_wd_scale_lr.items():
            param_groups.append({"params": params, "wd_mult": 0.0, "lr_mult": lr_mult})

    return param_groups


def unwrap_model(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
    """
    # since there could be multiple levels of wrapping, unwrap recursively
    if hasattr(model, "module"):
        return unwrap_model(model.module)
    else:
        return model


class GOTTrainer(Trainer):

    def _safe_save(self, output_dir: str):
        """Collects the state dict and dump to disk."""
        state_dict = self.model.state_dict()
        if self.args.should_save:
            cpu_state_dict = {
                key: value.cpu()
                for key, value in state_dict.items()
            }
            del state_dict
            self._save(output_dir, state_dict=cpu_state_dict)  # noqa


    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            # Save the model
            _state_dict = state_dict
            if _state_dict is None:
                # Only save the model itself if we are using distributed training
                model_to_save = unwrap_model(self.model)
                _state_dict = model_to_save.state_dict()

            weight_to_save = {}
            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
            for k, v in _state_dict.items():
                if any(key_match in k for key_match in keys_to_match):
                    weight_to_save[k] = v

            current_folder = output_dir.split('/')[-1]
            parent_folder = os.path.dirname(output_dir)
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))

        super(GOTTrainer, self)._save(output_dir, state_dict)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        opt_model = self.model

        if self.optimizer is None:
            # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            # decay_parameters = [name for name in decay_parameters if "bias" not in name]
            # optimizer_grouped_parameters = [
            #     {
            #         "params": [
            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
            #         ],
            #         "weight_decay": self.args.weight_decay,
            #     },
            #     {
            #         "params": [
            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
            #         ],
            #         "weight_decay": 0.0,
            #     },
            # ]

            optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func)
            
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer
    

    def _wrap_model(self, model, training=True, dataloader=None):
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

        if is_sagemaker_mp_enabled():
            import smdistributed.modelparallel.torch as smp
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
        # already initialized its own DDP and AMP
        if self.deepspeed:
            return self.deepspeed

        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            from apex import amp
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = nn.DataParallel(model)

        if self.args.jit_mode_eval:
            start_time = time.time()
            model = self.torch_jit_model_eval(model, dataloader, training)
            self.jit_compilation_time = round(time.time() - start_time, 4)

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
        if self.sharded_ddp is not None:
            from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
            from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
            from fairscale.nn.wrap import auto_wrap
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
                mixed_precision = self.args.fp16 or self.args.bf16
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
                self.model = model = FullyShardedDDP(
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
                ).to(self.args.device)
        # Distributed training using PyTorch FSDP
        elif self.fsdp is not None:
            if not self.args.fsdp_config["xla"]:
                # PyTorch FSDP!
                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

                if FSDPOption.OFFLOAD in self.args.fsdp:
                    cpu_offload = CPUOffload(offload_params=True)
                else:
                    cpu_offload = CPUOffload(offload_params=False)

                auto_wrap_policy = None

                if FSDPOption.AUTO_WRAP in self.args.fsdp:
                    if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                        auto_wrap_policy = functools.partial(
                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                        )
                    elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                        transformer_cls_to_wrap = set()
                        for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                            transformer_cls = get_module_class_from_name(model, layer_class)
                            if transformer_cls is None:
                                raise Exception("Could not find the transformer layer class to wrap in the model.")
                            else:
                                transformer_cls_to_wrap.add(transformer_cls)
                        auto_wrap_policy = functools.partial(
                            transformer_auto_wrap_policy,
                            # Transformer layer class to wrap
                            transformer_layer_cls=transformer_cls_to_wrap,
                        )
                mixed_precision_policy = None
                dtype = None
                if self.args.fp16:
                    dtype = torch.float16
                elif self.args.bf16:
                    dtype = torch.bfloat16
                if dtype is not None:
                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
                if type(model) != FSDP:
                    # XXX: Breaking the self.model convention but I see no way around it for now.
                    self.model = model = FSDP(
                        model,
                        sharding_strategy=self.fsdp,
                        cpu_offload=cpu_offload,
                        auto_wrap_policy=auto_wrap_policy,
                        mixed_precision=mixed_precision_policy,
                        device_id=self.args.device,
                        backward_prefetch=self.backward_prefetch,
                        forward_prefetch=self.forword_prefetch,
                        limit_all_gathers=self.limit_all_gathers,
                        use_orig_params=True,
                    )
            else:
                try:
                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                    from torch_xla.distributed.fsdp import checkpoint_module
                    from torch_xla.distributed.fsdp.wrap import (
                        size_based_auto_wrap_policy,
                        transformer_auto_wrap_policy,
                    )
                except ImportError:
                    raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
                auto_wrap_policy = None
                auto_wrapper_callable = None
                if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                    auto_wrap_policy = functools.partial(
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                    )
                elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                    transformer_cls_to_wrap = set()
                    for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                        transformer_cls = get_module_class_from_name(model, layer_class)
                        if transformer_cls is None:
                            raise Exception("Could not find the transformer layer class to wrap in the model.")
                        else:
                            transformer_cls_to_wrap.add(transformer_cls)
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
                        transformer_layer_cls=transformer_cls_to_wrap,
                    )
                fsdp_kwargs = self.args.xla_fsdp_config
                if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                    def auto_wrapper_callable(m, *args, **kwargs):
                        return FSDP(checkpoint_module(m), *args, **kwargs)

                # Wrap the base model with an outer FSDP wrapper
                self.model = model = FSDP(
                    model,
                    auto_wrap_policy=auto_wrap_policy,
                    auto_wrapper_callable=auto_wrapper_callable,
                    **fsdp_kwargs,
                )

                import torch_xla.core.xla_model as xm
                # Patch `xm.optimizer_step` should not reduce gradients in this case,
                # as FSDP does not need gradient reduction over sharded parameters.
                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                    loss = optimizer.step(**optimizer_args)
                    if barrier:
                        xm.mark_step()
                    return loss

                xm.optimizer_step = patched_optimizer_step
        elif is_sagemaker_dp_enabled():
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
        elif self.args.local_rank != -1:
            kwargs = {}
            if self.args.ddp_find_unused_parameters is not None:
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
            else:
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
            if is_torch_neuroncore_available():
                return model
            model = nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
                **kwargs,
            )

        # torch.compile() needs to be called after wrapping the model with FSDP or DDP
        # to ensure that it accounts for the graph breaks required by those wrappers
        if self.args.torch_compile:
            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)

        return model



================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py
================================================
import os
import torch
import torch.nn as nn

from transformers import Trainer
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from typing import Dict, Optional, Sequence


def unwrap_model(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
    """
    # since there could be multiple levels of wrapping, unwrap recursively
    if hasattr(model, "module"):
        return unwrap_model(model.module)
    else:
        return model


class GOTTrainer(Trainer):

    def _safe_save(self, output_dir: str):
        """Collects the state dict and dump to disk."""
        state_dict = self.model.state_dict()
        if self.args.should_save:
            cpu_state_dict = {
                key: value.cpu()
                for key, value in state_dict.items()
            }
            del state_dict
            self._save(output_dir, state_dict=cpu_state_dict)  # noqa


    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            # Save the model
            _state_dict = state_dict
            if _state_dict is None:
                # Only save the model itself if we are using distributed training
                model_to_save = unwrap_model(self.model)
                _state_dict = model_to_save.state_dict()

            weight_to_save = {}
            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
            for k, v in _state_dict.items():
                if any(key_match in k for key_match in keys_to_match):
                    weight_to_save[k] = v

            current_folder = output_dir.split('/')[-1]
            parent_folder = os.path.dirname(output_dir)
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))

        super(GOTTrainer, self)._save(output_dir, state_dict)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        opt_model = self.model

        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad
                    ],
                    "weight_decay": self.args.weight_decay,
                    "lr": self.args.learning_rate,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad],
                    "weight_decay": 0.0,
                    "lr": self.args.learning_rate,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad],
                    "weight_decay": self.args.weight_decay,
                    "lr": self.args.learning_rate,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad
                    ],
                    "weight_decay": 0.0,
                    "lr": self.args.learning_rate,
                },
            ]
            for idx, group in enumerate(optimizer_grouped_parameters):
                print(idx, len(group['params']), group['lr'])
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer

================================================
FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py
================================================
import os
import torch
import torch.nn as nn
import time
import functools
import re

from transformers import Trainer
from transformers.trainer_pt_utils import (
    get_module_class_from_name,
    get_parameter_names,
)
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_torch_neuroncore_available,
)
from transformers.trainer_utils import (
    FSDPOption,
    ShardedDDPOption,
)
from transformers.training_args import ParallelMode
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from typing import Dict, Optional, Sequence


def lr_scale_func(key):
    if "vision_model.encoder.layers" in key:
        in_pp_layer = int(re.findall(f"layers\.(\d+)\.", key)[0])
        # decay = 0.81 ** (23 - in_pp_layer - 1)
        decay = 0.81 ** (23 - in_pp_layer - 1) * 0.01
        # decay = 0.66 ** (23 - in_pp_layer - 1)
        return decay
        # return 0.01
    elif "vision_model" in key:
        # return 0.01
        return 0.0001
    return 1
                

def get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr, wd):
    """creates param groups based on weight decay condition (regularized vs non regularized)
    and learning rate scale condition (args.lr vs lr_mult * args.lr)
    scale_lr_cond is used during finetuning where head of the network requires a scaled
    version of the base learning rate.
    """
    wd_no_scale_lr = []
    wd_scale_lr = {}
    no_wd_no_scale_lr = []
    no_wd_scale_lr = {}
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        if no_weight_decay_cond is not None:
            no_wd = no_weight_decay_cond(name, param)
        else:
            # do not regularize biases nor Norm parameters
            no_wd = name.endswith(".bias") or len(param.shape) == 1

        if scale_lr_cond is not None:
            lr_mult = scale_lr_cond(name)
            print(name, lr_mult)
            scale_lr = lr_mult != 1
        else:
            scale_lr = False

        if not no_wd and not scale_lr:
            wd_no_scale_lr.append(param)
        elif not no_wd and scale_lr:
            if lr_mult not in wd_scale_lr:
                wd_scale_lr[lr_mult] = [param]
            else:
                wd_scale_lr[lr_mult].append(param)
        elif no_wd and not scale_lr:
            no_wd_no_scale_lr.append(param)
        else:
            if lr_mult not in no_wd_scale_lr:
                no_wd_scale_lr[lr_mult] = [param]
            else:
                no_wd_scale_lr[lr_mult].append(param)

    param_groups = []
    if len(wd_no_scale_lr):
        param_groups.append({"params": wd_no_scale_lr, "weight_decay": wd, "lr": lr})
    if len(wd_scale_lr):
        for lr_mult, params in wd_scale_lr.items():
            param_groups.append({"params": params, "weight_decay": wd, "lr": lr * lr_mult})
    if len(no_wd_no_scale_lr):
        param_groups.append(
            {"params": no_wd_no_scale_lr, "weight_decay": 0.0, "lr": lr}
        )
    if len(no_wd_scale_lr):
        for lr_mult, params in no_wd_scale_lr.items():
            param_groups.append({"params": params, "weight_decay": 0.0, "lr": lr * lr_mult})

    return param_groups


def unwrap_model(model: nn.Module) -> nn.Module:
    """
    Recursively unwraps a model from potential containers (as used in distributed training).

    Args:
        model (`torch.nn.Module`): The model to unwrap.
    """
    # since there could be multiple levels of wrapping, unwrap recursively
    if hasattr(model, "module"):
        return unwrap_model(model.module)
    else:
        return model


class GOTTrainer(Trainer):

    def _safe_save(self, output_dir: str):
        """Collects the state dict and dump to disk."""
        state_dict = self.model.state_dict()
        if self.args.should_save:
            cpu_state_dict = {
                key: value.cpu()
                for key, value in state_dict.items()
            }
            del state_dict
            self._save(output_dir, state_dict=cpu_state_dict)  # noqa


    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            # Save the model
            _state_dict = state_dict
            if _state_dict is None:
                # Only save the model itself if we are using distributed training
                model_to_save = unwrap_model(self.model)
                _state_dict = model_to_save.state_dict()

            weight_to_save = {}
            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
            for k, v in _state_dict.items():
                if any(key_match in k for key_match in keys_to_match):
                    weight_to_save[k] = v

            current_folder = output_dir.split('/')[-1]
            parent_folder = os.path.dirname(output_dir)
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))

        super(GOTTrainer, self)._save(output_dir, state_dict)

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        opt_model = self.model

        if self.optimizer is None:
            # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            # decay_parameters = [name for name in decay_parameters if "bias" not in name]
            # optimizer_grouped_parameters = [
            #     {
            #         "params": [
            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
            #         ],
            #         "weight_decay": self.args.weight_decay,
            #     },
            #     {
            #         "params": [
            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
            #         ],
            #         "weight_decay": 0.0,
            #     },
            # ]

            optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func, self.args.learning_rate, self.args.weight_decay)
            
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

        return self.optimizer
    

    def _wrap_model(self, model, training=True, dataloader=None):
        if self.args.use_ipex:
            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
            model = self.ipex_optimize_model(model, training, dtype=dtype)

        if is_sagemaker_mp_enabled():
            import smdistributed.modelparallel.torch as smp
            # Wrapping the base model twice in a DistributedModel will raise an error.
            if isinstance(self.model_wrapped, smp.model.DistributedModel):
                return self.model_wrapped
            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
        # already initialized its own DDP and AMP
        if self.deepspeed:
            return self.deepspeed

        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
        if unwrap_model(model) is not model:
            return model

        # Mixed precision training with apex (torch < 1.6)
        if self.use_apex and training:
            from apex import amp
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # Multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = nn.DataParallel(model)

        if self.args.jit_mode_eval:
            start_time = time.time()
            model = self.torch_jit_model_eval(model, dataloader, training)
            self.jit_compilation_time = round(time.time() - start_time, 4)

        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
        if not training:
            return model

        # Distributed training (should be after apex fp16 initialization)
        if self.sharded_ddp is not None:
            from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
            from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
            from fairscale.nn.wrap import auto_wrap
            # Sharded DDP!
            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
                model = ShardedDDP(model, self.optimizer)
            else:
                mixed_precision = self.args.fp16 or self.args.bf16
                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
                # XXX: Breaking the self.model convention but I see no way around it for now.
                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
                    model = auto_wrap(model)
                self.model = model = FullyShardedDDP(
                    model,
                    mixed_precision=mixed_precision,
                    reshard_after_forward=zero_3,
                    cpu_offload=cpu_offload,
                ).to(self.args.device)
        # Distributed training using PyTorch FSDP
        elif self.fsdp is not None:
            if not self.args.fsdp_config["xla"]:
                # PyTorch FSDP!
                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

                if FSDPOption.OFFLOAD in self.args.fsdp:
                    cpu_offload = CPUOffload(offload_params=True)
                else:
                    cpu_offload = CPUOffload(offload_params=False)

                auto_wrap_policy = None

                if FSDPOption.AUTO_WRAP in self.args.fsdp:
                    if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                        auto_wrap_policy = functools.partial(
                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                        )
                    elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                        transformer_cls_to_wrap = set()
                        for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                            transformer_cls = get_module_class_from_name(model, layer_class)
                            if transformer_cls is None:
                                raise Exception("Could not find the transformer layer class to wrap in the model.")
                            else:
                                transformer_cls_to_wrap.add(transformer_cls)
                        auto_wrap_policy = functools.partial(
                            transformer_auto_wrap_policy,
                            # Transformer layer class to wrap
                            transformer_layer_cls=transformer_cls_to_wrap,
                        )
                mixed_precision_policy = None
                dtype = None
                if self.args.fp16:
                    dtype = torch.float16
                elif self.args.bf16:
                    dtype = torch.bfloat16
                if dtype is not None:
                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
                if type(model) != FSDP:
                    # XXX: Breaking the self.model convention but I see no way around it for now.
                    self.model = model = FSDP(
                        model,
                        sharding_strategy=self.fsdp,
                        cpu_offload=cpu_offload,
                        auto_wrap_policy=auto_wrap_policy,
                        mixed_precision=mixed_precision_policy,
                        device_id=self.args.device,
                        backward_prefetch=self.backward_prefetch,
                        forward_prefetch=self.forword_prefetch,
                        limit_all_gathers=self.limit_all_gathers,
                        use_orig_params=True,
                    )
            else:
                try:
                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
                    from torch_xla.distributed.fsdp import checkpoint_module
                    from torch_xla.distributed.fsdp.wrap import (
                        size_based_auto_wrap_policy,
                        transformer_auto_wrap_policy,
                    )
                except ImportError:
                    raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
                auto_wrap_policy = None
                auto_wrapper_callable = None
                if self.args.fsdp_config["fsdp_min_num_params"] > 0:
                    auto_wrap_policy = functools.partial(
                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
                    )
                elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
                    transformer_cls_to_wrap = set()
                    for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
                        transformer_cls = get_module_class_from_name(model, layer_class)
                        if transformer_cls is None:
                            raise Exception("Could not find the transformer layer class to wrap in the model.")
                        else:
                            transformer_cls_to_wrap.add(transformer_cls)
                    auto_wrap_policy = functools.partial(
                        transformer_auto_wrap_policy,
                        # Transformer layer class to wrap
                        transformer_layer_cls=transformer_cls_to_wrap,
                    )
                fsdp_kwargs = self.args.xla_fsdp_config
                if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified
                    def auto_wrapper_callable(m, *args, **kwargs):
                        return FSDP(checkpoint_module(m), *args, **kwargs)

                # Wrap the base model with an outer FSDP wrapper
                self.model = model = FSDP(
                    model,
                    auto_wrap_policy=auto_wrap_policy,
                    auto_wrapper_callable=auto_wrapper_callable,
                    **fsdp_kwargs,
                )

                import torch_xla.core.xla_model as xm
                # Patch `xm.optimizer_step` should not reduce gradients in this case,
                # as FSDP does not need gradient reduction over sharded parameters.
                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
                    loss = optimizer.step(**optimizer_args)
                    if barrier:
                        xm.mark_step()
                    return loss

                xm.optimizer_step = patched_optimizer_step
        elif is_sagemaker_dp_enabled():
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
            )
        elif self.args.local_rank != -1:
            kwargs = {}
            if self.args.ddp_find_unused_parameters is not None:
                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
            elif isinstance(model, PreTrainedModel):
                # find_unused_parameters breaks checkpointing as per
                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
            else:
                kwargs["find_unused_parameters"] = True

            if self.args.ddp_bucket_cap_mb is not None:
                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
            if is_torch_neuroncore_available():
                return model
            model = nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
                **kwargs,
            )

        # torch.compile() needs to be called after wrapping the model with FSDP or DDP
        # to ensure that it accounts for the graph breaks required by those wrappers
        if self.args.torch_compile:
            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)

        return model



================================================
FILE: GOT-OCR-2.0-master/GOT/utils/arguments.py
================================================
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import transformers


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    use_cache: bool = field(default=False)
    vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/")
    freeze_vision_tower: bool = field(default=False)
    freeze_lm_model: bool = field(default=False)
    pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower
    vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    use_im_start_end: bool = field(default=False)


@dataclass
class DataArguments:
    datasets: str = field(default=None, metadata={"help": "combinations of the training data."})
    sep_image_conv_front: bool = False
    image_token_len: int = 256
    image_aspect_ratio: str = 'square'
    conversation_version: str = 'mpt'
    # conversation_version: str = 'v0'
    # conversation_version: str = 'v1'
    # conversation_version: str = 'nougat'
    # conversation_version: str = 'baichuan'
    # conversation_version: str = 'opt'
    box_limit: int = 0


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    force_fsdp: bool = field(default=False)
    interleave: bool = field(default=False)
    with_box: bool = field(default=False)
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    lora_enable: bool = False
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"

================================================
FILE: GOT-OCR-2.0-master/GOT/utils/constants.py
================================================
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15

LOGDIR = "log"

IGNORE_INDEX = -100
# DEFAULT_PAD_TOKEN = "[PAD]"

DEFAULT_PAD_TOKEN = "<|endoftext|>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_BOX_TOKEN = "<box>"

DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'

DEFAULT_IM_START_TOKEN = '<img>'
DEFAULT_IM_END_TOKEN = '</img>'



CONVERSATION_DATA = {

    'data_1': {
        'images': '/path/',
        'annotations': '/path/data1.json',
    },
    'data_2': {
        'images': '/path/',
        'annotations': '/path/data2.json',
    },
    'data_3': {
        'images': '/path/',
        'annotations': '/path/data3.json',
    },


}

================================================
FILE: GOT-OCR-2.0-master/GOT/utils/conversation.py
================================================
import dataclasses
from enum import auto, Enum
from typing import List, Tuple


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()
    MPT = auto()



# simple_conv_multimodal = Conversation(
#     system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
#            "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
#            "Follow the instructions carefully and explain your answers in detail.",
#     # system="",
#     roles=("Human", "Assistant"),
#     messages=(
#         ("Human", "Hi!"),
#         ("Assistant", "Hi there!  How can I help you today?\n")
#     ),
#     offset=2,
#     sep_style=SeparatorStyle.SINGLE,
#     sep="###",
# )

# conv_mpt = Conversation(
#     system="""<|im_start|>system
# - You are a helpful language and vision assistant.
# - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
# - You should follow the instructions carefully and explain your answers in detail.""",
#     roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
#     version="mpt",
#     messages=(),
#     offset=0,
#     sep_style=SeparatorStyle.MPT,
#     sep="<|im_end|>",
# )

@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "<|im_end|>"
    sep2: str = None
    version: str = "Unknown"

    skip_next: bool = False

    def get_prompt(self):
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep + '\n'
            for role, message in self.messages:
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    ret += role + ": " + message + self.sep
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        if self.sep_style == SeparatorStyle.MPT:
            if self.system:
                ret = self.system + self.sep 
            else:
                ret = ''
            for role, message in self.messages:
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    ret += role + message + self.sep
                else:
                    ret += role
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")
        # if self.sep_style == SeparatorStyle.MPT:
        #     if self.system:
        #         ret = self.system + self.sep
        #     else:
        #         ret = ''
        #     for role, message in self.messages:
        #         if message:
        #             if type(message) is tuple:
        #                 message, _, _ = message
        #             ret += role + message + self.sep 
        #             # if 'user' in role:
        #             #     ret += role + message + self.sep + "\n"
        #             # else:
        #             #     ret += role + message + self.sep 
        #         else:
        #             ret += role
        #     return ret
        # else:
        #     raise ValueError(f"Invalid style: {self.sep_style}")

    def append_message(self, role, message):
        self.messages.append([role, message])

    def get_images(self, return_pil=False):
        images = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                if type(msg) is tuple:
                    import base64
                    from io import BytesIO
                    from PIL import Image
                    msg, image, image_process_mode = msg
                    if image_process_mode == "Pad":
                        def expand2square(pil_img, background_color=(122, 116, 104)):
                            width, height = pil_img.size
                            if width == height:
                                return pil_img
                            elif width > height:
                                result = Image.new(pil_img.mode, (width, width), background_color)
                                # result.paste(pil_img, (0, (width - height) // 2))
                                result.paste(pil_img)
                                return result
                            else:
                                result = Image.new(pil_img.mode, (height, height), background_color)
                                # result.paste(pil_img, ((height - width) // 2, 0))
                                result.paste(pil_img)
                                return result
                        image = expand2square(image)
                    elif image_process_mode == "Crop":
                        max_hw, min_hw = max(image.size), min(image.size)
                        aspect_ratio = max_hw / min_hw
                        max_len, min_len = 800, 400
                        shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
                        longest_edge = int(shortest_edge * aspect_ratio)
                        W, H = image.size
                        if H > W:
                            H, W = longest_edge, shortest_edge
                        else:
                            H, W = shortest_edge, longest_edge
                        image = image.resize((W, H))
                    elif image_process_mode == "Resize":
                        image = image.resize((224, 224))
                    else:
                        raise ValueError(f"Invalid image_process_mode: {image_process_mode}")

                    if return_pil:
                        images.append(image)
                    else:
                        buffered = BytesIO()
                        image.convert('RGB').save(buffered, format="JPEG")
                        img_b64_str = base64.b64encode(buffered.getvalue()).decode()
                        images.append(img_b64_str)
        return images

    def to_gradio_chatbot(self):
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset:]):
            if i % 2 == 0:
                if type(msg) is tuple:
                    import base64
                    from io import BytesIO
                    msg, image, image_process_mode = msg
                    max_hw, min_hw = max(image.size), min(image.size)
                    aspect_ratio = max_hw / min_hw
                    max_len, min_len = 800, 400
                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
                    longest_edge = int(shortest_edge * aspect_ratio)
                    W, H = image.size
                    if H > W:
                        H, W = longest_edge, shortest_edge
                    else:
                        H, W = shortest_edge, longest_edge
                    image = image.resize((W, H))
                    # image = image.resize((224, 224))
                    buffered = BytesIO()
                    image.save(buffered, format="JPEG")
                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()
                    img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
                    msg = msg.replace('<image>', img_str)
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def copy(self):
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2)

    def dict(self):
        if len(self.get_images()) > 0:
            return {
                "system": self.system,
                "roles": self.roles,
                "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
                "offset": self.offset,
                "sep": self.sep,
                "sep2": self.sep2,
            }
        return {
            "system": self.system,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
            "sep": self.sep,
            "sep2": self.sep2,
        }


conv_v1 = Conversation(
    system="A chat between a curious human and an artificial intelligence assistant. "
           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
    roles=("Human", "Assistant"),
    messages=(
        ("Human", "Give three tips for staying healthy."),
        ("Assistant",
            "Sure, here are three tips for staying healthy:\n"
            "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
            "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
            "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
            "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
            "activities at least two days per week.\n"
            "2. Eat a balanced diet: 
Download .txt
gitextract_is3bdbjd/

├── GOT-OCR-2.0-master/
│   ├── GOT/
│   │   ├── __init__.py
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── base_dataset.py
│   │   │   └── conversation_dataset_qwen.py
│   │   ├── demo/
│   │   │   ├── process_results.py
│   │   │   ├── run_ocr_2.0.py
│   │   │   └── run_ocr_2.0_crop.py
│   │   ├── eval/
│   │   │   ├── eval_GOT_ocr.py
│   │   │   ├── evaluate_GOT.py
│   │   │   ├── multi_hardware_eval_GOT.py
│   │   │   └── pyevaltools/
│   │   │       ├── __init__.py
│   │   │       ├── eval_ocr.py
│   │   │       ├── eval_ocr_format.py
│   │   │       ├── eval_ocr_scene.py
│   │   │       └── merge_results.py
│   │   ├── model/
│   │   │   ├── GOT_ocr_2_0.py
│   │   │   ├── __init__.py
│   │   │   ├── plug/
│   │   │   │   └── blip_process.py
│   │   │   └── vision_encoder/
│   │   │       ├── __init__.py
│   │   │       └── vary_b.py
│   │   ├── train/
│   │   │   ├── train.py
│   │   │   ├── train_GOT.py
│   │   │   ├── train_flash_attn.py
│   │   │   ├── train_lora.py
│   │   │   ├── train_lora_flash_attn.py
│   │   │   ├── trainer.py
│   │   │   ├── trainer_llm_llrd.py
│   │   │   ├── trainer_vit_fixlr.py
│   │   │   └── trainer_vit_llrd.py
│   │   └── utils/
│   │       ├── arguments.py
│   │       ├── constants.py
│   │       ├── conversation.py
│   │       └── utils.py
│   ├── pyproject.toml
│   ├── pyvenv.cfg
│   ├── render_tools/
│   │   ├── content-mmd-to-html.html
│   │   └── tikz.html
│   ├── results/
│   │   └── demo.html
│   └── zero_config/
│       ├── zero2.json
│       └── zero3.json
└── README.md
Download .txt
SYMBOL INDEX (183 symbols across 25 files)

FILE: GOT-OCR-2.0-master/GOT/data/__init__.py
  class DataCollatorForSupervisedDataset (line 10) | class DataCollatorForSupervisedDataset(object):
    method __call__ (line 13) | def __call__(self, instances):
  function make_supervised_data_module (line 46) | def make_supervised_data_module(interleave, with_box, tokenizer, data_ar...

FILE: GOT-OCR-2.0-master/GOT/data/base_dataset.py
  class BaseDataset (line 17) | class BaseDataset(Dataset):
    method __init__ (line 18) | def __init__(
    method image_processor (line 30) | def image_processor(self, image):
    method __len__ (line 66) | def __len__(self):
    method __getitem__ (line 69) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:

FILE: GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py
  class ConversationDataset (line 23) | class ConversationDataset(BaseDataset):
    method __init__ (line 26) | def __init__(self, datasets, tokenizer, multimodal_cfg):
    method multimodal_processor (line 71) | def multimodal_processor(self, sources, flag_num_patches):
    method _tokenize_fn (line 85) | def _tokenize_fn(self, strings):
    method _mask_targets (line 110) | def _mask_targets(self, target, tokenized_lens, speakers):
    method token_processor (line 120) | def token_processor(self, sources, image_name):
    method __getitem__ (line 195) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:

FILE: GOT-OCR-2.0-master/GOT/demo/process_results.py
  function svg_to_html (line 12) | def svg_to_html(svg_content, output_filename):

FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py
  function load_image (line 35) | def load_image(image_file):
  function eval_model (line 44) | def eval_model(args):

FILE: GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py
  function load_image (line 32) | def load_image(image_file):
  function find_closest_aspect_ratio (line 40) | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height...
  function dynamic_preprocess (line 57) | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use...
  function eval_model (line 99) | def eval_model(args):

FILE: GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py
  function load_image (line 49) | def load_image(image_file):
  function find_closest_aspect_ratio (line 58) | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height...
  function dynamic_preprocess (line 75) | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use...
  function split_list (line 119) | def split_list(lst, n):
  function get_chunk (line 125) | def get_chunk(lst, n, k):
  function eval_model (line 133) | def eval_model(args):

FILE: GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py
  function run_eval (line 8) | def run_eval(chunk_id, model_name, gtfile_path, image_path, out_path, nu...

FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py
  function preprocess (line 27) | def preprocess(text, predict_root_):
  function contain_chinese_string (line 33) | def contain_chinese_string(text):
  function split_text (line 43) | def split_text(pages, a_type):
  function nougat_per_metrics (line 64) | def nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: in...
  function doc_formated_text_eval (line 106) | def doc_formated_text_eval(gt_root_, predict_root_, datatype):
  function doc_text_eval (line 182) | def doc_text_eval(gt_root_, predict_root_, datatype):

FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py
  function preprocess (line 27) | def preprocess(text, predict_root_):
  function contain_chinese_string (line 33) | def contain_chinese_string(text):
  function split_text (line 43) | def split_text(pages, a_type):
  function nougat_per_metrics (line 64) | def nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: in...
  function doc_formated_text_eval (line 106) | def doc_formated_text_eval(gt_root_, predict_root_, datatype):
  function doc_text_eval (line 182) | def doc_text_eval(gt_root_, predict_root_, datatype):

FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py
  function preprocess (line 20) | def preprocess(text, predict_root_):
  function contain_chinese_string (line 26) | def contain_chinese_string(text):
  function nougat_per_metrics (line 30) | def nougat_per_metrics(predict_root_, pred, gt, minlen=1):
  function doc_text_eval (line 54) | def doc_text_eval(gt_root_, predict_root_, datatype):

FILE: GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py
  function merge_outputs (line 5) | def merge_outputs(out_path):

FILE: GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py
  class GOTConfig (line 15) | class GOTConfig(Qwen2Config):
  class GOTQwenModel (line 19) | class GOTQwenModel(Qwen2Model):
    method __init__ (line 22) | def __init__(self, config: Qwen2Config):
    method initialize_vision_modules (line 30) | def initialize_vision_modules(
    method forward (line 73) | def forward(
  class GOTQwenForCausalLM (line 208) | class GOTQwenForCausalLM(Qwen2ForCausalLM):
    method __init__ (line 212) | def __init__(self, config):
    method get_model (line 222) | def get_model(self):
    method forward (line 230) | def forward(
    method prepare_inputs_for_generation (line 304) | def prepare_inputs_for_generation(
    method initialize_vision_tokenizer (line 362) | def initialize_vision_tokenizer(

FILE: GOT-OCR-2.0-master/GOT/model/plug/blip_process.py
  class BaseProcessor (line 18) | class BaseProcessor:
    method __init__ (line 19) | def __init__(self):
    method __call__ (line 23) | def __call__(self, item):
  class BlipImageBaseProcessor (line 35) | class BlipImageBaseProcessor(BaseProcessor):
    method __init__ (line 36) | def __init__(self, mean=None, std=None):
  function identity_func (line 48) | def identity_func(img):
  function autocontrast_func (line 52) | def autocontrast_func(img, cutoff=0):
  function equalize_func (line 85) | def equalize_func(img):
  function rotate_func (line 109) | def rotate_func(img, degree, fill=(0, 0, 0)):
  function solarize_func (line 120) | def solarize_func(img, thresh=128):
  function color_func (line 130) | def color_func(img, factor):
  function contrast_func (line 148) | def contrast_func(img, factor):
  function brightness_func (line 162) | def brightness_func(img, factor):
  function sharpness_func (line 171) | def sharpness_func(img, factor):
  function shear_x_func (line 192) | def shear_x_func(img, factor, fill=(0, 0, 0)):
  function translate_x_func (line 201) | def translate_x_func(img, offset, fill=(0, 0, 0)):
  function translate_y_func (line 213) | def translate_y_func(img, offset, fill=(0, 0, 0)):
  function posterize_func (line 225) | def posterize_func(img, bits):
  function shear_y_func (line 233) | def shear_y_func(img, factor, fill=(0, 0, 0)):
  function cutout_func (line 242) | def cutout_func(img, pad_size, replace=(0, 0, 0)):
  function enhance_level_to_args (line 256) | def enhance_level_to_args(MAX_LEVEL):
  function shear_level_to_args (line 263) | def shear_level_to_args(MAX_LEVEL, replace_value):
  function translate_level_to_args (line 273) | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
  function cutout_level_to_args (line 283) | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
  function solarize_level_to_args (line 291) | def solarize_level_to_args(MAX_LEVEL):
  function none_level_to_args (line 299) | def none_level_to_args(level):
  function posterize_level_to_args (line 303) | def posterize_level_to_args(MAX_LEVEL):
  function rotate_level_to_args (line 311) | def rotate_level_to_args(MAX_LEVEL, replace_value):
  class RandomAugment (line 359) | class RandomAugment(object):
    method __init__ (line 360) | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
    method get_random_ops (line 369) | def get_random_ops(self):
    method __call__ (line 373) | def __call__(self, img):
  class VideoRandomAugment (line 385) | class VideoRandomAugment(object):
    method __init__ (line 386) | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
    method get_random_ops (line 396) | def get_random_ops(self):
    method __call__ (line 400) | def __call__(self, frames):
    method _aug (line 419) | def _aug(self, img, ops, apply_or_not):
  class BlipImageTrainProcessor (line 438) | class BlipImageTrainProcessor(BlipImageBaseProcessor):
    method __init__ (line 439) | def __init__(
    method __call__ (line 474) | def __call__(self, item):
  class BlipImageEvalProcessor (line 478) | class BlipImageEvalProcessor(BlipImageBaseProcessor):
    method __init__ (line 479) | def __init__(self, image_size=384, mean=None, std=None):
    method __call__ (line 492) | def __call__(self, item):

FILE: GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py
  class Projector (line 24) | class Projector(nn.Module):
    method __init__ (line 25) | def __init__(
    method forward (line 45) | def forward(self, x: torch.Tensor):
  class MLPBlock (line 53) | class MLPBlock(nn.Module):
    method __init__ (line 54) | def __init__(
    method forward (line 65) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class LayerNorm2d (line 71) | class LayerNorm2d(nn.Module):
    method __init__ (line 72) | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
    method forward (line 78) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ImageEncoderViT (line 87) | class ImageEncoderViT(nn.Module):
    method __init__ (line 88) | def __init__(
    method forward (line 180) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Block (line 196) | class Block(nn.Module):
    method __init__ (line 199) | def __init__(
    method forward (line 243) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class Attention (line 262) | class Attention(nn.Module):
    method __init__ (line 265) | def __init__(
    method forward (line 301) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function window_partition (line 320) | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.T...
  function window_unpartition (line 344) | def window_unpartition(
  function get_rel_pos (line 369) | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torc...
  function add_decomposed_rel_pos (line 402) | def add_decomposed_rel_pos(
  class PatchEmbed (line 441) | class PatchEmbed(nn.Module):
    method __init__ (line 446) | def __init__(
    method forward (line 468) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  function build_vary_vit_b (line 476) | def build_vary_vit_b(checkpoint=None):
  function _build_vary (line 486) | def _build_vary(

FILE: GOT-OCR-2.0-master/GOT/train/train.py
  function train (line 33) | def train():

FILE: GOT-OCR-2.0-master/GOT/train/train_GOT.py
  function train (line 38) | def train():

FILE: GOT-OCR-2.0-master/GOT/train/train_lora.py
  function train (line 45) | def train():

FILE: GOT-OCR-2.0-master/GOT/train/trainer.py
  function unwrap_model (line 9) | def unwrap_model(model: nn.Module) -> nn.Module:
  class GOTTrainer (line 23) | class GOTTrainer(Trainer):
    method _safe_save (line 25) | def _safe_save(self, output_dir: str):
    method _save (line 42) | def _save(self, output_dir: Optional[str] = None, state_dict=None):

FILE: GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py
  function lr_scale_func (line 28) | def lr_scale_func(key):
  function get_param_groups (line 45) | def get_param_groups(model, no_weight_decay_cond, scale_lr_cond):
  function unwrap_model (line 104) | def unwrap_model(model: nn.Module) -> nn.Module:
  class GOTTrainer (line 118) | class GOTTrainer(Trainer):
    method _safe_save (line 120) | def _safe_save(self, output_dir: str):
    method _save (line 132) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
    method create_optimizer (line 158) | def create_optimizer(self):
    method _wrap_model (line 193) | def _wrap_model(self, model, training=True, dataloader=None):

FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py
  function unwrap_model (line 11) | def unwrap_model(model: nn.Module) -> nn.Module:
  class GOTTrainer (line 25) | class GOTTrainer(Trainer):
    method _safe_save (line 27) | def _safe_save(self, output_dir: str):
    method _save (line 39) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
    method create_optimizer (line 65) | def create_optimizer(self):

FILE: GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py
  function lr_scale_func (line 28) | def lr_scale_func(key):
  function get_param_groups (line 42) | def get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr, wd):
  function unwrap_model (line 101) | def unwrap_model(model: nn.Module) -> nn.Module:
  class GOTTrainer (line 115) | class GOTTrainer(Trainer):
    method _safe_save (line 117) | def _safe_save(self, output_dir: str):
    method _save (line 129) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
    method create_optimizer (line 155) | def create_optimizer(self):
    method _wrap_model (line 190) | def _wrap_model(self, model, training=True, dataloader=None):

FILE: GOT-OCR-2.0-master/GOT/utils/arguments.py
  class ModelArguments (line 7) | class ModelArguments:
  class DataArguments (line 19) | class DataArguments:
  class TrainingArguments (line 34) | class TrainingArguments(transformers.TrainingArguments):

FILE: GOT-OCR-2.0-master/GOT/utils/conversation.py
  class SeparatorStyle (line 6) | class SeparatorStyle(Enum):
  class Conversation (line 43) | class Conversation:
    method get_prompt (line 56) | def get_prompt(self):
    method append_message (line 113) | def append_message(self, role, message):
    method get_images (line 116) | def get_images(self, return_pil=False):
    method to_gradio_chatbot (line 167) | def to_gradio_chatbot(self):
    method copy (line 197) | def copy(self):
    method dict (line 207) | def dict(self):

FILE: GOT-OCR-2.0-master/GOT/utils/utils.py
  function build_logger (line 18) | def build_logger(logger_name, logger_filename):
  class StreamToLogger (line 61) | class StreamToLogger(object):
    method __init__ (line 65) | def __init__(self, logger, log_level=logging.INFO):
    method __getattr__ (line 71) | def __getattr__(self, attr):
    method write (line 74) | def write(self, buf):
    method flush (line 88) | def flush(self):
  function disable_torch_init (line 94) | def disable_torch_init():
  function violates_moderation (line 103) | def violates_moderation(text):
  function pretty_print_semaphore (line 124) | def pretty_print_semaphore(semaphore):
  class KeywordsStoppingCriteria (line 130) | class KeywordsStoppingCriteria(StoppingCriteria):
    method __init__ (line 131) | def __init__(self, keywords, tokenizer, input_ids):
    method __call__ (line 139) | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTe...
  function smart_tokenizer_and_embedding_resize (line 153) | def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,...
  function maybe_zero_3 (line 179) | def maybe_zero_3(param, ignore_status=False, name=None):
  function get_peft_state_maybe_zero_3 (line 194) | def get_peft_state_maybe_zero_3(named_params, bias):
  function get_peft_state_non_lora_maybe_zero_3 (line 219) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
  function find_all_linear_names (line 227) | def find_all_linear_names(model):
Condensed preview — 41 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (242K chars).
[
  {
    "path": "GOT-OCR-2.0-master/GOT/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/__init__.py",
    "chars": 2369,
    "preview": "\nimport torch\nimport transformers\nfrom dataclasses import dataclass, field\n\nfrom GOT.utils.constants import *\n\n\n@datacla"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/base_dataset.py",
    "chars": 2868,
    "preview": "import io\nimport os\nimport copy\nimport json\nimport logging\nimport torch\nimport transformers\nimport boto3\nfrom typing imp"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py",
    "chars": 11381,
    "preview": "\nimport io\nimport os\nimport copy\nimport json\nimport logging\nimport torch\nimport random\n\nfrom typing import List, Optiona"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/process_results.py",
    "chars": 617,
    "preview": "import string\n\npunctuation_dict = {\n    \",\": \",\",\n    \"。\": \".\",\n\n}\n\n\n# import os\n \ndef svg_to_html(svg_content, output_f"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py",
    "chars": 8186,
    "preview": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\nfrom GOT.utils.conve"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py",
    "chars": 7990,
    "preview": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\nfrom GOT.utils.conve"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py",
    "chars": 11244,
    "preview": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\n\nfrom tqdm import tq"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/evaluate_GOT.py",
    "chars": 2089,
    "preview": "import os\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model-name\", type=str, default=\"fac"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py",
    "chars": 1941,
    "preview": "import os\nimport argparse\nfrom multiprocessing import Pool\n# from GOT.eval.merge_results import merge_outputs\n# from GOT"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/__init__.py",
    "chars": 18,
    "preview": "author='aagrawal'\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py",
    "chars": 6811,
    "preview": "import json\n# from doctextVQAeval import VQAEval\n\nimport argparse\n# import fitz as pymupdf\nimport nltk\nfrom nltk.metrics"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py",
    "chars": 6811,
    "preview": "import json\n# from doctextVQAeval import VQAEval\n\nimport argparse\n# import fitz as pymupdf\nimport nltk\nfrom nltk.metrics"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py",
    "chars": 2410,
    "preview": "import json\nimport argparse\nimport nltk\nfrom nltk.metrics import precision, recall, f_measure\nimport numpy as np\nimport "
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py",
    "chars": 593,
    "preview": "import os\nimport json\nimport argparse\n\ndef merge_outputs(out_path):\n    files = os.listdir(out_path)\n    # print(files)\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py",
    "chars": 16431,
    "preview": "from transformers import AutoConfig, AutoModelForCausalLM, \\\n                         Qwen2Config, Qwen2Model, Qwen2ForC"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/__init__.py",
    "chars": 71,
    "preview": "\nfrom .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/plug/blip_process.py",
    "chars": 14312,
    "preview": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full lic"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/vision_encoder/__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py",
    "chars": 18874,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train.py",
    "chars": 8652,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_GOT.py",
    "chars": 5922,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_flash_attn.py",
    "chars": 494,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_lora.py",
    "chars": 9411,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_lora_flash_attn.py",
    "chars": 535,
    "preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer.py",
    "chars": 2381,
    "preview": "import os\nimport torch\nimport torch.nn as nn\n\nfrom transformers import Trainer\nfrom typing import Dict, Optional, Sequen"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py",
    "chars": 17657,
    "preview": "import os\nimport torch\nimport torch.nn as nn\nimport time\nimport functools\nimport re\n\nfrom transformers import Trainer\nfr"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py",
    "chars": 4628,
    "preview": "import os\nimport torch\nimport torch.nn as nn\n\nfrom transformers import Trainer\nfrom transformers.trainer_pt_utils import"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py",
    "chars": 17702,
    "preview": "import os\nimport torch\nimport torch.nn as nn\nimport time\nimport functools\nimport re\n\nfrom transformers import Trainer\nfr"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/arguments.py",
    "chars": 1984,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Dict, Optional, Sequence\nimport transformers\n\n\n@dataclass\ncl"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/constants.py",
    "chars": 733,
    "preview": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \"log\"\n\nIGNORE_INDEX = -100\n# DEFAULT_PAD"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/conversation.py",
    "chars": 18345,
    "preview": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\n\n\nclass SeparatorStyle(Enum):\n    \"\"\"Diffe"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/utils.py",
    "chars": 8494,
    "preview": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\nimport torch\nimport requests\n\nfrom transform"
  },
  {
    "path": "GOT-OCR-2.0-master/pyproject.toml",
    "chars": 1038,
    "preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"GOT\"\nversion ="
  },
  {
    "path": "GOT-OCR-2.0-master/pyvenv.cfg",
    "chars": 206,
    "preview": "home = /usr/bin\nimplementation = CPython\nversion_info = 3.8.10.final.0\nvirtualenv = 20.16.7\ninclude-system-site-packages"
  },
  {
    "path": "GOT-OCR-2.0-master/render_tools/content-mmd-to-html.html",
    "chars": 909,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\" data-lt-installed=\"true\"><head>\n  <meta charset=\"UTF-8\">\n  <title>Title</title>\n  <scrip"
  },
  {
    "path": "GOT-OCR-2.0-master/render_tools/tikz.html",
    "chars": 383,
    "preview": "<!DOCTYPE html>\r\n\r\n<html>\r\n\r\n<head>\r\n<meta charset=\"UTF-8\">\r\n<meta name=\"viewport\" content=\"width=device-width, initial-"
  },
  {
    "path": "GOT-OCR-2.0-master/results/demo.html",
    "chars": 1752,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\" data-lt-installed=\"true\"><head>\n  <meta charset=\"UTF-8\">\n  <title>Title</title>\n  <scrip"
  },
  {
    "path": "GOT-OCR-2.0-master/zero_config/zero2.json",
    "chars": 295,
    "preview": "{\r\n    \"bf16\": {\r\n        \"enabled\": true\r\n    },\r\n    \"train_micro_batch_size_per_gpu\": \"auto\",\r\n    \"zero_optimization"
  },
  {
    "path": "GOT-OCR-2.0-master/zero_config/zero3.json",
    "chars": 828,
    "preview": "{\r\n    \"fp16\": {\r\n        \"enabled\": \"auto\",\r\n        \"loss_scale\": 0,\r\n        \"loss_scale_window\": 1000,\r\n        \"ini"
  },
  {
    "path": "README.md",
    "chars": 12494,
    "preview": "<h3><a href=\"\">General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model</a></h3>\n\n<a href=\"https://huggingface"
  }
]

About this extraction

This page contains the full source code of the Ucas-HaoranWei/GOT-OCR2.0 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 41 files (224.5 KB), approximately 56.5k tokens, and a symbol index with 183 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!