Full Code of mlx-chat/mlx-chat-app for AI

main 20f441dc1188 cached
66 files
201.3 KB
50.2k tokens
197 symbols
1 requests
Download .txt
Showing preview only (218K chars total). Download the full file or copy to clipboard to get everything.
Repository: mlx-chat/mlx-chat-app
Branch: main
Commit: 20f441dc1188
Files: 66
Total size: 201.3 KB

Directory structure:
gitextract_9znslswk/

├── .github/
│   └── workflows/
│       └── lint.yml
├── .gitignore
├── .vscode/
│   └── settings.json
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app/
│   ├── .eslintrc.cjs
│   ├── assets/
│   │   └── icon.icns
│   ├── components.json
│   ├── dprint.json
│   ├── mac/
│   │   └── entitlements.mac.inherit.plist
│   ├── main/
│   │   ├── main.ts
│   │   ├── preload.ts
│   │   ├── renderer.d.ts
│   │   ├── splash/
│   │   │   ├── index.css
│   │   │   └── index.html
│   │   └── tsconfig.json
│   ├── next.config.js
│   ├── notarize.js
│   ├── package.json
│   ├── postcss.config.js
│   ├── src/
│   │   ├── AppProvider.tsx
│   │   ├── app/
│   │   │   ├── globals.css
│   │   │   ├── layout.tsx
│   │   │   ├── page.tsx
│   │   │   └── settings/
│   │   │       └── page.tsx
│   │   ├── components/
│   │   │   ├── chat/
│   │   │   │   ├── Chat.tsx
│   │   │   │   ├── ChatInput.tsx
│   │   │   │   ├── ChatMessage.tsx
│   │   │   │   ├── ChatMessages.tsx
│   │   │   │   └── SystemMessage.tsx
│   │   │   ├── options/
│   │   │   │   ├── SelectDirectory.tsx
│   │   │   │   └── SelectModel.tsx
│   │   │   └── ui/
│   │   │       ├── button.tsx
│   │   │       ├── input.tsx
│   │   │       ├── resizable.tsx
│   │   │       ├── select.tsx
│   │   │       ├── textarea.tsx
│   │   │       └── tooltip.tsx
│   │   ├── constants/
│   │   │   └── chat.tsx
│   │   └── lib/
│   │       ├── hooks.ts
│   │       ├── store.ts
│   │       └── utils.ts
│   ├── tailwind.config.main.js
│   ├── tailwind.config.ts
│   └── tsconfig.json
├── runner.py
├── runner.sh
└── server/
    ├── __init__.py
    ├── convert.py
    ├── models/
    │   ├── __init__.py
    │   ├── base.py
    │   ├── bert.py
    │   ├── gemma.py
    │   ├── layers.py
    │   └── llama.py
    ├── py.typed
    ├── requirements.txt
    ├── retriever/
    │   ├── document.py
    │   ├── embeddings.py
    │   ├── loader.py
    │   ├── splitter.py
    │   └── vectorstore.py
    ├── server.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/lint.yml
================================================
name: Lint

on: [pull_request]

jobs:
  lint:
    name: Lint
    runs-on: ubuntu-latest

    steps:
      - uses: actions/checkout@v1
        with:
          fetch-depth: 1

      - name: Use Node.js 16
        uses: actions/setup-node@v1
        with:
          node-version: 16

      - name: Install App Deps
        run: npm i --ignore-scripts
        working-directory: ./app
      - name: Lint App
        working-directory: ./app
        run: npm run lint


================================================
FILE: .gitignore
================================================
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

# dependencies
node_modules/
.pnp/
.pnp.js

# testing
coverage/

# next.js
.next/
out/

# production
build/
app/main/tailwind.css
dist/

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env*.local

# vercel
.vercel

# typescript
*.tsbuildinfo
next-env.d.ts

# Byte-compiled / optimized / DLL files
 __pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
server/lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: .vscode/settings.json
================================================
{
    "[python]": {
      "editor.tabSize": 4,
      "editor.defaultFormatter": "ms-python.autopep8",
    },
    // "It is recommended that either ESLint or cspell checks a file, but not both."
    // https://www.npmjs.com/package/@cspell/eslint-plugin
    "cSpell.enableFiletypes": [
      "!javascript",
      "!typescript",
    ],
    "css.format.spaceAroundSelectorSeparator": true,
    "editor.codeActionsOnSave": [
      "source.fixAll.eslint",
    ],
    "eslint.codeActionsOnSave.mode": "problems",
    "eslint.options": {
      "reportUnusedDisableDirectives": "error",
    },
    "eslint.rules.customizations": [
      { "rule": "*", "severity": "warn" },
    ],
    "editor.defaultFormatter": "dprint.dprint",
    "editor.formatOnSave": true,
    "editor.tabSize": 2,
    "editor.wordWrapColumn": 100,
    "eslint.workingDirectories": [
      "./app",
    ],
    "files.insertFinalNewline": true,
    "files.trimFinalNewlines": true,
    "git.allowForcePush": true,
    "git.inputValidationSubjectLength": 100,
    "git.inputValidationLength": 100,
    "javascript.preferences.quoteStyle": "single",
    "typescript.preferences.quoteStyle": "single",
    "scss.format.spaceAroundSelectorSeparator": true,
    "typescript.tsdk": "node_modules/typescript/lib",
  }
  

================================================
FILE: CODEOWNERS
================================================
*       @parkersm1th @stockeh


================================================
FILE: CONTRIBUTING.md
================================================
# Welcome to our contribution guide

Thank you for wanting to contribute to our project! We apprecaite any contributions that you make.

Chat with MLX is an open source project and we love to receive contributions from our community — you! There are many ways to contribute, from writing tutorials or blog posts, improving the documentation, submitting bug reports and feature requests or writing code which can be incorporated into the application itself.

## New Contributor Guide

