[
  {
    "path": ".gitignore",
    "content": "__pycache__\n.ipynb_checkpoints\n.pytest_cache\n.vim-arsync\n*.log\nresult.json\nweights/\nwandb/\n.cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Alexander Smirnov\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Makefile",
    "content": "LINE_WIDTH=120\nISORT_FLAGS=--line-width=${LINE_WIDTH} --profile black\nBLACK_FLAGS=--line-length=${LINE_WIDTH}\n\ninstall:\n\tpip install -r requirements.txt\n\ninstall-format:\n\tpip install black isort\n\nformat:\n\tisort ${ISORT_FLAGS} --check-only --diff .\n\tblack ${BLACK_FLAGS} --check --diff .\n\nformat-fix:\n\tisort ${ISORT_FLAGS} .\n\tblack ${BLACK_FLAGS} .\n"
  },
  {
    "path": "README.md",
    "content": "# Doppelganger\n\nFine-tuning LLM on my Telegram chats. You may read full story [in my blog](https://asmirnov.xyz/doppelganger).\n\n## Dataset Preparation\n\nFirst, we have to get the data. Open Telegram, go to 'Setting' -> 'Advanced' -> 'Export Telegram Data' and unselect everything except 'Personal chats' and 'Private groups' (don't select 'Only my messages there'). As output format choose 'Machine-readable JSON'. It will result in `result.json`.\n\nUse `prepare_dataset.py` to transform `result.json` to JSON with a list of sessions:\n\n```bash\npython prepare_dataset.py \"./data/result.json\" \"./data/messages.json\"\n```\n\nThere are some flags available for this script, you can read more in `--help`:\n\n```bash\npython prepare_dataset.py --help\n```\n\n<details>\n<summary>output</summary>\n\n```\nNAME\n    prepare_dataset.py - Transforms chat histories from .json telegram export to .json with a list of sessions. Session is a list of messages, where each message is a dict with fields 'author' and 'text'.\n\nSYNOPSIS\n    prepare_dataset.py INPUT OUTPUT <flags>\n\nDESCRIPTION\n    Transforms chat histories from .json telegram export to .json with a list of sessions. Session is a list of messages, where each message is a dict with fields 'author' and 'text'.\n\nPOSITIONAL ARGUMENTS\n    INPUT\n        Type: str\n        Path to .json telegram export, usually called result.json\n    OUTPUT\n        Type: str\n        Path to output .json file\n\nFLAGS\n    -t, --target_name=TARGET_NAME\n        Type: Optional[str | None]\n        Default: None\n        The name of the person to target. This person will be present in every session. If empty, will be tried to be detected from \"Saved Messages\"\n    -l, --last_x_months=LAST_X_MONTHS\n        Type: int\n        Default: 24\n        Number of last months to use messages from\n    -s, --session_minutes_threshold=SESSION_MINUTES_THRESHOLD\n        Type: int\n        Default: 10\n        Threshold in minutes where messages will belong to the same session\n    -c, --concat_one_user_messages_delimeter=CONCAT_ONE_USER_MESSAGES_DELIMETER\n        Type: str\n        Default: '\\n>>> '\n        Users might type several messages one after each other. They are concatenated using this delimeter\n\nNOTES\n    You can also use flags syntax for POSITIONAL ARGUMENTS\n```\n\n</details>\n\nIf you are interested, Telegram have several types of messages which should be handled differently:\n\n<details>\n<summary>default text message</summary>\n\n```\n{\n \"id\": 123,\n \"type\": \"message\",\n \"date\": \"2023-10-31T15:23:38\",\n \"date_unixtime\": \"1698746018\",\n \"from\": \"Username\",\n \"from_id\": \"user123\",\n \"text\": \"ты где?\",\n \"text_entities\": [\n  {\n   \"type\": \"plain\",\n   \"text\": \"ты где?\"\n  }\n ]\n}\n```\n\n</details>\n\n<details>\n<summary>multiple text entities</summary>\n\n```\n{\n \"id\": 345,\n \"type\": \"message\",\n \"date\": \"2023-10-25T01:56:50\",\n \"date_unixtime\": \"1698179210\",\n \"from\": \"Username\",\n \"from_id\": \"user456\",\n \"text\": [\n  \"California suspends GM Cruise's autonomous vehicle deployment | Hacker News\\n\",\n  {\n   \"type\": \"link\",\n   \"text\": \"https://news.ycombinator.com/item?id=38002752\"\n  }\n ],\n \"text_entities\": [\n  {\n   \"type\": \"plain\",\n   \"text\": \"California suspends GM Cruise's autonomous vehicle deployment | Hacker News\\n\"\n  },\n  {\n   \"type\": \"link\",\n   \"text\": \"https://news.ycombinator.com/item?id=38002752\"\n  }\n ]\n}\n```\n\n</details>\n\n<details>\n<summary>sticker</summary>\n\n```\n{\n \"id\": 789,\n \"type\": \"message\",\n \"date\": \"2023-10-30T23:24:20\",\n \"date_unixtime\": \"1698688460\",\n \"from\": \"Username\",\n \"from_id\": \"user789\",\n \"file\": \"(File not included. Change data exporting settings to download.)\",\n \"thumbnail\": \"(File not included. Change data exporting settings to download.)\",\n \"media_type\": \"sticker\",\n \"sticker_emoji\": \"🤗\",\n \"width\": 512,\n \"height\": 501,\n \"text\": \"\",\n \"text_entities\": []\n}\n```\n\n</details>\n\n## Training\n\nFinal version of models were trained with the parameters which are default in training scripts. Training logs can be accessed on [WandB](https://wandb.ai/furiousteabag/doppelganger).\n\n### LoRA fine-tune\n\nTo launch LoRA fine-tune with my default params, you will need GPU with 20GB VRAM. RTX 3090 is a good option for it's money. You may reduce `micro_batch_size` or `max_seq_length` if you want to lower the amount of VRAM required. To get full list of parameters, run:\n\n```\npython finetune_lora.py --help\n```\n\nTo train LoRA, run:\n\n```\npython finetune_lora.py\n```\n\n### Full fine-tune\n\nTo list available params with their default values, run:\n\n```\npython finetune_full.py --help\n```\n\nTo train:\n\n```\ntorchrun --nnodes=1 --nproc_per_node=NUMBER_OF_GPUS finetune_full.py\n```\n\n## Launching\n\nUse [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui). If you used LoRA, then clone [ehartford/dolphin-2.2.1-mistral-7b](https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b) or whatever model you are used as a base model and put trained LoRA connectors to `./loras/` folder within text-generation-webui. If you did full fine-tune, then copy training result to `./models/`.\n"
  },
  {
    "path": "data/.gitignore",
    "content": "*\n!.gitignore\n"
  },
  {
    "path": "finetune_full.py",
    "content": "import os\nfrom typing import Dict, List\n\nimport fire\nimport torch\nimport wandb\nfrom datasets import Dataset\nfrom loguru import logger\nfrom tokenizers import AddedToken\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, TrainingArguments\nfrom trl import SFTTrainer\n\nfrom utils.finetune_utils import DataCollatorForLanguageModelingChatML, prepare_dataset, print_trainable_parameters\n\nDEFAULT_BOS_TOKEN = \"<s>\"\nDEFAULT_EOS_TOKEN = \"<|im_end|>\"\nDEFAULT_UNK_TOKEN = \"<unk>\"\nDEFAULT_PAD_TOKEN = \"</s>\"\n\n\ndef smart_tokenizer_and_embedding_resize(\n    special_tokens_dict: Dict, tokenizer: PreTrainedTokenizer, model: PreTrainedModel, tokens_list: List = []\n):\n    \"\"\"Resize tokenizer and embedding.\n\n    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.\n    \"\"\"\n    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(tokens_list)\n    model.resize_token_embeddings(len(tokenizer))\n\n    if num_new_tokens > 0:\n        input_embeddings = model.get_input_embeddings().weight.data\n        output_embeddings = model.get_output_embeddings().weight.data\n\n        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)\n\n        input_embeddings[-num_new_tokens:] = input_embeddings_avg\n        output_embeddings[-num_new_tokens:] = output_embeddings_avg\n\n\ndef train(\n    model_name_or_path: str = \"mistralai/Mistral-7B-v0.1\",\n    data_path: str = \"./data/messages.json\",\n    output_dir: str = \"./weights/full/\",\n    gradient_accumulation_steps: int = 4,\n    micro_batch_size: int = 2,\n    num_epochs: int = 3,\n    learning_rate: float = 2e-5,\n    lr_scheduler_type: str = \"cosine\",\n    warmup_ratio: float = 0.03,\n    weight_decay: float = 0.0,\n    max_seq_length: int = 1024,\n    fsdp: str = \"full_shard auto_wrap\",\n    fsdp_transformer_layer_cls_to_wrap: str = \"MistralDecoderLayer\",\n    wandb_project: str = \"doppelganger\",\n    logging_steps: int = 1,\n):\n    if int(os.environ.get(\"LOCAL_RANK\", 0)) == 0:\n        batch_size = micro_batch_size * gradient_accumulation_steps * int(os.environ.get(\"WORLD_SIZE\", 1))\n        logger.info(f\"Total batch size: {batch_size}\")\n\n    model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)\n    model.config.use_cache = False\n\n    print_trainable_parameters(model)\n\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_name_or_path, model_max_length=max_seq_length, padding_side=\"right\", use_fast=False\n    )\n\n    special_tokens_dict = dict()\n    if tokenizer.bos_token is None:\n        special_tokens_dict[\"bos_token\"] = DEFAULT_BOS_TOKEN\n    if tokenizer.pad_token is None:\n        special_tokens_dict[\"pad_token\"] = DEFAULT_PAD_TOKEN\n    if tokenizer.unk_token is None:\n        special_tokens_dict[\"unk_token\"] = DEFAULT_UNK_TOKEN\n    special_tokens_dict[\"eos_token\"] = DEFAULT_EOS_TOKEN\n    tokens_list = [AddedToken(\"<|im_start|>\", normalized=False)]\n\n    smart_tokenizer_and_embedding_resize(\n        special_tokens_dict=special_tokens_dict,\n        tokens_list=tokens_list,\n        tokenizer=tokenizer,\n        model=model,\n    )\n\n    data_collator = DataCollatorForLanguageModelingChatML(tokenizer=tokenizer)\n    dataset = Dataset.from_dict({\"session\": prepare_dataset(data_path)})\n\n    if int(os.environ.get(\"LOCAL_RANK\", 0)) == 0:\n        wandb.init(project=wandb_project, job_type=\"train\", anonymous=\"allow\")\n\n    training_arguments = TrainingArguments(\n        output_dir=output_dir,\n        num_train_epochs=num_epochs,\n        per_device_train_batch_size=micro_batch_size,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        optim=\"adamw_torch\",\n        save_steps=500,\n        logging_steps=logging_steps,\n        logging_first_step=True,\n        learning_rate=learning_rate,\n        weight_decay=weight_decay,\n        bf16=True,\n        max_steps=-1,\n        warmup_ratio=warmup_ratio,\n        group_by_length=True,\n        lr_scheduler_type=lr_scheduler_type,\n        fsdp=fsdp,\n        fsdp_transformer_layer_cls_to_wrap=fsdp_transformer_layer_cls_to_wrap,\n        report_to=[\"wandb\"] if int(os.environ.get(\"LOCAL_RANK\", 0)) == 0 else [],\n    )\n\n    trainer = SFTTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        train_dataset=dataset,\n        data_collator=data_collator,\n        dataset_text_field=\"session\",\n        max_seq_length=max_seq_length,\n        packing=False,\n        args=training_arguments,\n    )\n\n    trainer.train()\n    trainer.save_state()\n    trainer.save_model(output_dir)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(train)\n"
  },
  {
    "path": "finetune_lora.py",
    "content": "import fire\nimport torch\nimport wandb\nfrom datasets import Dataset\nfrom loguru import logger\nfrom peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, MistralForCausalLM, TrainingArguments\nfrom trl import SFTTrainer\n\nfrom utils.finetune_utils import DataCollatorForLanguageModelingChatML, prepare_dataset, print_trainable_parameters\n\n\ndef train(\n    model_name_or_path: str = \"ehartford/dolphin-2.2.1-mistral-7b\",\n    data_path: str = \"./data/messages.json\",\n    output_dir: str = \"./weights/LoRA/\",\n    batch_size: int = 16,\n    micro_batch_size: int = 8,\n    num_epochs: int = 3,\n    lora_r: int = 32,\n    lora_alpha: int = 16,\n    lora_dropout: float = 0.05,\n    learning_rate: float = 2e-4,\n    lr_scheduler_type: str = \"cosine\",\n    warmup_ratio: float = 0.03,\n    weight_decay: float = 0.001,\n    max_seq_length: int = 1024,\n    wandb_project: str = \"doppelganger\",\n):\n    gradient_accumulation_steps = batch_size // micro_batch_size\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n    data_collator = DataCollatorForLanguageModelingChatML(tokenizer=tokenizer)\n\n    dataset = Dataset.from_dict({\"session\": prepare_dataset(data_path)})\n\n    # print(dataset[2500][\"session\"])\n    # collator_res = data_collator([tokenizer(dataset[\"session\"][2500], return_tensors=\"pt\")[\"input_ids\"][0]])\n    # print(tokenizer.decode(collator_res[\"labels\"][0][collator_res[\"labels\"][0] != -100]))\n\n    model: MistralForCausalLM = AutoModelForCausalLM.from_pretrained(\n        model_name_or_path,\n        quantization_config=BitsAndBytesConfig(\n            load_in_4bit=True,\n            bnb_4bit_quant_type=\"nf4\",\n            bnb_4bit_compute_dtype=torch.bfloat16,\n            bnb_4bit_use_double_quant=False,\n        ),\n        device_map={\"\": 0},\n    )\n    model.config\n    model.gradient_checkpointing_enable()\n\n    peft_config = LoraConfig(\n        r=lora_r,\n        lora_alpha=lora_alpha,\n        lora_dropout=lora_dropout,\n        bias=\"none\",\n        task_type=\"CAUSAL_LM\",\n        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\"],\n    )\n\n    model = prepare_model_for_kbit_training(model)\n    logger.info(model)\n\n    model = get_peft_model(model, peft_config)\n    logger.info(model)\n    print_trainable_parameters(model)\n\n    wandb.init(project=wandb_project, job_type=\"train\", anonymous=\"allow\")\n\n    training_arguments = TrainingArguments(\n        output_dir=output_dir,\n        num_train_epochs=num_epochs,\n        per_device_train_batch_size=micro_batch_size,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        optim=\"paged_adamw_8bit\",\n        save_steps=500,\n        logging_steps=10,\n        logging_first_step=True,\n        learning_rate=learning_rate,\n        weight_decay=weight_decay,\n        fp16=False,\n        bf16=False,\n        max_grad_norm=0.3,\n        max_steps=-1,\n        warmup_ratio=warmup_ratio,\n        group_by_length=True,\n        lr_scheduler_type=lr_scheduler_type,\n        report_to=[\"wandb\"],\n    )\n\n    trainer = SFTTrainer(\n        model=model,\n        train_dataset=dataset,\n        peft_config=peft_config,\n        max_seq_length=max_seq_length,\n        dataset_text_field=\"session\",\n        tokenizer=tokenizer,\n        args=training_arguments,\n        packing=False,\n        data_collator=data_collator,\n    )\n\n    trainer.train()\n    trainer.model.save_pretrained(output_dir)\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    fire.Fire(train)\n"
  },
  {
    "path": "prepare_dataset.py",
    "content": "import json\nfrom datetime import datetime, timedelta\nfrom typing import List, Literal, Optional, Tuple\n\nimport fire\nfrom loguru import logger\nfrom pydantic import BaseModel\nfrom tqdm import tqdm\n\n\nclass Message(BaseModel):\n    date: datetime\n    author: str\n    text: str\n\n\nclass Chat(BaseModel):\n    name: str\n    type: Literal[\"personal_chat\", \"private_group\", \"private_supergroup\"]\n    messages: List[Message]\n    sessions: Optional[List[List[Message]]] = []\n\n\ndef load_chats(path: str) -> Tuple[List[Chat], Tuple[int | None, str | None]]:\n    chats: List[Chat] = []\n    target_id, target_name = None, None\n    logger.info(f\"Loading chats from '{path}'...\")\n    with open(path, \"r\") as f:\n        for chat in json.load(f)[\"chats\"][\"list\"]:\n            # It means we encountered 'Saved Messages', from which\n            # we can extract id and a name of a target person\n            if \"name\" not in chat:\n                target_id = int(chat[\"id\"])\n                target_name = str(next(msg for msg in chat[\"messages\"] if msg[\"from_id\"] == f\"user{target_id}\")[\"from\"])\n            # If chat does not contain name that means we\n            # encountered \"Deleted Account\"\n            elif chat[\"name\"]:\n                messages = [\n                    Message(\n                        date=msg[\"date\"],\n                        author=msg[\"from\"],\n                        text=\"\".join([text_entity[\"text\"] for text_entity in msg[\"text_entities\"]])\n                        + msg.get(\"sticker_emoji\", \"\"),\n                    )\n                    for msg in chat[\"messages\"]\n                    if \"from\" in msg and msg[\"from\"] and (msg[\"text_entities\"] or \"sticker_emoji\" in msg)\n                ]\n                if messages:\n                    chat = Chat(name=chat[\"name\"], type=chat[\"type\"], messages=messages)\n                    chats.append(chat)\n    logger.info(f\"Found {len(chats)} chats in file '{path}'\")\n    if not target_name:\n        logger.warning(f\"Was not able to detect target name from 'Saved Messages'!\")\n    return chats, (target_id, target_name)\n\n\ndef transform_chats(\n    input: str,\n    output: str,\n    target_name: str | None = None,\n    last_x_months: int = 60,\n    session_minutes_threshold: int = 10,\n    concat_one_user_messages_delimeter: str = \"\\n>>> \",\n):\n    \"\"\"\n    Transforms chat histories from .json telegram export to .json with a list of sessions.\n    Session is a list of messages, where each message is a dict with fields 'author' and 'text'.\n\n    :param input: Path to .json telegram export, usually called result.json\n    :param output: Path to output .json file\n    :param target_name: The name of the person to target. This person will be present in every session. If empty, will be tried to be detected from \"Saved Messages\"\n    :param last_x_months: Number of last months to use messages from\n    :param session_minutes_threshold: Threshold in minutes where messages will belong to the same session\n    :param concat_one_user_messages_delimeter: Users might type several messages one after each other. They are concatenated using this delimeter\n    \"\"\"\n    chats, (_, extracted_target_name) = load_chats(input)\n    if not target_name:\n        target_name = extracted_target_name\n    logger.info(f\"Preparing dataset for user with name '{target_name}'...\")\n\n    # Dropping all messages which are older than given date\n    for chat in chats:\n        chat.messages = [msg for msg in chat.messages if msg.date > datetime.now() - timedelta(days=last_x_months * 30)]\n    chats = [chat for chat in chats if chat.messages]\n    logger.info(f\"After filtering by date, there are {len(chats)} chats left\")\n\n    # Create sessions for each chat by combining messages within\n    # session_minutes_threshold into one session\n    for chat in chats:\n        sessions = []\n        current_session = []\n        for msg in chat.messages:\n            if not current_session or (msg.date - current_session[-1].date).seconds / 60 < session_minutes_threshold:\n                current_session.append(msg)\n            else:\n                sessions.append(current_session)\n                current_session = [msg]\n        if current_session:\n            sessions.append(current_session)\n        chat.sessions = sessions\n    logger.info(f\"Combined messages into sessions\")\n\n    # Combine consecutive messages from single user into one message\n    for chat in chats:\n        sessions = []\n        for session in chat.sessions:\n            current_session = []\n            current_message = session[0]\n            current_message.text = concat_one_user_messages_delimeter.lstrip() + current_message.text\n            for msg in session[1:]:\n                if msg.author == current_message.author:\n                    current_message.text += concat_one_user_messages_delimeter + msg.text\n                else:\n                    current_session.append(current_message)\n                    current_message = msg\n                    current_message.text = concat_one_user_messages_delimeter.lstrip() + current_message.text\n            current_session.append(current_message)\n            sessions.append(current_session)\n        chat.sessions = sessions\n    logger.info(f\"Combined consecutive messages from single user into one message\")\n\n    # Only leave sessions which have target_name in them\n    # (1st does not count as we can't use it for training)\n    for chat in chats:\n        chat.sessions = [session for session in chat.sessions if any(msg.author == target_name for msg in session[1:])]\n\n    # # Cut off messages in each session by last message from target_name\n    # for chat in chats:\n    #     for session in chat.sessions:\n    #         session[:] = session[: max(i for i, msg in enumerate(session) if msg.author == target_name) + 1]\n\n    # Remove date from messages\n    for chat in chats:\n        for session in chat.sessions:\n            for msg in session:\n                del msg.date\n\n    # Write sessions to jsonl\n    all_sessions = []\n    for chat in chats:\n        for session in chat.sessions:\n            all_sessions.append(session)\n    with open(output, \"w\") as f:\n        json.dump(\n            [[{\"author\": msg.author, \"text\": msg.text} for msg in session] for session in all_sessions],\n            f,\n            indent=4,\n            ensure_ascii=False,\n        )\n    logger.info(\n        f\"Took {len(all_sessions)} sessions from {len(chats)} chats and wrote them to '{output}'. Average session length is {round(sum(len(session) for session in all_sessions) / len(all_sessions), 2)} messages\"\n    )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(transform_chats)\n"
  },
  {
    "path": "requirements.txt",
    "content": "scipy\ntorch\ntransformers\npeft\naccelerate\nbitsandbytes\ntrl\ndatasets\nsentencepiece\nwandb\nloguru\ntqdm\nfire\n"
  },
  {
    "path": "utils/finetune_utils.py",
    "content": "import json\nfrom typing import Any, Dict, List, Union\n\nfrom loguru import logger\nfrom transformers import DataCollatorForLanguageModeling\n\n\ndef print_trainable_parameters(model):\n    trainable_params = 0\n    all_param = 0\n    for _, param in model.named_parameters():\n        all_param += param.numel()\n        if param.requires_grad:\n            trainable_params += param.numel()\n    logger.info(\n        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {round(100 * trainable_params / all_param, 2)}\"\n    )\n\n\ndef prepare_dataset(path: str) -> List[str]:\n    with open(path, \"r\") as f:\n        sessions = json.load(f)\n\n    final_sessions = []\n    for session in sessions:\n        session_str = \"\\n\".join([f\"<|im_start|>{msg['author']}\\n{msg['text']}<|im_end|>\" for msg in session])\n        final_sessions.append(session_str)\n\n    return final_sessions\n\n\nclass DataCollatorForLanguageModelingChatML(DataCollatorForLanguageModeling):\n    \"\"\"\n    Data collator for [ChatML](https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/chatml.md) format, like:\n    ```\n    <|im_start|>John Smith\n    >>> damn, can't get around the 135 time limit\n    >>> trying to do everything super optimally, but no luck<|im_end|>\n    <|im_start|>Alexander Smirnov\n    >>> yeah same\n    >>> you still going with the same idea?<|im_end|>\n    ```\n    It will label any rows which are not starting with '>>>' to ignore.\n\n    Reference data collator implementation: [DataCollatorForCompletionOnlyLM](https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L56)\n\n    Args:\n        mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying\n            `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present\n             for flexibility and backwards-compatibility.\n        ignore_index (`int`, *optional*, defaults to `-100`):\n            The index to use to ignore the initial tokens with\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        mlm: bool = False,\n        ignore_index: int = -100,\n        **kwargs,\n    ):\n        super().__init__(*args, mlm=mlm, **kwargs)\n        self.ignore_index = ignore_index\n        self.start_token = self.tokenizer.encode(\"<|im_start|>\", add_special_tokens=False)[0]\n        self.end_token = self.tokenizer.encode(\"<|im_end|>\", add_special_tokens=False)[0]\n        self.new_line_token = self.tokenizer.encode(\"\\n\", add_special_tokens=False)[-1]\n        self.bos_token = self.tokenizer.bos_token_id\n\n    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:\n        batch = super().torch_call(examples)\n\n        for i in range(len(examples)):\n            if_start = False\n            for j in range(len(batch[\"labels\"][i])):\n                token = batch[\"labels\"][i][j].item()\n\n                if token == self.start_token:\n                    if_start = True\n\n                if if_start or token == self.bos_token:\n                    batch[\"labels\"][i][j] = self.ignore_index\n\n                if token == self.new_line_token:\n                    if_start = False\n\n        return batch\n"
  }
]