Full Code of ixaxaar/VardaGPT for AI

master 191954143b16 cached
28 files
76.2 KB
19.4k tokens
67 symbols
1 requests
Download .txt
Repository: ixaxaar/VardaGPT
Branch: master
Commit: 191954143b16
Files: 28
Total size: 76.2 KB

Directory structure:
gitextract_5qwi3ryd/

├── .github/
│   └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Makefile
├── README.md
├── STORY.md
├── requirements.txt
├── setup.cfg
├── src/
│   ├── data.py
│   ├── memory/
│   │   ├── __init__.py
│   │   ├── associative.py
│   │   ├── batch_associative.py
│   │   └── multitoken_batch_associative.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── gpt2_associative.py
│   │   └── gpt2_working.py
│   ├── train.py
│   └── train_parallel.py
└── test/
    ├── __init__.py
    ├── memory/
    │   ├── __init__..py
    │   ├── test_associative.py
    │   ├── test_associative_batch.py
    │   ├── test_memory_features.py
    │   ├── test_memory_types.py
    │   └── test_multitoken.py
    ├── models/
    │   ├── __init__.py
    │   └── test_gpt2_associative.py
    └── test_training.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/ci.yml
================================================
name: Python CI

on:
  push:
    branches:
      - main
  pull_request:

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
    - uses: actions/checkout@v2

    - name: Set up Python
      uses: actions/setup-python@v2
      with:
        python-version: '3.10'

    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install -r requirements.txt

    - name: Run tests and coverage
      run: |
        pytest --cov=./src
        coverage html
      shell: bash

    - name: Save coverage report
      uses: actions/upload-artifact@v2
      with:
        name: coverage-report
        path: htmlcov

    - name: Generate coverage comment
      uses: rabelenda/python-coverage-comment@v1
      with:
        artifact: coverage-report
        artifact-type: path


================================================
FILE: .gitignore
================================================
__pycache__/
*.py[cod]
*.mo
/venv
.env
env.production
env.devtest
.DS_Store
._*
*.log
!.vscode/settings.json
.vscode/launch.json
docker-compose.yml
geo
.vscode/
.coverage


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
  - repo: https://github.com/psf/black
    rev: 22.3.0
    hooks:
      - id: black
        args: [--line-length=120]

  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.1.0
    hooks:
      - id: check-json
      - id: check-yaml
      - id: end-of-file-fixer
      - id: trailing-whitespace

  - repo: https://github.com/pre-commit/mirrors-prettier
    rev: v2.6.2
    hooks:
      - id: prettier
        args:
          - --print-width=80
          - --prose-wrap=always
          - --tab-width=2
          - --single-quote=true
          - --no-bracket-spacing=true
          - --trailing-comma=es5
        files: \.(md|json)$
        additional_dependencies:
          - "prettier@2.6.2"

  - repo: https://github.com/pre-commit/pygrep-hooks
    rev: v1.9.0
    hooks:
      - id: python-check-blanket-noqa
      - id: python-check-mock-methods
      - id: python-use-type-annotations

  - repo: https://github.com/pycqa/flake8
    rev: 4.0.1
    hooks:
      - id: flake8
        additional_dependencies:
          - flake8-bugbear
          - flake8-builtins
          - flake8-comprehensions
          - pep8-naming

  - repo: https://github.com/pre-commit/mirrors-mypy
    rev: v0.920
    hooks:
      - id: mypy
        additional_dependencies:
          - types-requests


================================================
FILE: Makefile
================================================
.PHONY: help setup train evaluate inference precommit format clean

.DEFAULT_GOAL := help

help:
	@echo "VardaGPT - Memory-enhanced GPT-2 model powered by Hugging Face Transformers and FAISS"
	@echo "Usage:"
	@echo "  make <command>"
	@echo ""
	@echo "Commands:"
	@echo "  help        Display this help message"
	@echo "  setup       Set up the project by creating a virtual environment and installing dependencies"
	@echo "  train       Train the VardaGPT model"
	@echo "  evaluate    Evaluate the trained model on validation and testing sets"
	@echo "  inference   Generate text using the memory-enhanced GPT-2 model"
	@echo "  precommit   Run pre-commit hooks manually on all files"
	@echo "  format      Format code using black, flake8, mypy, and prettier"
	@echo "  clean       Clean up the project directory by removing virtual environment and temporary files"

setup:
	python -m venv venv
	source venv/bin/activate
	pip install -r requirements.txt

train:
	source venv/bin/activate
	python src/train.py

train-parallel:
	source venv/bin/activate
	python src/train_parallel.py

evaluate:
	source venv/bin/activate
	python src/evaluate.py

inference:
	source venv/bin/activate
	python src/inference.py --prompt "Your prompt text here"

precommit:
	pre-commit run --all-files

format:
	pre-commit run --all-files

test:
	coverage run -m pytest --log-cli-level=DEBUG --capture=tee-sys -v .

clean:
	rm -rf venv/
	find . -type f -name "*.pyc" -exec rm -f {} \;
	find . -type d -name "__pycache__" -exec rm -rf {} \;


================================================
FILE: README.md
================================================
# VardaGPT

<!-- START doctoc generated TOC please keep comment here to allow auto update -->
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->