To get an overview of the project, read the [README](https://github.com/mlx-chat/mlx-chat-app/blob/main/README.md). Here are some resources to help you get started with open source contributions:

- [Ways to contribute on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github)
- [Setup Git](https://docs.github.com/en/get-started/quickstart/set-up-git)
- [GitHub workflow](https://docs.github.com/en/get-started/quickstart/github-flow)
- [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests)

## Getting Started

### Issues

**Create**: If you spot a problem, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). If a related issue doesn't exist, you can open a new issue!

**Solve**: Scan through our [existing issues](https://github.com/mlx-chat/mlx-chat-app) to find one that interests you. If you find an issue to work on, you are welcome to assign it to yourself and open a PR with a fix.

### Make Changes

1. Create your own fork of the code
2. Create a working branch and start with your changes
3. Commit and send a pull request 

### Pull Request

When you're finished with the changes, create a pull request.
- Check to see your pull request passes our continuous integration (CI). If you cannot get a certain integration test to pass, let us know. We can assist you in fixing these issues or approve a merge manually.
- Make sure your additions are properly documented!
- Don't forget to [link PR to issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) if you are solving one.
- We may ask for changes to be made before a PR can be merged, either using [suggested changes](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/incorporating-feedback-in-your-pull-request) or pull request comments. You can apply suggested changes directly through the UI. You can make any other changes in your fork, then commit them to your branch.
- As you update your PR and apply changes, mark each conversation as [resolved](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/commenting-on-a-pull-request#resolving-conversations).
- If you run into any merge issues, checkout this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues.

### Your PR is Merged!

Congratulations :tada::tada: we thank you! :sparkles:

Once your PR is merged, your contributions will be publicly visible in the [Chat with MLX Repository](https://github.com/mlx-chat/mlx-chat-app).


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2024 MLX Chat

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================

![](docs/design-logo-light.png#gh-light-mode-only)
![](docs/design-logo-dark.png#gh-dark-mode-only)


**Chat with MLX** is a high-performance macOS application that connects your local documents to a personalized large language model (LLM). By leveraging retrieval-augmented generation (RAG), open source LLMs, and MLX for accelerated machine learning on Apple silicon, you can efficently search, query, and interact with your documents *without information ever leaving your device.*

Our high-level features include:
- **Query**: load and search with document-specific prompts
- **Converse**: switch model interaction modes (converse vs. assist) in real time
- **Instruct**: provide personalization and response tuning

## Installation and Setup

:warning: **Preliminary Steps**: we are working to release with correct packaging ([pyinstaller](https://github.com/pyinstaller/pyinstaller/) & [electron-builder](https://github.com/electron-userland/electron-builder)) and authentication ([Apple codesign](https://developer.apple.com/support/code-signing/)). In the interium, please clone and run in development by first setting up authentication and requirements. 

First, setup huggingface [access tokens](https://huggingface.co/settings/tokens) to download models (request access to [google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)), then
```bash
huggingface-cli login
```
Then download the npm/python requirements
```bash
cd app && npm install
pip install -r server/requirements.txt
```
Finally, start the application
```bash
cd app && npm run dev
```

## Contributions
All contributions are welcome. Please take a look at [contributing](CONTRIBUTING.md) guide.


================================================
FILE: app/.eslintrc.cjs
================================================
const namingConventions = [
  'error',
  {
    format: ['camelCase'],
    selector: 'default',
  },
  {
    format: ['camelCase', 'UPPER_CASE'],
    selector: 'variable',
  },
  {
    format: ['camelCase', 'UPPER_CASE', 'PascalCase'],
    modifiers: ['const', 'exported', 'global'],
    selector: 'variable',
  },
  {
    format: ['camelCase'],
    leadingUnderscore: 'allow',
    selector: 'parameter',
  },
  {
    format: ['camelCase'],
    leadingUnderscore: 'allow',
    modifiers: ['private'],
    selector: 'memberLike',
  },
  {
    format: ['PascalCase', 'UPPER_CASE'],
    selector: ['enum', 'enumMember'],
  },
  {
    format: ['PascalCase'],
    selector: 'typeLike',
  },
  {
    format: null,
    modifiers: ['destructured'],
    selector: 'variable',
  },
  {
    format: null,
    modifiers: ['requiresQuotes'],
    selector: [
      'classProperty',
      'objectLiteralProperty',
      'typeProperty',
      'classMethod',
      'objectLiteralMethod',
      'typeMethod',
      'accessor',
      'enumMember',
    ],
  },
  {
    format: ['camelCase', 'PascalCase', 'UPPER_CASE'],
    leadingUnderscore: 'allow',
    selector: 'import',
  },
];

const tsxNamingConventions = [
  {
    format: ['camelCase', 'PascalCase', 'UPPER_CASE'],
    leadingUnderscore: 'forbid',
    modifiers: ['global'],
    selector: ['variable', 'function'],
  },
];

module.exports = {
  env: {
    es2020: true,
  },
  extends: [
    'airbnb-base',
    'plugin:jsdoc/recommended',
    'plugin:@typescript-eslint/recommended',
    'plugin:import/typescript',
    'plugin:no-unsanitized/DOM',
  ],
  ignorePatterns: [
    'node_modules',
    'main',
    '.eslintrc.*',
    'out',
  ],
  overrides: [
    // Config files
    {
      files: [
        'common/**/*.ts*',
        '**/app*config.ts',
        '**/app*Config.ts',
      ],
      rules: {
        '@typescript-eslint/member-ordering': ['error', { default: { order: 'alphabetically' } }],
        'sort-keys': ['error', 'asc', { minKeys: 2, natural: true }],
      },
    },
    {
      files: [
        '*.ts*',
      ],
      rules: {
        '@typescript-eslint/no-shadow': 'error',
        '@typescript-eslint/no-unused-vars': 'off', // Using unused-imports plugin instead
        '@typescript-eslint/space-before-function-paren': [
          'error',
          {
            anonymous: 'never',
            asyncArrow: 'always',
            named: 'never',
          },
        ],
        'no-redeclare': 'off', // @typescript-eslint/no-redeclare is enabled and is more correct
        'no-shadow': 'off', // @typescript-eslint/no-shadow is enabled and is more correct
        'no-undef-init': 'off',
        'no-unused-vars': 'off', // Using unused-imports plugin instead
        'space-before-function-paren': 'off', // Using @typescript-eslint/space-before-function-paren instead
        'unused-imports/no-unused-imports': 'error',
        'unused-imports/no-unused-vars': ['error', {
          args: 'after-used',
          argsIgnorePattern: '^_',
          destructuredArrayIgnorePattern: '^_',
          ignoreRestSiblings: true,
        }],
      },
    },
    {
      files: [
        'src/**/*.ts*',
      ],
      rules: {
        '@typescript-eslint/await-thenable': 'error',
        '@typescript-eslint/dot-notation': ['error', { allowIndexSignaturePropertyAccess: true }],
        '@typescript-eslint/no-base-to-string': ['error', {
          ignoredTypeNames: ['Error', 'RegExp'],
        }],
        '@typescript-eslint/no-floating-promises': 'error',
        '@typescript-eslint/no-for-in-array': 'error',
        '@typescript-eslint/no-misused-promises': ['error', { checksVoidReturn: false }],
        '@typescript-eslint/no-throw-literal': 'error',
        '@typescript-eslint/no-unnecessary-condition': 'error',
        '@typescript-eslint/no-unnecessary-type-assertion': 'error',
        '@typescript-eslint/non-nullable-type-assertion-style': 'error',
        '@typescript-eslint/prefer-includes': 'error',
        '@typescript-eslint/prefer-optional-chain': 'error',
        '@typescript-eslint/prefer-string-starts-ends-with': 'error',
        '@typescript-eslint/require-await': 'error',
        '@typescript-eslint/space-infix-ops': 'error',
        'dot-notation': 'off',
        'no-throw-literal': 'off',
        'require-await': 'off',
        'space-infix-ops': 'off',
      },
    },
    {
      files: [
        '*.tsx',
      ],
      rules: {
        '@typescript-eslint/naming-convention': [
          ...namingConventions,
          ...tsxNamingConventions,
        ],
        '@typescript-eslint/require-await': 'error',
        'require-await': 'off',
      },
    },
  ],
  parser: '@typescript-eslint/parser',
  parserOptions: {
    project: 'tsconfig.json',
  },
  plugins: [
    'react',
    '@typescript-eslint',
    'jest-formatting',
    'modules-newlines',
    'unused-imports',
  ],
  root: true,
  rules: {
    '@typescript-eslint/ban-types': [
      'error',
      {
        extendDefaults: true,
        types: {
          object: {
            message: [
              'The `object` type is currently hard to use ([see this issue](https://github.com/microsoft/TypeScript/issues/21732)).',
              'Consider using `Record<string, unknown>` instead, as it allows you to more easily inspect and use the keys.',
            ].join('\n'),
          },
        },
      },
    ],
    'implicit-arrow-linebreak': 'off',
    '@typescript-eslint/consistent-type-assertions': ['error', { assertionStyle: 'never' }],
    '@typescript-eslint/consistent-type-imports': 'error',
    '@typescript-eslint/init-declarations': 'error',
    '@typescript-eslint/member-ordering': 'error',
    '@typescript-eslint/naming-convention': namingConventions,
    '@typescript-eslint/no-explicit-any': 'error',
    '@typescript-eslint/no-non-null-asserted-nullish-coalescing': 'error',
    '@typescript-eslint/no-use-before-define': ['error', {
      functions: false,
    }],
    '@typescript-eslint/prefer-for-of': 'error',
    '@typescript-eslint/type-annotation-spacing': 'error',
    'array-element-newline': ['error', 'consistent'],
    'block-spacing': 'off',
    camelcase: 'off', // Using @typescript-eslint/naming-convention instead.
    'comma-dangle': 'off',
    'default-param-last': 'off',
    'import/extensions': 'off',
    'import/no-relative-packages': 'off',
    'import/order': 'off',
    'import/prefer-default-export': 'off',
    'jsdoc/check-indentation': ['error', { excludeTags: ['description', 'example'] }],
    'jsdoc/check-line-alignment': 'error',
    'jsdoc/check-tag-names': ['error', {
      definedTags: ['jest-environment', 'jest-environment-options'],
    }],
    'jsdoc/no-bad-blocks': 'error',
    'jsdoc/no-multi-asterisks': 'off',
    'jsdoc/no-undefined-types': 'off',
    'jsdoc/require-jsdoc': 'off',
    'jsdoc/require-param': 'off',
    'jsdoc/require-param-description': 'off',
    'jsdoc/require-param-name': 'off',
    'jsdoc/require-param-type': 'off',
    'jsdoc/require-property': 'off',
    'jsdoc/require-property-description': 'off',
    'jsdoc/require-property-name': 'off',
    'jsdoc/require-property-type': 'off',
    'jsdoc/require-returns': 'off',
    'jsdoc/require-returns-description': 'off',
    'jsdoc/require-returns-type': 'off',
    'jsdoc/require-yields': 'off',
    'jsdoc/require-yields-check': 'off',
    'jsdoc/tag-lines': ['error', 'any', { startLines: 1 }],
    'max-classes-per-file': 'off',
    'max-len': ['error', {
      code: 100,
      ignorePattern: '(/* eslint |eslint-disable-next-line |@ts-expect-error )',
      ignoreRegExpLiterals: true,
      ignoreStrings: true,
      ignoreTemplateLiterals: true,
      ignoreUrls: true,
    }],
    'max-params': ['error', 3],
    'new-cap': [
      'error',
      {
        capIsNew: true,
        capIsNewExceptions: [
          'express.Router',
          'Immutable.Map',
          'Immutable.Set',
          'Immutable.List',
          'RightRailView',
          'URLWithSearchParams',
        ],
        newIsCap: true,
        newIsCapExceptions: [],
        properties: true,
      },
    ],
    'no-console': 'error',
    'no-continue': 'off',
    'no-empty-function': 'off',
    'no-promise-executor-return': 'off',
    'no-redeclare': 'error',
    'no-restricted-properties': [
      'error',
    ],
    'no-restricted-syntax': [
      'error',
    ],
    'no-use-before-define': 'off',
    'no-void': ['error', { allowAsStatement: true }],
    'padding-line-between-statements': [
      'error',
      { blankLine: 'never', next: 'import', prev: 'import' },
    ],
    'prefer-arrow-callback': ['error', { allowNamedFunctions: true }],
    'prefer-exponentiation-operator': 'off',
    'prefer-regex-literals': 'off',
  },
  settings: {
    'import/typescript': {
      typescript: {},
    },
  },
};


================================================
FILE: app/components.json
================================================
{
  "$schema": "https://ui.shadcn.com/schema.json",
  "style": "new-york",
  "rsc": true,
  "tsx": true,
  "tailwind": {
    "config": "tailwind.config.ts",
    "css": "main/splash/index.css",
    "baseColor": "slate",
    "cssVariables": true,
    "prefix": ""
  },
  "aliases": {
    "components": "@/components",
    "utils": "@/lib/utils"
  }
}

================================================
FILE: app/dprint.json
================================================
{
    "lineWidth": 100,
    "typescript": {
      "indentWidth": 2,
      "quoteStyle": "alwaysSingle",
      "semiColons": "always",
      "quoteProps": "asNeeded",
      "useBraces": "always",
      "trailingCommas": "onlyMultiLine",
      "module.sortImportDeclarations": "caseInsensitive",
      "exportDeclaration.forceMultiLine": true,
      "importDeclaration.forceMultiLine": true
    },
    "json": {
      "jsonTrailingCommaFiles": [
        ".vscode/launch.json",
        ".vscode/extensions.json",
        ".vscode/settings.json",
        ".vscode/tasks.json",
        "tsconfig.json"
      ]
    },
    "excludes": [
      "**/node_modules",
      "**/*-lock.json",
      "**/Dockerfile",
      "**/src/ui-tests/fixtures/**/*",
      "**/storybook-static/**/*",
      "**/build/**/*",
      "**/dist/**/*",
      "**/artifacts/**/*",
      "extension/src/assets/**/*.json"
    ],
    "plugins": [
      "https://plugins.dprint.dev/typescript-0.88.3.wasm",
      "https://plugins.dprint.dev/json-0.19.0.wasm",
      "https://plugins.dprint.dev/dockerfile-0.3.0.wasm"
    ]
  }
  

================================================
FILE: app/mac/entitlements.mac.inherit.plist
================================================
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
  <dict>
    <key>com.apple.security.cs.allow-jit</key>
    <true/>
    <key>com.apple.security.cs.allow-unsigned-executable-memory</key>
    <true/>
    <key>com.apple.security.cs.disable-library-validation</key>
    <true/>
  </dict>
</plist>


================================================
FILE: app/main/main.ts
================================================
// Main File for Electron

import {
  exec,
  execFile,
} from 'child_process';
import {
  app,
  BrowserWindow,
  dialog,
  globalShortcut,
  ipcMain,
  Menu,
  nativeImage,
  Tray,
} from 'electron';
import * as contextMenu from 'electron-context-menu';
import Store from 'electron-store';
import * as net from 'net';

const path = require('path');
const serve = require('electron-serve');
const { spawn } = require('child_process');

function handleSetTitle(event: any, title: string) {
  const webContents = event.sender;
  const win = BrowserWindow.fromWebContents(webContents);
  if (win !== null) {
    win.setTitle(title);
  }
}

// Python Server
class ServerManager {
  private serverProcess: any | null = null;
  public port: number | null = null;

  private findOpenPort(startingPort: number): Promise<number> {
    return new Promise<number>((resolve) => {
      const server = net.createServer();

      server.listen(startingPort, () => {
        const port = (server.address() as net.AddressInfo).port;
        server.close(() => resolve(port));
      });

      server.on(
        'error',
        (err: any) => err.code === 'EADDRINUSE' && resolve(this.findOpenPort(startingPort + 1)),
      );
    });
  }

  private runPythonServer(port: number): any {
    const args = ['--host 127.0.0.1', `--port ${port}`];
    const modifiedArgs = args.flatMap(arg => arg.split(/\s+/));
    const pythonProcess = isProd
      ? execFile(path.join(process.resourcesPath, 'server', 'runner'), modifiedArgs)
      : spawn('python', ['-m', 'server.server', ...modifiedArgs], {
        cwd: '../',
      });
    pythonProcess.stdout.on(
      'data',
      (data: Buffer) => console.log('Server output:', data.toString('utf8')),
    );
    pythonProcess.stderr.on(
      'data',
      (data: Buffer) => console.log(`Server error: ${data.toString('utf8')}`),
    );

    return pythonProcess;
  }

  start(model: string): Promise<void> {
    return new Promise<void>((resolve, reject) => {
      this.stop();

      this.findOpenPort(8080).then((port) => {
        this.port = port;
        console.log(`APP: Starting server for model: ${model} on port: ${port}`);
        this.serverProcess = this.runPythonServer(port);

        this.serverProcess.stdout.on('data', async (data: Buffer) => {
          const output = data.toString('utf8');

          await new Promise((resolve) => setTimeout(resolve, 1000));

          // Check if the server is ready
          if (output.includes('Starting httpd')) {
            fetch(`http://127.0.0.1:${port}/api/init`, {
              method: 'POST',
              headers: {
                'Content-Type': 'application/json',
              },
              body: JSON.stringify({ model }),
            }).then(() => {
              resolve(); // Resolve the promise when the server is ready
            }).catch((err) => {
              console.error('Error initializing the server:', err);
              reject(err);
            });
          }
        });

        this.serverProcess.on('close', (code: number | null) => {
          console.log(`Server process exited with code ${code}`);
          this.serverProcess = null;
        });

        this.serverProcess.on('error', (err: any) => {
          console.error(`Error in server process: ${err}`);
          this.serverProcess = null;
          reject(err);
        });
      });
    });
  }

  stop(): void {
    if (this.serverProcess) {
      console.log('Stopping the server...');
      this.serverProcess.kill();
      this.serverProcess = null;
    }
  }
}

// Loading Screen
let splash: BrowserWindow | null;
const createSplashScreen = () => {
  /// create a browser window
  splash = new BrowserWindow(
    {
      width: 200,
      height: 100,
      focusable: false,
      /// remove the window frame, so it will become a frameless window
      frame: false,
      skipTaskbar: true,
      autoHideMenuBar: true,
    },
  );
  splash.setResizable(false);
  splash.loadURL(`file://${__dirname}/../splash/index.html`);
  splash.on('closed', () => (splash = null));
  splash.webContents.on('did-finish-load', () => {
    if (splash) {
      splash.show();
    }
  });
};

// run renderer
const isProd = process.env.NODE_ENV !== 'development';
if (isProd) {
  serve({ directory: 'out' });
} else {
  app.setPath('userData', `${app.getPath('userData')} (development)`);
}

contextMenu.default({
  showInspectElement: !isProd,
});

let openModal: 'settings' | 'directory' | null = null;

let globalWindow: BrowserWindow | null = null;

const triggerShortcut = () => {
  if (openModal || !globalWindow) {
    return;
  }
  if (globalWindow.isFocused()) {
    globalWindow.blur();
    return;
  }
  globalWindow.show();
};

const store = new Store({
  schema: {
    keybind: {
      type: 'string',
      default: 'Cmd+O',
    },
    model: {
      type: 'string',
      default: 'mistralai/Mistral-7B-Instruct-v0.2',
    },
    personalization: {
      type: 'string',
      default: '',
    },
    customResponse: {
      type: 'string',
      default: '',
    },
  },
});

const serverManager = new ServerManager();

const createWindow = () => {
  const icon = nativeImage.createFromPath(
    !isProd
      ? '../assets/IconTemplate.png'
      : path.join(process.resourcesPath, 'IconTemplate.png'),
  );
  // if you want to resize it, be careful, it creates a copy
  const trayIcon = icon.resize({ width: 16 });
  // here is the important part (has to be set on the resized version)
  trayIcon.setTemplateImage(true);
  let tray = new Tray(trayIcon);
  tray.setTitle(isProd ? '' : 'M');

  const win = new BrowserWindow({
    webPreferences: {
      preload: path.join(__dirname, 'preload.js'),
      devTools: !isProd,
    },
    show: false,
    width: 600,
    height: 99,
    resizable: false,
    type: 'panel',
    frame: false,
    skipTaskbar: true,
    autoHideMenuBar: true,
    vibrancy: 'under-window', // on MacOS
    backgroundMaterial: 'acrylic',
    icon: __dirname + '../../assets/public/icon.icns',
  });
  globalWindow = win;
  win.setWindowButtonVisibility(false);
  win.setAlwaysOnTop(true, 'floating');
  win.setVisibleOnAllWorkspaces(true);

  // Expose URL
  if (isProd) {
    win.loadURL('app://./home.html');
  } else {
    // const port = process.argv[2];
    win.loadURL('http://localhost:3000/');
  }

  tray.addListener('click', () => {
    if (win.isFocused()) {
      win.blur();
      return;
    }
    win.show();
  });

  win.webContents.on('did-finish-load', async () => {
    await serverManager.start(store.get('model') as string);
    /// then close the loading screen window and show the main window
    if (splash) {
      splash.close();
    }
    app.dock.hide();
    win.show();
    globalShortcut.register(store.get('keybind') as string, triggerShortcut.bind(null));
  });

  // @ts-expect-error -- We don't have types for electron
  win.on('blur', (event) => {
    if (openModal) {
      win.setAlwaysOnTop(false);
    }
    if (openModal === 'directory') {
      return;
    }
    if (win.webContents.isDevToolsOpened()) {
      return;
    }
    globalShortcut.unregister('Escape');
    globalShortcut.unregister('Cmd+Q');
    win.hide();
    if (openModal) {
      return;
    }

    Menu.sendActionToFirstResponder('hide:');
  });

  win.on('focus', () => {
    globalShortcut.register('Cmd+Q', () => {
      if (!win.isFocused()) {
        return;
      }
      app.quit();
    });
    globalShortcut.register('Escape', () => {
      if (!win.isFocused()) {
        return;
      }
      win.blur();
    });
  });

  let settingsModal: BrowserWindow | null = null;

  const createSettings = () => {
    settingsModal = new BrowserWindow({
      webPreferences: {
        preload: path.join(__dirname, 'preload.js'),
      },
      width: 500,
      height: 500,
      resizable: false,
      minimizable: false,
      titleBarStyle: 'hidden',
      show: false,
      backgroundColor: '#000',
    });

    if (isProd) {
      settingsModal.loadURL('app://./settings.html');
    } else {
      // const port = process.argv[2];
      settingsModal.loadURL('http://localhost:3000/settings');
    }

    settingsModal.on('closed', () => {
      openModal = null;
      settingsModal?.destroy();
      settingsModal = null;
    });

    settingsModal.on('ready-to-show', () => {
      settingsModal?.show();
    });

    return settingsModal;
  };

  const nativeMenus: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [
    {
      label: 'MLX Chat',
      submenu: [
        {
          label: 'Settings',
          click() {
            openModal = 'settings';
            if (settingsModal !== null) {
              settingsModal.close();
            }
            createSettings();
          },
          accelerator: 'Cmd+,',
        },
      ],
    },
    {
      label: 'Edit',
      submenu: [
        { role: 'undo' },
        { role: 'redo' },
        { type: 'separator' },
        { role: 'cut' },
        { role: 'copy' },
        { role: 'paste' },
        { role: 'pasteAndMatchStyle' },
        { role: 'delete' },
        { role: 'selectAll' },
        { type: 'separator' },
        {
          label: 'Speech',
          submenu: [
            { role: 'startSpeaking' },
            { role: 'stopSpeaking' },
          ],
        },
      ],
    },
  ];

  const menu = Menu.buildFromTemplate(nativeMenus);
  Menu.setApplicationMenu(menu);
};

app.whenReady().then(() => {
  ipcMain.on('set-title', handleSetTitle);
  ipcMain.on('select-directory', (event: any) => {
    openModal = 'directory';
    dialog.showOpenDialog({ properties: ['openDirectory'] }).then((result: any) => {
      const win = BrowserWindow.fromWebContents(event.sender);
      // Weird hack to bring the window to the front after allowing windows in front of it
      win?.setAlwaysOnTop(true, 'floating');

      openModal = null;
      event.sender.send('selected-directory', result.filePaths);
    });
  });

  ipcMain.on('resize-window', (event, arg) => {
    const win = BrowserWindow.fromWebContents(event.sender);
    if (!win) {
      return;
    }
    win.setBounds({
      height: arg.height,
    });
    win.center();
  });

  ipcMain.on('fetch-setting', (event, arg) => {
    event.returnValue = store.get(arg);
  });

  ipcMain.on('update-setting', (_event, arg) => {
    if (arg.key === 'keybind') {
      globalShortcut.unregister(store.get('keybind') as string);
      globalShortcut.register(arg.value, triggerShortcut.bind(null));
    }
    store.set(arg.key, arg.value);
  });

  createSplashScreen();

  setTimeout(() => {
    createWindow();
  }, 500);

  app.on('activate', () => {
    if (BrowserWindow.getAllWindows().length === 0) { createWindow(); }
  });
});

app.on('will-quit', () => {
  exec(
    `lsof -i :${serverManager.port} -P | awk 'NR>1 {print $2}' | xargs kill`,
    (err, stdout, stderr) => {
      if (err) {
        console.log(err);
        return;
      }
      console.log(`stdout: ${stdout}`);
      console.log(`stderr: ${stderr}`);
    },
  );
  BrowserWindow.getAllWindows().forEach((win) => {
    win.close();
    win.destroy();
  });
});


================================================
FILE: app/main/preload.ts
================================================
// eslint-disable-next-line import/no-extraneous-dependencies
import {
  contextBridge,
  ipcRenderer,
} from 'electron';

export const electronAPI = {
  setTitle: (title: string) => ipcRenderer.send('set-title', title),
  selectDirectory: () => ipcRenderer.send('select-directory'),
  onSelectDirectory: (cb: (customData: string[]) => void) => {
    ipcRenderer.on('selected-directory', (event, customData) => {
      // eslint-disable-next-line no-console
      console.log(event);
      cb(customData);
    });
  },
  resizeWindow: (height: number) => ipcRenderer.send('resize-window', { height }),
  fetchSetting: (key: string) => ipcRenderer.sendSync('fetch-setting', key),
  updateSetting: (key: string, value: any) => ipcRenderer.send('update-setting', { key, value }),
};

contextBridge.exposeInMainWorld('electronAPI', electronAPI);


================================================
FILE: app/main/renderer.d.ts
================================================
import { electronAPI } from "./preload";

declare global {
  interface Window {
    electronAPI: typeof electronAPI;
  }
}

export {};


================================================
FILE: app/main/splash/index.css
================================================
@tailwind base;

@tailwind components;

@tailwind utilities;

div {
  -webkit-user-select: none;
  -webkit-app-region: drag;
}

.loading-bar {
  display: block;
  height: 0.2em;
  background-color: rgba(255, 255, 255, 0.2);
  position: relative;
  overflow: hidden;
  border-radius: 1rem;
}

.loading-bar:before {
  content: "";
  display: block;
  position: absolute;
  left: -100%;
  width: 100%;
  height: 100%;
  background-color: white;
  animation: loading-bar 1.5s ease-in-out infinite;
}

@keyframes loading-bar {
  from {
    left: -100%;
  }
  to {
    left: 100%;
  }
}


================================================
FILE: app/main/splash/index.html
================================================
<!DOCTYPE html>
<html>
  <head>
    <meta charset="UTF-8" />
    <title>FLOATING LOADING SCREEN</title>
    <link rel="stylesheet" href="../tailwind.css" />
    <link rel="stylesheet" href="./index.css" />
  </head>
  <body>
    <div
      class="h-screen w-screen fixed z-50 flex flex-col items-center justify-center text-white text-4xl gap-8 bg-slate-600"
    >
      <div class="loading-bar w-3/4 md:w-1/2 lg:w-1/3"></div>
    </div>
  </body>
</html>


================================================
FILE: app/main/tsconfig.json
================================================
{
  "compilerOptions": {
    "allowJs": true,
    "alwaysStrict": true,
    "esModuleInterop": true,
    "forceConsistentCasingInFileNames": true,
    "isolatedModules": true,
    "jsx": "preserve",
    "lib": ["dom", "es2017"],
    "module": "commonjs",
    "moduleResolution": "node",
    "noEmit": false,
    "noFallthroughCasesInSwitch": true,
    "noUnusedLocals": true,
    "noUnusedParameters": true,
    "resolveJsonModule": true,
    "skipLibCheck": true,
    "strict": true,
    "target": "esnext",
    "outDir": "./out",
  },
  "compileOnSave": true,
  "exclude": ["node_modules", "./out/**/*"],
  "include": ["**/*.ts", "**/*.tsx", "**/*.js", "public/**.icns"],
}


================================================
FILE: app/next.config.js
================================================
/** @type {import('next').NextConfig} */
const nextConfig = {
  output: "export",
  distDir: "out",
};

module.exports = nextConfig;


================================================
FILE: app/notarize.js
================================================
require('dotenv').config();
const { notarize } = require('electron-notarize');

exports.default = async function notarizing(context) {
  const { electronPlatformName, appOutDir } = context;
  if (electronPlatformName !== 'darwin') {
    return;
  }

  const appName = context.packager.appInfo.productFilename;

  return await notarize({
    appBundleId: 'com.parkersmith.mlx-chat',
    appPath: `${appOutDir}/${appName}.app`,
    appleId: process.env.APPLEID,
    appleIdPassword: process.env.APPLEIDPASS,
  });
};


================================================
FILE: app/package.json
================================================
{
  "name": "electron-app",
  "productName": "Electron App",
  "version": "0.1.0",
  "private": true,
  "main": "main/out/main.js",
  "homepage": "./",
  "description": "My Next.js project",
  "author": "test",
  "scripts": {
    "dev": "cross-env NODE_ENV=development concurrently -k \"cross-env BROWSER=none npm run next:dev\" \"npm run electron:dev\"",
    "build": " npm run build:main && next build",
    "start": "cross-env npm run electron",
    "build:tailwindMain": "npx tailwindcss build --config tailwind.config.main.js -o ./main/tailwind.css",
    "build:main": "tsc -p main && npm run build:tailwindMain",
    "next:dev": "next dev",
    "next:start": "next start",
    "next:lint": "next lint",
    "electron:dev": "npm run build:main && wait-on tcp:3000 && electron .",
    "electron": "electron .",
    "pack": "npm run build && electron-builder --dir",
    "dist": "npm run build && electron-builder",
    "lint": "npx eslint --max-warnings 0 --ext=.ts ."
  },
  "dependencies": {
    "@electron/osx-sign": "^1.0.5",
    "@fortawesome/fontawesome-free": "^6.5.1",
    "@fortawesome/fontawesome-svg-core": "^6.5.1",
    "@fortawesome/free-regular-svg-icons": "^6.5.1",
    "@fortawesome/free-solid-svg-icons": "^6.5.1",
    "@fortawesome/react-fontawesome": "^0.2.0",
    "@matejmazur/react-katex": "^3.1.3",
    "@radix-ui/react-icons": "^1.3.0",
    "@radix-ui/react-select": "^2.0.0",
    "@radix-ui/react-slot": "^1.0.2",
    "@radix-ui/react-tooltip": "^1.0.7",
    "@reduxjs/toolkit": "^2.2.1",
    "@types/electron": "^1.6.10",
    "@types/node": "^20.6.0",
    "@types/react": "^18.2.21",
    "@types/react-dom": "^18.2.7",
    "autoprefixer": "^10.4.15",
    "class-variance-authority": "^0.7.0",
    "clsx": "^2.1.0",
    "concurrently": "^8.2.1",
    "cross-env": "^7.0.3",
    "dprint": "^0.45.0",
    "electron-context-menu": "^3.6.1",
    "electron-serve": "^1.1.0",
    "electron-squirrel-startup": "^1.0.0",
    "electron-store": "^8.1.0",
    "eslint": "8.41.0",
    "eslint-config-next": "13.4.3",
    "markdown-to-jsx": "^7.4.1",
    "next": "13.4.3",
    "postcss": "^8.4.29",
    "react": "18.2.0",
    "react-dom": "18.2.0",
    "react-redux": "^9.1.0",
    "react-resizable-panels": "^2.0.11",
    "rxjs": "^7.8.1",
    "tailwind-merge": "^2.2.1",
    "tailwindcss": "^3.3.3",
    "tailwindcss-animate": "^1.0.7",
    "wait-on": "^7.0.1"
  },
  "devDependencies": {
    "@typescript-eslint/eslint-plugin": "^6.21.0",
    "@typescript-eslint/parser": "^6.21.0",
    "dotenv": "^16.4.5",
    "dprint": "^0.45.0",
    "electron": "^26.2.0",
    "electron-builder": "^24.6.4",
    "electron-notarize": "^1.2.2",
    "eslint": "^8.56.0",
    "eslint-config-airbnb-base": "^15.0.0",
    "eslint-plugin-compat": "^4.2.0",
    "eslint-plugin-jest": "^27.6.3",
    "eslint-plugin-jest-formatting": "^3.1.0",
    "eslint-plugin-jsdoc": "^48.0.6",
    "eslint-plugin-jsx-a11y": "^6.8.0",
    "eslint-plugin-justinanastos": "^1.3.1",
    "eslint-plugin-modules-newlines": "^0.0.7",
    "eslint-plugin-no-unsanitized": "^4.0.2",
    "eslint-plugin-react": "^7.33.2",
    "eslint-plugin-react-hooks": "^4.6.0",
    "eslint-plugin-unused-imports": "^3.0.0",
    "typescript": "^5.2.2"
  },
  "build": {
    "appId": "mlx-chat",
    "productName": "MLX Chat",
    "afterSign": "notarize.js",
    "win": {
      "target": [
        "nsis"
      ]
    },
    "nsis": {
      "oneClick": false,
      "perMachine": true,
      "allowToChangeInstallationDirectory": true,
      "uninstallDisplayName": "MLX Chat"
    },
    "mac": {
      "category": "your.app.category.type",
      "target": [
        "dmg"
      ],
      "gatekeeperAssess": false,
      "hardenedRuntime": true,
      "icon": "assets/icon.icns",
      "entitlements": "./mac/entitlements.mac.inherit.plist",
      "entitlementsInherit": "./mac/entitlements.mac.inherit.plist"
    },
    "dmg": {
      "title": "MLX Chat Installer",
      "sign": false
    },
    "extraFiles": [
      {
        "from": "assets",
        "to": "resources",
        "filter": [
          "**/*"
        ]
      },
      {
        "from": "../dist",
        "to": "resources/server",
        "filter": [
          "**/*"
        ]
      }
    ]
  }
}


================================================
FILE: app/postcss.config.js
================================================
module.exports = {
  plugins: {
    tailwindcss: {},
    autoprefixer: {},
  },
}


================================================
FILE: app/src/AppProvider.tsx
================================================
'use client';

import {
  useRef,
} from 'react';
import {
  Provider,
} from 'react-redux';
import type {
  AppStore,
} from './lib/store';
import {
  makeStore,
} from './lib/store';

export default function StoreProvider({
  children,
}: {
  children: React.ReactNode;
}) {
  const storeRef = useRef<AppStore>();
  if (!storeRef.current) {
    // Create the store instance the first time this renders
    storeRef.current = makeStore();
  }

  return <Provider store={storeRef.current}>{children}</Provider>;
}


================================================
FILE: app/src/app/globals.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;

@layer base {
  :root {
    --background: 0 0% 100%;
    --foreground: 240 10% 3.9%;
    --card: 0 0% 100%;
    --card-foreground: 240 10% 3.9%;
    --popover: 0 0% 100%;
    --popover-foreground: 240 10% 3.9%;
    --primary: 240 5.9% 10%;
    --primary-foreground: 0 0% 98%;
    --secondary: 240 4.8% 95.9%;
    --secondary-foreground: 240 5.9% 10%;
    --muted: 240 4.8% 95.9%;
    --muted-foreground: 240 3.8% 46.1%;
    --accent: 240 4.8% 95.9%;
    --accent-foreground: 240 5.9% 10%;
    --destructive: 0 72.22% 50.59%;
    --destructive-foreground: 0 0% 98%;
    --border: 240 5.9% 90%;
    --input: 240 5.9% 90%;
    --ring: 240 5% 64.9%;
    --radius: 0.5rem;
  }

  @media screen and (prefers-color-scheme: dark) {
      :root {
      --background: 240 10% 3.9%;
      --foreground: 0 0% 98%;
      --card: 240 10% 3.9%;
      --card-foreground: 0 0% 98%;
      --popover: 240 10% 3.9%;
      --popover-foreground: 0 0% 98%;
      --primary: 0 0% 98%;
      --primary-foreground: 240 5.9% 10%;
      --secondary: 240 3.7% 15.9%;
      --secondary-foreground: 0 0% 98%;
      --muted: 240 3.7% 15.9%;
      --muted-foreground: 240 5% 64.9%;
      --accent: 240 3.7% 15.9%;
      --accent-foreground: 0 0% 98%;
      --destructive: 0 62.8% 30.6%;
      --destructive-foreground: 0 85.7% 97.3%;
      --border: 240 3.7% 15.9%;
      --input: 240 3.7% 15.9%;
      --ring: 240 4.9% 83.9%;
    }
  }
}

@layer base {
  * {
    @apply border-border;
  }
  body {
    @apply text-foreground;
    /* font-feature-settings: "rlig" 1, "calt" 1; */
    font-synthesis-weight: none;
    text-rendering: optimizeLegibility;
  }
}

@layer utilities {
  .step {
    counter-increment: step;
  }

  .step:before {
    @apply absolute w-9 h-9 bg-muted rounded-full font-mono font-medium text-center text-base inline-flex items-center justify-center -indent-px border-4 border-background;
    @apply ml-[-50px] mt-[-4px];
    content: counter(step);
  }
}

@media (max-width: 640px) {
  .container {
    @apply px-4;
  }
}

/* Update scrollbar when in dark mode */
@media screen and (prefers-color-scheme: dark) {
  ::-webkit-scrollbar-thumb {
    background-color: hsl(var(--muted));
    border-radius: 5px;
    transition: all;
  }
  ::-webkit-scrollbar-thumb:hover {
    background-color: hsl(255, 4%, 20%);
  }
}

/* Update scrollbar when in light mode */
@media screen and (prefers-color-scheme: light) {
  ::-webkit-scrollbar-thumb {
    background-color: rgb(38 38 38);
    border-radius: 5px;
    transition: all;
  }
  ::-webkit-scrollbar-thumb:hover {
    background-color: hsl(255, 4%, 20%);
  }
}

::-webkit-scrollbar {
  width: 7px;
}
::-webkit-scrollbar-track {
  background-color: transparent;
}

html {
  font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;
}

ol,
ul,
menu {
  list-style: outside;
  margin: 0;
  padding-left: 20px;
}

.drag {
  -webkit-app-region: drag;
}

.no-drag {
  -webkit-app-region: no-drag;
}


================================================
FILE: app/src/app/layout.tsx
================================================
'use client';

import StoreProvider from '../AppProvider';
import './globals.css';
import '@fortawesome/fontawesome-svg-core/styles.css';
// Prevent fontawesome from adding its CSS since we did it manually above:
import {
  config,
} from '@fortawesome/fontawesome-svg-core';
import {
  TooltipProvider,
} from '../components/ui/tooltip';

config.autoAddCss = false;

export default function RootLayout({
  children,
}: {
  children: React.ReactNode;
}) {
  return (
    <html lang='en'>
      <body
        className={'min-h-screen overflow-y-hidden'}
        style={{
          userSelect: 'none',
        }}
      >
        <TooltipProvider>
          <StoreProvider>
            {children}
          </StoreProvider>
        </TooltipProvider>
      </body>
    </html>
  );
}


================================================
FILE: app/src/app/page.tsx
================================================
'use client';

import {
  faBan,
  faCheckCircle,
} from '@fortawesome/free-solid-svg-icons';
import {
  FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
  useEffect,
  useState,
} from 'react';
import Chat from '../components/chat/Chat';
import SelectDirectory from '../components/options/SelectDirectory';
import {
  Button,
} from '../components/ui/button';
import {
  Tooltip,
  TooltipContent,
  TooltipTrigger,
} from '../components/ui/tooltip';
import type {
  ChatMessage,
} from '../constants/chat';
import {
  useAppDispatch,
} from '../lib/hooks';
import {
  startDirectoryIndexing,
  stopDirectoryIndexing,
} from '../lib/store';

export default function Home() {
  const [selectedDirectory, setSelectedDirectory] = useState<string | null>(null);
  const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);

  const dispatch = useAppDispatch();

  function handleOpen() {
    if (typeof window !== 'undefined') {
      window.electronAPI.selectDirectory();
    }
  }

  useEffect(() => {
    window.electronAPI.onSelectDirectory(async (customData: string[]) => {
      setSelectedDirectory(customData[0]);
      try {
        dispatch(startDirectoryIndexing());
        await fetch('http://localhost:8080/api/index', {
          method: 'POST',
          headers: {
            'Content-Type': 'application/json',
          },
          body: JSON.stringify({
            directory: customData[0],
          }),
        });
        dispatch(stopDirectoryIndexing());
        // TODO: spinner while indexing
      } catch (error) {
        // eslint-disable-next-line no-console
        console.error('Error sending message: ', error);
        dispatch(stopDirectoryIndexing());
      }
    });
  }, []);

  useEffect(() => {
    window.electronAPI.onSelectDirectory(() => {
      if (chatHistory.length) {
        setChatHistory([
          ...chatHistory,
          { role: 'system', content: 'Assist' },
        ]);
      }
    });
  }, [chatHistory]);

  const handleClearHistory = () => {
    setChatHistory([]);
    if (typeof window !== 'undefined') {
      window.electronAPI.resizeWindow(99);
    }
  };

  const clearDirectory = () => {
    setSelectedDirectory(null);
    if (chatHistory.length) {
      setChatHistory([
        ...chatHistory,
        { role: 'system', content: 'Converse' },
      ]);
    }
  };

  return (
    <main className='flex flex-col'>
      <Chat
        chatHistory={chatHistory}
        setChatHistory={setChatHistory}
        selectedDirectory={selectedDirectory}
      />
      <div className='border-t border-t-neut
      ral-400 dark:border-t-neutral-700 pt-[5px] px-2'>
        <div className='flex justify-between drag'>
          {chatHistory.length
            ? (
              <Tooltip delayDuration={0}>
                <TooltipTrigger>
                  <Button
                    className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'
                    onClick={handleClearHistory}
                  >
                    <FontAwesomeIcon icon={faCheckCircle} className='text-green-500' />
                  </Button>
                </TooltipTrigger>
                <TooltipContent>Clear History</TooltipContent>
              </Tooltip>
            )
            : (
              <Button
                className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'
                disabled={true}
              >
                <FontAwesomeIcon icon={faBan} className='text-red-400' />
              </Button>
            )}
          <SelectDirectory
            clearDirectory={clearDirectory}
            handleOpen={handleOpen}
            selectedDirectory={selectedDirectory}
          />
        </div>
      </div>
    </main>
  );
}


================================================
FILE: app/src/app/settings/page.tsx
================================================
'use client';

import type {
  IconProp,
} from '@fortawesome/fontawesome-svg-core';
import {
  faCog,
  faMessage,
} from '@fortawesome/free-solid-svg-icons';
import {
  FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
  useEffect,
} from 'react';
import SelectModel from '../../components/options/SelectModel';
import {
  Textarea,
} from '../../components/ui/textarea';
import {
  convertToNiceShortcut,
  useKeyboardShortcut,
} from '../../lib/hooks';
import {
  cn,
} from '../../lib/utils';

enum SETTINGS {
  GENERAL,
  PROMPTS,
}

function SettingsOption({
  title,
  icon,
  onClick,
  selected,
}: {
  title: string;
  icon: IconProp;
  onClick: () => void;
  selected?: boolean;
}) {
  return (
    <div
      onClick={onClick}
      className={cn(
        'flex flex-col items-center h-[44px] gap-[2px] no-drag w-fit p-1 px-2 justify-center hover:bg-[#E2E3E2] dark:hover:bg-[#454544] active:bg-[#D5D6D5] dark:active:bg-[#525251] cursor-default rounded-md text-[#6C6C6C] dark:text-[#9C9C9B] active:text-[#202020] dark:active:text-[#EEEEEE]',
        {
          'bg-[#E2E3E2] dark:bg-[#454544]': selected,
        },
      )}
    >
      <FontAwesomeIcon
        className={cn('text-[20px] pt-1', {
          'dark:text-[#0D87FF] text-[#0066EB]': selected,
        })}
        icon={icon}
      />
      <h1
        className={cn('text-[11px]', {
          'dark:text-[#0D87FF] text-[#0066EB]': selected,
        })}
      >
        {title}
      </h1>
    </div>
  );
}

function GeneralSettings() {
  const {
    startListening,
    stopListening,
    shortcut,
  } = useKeyboardShortcut();

  const [keybind, setKeybind] = React.useState<string>(
    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('keybind') : '⌘O',
  );
  const [model, setModel] = React.useState<string>(
    typeof window !== 'undefined'
      ? window.electronAPI.fetchSetting('model')
      : 'mistralai/Mistral-7B-Instruct-v0.2',
  );

  useEffect(() => {
    if (!shortcut) {
      return;
    }
    setKeybind(shortcut);
  }, [shortcut]);

  return (
    <div className='flex flex-col justify-center w-full items-center'>
      <div className='flex items-center mt-2'>
        <p className='text-sm'>Launch keybind:</p>
        <input
          className='rounded-sm text-[12px] text-center drop-shadow-sm bg-[#fffff] dark:bg-[#343432] border-0 h-[18px] w-28 ml-3 active:border-0 focus:border-0 outline-offset-2 focus:outline-blue-400'
          type='text'
          readOnly
          value={convertToNiceShortcut(keybind)}
          onFocus={startListening}
          onBlur={() => {
            stopListening();
            if (typeof window !== 'undefined') {
              window.electronAPI.updateSetting('keybind', shortcut);
            }
          }}
        />
      </div>
      <div className='flex items-center mt-2'>
        <p className='text-sm mr-2'>Default model:</p>
        <SelectModel
          selectedModel={model}
          handleModelChange={(selectedModel) => {
            setModel(selectedModel);
            if (typeof window !== 'undefined' && selectedModel) {
              window.electronAPI.updateSetting('model', selectedModel);
            }
          }}
        />
      </div>
    </div>
  );
}

function PromptSettings() {
  const [personalization, setPersonalization] = React.useState<string>(
    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('personalization') : '',
  );
  const [response, setResponse] = React.useState<string>(
    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('customResponse') : '',
  );

  return (
    <div className='flex flex-col justify-center w-full items-center gap-4'>
      <div className='flex flex-col items-center mt-2 gap-2'>
        <p className='text-sm flex-shrink-0 font-bold'>Personalization</p>
        <Textarea
          className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'
          value={personalization}
          onChange={(e) => {
            setPersonalization(e.target.value);
            if (typeof window !== 'undefined') {
              window.electronAPI.updateSetting('personalization', e.target.value);
            }
          }}
          rows={5}
          placeholder={`Things to know about you... e.g.,
  - I enjoy thought provoking conversation
  - I am a fan of the arts and culture`}
        />
      </div>
      <div className='flex flex-col items-center mt-2 gap-2'>
        <p className='text-sm flex-shrink-0 font-bold'>Custom Response</p>
        <Textarea
          className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'
          value={response}
          onChange={(e) => {
            setResponse(e.target.value);
            if (typeof window !== 'undefined') {
              window.electronAPI.updateSetting('customResponse', e.target.value);
            }
          }}
          rows={5}
          placeholder={`How to format responses... e.g.,
  - Respond in a concise manner
  - Do not use slang`}
        />
      </div>
    </div>
  );
}

export default function Settings() {
  const [selectedSetting, setSelectedSetting] = React.useState<SETTINGS>(SETTINGS.GENERAL);

  return (
    <main className='flex flex-col bg-[#F1F1F1] dark:bg-[#383736] h-screen'>
      <div className='h-[81px] border-0 border-b border-b-neutral-300 dark:border-b-zinc-950 drag flex flex-col'>
        <h1 className='text-[12px] font-bold text-[#6C6C6C] dark:text-[#9C9C9B] text-center pt-1'>
          {selectedSetting === SETTINGS.GENERAL
            ? 'General'
            : 'Prompt'}
        </h1>
        <div className='flex justify-center items-center gap-[1px] mt-1'>
          <SettingsOption
            title='General'
            icon={faCog}
            onClick={() => setSelectedSetting(SETTINGS.GENERAL)}
            selected={selectedSetting === SETTINGS.GENERAL}
          />
          <SettingsOption
            title='Prompts'
            icon={faMessage}
            onClick={() => setSelectedSetting(SETTINGS.PROMPTS)}
            selected={selectedSetting === SETTINGS.PROMPTS}
          />
        </div>
      </div>
      <div className='flex-grow dark:bg-[#292929] bg-[#EEEDEC]'>
        {selectedSetting === SETTINGS.GENERAL
          ? <GeneralSettings />
          : <PromptSettings />}
      </div>
    </main>
  );
}


================================================
FILE: app/src/components/chat/Chat.tsx
================================================
import React from 'react';
import type {
  ChatMessage,
} from '../../constants/chat';
import {
  useAppDispatch,
} from '../../lib/hooks';
import {
  startWaitingForResponse,
  stopWaitingForResponse,
} from '../../lib/store';
import {
  cn,
} from '../../lib/utils';
import ChatInput from './ChatInput';
import ChatMessages from './ChatMessages';

const Chat = ({
  selectedDirectory,
  chatHistory,
  setChatHistory,
}: {
  selectedDirectory: string | null;
  chatHistory: ChatMessage[];
  setChatHistory: (chats: ChatMessage[]) => void;
}) => {
  const dispatch = useAppDispatch();
  const sendMessage = async (message: string) => {
    try {
      if (chatHistory.length === 0) {
        window.electronAPI.resizeWindow(500);
      }
      const newHistory = [
        ...chatHistory,
        { role: 'user' as const, content: message },
      ];
      setChatHistory(newHistory);
      dispatch(startWaitingForResponse());
      const response = await fetch('http://localhost:8080/api/query', {
        method: 'POST',
        headers: {
          'Content-Type': 'application/json',
        },
        body: JSON.stringify({
          messages: selectedDirectory
            ? [{ role: 'user', content: message }]
            : newHistory.filter((chat) => chat.role !== 'system'),
          temperature: 0.7,
          // eslint-disable-next-line @typescript-eslint/naming-convention
          top_p: 1.0,
          // eslint-disable-next-line @typescript-eslint/naming-convention
          max_tokens: 200,
          directory: selectedDirectory,
          instructions: {
            personalization: typeof window !== 'undefined'
              ? window.electronAPI.fetchSetting('personalization')
              : '',
            response: typeof window !== 'undefined'
              ? window.electronAPI.fetchSetting('customResponse')
              : '',
          },
        }),
      });
      dispatch(stopWaitingForResponse());

      const responseData = await response.json();
      const assistantResponse = responseData.choices[0].message.content;

      setChatHistory([
        ...newHistory,
        { role: 'assistant', content: assistantResponse },
      ]);
    } catch (error) {
      dispatch(stopWaitingForResponse());
      // eslint-disable-next-line no-console
      console.error('Error sending message: ', error);
    }
  };

  return (
    <>
      <div
        className={cn('flex justify-center border-b-neutral-400 dark:border-b-neutral-700', {
          'border-b': chatHistory.length > 0,
        })}
      >
        <ChatInput sendMessage={sendMessage} />
      </div>
      <div
        className={cn(
          'flex-grow min-w-full border-0 flex h-0',
          {
            'h-[400px]': chatHistory.length > 0,
          },
        )}
      >
        <ChatMessages chatHistory={chatHistory} />
      </div>
    </>
  );
};

export default Chat;


================================================
FILE: app/src/components/chat/ChatInput.tsx
================================================
import React, {
  useEffect,
  useRef,
  useState,
} from 'react';
import {
  useAppSelector,
} from '../../lib/hooks';
import {
  Input,
} from '../ui/input';

const ChatInput = ({
  sendMessage,
}: {
  sendMessage: (text: string) => void;
}) => {
  const [message, setMessage] = useState<string>('');
  const inputRef = useRef<HTMLInputElement>(null);

  const handleSend = (e: React.KeyboardEvent) => {
    if (e.key !== 'Enter' || message.length === 0) {
      return;
    }
    e.preventDefault();
    setMessage('');
    sendMessage(message);
  };

  // detect website focus and focus the input
  const handleFocus = () => {
    if (document.activeElement !== inputRef.current) {
      inputRef.current?.focus();
    }
  };

  useEffect(() => {
    window.addEventListener('focus', handleFocus);
    return () => {
      window.removeEventListener('focus', handleFocus);
    };
  }, []);

  const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);

  return (
    <div
      className='w-full py-2 drag'
      onClick={() => {
        if (document.activeElement !== inputRef.current) {
          inputRef.current?.focus();
        }
      }}
    >
      <Input
        value={message}
        onChange={(e) => setMessage(e.target.value)}
        placeholder={isDirectoryIndexing ? 'Indexing your files..' : 'Enter prompt here'}
        onKeyDown={handleSend}
        ref={inputRef}
        disabled={isDirectoryIndexing}
        className={'text-xl no-drag border-0 focus-visible:outline-transparent focus-visible:ring-0 focus-visible:shadow-0 w-full shadow-0'}
      />
    </div>
  );
};

export default ChatInput;


================================================
FILE: app/src/components/chat/ChatMessage.tsx
================================================
import Markdown from 'markdown-to-jsx';
import React from 'react';
import type {
  ChatMessage,
} from '../../constants/chat';

const Message = ({
  message,
}: {
  message: ChatMessage;
}) => (
  <div
    className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}
  >
    <div
      className={`p-2 rounded-sm ${
        message.role === 'user'
          ? 'bg-blue-500 text-white'
          : 'bg-[#E9E9EB] dark:bg-zinc-500'
      }`}
    >
      <div className='text-md select-text'>
        <Markdown
          children={message.content ?? ''}
        />
      </div>
    </div>
  </div>
);

export default Message;


================================================
FILE: app/src/components/chat/ChatMessages.tsx
================================================
/* eslint-disable function-paren-newline */
import {
  faCircleNotch,
} from '@fortawesome/free-solid-svg-icons';
import {
  FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
  useEffect,
} from 'react';
import type {
  ChatMessage,
} from '../../constants/chat';
import {
  useAppSelector,
} from '../../lib/hooks';
import Message from './ChatMessage';
import SystemMessage from './SystemMessage';

const ChatMessages = ({
  chatHistory,
}: {
  chatHistory: ChatMessage[];
}) => {
  const messagesRef = React.useRef<HTMLDivElement>(null);

  const scrollToBottom = () => {
    const scrollHeight = messagesRef.current?.scrollHeight;
    const height = messagesRef.current?.clientHeight ?? 0;
    const maxScrollTop = scrollHeight ? scrollHeight - height : 0;
    if (messagesRef.current) {
      messagesRef.current.scrollTop = maxScrollTop > 0 ? maxScrollTop : 0;
    }
  };

  const isWaitingForResponse = useAppSelector((state) => state.isWaitingForResponse);

  useEffect(() => {
    // check if the user is not at the bottom of the chat
    const currentScroll = messagesRef.current?.scrollTop ?? 0;
    const scrollHeight = messagesRef.current?.scrollHeight;
    const height = messagesRef.current?.clientHeight ?? 0;
    const maxScrollTop = scrollHeight ? scrollHeight - height : 0;
    const scrollInHistory = (maxScrollTop - currentScroll) > 200;

    if (scrollInHistory && chatHistory[chatHistory.length - 1]?.role !== 'user') {
      return;
    }

    scrollToBottom();
  }, [chatHistory]);
  return chatHistory.length
    ? (
      <div ref={messagesRef} className='flex flex-col flex-grow p-4 gap-4 overflow-y-scroll'>
        {chatHistory.map((message, index) => (message.role !== 'system'
          ? (
            <Message
              key={index}
              message={message}
            />
          )
          : (
            <SystemMessage
              key={index}
              message={message}
            />
          ))
        )}
        {isWaitingForResponse
          ? (
            <div
              className={'flex justify-start'}
            >
              <div
                className={'p-2 rounded-sm'}
              >
                <div className='text-md select-text'>
                  <FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />
                </div>
              </div>
            </div>
          )
          : null}
      </div>
    )
    : null;
};

export default ChatMessages;


================================================
FILE: app/src/components/chat/SystemMessage.tsx
================================================
import React from 'react';
import type {
  ChatMessage,
} from '../../constants/chat';

const Message = ({
  message,
}: {
  message: ChatMessage;
}) => (
  <div
    className={'flex w-full'}
  >
    <div
      className={'rounded-sm w-full relative flex items-center'}
    >
      <div className='w-full h-[1px] bg-red-500 rounded-md' />
      <div className='text-[12px] font-semibold select-text px-2 text-red-500 bg-transparent flex-grow rounded-md text-center py-1 whitespace-nowrap'>
        {message.content}
      </div>
      <div className='w-full h-[1px] bg-red-500 rounded-md' />
      <div className='arrow-tag text-[10px] p-0 absolute right-0 font-bold select-text uppercase pr-1 pl-1 text-white bg-red-500 flex-grow rounded-sm text-center whitespace-nowrap'>
        <svg
          className='absolute -left-[5px] top-[1px] z-[-1]'
          aria-hidden='true'
          role='img'
          width='8'
          height='13'
          viewBox='0 0 8 13'
        >
          <path
            className='fill-red-500 text-red-500'
            stroke='currentColor'
            fill='transparent'
            d='M8.16639 0.5H9C10.933 0.5 12.5 2.067 12.5 4V9C12.5 10.933 10.933 12.5 9 12.5H8.16639C7.23921 12.5 6.34992 12.1321 5.69373 11.4771L0.707739 6.5L5.69373 1.52292C6.34992 0.86789 7.23921 0.5 8.16639 0.5Z'
          >
          </path>
        </svg>
        Mode
      </div>
    </div>
  </div>
);

export default Message;


================================================
FILE: app/src/components/options/SelectDirectory.tsx
================================================
import {
  faCheckCircle,
  faCircleNotch,
  faXmark,
} from '@fortawesome/free-solid-svg-icons';
import {
  FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
  useEffect,
} from 'react';
import {
  useAppSelector,
  usePrevious,
} from '../../lib/hooks';
import {
  cn,
} from '../../lib/utils';
import {
  Button,
} from '../ui/button';

const SelectDirectory = ({
  handleOpen,
  selectedDirectory,
  clearDirectory,
}: {
  handleOpen: () => void;
  selectedDirectory: string | null;
  clearDirectory: () => void;
}) => {
  const shortenedDirectory = selectedDirectory
    ? `/${selectedDirectory.split('/')[1]}/../${selectedDirectory.split('/').pop()}`
    : 'Select Directory';
  const [isCheckShowing, setIsCheckShowing] = React.useState(false);

  const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);

  const oldLoadingState = usePrevious(isDirectoryIndexing);

  useEffect(() => {
    if (oldLoadingState && !isDirectoryIndexing) {
      setIsCheckShowing(true);
      setTimeout(() => {
        setIsCheckShowing(false);
      }, 3000);
    }
  }, [isDirectoryIndexing]);

  return (
    <div className='flex no-drag items-center group'>
      <Button
        className={cn(
          'bg-transparent text-neutral-800 z-0 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-2 flex items-center cursor-pointer',
          {
            'hover:bg-transparent dark:hover:bg-transparent cursor-default': isDirectoryIndexing,
          },
        )}
        onClick={isDirectoryIndexing ? undefined : handleOpen}
      >
        <div className='pr-1'>
          {selectedDirectory && !isDirectoryIndexing && !isCheckShowing && (
            <div
              className='group-hover:opacity-100 opacity-0 px-1 z-10 hover:bg-neutral-300 dark:hover:bg-neutral-800 rounded-sm transition-all cursor-pointer'
              onClick={(e) => {
                e.stopPropagation();
                clearDirectory();
              }}
            >
              <FontAwesomeIcon
                icon={faXmark}
              />
            </div>
          )}
          {isDirectoryIndexing && <FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />}
          {isCheckShowing && <FontAwesomeIcon className='text-green-500' icon={faCheckCircle} />}
        </div>
        {shortenedDirectory}
      </Button>
    </div>
  );
};

export default SelectDirectory;


================================================
FILE: app/src/components/options/SelectModel.tsx
================================================
import React from 'react';
import {
  Select,
  SelectContent,
  SelectGroup,
  SelectItem,
  SelectLabel,
  SelectTrigger,
  SelectValue,
} from '../ui/select';

const SelectModel = ({
  selectedModel,
  handleModelChange,
}: {
  selectedModel: string | null;
  handleModelChange: (model: string) => void;
}) => (
  <div className='no-drag'>
    <Select
      value={selectedModel ?? ''}
      onValueChange={(value) => handleModelChange(value)}
    >
      <SelectTrigger className='a-icon w-[140px] h-5 border-none shadow-transparent bg-white dark:bg-[#606160] transition-all border border-zinc-600 focus:ring-0 focus-within:ring-0 focus-visible:ring-0 peer-focus-within:ring-0 text-neutral-800 dark:text-white'>
        <SelectValue placeholder='Select a model' />
      </SelectTrigger>
      <SelectContent>
        <SelectGroup>
          <SelectLabel>AI Model</SelectLabel>
          <SelectItem value='mistralai/Mistral-7B-Instruct-v0.2'>Mistral7B</SelectItem>
          <SelectItem value='google/gemma-7b-it'>Gemma7B</SelectItem>
        </SelectGroup>
      </SelectContent>
    </Select>
  </div>
);

export default SelectModel;


================================================
FILE: app/src/components/ui/button.tsx
================================================
import {
  Slot,
} from '@radix-ui/react-slot';
import {
  cva,
  type VariantProps,
} from 'class-variance-authority';
import * as React from 'react';
import {
  cn,
} from '../../lib/utils';

const buttonVariants = cva(
  'inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50',
  {
    variants: {
      variant: {
        default: 'bg-primary text-primary-foreground shadow hover:bg-primary/90',
        destructive: 'bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90',
        outline:
          'border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground',
        secondary: 'bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80',
        ghost: 'hover:bg-accent hover:text-accent-foreground',
        link: 'text-primary underline-offset-4 hover:underline',
      },
      size: {
        default: 'h-9 px-4 py-2',
        sm: 'h-8 rounded-md px-3 text-xs',
        lg: 'h-10 rounded-md px-8',
        icon: 'h-9 w-9',
      },
    },
    defaultVariants: {
      variant: 'default',
      size: 'default',
    },
  },
);

export interface ButtonProps
  extends React.ButtonHTMLAttributes<HTMLButtonElement>, VariantProps<typeof buttonVariants>
{
  asChild?: boolean;
}

const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
  ({
    className,
    variant,
    size,
    asChild = false,
    ...props
  }, ref) => {
    // eslint-disable-next-line @typescript-eslint/naming-convention
    const Comp = asChild ? Slot : 'button';
    return (
      <Comp
        className={cn(buttonVariants({ variant, size, className }))}
        ref={ref}
        {...props}
      />
    );
  },
);
Button.displayName = 'Button';

export {
  Button,
  buttonVariants,
};


================================================
FILE: app/src/components/ui/input.tsx
================================================
import * as React from 'react';
import {
  cn,
} from '../../lib/utils';

export interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}

const Input = React.forwardRef<HTMLInputElement, InputProps>(
  ({ className, type, ...props }, ref) => (
    <input
      type={type}
      className={cn(
        'flex h-9 w-full rounded-md border border-input bg-transparent px-3 py-1 text-sm shadow-0 transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',
        className,
      )}
      ref={ref}
      {...props}
    />
  ),
);
Input.displayName = 'Input';

export {
  Input,
};


================================================
FILE: app/src/components/ui/resizable.tsx
================================================
'use client';

import {
  DragHandleDots2Icon,
} from '@radix-ui/react-icons';
import * as ResizablePrimitive from 'react-resizable-panels';
import {
  cn,
} from '../../lib/utils';

const ResizablePanelGroup = ({
  className,
  ...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (
  <ResizablePrimitive.PanelGroup
    className={cn(
      'flex h-full w-full data-[panel-group-direction=vertical]:flex-col',
      className,
    )}
    {...props}
  />
);

const ResizablePanel = ResizablePrimitive.Panel;

const ResizableHandle = ({
  withHandle,
  className,
  ...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {
  withHandle?: boolean;
}) => (
  <ResizablePrimitive.PanelResizeHandle
    className={cn(
      'relative flex w-px items-center justify-center bg-border after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90',
      className,
    )}
    {...props}
  >
    {withHandle && (
      <div className='z-10 flex h-4 w-3 items-center justify-center rounded-sm border bg-border'>
        <DragHandleDots2Icon className='h-2.5 w-2.5' />
      </div>
    )}
  </ResizablePrimitive.PanelResizeHandle>
);

export {
  ResizableHandle,
  ResizablePanel,
  ResizablePanelGroup,
};


================================================
FILE: app/src/components/ui/select.tsx
================================================
'use client';

import {
  CaretSortIcon,
  CheckIcon,
  ChevronDownIcon,
  ChevronUpIcon,
} from '@radix-ui/react-icons';
import * as SelectPrimitive from '@radix-ui/react-select';
import * as React from 'react';
import {
  cn,
} from '../../lib/utils';

const Select = SelectPrimitive.Root;

const SelectGroup = SelectPrimitive.Group;

const SelectValue = SelectPrimitive.Value;

const SelectTrigger = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.Trigger>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger>
>(({ className, children, ...props }, ref) => (
  <SelectPrimitive.Trigger
    ref={ref}
    className={cn(
      'flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus:outline-none disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1',
      className,
    )}
    {...props}
  >
    {children}
    {className?.includes('a-icon')
      ? (
        <SelectPrimitive.Icon asChild>
          <CaretSortIcon className='h-4 w-4 opacity-50' />
        </SelectPrimitive.Icon>
      )
      : null}
  </SelectPrimitive.Trigger>
));
SelectTrigger.displayName = SelectPrimitive.Trigger.displayName;

const SelectScrollUpButton = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.ScrollUpButton>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton>
>(({ className, ...props }, ref) => (
  <SelectPrimitive.ScrollUpButton
    ref={ref}
    className={cn(
      'flex cursor-default items-center justify-center py-1',
      className,
    )}
    {...props}
  >
    <ChevronUpIcon />
  </SelectPrimitive.ScrollUpButton>
));
SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName;

const SelectScrollDownButton = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.ScrollDownButton>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton>
>(({ className, ...props }, ref) => (
  <SelectPrimitive.ScrollDownButton
    ref={ref}
    className={cn(
      'flex cursor-default items-center justify-center py-1',
      className,
    )}
    {...props}
  >
    <ChevronDownIcon />
  </SelectPrimitive.ScrollDownButton>
));
SelectScrollDownButton.displayName = SelectPrimitive.ScrollDownButton.displayName;

const SelectContent = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.Content>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content>
>(({
  className,
  children,
  position = 'popper',
  ...props
}, ref) => (
  <SelectPrimitive.Portal>
    <SelectPrimitive.Content
      ref={ref}
      className={cn(
        'relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
        position === 'popper'
          && 'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1',
        className,
      )}
      position={position}
      {...props}
    >
      <SelectScrollUpButton />
      <SelectPrimitive.Viewport
        className={cn(
          'p-1',
          position === 'popper'
            && 'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]',
        )}
      >
        {children}
      </SelectPrimitive.Viewport>
      <SelectScrollDownButton />
    </SelectPrimitive.Content>
  </SelectPrimitive.Portal>
));
SelectContent.displayName = SelectPrimitive.Content.displayName;

const SelectLabel = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.Label>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label>
>(({ className, ...props }, ref) => (
  <SelectPrimitive.Label
    ref={ref}
    className={cn('px-2 py-1.5 text-sm font-semibold', className)}
    {...props}
  />
));
SelectLabel.displayName = SelectPrimitive.Label.displayName;

const SelectItem = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.Item>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item>
>(({ className, children, ...props }, ref) => (
  <SelectPrimitive.Item
    ref={ref}
    className={cn(
      'relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-2 pr-8 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50',
      className,
    )}
    {...props}
  >
    <span className='absolute right-2 flex h-3.5 w-3.5 items-center justify-center'>
      <SelectPrimitive.ItemIndicator>
        <CheckIcon className='h-4 w-4' />
      </SelectPrimitive.ItemIndicator>
    </span>
    <SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
  </SelectPrimitive.Item>
));
SelectItem.displayName = SelectPrimitive.Item.displayName;

const SelectSeparator = React.forwardRef<
  React.ElementRef<typeof SelectPrimitive.Separator>,
  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator>
>(({ className, ...props }, ref) => (
  <SelectPrimitive.Separator
    ref={ref}
    className={cn('-mx-1 my-1 h-px bg-muted', className)}
    {...props}
  />
));
SelectSeparator.displayName = SelectPrimitive.Separator.displayName;

export {
  Select,
  SelectContent,
  SelectGroup,
  SelectItem,
  SelectLabel,
  SelectScrollDownButton,
  SelectScrollUpButton,
  SelectSeparator,
  SelectTrigger,
  SelectValue,
};


================================================
FILE: app/src/components/ui/textarea.tsx
================================================
import * as React from 'react';
import {
  cn,
} from '../../lib/utils';

export interface TextareaProps extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {}

const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
  ({ className, ...props }, ref) => (
    <textarea
      className={cn(
        'flex min-h-[60px] w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',
        className,
      )}
      ref={ref}
      {...props}
    />
  ),
);
Textarea.displayName = 'Textarea';

export {
  Textarea,
};


================================================
FILE: app/src/components/ui/tooltip.tsx
================================================
'use client';

import * as TooltipPrimitive from '@radix-ui/react-tooltip';
import * as React from 'react';
import {
  cn,
} from '../../lib/utils';

const TooltipProvider = TooltipPrimitive.Provider;

const Tooltip = TooltipPrimitive.Root;

const TooltipTrigger = TooltipPrimitive.Trigger;

const TooltipContent = React.forwardRef<
  React.ElementRef<typeof TooltipPrimitive.Content>,
  React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
>(({ className, sideOffset = 4, ...props }, ref) => (
  <TooltipPrimitive.Content
    ref={ref}
    sideOffset={sideOffset}
    className={cn(
      'z-50 overflow-hidden rounded-md bg-primary px-3 py-1.5 text-xs text-primary-foreground animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
      className,
    )}
    {...props}
  />
));
TooltipContent.displayName = TooltipPrimitive.Content.displayName;

export {
  Tooltip,
  TooltipContent,
  TooltipProvider,
  TooltipTrigger,
};


================================================
FILE: app/src/constants/chat.tsx
================================================
export type ChatMessage = {
  role: 'user' | 'assistant' | 'system';
  content: string | null;
};


================================================
FILE: app/src/lib/hooks.ts
================================================
import {
  useEffect,
  useRef,
  useState,
} from 'react';
import {
  useDispatch,
  useSelector,
  useStore,
} from 'react-redux';
import type {
  TypedUseSelectorHook,
} from 'react-redux';
import type {
  AppDispatch,
  AppStore,
  RootState,
} from './store';

// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore: () => AppStore = useStore;

export function usePrevious<T>(value: T): T | undefined {
  const ref = useRef<T>();
  useEffect(() => {
    ref.current = value;
  });
  return ref.current;
}

export function convertToNiceShortcut(shortcut: string) {
  return shortcut.replace('Cmd', '⌘').replace('Option', '⌥').replace('Shift', '⇧').replaceAll(
    '+',
    '',
  );
}

export function useKeyboardShortcut() {
  const [isListening, setIsListening] = useState(false);
  const [shortcut, setShortcut] = useState('');

  useEffect(() => {
    const handleKeyDown = (event: KeyboardEvent) => {
      if (!isListening) { return; }

      // Prevent default action to avoid interfering with normal browser shortcuts
      event.preventDefault();

      // command key for mac icon
      //

      const keys = [];
      if (event.ctrlKey) { keys.push('Ctrl'); }
      if (event.shiftKey) { keys.push('Shift'); }
      if (event.altKey) { keys.push('Option'); }
      if (event.metaKey) {
        keys.push('Cmd'); // Command key for Mac
      }
      // Avoid adding modifier keys alone, check if another key is also pressed
      if (
        event.key.length === 1
        || (event.key !== 'Control' && event.key !== 'Shift' && event.key !== 'Alt'
          && event.key !== 'Meta')
      ) {
        keys.push(event.key.toUpperCase());
      }

      const combination = keys.join('+');
      setShortcut(combination);
    };

    if (isListening) {
      window.addEventListener('keydown', handleKeyDown);
    }

    return () => {
      window.removeEventListener('keydown', handleKeyDown);
    };
  }, [isListening]);

  const startListening = () => setIsListening(true);
  const stopListening = () => setIsListening(false);

  return {
    startListening,
    stopListening,
    shortcut,
  };
}


================================================
FILE: app/src/lib/store.ts
================================================
import {
  configureStore,
  createSlice,
} from '@reduxjs/toolkit';

const globalSlice = createSlice({
  name: 'global',
  initialState: {
    isDirectoryIndexing: false,
    isWaitingForResponse: false,
  },
  reducers: {
    startDirectoryIndexing: (state) => {
      // eslint-disable-next-line no-param-reassign
      state.isDirectoryIndexing = true;
    },
    stopDirectoryIndexing: (state) => {
      // eslint-disable-next-line no-param-reassign
      state.isDirectoryIndexing = false;
    },
    startWaitingForResponse: (state) => {
      // eslint-disable-next-line no-param-reassign
      state.isWaitingForResponse = true;
    },
    stopWaitingForResponse: (state) => {
      // eslint-disable-next-line no-param-reassign
      state.isWaitingForResponse = false;
    },
  },
});

export const {
  startDirectoryIndexing,
  stopDirectoryIndexing,
  startWaitingForResponse,
  stopWaitingForResponse,
} = globalSlice.actions;

export const makeStore = () =>
  configureStore({
    reducer: globalSlice.reducer,
  });

// Infer the type of makeStore
export type AppStore = ReturnType<typeof makeStore>;
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<AppStore['getState']>;
export type AppDispatch = AppStore['dispatch'];


================================================
FILE: app/src/lib/utils.ts
================================================
import {
  type ClassValue,
  clsx,
} from 'clsx';
import {
  twMerge,
} from 'tailwind-merge';

export function cn(...inputs: ClassValue[]) {
  return twMerge(clsx(inputs));
}


================================================
FILE: app/tailwind.config.main.js
================================================
/** @type {import('tailwindcss').Config} */
module.exports = {
  content: ["./main/**/*.{js,ts,jsx,tsx,mdx,html}"],
  //   purge: ['./subdir/index.html',     './src/components/**/*.{js,ts,jsx,tsx,mdx}',
  // ],
  theme: {
    extend: {},
  },
  plugins: [],
};


================================================
FILE: app/tailwind.config.ts
================================================
/* eslint-disable @typescript-eslint/naming-convention */
import type {
  Config,
} from 'tailwindcss';

const config = {
  darkMode: 'media',
  content: [
    './pages/**/*.{ts,tsx}',
    './components/**/*.{ts,tsx}',
    './app/**/*.{ts,tsx}',
    './src/**/*.{ts,tsx}',
  ],
  prefix: '',
  theme: {
    container: {
      center: true,
      padding: '2rem',
      screens: {
        '2xl': '1400px',
      },
    },
    extend: {
      colors: {
        border: 'hsl(var(--border))',
        input: 'hsl(var(--input))',
        ring: 'hsl(var(--ring))',
        background: 'hsl(var(--background))',
        foreground: 'hsl(var(--foreground))',
        primary: {
          DEFAULT: 'hsl(var(--primary))',
          foreground: 'hsl(var(--primary-foreground))',
        },
        secondary: {
          DEFAULT: 'hsl(var(--secondary))',
          foreground: 'hsl(var(--secondary-foreground))',
        },
        destructive: {
          DEFAULT: 'hsl(var(--destructive))',
          foreground: 'hsl(var(--destructive-foreground))',
        },
        muted: {
          DEFAULT: 'hsl(var(--muted))',
          foreground: 'hsl(var(--muted-foreground))',
        },
        accent: {
          DEFAULT: 'hsl(var(--accent))',
          foreground: 'hsl(var(--accent-foreground))',
        },
        popover: {
          DEFAULT: 'hsl(var(--popover))',
          foreground: 'hsl(var(--popover-foreground))',
        },
        card: {
          DEFAULT: 'hsl(var(--card))',
          foreground: 'hsl(var(--card-foreground))',
        },
      },
      borderRadius: {
        lg: 'var(--radius)',
        md: 'calc(var(--radius) - 2px)',
        sm: 'calc(var(--radius) - 4px)',
      },
      keyframes: {
        'accordion-down': {
          from: { height: '0' },
          to: { height: 'var(--radix-accordion-content-height)' },
        },
        'accordion-up': {
          from: { height: 'var(--radix-accordion-content-height)' },
          to: { height: '0' },
        },
      },
      animation: {
        'accordion-down': 'accordion-down 0.2s ease-out',
        'accordion-up': 'accordion-up 0.2s ease-out',
      },
    },
  },
  // eslint-disable-next-line global-require
  plugins: [require('tailwindcss-animate')],
} satisfies Config;

export default config;


================================================
FILE: app/tsconfig.json
================================================
{
  "compilerOptions": {
    "target": "es5",
    "lib": [
      "dom",
      "dom.iterable",
      "esnext",
    ],
    "allowJs": true,
    "skipLibCheck": true,
    "strict": true,
    "forceConsistentCasingInFileNames": true,
    "noEmit": true,
    "esModuleInterop": true,
    "module": "esnext",
    "moduleResolution": "node",
    "resolveJsonModule": true,
    "isolatedModules": true,
    "jsx": "preserve",
    "incremental": true,
    "plugins": [
      {
        "name": "next",
      },
    ],
    "paths": {
      "@/*": [
        "./src/*",
      ],
    },
  },
  "include": [
    "next-env.d.ts",
    "**/*.ts",
    "**/*.tsx",
    ".next/types/**/*.ts",
    "build/types/**/*.ts",
    "main/preload.ts",
    "main/main.ts",
    "out/types/**/*.ts",
    ".eslintrc.cjs",
  ],
  "exclude": [
    "node_modules",
    "out",
    "build",
  ],
}


================================================
FILE: runner.py
================================================
# Parent script to package (PyInstaller) server
#
# Example Usage:
#
# pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \
# runner.py

from server import server
server.main()


================================================
FILE: runner.sh
================================================
#!/bin/bash

collect_modules=(
  "mlx"
  "chromadb"
)

hidden_imports=(
  "server.models"
  "server.models.gemma"
  "server.models.bert"
  "server.models.llama"
)

exclude_modules=(
  "matplotlib"
  "pandas"
  "PIL"
  "IPython"
)

misc_params=(
  "--copy-metadata opentelemetry-sdk"
)

command="pyinstaller --onefile runner.py"

for module in "${collect_modules[@]}"; do
  command+=" --collect-all $module"
done
for module in "${hidden_imports[@]}"; do
  command+=" --hidden-import $module"
done
for module in "${exclude_modules[@]}"; do
  command+=" --exclude-module $module"
done
for param in "${misc_params[@]}"; do
  command+=" $param"
done

eval "$command"


================================================
FILE: server/__init__.py
================================================
from .utils import generate, load, convert

__version__ = "0.1.0"


================================================
FILE: server/convert.py
================================================
import argparse

from .utils import convert


def configure_parser() -> argparse.ArgumentParser:
    """
    Configures and returns the argument parser for the script.

    Returns:
        argparse.ArgumentParser: Configured argument parser.
    """
    parser = argparse.ArgumentParser(
        description="Convert Hugging Face model to MLX format"
    )

    parser.add_argument("--hf-path", type=str,
                        help="Path to the Hugging Face model.")
    parser.add_argument(
        "--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
    )
    parser.add_argument(
        "-q", "--quantize", help="Generate a quantized model.", action="store_true"
    )
    parser.add_argument(
        "--q-group-size", help="Group size for quantization.", type=int, default=64
    )
    parser.add_argument(
        "--q-bits", help="Bits per weight for quantization.", type=int, default=4
    )
    parser.add_argument(
        "--dtype",
        help="Type to save the parameters, ignored if -q is given.",
        type=str,
        choices=["float16", "bfloat16", "float32"],
        default="float16",
    )
    parser.add_argument(
        "--upload-repo",
        help="The Hugging Face repo to upload the model to.",
        type=str,
        default=None,
    )
    return parser


if __name__ == "__main__":
    parser = configure_parser()
    args = parser.parse_args()
    convert(**vars(args))


================================================
FILE: server/models/__init__.py
================================================


================================================
FILE: server/models/base.py
================================================
import inspect
from dataclasses import dataclass


@dataclass
class BaseModelArgs:
    @classmethod
    def from_dict(cls, params):
        return cls(
            **{
                k: v
                for k, v in params.items()
                if k in inspect.signature(cls).parameters
            }
        )


================================================
FILE: server/models/bert.py
================================================
import math
import inspect
import mlx.core as mx
import mlx.nn as nn

from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Union, Callable

from .base import BaseModelArgs


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    classifier_dropout: float
    hidden_act: str
    hidden_dropout_prob: float
    hidden_size: int
    initializer_range: float
    intermediate_size: int
    layer_norm_eps: float
    max_position_embeddings: int
    num_attention_heads: int
    num_hidden_layers: int
    pad_token_id: int
    position_embedding_type: str
    torch_dtype: str
    type_vocab_size: int
    use_cache: bool
    vocab_size: int
    chunk_size_feed_forward: int = None
    attention_probs_dropout_prob: float = 0.0
    is_decoder: bool = False
    add_cross_attention: bool = False
    output_attentions: bool = False
    output_hidden_states: bool = False
    use_return_dict: bool = True


def apply_chunking_to_forward(
    forward_fn: Callable[..., mx.array], chunk_size: int, chunk_dim: int, *input_tensors
) -> mx.array:
    """
    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
    `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.

    If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
    applying `forward_fn` to `input_tensors`.

    Args:
        forward_fn (`Callable[..., mx.array]`):
            The forward function of the model.
        chunk_size (`int`):
            The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
        chunk_dim (`int`):
            The dimension over which the `input_tensors` should be chunked.
        input_tensors (`Tuple[mx.array]`):
            The input tensors of `forward_fn` which will be chunked

    Returns:
        `mx.array`: A tensor with the same shape as the `forward_fn` would have given if applied`.


    Examples:

    ```python
    # rename the usual forward() fn to forward_chunk()
    def __call___chunk(self, hidden_states):
        hidden_states = self.decoder(hidden_states)
        return hidden_states


    # implement a chunked forward function
    def __call__(self, hidden_states):
        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
    ```"""

    assert len(input_tensors) > 0, f"{
        input_tensors} has to be a tuple/list of tensors"

    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
    num_args_in_forward_chunk_fn = len(
        inspect.signature(forward_fn).parameters)
    if num_args_in_forward_chunk_fn != len(input_tensors):
        raise ValueError(
            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {
                len(input_tensors)} input "
            "tensors are given"
        )

    if chunk_size > 0:
        tensor_shape = input_tensors[0].shape[chunk_dim]
        for input_tensor in input_tensors:
            if input_tensor.shape[chunk_dim] != tensor_shape:
                raise ValueError(
                    f"All input tenors have to be of the same shape: {
                        tensor_shape}, "
                    f"found shape {input_tensor.shape[chunk_dim]}"
                )

        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
            raise ValueError(
                f"The dimension to be chunked {
                    input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
                f"size {chunk_size}"
            )

        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size

        # chunk input tensor into tuples
        input_tensors_chunks = tuple(input_tensor.chunk(
            num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        # apply forward fn to every tuple
        output_chunks = tuple(forward_fn(*input_tensors_chunk)
                              for input_tensors_chunk in zip(*input_tensors_chunks))
        # concatenate output at same dimension
        return mx.concatenate(output_chunks, dim=chunk_dim)

    return forward_fn(*input_tensors)


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)

        self._position_ids = mx.expand_dims(
            mx.arange(0, config.max_position_embeddings), axis=0)
        self._token_type_ids = mx.zeros((self._position_ids.shape))

        self.LayerNorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.position_embedding_type = getattr(
            config, "position_embedding_type", "absolute")

    def __call__(
        self,
        input_ids: Optional[float] = None,
        token_type_ids: Optional[float] = None,
        position_ids: Optional[float] = None,
        inputs_embeds: Optional[float] = None,
        past_key_values_length: int = 0,
    ) -> mx.array:
        if input_ids is not None:
            input_shape = input_ids.shape
        else:
            input_shape = inputs_embeds.shape[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self._position_ids[:, past_key_values_length:
                                              seq_length + past_key_values_length]

        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
        if token_type_ids is None:
            if hasattr(self, "_token_type_ids"):
                buffered_token_type_ids = self._token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = mx.repeat(
                    buffered_token_type_ids, input_shape[0], axis=0)
                token_type_ids = buffered_token_type_ids_expanded.astype(
                    mx.int8)
            else:
                token_type_ids = mx.zeros(
                    input_shape, dtype=mx.float16)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"Hidden size ({config.hidden_size}) is not a multiple of "
                f"the number of attention heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: mx.array) -> mx.array:
        new_x_shape = x.shape[
            :-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.reshape(new_x_shape)
        return x.transpose([0, 2, 1, 3])

    def __call__(
        self,
        hidden_states: mx.array,
        attention_mask: Optional[float] = None,
        head_mask: Optional[float] = None,
        encoder_hidden_states: Optional[float] = None,
        encoder_attention_mask: Optional[float] = None,
        past_key_value: Optional[Tuple[Tuple[float]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[mx.array]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(
                self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(
                self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = mx.cat([past_key_value[0], key_layer], dim=2)
            value_layer = mx.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(mx.array, mx.array) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(mx.array, mx.array) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = mx.matmul(
            query_layer, key_layer.transpose([0, 1, -1, -2]))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = mx.array(key_length - 1, dtype=mx.float16).reshape(
                    -1, 1
                )
            else:
                position_ids_l = mx.arange(
                    query_length, dtype=mx.float16).reshape(-1, 1)
            position_ids_r = mx.arange(
                key_length, dtype=mx.float16).reshape(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = mx.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = mx.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = mx.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + \
                    relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / \
            math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = mx.softmax(attention_scores, axis=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = mx.matmul(attention_probs, value_layer)

        context_layer = context_layer.transpose([0, 2, 1, 3])
        new_context_layer_shape = context_layer.shape[
            :-2] + (self.all_head_size,)
        context_layer = context_layer.reshape(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (
            context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        self.self = BertSelfAttention(
            config, position_embedding_type=position_embedding_type)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def __call__(
        self,
        hidden_states: mx.array,
        attention_mask: Optional[float] = None,
        head_mask: Optional[float] = None,
        encoder_hidden_states: Optional[float] = None,
        encoder_attention_mask: Optional[float] = None,
        past_key_value: Optional[Tuple[Tuple[float]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[mx.array]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        # add attentions if we output them
        outputs = (attention_output,) + self_outputs[1:]
        return outputs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = nn.GELU()

    def __call__(self, hidden_states: mx.array) -> mx.array:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.is_decoder = config.is_decoder
        self.add_cross_attention = config.add_cross_attention
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(
                    f"{self} should be used as a decoder model if cross attention is added")
            self.crossattention = BertAttention(
                config, position_embedding_type="absolute")
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def __call__(
        self,
        hidden_states: mx.array,
        attention_mask: Optional[float] = None,
        head_mask: Optional[float] = None,
        encoder_hidden_states: Optional[float] = None,
        encoder_attention_mask: Optional[float] = None,
        past_key_value: Optional[Tuple[Tuple[float]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[mx.array]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:
                                                  2] if past_key_value is not None else None
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]

        # if decoder, the last output is tuple of self-attn cache
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            # add self attentions if we output attention weights
            outputs = self_attention_outputs[1:]

        cross_attn_present_key_value = None
        if self.is_decoder and encoder_hidden_states is not None:
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {
                        self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:
                                                       ] if past_key_value is not None else None
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            attention_output = cross_attention_outputs[0]
            # add cross attentions if we output attention weights
            outputs = outputs + cross_attention_outputs[1:-1]

            # add cross-attn cache to positions 3,4 of present_key_value tuple
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        # if decoder, return the attn key/values as the last output
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = [
            BertLayer(config) for _ in range(config.num_hidden_layers)
        ]
        self.gradient_checkpointing = False

    def __call__(
        self,
        hidden_states: mx.array,
        attention_mask: Optional[float] = None,
        head_mask: Optional[float] = None,
        encoder_hidden_states: Optional[float] = None,
        encoder_attention_mask: Optional[float] = None,
        past_key_values: Optional[Tuple[Tuple[float]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[mx.array], Dict]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                use_cache = False

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + \
                        (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return dict(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def __call__(self, hidden_states: mx.array) -> mx.array:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertModel(nn.Module):
    """

    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in [Attention is
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
    """

    def __init__(self, config: ModelArgs, add_pooling_layer=True):
        super().__init__()
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config) if add_pooling_layer else None

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def get_extended_attention_mask(
        self, attention_mask: mx.array, input_shape: Tuple[int], dtype: float = mx.float16
    ) -> mx.array:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`mx.array`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `mx.array` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if len(attention_mask.shape) == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif len(attention_mask.shape) == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                pass
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) "
                f"or attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and the dtype's smallest value for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.astype(
            dtype=dtype)  # fp16 compatibility
        extended_attention_mask = (
            # torch.finfo(torch.float16).min
            1.0 - extended_attention_mask) * -65504
        return extended_attention_mask

    def get_head_mask(
        self, head_mask: Optional[mx.array], num_hidden_layers: int, is_attention_chunked: bool = False
    ) -> mx.array:
        """
        Prepare the head mask if needed.

        Args:
            head_mask (`mx.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
            num_hidden_layers (`int`):
                The number of hidden layers in the model.
            is_attention_chunked (`bool`, *optional*, defaults to `False`):
                Whether or not the attentions scores are computed by chunks or not.

        Returns:
            `mx.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
            `[None]` for each layer.
        """
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(
                head_mask, num_hidden_layers)
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
        else:
            head_mask = [None] * num_hidden_layers

        return head_mask

    def __call__(
        self,
        input_ids: Optional[mx.array] = None,
        attention_mask: Optional[mx.array] = None,
        token_type_ids: Optional[mx.array] = None,
        position_ids: Optional[mx.array] = None,
        head_mask: Optional[mx.array] = None,
        inputs_embeds: Optional[mx.array] = None,
        encoder_hidden_states: Optional[mx.array] = None,
        encoder_attention_mask: Optional[mx.array] = None,
        past_key_values: Optional[List[float]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[mx.array], dict]:
        r"""
        encoder_hidden_states  (`mx.float16` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`mx.float16` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (`tuple(tuple(mx.float16))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.shape
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.shape[:-1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = mx.ones(
                ((batch_size, seq_length + past_key_values_length)))

        if token_type_ids is None:
            if hasattr(self.embeddings, "_token_type_ids"):
                buffered_token_type_ids = self.embeddings._token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = mx.repeat(
                    buffered_token_type_ids, batch_size, axis=0)
                token_type_ids = buffered_token_type_ids_expanded.astype(
                    mx.int8)
            else:
                token_type_ids = mx.zeros(
                    input_shape, dtype=mx.float16)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: mx.array = self.get_extended_attention_mask(
            attention_mask, input_shape)
        encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(
            head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs['last_hidden_state'] if return_dict else encoder_outputs
        pooled_output = self.pooler(
            sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return dict(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs['past_key_values'],
            hidden_states=encoder_outputs['hidden_states'],
            attentions=encoder_outputs['attentions'],
            cross_attentions=encoder_outputs['cross_attentions'],
        )


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.model_type = args.model_type
        self.model = BertModel(args)

    def __call__(
        self,
        input_ids: mx.array,
        attention_mask: mx.array = None,
    ):
        return self.model(input_ids, attention_mask)

    @staticmethod
    def sanitize(weights):
        # remove position_ids and add model.
        return {
            f'model.{k}' if not 'model' in k else k: v for k, v in weights.items() if 'embeddings.position_ids' not in k
        }

    @property
    def layers(self):
        return self.model.layers


================================================
FILE: server/models/gemma.py
================================================
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    hidden_size: int
    num_hidden_layers: int
    intermediate_size: int
    num_attention_heads: int
    head_dim: int
    rms_norm_eps: float
    vocab_size: int
    num_key_value_heads: int = None
    rope_theta: float = 10000
    rope_traditional: bool = False


@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
    x = x.astype(mx.float32)
    x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
    return (1.0 + weight) * x.astype(weight.dtype)


class RMSNorm(nn.Module):
    def __init__(self, dims: int, eps: float = 1e-5):
        super().__init__()
        self.weight = mx.ones((dims,))
        self.eps = eps

    def __call__(self, x):
        return rms_norm(x, self.weight, self.eps)


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        dim = args.hidden_size
        self.n_heads = n_heads = args.num_attention_heads
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
        self.head_dim = head_dim = args.head_dim

        self.repeats = n_heads // n_kv_heads

        self.scale = head_dim**-0.5

        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

        self.rope = nn.RoPE(
            head_dim,
            traditional=args.rope_traditional,
            base=args.rope_theta,
        )

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        B, L, D = x.shape

        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

        if self.repeats > 1:
            keys = mx.repeat(keys, self.repeats, axis=1)
            values = mx.repeat(values, self.repeats, axis=1)

        if cache is not None:
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)

        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
        if mask is not None:
            scores += mask
        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.o_proj(output), (keys, values)


class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)

    def __call__(self, x) -> mx.array:
        return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.num_attention_heads = args.num_attention_heads
        self.hidden_size = args.hidden_size
        self.self_attn = Attention(args)
        self.mlp = MLP(args.hidden_size, args.intermediate_size)
        self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
        self.args = args

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
        h = x + r
        r = self.mlp(self.post_attention_layernorm(h))
        out = h + r
        return out, cache


class GemmaModel(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.num_hidden_layers = args.num_hidden_layers
        assert self.vocab_size > 0
        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [
            TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
        ]
        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        h = self.embed_tokens(inputs)
        h = h * (self.args.hidden_size**0.5)

        mask = None
        if h.shape[1] > 1:
            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
            mask = mask.astype(h.dtype)

        if cache is None:
            cache = [None] * len(self.layers)

        for e, layer in enumerate(self.layers):
            h, cache[e] = layer(h, mask, cache[e])

        return self.norm(h), cache


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.model_type = args.model_type
        self.model = GemmaModel(args)

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        out, cache = self.model(inputs, cache)
        out = out @ self.model.embed_tokens.weight.T
        return out, cache

    @property
    def layers(self):
        return self.model.layers


================================================
FILE: server/models/layers.py
================================================
from functools import partial

import mlx.core as mx
import mlx.nn as nn


@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
    x = x.astype(mx.float32)
    x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
    return weight * x.astype(weight.dtype)


class RMSNorm(nn.Module):
    def __init__(self, dims: int, eps: float = 1e-5):
        super().__init__()
        self.weight = mx.ones((dims,))
        self.eps = eps

    def __call__(self, x):
        return rms_norm(x, self.weight, self.eps)


@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
    t = x.dtype
    x = x.astype(mx.float32)
    means = mx.mean(x, axis=-1, keepdims=True)
    var = mx.var(x, axis=-1, keepdims=True)
    x = (x - means) * mx.rsqrt(var + eps)
    x = x.astype(t)
    return weight * x + bias if weight is not None else x


class LayerNorm(nn.Module):
    def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
        super().__init__()
        if affine:
            self.bias = mx.zeros((dims,))
            self.weight = mx.ones((dims,))
        self.eps = eps
        self.dims = dims

    def _extra_repr(self):
        return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        if "weight" in self:
            return ln_norm(x, self.eps, self.weight, self.bias)
        else:
            return ln_norm(x, self.eps)


================================================
FILE: server/models/llama.py
================================================
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .layers import RMSNorm


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    hidden_size: int
    num_hidden_layers: int
    intermediate_size: int
    num_attention_heads: int
    rms_norm_eps: float
    vocab_size: int
    num_key_value_heads: int = None
    rope_theta: float = 10000
    rope_traditional: bool = False
    rope_scaling: Optional[Dict[str, Union[float, str]]] = None

    def __post_init__(self):
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads

        if self.rope_scaling:
            required_keys = {"factor", "type"}
            if not all(key in self.rope_scaling for key in required_keys):
                raise ValueError(f"rope_scaling must contain keys {required_keys}")

            if self.rope_scaling["type"] != "linear":
                raise ValueError("rope_scaling 'type' currently only supports 'linear'")


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        dim = args.hidden_size
        self.n_heads = n_heads = args.num_attention_heads
        self.n_kv_heads = n_kv_heads = args.num_key_value_heads

        self.repeats = n_heads // n_kv_heads

        head_dim = args.hidden_size // n_heads
        self.scale = head_dim**-0.5

        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

        rope_scale = (
            1 / args.rope_scaling["factor"]
            if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
            else 1
        )
        self.rope = nn.RoPE(
            head_dim,
            traditional=args.rope_traditional,
            base=args.rope_theta,
            scale=rope_scale,
        )

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        B, L, D = x.shape

        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

        if self.repeats > 1:
            keys = mx.repeat(keys, self.repeats, axis=1)
            values = mx.repeat(values, self.repeats, axis=1)

        if cache is not None:
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)

        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
        if mask is not None:
            scores += mask
        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.o_proj(output), (keys, values)


class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)

    def __call__(self, x) -> mx.array:
        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.num_attention_heads = args.num_attention_heads
        self.hidden_size = args.hidden_size
        self.self_attn = Attention(args)
        self.mlp = MLP(args.hidden_size, args.intermediate_size)
        self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
        self.args = args

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
        h = x + r
        r = self.mlp(self.post_attention_layernorm(h))
        out = h + r
        return out, cache


class LlamaModel(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.num_hidden_layers = args.num_hidden_layers
        assert self.vocab_size > 0
        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
        self.layers = [
            TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
        ]
        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        h = self.embed_tokens(inputs)

        mask = None
        if h.shape[1] > 1:
            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
            mask = mask.astype(h.dtype)

        if cache is None:
            cache = [None] * len(self.layers)

        for e, layer in enumerate(self.layers):
            h, cache[e] = layer(h, mask, cache[e])

        return self.norm(h), cache


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.model_type = args.model_type
        self.model = LlamaModel(args)
        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        out, cache = self.model(inputs, cache)
        return self.lm_head(out), cache

    @staticmethod
    def sanitize(weights):
        # Remove unused precomputed rotary freqs
        return {
            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
        }

    @property
    def layers(self):
        return self.model.layers


================================================
FILE: server/py.typed
================================================



================================================
FILE: server/requirements.txt
================================================
chromadb==0.4.23
huggingface_hub==0.20.3
mlx==0.4.0
mlx_data==0.0.2
transformers==4.38.1
pyinstaller==6.4.0


================================================
FILE: server/retriever/document.py
================================================
from typing import Any, Literal, Optional


class Document():
    """Class for storing a piece of text and associated metadata."""

    page_content: str
    """String text."""
    metadata: dict = dict()
    """Arbitrary metadata about the page content (e.g., source, relationships to other
        documents, etc.).
    """
    type: Literal["Document"] = "Document"

    def __init__(self, page_content: str, metadata: Optional[dict] = None, **kwargs: Any) -> None:
        """Pass page_content in as positional or named arg."""
        self.page_content = page_content
        self.metadata = metadata or dict()

        for key, value in kwargs.items():
            setattr(self, key, value)


================================================
FILE: server/retriever/embeddings.py
================================================
import os
import mlx.core as mx
import mlx.nn as nn

from transformers import PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List

from ..utils import load, get_mlx_path, convert


class Embeddings(ABC):
    """Interface for embedding models."""

    @abstractmethod
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed search docs."""

    @abstractmethod
    def embed_query(self, text: str) -> List[float]:
        """Embed query text."""


class E5Embeddings(Embeddings):

    model: Any = None
    tokenizer: PreTrainedTokenizer = None

    def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False):
        mlx_path = get_mlx_path(hf_path, quantize=quantize)
        if not os.path.isdir(mlx_path):
            convert(hf_path, mlx_path, quantize=quantize)
        self.model, self.tokenizer = load(mlx_path)

    def _average_pool(self, last_hidden_states: mx.array,
                      attention_mask: mx.array) -> mx.array:
        last_hidden = mx.where(~attention_mask[..., None].astype(dtype=mx.bool_),
                               0.0, last_hidden_states)
        return mx.sum(last_hidden, axis=1) / mx.sum(attention_mask, axis=1, keepdims=True)

    def embed_documents(self, texts: List[str], batch_size: int = 8) -> List[List[float]]:
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_embeddings = self.embed_query(batch_texts, batch=True)
            embeddings.extend(batch_embeddings)
        return embeddings

    def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
        tokens = self.tokenizer(texts, max_length=512, padding=True,
                                truncation=True, return_tensors='np',
                                return_attention_mask=True)
        tokens = {key: mx.array(v) for key, v in tokens.items()}
        outputs = self.model(**tokens)
        embeddings = self._average_pool(
            outputs['last_hidden_state'], tokens['attention_mask'])
        embeddings = embeddings / \
            mx.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)

        if batch:
            return embeddings.tolist()  # -> List[List[float]]

        return embeddings[0].tolist()  # -> List[float]


class ChatEmbeddings(Embeddings):

    model: nn.Module = None
    tokenizer: PreTrainedTokenizer = None

    def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return [self.embed_query(text) for text in texts]

    def embed_query(self,  text: str) -> List[float]:
        h = self.model.embed_tokens(mx.array(
            self.tokenizer.encode(text, add_special_tokens=False)))
        # normalized to have unit length
        h = mx.mean(h, axis=0)
        h = h / mx.linalg.norm(h)
        return h.tolist()


================================================
FILE: server/retriever/loader.py
================================================
import os
import glob
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor

from .document import Document


def directory_loader(directory: Optional[str] = None) -> Optional[List[Document]]:
    if directory is not None and os.path.exists(directory):
        allowed_extensions = ['.txt', '.md', '.csv', '.json', '.xml', '.ts']

        def read_file(file_path):
            _, file_extension = os.path.splitext(file_path)
            if file_extension.lower() in allowed_extensions:
                with open(file_path, 'r', encoding='utf-8') as file:
                    return Document(page_content=file.read(), metadata={'source': file_path})

        files = glob.glob(os.path.join(directory, '**', '*.*'), recursive=True)

        with ThreadPoolExecutor() as executor:
            return list(filter(None, executor.map(read_file, files)))
    else:
        raise FileNotFoundError(f"Directory '{directory}' does not exist.")


================================================
FILE: server/retriever/splitter.py
================================================
import re
import copy

from abc import ABC, abstractmethod
from typing import (
    Any,
    List,
    Optional,
    Callable,
    Iterable
)

from .document import Document


def _split_text_with_regex(
    text: str, separator: str, keep_separator: bool
) -> List[str]:
    # Now that we have the separator, split the text
    if separator:
        if keep_separator:
            # The parentheses in the pattern keep the delimiters in the result.
            _splits = re.split(f"({separator})", text)
            splits = [_splits[i] + _splits[i + 1]
                      for i in range(1, len(_splits), 2)]
            if len(_splits) % 2 == 0:
                splits += _splits[-1:]
            splits = [_splits[0]] + splits
        else:
            splits = re.split(separator, text)
    else:
        splits = list(text)
    return [s for s in splits if s != ""]


class TextSplitter(ABC):
    """Interface for splitting text into chunks."""

    def __init__(
        self,
        chunk_size: int = 4000,
        chunk_overlap: int = 200,
        length_function: Callable[[str], int] = len,
        keep_separator: bool = False,
        add_start_index: bool = False,
        strip_whitespace: bool = True,
    ) -> None:
        """Create a new TextSplitter.

        Args:
            chunk_size: Maximum size of chunks to return
            chunk_overlap: Overlap in characters between chunks
            length_function: Function that measures the length of given chunks
            keep_separator: Whether to keep the separator in the chunks
            add_start_index: If `True`, includes chunk's start index in metadata
            strip_whitespace: If `True`, strips whitespace from the start and end of
                              every document
        """
        if chunk_overlap > chunk_size:
            raise ValueError(f"Got a larger chunk overlap ({
                             chunk_overlap}) than chunk size ({chunk_size}), should be smaller.")
        self._chunk_size = chunk_size
        self._chunk_overlap = chunk_overlap
        self._length_function = length_function
        self._keep_separator = keep_separator
        self._add_start_index = add_start_index
        self._strip_whitespace = strip_whitespace

    @abstractmethod
    def split_text(self, text: str) -> List[str]:
        """Split text into multiple components."""

    def create_documents(
        self, texts: List[str], metadatas: Optional[List[dict]] = None
    ) -> List[Document]:
        """Create documents from a list of texts."""
        _metadatas = metadatas or [{}] * len(texts)
        documents = []
        for i, text in enumerate(texts):
            index = 0
            previous_chunk_len = 0
            for chunk in self.split_text(text):
                metadata = copy.deepcopy(_metadatas[i])
                if self._add_start_index:
                    offset = index + previous_chunk_len - self._chunk_overlap
                    index = text.find(chunk, max(0, offset))
                    metadata["start_index"] = index
                    previous_chunk_len = len(chunk)
                new_doc = Document(page_content=chunk, metadata=metadata)
                documents.append(new_doc)
        return documents

    def split_documents(self, documents: Iterable[Document]) -> List[Document]:
        """Split documents."""
        texts, metadatas = [], []
        for doc in documents:
            texts.append(doc.page_content)
            metadatas.append(doc.metadata)
        return self.create_documents(texts, metadatas=metadatas)

    def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
        text = separator.join(docs)
        if self._strip_whitespace:
            text = text.strip()
        if text == "":
            return None
        else:
            return text

    def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
        # We now want to combine these smaller pieces into medium size
        # chunks to send to the LLM.
        separator_len = self._length_function(separator)

        docs = []
        current_doc: List[str] = []
        total = 0
        for d in splits:
            _len = self._length_function(d)
            if (
                total + _len + (separator_len if len(current_doc) > 0 else 0)
                > self._chunk_size
            ):
                if total > self._chunk_size:
                    print(f"Created a chunk of size {total}, " +
                          f"which is longer than the specified {self._chunk_size}")
                if len(current_doc) > 0:
                    doc = self._join_docs(current_doc, separator)
                    if doc is not None:
                        docs.append(doc)
                    # Keep on popping if:
                    # - we have a larger chunk than in the chunk overlap
                    # - or if we still have any chunks and the length is long
                    while total > self._chunk_overlap or (
                        total + _len +
                            (separator_len if len(current_doc) > 0 else 0)
                        > self._chunk_size
                        and total > 0
                    ):
                        total -= self._length_function(current_doc[0]) + (
                            separator_len if len(current_doc) > 1 else 0
                        )
                        current_doc = current_doc[1:]
            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
        doc = self._join_docs(current_doc, separator)
        if doc is not None:
            docs.append(doc)
        return docs


class RecursiveCharacterTextSplitter(TextSplitter):
    """Splitting text by recursively look at characters.

    Recursively tries to split by different characters to find one
    that works.
    """

    def __init__(
        self,
        separators: Optional[List[str]] = None,
        keep_separator: bool = True,
        is_separator_regex: bool = False,
        **kwargs: Any,
    ) -> None:
        """Create a new TextSplitter."""
        super().__init__(keep_separator=keep_separator, **kwargs)
        self._separators = separators or ["\n\n", "\n", " ", ""]
        self._is_separator_regex = is_separator_regex

    def _split_text(self, text: str, separators: List[str]) -> List[str]:
        """Split incoming text and return chunks."""
        final_chunks = []
        # Get appropriate separator to use
        separator = separators[-1]
        new_separators = []
        for i, _s in enumerate(separators):
            _separator = _s if self._is_separator_regex else re.escape(_s)
            if _s == "":
                separator = _s
                break
            if re.search(_separator, text):
                separator = _s
                new_separators = separators[i + 1:]
                break

        _separator = separator if self._is_separator_regex else re.escape(
            separator)
        splits = _split_text_with_regex(text, _separator, self._keep_separator)

        # Now go merging things, recursively splitting longer texts.
        _good_splits = []
        _separator = "" if self._keep_separator else separator
        for s in splits:
            if self._length_function(s) < self._chunk_size:
                _good_splits.append(s)
            else:
                if _good_splits:
                    merged_text = self._merge_splits(_good_splits, _separator)
                    final_chunks.extend(merged_text)
                    _good_splits = []
                if not new_separators:
                    final_chunks.append(s)
                else:
                    other_info = self._split_text(s, new_separators)
                    final_chunks.extend(other_info)
        if _good_splits:
            merged_text = self._merge_splits(_good_splits, _separator)
            final_chunks.extend(merged_text)
        return final_chunks

    def split_text(self, text: str) -> List[str]:
        return self._split_text(text, self._separators)


================================================
FILE: server/retriever/vectorstore.py
================================================
import uuid
import functools
import mlx.core as mx

import chromadb
import chromadb.config

from chromadb.utils.batch_utils import create_batches
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
from typing import (
    Any,
    List,
    Dict,
    TypeVar,
    Callable,
    Iterable,
    Optional,
    Tuple,
    Type,
)
from .document import Document
from .embeddings import Embeddings

Chroma = TypeVar('Chroma', bound='Chroma')


DEFAULT_K = 4  # Number of Documents to return.


def _results_to_docs(results: Any) -> List[Document]:
    return [doc for doc, _ in _results_to_docs_and_scores(results)]


def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
    return [
        # TODO: Chroma can do batch querying,
        # we shouldn't hard code to the 1st result
        (Document(page_content=result[0], metadata=result[1] or {}), result[2])
        for result in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0],
        )
    ]


def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
    """Validate specified keyword args are mutually exclusive."""

    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            """Validate exactly one arg in each group is not None."""
            counts = [
                sum(1 for arg in arg_group if kwargs.get(arg) is not None)
                for arg_group in arg_groups
            ]
            invalid_groups = [
                i for i, count in enumerate(counts) if count != 1]
            if invalid_groups:
                invalid_group_names = [
                    ", ".join(arg_groups[i]) for i in invalid_groups]
                raise ValueError(
                    "Exactly one argument in each of the following"
                    " groups must be defined:"
                    f" {', '.join(invalid_group_names)}"
                )
            return func(*args, **kwargs)

        return wrapper

    return decorator


def cosine_similarity(
    X: mx.array, T: mx.array, axis: int = 1
) -> mx.array:
    """Row-wise cosine similarity between two equal-width matrices."""
    X, T = mx.array(X), mx.array(T)
    X_norm = mx.linalg.norm(X, axis=axis)
    T_norm = mx.linalg.norm(T, axis=axis)
    similarity = X @ T.T / mx.outer(X_norm, T_norm)
    return similarity


def maximal_marginal_relevance(
    query_embedding: mx.array,
    embedding_list: mx.array,
    lambda_mult: float = 0.5,
    k: int = 4,
) -> List[int]:
    """Calculate maximal marginal relevance."""
    if min(k, len(embedding_list)) <= 0:
        return []
    if query_embedding.ndim == 1:
        query_embedding = mx.expand_dims(query_embedding, axis=0)
    similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
    most_similar = int(mx.argmax(similarity_to_query).tolist())
    idxs = [most_similar]
    selected = mx.array([embedding_list[most_similar]])
    while len(idxs) < min(k, len(embedding_list)):
        best_score = -mx.inf
        idx_to_add = -1
        similarity_to_selected = cosine_similarity(embedding_list, selected)
        for i, query_score in enumerate(similarity_to_query):
            if i in idxs:
                continue
            redundant_score = max(similarity_to_selected[i])
            equation_score = (
                lambda_mult * query_score - (1 - lambda_mult) * redundant_score
            )
            if equation_score > best_score:
                best_score = equation_score
                idx_to_add = i
        idxs.append(idx_to_add)
        selected = mx.concatenate([
            selected, embedding_list[idx_to_add:idx_to_add+1]], axis=0)
    return idxs


class Chroma():
    """
    similarity_search
    max_marginal_relevance_search
    """
    _DEFAULT_COLLECTION_NAME = "mlx-chat-app"

    def __init__(
        self,
        collection_name: str = _DEFAULT_COLLECTION_NAME,
        embedding_function: Optional[Embeddings] = None,
        persist_directory: Optional[str] = None,
        client_settings: Optional[chromadb.config.Settings] = None,
        collection_metadata: Optional[Dict] = None,
        client: Optional[chromadb.Client] = None,
        relevance_score_fn: Optional[Callable[[float], float]] = None,
    ) -> None:

        if client is not None:
            self._client_settings = client_settings
            self._client = client
            self._persist_directory = persist_directory
        else:
            if client_settings:
                # If client_settings is provided with persist_directory specified,
                # then it is "in-memory and persisting to disk" mode.
                client_settings.persist_directory = (
                    persist_directory or client_settings.persist_directory
                )
                if client_settings.persist_directory is not None:
                    # Maintain backwards compatibility with chromadb < 0.4.0
                    major, minor, _ = chromadb.__version__.split(".")
                    if int(major) == 0 and int(minor) < 4:
                        client_settings.chroma_db_impl = "duckdb+parquet"

                _client_settings = client_settings
            elif persist_directory:
                # Maintain backwards compatibility with chromadb < 0.4.0
                major, minor, _ = chromadb.__version__.split(".")
                if int(major) == 0 and int(minor) < 4:
                    _client_settings = chromadb.config.Settings(
                        chroma_db_impl="duckdb+parquet",
                    )
                else:
                    _client_settings = chromadb.config.Settings(
                        is_persistent=True)
                _client_settings.persist_directory = persist_directory
            else:
                _client_settings = chromadb.config.Settings()
            self._client_settings = _client_settings
            self._client = chromadb.Client(_client_settings)
            self._persist_directory = (
                _client_settings.persist_directory or persist_directory
            )

        self._embedding_function = embedding_function
        self._collection = self._client.get_or_create_collection(
            name=collection_name,
            embedding_function=None,
            metadata=collection_metadata,
        )
        self.override_relevance_score_fn = relevance_score_fn

    @property
    def embeddings(self) -> Optional[Embeddings]:
        return self._embedding_function

    @xor_args(("query_texts", "query_embeddings"))
    def __query_collection(
        self,
        query_texts: Optional[List[str]] = None,
        query_embeddings: Optional[List[List[float]]] = None,
        n_results: int = 4,
        where: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Query the chroma collection."""
        return self._collection.query(
            query_texts=query_texts,
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document,
            **kwargs,
        )

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> List[str]:
        """Run more texts through the embeddings and add to the vectorstore.

        Args:
            texts (Iterable[str]): Texts to add to the vectorstore.
            metadatas (Optional[List[dict]], optional): Optional list of metadatas.
            ids (Optional[List[str]], optional): Optional list of IDs.

        Returns:
            List[str]: List of IDs of the added texts.
        """
        # TODO: Handle the case where the user doesn't provide ids on the Collection
        if ids is None:
            ids = [str(uuid.uuid1()) for _ in texts]
        embeddings = None
        texts = list(texts)
        if self._embedding_function is not None:
            embeddings = self._embedding_function.embed_documents(texts)
        if metadatas:
            # fill metadatas with empty dicts if somebody
            # did not specify metadata for all texts
            length_diff = len(texts) - len(metadatas)
            if length_diff:
                metadatas = metadatas + [{}] * length_diff
            empty_ids = []
            non_empty_ids = []
            for idx, m in enumerate(metadatas):
                if m:
                    non_empty_ids.append(idx)
                else:
                    empty_ids.append(idx)
            if non_empty_ids:
                metadatas = [metadatas[idx] for idx in non_empty_ids]
                texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
                embeddings_with_metadatas = (
                    [embeddings[idx]
                        for idx in non_empty_ids] if embeddings else None
                )
                ids_with_metadata = [ids[idx] for idx in non_empty_ids]
                try:
                    self._collection.upsert(
                        metadatas=metadatas,
                        embeddings=embeddings_with_metadatas,
                        documents=texts_with_metadatas,
                        ids=ids_with_metadata,
                    )
                except ValueError as e:
                    if "Expected metadata value to be" in str(e):
                        msg = (
                            "Try filtering complex metadata from the document using "
                            "langchain_community.vectorstores.utils.filter_complex_metadata."
                        )
                        raise ValueError(e.args[0] + "\n\n" + msg)
                    else:
                        raise e
            if empty_ids:
                texts_without_metadatas = [texts[j] for j in empty_ids]
                embeddings_without_metadatas = (
                    [embeddings[j] for j in empty_ids] if embeddings else None
                )
                ids_without_metadatas = [ids[j] for j in empty_ids]
                self._collection.upsert(
                    embeddings=embeddings_without_metadatas,
                    documents=texts_without_metadatas,
                    ids=ids_without_metadatas,
                )
        else:
            self._collection.upsert(
                embeddings=embeddings,
                documents=texts,
                ids=ids,
            )
        return ids

    def similarity_search(
        self,
        query: str,
        k: int = DEFAULT_K,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Run similarity search with Chroma.

        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.

        Returns:
            List[Document]: List of documents most similar to the query text.
        """
        docs_and_scores = self.similarity_search_with_score(
            query, k, filter=filter, **kwargs
        )
        return [doc for doc, _ in docs_and_scores]

    def similarity_search_with_score(
        self,
        query: str,
        k: int = DEFAULT_K,
        filter: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Run similarity search with Chroma with distance.

        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.

        Returns:
            List[Tuple[Document, float]]: List of documents most similar to
            the query text and cosine distance in float for each.
            Lower score represents more similarity.
        """
        if self._embedding_function is None:
            results = self.__query_collection(
                query_texts=[query],
                n_results=k,
                where=filter,
                where_document=where_document,
                **kwargs,
            )
        else:
            query_embedding = self._embedding_function.embed_query(query)
            results = self.__query_collection(
                query_embeddings=[query_embedding],
                n_results=k,
                where=filter,
                where_document=where_document,
                **kwargs,
            )

        return _results_to_docs_and_scores(results)

    def max_marginal_relevance_search_by_vector(
        self,
        embedding: List[float],
        k: int = DEFAULT_K,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.
        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.

        Returns:
            List of Documents selected by maximal marginal relevance.
        """

        results = self.__query_collection(
            query_embeddings=embedding,
            n_results=fetch_k,
            where=filter,
            where_document=where_document,
            include=["metadatas", "documents", "distances", "embeddings"],
            **kwargs,
        )
        mmr_selected = maximal_marginal_relevance(
            mx.array(embedding, dtype=mx.float32),
            mx.array(results["embeddings"][0]),
            k=k,
            lambda_mult=lambda_mult,
        )

        candidates = _results_to_docs(results)

        selected_results = [r for i, r in enumerate(
            candidates) if i in mmr_selected]
        return selected_results

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = DEFAULT_K,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        filter: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.
        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            query: Text to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.

        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        if self._embedding_function is None:
            raise ValueError(
                "For MMR search, you must specify an embedding function on" "creation."
            )

        embedding = self._embedding_function.embed_query(query)
        docs = self.max_marginal_relevance_search_by_vector(
            embedding,
            k,
            fetch_k,
            lambda_mult=lambda_mult,
            filter=filter,
            where_document=where_document,
        )
        return docs

    def delete_collection(self) -> None:
        """Delete the collection."""
        self._client.delete_collection(self._collection.name)

    def get(
        self,
        ids: Optional[OneOrMany[ID]] = None,
        where: Optional[Where] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        where_document: Optional[WhereDocument] = None,
        include: Optional[List[str]] = None,
    ) -> Dict[str, Any]:
        """Gets the collection.

        Args:
            ids: The ids of the embeddings to get. Optional.
            where: A Where type dict used to filter results by.
                   E.g. `{"color" : "red", "price": 4.20}`. Optional.
            limit: The number of documents to return. Optional.
            offset: The offset to start returning results from.
                    Useful for paging results with limit. Optional.
            where_document: A WhereDocument type dict used to filter by the documents.
                            E.g. `{$contains: "hello"}`. Optional.
            include: A list of what to include in the results.
                     Can contain `"embeddings"`, `"metadatas"`, `"documents"`.
                     Ids are always included.
                     Defaults to `["metadatas", "documents"]`. Optional.
        """
        kwargs = {
            "ids": ids,
            "where": where,
            "limit": limit,
            "offset": offset,
            "where_document": where_document,
        }

        if include is not None:
            kwargs["include"] = include

        return self._collection.get(**kwargs)

    def update_document(self, document_id: str, document: Document) -> None:
        """Update a document in the collection.

        Args:
            document_id (str): ID of the document to update.
            document (Document): Document to update.
        """
        return self.update_documents([document_id], [document])

    def update_documents(self, ids: List[str], documents: List[Document]) -> None:
        """Update a document in the collection.

        Args:
            ids (List[str]): List of ids of the document to update.
            documents (List[Document]): List of documents to update.
        """
        text = [document.page_content for document in documents]
        metadata = [document.metadata for document in documents]
        if self._embedding_function is None:
            raise ValueError(
                "For update, you must specify an embedding function on creation."
            )
        embeddings = self._embedding_function.embed_documents(text)

        if hasattr(
            self._collection._client, "max_batch_size"
        ):
            for batch in create_batches(
                api=self._collection._client,
                ids=ids,
                metadatas=metadata,
                documents=text,
                embeddings=embeddings,
            ):
                self._collection.update(
                    ids=batch[0],
                    embeddings=batch[1],
                    documents=batch[3],
                    metadatas=batch[2],
                )
        else:
            self._collection.update(
                ids=ids,
                embeddings=embeddings,
                documents=text,
                metadatas=metadata,
            )

    @classmethod
    def from_texts(
        cls: Type[Chroma],
        texts: List[str],
        embedding: Optional[Embeddings] = None,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        collection_name: str = _DEFAULT_COLLECTION_NAME,
        persist_directory: Optional[str] = None,
        client_settings: Optional[chromadb.config.Settings] = None,
        client: Optional[chromadb.Client] = None,
        collection_metadata: Optional[Dict] = None,
        **kwargs: Any,
    ) -> Chroma:
        """Create a Chroma vectorstore from a raw documents.

        If a persist_directory is specified, the collection will be persisted there.
        Otherwise, the data will be ephemeral in-memory.

        Args:
            texts (List[str]): List of texts to add to the collection.
            collection_name (str): Name of the collection to create.
            persist_directory (Optional[str]): Directory to persist the collection.
            embedding (Optional[Embeddings]): Embedding function. Defaults to None.
            metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
            ids (Optional[List[str]]): List of document IDs. Defaults to None.
            client_settings (Optional[chromadb.config.Settings]): Chroma client settings
            collection_metadata (Optional[Dict]): Collection configurations.
                                                  Defaults to None.

        Returns:
            Chroma: Chroma vectorstore.
        """
        chroma_collection = cls(
            collection_name=collection_name,
            embedding_function=embedding,
            persist_directory=persist_directory,
            client_settings=client_settings,
            client=client,
            collection_metadata=collection_metadata,
            **kwargs,
        )
        if ids is None:
            ids = [str(uuid.uuid1()) for _ in texts]
        if hasattr(
            chroma_collection._client, "max_batch_size"
        ):
            for batch in create_batches(
                api=chroma_collection._client,
                ids=ids,
                metadatas=metadatas,
                documents=texts,
            ):
                chroma_collection.add_texts(
                    texts=batch[3] if batch[3] else [],
                    metadatas=batch[2] if batch[2] else None,
                    ids=batch[0],
                )
        else:
            chroma_collection.add_texts(
                texts=texts, metadatas=metadatas, ids=ids)
        return chroma_collection

    @classmethod
    def from_documents(
        cls: Type[Chroma],
        documents: List[Document],
        embedding: Optional[Embeddings] = None,
        ids: Optional[List[str]] = None,
        collection_name: str = _DEFAULT_COLLECTION_NAME,
        persist_directory: Optional[str] = None,
        client_settings: Optional[chromadb.config.Settings] = None,
        client: Optional[chromadb.Client] = None,  # Add this line
        collection_metadata: Optional[Dict] = None,
        **kwargs: Any,
    ) -> Chroma:
        """Create a Chroma vectorstore from a list of documents.

        If a persist_directory is specified, the collection will be persisted there.
        Otherwise, the data will be ephemeral in-memory.

        Args:
            collection_name (str): Name of the collection to create.
            persist_directory (Optional[str]): Directory to persist the collection.
            ids (Optional[List[str]]): List of document IDs. Defaults to None.
            documents (List[Document]): List of documents to add to the vectorstore.
            embedding (Optional[Embeddings]): Embedding function. Defaults to None.
            client_settings (Optional[chromadb.config.Settings]): Chroma client settings
            collection_metadata (Optional[Dict]): Collection configurations.
                                                  Defaults to None.

        Returns:
            Chroma: Chroma vectorstore.
        """
        texts = [doc.page_content for doc in documents]
        metadatas = [doc.metadata for doc in documents]
        return cls.from_texts(
            texts=texts,
            embedding=embedding,
            metadatas=metadatas,
            ids=ids,
            collection_name=collection_name,
            persist_directory=persist_directory,
            client_settings=client_settings,
            client=client,
            collection_metadata=collection_metadata,
            **kwargs,
        )

    def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
        """Delete by vector IDs.

        Args:
            ids: List of ids to delete.
        """
        self._collection.delete(ids=ids)


================================================
FILE: server/server.py
================================================
import os
import sys
import json
import time
import uuid

import mlx.core as mx
import mlx.nn as nn

from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Dict, Optional
from transformers import PreTrainedTokenizer

from .utils import load, generate_step, get_mlx_path, convert

from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
from .retriever.vectorstore import Chroma
from .retriever.embeddings import ChatEmbeddings, E5Embeddings

_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
_database: Optional[Chroma] = None


def load_model(model_path: str, adapter_file: Optional[str] = None):
    global _model
    global _tokenizer

    models_to_quantize = ['mistral', 'llama', 'gemma']
    quantize = any(variable in model_path for variable in models_to_quantize)

    mlx_path = get_mlx_path(model_path, quantize=quantize)
    if not os.path.isdir(mlx_path):
        convert(model_path, mlx_path, quantize=quantize)

    _model, _tokenizer = load(mlx_path, adapter_file=adapter_file)


def index_directory(directory: str, use_embedding: bool = True):
    global _database
    start_t = time.time()
    raw_docs = directory_loader(directory)
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=512, chunk_overlap=32, add_start_index=True
    )
    embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings(
        model=_model.model, tokenizer=_tokenizer)
    splits = text_splitter.split_documents(raw_docs)
    _database = Chroma.from_documents(
        documents=splits,
        embedding=embedding
    )
    print(f'>> indexed {len(splits)} documents in',
          f'{time.time() - start_t:.2f}s', flush=True)


def create_response(chat_id, prompt, tokens, text):
    response = {
        'id': chat_id,
        'object': 'chat.completion',
        'created': int(time.time()),
        'model':  _model.model_type,
        'system_fingerprint': f'fp_{uuid.uuid4()}',
        'choices': [
            {
                'index': 0,
                'message': {
                    'role': 'assistant',
                    'content': text,
                },
                'logprobs': None,
                'finish_reason': None,
            }
        ],
        'usage': {
            'prompt_tokens': len(prompt),
            'completion_tokens': len(tokens),
            'total_tokens': len(prompt) + len(tokens),
        },
    }
    return response


def format_messages(messages: List[Dict], indexed_files: Optional[str], instructions: Optional[Dict]):
    personalization = instructions.get(
        'personalization', '').strip().replace('\n', '; ')
    response = instructions.get('response', '').strip().replace('\n', '; ')

    context = f"with background knowledge of {
        indexed_files.strip().replace('\n', '; ')}" if indexed_files else ''
    audience = personalization if personalization else 'general'
    style = response if response else 'technical, accurate, and professional'

    messages[-1]['content'] = f"""
<Context>
  you are my personalized AI chatbot {context}
</Context>
<Objective>
  respond to the following: {messages[-1]['content']}
</Objective>
<Style>
  {style}
</Style>
<Tone>
  friendly, helpful, and confident
</Tone>
<Audience>
  {audience}
</Audience>
<Response>
  brief, concise, and to the point. Please don't start with "Sure, ..."
</Response>
""".strip()


class APIHandler(BaseHTTPRequestHandler):
    def _set_headers(self, status_code=200):
        self.send_response(status_code)
        self.send_header('Content-type', 'application/json')
        self.send_header('Access-Control-Allow-Origin', '*')
        self.send_header('Access-Control-Allow-Methods', '*')
        self.send_header('Access-Control-Allow-Headers', '*')
        self.end_headers()

    def do_OPTIONS(self):
        self._set_headers(204)

    def do_POST(self):
        """
        Endpoint: /api/index
            Desc: indexes the directory
            Body:
                {
                    directory: str
                }

        Endpoint: /api/init
            Desc: initializes the model
            Body:
                {
                    model: str
                }

        Endpoint: /api/query
            Desc: handles messages requests (with directory index)
            Body:
                {
                    messages: [ { role: str, content: str } ],
                    max_tokens: int,
                    repetition_penalty: float,
                    repetition_context_size: int,
                    temperature: float,
                    top_p: float,
                    instructions: {
                        personalization: str,
                        response: str
                    },
                    directory: str
                }
        """
        try:
            post_data = self.rfile.read(int(self.headers['Content-Length']))
            body = json.loads(post_data.decode('utf-8'))
            method = {
                '/api/index': self.index,
                '/api/query': self.query,
                '/api/init': self.init,
            }
            handle = method.get(self.path, None)
            if handle is None:
                self._set_headers(404)
                self.wfile.write(b'Not Found')
                return

            response = handle(body)
            self._set_headers(200)
            self.wfile.write(json.dumps(response).encode('utf-8'))

        except Exception as e:
            print(f"Error: {e}", flush=True)
            self._set_headers(500)
            self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8'))

    def index(self, body):
        directory = body.get('directory', None)
        index_directory(directory)
        return {'directory': directory}

    def init(self, body):
        model = body.get('model', None)
        load_model(model)
        return {'model': model}

    def query(self, body):
        chat_id = f'chatcmpl-{uuid.uuid4()}'

        directory = body.get('directory', None)
        messages = body.get('messages', [])
        instructions = body.get('instructions', None)

        indexed_files = ''
        if directory:
            # emperically better than `similarity_search`
            docs = _database.max_marginal_relevance_search(
                messages[-1]['content'],
                k=6  # number of documents to return
            )
            indexed_files = '\n'.join([doc.page_content for doc in docs])

            print(body, flush=True)
            print(('\n'+'--'*10+'\n').join([
                f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True)

        format_messages(messages, indexed_files, instructions)
        print(messages, flush=True)

        prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        ), add_special_tokens=True))

        max_tokens = body.get('max_tokens', 100)
        repetition_penalty = body.get('repetition_penalty', None)
        repetition_context_size = body.get('repetition_context_size', 20)
        temperature = body.get('temperature', 1.0)
        top_p = body.get('top_p', 1.0)

        tokens = []
        REPLACEMENT_CHAR = '\ufffd'
        for (token, prob), _ in zip(
            generate_step(
                prompt,
                _model,
                temperature,
                repetition_penalty,
                repetition_context_size,
                top_p,
            ),
            range(max_tokens),
        ):
            if token == _tokenizer.eos_token_id:
                break
            tokens.append(token.item())

        text = _tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
        # TODO: GEMMA IS OBSESSED WITH "Sure, ..."
        if text.startswith('Sure, '):
            text = text.split('\n')
            text[0] = text[0].replace('Sure, ', '').capitalize()
            text = '\n'.join([l for l in text])
        return create_response(chat_id, prompt, tokens, text)


def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
    server_address = (host, port)
    httpd = server_class(server_address, handler_class)
    print(f'Starting httpd at {host} on port {port}...', flush=True)
    httpd.serve_forever()


def main():
    if len(sys.argv) < 2:
        print(
            "Usage: python script.py [--host <host_address>] [--port <port_number>]")
        sys.exit(1)

    args = {
        '--host': '127.0.0.1',
        '--port': 8080
    }

    i = 1
    while i < len(sys.argv):
        if sys.argv[i] in args:
            args[sys.argv[i]] = sys.argv[i + 1]
            i += 2
        else:
            print(f"Unknown argument: {sys.argv[i]}")
            sys.exit(1)

    # Now you can access the parsed arguments using args dictionary
    host = args['--host']
    port = int(args['--port'])

    print(f'>> starting server on {host}:{port}', flush=True)
    run(host, port)


if __name__ == '__main__':
    main()


================================================
FILE: server/utils.py
================================================
import os
import copy
import glob
import shutil
import importlib
import json
import logging
import time
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten

from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

# Constants
MODEL_REMAPPING = {
    "mistral": "llama",  # mistral is compatible with llama
    "phi-msft": "phixtral",
}

MAX_FILE_SIZE_GB = 5

linear_class_predicate = (
    lambda m: isinstance(m, nn.Linear)
    and m.weight.shape[0]
    != 8  # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)


def _get_classes(config: dict):
    """
    Retrieve the model and model args classes based on the configuration.

    Args:
        config (dict): The model configuration.

    Returns:
        A tuple containing the Model class and the ModelArgs class.
    """
    model_type = config["model_type"]
    model_type = MODEL_REMAPPING.get(model_type, model_type)
    try:
        arch = importlib.import_module(f"server.models.{model_type}")
    except ImportError:
        msg = f"Model type {model_type} not supported."
        logging.error(msg)
        raise ValueError(msg)

    return arch.Model, arch.ModelArgs


def get_model_path(path_or_hf_repo: str) -> Path:
    """
    Ensures the model is available locally. If the path does not exist locally,
    it is downloaded from the Hugging Face Hub.

    Args:
        path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.

    Returns:
        Path: The path to the model.
    """
    model_path = Path(path_or_hf_repo)
    if not model_path.exists():
        model_path = Path(
            snapshot_download(
                repo_id=path_or_hf_repo,
                allow_patterns=[
                    "*.json",
                    "*.safetensors",
                    "*.py",
                    "tokenizer.model",
                    "*.tiktoken",
                ],
            )
        )
    return model_path


def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
    """
    Apply repetition penalty to specific logits based on the given context.

    Paper: https://arxiv.org/abs/1909.05858

    Args:
        logits (mx.array): The logits produced by the language model.
        generated_tokens (any): A list of N previous tokens.
        penalty (float): The repetition penalty factor to be applied.

    Returns:
        logits (mx.array): Logits with repetition penalty applied to generated tokens.
    """
    if len(generated_tokens) > 0:
        indices = mx.array([token for token in generated_tokens])
        selected_logits = logits[:, indices]
        selected_logits = mx.where(
            selected_logits < 0, selected_logits * penalty, selected_logits / penalty
        )
        logits[:, indices] = selected_logits
    return logits


def generate_step(
    prompt: mx.array,
    model: nn.Module,
    temp: 0.0,
    repetition_penalty: Optional[float] = None,
    repetition_context_size: Optional[int] = 20,
    top_p: float = 1.0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
    """
    A generator producing text based on the given prompt from the model.

    Args:
        prompt (mx.array): The input prompt.
        model (nn.Module): The model to use for generation.
        temp (float): The temperature for sampling, if 0 the argmax is used.
        repetition_penalty (float, optional): The penalty factor for repeating tokens.
        repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).

    Yields:
        Generator[Tuple[mx.array, mx.array]]: A generator producing
        one token and probability per call.
    """

    def sample(logits: mx.array) -> Tuple[mx.array, float]:
        softmax_logits = mx.softmax(logits)

        if temp == 0:
            token = mx.argmax(logits, axis=-1)
        else:
            if top_p > 0 and top_p < 1.0:
                if (
                    logits.dtype == mx.bfloat16
                ):  # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfl
Download .txt
gitextract_9znslswk/

├── .github/
│   └── workflows/
│       └── lint.yml
├── .gitignore
├── .vscode/
│   └── settings.json
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app/
│   ├── .eslintrc.cjs
│   ├── assets/
│   │   └── icon.icns
│   ├── components.json
│   ├── dprint.json
│   ├── mac/
│   │   └── entitlements.mac.inherit.plist
│   ├── main/
│   │   ├── main.ts
│   │   ├── preload.ts
│   │   ├── renderer.d.ts
│   │   ├── splash/
│   │   │   ├── index.css
│   │   │   └── index.html
│   │   └── tsconfig.json
│   ├── next.config.js
│   ├── notarize.js
│   ├── package.json
│   ├── postcss.config.js
│   ├── src/
│   │   ├── AppProvider.tsx
│   │   ├── app/
│   │   │   ├── globals.css
│   │   │   ├── layout.tsx
│   │   │   ├── page.tsx
│   │   │   └── settings/
│   │   │       └── page.tsx
│   │   ├── components/
│   │   │   ├── chat/
│   │   │   │   ├── Chat.tsx
│   │   │   │   ├── ChatInput.tsx
│   │   │   │   ├── ChatMessage.tsx
│   │   │   │   ├── ChatMessages.tsx
│   │   │   │   └── SystemMessage.tsx
│   │   │   ├── options/
│   │   │   │   ├── SelectDirectory.tsx
│   │   │   │   └── SelectModel.tsx
│   │   │   └── ui/
│   │   │       ├── button.tsx
│   │   │       ├── input.tsx
│   │   │       ├── resizable.tsx
│   │   │       ├── select.tsx
│   │   │       ├── textarea.tsx
│   │   │       └── tooltip.tsx
│   │   ├── constants/
│   │   │   └── chat.tsx
│   │   └── lib/
│   │       ├── hooks.ts
│   │       ├── store.ts
│   │       └── utils.ts
│   ├── tailwind.config.main.js
│   ├── tailwind.config.ts
│   └── tsconfig.json
├── runner.py
├── runner.sh
└── server/
    ├── __init__.py
    ├── convert.py
    ├── models/
    │   ├── __init__.py
    │   ├── base.py
    │   ├── bert.py
    │   ├── gemma.py
    │   ├── layers.py
    │   └── llama.py
    ├── py.typed
    ├── requirements.txt
    ├── retriever/
    │   ├── document.py
    │   ├── embeddings.py
    │   ├── loader.py
    │   ├── splitter.py
    │   └── vectorstore.py
    ├── server.py
    └── utils.py
Download .txt
SYMBOL INDEX (197 symbols across 26 files)

FILE: app/main/main.ts
  function handleSetTitle (line 25) | function handleSetTitle(event: any, title: string) {
  class ServerManager (line 34) | class ServerManager {
    method findOpenPort (line 38) | private findOpenPort(startingPort: number): Promise<number> {
    method runPythonServer (line 54) | private runPythonServer(port: number): any {
    method start (line 74) | start(model: string): Promise<void> {
    method stop (line 119) | stop(): void {
  method click (line 343) | click() {

FILE: app/main/renderer.d.ts
  type Window (line 4) | interface Window {

FILE: app/src/AppProvider.tsx
  function StoreProvider (line 16) | function StoreProvider({

FILE: app/src/app/layout.tsx
  function RootLayout (line 16) | function RootLayout({

FILE: app/src/app/page.tsx
  function Home (line 35) | function Home() {

FILE: app/src/app/settings/page.tsx
  type SETTINGS (line 28) | enum SETTINGS {
  function SettingsOption (line 33) | function SettingsOption({
  function GeneralSettings (line 71) | function GeneralSettings() {
  function PromptSettings (line 128) | function PromptSettings() {
  function Settings (line 176) | function Settings() {

FILE: app/src/components/ui/button.tsx
  type ButtonProps (line 40) | interface ButtonProps

FILE: app/src/components/ui/input.tsx
  type InputProps (line 6) | interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}

FILE: app/src/components/ui/textarea.tsx
  type TextareaProps (line 6) | interface TextareaProps extends React.TextareaHTMLAttributes<HTMLTextAre...

FILE: app/src/constants/chat.tsx
  type ChatMessage (line 1) | type ChatMessage = {

FILE: app/src/lib/hooks.ts
  function usePrevious (line 25) | function usePrevious<T>(value: T): T | undefined {
  function convertToNiceShortcut (line 33) | function convertToNiceShortcut(shortcut: string) {
  function useKeyboardShortcut (line 40) | function useKeyboardShortcut() {

FILE: app/src/lib/store.ts
  type AppStore (line 45) | type AppStore = ReturnType<typeof makeStore>;
  type RootState (line 47) | type RootState = ReturnType<AppStore['getState']>;
  type AppDispatch (line 48) | type AppDispatch = AppStore['dispatch'];

FILE: app/src/lib/utils.ts
  function cn (line 9) | function cn(...inputs: ClassValue[]) {

FILE: server/convert.py
  function configure_parser (line 6) | def configure_parser() -> argparse.ArgumentParser:

FILE: server/models/base.py
  class BaseModelArgs (line 6) | class BaseModelArgs:
    method from_dict (line 8) | def from_dict(cls, params):

FILE: server/models/bert.py
  class ModelArgs (line 13) | class ModelArgs(BaseModelArgs):
  function apply_chunking_to_forward (line 40) | def apply_chunking_to_forward(
  class BertEmbeddings (line 122) | class BertEmbeddings(nn.Module):
    method __init__ (line 125) | def __init__(self, config):
    method __call__ (line 144) | def __call__(
  class BertSelfAttention (line 191) | class BertSelfAttention(nn.Module):
    method __init__ (line 192) | def __init__(self, config, position_embedding_type=None):
    method transpose_for_scores (line 220) | def transpose_for_scores(self, x: mx.array) -> mx.array:
    method __call__ (line 226) | def __call__(
  class BertSelfOutput (line 342) | class BertSelfOutput(nn.Module):
    method __init__ (line 343) | def __init__(self, config):
    method __call__ (line 350) | def __call__(self, hidden_states: mx.array, input_tensor: mx.array) ->...
  class BertAttention (line 357) | class BertAttention(nn.Module):
    method __init__ (line 358) | def __init__(self, config, position_embedding_type=None):
    method __call__ (line 365) | def __call__(
  class BertIntermediate (line 390) | class BertIntermediate(nn.Module):
    method __init__ (line 391) | def __init__(self, config):
    method __call__ (line 396) | def __call__(self, hidden_states: mx.array) -> mx.array:
  class BertOutput (line 402) | class BertOutput(nn.Module):
    method __init__ (line 403) | def __init__(self, config):
    method __call__ (line 410) | def __call__(self, hidden_states: mx.array, input_tensor: mx.array) ->...
  class BertLayer (line 417) | class BertLayer(nn.Module):
    method __init__ (line 418) | def __init__(self, config):
    method __call__ (line 434) | def __call__(
    method feed_forward_chunk (line 504) | def feed_forward_chunk(self, attention_output):
  class BertEncoder (line 510) | class BertEncoder(nn.Module):
    method __init__ (line 511) | def __init__(self, config):
    method __call__ (line 519) | def __call__(
  class BertPooler (line 603) | class BertPooler(nn.Module):
    method __init__ (line 604) | def __init__(self, config):
    method __call__ (line 609) | def __call__(self, hidden_states: mx.array) -> mx.array:
  class BertModel (line 618) | class BertModel(nn.Module):
    method __init__ (line 631) | def __init__(self, config: ModelArgs, add_pooling_layer=True):
    method get_input_embeddings (line 639) | def get_input_embeddings(self):
    method set_input_embeddings (line 642) | def set_input_embeddings(self, value):
    method get_extended_attention_mask (line 645) | def get_extended_attention_mask(
    method get_head_mask (line 690) | def get_head_mask(
    method __call__ (line 718) | def __call__(
  class Model (line 847) | class Model(nn.Module):
    method __init__ (line 848) | def __init__(self, args: ModelArgs):
    method __call__ (line 853) | def __call__(
    method sanitize (line 861) | def sanitize(weights):
    method layers (line 868) | def layers(self):

FILE: server/models/gemma.py
  class ModelArgs (line 12) | class ModelArgs(BaseModelArgs):
  function rms_norm (line 27) | def rms_norm(x, weight, eps):
  class RMSNorm (line 33) | class RMSNorm(nn.Module):
    method __init__ (line 34) | def __init__(self, dims: int, eps: float = 1e-5):
    method __call__ (line 39) | def __call__(self, x):
  class Attention (line 43) | class Attention(nn.Module):
    method __init__ (line 44) | def __init__(self, args: ModelArgs):
    method __call__ (line 67) | def __call__(
  class MLP (line 104) | class MLP(nn.Module):
    method __init__ (line 105) | def __init__(self, dim, hidden_dim):
    method __call__ (line 111) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 115) | class TransformerBlock(nn.Module):
    method __init__ (line 116) | def __init__(self, args: ModelArgs):
    method __call__ (line 126) | def __call__(
  class GemmaModel (line 139) | class GemmaModel(nn.Module):
    method __init__ (line 140) | def __init__(self, args: ModelArgs):
    method __call__ (line 152) | def __call__(
  class Model (line 174) | class Model(nn.Module):
    method __init__ (line 175) | def __init__(self, args: ModelArgs):
    method __call__ (line 180) | def __call__(
    method layers (line 190) | def layers(self):

FILE: server/models/layers.py
  function rms_norm (line 8) | def rms_norm(x, weight, eps):
  class RMSNorm (line 14) | class RMSNorm(nn.Module):
    method __init__ (line 15) | def __init__(self, dims: int, eps: float = 1e-5):
    method __call__ (line 20) | def __call__(self, x):
  function ln_norm (line 25) | def ln_norm(x, eps, weight=None, bias=None):
  class LayerNorm (line 35) | class LayerNorm(nn.Module):
    method __init__ (line 36) | def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
    method _extra_repr (line 44) | def _extra_repr(self):
    method __call__ (line 47) | def __call__(self, x: mx.array) -> mx.array:

FILE: server/models/llama.py
  class ModelArgs (line 12) | class ModelArgs(BaseModelArgs):
    method __post_init__ (line 25) | def __post_init__(self):
  class Attention (line 38) | class Attention(nn.Module):
    method __init__ (line 39) | def __init__(self, args: ModelArgs):
    method __call__ (line 68) | def __call__(
  class MLP (line 105) | class MLP(nn.Module):
    method __init__ (line 106) | def __init__(self, dim, hidden_dim):
    method __call__ (line 112) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 116) | class TransformerBlock(nn.Module):
    method __init__ (line 117) | def __init__(self, args: ModelArgs):
    method __call__ (line 127) | def __call__(
  class LlamaModel (line 140) | class LlamaModel(nn.Module):
    method __init__ (line 141) | def __init__(self, args: ModelArgs):
    method __call__ (line 153) | def __call__(
  class Model (line 174) | class Model(nn.Module):
    method __init__ (line 175) | def __init__(self, args: ModelArgs):
    method __call__ (line 181) | def __call__(
    method sanitize (line 190) | def sanitize(weights):
    method layers (line 197) | def layers(self):

FILE: server/retriever/document.py
  class Document (line 4) | class Document():
    method __init__ (line 15) | def __init__(self, page_content: str, metadata: Optional[dict] = None,...

FILE: server/retriever/embeddings.py
  class Embeddings (line 12) | class Embeddings(ABC):
    method embed_documents (line 16) | def embed_documents(self, texts: List[str]) -> List[List[float]]:
    method embed_query (line 20) | def embed_query(self, text: str) -> List[float]:
  class E5Embeddings (line 24) | class E5Embeddings(Embeddings):
    method __init__ (line 29) | def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', qu...
    method _average_pool (line 35) | def _average_pool(self, last_hidden_states: mx.array,
    method embed_documents (line 41) | def embed_documents(self, texts: List[str], batch_size: int = 8) -> Li...
    method embed_query (line 49) | def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
  class ChatEmbeddings (line 66) | class ChatEmbeddings(Embeddings):
    method __init__ (line 71) | def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
    method embed_documents (line 75) | def embed_documents(self, texts: List[str]) -> List[List[float]]:
    method embed_query (line 78) | def embed_query(self,  text: str) -> List[float]:

FILE: server/retriever/loader.py
  function directory_loader (line 9) | def directory_loader(directory: Optional[str] = None) -> Optional[List[D...

FILE: server/retriever/splitter.py
  function _split_text_with_regex (line 16) | def _split_text_with_regex(
  class TextSplitter (line 36) | class TextSplitter(ABC):
    method __init__ (line 39) | def __init__(
    method split_text (line 70) | def split_text(self, text: str) -> List[str]:
    method create_documents (line 73) | def create_documents(
    method split_documents (line 93) | def split_documents(self, documents: Iterable[Document]) -> List[Docum...
    method _join_docs (line 101) | def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
    method _merge_splits (line 110) | def _merge_splits(self, splits: Iterable[str], separator: str) -> List...
  class RecursiveCharacterTextSplitter (line 152) | class RecursiveCharacterTextSplitter(TextSplitter):
    method __init__ (line 159) | def __init__(
    method _split_text (line 171) | def _split_text(self, text: str, separators: List[str]) -> List[str]:
    method split_text (line 212) | def split_text(self, text: str) -> List[str]:

FILE: server/retriever/vectorstore.py
  function _results_to_docs (line 30) | def _results_to_docs(results: Any) -> List[Document]:
  function _results_to_docs_and_scores (line 34) | def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, fl...
  function xor_args (line 47) | def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
  function cosine_similarity (line 75) | def cosine_similarity(
  function maximal_marginal_relevance (line 86) | def maximal_marginal_relevance(
  class Chroma (line 121) | class Chroma():
    method __init__ (line 128) | def __init__(
    method embeddings (line 185) | def embeddings(self) -> Optional[Embeddings]:
    method __query_collection (line 189) | def __query_collection(
    method add_texts (line 208) | def add_texts(
    method similarity_search (line 288) | def similarity_search(
    method similarity_search_with_score (line 310) | def similarity_search_with_score(
    method max_marginal_relevance_search_by_vector (line 350) | def max_marginal_relevance_search_by_vector(
    method max_marginal_relevance_search (line 399) | def max_marginal_relevance_search(
    method delete_collection (line 442) | def delete_collection(self) -> None:
    method get (line 446) | def get(
    method update_document (line 484) | def update_document(self, document_id: str, document: Document) -> None:
    method update_documents (line 493) | def update_documents(self, ids: List[str], documents: List[Document]) ...
    method from_texts (line 533) | def from_texts(
    method from_documents (line 596) | def from_documents(
    method delete (line 641) | def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:

FILE: server/server.py
  function load_model (line 26) | def load_model(model_path: str, adapter_file: Optional[str] = None):
  function index_directory (line 40) | def index_directory(directory: str, use_embedding: bool = True):
  function create_response (line 58) | def create_response(chat_id, prompt, tokens, text):
  function format_messages (line 85) | def format_messages(messages: List[Dict], indexed_files: Optional[str], ...
  class APIHandler (line 117) | class APIHandler(BaseHTTPRequestHandler):
    method _set_headers (line 118) | def _set_headers(self, status_code=200):
    method do_OPTIONS (line 126) | def do_OPTIONS(self):
    method do_POST (line 129) | def do_POST(self):
    method index (line 185) | def index(self, body):
    method init (line 190) | def init(self, body):
    method query (line 195) | def query(self, body):
  function run (line 256) | def run(host: str, port: int, server_class=HTTPServer, handler_class=API...
  function main (line 263) | def main():

FILE: server/utils.py
  function _get_classes (line 34) | def _get_classes(config: dict):
  function get_model_path (line 56) | def get_model_path(path_or_hf_repo: str) -> Path:
  function apply_repetition_penalty (line 84) | def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, pe...
  function generate_step (line 108) | def generate_step(
  function generate (line 196) | def generate(
  function load_model (line 282) | def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
  function load (line 353) | def load(
  function fetch_from_hub (line 390) | def fetch_from_hub(
  function make_shards (line 401) | def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB)...
  function upload_to_hub (line 425) | def upload_to_hub(path: str, upload_repo: str, hf_path: str):
  function save_weights (line 471) | def save_weights(
  function quantize_model (line 523) | def quantize_model(
  function get_mlx_path (line 550) | def get_mlx_path(hf_path: str, quantize: bool = False) -> str:
  function convert (line 556) | def convert(
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (218K chars).
[
  {
    "path": ".github/workflows/lint.yml",
    "chars": 463,
    "preview": "name: Lint\n\non: [pull_request]\n\njobs:\n  lint:\n    name: Lint\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: action"
  },
  {
    "path": ".gitignore",
    "chars": 2202,
    "preview": "# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.\n\n# dependencies\nnode_modules/\n.pnp"
  },
  {
    "path": ".vscode/settings.json",
    "chars": 1276,
    "preview": "{\n    \"[python]\": {\n      \"editor.tabSize\": 4,\n      \"editor.defaultFormatter\": \"ms-python.autopep8\",\n    },\n    // \"It "
  },
  {
    "path": "CODEOWNERS",
    "chars": 30,
    "preview": "*       @parkersm1th @stockeh\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 3286,
    "preview": "# Welcome to our contribution guide\n\nThank you for wanting to contribute to our project! We apprecaite any contributions"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2024 MLX Chat\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 1681,
    "preview": "\n![](docs/design-logo-light.png#gh-light-mode-only)\n![](docs/design-logo-dark.png#gh-dark-mode-only)\n\n\n**Chat with MLX**"
  },
  {
    "path": "app/.eslintrc.cjs",
    "chars": 8829,
    "preview": "const namingConventions = [\n  'error',\n  {\n    format: ['camelCase'],\n    selector: 'default',\n  },\n  {\n    format: ['ca"
  },
  {
    "path": "app/components.json",
    "chars": 348,
    "preview": "{\n  \"$schema\": \"https://ui.shadcn.com/schema.json\",\n  \"style\": \"new-york\",\n  \"rsc\": true,\n  \"tsx\": true,\n  \"tailwind\": {"
  },
  {
    "path": "app/dprint.json",
    "chars": 1091,
    "preview": "{\n    \"lineWidth\": 100,\n    \"typescript\": {\n      \"indentWidth\": 2,\n      \"quoteStyle\": \"alwaysSingle\",\n      \"semiColon"
  },
  {
    "path": "app/mac/entitlements.mac.inherit.plist",
    "chars": 409,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/P"
  },
  {
    "path": "app/main/main.ts",
    "chars": 11149,
    "preview": "// Main File for Electron\n\nimport {\n  exec,\n  execFile,\n} from 'child_process';\nimport {\n  app,\n  BrowserWindow,\n  dialo"
  },
  {
    "path": "app/main/preload.ts",
    "chars": 842,
    "preview": "// eslint-disable-next-line import/no-extraneous-dependencies\nimport {\n  contextBridge,\n  ipcRenderer,\n} from 'electron'"
  },
  {
    "path": "app/main/renderer.d.ts",
    "chars": 135,
    "preview": "import { electronAPI } from \"./preload\";\n\ndeclare global {\n  interface Window {\n    electronAPI: typeof electronAPI;\n  }"
  },
  {
    "path": "app/main/splash/index.css",
    "chars": 581,
    "preview": "@tailwind base;\n\n@tailwind components;\n\n@tailwind utilities;\n\ndiv {\n  -webkit-user-select: none;\n  -webkit-app-region: d"
  },
  {
    "path": "app/main/splash/index.html",
    "chars": 455,
    "preview": "<!DOCTYPE html>\n<html>\n  <head>\n    <meta charset=\"UTF-8\" />\n    <title>FLOATING LOADING SCREEN</title>\n    <link rel=\"s"
  },
  {
    "path": "app/main/tsconfig.json",
    "chars": 676,
    "preview": "{\n  \"compilerOptions\": {\n    \"allowJs\": true,\n    \"alwaysStrict\": true,\n    \"esModuleInterop\": true,\n    \"forceConsisten"
  },
  {
    "path": "app/next.config.js",
    "chars": 133,
    "preview": "/** @type {import('next').NextConfig} */\nconst nextConfig = {\n  output: \"export\",\n  distDir: \"out\",\n};\n\nmodule.exports ="
  },
  {
    "path": "app/notarize.js",
    "chars": 515,
    "preview": "require('dotenv').config();\nconst { notarize } = require('electron-notarize');\n\nexports.default = async function notariz"
  },
  {
    "path": "app/package.json",
    "chars": 4222,
    "preview": "{\n  \"name\": \"electron-app\",\n  \"productName\": \"Electron App\",\n  \"version\": \"0.1.0\",\n  \"private\": true,\n  \"main\": \"main/ou"
  },
  {
    "path": "app/postcss.config.js",
    "chars": 82,
    "preview": "module.exports = {\n  plugins: {\n    tailwindcss: {},\n    autoprefixer: {},\n  },\n}\n"
  },
  {
    "path": "app/src/AppProvider.tsx",
    "chars": 514,
    "preview": "'use client';\n\nimport {\n  useRef,\n} from 'react';\nimport {\n  Provider,\n} from 'react-redux';\nimport type {\n  AppStore,\n}"
  },
  {
    "path": "app/src/app/globals.css",
    "chars": 3008,
    "preview": "@tailwind base;\n@tailwind components;\n@tailwind utilities;\n\n@layer base {\n  :root {\n    --background: 0 0% 100%;\n    --f"
  },
  {
    "path": "app/src/app/layout.tsx",
    "chars": 781,
    "preview": "'use client';\n\nimport StoreProvider from '../AppProvider';\nimport './globals.css';\nimport '@fortawesome/fontawesome-svg-"
  },
  {
    "path": "app/src/app/page.tsx",
    "chars": 4140,
    "preview": "'use client';\n\nimport {\n  faBan,\n  faCheckCircle,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesomeIcon"
  },
  {
    "path": "app/src/app/settings/page.tsx",
    "chars": 6432,
    "preview": "'use client';\n\nimport type {\n  IconProp,\n} from '@fortawesome/fontawesome-svg-core';\nimport {\n  faCog,\n  faMessage,\n} fr"
  },
  {
    "path": "app/src/components/chat/Chat.tsx",
    "chars": 2889,
    "preview": "import React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\nimport {\n  useAppDispatch,\n} from"
  },
  {
    "path": "app/src/components/chat/ChatInput.tsx",
    "chars": 1648,
    "preview": "import React, {\n  useEffect,\n  useRef,\n  useState,\n} from 'react';\nimport {\n  useAppSelector,\n} from '../../lib/hooks';\n"
  },
  {
    "path": "app/src/components/chat/ChatMessage.tsx",
    "chars": 641,
    "preview": "import Markdown from 'markdown-to-jsx';\nimport React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/"
  },
  {
    "path": "app/src/components/chat/ChatMessages.tsx",
    "chars": 2487,
    "preview": "/* eslint-disable function-paren-newline */\nimport {\n  faCircleNotch,\n} from '@fortawesome/free-solid-svg-icons';\nimport"
  },
  {
    "path": "app/src/components/chat/SystemMessage.tsx",
    "chars": 1444,
    "preview": "import React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\n\nconst Message = ({\n  message,\n}:"
  },
  {
    "path": "app/src/components/options/SelectDirectory.tsx",
    "chars": 2526,
    "preview": "import {\n  faCheckCircle,\n  faCircleNotch,\n  faXmark,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesome"
  },
  {
    "path": "app/src/components/options/SelectModel.tsx",
    "chars": 1141,
    "preview": "import React from 'react';\nimport {\n  Select,\n  SelectContent,\n  SelectGroup,\n  SelectItem,\n  SelectLabel,\n  SelectTrigg"
  },
  {
    "path": "app/src/components/ui/button.tsx",
    "chars": 1925,
    "preview": "import {\n  Slot,\n} from '@radix-ui/react-slot';\nimport {\n  cva,\n  type VariantProps,\n} from 'class-variance-authority';\n"
  },
  {
    "path": "app/src/components/ui/input.tsx",
    "chars": 777,
    "preview": "import * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nexport interface InputProps extends React.Inpu"
  },
  {
    "path": "app/src/components/ui/resizable.tsx",
    "chars": 1773,
    "preview": "'use client';\n\nimport {\n  DragHandleDots2Icon,\n} from '@radix-ui/react-icons';\nimport * as ResizablePrimitive from 'reac"
  },
  {
    "path": "app/src/components/ui/select.tsx",
    "chars": 5719,
    "preview": "'use client';\n\nimport {\n  CaretSortIcon,\n  CheckIcon,\n  ChevronDownIcon,\n  ChevronUpIcon,\n} from '@radix-ui/react-icons'"
  },
  {
    "path": "app/src/components/ui/textarea.tsx",
    "chars": 711,
    "preview": "import * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nexport interface TextareaProps extends React.T"
  },
  {
    "path": "app/src/components/ui/tooltip.tsx",
    "chars": 1168,
    "preview": "'use client';\n\nimport * as TooltipPrimitive from '@radix-ui/react-tooltip';\nimport * as React from 'react';\nimport {\n  c"
  },
  {
    "path": "app/src/constants/chat.tsx",
    "chars": 98,
    "preview": "export type ChatMessage = {\n  role: 'user' | 'assistant' | 'system';\n  content: string | null;\n};\n"
  },
  {
    "path": "app/src/lib/hooks.ts",
    "chars": 2301,
    "preview": "import {\n  useEffect,\n  useRef,\n  useState,\n} from 'react';\nimport {\n  useDispatch,\n  useSelector,\n  useStore,\n} from 'r"
  },
  {
    "path": "app/src/lib/store.ts",
    "chars": 1295,
    "preview": "import {\n  configureStore,\n  createSlice,\n} from '@reduxjs/toolkit';\n\nconst globalSlice = createSlice({\n  name: 'global'"
  },
  {
    "path": "app/src/lib/utils.ts",
    "chars": 177,
    "preview": "import {\n  type ClassValue,\n  clsx,\n} from 'clsx';\nimport {\n  twMerge,\n} from 'tailwind-merge';\n\nexport function cn(...i"
  },
  {
    "path": "app/tailwind.config.main.js",
    "chars": 261,
    "preview": "/** @type {import('tailwindcss').Config} */\nmodule.exports = {\n  content: [\"./main/**/*.{js,ts,jsx,tsx,mdx,html}\"],\n  //"
  },
  {
    "path": "app/tailwind.config.ts",
    "chars": 2288,
    "preview": "/* eslint-disable @typescript-eslint/naming-convention */\nimport type {\n  Config,\n} from 'tailwindcss';\n\nconst config = "
  },
  {
    "path": "app/tsconfig.json",
    "chars": 859,
    "preview": "{\n  \"compilerOptions\": {\n    \"target\": \"es5\",\n    \"lib\": [\n      \"dom\",\n      \"dom.iterable\",\n      \"esnext\",\n    ],\n   "
  },
  {
    "path": "runner.py",
    "chars": 341,
    "preview": "# Parent script to package (PyInstaller) server\n#\n# Example Usage:\n#\n# pyinstaller --onefile --collect-all mlx --copy-me"
  },
  {
    "path": "runner.sh",
    "chars": 662,
    "preview": "#!/bin/bash\n\ncollect_modules=(\n  \"mlx\"\n  \"chromadb\"\n)\n\nhidden_imports=(\n  \"server.models\"\n  \"server.models.gemma\"\n  \"ser"
  },
  {
    "path": "server/__init__.py",
    "chars": 66,
    "preview": "from .utils import generate, load, convert\n\n__version__ = \"0.1.0\"\n"
  },
  {
    "path": "server/convert.py",
    "chars": 1442,
    "preview": "import argparse\n\nfrom .utils import convert\n\n\ndef configure_parser() -> argparse.ArgumentParser:\n    \"\"\"\n    Configures "
  },
  {
    "path": "server/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "server/models/base.py",
    "chars": 314,
    "preview": "import inspect\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass BaseModelArgs:\n    @classmethod\n    def from_dict(cl"
  },
  {
    "path": "server/models/bert.py",
    "chars": 37310,
    "preview": "import math\nimport inspect\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom dataclasses import dataclass\nfrom typing impo"
  },
  {
    "path": "server/models/gemma.py",
    "chars": 6069,
    "preview": "from dataclasses import dataclass\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport mlx.core as m"
  },
  {
    "path": "server/models/layers.py",
    "chars": 1448,
    "preview": "from functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\n@partial(mx.compile, shapeless=True)\ndef rms_"
  },
  {
    "path": "server/models/llama.py",
    "chars": 6603,
    "preview": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn a"
  },
  {
    "path": "server/py.typed",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "server/requirements.txt",
    "chars": 108,
    "preview": "chromadb==0.4.23\nhuggingface_hub==0.20.3\nmlx==0.4.0\nmlx_data==0.0.2\ntransformers==4.38.1\npyinstaller==6.4.0\n"
  },
  {
    "path": "server/retriever/document.py",
    "chars": 697,
    "preview": "from typing import Any, Literal, Optional\n\n\nclass Document():\n    \"\"\"Class for storing a piece of text and associated me"
  },
  {
    "path": "server/retriever/embeddings.py",
    "chars": 3031,
    "preview": "import os\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom transformers import PreTrainedTokenizer\nfrom abc import ABC, a"
  },
  {
    "path": "server/retriever/loader.py",
    "chars": 964,
    "preview": "import os\nimport glob\nfrom typing import List, Optional\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom .documen"
  },
  {
    "path": "server/retriever/splitter.py",
    "chars": 8120,
    "preview": "import re\nimport copy\n\nfrom abc import ABC, abstractmethod\nfrom typing import (\n    Any,\n    List,\n    Optional,\n    Cal"
  },
  {
    "path": "server/retriever/vectorstore.py",
    "chars": 24443,
    "preview": "import uuid\nimport functools\nimport mlx.core as mx\n\nimport chromadb\nimport chromadb.config\n\nfrom chromadb.utils.batch_ut"
  },
  {
    "path": "server/server.py",
    "chars": 9136,
    "preview": "import os\nimport sys\nimport json\nimport time\nimport uuid\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom http.server im"
  },
  {
    "path": "server/utils.py",
    "chars": 18938,
    "preview": "import os\nimport copy\nimport glob\nimport shutil\nimport importlib\nimport json\nimport logging\nimport time\nfrom pathlib imp"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the mlx-chat/mlx-chat-app GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (201.3 KB), approximately 50.2k tokens, and a symbol index with 197 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!