Repository: ixaxaar/VardaGPT Branch: master Commit: 191954143b16 Files: 28 Total size: 76.2 KB Directory structure: gitextract_5qwi3ryd/ ├── .github/ │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Makefile ├── README.md ├── STORY.md ├── requirements.txt ├── setup.cfg ├── src/ │ ├── data.py │ ├── memory/ │ │ ├── __init__.py │ │ ├── associative.py │ │ ├── batch_associative.py │ │ └── multitoken_batch_associative.py │ ├── models/ │ │ ├── __init__.py │ │ ├── gpt2_associative.py │ │ └── gpt2_working.py │ ├── train.py │ └── train_parallel.py └── test/ ├── __init__.py ├── memory/ │ ├── __init__..py │ ├── test_associative.py │ ├── test_associative_batch.py │ ├── test_memory_features.py │ ├── test_memory_types.py │ └── test_multitoken.py ├── models/ │ ├── __init__.py │ └── test_gpt2_associative.py └── test_training.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ci.yml ================================================ name: Python CI on: push: branches: - main pull_request: jobs: test: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: '3.10' - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - name: Run tests and coverage run: | pytest --cov=./src coverage html shell: bash - name: Save coverage report uses: actions/upload-artifact@v2 with: name: coverage-report path: htmlcov - name: Generate coverage comment uses: rabelenda/python-coverage-comment@v1 with: artifact: coverage-report artifact-type: path ================================================ FILE: .gitignore ================================================ __pycache__/ *.py[cod] *.mo /venv .env env.production env.devtest .DS_Store ._* *.log !.vscode/settings.json .vscode/launch.json docker-compose.yml geo .vscode/ .coverage ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black args: [--line-length=120] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: - id: check-json - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-prettier rev: v2.6.2 hooks: - id: prettier args: - --print-width=80 - --prose-wrap=always - --tab-width=2 - --single-quote=true - --no-bracket-spacing=true - --trailing-comma=es5 files: \.(md|json)$ additional_dependencies: - "prettier@2.6.2" - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.9.0 hooks: - id: python-check-blanket-noqa - id: python-check-mock-methods - id: python-use-type-annotations - repo: https://github.com/pycqa/flake8 rev: 4.0.1 hooks: - id: flake8 additional_dependencies: - flake8-bugbear - flake8-builtins - flake8-comprehensions - pep8-naming - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.920 hooks: - id: mypy additional_dependencies: - types-requests ================================================ FILE: Makefile ================================================ .PHONY: help setup train evaluate inference precommit format clean .DEFAULT_GOAL := help help: @echo "VardaGPT - Memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS" @echo "Usage:" @echo " make " @echo "" @echo "Commands:" @echo " help Display this help message" @echo " setup Set up the project by creating a virtual environment and installing dependencies" @echo " train Train the VardaGPT model" @echo " evaluate Evaluate the trained model on validation and testing sets" @echo " inference Generate text using the memory-enhanced GPT-2 model" @echo " precommit Run pre-commit hooks manually on all files" @echo " format Format code using black, flake8, mypy, and prettier" @echo " clean Clean up the project directory by removing virtual environment and temporary files" setup: python -m venv venv source venv/bin/activate pip install -r requirements.txt train: source venv/bin/activate python src/train.py train-parallel: source venv/bin/activate python src/train_parallel.py evaluate: source venv/bin/activate python src/evaluate.py inference: source venv/bin/activate python src/inference.py --prompt "Your prompt text here" precommit: pre-commit run --all-files format: pre-commit run --all-files test: coverage run -m pytest --log-cli-level=DEBUG --capture=tee-sys -v . clean: rm -rf venv/ find . -type f -name "*.pyc" -exec rm -f {} \; find . -type d -name "__pycache__" -exec rm -rf {} \; ================================================ FILE: README.md ================================================ # VardaGPT - [VardaGPT](#vardagpt) - [TLDR - Training](#tldr---training) - [Requirements](#requirements) - [Usage](#usage) - [Overview](#overview) - [Models](#models) - [Training, Evaluation, and Fine-tuning Process](#training-evaluation-and-fine-tuning-process) - [1. Data Preparation](#1-data-preparation) - [2. GPT-2 Model Adaptation](#2-gpt-2-model-adaptation) - [3. Training](#3-training) - [4. Evaluation](#4-evaluation) - [5. Fine-tuning (if necessary)](#5-fine-tuning-if-necessary) - [Prerequisites](#prerequisites) - [Setup](#setup) - [Directory Structure](#directory-structure) - [Usage](#usage-1) - [Data Preparation](#data-preparation) - [Training](#training) - [Evaluation](#evaluation) - [Inference](#inference) - [Contributing](#contributing) - [Code Formatting and Pre-commit](#code-formatting-and-pre-commit) - [Setup](#setup-1) - [Using Pre-commit](#using-pre-commit) - [License](#license) VardaGPT is a memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS. Inspired by J.R.R. Tolkien's Silmarillion, VardaGPT aims to provide guidance and knowledge through its memory-augmented text generation capabilities. ## TLDR - Training The `VardaGPTAssociative` model combines GPT-2 with an associative memory to improve context retrieval. This repository includes a script to train this model on the WikiText-2 dataset. ### Requirements - Python 3.7+ - PyTorch 1.8.1+ - torchtext 0.9.1 - transformers 4.10.0 - rich 10.3.0 - faiss-cpu 1.7.1 To install the required packages, you can use the following command: ```bash pip install -r requirements.txt ``` ### Usage To train the `VardaGPTAssociative` model on the WikiText-2 dataset, use the provided training script (`train_varda_gpt_associative.py`). You can customize the training settings by passing command-line arguments. Here's a basic example: ```bash python train_varda_gpt_associative.py --epochs 5 --learning_rate 1e-4 --use_gpu ``` Available command-line arguments: - `--epochs`: Number of epochs to train the model (default: 5). - `--learning_rate`: Learning rate for the optimizer (default: 1e-4). - `--memory_size`: Maximum number of items the associative memory can store (default: 10000). - `--memory_dim`: Dimensionality of the embeddings stored in the associative memory (default: 768). - `--index_type`: Type of index used for the associative memory (default: "flat"). - `--num_clusters`: Number of clusters to use for the memory if the index type is "ivf" (default: 1024). - `--num_search_results`: Number of search results to return from the associative memory (default: 5). - `--use_gpu`: Whether to use the GPU for the model if available (default: False). - `--batch_size`: Batch size for training (default: 1). - `--forgetfulness_factor`: Forgetfulness factor for the associative memory (default: 0.001). During training, the script will periodically print the training loss, validation loss, and elapsed time for each epoch, along with a progress bar for each training step. After training, you can use the trained model for your specific use case, such as text generation or fine-tuning for a particular task. ## Overview
Click me ```plantuml @startuml !define AWSPUML https://raw.githubusercontent.com/awslabs/aws-icons-for-plantuml/v14.0 actor User skinparam component { BackgroundColor<> LightSkyBlue BackgroundColor<> Plum BackgroundColor<> LightGreen BackgroundColor<> LightSalmon BackgroundColor<> LightCoral BorderColor Black FontName Arial } package "VardaGPT" { [Data Preparation]<> --> [FAISS Memory]<> [Data Preparation]<> --> [GPT-2 Adaptation]<> [FAISS Memory]<> --> [GPT-2 Adaptation]<> [GPT-2 Adaptation]<> --> [Training]<> [Training]<> --> [Inference]<> [FAISS Memory]<> --> [Inference]<> User --> [Data Preparation]<> : Dataset User --> [Inference]<> : Prompts } @enduml ```
![overview](./assets/README.svg) This diagram shows the main components of the VardaGPT project and their interactions. The Data Preparation component processes the dataset and feeds it to both the FAISS Memory Model and the GPT-2 Model Adaptation component. The FAISS Memory Model generates embeddings, which are used by the GPT-2 Model Adaptation component to create a modified GPT-2 model. The modified GPT-2 model is then trained and evaluated, and the final trained model is used in the Inference and Application component. The user provides the dataset and prompts for text generation. ## Models The associative memory model:
Click me ```plantuml @startuml rectangle "Input Vectors" as input #b3e0ff rectangle "Memory" as memory #f2d7b9 rectangle "Concatenated Input" as concatenated_input #f6e3c6 rectangle "Fully Connected Layer (fc)" as fc #e5ebf0 rectangle "GPT-2 Transformer" as transformer #c6e0b4 rectangle "GPT-2 LM Head" as lm_head #c9daf8 rectangle "Fully Connected Layer\n(fc_storable_vector)" as fc_storable_vector #c9daf8 rectangle "Fully Connected Layer\n(fc_store_decision)" as fc_store_decision #c9daf8 input -down-> memory : Perform search in memory memory -down-> concatenated_input : Concatenate search results with input vectors concatenated_input -down-> fc : Apply fully connected layer (fc) fc -down-> transformer : Pass through GPT-2 transformer transformer -down-> lm_head : Apply GPT-2 lm_head transformer -right-> fc_storable_vector : Apply fully connected layer (fc_storable_vector) transformer -right-> fc_store_decision : Apply fully connected layer (fc_store_decision) note right of fc_storable_vector: Calculate storable vector\n and store decision note right of fc_store_decision: Store the storable_vector in\n the associative memory if\n the store_decision is affirmative note bottom of lm_head: Return logits @enduml ```
![model1](./assets/README_001.svg)
Click me ```plantuml @startuml title Forward Function !define Tensor(t,d) t + " (" + d + ")" !define DEVICE "device" actor "input_vectors" as input_vectors actor "memory_input" as memory_input note right of input_vectors Tensor: (batch_size, seq_len, embedding_dim) end note note right of memory_input Tensor (optional): (batch_size, seq_len, embedding_dim) end note input_vectors -> DEVICE memory_input -> DEVICE DEVICE -> "search(memory_input)" as search search --> "indices, distances" as search_result note right of search_result Tensors: indices: (batch_size, seq_len, num_search_results) distances: (batch_size, seq_len, num_search_results) end note search_result -> "get_all_embeddings()" as all_embeddings note right of all_embeddings Tensor: (memory_size, embedding_dim) end note all_embeddings -> "search_results" as search_results note right of search_results Tensor: (batch_size, seq_len, search_results_dim) end note search_results --> "concatenate(input_vectors, search_results)" as concatenated_input note right of concatenated_input Tensor: (batch_size, seq_len, embedding_dim + search_results_dim) end note concatenated_input --> "self.fc(concatenated_input)" as fc_output note right of fc_output Tensor: (batch_size, seq_len, embedding_dim) end note fc_output --> "self.gpt2_model.transformer(inputs_embeds=input_vectors)" as transformer_outputs transformer_outputs --> "hidden_states" as hidden_states note right of hidden_states Tensor: (batch_size, seq_len, embedding_dim) end note hidden_states --> "self.gpt2_model.lm_head(hidden_states)" as logits note right of logits Tensor: (batch_size, seq_len, vocab_size) end note hidden_states --> "self.fc_storable_vector(hidden_states)" as storable_vector note right of storable_vector Tensor: (batch_size, seq_len, memory_dim) end note hidden_states --> "self.fc_store_decision(hidden_states)" as store_decision note right of store_decision Tensor: (batch_size, seq_len, 1) end note hidden_states --> "self.fc_delete_decision(hidden_states)" as delete_decision note right of delete_decision Tensor: (batch_size, seq_len, num_search_results) end note hidden_states --> "self.fc_deletable_vector(hidden_states)" as deletable_vector note right of deletable_vector Tensor: (batch_size, seq_len, memory_dim) end note storable_vector --> "self.memory.add(storable_vector_to_store)" as add_memory deletable_vector --> "calculate L2 distances" as l2_distances note right of l2_distances Tensor: (batch_size, num_search_results) end note l2_distances --> "threshold comparison" as threshold_comparison note right of threshold_comparison Tensor (bool): (batch_size, num_search_results) end note threshold_comparison --> "self.memory.remove(indices_to_delete_flat)" as remove_memory logits --> "return logits" as return_logits @enduml ```
![model](./assets/README_002.svg) ## Training, Evaluation, and Fine-tuning Process
Click me ```plantuml @startuml skinparam activity { BackgroundColor LightSkyBlue BorderColor Black FontName Arial } start :Data Preparation; partition "FAISS Memory Model" { :Create FAISS Index; :Encode and Decode Text Data; :Test FAISS Index; } partition "GPT-2 Model Adaptation" { :Load Pre-trained GPT-2 Model; :Modify GPT-2 Architecture; :Define Custom Loss Function; } partition "Training" { :Train Adapted GPT-2 Model; :Save Model Checkpoints; } partition "Evaluation" { :Evaluate Model on Testing Set; :Calculate Metrics; } if (Fine-tuning needed?) then (Yes) partition "Fine-tuning" { :Adjust Hyperparameters; :Iterate Training and Evaluation; } endif partition "Inference and Application" { :Inference Function; :API or Interface; } stop @enduml ```
![process](./assets/README_003.svg) ### 1. Data Preparation - Collect and preprocess a dataset for training, evaluation, and fine-tuning. - Split the dataset into training, validation, and testing sets. - Create data loaders for handling data. ### 2. GPT-2 Model Adaptation - Load a pre-trained GPT-2 model from Hugging Face Transformers. - Modify the GPT-2 model architecture to incorporate the FAISS memory model. - Define a custom loss function that considers both the GPT-2 model's output and the memory model. ### 3. Training - Set up the training loop and train the adapted GPT-2 model. - Save model checkpoints and track training metrics (loss, perplexity, etc.). - Monitor the training progress, validate the model on the validation set, and perform early stopping if necessary. ### 4. Evaluation - Evaluate the trained model on the testing set. - Calculate evaluation metrics (e.g., perplexity, accuracy, F1-score). ### 5. Fine-tuning (if necessary) - If the model's performance on the testing set is not satisfactory, fine-tune the model with different hyperparameters, learning rates, or architectures. - Iterate through the training and evaluation steps until the desired performance is achieved. ## Prerequisites - Python 3.6 or higher - PyTorch - Hugging Face Transformers - FAISS (CPU or GPU version) ## Setup 1. Clone the repository: ```bash git clone https://github.com/yourusername/VardaGPT.git cd VardaGPT ``` 2. Create and activate a virtual environment: ```bash python -m venv venv source venv/bin/activate ``` 3. Install the required libraries: ```bash pip install -r requirements.txt ``` ## Directory Structure - `src/`: Contains the Python source code for the project. - `data/`: Stores the datasets used for training and evaluation. - `models/`: Holds the trained models and their checkpoints. ## Usage ### Data Preparation 1. Place your dataset in the `data/` directory. 2. Preprocess and split your dataset into training, validation, and testing sets using the provided scripts in `src/`. ### Training 1. Configure the training settings and model hyperparameters in the `src/config.py` file. 2. Run the training script: ```bash python src/train.py ``` 3. Monitor the training progress and save model checkpoints in the `models/` directory. ### Evaluation 1. Evaluate the trained model on the validation and testing sets using the provided evaluation script: ```bash python src/evaluate.py ``` ### Inference 1. Use the provided inference script to generate text with the memory-enhanced GPT-2 model: ```bash python src/inference.py --prompt "Your prompt text here" ``` ## Contributing Feel free to contribute to this project by submitting pull requests or opening issues for bug reports and feature requests. ## Code Formatting and Pre-commit This project uses `black`, `flake8`, and `mypy` for Python code formatting and linting. We also use `prettier` to format JSON and Markdown files. The configuration for these tools is in the `.pre-commit-config.yaml` file. ### Setup 1. Install `pre-commit` if you haven't already: ```bash pip install pre-commit ``` 2. Set up the git hooks: ```bash pre-commit install ``` ### Using Pre-commit Whenever you commit changes, the pre-commit hooks will automatically format your code and check for issues. If the hooks detect any problems, the commit will be aborted, and you'll see a list of issues that need to be fixed. Once you've resolved the issues, you can try committing again. You can also run the pre-commit hooks manually on all files: ```bash pre-commit run --all-files ``` Or run the hooks on specific files: ```bash pre-commit run --files ``` By following this setup and using pre-commit hooks, you can ensure that the code in the repository remains consistently formatted and adheres to the project's coding standards. ## License This project is licensed under the [MIT License](LICENSE). ================================================ FILE: STORY.md ================================================ # Story of this project 😅 ## Background 🤔 With all the hype around ChatGPT, I wondered how much impact ChatGPT really had. I mean, for a programmer, would ChatGPT be like a pair programmer? Like GitHub Copilot++? Or would ChatGPT totally replace programmers so that product managers could tell it what feature to build, and it would just build it! Imagine a bunch of product managers sitting in a sprint planning meeting where, after signing off on the tasks to be done this sprint and starting the sprint, ChatGPT was deployed on those tasks. The sprint lasted for about 2 hours, and everyone met again to do the next day's sprint grooming. 😆 ## Project Idea 💡 Now, what the heck should I build to test this? Why not try attaching a memory module to a GPT? I've seen some folks on the internet complain about the "low memory" problem of language models. I've also used FAISS and FLANN before, so I am familiar with how to technically achieve this. Whether it will actually work or not—well, my 1080Ti is on its deathbed with a broken fan, and I don't have the 💸 to train this thing on AWS anyway. Let's aim for unit tests to work then. ## Process 🏃 Okay. I started with the project plan: ![start](./assets/1.png) Then I made ChatGPT generate the project foundations, step by step, from creating project directories, Makefile, README, pre-commit, vscode settings for the same tools in pre-commit, setup.cfg, and a GitHub workflow to run tests. In each case, I had to specify exactly what I wanted it to generate. ![start](./assets/2.png) Yes, I made ChatGPT choose the project name as well. ![start](./assets/3.png) If I forgot something, I would go back and ask ChatGPT to add it: ![start](./assets/4.png) In fact, I found it better to let ChatGPT generate a toy-ish version of the code first, then let it add things to it step-by-step. This resulted in much better output than, say, asking ChatGPT to generate production-quality code with all features in the first go. This also gave me a way to break down my requirements and feed them one at a time - as I was also acting as a code-reviewer for the generated output, and so this method was also easier for me to work with. ![start](./assets/5.png) ![start](./assets/6.png) ![start](./assets/7.png) ![start](./assets/8.png) ![start](./assets/9.png) Of course, I made ChatGPT write unit tests, and if they failed, I would just copy the pytest output and feed it back into ChatGPT. ![start](./assets/10.png) ![start](./assets/11.png) ChatGPT even figured this out!: ![start](./assets/12.png) The result - I present to you [VardaGPT](https://github.com/ixaxaar/vardagpt)—every inch of this repository was generated by ChatGPT-4! It took a few hours, mostly around 3 weekends, mostly at odd times, to generate this project. ## Experience 😮 It felt neither like a Copilot++ nor like the product manager scenario but rather all at the same time. Sometimes I was amazed at what ChatGPT was able to understand, sometimes I had to stubbornly push it to go in a certain direction, sometimes it generated things I did not think of, sometimes I got super frustrated while making ChatGPT fix the code in a certain way. It was more like handholding a fresh grad who had absorbed all of human knowledge but needed someone to tie various parts of that knowledge to create something useful. Also ChatGPT is bad at dealing with abstractions beyond 2 layers. ChatGPT is definitely a productivity multiplier. I think it is rather a differential productivity multiplier, as it would enhance more the capabilities of those who already know more. If I did not understand deep learning and FAISS, or how projects are structured, I don't think I would have been able to pull this off. On the other hand, it also has some sort of a leveling effect—I have not worked on PyTorch in a while, have no idea of FAISS's new APIs, etc., but these gaps were filled in by ChatGPT. Finally, it was also tiring. Imagine being reduced to giving only instructions and doing code review. Reading and understanding code is tiring! ## Conclusion ❓ It looks like my job is safe this year. Time to generate an elaborate software project and claim supremacy on my ChatGPT usage abilities to hedge against next year. I wonder if by the time ChatGPT-6 comes out, would engineering teams be like, "Hey, let's generate our own Grafana with a purple theme 😄". ## Aside 🦄 I could not resist but add this bit. ChatGPT is great at generating Agda! Maybe this would also be the ultimate tool that can be used to formalize all of pure math? 😱 ![start](./assets/13.png) ================================================ FILE: requirements.txt ================================================ black==23.3.0 certifi==2022.12.7 charset-normalizer==3.1.0 click==8.1.3 cmake==3.26.3 coverage==7.2.3 exceptiongroup==1.1.1 faiss-cpu==1.7.3 filelock==3.12.0 flake8==6.0.0 huggingface-hub==0.13.4 idna==3.4 iniconfig==2.0.0 Jinja2==3.1.2 lit==16.0.1 markdown-it-py==2.2.0 MarkupSafe==2.1.2 mccabe==0.7.0 mdurl==0.1.2 mpmath==1.3.0 mypy==1.2.0 mypy-extensions==1.0.0 networkx==3.1 numpy==1.24.2 nvidia-cublas-cu11==11.10.3.66 nvidia-cuda-cupti-cu11==11.7.101 nvidia-cuda-nvrtc-cu11==11.7.99 nvidia-cuda-runtime-cu11==11.7.99 nvidia-cudnn-cu11==8.5.0.96 nvidia-cufft-cu11==10.9.0.58 nvidia-curand-cu11==10.2.10.91 nvidia-cusolver-cu11==11.4.0.1 nvidia-cusparse-cu11==11.7.4.91 nvidia-nccl-cu11==2.14.3 nvidia-nvtx-cu11==11.7.91 packaging==23.1 pathspec==0.11.1 platformdirs==3.2.0 pluggy==1.0.0 pycodestyle==2.10.0 pydantic==1.10.7 pyflakes==3.0.1 Pygments==2.15.1 pytest==7.3.1 pytest-cov==4.0.0 PyYAML==6.0 regex==2023.3.23 requests==2.28.2 rich==13.3.5 sympy==1.11.1 tokenizers==0.13.3 tomli==2.0.1 torch==2.0.0 torchdata==0.6.0 torchtext==0.15.1 tqdm==4.65.0 transformers==4.28.1 triton==2.0.0 typing_extensions==4.5.0 urllib3==1.26.15 ================================================ FILE: setup.cfg ================================================ [flake8] exclude = */__init__.py,migrations/* ignore = E111, E114, E121, E131, W503, F405, F403, E126, E501, F841, E124, E251, E203 max-line-length = 120 max-doc-length = 120 show-source = true statistics = false doctests = true [tool.black] line-length = 120 [mypy] ignore_missing_imports = True [mypy-tests.*] disallow_untyped_defs = False [tool:pytest] addopts = -p no:warnings ignore = tests ================================================ FILE: src/data.py ================================================ from typing import Any, Tuple import torch from torch.utils.data import DataLoader from torchtext.data import get_tokenizer from torchtext.datasets import WikiText2 from transformers import GPT2Tokenizer def load_wikitext2() -> Tuple[DataLoader[Any], DataLoader[Any], DataLoader[Any]]: """ Load the WikiText-2 dataset for training, validation, and testing. :return: A tuple of three DataLoaders for train, valid, and test sets. """ # Define a tokenizer using the GPT2Tokenizer from the transformers library tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Define a tokenize function using the GPT2Tokenizer def tokenize(text: str) -> torch.Tensor: return tokenizer.encode(text, return_tensors="pt").squeeze(0) # type: ignore # Use torchtext's get_tokenizer function to create a tokenizer pipeline text_pipeline = get_tokenizer(tokenize) # Load the WikiText-2 dataset train_data, valid_data, test_data = WikiText2() # Create DataLoaders for the train, valid, and test sets train_loader = DataLoader([text_pipeline(text) for text in train_data], batch_size=1, shuffle=True) # type: ignore valid_loader = DataLoader([text_pipeline(text) for text in valid_data], batch_size=1, shuffle=True) # type: ignore test_loader = DataLoader([text_pipeline(text) for text in test_data], batch_size=1, shuffle=True) # type: ignore return train_loader, valid_loader, test_loader ================================================ FILE: src/memory/__init__.py ================================================ from typing import Any, Optional import numpy as np import numpy.typing as npt from faiss import IndexIVF class Memory: memory_size: int = 0 embedding_dim: int = 0 index: IndexIVF def add(self, embeddings: npt.NDArray[Any]) -> None: self.index.add(embeddings) def remove(self, ids: npt.NDArray[Any]) -> None: self.index.remove_ids(ids) def update(self, ids: npt.NDArray[Any], updated_embeddings: npt.NDArray[Any]) -> None: self.remove(ids) self.add(updated_embeddings) def search(self, query_vectors: npt.NDArray[Any], k: int = 10) -> Any: raise NotImplementedError() def refresh(self) -> None: self.index.reset() ================================================ FILE: src/memory/associative.py ================================================ from typing import Any, Union import faiss import numpy as np import numpy.typing as npt import torch from . import Memory class AssociativeMemory(Memory): def __init__( self, memory_size: int, embedding_dim: int, index_type: str = "flat", num_clusters: int = 1024, m: int = 8, ef_construction: int = 100, ef_search: int = 64, use_gpu: bool = False, gpu_device: int = 0, forgetfulness_factor: float = 0.0001, ): """ Initialize the associative memory. :param memory_size: The maximum number of items the memory can store. :param embedding_dim: The dimensionality of the input embeddings. :param index_type: The type of FAISS index to use (default: 'flat'). :param num_clusters: The number of clusters to use for an IVF index (default: 1024). :param m: The number of product quantization codes to use for an IVFPQ index (default: 8). :param ef_construction: The size of the entry list for an HNSW index (default: 100). :param ef_search: The search list size for an HNSW index (default: 64). :param use_gpu: Whether to use GPU acceleration for the FAISS index (default: False). :param gpu_device: The ID of the GPU device to use (default: 0). :param forgetfulness_factor: The percentage of items to remove during random forgetting (default: 1). """ # Initialize memory parameters self.memory_size = memory_size self.embedding_dim = embedding_dim self.forgetfulness_factor = forgetfulness_factor self.use_gpu = use_gpu self.gpu_device = gpu_device # Create the appropriate Faiss index based on the specified type if index_type == "flat": # Inverted File with Flat index - a compressed index with an inverted file structure quantizer = faiss.IndexFlatL2(embedding_dim) index = faiss.IndexIVFFlat(quantizer, embedding_dim, num_clusters, faiss.METRIC_L2) index.make_direct_map() index.set_direct_map_type(faiss.DirectMap.Hashtable) elif index_type == "compressed": # Inverted File with Product Quantization index - a compressed index with a product quantization compression quantizer = faiss.IndexFlatL2(embedding_dim) index = faiss.IndexIVFPQ(quantizer, embedding_dim, num_clusters, m, 8) elif index_type == "graph": # Hierarchical Navigable Small World index - a graph-based index with a flat storage index = faiss.IndexHNSWFlat(embedding_dim, ef_construction, faiss.METRIC_L2) index.hnsw.efSearch = ef_search else: raise ValueError(f"Invalid index_type: {index_type}") self.index_type = index_type # Enable GPU support if specified if use_gpu: self.res = faiss.StandardGpuResources() self.index = faiss.index_cpu_to_gpu(self.res, gpu_device, index) else: self.index = index # Train the index with empty data self.index.train(np.zeros((max(memory_size, num_clusters), embedding_dim), dtype=np.float32)) # Initialize an empty array to store the input vectors self.input_vectors = np.zeros((memory_size, embedding_dim), dtype=np.float32) def _to_numpy(self, data: Union[npt.NDArray[Any], torch.Tensor]) -> Union[npt.NDArray[Any], torch.Tensor]: """ Convert input data to a NumPy array if it's a PyTorch tensor and not using GPU. :param data: Input data to be converted. Can be either a NumPy array or a PyTorch tensor. :return: Converted data as a NumPy array or a PyTorch tensor. """ if isinstance(data, torch.Tensor) and not self.use_gpu: return data.detach().cpu().numpy() # type: ignore return data def add(self, embeddings: Union[npt.NDArray[Any], torch.Tensor]) -> None: """ Add embeddings to the memory. :param embeddings: A 2D array of shape (n, embedding_dim) containing the embeddings to be added, where n is the number of items to add. Can be either a NumPy array or a PyTorch tensor. """ embeddings = self._to_numpy(embeddings) n_added = self.index.ntotal # Existing number of added items n_to_add = embeddings.shape[0] # Number of items to add # Update the input_vectors array with the new embeddings self.input_vectors[n_added : n_added + n_to_add] = embeddings # Add embeddings to the index if self.use_gpu: ptr = faiss.torch_utils.swig_ptr_from_FloatTensor(embeddings) self.index.add_c(n_to_add, ptr) else: self.index.add(embeddings) def remove(self, ids: Union[npt.NDArray[Any], torch.Tensor]) -> None: """ Remove embeddings with the specified IDs from the memory. :param ids: A 1D array of shape (n,) containing the indices of the items to be removed, where n is the number of items to remove. Can be either a NumPy array or a PyTorch tensor. """ ids = self._to_numpy(ids) if self.index_type != "flat": raise ValueError( f"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}" ) self.index.remove_ids(ids) # Remove input vectors from the input_vectors array self.input_vectors = np.delete(self.input_vectors, ids, axis=0) def update( self, ids: Union[npt.NDArray[Any], torch.Tensor], updated_embeddings: Union[npt.NDArray[Any], torch.Tensor] ) -> None: """ Update embeddings with the specified IDs in the memory. :param ids: A 1D array of shape (n,) containing the indices of the items to be updated, where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor. :param updated_embeddings: A 2D array of shape (n, embedding_dim) containing the updated embeddings, where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor. """ ids = self._to_numpy(ids) updated_embeddings = self._to_numpy(updated_embeddings) if self.index_type != "flat": raise ValueError( f"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}" ) self.remove(ids) self.add(updated_embeddings) def search(self, query_vectors: Any, k: int = 10) -> Any: """ Search the memory for the top k closest embeddings to the query vectors. :param query_vectors: A 2D array or tensor of shape (n, embedding_dim) containing the query vectors, where n is the number of query vectors. :param k: The number of nearest neighbors to return for each query. :return: A tuple containing two 2D arrays or tensors for indices and distances, both of shape (n, k). """ if self.use_gpu and isinstance(query_vectors, torch.Tensor): n_query = query_vectors.shape[0] distances = torch.empty((n_query, k), device=self.gpu_device) # type: ignore indices = torch.empty((n_query, k), dtype=torch.long, device=self.gpu_device) # type: ignore ptr_query = faiss.torch_utils.swig_ptr_from_FloatTensor(query_vectors) ptr_distances = faiss.torch_utils.swig_ptr_from_FloatTensor(distances) ptr_indices = faiss.torch_utils.swig_ptr_from_LongTensor(indices) self.index.search_c(n_query, ptr_query, k, ptr_distances, ptr_indices) return indices, distances else: if isinstance(query_vectors, torch.Tensor): query_vectors = query_vectors.numpy() distances, indices = self.index.search(query_vectors, k) if isinstance(query_vectors, torch.Tensor): return torch.from_numpy(indices), torch.from_numpy(distances) else: return indices, distances def age_memory(self, decay_factor: float = 0.999) -> None: """ Age the memory embeddings by multiplying them with a decay factor. :param decay_factor: float, factor to multiply the embeddings with (default: 0.99) """ assert 0 <= decay_factor <= 1, "Decay factor should be between 0 and 1." # Get the current embeddings from the memory current_embeddings = np.zeros((self.index.ntotal, self.embedding_dim), dtype=np.float32) for idx in range(self.index.ntotal): current_embeddings[idx] = self.index.reconstruct(idx) # Apply the decay factor aged_embeddings = current_embeddings * decay_factor # Update the memory with the aged embeddings ids = np.arange(self.index.ntotal, dtype=np.int64) self.update(ids, aged_embeddings) def forget_randomly(self) -> None: """ Remove a random subset of embeddings from the memory based on the forgetfulness factor. """ assert 0 <= self.forgetfulness_factor <= 1, "Forgetfulness factor should be between 0 and 1." total_embeddings = self.index.ntotal num_embeddings_to_forget = int(total_embeddings * self.forgetfulness_factor) # Select random embeddings to forget ids_to_forget = np.random.choice(total_embeddings, size=num_embeddings_to_forget, replace=False) # Remove the selected embeddings from the memory self.remove(ids_to_forget) def garbage_collect(self, threshold: float = 1e-6) -> npt.NDArray[Any]: """ Remove nearly zero vectors from the memory and return the indices of the empty vectors. Parameters: threshold (float): Threshold value to consider a vector nearly zero. Returns: npt.NDArray[Any]: Indices of the empty vectors. """ # Fetch all embeddings from the memory embeddings = self.get_all_embeddings() # Calculate the L2 norms of the embeddings norms = np.linalg.norm(embeddings, axis=1) # Identify nearly zero vectors based on the threshold nearly_zero_vectors = np.where(norms < threshold)[0] # Remove nearly zero vectors from the memory self.remove(nearly_zero_vectors) return nearly_zero_vectors def get_all_embeddings(self) -> npt.NDArray[Any]: """ Retrieve all embeddings stored in the memory. Returns: npt.NDArray[Any]: A 2D array of shape (n, embedding_dim) containing all stored embeddings, where n is the number of stored items. """ # Return the stored input_vectors directly return self.input_vectors[: self.index.ntotal] # Use slicing to get only the added items def __getitem__(self, index: int) -> Any: """ Retrieve the input vector at the specified index. :param index: The index of the input vector to retrieve. :return: A 1D array of shape (embedding_dim,) containing the input vector. """ if index >= self.index.ntotal: raise IndexError("Index out of range.") return self.input_vectors[index] def size(self) -> int: """ Get the number of items stored in the memory. :return: The number of items stored in the memory. """ return int(self.index.ntotal) ================================================ FILE: src/memory/batch_associative.py ================================================ from typing import Any, List, Tuple, Union import numpy as np import numpy.typing as npt try: import torch _TORCH_AVAILABLE = True except ImportError: _TORCH_AVAILABLE = False from .associative import AssociativeMemory class BatchAssociativeMemory: def __init__( self, num_batches: int, memory_size: int, embedding_dim: int, **kwargs: Any, ) -> None: """ Initialize a batch associative memory. :param num_batches: Number of separate associative memories in this batch. :param memory_size: The maximum number of items each memory can store. :param embedding_dim: The dimensionality of the input embeddings. :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor. """ self.num_batches = num_batches self.embedding_dim = embedding_dim self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)] def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"]) -> None: """ Perform a batch of add operations. Assumes a single token per batch. :param embeddings: A 2D array of shape (num_batches, embedding_dim) containing the embeddings of single tokens to be added for each batch. """ for memory, batch_embeddings in zip(self.memories, embeddings): memory.add(batch_embeddings.reshape(1, -1)) def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -> None: """ Perform a batch of remove operations. Assumes a single token per batch. :param ids: A 1D array of shape (num_batches,) containing the indices of the items to be removed. """ for memory, batch_id in zip(self.memories, ids): memory.remove(np.array([batch_id], dtype=np.int64)) def batch_search( self, query_vectors: Union[npt.NDArray[Any], "torch.Tensor"], k: int = 10 ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]: """ Perform a batch of search operations. Assumes a single token per batch. :param query_vectors: A 2D array of shape (num_batches, embedding_dim) containing the query vectors of single tokens for each batch. :param k: The number of nearest neighbors to return for each query. :return: A list of tuples, each containing two 1D arrays of shape (1,k) for indices and distances. """ results = [] for memory, batch_query_vector in zip(self.memories, query_vectors.reshape(-1, 1, self.embedding_dim)): indices, distances = memory.search(batch_query_vector, k) results.append((indices, distances)) return results ================================================ FILE: src/memory/multitoken_batch_associative.py ================================================ from typing import Any, List, Tuple, Union import numpy as np import numpy.typing as npt try: import torch _TORCH_AVAILABLE = True except ImportError: _TORCH_AVAILABLE = False from .associative import AssociativeMemory class MultiTokenBatchAssociativeMemory: def __init__( self, num_batches: int, memory_size: int, embedding_dim: int, **kwargs: Any, ) -> None: """ Initialize a batch associative memory. :param num_batches: Number of separate associative memories in this batch. :param memory_size: The maximum number of items each memory can store. :param embedding_dim: The dimensionality of the input embeddings. :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor. """ self.num_batches = num_batches self.embedding_dim = embedding_dim self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)] def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"]) -> None: """ Perform a batch of add operations. :param embeddings: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the embeddings to be added for each batch. """ for memory, batch_embeddings in zip(self.memories, embeddings): memory.add(batch_embeddings) def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -> None: """ Perform a batch of remove operations. :param ids: A 2D array of shape (num_batches, num_tokens) containing the indices of the items to be removed. """ for memory, batch_ids in zip(self.memories, ids): memory.remove(batch_ids) def batch_search( self, query_vectors: Union[npt.NDArray[Any], "torch.Tensor"], k: int = 10 ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]: """ Perform a batch of search operations. :param query_vectors: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the query vectors for each batch. :param k: The number of nearest neighbors to return for each query. :return: A list of tuples, each containing two 2D arrays of shape (num_tokens, k) for indices and distances. """ results = [] for memory, batch_query_vectors in zip(self.memories, query_vectors): indices, distances = memory.search(batch_query_vectors, k) results.append((indices, distances)) return results ================================================ FILE: src/models/__init__.py ================================================ ================================================ FILE: src/models/gpt2_associative.py ================================================ from typing import Any import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel from memory.batch_associative import BatchAssociativeMemory class VardaGPTAssociative(nn.Module): def __init__( self, gpt2_model_name: str = "gpt2", memory_size: int = 10000, memory_dim: int = 768, index_type: str = "flat", num_clusters: int = 1024, num_search_results: int = 5, use_gpu: bool = False, batch_size: int = 1, forgetfulness_factor: float = 0.001, ): """ Initialize a GPT-2 model with associative memory. :param gpt2_model_name: The name of the GPT-2 model to load. Default is "gpt2". :param memory_size: The maximum number of items the associative memory can store. Default is 10000. :param memory_dim: The dimensionality of the embeddings stored in the associative memory. Default is 768. :param index_type: The type of index used for the associative memory. Default is "flat". :param num_clusters: The number of clusters to use for the memory if the index type is "ivf". Default is 1024. :param num_search_results: The number of search results to return from the associative memory. Default is 5. :param use_gpu: Whether to use the GPU for the model if available. Default is False. """ super(VardaGPTAssociative, self).__init__() # Set up the device for the model self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") # Load the GPT-2 model and configuration self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name) self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name) # Initialize the BatchAssociativeMemory module self.memory = BatchAssociativeMemory( num_batches=batch_size, memory_size=memory_size, embedding_dim=memory_dim, index_type=index_type, num_clusters=num_clusters, use_gpu=use_gpu, forgetfulness_factor=forgetfulness_factor, ) # Define dimensions for search results and output self.search_results_dim = memory_dim * num_search_results # Linear layers for concatenated input, storable vector, and store decision self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd) self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim) self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1) # Move all layers to the device self.to(self.device) self.memory_dim = memory_dim self.num_search_results = num_search_results self.forgetfulness_factor = forgetfulness_factor def forward(self, input_vectors: torch.Tensor) -> Any: """ Perform a forward pass through the GPT-2 model with associative memory. :param input_vectors: A 3D tensor of shape (batch_size, sequence_length, input_dim) containing the input vectors for each token in the batch. :param memory_input: A 2D tensor of shape (batch_size, memory_dim) containing the memory input for each item in the batch. If not provided, memory will not be used. :return: A 3D tensor of shape (batch_size, sequence_length, vocab_size) containing the logits from the GPT-2 model. """ input_vectors = input_vectors.to(self.device) batch_size, seq_len, _ = input_vectors.shape # Initialize search_results tensor with the correct shape search_results = torch.zeros((batch_size, seq_len, self.search_results_dim), device=self.device) # Search for relevant results for each item in the batch for t in range(seq_len): search_results_list = self.memory.batch_search(input_vectors[:, t, :].squeeze(1), self.num_search_results) retrieved_embeddings_list = [] # Retrieve and concatenate search results with input vectors for ctr, (indices, _) in enumerate(search_results_list): retrieved_embeddings = torch.cat( [ self.memory.memories[ctr][i].unsqueeze(0) if i >= 0 else torch.zeros(self.memory_dim).unsqueeze(0) for i in indices.squeeze() ], dim=0, ) # Update the corresponding search_results tensor retrieved_embeddings_list.append(retrieved_embeddings) search_results[:, t, :] = torch.cat(retrieved_embeddings_list, dim=0).view(batch_size, -1) concatenated_input = torch.cat([input_vectors, search_results], dim=-1) input_vectors = self.fc(concatenated_input) # Pass input_vectors through GPT-2 model's transformer and obtain hidden states transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors) hidden_states = transformer_outputs.last_hidden_state # Get logits from hidden states logits = self.gpt2_model.lm_head(hidden_states) # Calculate storable vector and store decision storable_vector = self.fc_storable_vector(hidden_states) store_decision = self.fc_store_decision(hidden_states) # Store the storable_vector in the associative memory if the store_decision is affirmative store_threshold = 0.5 # Define a threshold for store decision store_mask = (store_decision > store_threshold).float() storable_vector_to_store = storable_vector * store_mask for i in range(seq_len): storable_vector_to_store_i = storable_vector_to_store[:, i, :].view(batch_size, -1).detach() self.memory.batch_add(storable_vector_to_store_i) # Randomly forget items from the memory with a specified probability for memory in self.memory.memories: memory.forget_randomly() return logits ================================================ FILE: src/models/gpt2_working.py ================================================ from typing import Any, Optional import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel from memory.associative import AssociativeMemory class VardaGPTWorking(nn.Module): def __init__( self, gpt2_model_name: str = "gpt2-small", memory_size: int = 10000, memory_dim: int = 768, index_type: str = "flat", num_clusters: int = 1024, num_search_results: int = 5, use_gpu: bool = False, ): super(VardaGPTWorking, self).__init__() # Set up the device for the model self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") # Load the GPT-2 model and configuration self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name) self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name) # Initialize the AssociativeMemory module self.memory = AssociativeMemory( memory_size=memory_size, embedding_dim=memory_dim, index_type=index_type, num_clusters=num_clusters, use_gpu=use_gpu, ) # Define dimensions for search results and output self.search_results_dim = memory_dim * num_search_results # Linear layers for concatenated input, storable vector, store decision, delete decision, and deletable vector self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd) self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim) self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1) # Move all layers to the device self.to(self.device) self.num_search_results = num_search_results def forward(self, input_vectors: torch.Tensor, memory_input: Optional[torch.Tensor] = None) -> Any: input_vectors = input_vectors.to(self.device) # Search for relevant results if memory_input is provided if memory_input is not None: indices, distances = self.memory.search(memory_input.cpu().numpy()) # Retrieve and concatenate search results with input vectors search_results = self.memory.get_all_embeddings()[indices].reshape(-1, self.search_results_dim) search_results = torch.tensor(search_results).to(self.device) concatenated_input = torch.cat([input_vectors, search_results], dim=-1) # Pass concatenated input through linear layer input_vectors = self.fc(concatenated_input) # Pass input_vectors through GPT-2 model's transformer and obtain hidden states transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors) hidden_states = transformer_outputs.last_hidden_state # Get logits from hidden states logits = self.gpt2_model.lm_head(hidden_states) # Calculate storable vector, store decision, delete decision, and deletable vector storable_vector = self.fc_storable_vector(hidden_states) store_decision = self.fc_store_decision(hidden_states) # Store the storable_vector in the associative memory if the store_decision is affirmative store_threshold = 0.5 # Define a threshold for store decision store_mask = (store_decision > store_threshold).float() storable_vector_to_store = storable_vector * store_mask self.memory.add(storable_vector_to_store) return logits ================================================ FILE: src/train.py ================================================ import argparse import time from typing import Any import torch import torch.optim as optim from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from rich.progress import track from rich.table import Table from rich.theme import Theme from torch.utils.data import DataLoader from data import load_wikitext2 from models.gpt2_associative import VardaGPTAssociative def train( model: VardaGPTAssociative, train_loader: DataLoader[Any], valid_loader: DataLoader[Any], epochs: int, lr: float, device: torch.device, ) -> None: """ Train the model with the given training and validation data. :param model: The VardaGPTAssociative model to be trained. :param train_loader: DataLoader for the training data. :param valid_loader: DataLoader for the validation data. :param epochs: Number of epochs to train the model. :param lr: Learning rate for the optimizer. :param device: The device to use for training (CPU or GPU). """ # Initialize the optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=lr) loss_function = torch.nn.CrossEntropyLoss() # Create a console object for printing colorful prompts theme = Theme({"info": "dim green", "warning": "yellow", "error": "bold red"}) console = Console(theme=theme) # Training loop for epoch in range(epochs): model.train() epoch_loss = 0.0 epoch_start_time = time.time() for _, batch in enumerate(track(train_loader, description=f"[bold][info]Epoch {epoch + 1}")): # Move the input to the device input_vectors = batch.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass logits = model(input_vectors) # Calculate loss loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1)) # Backward pass loss.backward() # Update weights optimizer.step() epoch_loss += loss.item() # Calculate average epoch loss average_epoch_loss = epoch_loss / len(train_loader) epoch_time = time.time() - epoch_start_time # Validation model.eval() valid_loss = 0.0 with torch.no_grad(): for _, batch in enumerate(valid_loader): input_vectors = batch.to(device) logits = model(input_vectors) loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1)) valid_loss += loss.item() # Calculate average validation loss average_valid_loss = valid_loss / len(valid_loader) # Print epoch summary table = Table(title=f"Epoch {epoch + 1} Summary") table.add_column("Metric", style="bold") table.add_column("Value", style="bold") table.add_row("Training Loss", f"{average_epoch_loss:.4f}") table.add_row("Validation Loss", f"{average_valid_loss:.4f}") table.add_row("Time", f"{epoch_time:.2f} seconds") console.print(table) if __name__ == "__main__": console = Console() console.print(Panel.fit("[bold blue]VardaGPTAssociative Training Script[/bold blue]")) description = """\ This script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval. """ console.print(Markdown(description)) parser = argparse.ArgumentParser(description="Train VardaGPTAssociative model on WikiText-2 dataset") parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train the model") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer") parser.add_argument( "--memory_size", type=int, default=10000, help="Maximum number of items the associative memory can store" ) parser.add_argument( "--memory_dim", type=int, default=768, help="Dimensionality of the embeddings stored in the associative memory" ) parser.add_argument("--index_type", type=str, default="flat", help="Type of index used for the associative memory") parser.add_argument( "--num_clusters", type=int, default=1024, help="Number of clusters to use for the memory if the index type is 'ivf'", ) parser.add_argument( "--num_search_results", type=int, default=5, help="Number of search results to return from the associative memory", ) parser.add_argument("--use_gpu", action="store_true", help="Whether to use the GPU for the model if available") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training") parser.add_argument( "--forgetfulness_factor", type=float, default=0.001, help="Forgetfulness factor for the associative memory" ) args = parser.parse_args() console.print("[bold green]Training settings:[/bold green]") console.print(f" Epochs: {args.epochs}") console.print(f" Learning rate: {args.learning_rate}") model = VardaGPTAssociative( gpt2_model_name="gpt2", memory_size=args.memory_size, memory_dim=args.memory_dim, index_type=args.index_type, num_clusters=args.num_clusters, num_search_results=args.num_search_results, use_gpu=args.use_gpu, batch_size=args.batch_size, forgetfulness_factor=args.forgetfulness_factor, ) # Move the model to the device device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu") model.to(device) train_loader, valid_loader, test_loader = load_wikitext2() # Train the model train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device) ================================================ FILE: src/train_parallel.py ================================================ import argparse import time from typing import Any, Union import torch import torch.optim as optim from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from rich.progress import track from rich.table import Table from rich.theme import Theme from torch.utils.data import DataLoader from data import load_wikitext2 from models.gpt2_associative import VardaGPTAssociative def train( model: Union[VardaGPTAssociative, torch.nn.DataParallel], train_loader: DataLoader[Any], valid_loader: DataLoader[Any], epochs: int, lr: float, device: torch.device, ) -> None: """ Train the model with the given training and validation data. :param model: The VardaGPTAssociative model to be trained. :param train_loader: DataLoader for the training data. :param valid_loader: DataLoader for the validation data. :param epochs: Number of epochs to train the model. :param lr: Learning rate for the optimizer. :param device: The device to use for training (CPU or GPU). """ # Initialize the optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=lr) loss_function = torch.nn.CrossEntropyLoss() # Create a console object for printing colorful prompts theme = Theme({"info": "dim green", "warning": "yellow", "error": "bold red"}) console = Console(theme=theme) # Training loop for epoch in range(epochs): model.train() epoch_loss = 0.0 epoch_start_time = time.time() for _, batch in enumerate(track(train_loader, description=f"[bold][info]Epoch {epoch + 1}")): # Move the input to the device input_vectors = batch.to(device) # Zero the gradients optimizer.zero_grad() # Forward pass logits = model(input_vectors) # Calculate loss loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1)) # Backward pass loss.backward() # Update weights optimizer.step() epoch_loss += loss.item() # Calculate average epoch loss average_epoch_loss = epoch_loss / len(train_loader) epoch_time = time.time() - epoch_start_time # Validation model.eval() valid_loss = 0.0 with torch.no_grad(): for _, batch in enumerate(valid_loader): input_vectors = batch.to(device) logits = model(input_vectors) loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1)) valid_loss += loss.item() # Calculate average validation loss average_valid_loss = valid_loss / len(valid_loader) # Print epoch summary table = Table(title=f"Epoch {epoch + 1} Summary") table.add_column("Metric", style="bold") table.add_column("Value", style="bold") table.add_row("Training Loss", f"{average_epoch_loss:.4f}") table.add_row("Validation Loss", f"{average_valid_loss:.4f}") table.add_row("Time", f"{epoch_time:.2f} seconds") console.print(table) if __name__ == "__main__": console = Console() console.print(Panel.fit("[bold blue]VardaGPTAssociative Training Script[/bold blue]")) description = """\ This script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval. """ console.print(Markdown(description)) parser = argparse.ArgumentParser(description="Train VardaGPTAssociative model on WikiText-2 dataset") parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train the model") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer") parser.add_argument( "--memory_size", type=int, default=10000, help="Maximum number of items the associative memory can store" ) parser.add_argument( "--memory_dim", type=int, default=768, help="Dimensionality of the embeddings stored in the associative memory" ) parser.add_argument("--index_type", type=str, default="flat", help="Type of index used for the associative memory") parser.add_argument( "--num_clusters", type=int, default=1024, help="Number of clusters to use for the memory if the index type is 'ivf'", ) parser.add_argument( "--num_search_results", type=int, default=5, help="Number of search results to return from the associative memory", ) parser.add_argument("--use_gpu", action="store_true", help="Whether to use the GPU for the model if available") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training") parser.add_argument( "--forgetfulness_factor", type=float, default=0.001, help="Forgetfulness factor for the associative memory" ) args = parser.parse_args() console.print("[bold green]Training settings:[/bold green]") console.print(f" Epochs: {args.epochs}") console.print(f" Learning rate: {args.learning_rate}") model = VardaGPTAssociative( gpt2_model_name="gpt2", memory_size=args.memory_size, memory_dim=args.memory_dim, index_type=args.index_type, num_clusters=args.num_clusters, num_search_results=args.num_search_results, use_gpu=args.use_gpu, batch_size=args.batch_size, forgetfulness_factor=args.forgetfulness_factor, ) # Move the model to the device device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu") if torch.cuda.device_count() > 1 and args.use_gpu: print(f"Using {torch.cuda.device_count()} GPUs for training.") model = torch.nn.DataParallel(model) # type: ignore model.to(device) train_loader, valid_loader, test_loader = load_wikitext2() # Train the model train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device) ================================================ FILE: test/__init__.py ================================================ ================================================ FILE: test/memory/__init__..py ================================================ ================================================ FILE: test/memory/test_associative.py ================================================ # type: ignore import numpy as np import torch import pytest from src.memory.associative import AssociativeMemory @pytest.fixture(params=["numpy", "torch"]) def embeddings(request): embeddings_np = np.random.rand(1000, 768).astype(np.float32) if request.param == "numpy": return embeddings_np else: return torch.from_numpy(embeddings_np) def test_add_and_search(embeddings): memory = AssociativeMemory(memory_size=50000, embedding_dim=768) memory.add(embeddings) query_vectors = embeddings[:5] indices, distances = memory.search(query_vectors) assert indices.shape == (5, 10) def test_update(embeddings): memory = AssociativeMemory(memory_size=50000, embedding_dim=768) memory.add(embeddings) element_id = np.array([0], dtype=np.int64) updated_embedding = embeddings[:1] memory.update(element_id, updated_embedding) query_vectors = embeddings[1:2] indices, distances = memory.search(query_vectors) assert 0 not in indices[0] def test_refresh_memory(embeddings): memory = AssociativeMemory(memory_size=50000, embedding_dim=768) memory.add(embeddings) memory.refresh() query_vectors = embeddings[:5] indices, distances = memory.search(query_vectors) assert indices.shape == (5, 10) def test_getitem(embeddings): memory = AssociativeMemory(memory_size=50000, embedding_dim=768) memory.add(embeddings) # Test __getitem__ method retrieved_vector = memory[5] assert np.allclose(retrieved_vector, embeddings[5]) ================================================ FILE: test/memory/test_associative_batch.py ================================================ # type: ignore import numpy as np import pytest try: import torch except ImportError: pass from src.memory.batch_associative import BatchAssociativeMemory @pytest.fixture def batch_memory(): num_batches = 3 memory_size = 1000 embedding_dim = 128 return BatchAssociativeMemory(num_batches, memory_size, embedding_dim) @pytest.fixture(params=["numpy", "torch"]) def tensor_type(request): return request.param def create_tensor(tensor_type, data): if tensor_type == "numpy": return data else: return torch.from_numpy(data) def test_batch_add(batch_memory, tensor_type): # Create integer batched embeddings embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32) embeddings = create_tensor(tensor_type, embeddings_np) # Add embeddings to the batch memory batch_memory.batch_add(embeddings) # Check if the embeddings are added to the corresponding memories for i, memory in enumerate(batch_memory.memories): all_embeddings = memory.get_all_embeddings() assert all_embeddings.shape == (1, batch_memory.embedding_dim) assert np.array_equal(embeddings_np[i], all_embeddings[0]) def test_batch_remove(batch_memory, tensor_type): # Create integer batched embeddings embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32) embeddings = create_tensor(tensor_type, embeddings_np) # Add embeddings to the batch memory batch_memory.batch_add(embeddings) # Remove embeddings by index indices_to_remove_np = np.zeros(batch_memory.num_batches) indices_to_remove = create_tensor(tensor_type, indices_to_remove_np) batch_memory.batch_remove(indices_to_remove) # Check if the embeddings are removed from the corresponding memories for _, memory in enumerate(batch_memory.memories): all_embeddings = memory.get_all_embeddings() all_embeddings.shape = (0, batch_memory.embedding_dim) def test_batch_search(batch_memory, tensor_type): # Create a specific vector specific_vector_np = np.ones((1, batch_memory.embedding_dim), dtype=np.float32) specific_vector = create_tensor(tensor_type, specific_vector_np) # Create batched embeddings embeddings_np = np.random.randint(0, 10, size=(batch_memory.num_batches - 1, batch_memory.embedding_dim)).astype( np.float32 ) # Combine the specific vector with random embeddings embeddings_np = np.vstack((specific_vector_np, embeddings_np)) embeddings = create_tensor(tensor_type, embeddings_np) # Add embeddings to the batch memory batch_memory.batch_add(embeddings) # Perform batched search for the specific vector k = 1 search_results = batch_memory.batch_search(specific_vector, k) # Check if the first search result is the same as the specific vector indices, distances = search_results[0] found_vector = batch_memory.memories[0].get_all_embeddings()[indices[0][0]] assert np.allclose(specific_vector_np, found_vector) ================================================ FILE: test/memory/test_memory_features.py ================================================ # type: ignore import numpy as np import torch import pytest from src.memory.associative import AssociativeMemory @pytest.fixture(params=["numpy", "torch"]) def embeddings(request): embeddings_np = np.random.rand(1000, 128).astype(np.float32) if request.param == "numpy": return embeddings_np else: return torch.from_numpy(embeddings_np) @pytest.fixture def forgetful_memory(): memory_size = 1000 embedding_dim = 128 forgetfulness_factor = 0.9 return AssociativeMemory(memory_size, embedding_dim, forgetfulness_factor=forgetfulness_factor) @pytest.fixture def memory(): memory_size = 1000 embedding_dim = 128 return AssociativeMemory(memory_size, embedding_dim) def test_age_memory(embeddings, forgetful_memory): # Add some embeddings to the memory forgetful_memory.add(embeddings[:10]) # Age the memory forgetful_memory.age_memory(0.5) # Check if the embeddings have been aged aged_embeddings = forgetful_memory.get_all_embeddings() assert aged_embeddings.shape == embeddings[:10].shape assert np.allclose(embeddings[:10] * 0.5, aged_embeddings) def test_forget_randomly(embeddings, forgetful_memory): # Add some embeddings to the memory forgetful_memory.add(embeddings[:100]) # Forget randomly forgetful_memory.forget_randomly() # Check if the number of embeddings in memory has decreased remaining_embeddings = forgetful_memory.get_all_embeddings() expected_remaining_embeddings = int(embeddings[:100].shape[0] * (1 - forgetful_memory.forgetfulness_factor)) assert ( remaining_embeddings.shape[0] == expected_remaining_embeddings or remaining_embeddings.shape[0] == expected_remaining_embeddings + 1 ) def test_garbage_collect(embeddings, memory): # Add some nearly zero embeddings to the memory nearly_zero_embeddings = embeddings[:10] * 1e-7 memory.add(nearly_zero_embeddings) # Garbage collect removed_indices = memory.garbage_collect(threshold=1e-6) # Check if the nearly zero embeddings have been removed assert len(removed_indices) == len(nearly_zero_embeddings) remaining_embeddings = memory.get_all_embeddings() assert remaining_embeddings.shape[0] == 0 assert np.array_equal(removed_indices, np.arange(len(nearly_zero_embeddings))) ================================================ FILE: test/memory/test_memory_types.py ================================================ # type: ignore import numpy as np import torch import pytest from src.memory.associative import AssociativeMemory @pytest.fixture(params=["flat", "compressed", "graph"]) def index_type(request): return request.param @pytest.fixture(params=["numpy", "torch"]) def tensor_type(request): return request.param def to_tensor(tensor_type, array): if tensor_type == "numpy": return array elif tensor_type == "torch": return torch.from_numpy(array) else: raise ValueError("Unsupported tensor_type") def test_associative_memory_add_remove_update_search(index_type, tensor_type): memory_size = 1000 embedding_dim = 128 num_test_vectors = 100 k = 10 memory = AssociativeMemory(memory_size, embedding_dim, index_type=index_type) # Add test embeddings to memory test_embeddings = to_tensor(tensor_type, np.random.random((num_test_vectors, embedding_dim)).astype(np.float32)) memory.add(test_embeddings) # Search for closest embeddings in memory query_vectors = to_tensor(tensor_type, np.random.random((5, embedding_dim)).astype(np.float32)) search_results, search_distances = memory.search(query_vectors, k) assert search_results.shape == (query_vectors.shape[0], k) if index_type == "flat": # Remove some embeddings from memory ids_to_remove = np.array([2, 5, 10, 30, 50]) memory.remove(ids_to_remove) # Update some embeddings in memory ids_to_update = np.array([0, 1, 3, 4]) updated_embeddings = to_tensor( tensor_type, np.random.random((len(ids_to_update), embedding_dim)).astype(np.float32) ) memory.update(ids_to_update, updated_embeddings) # Check updated embeddings updated_search_results, updated_distances = memory.search(updated_embeddings, k=1) for i, _ in enumerate(ids_to_update): assert np.isclose(updated_distances[i, 0], 0, atol=1e-6) elif index_type in ["compressed", "graph"]: with pytest.raises(ValueError): memory.remove(np.array([0])) with pytest.raises(ValueError): memory.update( np.array([0]), to_tensor(tensor_type, np.random.random((1, embedding_dim)).astype(np.float32)) ) ================================================ FILE: test/memory/test_multitoken.py ================================================ # type: ignore import numpy as np import torch from src.memory.multitoken_batch_associative import MultiTokenBatchAssociativeMemory def test_batch_add(): num_batches = 2 memory_size = 5 embedding_dim = 3 num_tokens = 4 memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim) embeddings = np.random.randn(num_batches, num_tokens, embedding_dim) memory.batch_add(embeddings) for mem in memory.memories: assert mem.size() == num_tokens def test_batch_remove(): num_batches = 2 memory_size = 5 embedding_dim = 3 num_tokens = 4 memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim) embeddings = np.random.randn(num_batches, num_tokens, embedding_dim) memory.batch_add(embeddings) ids_to_remove = np.array([[0, 1], [2, 3]]) memory.batch_remove(ids_to_remove) for mem in memory.memories: assert mem.size() == num_tokens - 2 def test_batch_search(): num_batches = 2 memory_size = 5 embedding_dim = 3 num_tokens = 4 memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim) embeddings = np.random.randn(num_batches, num_tokens, embedding_dim) memory.batch_add(embeddings) query_vectors = np.random.randn(num_batches, num_tokens, embedding_dim) k = 3 search_results = memory.batch_search(query_vectors, k) for indices, distances in search_results: assert indices.shape == (num_tokens, k) assert distances.shape == (num_tokens, k) def test_torch_tensors(): num_batches = 2 memory_size = 5 embedding_dim = 3 num_tokens = 4 memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim) embeddings = torch.randn(num_batches, num_tokens, embedding_dim) memory.batch_add(embeddings) query_vectors = torch.randn(num_batches, num_tokens, embedding_dim) k = 3 search_results = memory.batch_search(query_vectors, k) for indices, distances in search_results: assert indices.shape == (num_tokens, k) assert distances.shape == (num_tokens, k) ================================================ FILE: test/models/__init__.py ================================================ ================================================ FILE: test/models/test_gpt2_associative.py ================================================ # type: ignore import pytest import torch from src.models.gpt2_associative import VardaGPTAssociative from transformers import GPT2Config @pytest.fixture def varda_gpt_associative(): return VardaGPTAssociative(gpt2_model_name="gpt2", use_gpu=False, batch_size=2) def test_initialization(varda_gpt_associative): assert isinstance(varda_gpt_associative, VardaGPTAssociative) assert varda_gpt_associative.device.type == "cpu" assert isinstance(varda_gpt_associative.gpt2_config, GPT2Config) assert varda_gpt_associative.num_search_results == 5 assert varda_gpt_associative.forgetfulness_factor == 0.001 def test_forward_pass_no_memory(varda_gpt_associative): batch_size = 2 sequence_length = 4 input_dim = varda_gpt_associative.gpt2_config.n_embd input_vectors = torch.randn(batch_size, sequence_length, input_dim) logits = varda_gpt_associative.forward(input_vectors) assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size) def test_forward_pass_with_memory(varda_gpt_associative): batch_size = 2 sequence_length = 4 input_dim = varda_gpt_associative.gpt2_config.n_embd input_vectors = torch.randn(batch_size, sequence_length, input_dim) logits = varda_gpt_associative.forward(input_vectors) assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size) ================================================ FILE: test/test_training.py ================================================