[
  {
    "path": "GOT-OCR-2.0-master/GOT/__init__.py",
    "content": ""
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/__init__.py",
    "content": "\nimport torch\nimport transformers\nfrom dataclasses import dataclass, field\n\nfrom GOT.utils.constants import *\n\n\n@dataclass\nclass DataCollatorForSupervisedDataset(object):\n    tokenizer: transformers.PreTrainedTokenizer\n\n    def __call__(self, instances):\n        # print(instances)\n        # exit()\n        input_ids, labels = tuple([instance[key] for instance in instances] for key in (\"input_ids\", \"labels\"))\n        images = [torch.stack(instance['image']) for instance in instances]\n\n        # if 'flattened_patches' in instances[0]['image_high'][0].keys():\n        #     images_high = [torch.stack([instance['image_high'][0]['flattened_patches']]) for instance in instances]\n        # else:\n        images_high = [torch.stack(instance['image_high']) for instance in instances]\n\n        images = list(zip(images, images_high))\n\n\n        input_ids = torch.nn.utils.rnn.pad_sequence(\n            input_ids,\n            batch_first=True,\n            padding_value=self.tokenizer.pad_token_id)\n            \n        labels = torch.nn.utils.rnn.pad_sequence(\n            labels,\n            batch_first=True,\n            padding_value=IGNORE_INDEX)\n        \n        batch = dict(\n            input_ids=input_ids,\n            labels=labels,\n            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n            images=images,\n        )\n        return batch\n    \n\ndef make_supervised_data_module(interleave, with_box, tokenizer, data_args):\n\n    if data_args.conversation_version == 'mpt':\n        from GOT.data.conversation_dataset_qwen import ConversationDataset\n        dataset_cls = ConversationDataset\n        \n    train_dataset = dataset_cls(\n        tokenizer=tokenizer,\n        datasets=data_args.datasets,\n        multimodal_cfg=dict(\n            sep_image_conv_front=data_args.sep_image_conv_front,\n            image_token_len=data_args.image_token_len,\n            image_aspect_ratio=data_args.image_aspect_ratio,\n            use_im_start_end=data_args.use_im_start_end,\n            image_processor=data_args.image_processor,\n            image_processor_high = data_args.image_processor_high,\n            box_limit=data_args.box_limit,\n        )\n    )\n    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n    return dict(train_dataset=train_dataset,\n                eval_dataset=None,\n                data_collator=data_collator)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/base_dataset.py",
    "content": "import io\nimport os\nimport copy\nimport json\nimport logging\nimport torch\nimport transformers\nimport boto3\nfrom typing import List, Optional, Tuple, Union, Dict, Sequence\nfrom torch.utils.data import Dataset\nfrom PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\nfrom GOT.utils.constants import *\n\n\n\nclass BaseDataset(Dataset):\n    def __init__(\n        self, \n        datasets: str,\n        tokenizer: transformers.PreTrainedTokenizer,\n        multimodal_cfg: dict\n    ):\n        super(BaseDataset, self).__init__()\n        self.tokenizer = tokenizer\n        self.multimodal_cfg = multimodal_cfg\n\n        logging.warning(f\"Using {multimodal_cfg['image_token_len']} tokens for representing image\")\n\n    def image_processor(self, image):\n        # processor = self.multimodal_cfg['image_processor']  # the first processor, usually is the clip pretrained model (vit)\n        processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)\n        image_high = image.copy()\n\n        #  Vary old codes\n        \n        # # TODO the 'keep', 'padding' only used for the first processor\n        # if self.multimodal_cfg['image_aspect_ratio'] == 'keep':\n        #     max_hw, min_hw = max(image.size), min(image.size)\n        #     aspect_ratio = max_hw / min_hw\n        #     max_len, min_len = 448, 224\n        #     shortest_edge = int(min(max_len / aspect_ratio, min_len))\n        #     image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={\"shortest_edge\": shortest_edge})['pixel_values'][0]\n        # elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':\n        #     def expand2square(pil_img, background_color):\n        #         width, height = pil_img.size\n        #         if width == height:\n        #             return pil_img\n        #         elif width > height:\n        #             result = Image.new(pil_img.mode, (width, width), background_color)\n        #             result.paste(pil_img) # for simpler box processing\n        #             return result\n        #         else:\n        #             result = Image.new(pil_img.mode, (height, height), background_color)\n        #             result.paste(pil_img) # for simpler box processing\n        #             return result\n        #     image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))\n        #     image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={\"shortest_edge\": 224})['pixel_values'][0]\n        # else:\n        #     image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n\n        image_high = processor_high(image_high)\n\n        return image_high\n\n    def __len__(self):\n        return len(self.list_data_dict)\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        pass"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/data/conversation_dataset_qwen.py",
    "content": "\nimport io\nimport os\nimport copy\nimport json\nimport logging\nimport torch\nimport random\n\nfrom typing import List, Optional, Tuple, Union, Dict, Sequence\nfrom PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nfrom GOT.data.base_dataset import BaseDataset\nfrom GOT.utils.constants import *\nfrom GOT.utils import conversation as conversation_lib\nimport boto3\nimport smart_open\nfrom megfile import smart_glob\nfrom natsort import natsorted\n\n\nclass ConversationDataset(BaseDataset):\n    \"\"\"Conversation format dataset stage2 fine-tuning.\"\"\"\n\n    def __init__(self, datasets, tokenizer, multimodal_cfg):\n        super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)\n        # v0 version format conversation\n        conversation_lib.default_conversation = conversation_lib.conv_templates[\"mpt\"]\n        logging.warning(\"Formatting inputs into conversation type: mpt-fixed\")\n        logging.warning(\"Loading data...\")\n\n        list_data_dict = []\n        list_image_path = []\n\n        # TODO add your data  [data1, data2, data3, .....]\n        got_data_dict = {\n            \"pdf-ocr\": [\"data1\", \"data2\"],\n            'scene-ocr': [\"data3\", \"data4\"]\n            # ......\n        }\n        for name_all in datasets.split(\"+\"):\n            for name in got_data_dict[name_all]:\n                dataset = CONVERSATION_DATA[name]\n\n                data_path = dataset['annotations']\n                data = json.load(open(data_path, \"r\"))\n\n                list_data_dict.extend(data)\n\n                image_path = dataset['images']\n\n                list_image_path.extend([image_path] * len(data))\n\n                logging.warning(f\"Data from {data_path} provide {len(data)} conversations.\")\n\n        assert len(list_data_dict) == len(list_image_path)\n        logging.warning(f\"{len(list_data_dict)} conversations in total.\")\n        a_new_list = list(zip(list_data_dict, list_image_path))\n        random.shuffle(a_new_list)\n        list_data_dict_new, list_image_path_new = zip(*a_new_list)\n        self.list_data_dict = list_data_dict_new\n        self.list_image_path = list_image_path_new\n\n        self.im_patch_token = 151859\n\n        self.im_start_token = 151857\n\n        self.im_end_token = 151858\n    \n    def multimodal_processor(self, sources, flag_num_patches):\n        for source in sources:\n            if self.multimodal_cfg['sep_image_conv_front']:\n                assert DEFAULT_IMAGE_TOKEN in source[0]['value']\n                source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()\n                source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + \": \" + source[0]['value']\n\n            for sentence in source:\n                replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']*flag_num_patches\n                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n                # sentence[\"value\"] = str(sentence[\"value\"]).replace('\\qquad', '\\quad')\n                sentence[\"value\"] = str(sentence[\"value\"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)\n        return sources\n\n    def _tokenize_fn(self, strings):\n        \"\"\"Tokenize a list of strings.\"\"\"\n        tokenized_list = [\n            self.tokenizer(\n                text,\n                return_tensors=\"pt\",\n                padding=\"longest\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n            ) for text in strings\n        ]\n        input_ids = labels = [\n            tokenized.input_ids[0] for tokenized in tokenized_list\n        ]\n        input_ids_lens = labels_lens = [\n            tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()\n            for tokenized in tokenized_list\n        ]\n        return dict(\n            input_ids=input_ids,\n            labels=labels,\n            input_ids_lens=input_ids_lens,\n            labels_lens=labels_lens,\n        )\n\n    def _mask_targets(self, target, tokenized_lens, speakers):\n        # cur_idx = 0\n        cur_idx = tokenized_lens[0]\n        tokenized_lens = tokenized_lens[1:]\n        target[:cur_idx] = IGNORE_INDEX\n        for tokenized_len, speaker in zip(tokenized_lens, speakers):\n            if speaker.lower() == \"human\":\n                target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX\n            cur_idx += tokenized_len\n\n    def token_processor(self, sources, image_name):\n        conv = conversation_lib.default_conversation.copy()\n        roles = {\"human\": conv.roles[0], \"gpt\": conv.roles[1]}\n\n        # Apply prompt templates\n        conversations = []\n        for i, source in enumerate(sources):\n            if roles[source[0][\"from\"]] != conv.roles[0]:\n                # Skip the first one if it is not from human\n                source = source[1:]\n\n            conv.messages = []\n            for j, sentence in enumerate(source):\n                role = roles[sentence[\"from\"]]\n                assert role == conv.roles[j % 2], f\"{i}\"\n                conv.append_message(role, sentence[\"value\"])\n            conversations.append(conv.get_prompt())\n\n        # Tokenize conversations\n\n\n        input_ids = self.tokenizer(\n            conversations,\n            return_tensors=\"pt\",\n            padding=\"longest\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n        ).input_ids\n\n        # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)\n        targets = input_ids.clone()\n        assert conv.sep_style == conversation_lib.SeparatorStyle.MPT\n\n        # Mask targets\n        sep = conv.sep + conv.roles[1]\n        for conversation, target in zip(conversations, targets):\n            total_len = int(target.ne(self.tokenizer.pad_token_id).sum())\n\n            rounds = conversation.split(conv.sep)\n            re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt\n            for conv_idx in range(3, len(rounds), 2):\n                re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt\n            cur_len = 0\n            target[:cur_len] = IGNORE_INDEX\n            for i, rou in enumerate(re_rounds):\n                if rou == \"\":\n                    break\n\n                parts = rou.split(sep)\n                if len(parts) != 2:\n                    break\n                parts[0] += sep\n                round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)\n                # round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer))\n                # instruction_len = len(tokenizer_image_token(parts[0], tokenizer))\n                instruction_len = len(self.tokenizer(parts[0]).input_ids)\n                target[cur_len : cur_len + instruction_len] = IGNORE_INDEX\n\n                cur_len += round_len\n            target[cur_len:] = IGNORE_INDEX\n\n            if cur_len < self.tokenizer.model_max_length:\n                if cur_len != total_len:\n                    target[:] = IGNORE_INDEX\n                    print(\n                        f\"WARNING: tokenization mismatch: {cur_len} vs. {total_len}.\"\n                        f\" (ignored)\"\n                    )\n                    print(image_name)\n\n        return dict(\n            input_ids=input_ids,\n            labels=targets,\n        )\n\n    def __getitem__(self, i) -> Dict[str, torch.Tensor]:\n        # data = self.list_data_dict[i]\n        data = copy.deepcopy(self.list_data_dict[i])\n\n        if isinstance(data, dict):\n            image_list =  []\n            image_high_list = []\n            flag_num_patches = 1\n            if 'image' in data:\n                image_path = self.list_image_path[i]\n                image_file = data['image']\n\n                # multi-crop or multi page, only support .png files\n                if ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path):\n                    if image_file[0] == '/':\n                        patch_dir = image_path[:-1] + image_file\n                        patches = smart_glob(patch_dir + '*.png')\n                    else:\n                        patch_dir = image_path + image_file\n                        patches = smart_glob(patch_dir + '*.png')\n\n                    # print(patches)\n                    if not patches:\n                        print(f'cannot glob the dir {patch_dir}.')\n                        return self.__getitem__(0)\n\n                    # sort multi images by name\n                    patches = natsorted(patches)\n                    flag_num_patches = len(patches)\n\n                    for patch in patches:\n                        try:\n                            image = Image.open(patch).convert('RGB')\n                        except:\n                            print(f'cannot identify image file {patch}.')\n                            return self.__getitem__(0)\n\n                        try:\n                            img = self.image_processor(image)\n                            image_list.append(img)\n                            image_high_list.append(img)\n                        except:\n                            print(f'image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!')\n                            return self.__getitem__(0)\n\n                else:\n                    flag_num_patches = 1\n                    try:\n                        image = Image.open(image_path + image_file).convert('RGB')\n                    except:\n                        print(f'cannot identify image file {image_file}.')\n                        return self.__getitem__(0)\n\n                    try:\n                        image = self.image_processor(image)\n                    except:\n                        print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')\n                        return self.__getitem__(0)\n\n            conversations = self.multimodal_processor([data[\"conversations\"]], flag_num_patches)\n            # print(conversations)\n            # exit()\n        else:\n            conversations = [data]\n\n        # align with fastchat & llava here, put the conversation into a list for tokenization\n        image_name = image_path + image_file\n        data_dict = self.token_processor(conversations, image_name)\n        data_dict = dict(input_ids=data_dict[\"input_ids\"][0], labels=data_dict[\"labels\"][0])\n        \n        if isinstance(data, dict) and 'image' in data:\n            if image_list and image_high_list:\n                data_dict['image'] = image_list\n                data_dict['image_high'] = image_high_list\n            else:\n                data_dict['image'] = [image]\n                data_dict['image_high'] = [image]\n        else:\n            # crop_size = self.multimodal_cfg['image_processor'].crop_size\n            # data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]\n            # Vary for two image, GOT does not use the data_dict['image]\n            data_dict['image'] = [torch.zeros(3, 1024, 1024)]\n            data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]\n        return data_dict\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/process_results.py",
    "content": "import string\n\npunctuation_dict = {\n    \"，\": \",\",\n    \"。\": \".\",\n\n}\n\n\n# import os\n \ndef svg_to_html(svg_content, output_filename):\n\n    html_content = f\"\"\"\n    <!DOCTYPE html>\n    <html lang=\"en\">\n    <head>\n        <meta charset=\"UTF-8\">\n        <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n        <title>SVG Embedded in HTML</title>\n    </head>\n    <body>\n        <svg width=\"2100\" height=\"15000\" xmlns=\"http://www.w3.org/2000/svg\">\n            {svg_content}\n        </svg>\n    </body>\n    </html>\n    \"\"\"\n\n    with open(output_filename, 'w') as file:\n        file.write(html_content)\n \n\n "
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0.py",
    "content": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\nfrom GOT.utils.conversation import conv_templates, SeparatorStyle\nfrom GOT.utils.utils import disable_torch_init\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria\nfrom GOT.model import *\nfrom GOT.utils.utils import KeywordsStoppingCriteria\n\nfrom PIL import Image\n\nimport os\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom GOT.model.plug.blip_process import BlipImageEvalProcessor\n\nfrom transformers import TextStreamer\nimport re\nfrom GOT.demo.process_results import punctuation_dict, svg_to_html\nimport string\n\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'\n\nDEFAULT_IM_START_TOKEN = '<img>'\nDEFAULT_IM_END_TOKEN = '</img>'\n\n\n \ntranslation_table = str.maketrans(punctuation_dict)\n\n\ndef load_image(image_file):\n    if image_file.startswith('http') or image_file.startswith('https'):\n        response = requests.get(image_file)\n        image = Image.open(BytesIO(response.content)).convert('RGB')\n    else:\n        image = Image.open(image_file).convert('RGB')\n    return image\n\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n    model_name = os.path.expanduser(args.model_name)\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n\n\n    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()\n\n    \n\n    model.to(device='cuda',  dtype=torch.bfloat16)\n\n\n    # TODO vary old codes, NEED del \n    image_processor = BlipImageEvalProcessor(image_size=1024)\n\n    image_processor_high =  BlipImageEvalProcessor(image_size=1024)\n\n    use_im_start_end = True\n\n    image_token_len = 256\n\n    image = load_image(args.image_file)\n\n    w, h = image.size\n    # print(image.size)\n    \n    if args.type == 'format':\n        qs = 'OCR with format: '\n    else:\n        qs = 'OCR: '\n\n    if args.box:\n        bbox = eval(args.box)\n        if len(bbox) == 2:\n            bbox[0] = int(bbox[0]/w*1000)\n            bbox[1] = int(bbox[1]/h*1000)\n        if len(bbox) == 4:\n            bbox[0] = int(bbox[0]/w*1000)\n            bbox[1] = int(bbox[1]/h*1000)\n            bbox[2] = int(bbox[2]/w*1000)\n            bbox[3] = int(bbox[3]/h*1000)\n        if args.type == 'format':\n            qs = str(bbox) + ' ' + 'OCR with format: '\n        else:\n            qs = str(bbox) + ' ' + 'OCR: '\n\n    if args.color:\n        if args.type == 'format':\n            qs = '[' + args.color + ']' + ' ' + 'OCR with format: '\n        else:\n            qs = '[' + args.color + ']' + ' ' + 'OCR: '\n\n    if use_im_start_end:\n        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\\n' + qs \n    else:\n        qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n\n\n\n    conv_mode = \"mpt\"\n    args.conv_mode = conv_mode\n\n    conv = conv_templates[args.conv_mode].copy()\n    conv.append_message(conv.roles[0], qs)\n    conv.append_message(conv.roles[1], None)\n    prompt = conv.get_prompt()\n\n    print(prompt)\n\n\n    inputs = tokenizer([prompt])\n\n\n    # vary old codes, no use\n    image_1 = image.copy()\n    image_tensor = image_processor(image)\n\n\n    image_tensor_1 = image_processor_high(image_1)\n\n\n    input_ids = torch.as_tensor(inputs.input_ids).cuda()\n\n    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n    keywords = [stop_str]\n    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n\n\n    with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n        output_ids = model.generate(\n            input_ids,\n            images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],\n            do_sample=False,\n            num_beams = 1,\n            no_repeat_ngram_size = 20,\n            streamer=streamer,\n            max_new_tokens=4096,\n            stopping_criteria=[stopping_criteria]\n            )\n        \n\n        if args.render:\n            print('==============rendering===============')\n\n            outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n            \n            if outputs.endswith(stop_str):\n                outputs = outputs[:-len(stop_str)]\n            outputs = outputs.strip()\n\n            if '**kern' in outputs:\n                import verovio\n                from cairosvg import svg2png\n                import cv2\n                import numpy as np\n                tk = verovio.toolkit()\n                tk.loadData(outputs)\n                tk.setOptions({\"pageWidth\": 2100, \"footer\": 'none',\n               'barLineWidth': 0.5, 'beamMaxSlope': 15,\n               'staffLineWidth': 0.2, 'spacingStaff': 6})\n                tk.getPageCount()\n                svg = tk.renderToSVG()\n                svg = svg.replace(\"overflow=\\\"inherit\\\"\", \"overflow=\\\"visible\\\"\")\n\n                svg_to_html(svg, \"./results/demo.html\")\n\n            if args.type == 'format' and '**kern' not in outputs:\n\n                \n                if  '\\\\begin{tikzpicture}' not in outputs:\n                    html_path = \"./render_tools/\" + \"/content-mmd-to-html.html\"\n                    html_path_2 = \"./results/demo.html\"\n                    right_num = outputs.count('\\\\right')\n                    left_num = outputs.count('\\left')\n\n                    if right_num != left_num:\n                        outputs = outputs.replace('\\left(', '(').replace('\\\\right)', ')').replace('\\left[', '[').replace('\\\\right]', ']').replace('\\left{', '{').replace('\\\\right}', '}').replace('\\left|', '|').replace('\\\\right|', '|').replace('\\left.', '.').replace('\\\\right.', '.')\n\n\n                    outputs = outputs.replace('\"', '``').replace('$', '')\n\n                    outputs_list = outputs.split('\\n')\n                    gt= ''\n                    for out in outputs_list:\n                        gt +=  '\"' + out.replace('\\\\', '\\\\\\\\') + r'\\n' + '\"' + '+' + '\\n' \n                    \n                    gt = gt[:-2]\n\n                    with open(html_path, 'r') as web_f:\n                        lines = web_f.read()\n                        lines = lines.split(\"const text =\")\n                        new_web = lines[0] + 'const text ='  + gt  + lines[1]\n                else:\n                    html_path = \"./render_tools/\" + \"/tikz.html\"\n                    html_path_2 = \"./results/demo.html\"\n                    outputs = outputs.translate(translation_table)\n                    outputs_list = outputs.split('\\n')\n                    gt= ''\n                    for out in outputs_list:\n                        if out:\n                            if '\\\\begin{tikzpicture}' not in out and '\\\\end{tikzpicture}' not in out:\n                                while out[-1] == ' ':\n                                    out = out[:-1]\n                                    if out is None:\n                                        break\n    \n                                if out:\n                                    if out[-1] != ';':\n                                        gt += out[:-1] + ';\\n'\n                                    else:\n                                        gt += out + '\\n'\n                            else:\n                                gt += out + '\\n'\n\n\n                    with open(html_path, 'r') as web_f:\n                        lines = web_f.read()\n                        lines = lines.split(\"const text =\")\n                        new_web = lines[0] + gt + lines[1]\n\n                with open(html_path_2, 'w') as web_f_new:\n                    web_f_new.write(new_web)\n\n\n\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--image-file\", type=str, required=True)\n    parser.add_argument(\"--type\", type=str, required=True)\n    parser.add_argument(\"--box\", type=str, default= '')\n    parser.add_argument(\"--color\", type=str, default= '')\n    parser.add_argument(\"--render\", action='store_true')\n    args = parser.parse_args()\n\n    eval_model(args)\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/demo/run_ocr_2.0_crop.py",
    "content": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\nfrom GOT.utils.conversation import conv_templates, SeparatorStyle\nfrom GOT.utils.utils import disable_torch_init\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria\nfrom GOT.model import *\nfrom GOT.utils.utils import KeywordsStoppingCriteria\n\nfrom PIL import Image\n\nimport os\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom GOT.model.plug.blip_process import BlipImageEvalProcessor\nfrom transformers import TextStreamer\nfrom natsort import natsorted\nimport glob\n\n\n\n\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'\nDEFAULT_IM_START_TOKEN = '<img>'\nDEFAULT_IM_END_TOKEN = '</img>'\n\n\n\ndef load_image(image_file):\n    if image_file.startswith('http') or image_file.startswith('https'):\n        response = requests.get(image_file)\n        image = Image.open(BytesIO(response.content)).convert('RGB')\n    else:\n        image = Image.open(image_file).convert('RGB')\n    return image\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    best_ratio_diff = float('inf')\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')\n    return best_ratio\n\n\ndef dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set(\n        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n        i * j <= max_num and i * j >= min_num)\n    # print(target_ratios)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(\n        aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n    # print(target_aspect_ratio)\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = (\n            (i % (target_width // image_size)) * image_size,\n            (i // (target_width // image_size)) * image_size,\n            ((i % (target_width // image_size)) + 1) * image_size,\n            ((i // (target_width // image_size)) + 1) * image_size\n        )\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n\n\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n    model_name = os.path.expanduser(args.model_name)\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n\n\n    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()\n\n\n\n    model.to(device='cuda',  dtype=torch.bfloat16)\n\n\n    # vary old codes, no use\n    image_processor = BlipImageEvalProcessor(image_size=1024)\n\n    image_processor_high =  BlipImageEvalProcessor(image_size=1024)\n\n    use_im_start_end = True\n\n\n    image_token_len = 256\n\n\n\n\n    image_list = []\n\n    if args.multi_page:\n        qs = 'OCR with format across multi pages: '\n        # only for png files\n        patches = glob.glob(args.image_file + '/*png')\n        patches = natsorted(patches)\n        sub_images = []\n        for sub_image in patches:\n            sub_images.append(load_image(sub_image))\n\n        ll = len(patches)\n\n    else:\n        qs = 'OCR with format upon the patch reference: '\n        img = load_image(args.image_file)\n        sub_images = dynamic_preprocess(img)\n        ll = len(sub_images)\n\n    for p in sub_images:\n\n        image = p\n        image_1 = image.copy()\n        # no use, vary old codes\n        image_tensor = image_processor(image)\n\n\n        image_tensor_1 = image_processor_high(image_1)\n\n        image_list.append(image_tensor_1)\n\n\n    image_list = torch.stack(image_list)\n\n    print('====new images batch size======:  ',image_list.shape)\n\n\n\n\n\n    # qs = args.query\n    if use_im_start_end:\n        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\\n' + qs \n    else:\n        qs = DEFAULT_IMAGE_TOKEN + '\\n' + qs\n\n\n    \n\n    conv_mode = \"mpt\"\n    args.conv_mode = conv_mode\n\n    conv = conv_templates[args.conv_mode].copy()\n    conv.append_message(conv.roles[0], qs)\n    conv.append_message(conv.roles[1], None)\n    prompt = conv.get_prompt()\n\n\n    inputs = tokenizer([prompt])\n\n    input_ids = torch.as_tensor(inputs.input_ids).cuda()\n\n    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n    keywords = [stop_str]\n    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n\n\n    with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n        output_ids = model.generate(\n            input_ids,\n            images=[(image_list.half().cuda(), image_list.half().cuda())],\n            do_sample=False,\n            num_beams = 1,\n            # no_repeat_ngram_size = 20,\n            streamer=streamer,\n            max_new_tokens=4096,\n            stopping_criteria=[stopping_criteria]\n            )\n        \n    if args.render:\n        print('==============rendering===============')\n        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n        \n        if outputs.endswith(stop_str):\n            outputs = outputs[:-len(stop_str)]\n        outputs = outputs.strip()\n\n        html_path = \"./render_tools/\" + \"/content-mmd-to-html.html\"\n        html_path_2 = \"./results/demo.html\"\n        right_num = outputs.count('\\\\right')\n        left_num = outputs.count('\\left')\n\n        if right_num != left_num:\n            outputs = outputs.replace('\\left(', '(').replace('\\\\right)', ')').replace('\\left[', '[').replace('\\\\right]', ']').replace('\\left{', '{').replace('\\\\right}', '}').replace('\\left|', '|').replace('\\\\right|', '|').replace('\\left.', '.').replace('\\\\right.', '.')\n\n\n        outputs = outputs.replace('\"', '``').replace('$', '')\n\n        outputs_list = outputs.split('\\n')\n        gt= ''\n        for out in outputs_list:\n            gt +=  '\"' + out.replace('\\\\', '\\\\\\\\') + r'\\n' + '\"' + '+' + '\\n' \n        \n        gt = gt[:-2]\n\n        with open(html_path, 'r') as web_f:\n            lines = web_f.read()\n            lines = lines.split(\"const text =\")\n            new_web = lines[0] + 'const text ='  + gt  + lines[1]\n            \n        with open(html_path_2, 'w') as web_f_new:\n            web_f_new.write(new_web)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--image-file\", type=str, required=True)\n    parser.add_argument(\"--conv-mode\", type=str, default=None)\n    parser.add_argument(\"--multi-page\", action='store_true')\n    parser.add_argument(\"--render\", action='store_true')\n    args = parser.parse_args()\n\n    eval_model(args)\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/eval_GOT_ocr.py",
    "content": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\n\nfrom tqdm import tqdm\nfrom PIL import Image\nimport json\nimport os\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nimport math\n\nimport argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport torch\nimport os\nfrom GOT.utils.conversation import conv_templates, SeparatorStyle\nfrom GOT.utils.utils import disable_torch_init\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria\nfrom GOT.model import *\nfrom GOT.utils.utils import KeywordsStoppingCriteria\n\nfrom PIL import Image\n\nimport os\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom GOT.model.plug.blip_process import BlipImageEvalProcessor\n\nfrom transformers import TextStreamer\nfrom GOT.model.plug.transforms import train_transform, test_transform\nimport re\nfrom GOT.demo.process_results import punctuation_dict, svg_to_html\n\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'\nDEFAULT_IM_START_TOKEN = '<img>'\nDEFAULT_IM_END_TOKEN = '</img>'\n\n\nimport string\n \ntranslation_table = str.maketrans(punctuation_dict)\n\n\ndef load_image(image_file):\n    if image_file.startswith('http') or image_file.startswith('https'):\n        response = requests.get(image_file)\n        image = Image.open(BytesIO(response.content)).convert('RGB')\n    else:\n        image = Image.open(image_file).convert('RGB')\n    return image\n\n\ndef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):\n    best_ratio_diff = float('inf')\n    best_ratio = (1, 1)\n    area = width * height\n    for ratio in target_ratios:\n        target_aspect_ratio = ratio[0] / ratio[1]\n        ratio_diff = abs(aspect_ratio - target_aspect_ratio)\n        if ratio_diff < best_ratio_diff:\n            best_ratio_diff = ratio_diff\n            best_ratio = ratio\n        elif ratio_diff == best_ratio_diff:\n            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:\n                best_ratio = ratio\n    # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')\n    return best_ratio\n\n\ndef dynamic_preprocess(image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):\n    orig_width, orig_height = image.size\n    aspect_ratio = orig_width / orig_height\n\n    # calculate the existing image aspect ratio\n    target_ratios = set(\n        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if\n        i * j <= max_num and i * j >= min_num)\n    # print(target_ratios)\n    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])\n\n    # find the closest aspect ratio to the target\n    target_aspect_ratio = find_closest_aspect_ratio(\n        aspect_ratio, target_ratios, orig_width, orig_height, image_size)\n\n    # print(target_aspect_ratio)\n    # calculate the target width and height\n    target_width = image_size * target_aspect_ratio[0]\n    target_height = image_size * target_aspect_ratio[1]\n    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]\n\n    # print(blocks)\n\n    # resize the image\n    resized_img = image.resize((target_width, target_height))\n    processed_images = []\n    for i in range(blocks):\n        box = (\n            (i % (target_width // image_size)) * image_size,\n            (i // (target_width // image_size)) * image_size,\n            ((i % (target_width // image_size)) + 1) * image_size,\n            ((i // (target_width // image_size)) + 1) * image_size\n        )\n        # split the image\n        split_img = resized_img.crop(box)\n        processed_images.append(split_img)\n    assert len(processed_images) == blocks\n    if use_thumbnail and len(processed_images) != 1:\n        thumbnail_img = image.resize((image_size, image_size))\n        processed_images.append(thumbnail_img)\n    return processed_images\n\n\n\ndef split_list(lst, n):\n    \"\"\"Split a list into n (roughly) equal-sized chunks\"\"\"\n    chunk_size = math.ceil(len(lst) / n)  # integer division\n    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]\n\n\ndef get_chunk(lst, n, k):\n    chunks = split_list(lst, n)\n    return chunks[k]\n\n\n\noutput_list = []\n\ndef eval_model(args):\n    # Model\n    disable_torch_init()\n    model_name = os.path.expanduser(args.model_name)\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n\n\n    model = GOTQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=151643).eval()\n\n\n    # vary old codes, no use\n    image_processor = BlipImageEvalProcessor(image_size=1024)\n\n\n    # image_processor_high = BlipImageEvalProcessor(image_size=1280)\n    image_processor_high = BlipImageEvalProcessor(image_size=1024)\n    use_im_start_end = True\n\n\n\n    # image_token_len = 400\n    image_token_len = 256\n    gts_path = args.gtfile_path\n    gts = json.load(open(gts_path))\n\n    # gts = gts[0]\n\n\n    print(\"Generate Results......\")\n\n\n    if \"OCR\" in args.datatype:\n            gts = get_chunk(gts, args.num_chunks, args.chunk_idx)\n\n\n    for ann in tqdm(gts):\n        output_json = {}\n        \n        if \"OCR\" in args.datatype:\n            qs = ann[\"conversations\"][0][\"value\"]\n        else:\n            qs = ann[\"question\"]\n            # ans = ann[\"answers\"][0]\n        \n        qs2 = qs\n        image_file = ann[\"image\"] \n        if 'Text' in args.datatype:\n            image_file = image_file + '.jpg'\n        if \"VQAv2\" in args.datatype:\n            image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'\n        if \"Cap\" in args.datatype:\n            image_file = 'COCO_' + 'val2014' + '_'+ str(image_file).zfill(12) + '.jpg'\n\n        image_file_path = os.path.join(args.image_path, image_file)\n        # print(image_file_path)\n        # exit()\n\n        # qs = args.query\n        # if mm_use_im_start_end:\n\n\n\n        multi_crop = False\n        if multi_crop:\n            image_list = []\n            # qs =  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN  + '\\n' +  'OCR with format upon the patch reference: '\n            img = load_image(image_file_path)\n            sub_images = dynamic_preprocess(img)\n            ll = len(sub_images)\n            for p in sub_images:\n                image = p\n                image_1 = image.copy()\n                # vary old code, NO USE\n                image_tensor = image_processor_high(image_1)\n\n                # image_tensor_1 = image_processor_high.preprocess(image_1, return_tensors='pt')['pixel_values'][0]\n\n                image_tensor_1 = image_processor_high(image_1)\n\n                image_list.append(image_tensor_1)\n\n                # print(image_tensor_1.shape)\n\n            image_list = torch.stack(image_list)\n\n        else:\n            ll = 1\n            image = load_image(image_file_path)\n            image_1 = image.copy()\n            # image_1 = image_1.resize((1024, 1024))\n\n            # vary old code, NO USE\n            image_tensor = image_processor_high(image_1)\n\n            image_tensor_1 = image_processor_high(image_1)\n            # image_tensor_1 = torch.zeros(3, 1024, 1024)\n\n\n        qs =  DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len*ll + DEFAULT_IM_END_TOKEN  + '\\n' +  'OCR with format: '\n\n        \n\n        conv_mode = \"mpt\"\n\n        if args.conv_mode is not None and conv_mode != args.conv_mode:\n            print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))\n        else:\n            args.conv_mode = conv_mode\n\n        conv = conv_templates[args.conv_mode].copy()\n        conv.append_message(conv.roles[0], qs)\n        conv.append_message(conv.roles[1], None)\n        prompt = conv.get_prompt()\n        inputs = tokenizer([prompt])\n\n        input_ids = torch.as_tensor(inputs.input_ids).cuda()\n\n        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n        keywords = [stop_str]\n        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n\n        if multi_crop:\n            with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n                output_ids = model.generate(\n                    input_ids,\n                    images=[(image_list.half().cuda(), image_list.half().cuda())],\n                    do_sample=False,\n                    num_beams = 1,\n                    # temperature=0.2,\n                    # no_repeat_ngram_size = 20,\n                    # streamer=streamer,\n                    max_new_tokens=4096,\n                    stopping_criteria=[stopping_criteria]\n                    )\n        else:\n            with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n                output_ids = model.generate(\n                    input_ids,\n                    images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],\n                    do_sample=False,\n                    num_beams = 1,\n                    # temperature=0.2,\n                    no_repeat_ngram_size = 20,\n                    # encoder_repetition_penalty = 1.2,\n                    # penalty_alpha=0.2,\n                    # top_k=3,\n                    max_new_tokens=4096,\n                    stopping_criteria=[stopping_criteria]\n                    )\n\n        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n        \n        if outputs.endswith(stop_str):\n            outputs = outputs[:-len(stop_str)]\n        outputs = outputs.strip()\n        # outputs = outputs.strip()[:-1]\n        if \"Cap\" in args.datatype:\n            # output_json['image'] = ann[\"image\"]\n            output_json['image_id'] = ann[\"id\"]\n            output_json[\"caption\"] = outputs\n        else:\n            # output_json['questionId'] = qs_id\n            # output_json['question_id'] = qs_id\n            output_json['image'] = ann[\"image\"]\n            output_json['question'] = qs \n            output_json['label'] = ann[\"conversations\"][1][\"value\"]\n            output_json['answer'] = outputs\n        output_list.append(output_json)\n\n    filename = args.out_path + \"/results_\" + str(args.chunk_idx) + \".json\"\n    with open(filename, 'w', encoding=\"utf-8\") as file_obj:\n        json.dump(output_list, file_obj, ensure_ascii=False, indent=1)\n        # print(outputs)\n    # print(\"Evaluate Results... \")\n    # doc_text_eval(gts_path, filename, args.datatype)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--gtfile_path\", type=str, required=True)\n    parser.add_argument(\"--image_path\", type=str, required=True)\n    parser.add_argument(\"--out_path\", type=str, required=True)\n    parser.add_argument(\"--datatype\", type=str, required=True)  # Text or Doc\n    parser.add_argument(\"--num-chunks\", type=int, default=1)\n    parser.add_argument(\"--chunk-idx\", type=int, default=0)\n    # parser.add_argument(\"--query\", type=str, required=True)\n    parser.add_argument(\"--conv-mode\", type=str, default=None)\n    parser.add_argument(\"--temperature\", type=float, default=0.2)\n    args = parser.parse_args()\n    print(args)\n    eval_model(args)\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/evaluate_GOT.py",
    "content": "import os\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\nparser.add_argument(\"--gtfile_path\", type=str, required=True)\nparser.add_argument(\"--image_path\", type=str, required=True)\nparser.add_argument(\"--out_path\", type=str, required=True)\nparser.add_argument(\"--num-chunks\", type=int, default=1)\nparser.add_argument(\"--temperature\", type=float, default=0.2)\nparser.add_argument(\"--datatype\", type=str, required=True)  # Text\\Doc\\VQAv2\\Cap\n# parser.add_argument(\"--eval\", type=str, required=True)\nargs = parser.parse_args()\n\nos.system(\"python3 -m GOT.eval.multi_hardware_eval_GOT\" + \" \"\n          + \"--model-name\" + \" \" + args.model_name + \" \"\n          + \"--gtfile_path\" + \" \" + args.gtfile_path + \" \"\n          + \"--image_path\" + \" \" + args.image_path + \" \"\n          + \"--out_path\" + \" \" + args.out_path + \" \"\n          + \"--num-chunks\" + \" \" + str(args.num_chunks) + \" \"\n          + \"--temperature\" + \" \" + str(args.temperature) + \" \"\n          + \"--datatype\" + \" \" + args.datatype\n          )\n\nprint(\"Evaluating.....\")\nos.system(\"python3 -m GOT.eval.pyevaltools.merge_results\" + \" \"\n          + \"--out_path\" + \" \" + args.out_path)\n\n\n# if args.datatype == \"OCR\":\n\n\na_type = 'plain'  # 'palin'; 'format'; 'scene'\n\nif a_type == 'plain':\n    os.system(\"python3 -m GOT.eval.pyevaltools.eval_ocr\" + \" \"\n                + \"--out_path\" + \" \" + args.out_path + \" \"\n                + \"--gt_path\" + \" \" + args.gtfile_path + \" \"\n                + \"--datatype\" + \" \" + args.datatype\n                )\nif a_type == 'format':\n    os.system(\"python3 -m GOT.eval.pyevaltools.eval_ocr_format\" + \" \"\n            + \"--out_path\" + \" \" + args.out_path + \" \"\n            + \"--gt_path\" + \" \" + args.gtfile_path + \" \"\n            + \"--datatype\" + \" \" + args.datatype\n            )\nif a_type == 'scene':\n    os.system(\"python3 -m GOT.eval.pyevaltools.eval_ocr_scene\" + \" \"\n        + \"--out_path\" + \" \" + args.out_path + \" \"\n        + \"--gt_path\" + \" \" + args.gtfile_path + \" \"\n        + \"--datatype\" + \" \" + args.datatype\n        )"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/multi_hardware_eval_GOT.py",
    "content": "import os\nimport argparse\nfrom multiprocessing import Pool\n# from GOT.eval.merge_results import merge_outputs\n# from GOT.eval.doctextVQA import doc_text_eval\n\n\ndef run_eval(chunk_id, model_name, gtfile_path, image_path, out_path, num_chunks, datatype, temperature):\n    os.system(\"CUDA_VISIBLE_DEVICES=\" + str(chunk_id) + \" \"\n              + \"python3 -m GOT.eval.eval_GOT_ocr\" + \" \"\n              + \"--model-name\" + \" \" + model_name + \" \"\n              + \"--gtfile_path\" +  \" \" + gtfile_path + \" \"\n              + \"--image_path\" + \" \" + image_path + \" \"\n              + \"--out_path\" +  \" \" + out_path + \" \"\n              + \"--num-chunks\" + \" \" +  str(num_chunks) + \" \"\n              + \"--chunk-idx\" + \" \" +  str(chunk_id) + \" \"\n              + \"--temperature\" + \" \" +  str(temperature) + \" \"\n              + \"--datatype\" + \" \" + datatype\n              )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"facebook/opt-350m\")\n    parser.add_argument(\"--gtfile_path\", type=str, required=True)\n    parser.add_argument(\"--image_path\", type=str, required=True)\n    parser.add_argument(\"--out_path\", type=str, required=True)\n    parser.add_argument(\"--num-chunks\", type=int, default=1)\n    parser.add_argument(\"--temperature\", type=float, default=0.2)\n    parser.add_argument(\"--datatype\", type=str, required=True)  # Text or Doc\n    # parser.add_argument(\"--eval\", type=str, required=True)\n    args = parser.parse_args()\n\n    num_chunks = args.num_chunks\n    \n    if os.path.exists(args.out_path) == False:\n        os.makedirs(args.out_path)\n\n\n    with Pool(num_chunks) as p:\n        for i in range(num_chunks):\n            chunk_id = i\n            p.apply_async(run_eval, (chunk_id, args.model_name, args.gtfile_path,\n                                     args.image_path, args.out_path, num_chunks, args.datatype, args.temperature))\n        p.close()\n        p.join()\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/__init__.py",
    "content": "author='aagrawal'\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr.py",
    "content": "import json\n# from doctextVQAeval import VQAEval\n\nimport argparse\n# import fitz as pymupdf\nimport nltk\nfrom nltk.metrics import precision, recall, f_measure\nimport numpy as np\nimport jieba\n# import megfile as mf\nimport pickle\nimport pandas as pd\nimport re\n# from loguru import logger\n# nltk.download('wordnet')\nfrom nltk.translate import meteor_score\n\n# from marker_scoring import score_text\n# from utils import contain_chinese_string\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--out_path\", type=str, required=True)\nparser.add_argument(\"--gt_path\", type=str, required=True)\nparser.add_argument(\"--datatype\", type=str, required=True)\nargs = parser.parse_args()\n\ndef preprocess(text, predict_root_):\n    if 'InternVL' in predict_root_:\n        text = text.split(\"All words in the image:\\n\")[1]\n        text = text.split(\"[UNUSED_TOKEN_145]\")[0]\n    return text\n\ndef contain_chinese_string(text):\n    # 使用正则表达式匹配中文字符\n    chinese_pattern = re.compile(r'[\\u4e00-\\u9fa5]')\n    return bool(chinese_pattern.search(text))\n\n\ninline_reg = re.compile(r\"\\\\\\((.*?)(?<!\\\\)\\\\\\)\")\ndisplay_reg = re.compile(r\"\\\\\\[(.+?)(?<!\\\\)\\\\\\]\")\ntable_reg = re.compile(r\"\\\\begin\\{tabular\\}(.+?)(?:\\\\end\\{tabular\\}|$)\", re.S)\n\ndef split_text(pages, a_type):\n    \"\"\"\n    Split a list of pages into text, inline math, display math, and table blocks.\n\n    Args:\n        pages: The pages to split.\n    \"\"\"\n    text, math, table = [], [], []\n    for page in pages:\n        for i, reg in enumerate([inline_reg, display_reg, table_reg]):\n            matches = \"\\n\".join(reg.findall(page[a_type]))\n            if i == 2:\n                table.append(matches)\n            elif i == 1:\n                math[-1] += matches\n            else:\n                math.append(matches)\n        page_str = page[a_type]\n        text.append(page_str.strip())\n    return text, math, table\n\ndef nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: int = 2):\n    \"\"\"\n    Args:\n    - heavy_mode:\n        0 is clean mode, only similar, bleu, f1\n        1 is normal, do not include edit_dist\n        2 is heavy, total\n    \"\"\"\n    metrics = {}\n\n    # pred = preprocess(pred, predict_root_)\n\n    if len(pred) < minlen or len(gt) < minlen:\n        return metrics\n\n    # metrics[\"similar\"] = score_text(pred, gt)\n    if contain_chinese_string(gt) or contain_chinese_string(pred):\n        reference = jieba.lcut(gt)\n        hypothesis = jieba.lcut(pred)\n    else:\n        reference = gt.split()\n        hypothesis = pred.split()\n\n    metrics[\"bleu\"] = nltk.translate.bleu([reference], hypothesis)\n    if heavy_mode >= 1:\n        # try:\n        metrics[\"meteor\"] = meteor_score.meteor_score([reference], hypothesis)\n        # except LookupError:\n        #     metrics[\"meteor\"] = np.nan\n\n    reference = set(reference)\n    hypothesis = set(hypothesis)\n    metrics[\"f_measure\"] = f_measure(reference, hypothesis)\n\n    if heavy_mode >= 1:\n        metrics[\"precision\"] = precision(reference, hypothesis)\n        metrics[\"recall\"] = recall(reference, hypothesis)\n    if heavy_mode == 2:\n        # 速度太慢\n        metrics[\"edit_dist\"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))\n    return metrics\n\ndef doc_formated_text_eval(gt_root_, predict_root_, datatype):\n\n    predicts = json.load(open(predict_root_, encoding='utf-8'))\n    \n    # print(predicts)\n\n    gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')\n    pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')\n    text_results = []\n    math_results = []\n    table_results = []\n\n    for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):\n        # try:\n        # text, math, table\n        text_gts, text_pres = gt0, pre0\n        math_gts, math_pres = gt1, pre1\n        table_gts, table_pres = gt2, pre2\n\n        # for text_gt, text_pre in zip(text_gts, text_pres):\n        ans = nougat_per_metrics(predict_root_, text_gts, text_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            text_results.append(ans)\n        # for math_gt, math_pre in zip(math_gts, math_pres):\n        ans = nougat_per_metrics(predict_root_, math_gts, math_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            math_results.append(ans)\n        \n        # for table_gt, table_pre in zip(table_gts, table_pres):\n        ans = nougat_per_metrics(predict_root_, table_gts, table_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            table_results.append(ans)\n    \n    mean_dict = {}\n    # print((result))\n    # print(len(result))\n    mean_dict[\"eval question num\"] = len(text_results)\n    mean_dict['text'] = {}\n    mean_dict['math'] = {}\n    mean_dict['table'] = {}\n\n    for k, v in text_results[0].items():\n        mean_dict['text'][k] = 0\n        mean_dict['math'][k] = 0\n        mean_dict['table'][k] = 0\n    \n    for each in text_results:\n        for k, v in each.items():\n            mean_dict['text'][k] += v\n    \n    for each in math_results:\n        for k, v in each.items():\n            mean_dict['math'][k] += v\n\n    for each in table_results:\n        for k, v in each.items():\n            mean_dict['table'][k] += v\n\n    for k, v in mean_dict['text'].items():\n        mean_dict['text'][k] /= len(text_results)\n\n    for k, v in mean_dict['math'].items():\n        mean_dict['math'][k] /= len(math_results)\n\n\n    for k, v in mean_dict['table'].items():\n        mean_dict['table'][k] /= len(table_results)\n\n    print(json.dumps(mean_dict, indent=4))\n\ndef doc_text_eval(gt_root_, predict_root_, datatype):\n\n   \n    predicts = json.load(open(predict_root_, encoding='utf-8'))\n    \n    # print(predicts)\n    result = []\n    for ann in predicts:\n        try:\n            ans = nougat_per_metrics(predict_root_, ann[\"label\"], ann[\"answer\"])\n            if len(ans) == 0:\n                continue\n            result.append(ans)\n        except:\n            assert False, print(\"ERROR!!! Check yout output!!!\")\n    \n    mean_dict = {}\n    # print((result))\n    # print(len(result))\n    mean_dict[\"eval question num\"] = len(result)\n    for k, v in result[0].items():\n        mean_dict[k] = 0\n    \n    for each in result:\n        for k, v in each.items():\n            mean_dict[k] += v\n\n    for k, v in mean_dict.items():\n        if k == \"eval question num\":\n            continue\n        mean_dict[k] /= len(result)\n    print(json.dumps(mean_dict, indent=4))\n\n# doc_text_eval(\"/data/data/DocVQA/val/val_v1.0.json\", \"/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json\", \"Doc\")\n\n\n# doc_formated_text_eval(args.gt_path, args.out_path + \"/results_final.json\", args.datatype)\n\ndoc_text_eval(args.gt_path, args.out_path + \"/results_final.json\", args.datatype)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_format.py",
    "content": "import json\n# from doctextVQAeval import VQAEval\n\nimport argparse\n# import fitz as pymupdf\nimport nltk\nfrom nltk.metrics import precision, recall, f_measure\nimport numpy as np\nimport jieba\n# import megfile as mf\nimport pickle\nimport pandas as pd\nimport re\n# from loguru import logger\n# nltk.download('wordnet')\nfrom nltk.translate import meteor_score\n\n# from marker_scoring import score_text\n# from utils import contain_chinese_string\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--out_path\", type=str, required=True)\nparser.add_argument(\"--gt_path\", type=str, required=True)\nparser.add_argument(\"--datatype\", type=str, required=True)\nargs = parser.parse_args()\n\ndef preprocess(text, predict_root_):\n    if 'InternVL' in predict_root_:\n        text = text.split(\"All words in the image:\\n\")[1]\n        text = text.split(\"[UNUSED_TOKEN_145]\")[0]\n    return text\n\ndef contain_chinese_string(text):\n    # 使用正则表达式匹配中文字符\n    chinese_pattern = re.compile(r'[\\u4e00-\\u9fa5]')\n    return bool(chinese_pattern.search(text))\n\n\ninline_reg = re.compile(r\"\\\\\\((.*?)(?<!\\\\)\\\\\\)\")\ndisplay_reg = re.compile(r\"\\\\\\[(.+?)(?<!\\\\)\\\\\\]\")\ntable_reg = re.compile(r\"\\\\begin\\{tabular\\}(.+?)(?:\\\\end\\{tabular\\}|$)\", re.S)\n\ndef split_text(pages, a_type):\n    \"\"\"\n    Split a list of pages into text, inline math, display math, and table blocks.\n\n    Args:\n        pages: The pages to split.\n    \"\"\"\n    text, math, table = [], [], []\n    for page in pages:\n        for i, reg in enumerate([inline_reg, display_reg, table_reg]):\n            matches = \"\\n\".join(reg.findall(page[a_type]))\n            if i == 2:\n                table.append(matches)\n            elif i == 1:\n                math[-1] += matches\n            else:\n                math.append(matches)\n        page_str = page[a_type]\n        text.append(page_str.strip())\n    return text, math, table\n\ndef nougat_per_metrics(predict_root_, pred, gt, minlen=1, heavy_mode: int = 2):\n    \"\"\"\n    Args:\n    - heavy_mode:\n        0 is clean mode, only similar, bleu, f1\n        1 is normal, do not include edit_dist\n        2 is heavy, total\n    \"\"\"\n    metrics = {}\n\n    # pred = preprocess(pred, predict_root_)\n\n    if len(pred) < minlen or len(gt) < minlen:\n        return metrics\n\n    # metrics[\"similar\"] = score_text(pred, gt)\n    if contain_chinese_string(gt) or contain_chinese_string(pred):\n        reference = jieba.lcut(gt)\n        hypothesis = jieba.lcut(pred)\n    else:\n        reference = gt.split()\n        hypothesis = pred.split()\n\n    metrics[\"bleu\"] = nltk.translate.bleu([reference], hypothesis)\n    if heavy_mode >= 1:\n        # try:\n        metrics[\"meteor\"] = meteor_score.meteor_score([reference], hypothesis)\n        # except LookupError:\n        #     metrics[\"meteor\"] = np.nan\n\n    reference = set(reference)\n    hypothesis = set(hypothesis)\n    metrics[\"f_measure\"] = f_measure(reference, hypothesis)\n\n    if heavy_mode >= 1:\n        metrics[\"precision\"] = precision(reference, hypothesis)\n        metrics[\"recall\"] = recall(reference, hypothesis)\n    if heavy_mode == 2:\n        # 速度太慢\n        metrics[\"edit_dist\"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))\n    return metrics\n\ndef doc_formated_text_eval(gt_root_, predict_root_, datatype):\n\n    predicts = json.load(open(predict_root_, encoding='utf-8'))\n    \n    # print(predicts)\n\n    gt_text_split, gt_math_split, gt_table_split= split_text(predicts, 'label')\n    pre_text_split, pre_math_split, pre_table_split = split_text(predicts, 'answer')\n    text_results = []\n    math_results = []\n    table_results = []\n\n    for gt0, pre0, gt1, pre1, gt2, pre2 in zip(gt_text_split, pre_text_split, gt_math_split, pre_math_split, gt_table_split, pre_table_split):\n        # try:\n        # text, math, table\n        text_gts, text_pres = gt0, pre0\n        math_gts, math_pres = gt1, pre1\n        table_gts, table_pres = gt2, pre2\n\n        # for text_gt, text_pre in zip(text_gts, text_pres):\n        ans = nougat_per_metrics(predict_root_, text_gts, text_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            text_results.append(ans)\n        # for math_gt, math_pre in zip(math_gts, math_pres):\n        ans = nougat_per_metrics(predict_root_, math_gts, math_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            math_results.append(ans)\n        \n        # for table_gt, table_pre in zip(table_gts, table_pres):\n        ans = nougat_per_metrics(predict_root_, table_gts, table_pres)\n        # if len(ans) == 0:\n        #     continue\n        if ans:\n            table_results.append(ans)\n    \n    mean_dict = {}\n    # print((result))\n    # print(len(result))\n    mean_dict[\"eval question num\"] = len(text_results)\n    mean_dict['text'] = {}\n    mean_dict['math'] = {}\n    mean_dict['table'] = {}\n\n    for k, v in text_results[0].items():\n        mean_dict['text'][k] = 0\n        mean_dict['math'][k] = 0\n        mean_dict['table'][k] = 0\n    \n    for each in text_results:\n        for k, v in each.items():\n            mean_dict['text'][k] += v\n    \n    for each in math_results:\n        for k, v in each.items():\n            mean_dict['math'][k] += v\n\n    for each in table_results:\n        for k, v in each.items():\n            mean_dict['table'][k] += v\n\n    for k, v in mean_dict['text'].items():\n        mean_dict['text'][k] /= len(text_results)\n\n    for k, v in mean_dict['math'].items():\n        mean_dict['math'][k] /= len(math_results)\n\n\n    for k, v in mean_dict['table'].items():\n        mean_dict['table'][k] /= len(table_results)\n\n    print(json.dumps(mean_dict, indent=4))\n\ndef doc_text_eval(gt_root_, predict_root_, datatype):\n\n   \n    predicts = json.load(open(predict_root_, encoding='utf-8'))\n    \n    # print(predicts)\n    result = []\n    for ann in predicts:\n        try:\n            ans = nougat_per_metrics(predict_root_, ann[\"label\"], ann[\"answer\"])\n            if len(ans) == 0:\n                continue\n            result.append(ans)\n        except:\n            assert False, print(\"ERROR!!! Check yout output!!!\")\n    \n    mean_dict = {}\n    # print((result))\n    # print(len(result))\n    mean_dict[\"eval question num\"] = len(result)\n    for k, v in result[0].items():\n        mean_dict[k] = 0\n    \n    for each in result:\n        for k, v in each.items():\n            mean_dict[k] += v\n\n    for k, v in mean_dict.items():\n        if k == \"eval question num\":\n            continue\n        mean_dict[k] /= len(result)\n    print(json.dumps(mean_dict, indent=4))\n\n# doc_text_eval(\"/data/data/DocVQA/val/val_v1.0.json\", \"/data/codes/GOT_docshot-main/results_cc595k-freeze-docvqa-unfreeze-224/results_final.json\", \"Doc\")\n\n\ndoc_formated_text_eval(args.gt_path, args.out_path + \"/results_final.json\", args.datatype)\n\n# doc_text_eval(args.gt_path, args.out_path + \"/results_final.json\", args.datatype)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/eval_ocr_scene.py",
    "content": "import json\nimport argparse\nimport nltk\nfrom nltk.metrics import precision, recall, f_measure\nimport numpy as np\nimport jieba\n# import megfile as mf\nimport pickle\nimport pandas as pd\nimport re\nfrom nltk.translate import meteor_score\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--out_path\", type=str, required=True)\nparser.add_argument(\"--gt_path\", type=str, required=True)\nparser.add_argument(\"--datatype\", type=str, required=True)\nargs = parser.parse_args()\n\ndef preprocess(text, predict_root_):\n    if 'InternVL' in predict_root_:\n        text = text.split(\"All words in the image:\\n\")[1]\n        text = text.split(\"[UNUSED_TOKEN_145]\")[0]\n    return text\n\ndef contain_chinese_string(text):\n    chinese_pattern = re.compile(r'[\\u4e00-\\u9fa5]')\n    return bool(chinese_pattern.search(text))\n\ndef nougat_per_metrics(predict_root_, pred, gt, minlen=1):\n\n    metrics = {}\n\n    if len(pred) < minlen or len(gt) < minlen:\n        return metrics\n\n\n    reference = list(gt)\n    hypothesis = list(pred)\n\n    metrics[\"bleu\"] = nltk.translate.bleu([reference], hypothesis)\n\n    metrics[\"meteor\"] = meteor_score.meteor_score([reference], hypothesis)\n\n    reference = set(reference)\n    hypothesis = set(hypothesis)\n    metrics[\"f_measure\"] = f_measure(reference, hypothesis)\n    metrics[\"precision\"] = precision(reference, hypothesis)\n    metrics[\"recall\"] = recall(reference, hypothesis)\n    metrics[\"edit_dist\"] = nltk.edit_distance(pred, gt) / max(len(pred), len(gt))\n\n    return metrics\n\ndef doc_text_eval(gt_root_, predict_root_, datatype):\n\n   \n\n    predicts = json.load(open(predict_root_, encoding='utf-8'))\n    \n    result = []\n    for ann in predicts:\n        try:\n            ans = nougat_per_metrics(predict_root_, ann[\"label\"], ann[\"answer\"])\n            if len(ans) == 0:\n                continue\n            result.append(ans)\n        except:\n            assert False, print(\"ERROR!!! Check yout output!!!\")\n    \n    mean_dict = {}\n\n    mean_dict[\"eval question num\"] = len(result)\n    for k, v in result[0].items():\n        mean_dict[k] = 0\n    \n    for each in result:\n        for k, v in each.items():\n            mean_dict[k] += v\n\n    for k, v in mean_dict.items():\n        if k == \"eval question num\":\n            continue\n        mean_dict[k] /= len(result)\n    print(json.dumps(mean_dict, indent=4))\n\n\ndoc_text_eval(args.gt_path, args.out_path + \"/results_final.json\", args.datatype)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/eval/pyevaltools/merge_results.py",
    "content": "import os\nimport json\nimport argparse\n\ndef merge_outputs(out_path):\n    files = os.listdir(out_path)\n    # print(files)\n    alist = []\n    for file in files:\n        alist += json.load(open(os.path.join(out_path, file), encoding='utf-8'))\n    # print(len(alist))\n\n    filename = out_path + \"/results_final\" + \".json\"\n    with open(filename, 'w', encoding=\"utf-8\") as file_obj:\n        json.dump(alist, file_obj, ensure_ascii=False, indent=1)\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--out_path\", type=str, required=True)\nargs = parser.parse_args()\n\nmerge_outputs(args.out_path)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/GOT_ocr_2_0.py",
    "content": "from transformers import AutoConfig, AutoModelForCausalLM, \\\n                         Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \\\n                         CLIPVisionModel, CLIPImageProcessor\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom typing import List, Optional, Tuple, Union\nfrom transformers.cache_utils import Cache, DynamicCache\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import CrossEntropyLoss\nfrom GOT.utils.constants import *\nfrom GOT.model.vision_encoder.vary_b import build_vary_vit_b\nfrom GOT.model.plug.blip_process import BlipImageEvalProcessor\n\nclass GOTConfig(Qwen2Config):\n    model_type = \"GOT\"\n\n\nclass GOTQwenModel(Qwen2Model):\n    config_class = GOTConfig\n\n    def __init__(self, config: Qwen2Config):\n        super(GOTQwenModel, self).__init__(config)\n\n        self.vision_tower_high = build_vary_vit_b()\n\n        self.mm_projector_vary =  nn.Linear(1024, 1024)\n\n\n    def initialize_vision_modules(\n        self, \n        vision_tower,\n        pretrained_stage1_model=None,\n        freeze_vision_tower=False,\n        use_im_start_end=False,\n        vision_select_layer=-1,\n        dtype=torch.float16,\n        device=\"cuda\"\n    ):\n\n        # Vary old codes, not use in GOT\n        image_processor = BlipImageEvalProcessor(image_size=1024)\n        # 1024*1024\n\n        image_processor_high = BlipImageEvalProcessor(image_size=1024)\n\n\n      \n        self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)\n\n        self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)\n\n\n        image_token_len = 256\n\n        self.config.vision_tower = vision_tower\n        self.config.image_token_len = image_token_len\n        # self.config.use_im_start_end = use_im_start_end\n        self.config.use_im_start_end = True\n\n        self.config.vision_select_layer = vision_select_layer\n        self.config.freeze_vision_tower = freeze_vision_tower\n        \n        return dict(\n            image_processor=image_processor,\n            image_processor_high=image_processor_high,\n            image_token_len=image_token_len,\n        )\n         \n    # def get_input_embeddings(self, x):\n    #     return self.wte(x)\n    \n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n\n        # HACK: replace back original embeddings for LLaVA pretraining\n        orig_embeds_params = getattr(self, 'orig_embeds_params', None)\n        if orig_embeds_params is not None:\n            with torch.no_grad():\n                self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n\n        vision_tower_high = getattr(self, 'vision_tower_high', None)\n\n\n        if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:\n        # if True:\n            # assert type(images) is list, ValueError(\"To fit both interleave and conversation, images must be list of batches of images\")\n            # print(im)\n            use_im_start_end = getattr(self.config, \"use_im_start_end\", -1)\n\n            vision_select_layer = getattr(self.config, \"vision_select_layer\", -1)\n            im_patch_token = getattr(self.config, \"im_patch_token\", -1)\n            im_start_token = getattr(self.config, \"im_start_token\", -1)\n            im_end_token = getattr(self.config, \"im_end_token\", -1)\n            freeze_vision_tower = getattr(self.config, \"freeze_vision_tower\", False)\n\n            im_patch_token = 151859\n\n            im_start_token = 151857\n\n            im_end_token = 151858\n            \n\n\n            image_features = []\n            \n\n            for image in images:\n                P, C, H, W = image[1].shape\n                # with torch.set_grad_enabled(True):\n                #     # print(image[1].shape)\n                #     cnn_feature = vision_tower_high(image[1])\n                #     cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256  1024\n                #     # image_features.append(cnn_feature)\n                # image_features_2.append(cnn_feature)\n                if P == 1:\n                    with torch.set_grad_enabled(False):\n                        # print(image[1].shape)\n                        cnn_feature = vision_tower_high(image[1])\n                        cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024\n                        # image_features.append(cnn_feature)\n                    # image_features_2.append(cnn_feature)\n                    image_feature = self.mm_projector_vary(cnn_feature)\n                    image_features.append(image_feature)\n\n                else:\n                    image_patches = torch.unbind(image[1])\n                    image_patches_features = []\n                    for image_patch in image_patches:\n                        image_p = torch.stack([image_patch])\n                        with torch.set_grad_enabled(False):\n                            cnn_feature_p = vision_tower_high(image_p)\n                            cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)\n                        image_feature_p = self.mm_projector_vary(cnn_feature_p)\n                        image_patches_features.append(image_feature_p)\n                    image_feature = torch.cat(image_patches_features, dim=1)\n                    # print(P)\n                    # print(image_feature.shape)\n                    # exit()\n                    image_features.append(image_feature)\n\n\n\n            dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)\n            # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)\n            dummy_image_features = dummy_image_features_2\n            use_im_start_end = True\n            new_input_embeds = []\n            for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):\n                if (cur_input_ids == im_patch_token).sum() == 0:\n                    # multimodal LLM, but the current sample is not multimodal\n                    cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()\n                    new_input_embeds.append(cur_input_embeds)\n                    continue\n\n                if use_im_start_end:\n                    if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():\n                        raise ValueError(\"The number of image start tokens and image end tokens should be the same.\")\n                    \n                    image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]\n                    for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):\n                        per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)\n                        num_patches = per_cur_image_features.shape[0]\n\n                        if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:\n                            raise ValueError(\"The image end token should follow the image start token.\")\n                        \n                        cur_input_embeds = torch.cat(\n                            (\n                                cur_input_embeds[:image_start_token_pos+1], \n                                per_cur_image_features, \n                                cur_input_embeds[image_start_token_pos + num_patches + 1:]\n                            ), \n                            dim=0\n                        )\n\n\n                    new_input_embeds.append(cur_input_embeds)\n                else:\n                    raise NotImplementedError\n\n            inputs_embeds = torch.stack(new_input_embeds, dim=0)\n\n        return super(GOTQwenModel, self).forward(\n            input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,\n            output_attentions=output_attentions, output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n\n\n\nclass GOTQwenForCausalLM(Qwen2ForCausalLM):\n    config_class = GOTConfig\n    # supports_gradient_checkpointing = True\n\n    def __init__(self, config):\n        super(Qwen2ForCausalLM, self).__init__(config)\n        self.model = GOTQwenModel(config)\n\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_model(self):\n        return self.model\n\n    # def _set_gradient_checkpointing(self, module, value=False):\n    #     if isinstance(module, GOTQwenModel):\n    #         module.gradient_checkpointing = value\n    # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)\n    # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        images: Optional[torch.FloatTensor] = None,\n        return_dict: Optional[bool] = None,\n        \n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        # print(input_ids)\n        # print(len(images))\n\n        # print(inputs_embeds)\n\n        outputs  = self.model(\n            input_ids=input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            images=images,\n            return_dict=return_dict\n            \n        )\n\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        # logits\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ):\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                cache_length = past_key_values.get_seq_length()\n                past_length = past_key_values.seen_tokens\n                max_cache_length = past_key_values.get_max_length()\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n                \"images\": kwargs.get(\"images\", None),\n            }\n        )\n        return model_inputs\n\n    def initialize_vision_tokenizer(\n        self, \n        tokenizer, \n        freeze_lm_model=False, \n        pretrained_stage1_model=None,\n        device=\"cuda\"\n    ):\n        config = self.get_model().config\n\n        # add image patch token <image>\n        # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n        self.resize_token_embeddings(len(tokenizer))\n        # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n\n        config.im_patch_token = 151859\n\n        config.use_im_start_end = True\n\n        # add image start token <im_start> and end token <im_end>\n        if config.use_im_start_end:\n            # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n            self.resize_token_embeddings(len(tokenizer))\n            # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n\n            config.im_start_token, config.im_end_token = 151857, 151858\n\n\nAutoConfig.register(\"GOT\", GOTConfig)\nAutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/__init__.py",
    "content": "\nfrom .GOT_ocr_2_0 import GOTQwenModel, GOTQwenForCausalLM, GOTConfig\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/plug/blip_process.py",
    "content": "\"\"\"\n Copyright (c) 2022, salesforce.com, inc.\n All rights reserved.\n SPDX-License-Identifier: BSD-3-Clause\n For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause\n\"\"\"\n\nimport cv2\nimport numpy as np\n\nimport torch\n\n# from omegaconf import OmegaConf\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import InterpolationMode\nfrom PIL import Image\n\nclass BaseProcessor:\n    def __init__(self):\n        self.transform = lambda x: x\n        return\n\n    def __call__(self, item):\n        return self.transform(item)\n\n    # @classmethod\n    # def from_config(cls, cfg=None):\n    #     return cls()\n\n    # def build(self, **kwargs):\n    #     cfg = OmegaConf.create(kwargs)\n\n    #     return self.from_config(cfg)\n\nclass BlipImageBaseProcessor(BaseProcessor):\n    def __init__(self, mean=None, std=None):\n        if mean is None:\n            mean = (0.48145466, 0.4578275, 0.40821073)\n        if std is None:\n            std = (0.26862954, 0.26130258, 0.27577711)\n        # mean = (0.0, 0.0, 0.0)\n        # std = (1.0, 1.0, 1.0)\n\n        self.normalize = transforms.Normalize(mean, std)\n\n\n## aug functions\ndef identity_func(img):\n    return img\n\n\ndef autocontrast_func(img, cutoff=0):\n    \"\"\"\n    same output as PIL.ImageOps.autocontrast\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        n = ch.size\n        cut = cutoff * n // 100\n        if cut == 0:\n            high, low = ch.max(), ch.min()\n        else:\n            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n            low = np.argwhere(np.cumsum(hist) > cut)\n            low = 0 if low.shape[0] == 0 else low[0]\n            high = np.argwhere(np.cumsum(hist[::-1]) > cut)\n            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]\n        if high <= low:\n            table = np.arange(n_bins)\n        else:\n            scale = (n_bins - 1) / (high - low)\n            offset = -low * scale\n            table = np.arange(n_bins) * scale + offset\n            table[table < 0] = 0\n            table[table > n_bins - 1] = n_bins - 1\n        table = table.clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef equalize_func(img):\n    \"\"\"\n    same output as PIL.ImageOps.equalize\n    PIL's implementation is different from cv2.equalize\n    \"\"\"\n    n_bins = 256\n\n    def tune_channel(ch):\n        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n        non_zero_hist = hist[hist != 0].reshape(-1)\n        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)\n        if step == 0:\n            return ch\n        n = np.empty_like(hist)\n        n[0] = step // 2\n        n[1:] = hist[:-1]\n        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef rotate_func(img, degree, fill=(0, 0, 0)):\n    \"\"\"\n    like PIL, rotate by degree, not radians\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    center = W / 2, H / 2\n    M = cv2.getRotationMatrix2D(center, degree, 1)\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)\n    return out\n\n\ndef solarize_func(img, thresh=128):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    table = np.array([el if el < thresh else 255 - el for el in range(256)])\n    table = table.clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef color_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Color\n    \"\"\"\n    ## implementation according to PIL definition, quite slow\n    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]\n    #  out = blend(degenerate, img, factor)\n    #  M = (\n    #      np.eye(3) * factor\n    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)\n    #  )[np.newaxis, np.newaxis, :]\n    M = np.float32(\n        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]\n    ) * factor + np.float32([[0.114], [0.587], [0.299]])\n    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)\n    return out\n\n\ndef contrast_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))\n    table = (\n        np.array([(el - mean) * factor + mean for el in range(256)])\n        .clip(0, 255)\n        .astype(np.uint8)\n    )\n    out = table[img]\n    return out\n\n\ndef brightness_func(img, factor):\n    \"\"\"\n    same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef sharpness_func(img, factor):\n    \"\"\"\n    The differences the this result and PIL are all on the 4 boundaries, the center\n    areas are same\n    \"\"\"\n    kernel = np.ones((3, 3), dtype=np.float32)\n    kernel[1][1] = 5\n    kernel /= 13\n    degenerate = cv2.filter2D(img, -1, kernel)\n    if factor == 0.0:\n        out = degenerate\n    elif factor == 1.0:\n        out = img\n    else:\n        out = img.astype(np.float32)\n        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]\n        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)\n        out = out.astype(np.uint8)\n    return out\n\n\ndef shear_x_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, factor, 0], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_x_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, -offset], [0, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef translate_y_func(img, offset, fill=(0, 0, 0)):\n    \"\"\"\n    same output as PIL.Image.transform\n    \"\"\"\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [0, 1, -offset]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef posterize_func(img, bits):\n    \"\"\"\n    same output as PIL.ImageOps.posterize\n    \"\"\"\n    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))\n    return out\n\n\ndef shear_y_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [factor, 1, 0]])\n    out = cv2.warpAffine(\n        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR\n    ).astype(np.uint8)\n    return out\n\n\ndef cutout_func(img, pad_size, replace=(0, 0, 0)):\n    replace = np.array(replace, dtype=np.uint8)\n    H, W = img.shape[0], img.shape[1]\n    rh, rw = np.random.random(2)\n    pad_size = pad_size // 2\n    ch, cw = int(rh * H), int(rw * W)\n    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)\n    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)\n    out = img.copy()\n    out[x1:x2, y1:y2, :] = replace\n    return out\n\n\n### level to args\ndef enhance_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        return ((level / MAX_LEVEL) * 1.8 + 0.1,)\n\n    return level_to_args\n\n\ndef shear_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 0.3\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef translate_level_to_args(translate_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * float(translate_const)\n        if np.random.random() > 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * cutout_const)\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef solarize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 256)\n        return (level,)\n\n    return level_to_args\n\n\ndef none_level_to_args(level):\n    return ()\n\n\ndef posterize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 4)\n        return (level,)\n\n    return level_to_args\n\n\ndef rotate_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 30\n        if np.random.random() < 0.5:\n            level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\nfunc_dict = {\n    \"Identity\": identity_func,\n    \"AutoContrast\": autocontrast_func,\n    \"Equalize\": equalize_func,\n    \"Rotate\": rotate_func,\n    \"Solarize\": solarize_func,\n    \"Color\": color_func,\n    \"Contrast\": contrast_func,\n    \"Brightness\": brightness_func,\n    \"Sharpness\": sharpness_func,\n    \"ShearX\": shear_x_func,\n    \"TranslateX\": translate_x_func,\n    \"TranslateY\": translate_y_func,\n    \"Posterize\": posterize_func,\n    \"ShearY\": shear_y_func,\n}\n\ntranslate_const = 10\nMAX_LEVEL = 10\nreplace_value = (128, 128, 128)\narg_dict = {\n    \"Identity\": none_level_to_args,\n    \"AutoContrast\": none_level_to_args,\n    \"Equalize\": none_level_to_args,\n    \"Rotate\": rotate_level_to_args(MAX_LEVEL, replace_value),\n    \"Solarize\": solarize_level_to_args(MAX_LEVEL),\n    \"Color\": enhance_level_to_args(MAX_LEVEL),\n    \"Contrast\": enhance_level_to_args(MAX_LEVEL),\n    \"Brightness\": enhance_level_to_args(MAX_LEVEL),\n    \"Sharpness\": enhance_level_to_args(MAX_LEVEL),\n    \"ShearX\": shear_level_to_args(MAX_LEVEL, replace_value),\n    \"TranslateX\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"TranslateY\": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),\n    \"Posterize\": posterize_level_to_args(MAX_LEVEL),\n    \"ShearY\": shear_level_to_args(MAX_LEVEL, replace_value),\n}\n\n\nclass RandomAugment(object):\n    def __init__(self, N=2, M=10, isPIL=False, augs=[]):\n        self.N = N\n        self.M = M\n        self.isPIL = isPIL\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N)\n        return [(op, 0.5, self.M) for op in sampled_ops]\n\n    def __call__(self, img):\n        if self.isPIL:\n            img = np.array(img)\n        ops = self.get_random_ops()\n        for name, prob, level in ops:\n            if np.random.random() > prob:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return img\n\n\nclass VideoRandomAugment(object):\n    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):\n        self.N = N\n        self.M = M\n        self.p = p\n        self.tensor_in_tensor_out = tensor_in_tensor_out\n        if augs:\n            self.augs = augs\n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N, replace=False)\n        return [(op, self.M) for op in sampled_ops]\n\n    def __call__(self, frames):\n        assert (\n            frames.shape[-1] == 3\n        ), \"Expecting last dimension for 3-channels RGB (b, h, w, c).\"\n\n        if self.tensor_in_tensor_out:\n            frames = frames.numpy().astype(np.uint8)\n\n        num_frames = frames.shape[0]\n\n        ops = num_frames * [self.get_random_ops()]\n        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]\n\n        frames = torch.stack(\n            list(map(self._aug, frames, ops, apply_or_not)), dim=0\n        ).float()\n\n        return frames\n\n    def _aug(self, img, ops, apply_or_not):\n        for i, (name, level) in enumerate(ops):\n            if not apply_or_not[i]:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args)\n        return torch.from_numpy(img)\n\n\n# if __name__ == \"__main__\":\n#     a = RandomAugment()\n#     img = np.random.randn(32, 32, 3)\n#     a(img)\n\n\n\n\n\n\nclass BlipImageTrainProcessor(BlipImageBaseProcessor):\n    def __init__(\n        self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0\n    ):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    image_size,\n                    scale=(min_scale, max_scale),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                # transforms.RandomHorizontalFlip(),\n                RandomAugment(\n                    2,\n                    5,\n                    isPIL=True,\n                    augs=[\n                        \"Identity\",\n                        # \"AutoContrast\",\n                        \"Brightness\",\n                        \"Sharpness\",\n                        \"Equalize\",\n                        # \"ShearX\",\n                        # \"ShearY\",\n                        # \"TranslateX\",\n                        # \"TranslateY\",\n                        # \"Rotate\",\n                    ],\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n\nclass BlipImageEvalProcessor(BlipImageBaseProcessor):\n    def __init__(self, image_size=384, mean=None, std=None):\n        super().__init__(mean=mean, std=std)\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(\n                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC\n                ),\n                transforms.ToTensor(),\n                self.normalize,\n            ]\n        )\n\n    def __call__(self, item):\n        return self.transform(item)\n\n\n# if __name__ == \"__main__\":\n#     a = BlipImageTrainProcessor(image_size=1024)\n#     # img = np.random.randn(1024, 1024, 3)\n#     # x = torch.zeros(1024, 1024, 3)\n#     x = Image.open(\"/data/codes/GOT-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg\").convert(\"RGB\")\n#     print(x.size)\n#     y = a(x)\n\n#     print(y.size())\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/vision_encoder/__init__.py",
    "content": "\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/model/vision_encoder/vary_b.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing import Optional, Tuple, Type\n\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\n\nfrom typing import Type\n\n# from GOT.model.vision_encoder.vitg_qwen import Resampler\nimport math\n\n\nclass Projector(nn.Module):\n    def __init__(\n        self,\n        width: 256,\n        n_queries: int = 256,\n        output_dim: int = 4096,\n        **kwargs\n    ):\n        super().__init__()\n\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n        self.attn_pool = Resampler(\n            grid_size=int(math.sqrt(n_queries)),\n            embed_dim=output_dim,\n            num_heads=output_dim // 128,\n            kv_dim=width,\n            norm_layer=norm_layer,\n        )\n        self.ln_post = norm_layer(output_dim)\n        self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))\n\n    def forward(self, x: torch.Tensor):\n        x = self.attn_pool(x)\n        x = self.ln_post(x)\n        x = x @ self.proj\n\n        return x\n\n\nclass MLPBlock(nn.Module):\n    def __init__(\n        self,\n        embedding_dim: int,\n        mlp_dim: int,\n        act: Type[nn.Module] = nn.GELU,\n    ) -> None:\n        super().__init__()\n        self.lin1 = nn.Linear(embedding_dim, mlp_dim)\n        self.lin2 = nn.Linear(mlp_dim, embedding_dim)\n        self.act = act()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.lin2(self.act(self.lin1(x)))\n\n\n# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa\n# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa\nclass LayerNorm2d(nn.Module):\n    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(num_channels))\n        self.bias = nn.Parameter(torch.zeros(num_channels))\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        u = x.mean(1, keepdim=True)\n        s = (x - u).pow(2).mean(1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.eps)\n        x = self.weight[:, None, None] * x + self.bias[:, None, None]\n        return x\n\n\n# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa\nclass ImageEncoderViT(nn.Module):\n    def __init__(\n        self,\n        img_size: int = 1024,\n        patch_size: int = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        depth: int = 12,\n        num_heads: int = 12,\n        mlp_ratio: float = 4.0,\n        out_chans: int = 256,\n        qkv_bias: bool = True,\n        norm_layer: Type[nn.Module] = nn.LayerNorm,\n        act_layer: Type[nn.Module] = nn.GELU,\n        use_abs_pos: bool = True,\n        use_rel_pos: bool = False,\n        rel_pos_zero_init: bool = True,\n        window_size: int = 0,\n        global_attn_indexes: Tuple[int, ...] = (),\n    ) -> None:\n        \"\"\"\n        Args:\n            img_size (int): Input image size.\n            patch_size (int): Patch size.\n            in_chans (int): Number of input image channels.\n            embed_dim (int): Patch embedding dimension.\n            depth (int): Depth of ViT.\n            num_heads (int): Number of attention heads in each ViT block.\n            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n            qkv_bias (bool): If True, add a learnable bias to query, key, value.\n            norm_layer (nn.Module): Normalization layer.\n            act_layer (nn.Module): Activation layer.\n            use_abs_pos (bool): If True, use absolute positional embeddings.\n            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            window_size (int): Window size for window attention blocks.\n            global_attn_indexes (list): Indexes for blocks using global attention.\n        \"\"\"\n        super().__init__()\n        self.img_size = img_size\n\n        self.patch_embed = PatchEmbed(\n            kernel_size=(patch_size, patch_size),\n            stride=(patch_size, patch_size),\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n        )\n\n        self.pos_embed: Optional[nn.Parameter] = None\n        if use_abs_pos:\n            # Initialize absolute positional embedding with pretrain image size.\n            self.pos_embed = nn.Parameter(\n                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)\n            )\n\n        self.blocks = nn.ModuleList()\n        for i in range(depth):\n            block = Block(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                use_rel_pos=use_rel_pos,\n                rel_pos_zero_init=rel_pos_zero_init,\n                window_size=window_size if i not in global_attn_indexes else 0,\n                input_size=(img_size // patch_size, img_size // patch_size),\n            )\n            self.blocks.append(block)\n\n        self.neck = nn.Sequential(\n            nn.Conv2d(\n                embed_dim,\n                out_chans,\n                kernel_size=1,\n                bias=False,\n            ),\n            LayerNorm2d(out_chans),\n            nn.Conv2d(\n                out_chans,\n                out_chans,\n                kernel_size=3,\n                padding=1,\n                bias=False,\n            ),\n            LayerNorm2d(out_chans),\n        )\n\n        \n        self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)\n        self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.patch_embed(x)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x = self.neck(x.permute(0, 3, 1, 2))\n        x = self.net_2(x)\n        x = self.net_3(x)\n\n\n        return x\n\n\nclass Block(nn.Module):\n    \"\"\"Transformer blocks with support of window attention and residual propagation blocks\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = True,\n        norm_layer: Type[nn.Module] = nn.LayerNorm,\n        act_layer: Type[nn.Module] = nn.GELU,\n        use_rel_pos: bool = False,\n        rel_pos_zero_init: bool = True,\n        window_size: int = 0,\n        input_size: Optional[Tuple[int, int]] = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim (int): Number of input channels.\n            num_heads (int): Number of attention heads in each ViT block.\n            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n            qkv_bias (bool): If True, add a learnable bias to query, key, value.\n            norm_layer (nn.Module): Normalization layer.\n            act_layer (nn.Module): Activation layer.\n            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            window_size (int): Window size for window attention blocks. If it equals 0, then\n                use global attention.\n            input_size (tuple(int, int) or None): Input resolution for calculating the relative\n                positional parameter size.\n        \"\"\"\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            use_rel_pos=use_rel_pos,\n            rel_pos_zero_init=rel_pos_zero_init,\n            input_size=input_size if window_size == 0 else (window_size, window_size),\n        )\n\n        self.norm2 = norm_layer(dim)\n        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)\n\n        self.window_size = window_size\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        shortcut = x\n        x = self.norm1(x)\n        # Window partition\n        if self.window_size > 0:\n            H, W = x.shape[1], x.shape[2]\n            x, pad_hw = window_partition(x, self.window_size)\n\n        x = self.attn(x)\n        # Reverse window partition\n        if self.window_size > 0:\n            x = window_unpartition(x, self.window_size, pad_hw, (H, W))\n\n        x = shortcut + x\n        x = x + self.mlp(self.norm2(x))\n\n        return x\n\n\nclass Attention(nn.Module):\n    \"\"\"Multi-head Attention block with relative position embeddings.\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = True,\n        use_rel_pos: bool = False,\n        rel_pos_zero_init: bool = True,\n        input_size: Optional[Tuple[int, int]] = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim (int): Number of input channels.\n            num_heads (int): Number of attention heads.\n            qkv_bias (bool):  If True, add a learnable bias to query, key, value.\n            rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            input_size (tuple(int, int) or None): Input resolution for calculating the relative\n                positional parameter size.\n        \"\"\"\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(dim, dim)\n\n        self.use_rel_pos = use_rel_pos\n        if self.use_rel_pos:\n            assert (\n                input_size is not None\n            ), \"Input size must be provided if using relative positional encoding.\"\n            # initialize relative positional embeddings\n            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))\n            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, H, W, _ = x.shape\n        # qkv with shape (3, B, nHead, H * W, C)\n        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        # q, k, v with shape (B * nHead, H * W, C)\n        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n\n        if self.use_rel_pos:\n            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))\n\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)\n        x = self.proj(x)\n\n        return x\n\n\ndef window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:\n    \"\"\"\n    Partition into non-overlapping windows with padding if needed.\n    Args:\n        x (tensor): input tokens with [B, H, W, C].\n        window_size (int): window size.\n\n    Returns:\n        windows: windows after partition with [B * num_windows, window_size, window_size, C].\n        (Hp, Wp): padded height and width before partition\n    \"\"\"\n    B, H, W, C = x.shape\n\n    pad_h = (window_size - H % window_size) % window_size\n    pad_w = (window_size - W % window_size) % window_size\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))\n    Hp, Wp = H + pad_h, W + pad_w\n\n    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows, (Hp, Wp)\n\n\ndef window_unpartition(\n    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]\n) -> torch.Tensor:\n    \"\"\"\n    Window unpartition into original sequences and removing padding.\n    Args:\n        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].\n        window_size (int): window size.\n        pad_hw (Tuple): padded height and width (Hp, Wp).\n        hw (Tuple): original height and width (H, W) before padding.\n\n    Returns:\n        x: unpartitioned sequences with [B, H, W, C].\n    \"\"\"\n    Hp, Wp = pad_hw\n    H, W = hw\n    B = windows.shape[0] // (Hp * Wp // window_size // window_size)\n    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)\n\n    if Hp > H or Wp > W:\n        x = x[:, :H, :W, :].contiguous()\n    return x\n\n\ndef get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Get relative positional embeddings according to the relative positions of\n        query and key sizes.\n    Args:\n        q_size (int): size of query q.\n        k_size (int): size of key k.\n        rel_pos (Tensor): relative position embeddings (L, C).\n\n    Returns:\n        Extracted positional embeddings according to relative positions.\n    \"\"\"\n    max_rel_dist = int(2 * max(q_size, k_size) - 1)\n    # Interpolate rel pos if needed.\n    if rel_pos.shape[0] != max_rel_dist:\n        # Interpolate rel pos.\n        rel_pos_resized = F.interpolate(\n            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),\n            size=max_rel_dist,\n            mode=\"linear\",\n        )\n        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)\n    else:\n        rel_pos_resized = rel_pos\n\n    # Scale the coords with short length if shapes for q and k are different.\n    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)\n    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)\n    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)\n\n    return rel_pos_resized[relative_coords.long()]\n\n\ndef add_decomposed_rel_pos(\n    attn: torch.Tensor,\n    q: torch.Tensor,\n    rel_pos_h: torch.Tensor,\n    rel_pos_w: torch.Tensor,\n    q_size: Tuple[int, int],\n    k_size: Tuple[int, int],\n) -> torch.Tensor:\n    \"\"\"\n    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.\n    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950\n    Args:\n        attn (Tensor): attention map.\n        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).\n        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.\n        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.\n        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).\n        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).\n\n    Returns:\n        attn (Tensor): attention map with added relative positional embeddings.\n    \"\"\"\n    q_h, q_w = q_size\n    k_h, k_w = k_size\n    Rh = get_rel_pos(q_h, k_h, rel_pos_h)\n    Rw = get_rel_pos(q_w, k_w, rel_pos_w)\n\n    B, _, dim = q.shape\n    r_q = q.reshape(B, q_h, q_w, dim)\n    rel_h = torch.einsum(\"bhwc,hkc->bhwk\", r_q, Rh)\n    rel_w = torch.einsum(\"bhwc,wkc->bhwk\", r_q, Rw)\n\n    attn = (\n        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]\n    ).view(B, q_h * q_w, k_h * k_w)\n\n    return attn\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Image to Patch Embedding.\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size: Tuple[int, int] = (16, 16),\n        stride: Tuple[int, int] = (16, 16),\n        padding: Tuple[int, int] = (0, 0),\n        in_chans: int = 3,\n        embed_dim: int = 768,\n    ) -> None:\n        \"\"\"\n        Args:\n            kernel_size (Tuple): kernel size of the projection layer.\n            stride (Tuple): stride of the projection layer.\n            padding (Tuple): padding size of the projection layer.\n            in_chans (int): Number of input image channels.\n            embed_dim (int): Patch embedding dimension.\n        \"\"\"\n        super().__init__()\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.proj(x)\n        # B C H W -> B H W C\n        x = x.permute(0, 2, 3, 1)\n        return x\n\n\n\ndef build_vary_vit_b(checkpoint=None):\n    return _build_vary(\n        encoder_embed_dim=768,\n        encoder_depth=12,\n        encoder_num_heads=12,\n        encoder_global_attn_indexes=[2, 5, 8, 11],\n        checkpoint=checkpoint,\n    )\n\n\ndef _build_vary(\n    encoder_embed_dim,\n    encoder_depth,\n    encoder_num_heads,\n    encoder_global_attn_indexes,\n    checkpoint=None,\n):\n    prompt_embed_dim = 256\n    image_size = 1024\n    vit_patch_size = 16\n    image_embedding_size = image_size // vit_patch_size\n    image_encoder=ImageEncoderViT(\n            depth=encoder_depth,\n            embed_dim=encoder_embed_dim,\n            img_size=image_size,\n            mlp_ratio=4,\n            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),\n            num_heads=encoder_num_heads,\n            patch_size=vit_patch_size,\n            qkv_bias=True,\n            use_rel_pos=True,\n            global_attn_indexes=encoder_global_attn_indexes,\n            window_size=14,\n            out_chans=prompt_embed_dim,\n        )\n    \n    # if checkpoint is not None:\n    #     # with open(checkpoint, \"rb\") as f:\n    #     state_dict = torch.load(checkpoint)\n    #     # print(state_dict.keys())\n    #     # for key in state_dict:\n    #     # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)\n    #     # ocr-anyting\n    #     # image_encoder.load_state_dict(state_dict, strict=True)\n    #     # tob\n    #     # model.vision_tower.\n    #     image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)\n    #     print(checkpoint)\n    return image_encoder\n\n\n\n\nif __name__ == '__main__':\n\n    x = torch.zeros(2, 3, 1024, 1024)\n\n    # x.permute(0, 3, 1, 2)\n\n    net = build_vary_vit_b(checkpoint ='/mnt/shared-storage/tenant/hypertext/xpkong/jycode/checkpoint/pytorch_model.bin')\n    \n    # mlp = Projector(width=256, n_queries = 256, output_dim = 768)\n    y = net(x)\n    y = y.flatten(2).permute(0, 2, 1)\n    print(y.shape)\n    # y = mlp(y)\n    \n    # y = net_2(y)\n    # y = net_3(y)\n    # \n\n    # print(y.shape)"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport logging\nimport pathlib\nimport torch\nimport transformers\n\n# from GOT.train.trainer import GOTTrainer\n# from GOT.train.trainer_vit_llrd import GOTTrainer\nfrom GOT.train.trainer_vit_fixlr import GOTTrainer\nfrom GOT.model import GOTLlamaForCausalLM\nfrom GOT.model import *\nfrom GOT.data import make_supervised_data_module\nfrom GOT.utils.arguments import *\nfrom GOT.utils.constants import *\nfrom GOT.utils.utils import smart_tokenizer_and_embedding_resize\nfrom GOT.model.vision_encoder.sam import build_sam_vit_b\nfrom GOT.model.vision_encoder.swin_transformer import build_swin_transformer\ndef train():\n    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    model = GOTLlamaForCausalLM.from_pretrained(\n        model_args.model_name_or_path,\n        cache_dir=training_args.cache_dir,\n    )\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n        '/data/hypertext/xpkong/newcode/checkpoints/kly-vary-1025-cc595-pretrain/',\n        cache_dir=training_args.cache_dir,\n        model_max_length=training_args.model_max_length,\n        padding_side=\"right\",\n        use_fast=False,\n    )\n\n    # tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side=\"right\", model_max_length=training_args.model_max_length,)\n\n    # # model = AutoModelForCausalLM.from_pretrained(\"/data/public/ucaswei/cache/Qwen/qwen/\", device_map=\"cuda\", trust_remote_code=True).eval()\n\n    # model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')\n\n\n    if data_args.conversation_version == \"v0\" or \"models--decapoda-research--llama-7b-hf\" in model_args.model_name_or_path:\n        if tokenizer.pad_token is None:\n            smart_tokenizer_and_embedding_resize(\n                special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),\n                tokenizer=tokenizer,\n                model=model,\n            )\n        if \"llama\" in model_args.model_name_or_path:\n            tokenizer.add_special_tokens({\n                \"eos_token\": DEFAULT_EOS_TOKEN,\n                \"bos_token\": DEFAULT_BOS_TOKEN,\n                \"unk_token\": DEFAULT_UNK_TOKEN,\n            })\n    else:\n        tokenizer.pad_token = tokenizer.unk_token\n\n    # tokenizer.pad_token = DEFAULT_UNK_TOKEN\n    # tokenizer.pad_token = tokenizer.eos_token\n    # tokenizer.add_special_tokens({'pad_token':'<|endoftext|>'})\n\n    dtype = torch.float32\n    if training_args.fp16:\n        dtype = torch.float16\n    if training_args.bf16:\n        dtype = torch.bfloat16\n\n    vision_tower_dict = model.get_model().initialize_vision_modules(\n        vision_tower=model_args.vision_tower,\n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        freeze_vision_tower=model_args.freeze_vision_tower,\n        use_im_start_end=model_args.use_im_start_end,\n        vision_select_layer=model_args.vision_select_layer,\n        dtype=dtype,\n        device=training_args.device\n    )\n\n    model.initialize_vision_tokenizer(\n        tokenizer=tokenizer, \n        freeze_lm_model=model_args.freeze_lm_model, \n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        device=training_args.device,\n    )\n    model.get_model().vision_tower = transformers.CLIPVisionModel.from_pretrained(\n        '/data/public/ucaswei/pretrain/vit-large-patch14')\n    model.get_model().vision_tower_high = build_sam_vit_b(checkpoint='/data/hypertext/xpkong/newcode/checkpoints/kly-sam-opt-all-1023-new/pytorch_model.bin')\n    # model.get_model().mm_projector = create_perciever()\n\n    model.to(dtype=dtype, device=training_args.device)\n    # 'image_processor_high\n    # data_args.image_token_len = vision_tower_dict['image_token_len']\n    data_args.image_token_len = 256\n    data_args.image_processor = vision_tower_dict['image_processor']\n    data_args.image_processor_high = vision_tower_dict['image_processor_high']\n    data_args.use_im_start_end = model_args.use_im_start_end\n\n    # mixed relation, to be fixed\n    if model_args.freeze_lm_model:\n        model.requires_grad_(False)\n        for p in model.get_model().mm_projector.parameters():\n            p.requires_grad = True\n        # for p in model.get_model().vision_encoder.parameters():\n        #     p.requires_grad = True\n        # for p in model.get_model().chatt.parameters():\n        #     p.requires_grad = True\n        for p in model.get_input_embeddings().parameters():\n            p.requires_grad = True\n        # conv_final\n        # for p in model.get_model().conv_final.parameters():\n        #     p.requires_grad = True\n\n\n        if not model_args.freeze_vision_tower:\n\n            model.get_model().vision_tower.requires_grad_(True)\n            # for i in range(20):\n            #     model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)\n\n            #         for p in model.get_model().vision_encoder.parameters():\n            # p.requires_grad = True\n\n            # for n, p in model.named_parameters():\n            #     print(n, p.requires_grad)\n\n    if model_args.freeze_vision_tower:\n        model.get_model().vision_tower.requires_grad_(False)\n    \n    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]\n    print(f\"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M\")\n\n    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]\n    # if len(params_no_grad) > 0:\n    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:\n    #         if len(params_no_grad) < 10:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))\n    #         else:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))\n    #         print(\"[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.\")\n    #         print(\"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining\")\n\n    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    #         def patch_FSDP_use_orig_params(func):\n    #             def wrap_func(*args, **kwargs):\n    #                 use_orig_params = kwargs.pop('use_orig_params', True)\n    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)\n    #             return wrap_func\n\n    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)\n\n\n    data_module = make_supervised_data_module(\n        interleave=training_args.interleave, \n        with_box=training_args.with_box, \n        tokenizer=tokenizer, \n        data_args=data_args\n    )\n\n    trainer = GOTTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        args=training_args,\n        **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n    trainer._safe_save(output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_GOT.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport logging\nimport pathlib\nimport torch\n# torch.set_num_threads(1)\nimport transformers\n\n# from GOT.train.trainer import GOTTrainer\n# from GOT.train.trainer_vit_llrd import GOTTrainer\nfrom GOT.train.trainer_vit_fixlr import GOTTrainer\nfrom GOT.model import *\nfrom GOT.data import make_supervised_data_module\nfrom GOT.utils.arguments import *\nfrom GOT.utils.constants import *\nfrom GOT.utils.utils import smart_tokenizer_and_embedding_resize\nfrom GOT.model.vision_encoder.vary_b import build_vary_vit_b\nimport os\n\n# os.environ['NCCL_IB_DISABLE'] = '1'\nos.environ['NCCL_DEBUG'] = 'INFO'\nos.environ['OSS_ENDPOINT'] = \"http://oss.i.shaipower.com\"\n\ndef train():\n    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, padding_side=\"right\", model_max_length=training_args.model_max_length,)\n\n\n    model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, use_safetensors=True)\n\n\n\n    smart_tokenizer_and_embedding_resize(\n        special_tokens_dict=dict(pad_token='<|endoftext|>'),\n        tokenizer=tokenizer,\n        model=model,\n        )\n\n\n    dtype = torch.float32\n    if training_args.fp16:\n        dtype = torch.float16\n    if training_args.bf16:\n        dtype = torch.bfloat16\n\n    vision_tower_dict = model.get_model().initialize_vision_modules(\n        vision_tower=model_args.vision_tower,\n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        freeze_vision_tower=model_args.freeze_vision_tower,\n        use_im_start_end=model_args.use_im_start_end,\n        vision_select_layer=model_args.vision_select_layer,\n        dtype=dtype,\n        device=training_args.device\n    )\n\n    model.initialize_vision_tokenizer(\n        tokenizer=tokenizer, \n        freeze_lm_model=model_args.freeze_lm_model, \n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        device=training_args.device,\n    )\n\n\n    model.to(dtype=dtype, device=training_args.device)\n    # 'image_processor_high\n    # data_args.image_token_len = vision_tower_dict['image_token_len']\n    data_args.image_token_len = 256\n    data_args.image_processor = vision_tower_dict['image_processor']\n    data_args.image_processor_high = vision_tower_dict['image_processor_high']\n    data_args.use_im_start_end = model_args.use_im_start_end\n\n    # mixed relation, to be fixed\n    if model_args.freeze_lm_model:\n        model.requires_grad_(False)\n        for p in model.get_model().mm_projector.parameters():\n            p.requires_grad = True\n        for p in model.get_model().mm_projector_vary.parameters():\n            p.requires_grad = True\n        for p in model.get_input_embeddings().parameters():\n            p.requires_grad = True\n\n\n                \n    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]\n    print(f\"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M\")\n\n    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]\n    # if len(params_no_grad) > 0:\n    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:\n    #         if len(params_no_grad) < 10:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))\n    #         else:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))\n    #         print(\"[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.\")\n    #         print(\"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining\")\n\n    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    #         def patch_FSDP_use_orig_params(func):\n    #             def wrap_func(*args, **kwargs):\n    #                 use_orig_params = kwargs.pop('use_orig_params', True)\n    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)\n    #             return wrap_func\n\n    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)\n\n    \n\n    data_module = make_supervised_data_module(\n        interleave=training_args.interleave, \n        with_box=training_args.with_box, \n        tokenizer=tokenizer, \n        data_args=data_args\n    )\n\n    trainer = GOTTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        args=training_args,\n        **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n    trainer._safe_save(output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_flash_attn.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.\n\n# Need to call this before importing transformers.\nfrom GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\n\nreplace_llama_attn_with_flash_attn()\n\nfrom GOT.train.train import train\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_lora.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li\n#\n#    Licensed under the Apache License, Version 2.0 (the \"License\");\n#    you may not use this file except in compliance with the License.\n#    You may obtain a copy of the License at\n#\n#        http://www.apache.org/licenses/LICENSE-2.0\n#\n#    Unless required by applicable law or agreed to in writing, software\n#    distributed under the License is distributed on an \"AS IS\" BASIS,\n#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#    See the License for the specific language governing permissions and\n#    limitations under the License.\n\nimport logging\nimport pathlib\nimport torch\nimport transformers\n\n# from GOT.train.trainer import GOTTrainer\n# from GOT.train.trainer_vit_llrd import GOTTrainer\nfrom GOT.train.trainer_vit_fixlr import GOTTrainer\nfrom GOT.model import GOTLlamaForCausalLM\nfrom GOT.data import make_supervised_data_module\nfrom GOT.utils.arguments import *\nfrom GOT.utils.constants import *\nfrom GOT.utils.utils import *\n\n\n# def find_all_linear_names(model):\n#     cls = torch.nn.Linear\n#     lora_module_names = set()\n#     for name, module in model.named_modules():\n#         if isinstance(module, cls):\n#             names = name.split('.')\n#             lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n\n\n#     if 'lm_head' in lora_module_names: # needed for 16-bit\n#         lora_module_names.remove('lm_head')\n#     return list(lora_module_names)\n\ndef train():\n    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))\n    model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    # model = GOTLlamaForCausalLM.from_pretrained(\n    #     model_args.model_name_or_path,\n    #     cache_dir=training_args.cache_dir,\n    # )\n\n    # tokenizer = transformers.AutoTokenizer.from_pretrained(\n    #     model_args.model_name_or_path,\n    #     cache_dir=training_args.cache_dir,\n    #     model_max_length=training_args.model_max_length,\n    #     padding_side=\"right\",\n    #     use_fast=False,\n    # )\n\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\"/data/public/ucaswei/cache/Qwen/qwen-chat/\", trust_remote_code=True, padding_side=\"right\", model_max_length=training_args.model_max_length,)\n\n    # # model = AutoModelForCausalLM.from_pretrained(\"/data/public/ucaswei/cache/Qwen/qwen/\", device_map=\"cuda\", trust_remote_code=True).eval()\n\n    model = GOTQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')\n\n    smart_tokenizer_and_embedding_resize(\n        special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),\n        tokenizer=tokenizer,\n        model=model,\n        )\n\n    # if data_args.conversation_version == \"v0\" or \"models--decapoda-research--llama-7b-hf\" in model_args.model_name_or_path:\n    #     if tokenizer.pad_token is None:\n    #         smart_tokenizer_and_embedding_resize(\n    #             special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),\n    #             tokenizer=tokenizer,\n    #             model=model,\n    #         )\n    #     if \"llama\" in model_args.model_name_or_path:\n    #         tokenizer.add_special_tokens({\n    #             \"eos_token\": DEFAULT_EOS_TOKEN,\n    #             \"bos_token\": DEFAULT_BOS_TOKEN,\n    #             \"unk_token\": DEFAULT_UNK_TOKEN,\n    #         })\n    # else:\n    #     tokenizer.pad_token = tokenizer.unk_token\n\n    dtype = torch.float32\n    if training_args.fp16:\n        dtype = torch.float16\n    if training_args.bf16:\n        dtype = torch.bfloat16\n\n    if training_args.lora_enable:\n        from peft import LoraConfig, get_peft_model\n        lora_config = LoraConfig(\n            r=training_args.lora_r,\n            lora_alpha=training_args.lora_alpha,\n            target_modules=find_all_linear_names(model),\n            lora_dropout=training_args.lora_dropout,\n            bias=training_args.lora_bias,\n            task_type=\"CAUSAL_LM\",\n        )\n        logging.warning(\"Adding LoRA adapters...\")\n        model = get_peft_model(model, lora_config)\n\n    vision_tower_dict = model.get_model().initialize_vision_modules(\n        vision_tower=model_args.vision_tower,\n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        freeze_vision_tower=model_args.freeze_vision_tower,\n        use_im_start_end=model_args.use_im_start_end,\n        vision_select_layer=model_args.vision_select_layer,\n        dtype=dtype,\n        device=training_args.device\n    )\n\n    model.initialize_vision_tokenizer(\n        tokenizer=tokenizer, \n        freeze_lm_model=model_args.freeze_lm_model, \n        pretrained_stage1_model=model_args.pretrained_stage1_model,\n        device=training_args.device,\n    )\n\n    model.get_model().vision_tower = create_clip_vit_g(448)\n    model.get_model().mm_projector = create_perciever()\n    model.to(dtype=dtype, device=training_args.device)\n\n    data_args.image_token_len = vision_tower_dict['image_token_len']\n    data_args.image_processor = vision_tower_dict['image_processor']\n    data_args.image_processor_high = vision_tower_dict['image_processor_high']\n    data_args.use_im_start_end = model_args.use_im_start_end\n\n    # mixed relation, to be fixed\n    if model_args.freeze_lm_model:\n        model.requires_grad_(False)\n        for p in model.get_model().mm_projector.parameters():\n            p.requires_grad = True\n        for p in model.get_input_embeddings().parameters():\n            p.requires_grad = True\n        for p in model.get_model().conv_final.parameters():\n            p.requires_grad = True\n        for p in model.get_model().vision_encoder.parameters():\n            p.requires_grad = True\n\n        if not model_args.freeze_vision_tower:\n            model.get_model().vision_tower.requires_grad_(True)\n            # for i in range(20):\n            #     model.get_model().vision_tower.vision_model.encoder.layers[i].requires_grad_(False)\n            model.get_model().vision_tower.vision_model.encoder.layers[-1].requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.embeddings.requires_grad_(False)\n            # model.get_model().vision_tower.vision_model.pre_layrnorm.requires_grad_(False)\n            model.get_model().vision_tower.vision_model.post_layernorm.requires_grad_(False)\n\n            for n, p in model.named_parameters():\n                print(n, p.requires_grad)\n                \n    params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]\n    print(f\"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M\")\n\n    # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]\n    # if len(params_no_grad) > 0:\n    #     if training_args.fsdp is not None and len(training_args.fsdp) > 0:\n    #         if len(params_no_grad) < 10:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))\n    #         else:\n    #             print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))\n    #         print(\"[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.\")\n    #         print(\"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining\")\n\n    #         from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n    #         def patch_FSDP_use_orig_params(func):\n    #             def wrap_func(*args, **kwargs):\n    #                 use_orig_params = kwargs.pop('use_orig_params', True)\n    #                 return func(*args, **kwargs, use_orig_params=use_orig_params)\n    #             return wrap_func\n\n    #         FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)\n\n    data_module = make_supervised_data_module(\n        interleave=training_args.interleave, \n        with_box=training_args.with_box, \n        tokenizer=tokenizer, \n        data_args=data_args\n    )\n\n    trainer = GOTTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        args=training_args,\n        **data_module)\n\n    if list(pathlib.Path(training_args.output_dir).glob(\"checkpoint-*\")):\n        trainer.train(resume_from_checkpoint=True)\n    else:\n        trainer.train()\n    trainer.save_state()\n\n    if training_args.lora_enable:\n        state_dict = get_peft_state_maybe_zero_3(\n            model.named_parameters(), training_args.lora_bias\n        )\n        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(\n            model.named_parameters()\n        )\n        if training_args.local_rank == 0 or training_args.local_rank == -1:\n            model.config.save_pretrained(training_args.output_dir)\n            model.save_pretrained(training_args.output_dir, state_dict=state_dict)\n            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))\n    else:\n        trainer._safe_save(output_dir=training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/train_lora_flash_attn.py",
    "content": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:\n# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.\n\n# Need to call this before importing transformers.\nfrom GOT.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn\n\nreplace_llama_attn_with_flash_attn()\n\n# from GOT.train.train import train\nfrom GOT.train.train_lora import train\n\nif __name__ == \"__main__\":\n    train()\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\n\nfrom transformers import Trainer\nfrom typing import Dict, Optional, Sequence\n\n\ndef unwrap_model(model: nn.Module) -> nn.Module:\n    \"\"\"\n    Recursively unwraps a model from potential containers (as used in distributed training).\n\n    Args:\n        model (`torch.nn.Module`): The model to unwrap.\n    \"\"\"\n    # since there could be multiple levels of wrapping, unwrap recursively\n    if hasattr(model, \"module\"):\n        return unwrap_model(model.module)\n    else:\n        return model\n\n\nclass GOTTrainer(Trainer):\n\n    def _safe_save(self, output_dir: str):\n        \"\"\"Collects the state dict and dump to disk.\"\"\"\n        if self.deepspeed:\n            torch.cuda.synchronize()\n            self.save_model(output_dir)\n            return\n    \n        state_dict = self.model.state_dict()\n        if self.args.should_save:\n            cpu_state_dict = {\n                key: value.cpu()\n                for key, value in state_dict.items()\n            }\n            del state_dict\n            self._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            # Save the model\n            _state_dict = state_dict\n            if _state_dict is None:\n                # Only save the model itself if we are using distributed training\n                model_to_save = unwrap_model(self.model)\n                _state_dict = model_to_save.state_dict()\n\n            weight_to_save = {}\n            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']\n            for k, v in _state_dict.items():\n                if any(key_match in k for key_match in keys_to_match):\n                    weight_to_save[k] = v\n\n            current_folder = output_dir.split('/')[-1]\n            parent_folder = os.path.dirname(output_dir)\n            if current_folder.startswith('checkpoint-'):\n                mm_projector_folder = os.path.join(parent_folder, \"mm_projector\")\n                os.makedirs(mm_projector_folder, exist_ok=True)\n                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))\n            else:\n                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n\n        super(GOTTrainer, self)._save(output_dir, state_dict)\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_llm_llrd.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport time\nimport functools\nimport re\n\nfrom transformers import Trainer\nfrom transformers.trainer_pt_utils import (\n    get_module_class_from_name,\n    get_parameter_names,\n)\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    is_sagemaker_dp_enabled,\n    is_sagemaker_mp_enabled,\n    is_torch_neuroncore_available,\n)\nfrom transformers.trainer_utils import (\n    FSDPOption,\n    ShardedDDPOption,\n)\nfrom transformers.training_args import ParallelMode\nfrom transformers.modeling_utils import PreTrainedModel, unwrap_model\nfrom typing import Dict, Optional, Sequence\n\n\ndef lr_scale_func(key):\n    if \"embed_tokens.weight\" in key:\n        return 0\n    if \"mm_projector\" in key:\n        return 0.01\n        # return 1\n    elif \"vision_tower\" in key:\n        return 0.01\n        # return 1\n    elif \"norm.weight\" in key or \"lm_head.weight\" in key:\n        return 1\n    else:\n        in_pp_layer = int(re.findall(f\"layers\\.(\\d+)\\.\", key)[0])\n        decay = 0.86 ** (32 - in_pp_layer - 1)\n        return decay\n                \n\ndef get_param_groups(model, no_weight_decay_cond, scale_lr_cond):\n    \"\"\"creates param groups based on weight decay condition (regularized vs non regularized)\n    and learning rate scale condition (args.lr vs lr_mult * args.lr)\n    scale_lr_cond is used during finetuning where head of the network requires a scaled\n    version of the base learning rate.\n    \"\"\"\n    wd_no_scale_lr = []\n    wd_scale_lr = {}\n    no_wd_no_scale_lr = []\n    no_wd_scale_lr = {}\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n\n        if no_weight_decay_cond is not None:\n            no_wd = no_weight_decay_cond(name, param)\n        else:\n            # do not regularize biases nor Norm parameters\n            no_wd = name.endswith(\".bias\") or len(param.shape) == 1\n\n        if scale_lr_cond is not None:\n            lr_mult = scale_lr_cond(name)\n            print(name, lr_mult)\n            scale_lr = lr_mult != 1\n        else:\n            scale_lr = False\n\n        if not no_wd and not scale_lr:\n            wd_no_scale_lr.append(param)\n        elif not no_wd and scale_lr:\n            if lr_mult not in wd_scale_lr:\n                wd_scale_lr[lr_mult] = [param]\n            else:\n                wd_scale_lr[lr_mult].append(param)\n        elif no_wd and not scale_lr:\n            no_wd_no_scale_lr.append(param)\n        else:\n            if lr_mult not in no_wd_scale_lr:\n                no_wd_scale_lr[lr_mult] = [param]\n            else:\n                no_wd_scale_lr[lr_mult].append(param)\n\n    param_groups = []\n    if len(wd_no_scale_lr):\n        param_groups.append({\"params\": wd_no_scale_lr, \"wd_mult\": 1.0, \"lr_mult\": 1.0})\n    if len(wd_scale_lr):\n        for lr_mult, params in wd_scale_lr.items():\n            param_groups.append({\"params\": params, \"wd_mult\": 1.0, \"lr_mult\": lr_mult})\n    if len(no_wd_no_scale_lr):\n        param_groups.append(\n            {\"params\": no_wd_no_scale_lr, \"wd_mult\": 0.0, \"lr_mult\": 1.0}\n        )\n    if len(no_wd_scale_lr):\n        for lr_mult, params in no_wd_scale_lr.items():\n            param_groups.append({\"params\": params, \"wd_mult\": 0.0, \"lr_mult\": lr_mult})\n\n    return param_groups\n\n\ndef unwrap_model(model: nn.Module) -> nn.Module:\n    \"\"\"\n    Recursively unwraps a model from potential containers (as used in distributed training).\n\n    Args:\n        model (`torch.nn.Module`): The model to unwrap.\n    \"\"\"\n    # since there could be multiple levels of wrapping, unwrap recursively\n    if hasattr(model, \"module\"):\n        return unwrap_model(model.module)\n    else:\n        return model\n\n\nclass GOTTrainer(Trainer):\n\n    def _safe_save(self, output_dir: str):\n        \"\"\"Collects the state dict and dump to disk.\"\"\"\n        state_dict = self.model.state_dict()\n        if self.args.should_save:\n            cpu_state_dict = {\n                key: value.cpu()\n                for key, value in state_dict.items()\n            }\n            del state_dict\n            self._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            # Save the model\n            _state_dict = state_dict\n            if _state_dict is None:\n                # Only save the model itself if we are using distributed training\n                model_to_save = unwrap_model(self.model)\n                _state_dict = model_to_save.state_dict()\n\n            weight_to_save = {}\n            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']\n            for k, v in _state_dict.items():\n                if any(key_match in k for key_match in keys_to_match):\n                    weight_to_save[k] = v\n\n            current_folder = output_dir.split('/')[-1]\n            parent_folder = os.path.dirname(output_dir)\n            if current_folder.startswith('checkpoint-'):\n                mm_projector_folder = os.path.join(parent_folder, \"mm_projector\")\n                os.makedirs(mm_projector_folder, exist_ok=True)\n                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))\n            else:\n                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n\n        super(GOTTrainer, self)._save(output_dir, state_dict)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        opt_model = self.model\n\n        if self.optimizer is None:\n            # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            # decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            # optimizer_grouped_parameters = [\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": self.args.weight_decay,\n            #     },\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": 0.0,\n            #     },\n            # ]\n\n            optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func)\n            \n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n\n        return self.optimizer\n    \n\n    def _wrap_model(self, model, training=True, dataloader=None):\n        if self.args.use_ipex:\n            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32\n            model = self.ipex_optimize_model(model, training, dtype=dtype)\n\n        if is_sagemaker_mp_enabled():\n            import smdistributed.modelparallel.torch as smp\n            # Wrapping the base model twice in a DistributedModel will raise an error.\n            if isinstance(self.model_wrapped, smp.model.DistributedModel):\n                return self.model_wrapped\n            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)\n        # already initialized its own DDP and AMP\n        if self.deepspeed:\n            return self.deepspeed\n\n        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again\n        if unwrap_model(model) is not model:\n            return model\n\n        # Mixed precision training with apex (torch < 1.6)\n        if self.use_apex and training:\n            from apex import amp\n            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)\n\n        # Multi-gpu training (should be after apex fp16 initialization)\n        if self.args.n_gpu > 1:\n            model = nn.DataParallel(model)\n\n        if self.args.jit_mode_eval:\n            start_time = time.time()\n            model = self.torch_jit_model_eval(model, dataloader, training)\n            self.jit_compilation_time = round(time.time() - start_time, 4)\n\n        # Note: in torch.distributed mode, there's no point in wrapping the model\n        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.\n        if not training:\n            return model\n\n        # Distributed training (should be after apex fp16 initialization)\n        if self.sharded_ddp is not None:\n            from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP\n            from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP\n            from fairscale.nn.wrap import auto_wrap\n            # Sharded DDP!\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                model = ShardedDDP(model, self.optimizer)\n            else:\n                mixed_precision = self.args.fp16 or self.args.bf16\n                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp\n                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3\n                # XXX: Breaking the self.model convention but I see no way around it for now.\n                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:\n                    model = auto_wrap(model)\n                self.model = model = FullyShardedDDP(\n                    model,\n                    mixed_precision=mixed_precision,\n                    reshard_after_forward=zero_3,\n                    cpu_offload=cpu_offload,\n                ).to(self.args.device)\n        # Distributed training using PyTorch FSDP\n        elif self.fsdp is not None:\n            if not self.args.fsdp_config[\"xla\"]:\n                # PyTorch FSDP!\n                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision\n                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\n\n                if FSDPOption.OFFLOAD in self.args.fsdp:\n                    cpu_offload = CPUOffload(offload_params=True)\n                else:\n                    cpu_offload = CPUOffload(offload_params=False)\n\n                auto_wrap_policy = None\n\n                if FSDPOption.AUTO_WRAP in self.args.fsdp:\n                    if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                        auto_wrap_policy = functools.partial(\n                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                        )\n                    elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                        transformer_cls_to_wrap = set()\n                        for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                            transformer_cls = get_module_class_from_name(model, layer_class)\n                            if transformer_cls is None:\n                                raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                            else:\n                                transformer_cls_to_wrap.add(transformer_cls)\n                        auto_wrap_policy = functools.partial(\n                            transformer_auto_wrap_policy,\n                            # Transformer layer class to wrap\n                            transformer_layer_cls=transformer_cls_to_wrap,\n                        )\n                mixed_precision_policy = None\n                dtype = None\n                if self.args.fp16:\n                    dtype = torch.float16\n                elif self.args.bf16:\n                    dtype = torch.bfloat16\n                if dtype is not None:\n                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)\n                if type(model) != FSDP:\n                    # XXX: Breaking the self.model convention but I see no way around it for now.\n                    self.model = model = FSDP(\n                        model,\n                        sharding_strategy=self.fsdp,\n                        cpu_offload=cpu_offload,\n                        auto_wrap_policy=auto_wrap_policy,\n                        mixed_precision=mixed_precision_policy,\n                        device_id=self.args.device,\n                        backward_prefetch=self.backward_prefetch,\n                        forward_prefetch=self.forword_prefetch,\n                        limit_all_gathers=self.limit_all_gathers,\n                        use_orig_params=True,\n                    )\n            else:\n                try:\n                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP\n                    from torch_xla.distributed.fsdp import checkpoint_module\n                    from torch_xla.distributed.fsdp.wrap import (\n                        size_based_auto_wrap_policy,\n                        transformer_auto_wrap_policy,\n                    )\n                except ImportError:\n                    raise ImportError(\"Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.\")\n                auto_wrap_policy = None\n                auto_wrapper_callable = None\n                if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                    auto_wrap_policy = functools.partial(\n                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                    )\n                elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                    transformer_cls_to_wrap = set()\n                    for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                        transformer_cls = get_module_class_from_name(model, layer_class)\n                        if transformer_cls is None:\n                            raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                        else:\n                            transformer_cls_to_wrap.add(transformer_cls)\n                    auto_wrap_policy = functools.partial(\n                        transformer_auto_wrap_policy,\n                        # Transformer layer class to wrap\n                        transformer_layer_cls=transformer_cls_to_wrap,\n                    )\n                fsdp_kwargs = self.args.xla_fsdp_config\n                if self.args.fsdp_config[\"xla_fsdp_grad_ckpt\"]:\n                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified\n                    def auto_wrapper_callable(m, *args, **kwargs):\n                        return FSDP(checkpoint_module(m), *args, **kwargs)\n\n                # Wrap the base model with an outer FSDP wrapper\n                self.model = model = FSDP(\n                    model,\n                    auto_wrap_policy=auto_wrap_policy,\n                    auto_wrapper_callable=auto_wrapper_callable,\n                    **fsdp_kwargs,\n                )\n\n                import torch_xla.core.xla_model as xm\n                # Patch `xm.optimizer_step` should not reduce gradients in this case,\n                # as FSDP does not need gradient reduction over sharded parameters.\n                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):\n                    loss = optimizer.step(**optimizer_args)\n                    if barrier:\n                        xm.mark_step()\n                    return loss\n\n                xm.optimizer_step = patched_optimizer_step\n        elif is_sagemaker_dp_enabled():\n            model = nn.parallel.DistributedDataParallel(\n                model, device_ids=[int(os.getenv(\"SMDATAPARALLEL_LOCAL_RANK\"))]\n            )\n        elif self.args.local_rank != -1:\n            kwargs = {}\n            if self.args.ddp_find_unused_parameters is not None:\n                kwargs[\"find_unused_parameters\"] = self.args.ddp_find_unused_parameters\n            elif isinstance(model, PreTrainedModel):\n                # find_unused_parameters breaks checkpointing as per\n                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021\n                kwargs[\"find_unused_parameters\"] = not model.is_gradient_checkpointing\n            else:\n                kwargs[\"find_unused_parameters\"] = True\n\n            if self.args.ddp_bucket_cap_mb is not None:\n                kwargs[\"bucket_cap_mb\"] = self.args.ddp_bucket_cap_mb\n            if is_torch_neuroncore_available():\n                return model\n            model = nn.parallel.DistributedDataParallel(\n                model,\n                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,\n                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,\n                **kwargs,\n            )\n\n        # torch.compile() needs to be called after wrapping the model with FSDP or DDP\n        # to ensure that it accounts for the graph breaks required by those wrappers\n        if self.args.torch_compile:\n            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)\n\n        return model\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_vit_fixlr.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\n\nfrom transformers import Trainer\nfrom transformers.trainer_pt_utils import get_parameter_names\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom typing import Dict, Optional, Sequence\n\n\ndef unwrap_model(model: nn.Module) -> nn.Module:\n    \"\"\"\n    Recursively unwraps a model from potential containers (as used in distributed training).\n\n    Args:\n        model (`torch.nn.Module`): The model to unwrap.\n    \"\"\"\n    # since there could be multiple levels of wrapping, unwrap recursively\n    if hasattr(model, \"module\"):\n        return unwrap_model(model.module)\n    else:\n        return model\n\n\nclass GOTTrainer(Trainer):\n\n    def _safe_save(self, output_dir: str):\n        \"\"\"Collects the state dict and dump to disk.\"\"\"\n        state_dict = self.model.state_dict()\n        if self.args.should_save:\n            cpu_state_dict = {\n                key: value.cpu()\n                for key, value in state_dict.items()\n            }\n            del state_dict\n            self._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            # Save the model\n            _state_dict = state_dict\n            if _state_dict is None:\n                # Only save the model itself if we are using distributed training\n                model_to_save = unwrap_model(self.model)\n                _state_dict = model_to_save.state_dict()\n\n            weight_to_save = {}\n            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']\n            for k, v in _state_dict.items():\n                if any(key_match in k for key_match in keys_to_match):\n                    weight_to_save[k] = v\n\n            current_folder = output_dir.split('/')[-1]\n            parent_folder = os.path.dirname(output_dir)\n            if current_folder.startswith('checkpoint-'):\n                mm_projector_folder = os.path.join(parent_folder, \"mm_projector\")\n                os.makedirs(mm_projector_folder, exist_ok=True)\n                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))\n            else:\n                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n\n        super(GOTTrainer, self)._save(output_dir, state_dict)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        opt_model = self.model\n\n        if self.optimizer is None:\n            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            optimizer_grouped_parameters = [\n                {\n                    \"params\": [\n                        p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad\n                    ],\n                    \"weight_decay\": self.args.weight_decay,\n                    \"lr\": self.args.learning_rate,\n                },\n                {\n                    \"params\": [\n                        p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad],\n                    \"weight_decay\": 0.0,\n                    \"lr\": self.args.learning_rate,\n                },\n                {\n                    \"params\": [\n                        p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad],\n                    \"weight_decay\": self.args.weight_decay,\n                    \"lr\": self.args.learning_rate,\n                },\n                {\n                    \"params\": [\n                        p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad\n                    ],\n                    \"weight_decay\": 0.0,\n                    \"lr\": self.args.learning_rate,\n                },\n            ]\n            for idx, group in enumerate(optimizer_grouped_parameters):\n                print(idx, len(group['params']), group['lr'])\n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n\n        return self.optimizer"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/train/trainer_vit_llrd.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\nimport time\nimport functools\nimport re\n\nfrom transformers import Trainer\nfrom transformers.trainer_pt_utils import (\n    get_module_class_from_name,\n    get_parameter_names,\n)\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    is_sagemaker_dp_enabled,\n    is_sagemaker_mp_enabled,\n    is_torch_neuroncore_available,\n)\nfrom transformers.trainer_utils import (\n    FSDPOption,\n    ShardedDDPOption,\n)\nfrom transformers.training_args import ParallelMode\nfrom transformers.modeling_utils import PreTrainedModel, unwrap_model\nfrom typing import Dict, Optional, Sequence\n\n\ndef lr_scale_func(key):\n    if \"vision_model.encoder.layers\" in key:\n        in_pp_layer = int(re.findall(f\"layers\\.(\\d+)\\.\", key)[0])\n        # decay = 0.81 ** (23 - in_pp_layer - 1)\n        decay = 0.81 ** (23 - in_pp_layer - 1) * 0.01\n        # decay = 0.66 ** (23 - in_pp_layer - 1)\n        return decay\n        # return 0.01\n    elif \"vision_model\" in key:\n        # return 0.01\n        return 0.0001\n    return 1\n                \n\ndef get_param_groups(model, no_weight_decay_cond, scale_lr_cond, lr, wd):\n    \"\"\"creates param groups based on weight decay condition (regularized vs non regularized)\n    and learning rate scale condition (args.lr vs lr_mult * args.lr)\n    scale_lr_cond is used during finetuning where head of the network requires a scaled\n    version of the base learning rate.\n    \"\"\"\n    wd_no_scale_lr = []\n    wd_scale_lr = {}\n    no_wd_no_scale_lr = []\n    no_wd_scale_lr = {}\n    for name, param in model.named_parameters():\n        if not param.requires_grad:\n            continue\n\n        if no_weight_decay_cond is not None:\n            no_wd = no_weight_decay_cond(name, param)\n        else:\n            # do not regularize biases nor Norm parameters\n            no_wd = name.endswith(\".bias\") or len(param.shape) == 1\n\n        if scale_lr_cond is not None:\n            lr_mult = scale_lr_cond(name)\n            print(name, lr_mult)\n            scale_lr = lr_mult != 1\n        else:\n            scale_lr = False\n\n        if not no_wd and not scale_lr:\n            wd_no_scale_lr.append(param)\n        elif not no_wd and scale_lr:\n            if lr_mult not in wd_scale_lr:\n                wd_scale_lr[lr_mult] = [param]\n            else:\n                wd_scale_lr[lr_mult].append(param)\n        elif no_wd and not scale_lr:\n            no_wd_no_scale_lr.append(param)\n        else:\n            if lr_mult not in no_wd_scale_lr:\n                no_wd_scale_lr[lr_mult] = [param]\n            else:\n                no_wd_scale_lr[lr_mult].append(param)\n\n    param_groups = []\n    if len(wd_no_scale_lr):\n        param_groups.append({\"params\": wd_no_scale_lr, \"weight_decay\": wd, \"lr\": lr})\n    if len(wd_scale_lr):\n        for lr_mult, params in wd_scale_lr.items():\n            param_groups.append({\"params\": params, \"weight_decay\": wd, \"lr\": lr * lr_mult})\n    if len(no_wd_no_scale_lr):\n        param_groups.append(\n            {\"params\": no_wd_no_scale_lr, \"weight_decay\": 0.0, \"lr\": lr}\n        )\n    if len(no_wd_scale_lr):\n        for lr_mult, params in no_wd_scale_lr.items():\n            param_groups.append({\"params\": params, \"weight_decay\": 0.0, \"lr\": lr * lr_mult})\n\n    return param_groups\n\n\ndef unwrap_model(model: nn.Module) -> nn.Module:\n    \"\"\"\n    Recursively unwraps a model from potential containers (as used in distributed training).\n\n    Args:\n        model (`torch.nn.Module`): The model to unwrap.\n    \"\"\"\n    # since there could be multiple levels of wrapping, unwrap recursively\n    if hasattr(model, \"module\"):\n        return unwrap_model(model.module)\n    else:\n        return model\n\n\nclass GOTTrainer(Trainer):\n\n    def _safe_save(self, output_dir: str):\n        \"\"\"Collects the state dict and dump to disk.\"\"\"\n        state_dict = self.model.state_dict()\n        if self.args.should_save:\n            cpu_state_dict = {\n                key: value.cpu()\n                for key, value in state_dict.items()\n            }\n            del state_dict\n            self._save(output_dir, state_dict=cpu_state_dict)  # noqa\n\n\n    def _save(self, output_dir: Optional[str] = None, state_dict=None):\n        if getattr(self.args, 'tune_mm_mlp_adapter', False):\n            # Save the model\n            _state_dict = state_dict\n            if _state_dict is None:\n                # Only save the model itself if we are using distributed training\n                model_to_save = unwrap_model(self.model)\n                _state_dict = model_to_save.state_dict()\n\n            weight_to_save = {}\n            keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']\n            for k, v in _state_dict.items():\n                if any(key_match in k for key_match in keys_to_match):\n                    weight_to_save[k] = v\n\n            current_folder = output_dir.split('/')[-1]\n            parent_folder = os.path.dirname(output_dir)\n            if current_folder.startswith('checkpoint-'):\n                mm_projector_folder = os.path.join(parent_folder, \"mm_projector\")\n                os.makedirs(mm_projector_folder, exist_ok=True)\n                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))\n            else:\n                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))\n\n        super(GOTTrainer, self)._save(output_dir, state_dict)\n\n    def create_optimizer(self):\n        \"\"\"\n        Setup the optimizer.\n\n        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the\n        Trainer's init through `optimizers`, or subclass and override this method in a subclass.\n        \"\"\"\n        opt_model = self.model\n\n        if self.optimizer is None:\n            # decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)\n            # decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n            # optimizer_grouped_parameters = [\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": self.args.weight_decay,\n            #     },\n            #     {\n            #         \"params\": [\n            #             p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)\n            #         ],\n            #         \"weight_decay\": 0.0,\n            #     },\n            # ]\n\n            optimizer_grouped_parameters = get_param_groups(opt_model, None, lr_scale_func, self.args.learning_rate, self.args.weight_decay)\n            \n            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)\n            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)\n\n        return self.optimizer\n    \n\n    def _wrap_model(self, model, training=True, dataloader=None):\n        if self.args.use_ipex:\n            dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32\n            model = self.ipex_optimize_model(model, training, dtype=dtype)\n\n        if is_sagemaker_mp_enabled():\n            import smdistributed.modelparallel.torch as smp\n            # Wrapping the base model twice in a DistributedModel will raise an error.\n            if isinstance(self.model_wrapped, smp.model.DistributedModel):\n                return self.model_wrapped\n            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)\n        # already initialized its own DDP and AMP\n        if self.deepspeed:\n            return self.deepspeed\n\n        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again\n        if unwrap_model(model) is not model:\n            return model\n\n        # Mixed precision training with apex (torch < 1.6)\n        if self.use_apex and training:\n            from apex import amp\n            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)\n\n        # Multi-gpu training (should be after apex fp16 initialization)\n        if self.args.n_gpu > 1:\n            model = nn.DataParallel(model)\n\n        if self.args.jit_mode_eval:\n            start_time = time.time()\n            model = self.torch_jit_model_eval(model, dataloader, training)\n            self.jit_compilation_time = round(time.time() - start_time, 4)\n\n        # Note: in torch.distributed mode, there's no point in wrapping the model\n        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.\n        if not training:\n            return model\n\n        # Distributed training (should be after apex fp16 initialization)\n        if self.sharded_ddp is not None:\n            from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP\n            from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP\n            from fairscale.nn.wrap import auto_wrap\n            # Sharded DDP!\n            if self.sharded_ddp == ShardedDDPOption.SIMPLE:\n                model = ShardedDDP(model, self.optimizer)\n            else:\n                mixed_precision = self.args.fp16 or self.args.bf16\n                cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp\n                zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3\n                # XXX: Breaking the self.model convention but I see no way around it for now.\n                if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:\n                    model = auto_wrap(model)\n                self.model = model = FullyShardedDDP(\n                    model,\n                    mixed_precision=mixed_precision,\n                    reshard_after_forward=zero_3,\n                    cpu_offload=cpu_offload,\n                ).to(self.args.device)\n        # Distributed training using PyTorch FSDP\n        elif self.fsdp is not None:\n            if not self.args.fsdp_config[\"xla\"]:\n                # PyTorch FSDP!\n                from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision\n                from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP\n                from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy\n\n                if FSDPOption.OFFLOAD in self.args.fsdp:\n                    cpu_offload = CPUOffload(offload_params=True)\n                else:\n                    cpu_offload = CPUOffload(offload_params=False)\n\n                auto_wrap_policy = None\n\n                if FSDPOption.AUTO_WRAP in self.args.fsdp:\n                    if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                        auto_wrap_policy = functools.partial(\n                            size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                        )\n                    elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                        transformer_cls_to_wrap = set()\n                        for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                            transformer_cls = get_module_class_from_name(model, layer_class)\n                            if transformer_cls is None:\n                                raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                            else:\n                                transformer_cls_to_wrap.add(transformer_cls)\n                        auto_wrap_policy = functools.partial(\n                            transformer_auto_wrap_policy,\n                            # Transformer layer class to wrap\n                            transformer_layer_cls=transformer_cls_to_wrap,\n                        )\n                mixed_precision_policy = None\n                dtype = None\n                if self.args.fp16:\n                    dtype = torch.float16\n                elif self.args.bf16:\n                    dtype = torch.bfloat16\n                if dtype is not None:\n                    mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)\n                if type(model) != FSDP:\n                    # XXX: Breaking the self.model convention but I see no way around it for now.\n                    self.model = model = FSDP(\n                        model,\n                        sharding_strategy=self.fsdp,\n                        cpu_offload=cpu_offload,\n                        auto_wrap_policy=auto_wrap_policy,\n                        mixed_precision=mixed_precision_policy,\n                        device_id=self.args.device,\n                        backward_prefetch=self.backward_prefetch,\n                        forward_prefetch=self.forword_prefetch,\n                        limit_all_gathers=self.limit_all_gathers,\n                        use_orig_params=True,\n                    )\n            else:\n                try:\n                    from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP\n                    from torch_xla.distributed.fsdp import checkpoint_module\n                    from torch_xla.distributed.fsdp.wrap import (\n                        size_based_auto_wrap_policy,\n                        transformer_auto_wrap_policy,\n                    )\n                except ImportError:\n                    raise ImportError(\"Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.\")\n                auto_wrap_policy = None\n                auto_wrapper_callable = None\n                if self.args.fsdp_config[\"fsdp_min_num_params\"] > 0:\n                    auto_wrap_policy = functools.partial(\n                        size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config[\"fsdp_min_num_params\"]\n                    )\n                elif self.args.fsdp_config.get(\"fsdp_transformer_layer_cls_to_wrap\", None) is not None:\n                    transformer_cls_to_wrap = set()\n                    for layer_class in self.args.fsdp_config[\"fsdp_transformer_layer_cls_to_wrap\"]:\n                        transformer_cls = get_module_class_from_name(model, layer_class)\n                        if transformer_cls is None:\n                            raise Exception(\"Could not find the transformer layer class to wrap in the model.\")\n                        else:\n                            transformer_cls_to_wrap.add(transformer_cls)\n                    auto_wrap_policy = functools.partial(\n                        transformer_auto_wrap_policy,\n                        # Transformer layer class to wrap\n                        transformer_layer_cls=transformer_cls_to_wrap,\n                    )\n                fsdp_kwargs = self.args.xla_fsdp_config\n                if self.args.fsdp_config[\"xla_fsdp_grad_ckpt\"]:\n                    # Apply gradient checkpointing to auto-wrapped sub-modules if specified\n                    def auto_wrapper_callable(m, *args, **kwargs):\n                        return FSDP(checkpoint_module(m), *args, **kwargs)\n\n                # Wrap the base model with an outer FSDP wrapper\n                self.model = model = FSDP(\n                    model,\n                    auto_wrap_policy=auto_wrap_policy,\n                    auto_wrapper_callable=auto_wrapper_callable,\n                    **fsdp_kwargs,\n                )\n\n                import torch_xla.core.xla_model as xm\n                # Patch `xm.optimizer_step` should not reduce gradients in this case,\n                # as FSDP does not need gradient reduction over sharded parameters.\n                def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):\n                    loss = optimizer.step(**optimizer_args)\n                    if barrier:\n                        xm.mark_step()\n                    return loss\n\n                xm.optimizer_step = patched_optimizer_step\n        elif is_sagemaker_dp_enabled():\n            model = nn.parallel.DistributedDataParallel(\n                model, device_ids=[int(os.getenv(\"SMDATAPARALLEL_LOCAL_RANK\"))]\n            )\n        elif self.args.local_rank != -1:\n            kwargs = {}\n            if self.args.ddp_find_unused_parameters is not None:\n                kwargs[\"find_unused_parameters\"] = self.args.ddp_find_unused_parameters\n            elif isinstance(model, PreTrainedModel):\n                # find_unused_parameters breaks checkpointing as per\n                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021\n                kwargs[\"find_unused_parameters\"] = not model.is_gradient_checkpointing\n            else:\n                kwargs[\"find_unused_parameters\"] = True\n\n            if self.args.ddp_bucket_cap_mb is not None:\n                kwargs[\"bucket_cap_mb\"] = self.args.ddp_bucket_cap_mb\n            if is_torch_neuroncore_available():\n                return model\n            model = nn.parallel.DistributedDataParallel(\n                model,\n                device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,\n                output_device=self.args.local_rank if self.args._n_gpu != 0 else None,\n                **kwargs,\n            )\n\n        # torch.compile() needs to be called after wrapping the model with FSDP or DDP\n        # to ensure that it accounts for the graph breaks required by those wrappers\n        if self.args.torch_compile:\n            model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)\n\n        return model\n\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/arguments.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Dict, Optional, Sequence\nimport transformers\n\n\n@dataclass\nclass ModelArguments:\n    model_name_or_path: Optional[str] = field(default=\"facebook/opt-125m\")\n    use_cache: bool = field(default=False)\n    vision_tower: Optional[str] = field(default=\"~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/\")\n    freeze_vision_tower: bool = field(default=False)\n    freeze_lm_model: bool = field(default=False)\n    pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower\n    vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer\n    use_im_start_end: bool = field(default=False)\n\n\n@dataclass\nclass DataArguments:\n    datasets: str = field(default=None, metadata={\"help\": \"combinations of the training data.\"})\n    sep_image_conv_front: bool = False\n    image_token_len: int = 256\n    image_aspect_ratio: str = 'square'\n    conversation_version: str = 'mpt'\n    # conversation_version: str = 'v0'\n    # conversation_version: str = 'v1'\n    # conversation_version: str = 'nougat'\n    # conversation_version: str = 'baichuan'\n    # conversation_version: str = 'opt'\n    box_limit: int = 0\n\n\n@dataclass\nclass TrainingArguments(transformers.TrainingArguments):\n    cache_dir: Optional[str] = field(default=None)\n    optim: str = field(default=\"adamw_torch\")\n    remove_unused_columns: bool = field(default=False)\n    force_fsdp: bool = field(default=False)\n    interleave: bool = field(default=False)\n    with_box: bool = field(default=False)\n    model_max_length: int = field(\n        default=512,\n        metadata={\n            \"help\":\n            \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n        },\n    )\n    lora_enable: bool = False\n    lora_r: int = 8\n    lora_alpha: int = 16\n    lora_dropout: float = 0.05\n    lora_weight_path: str = \"\"\n    lora_bias: str = \"none\""
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/constants.py",
    "content": "CONTROLLER_HEART_BEAT_EXPIRATION = 30\nWORKER_HEART_BEAT_INTERVAL = 15\n\nLOGDIR = \"log\"\n\nIGNORE_INDEX = -100\n# DEFAULT_PAD_TOKEN = \"[PAD]\"\n\nDEFAULT_PAD_TOKEN = \"<|endoftext|>\"\nDEFAULT_EOS_TOKEN = \"</s>\"\nDEFAULT_BOS_TOKEN = \"</s>\"\nDEFAULT_UNK_TOKEN = \"<unk>\"\nDEFAULT_IMAGE_TOKEN = \"<image>\"\nDEFAULT_BOX_TOKEN = \"<box>\"\n\nDEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'\n\nDEFAULT_IM_START_TOKEN = '<img>'\nDEFAULT_IM_END_TOKEN = '</img>'\n\n\n\nCONVERSATION_DATA = {\n\n    'data_1': {\n        'images': '/path/',\n        'annotations': '/path/data1.json',\n    },\n    'data_2': {\n        'images': '/path/',\n        'annotations': '/path/data2.json',\n    },\n    'data_3': {\n        'images': '/path/',\n        'annotations': '/path/data3.json',\n    },\n\n\n}"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/conversation.py",
    "content": "import dataclasses\nfrom enum import auto, Enum\nfrom typing import List, Tuple\n\n\nclass SeparatorStyle(Enum):\n    \"\"\"Different separator style.\"\"\"\n    SINGLE = auto()\n    TWO = auto()\n    MPT = auto()\n\n\n\n# simple_conv_multimodal = Conversation(\n#     system=\"You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.\"\n#            \"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\n#            \"Follow the instructions carefully and explain your answers in detail.\",\n#     # system=\"\",\n#     roles=(\"Human\", \"Assistant\"),\n#     messages=(\n#         (\"Human\", \"Hi!\"),\n#         (\"Assistant\", \"Hi there!  How can I help you today?\\n\")\n#     ),\n#     offset=2,\n#     sep_style=SeparatorStyle.SINGLE,\n#     sep=\"###\",\n# )\n\n# conv_mpt = Conversation(\n#     system=\"\"\"<|im_start|>system\n# - You are a helpful language and vision assistant.\n# - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n# - You should follow the instructions carefully and explain your answers in detail.\"\"\",\n#     roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n#     version=\"mpt\",\n#     messages=(),\n#     offset=0,\n#     sep_style=SeparatorStyle.MPT,\n#     sep=\"<|im_end|>\",\n# )\n\n@dataclasses.dataclass\nclass Conversation:\n    \"\"\"A class that keeps all conversation history.\"\"\"\n    system: str\n    roles: List[str]\n    messages: List[List[str]]\n    offset: int\n    sep_style: SeparatorStyle = SeparatorStyle.SINGLE\n    sep: str = \"<|im_end|>\"\n    sep2: str = None\n    version: str = \"Unknown\"\n\n    skip_next: bool = False\n\n    def get_prompt(self):\n        if self.sep_style == SeparatorStyle.SINGLE:\n            ret = self.system + self.sep + '\\n'\n            for role, message in self.messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + \": \" + message + self.sep\n                else:\n                    ret += role + \":\"\n            return ret\n        elif self.sep_style == SeparatorStyle.TWO:\n            seps = [self.sep, self.sep2]\n            ret = self.system + seps[0]\n            for i, (role, message) in enumerate(self.messages):\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + \": \" + message + seps[i % 2]\n                else:\n                    ret += role + \":\"\n            return ret\n        if self.sep_style == SeparatorStyle.MPT:\n            if self.system:\n                ret = self.system + self.sep \n            else:\n                ret = ''\n            for role, message in self.messages:\n                if message:\n                    if type(message) is tuple:\n                        message, _, _ = message\n                    ret += role + message + self.sep\n                else:\n                    ret += role\n            return ret\n        else:\n            raise ValueError(f\"Invalid style: {self.sep_style}\")\n        # if self.sep_style == SeparatorStyle.MPT:\n        #     if self.system:\n        #         ret = self.system + self.sep\n        #     else:\n        #         ret = ''\n        #     for role, message in self.messages:\n        #         if message:\n        #             if type(message) is tuple:\n        #                 message, _, _ = message\n        #             ret += role + message + self.sep \n        #             # if 'user' in role:\n        #             #     ret += role + message + self.sep + \"\\n\"\n        #             # else:\n        #             #     ret += role + message + self.sep \n        #         else:\n        #             ret += role\n        #     return ret\n        # else:\n        #     raise ValueError(f\"Invalid style: {self.sep_style}\")\n\n    def append_message(self, role, message):\n        self.messages.append([role, message])\n\n    def get_images(self, return_pil=False):\n        images = []\n        for i, (role, msg) in enumerate(self.messages[self.offset:]):\n            if i % 2 == 0:\n                if type(msg) is tuple:\n                    import base64\n                    from io import BytesIO\n                    from PIL import Image\n                    msg, image, image_process_mode = msg\n                    if image_process_mode == \"Pad\":\n                        def expand2square(pil_img, background_color=(122, 116, 104)):\n                            width, height = pil_img.size\n                            if width == height:\n                                return pil_img\n                            elif width > height:\n                                result = Image.new(pil_img.mode, (width, width), background_color)\n                                # result.paste(pil_img, (0, (width - height) // 2))\n                                result.paste(pil_img)\n                                return result\n                            else:\n                                result = Image.new(pil_img.mode, (height, height), background_color)\n                                # result.paste(pil_img, ((height - width) // 2, 0))\n                                result.paste(pil_img)\n                                return result\n                        image = expand2square(image)\n                    elif image_process_mode == \"Crop\":\n                        max_hw, min_hw = max(image.size), min(image.size)\n                        aspect_ratio = max_hw / min_hw\n                        max_len, min_len = 800, 400\n                        shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))\n                        longest_edge = int(shortest_edge * aspect_ratio)\n                        W, H = image.size\n                        if H > W:\n                            H, W = longest_edge, shortest_edge\n                        else:\n                            H, W = shortest_edge, longest_edge\n                        image = image.resize((W, H))\n                    elif image_process_mode == \"Resize\":\n                        image = image.resize((224, 224))\n                    else:\n                        raise ValueError(f\"Invalid image_process_mode: {image_process_mode}\")\n\n                    if return_pil:\n                        images.append(image)\n                    else:\n                        buffered = BytesIO()\n                        image.convert('RGB').save(buffered, format=\"JPEG\")\n                        img_b64_str = base64.b64encode(buffered.getvalue()).decode()\n                        images.append(img_b64_str)\n        return images\n\n    def to_gradio_chatbot(self):\n        ret = []\n        for i, (role, msg) in enumerate(self.messages[self.offset:]):\n            if i % 2 == 0:\n                if type(msg) is tuple:\n                    import base64\n                    from io import BytesIO\n                    msg, image, image_process_mode = msg\n                    max_hw, min_hw = max(image.size), min(image.size)\n                    aspect_ratio = max_hw / min_hw\n                    max_len, min_len = 800, 400\n                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))\n                    longest_edge = int(shortest_edge * aspect_ratio)\n                    W, H = image.size\n                    if H > W:\n                        H, W = longest_edge, shortest_edge\n                    else:\n                        H, W = shortest_edge, longest_edge\n                    image = image.resize((W, H))\n                    # image = image.resize((224, 224))\n                    buffered = BytesIO()\n                    image.save(buffered, format=\"JPEG\")\n                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()\n                    img_str = f'<img src=\"data:image/png;base64,{img_b64_str}\" alt=\"user upload image\" />'\n                    msg = msg.replace('<image>', img_str)\n                ret.append([msg, None])\n            else:\n                ret[-1][-1] = msg\n        return ret\n\n    def copy(self):\n        return Conversation(\n            system=self.system,\n            roles=self.roles,\n            messages=[[x, y] for x, y in self.messages],\n            offset=self.offset,\n            sep_style=self.sep_style,\n            sep=self.sep,\n            sep2=self.sep2)\n\n    def dict(self):\n        if len(self.get_images()) > 0:\n            return {\n                \"system\": self.system,\n                \"roles\": self.roles,\n                \"messages\": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],\n                \"offset\": self.offset,\n                \"sep\": self.sep,\n                \"sep2\": self.sep2,\n            }\n        return {\n            \"system\": self.system,\n            \"roles\": self.roles,\n            \"messages\": self.messages,\n            \"offset\": self.offset,\n            \"sep\": self.sep,\n            \"sep2\": self.sep2,\n        }\n\n\nconv_v1 = Conversation(\n    system=\"A chat between a curious human and an artificial intelligence assistant. \"\n           \"The assistant gives helpful, detailed, and polite answers to the human's questions.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"Give three tips for staying healthy.\"),\n        (\"Assistant\",\n            \"Sure, here are three tips for staying healthy:\\n\"\n            \"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. \"\n            \"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, \"\n            \"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or \"\n            \"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening \"\n            \"activities at least two days per week.\\n\"\n            \"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, \"\n            \"vegetables, whole grains, lean proteins, and healthy fats can help support \"\n            \"your overall health. Try to limit your intake of processed and high-sugar foods, \"\n            \"and aim to drink plenty of water throughout the day.\\n\"\n            \"3. Get enough sleep: Getting enough quality sleep is essential for your physical \"\n            \"and mental health. Adults should aim for seven to nine hours of sleep per night. \"\n            \"Establish a regular sleep schedule and try to create a relaxing bedtime routine to \"\n            \"help improve the quality of your sleep.\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nconv_v1_2 = Conversation(\n    system=\"A chat between a curious human and an artificial intelligence assistant. \"\n           \"The assistant gives helpful, detailed, and polite answers to the human's questions.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"What are the key differences between renewable and non-renewable energy sources?\"),\n        (\"Assistant\",\n            \"Renewable energy sources are those that can be replenished naturally in a relatively \"\n            \"short amount of time, such as solar, wind, hydro, geothermal, and biomass. \"\n            \"Non-renewable energy sources, on the other hand, are finite and will eventually be \"\n            \"depleted, such as coal, oil, and natural gas. Here are some key differences between \"\n            \"renewable and non-renewable energy sources:\\n\"\n            \"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable \"\n            \"energy sources are finite and will eventually run out.\\n\"\n            \"2. Environmental impact: Renewable energy sources have a much lower environmental impact \"\n            \"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, \"\n            \"and other negative effects.\\n\"\n            \"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically \"\n            \"have lower operational costs than non-renewable sources.\\n\"\n            \"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote \"\n            \"locations than non-renewable sources.\\n\"\n            \"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different \"\n            \"situations and needs, while non-renewable sources are more rigid and inflexible.\\n\"\n            \"6. Sustainability: Renewable energy sources are more sustainable over the long term, while \"\n            \"non-renewable sources are not, and their depletion can lead to economic and social instability.\\n\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nconv_vicuna_v1_1 = Conversation(\n    system=\"A chat between a curious user and an artificial intelligence assistant. \"\n    \"The assistant gives helpful, detailed, and polite answers to the user's questions.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"v1\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n)\n\n# conv_mpt = Conversation(\n#     system=\"\"\"<|im_start|>system\n# - You are designed by Megvii(旷视), and your name is GOT.\n# - 你叫GOT, 你来自旷视, 你是旷视开发的。\n# - 你擅长分析表格，仔细读图表中的内容，然后给出你的答案。\"\"\",\n#     roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n#     version=\"mpt\",\n#     messages=(),\n#     offset=0,\n#     sep_style=SeparatorStyle.MPT,\n#     sep=\"<|im_end|>\",\n# )\n\nconv_mpt = Conversation(\n    system=\"\"\"<|im_start|>system\nYou should follow the instructions carefully and explain your answers in detail.\"\"\",\n    # system = None,\n    roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n    version=\"mpt\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.MPT,\n    sep=\"<|im_end|>\",\n)\n\nconv_mpt_eval = Conversation(\n    system=\"\",\n    roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n    version=\"mpt\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.MPT,\n    sep=\"<|im_end|>\",\n)\n\nconv_mpt_text = Conversation(\n    system=\"\"\"<|im_start|>system\n- You are a helpful assistant chatbot trained by MosaicML.\n- You answer questions.\n- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.\"\"\",\n    roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n    version=\"mpt\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.MPT,\n    sep=\"<|im_end|>\",\n)\n\nconv_bair_v1 = Conversation(\n    system=\"BEGINNING OF CONVERSATION:\",\n    roles=(\"USER\", \"GPT\"),\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n)\n\n# simple_conv = Conversation(\n#     system=\"You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture.\"\n#            \"You are designed to assist human with a variety of tasks using natural language.\"\n#            \"Follow the instructions carefully.\",\n#     roles=(\"Human\", \"Assistant\"),\n#     messages=(\n#         (\"Human\", \"Hi!\"),\n#         (\"Assistant\", \"Hi there!  How can I help you today?\\n\")\n#     ),\n#     offset=2,\n#     sep_style=SeparatorStyle.SINGLE,\n#     sep=\"###\",\n# )\n\n\nsimple_conv = Conversation(\n    system=\"\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n    ),\n    offset=0,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nsimple_conv_multimodal = Conversation(\n    system=\"You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.\"\n           \"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\n           \"Follow the instructions carefully and explain your answers in detail.\",\n    # system=\"\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"Hi!\"),\n        (\"Assistant\", \"Hi there!  How can I help you today?\\n\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nsimple_conv_mpt_multimodal = Conversation(\n    system=\"\"\"<|im_start|>system\n- You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.\n- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n- You should follow the instructions carefully and explain your answers in detail.\"\"\",\n    roles=(\"<|im_start|>user\\n\", \"<|im_start|>assistant\\n\"),\n    version=\"mpt\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.MPT,\n    sep=\"<|im_end|>\",\n)\n\nsimple_conv_legacy = Conversation(\n    system=\"You are GOT, a large language model trained by Foundation Model Group, Megvii Technology.\"\n           \"You are designed to assist human with a variety of tasks using natural language.\"\n           \"Follow the instructions carefully.\",\n    roles=(\"Human\", \"Assistant\"),\n    messages=(\n        (\"Human\", \"Hi!\\n\\n### Response:\"),\n        (\"Assistant\", \"Hi there!  How can I help you today?\\n\")\n    ),\n    offset=2,\n    sep_style=SeparatorStyle.SINGLE,\n    sep=\"###\",\n)\n\nconv_llava_v1 = Conversation(\n    system=\"You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.\"\n           \"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\"\n           \"Follow the instructions carefully and explain your answers in detail.\",\n    roles=(\"USER\", \"ASSISTANT\"),\n    version=\"v1\",\n    messages=(),\n    offset=0,\n    sep_style=SeparatorStyle.TWO,\n    sep=\" \",\n    sep2=\"</s>\",\n)\n\ndefault_conversation = conv_mpt\nconv_templates = {\n    \"default\": simple_conv_multimodal,\n    \"simple\": simple_conv,\n    \"simple_legacy\": simple_conv_legacy,\n    \"multimodal\": simple_conv,\n    \"mpt_multimodal\": simple_conv_mpt_multimodal,\n    \"llava_v1\": conv_llava_v1,\n    \"mpt_eval\": conv_mpt_eval,\n    # fastchat\n    \"v1\": conv_vicuna_v1_1,\n    \"bair_v1\": conv_bair_v1,\n    \"vicuna_v1_1\": conv_vicuna_v1_1,\n    \"mpt\": conv_mpt,\n    \"mpt_text\": conv_mpt_text,\n}\n\n\nif __name__ == \"__main__\":\n    print(default_conversation.get_prompt())\n"
  },
  {
    "path": "GOT-OCR-2.0-master/GOT/utils/utils.py",
    "content": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\nimport torch\nimport requests\n\nfrom transformers import StoppingCriteria\nfrom GOT.utils.constants import LOGDIR\n\nserver_error_msg = \"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**\"\nmoderation_msg = \"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.\"\n\nhandler = None\n\n\ndef build_logger(logger_name, logger_filename):\n    global handler\n\n    formatter = logging.Formatter(\n        fmt=\"%(asctime)s | %(levelname)s | %(name)s | %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n    )\n\n    # Set the format of root handlers\n    if not logging.getLogger().handlers:\n        logging.basicConfig(level=logging.INFO)\n    logging.getLogger().handlers[0].setFormatter(formatter)\n\n    # Redirect stdout and stderr to loggers\n    stdout_logger = logging.getLogger(\"stdout\")\n    stdout_logger.setLevel(logging.INFO)\n    sl = StreamToLogger(stdout_logger, logging.INFO)\n    sys.stdout = sl\n\n    stderr_logger = logging.getLogger(\"stderr\")\n    stderr_logger.setLevel(logging.ERROR)\n    sl = StreamToLogger(stderr_logger, logging.ERROR)\n    sys.stderr = sl\n\n    # Get logger\n    logger = logging.getLogger(logger_name)\n    logger.setLevel(logging.INFO)\n\n    # Add a file handler for all loggers\n    if handler is None:\n        os.makedirs(LOGDIR, exist_ok=True)\n        filename = os.path.join(LOGDIR, logger_filename)\n        handler = logging.handlers.TimedRotatingFileHandler(\n            filename, when='D', utc=True)\n        handler.setFormatter(formatter)\n\n        for name, item in logging.root.manager.loggerDict.items():\n            if isinstance(item, logging.Logger):\n                item.addHandler(handler)\n\n    return logger\n\n\nclass StreamToLogger(object):\n    \"\"\"\n    Fake file-like stream object that redirects writes to a logger instance.\n    \"\"\"\n    def __init__(self, logger, log_level=logging.INFO):\n        self.terminal = sys.stdout\n        self.logger = logger\n        self.log_level = log_level\n        self.linebuf = ''\n\n    def __getattr__(self, attr):\n        return getattr(self.terminal, attr)\n\n    def write(self, buf):\n        temp_linebuf = self.linebuf + buf\n        self.linebuf = ''\n        for line in temp_linebuf.splitlines(True):\n            # From the io.TextIOWrapper docs:\n            #   On output, if newline is None, any '\\n' characters written\n            #   are translated to the system default line separator.\n            # By default sys.stdout.write() expects '\\n' newlines and then\n            # translates them so this is still cross platform.\n            if line[-1] == '\\n':\n                self.logger.log(self.log_level, line.rstrip())\n            else:\n                self.linebuf += line\n\n    def flush(self):\n        if self.linebuf != '':\n            self.logger.log(self.log_level, self.linebuf.rstrip())\n        self.linebuf = ''\n\n\ndef disable_torch_init():\n    \"\"\"\n    Disable the redundant torch default initialization to accelerate model creation.\n    \"\"\"\n    import torch\n    setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n    setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n\n\ndef violates_moderation(text):\n    \"\"\"\n    Check whether the text violates OpenAI moderation API.\n    \"\"\"\n    url = \"https://api.openai.com/v1/moderations\"\n    headers = {\"Content-Type\": \"application/json\",\n               \"Authorization\": \"Bearer \" + os.environ[\"OPENAI_API_KEY\"]}\n    text = text.replace(\"\\n\", \"\")\n    data = \"{\" + '\"input\": ' + f'\"{text}\"' + \"}\"\n    data = data.encode(\"utf-8\")\n    try:\n        ret = requests.post(url, headers=headers, data=data, timeout=5)\n        flagged = ret.json()[\"results\"][0][\"flagged\"]\n    except requests.exceptions.RequestException as e:\n        flagged = False\n    except KeyError as e:\n        flagged = False\n\n    return flagged\n\n\ndef pretty_print_semaphore(semaphore):\n    if semaphore is None:\n        return \"None\"\n    return f\"Semaphore(value={semaphore._value}, locked={semaphore.locked()})\"\n\n\nclass KeywordsStoppingCriteria(StoppingCriteria):\n    def __init__(self, keywords, tokenizer, input_ids):\n        self.keywords = keywords\n        self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]\n        self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]\n        self.tokenizer = tokenizer\n        self.start_len = None\n        self.input_ids = input_ids\n\n    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:\n        if self.start_len is None:\n            self.start_len = self.input_ids.shape[1]\n        else:\n            for keyword_id in self.keyword_ids:\n                if output_ids[0, -1] == keyword_id:\n                    return True\n            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]\n            for keyword in self.keywords:\n                if keyword in outputs:\n                    return True\n        return False\n\n\ndef smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    # # num_new_tokens = 1\n    # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True)\n    # model.resize_token_embeddings(len(tokenizer))\n\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(\n            dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n    \ndef maybe_zero_3(param, ignore_status=False, name=None):\n    from deepspeed import zero\n    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus\n    if hasattr(param, \"ds_id\"):\n        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:\n            if not ignore_status:\n                logging.warning(f\"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}\")\n        with zero.GatheredParameters([param]):\n            param = param.data.detach().cpu().clone()\n    else:\n        param = param.detach().cpu().clone()\n    return param\n\n\n# Borrowed from peft.utils.get_peft_model_state_dict\ndef get_peft_state_maybe_zero_3(named_params, bias):\n    if bias == \"none\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k}\n    elif bias == \"all\":\n        to_return = {k: t for k, t in named_params if \"lora_\" in k or \"bias\" in k}\n    elif bias == \"lora_only\":\n        to_return = {}\n        maybe_lora_bias = {}\n        lora_bias_names = set()\n        for k, t in named_params:\n            if \"lora_\" in k:\n                to_return[k] = t\n                bias_name = k.split(\"lora_\")[0] + \"bias\"\n                lora_bias_names.add(bias_name)\n            elif \"bias\" in k:\n                maybe_lora_bias[k] = t\n        for k, t in maybe_lora_bias:\n            if bias_name in lora_bias_names:\n                to_return[bias_name] = t\n    else:\n        raise NotImplementedError\n    to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}\n    return to_return\n\n\ndef get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):\n    to_return = {k: t for k, t in named_params if \"lora_\" not in k}\n    if require_grad_only:\n        to_return = {k: t for k, t in to_return.items() if t.requires_grad}\n    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}\n    return to_return\n\n\ndef find_all_linear_names(model):\n    cls = torch.nn.Linear\n    lora_module_names = set()\n    for name, module in model.named_modules():\n        if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name:\n            lora_module_names.add(name)\n\n    print(lora_module_names)\n    return list(lora_module_names)"
  },
  {
    "path": "GOT-OCR-2.0-master/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"GOT\"\nversion = \"0.1.0\"\ndescription = \"Towards OCR-2.0.\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n]\ndependencies = [\n    \"markdown2[all]\", \"numpy\",\n    \"requests\", \"sentencepiece\", \"tokenizers>=0.15.2\",\n    \"torch\", \"torchvision\", \"wandb\",\n    \"shortuuid\", \"httpx==0.24.0\",\n    \"deepspeed==0.12.3\",\n    \"peft==0.4.0\",\n    \"albumentations\",\n    \"opencv-python\",\n    \"tiktoken==0.6.0\",\n    \"accelerate==0.28.0\",\n    \"transformers==4.37.2\",\n    \"bitsandbytes==0.41.0\",\n    \"scikit-learn==1.2.2\",\n    \"sentencepiece==0.1.99\",\n    \"einops==0.6.1\", \"einops-exts==0.0.4\", \"timm==0.6.13\",\n]\n\n[tool.setuptools.packages.find]\nexclude = [\"assets*\", \"benchmark*\", \"docs\", \"dist*\", \"playground*\", \"scripts*\", \"tests*\"]\n\n[tool.wheel]\nexclude = [\"assets*\", \"benchmark*\", \"docs\", \"dist*\", \"playground*\", \"scripts*\", \"tests*\"]\n"
  },
  {
    "path": "GOT-OCR-2.0-master/pyvenv.cfg",
    "content": "home = /usr/bin\nimplementation = CPython\nversion_info = 3.8.10.final.0\nvirtualenv = 20.16.7\ninclude-system-site-packages = true\nbase-prefix = /usr\nbase-exec-prefix = /usr\nbase-executable = /usr/bin/python3\n"
  },
  {
    "path": "GOT-OCR-2.0-master/render_tools/content-mmd-to-html.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\" data-lt-installed=\"true\"><head>\n  <meta charset=\"UTF-8\">\n  <title>Title</title>\n  <script>\n    const text = \n  </script>\n  <style>\n    #content {\n      max-width: 800px;\n      margin: auto;\n    }\n  </style>\n  <script>\n    let script = document.createElement('script');\n    script.src = \"https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js\";\n    document.head.append(script);\n\n    script.onload = function() {\n      const isLoaded = window.loadMathJax();\n      if (isLoaded) {\n        console.log('Styles loaded!')\n      }\n\n      const el = window.document.getElementById('content-text');\n      if (el) {\n        const options = {\n          htmlTags: true\n        };\n        const html = window.render(text, options);\n        el.outerHTML = html;\n      }\n    };\n  </script>\n</head>\n<body>\n  <div id=\"content\"><div id=\"content-text\"></div></div>\n</body>\n</html>\n"
  },
  {
    "path": "GOT-OCR-2.0-master/render_tools/tikz.html",
    "content": "<!DOCTYPE html>\r\n\r\n<html>\r\n\r\n<head>\r\n<meta charset=\"UTF-8\">\r\n<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\r\n<title>Document</title>\r\n<link rel=\"stylesheet\" type=\"text/css\" href=\"https://tikzjax.com/v1/fonts.css\">\r\n<script src=\"https://tikzjax.com/v1/tikzjax.js\"></script>\r\n</head>\r\n<body>\r\n<script type=\"text/tikz\">\r\nconst text =\r\n</script>\r\n</body>\r\n</html>"
  },
  {
    "path": "GOT-OCR-2.0-master/results/demo.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\" data-lt-installed=\"true\"><head>\n  <meta charset=\"UTF-8\">\n  <title>Title</title>\n  <script>\n    const text =\"\\\\title{\\n\"+\n\"MADRIX \\\\({ }^{\\\\circledR}\\\\) PLEXUS -\\n\"+\n\"}\\n\"+\n\"\\\\section*{Quick Start Guide \\\\& Technical Manual}\\n\"+\n\"\\\\(5^{\\\\text {th }}\\\\) Edition - November 2017\\n\"+\n\"Thank You For Purchasing MADRIK \\\\({ }^{\\\\circledR}\\\\) PLEXUS!\\n\"+\n\"Please read this guide carefully and thoroughly before using MADRIX \\\\({ }^{\\\\circledR}\\\\) PLEXUS. Make sure that you fully understand all information.\\n\"+\n\"This MADRIX \\\\({ }^{\\\\circledR}\\\\) PLEXUS Quick Start Guide and the MADRIX \\\\({ }^{\\\\circledR}\\\\) PLEXUS User Manual are written in English and German.\\n\"+\n\"Developed and made in Germany.\\n\"+\n\"\\\\section*{Imprint}\\n\"+\n\"inaage GmbH\\n\"+\n\"Wiener Straße 56\\n\"+\n\"01219 Dresden\\n\"+\n\"Germany\\n\"+\n\"Managing Directors: Christian Hertel, Sebastian Pinzer, Sebastian Wissmann\\n\"+\n\"Web www.madrix.com\\n\"+\n\"E-mail info@madrix.com\\n\"+\n\"Phone +4935186268690\\n\" \n  </script>\n  <style>\n    #content {\n      max-width: 800px;\n      margin: auto;\n    }\n  </style>\n  <script>\n    let script = document.createElement('script');\n    script.src = \"https://cdn.jsdelivr.net/npm/mathpix-markdown-it@1.3.6/es5/bundle.js\";\n    document.head.append(script);\n\n    script.onload = function() {\n      const isLoaded = window.loadMathJax();\n      if (isLoaded) {\n        console.log('Styles loaded!')\n      }\n\n      const el = window.document.getElementById('content-text');\n      if (el) {\n        const options = {\n          htmlTags: true\n        };\n        const html = window.render(text, options);\n        el.outerHTML = html;\n      }\n    };\n  </script>\n</head>\n<body>\n  <div id=\"content\"><div id=\"content-text\"></div></div>\n</body>\n</html>\n"
  },
  {
    "path": "GOT-OCR-2.0-master/zero_config/zero2.json",
    "content": "{\r\n    \"bf16\": {\r\n        \"enabled\": true\r\n    },\r\n    \"train_micro_batch_size_per_gpu\": \"auto\",\r\n    \"zero_optimization\": {\r\n        \"stage\": 2,\r\n        \"overlap_comm\": true,\r\n        \"contiguous_gradients\": true,\r\n        \"sub_group_size\": 1e9,\r\n        \"reduce_bucket_size\": \"auto\"\r\n    }\r\n}"
  },
  {
    "path": "GOT-OCR-2.0-master/zero_config/zero3.json",
    "content": "{\r\n    \"fp16\": {\r\n        \"enabled\": \"auto\",\r\n        \"loss_scale\": 0,\r\n        \"loss_scale_window\": 1000,\r\n        \"initial_scale_power\": 16,\r\n        \"hysteresis\": 2,\r\n        \"min_loss_scale\": 1\r\n    },\r\n    \"bf16\": {\r\n        \"enabled\": \"auto\"\r\n    },\r\n    \"train_micro_batch_size_per_gpu\": \"auto\",\r\n    \"train_batch_size\": \"auto\",\r\n    \"gradient_accumulation_steps\": \"auto\",\r\n    \"zero_optimization\": {\r\n        \"stage\": 3,\r\n        \"overlap_comm\": true,\r\n        \"contiguous_gradients\": true,\r\n        \"sub_group_size\": 1e9,\r\n        \"reduce_bucket_size\": \"auto\",\r\n        \"stage3_prefetch_bucket_size\": \"auto\",\r\n        \"stage3_param_persistence_threshold\": \"auto\",\r\n        \"stage3_max_live_parameters\": 1e9,\r\n        \"stage3_max_reuse_distance\": 1e9,\r\n        \"stage3_gather_16bit_weights_on_model_save\": true\r\n    }\r\n}"
  },
  {
    "path": "README.md",
    "content": "<h3><a href=\"\">General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model</a></h3>\n\n<a href=\"https://huggingface.co/ucaslcl/GOT-OCR2_0\"><img src=\"https://img.shields.io/badge/Huggingface-yellow\"></a>\n<a href=\"https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0\"><img src=\"https://img.shields.io/badge/Modelscope-red\"></a>\n<a href=\"https://arxiv.org/abs/2409.01704\"><img src=\"https://img.shields.io/badge/Paper-PDF-orange\"></a> \n<a href=\"https://zhuanlan.zhihu.com/p/718163422\"><img src=\"https://img.shields.io/badge/zhihu-red\"></a> \n<a href=\"https://huggingface.co/spaces/ucaslcl/GOT_online\"><img src=\"https://img.shields.io/badge/demo-green\"></a> \n\n[Haoran Wei*](https://scholar.google.com/citations?user=J4naK0MAAAAJ&hl=en), Chenglong Liu*, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu,  [Zheng Ge](https://joker316701882.github.io/), Liang Zhao, [Jianjian Sun](https://scholar.google.com/citations?user=MVZrGkYAAAAJ&hl=en), [Yuang Peng](https://yuangpeng.com), Chunrui Han, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ&hl=en)\n\n<p align=\"center\">\n<img src=\"assets/got_logo.png\" style=\"width: 200px\" align=center>\n</p>\n\n\n## Release\n- [2025/2/1] 🚀🚀🚀 GOT-OCR2.0 is merged to [Huggingface-transformers](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)/[space](https://huggingface.co/spaces/yonigozlan/GOT-OCR-Transformers). It supports inference batched. Thanks to the MLE of Huggingface [Yoni](https://github.com/yonigozlan).\n- [2024/12/24] 🔥🔥🔥 My new work on system-2 perception is released [slow-perception](https://github.com/Ucas-HaoranWei/Slow-Perception).\n- [2024/12/18] 🚀🚀🚀 GOT-OCR2.0 is supported in [PaddleMIX](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/paddlemix/examples/GOT_OCR_2_0) by Paddle Team. Thanks for the Paddle team!\n- [2024/12/8] 🔥🔥🔥 The model download has exceeded 1M on [Huggingface](https://huggingface.co/stepfun-ai/GOT-OCR2_0).\n- [2024/12/5] The seven wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg).\n- [2024/11/4] The six wechat [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat6-2.jpg).\n- [2024/10/24] The previous four wechat groups are full, so we created a fifth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat5.png).\n- [2024/10/11] Too many friends want to join the wechat group, so we created a fourth [group](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat4.jpg).\n- [2024/10/2] [onnx](https://github.com/BaofengZan/GOT-OCRv2-onnx) and [mnn](https://github.com/BaofengZan/mnn-llm-GOT-OCR2.0) versions of GOT-OCR2.0.\n- [2024/9/29]🔥🔥🔥 The community has implemented the first version of [llama_cpp_inference](https://github.com/1694439208/GOT-OCR-Inference).\n- [2024/9/24]🔥🔥🔥 Support [ms-swift](https://github.com/modelscope/ms-swift/issues/2122) quick [Fine-tune](#fine-tune) for your own data. \n- [2024/9/23]🔥🔥🔥 We release the official [Modelscope demo](https://modelscope.cn/studios/stepfun-ai/GOT_official_online_demo). Thanks very much for Modelscope providing the GPU resource.\n- [2024/9/19]🔥🔥🔥 GOT-OCR2.0 achieves Huggingface trending #1.\n- [2024/9/14]🔥🔥🔥 We release the official [demo](https://huggingface.co/spaces/ucaslcl/GOT_online). Thanks very much for Huggingface providing the GPU resource. \n- [2024/9/13]🔥🔥🔥 We release the [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0) deployment. \n- [2024/9/03]🔥🔥🔥 We open-source the codes, weights, and benchmarks. The paper can be found in this [repo](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/GOT-OCR-2.0-paper.pdf). We also have submitted it to Arxiv. \n- [2024/9/03]🔥🔥🔥 We release the OCR-2.0 model GOT! \n\n\n[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)\n[![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)\n\n\n\n\n## Community contributions\nWe encourage everyone to develop GOT applications based on this repo. Thanks for the following contributions :\n\n[OpenVINO](https://github.com/can-gaa-hou/GOT-OCR2.0-OpenVINO)~ contributor: [@can-gaa-hou](https://github.com/can-gaa-hou)\n\n[GGUF and Llama.cpp inference](https://github.com/MosRat/got.cpp)~ contributor: [@MosRat](https://github.com/MosRat)\n\n[vllm reference](https://github.com/liunian-Jay/MU-GOT/blob/master/PDF_parsing/GOT/GOT/model/modeling_GOT_vllm.py) ~ contributor: [@Jay](https://github.com/liunian-Jay)\n\n[onnx and mnn supports](https://github.com/BaofengZan/GOT-OCRv2-onnx) ~ contributor: [@BaofengZan](https://github.com/BaofengZan)\n\n[llama_cpp inference](https://github.com/1694439208/GOT-OCR-Inference) ~ contributor: [@1694439208](https://github.com/1694439208)\n\n[Colab of GOT](https://colab.research.google.com/drive/1nmiNciZ5ugQVp4rFbL9ZWpEPd92Y9o7p?usp=sharing)   ~      contributor: [@Zizhe Wang](https://github.com/PaperPlaneDeemo)\n\n[CPU version of GOT](https://github.com/ElvisClaros/GOT-OCR2.0) ~ contributor: [@ElvisClaros](https://github.com/ElvisClaros)\n\n[Online demo](https://huggingface.co/spaces/Tonic/GOT-OCR) ~ contributor: [@Joseph Pollack](https://huggingface.co/Tonic)\n\n[Dokcer & client demo](https://github.com/QIN2DIM/GOT-OCR2.0) ~ contributor: [@QIN2DIM](https://github.com/QIN2DIM) \n\n[GUI of GOT](https://github.com/XJF2332/GOT-OCR-2-GUI) ~ contributor: [@XJF2332](https://github.com/XJF2332) \n\n## Contents\n- [Install](#install)\n- [GOT Weights](#got-weights)\n- [Benchmarks](#benchmarks)\n- [Demo](#demo)\n- [Train](#train)\n- [Fine-tune](#fine-tune)\n- [Eval](#eval)\n\n***\n<p align=\"center\">\n<img src=\"assets/got_support.jpg\" style=\"width: 800px\" align=center>\n</p>\n<p align=\"center\">\n<a href=\"\">Towards OCR-2.0 via a Unified End-to-end Model</a>       \n</p>\n\n***\n\n\n## Install\n0. Our environment is cuda11.8+torch2.0.1\n1. Clone this repository and navigate to the GOT folder\n```bash\ngit clone https://github.com/Ucas-HaoranWei/GOT-OCR2.0.git\ncd 'the GOT folder'\n```\n2. Install Package\n```Shell\nconda create -n got python=3.10 -y\nconda activate got\npip install -e .\n```\n\n3. Install Flash-Attention\n```\npip install ninja\npip install flash-attn --no-build-isolation\n```\n## GOT Weights\n- [Huggingface](https://huggingface.co/ucaslcl/GOT-OCR2_0)\n- [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing)\n- [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2\n\n## Benchmarks\n- [Google Drive](https://drive.google.com/drive/folders/1OdDtsJ8bFJYlNUzCQG4hRkUL6V-qBQaN?usp=sharing)\n- [BaiduYun](https://pan.baidu.com/s/1G4aArpCOt6I_trHv_1SE2g) code: OCR2\n\n## Demo\n1. plain texts OCR:\n```Shell\npython3 GOT/demo/run_ocr_2.0.py  --model-name  /GOT_weights/  --image-file  /an/image/file.png  --type ocr\n```\n2. format texts OCR:\n```Shell\npython3 GOT/demo/run_ocr_2.0.py  --model-name  /GOT_weights/  --image-file  /an/image/file.png  --type format\n```\n3. fine-grained OCR:\n```Shell\npython3 GOT/demo/run_ocr_2.0.py  --model-name  /GOT_weights/  --image-file  /an/image/file.png  --type format/ocr --box [x1,y1,x2,y2]\n```\n```Shell\npython3 GOT/demo/run_ocr_2.0.py  --model-name  /GOT_weights/  --image-file  /an/image/file.png  --type format/ocr --color red/green/blue\n```\n4. multi-crop OCR:\n```Shell\npython3 GOT/demo/run_ocr_2.0_crop.py  --model-name  /GOT_weights/ --image-file  /an/image/file.png \n```\n5. **Note**: This feature is not batch inference!! It works on the token level.  Please read the paper and then correct use multi-page OCR (the image path contains multiple .png files):\n```Shell\npython3 GOT/demo/run_ocr_2.0_crop.py  --model-name  /GOT_weights/ --image-file  /images/path/  --multi-page\n```\n6. render the formatted OCR results:\n```Shell\npython3 GOT/demo/run_ocr_2.0.py  --model-name  /GOT_weights/  --image-file  /an/image/file.png  --type format --render\n ```\n**Note**:\nThe rendering results can be found in /results/demo.html. Please open the demo.html to see the results.\n\n\n## Train\n0. Train sample can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/train_sample.jpg). Note that the '\\<image>' in the 'conversations'-'human'-'value' is necessary!\n1. This codebase only supports post-training (stage-2/stage-3) upon our GOT weights.\n2. If you want to train from stage-1 described in our paper, you need this [repo](https://github.com/Ucas-HaoranWei/Vary-tiny-600k).\n\n```Shell\ndeepspeed   /GOT-OCR-2.0-master/GOT/train/train_GOT.py \\\n --deepspeed /GOT-OCR-2.0-master/zero_config/zero2.json    --model_name_or_path /GOT_weights/ \\\n --use_im_start_end True   \\\n --bf16 True   \\\n --gradient_accumulation_steps 2    \\\n --evaluation_strategy \"no\"   \\\n --save_strategy \"steps\"  \\\n --save_steps 200   \\\n --save_total_limit 1   \\\n --weight_decay 0.    \\\n --warmup_ratio 0.001     \\\n --lr_scheduler_type \"cosine\"    \\\n --logging_steps 1    \\\n --tf32 True     \\\n --model_max_length 8192    \\\n --gradient_checkpointing True   \\\n --dataloader_num_workers 8    \\\n --report_to none  \\\n --per_device_train_batch_size 2    \\\n --num_train_epochs 1  \\\n --learning_rate 2e-5   \\\n --datasets pdf-ocr+scence \\\n --output_dir /your/output/path\n```\n\n\n**Note**:\n1. Change the corresponding data information in [constant.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/utils).\n2. Change line 37 in [conversation_dataset_qwen.py](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/tree/main/GOT-OCR-2.0-master/GOT/data) to your data_name.\n\n## Fine-tune\nQuick Fine-tune with ms-swift:\n\n```Shell\ngit clone https://github.com/modelscope/ms-swift.git\ncd ms-swift\npip install -e .[llm]\n```\n```Shell\n# default：sft LLM & projector, freeze vision encoder\nCUDA_VISIBLE_DEVICES=0 swift sft\\\n--model_type got-ocr2 \\\n--model_id_or_path stepfun-ai/GOT-OCR2_0 \\\n--sft_type lora \\\n--dataset latex-ocr-print#5000\n\n# Deepspeed ZeRO2\nNPROC_PER_NODE=4 \\\nCUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \\\n--model_type got-ocr2 \\\n--model_id_or_path stepfun-ai/GOT-OCR2_0 \\\n--sft_type lora \\\n--dataset latex-ocr-print#5000 \\\n--deepspeed default-zero2\n```\n\n**With your data**:\n```Shell\n--dataset train.jsonl\n--val_dataset val.jsonl (optional)\n```\n**Data format**:\n```Shell\n{\"query\": \"<image>55555\", \"response\": \"66666\", \"images\": [\"image_path\"]}\n{\"query\": \"<image><image>eeeee\", \"response\": \"fffff\", \"history\": [], \"images\": [\"image_path1\", \"image_path2\"]}\n{\"query\": \"EEEEE\", \"response\": \"FFFFF\", \"history\": [[\"query1\", \"response1\"], [\"query2\", \"response2\"]]}\n```\nMore details can be seen in [ms-swift](https://github.com/modelscope/ms-swift/issues/2122).\n\n## Eval\n1. We use the [Fox](https://github.com/ucaslcl/Fox) and [OneChart](https://github.com/LingyvKong/OneChart) benchmarks, and other benchmarks can be found in the weights download link.\n2. The eval codes can be found in GOT/eval.\n3. You can use the evaluate_GOT.py to run the eval. If you have 8 GPUs， the --num-chunks can be set to 8.\n ```Shell\npython3 GOT/eval/evaluate_GOT.py --model-name /GOT_weights/ --gtfile_path xxxx.json --image_path  /image/path/ --out_path /data/eval_results/GOT_mathpix_test/ --num-chunks 8 --datatype OCR\n```\n\n## Contact\nIf you are interested in this work or have questions about the code or the paper, please join our communication [Wechat](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/wechat.jpg) group.\n\n**Note**:\nAll six wechat groups are full, please join [group 7](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/blob/main/assets/Wechat7.jpg).\n\nDon't hesitate to contact me by email, weihaoran18@mails.ucas.ac.cn, if you have any questions.\n\n## Acknowledgement\n- [Vary](https://github.com/Ucas-HaoranWei/Vary/): the codebase we built upon!\n- [Qwen](https://github.com/QwenLM/Qwen): the LLM base model of Vary, which is good at both English and Chinese!\n\n\n## Citation\n```bibtex\n@article{wei2024general,\n  title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model},\n  author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others},\n  journal={arXiv preprint arXiv:2409.01704},\n  year={2024}\n}\n@article{wei2023vary,\n  title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models},\n  author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},\n  journal={arXiv preprint arXiv:2312.06109},\n  year={2023}\n}\n\n\n"
  }
]