- [VardaGPT](#vardagpt)
  - [TLDR - Training](#tldr---training)
    - [Requirements](#requirements)
    - [Usage](#usage)
  - [Overview](#overview)
  - [Models](#models)
  - [Training, Evaluation, and Fine-tuning Process](#training-evaluation-and-fine-tuning-process)
    - [1. Data Preparation](#1-data-preparation)
    - [2. GPT-2 Model Adaptation](#2-gpt-2-model-adaptation)
    - [3. Training](#3-training)
    - [4. Evaluation](#4-evaluation)
    - [5. Fine-tuning (if necessary)](#5-fine-tuning-if-necessary)
  - [Prerequisites](#prerequisites)
  - [Setup](#setup)
  - [Directory Structure](#directory-structure)
  - [Usage](#usage-1)
    - [Data Preparation](#data-preparation)
    - [Training](#training)
    - [Evaluation](#evaluation)
    - [Inference](#inference)
  - [Contributing](#contributing)
  - [Code Formatting and Pre-commit](#code-formatting-and-pre-commit)
    - [Setup](#setup-1)
    - [Using Pre-commit](#using-pre-commit)
  - [License](#license)

<!-- END doctoc generated TOC please keep comment here to allow auto update -->

VardaGPT is a memory-enhanced GPT-2 model powered by Hugging Face Transformers
and FAISS. Inspired by J.R.R. Tolkien's Silmarillion, VardaGPT aims to provide
guidance and knowledge through its memory-augmented text generation
capabilities.

## TLDR - Training

The `VardaGPTAssociative` model combines GPT-2 with an associative memory to
improve context retrieval. This repository includes a script to train this model
on the WikiText-2 dataset.

### Requirements

- Python 3.7+
- PyTorch 1.8.1+
- torchtext 0.9.1
- transformers 4.10.0
- rich 10.3.0
- faiss-cpu 1.7.1

To install the required packages, you can use the following command:

```bash
pip install -r requirements.txt
```

### Usage

To train the `VardaGPTAssociative` model on the WikiText-2 dataset, use the
provided training script (`train_varda_gpt_associative.py`). You can customize
the training settings by passing command-line arguments. Here's a basic example:

```bash
python train_varda_gpt_associative.py --epochs 5 --learning_rate 1e-4 --use_gpu
```

Available command-line arguments:

- `--epochs`: Number of epochs to train the model (default: 5).
- `--learning_rate`: Learning rate for the optimizer (default: 1e-4).
- `--memory_size`: Maximum number of items the associative memory can store
  (default: 10000).
- `--memory_dim`: Dimensionality of the embeddings stored in the associative
  memory (default: 768).
- `--index_type`: Type of index used for the associative memory (default:
  "flat").
- `--num_clusters`: Number of clusters to use for the memory if the index type
  is "ivf" (default: 1024).
- `--num_search_results`: Number of search results to return from the
  associative memory (default: 5).
- `--use_gpu`: Whether to use the GPU for the model if available (default:
  False).
- `--batch_size`: Batch size for training (default: 1).
- `--forgetfulness_factor`: Forgetfulness factor for the associative memory
  (default: 0.001).

During training, the script will periodically print the training loss,
validation loss, and elapsed time for each epoch, along with a progress bar for
each training step.

After training, you can use the trained model for your specific use case, such
as text generation or fine-tuning for a particular task.

## Overview

<details>
  <summary>Click me</summary>

```plantuml
@startuml
!define AWSPUML https://raw.githubusercontent.com/awslabs/aws-icons-for-plantuml/v14.0

actor User

skinparam component {
  BackgroundColor<<Data Preparation>> LightSkyBlue
  BackgroundColor<<FAISS Memory>> Plum
  BackgroundColor<<GPT-2 Adaptation>> LightGreen
  BackgroundColor<<Training>> LightSalmon
  BackgroundColor<<Inference>> LightCoral
  BorderColor Black
  FontName Arial
}

package "VardaGPT" {
  [Data Preparation]<<Data Preparation>> --> [FAISS Memory]<<FAISS Memory>>
  [Data Preparation]<<Data Preparation>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>

  [FAISS Memory]<<FAISS Memory>> --> [GPT-2 Adaptation]<<GPT-2 Adaptation>>
  [GPT-2 Adaptation]<<GPT-2 Adaptation>> --> [Training]<<Training>>

  [Training]<<Training>> --> [Inference]<<Inference>>
  [FAISS Memory]<<FAISS Memory>> --> [Inference]<<Inference>>

  User --> [Data Preparation]<<Data Preparation>> : Dataset
  User --> [Inference]<<Inference>> : Prompts
}

@enduml
```

</details>

![overview](./assets/README.svg)

This diagram shows the main components of the VardaGPT project and their
interactions. The Data Preparation component processes the dataset and feeds it
to both the FAISS Memory Model and the GPT-2 Model Adaptation component. The
FAISS Memory Model generates embeddings, which are used by the GPT-2 Model
Adaptation component to create a modified GPT-2 model. The modified GPT-2 model
is then trained and evaluated, and the final trained model is used in the
Inference and Application component. The user provides the dataset and prompts
for text generation.

## Models

The associative memory model:

<details>
  <summary>Click me</summary>

```plantuml
@startuml

rectangle "Input Vectors" as input #b3e0ff
rectangle "Memory" as memory #f2d7b9
rectangle "Concatenated Input" as concatenated_input #f6e3c6
rectangle "Fully Connected Layer (fc)" as fc #e5ebf0
rectangle "GPT-2 Transformer" as transformer #c6e0b4
rectangle "GPT-2 LM Head" as lm_head #c9daf8
rectangle "Fully Connected Layer\n(fc_storable_vector)" as fc_storable_vector #c9daf8
rectangle "Fully Connected Layer\n(fc_store_decision)" as fc_store_decision #c9daf8

input -down-> memory : Perform search in memory
memory -down-> concatenated_input : Concatenate search results with input vectors
concatenated_input -down-> fc : Apply fully connected layer (fc)
fc -down-> transformer : Pass through GPT-2 transformer
transformer -down-> lm_head : Apply GPT-2 lm_head
transformer -right-> fc_storable_vector : Apply fully connected layer (fc_storable_vector)
transformer -right-> fc_store_decision : Apply fully connected layer (fc_store_decision)

note right of fc_storable_vector: Calculate storable vector\n and store decision
note right of fc_store_decision: Store the storable_vector in\n the associative memory if\n the store_decision is affirmative
note bottom of lm_head: Return logits

@enduml

```

</details>

![model1](./assets/README_001.svg)

<details>
  <summary>Click me</summary>

```plantuml
@startuml
title Forward Function

!define Tensor(t,d) t + " (" + d + ")"
!define DEVICE "device"

actor "input_vectors" as input_vectors
actor "memory_input" as memory_input

note right of input_vectors
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

note right of memory_input
  Tensor (optional):
  (batch_size, seq_len, embedding_dim)
end note

input_vectors -> DEVICE
memory_input -> DEVICE

DEVICE -> "search(memory_input)" as search
search --> "indices, distances" as search_result
note right of search_result
  Tensors:
  indices: (batch_size, seq_len, num_search_results)
  distances: (batch_size, seq_len, num_search_results)
end note

search_result -> "get_all_embeddings()" as all_embeddings
note right of all_embeddings
  Tensor:
  (memory_size, embedding_dim)
end note

all_embeddings -> "search_results" as search_results
note right of search_results
  Tensor:
  (batch_size, seq_len, search_results_dim)
end note

search_results --> "concatenate(input_vectors, search_results)" as concatenated_input
note right of concatenated_input
  Tensor:
  (batch_size, seq_len, embedding_dim + search_results_dim)
end note

concatenated_input --> "self.fc(concatenated_input)" as fc_output
note right of fc_output
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

fc_output --> "self.gpt2_model.transformer(inputs_embeds=input_vectors)" as transformer_outputs
transformer_outputs --> "hidden_states" as hidden_states
note right of hidden_states
  Tensor:
  (batch_size, seq_len, embedding_dim)
end note

hidden_states --> "self.gpt2_model.lm_head(hidden_states)" as logits
note right of logits
  Tensor:
  (batch_size, seq_len, vocab_size)
end note

hidden_states --> "self.fc_storable_vector(hidden_states)" as storable_vector
note right of storable_vector
  Tensor:
  (batch_size, seq_len, memory_dim)
end note

hidden_states --> "self.fc_store_decision(hidden_states)" as store_decision
note right of store_decision
  Tensor:
  (batch_size, seq_len, 1)
end note

hidden_states --> "self.fc_delete_decision(hidden_states)" as delete_decision
note right of delete_decision
  Tensor:
  (batch_size, seq_len, num_search_results)
end note

hidden_states --> "self.fc_deletable_vector(hidden_states)" as deletable_vector
note right of deletable_vector
  Tensor:
  (batch_size, seq_len, memory_dim)
end note

storable_vector --> "self.memory.add(storable_vector_to_store)" as add_memory

deletable_vector --> "calculate L2 distances" as l2_distances
note right of l2_distances
  Tensor:
  (batch_size, num_search_results)
end note

l2_distances --> "threshold comparison" as threshold_comparison
note right of threshold_comparison
  Tensor (bool):
  (batch_size, num_search_results)
end note

threshold_comparison --> "self.memory.remove(indices_to_delete_flat)" as remove_memory

logits --> "return logits" as return_logits

@enduml
```

</details>

![model](./assets/README_002.svg)

## Training, Evaluation, and Fine-tuning Process

<details>
  <summary>Click me</summary>

```plantuml
@startuml

skinparam activity {
  BackgroundColor LightSkyBlue
  BorderColor Black
  FontName Arial
}

start

:Data Preparation;

partition "FAISS Memory Model" {
  :Create FAISS Index;
  :Encode and Decode Text Data;
  :Test FAISS Index;
}

partition "GPT-2 Model Adaptation" {
  :Load Pre-trained GPT-2 Model;
  :Modify GPT-2 Architecture;
  :Define Custom Loss Function;
}

partition "Training" {
  :Train Adapted GPT-2 Model;
  :Save Model Checkpoints;
}

partition "Evaluation" {
  :Evaluate Model on Testing Set;
  :Calculate Metrics;
}

if (Fine-tuning needed?) then (Yes)
  partition "Fine-tuning" {
    :Adjust Hyperparameters;
    :Iterate Training and Evaluation;
  }
endif

partition "Inference and Application" {
  :Inference Function;
  :API or Interface;
}

stop

@enduml
```

</details>

![process](./assets/README_003.svg)

### 1. Data Preparation

- Collect and preprocess a dataset for training, evaluation, and fine-tuning.
- Split the dataset into training, validation, and testing sets.
- Create data loaders for handling data.

### 2. GPT-2 Model Adaptation

- Load a pre-trained GPT-2 model from Hugging Face Transformers.
- Modify the GPT-2 model architecture to incorporate the FAISS memory model.
- Define a custom loss function that considers both the GPT-2 model's output and
  the memory model.

### 3. Training

- Set up the training loop and train the adapted GPT-2 model.
- Save model checkpoints and track training metrics (loss, perplexity, etc.).
- Monitor the training progress, validate the model on the validation set, and
  perform early stopping if necessary.

### 4. Evaluation

- Evaluate the trained model on the testing set.
- Calculate evaluation metrics (e.g., perplexity, accuracy, F1-score).

### 5. Fine-tuning (if necessary)

- If the model's performance on the testing set is not satisfactory, fine-tune
  the model with different hyperparameters, learning rates, or architectures.
- Iterate through the training and evaluation steps until the desired
  performance is achieved.

## Prerequisites

- Python 3.6 or higher
- PyTorch
- Hugging Face Transformers
- FAISS (CPU or GPU version)

## Setup

1. Clone the repository:

```bash
git clone https://github.com/yourusername/VardaGPT.git
cd VardaGPT
```

2. Create and activate a virtual environment:

```bash
python -m venv venv
source venv/bin/activate
```

3. Install the required libraries:

```bash
pip install -r requirements.txt
```

## Directory Structure

- `src/`: Contains the Python source code for the project.
- `data/`: Stores the datasets used for training and evaluation.
- `models/`: Holds the trained models and their checkpoints.

## Usage

### Data Preparation

1. Place your dataset in the `data/` directory.
2. Preprocess and split your dataset into training, validation, and testing sets
   using the provided scripts in `src/`.

### Training

1. Configure the training settings and model hyperparameters in the
   `src/config.py` file.
2. Run the training script:

```bash
python src/train.py
```

3. Monitor the training progress and save model checkpoints in the `models/`
   directory.

### Evaluation

1. Evaluate the trained model on the validation and testing sets using the
   provided evaluation script:

```bash
python src/evaluate.py
```

### Inference

1. Use the provided inference script to generate text with the memory-enhanced
   GPT-2 model:

```bash
python src/inference.py --prompt "Your prompt text here"
```

## Contributing

Feel free to contribute to this project by submitting pull requests or opening
issues for bug reports and feature requests.

## Code Formatting and Pre-commit

This project uses `black`, `flake8`, and `mypy` for Python code formatting and
linting. We also use `prettier` to format JSON and Markdown files. The
configuration for these tools is in the `.pre-commit-config.yaml` file.

### Setup

1. Install `pre-commit` if you haven't already:

```bash
pip install pre-commit
```

2. Set up the git hooks:

```bash
pre-commit install
```

### Using Pre-commit

Whenever you commit changes, the pre-commit hooks will automatically format your
code and check for issues. If the hooks detect any problems, the commit will be
aborted, and you'll see a list of issues that need to be fixed. Once you've
resolved the issues, you can try committing again.

You can also run the pre-commit hooks manually on all files:

```bash
pre-commit run --all-files
```

Or run the hooks on specific files:

```bash
pre-commit run --files <file1> <file2>
```

By following this setup and using pre-commit hooks, you can ensure that the code
in the repository remains consistently formatted and adheres to the project's
coding standards.

## License

This project is licensed under the [MIT License](LICENSE).


================================================
FILE: STORY.md
================================================
# Story of this project 😅

## Background 🤔

With all the hype around ChatGPT, I wondered how much impact ChatGPT really had.
I mean, for a programmer, would ChatGPT be like a pair programmer? Like GitHub
Copilot++? Or would ChatGPT totally replace programmers so that product managers
could tell it what feature to build, and it would just build it!

Imagine a bunch of product managers sitting in a sprint planning meeting where,
after signing off on the tasks to be done this sprint and starting the sprint,
ChatGPT was deployed on those tasks. The sprint lasted for about 2 hours, and
everyone met again to do the next day's sprint grooming. 😆

## Project Idea 💡

Now, what the heck should I build to test this? Why not try attaching a memory
module to a GPT? I've seen some folks on the internet complain about the "low
memory" problem of language models. I've also used FAISS and FLANN before, so I
am familiar with how to technically achieve this. Whether it will actually work
or not—well, my 1080Ti is on its deathbed with a broken fan, and I don't have
the 💸 to train this thing on AWS anyway. Let's aim for unit tests to work then.

## Process 🏃

Okay.

I started with the project plan:

![start](./assets/1.png)

Then I made ChatGPT generate the project foundations, step by step, from
creating project directories, Makefile, README, pre-commit, vscode settings for
the same tools in pre-commit, setup.cfg, and a GitHub workflow to run tests. In
each case, I had to specify exactly what I wanted it to generate.

![start](./assets/2.png)

Yes, I made ChatGPT choose the project name as well.

![start](./assets/3.png)

If I forgot something, I would go back and ask ChatGPT to add it:

![start](./assets/4.png)

In fact, I found it better to let ChatGPT generate a toy-ish version of the code
first, then let it add things to it step-by-step. This resulted in much better
output than, say, asking ChatGPT to generate production-quality code with all
features in the first go. This also gave me a way to break down my requirements
and feed them one at a time - as I was also acting as a code-reviewer for the
generated output, and so this method was also easier for me to work with.

![start](./assets/5.png) ![start](./assets/6.png) ![start](./assets/7.png)
![start](./assets/8.png) ![start](./assets/9.png)

Of course, I made ChatGPT write unit tests, and if they failed, I would just
copy the pytest output and feed it back into ChatGPT.

![start](./assets/10.png) ![start](./assets/11.png)

ChatGPT even figured this out!:

![start](./assets/12.png)

The result - I present to you
[VardaGPT](https://github.com/ixaxaar/vardagpt)—every inch of this repository
was generated by ChatGPT-4! It took a few hours, mostly around 3 weekends,
mostly at odd times, to generate this project.

## Experience 😮

It felt neither like a Copilot++ nor like the product manager scenario but
rather all at the same time. Sometimes I was amazed at what ChatGPT was able to
understand, sometimes I had to stubbornly push it to go in a certain direction,
sometimes it generated things I did not think of, sometimes I got super
frustrated while making ChatGPT fix the code in a certain way.

It was more like handholding a fresh grad who had absorbed all of human
knowledge but needed someone to tie various parts of that knowledge to create
something useful. Also ChatGPT is bad at dealing with abstractions beyond 2
layers.

ChatGPT is definitely a productivity multiplier. I think it is rather a
differential productivity multiplier, as it would enhance more the capabilities
of those who already know more. If I did not understand deep learning and FAISS,
or how projects are structured, I don't think I would have been able to pull
this off. On the other hand, it also has some sort of a leveling effect—I have
not worked on PyTorch in a while, have no idea of FAISS's new APIs, etc., but
these gaps were filled in by ChatGPT.

Finally, it was also tiring. Imagine being reduced to giving only instructions
and doing code review. Reading and understanding code is tiring!

## Conclusion ❓

It looks like my job is safe this year. Time to generate an elaborate software
project and claim supremacy on my ChatGPT usage abilities to hedge against next
year.

I wonder if by the time ChatGPT-6 comes out, would engineering teams be like,
"Hey, let's generate our own Grafana with a purple theme 😄".

## Aside 🦄

I could not resist but add this bit. ChatGPT is great at generating Agda! Maybe
this would also be the ultimate tool that can be used to formalize all of pure
math? 😱

![start](./assets/13.png)


================================================
FILE: requirements.txt
================================================
black==23.3.0
certifi==2022.12.7
charset-normalizer==3.1.0
click==8.1.3
cmake==3.26.3
coverage==7.2.3
exceptiongroup==1.1.1
faiss-cpu==1.7.3
filelock==3.12.0
flake8==6.0.0
huggingface-hub==0.13.4
idna==3.4
iniconfig==2.0.0
Jinja2==3.1.2
lit==16.0.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
mccabe==0.7.0
mdurl==0.1.2
mpmath==1.3.0
mypy==1.2.0
mypy-extensions==1.0.0
networkx==3.1
numpy==1.24.2
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
packaging==23.1
pathspec==0.11.1
platformdirs==3.2.0
pluggy==1.0.0
pycodestyle==2.10.0
pydantic==1.10.7
pyflakes==3.0.1
Pygments==2.15.1
pytest==7.3.1
pytest-cov==4.0.0
PyYAML==6.0
regex==2023.3.23
requests==2.28.2
rich==13.3.5
sympy==1.11.1
tokenizers==0.13.3
tomli==2.0.1
torch==2.0.0
torchdata==0.6.0
torchtext==0.15.1
tqdm==4.65.0
transformers==4.28.1
triton==2.0.0
typing_extensions==4.5.0
urllib3==1.26.15


================================================
FILE: setup.cfg
================================================
[flake8]
exclude = */__init__.py,migrations/*
ignore = E111, E114, E121, E131, W503, F405, F403, E126, E501, F841, E124, E251, E203
max-line-length = 120
max-doc-length = 120
show-source = true
statistics = false
doctests = true

[tool.black]
line-length = 120

[mypy]
ignore_missing_imports = True

[mypy-tests.*]
disallow_untyped_defs = False

[tool:pytest]
addopts = -p no:warnings
ignore = tests


================================================
FILE: src/data.py
================================================
from typing import Any, Tuple

import torch
from torch.utils.data import DataLoader
from torchtext.data import get_tokenizer
from torchtext.datasets import WikiText2
from transformers import GPT2Tokenizer


def load_wikitext2() -> Tuple[DataLoader[Any], DataLoader[Any], DataLoader[Any]]:
    """
    Load the WikiText-2 dataset for training, validation, and testing.

    :return: A tuple of three DataLoaders for train, valid, and test sets.
    """

    # Define a tokenizer using the GPT2Tokenizer from the transformers library
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Define a tokenize function using the GPT2Tokenizer
    def tokenize(text: str) -> torch.Tensor:
        return tokenizer.encode(text, return_tensors="pt").squeeze(0)  # type: ignore

    # Use torchtext's get_tokenizer function to create a tokenizer pipeline
    text_pipeline = get_tokenizer(tokenize)

    # Load the WikiText-2 dataset
    train_data, valid_data, test_data = WikiText2()

    # Create DataLoaders for the train, valid, and test sets
    train_loader = DataLoader([text_pipeline(text) for text in train_data], batch_size=1, shuffle=True)  # type: ignore
    valid_loader = DataLoader([text_pipeline(text) for text in valid_data], batch_size=1, shuffle=True)  # type: ignore
    test_loader = DataLoader([text_pipeline(text) for text in test_data], batch_size=1, shuffle=True)  # type: ignore

    return train_loader, valid_loader, test_loader


================================================
FILE: src/memory/__init__.py
================================================
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
from faiss import IndexIVF


class Memory:
    memory_size: int = 0
    embedding_dim: int = 0
    index: IndexIVF

    def add(self, embeddings: npt.NDArray[Any]) -> None:
        self.index.add(embeddings)

    def remove(self, ids: npt.NDArray[Any]) -> None:
        self.index.remove_ids(ids)

    def update(self, ids: npt.NDArray[Any], updated_embeddings: npt.NDArray[Any]) -> None:
        self.remove(ids)
        self.add(updated_embeddings)

    def search(self, query_vectors: npt.NDArray[Any], k: int = 10) -> Any:
        raise NotImplementedError()

    def refresh(self) -> None:
        self.index.reset()


================================================
FILE: src/memory/associative.py
================================================
from typing import Any, Union

import faiss
import numpy as np
import numpy.typing as npt
import torch

from . import Memory


class AssociativeMemory(Memory):
    def __init__(
        self,
        memory_size: int,
        embedding_dim: int,
        index_type: str = "flat",
        num_clusters: int = 1024,
        m: int = 8,
        ef_construction: int = 100,
        ef_search: int = 64,
        use_gpu: bool = False,
        gpu_device: int = 0,
        forgetfulness_factor: float = 0.0001,
    ):
        """
        Initialize the associative memory.

        :param memory_size: The maximum number of items the memory can store.
        :param embedding_dim: The dimensionality of the input embeddings.
        :param index_type: The type of FAISS index to use (default: 'flat').
        :param num_clusters: The number of clusters to use for an IVF index (default: 1024).
        :param m: The number of product quantization codes to use for an IVFPQ index (default: 8).
        :param ef_construction: The size of the entry list for an HNSW index (default: 100).
        :param ef_search: The search list size for an HNSW index (default: 64).
        :param use_gpu: Whether to use GPU acceleration for the FAISS index (default: False).
        :param gpu_device: The ID of the GPU device to use (default: 0).
        :param forgetfulness_factor: The percentage of items to remove during random forgetting (default: 1).
        """
        # Initialize memory parameters
        self.memory_size = memory_size
        self.embedding_dim = embedding_dim
        self.forgetfulness_factor = forgetfulness_factor
        self.use_gpu = use_gpu
        self.gpu_device = gpu_device

        # Create the appropriate Faiss index based on the specified type
        if index_type == "flat":
            # Inverted File with Flat index - a compressed index with an inverted file structure
            quantizer = faiss.IndexFlatL2(embedding_dim)
            index = faiss.IndexIVFFlat(quantizer, embedding_dim, num_clusters, faiss.METRIC_L2)
            index.make_direct_map()
            index.set_direct_map_type(faiss.DirectMap.Hashtable)
        elif index_type == "compressed":
            # Inverted File with Product Quantization index - a compressed index with a product quantization compression
            quantizer = faiss.IndexFlatL2(embedding_dim)
            index = faiss.IndexIVFPQ(quantizer, embedding_dim, num_clusters, m, 8)
        elif index_type == "graph":
            # Hierarchical Navigable Small World index - a graph-based index with a flat storage
            index = faiss.IndexHNSWFlat(embedding_dim, ef_construction, faiss.METRIC_L2)
            index.hnsw.efSearch = ef_search
        else:
            raise ValueError(f"Invalid index_type: {index_type}")
        self.index_type = index_type

        # Enable GPU support if specified
        if use_gpu:
            self.res = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(self.res, gpu_device, index)
        else:
            self.index = index

        # Train the index with empty data
        self.index.train(np.zeros((max(memory_size, num_clusters), embedding_dim), dtype=np.float32))

        # Initialize an empty array to store the input vectors
        self.input_vectors = np.zeros((memory_size, embedding_dim), dtype=np.float32)

    def _to_numpy(self, data: Union[npt.NDArray[Any], torch.Tensor]) -> Union[npt.NDArray[Any], torch.Tensor]:
        """
        Convert input data to a NumPy array if it's a PyTorch tensor and not using GPU.

        :param data: Input data to be converted. Can be either a NumPy array or a PyTorch tensor.
        :return: Converted data as a NumPy array or a PyTorch tensor.
        """
        if isinstance(data, torch.Tensor) and not self.use_gpu:
            return data.detach().cpu().numpy()  # type: ignore
        return data

    def add(self, embeddings: Union[npt.NDArray[Any], torch.Tensor]) -> None:
        """
        Add embeddings to the memory.

        :param embeddings: A 2D array of shape (n, embedding_dim) containing the embeddings to be added,
                           where n is the number of items to add. Can be either a NumPy array or a PyTorch tensor.
        """
        embeddings = self._to_numpy(embeddings)
        n_added = self.index.ntotal  # Existing number of added items
        n_to_add = embeddings.shape[0]  # Number of items to add

        # Update the input_vectors array with the new embeddings
        self.input_vectors[n_added : n_added + n_to_add] = embeddings

        # Add embeddings to the index
        if self.use_gpu:
            ptr = faiss.torch_utils.swig_ptr_from_FloatTensor(embeddings)
            self.index.add_c(n_to_add, ptr)
        else:
            self.index.add(embeddings)

    def remove(self, ids: Union[npt.NDArray[Any], torch.Tensor]) -> None:
        """
        Remove embeddings with the specified IDs from the memory.

        :param ids: A 1D array of shape (n,) containing the indices of the items to be removed,
                    where n is the number of items to remove. Can be either a NumPy array or a PyTorch tensor.
        """
        ids = self._to_numpy(ids)

        if self.index_type != "flat":
            raise ValueError(
                f"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}"
            )
        self.index.remove_ids(ids)

        # Remove input vectors from the input_vectors array
        self.input_vectors = np.delete(self.input_vectors, ids, axis=0)

    def update(
        self, ids: Union[npt.NDArray[Any], torch.Tensor], updated_embeddings: Union[npt.NDArray[Any], torch.Tensor]
    ) -> None:
        """
        Update embeddings with the specified IDs in the memory.

        :param ids: A 1D array of shape (n,) containing the indices of the items to be updated,
                    where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor.
        :param updated_embeddings: A 2D array of shape (n, embedding_dim) containing the updated embeddings,
                    where n is the number of items to update. Can be either a NumPy array or a PyTorch tensor.
        """
        ids = self._to_numpy(ids)
        updated_embeddings = self._to_numpy(updated_embeddings)

        if self.index_type != "flat":
            raise ValueError(
                f"Update is not implemented in FAISS this type of index, use flat instad of: {self.index_type}"
            )
        self.remove(ids)
        self.add(updated_embeddings)

    def search(self, query_vectors: Any, k: int = 10) -> Any:
        """
        Search the memory for the top k closest embeddings to the query vectors.

        :param query_vectors: A 2D array or tensor of shape (n, embedding_dim) containing the query vectors,
                            where n is the number of query vectors.
        :param k: The number of nearest neighbors to return for each query.
        :return: A tuple containing two 2D arrays or tensors for indices and distances, both of shape (n, k).
        """
        if self.use_gpu and isinstance(query_vectors, torch.Tensor):
            n_query = query_vectors.shape[0]
            distances = torch.empty((n_query, k), device=self.gpu_device)  # type: ignore
            indices = torch.empty((n_query, k), dtype=torch.long, device=self.gpu_device)  # type: ignore

            ptr_query = faiss.torch_utils.swig_ptr_from_FloatTensor(query_vectors)
            ptr_distances = faiss.torch_utils.swig_ptr_from_FloatTensor(distances)
            ptr_indices = faiss.torch_utils.swig_ptr_from_LongTensor(indices)
            self.index.search_c(n_query, ptr_query, k, ptr_distances, ptr_indices)

            return indices, distances
        else:
            if isinstance(query_vectors, torch.Tensor):
                query_vectors = query_vectors.numpy()
            distances, indices = self.index.search(query_vectors, k)
            if isinstance(query_vectors, torch.Tensor):
                return torch.from_numpy(indices), torch.from_numpy(distances)
            else:
                return indices, distances

    def age_memory(self, decay_factor: float = 0.999) -> None:
        """
        Age the memory embeddings by multiplying them with a decay factor.

        :param decay_factor: float, factor to multiply the embeddings with (default: 0.99)
        """
        assert 0 <= decay_factor <= 1, "Decay factor should be between 0 and 1."

        # Get the current embeddings from the memory
        current_embeddings = np.zeros((self.index.ntotal, self.embedding_dim), dtype=np.float32)
        for idx in range(self.index.ntotal):
            current_embeddings[idx] = self.index.reconstruct(idx)

        # Apply the decay factor
        aged_embeddings = current_embeddings * decay_factor

        # Update the memory with the aged embeddings
        ids = np.arange(self.index.ntotal, dtype=np.int64)
        self.update(ids, aged_embeddings)

    def forget_randomly(self) -> None:
        """
        Remove a random subset of embeddings from the memory based on the forgetfulness factor.
        """
        assert 0 <= self.forgetfulness_factor <= 1, "Forgetfulness factor should be between 0 and 1."

        total_embeddings = self.index.ntotal
        num_embeddings_to_forget = int(total_embeddings * self.forgetfulness_factor)

        # Select random embeddings to forget
        ids_to_forget = np.random.choice(total_embeddings, size=num_embeddings_to_forget, replace=False)

        # Remove the selected embeddings from the memory
        self.remove(ids_to_forget)

    def garbage_collect(self, threshold: float = 1e-6) -> npt.NDArray[Any]:
        """
        Remove nearly zero vectors from the memory and return the indices of the empty vectors.

        Parameters:
        threshold (float): Threshold value to consider a vector nearly zero.

        Returns:
        npt.NDArray[Any]: Indices of the empty vectors.
        """
        # Fetch all embeddings from the memory
        embeddings = self.get_all_embeddings()

        # Calculate the L2 norms of the embeddings
        norms = np.linalg.norm(embeddings, axis=1)

        # Identify nearly zero vectors based on the threshold
        nearly_zero_vectors = np.where(norms < threshold)[0]

        # Remove nearly zero vectors from the memory
        self.remove(nearly_zero_vectors)

        return nearly_zero_vectors

    def get_all_embeddings(self) -> npt.NDArray[Any]:
        """
        Retrieve all embeddings stored in the memory.

        Returns:
        npt.NDArray[Any]: A 2D array of shape (n, embedding_dim) containing all stored embeddings,
                          where n is the number of stored items.
        """
        # Return the stored input_vectors directly
        return self.input_vectors[: self.index.ntotal]  # Use slicing to get only the added items

    def __getitem__(self, index: int) -> Any:
        """
        Retrieve the input vector at the specified index.

        :param index: The index of the input vector to retrieve.
        :return: A 1D array of shape (embedding_dim,) containing the input vector.
        """
        if index >= self.index.ntotal:
            raise IndexError("Index out of range.")
        return self.input_vectors[index]

    def size(self) -> int:
        """
        Get the number of items stored in the memory.

        :return: The number of items stored in the memory.
        """
        return int(self.index.ntotal)


================================================
FILE: src/memory/batch_associative.py
================================================
from typing import Any, List, Tuple, Union

import numpy as np
import numpy.typing as npt

try:
    import torch

    _TORCH_AVAILABLE = True
except ImportError:
    _TORCH_AVAILABLE = False

from .associative import AssociativeMemory


class BatchAssociativeMemory:
    def __init__(
        self,
        num_batches: int,
        memory_size: int,
        embedding_dim: int,
        **kwargs: Any,
    ) -> None:
        """
        Initialize a batch associative memory.

        :param num_batches: Number of separate associative memories in this batch.
        :param memory_size: The maximum number of items each memory can store.
        :param embedding_dim: The dimensionality of the input embeddings.
        :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor.
        """
        self.num_batches = num_batches
        self.embedding_dim = embedding_dim
        self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)]

    def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"]) -> None:
        """
        Perform a batch of add operations. Assumes a single token per batch.

        :param embeddings: A 2D array of shape (num_batches, embedding_dim) containing the embeddings
                           of single tokens to be added for each batch.
        """
        for memory, batch_embeddings in zip(self.memories, embeddings):
            memory.add(batch_embeddings.reshape(1, -1))

    def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -> None:
        """
        Perform a batch of remove operations. Assumes a single token per batch.

        :param ids: A 1D array of shape (num_batches,) containing the indices of the items to be removed.
        """
        for memory, batch_id in zip(self.memories, ids):
            memory.remove(np.array([batch_id], dtype=np.int64))

    def batch_search(
        self, query_vectors: Union[npt.NDArray[Any], "torch.Tensor"], k: int = 10
    ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]:
        """
        Perform a batch of search operations. Assumes a single token per batch.

        :param query_vectors: A 2D array of shape (num_batches, embedding_dim) containing the query vectors
                              of single tokens for each batch.
        :param k: The number of nearest neighbors to return for each query.
        :return: A list of tuples, each containing two 1D arrays of shape (1,k) for indices and distances.
        """
        results = []
        for memory, batch_query_vector in zip(self.memories, query_vectors.reshape(-1, 1, self.embedding_dim)):
            indices, distances = memory.search(batch_query_vector, k)
            results.append((indices, distances))
        return results


================================================
FILE: src/memory/multitoken_batch_associative.py
================================================
from typing import Any, List, Tuple, Union

import numpy as np
import numpy.typing as npt

try:
    import torch

    _TORCH_AVAILABLE = True
except ImportError:
    _TORCH_AVAILABLE = False

from .associative import AssociativeMemory


class MultiTokenBatchAssociativeMemory:
    def __init__(
        self,
        num_batches: int,
        memory_size: int,
        embedding_dim: int,
        **kwargs: Any,
    ) -> None:
        """
        Initialize a batch associative memory.

        :param num_batches: Number of separate associative memories in this batch.
        :param memory_size: The maximum number of items each memory can store.
        :param embedding_dim: The dimensionality of the input embeddings.
        :param kwargs: Additional arguments to be passed to the AssociativeMemory constructor.
        """
        self.num_batches = num_batches
        self.embedding_dim = embedding_dim
        self.memories = [AssociativeMemory(memory_size, embedding_dim, **kwargs) for _ in range(num_batches)]

    def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"]) -> None:
        """
        Perform a batch of add operations.

        :param embeddings: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the embeddings
                           to be added for each batch.
        """
        for memory, batch_embeddings in zip(self.memories, embeddings):
            memory.add(batch_embeddings)

    def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -> None:
        """
        Perform a batch of remove operations.

        :param ids: A 2D array of shape (num_batches, num_tokens) containing the indices of the items to be removed.
        """
        for memory, batch_ids in zip(self.memories, ids):
            memory.remove(batch_ids)

    def batch_search(
        self, query_vectors: Union[npt.NDArray[Any], "torch.Tensor"], k: int = 10
    ) -> List[Tuple[npt.NDArray[np.int64], npt.NDArray[np.float32]]]:
        """
        Perform a batch of search operations.

        :param query_vectors: A 3D array of shape (num_batches, num_tokens, embedding_dim) containing the query vectors
                              for each batch.
        :param k: The number of nearest neighbors to return for each query.
        :return: A list of tuples, each containing two 2D arrays of shape (num_tokens, k) for indices and distances.
        """
        results = []
        for memory, batch_query_vectors in zip(self.memories, query_vectors):
            indices, distances = memory.search(batch_query_vectors, k)
            results.append((indices, distances))
        return results


================================================
FILE: src/models/__init__.py
================================================


================================================
FILE: src/models/gpt2_associative.py
================================================
from typing import Any

import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel

from memory.batch_associative import BatchAssociativeMemory


class VardaGPTAssociative(nn.Module):
    def __init__(
        self,
        gpt2_model_name: str = "gpt2",
        memory_size: int = 10000,
        memory_dim: int = 768,
        index_type: str = "flat",
        num_clusters: int = 1024,
        num_search_results: int = 5,
        use_gpu: bool = False,
        batch_size: int = 1,
        forgetfulness_factor: float = 0.001,
    ):
        """
        Initialize a GPT-2 model with associative memory.

        :param gpt2_model_name: The name of the GPT-2 model to load. Default is "gpt2".
        :param memory_size: The maximum number of items the associative memory can store. Default is 10000.
        :param memory_dim: The dimensionality of the embeddings stored in the associative memory. Default is 768.
        :param index_type: The type of index used for the associative memory. Default is "flat".
        :param num_clusters: The number of clusters to use for the memory if the index type is "ivf". Default is 1024.
        :param num_search_results: The number of search results to return from the associative memory. Default is 5.
        :param use_gpu: Whether to use the GPU for the model if available. Default is False.
        """
        super(VardaGPTAssociative, self).__init__()

        # Set up the device for the model
        self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")

        # Load the GPT-2 model and configuration
        self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name)
        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)

        # Initialize the BatchAssociativeMemory module
        self.memory = BatchAssociativeMemory(
            num_batches=batch_size,
            memory_size=memory_size,
            embedding_dim=memory_dim,
            index_type=index_type,
            num_clusters=num_clusters,
            use_gpu=use_gpu,
            forgetfulness_factor=forgetfulness_factor,
        )

        # Define dimensions for search results and output
        self.search_results_dim = memory_dim * num_search_results

        # Linear layers for concatenated input, storable vector, and store decision
        self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd)
        self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim)
        self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1)

        # Move all layers to the device
        self.to(self.device)

        self.memory_dim = memory_dim
        self.num_search_results = num_search_results
        self.forgetfulness_factor = forgetfulness_factor

    def forward(self, input_vectors: torch.Tensor) -> Any:
        """
        Perform a forward pass through the GPT-2 model with associative memory.

        :param input_vectors: A 3D tensor of shape (batch_size, sequence_length, input_dim) containing
            the input vectors for each token in the batch.
        :param memory_input: A 2D tensor of shape (batch_size, memory_dim) containing the memory input for each item
            in the batch. If not provided, memory will not be used.
        :return: A 3D tensor of shape (batch_size, sequence_length, vocab_size) containing
        the logits from the GPT-2 model.
        """
        input_vectors = input_vectors.to(self.device)
        batch_size, seq_len, _ = input_vectors.shape

        # Initialize search_results tensor with the correct shape
        search_results = torch.zeros((batch_size, seq_len, self.search_results_dim), device=self.device)

        # Search for relevant results for each item in the batch
        for t in range(seq_len):
            search_results_list = self.memory.batch_search(input_vectors[:, t, :].squeeze(1), self.num_search_results)
            retrieved_embeddings_list = []
            # Retrieve and concatenate search results with input vectors
            for ctr, (indices, _) in enumerate(search_results_list):
                retrieved_embeddings = torch.cat(
                    [
                        self.memory.memories[ctr][i].unsqueeze(0)
                        if i >= 0
                        else torch.zeros(self.memory_dim).unsqueeze(0)
                        for i in indices.squeeze()
                    ],
                    dim=0,
                )
                # Update the corresponding search_results tensor
                retrieved_embeddings_list.append(retrieved_embeddings)
            search_results[:, t, :] = torch.cat(retrieved_embeddings_list, dim=0).view(batch_size, -1)

        concatenated_input = torch.cat([input_vectors, search_results], dim=-1)

        input_vectors = self.fc(concatenated_input)

        # Pass input_vectors through GPT-2 model's transformer and obtain hidden states
        transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors)
        hidden_states = transformer_outputs.last_hidden_state

        # Get logits from hidden states
        logits = self.gpt2_model.lm_head(hidden_states)

        # Calculate storable vector and store decision
        storable_vector = self.fc_storable_vector(hidden_states)
        store_decision = self.fc_store_decision(hidden_states)

        # Store the storable_vector in the associative memory if the store_decision is affirmative
        store_threshold = 0.5  # Define a threshold for store decision
        store_mask = (store_decision > store_threshold).float()
        storable_vector_to_store = storable_vector * store_mask

        for i in range(seq_len):
            storable_vector_to_store_i = storable_vector_to_store[:, i, :].view(batch_size, -1).detach()
            self.memory.batch_add(storable_vector_to_store_i)

        # Randomly forget items from the memory with a specified probability
        for memory in self.memory.memories:
            memory.forget_randomly()

        return logits


================================================
FILE: src/models/gpt2_working.py
================================================
from typing import Any, Optional

import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel

from memory.associative import AssociativeMemory


class VardaGPTWorking(nn.Module):
    def __init__(
        self,
        gpt2_model_name: str = "gpt2-small",
        memory_size: int = 10000,
        memory_dim: int = 768,
        index_type: str = "flat",
        num_clusters: int = 1024,
        num_search_results: int = 5,
        use_gpu: bool = False,
    ):
        super(VardaGPTWorking, self).__init__()

        # Set up the device for the model
        self.device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")

        # Load the GPT-2 model and configuration
        self.gpt2_config = GPT2Config.from_pretrained(gpt2_model_name)
        self.gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)

        # Initialize the AssociativeMemory module
        self.memory = AssociativeMemory(
            memory_size=memory_size,
            embedding_dim=memory_dim,
            index_type=index_type,
            num_clusters=num_clusters,
            use_gpu=use_gpu,
        )

        # Define dimensions for search results and output
        self.search_results_dim = memory_dim * num_search_results

        # Linear layers for concatenated input, storable vector, store decision, delete decision, and deletable vector
        self.fc = nn.Linear(self.gpt2_config.n_embd + self.search_results_dim, self.gpt2_config.n_embd)
        self.fc_storable_vector = nn.Linear(self.gpt2_config.n_embd, memory_dim)
        self.fc_store_decision = nn.Linear(self.gpt2_config.n_embd, 1)

        # Move all layers to the device
        self.to(self.device)

        self.num_search_results = num_search_results

    def forward(self, input_vectors: torch.Tensor, memory_input: Optional[torch.Tensor] = None) -> Any:
        input_vectors = input_vectors.to(self.device)

        # Search for relevant results if memory_input is provided
        if memory_input is not None:
            indices, distances = self.memory.search(memory_input.cpu().numpy())

            # Retrieve and concatenate search results with input vectors
            search_results = self.memory.get_all_embeddings()[indices].reshape(-1, self.search_results_dim)
            search_results = torch.tensor(search_results).to(self.device)
            concatenated_input = torch.cat([input_vectors, search_results], dim=-1)

            # Pass concatenated input through linear layer
            input_vectors = self.fc(concatenated_input)

        # Pass input_vectors through GPT-2 model's transformer and obtain hidden states
        transformer_outputs = self.gpt2_model.transformer(inputs_embeds=input_vectors)
        hidden_states = transformer_outputs.last_hidden_state

        # Get logits from hidden states
        logits = self.gpt2_model.lm_head(hidden_states)

        # Calculate storable vector, store decision, delete decision, and deletable vector
        storable_vector = self.fc_storable_vector(hidden_states)
        store_decision = self.fc_store_decision(hidden_states)

        # Store the storable_vector in the associative memory if the store_decision is affirmative
        store_threshold = 0.5  # Define a threshold for store decision
        store_mask = (store_decision > store_threshold).float()
        storable_vector_to_store = storable_vector * store_mask
        self.memory.add(storable_vector_to_store)

        return logits


================================================
FILE: src/train.py
================================================
import argparse
import time
from typing import Any

import torch
import torch.optim as optim
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.progress import track
from rich.table import Table
from rich.theme import Theme
from torch.utils.data import DataLoader

from data import load_wikitext2
from models.gpt2_associative import VardaGPTAssociative


def train(
    model: VardaGPTAssociative,
    train_loader: DataLoader[Any],
    valid_loader: DataLoader[Any],
    epochs: int,
    lr: float,
    device: torch.device,
) -> None:
    """
    Train the model with the given training and validation data.

    :param model: The VardaGPTAssociative model to be trained.
    :param train_loader: DataLoader for the training data.
    :param valid_loader: DataLoader for the validation data.
    :param epochs: Number of epochs to train the model.
    :param lr: Learning rate for the optimizer.
    :param device: The device to use for training (CPU or GPU).
    """

    # Initialize the optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_function = torch.nn.CrossEntropyLoss()

    # Create a console object for printing colorful prompts
    theme = Theme({"info": "dim green", "warning": "yellow", "error": "bold red"})
    console = Console(theme=theme)

    # Training loop
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_start_time = time.time()

        for _, batch in enumerate(track(train_loader, description=f"[bold][info]Epoch {epoch + 1}")):
            # Move the input to the device
            input_vectors = batch.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            logits = model(input_vectors)

            # Calculate loss
            loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()

            epoch_loss += loss.item()

        # Calculate average epoch loss
        average_epoch_loss = epoch_loss / len(train_loader)
        epoch_time = time.time() - epoch_start_time

        # Validation
        model.eval()
        valid_loss = 0.0
        with torch.no_grad():
            for _, batch in enumerate(valid_loader):
                input_vectors = batch.to(device)
                logits = model(input_vectors)
                loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))
                valid_loss += loss.item()

        # Calculate average validation loss
        average_valid_loss = valid_loss / len(valid_loader)

        # Print epoch summary
        table = Table(title=f"Epoch {epoch + 1} Summary")
        table.add_column("Metric", style="bold")
        table.add_column("Value", style="bold")
        table.add_row("Training Loss", f"{average_epoch_loss:.4f}")
        table.add_row("Validation Loss", f"{average_valid_loss:.4f}")
        table.add_row("Time", f"{epoch_time:.2f} seconds")
        console.print(table)


if __name__ == "__main__":
    console = Console()

    console.print(Panel.fit("[bold blue]VardaGPTAssociative Training Script[/bold blue]"))

    description = """\
This script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval.
"""
    console.print(Markdown(description))

    parser = argparse.ArgumentParser(description="Train VardaGPTAssociative model on WikiText-2 dataset")
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train the model")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer")
    parser.add_argument(
        "--memory_size", type=int, default=10000, help="Maximum number of items the associative memory can store"
    )
    parser.add_argument(
        "--memory_dim", type=int, default=768, help="Dimensionality of the embeddings stored in the associative memory"
    )
    parser.add_argument("--index_type", type=str, default="flat", help="Type of index used for the associative memory")
    parser.add_argument(
        "--num_clusters",
        type=int,
        default=1024,
        help="Number of clusters to use for the memory if the index type is 'ivf'",
    )
    parser.add_argument(
        "--num_search_results",
        type=int,
        default=5,
        help="Number of search results to return from the associative memory",
    )
    parser.add_argument("--use_gpu", action="store_true", help="Whether to use the GPU for the model if available")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument(
        "--forgetfulness_factor", type=float, default=0.001, help="Forgetfulness factor for the associative memory"
    )

    args = parser.parse_args()

    console.print("[bold green]Training settings:[/bold green]")
    console.print(f"  Epochs: {args.epochs}")
    console.print(f"  Learning rate: {args.learning_rate}")

    model = VardaGPTAssociative(
        gpt2_model_name="gpt2",
        memory_size=args.memory_size,
        memory_dim=args.memory_dim,
        index_type=args.index_type,
        num_clusters=args.num_clusters,
        num_search_results=args.num_search_results,
        use_gpu=args.use_gpu,
        batch_size=args.batch_size,
        forgetfulness_factor=args.forgetfulness_factor,
    )

    # Move the model to the device
    device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
    model.to(device)

    train_loader, valid_loader, test_loader = load_wikitext2()

    # Train the model
    train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device)


================================================
FILE: src/train_parallel.py
================================================
import argparse
import time
from typing import Any, Union

import torch
import torch.optim as optim
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.progress import track
from rich.table import Table
from rich.theme import Theme
from torch.utils.data import DataLoader

from data import load_wikitext2
from models.gpt2_associative import VardaGPTAssociative


def train(
    model: Union[VardaGPTAssociative, torch.nn.DataParallel],
    train_loader: DataLoader[Any],
    valid_loader: DataLoader[Any],
    epochs: int,
    lr: float,
    device: torch.device,
) -> None:
    """
    Train the model with the given training and validation data.

    :param model: The VardaGPTAssociative model to be trained.
    :param train_loader: DataLoader for the training data.
    :param valid_loader: DataLoader for the validation data.
    :param epochs: Number of epochs to train the model.
    :param lr: Learning rate for the optimizer.
    :param device: The device to use for training (CPU or GPU).
    """

    # Initialize the optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_function = torch.nn.CrossEntropyLoss()

    # Create a console object for printing colorful prompts
    theme = Theme({"info": "dim green", "warning": "yellow", "error": "bold red"})
    console = Console(theme=theme)

    # Training loop
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_start_time = time.time()

        for _, batch in enumerate(track(train_loader, description=f"[bold][info]Epoch {epoch + 1}")):
            # Move the input to the device
            input_vectors = batch.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            logits = model(input_vectors)

            # Calculate loss
            loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))

            # Backward pass
            loss.backward()

            # Update weights
            optimizer.step()

            epoch_loss += loss.item()

        # Calculate average epoch loss
        average_epoch_loss = epoch_loss / len(train_loader)
        epoch_time = time.time() - epoch_start_time

        # Validation
        model.eval()
        valid_loss = 0.0
        with torch.no_grad():
            for _, batch in enumerate(valid_loader):
                input_vectors = batch.to(device)
                logits = model(input_vectors)
                loss = loss_function(logits.view(-1, logits.shape[-1]), input_vectors.view(-1))
                valid_loss += loss.item()

        # Calculate average validation loss
        average_valid_loss = valid_loss / len(valid_loader)

        # Print epoch summary
        table = Table(title=f"Epoch {epoch + 1} Summary")
        table.add_column("Metric", style="bold")
        table.add_column("Value", style="bold")
        table.add_row("Training Loss", f"{average_epoch_loss:.4f}")
        table.add_row("Validation Loss", f"{average_valid_loss:.4f}")
        table.add_row("Time", f"{epoch_time:.2f} seconds")
        console.print(table)


if __name__ == "__main__":
    console = Console()

    console.print(Panel.fit("[bold blue]VardaGPTAssociative Training Script[/bold blue]"))

    description = """\
This script trains a VardaGPTAssociative model on the WikiText-2 dataset. The model combines GPT-2 with an associative memory to improve context retrieval.
"""
    console.print(Markdown(description))

    parser = argparse.ArgumentParser(description="Train VardaGPTAssociative model on WikiText-2 dataset")
    parser.add_argument("--epochs", type=int, default=5, help="Number of epochs to train the model")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for the optimizer")
    parser.add_argument(
        "--memory_size", type=int, default=10000, help="Maximum number of items the associative memory can store"
    )
    parser.add_argument(
        "--memory_dim", type=int, default=768, help="Dimensionality of the embeddings stored in the associative memory"
    )
    parser.add_argument("--index_type", type=str, default="flat", help="Type of index used for the associative memory")
    parser.add_argument(
        "--num_clusters",
        type=int,
        default=1024,
        help="Number of clusters to use for the memory if the index type is 'ivf'",
    )
    parser.add_argument(
        "--num_search_results",
        type=int,
        default=5,
        help="Number of search results to return from the associative memory",
    )
    parser.add_argument("--use_gpu", action="store_true", help="Whether to use the GPU for the model if available")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument(
        "--forgetfulness_factor", type=float, default=0.001, help="Forgetfulness factor for the associative memory"
    )

    args = parser.parse_args()

    console.print("[bold green]Training settings:[/bold green]")
    console.print(f"  Epochs: {args.epochs}")
    console.print(f"  Learning rate: {args.learning_rate}")

    model = VardaGPTAssociative(
        gpt2_model_name="gpt2",
        memory_size=args.memory_size,
        memory_dim=args.memory_dim,
        index_type=args.index_type,
        num_clusters=args.num_clusters,
        num_search_results=args.num_search_results,
        use_gpu=args.use_gpu,
        batch_size=args.batch_size,
        forgetfulness_factor=args.forgetfulness_factor,
    )

    # Move the model to the device
    device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")

    if torch.cuda.device_count() > 1 and args.use_gpu:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)  # type: ignore

    model.to(device)

    train_loader, valid_loader, test_loader = load_wikitext2()

    # Train the model
    train(model, train_loader, valid_loader, args.epochs, args.learning_rate, device)


================================================
FILE: test/__init__.py
================================================


================================================
FILE: test/memory/__init__..py
================================================


================================================
FILE: test/memory/test_associative.py
================================================
# type: ignore
import numpy as np
import torch
import pytest

from src.memory.associative import AssociativeMemory


@pytest.fixture(params=["numpy", "torch"])
def embeddings(request):
    embeddings_np = np.random.rand(1000, 768).astype(np.float32)
    if request.param == "numpy":
        return embeddings_np
    else:
        return torch.from_numpy(embeddings_np)


def test_add_and_search(embeddings):
    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)

    memory.add(embeddings)

    query_vectors = embeddings[:5]
    indices, distances = memory.search(query_vectors)

    assert indices.shape == (5, 10)


def test_update(embeddings):
    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)

    memory.add(embeddings)

    element_id = np.array([0], dtype=np.int64)
    updated_embedding = embeddings[:1]
    memory.update(element_id, updated_embedding)

    query_vectors = embeddings[1:2]
    indices, distances = memory.search(query_vectors)

    assert 0 not in indices[0]


def test_refresh_memory(embeddings):
    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)

    memory.add(embeddings)

    memory.refresh()

    query_vectors = embeddings[:5]
    indices, distances = memory.search(query_vectors)

    assert indices.shape == (5, 10)


def test_getitem(embeddings):
    memory = AssociativeMemory(memory_size=50000, embedding_dim=768)

    memory.add(embeddings)

    # Test __getitem__ method
    retrieved_vector = memory[5]
    assert np.allclose(retrieved_vector, embeddings[5])


================================================
FILE: test/memory/test_associative_batch.py
================================================
# type: ignore
import numpy as np
import pytest

try:
    import torch
except ImportError:
    pass

from src.memory.batch_associative import BatchAssociativeMemory


@pytest.fixture
def batch_memory():
    num_batches = 3
    memory_size = 1000
    embedding_dim = 128
    return BatchAssociativeMemory(num_batches, memory_size, embedding_dim)


@pytest.fixture(params=["numpy", "torch"])
def tensor_type(request):
    return request.param


def create_tensor(tensor_type, data):
    if tensor_type == "numpy":
        return data
    else:
        return torch.from_numpy(data)


def test_batch_add(batch_memory, tensor_type):
    # Create integer batched embeddings
    embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32)
    embeddings = create_tensor(tensor_type, embeddings_np)

    # Add embeddings to the batch memory
    batch_memory.batch_add(embeddings)

    # Check if the embeddings are added to the corresponding memories
    for i, memory in enumerate(batch_memory.memories):
        all_embeddings = memory.get_all_embeddings()
        assert all_embeddings.shape == (1, batch_memory.embedding_dim)
        assert np.array_equal(embeddings_np[i], all_embeddings[0])


def test_batch_remove(batch_memory, tensor_type):
    # Create integer batched embeddings
    embeddings_np = np.random.randint(0, 10, (batch_memory.num_batches, batch_memory.embedding_dim)).astype(np.float32)
    embeddings = create_tensor(tensor_type, embeddings_np)

    # Add embeddings to the batch memory
    batch_memory.batch_add(embeddings)

    # Remove embeddings by index
    indices_to_remove_np = np.zeros(batch_memory.num_batches)
    indices_to_remove = create_tensor(tensor_type, indices_to_remove_np)
    batch_memory.batch_remove(indices_to_remove)

    # Check if the embeddings are removed from the corresponding memories
    for _, memory in enumerate(batch_memory.memories):
        all_embeddings = memory.get_all_embeddings()
        all_embeddings.shape = (0, batch_memory.embedding_dim)


def test_batch_search(batch_memory, tensor_type):
    # Create a specific vector
    specific_vector_np = np.ones((1, batch_memory.embedding_dim), dtype=np.float32)
    specific_vector = create_tensor(tensor_type, specific_vector_np)

    # Create batched embeddings
    embeddings_np = np.random.randint(0, 10, size=(batch_memory.num_batches - 1, batch_memory.embedding_dim)).astype(
        np.float32
    )

    # Combine the specific vector with random embeddings
    embeddings_np = np.vstack((specific_vector_np, embeddings_np))
    embeddings = create_tensor(tensor_type, embeddings_np)

    # Add embeddings to the batch memory
    batch_memory.batch_add(embeddings)

    # Perform batched search for the specific vector
    k = 1
    search_results = batch_memory.batch_search(specific_vector, k)

    # Check if the first search result is the same as the specific vector
    indices, distances = search_results[0]
    found_vector = batch_memory.memories[0].get_all_embeddings()[indices[0][0]]
    assert np.allclose(specific_vector_np, found_vector)


================================================
FILE: test/memory/test_memory_features.py
================================================
# type: ignore
import numpy as np
import torch
import pytest

from src.memory.associative import AssociativeMemory


@pytest.fixture(params=["numpy", "torch"])
def embeddings(request):
    embeddings_np = np.random.rand(1000, 128).astype(np.float32)
    if request.param == "numpy":
        return embeddings_np
    else:
        return torch.from_numpy(embeddings_np)


@pytest.fixture
def forgetful_memory():
    memory_size = 1000
    embedding_dim = 128
    forgetfulness_factor = 0.9
    return AssociativeMemory(memory_size, embedding_dim, forgetfulness_factor=forgetfulness_factor)


@pytest.fixture
def memory():
    memory_size = 1000
    embedding_dim = 128
    return AssociativeMemory(memory_size, embedding_dim)


def test_age_memory(embeddings, forgetful_memory):
    # Add some embeddings to the memory
    forgetful_memory.add(embeddings[:10])

    # Age the memory
    forgetful_memory.age_memory(0.5)

    # Check if the embeddings have been aged
    aged_embeddings = forgetful_memory.get_all_embeddings()
    assert aged_embeddings.shape == embeddings[:10].shape
    assert np.allclose(embeddings[:10] * 0.5, aged_embeddings)


def test_forget_randomly(embeddings, forgetful_memory):
    # Add some embeddings to the memory
    forgetful_memory.add(embeddings[:100])

    # Forget randomly
    forgetful_memory.forget_randomly()

    # Check if the number of embeddings in memory has decreased
    remaining_embeddings = forgetful_memory.get_all_embeddings()
    expected_remaining_embeddings = int(embeddings[:100].shape[0] * (1 - forgetful_memory.forgetfulness_factor))
    assert (
        remaining_embeddings.shape[0] == expected_remaining_embeddings
        or remaining_embeddings.shape[0] == expected_remaining_embeddings + 1
    )


def test_garbage_collect(embeddings, memory):
    # Add some nearly zero embeddings to the memory
    nearly_zero_embeddings = embeddings[:10] * 1e-7
    memory.add(nearly_zero_embeddings)

    # Garbage collect
    removed_indices = memory.garbage_collect(threshold=1e-6)

    # Check if the nearly zero embeddings have been removed
    assert len(removed_indices) == len(nearly_zero_embeddings)
    remaining_embeddings = memory.get_all_embeddings()
    assert remaining_embeddings.shape[0] == 0
    assert np.array_equal(removed_indices, np.arange(len(nearly_zero_embeddings)))


================================================
FILE: test/memory/test_memory_types.py
================================================
# type: ignore
import numpy as np
import torch
import pytest

from src.memory.associative import AssociativeMemory


@pytest.fixture(params=["flat", "compressed", "graph"])
def index_type(request):
    return request.param


@pytest.fixture(params=["numpy", "torch"])
def tensor_type(request):
    return request.param


def to_tensor(tensor_type, array):
    if tensor_type == "numpy":
        return array
    elif tensor_type == "torch":
        return torch.from_numpy(array)
    else:
        raise ValueError("Unsupported tensor_type")


def test_associative_memory_add_remove_update_search(index_type, tensor_type):
    memory_size = 1000
    embedding_dim = 128
    num_test_vectors = 100
    k = 10

    memory = AssociativeMemory(memory_size, embedding_dim, index_type=index_type)

    # Add test embeddings to memory
    test_embeddings = to_tensor(tensor_type, np.random.random((num_test_vectors, embedding_dim)).astype(np.float32))
    memory.add(test_embeddings)

    # Search for closest embeddings in memory
    query_vectors = to_tensor(tensor_type, np.random.random((5, embedding_dim)).astype(np.float32))
    search_results, search_distances = memory.search(query_vectors, k)

    assert search_results.shape == (query_vectors.shape[0], k)

    if index_type == "flat":
        # Remove some embeddings from memory
        ids_to_remove = np.array([2, 5, 10, 30, 50])
        memory.remove(ids_to_remove)

        # Update some embeddings in memory
        ids_to_update = np.array([0, 1, 3, 4])
        updated_embeddings = to_tensor(
            tensor_type, np.random.random((len(ids_to_update), embedding_dim)).astype(np.float32)
        )

        memory.update(ids_to_update, updated_embeddings)

        # Check updated embeddings
        updated_search_results, updated_distances = memory.search(updated_embeddings, k=1)
        for i, _ in enumerate(ids_to_update):
            assert np.isclose(updated_distances[i, 0], 0, atol=1e-6)

    elif index_type in ["compressed", "graph"]:
        with pytest.raises(ValueError):
            memory.remove(np.array([0]))

        with pytest.raises(ValueError):
            memory.update(
                np.array([0]), to_tensor(tensor_type, np.random.random((1, embedding_dim)).astype(np.float32))
            )


================================================
FILE: test/memory/test_multitoken.py
================================================
# type: ignore
import numpy as np
import torch
from src.memory.multitoken_batch_associative import MultiTokenBatchAssociativeMemory


def test_batch_add():
    num_batches = 2
    memory_size = 5
    embedding_dim = 3
    num_tokens = 4

    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)

    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)
    memory.batch_add(embeddings)

    for mem in memory.memories:
        assert mem.size() == num_tokens


def test_batch_remove():
    num_batches = 2
    memory_size = 5
    embedding_dim = 3
    num_tokens = 4

    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)

    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)
    memory.batch_add(embeddings)

    ids_to_remove = np.array([[0, 1], [2, 3]])
    memory.batch_remove(ids_to_remove)

    for mem in memory.memories:
        assert mem.size() == num_tokens - 2


def test_batch_search():
    num_batches = 2
    memory_size = 5
    embedding_dim = 3
    num_tokens = 4

    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)

    embeddings = np.random.randn(num_batches, num_tokens, embedding_dim)
    memory.batch_add(embeddings)

    query_vectors = np.random.randn(num_batches, num_tokens, embedding_dim)
    k = 3
    search_results = memory.batch_search(query_vectors, k)

    for indices, distances in search_results:
        assert indices.shape == (num_tokens, k)
        assert distances.shape == (num_tokens, k)


def test_torch_tensors():
    num_batches = 2
    memory_size = 5
    embedding_dim = 3
    num_tokens = 4

    memory = MultiTokenBatchAssociativeMemory(num_batches, memory_size, embedding_dim)

    embeddings = torch.randn(num_batches, num_tokens, embedding_dim)
    memory.batch_add(embeddings)

    query_vectors = torch.randn(num_batches, num_tokens, embedding_dim)
    k = 3
    search_results = memory.batch_search(query_vectors, k)

    for indices, distances in search_results:
        assert indices.shape == (num_tokens, k)
        assert distances.shape == (num_tokens, k)


================================================
FILE: test/models/__init__.py
================================================


================================================
FILE: test/models/test_gpt2_associative.py
================================================
# type: ignore
import pytest
import torch
from src.models.gpt2_associative import VardaGPTAssociative
from transformers import GPT2Config


@pytest.fixture
def varda_gpt_associative():
    return VardaGPTAssociative(gpt2_model_name="gpt2", use_gpu=False, batch_size=2)


def test_initialization(varda_gpt_associative):
    assert isinstance(varda_gpt_associative, VardaGPTAssociative)
    assert varda_gpt_associative.device.type == "cpu"
    assert isinstance(varda_gpt_associative.gpt2_config, GPT2Config)
    assert varda_gpt_associative.num_search_results == 5
    assert varda_gpt_associative.forgetfulness_factor == 0.001


def test_forward_pass_no_memory(varda_gpt_associative):
    batch_size = 2
    sequence_length = 4
    input_dim = varda_gpt_associative.gpt2_config.n_embd

    input_vectors = torch.randn(batch_size, sequence_length, input_dim)
    logits = varda_gpt_associative.forward(input_vectors)

    assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size)


def test_forward_pass_with_memory(varda_gpt_associative):
    batch_size = 2
    sequence_length = 4
    input_dim = varda_gpt_associative.gpt2_config.n_embd

    input_vectors = torch.randn(batch_size, sequence_length, input_dim)
    logits = varda_gpt_associative.forward(input_vectors)

    assert logits.shape == (batch_size, sequence_length, varda_gpt_associative.gpt2_config.vocab_size)


================================================
FILE: test/test_training.py
================================================
Download .txt
gitextract_5qwi3ryd/

├── .github/
│   └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Makefile
├── README.md
├── STORY.md
├── requirements.txt
├── setup.cfg
├── src/
│   ├── data.py
│   ├── memory/
│   │   ├── __init__.py
│   │   ├── associative.py
│   │   ├── batch_associative.py
│   │   └── multitoken_batch_associative.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── gpt2_associative.py
│   │   └── gpt2_working.py
│   ├── train.py
│   └── train_parallel.py
└── test/
    ├── __init__.py
    ├── memory/
    │   ├── __init__..py
    │   ├── test_associative.py
    │   ├── test_associative_batch.py
    │   ├── test_memory_features.py
    │   ├── test_memory_types.py
    │   └── test_multitoken.py
    ├── models/
    │   ├── __init__.py
    │   └── test_gpt2_associative.py
    └── test_training.py
Download .txt
SYMBOL INDEX (67 symbols across 15 files)

FILE: src/data.py
  function load_wikitext2 (line 10) | def load_wikitext2() -> Tuple[DataLoader[Any], DataLoader[Any], DataLoad...

FILE: src/memory/__init__.py
  class Memory (line 8) | class Memory:
    method add (line 13) | def add(self, embeddings: npt.NDArray[Any]) -> None:
    method remove (line 16) | def remove(self, ids: npt.NDArray[Any]) -> None:
    method update (line 19) | def update(self, ids: npt.NDArray[Any], updated_embeddings: npt.NDArra...
    method search (line 23) | def search(self, query_vectors: npt.NDArray[Any], k: int = 10) -> Any:
    method refresh (line 26) | def refresh(self) -> None:

FILE: src/memory/associative.py
  class AssociativeMemory (line 11) | class AssociativeMemory(Memory):
    method __init__ (line 12) | def __init__(
    method _to_numpy (line 78) | def _to_numpy(self, data: Union[npt.NDArray[Any], torch.Tensor]) -> Un...
    method add (line 89) | def add(self, embeddings: Union[npt.NDArray[Any], torch.Tensor]) -> None:
    method remove (line 110) | def remove(self, ids: Union[npt.NDArray[Any], torch.Tensor]) -> None:
    method update (line 128) | def update(
    method search (line 149) | def search(self, query_vectors: Any, k: int = 10) -> Any:
    method age_memory (line 178) | def age_memory(self, decay_factor: float = 0.999) -> None:
    method forget_randomly (line 198) | def forget_randomly(self) -> None:
    method garbage_collect (line 213) | def garbage_collect(self, threshold: float = 1e-6) -> npt.NDArray[Any]:
    method get_all_embeddings (line 237) | def get_all_embeddings(self) -> npt.NDArray[Any]:
    method __getitem__ (line 248) | def __getitem__(self, index: int) -> Any:
    method size (line 259) | def size(self) -> int:

FILE: src/memory/batch_associative.py
  class BatchAssociativeMemory (line 16) | class BatchAssociativeMemory:
    method __init__ (line 17) | def __init__(
    method batch_add (line 36) | def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"...
    method batch_remove (line 46) | def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -...
    method batch_search (line 55) | def batch_search(

FILE: src/memory/multitoken_batch_associative.py
  class MultiTokenBatchAssociativeMemory (line 16) | class MultiTokenBatchAssociativeMemory:
    method __init__ (line 17) | def __init__(
    method batch_add (line 36) | def batch_add(self, embeddings: Union[npt.NDArray[Any], "torch.Tensor"...
    method batch_remove (line 46) | def batch_remove(self, ids: Union[npt.NDArray[Any], "torch.Tensor"]) -...
    method batch_search (line 55) | def batch_search(

FILE: src/models/gpt2_associative.py
  class VardaGPTAssociative (line 10) | class VardaGPTAssociative(nn.Module):
    method __init__ (line 11) | def __init__(
    method forward (line 69) | def forward(self, input_vectors: torch.Tensor) -> Any:

FILE: src/models/gpt2_working.py
  class VardaGPTWorking (line 10) | class VardaGPTWorking(nn.Module):
    method __init__ (line 11) | def __init__(
    method forward (line 52) | def forward(self, input_vectors: torch.Tensor, memory_input: Optional[...

FILE: src/train.py
  function train (line 19) | def train(

FILE: src/train_parallel.py
  function train (line 19) | def train(

FILE: test/memory/test_associative.py
  function embeddings (line 10) | def embeddings(request):
  function test_add_and_search (line 18) | def test_add_and_search(embeddings):
  function test_update (line 29) | def test_update(embeddings):
  function test_refresh_memory (line 44) | def test_refresh_memory(embeddings):
  function test_getitem (line 57) | def test_getitem(embeddings):

FILE: test/memory/test_associative_batch.py
  function batch_memory (line 14) | def batch_memory():
  function tensor_type (line 22) | def tensor_type(request):
  function create_tensor (line 26) | def create_tensor(tensor_type, data):
  function test_batch_add (line 33) | def test_batch_add(batch_memory, tensor_type):
  function test_batch_remove (line 48) | def test_batch_remove(batch_memory, tensor_type):
  function test_batch_search (line 67) | def test_batch_search(batch_memory, tensor_type):

FILE: test/memory/test_memory_features.py
  function embeddings (line 10) | def embeddings(request):
  function forgetful_memory (line 19) | def forgetful_memory():
  function memory (line 27) | def memory():
  function test_age_memory (line 33) | def test_age_memory(embeddings, forgetful_memory):
  function test_forget_randomly (line 46) | def test_forget_randomly(embeddings, forgetful_memory):
  function test_garbage_collect (line 62) | def test_garbage_collect(embeddings, memory):

FILE: test/memory/test_memory_types.py
  function index_type (line 10) | def index_type(request):
  function tensor_type (line 15) | def tensor_type(request):
  function to_tensor (line 19) | def to_tensor(tensor_type, array):
  function test_associative_memory_add_remove_update_search (line 28) | def test_associative_memory_add_remove_update_search(index_type, tensor_...

FILE: test/memory/test_multitoken.py
  function test_batch_add (line 7) | def test_batch_add():
  function test_batch_remove (line 22) | def test_batch_remove():
  function test_batch_search (line 40) | def test_batch_search():
  function test_torch_tensors (line 60) | def test_torch_tensors():

FILE: test/models/test_gpt2_associative.py
  function varda_gpt_associative (line 9) | def varda_gpt_associative():
  function test_initialization (line 13) | def test_initialization(varda_gpt_associative):
  function test_forward_pass_no_memory (line 21) | def test_forward_pass_no_memory(varda_gpt_associative):
  function test_forward_pass_with_memory (line 32) | def test_forward_pass_with_memory(varda_gpt_associative):
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (83K chars).
[
  {
    "path": ".github/ci.yml",
    "chars": 809,
    "preview": "name: Python CI\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    st"
  },
  {
    "path": ".gitignore",
    "chars": 171,
    "preview": "__pycache__/\n*.py[cod]\n*.mo\n/venv\n.env\nenv.production\nenv.devtest\n.DS_Store\n._*\n*.log\n!.vscode/settings.json\n.vscode/lau"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 1307,
    "preview": "repos:\n  - repo: https://github.com/psf/black\n    rev: 22.3.0\n    hooks:\n      - id: black\n        args: [--line-length="
  },
  {
    "path": "Makefile",
    "chars": 1624,
    "preview": ".PHONY: help setup train evaluate inference precommit format clean\n\n.DEFAULT_GOAL := help\n\nhelp:\n\t@echo \"\u001b[35mVardaGPT\u001b["
  },
  {
    "path": "README.md",
    "chars": 14266,
    "preview": "# VardaGPT\n\n<!-- START doctoc generated TOC please keep comment here to allow auto update -->\n<!-- DON'T EDIT THIS SECTI"
  },
  {
    "path": "STORY.md",
    "chars": 4599,
    "preview": "# Story of this project 😅\n\n## Background 🤔\n\nWith all the hype around ChatGPT, I wondered how much impact ChatGPT really "
  },
  {
    "path": "requirements.txt",
    "chars": 1137,
    "preview": "black==23.3.0\ncertifi==2022.12.7\ncharset-normalizer==3.1.0\nclick==8.1.3\ncmake==3.26.3\ncoverage==7.2.3\nexceptiongroup==1."
  },
  {
    "path": "setup.cfg",
    "chars": 400,
    "preview": "[flake8]\nexclude = */__init__.py,migrations/*\nignore = E111, E114, E121, E131, W503, F405, F403, E126, E501, F841, E124,"
  },
  {
    "path": "src/data.py",
    "chars": 1455,
    "preview": "from typing import Any, Tuple\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchtext.data import get_token"
  },
  {
    "path": "src/memory/__init__.py",
    "chars": 702,
    "preview": "from typing import Any, Optional\n\nimport numpy as np\nimport numpy.typing as npt\nfrom faiss import IndexIVF\n\n\nclass Memor"
  },
  {
    "path": "src/memory/associative.py",
    "chars": 11596,
    "preview": "from typing import Any, Union\n\nimport faiss\nimport numpy as np\nimport numpy.typing as npt\nimport torch\n\nfrom . import Me"
  },
  {
    "path": "src/memory/batch_associative.py",
    "chars": 2827,
    "preview": "from typing import Any, List, Tuple, Union\n\nimport numpy as np\nimport numpy.typing as npt\n\ntry:\n    import torch\n\n    _T"
  },
  {
    "path": "src/memory/multitoken_batch_associative.py",
    "chars": 2672,
    "preview": "from typing import Any, List, Tuple, Union\n\nimport numpy as np\nimport numpy.typing as npt\n\ntry:\n    import torch\n\n    _T"
  },
  {
    "path": "src/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/gpt2_associative.py",
    "chars": 6098,
    "preview": "from typing import Any\n\nimport torch\nimport torch.nn as nn\nfrom transformers import GPT2Config, GPT2LMHeadModel\n\nfrom me"
  },
  {
    "path": "src/models/gpt2_working.py",
    "chars": 3506,
    "preview": "from typing import Any, Optional\n\nimport torch\nimport torch.nn as nn\nfrom transformers import GPT2Config, GPT2LMHeadMode"
  },
  {
    "path": "src/train.py",
    "chars": 5895,
    "preview": "import argparse\nimport time\nfrom typing import Any\n\nimport torch\nimport torch.optim as optim\nfrom rich.console import Co"
  },
  {
    "path": "src/train_parallel.py",
    "chars": 6121,
    "preview": "import argparse\nimport time\nfrom typing import Any, Union\n\nimport torch\nimport torch.optim as optim\nfrom rich.console im"
  },
  {
    "path": "test/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "test/memory/__init__..py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "test/memory/test_associative.py",
    "chars": 1553,
    "preview": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@py"
  },
  {
    "path": "test/memory/test_associative_batch.py",
    "chars": 3125,
    "preview": "# type: ignore\nimport numpy as np\nimport pytest\n\ntry:\n    import torch\nexcept ImportError:\n    pass\n\nfrom src.memory.bat"
  },
  {
    "path": "test/memory/test_memory_features.py",
    "chars": 2343,
    "preview": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@py"
  },
  {
    "path": "test/memory/test_memory_types.py",
    "chars": 2286,
    "preview": "# type: ignore\nimport numpy as np\nimport torch\nimport pytest\n\nfrom src.memory.associative import AssociativeMemory\n\n\n@py"
  },
  {
    "path": "test/memory/test_multitoken.py",
    "chars": 2153,
    "preview": "# type: ignore\nimport numpy as np\nimport torch\nfrom src.memory.multitoken_batch_associative import MultiTokenBatchAssoci"
  },
  {
    "path": "test/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "test/models/test_gpt2_associative.py",
    "chars": 1416,
    "preview": "# type: ignore\nimport pytest\nimport torch\nfrom src.models.gpt2_associative import VardaGPTAssociative\nfrom transformers "
  },
  {
    "path": "test/test_training.py",
    "chars": 0,
    "preview": ""
  }
]

About this extraction

This page contains the full source code of the ixaxaar/VardaGPT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (76.2 KB), approximately 19.4k tokens, and a symbol index with 67 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!