Repository: furiousteabag/doppelganger Branch: master Commit: 1143bcd9d4c8 Files: 10 Total size: 24.2 KB Directory structure: gitextract__ysxx7ew/ ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── data/ │ └── .gitignore ├── finetune_full.py ├── finetune_lora.py ├── prepare_dataset.py ├── requirements.txt └── utils/ └── finetune_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__ .ipynb_checkpoints .pytest_cache .vim-arsync *.log result.json weights/ wandb/ .cache/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Alexander Smirnov Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ LINE_WIDTH=120 ISORT_FLAGS=--line-width=${LINE_WIDTH} --profile black BLACK_FLAGS=--line-length=${LINE_WIDTH} install: pip install -r requirements.txt install-format: pip install black isort format: isort ${ISORT_FLAGS} --check-only --diff . black ${BLACK_FLAGS} --check --diff . format-fix: isort ${ISORT_FLAGS} . black ${BLACK_FLAGS} . ================================================ FILE: README.md ================================================ # Doppelganger Fine-tuning LLM on my Telegram chats. You may read full story [in my blog](https://asmirnov.xyz/doppelganger). ## Dataset Preparation First, we have to get the data. Open Telegram, go to 'Setting' -> 'Advanced' -> 'Export Telegram Data' and unselect everything except 'Personal chats' and 'Private groups' (don't select 'Only my messages there'). As output format choose 'Machine-readable JSON'. It will result in `result.json`. Use `prepare_dataset.py` to transform `result.json` to JSON with a list of sessions: ```bash python prepare_dataset.py "./data/result.json" "./data/messages.json" ``` There are some flags available for this script, you can read more in `--help`: ```bash python prepare_dataset.py --help ```
output ``` NAME prepare_dataset.py - Transforms chat histories from .json telegram export to .json with a list of sessions. Session is a list of messages, where each message is a dict with fields 'author' and 'text'. SYNOPSIS prepare_dataset.py INPUT OUTPUT DESCRIPTION Transforms chat histories from .json telegram export to .json with a list of sessions. Session is a list of messages, where each message is a dict with fields 'author' and 'text'. POSITIONAL ARGUMENTS INPUT Type: str Path to .json telegram export, usually called result.json OUTPUT Type: str Path to output .json file FLAGS -t, --target_name=TARGET_NAME Type: Optional[str | None] Default: None The name of the person to target. This person will be present in every session. If empty, will be tried to be detected from "Saved Messages" -l, --last_x_months=LAST_X_MONTHS Type: int Default: 24 Number of last months to use messages from -s, --session_minutes_threshold=SESSION_MINUTES_THRESHOLD Type: int Default: 10 Threshold in minutes where messages will belong to the same session -c, --concat_one_user_messages_delimeter=CONCAT_ONE_USER_MESSAGES_DELIMETER Type: str Default: '\n>>> ' Users might type several messages one after each other. They are concatenated using this delimeter NOTES You can also use flags syntax for POSITIONAL ARGUMENTS ```
If you are interested, Telegram have several types of messages which should be handled differently:
default text message ``` { "id": 123, "type": "message", "date": "2023-10-31T15:23:38", "date_unixtime": "1698746018", "from": "Username", "from_id": "user123", "text": "ты где?", "text_entities": [ { "type": "plain", "text": "ты где?" } ] } ```
multiple text entities ``` { "id": 345, "type": "message", "date": "2023-10-25T01:56:50", "date_unixtime": "1698179210", "from": "Username", "from_id": "user456", "text": [ "California suspends GM Cruise's autonomous vehicle deployment | Hacker News\n", { "type": "link", "text": "https://news.ycombinator.com/item?id=38002752" } ], "text_entities": [ { "type": "plain", "text": "California suspends GM Cruise's autonomous vehicle deployment | Hacker News\n" }, { "type": "link", "text": "https://news.ycombinator.com/item?id=38002752" } ] } ```
sticker ``` { "id": 789, "type": "message", "date": "2023-10-30T23:24:20", "date_unixtime": "1698688460", "from": "Username", "from_id": "user789", "file": "(File not included. Change data exporting settings to download.)", "thumbnail": "(File not included. Change data exporting settings to download.)", "media_type": "sticker", "sticker_emoji": "🤗", "width": 512, "height": 501, "text": "", "text_entities": [] } ```
## Training Final version of models were trained with the parameters which are default in training scripts. Training logs can be accessed on [WandB](https://wandb.ai/furiousteabag/doppelganger). ### LoRA fine-tune To launch LoRA fine-tune with my default params, you will need GPU with 20GB VRAM. RTX 3090 is a good option for it's money. You may reduce `micro_batch_size` or `max_seq_length` if you want to lower the amount of VRAM required. To get full list of parameters, run: ``` python finetune_lora.py --help ``` To train LoRA, run: ``` python finetune_lora.py ``` ### Full fine-tune To list available params with their default values, run: ``` python finetune_full.py --help ``` To train: ``` torchrun --nnodes=1 --nproc_per_node=NUMBER_OF_GPUS finetune_full.py ``` ## Launching Use [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui). If you used LoRA, then clone [ehartford/dolphin-2.2.1-mistral-7b](https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b) or whatever model you are used as a base model and put trained LoRA connectors to `./loras/` folder within text-generation-webui. If you did full fine-tune, then copy training result to `./models/`. ================================================ FILE: data/.gitignore ================================================ * !.gitignore ================================================ FILE: finetune_full.py ================================================ import os from typing import Dict, List import fire import torch import wandb from datasets import Dataset from loguru import logger from tokenizers import AddedToken from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, TrainingArguments from trl import SFTTrainer from utils.finetune_utils import DataCollatorForLanguageModelingChatML, prepare_dataset, print_trainable_parameters DEFAULT_BOS_TOKEN = "" DEFAULT_EOS_TOKEN = "<|im_end|>" DEFAULT_UNK_TOKEN = "" DEFAULT_PAD_TOKEN = "" def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: PreTrainedTokenizer, model: PreTrainedModel, tokens_list: List = [] ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(tokens_list) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg def train( model_name_or_path: str = "mistralai/Mistral-7B-v0.1", data_path: str = "./data/messages.json", output_dir: str = "./weights/full/", gradient_accumulation_steps: int = 4, micro_batch_size: int = 2, num_epochs: int = 3, learning_rate: float = 2e-5, lr_scheduler_type: str = "cosine", warmup_ratio: float = 0.03, weight_decay: float = 0.0, max_seq_length: int = 1024, fsdp: str = "full_shard auto_wrap", fsdp_transformer_layer_cls_to_wrap: str = "MistralDecoderLayer", wandb_project: str = "doppelganger", logging_steps: int = 1, ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: batch_size = micro_batch_size * gradient_accumulation_steps * int(os.environ.get("WORLD_SIZE", 1)) logger.info(f"Total batch size: {batch_size}") model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16) model.config.use_cache = False print_trainable_parameters(model) tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, model_max_length=max_seq_length, padding_side="right", use_fast=False ) special_tokens_dict = dict() if tokenizer.bos_token is None: special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN if tokenizer.pad_token is None: special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if tokenizer.unk_token is None: special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN tokens_list = [AddedToken("<|im_start|>", normalized=False)] smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokens_list=tokens_list, tokenizer=tokenizer, model=model, ) data_collator = DataCollatorForLanguageModelingChatML(tokenizer=tokenizer) dataset = Dataset.from_dict({"session": prepare_dataset(data_path)}) if int(os.environ.get("LOCAL_RANK", 0)) == 0: wandb.init(project=wandb_project, job_type="train", anonymous="allow") training_arguments = TrainingArguments( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, optim="adamw_torch", save_steps=500, logging_steps=logging_steps, logging_first_step=True, learning_rate=learning_rate, weight_decay=weight_decay, bf16=True, max_steps=-1, warmup_ratio=warmup_ratio, group_by_length=True, lr_scheduler_type=lr_scheduler_type, fsdp=fsdp, fsdp_transformer_layer_cls_to_wrap=fsdp_transformer_layer_cls_to_wrap, report_to=["wandb"] if int(os.environ.get("LOCAL_RANK", 0)) == 0 else [], ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, data_collator=data_collator, dataset_text_field="session", max_seq_length=max_seq_length, packing=False, args=training_arguments, ) trainer.train() trainer.save_state() trainer.save_model(output_dir) wandb.finish() if __name__ == "__main__": fire.Fire(train) ================================================ FILE: finetune_lora.py ================================================ import fire import torch import wandb from datasets import Dataset from loguru import logger from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, MistralForCausalLM, TrainingArguments from trl import SFTTrainer from utils.finetune_utils import DataCollatorForLanguageModelingChatML, prepare_dataset, print_trainable_parameters def train( model_name_or_path: str = "ehartford/dolphin-2.2.1-mistral-7b", data_path: str = "./data/messages.json", output_dir: str = "./weights/LoRA/", batch_size: int = 16, micro_batch_size: int = 8, num_epochs: int = 3, lora_r: int = 32, lora_alpha: int = 16, lora_dropout: float = 0.05, learning_rate: float = 2e-4, lr_scheduler_type: str = "cosine", warmup_ratio: float = 0.03, weight_decay: float = 0.001, max_seq_length: int = 1024, wandb_project: str = "doppelganger", ): gradient_accumulation_steps = batch_size // micro_batch_size tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) data_collator = DataCollatorForLanguageModelingChatML(tokenizer=tokenizer) dataset = Dataset.from_dict({"session": prepare_dataset(data_path)}) # print(dataset[2500]["session"]) # collator_res = data_collator([tokenizer(dataset["session"][2500], return_tensors="pt")["input_ids"][0]]) # print(tokenizer.decode(collator_res["labels"][0][collator_res["labels"][0] != -100])) model: MistralForCausalLM = AutoModelForCausalLM.from_pretrained( model_name_or_path, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False, ), device_map={"": 0}, ) model.config model.gradient_checkpointing_enable() peft_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"], ) model = prepare_model_for_kbit_training(model) logger.info(model) model = get_peft_model(model, peft_config) logger.info(model) print_trainable_parameters(model) wandb.init(project=wandb_project, job_type="train", anonymous="allow") training_arguments = TrainingArguments( output_dir=output_dir, num_train_epochs=num_epochs, per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, optim="paged_adamw_8bit", save_steps=500, logging_steps=10, logging_first_step=True, learning_rate=learning_rate, weight_decay=weight_decay, fp16=False, bf16=False, max_grad_norm=0.3, max_steps=-1, warmup_ratio=warmup_ratio, group_by_length=True, lr_scheduler_type=lr_scheduler_type, report_to=["wandb"], ) trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, max_seq_length=max_seq_length, dataset_text_field="session", tokenizer=tokenizer, args=training_arguments, packing=False, data_collator=data_collator, ) trainer.train() trainer.model.save_pretrained(output_dir) wandb.finish() if __name__ == "__main__": fire.Fire(train) ================================================ FILE: prepare_dataset.py ================================================ import json from datetime import datetime, timedelta from typing import List, Literal, Optional, Tuple import fire from loguru import logger from pydantic import BaseModel from tqdm import tqdm class Message(BaseModel): date: datetime author: str text: str class Chat(BaseModel): name: str type: Literal["personal_chat", "private_group", "private_supergroup"] messages: List[Message] sessions: Optional[List[List[Message]]] = [] def load_chats(path: str) -> Tuple[List[Chat], Tuple[int | None, str | None]]: chats: List[Chat] = [] target_id, target_name = None, None logger.info(f"Loading chats from '{path}'...") with open(path, "r") as f: for chat in json.load(f)["chats"]["list"]: # It means we encountered 'Saved Messages', from which # we can extract id and a name of a target person if "name" not in chat: target_id = int(chat["id"]) target_name = str(next(msg for msg in chat["messages"] if msg["from_id"] == f"user{target_id}")["from"]) # If chat does not contain name that means we # encountered "Deleted Account" elif chat["name"]: messages = [ Message( date=msg["date"], author=msg["from"], text="".join([text_entity["text"] for text_entity in msg["text_entities"]]) + msg.get("sticker_emoji", ""), ) for msg in chat["messages"] if "from" in msg and msg["from"] and (msg["text_entities"] or "sticker_emoji" in msg) ] if messages: chat = Chat(name=chat["name"], type=chat["type"], messages=messages) chats.append(chat) logger.info(f"Found {len(chats)} chats in file '{path}'") if not target_name: logger.warning(f"Was not able to detect target name from 'Saved Messages'!") return chats, (target_id, target_name) def transform_chats( input: str, output: str, target_name: str | None = None, last_x_months: int = 60, session_minutes_threshold: int = 10, concat_one_user_messages_delimeter: str = "\n>>> ", ): """ Transforms chat histories from .json telegram export to .json with a list of sessions. Session is a list of messages, where each message is a dict with fields 'author' and 'text'. :param input: Path to .json telegram export, usually called result.json :param output: Path to output .json file :param target_name: The name of the person to target. This person will be present in every session. If empty, will be tried to be detected from "Saved Messages" :param last_x_months: Number of last months to use messages from :param session_minutes_threshold: Threshold in minutes where messages will belong to the same session :param concat_one_user_messages_delimeter: Users might type several messages one after each other. They are concatenated using this delimeter """ chats, (_, extracted_target_name) = load_chats(input) if not target_name: target_name = extracted_target_name logger.info(f"Preparing dataset for user with name '{target_name}'...") # Dropping all messages which are older than given date for chat in chats: chat.messages = [msg for msg in chat.messages if msg.date > datetime.now() - timedelta(days=last_x_months * 30)] chats = [chat for chat in chats if chat.messages] logger.info(f"After filtering by date, there are {len(chats)} chats left") # Create sessions for each chat by combining messages within # session_minutes_threshold into one session for chat in chats: sessions = [] current_session = [] for msg in chat.messages: if not current_session or (msg.date - current_session[-1].date).seconds / 60 < session_minutes_threshold: current_session.append(msg) else: sessions.append(current_session) current_session = [msg] if current_session: sessions.append(current_session) chat.sessions = sessions logger.info(f"Combined messages into sessions") # Combine consecutive messages from single user into one message for chat in chats: sessions = [] for session in chat.sessions: current_session = [] current_message = session[0] current_message.text = concat_one_user_messages_delimeter.lstrip() + current_message.text for msg in session[1:]: if msg.author == current_message.author: current_message.text += concat_one_user_messages_delimeter + msg.text else: current_session.append(current_message) current_message = msg current_message.text = concat_one_user_messages_delimeter.lstrip() + current_message.text current_session.append(current_message) sessions.append(current_session) chat.sessions = sessions logger.info(f"Combined consecutive messages from single user into one message") # Only leave sessions which have target_name in them # (1st does not count as we can't use it for training) for chat in chats: chat.sessions = [session for session in chat.sessions if any(msg.author == target_name for msg in session[1:])] # # Cut off messages in each session by last message from target_name # for chat in chats: # for session in chat.sessions: # session[:] = session[: max(i for i, msg in enumerate(session) if msg.author == target_name) + 1] # Remove date from messages for chat in chats: for session in chat.sessions: for msg in session: del msg.date # Write sessions to jsonl all_sessions = [] for chat in chats: for session in chat.sessions: all_sessions.append(session) with open(output, "w") as f: json.dump( [[{"author": msg.author, "text": msg.text} for msg in session] for session in all_sessions], f, indent=4, ensure_ascii=False, ) logger.info( f"Took {len(all_sessions)} sessions from {len(chats)} chats and wrote them to '{output}'. Average session length is {round(sum(len(session) for session in all_sessions) / len(all_sessions), 2)} messages" ) if __name__ == "__main__": fire.Fire(transform_chats) ================================================ FILE: requirements.txt ================================================ scipy torch transformers peft accelerate bitsandbytes trl datasets sentencepiece wandb loguru tqdm fire ================================================ FILE: utils/finetune_utils.py ================================================ import json from typing import Any, Dict, List, Union from loguru import logger from transformers import DataCollatorForLanguageModeling def print_trainable_parameters(model): trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() logger.info( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {round(100 * trainable_params / all_param, 2)}" ) def prepare_dataset(path: str) -> List[str]: with open(path, "r") as f: sessions = json.load(f) final_sessions = [] for session in sessions: session_str = "\n".join([f"<|im_start|>{msg['author']}\n{msg['text']}<|im_end|>" for msg in session]) final_sessions.append(session_str) return final_sessions class DataCollatorForLanguageModelingChatML(DataCollatorForLanguageModeling): """ Data collator for [ChatML](https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/chatml.md) format, like: ``` <|im_start|>John Smith >>> damn, can't get around the 135 time limit >>> trying to do everything super optimally, but no luck<|im_end|> <|im_start|>Alexander Smirnov >>> yeah same >>> you still going with the same idea?<|im_end|> ``` It will label any rows which are not starting with '>>>' to ignore. Reference data collator implementation: [DataCollatorForCompletionOnlyLM](https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L56) Args: mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present for flexibility and backwards-compatibility. ignore_index (`int`, *optional*, defaults to `-100`): The index to use to ignore the initial tokens with """ def __init__( self, *args, mlm: bool = False, ignore_index: int = -100, **kwargs, ): super().__init__(*args, mlm=mlm, **kwargs) self.ignore_index = ignore_index self.start_token = self.tokenizer.encode("<|im_start|>", add_special_tokens=False)[0] self.end_token = self.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] self.new_line_token = self.tokenizer.encode("\n", add_special_tokens=False)[-1] self.bos_token = self.tokenizer.bos_token_id def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: batch = super().torch_call(examples) for i in range(len(examples)): if_start = False for j in range(len(batch["labels"][i])): token = batch["labels"][i][j].item() if token == self.start_token: if_start = True if if_start or token == self.bos_token: batch["labels"][i][j] = self.ignore_index if token == self.new_line_token: if_start = False return batch