[
  {
    "path": ".github/ci.yml",
    "content": "name: Python CI\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v2\n\n    - name: Set up Python\n      uses: actions/setup-python@v2\n      with:\n        python-version: '3.10'\n\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install -r requirements.txt\n\n    - name: Run tests and coverage\n      run: |\n        pytest --cov=./src\n        coverage html\n      shell: bash\n\n    - name: Save coverage report\n      uses: actions/upload-artifact@v2\n      with:\n        name: coverage-report\n        path: htmlcov\n\n    - name: Generate coverage comment\n      uses: rabelenda/python-coverage-comment@v1\n      with:\n        artifact: coverage-report\n        artifact-type: path\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.py[cod]\n*.mo\n/venv\n.env\nenv.production\nenv.devtest\n.DS_Store\n._*\n*.log\n!.vscode/settings.json\n.vscode/launch.json\ndocker-compose.yml\ngeo\n.vscode/\n.coverage\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/psf/black\n    rev: 22.3.0\n    hooks:\n      - id: black\n        args: [--line-length=120]\n\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.1.0\n    hooks:\n      - id: check-json\n      - id: check-yaml\n      - id: end-of-file-fixer\n      - id: trailing-whitespace\n\n  - repo: https://github.com/pre-commit/mirrors-prettier\n    rev: v2.6.2\n    hooks:\n      - id: prettier\n        args:\n          - --print-width=80\n          - --prose-wrap=always\n          - --tab-width=2\n          - --single-quote=true\n          - --no-bracket-spacing=true\n          - --trailing-comma=es5\n        files: \\.(md|json)$\n        additional_dependencies:\n          - \"prettier@2.6.2\"\n\n  - repo: https://github.com/pre-commit/pygrep-hooks\n    rev: v1.9.0\n    hooks:\n      - id: python-check-blanket-noqa\n      - id: python-check-mock-methods\n      - id: python-use-type-annotations\n\n  - repo: https://github.com/pycqa/flake8\n    rev: 4.0.1\n    hooks:\n      - id: flake8\n        additional_dependencies:\n          - flake8-bugbear\n          - flake8-builtins\n          - flake8-comprehensions\n          - pep8-naming\n\n  - repo: https://github.com/pre-commit/mirrors-mypy\n    rev: v0.920\n    hooks:\n      - id: mypy\n        additional_dependencies:\n          - types-requests\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: help setup train evaluate inference precommit format clean\n\n.DEFAULT_GOAL := help\n\nhelp:\n\t@echo \"\u001b[35mVardaGPT\u001b[0m - Memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS\"\n\t@echo \"\u001b[1mUsage:\u001b[0m\"\n\t@echo \"  make \u001b[35m<command>\u001b[0m\"\n\t@echo \"\"\n\t@echo \"\u001b[1mCommands:\u001b[0m\"\n\t@echo \"  \u001b[35mhelp\u001b[0m        Display this help message\"\n\t@echo \"  \u001b[35msetup\u001b[0m       Set up the project by creating a virtual environment and installing dependencies\"\n\t@echo \"  \u001b[35mtrain\u001b[0m       Train the VardaGPT model\"\n\t@echo \"  \u001b[35mevaluate\u001b[0m    Evaluate the trained model on validation and testing sets\"\n\t@echo \"  \u001b[35minference\u001b[0m   Generate text using the memory-enhanced GPT-2 model\"\n\t@echo \"  \u001b[35mprecommit\u001b[0m   Run pre-commit hooks manually on all files\"\n\t@echo \"  \u001b[35mformat\u001b[0m      Format code using black, flake8, mypy, and prettier\"\n\t@echo \"  \u001b[35mclean\u001b[0m       Clean up the project directory by removing virtual environment and temporary files\"\n\nsetup:\n\tpython -m venv venv\n\tsource venv/bin/activate\n\tpip install -r requirements.txt\n\ntrain:\n\tsource venv/bin/activate\n\tpython src/train.py\n\ntrain-parallel:\n\tsource venv/bin/activate\n\tpython src/train_parallel.py\n\nevaluate:\n\tsource venv/bin/activate\n\tpython src/evaluate.py\n\ninference:\n\tsource venv/bin/activate\n\tpython src/inference.py --prompt \"Your prompt text here\"\n\nprecommit:\n\tpre-commit run --all-files\n\nformat:\n\tpre-commit run --all-files\n\ntest:\n\tcoverage run -m pytest --log-cli-level=DEBUG --capture=tee-sys -v .\n\nclean:\n\trm -rf venv/\n\tfind . -type f -name \"*.pyc\" -exec rm -f {} \\;\n\tfind . -type d -name \"__pycache__\" -exec rm -rf {} \\;\n"
  },
  {
    "path": "README.md",
    "content": "# VardaGPT\n\n<!-- START doctoc generated TOC please keep comment here to allow auto update -->\n<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->\n\n- [VardaGPT](#vardagpt)\n  - [TLDR - Training](#tldr---training)\n    - [Requirements](#requirements)\n    - [Usage](#usage)\n  - [Overview](#overview)\n  - [Models](#models)\n  - [Training, Evaluation, and Fine-tuning Process](#training-evaluation-and-fine-tuning-process)\n    - [1. Data Preparation](#1-data-preparation)\n    - [2. GPT-2 Model Adaptation](#2-gpt-2-model-adaptation)\n    - [3. Training](#3-training)\n    - [4. Evaluation](#4-evaluation)\n    - [5. Fine-tuning (if necessary)](#5-fine-tuning-if-necessary)\n  - [Prerequisites](#prerequisites)\n  - [Setup](#setup)\n  - [Directory Structure](#directory-structure)\n  - [Usage](#usage-1)\n    - [Data Preparation](#data-preparation)\n    - [Training](#training)\n    - [Evaluation](#evaluation)\n    - [Inference](#inference)\n  - [Contributing](#contributing)\n  - [Code Formatting and Pre-commit](#code-formatting-and-pre-commit)\n    - [Setup](#setup-1)\n    - [Using Pre-commit](#using-pre-commit)\n  - [License](#license)\n\n<!-- END doctoc generated TOC please keep comment here to allow auto update -->\n\nVardaGPT is a memory-enhanced GPT-2 model powered by Hugging Face Transformers\nand FAISS. Inspired by J.R.R. Tolkien's Silmarillion, VardaGPT aims to provide\nguidance and knowledge through its memory-augmented text generation\ncapabilities.\n\n## TLDR - Training\n\nThe `VardaGPTAssociative` model combines GPT-2 with an associative memory to\nimprove context retrieval. This repository includes a script to train this model\non the WikiText-2 dataset.\n\n### Requirements\n\n- Python 3.7+\n- PyTorch 1.8.1+\n- torchtext 0.9.1\n- transformers 4.10.0\n- rich 10.3.0\n- faiss-cpu 1.7.1\n\nTo install the required packages, you can use the following command:\n\n```bash\npip install -r requirements.txt\n```\n\n### Usage\n\nTo train the `VardaGPTAssociative` model on the WikiText-2 dataset, use the\nprovided training script (`train_varda_gpt_associative.py`). You can customize\nthe training settings by passing command-line arguments. Here's a basic example:\n\n```bash\npython train_varda_gpt_associative.py --epochs 5 --learning_rate 1e-4 --use_gpu\n```\n\nAvailable command-line arguments:\n\n- `--epochs`: Number of epochs to train the model (default: 5).\n- `--learning_rate`: Learning rate for the optimizer (default: 1e-4).\n- `--memory_size`: Maximum number of items the associative memory can store\n  (default: 10000).\n- `--memory_dim`: Dimensionality of the embeddings stored in the associative\n  memory (default: 768).\n- `--index_type`: Type of index used for the associative memory (default:\n  \"flat\").\n- `--num_clusters`: Number of clusters to use for the memory if the index type\n  is \"ivf\" (default: 1024).\n- `--num_search_results`: Number of search results to return from the\n  associative memory (default: 5).\n- `--use_gpu`: Whether to use the GPU for the model if available (default:\n  False).\n- `--batch_size`: Batch size for training (default: 1).\n- `--forgetfulness_factor`: Forgetfulness factor for the associative memory\n  (default: 0.001).\n\nDuring training, the script will periodically print the training loss,\nvalidation loss, and elapsed time for each epoch, along with a progress bar for\neach training step.\n\nAfter training, you can use the trained model for your specific use case, such\nas text generation or fine-tuning for a particular task.\n\n## Overview\n\n<details>\n  <summary>Click me</summary>\n\n```plantuml\n@startuml\n!define AWSPUML https://raw.githubusercontent.com/awslabs/aws-icons-for-plantuml/v14.0\n\nactor User\n\nskinparam component {\n  BackgroundColor<<Data Preparation>> LightSkyBlue\n  BackgroundColor<<FAISS Memory>> Plum\n  BackgroundColor<<GPT-2 Adaptation>> LightGreen\n  BackgroundColor<<Training>> LightSalmon\n  BackgroundColor<<Inference>> LightCoral\n  BorderColor Black\n  FontName Arial\n}\n\npackage \"VardaGPT\" {\n  [Data Preparation]<<Data Preparation>> --> [FAISS Memory]<<FAISS Memory>>\n  [Data Preparation]<<Data Preparation>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>\n\n  [FAISS Memory]<<FAISS Memory>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>\n  [GPT-2 Adaptation]<<GPT-2 Adaptation>> --> [Training]<<Training>>\n\n  [Training]<<Training>> --> [Inference]<<Inference>>\n  [FAISS Memory]<<FAISS Memory>> --> [Inference]<<Inference>>\n\n  User --> [Data Preparation]<<Data Preparation>> : Dataset\n  User --> [Inference]<<Inference>> : Prompts\n}\n\n@enduml\n```\n\n</details>\n\n![overview](./assets/README.svg)\n\nThis diagram shows the main components of the VardaGPT project and their\ninteractions. The Data Preparation component processes the dataset and feeds it\nto both the FAISS Memory Model and the GPT-2 Model Adaptation component. The\nFAISS Memory Model generates embeddings, which are used by the GPT-2 Model\nAdaptation component to create a modified GPT-2 model. The modified GPT-2 model\nis then trained and evaluated, and the final trained model is used in the\nInference and Application component. The user provides the dataset and prompts\nfor text generation.\n\n## Models\n\nThe associative memory model:\n\n<details>\n  <summary>Click me</summary>\n\n```plantuml\n@startuml\n\nrectangle \"Input Vectors\" as input #b3e0ff\nrectangle \"Memory\" as memory #f2d7b9\nrectangle \"Concatenated Input\" as concatenated_input #f6e3c6\nrectangle \"Fully Connected Layer (fc)\" as fc #e5ebf0\nrectangle \"GPT-2 Transformer\" as transformer #c6e0b4\nrectangle \"GPT-2 LM Head\" as lm_head #c9daf8\nrectangle \"Fully Connected Layer\\n(fc_storable_vector)\" as fc_storable_vector #c9daf8\nrectangle \"Fully Connected Layer\\n(fc_store_decision)\" as fc_store_decision #c9daf8\n\ninput -down-> memory : Perform search in memory\nmemory -down-> concatenated_input : Concatenate search results with input vectors\nconcatenated_input -down-> fc : Apply fully connected layer (fc)\nfc -down-> transformer : Pass through GPT-2 transformer\ntransformer -down-> lm_head : Apply GPT-2 lm_head\ntransformer -right-> fc_storable_vector : Apply fully connected layer (fc_storable_vector)\ntransformer -right-> fc_store_decision : Apply fully connected layer (fc_store_decision)\n\nnote right of fc_storable_vector: Calculate storable vector\\n and store decision\nnote right of fc_store_decision: Store the storable_vector in\\n the associative memory if\\n the store_decision is affirmative\nnote bottom of lm_head: Return logits\n\n@enduml\n\n```\n\n</details>\n\n![model1](./assets/README_001.svg)\n\n<details>\n  <summary>Click me</summary>\n\n```plantuml\n@startuml\ntitle Forward Function\n\n!define Tensor(t,d) t + \" (\" + d + \")\"\n!define DEVICE \"device\"\n\nactor \"input_vectors\" as input_vectors\nactor \"memory_input\" as memory_input\n\nnote right of input_vectors\n  Tensor:\n  (batch_size, seq_len, embedding_dim)\nend note\n\nnote right of memory_input\n  Tensor (optional):\n  (batch_size, seq_len, embedding_dim)\nend note\n\ninput_vectors -> DEVICE\nmemory_input -> DEVICE\n\nDEVICE -> \"search(memory_input)\" as search\nsearch --> \"indices, distances\" as search_result\nnote right of search_result\n  Tensors:\n  indices: (batch_size, seq_len, num_search_results)\n  distances: (batch_size, seq_len, num_search_results)\nend note\n\nsearch_result -> \"get_all_embeddings()\" as all_embeddings\nnote right of all_embeddings\n  Tensor:\n  (memory_size, embedding_dim)\nend note\n\nall_embeddings -> \"search_results\" as search_results\nnote right of search_results\n  Tensor:\n  (batch_size, seq_len, search_results_dim)\nend note\n\nsearch_results --> \"concatenate(input_vectors, search_results)\" as concatenated_input\nnote right of concatenated_input\n  Tensor:\n  (batch_size, seq_len, embedding_dim + search_results_dim)\nend note\n\nconcatenated_input --> \"self.fc(concatenated_input)\" as fc_output\nnote right of fc_output\n  Tensor:\n  (batch_size, seq_len, embedding_dim)\nend note\n\nfc_output --> \"self.gpt2_model.transformer(inputs_embeds=input_vectors)\" as transformer_outputs\ntransformer_outputs --> \"hidden_states\" as hidden_states\nnote right of hidden_states\n  Tensor:\n  (batch_size, seq_len, embedding_dim)\nend note\n\nhidden_states --> \"self.gpt2_model.lm_head(hidden_states)\" as logits\nnote right of logits\n  Tensor:\n  (batch_size, seq_len, vocab_size)\nend note\n\nhidden_states --> \"self.fc_storable_vector(hidden_states)\" as storable_vector\nnote right of storable_vector\n  Tensor:\n  (batch_size, seq_len, memory_dim)\nend note\n\nhidden_states --> \"self.fc_store_decision(hidden_states)\" as store_decision\nnote right of store_decision\n  Tensor:\n  (batch_size, seq_len, 1)\nend note\n\nhidden_states --> \"self.fc_delete_decision(hidden_states)\" as delete_decision\nnote right of delete_decision\n  Tensor:\n  (batch_size, seq_len, num_search_results)\nend note\n\nhidden_states --> \"self.fc_deletable_vector(hidden_states)\" as deletable_vector\nnote right of deletable_vector\n  Tensor:\n  (batch_size, seq_len, memory_dim)\nend note\n\nstorable_vector --> \"self.memory.add(storable_vector_to_store)\" as add_memory\n\ndeletable_vector --> \"calculate L2 distances\" as l2_distances\nnote right of l2_distances\n  Tensor:\n  (batch_size, num_search_results)\nend note\n\nl2_distances --> \"threshold comparison\" as threshold_comparison\nnote right of threshold_comparison\n  Tensor (bool):\n  (batch_size, num_search_results)\nend note\n\nthreshold_comparison --> \"self.memory.remove(indices_to_delete_flat)\" as remove_memory\n\nlogits --> \"return logits\" as return_logits\n\n@enduml\n```\n\n</details>\n\n![model](./assets/README_002.svg)\n\n## Training, Evaluation, and Fine-tuning Process\n\n<details>\n  <summary>Click me</summary>\n\n```plantuml\n@startuml\n\nskinparam activity {\n  BackgroundColor LightSkyBlue\n  BorderColor Black\n  FontName Arial\n}\n\nstart\n\n:Data Preparation;\n\npartition \"FAISS Memory Model\" {\n  :Create FAISS Index;\n  :Encode and Decode Text Data;\n  :Test FAISS Index;\n}\n\npartition \"GPT-2 Model Adaptation\" {\n  :Load Pre-trained GPT-2 Model;\n  :Modify GPT-2 Architecture;\n  :Define Custom Loss Function;\n}\n\npartition \"Training\" {\n  :Train Adapted GPT-2 Model;\n  :Save Model Checkpoints;\n}\n\npartition \"Evaluation\" {\n  :Evaluate Model on Testing Set;\n  :Calculate Metrics;\n}\n\nif (Fine-tuning needed?) then (Yes)\n  partition \"Fine-tuning\" {\n    :Adjust Hyperparameters;\n    :Iterate Training and Evaluation;\n  }\nendif\n\npartition \"Inference and Application\" {\n  :Inference Function;\n  :API or Interface;\n}\n\nstop\n\n@enduml\n```\n\n</details>\n\n![process](./assets/README_003.svg)\n\n### 1. Data Preparation\n\n- Collect and preprocess a dataset for training, evaluation, and fine-tuning.\n- Split the dataset into training, validation, and testing sets.\n- Create data loaders for handling data.\n\n### 2. GPT-2 Model Adaptation\n\n- Load a pre-trained GPT-2 model from Hugging Face Transformers.\n- Modify the GPT-2 model architecture to incorporate the FAISS memory model.\n- Define a custom loss function that considers both the GPT-2 model's output and\n  the memory model.\n\n### 3. Training\n\n- Set up the training loop and train the adapted GPT-2 model.\n- Save model checkpoints and track training metrics (loss, perplexity, etc.).\n- Monitor the training progress, validate the model on the validation set, and\n  perform early stopping if necessary.\n\n### 4. Evaluation\n\n- Evaluate the trained model on the testing set.\n- Calculate evaluation metrics (e.g., perplexity, accuracy, F1-score).\n\n### 5. Fine-tuning (if necessary)\n\n- If the model's performance on the testing set is not satisfactory, fine-tune\n  the model with different hyperparameters, learning rates, or architectures.\n- Iterate through the training and evaluation steps until the desired\n  performance is achieved.\n\n## Prerequisites\n\n- Python 3.6 or higher\n- PyTorch\n- Hugging Face Transformers\n- FAISS (CPU or GPU version)\n\n## Setup\n\n1. Clone the repository:\n\n```bash\ngit clone https://github.com/yourusername/VardaGPT.git\ncd VardaGPT\n```\n\n2. Create and activate a virtual environment:\n\n```bash\npython -m venv venv\nsource venv/bin/activate\n```\n\n3. Install the required libraries:\n\n```bash\npip install -r requirements.txt\n```\n\n## Directory Structure\n\n- `src/`: Contains the Python source code for the project.\n- `data/`: Stores the datasets used for training and evaluation.\n- `models/`: Holds the trained models and their checkpoints.\n\n## Usage\n\n### Data Preparation\n\n1. Place your dataset in the `data/` directory.\n2. Preprocess and split your dataset into training, validation, and testing sets\n   using the provided scripts in `src/`.\n\n### Training\n\n1. Configure the training settings and model hyperparameters in the\n   `src/config.py` file.\n2. Run the training script:\n\n```bash\npython src/train.py\n```\n\n3. Monitor the training progress and save model checkpoints in the `models/`\n   directory.\n\n### Evaluation\n\n1. Evaluate the trained model on the validation and testing sets using the\n   provided evaluation script:\n\n```bash\npython src/evaluate.py\n```\n\n### Inference\n\n1. Use the provided inference script to generate text with the memory-enhanced\n   GPT-2 model:\n\n```bash\npython src/inference.py --prompt \"Your prompt text here\"\n```\n\n## Contributing\n\nFeel free to contribute to this project by submitting pull requests or opening\nissues for bug reports and feature requests.\n\n## Code Formatting and Pre-commit\n\nThis project uses `black`, `flake8`, and `mypy` for Python code formatting and\nlinting. We also use `prettier` to format JSON and Markdown files. The\nconfiguration for these tools is in the `.pre-commit-config.yaml` file.\n\n### Setup\n\n1. Install `pre-commit` if you haven't already:\n\n```bash\npip install pre-commit\n```\n\n2. Set up the git hooks:\n\n```bash\npre-commit install\n```\n\n### Using Pre-commit\n\nWhenever you commit changes, the pre-commit hooks will automatically format your\ncode and check for issues. If the hooks detect any problems, the commit will be\naborted, and you'll see a list of issues that need to be fixed. Once you've\nresolved the issues, you can try committing again.\n\nYou can also run the pre-commit hooks manually on all files:\n\n```bash\npre-commit run --all-files\n```\n\nOr run the hooks on specific files:\n\n```bash\npre-commit run --files <file1> <file2>\n```\n\nBy following this setup and using pre-commit hooks, you can ensure that the code\nin the repository remains consistently formatted and adheres to the project's\ncoding standards.\n\n## License\n\nThis project is licensed under the [MIT License](LICENSE).\n"
  },
  {
    "path": "STORY.md",
    "content": "# Story of this project 😅\n\n## Background 🤔\n\nWith all the hype around ChatGPT, I wondered how much impact ChatGPT really had.\nI mean, for a programmer, would ChatGPT be like a pair programmer? Like GitHub\nCopilot++? Or would ChatGPT totally replace programmers so that product managers\ncould tell it what feature to build, and it would just build it!\n\nImagine a bunch of product managers sitting in a sprint planning meeting where,\nafter signing off on the tasks to be done this sprint and starting the sprint,\nChatGPT was deployed on those tasks. The sprint lasted for about 2 hours, and\neveryone met again to do the next day's sprint grooming. 😆\n\n## Project Idea 💡\n\nNow, what the heck should I build to test this? Why not try attaching a memory\nmodule to a GPT? I've seen some folks on the internet complain about the \"low\nmemory\" problem of language models. I've also used FAISS and FLANN before, so I\nam familiar with how to technically achieve this. Whether it will actually work\nor not—well, my 1080Ti is on its deathbed with a broken fan, and I don't have\nthe 💸 to train this thing on AWS anyway. Let's aim for unit tests to work then.\n\n## Process 🏃\n\nOkay.\n\nI started with the project plan:\n\n![start](./assets/1.png)\n\nThen I made ChatGPT generate the project foundations, step by step, from\ncreating project directories, Makefile, README, pre-commit, vscode settings for\nthe same tools in pre-commit, setup.cfg, and a GitHub workflow to run tests. In\neach case, I had to specify exactly what I wanted it to generate.\n\n![start](./assets/2.png)\n\nYes, I made ChatGPT choose the project name as well.\n\n![start](./assets/3.png)\n\nIf I forgot something, I would go back and ask ChatGPT to add it:\n\n![start](./assets/4.png)\n\nIn fact, I found it better to let ChatGPT generate a toy-ish version of the code\nfirst, then let it add things to it step-by-step. This resulted in much better\noutput than, say, asking ChatGPT to generate production-quality code with all\nfeatures in the first go. This also gave me a way to break down my requirements\nand feed them one at a time - as I was also acting as a code-reviewer for the\ngenerated output, and so this method was also easier for me to work with.\n\n![start](./assets/5.png) ![start](./assets/6.png) ![start](./assets/7.png)\n![start](./assets/8.png) ![start](./assets/9.png)\n\nOf course, I made ChatGPT write unit tests, and if they failed, I would just\ncopy the pytest output and feed it back into ChatGPT.\n\n![start](./assets/10.png) ![start](./assets/11.png)\n\nChatGPT even figured this out!:\n\n![start](./assets/12.png)\n\nThe result - I present to you\n[VardaGPT](https://github.com/ixaxaar/vardagpt)—every inch of this repository\nwas generated by ChatGPT-4! It took a few hours, mostly around 3 weekends,\nmostly at odd times, to generate this project.\n\n## Experience 😮\n\nIt felt neither like a Copilot++ nor like the product manager scenario but\nrather all at the same time. Sometimes I was amazed at what ChatGPT was able to\nunderstand, sometimes I had to stubbornly push it to go in a certain direction,\nsometimes it generated things I did not think of, sometimes I got super\nfrustrated while making ChatGPT fix the code in a certain way.\n\nIt was more like handholding a fresh grad who had absorbed all of human\nknowledge but needed someone to tie various parts of that knowledge to create\nsomething useful. Also ChatGPT is bad at dealing with abstractions beyond 2\nlayers.\n\nChatGPT is definitely a productivity multiplier. I think it is rather a\ndifferential productivity multiplier, as it would enhance more the capabilities\nof those who already know more. If I did not understand deep learning and FAISS,\nor how projects are structured, I don't think I would have been able to pull\nthis off. On the other hand, it also has some sort of a leveling effect—I have\nnot worked on PyTorch in a while, have no idea of FAISS's new APIs, etc., but\nthese gaps were filled in by ChatGPT.\n\nFinally, it was also tiring. Imagine being reduced to giving only instructions\nand doing code review. Reading and understanding code is tiring!\n\n## Conclusion ❓\n\nIt looks like my job is safe this year. Time to generate an elaborate software\nproject and claim supremacy on my ChatGPT usage abilities to hedge against next\nyear.\n\nI wonder if by the time ChatGPT-6 comes out, would engineering teams be like,\n\"Hey, let's generate our own Grafana with a purple theme 😄\".\n\n## Aside 🦄\n\nI could not resist but add this bit. ChatGPT is great at generating Agda! Maybe\nthis would also be the ultimate tool that can be used to formalize all of pure\nmath? 😱\n\n![start](./assets/13.png)\n"
  },
  {
    "path": "requirements.txt",
    "content": "black==23.3.0\ncertifi==2022.12.7\ncharset-normalizer==3.1.0\nclick==8.1.3\ncmake==3.26.3\ncoverage==7.2.3\nexceptiongroup==1.1.1\nfaiss-cpu==1.7.3\nfilelock==3.12.0\nflake8==6.0.0\nhuggingface-hub==0.13.4\nidna==3.4\niniconfig==2.0.0\nJinja2==3.1.2\nlit==16.0.1\nmarkdown-it-py==2.2.0\nMarkupSafe==2.1.2\nmccabe==0.7.0\nmdurl==0.1.2\nmpmath==1.3.0\nmypy==1.2.0\nmypy-extensions==1.0.0\nnetworkx==3.1\nnumpy==1.24.2\nnvidia-cublas-cu11==11.10.3.66\nnvidia-cuda-cupti-cu11==11.7.101\nnvidia-cuda-nvrtc-cu11==11.7.99\nnvidia-cuda-runtime-cu11==11.7.99\nnvidia-cudnn-cu11==8.5.0.96\nnvidia-cufft-cu11==10.9.0.58\nnvidia-curand-cu11==10.2.10.91\nnvidia-cusolver-cu11==11.4.0.1\nnvidia-cusparse-cu11==11.7.4.91\nnvidia-nccl-cu11==2.14.3\nnvidia-nvtx-cu11==11.7.91\npackaging==23.1\npathspec==0.11.1\nplatformdirs==3.2.0\npluggy==1.0.0\npycodestyle==2.10.0\npydantic==1.10.7\npyflakes==3.0.1\nPygments==2.15.1\npytest==7.3.1\npytest-cov==4.0.0\nPyYAML==6.0\nregex==2023.3.23\nrequests==2.28.2\nrich==13.3.5\nsympy==1.11.1\ntokenizers==0.13.3\ntomli==2.0.1\ntorch==2.0.0\ntorchdata==0.6.0\ntorchtext==0.15.1\ntqdm==4.65.0\ntransformers==4.28.1\ntriton==2.0.0\ntyping_extensions==4.5.0\nurllib3==1.26.15\n"
  },
  {
    "path": "setup.cfg",
    "content": "[flake8]\nexclude = */__init__.py,migrations/*\nignore = E111, E114, E121, E131, W503, F405, F403, E126, E501, F841, E124, E251, E203\nmax-line-length = 120\nmax-doc-length = 120\nshow-source = true\nstatistics = false\ndoctests = true\n\n[tool.black]\nline-length = 120\n\n[mypy]\nignore_missing_imports = True\n\n[mypy-tests.*]\ndisallow_untyped_defs = False\n\n[tool:pytest]\naddopts = -p no:warnings\nignore = tests\n"
  },
  {
    "path": "src/data.py",
    "content": "from typing import Any, Tuple\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchtext.data import get_tokenizer\nfrom torchtext.datasets import WikiText2\nfrom transformers import GPT2Tokenizer\n\n\ndef load_wikitext2() -> Tuple[DataLoader[Any], DataLoader[Any], DataLoader[Any]]:\n    \"\"\"\n    Load the WikiText-2 dataset for training, validation, and testing.\n\n    :return: A tuple of three DataLoaders for train, valid, and test sets.\n    \"\"\"\n\n    # Define a tokenizer using the GPT2Tokenizer from the transformers library\n    tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n\n    # Define a tokenize function using the GPT2Tokenizer\n    def tokenize(text: str) -> torch.Tensor:\n        return tokenizer.encode(text, return_tensors=\"pt\").squeeze(0)  # type: ignore\n\n    # Use torchtext's get_tokenizer function to create a tokenizer pipeline\n    text_pipeline = get_tokenizer(tokenize)\n\n    # Load the WikiText-2 dataset\n    train_data, valid_data, test_data = WikiText2()\n\n    # Create DataLoaders for the train, valid, and test sets\n    train_loader = DataLoader([text_pipeline(text) for text in train_data], batch_size=1, shuffle=True)  # type: ignore\n    valid_loader = DataLoader([text_pipeline(text) for text in valid_data], batch_size=1, shuffle=True)  # type: ignore\n    test_loader = DataLoader([text_pipeline(text) for text in test_data], batch_size=1, shuffle=True)  # type: ignore\n\n    return train_loader, valid_loader, test_loader\n"
  },
  {
    "path": "src/memory/__init__.py",
    "content": "from typing import Any, Optional\n\nimport numpy as np\nimport numpy.typing as npt\nfrom faiss import IndexIVF\n\n\nclass Memory:\n    memory_size: int = 0\n    embedding_dim: int = 0\n    index: IndexIVF\n\n    def add(self, embeddings: npt.NDArray[Any]) -> None:\n        self.index.add(embeddings)\n\n    def remove(self, ids: npt.NDArray[Any]) -> None:\n        self.index.remove_ids(ids)\n\n    def update(self, ids: npt.NDArray[Any], updated_embeddings: npt.NDArray[Any]) -> None:\n        self.remove(ids)\n        self.add(updated_embeddings)\n\n    def search(self, query_vectors: npt.NDArray[Any], k: int = 10) -> Any:\n        raise NotImplementedError()\n\n    def refresh(self) -> None:\n        self.index.reset()\n"
  },
  {
    "path": "src/memory/associative.py",
    "content": "from typing import Any, Union\n\nimport faiss\nimport numpy as np\nimport numpy.typing as npt\nimport torch\n\nfrom . import Memory\n\n\nclass AssociativeMemory(Memory):\n    def __init__(\n        self,\n        memory_size: int,\n        embedding_dim: int,\n        index_type: str = \"flat\",\n        num_clusters: int = 1024,\n        m: int = 8,\n        ef_construction: int = 100,\n        ef_search: int = 64,\n        use_gpu: bool = False,\n        gpu_device: int = 0,\n        forgetfulness_factor: float = 0.0001,\n    ):\n        \"\"\"\n        Initialize the associative memory.\n\n        :param memory_size: The maximum number of items the memory can store.\n        :param embedding_dim: The dimensionality of the input embeddings.\n        :param index_type: The type of FAISS index to use (default: 'flat').\n        :param num_clusters: The number of clusters to use for an IVF index (default: 1024).\n        :param m: The number of product quantization codes to use for an IVFPQ index (default: 8).\n        :param ef_construction: The size of the entry list for an HNSW index (default: 100).\n        :param ef_search: The search list size for an HNSW index (default: 64).\n        :param use_gpu: Whether to use GPU acceleration for the FAISS index (default: False).\n        :param gpu_device: The ID of the GPU device to use (default: 0).\n        :param forgetfulness_factor: The percentage of items to remove during random forgetting (default: 1).\n        \"\"\"\n        # Initialize memory parameters\n        self.memory_size = memory_size\n        self.embedding_dim = embedding_dim\n        self.forgetfulness_factor = forgetfulness_factor\n        self.use_gpu = use_gpu\n        self.gpu_device = gpu_device\n\n        # Create the appropriate Faiss index based on the specified type\n        if index_type == \"flat\":\n            # Inverted File with Flat index - a compressed index with an inverted file structure\n            quantizer = faiss.IndexFlatL2(embedding_dim)\n            index = faiss.IndexIVFFlat(quantizer, embedding_dim, num_clusters, faiss.METRIC_L2)\n            index.make_direct_map()\n            index.set_direct_map_type(faiss.DirectMap.Hashtable)\n        elif index_type == \"compressed\":\n            # Inverted File with Product Quantization index - a compressed index with a product quantization compression\n            quantizer = faiss.IndexFlatL2(embedding_dim)\n            index = faiss.IndexIVFPQ(quantizer, embedding_dim, num_clusters, m, 8)\n        elif index_type == \"graph\":\n            # Hierarchical Navigable Small World index - a graph-based index with a flat storage\n            index = faiss.IndexHNSWFlat(embedding_dim, ef_construction, faiss.METRIC_L2)\n            index.hnsw.efSearch = ef_search\n        else:\n            raise ValueError(f\"Invalid index_type: {index_type}\")\n        self.index_type = index_type\n\n        # Enable GPU support if specified\n        if use_gpu:\n            self.res = faiss.StandardGpuResources()\n            self.index = faiss.index_cpu_to_gpu(self.res, gpu_device, index)\n        else:\n            self.index = index\n\n        # Train the index with empty data\n        self.index.train(np.zeros((max(memory_size, num_clusters), embedding_dim), dtype=np.float32))\n\n        # Initialize an empty array to store the input vectors\n        self.input_vectors = np.zeros((memory_size, embedding_dim), dtype=np.float32)\n\n    def _to_numpy(self, data: Union[npt.NDArray[Any], torch.Tensor]) -> Union[npt.NDArray[Any], torch.Tensor]:\n        \"\"\"\n        Convert input data to a NumPy array if it's a PyTorch tensor and not using GPU.\n\n        :param data: Input data to be converted. Can be either a NumPy array or a PyTorch tensor.\n        :return: Converted data as a NumPy array or a PyTorch tensor.\n        \"\"\"\n        if isinstance(data, torch.Tensor) and not self.use_gpu:\n            return data.detach().cpu().numpy()  # type: ignore\n        return data\n\n    def add(self, embeddings: Union[npt.NDArray[Any], torch.Tensor]) -> None:\n        \"\"\"\n        Add embeddings to the memory.\n\n        :param embeddings: A 2D array of shape (n, embedding_dim) containing the embeddings to be added,\n                           where n is the number of items to add. Can be either a NumPy array or a PyTorch tensor.\n        \"\"\"\n        embeddings = self._to_numpy(embeddings)\n        n_added = self.index.ntotal  # Existing number of added items\n        n_to_add = embeddings.shape[0]  # Number of items to add\n\n        # Update the input_vectors array with the new embeddings\n        self.input_vectors[n_added : n_added + n_to_add] = embeddings\n\n        # Add embeddings to the index\n        if self.use_gpu:\n            ptr = faiss.torch_utils.swig_ptr_from_FloatTensor(embeddings)\n            self.index.add_c(n_to_add, ptr)\n        else:\n            self.index.add(embeddings)\n\n    def remove(self, ids: Union[npt.NDArray[Any], torch.Tensor]) -> None:\n        \"\"\"\n        Remove embeddings with the specified IDs from the memory.\n\n        :param ids: A 1D array of shape (n,) containing the indices of the items to be removed,\n                    where n is the number of items to remove. Can be either a NumPy array or a PyTorch tensor.\n        \"\"\"\n        ids = self._to_numpy(ids)\n\n        if self.index_type != \"flat\":\n            raise ValueError(\n                f\"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}\"\n            )\n        self.index.remove_ids(ids)\n\n        # Remove input vectors from the input_vectors array\n        self.input_vectors = np.delete(self.input_vectors, ids, axis=0)\n\n    def update(\n        self, ids: Union[npt.NDArray[Any], torch.Tensor], updated_embeddings: Union[npt.NDArray[Any], torch.Tensor]\n    ) -> None:\n        \"\"\"\n        Update embeddings with the specified IDs in the memory.\n\n        :param ids: A 1D array of shape (n,) containing the indices of the items to be updated,\n                    where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor.\n        :param updated_embeddings: A 2D array of shape (n, embedding_dim) containing the updated embeddings,\n                    where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor.\n        \"\"\"\n        ids = self._to_numpy(ids)\n        updated_embeddings = self._to_numpy(updated_embeddings)\n\n        if self.index_type != \"flat\":\n            raise ValueError(\n                f\"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}\"\n            )\n        self.remove(ids)\n        self.add(updated_embeddings)\n\n    def search(self, query_vectors: Any, k: int = 10) -> Any:\n        \"\"\"\n        Search the memory for the top k closest embeddings to the query vectors.\n\n        :param query_vectors: A 2D array or tensor of shape (n, embedding_dim) containing the query vectors,\n                            where n is the number of query vectors.\n        :param k: The number of nearest neighbors to return for each query.\n        :return: A tuple containing two 2D arrays or tensors for indices and distances, both of shape (n, k).\n        \"\"\"\n        if self.use_gpu and isinstance(query_vectors, torch.Tensor):\n            n_query = query_vectors.shape[0]\n            distances = torch.empty((n_query, k), device=self.gpu_device)  # type: ignore\n            indices = torch.empty((n_query, k), dtype=torch.long, device=self.gpu_device)  # type: ignore\n\n            ptr_query = faiss.torch_utils.swig_ptr_from_FloatTensor(query_vectors)\n            ptr_distances = faiss.torch_utils.swig_ptr_from_FloatTensor(distances)\n            ptr_indices = faiss.torch_utils.swig_ptr_from_LongTensor(indices)\n            self.index.search_c(n_query, ptr_query, k, ptr_distances, ptr_indices)\n\n            return indices, distances\n        else:\n            if isinstance(query_vectors, torch.Tensor):\n                query_vectors = query_vectors.numpy()\n            distances, indices = self.index.search(query_vectors, k)\n            if isinstance(query_vectors, torch.Tensor):\n                return torch.from_numpy(indices), torch.from_numpy(distances)\n            else:\n                return indices, distances\n\n    def age_memory(self, decay_factor: float = 0.999) -> None:\n        \"\"\"\n        Age the memory embeddings by multiplying them with a decay factor.\n\n        :param decay_factor: float, factor to multiply the embeddings with (default: 0.99)\n        \"\"\"\n        assert 0 <= decay_factor <= 1, \"Decay factor should be between 0 and 1.\"\n\n        # Get the current embeddings from the memory\n        current_embeddings = np.zeros((self.index.ntotal, self.embedding_dim), dtype=np.float32)\n        for idx in range(self.index.ntotal):\n            current_embeddings[idx] = self.index.reconstruct(idx)\n\n        # Apply the decay factor\n        aged_embeddings = current_embeddings * decay_factor\n\n        # Update the memory with the aged embeddings\n        ids = np.arange(self.index.ntotal, dtype=np.int64)\n        self.update(ids, aged_embeddings)\n\n    def forget_randomly(self) -> None:\n        \"\"\"\n        Remove a random subset of embeddings from the memory based on the forgetfulness factor.\n        \"\"\"\n        assert 0 <= self.forgetfulness_factor <= 1, \"Forgetfulness factor should be between 0 and 1.\"\n\n        total_embeddings = self.index.ntotal\n        num_embeddings_to_forget = int(total_embeddings * self.forgetfulness_factor)\n\n        # Select random embeddings to forget\n        ids_to_forget = np.random.choice(total_embeddings, size=num_embeddings_to_forget, replace=False)\n\n        # Remove the selected embeddings from the memory\n        self.remove(ids_to_forget)\n\n    def garbage_collect(self, threshold: float = 1e-6) -> npt.NDArray[Any]:\n        \"\"\"\n        Remove nearly zero vectors from the memory and return the indices of the empty vectors.\n\n        Parameters:\n        threshold (float): Threshold value to consider a vector nearly zero.\n\n        Returns:\n        npt.NDArray[Any]: Indices of the empty vectors.\n        \"\"\"\n        # Fetch all embeddings from the memory\n        embeddings = self.get_all_embeddings()\n\n        # Calculate the L2 norms of the embeddings\n        norms = np.linalg.norm(embeddings, axis=1)\n\n        # Identify nearly zero vectors based on the threshold\n        nearly_zero_vectors = np.where(norms < threshold)[0]\n\n        # Remove nearly zero vectors from the memory\n        self.remove(nearly_zero_vectors)\n\n        return nearly_zero_vectors\n\n    def get_all_embeddings(self) -> npt.NDArray[Any]:\n        \"\"\"\n        Retrieve all embeddings stored in the memory.\n\n        Returns:\n        npt.NDArray[Any]: A 2D array of shape (n, embedding_dim) containing all stored embeddings,\n                          where n is the number of stored items.\n        \"\"\"\n        # Return the stored input_vectors directly\n        return self.input_vectors[: self.index.ntotal]  # Use slicing to get only the added items\n\n    def __getitem__(self, index: int) -> Any:\n        \"\"\"\n        Retrieve the input vector at the specified index.\n\n        :param index: The index of the input vector to retrieve.\n        :return: A 1D array of shape (embedding_dim,) containing the input vector.\n        \"\"\"\n        if index >= self.index.ntotal:\n            raise IndexError(\"Index out of range.\")\n        return self.input_vectors[index]\n\n    def size(self) -> int:\n        \"\"\"\n        Get the number of items stored in the memory.\n\n        :return: The number of items stored in the memory.\n        \"\"\"\n        return int(self.index.ntotal)\n"
  },
  {
    "path": "src/memory/batch_associative.py",
    "content": "from typing import Any, List, Tuple, Union\n\nimport numpy as np\nimport numpy.typing as npt\n\ntry:\n    import torch\n\n    _TORCH_AVAILABLE = True\nexcept ImportError:\n    _TORCH_AVAILABLE = False\n\nfrom .associative import AssociativeMemory\n\n\nclass BatchAssociativeMemory:\n    def __init__(\n        self,\n        num_batches: int,\n        memory_size: int,\n        embedding_dim: int,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Initialize a batch associative memory.\n\n        :param num_batches: Number of separate associative memories in this batch.\n        :param memory_size: The maximum number of items each memory can store.\n        :param embedding_dim: The dimensionality of the input embeddings.\n        :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor.\n        \"\"\"\n        self.num_batches = num_batches\n        self.embedding_dim = embedding_dim\n        self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)]\n\n    def batch_add(self, embeddings: Union[npt.NDArray[Any], \"torch.Tensor\"]) -> None:\n        \"\"\"\n        Perform a batch of add operations. Assumes a single token per batch.\n\n        :param embeddings: A 2D array of shape (num_batches, embedding_dim) containing the embeddings\n                           of single tokens to be added for each batch.\n        \"\"\"\n        for memory, batch_embeddings in zip(self.memories, embeddings):\n            memory.add(batch_embeddings.reshape(1, -1))\n\n    def batch_remove(self, ids: Union[npt.NDArray[Any], \"torch.Tensor\"]) -> None:\n        \"\"\"\n        Perform a batch of remove operations. Assumes a single token per batch.\n\n        :param ids: A 1D array of shape (num_batches,) containing the indices of the items to be removed.\n        \"\"\"\n        for memory, batch_id in zip(self.memories, ids):\n            memory.remove(np.array([batch_id], dtype=np.int64))\n\n    def batch_search(\n        self, query_vectors: Union[npt.NDArray[Any], \"torch.Tensor\"], k: int = 10\n    ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]:\n        \"\"\"\n        Perform a batch of search operations. Assumes a single token per batch.\n\n        :param query_vectors: A 2D array of shape (num_batches, embedding_dim) containing the query vectors\n                              of single tokens for each batch.\n        :param k: The number of nearest neighbors to return for each query.\n        :return: A list of tuples, each containing two 1D arrays of shape (1,k) for indices and distances.\n        \"\"\"\n        results = []\n        for memory, batch_query_vector in zip(self.memories, query_vectors.reshape(-1, 1, self.embedding_dim)):\n            indices, distances = memory.search(batch_query_vector, k)\n            results.append((indices, distances))\n        return results\n"
  },
  {
    "path": "src/memory/multitoken_batch_associative.py",
    "content": "from typing import Any, List, Tuple, Union\n\nimport numpy as np\nimport numpy.typing as npt\n\ntry:\n    import torch\n\n    _TORCH_AVAILABLE = True\nexcept ImportError:\n    _TORCH_AVAILABLE = False\n\nfrom .associative import AssociativeMemory\n\n\nclass MultiTokenBatchAssociativeMemory:\n    def __init__(\n        self,\n        num_batches: int,\n        memory_size: int,\n        embedding_dim: int,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Initialize a batch associative memory.\n\n        :param num_batches: Number of separate associative memories in this batch.\n        :param memory_size: The maximum number of items each memory can store.\n        :param embedding_dim: The dimensionality of the input embeddings.\n        :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor.\n        \"\"\"\n        self.num_batches = num_batches\n        self.embedding_dim = embedding_dim\n        self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)]\n\n    def batch_add(self, embeddings: Union[npt.NDArray[Any], \"torch.Tensor\"]) -> None:\n        \"\"\"\n        Perform a batch of add operations.\n\n        :param embeddings: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the embeddings\n                           to be added for each batch.\n        \"\"\"\n        for memory, batch_embeddings in zip(self.memories, embeddings):\n            memory.add(batch_embeddings)\n\n    def batch_remove(self, ids: Union[npt.NDArray[Any], \"torch.Tensor\"]) -> None:\n        \"\"\"\n        Perform a batch of remove operations.\n\n        :param ids: A 2D array of shape (num_batches, num_tokens) containing the indices of the items to be removed.\n        \"\"\"\n        for memory, batch_ids in zip(self.memories, ids):\n            memory.remove(batch_ids)\n\n    def batch_search(\n        self, query_vectors: Union[npt.NDArray[Any], \"torch.Tensor\"], k: int = 10\n    ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]:\n        \"\"\"\n        Perform a batch of search operations.\n\n        :param query_vectors: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the query vectors\n                              for each batch.\n        :param k: The number of nearest neighbors to return for each query.\n        :return: A list of tuples, each containing two 2D arrays of shape (num_tokens, k) for indices and distances.\n        \"\"\"\n        results = []\n        for memory, batch_query_vectors in zip(self.memories, query_vectors):\n            indices, distances = memory.search(batch_query_vectors, k)\n            results.append((indices, distances))\n        return results\n"
  },
  {
    "path": "src/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/gpt2_associative.py",
    "content": "from typing import Any\n\nimport torch\nimport torch.nn as nn\nfrom transformers import GPT2Config, GPT2LMHeadModel\n\nfrom memory.batch_associative import BatchAssociativeMemory\n\n\nclass VardaGPTAssociative(nn.Module):\n    def __init__(\n        self,\n        gpt2_model_name: str = \"gpt2\",\n        memory_size: int = 10000,\n        memory_dim: int = 768,\n        index_type: str = \"flat\",\n        num_clusters: int = 1024,\n        num_search_results: int = 5,\n        use_gpu: bool = False,\n        batch_size: int = 1,\n        forgetfulness_factor: float = 0.001,\n    ):\n        \"\"\"\n        Initialize a GPT-2 model with associative memory.\n\n        :param gpt2_model_name: The name of the GPT-2 model to load. Default is \"gpt2\".\n        :param memory_size: The maximum number of items the associative memory can store. Default is 10000.\n        :param memory_dim: The dimensionality of the embeddings stored in the associative memory. Default is 768.\n        :param index_type: The type of index used for the associative memory. Default is \"flat\".\n        :param num_clusters: The number of clusters to use for the memory if the index type is \"ivf\". Default is 1024.\n        :param num_search_results: The number of search results to return from the associative memory. Default is 5.\n        :param use_gpu: Whether to use the GPU for the model if available. Default is False.\n        \"\"\"\n        super(VardaGPTAssociative, self).__init__()\n\n        # Set up the device for the model\n        self.device = torch.device(\"cuda\" if use_gpu and torch.cuda.is_available() else \"cpu\")\n\n        # Load the GPT-2 model and configuration\n        self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name)\n        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)\n\n        # Initialize the BatchAssociativeMemory module\n        self.memory = BatchAssociativeMemory(\n            num_batches=batch_size,\n            memory_size=memory_size,\n            embedding_dim=memory_dim,\n            index_type=index_type,\n            num_clusters=num_clusters,\n            use_gpu=use_gpu,\n            forgetfulness_factor=forgetfulness_factor,\n        )\n\n        # Define dimensions for search results and output\n        self.search_results_dim = memory_dim * num_search_results\n\n        # Linear layers for concatenated input, storable vector, and store decision\n        self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd)\n        self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim)\n        self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1)\n\n        # Move all layers to the device\n        self.to(self.device)\n\n        self.memory_dim = memory_dim\n        self.num_search_results = num_search_results\n        self.forgetfulness_factor = forgetfulness_factor\n\n    def forward(self, input_vectors: torch.Tensor) -> Any:\n        \"\"\"\n        Perform a forward pass through the GPT-2 model with associative memory.\n\n        :param input_vectors: A 3D tensor of shape (batch_size, sequence_length, input_dim) containing\n            the input vectors for each token in the batch.\n        :param memory_input: A 2D tensor of shape (batch_size, memory_dim) containing the memory input for each item\n            in the batch. If not provided, memory will not be used.\n        :return: A 3D tensor of shape (batch_size, sequence_length, vocab_size) containing\n        the logits from the GPT-2 model.\n        \"\"\"\n        input_vectors = input_vectors.to(self.device)\n        batch_size, seq_len, _ = input_vectors.shape\n\n        # Initialize search_results tensor with the correct shape\n        search_results = torch.zeros((batch_size, seq_len, self.search_results_dim), device=self.device)\n\n        # Search for relevant results for each item in the batch\n        for t in range(seq_len):\n            search_results_list = self.memory.batch_search(input_vectors[:, t, :].squeeze(1), self.num_search_results)\n            retrieved_embeddings_list = []\n            # Retrieve and concatenate search results with input vectors\n            for ctr, (indices, _) in enumerate(search_results_list):\n                retrieved_embeddings = torch.cat(\n                    [\n                        self.memory.memories[ctr][i].unsqueeze(0)\n                        if i >= 0\n                        else torch.zeros(self.memory_dim).unsqueeze(0)\n                        for i in indices.squeeze()\n                    ],\n                    dim=0,\n                )\n                # Update the corresponding search_results tensor\n                retrieved_embeddings_list.append(retrieved_embeddings)\n            search_results[:, t, :] = torch.cat(retrieved_embeddings_list, dim=0).view(batch_size, -1)\n\n        concatenated_input = torch.cat([input_vectors, search_results], dim=-1)\n\n        input_vectors = self.fc(concatenated_input)\n\n        # Pass input_vectors through GPT-2 model's transformer and obtain hidden states\n        transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors)\n        hidden_states = transformer_outputs.last_hidden_state\n\n        # Get logits from hidden states\n        logits = self.gpt2_model.lm_head(hidden_states)\n\n        # Calculate storable vector and store decision\n        storable_vector = self.fc_storable_vector(hidden_states)\n        store_decision = self.fc_store_decision(hidden_states)\n\n        # Store the storable_vector in the associative memory if the store_decision is affirmative\n        store_threshold = 0.5  # Define a threshold for store decision\n        store_mask = (store_decision > store_threshold).float()\n        storable_vector_to_store = storable_vector * store_mask\n\n        for i in range(seq_len):\n            storable_vector_to_store_i = storable_vector_to_store[:, i, :].view(batch_size, -1).detach()\n            self.memory.batch_add(storable_vector_to_store_i)\n\n        # Randomly forget items from the memory with a specified probability\n        for memory in self.memory.memories:\n            memory.forget_randomly()\n\n        return logits\n"
  },
  {
    "path": "src/models/gpt2_working.py",
    "content": "from typing import Any, Optional\n\nimport torch\nimport torch.nn as nn\nfrom transformers import GPT2Config, GPT2LMHeadModel\n\nfrom memory.associative import AssociativeMemory\n\n\nclass VardaGPTWorking(nn.Module):\n    def __init__(\n        self,\n        gpt2_model_name: str = \"gpt2-small\",\n        memory_size: int = 10000,\n        memory_dim: int = 768,\n        index_type: str = \"flat\",\n        num_clusters: int = 1024,\n        num_search_results: int = 5,\n        use_gpu: bool = False,\n    ):\n        super(VardaGPTWorking, self).__init__()\n\n        # Set up the device for the model\n        self.device = torch.device(\"cuda\" if use_gpu and torch.cuda.is_available() else \"cpu\")\n\n        # Load the GPT-2 model and configuration\n        self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name)\n        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)\n\n        # Initialize the AssociativeMemory module\n        self.memory = AssociativeMemory(\n            memory_size=memory_size,\n            embedding_dim=memory_dim,\n            index_type=index_type,\n            num_clusters=num_clusters,\n            use_gpu=use_gpu,\n        )\n\n        # Define dimensions for search results and output\n        self.search_results_dim = memory_dim * num_search_results\n\n        # Linear layers for concatenated input, storable vector, store decision, delete decision, and deletable vector\n        self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd)\n        self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim)\n        self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1)\n\n        # Move all layers to the device\n        self.to(self.device)\n\n        self.num_search_results = num_search_results\n\n    def forward(self, input_vectors: torch.Tensor, memory_input: Optional[torch.Tensor] = None) -> Any:\n        input_vectors = input_vectors.to(self.device)\n\n        # Search for relevant results if memory_input is provided\n        if memory_input is not None:\n            indices, distances = self.memory.search(memory_input.cpu().numpy())\n\n            # Retrieve and concatenate search results with input vectors\n            search_results = self.memory.get_all_embeddings()[indices].reshape(-1, self.search_results_dim)\n            search_results = torch.tensor(search_results).to(self.device)\n            concatenated_input = torch.cat([input_vectors, search_results], dim=-1)\n\n            # Pass concatenated input through linear layer\n            input_vectors = self.fc(concatenated_input)\n\n        # Pass input_vectors through GPT-2 model's transformer and obtain hidden states\n        transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors)\n        hidden_states = transformer_outputs.last_hidden_state\n\n        # Get logits from hidden states\n        logits = self.gpt2_model.lm_head(hidden_states)\n\n        # Calculate storable vector, store decision, delete decision, and deletable vector\n        storable_vector = self.fc_storable_vector(hidden_states)\n        store_decision = self.fc_store_decision(hidden_states)\n\n        # Store the storable_vector in the associative memory if the store_decision is affirmative\n        store_threshold = 0.5  # Define a threshold for store decision\n        store_mask = (store_decision > store_threshold).float()\n        storable_vector_to_store = storable_vector * store_mask\n        self.memory.add(storable_vector_to_store)\n\n        return logits\n"
  },
  {
    "path": "src/train.py",
    "content": "import argparse\nimport time\nfrom typing import Any\n\nimport torch\nimport torch.optim as optim\nfrom rich.console import Console\nfrom rich.markdown import Markdown\nfrom rich.panel import Panel\nfrom rich.progress import track\nfrom rich.table import Table\nfrom rich.theme import Theme\nfrom torch.utils.data import DataLoader\n\nfrom data import load_wikitext2\nfrom models.gpt2_associative import VardaGPTAssociative\n\n\ndef train(\n    model: VardaGPTAssociative,\n    train_loader: DataLoader[Any],\n    valid_loader: DataLoader[Any],\n    epochs: int,\n    lr: float,\n    device: torch.device,\n) -> None:\n    \"\"\"\n    Train the model with the given training and validation data.\n\n    :param model: The VardaGPTAssociative model to be trained.\n    :param train_loader: DataLoader for the training data.\n    :param valid_loader: DataLoader for the validation data.\n    :param epochs: Number of epochs to train the model.\n    :param lr: Learning rate for the optimizer.\n    :param device: The device to use for training (CPU or GPU).\n    \"\"\"\n\n    # Initialize the optimizer and loss function\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n    loss_function = torch.nn.CrossEntropyLoss()\n\n    # Create a console object for printing colorful prompts\n    theme = Theme({\"info\": \"dim green\", \"warning\": \"yellow\", \"error\": \"bold red\"})\n    console = Console(theme=theme)\n\n    # Training loop\n    for epoch in range(epochs):\n        model.train()\n        epoch_loss = 0.0\n        epoch_start_time = time.time()\n\n        for _, batch in enumerate(track(train_loader, description=f\"[bold][info]Epoch {epoch + 1}\")):\n            # Move the input to the device\n            input_vectors = batch.to(device)\n\n            # Zero the gradients\n            optimizer.zero_grad()\n\n            # Forward pass\n            logits = model(input_vectors)\n\n            # Calculate loss\n            loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))\n\n            # Backward pass\n            loss.backward()\n\n            # Update weights\n            optimizer.step()\n\n            epoch_loss += loss.item()\n\n        # Calculate average epoch loss\n        average_epoch_loss = epoch_loss / len(train_loader)\n        epoch_time = time.time() - epoch_start_time\n\n        # Validation\n        model.eval()\n        valid_loss = 0.0\n        with torch.no_grad():\n            for _, batch in enumerate(valid_loader):\n                input_vectors = batch.to(device)\n                logits = model(input_vectors)\n                loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))\n                valid_loss += loss.item()\n\n        # Calculate average validation loss\n        average_valid_loss = valid_loss / len(valid_loader)\n\n        # Print epoch summary\n        table = Table(title=f\"Epoch {epoch + 1} Summary\")\n        table.add_column(\"Metric\", style=\"bold\")\n        table.add_column(\"Value\", style=\"bold\")\n        table.add_row(\"Training Loss\", f\"{average_epoch_loss:.4f}\")\n        table.add_row(\"Validation Loss\", f\"{average_valid_loss:.4f}\")\n        table.add_row(\"Time\", f\"{epoch_time:.2f} seconds\")\n        console.print(table)\n\n\nif __name__ == \"__main__\":\n    console = Console()\n\n    console.print(Panel.fit(\"[bold blue]VardaGPTAssociative Training Script[/bold blue]\"))\n\n    description = \"\"\"\\\nThis script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval.\n\"\"\"\n    console.print(Markdown(description))\n\n    parser = argparse.ArgumentParser(description=\"Train VardaGPTAssociative model on WikiText-2 dataset\")\n    parser.add_argument(\"--epochs\", type=int, default=5, help=\"Number of epochs to train the model\")\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"Learning rate for the optimizer\")\n    parser.add_argument(\n        \"--memory_size\", type=int, default=10000, help=\"Maximum number of items the associative memory can store\"\n    )\n    parser.add_argument(\n        \"--memory_dim\", type=int, default=768, help=\"Dimensionality of the embeddings stored in the associative memory\"\n    )\n    parser.add_argument(\"--index_type\", type=str, default=\"flat\", help=\"Type of index used for the associative memory\")\n    parser.add_argument(\n        \"--num_clusters\",\n        type=int,\n        default=1024,\n        help=\"Number of clusters to use for the memory if the index type is 'ivf'\",\n    )\n    parser.add_argument(\n        \"--num_search_results\",\n        type=int,\n        default=5,\n        help=\"Number of search results to return from the associative memory\",\n    )\n    parser.add_argument(\"--use_gpu\", action=\"store_true\", help=\"Whether to use the GPU for the model if available\")\n    parser.add_argument(\"--batch_size\", type=int, default=1, help=\"Batch size for training\")\n    parser.add_argument(\n        \"--forgetfulness_factor\", type=float, default=0.001, help=\"Forgetfulness factor for the associative memory\"\n    )\n\n    args = parser.parse_args()\n\n    console.print(\"[bold green]Training settings:[/bold green]\")\n    console.print(f\"  Epochs: {args.epochs}\")\n    console.print(f\"  Learning rate: {args.learning_rate}\")\n\n    model = VardaGPTAssociative(\n        gpt2_model_name=\"gpt2\",\n        memory_size=args.memory_size,\n        memory_dim=args.memory_dim,\n        index_type=args.index_type,\n        num_clusters=args.num_clusters,\n        num_search_results=args.num_search_results,\n        use_gpu=args.use_gpu,\n        batch_size=args.batch_size,\n        forgetfulness_factor=args.forgetfulness_factor,\n    )\n\n    # Move the model to the device\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n    model.to(device)\n\n    train_loader, valid_loader, test_loader = load_wikitext2()\n\n    # Train the model\n    train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device)\n"
  },
  {
    "path": "src/train_parallel.py",
    "content": "import argparse\nimport time\nfrom typing import Any, Union\n\nimport torch\nimport torch.optim as optim\nfrom rich.console import Console\nfrom rich.markdown import Markdown\nfrom rich.panel import Panel\nfrom rich.progress import track\nfrom rich.table import Table\nfrom rich.theme import Theme\nfrom torch.utils.data import DataLoader\n\nfrom data import load_wikitext2\nfrom models.gpt2_associative import VardaGPTAssociative\n\n\ndef train(\n    model: Union[VardaGPTAssociative, torch.nn.DataParallel],\n    train_loader: DataLoader[Any],\n    valid_loader: DataLoader[Any],\n    epochs: int,\n    lr: float,\n    device: torch.device,\n) -> None:\n    \"\"\"\n    Train the model with the given training and validation data.\n\n    :param model: The VardaGPTAssociative model to be trained.\n    :param train_loader: DataLoader for the training data.\n    :param valid_loader: DataLoader for the validation data.\n    :param epochs: Number of epochs to train the model.\n    :param lr: Learning rate for the optimizer.\n    :param device: The device to use for training (CPU or GPU).\n    \"\"\"\n\n    # Initialize the optimizer and loss function\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n    loss_function = torch.nn.CrossEntropyLoss()\n\n    # Create a console object for printing colorful prompts\n    theme = Theme({\"info\": \"dim green\", \"warning\": \"yellow\", \"error\": \"bold red\"})\n    console = Console(theme=theme)\n\n    # Training loop\n    for epoch in range(epochs):\n        model.train()\n        epoch_loss = 0.0\n        epoch_start_time = time.time()\n\n        for _, batch in enumerate(track(train_loader, description=f\"[bold][info]Epoch {epoch + 1}\")):\n            # Move the input to the device\n            input_vectors = batch.to(device)\n\n            # Zero the gradients\n            optimizer.zero_grad()\n\n            # Forward pass\n            logits = model(input_vectors)\n\n            # Calculate loss\n            loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))\n\n            # Backward pass\n            loss.backward()\n\n            # Update weights\n            optimizer.step()\n\n            epoch_loss += loss.item()\n\n        # Calculate average epoch loss\n        average_epoch_loss = epoch_loss / len(train_loader)\n        epoch_time = time.time() - epoch_start_time\n\n        # Validation\n        model.eval()\n        valid_loss = 0.0\n        with torch.no_grad():\n            for _, batch in enumerate(valid_loader):\n                input_vectors = batch.to(device)\n                logits = model(input_vectors)\n                loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))\n                valid_loss += loss.item()\n\n        # Calculate average validation loss\n        average_valid_loss = valid_loss / len(valid_loader)\n\n        # Print epoch summary\n        table = Table(title=f\"Epoch {epoch + 1} Summary\")\n        table.add_column(\"Metric\", style=\"bold\")\n        table.add_column(\"Value\", style=\"bold\")\n        table.add_row(\"Training Loss\", f\"{average_epoch_loss:.4f}\")\n        table.add_row(\"Validation Loss\", f\"{average_valid_loss:.4f}\")\n        table.add_row(\"Time\", f\"{epoch_time:.2f} seconds\")\n        console.print(table)\n\n\nif __name__ == \"__main__\":\n    console = Console()\n\n    console.print(Panel.fit(\"[bold blue]VardaGPTAssociative Training Script[/bold blue]\"))\n\n    description = \"\"\"\\\nThis script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval.\n\"\"\"\n    console.print(Markdown(description))\n\n    parser = argparse.ArgumentParser(description=\"Train VardaGPTAssociative model on WikiText-2 dataset\")\n    parser.add_argument(\"--epochs\", type=int, default=5, help=\"Number of epochs to train the model\")\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4, help=\"Learning rate for the optimizer\")\n    parser.add_argument(\n        \"--memory_size\", type=int, default=10000, help=\"Maximum number of items the associative memory can store\"\n    )\n    parser.add_argument(\n        \"--memory_dim\", type=int, default=768, help=\"Dimensionality of the embeddings stored in the associative memory\"\n    )\n    parser.add_argument(\"--index_type\", type=str, default=\"flat\", help=\"Type of index used for the associative memory\")\n    parser.add_argument(\n        \"--num_clusters\",\n        type=int,\n        default=1024,\n        help=\"Number of clusters to use for the memory if the index type is 'ivf'\",\n    )\n    parser.add_argument(\n        \"--num_search_results\",\n        type=int,\n        default=5,\n        help=\"Number of search results to return from the associative memory\",\n    )\n    parser.add_argument(\"--use_gpu\", action=\"store_true\", help=\"Whether to use the GPU for the model if available\")\n    parser.add_argument(\"--batch_size\", type=int, default=1, help=\"Batch size for training\")\n    parser.add_argument(\n        \"--forgetfulness_factor\", type=float, default=0.001, help=\"Forgetfulness factor for the associative memory\"\n    )\n\n    args = parser.parse_args()\n\n    console.print(\"[bold green]Training settings:[/bold green]\")\n    console.print(f\"  Epochs: {args.epochs}\")\n    console.print(f\"  Learning rate: {args.learning_rate}\")\n\n    model = VardaGPTAssociative(\n        gpt2_model_name=\"gpt2\",\n        memory_size=args.memory_size,\n        memory_dim=args.memory_dim,\n        index_type=args.index_type,\n        num_clusters=args.num_clusters,\n        num_search_results=args.num_search_results,\n        use_gpu=args.use_gpu,\n        batch_size=args.batch_size,\n        forgetfulness_factor=args.forgetfulness_factor,\n    )\n\n    # Move the model to the device\n    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.use_gpu else \"cpu\")\n\n    if torch.cuda.device_count() > 1 and args.use_gpu:\n        print(f\"Using {torch.cuda.device_count()} GPUs for training.\")\n        model = torch.nn.DataParallel(model)  # type: ignore\n\n    model.to(device)\n\n    train_loader, valid_loader, test_loader = load_wikitext2()\n\n    # Train the model\n    train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device)\n"
  },
  {
    "path": "test/__init__.py",
    "content": ""
  },
  {
    "path": "test/memory/__init__..py",
    "content": ""
  },
  {
    "path": "test/memory/test_associative.py",
    "content": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@pytest.fixture(params=[\"numpy\", \"torch\"])\ndef embeddings(request):\n    embeddings_np = np.random.rand(1000, 768).astype(np.float32)\n    if request.param == \"numpy\":\n        return embeddings_np\n    else:\n        return torch.from_numpy(embeddings_np)\n\n\ndef test_add_and_search(embeddings):\n    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)\n\n    memory.add(embeddings)\n\n    query_vectors = embeddings[:5]\n    indices, distances = memory.search(query_vectors)\n\n    assert indices.shape == (5, 10)\n\n\ndef test_update(embeddings):\n    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)\n\n    memory.add(embeddings)\n\n    element_id = np.array([0], dtype=np.int64)\n    updated_embedding = embeddings[:1]\n    memory.update(element_id, updated_embedding)\n\n    query_vectors = embeddings[1:2]\n    indices, distances = memory.search(query_vectors)\n\n    assert 0 not in indices[0]\n\n\ndef test_refresh_memory(embeddings):\n    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)\n\n    memory.add(embeddings)\n\n    memory.refresh()\n\n    query_vectors = embeddings[:5]\n    indices, distances = memory.search(query_vectors)\n\n    assert indices.shape == (5, 10)\n\n\ndef test_getitem(embeddings):\n    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)\n\n    memory.add(embeddings)\n\n    # Test __getitem__ method\n    retrieved_vector = memory[5]\n    assert np.allclose(retrieved_vector, embeddings[5])\n"
  },
  {
    "path": "test/memory/test_associative_batch.py",
    "content": "# type: ignore\nimport numpy as np\nimport pytest\n\ntry:\n    import torch\nexcept ImportError:\n    pass\n\nfrom src.memory.batch_associative import BatchAssociativeMemory\n\n\n@pytest.fixture\ndef batch_memory():\n    num_batches = 3\n    memory_size = 1000\n    embedding_dim = 128\n    return BatchAssociativeMemory(num_batches, memory_size, embedding_dim)\n\n\n@pytest.fixture(params=[\"numpy\", \"torch\"])\ndef tensor_type(request):\n    return request.param\n\n\ndef create_tensor(tensor_type, data):\n    if tensor_type == \"numpy\":\n        return data\n    else:\n        return torch.from_numpy(data)\n\n\ndef test_batch_add(batch_memory, tensor_type):\n    # Create integer batched embeddings\n    embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32)\n    embeddings = create_tensor(tensor_type, embeddings_np)\n\n    # Add embeddings to the batch memory\n    batch_memory.batch_add(embeddings)\n\n    # Check if the embeddings are added to the corresponding memories\n    for i, memory in enumerate(batch_memory.memories):\n        all_embeddings = memory.get_all_embeddings()\n        assert all_embeddings.shape == (1, batch_memory.embedding_dim)\n        assert np.array_equal(embeddings_np[i], all_embeddings[0])\n\n\ndef test_batch_remove(batch_memory, tensor_type):\n    # Create integer batched embeddings\n    embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32)\n    embeddings = create_tensor(tensor_type, embeddings_np)\n\n    # Add embeddings to the batch memory\n    batch_memory.batch_add(embeddings)\n\n    # Remove embeddings by index\n    indices_to_remove_np = np.zeros(batch_memory.num_batches)\n    indices_to_remove = create_tensor(tensor_type, indices_to_remove_np)\n    batch_memory.batch_remove(indices_to_remove)\n\n    # Check if the embeddings are removed from the corresponding memories\n    for _, memory in enumerate(batch_memory.memories):\n        all_embeddings = memory.get_all_embeddings()\n        all_embeddings.shape = (0, batch_memory.embedding_dim)\n\n\ndef test_batch_search(batch_memory, tensor_type):\n    # Create a specific vector\n    specific_vector_np = np.ones((1, batch_memory.embedding_dim), dtype=np.float32)\n    specific_vector = create_tensor(tensor_type, specific_vector_np)\n\n    # Create batched embeddings\n    embeddings_np = np.random.randint(0, 10, size=(batch_memory.num_batches - 1, batch_memory.embedding_dim)).astype(\n        np.float32\n    )\n\n    # Combine the specific vector with random embeddings\n    embeddings_np = np.vstack((specific_vector_np, embeddings_np))\n    embeddings = create_tensor(tensor_type, embeddings_np)\n\n    # Add embeddings to the batch memory\n    batch_memory.batch_add(embeddings)\n\n    # Perform batched search for the specific vector\n    k = 1\n    search_results = batch_memory.batch_search(specific_vector, k)\n\n    # Check if the first search result is the same as the specific vector\n    indices, distances = search_results[0]\n    found_vector = batch_memory.memories[0].get_all_embeddings()[indices[0][0]]\n    assert np.allclose(specific_vector_np, found_vector)\n"
  },
  {
    "path": "test/memory/test_memory_features.py",
    "content": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@pytest.fixture(params=[\"numpy\", \"torch\"])\ndef embeddings(request):\n    embeddings_np = np.random.rand(1000, 128).astype(np.float32)\n    if request.param == \"numpy\":\n        return embeddings_np\n    else:\n        return torch.from_numpy(embeddings_np)\n\n\n@pytest.fixture\ndef forgetful_memory():\n    memory_size = 1000\n    embedding_dim = 128\n    forgetfulness_factor = 0.9\n    return AssociativeMemory(memory_size, embedding_dim, forgetfulness_factor=forgetfulness_factor)\n\n\n@pytest.fixture\ndef memory():\n    memory_size = 1000\n    embedding_dim = 128\n    return AssociativeMemory(memory_size, embedding_dim)\n\n\ndef test_age_memory(embeddings, forgetful_memory):\n    # Add some embeddings to the memory\n    forgetful_memory.add(embeddings[:10])\n\n    # Age the memory\n    forgetful_memory.age_memory(0.5)\n\n    # Check if the embeddings have been aged\n    aged_embeddings = forgetful_memory.get_all_embeddings()\n    assert aged_embeddings.shape == embeddings[:10].shape\n    assert np.allclose(embeddings[:10] * 0.5, aged_embeddings)\n\n\ndef test_forget_randomly(embeddings, forgetful_memory):\n    # Add some embeddings to the memory\n    forgetful_memory.add(embeddings[:100])\n\n    # Forget randomly\n    forgetful_memory.forget_randomly()\n\n    # Check if the number of embeddings in memory has decreased\n    remaining_embeddings = forgetful_memory.get_all_embeddings()\n    expected_remaining_embeddings = int(embeddings[:100].shape[0] * (1 - forgetful_memory.forgetfulness_factor))\n    assert (\n        remaining_embeddings.shape[0] == expected_remaining_embeddings\n        or remaining_embeddings.shape[0] == expected_remaining_embeddings + 1\n    )\n\n\ndef test_garbage_collect(embeddings, memory):\n    # Add some nearly zero embeddings to the memory\n    nearly_zero_embeddings = embeddings[:10] * 1e-7\n    memory.add(nearly_zero_embeddings)\n\n    # Garbage collect\n    removed_indices = memory.garbage_collect(threshold=1e-6)\n\n    # Check if the nearly zero embeddings have been removed\n    assert len(removed_indices) == len(nearly_zero_embeddings)\n    remaining_embeddings = memory.get_all_embeddings()\n    assert remaining_embeddings.shape[0] == 0\n    assert np.array_equal(removed_indices, np.arange(len(nearly_zero_embeddings)))\n"
  },
  {
    "path": "test/memory/test_memory_types.py",
    "content": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@pytest.fixture(params=[\"flat\", \"compressed\", \"graph\"])\ndef index_type(request):\n    return request.param\n\n\n@pytest.fixture(params=[\"numpy\", \"torch\"])\ndef tensor_type(request):\n    return request.param\n\n\ndef to_tensor(tensor_type, array):\n    if tensor_type == \"numpy\":\n        return array\n    elif tensor_type == \"torch\":\n        return torch.from_numpy(array)\n    else:\n        raise ValueError(\"Unsupported tensor_type\")\n\n\ndef test_associative_memory_add_remove_update_search(index_type, tensor_type):\n    memory_size = 1000\n    embedding_dim = 128\n    num_test_vectors = 100\n    k = 10\n\n    memory = AssociativeMemory(memory_size, embedding_dim, index_type=index_type)\n\n    # Add test embeddings to memory\n    test_embeddings = to_tensor(tensor_type, np.random.random((num_test_vectors, embedding_dim)).astype(np.float32))\n    memory.add(test_embeddings)\n\n    # Search for closest embeddings in memory\n    query_vectors = to_tensor(tensor_type, np.random.random((5, embedding_dim)).astype(np.float32))\n    search_results, search_distances = memory.search(query_vectors, k)\n\n    assert search_results.shape == (query_vectors.shape[0], k)\n\n    if index_type == \"flat\":\n        # Remove some embeddings from memory\n        ids_to_remove = np.array([2, 5, 10, 30, 50])\n        memory.remove(ids_to_remove)\n\n        # Update some embeddings in memory\n        ids_to_update = np.array([0, 1, 3, 4])\n        updated_embeddings = to_tensor(\n            tensor_type, np.random.random((len(ids_to_update), embedding_dim)).astype(np.float32)\n        )\n\n        memory.update(ids_to_update, updated_embeddings)\n\n        # Check updated embeddings\n        updated_search_results, updated_distances = memory.search(updated_embeddings, k=1)\n        for i, _ in enumerate(ids_to_update):\n            assert np.isclose(updated_distances[i, 0], 0, atol=1e-6)\n\n    elif index_type in [\"compressed\", \"graph\"]:\n        with pytest.raises(ValueError):\n            memory.remove(np.array([0]))\n\n        with pytest.raises(ValueError):\n            memory.update(\n                np.array([0]), to_tensor(tensor_type, np.random.random((1, embedding_dim)).astype(np.float32))\n            )\n"
  },
  {
    "path": "test/memory/test_multitoken.py",
    "content": "# type: ignore\nimport numpy as np\nimport torch\nfrom src.memory.multitoken_batch_associative import MultiTokenBatchAssociativeMemory\n\n\ndef test_batch_add():\n    num_batches = 2\n    memory_size = 5\n    embedding_dim = 3\n    num_tokens = 4\n\n    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)\n\n    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)\n    memory.batch_add(embeddings)\n\n    for mem in memory.memories:\n        assert mem.size() == num_tokens\n\n\ndef test_batch_remove():\n    num_batches = 2\n    memory_size = 5\n    embedding_dim = 3\n    num_tokens = 4\n\n    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)\n\n    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)\n    memory.batch_add(embeddings)\n\n    ids_to_remove = np.array([[0, 1], [2, 3]])\n    memory.batch_remove(ids_to_remove)\n\n    for mem in memory.memories:\n        assert mem.size() == num_tokens - 2\n\n\ndef test_batch_search():\n    num_batches = 2\n    memory_size = 5\n    embedding_dim = 3\n    num_tokens = 4\n\n    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)\n\n    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)\n    memory.batch_add(embeddings)\n\n    query_vectors = np.random.randn(num_batches, num_tokens, embedding_dim)\n    k = 3\n    search_results = memory.batch_search(query_vectors, k)\n\n    for indices, distances in search_results:\n        assert indices.shape == (num_tokens, k)\n        assert distances.shape == (num_tokens, k)\n\n\ndef test_torch_tensors():\n    num_batches = 2\n    memory_size = 5\n    embedding_dim = 3\n    num_tokens = 4\n\n    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)\n\n    embeddings = torch.randn(num_batches, num_tokens, embedding_dim)\n    memory.batch_add(embeddings)\n\n    query_vectors = torch.randn(num_batches, num_tokens, embedding_dim)\n    k = 3\n    search_results = memory.batch_search(query_vectors, k)\n\n    for indices, distances in search_results:\n        assert indices.shape == (num_tokens, k)\n        assert distances.shape == (num_tokens, k)\n"
  },
  {
    "path": "test/models/__init__.py",
    "content": ""
  },
  {
    "path": "test/models/test_gpt2_associative.py",
    "content": "# type: ignore\nimport pytest\nimport torch\nfrom src.models.gpt2_associative import VardaGPTAssociative\nfrom transformers import GPT2Config\n\n\n@pytest.fixture\ndef varda_gpt_associative():\n    return VardaGPTAssociative(gpt2_model_name=\"gpt2\", use_gpu=False, batch_size=2)\n\n\ndef test_initialization(varda_gpt_associative):\n    assert isinstance(varda_gpt_associative, VardaGPTAssociative)\n    assert varda_gpt_associative.device.type == \"cpu\"\n    assert isinstance(varda_gpt_associative.gpt2_config, GPT2Config)\n    assert varda_gpt_associative.num_search_results == 5\n    assert varda_gpt_associative.forgetfulness_factor == 0.001\n\n\ndef test_forward_pass_no_memory(varda_gpt_associative):\n    batch_size = 2\n    sequence_length = 4\n    input_dim = varda_gpt_associative.gpt2_config.n_embd\n\n    input_vectors = torch.randn(batch_size, sequence_length, input_dim)\n    logits = varda_gpt_associative.forward(input_vectors)\n\n    assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size)\n\n\ndef test_forward_pass_with_memory(varda_gpt_associative):\n    batch_size = 2\n    sequence_length = 4\n    input_dim = varda_gpt_associative.gpt2_config.n_embd\n\n    input_vectors = torch.randn(batch_size, sequence_length, input_dim)\n    logits = varda_gpt_associative.forward(input_vectors)\n\n    assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size)\n"
  },
  {
    "path": "test/test_training.py",
    "content": ""
  }
]