Showing preview only (218K chars total). Download the full file or copy to clipboard to get everything.
Repository: mlx-chat/mlx-chat-app
Branch: main
Commit: 20f441dc1188
Files: 66
Total size: 201.3 KB
Directory structure:
gitextract_9znslswk/
├── .github/
│ └── workflows/
│ └── lint.yml
├── .gitignore
├── .vscode/
│ └── settings.json
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app/
│ ├── .eslintrc.cjs
│ ├── assets/
│ │ └── icon.icns
│ ├── components.json
│ ├── dprint.json
│ ├── mac/
│ │ └── entitlements.mac.inherit.plist
│ ├── main/
│ │ ├── main.ts
│ │ ├── preload.ts
│ │ ├── renderer.d.ts
│ │ ├── splash/
│ │ │ ├── index.css
│ │ │ └── index.html
│ │ └── tsconfig.json
│ ├── next.config.js
│ ├── notarize.js
│ ├── package.json
│ ├── postcss.config.js
│ ├── src/
│ │ ├── AppProvider.tsx
│ │ ├── app/
│ │ │ ├── globals.css
│ │ │ ├── layout.tsx
│ │ │ ├── page.tsx
│ │ │ └── settings/
│ │ │ └── page.tsx
│ │ ├── components/
│ │ │ ├── chat/
│ │ │ │ ├── Chat.tsx
│ │ │ │ ├── ChatInput.tsx
│ │ │ │ ├── ChatMessage.tsx
│ │ │ │ ├── ChatMessages.tsx
│ │ │ │ └── SystemMessage.tsx
│ │ │ ├── options/
│ │ │ │ ├── SelectDirectory.tsx
│ │ │ │ └── SelectModel.tsx
│ │ │ └── ui/
│ │ │ ├── button.tsx
│ │ │ ├── input.tsx
│ │ │ ├── resizable.tsx
│ │ │ ├── select.tsx
│ │ │ ├── textarea.tsx
│ │ │ └── tooltip.tsx
│ │ ├── constants/
│ │ │ └── chat.tsx
│ │ └── lib/
│ │ ├── hooks.ts
│ │ ├── store.ts
│ │ └── utils.ts
│ ├── tailwind.config.main.js
│ ├── tailwind.config.ts
│ └── tsconfig.json
├── runner.py
├── runner.sh
└── server/
├── __init__.py
├── convert.py
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── bert.py
│ ├── gemma.py
│ ├── layers.py
│ └── llama.py
├── py.typed
├── requirements.txt
├── retriever/
│ ├── document.py
│ ├── embeddings.py
│ ├── loader.py
│ ├── splitter.py
│ └── vectorstore.py
├── server.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/lint.yml
================================================
name: Lint
on: [pull_request]
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
with:
fetch-depth: 1
- name: Use Node.js 16
uses: actions/setup-node@v1
with:
node-version: 16
- name: Install App Deps
run: npm i --ignore-scripts
working-directory: ./app
- name: Lint App
working-directory: ./app
run: npm run lint
================================================
FILE: .gitignore
================================================
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
node_modules/
.pnp/
.pnp.js
# testing
coverage/
# next.js
.next/
out/
# production
build/
app/main/tailwind.css
dist/
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
# local env files
.env*.local
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
server/lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: .vscode/settings.json
================================================
{
"[python]": {
"editor.tabSize": 4,
"editor.defaultFormatter": "ms-python.autopep8",
},
// "It is recommended that either ESLint or cspell checks a file, but not both."
// https://www.npmjs.com/package/@cspell/eslint-plugin
"cSpell.enableFiletypes": [
"!javascript",
"!typescript",
],
"css.format.spaceAroundSelectorSeparator": true,
"editor.codeActionsOnSave": [
"source.fixAll.eslint",
],
"eslint.codeActionsOnSave.mode": "problems",
"eslint.options": {
"reportUnusedDisableDirectives": "error",
},
"eslint.rules.customizations": [
{ "rule": "*", "severity": "warn" },
],
"editor.defaultFormatter": "dprint.dprint",
"editor.formatOnSave": true,
"editor.tabSize": 2,
"editor.wordWrapColumn": 100,
"eslint.workingDirectories": [
"./app",
],
"files.insertFinalNewline": true,
"files.trimFinalNewlines": true,
"git.allowForcePush": true,
"git.inputValidationSubjectLength": 100,
"git.inputValidationLength": 100,
"javascript.preferences.quoteStyle": "single",
"typescript.preferences.quoteStyle": "single",
"scss.format.spaceAroundSelectorSeparator": true,
"typescript.tsdk": "node_modules/typescript/lib",
}
================================================
FILE: CODEOWNERS
================================================
* @parkersm1th @stockeh
================================================
FILE: CONTRIBUTING.md
================================================
# Welcome to our contribution guide
Thank you for wanting to contribute to our project! We apprecaite any contributions that you make.
Chat with MLX is an open source project and we love to receive contributions from our community — you! There are many ways to contribute, from writing tutorials or blog posts, improving the documentation, submitting bug reports and feature requests or writing code which can be incorporated into the application itself.
## New Contributor Guide
To get an overview of the project, read the [README](https://github.com/mlx-chat/mlx-chat-app/blob/main/README.md). Here are some resources to help you get started with open source contributions:
- [Ways to contribute on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github)
- [Setup Git](https://docs.github.com/en/get-started/quickstart/set-up-git)
- [GitHub workflow](https://docs.github.com/en/get-started/quickstart/github-flow)
- [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests)
## Getting Started
### Issues
**Create**: If you spot a problem, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). If a related issue doesn't exist, you can open a new issue!
**Solve**: Scan through our [existing issues](https://github.com/mlx-chat/mlx-chat-app) to find one that interests you. If you find an issue to work on, you are welcome to assign it to yourself and open a PR with a fix.
### Make Changes
1. Create your own fork of the code
2. Create a working branch and start with your changes
3. Commit and send a pull request
### Pull Request
When you're finished with the changes, create a pull request.
- Check to see your pull request passes our continuous integration (CI). If you cannot get a certain integration test to pass, let us know. We can assist you in fixing these issues or approve a merge manually.
- Make sure your additions are properly documented!
- Don't forget to [link PR to issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) if you are solving one.
- We may ask for changes to be made before a PR can be merged, either using [suggested changes](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/incorporating-feedback-in-your-pull-request) or pull request comments. You can apply suggested changes directly through the UI. You can make any other changes in your fork, then commit them to your branch.
- As you update your PR and apply changes, mark each conversation as [resolved](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/commenting-on-a-pull-request#resolving-conversations).
- If you run into any merge issues, checkout this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues.
### Your PR is Merged!
Congratulations :tada::tada: we thank you! :sparkles:
Once your PR is merged, your contributions will be publicly visible in the [Chat with MLX Repository](https://github.com/mlx-chat/mlx-chat-app).
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 MLX Chat
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================


**Chat with MLX** is a high-performance macOS application that connects your local documents to a personalized large language model (LLM). By leveraging retrieval-augmented generation (RAG), open source LLMs, and MLX for accelerated machine learning on Apple silicon, you can efficently search, query, and interact with your documents *without information ever leaving your device.*
Our high-level features include:
- **Query**: load and search with document-specific prompts
- **Converse**: switch model interaction modes (converse vs. assist) in real time
- **Instruct**: provide personalization and response tuning
## Installation and Setup
:warning: **Preliminary Steps**: we are working to release with correct packaging ([pyinstaller](https://github.com/pyinstaller/pyinstaller/) & [electron-builder](https://github.com/electron-userland/electron-builder)) and authentication ([Apple codesign](https://developer.apple.com/support/code-signing/)). In the interium, please clone and run in development by first setting up authentication and requirements.
First, setup huggingface [access tokens](https://huggingface.co/settings/tokens) to download models (request access to [google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)), then
```bash
huggingface-cli login
```
Then download the npm/python requirements
```bash
cd app && npm install
pip install -r server/requirements.txt
```
Finally, start the application
```bash
cd app && npm run dev
```
## Contributions
All contributions are welcome. Please take a look at [contributing](CONTRIBUTING.md) guide.
================================================
FILE: app/.eslintrc.cjs
================================================
const namingConventions = [
'error',
{
format: ['camelCase'],
selector: 'default',
},
{
format: ['camelCase', 'UPPER_CASE'],
selector: 'variable',
},
{
format: ['camelCase', 'UPPER_CASE', 'PascalCase'],
modifiers: ['const', 'exported', 'global'],
selector: 'variable',
},
{
format: ['camelCase'],
leadingUnderscore: 'allow',
selector: 'parameter',
},
{
format: ['camelCase'],
leadingUnderscore: 'allow',
modifiers: ['private'],
selector: 'memberLike',
},
{
format: ['PascalCase', 'UPPER_CASE'],
selector: ['enum', 'enumMember'],
},
{
format: ['PascalCase'],
selector: 'typeLike',
},
{
format: null,
modifiers: ['destructured'],
selector: 'variable',
},
{
format: null,
modifiers: ['requiresQuotes'],
selector: [
'classProperty',
'objectLiteralProperty',
'typeProperty',
'classMethod',
'objectLiteralMethod',
'typeMethod',
'accessor',
'enumMember',
],
},
{
format: ['camelCase', 'PascalCase', 'UPPER_CASE'],
leadingUnderscore: 'allow',
selector: 'import',
},
];
const tsxNamingConventions = [
{
format: ['camelCase', 'PascalCase', 'UPPER_CASE'],
leadingUnderscore: 'forbid',
modifiers: ['global'],
selector: ['variable', 'function'],
},
];
module.exports = {
env: {
es2020: true,
},
extends: [
'airbnb-base',
'plugin:jsdoc/recommended',
'plugin:@typescript-eslint/recommended',
'plugin:import/typescript',
'plugin:no-unsanitized/DOM',
],
ignorePatterns: [
'node_modules',
'main',
'.eslintrc.*',
'out',
],
overrides: [
// Config files
{
files: [
'common/**/*.ts*',
'**/app*config.ts',
'**/app*Config.ts',
],
rules: {
'@typescript-eslint/member-ordering': ['error', { default: { order: 'alphabetically' } }],
'sort-keys': ['error', 'asc', { minKeys: 2, natural: true }],
},
},
{
files: [
'*.ts*',
],
rules: {
'@typescript-eslint/no-shadow': 'error',
'@typescript-eslint/no-unused-vars': 'off', // Using unused-imports plugin instead
'@typescript-eslint/space-before-function-paren': [
'error',
{
anonymous: 'never',
asyncArrow: 'always',
named: 'never',
},
],
'no-redeclare': 'off', // @typescript-eslint/no-redeclare is enabled and is more correct
'no-shadow': 'off', // @typescript-eslint/no-shadow is enabled and is more correct
'no-undef-init': 'off',
'no-unused-vars': 'off', // Using unused-imports plugin instead
'space-before-function-paren': 'off', // Using @typescript-eslint/space-before-function-paren instead
'unused-imports/no-unused-imports': 'error',
'unused-imports/no-unused-vars': ['error', {
args: 'after-used',
argsIgnorePattern: '^_',
destructuredArrayIgnorePattern: '^_',
ignoreRestSiblings: true,
}],
},
},
{
files: [
'src/**/*.ts*',
],
rules: {
'@typescript-eslint/await-thenable': 'error',
'@typescript-eslint/dot-notation': ['error', { allowIndexSignaturePropertyAccess: true }],
'@typescript-eslint/no-base-to-string': ['error', {
ignoredTypeNames: ['Error', 'RegExp'],
}],
'@typescript-eslint/no-floating-promises': 'error',
'@typescript-eslint/no-for-in-array': 'error',
'@typescript-eslint/no-misused-promises': ['error', { checksVoidReturn: false }],
'@typescript-eslint/no-throw-literal': 'error',
'@typescript-eslint/no-unnecessary-condition': 'error',
'@typescript-eslint/no-unnecessary-type-assertion': 'error',
'@typescript-eslint/non-nullable-type-assertion-style': 'error',
'@typescript-eslint/prefer-includes': 'error',
'@typescript-eslint/prefer-optional-chain': 'error',
'@typescript-eslint/prefer-string-starts-ends-with': 'error',
'@typescript-eslint/require-await': 'error',
'@typescript-eslint/space-infix-ops': 'error',
'dot-notation': 'off',
'no-throw-literal': 'off',
'require-await': 'off',
'space-infix-ops': 'off',
},
},
{
files: [
'*.tsx',
],
rules: {
'@typescript-eslint/naming-convention': [
...namingConventions,
...tsxNamingConventions,
],
'@typescript-eslint/require-await': 'error',
'require-await': 'off',
},
},
],
parser: '@typescript-eslint/parser',
parserOptions: {
project: 'tsconfig.json',
},
plugins: [
'react',
'@typescript-eslint',
'jest-formatting',
'modules-newlines',
'unused-imports',
],
root: true,
rules: {
'@typescript-eslint/ban-types': [
'error',
{
extendDefaults: true,
types: {
object: {
message: [
'The `object` type is currently hard to use ([see this issue](https://github.com/microsoft/TypeScript/issues/21732)).',
'Consider using `Record<string, unknown>` instead, as it allows you to more easily inspect and use the keys.',
].join('\n'),
},
},
},
],
'implicit-arrow-linebreak': 'off',
'@typescript-eslint/consistent-type-assertions': ['error', { assertionStyle: 'never' }],
'@typescript-eslint/consistent-type-imports': 'error',
'@typescript-eslint/init-declarations': 'error',
'@typescript-eslint/member-ordering': 'error',
'@typescript-eslint/naming-convention': namingConventions,
'@typescript-eslint/no-explicit-any': 'error',
'@typescript-eslint/no-non-null-asserted-nullish-coalescing': 'error',
'@typescript-eslint/no-use-before-define': ['error', {
functions: false,
}],
'@typescript-eslint/prefer-for-of': 'error',
'@typescript-eslint/type-annotation-spacing': 'error',
'array-element-newline': ['error', 'consistent'],
'block-spacing': 'off',
camelcase: 'off', // Using @typescript-eslint/naming-convention instead.
'comma-dangle': 'off',
'default-param-last': 'off',
'import/extensions': 'off',
'import/no-relative-packages': 'off',
'import/order': 'off',
'import/prefer-default-export': 'off',
'jsdoc/check-indentation': ['error', { excludeTags: ['description', 'example'] }],
'jsdoc/check-line-alignment': 'error',
'jsdoc/check-tag-names': ['error', {
definedTags: ['jest-environment', 'jest-environment-options'],
}],
'jsdoc/no-bad-blocks': 'error',
'jsdoc/no-multi-asterisks': 'off',
'jsdoc/no-undefined-types': 'off',
'jsdoc/require-jsdoc': 'off',
'jsdoc/require-param': 'off',
'jsdoc/require-param-description': 'off',
'jsdoc/require-param-name': 'off',
'jsdoc/require-param-type': 'off',
'jsdoc/require-property': 'off',
'jsdoc/require-property-description': 'off',
'jsdoc/require-property-name': 'off',
'jsdoc/require-property-type': 'off',
'jsdoc/require-returns': 'off',
'jsdoc/require-returns-description': 'off',
'jsdoc/require-returns-type': 'off',
'jsdoc/require-yields': 'off',
'jsdoc/require-yields-check': 'off',
'jsdoc/tag-lines': ['error', 'any', { startLines: 1 }],
'max-classes-per-file': 'off',
'max-len': ['error', {
code: 100,
ignorePattern: '(/* eslint |eslint-disable-next-line |@ts-expect-error )',
ignoreRegExpLiterals: true,
ignoreStrings: true,
ignoreTemplateLiterals: true,
ignoreUrls: true,
}],
'max-params': ['error', 3],
'new-cap': [
'error',
{
capIsNew: true,
capIsNewExceptions: [
'express.Router',
'Immutable.Map',
'Immutable.Set',
'Immutable.List',
'RightRailView',
'URLWithSearchParams',
],
newIsCap: true,
newIsCapExceptions: [],
properties: true,
},
],
'no-console': 'error',
'no-continue': 'off',
'no-empty-function': 'off',
'no-promise-executor-return': 'off',
'no-redeclare': 'error',
'no-restricted-properties': [
'error',
],
'no-restricted-syntax': [
'error',
],
'no-use-before-define': 'off',
'no-void': ['error', { allowAsStatement: true }],
'padding-line-between-statements': [
'error',
{ blankLine: 'never', next: 'import', prev: 'import' },
],
'prefer-arrow-callback': ['error', { allowNamedFunctions: true }],
'prefer-exponentiation-operator': 'off',
'prefer-regex-literals': 'off',
},
settings: {
'import/typescript': {
typescript: {},
},
},
};
================================================
FILE: app/components.json
================================================
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "new-york",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "main/splash/index.css",
"baseColor": "slate",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"
}
}
================================================
FILE: app/dprint.json
================================================
{
"lineWidth": 100,
"typescript": {
"indentWidth": 2,
"quoteStyle": "alwaysSingle",
"semiColons": "always",
"quoteProps": "asNeeded",
"useBraces": "always",
"trailingCommas": "onlyMultiLine",
"module.sortImportDeclarations": "caseInsensitive",
"exportDeclaration.forceMultiLine": true,
"importDeclaration.forceMultiLine": true
},
"json": {
"jsonTrailingCommaFiles": [
".vscode/launch.json",
".vscode/extensions.json",
".vscode/settings.json",
".vscode/tasks.json",
"tsconfig.json"
]
},
"excludes": [
"**/node_modules",
"**/*-lock.json",
"**/Dockerfile",
"**/src/ui-tests/fixtures/**/*",
"**/storybook-static/**/*",
"**/build/**/*",
"**/dist/**/*",
"**/artifacts/**/*",
"extension/src/assets/**/*.json"
],
"plugins": [
"https://plugins.dprint.dev/typescript-0.88.3.wasm",
"https://plugins.dprint.dev/json-0.19.0.wasm",
"https://plugins.dprint.dev/dockerfile-0.3.0.wasm"
]
}
================================================
FILE: app/mac/entitlements.mac.inherit.plist
================================================
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.security.cs.allow-jit</key>
<true/>
<key>com.apple.security.cs.allow-unsigned-executable-memory</key>
<true/>
<key>com.apple.security.cs.disable-library-validation</key>
<true/>
</dict>
</plist>
================================================
FILE: app/main/main.ts
================================================
// Main File for Electron
import {
exec,
execFile,
} from 'child_process';
import {
app,
BrowserWindow,
dialog,
globalShortcut,
ipcMain,
Menu,
nativeImage,
Tray,
} from 'electron';
import * as contextMenu from 'electron-context-menu';
import Store from 'electron-store';
import * as net from 'net';
const path = require('path');
const serve = require('electron-serve');
const { spawn } = require('child_process');
function handleSetTitle(event: any, title: string) {
const webContents = event.sender;
const win = BrowserWindow.fromWebContents(webContents);
if (win !== null) {
win.setTitle(title);
}
}
// Python Server
class ServerManager {
private serverProcess: any | null = null;
public port: number | null = null;
private findOpenPort(startingPort: number): Promise<number> {
return new Promise<number>((resolve) => {
const server = net.createServer();
server.listen(startingPort, () => {
const port = (server.address() as net.AddressInfo).port;
server.close(() => resolve(port));
});
server.on(
'error',
(err: any) => err.code === 'EADDRINUSE' && resolve(this.findOpenPort(startingPort + 1)),
);
});
}
private runPythonServer(port: number): any {
const args = ['--host 127.0.0.1', `--port ${port}`];
const modifiedArgs = args.flatMap(arg => arg.split(/\s+/));
const pythonProcess = isProd
? execFile(path.join(process.resourcesPath, 'server', 'runner'), modifiedArgs)
: spawn('python', ['-m', 'server.server', ...modifiedArgs], {
cwd: '../',
});
pythonProcess.stdout.on(
'data',
(data: Buffer) => console.log('Server output:', data.toString('utf8')),
);
pythonProcess.stderr.on(
'data',
(data: Buffer) => console.log(`Server error: ${data.toString('utf8')}`),
);
return pythonProcess;
}
start(model: string): Promise<void> {
return new Promise<void>((resolve, reject) => {
this.stop();
this.findOpenPort(8080).then((port) => {
this.port = port;
console.log(`APP: Starting server for model: ${model} on port: ${port}`);
this.serverProcess = this.runPythonServer(port);
this.serverProcess.stdout.on('data', async (data: Buffer) => {
const output = data.toString('utf8');
await new Promise((resolve) => setTimeout(resolve, 1000));
// Check if the server is ready
if (output.includes('Starting httpd')) {
fetch(`http://127.0.0.1:${port}/api/init`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ model }),
}).then(() => {
resolve(); // Resolve the promise when the server is ready
}).catch((err) => {
console.error('Error initializing the server:', err);
reject(err);
});
}
});
this.serverProcess.on('close', (code: number | null) => {
console.log(`Server process exited with code ${code}`);
this.serverProcess = null;
});
this.serverProcess.on('error', (err: any) => {
console.error(`Error in server process: ${err}`);
this.serverProcess = null;
reject(err);
});
});
});
}
stop(): void {
if (this.serverProcess) {
console.log('Stopping the server...');
this.serverProcess.kill();
this.serverProcess = null;
}
}
}
// Loading Screen
let splash: BrowserWindow | null;
const createSplashScreen = () => {
/// create a browser window
splash = new BrowserWindow(
{
width: 200,
height: 100,
focusable: false,
/// remove the window frame, so it will become a frameless window
frame: false,
skipTaskbar: true,
autoHideMenuBar: true,
},
);
splash.setResizable(false);
splash.loadURL(`file://${__dirname}/../splash/index.html`);
splash.on('closed', () => (splash = null));
splash.webContents.on('did-finish-load', () => {
if (splash) {
splash.show();
}
});
};
// run renderer
const isProd = process.env.NODE_ENV !== 'development';
if (isProd) {
serve({ directory: 'out' });
} else {
app.setPath('userData', `${app.getPath('userData')} (development)`);
}
contextMenu.default({
showInspectElement: !isProd,
});
let openModal: 'settings' | 'directory' | null = null;
let globalWindow: BrowserWindow | null = null;
const triggerShortcut = () => {
if (openModal || !globalWindow) {
return;
}
if (globalWindow.isFocused()) {
globalWindow.blur();
return;
}
globalWindow.show();
};
const store = new Store({
schema: {
keybind: {
type: 'string',
default: 'Cmd+O',
},
model: {
type: 'string',
default: 'mistralai/Mistral-7B-Instruct-v0.2',
},
personalization: {
type: 'string',
default: '',
},
customResponse: {
type: 'string',
default: '',
},
},
});
const serverManager = new ServerManager();
const createWindow = () => {
const icon = nativeImage.createFromPath(
!isProd
? '../assets/IconTemplate.png'
: path.join(process.resourcesPath, 'IconTemplate.png'),
);
// if you want to resize it, be careful, it creates a copy
const trayIcon = icon.resize({ width: 16 });
// here is the important part (has to be set on the resized version)
trayIcon.setTemplateImage(true);
let tray = new Tray(trayIcon);
tray.setTitle(isProd ? '' : 'M');
const win = new BrowserWindow({
webPreferences: {
preload: path.join(__dirname, 'preload.js'),
devTools: !isProd,
},
show: false,
width: 600,
height: 99,
resizable: false,
type: 'panel',
frame: false,
skipTaskbar: true,
autoHideMenuBar: true,
vibrancy: 'under-window', // on MacOS
backgroundMaterial: 'acrylic',
icon: __dirname + '../../assets/public/icon.icns',
});
globalWindow = win;
win.setWindowButtonVisibility(false);
win.setAlwaysOnTop(true, 'floating');
win.setVisibleOnAllWorkspaces(true);
// Expose URL
if (isProd) {
win.loadURL('app://./home.html');
} else {
// const port = process.argv[2];
win.loadURL('http://localhost:3000/');
}
tray.addListener('click', () => {
if (win.isFocused()) {
win.blur();
return;
}
win.show();
});
win.webContents.on('did-finish-load', async () => {
await serverManager.start(store.get('model') as string);
/// then close the loading screen window and show the main window
if (splash) {
splash.close();
}
app.dock.hide();
win.show();
globalShortcut.register(store.get('keybind') as string, triggerShortcut.bind(null));
});
// @ts-expect-error -- We don't have types for electron
win.on('blur', (event) => {
if (openModal) {
win.setAlwaysOnTop(false);
}
if (openModal === 'directory') {
return;
}
if (win.webContents.isDevToolsOpened()) {
return;
}
globalShortcut.unregister('Escape');
globalShortcut.unregister('Cmd+Q');
win.hide();
if (openModal) {
return;
}
Menu.sendActionToFirstResponder('hide:');
});
win.on('focus', () => {
globalShortcut.register('Cmd+Q', () => {
if (!win.isFocused()) {
return;
}
app.quit();
});
globalShortcut.register('Escape', () => {
if (!win.isFocused()) {
return;
}
win.blur();
});
});
let settingsModal: BrowserWindow | null = null;
const createSettings = () => {
settingsModal = new BrowserWindow({
webPreferences: {
preload: path.join(__dirname, 'preload.js'),
},
width: 500,
height: 500,
resizable: false,
minimizable: false,
titleBarStyle: 'hidden',
show: false,
backgroundColor: '#000',
});
if (isProd) {
settingsModal.loadURL('app://./settings.html');
} else {
// const port = process.argv[2];
settingsModal.loadURL('http://localhost:3000/settings');
}
settingsModal.on('closed', () => {
openModal = null;
settingsModal?.destroy();
settingsModal = null;
});
settingsModal.on('ready-to-show', () => {
settingsModal?.show();
});
return settingsModal;
};
const nativeMenus: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [
{
label: 'MLX Chat',
submenu: [
{
label: 'Settings',
click() {
openModal = 'settings';
if (settingsModal !== null) {
settingsModal.close();
}
createSettings();
},
accelerator: 'Cmd+,',
},
],
},
{
label: 'Edit',
submenu: [
{ role: 'undo' },
{ role: 'redo' },
{ type: 'separator' },
{ role: 'cut' },
{ role: 'copy' },
{ role: 'paste' },
{ role: 'pasteAndMatchStyle' },
{ role: 'delete' },
{ role: 'selectAll' },
{ type: 'separator' },
{
label: 'Speech',
submenu: [
{ role: 'startSpeaking' },
{ role: 'stopSpeaking' },
],
},
],
},
];
const menu = Menu.buildFromTemplate(nativeMenus);
Menu.setApplicationMenu(menu);
};
app.whenReady().then(() => {
ipcMain.on('set-title', handleSetTitle);
ipcMain.on('select-directory', (event: any) => {
openModal = 'directory';
dialog.showOpenDialog({ properties: ['openDirectory'] }).then((result: any) => {
const win = BrowserWindow.fromWebContents(event.sender);
// Weird hack to bring the window to the front after allowing windows in front of it
win?.setAlwaysOnTop(true, 'floating');
openModal = null;
event.sender.send('selected-directory', result.filePaths);
});
});
ipcMain.on('resize-window', (event, arg) => {
const win = BrowserWindow.fromWebContents(event.sender);
if (!win) {
return;
}
win.setBounds({
height: arg.height,
});
win.center();
});
ipcMain.on('fetch-setting', (event, arg) => {
event.returnValue = store.get(arg);
});
ipcMain.on('update-setting', (_event, arg) => {
if (arg.key === 'keybind') {
globalShortcut.unregister(store.get('keybind') as string);
globalShortcut.register(arg.value, triggerShortcut.bind(null));
}
store.set(arg.key, arg.value);
});
createSplashScreen();
setTimeout(() => {
createWindow();
}, 500);
app.on('activate', () => {
if (BrowserWindow.getAllWindows().length === 0) { createWindow(); }
});
});
app.on('will-quit', () => {
exec(
`lsof -i :${serverManager.port} -P | awk 'NR>1 {print $2}' | xargs kill`,
(err, stdout, stderr) => {
if (err) {
console.log(err);
return;
}
console.log(`stdout: ${stdout}`);
console.log(`stderr: ${stderr}`);
},
);
BrowserWindow.getAllWindows().forEach((win) => {
win.close();
win.destroy();
});
});
================================================
FILE: app/main/preload.ts
================================================
// eslint-disable-next-line import/no-extraneous-dependencies
import {
contextBridge,
ipcRenderer,
} from 'electron';
export const electronAPI = {
setTitle: (title: string) => ipcRenderer.send('set-title', title),
selectDirectory: () => ipcRenderer.send('select-directory'),
onSelectDirectory: (cb: (customData: string[]) => void) => {
ipcRenderer.on('selected-directory', (event, customData) => {
// eslint-disable-next-line no-console
console.log(event);
cb(customData);
});
},
resizeWindow: (height: number) => ipcRenderer.send('resize-window', { height }),
fetchSetting: (key: string) => ipcRenderer.sendSync('fetch-setting', key),
updateSetting: (key: string, value: any) => ipcRenderer.send('update-setting', { key, value }),
};
contextBridge.exposeInMainWorld('electronAPI', electronAPI);
================================================
FILE: app/main/renderer.d.ts
================================================
import { electronAPI } from "./preload";
declare global {
interface Window {
electronAPI: typeof electronAPI;
}
}
export {};
================================================
FILE: app/main/splash/index.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;
div {
-webkit-user-select: none;
-webkit-app-region: drag;
}
.loading-bar {
display: block;
height: 0.2em;
background-color: rgba(255, 255, 255, 0.2);
position: relative;
overflow: hidden;
border-radius: 1rem;
}
.loading-bar:before {
content: "";
display: block;
position: absolute;
left: -100%;
width: 100%;
height: 100%;
background-color: white;
animation: loading-bar 1.5s ease-in-out infinite;
}
@keyframes loading-bar {
from {
left: -100%;
}
to {
left: 100%;
}
}
================================================
FILE: app/main/splash/index.html
================================================
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<title>FLOATING LOADING SCREEN</title>
<link rel="stylesheet" href="../tailwind.css" />
<link rel="stylesheet" href="./index.css" />
</head>
<body>
<div
class="h-screen w-screen fixed z-50 flex flex-col items-center justify-center text-white text-4xl gap-8 bg-slate-600"
>
<div class="loading-bar w-3/4 md:w-1/2 lg:w-1/3"></div>
</div>
</body>
</html>
================================================
FILE: app/main/tsconfig.json
================================================
{
"compilerOptions": {
"allowJs": true,
"alwaysStrict": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"isolatedModules": true,
"jsx": "preserve",
"lib": ["dom", "es2017"],
"module": "commonjs",
"moduleResolution": "node",
"noEmit": false,
"noFallthroughCasesInSwitch": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"resolveJsonModule": true,
"skipLibCheck": true,
"strict": true,
"target": "esnext",
"outDir": "./out",
},
"compileOnSave": true,
"exclude": ["node_modules", "./out/**/*"],
"include": ["**/*.ts", "**/*.tsx", "**/*.js", "public/**.icns"],
}
================================================
FILE: app/next.config.js
================================================
/** @type {import('next').NextConfig} */
const nextConfig = {
output: "export",
distDir: "out",
};
module.exports = nextConfig;
================================================
FILE: app/notarize.js
================================================
require('dotenv').config();
const { notarize } = require('electron-notarize');
exports.default = async function notarizing(context) {
const { electronPlatformName, appOutDir } = context;
if (electronPlatformName !== 'darwin') {
return;
}
const appName = context.packager.appInfo.productFilename;
return await notarize({
appBundleId: 'com.parkersmith.mlx-chat',
appPath: `${appOutDir}/${appName}.app`,
appleId: process.env.APPLEID,
appleIdPassword: process.env.APPLEIDPASS,
});
};
================================================
FILE: app/package.json
================================================
{
"name": "electron-app",
"productName": "Electron App",
"version": "0.1.0",
"private": true,
"main": "main/out/main.js",
"homepage": "./",
"description": "My Next.js project",
"author": "test",
"scripts": {
"dev": "cross-env NODE_ENV=development concurrently -k \"cross-env BROWSER=none npm run next:dev\" \"npm run electron:dev\"",
"build": " npm run build:main && next build",
"start": "cross-env npm run electron",
"build:tailwindMain": "npx tailwindcss build --config tailwind.config.main.js -o ./main/tailwind.css",
"build:main": "tsc -p main && npm run build:tailwindMain",
"next:dev": "next dev",
"next:start": "next start",
"next:lint": "next lint",
"electron:dev": "npm run build:main && wait-on tcp:3000 && electron .",
"electron": "electron .",
"pack": "npm run build && electron-builder --dir",
"dist": "npm run build && electron-builder",
"lint": "npx eslint --max-warnings 0 --ext=.ts ."
},
"dependencies": {
"@electron/osx-sign": "^1.0.5",
"@fortawesome/fontawesome-free": "^6.5.1",
"@fortawesome/fontawesome-svg-core": "^6.5.1",
"@fortawesome/free-regular-svg-icons": "^6.5.1",
"@fortawesome/free-solid-svg-icons": "^6.5.1",
"@fortawesome/react-fontawesome": "^0.2.0",
"@matejmazur/react-katex": "^3.1.3",
"@radix-ui/react-icons": "^1.3.0",
"@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-tooltip": "^1.0.7",
"@reduxjs/toolkit": "^2.2.1",
"@types/electron": "^1.6.10",
"@types/node": "^20.6.0",
"@types/react": "^18.2.21",
"@types/react-dom": "^18.2.7",
"autoprefixer": "^10.4.15",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.0",
"concurrently": "^8.2.1",
"cross-env": "^7.0.3",
"dprint": "^0.45.0",
"electron-context-menu": "^3.6.1",
"electron-serve": "^1.1.0",
"electron-squirrel-startup": "^1.0.0",
"electron-store": "^8.1.0",
"eslint": "8.41.0",
"eslint-config-next": "13.4.3",
"markdown-to-jsx": "^7.4.1",
"next": "13.4.3",
"postcss": "^8.4.29",
"react": "18.2.0",
"react-dom": "18.2.0",
"react-redux": "^9.1.0",
"react-resizable-panels": "^2.0.11",
"rxjs": "^7.8.1",
"tailwind-merge": "^2.2.1",
"tailwindcss": "^3.3.3",
"tailwindcss-animate": "^1.0.7",
"wait-on": "^7.0.1"
},
"devDependencies": {
"@typescript-eslint/eslint-plugin": "^6.21.0",
"@typescript-eslint/parser": "^6.21.0",
"dotenv": "^16.4.5",
"dprint": "^0.45.0",
"electron": "^26.2.0",
"electron-builder": "^24.6.4",
"electron-notarize": "^1.2.2",
"eslint": "^8.56.0",
"eslint-config-airbnb-base": "^15.0.0",
"eslint-plugin-compat": "^4.2.0",
"eslint-plugin-jest": "^27.6.3",
"eslint-plugin-jest-formatting": "^3.1.0",
"eslint-plugin-jsdoc": "^48.0.6",
"eslint-plugin-jsx-a11y": "^6.8.0",
"eslint-plugin-justinanastos": "^1.3.1",
"eslint-plugin-modules-newlines": "^0.0.7",
"eslint-plugin-no-unsanitized": "^4.0.2",
"eslint-plugin-react": "^7.33.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-unused-imports": "^3.0.0",
"typescript": "^5.2.2"
},
"build": {
"appId": "mlx-chat",
"productName": "MLX Chat",
"afterSign": "notarize.js",
"win": {
"target": [
"nsis"
]
},
"nsis": {
"oneClick": false,
"perMachine": true,
"allowToChangeInstallationDirectory": true,
"uninstallDisplayName": "MLX Chat"
},
"mac": {
"category": "your.app.category.type",
"target": [
"dmg"
],
"gatekeeperAssess": false,
"hardenedRuntime": true,
"icon": "assets/icon.icns",
"entitlements": "./mac/entitlements.mac.inherit.plist",
"entitlementsInherit": "./mac/entitlements.mac.inherit.plist"
},
"dmg": {
"title": "MLX Chat Installer",
"sign": false
},
"extraFiles": [
{
"from": "assets",
"to": "resources",
"filter": [
"**/*"
]
},
{
"from": "../dist",
"to": "resources/server",
"filter": [
"**/*"
]
}
]
}
}
================================================
FILE: app/postcss.config.js
================================================
module.exports = {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}
================================================
FILE: app/src/AppProvider.tsx
================================================
'use client';
import {
useRef,
} from 'react';
import {
Provider,
} from 'react-redux';
import type {
AppStore,
} from './lib/store';
import {
makeStore,
} from './lib/store';
export default function StoreProvider({
children,
}: {
children: React.ReactNode;
}) {
const storeRef = useRef<AppStore>();
if (!storeRef.current) {
// Create the store instance the first time this renders
storeRef.current = makeStore();
}
return <Provider store={storeRef.current}>{children}</Provider>;
}
================================================
FILE: app/src/app/globals.css
================================================
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
:root {
--background: 0 0% 100%;
--foreground: 240 10% 3.9%;
--card: 0 0% 100%;
--card-foreground: 240 10% 3.9%;
--popover: 0 0% 100%;
--popover-foreground: 240 10% 3.9%;
--primary: 240 5.9% 10%;
--primary-foreground: 0 0% 98%;
--secondary: 240 4.8% 95.9%;
--secondary-foreground: 240 5.9% 10%;
--muted: 240 4.8% 95.9%;
--muted-foreground: 240 3.8% 46.1%;
--accent: 240 4.8% 95.9%;
--accent-foreground: 240 5.9% 10%;
--destructive: 0 72.22% 50.59%;
--destructive-foreground: 0 0% 98%;
--border: 240 5.9% 90%;
--input: 240 5.9% 90%;
--ring: 240 5% 64.9%;
--radius: 0.5rem;
}
@media screen and (prefers-color-scheme: dark) {
:root {
--background: 240 10% 3.9%;
--foreground: 0 0% 98%;
--card: 240 10% 3.9%;
--card-foreground: 0 0% 98%;
--popover: 240 10% 3.9%;
--popover-foreground: 0 0% 98%;
--primary: 0 0% 98%;
--primary-foreground: 240 5.9% 10%;
--secondary: 240 3.7% 15.9%;
--secondary-foreground: 0 0% 98%;
--muted: 240 3.7% 15.9%;
--muted-foreground: 240 5% 64.9%;
--accent: 240 3.7% 15.9%;
--accent-foreground: 0 0% 98%;
--destructive: 0 62.8% 30.6%;
--destructive-foreground: 0 85.7% 97.3%;
--border: 240 3.7% 15.9%;
--input: 240 3.7% 15.9%;
--ring: 240 4.9% 83.9%;
}
}
}
@layer base {
* {
@apply border-border;
}
body {
@apply text-foreground;
/* font-feature-settings: "rlig" 1, "calt" 1; */
font-synthesis-weight: none;
text-rendering: optimizeLegibility;
}
}
@layer utilities {
.step {
counter-increment: step;
}
.step:before {
@apply absolute w-9 h-9 bg-muted rounded-full font-mono font-medium text-center text-base inline-flex items-center justify-center -indent-px border-4 border-background;
@apply ml-[-50px] mt-[-4px];
content: counter(step);
}
}
@media (max-width: 640px) {
.container {
@apply px-4;
}
}
/* Update scrollbar when in dark mode */
@media screen and (prefers-color-scheme: dark) {
::-webkit-scrollbar-thumb {
background-color: hsl(var(--muted));
border-radius: 5px;
transition: all;
}
::-webkit-scrollbar-thumb:hover {
background-color: hsl(255, 4%, 20%);
}
}
/* Update scrollbar when in light mode */
@media screen and (prefers-color-scheme: light) {
::-webkit-scrollbar-thumb {
background-color: rgb(38 38 38);
border-radius: 5px;
transition: all;
}
::-webkit-scrollbar-thumb:hover {
background-color: hsl(255, 4%, 20%);
}
}
::-webkit-scrollbar {
width: 7px;
}
::-webkit-scrollbar-track {
background-color: transparent;
}
html {
font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;
}
ol,
ul,
menu {
list-style: outside;
margin: 0;
padding-left: 20px;
}
.drag {
-webkit-app-region: drag;
}
.no-drag {
-webkit-app-region: no-drag;
}
================================================
FILE: app/src/app/layout.tsx
================================================
'use client';
import StoreProvider from '../AppProvider';
import './globals.css';
import '@fortawesome/fontawesome-svg-core/styles.css';
// Prevent fontawesome from adding its CSS since we did it manually above:
import {
config,
} from '@fortawesome/fontawesome-svg-core';
import {
TooltipProvider,
} from '../components/ui/tooltip';
config.autoAddCss = false;
export default function RootLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<html lang='en'>
<body
className={'min-h-screen overflow-y-hidden'}
style={{
userSelect: 'none',
}}
>
<TooltipProvider>
<StoreProvider>
{children}
</StoreProvider>
</TooltipProvider>
</body>
</html>
);
}
================================================
FILE: app/src/app/page.tsx
================================================
'use client';
import {
faBan,
faCheckCircle,
} from '@fortawesome/free-solid-svg-icons';
import {
FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
useEffect,
useState,
} from 'react';
import Chat from '../components/chat/Chat';
import SelectDirectory from '../components/options/SelectDirectory';
import {
Button,
} from '../components/ui/button';
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '../components/ui/tooltip';
import type {
ChatMessage,
} from '../constants/chat';
import {
useAppDispatch,
} from '../lib/hooks';
import {
startDirectoryIndexing,
stopDirectoryIndexing,
} from '../lib/store';
export default function Home() {
const [selectedDirectory, setSelectedDirectory] = useState<string | null>(null);
const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);
const dispatch = useAppDispatch();
function handleOpen() {
if (typeof window !== 'undefined') {
window.electronAPI.selectDirectory();
}
}
useEffect(() => {
window.electronAPI.onSelectDirectory(async (customData: string[]) => {
setSelectedDirectory(customData[0]);
try {
dispatch(startDirectoryIndexing());
await fetch('http://localhost:8080/api/index', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
directory: customData[0],
}),
});
dispatch(stopDirectoryIndexing());
// TODO: spinner while indexing
} catch (error) {
// eslint-disable-next-line no-console
console.error('Error sending message: ', error);
dispatch(stopDirectoryIndexing());
}
});
}, []);
useEffect(() => {
window.electronAPI.onSelectDirectory(() => {
if (chatHistory.length) {
setChatHistory([
...chatHistory,
{ role: 'system', content: 'Assist' },
]);
}
});
}, [chatHistory]);
const handleClearHistory = () => {
setChatHistory([]);
if (typeof window !== 'undefined') {
window.electronAPI.resizeWindow(99);
}
};
const clearDirectory = () => {
setSelectedDirectory(null);
if (chatHistory.length) {
setChatHistory([
...chatHistory,
{ role: 'system', content: 'Converse' },
]);
}
};
return (
<main className='flex flex-col'>
<Chat
chatHistory={chatHistory}
setChatHistory={setChatHistory}
selectedDirectory={selectedDirectory}
/>
<div className='border-t border-t-neut
ral-400 dark:border-t-neutral-700 pt-[5px] px-2'>
<div className='flex justify-between drag'>
{chatHistory.length
? (
<Tooltip delayDuration={0}>
<TooltipTrigger>
<Button
className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'
onClick={handleClearHistory}
>
<FontAwesomeIcon icon={faCheckCircle} className='text-green-500' />
</Button>
</TooltipTrigger>
<TooltipContent>Clear History</TooltipContent>
</Tooltip>
)
: (
<Button
className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'
disabled={true}
>
<FontAwesomeIcon icon={faBan} className='text-red-400' />
</Button>
)}
<SelectDirectory
clearDirectory={clearDirectory}
handleOpen={handleOpen}
selectedDirectory={selectedDirectory}
/>
</div>
</div>
</main>
);
}
================================================
FILE: app/src/app/settings/page.tsx
================================================
'use client';
import type {
IconProp,
} from '@fortawesome/fontawesome-svg-core';
import {
faCog,
faMessage,
} from '@fortawesome/free-solid-svg-icons';
import {
FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
useEffect,
} from 'react';
import SelectModel from '../../components/options/SelectModel';
import {
Textarea,
} from '../../components/ui/textarea';
import {
convertToNiceShortcut,
useKeyboardShortcut,
} from '../../lib/hooks';
import {
cn,
} from '../../lib/utils';
enum SETTINGS {
GENERAL,
PROMPTS,
}
function SettingsOption({
title,
icon,
onClick,
selected,
}: {
title: string;
icon: IconProp;
onClick: () => void;
selected?: boolean;
}) {
return (
<div
onClick={onClick}
className={cn(
'flex flex-col items-center h-[44px] gap-[2px] no-drag w-fit p-1 px-2 justify-center hover:bg-[#E2E3E2] dark:hover:bg-[#454544] active:bg-[#D5D6D5] dark:active:bg-[#525251] cursor-default rounded-md text-[#6C6C6C] dark:text-[#9C9C9B] active:text-[#202020] dark:active:text-[#EEEEEE]',
{
'bg-[#E2E3E2] dark:bg-[#454544]': selected,
},
)}
>
<FontAwesomeIcon
className={cn('text-[20px] pt-1', {
'dark:text-[#0D87FF] text-[#0066EB]': selected,
})}
icon={icon}
/>
<h1
className={cn('text-[11px]', {
'dark:text-[#0D87FF] text-[#0066EB]': selected,
})}
>
{title}
</h1>
</div>
);
}
function GeneralSettings() {
const {
startListening,
stopListening,
shortcut,
} = useKeyboardShortcut();
const [keybind, setKeybind] = React.useState<string>(
typeof window !== 'undefined' ? window.electronAPI.fetchSetting('keybind') : '⌘O',
);
const [model, setModel] = React.useState<string>(
typeof window !== 'undefined'
? window.electronAPI.fetchSetting('model')
: 'mistralai/Mistral-7B-Instruct-v0.2',
);
useEffect(() => {
if (!shortcut) {
return;
}
setKeybind(shortcut);
}, [shortcut]);
return (
<div className='flex flex-col justify-center w-full items-center'>
<div className='flex items-center mt-2'>
<p className='text-sm'>Launch keybind:</p>
<input
className='rounded-sm text-[12px] text-center drop-shadow-sm bg-[#fffff] dark:bg-[#343432] border-0 h-[18px] w-28 ml-3 active:border-0 focus:border-0 outline-offset-2 focus:outline-blue-400'
type='text'
readOnly
value={convertToNiceShortcut(keybind)}
onFocus={startListening}
onBlur={() => {
stopListening();
if (typeof window !== 'undefined') {
window.electronAPI.updateSetting('keybind', shortcut);
}
}}
/>
</div>
<div className='flex items-center mt-2'>
<p className='text-sm mr-2'>Default model:</p>
<SelectModel
selectedModel={model}
handleModelChange={(selectedModel) => {
setModel(selectedModel);
if (typeof window !== 'undefined' && selectedModel) {
window.electronAPI.updateSetting('model', selectedModel);
}
}}
/>
</div>
</div>
);
}
function PromptSettings() {
const [personalization, setPersonalization] = React.useState<string>(
typeof window !== 'undefined' ? window.electronAPI.fetchSetting('personalization') : '',
);
const [response, setResponse] = React.useState<string>(
typeof window !== 'undefined' ? window.electronAPI.fetchSetting('customResponse') : '',
);
return (
<div className='flex flex-col justify-center w-full items-center gap-4'>
<div className='flex flex-col items-center mt-2 gap-2'>
<p className='text-sm flex-shrink-0 font-bold'>Personalization</p>
<Textarea
className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'
value={personalization}
onChange={(e) => {
setPersonalization(e.target.value);
if (typeof window !== 'undefined') {
window.electronAPI.updateSetting('personalization', e.target.value);
}
}}
rows={5}
placeholder={`Things to know about you... e.g.,
- I enjoy thought provoking conversation
- I am a fan of the arts and culture`}
/>
</div>
<div className='flex flex-col items-center mt-2 gap-2'>
<p className='text-sm flex-shrink-0 font-bold'>Custom Response</p>
<Textarea
className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'
value={response}
onChange={(e) => {
setResponse(e.target.value);
if (typeof window !== 'undefined') {
window.electronAPI.updateSetting('customResponse', e.target.value);
}
}}
rows={5}
placeholder={`How to format responses... e.g.,
- Respond in a concise manner
- Do not use slang`}
/>
</div>
</div>
);
}
export default function Settings() {
const [selectedSetting, setSelectedSetting] = React.useState<SETTINGS>(SETTINGS.GENERAL);
return (
<main className='flex flex-col bg-[#F1F1F1] dark:bg-[#383736] h-screen'>
<div className='h-[81px] border-0 border-b border-b-neutral-300 dark:border-b-zinc-950 drag flex flex-col'>
<h1 className='text-[12px] font-bold text-[#6C6C6C] dark:text-[#9C9C9B] text-center pt-1'>
{selectedSetting === SETTINGS.GENERAL
? 'General'
: 'Prompt'}
</h1>
<div className='flex justify-center items-center gap-[1px] mt-1'>
<SettingsOption
title='General'
icon={faCog}
onClick={() => setSelectedSetting(SETTINGS.GENERAL)}
selected={selectedSetting === SETTINGS.GENERAL}
/>
<SettingsOption
title='Prompts'
icon={faMessage}
onClick={() => setSelectedSetting(SETTINGS.PROMPTS)}
selected={selectedSetting === SETTINGS.PROMPTS}
/>
</div>
</div>
<div className='flex-grow dark:bg-[#292929] bg-[#EEEDEC]'>
{selectedSetting === SETTINGS.GENERAL
? <GeneralSettings />
: <PromptSettings />}
</div>
</main>
);
}
================================================
FILE: app/src/components/chat/Chat.tsx
================================================
import React from 'react';
import type {
ChatMessage,
} from '../../constants/chat';
import {
useAppDispatch,
} from '../../lib/hooks';
import {
startWaitingForResponse,
stopWaitingForResponse,
} from '../../lib/store';
import {
cn,
} from '../../lib/utils';
import ChatInput from './ChatInput';
import ChatMessages from './ChatMessages';
const Chat = ({
selectedDirectory,
chatHistory,
setChatHistory,
}: {
selectedDirectory: string | null;
chatHistory: ChatMessage[];
setChatHistory: (chats: ChatMessage[]) => void;
}) => {
const dispatch = useAppDispatch();
const sendMessage = async (message: string) => {
try {
if (chatHistory.length === 0) {
window.electronAPI.resizeWindow(500);
}
const newHistory = [
...chatHistory,
{ role: 'user' as const, content: message },
];
setChatHistory(newHistory);
dispatch(startWaitingForResponse());
const response = await fetch('http://localhost:8080/api/query', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages: selectedDirectory
? [{ role: 'user', content: message }]
: newHistory.filter((chat) => chat.role !== 'system'),
temperature: 0.7,
// eslint-disable-next-line @typescript-eslint/naming-convention
top_p: 1.0,
// eslint-disable-next-line @typescript-eslint/naming-convention
max_tokens: 200,
directory: selectedDirectory,
instructions: {
personalization: typeof window !== 'undefined'
? window.electronAPI.fetchSetting('personalization')
: '',
response: typeof window !== 'undefined'
? window.electronAPI.fetchSetting('customResponse')
: '',
},
}),
});
dispatch(stopWaitingForResponse());
const responseData = await response.json();
const assistantResponse = responseData.choices[0].message.content;
setChatHistory([
...newHistory,
{ role: 'assistant', content: assistantResponse },
]);
} catch (error) {
dispatch(stopWaitingForResponse());
// eslint-disable-next-line no-console
console.error('Error sending message: ', error);
}
};
return (
<>
<div
className={cn('flex justify-center border-b-neutral-400 dark:border-b-neutral-700', {
'border-b': chatHistory.length > 0,
})}
>
<ChatInput sendMessage={sendMessage} />
</div>
<div
className={cn(
'flex-grow min-w-full border-0 flex h-0',
{
'h-[400px]': chatHistory.length > 0,
},
)}
>
<ChatMessages chatHistory={chatHistory} />
</div>
</>
);
};
export default Chat;
================================================
FILE: app/src/components/chat/ChatInput.tsx
================================================
import React, {
useEffect,
useRef,
useState,
} from 'react';
import {
useAppSelector,
} from '../../lib/hooks';
import {
Input,
} from '../ui/input';
const ChatInput = ({
sendMessage,
}: {
sendMessage: (text: string) => void;
}) => {
const [message, setMessage] = useState<string>('');
const inputRef = useRef<HTMLInputElement>(null);
const handleSend = (e: React.KeyboardEvent) => {
if (e.key !== 'Enter' || message.length === 0) {
return;
}
e.preventDefault();
setMessage('');
sendMessage(message);
};
// detect website focus and focus the input
const handleFocus = () => {
if (document.activeElement !== inputRef.current) {
inputRef.current?.focus();
}
};
useEffect(() => {
window.addEventListener('focus', handleFocus);
return () => {
window.removeEventListener('focus', handleFocus);
};
}, []);
const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);
return (
<div
className='w-full py-2 drag'
onClick={() => {
if (document.activeElement !== inputRef.current) {
inputRef.current?.focus();
}
}}
>
<Input
value={message}
onChange={(e) => setMessage(e.target.value)}
placeholder={isDirectoryIndexing ? 'Indexing your files..' : 'Enter prompt here'}
onKeyDown={handleSend}
ref={inputRef}
disabled={isDirectoryIndexing}
className={'text-xl no-drag border-0 focus-visible:outline-transparent focus-visible:ring-0 focus-visible:shadow-0 w-full shadow-0'}
/>
</div>
);
};
export default ChatInput;
================================================
FILE: app/src/components/chat/ChatMessage.tsx
================================================
import Markdown from 'markdown-to-jsx';
import React from 'react';
import type {
ChatMessage,
} from '../../constants/chat';
const Message = ({
message,
}: {
message: ChatMessage;
}) => (
<div
className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}
>
<div
className={`p-2 rounded-sm ${
message.role === 'user'
? 'bg-blue-500 text-white'
: 'bg-[#E9E9EB] dark:bg-zinc-500'
}`}
>
<div className='text-md select-text'>
<Markdown
children={message.content ?? ''}
/>
</div>
</div>
</div>
);
export default Message;
================================================
FILE: app/src/components/chat/ChatMessages.tsx
================================================
/* eslint-disable function-paren-newline */
import {
faCircleNotch,
} from '@fortawesome/free-solid-svg-icons';
import {
FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
useEffect,
} from 'react';
import type {
ChatMessage,
} from '../../constants/chat';
import {
useAppSelector,
} from '../../lib/hooks';
import Message from './ChatMessage';
import SystemMessage from './SystemMessage';
const ChatMessages = ({
chatHistory,
}: {
chatHistory: ChatMessage[];
}) => {
const messagesRef = React.useRef<HTMLDivElement>(null);
const scrollToBottom = () => {
const scrollHeight = messagesRef.current?.scrollHeight;
const height = messagesRef.current?.clientHeight ?? 0;
const maxScrollTop = scrollHeight ? scrollHeight - height : 0;
if (messagesRef.current) {
messagesRef.current.scrollTop = maxScrollTop > 0 ? maxScrollTop : 0;
}
};
const isWaitingForResponse = useAppSelector((state) => state.isWaitingForResponse);
useEffect(() => {
// check if the user is not at the bottom of the chat
const currentScroll = messagesRef.current?.scrollTop ?? 0;
const scrollHeight = messagesRef.current?.scrollHeight;
const height = messagesRef.current?.clientHeight ?? 0;
const maxScrollTop = scrollHeight ? scrollHeight - height : 0;
const scrollInHistory = (maxScrollTop - currentScroll) > 200;
if (scrollInHistory && chatHistory[chatHistory.length - 1]?.role !== 'user') {
return;
}
scrollToBottom();
}, [chatHistory]);
return chatHistory.length
? (
<div ref={messagesRef} className='flex flex-col flex-grow p-4 gap-4 overflow-y-scroll'>
{chatHistory.map((message, index) => (message.role !== 'system'
? (
<Message
key={index}
message={message}
/>
)
: (
<SystemMessage
key={index}
message={message}
/>
))
)}
{isWaitingForResponse
? (
<div
className={'flex justify-start'}
>
<div
className={'p-2 rounded-sm'}
>
<div className='text-md select-text'>
<FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />
</div>
</div>
</div>
)
: null}
</div>
)
: null;
};
export default ChatMessages;
================================================
FILE: app/src/components/chat/SystemMessage.tsx
================================================
import React from 'react';
import type {
ChatMessage,
} from '../../constants/chat';
const Message = ({
message,
}: {
message: ChatMessage;
}) => (
<div
className={'flex w-full'}
>
<div
className={'rounded-sm w-full relative flex items-center'}
>
<div className='w-full h-[1px] bg-red-500 rounded-md' />
<div className='text-[12px] font-semibold select-text px-2 text-red-500 bg-transparent flex-grow rounded-md text-center py-1 whitespace-nowrap'>
{message.content}
</div>
<div className='w-full h-[1px] bg-red-500 rounded-md' />
<div className='arrow-tag text-[10px] p-0 absolute right-0 font-bold select-text uppercase pr-1 pl-1 text-white bg-red-500 flex-grow rounded-sm text-center whitespace-nowrap'>
<svg
className='absolute -left-[5px] top-[1px] z-[-1]'
aria-hidden='true'
role='img'
width='8'
height='13'
viewBox='0 0 8 13'
>
<path
className='fill-red-500 text-red-500'
stroke='currentColor'
fill='transparent'
d='M8.16639 0.5H9C10.933 0.5 12.5 2.067 12.5 4V9C12.5 10.933 10.933 12.5 9 12.5H8.16639C7.23921 12.5 6.34992 12.1321 5.69373 11.4771L0.707739 6.5L5.69373 1.52292C6.34992 0.86789 7.23921 0.5 8.16639 0.5Z'
>
</path>
</svg>
Mode
</div>
</div>
</div>
);
export default Message;
================================================
FILE: app/src/components/options/SelectDirectory.tsx
================================================
import {
faCheckCircle,
faCircleNotch,
faXmark,
} from '@fortawesome/free-solid-svg-icons';
import {
FontAwesomeIcon,
} from '@fortawesome/react-fontawesome';
import React, {
useEffect,
} from 'react';
import {
useAppSelector,
usePrevious,
} from '../../lib/hooks';
import {
cn,
} from '../../lib/utils';
import {
Button,
} from '../ui/button';
const SelectDirectory = ({
handleOpen,
selectedDirectory,
clearDirectory,
}: {
handleOpen: () => void;
selectedDirectory: string | null;
clearDirectory: () => void;
}) => {
const shortenedDirectory = selectedDirectory
? `/${selectedDirectory.split('/')[1]}/../${selectedDirectory.split('/').pop()}`
: 'Select Directory';
const [isCheckShowing, setIsCheckShowing] = React.useState(false);
const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);
const oldLoadingState = usePrevious(isDirectoryIndexing);
useEffect(() => {
if (oldLoadingState && !isDirectoryIndexing) {
setIsCheckShowing(true);
setTimeout(() => {
setIsCheckShowing(false);
}, 3000);
}
}, [isDirectoryIndexing]);
return (
<div className='flex no-drag items-center group'>
<Button
className={cn(
'bg-transparent text-neutral-800 z-0 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-2 flex items-center cursor-pointer',
{
'hover:bg-transparent dark:hover:bg-transparent cursor-default': isDirectoryIndexing,
},
)}
onClick={isDirectoryIndexing ? undefined : handleOpen}
>
<div className='pr-1'>
{selectedDirectory && !isDirectoryIndexing && !isCheckShowing && (
<div
className='group-hover:opacity-100 opacity-0 px-1 z-10 hover:bg-neutral-300 dark:hover:bg-neutral-800 rounded-sm transition-all cursor-pointer'
onClick={(e) => {
e.stopPropagation();
clearDirectory();
}}
>
<FontAwesomeIcon
icon={faXmark}
/>
</div>
)}
{isDirectoryIndexing && <FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />}
{isCheckShowing && <FontAwesomeIcon className='text-green-500' icon={faCheckCircle} />}
</div>
{shortenedDirectory}
</Button>
</div>
);
};
export default SelectDirectory;
================================================
FILE: app/src/components/options/SelectModel.tsx
================================================
import React from 'react';
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectTrigger,
SelectValue,
} from '../ui/select';
const SelectModel = ({
selectedModel,
handleModelChange,
}: {
selectedModel: string | null;
handleModelChange: (model: string) => void;
}) => (
<div className='no-drag'>
<Select
value={selectedModel ?? ''}
onValueChange={(value) => handleModelChange(value)}
>
<SelectTrigger className='a-icon w-[140px] h-5 border-none shadow-transparent bg-white dark:bg-[#606160] transition-all border border-zinc-600 focus:ring-0 focus-within:ring-0 focus-visible:ring-0 peer-focus-within:ring-0 text-neutral-800 dark:text-white'>
<SelectValue placeholder='Select a model' />
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectLabel>AI Model</SelectLabel>
<SelectItem value='mistralai/Mistral-7B-Instruct-v0.2'>Mistral7B</SelectItem>
<SelectItem value='google/gemma-7b-it'>Gemma7B</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
</div>
);
export default SelectModel;
================================================
FILE: app/src/components/ui/button.tsx
================================================
import {
Slot,
} from '@radix-ui/react-slot';
import {
cva,
type VariantProps,
} from 'class-variance-authority';
import * as React from 'react';
import {
cn,
} from '../../lib/utils';
const buttonVariants = cva(
'inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50',
{
variants: {
variant: {
default: 'bg-primary text-primary-foreground shadow hover:bg-primary/90',
destructive: 'bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90',
outline:
'border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground',
secondary: 'bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80',
ghost: 'hover:bg-accent hover:text-accent-foreground',
link: 'text-primary underline-offset-4 hover:underline',
},
size: {
default: 'h-9 px-4 py-2',
sm: 'h-8 rounded-md px-3 text-xs',
lg: 'h-10 rounded-md px-8',
icon: 'h-9 w-9',
},
},
defaultVariants: {
variant: 'default',
size: 'default',
},
},
);
export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>, VariantProps<typeof buttonVariants>
{
asChild?: boolean;
}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({
className,
variant,
size,
asChild = false,
...props
}, ref) => {
// eslint-disable-next-line @typescript-eslint/naming-convention
const Comp = asChild ? Slot : 'button';
return (
<Comp
className={cn(buttonVariants({ variant, size, className }))}
ref={ref}
{...props}
/>
);
},
);
Button.displayName = 'Button';
export {
Button,
buttonVariants,
};
================================================
FILE: app/src/components/ui/input.tsx
================================================
import * as React from 'react';
import {
cn,
} from '../../lib/utils';
export interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}
const Input = React.forwardRef<HTMLInputElement, InputProps>(
({ className, type, ...props }, ref) => (
<input
type={type}
className={cn(
'flex h-9 w-full rounded-md border border-input bg-transparent px-3 py-1 text-sm shadow-0 transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',
className,
)}
ref={ref}
{...props}
/>
),
);
Input.displayName = 'Input';
export {
Input,
};
================================================
FILE: app/src/components/ui/resizable.tsx
================================================
'use client';
import {
DragHandleDots2Icon,
} from '@radix-ui/react-icons';
import * as ResizablePrimitive from 'react-resizable-panels';
import {
cn,
} from '../../lib/utils';
const ResizablePanelGroup = ({
className,
...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (
<ResizablePrimitive.PanelGroup
className={cn(
'flex h-full w-full data-[panel-group-direction=vertical]:flex-col',
className,
)}
{...props}
/>
);
const ResizablePanel = ResizablePrimitive.Panel;
const ResizableHandle = ({
withHandle,
className,
...props
}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {
withHandle?: boolean;
}) => (
<ResizablePrimitive.PanelResizeHandle
className={cn(
'relative flex w-px items-center justify-center bg-border after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90',
className,
)}
{...props}
>
{withHandle && (
<div className='z-10 flex h-4 w-3 items-center justify-center rounded-sm border bg-border'>
<DragHandleDots2Icon className='h-2.5 w-2.5' />
</div>
)}
</ResizablePrimitive.PanelResizeHandle>
);
export {
ResizableHandle,
ResizablePanel,
ResizablePanelGroup,
};
================================================
FILE: app/src/components/ui/select.tsx
================================================
'use client';
import {
CaretSortIcon,
CheckIcon,
ChevronDownIcon,
ChevronUpIcon,
} from '@radix-ui/react-icons';
import * as SelectPrimitive from '@radix-ui/react-select';
import * as React from 'react';
import {
cn,
} from '../../lib/utils';
const Select = SelectPrimitive.Root;
const SelectGroup = SelectPrimitive.Group;
const SelectValue = SelectPrimitive.Value;
const SelectTrigger = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Trigger
ref={ref}
className={cn(
'flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus:outline-none disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1',
className,
)}
{...props}
>
{children}
{className?.includes('a-icon')
? (
<SelectPrimitive.Icon asChild>
<CaretSortIcon className='h-4 w-4 opacity-50' />
</SelectPrimitive.Icon>
)
: null}
</SelectPrimitive.Trigger>
));
SelectTrigger.displayName = SelectPrimitive.Trigger.displayName;
const SelectScrollUpButton = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.ScrollUpButton>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton>
>(({ className, ...props }, ref) => (
<SelectPrimitive.ScrollUpButton
ref={ref}
className={cn(
'flex cursor-default items-center justify-center py-1',
className,
)}
{...props}
>
<ChevronUpIcon />
</SelectPrimitive.ScrollUpButton>
));
SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName;
const SelectScrollDownButton = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.ScrollDownButton>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton>
>(({ className, ...props }, ref) => (
<SelectPrimitive.ScrollDownButton
ref={ref}
className={cn(
'flex cursor-default items-center justify-center py-1',
className,
)}
{...props}
>
<ChevronDownIcon />
</SelectPrimitive.ScrollDownButton>
));
SelectScrollDownButton.displayName = SelectPrimitive.ScrollDownButton.displayName;
const SelectContent = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content>
>(({
className,
children,
position = 'popper',
...props
}, ref) => (
<SelectPrimitive.Portal>
<SelectPrimitive.Content
ref={ref}
className={cn(
'relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
position === 'popper'
&& 'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1',
className,
)}
position={position}
{...props}
>
<SelectScrollUpButton />
<SelectPrimitive.Viewport
className={cn(
'p-1',
position === 'popper'
&& 'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]',
)}
>
{children}
</SelectPrimitive.Viewport>
<SelectScrollDownButton />
</SelectPrimitive.Content>
</SelectPrimitive.Portal>
));
SelectContent.displayName = SelectPrimitive.Content.displayName;
const SelectLabel = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Label>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label>
>(({ className, ...props }, ref) => (
<SelectPrimitive.Label
ref={ref}
className={cn('px-2 py-1.5 text-sm font-semibold', className)}
{...props}
/>
));
SelectLabel.displayName = SelectPrimitive.Label.displayName;
const SelectItem = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Item
ref={ref}
className={cn(
'relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-2 pr-8 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50',
className,
)}
{...props}
>
<span className='absolute right-2 flex h-3.5 w-3.5 items-center justify-center'>
<SelectPrimitive.ItemIndicator>
<CheckIcon className='h-4 w-4' />
</SelectPrimitive.ItemIndicator>
</span>
<SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
</SelectPrimitive.Item>
));
SelectItem.displayName = SelectPrimitive.Item.displayName;
const SelectSeparator = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Separator>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator>
>(({ className, ...props }, ref) => (
<SelectPrimitive.Separator
ref={ref}
className={cn('-mx-1 my-1 h-px bg-muted', className)}
{...props}
/>
));
SelectSeparator.displayName = SelectPrimitive.Separator.displayName;
export {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectScrollDownButton,
SelectScrollUpButton,
SelectSeparator,
SelectTrigger,
SelectValue,
};
================================================
FILE: app/src/components/ui/textarea.tsx
================================================
import * as React from 'react';
import {
cn,
} from '../../lib/utils';
export interface TextareaProps extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {}
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
({ className, ...props }, ref) => (
<textarea
className={cn(
'flex min-h-[60px] w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',
className,
)}
ref={ref}
{...props}
/>
),
);
Textarea.displayName = 'Textarea';
export {
Textarea,
};
================================================
FILE: app/src/components/ui/tooltip.tsx
================================================
'use client';
import * as TooltipPrimitive from '@radix-ui/react-tooltip';
import * as React from 'react';
import {
cn,
} from '../../lib/utils';
const TooltipProvider = TooltipPrimitive.Provider;
const Tooltip = TooltipPrimitive.Root;
const TooltipTrigger = TooltipPrimitive.Trigger;
const TooltipContent = React.forwardRef<
React.ElementRef<typeof TooltipPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
>(({ className, sideOffset = 4, ...props }, ref) => (
<TooltipPrimitive.Content
ref={ref}
sideOffset={sideOffset}
className={cn(
'z-50 overflow-hidden rounded-md bg-primary px-3 py-1.5 text-xs text-primary-foreground animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
className,
)}
{...props}
/>
));
TooltipContent.displayName = TooltipPrimitive.Content.displayName;
export {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
};
================================================
FILE: app/src/constants/chat.tsx
================================================
export type ChatMessage = {
role: 'user' | 'assistant' | 'system';
content: string | null;
};
================================================
FILE: app/src/lib/hooks.ts
================================================
import {
useEffect,
useRef,
useState,
} from 'react';
import {
useDispatch,
useSelector,
useStore,
} from 'react-redux';
import type {
TypedUseSelectorHook,
} from 'react-redux';
import type {
AppDispatch,
AppStore,
RootState,
} from './store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore: () => AppStore = useStore;
export function usePrevious<T>(value: T): T | undefined {
const ref = useRef<T>();
useEffect(() => {
ref.current = value;
});
return ref.current;
}
export function convertToNiceShortcut(shortcut: string) {
return shortcut.replace('Cmd', '⌘').replace('Option', '⌥').replace('Shift', '⇧').replaceAll(
'+',
'',
);
}
export function useKeyboardShortcut() {
const [isListening, setIsListening] = useState(false);
const [shortcut, setShortcut] = useState('');
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (!isListening) { return; }
// Prevent default action to avoid interfering with normal browser shortcuts
event.preventDefault();
// command key for mac icon
//
const keys = [];
if (event.ctrlKey) { keys.push('Ctrl'); }
if (event.shiftKey) { keys.push('Shift'); }
if (event.altKey) { keys.push('Option'); }
if (event.metaKey) {
keys.push('Cmd'); // Command key for Mac
}
// Avoid adding modifier keys alone, check if another key is also pressed
if (
event.key.length === 1
|| (event.key !== 'Control' && event.key !== 'Shift' && event.key !== 'Alt'
&& event.key !== 'Meta')
) {
keys.push(event.key.toUpperCase());
}
const combination = keys.join('+');
setShortcut(combination);
};
if (isListening) {
window.addEventListener('keydown', handleKeyDown);
}
return () => {
window.removeEventListener('keydown', handleKeyDown);
};
}, [isListening]);
const startListening = () => setIsListening(true);
const stopListening = () => setIsListening(false);
return {
startListening,
stopListening,
shortcut,
};
}
================================================
FILE: app/src/lib/store.ts
================================================
import {
configureStore,
createSlice,
} from '@reduxjs/toolkit';
const globalSlice = createSlice({
name: 'global',
initialState: {
isDirectoryIndexing: false,
isWaitingForResponse: false,
},
reducers: {
startDirectoryIndexing: (state) => {
// eslint-disable-next-line no-param-reassign
state.isDirectoryIndexing = true;
},
stopDirectoryIndexing: (state) => {
// eslint-disable-next-line no-param-reassign
state.isDirectoryIndexing = false;
},
startWaitingForResponse: (state) => {
// eslint-disable-next-line no-param-reassign
state.isWaitingForResponse = true;
},
stopWaitingForResponse: (state) => {
// eslint-disable-next-line no-param-reassign
state.isWaitingForResponse = false;
},
},
});
export const {
startDirectoryIndexing,
stopDirectoryIndexing,
startWaitingForResponse,
stopWaitingForResponse,
} = globalSlice.actions;
export const makeStore = () =>
configureStore({
reducer: globalSlice.reducer,
});
// Infer the type of makeStore
export type AppStore = ReturnType<typeof makeStore>;
// Infer the `RootState` and `AppDispatch` types from the store itself
export type RootState = ReturnType<AppStore['getState']>;
export type AppDispatch = AppStore['dispatch'];
================================================
FILE: app/src/lib/utils.ts
================================================
import {
type ClassValue,
clsx,
} from 'clsx';
import {
twMerge,
} from 'tailwind-merge';
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}
================================================
FILE: app/tailwind.config.main.js
================================================
/** @type {import('tailwindcss').Config} */
module.exports = {
content: ["./main/**/*.{js,ts,jsx,tsx,mdx,html}"],
// purge: ['./subdir/index.html', './src/components/**/*.{js,ts,jsx,tsx,mdx}',
// ],
theme: {
extend: {},
},
plugins: [],
};
================================================
FILE: app/tailwind.config.ts
================================================
/* eslint-disable @typescript-eslint/naming-convention */
import type {
Config,
} from 'tailwindcss';
const config = {
darkMode: 'media',
content: [
'./pages/**/*.{ts,tsx}',
'./components/**/*.{ts,tsx}',
'./app/**/*.{ts,tsx}',
'./src/**/*.{ts,tsx}',
],
prefix: '',
theme: {
container: {
center: true,
padding: '2rem',
screens: {
'2xl': '1400px',
},
},
extend: {
colors: {
border: 'hsl(var(--border))',
input: 'hsl(var(--input))',
ring: 'hsl(var(--ring))',
background: 'hsl(var(--background))',
foreground: 'hsl(var(--foreground))',
primary: {
DEFAULT: 'hsl(var(--primary))',
foreground: 'hsl(var(--primary-foreground))',
},
secondary: {
DEFAULT: 'hsl(var(--secondary))',
foreground: 'hsl(var(--secondary-foreground))',
},
destructive: {
DEFAULT: 'hsl(var(--destructive))',
foreground: 'hsl(var(--destructive-foreground))',
},
muted: {
DEFAULT: 'hsl(var(--muted))',
foreground: 'hsl(var(--muted-foreground))',
},
accent: {
DEFAULT: 'hsl(var(--accent))',
foreground: 'hsl(var(--accent-foreground))',
},
popover: {
DEFAULT: 'hsl(var(--popover))',
foreground: 'hsl(var(--popover-foreground))',
},
card: {
DEFAULT: 'hsl(var(--card))',
foreground: 'hsl(var(--card-foreground))',
},
},
borderRadius: {
lg: 'var(--radius)',
md: 'calc(var(--radius) - 2px)',
sm: 'calc(var(--radius) - 4px)',
},
keyframes: {
'accordion-down': {
from: { height: '0' },
to: { height: 'var(--radix-accordion-content-height)' },
},
'accordion-up': {
from: { height: 'var(--radix-accordion-content-height)' },
to: { height: '0' },
},
},
animation: {
'accordion-down': 'accordion-down 0.2s ease-out',
'accordion-up': 'accordion-up 0.2s ease-out',
},
},
},
// eslint-disable-next-line global-require
plugins: [require('tailwindcss-animate')],
} satisfies Config;
export default config;
================================================
FILE: app/tsconfig.json
================================================
{
"compilerOptions": {
"target": "es5",
"lib": [
"dom",
"dom.iterable",
"esnext",
],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [
{
"name": "next",
},
],
"paths": {
"@/*": [
"./src/*",
],
},
},
"include": [
"next-env.d.ts",
"**/*.ts",
"**/*.tsx",
".next/types/**/*.ts",
"build/types/**/*.ts",
"main/preload.ts",
"main/main.ts",
"out/types/**/*.ts",
".eslintrc.cjs",
],
"exclude": [
"node_modules",
"out",
"build",
],
}
================================================
FILE: runner.py
================================================
# Parent script to package (PyInstaller) server
#
# Example Usage:
#
# pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \
# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \
# runner.py
from server import server
server.main()
================================================
FILE: runner.sh
================================================
#!/bin/bash
collect_modules=(
"mlx"
"chromadb"
)
hidden_imports=(
"server.models"
"server.models.gemma"
"server.models.bert"
"server.models.llama"
)
exclude_modules=(
"matplotlib"
"pandas"
"PIL"
"IPython"
)
misc_params=(
"--copy-metadata opentelemetry-sdk"
)
command="pyinstaller --onefile runner.py"
for module in "${collect_modules[@]}"; do
command+=" --collect-all $module"
done
for module in "${hidden_imports[@]}"; do
command+=" --hidden-import $module"
done
for module in "${exclude_modules[@]}"; do
command+=" --exclude-module $module"
done
for param in "${misc_params[@]}"; do
command+=" $param"
done
eval "$command"
================================================
FILE: server/__init__.py
================================================
from .utils import generate, load, convert
__version__ = "0.1.0"
================================================
FILE: server/convert.py
================================================
import argparse
from .utils import convert
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str,
help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters, ignored if -q is given.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
if __name__ == "__main__":
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))
================================================
FILE: server/models/__init__.py
================================================
================================================
FILE: server/models/base.py
================================================
import inspect
from dataclasses import dataclass
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
================================================
FILE: server/models/bert.py
================================================
import math
import inspect
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Union, Callable
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
classifier_dropout: float
hidden_act: str
hidden_dropout_prob: float
hidden_size: int
initializer_range: float
intermediate_size: int
layer_norm_eps: float
max_position_embeddings: int
num_attention_heads: int
num_hidden_layers: int
pad_token_id: int
position_embedding_type: str
torch_dtype: str
type_vocab_size: int
use_cache: bool
vocab_size: int
chunk_size_feed_forward: int = None
attention_probs_dropout_prob: float = 0.0
is_decoder: bool = False
add_cross_attention: bool = False
output_attentions: bool = False
output_hidden_states: bool = False
use_return_dict: bool = True
def apply_chunking_to_forward(
forward_fn: Callable[..., mx.array], chunk_size: int, chunk_dim: int, *input_tensors
) -> mx.array:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
applying `forward_fn` to `input_tensors`.
Args:
forward_fn (`Callable[..., mx.array]`):
The forward function of the model.
chunk_size (`int`):
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (`int`):
The dimension over which the `input_tensors` should be chunked.
input_tensors (`Tuple[mx.array]`):
The input tensors of `forward_fn` which will be chunked
Returns:
`mx.array`: A tensor with the same shape as the `forward_fn` would have given if applied`.
Examples:
```python
# rename the usual forward() fn to forward_chunk()
def __call___chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def __call__(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
```"""
assert len(input_tensors) > 0, f"{
input_tensors} has to be a tuple/list of tensors"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(
inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {
len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {
tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {
input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(
num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk)
for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return mx.concatenate(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
self._position_ids = mx.expand_dims(
mx.arange(0, config.max_position_embeddings), axis=0)
self._token_type_ids = mx.zeros((self._position_ids.shape))
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute")
def __call__(
self,
input_ids: Optional[float] = None,
token_type_ids: Optional[float] = None,
position_ids: Optional[float] = None,
inputs_embeds: Optional[float] = None,
past_key_values_length: int = 0,
) -> mx.array:
if input_ids is not None:
input_shape = input_ids.shape
else:
input_shape = inputs_embeds.shape[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self._position_ids[:, past_key_values_length:
seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "_token_type_ids"):
buffered_token_type_ids = self._token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = mx.repeat(
buffered_token_type_ids, input_shape[0], axis=0)
token_type_ids = buffered_token_type_ids_expanded.astype(
mx.int8)
else:
token_type_ids = mx.zeros(
input_shape, dtype=mx.float16)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"Hidden size ({config.hidden_size}) is not a multiple of "
f"the number of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x: mx.array) -> mx.array:
new_x_shape = x.shape[
:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.reshape(new_x_shape)
return x.transpose([0, 2, 1, 3])
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[float] = None,
head_mask: Optional[float] = None,
encoder_hidden_states: Optional[float] = None,
encoder_attention_mask: Optional[float] = None,
past_key_value: Optional[Tuple[Tuple[float]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[mx.array]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = mx.cat([past_key_value[0], key_layer], dim=2)
value_layer = mx.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(mx.array, mx.array) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(mx.array, mx.array) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = mx.matmul(
query_layer, key_layer.transpose([0, 1, -1, -2]))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = mx.array(key_length - 1, dtype=mx.float16).reshape(
-1, 1
)
else:
position_ids_l = mx.arange(
query_length, dtype=mx.float16).reshape(-1, 1)
position_ids_r = mx.arange(
key_length, dtype=mx.float16).reshape(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = mx.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = mx.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = mx.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + \
relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = mx.softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = mx.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose([0, 2, 1, 3])
new_context_layer_shape = context_layer.shape[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (
context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertSelfAttention(
config, position_embedding_type=position_embedding_type)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[float] = None,
head_mask: Optional[float] = None,
encoder_hidden_states: Optional[float] = None,
encoder_attention_mask: Optional[float] = None,
past_key_value: Optional[Tuple[Tuple[float]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[mx.array]:
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
# add attentions if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = nn.GELU()
def __call__(self, hidden_states: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(
f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BertAttention(
config, position_embedding_type="absolute")
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[float] = None,
head_mask: Optional[float] = None,
encoder_hidden_states: Optional[float] = None,
encoder_attention_mask: Optional[float] = None,
past_key_value: Optional[Tuple[Tuple[float]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[mx.array]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:
2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
# add self attentions if we output attention weights
outputs = self_attention_outputs[1:]
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {
self} has to be instantiated with cross-attention layers"
" by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:
] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = [
BertLayer(config) for _ in range(config.num_hidden_layers)
]
self.gradient_checkpointing = False
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[float] = None,
head_mask: Optional[float] = None,
encoder_hidden_states: Optional[float] = None,
encoder_attention_mask: Optional[float] = None,
past_key_values: Optional[Tuple[Tuple[float]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple[mx.array], Dict]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return dict(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def __call__(self, hidden_states: mx.array) -> mx.array:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertModel(nn.Module):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
def __init__(self, config: ModelArgs, add_pooling_layer=True):
super().__init__()
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def get_extended_attention_mask(
self, attention_mask: mx.array, input_shape: Tuple[int], dtype: float = mx.float16
) -> mx.array:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`mx.array`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`mx.array` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if len(attention_mask.shape) == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif len(attention_mask.shape) == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
pass
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) "
f"or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.astype(
dtype=dtype) # fp16 compatibility
extended_attention_mask = (
# torch.finfo(torch.float16).min
1.0 - extended_attention_mask) * -65504
return extended_attention_mask
def get_head_mask(
self, head_mask: Optional[mx.array], num_hidden_layers: int, is_attention_chunked: bool = False
) -> mx.array:
"""
Prepare the head mask if needed.
Args:
head_mask (`mx.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
is_attention_chunked (`bool`, *optional*, defaults to `False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
`mx.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(
head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def __call__(
self,
input_ids: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None,
token_type_ids: Optional[mx.array] = None,
position_ids: Optional[mx.array] = None,
head_mask: Optional[mx.array] = None,
inputs_embeds: Optional[mx.array] = None,
encoder_hidden_states: Optional[mx.array] = None,
encoder_attention_mask: Optional[mx.array] = None,
past_key_values: Optional[List[float]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[mx.array], dict]:
r"""
encoder_hidden_states (`mx.float16` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`mx.float16` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(mx.float16))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.shape[:-1]
else:
raise ValueError(
"You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = mx.ones(
((batch_size, seq_length + past_key_values_length)))
if token_type_ids is None:
if hasattr(self.embeddings, "_token_type_ids"):
buffered_token_type_ids = self.embeddings._token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = mx.repeat(
buffered_token_type_ids, batch_size, axis=0)
token_type_ids = buffered_token_type_ids_expanded.astype(
mx.int8)
else:
token_type_ids = mx.zeros(
input_shape, dtype=mx.float16)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: mx.array = self.get_extended_attention_mask(
attention_mask, input_shape)
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(
head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs['last_hidden_state'] if return_dict else encoder_outputs
pooled_output = self.pooler(
sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return dict(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs['past_key_values'],
hidden_states=encoder_outputs['hidden_states'],
attentions=encoder_outputs['attentions'],
cross_attentions=encoder_outputs['cross_attentions'],
)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = BertModel(args)
def __call__(
self,
input_ids: mx.array,
attention_mask: mx.array = None,
):
return self.model(input_ids, attention_mask)
@staticmethod
def sanitize(weights):
# remove position_ids and add model.
return {
f'model.{k}' if not 'model' in k else k: v for k, v in weights.items() if 'embeddings.position_ids' not in k
}
@property
def layers(self):
return self.model.layers
================================================
FILE: server/models/gemma.py
================================================
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.repeats = n_heads // n_kv_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
return out, cache
@property
def layers(self):
return self.model.layers
================================================
FILE: server/models/layers.py
================================================
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
t = x.dtype
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
return weight * x + bias if weight is not None else x
class LayerNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if "weight" in self:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps)
================================================
FILE: server/models/llama.py
================================================
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if self.repeats > 1:
keys = mx.repeat(keys, self.repeats, axis=1)
values = mx.repeat(values, self.repeats, axis=1)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
================================================
FILE: server/py.typed
================================================
================================================
FILE: server/requirements.txt
================================================
chromadb==0.4.23
huggingface_hub==0.20.3
mlx==0.4.0
mlx_data==0.0.2
transformers==4.38.1
pyinstaller==6.4.0
================================================
FILE: server/retriever/document.py
================================================
from typing import Any, Literal, Optional
class Document():
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = dict()
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
type: Literal["Document"] = "Document"
def __init__(self, page_content: str, metadata: Optional[dict] = None, **kwargs: Any) -> None:
"""Pass page_content in as positional or named arg."""
self.page_content = page_content
self.metadata = metadata or dict()
for key, value in kwargs.items():
setattr(self, key, value)
================================================
FILE: server/retriever/embeddings.py
================================================
import os
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer
from abc import ABC, abstractmethod
from typing import Any, List
from ..utils import load, get_mlx_path, convert
class Embeddings(ABC):
"""Interface for embedding models."""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
class E5Embeddings(Embeddings):
model: Any = None
tokenizer: PreTrainedTokenizer = None
def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False):
mlx_path = get_mlx_path(hf_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(hf_path, mlx_path, quantize=quantize)
self.model, self.tokenizer = load(mlx_path)
def _average_pool(self, last_hidden_states: mx.array,
attention_mask: mx.array) -> mx.array:
last_hidden = mx.where(~attention_mask[..., None].astype(dtype=mx.bool_),
0.0, last_hidden_states)
return mx.sum(last_hidden, axis=1) / mx.sum(attention_mask, axis=1, keepdims=True)
def embed_documents(self, texts: List[str], batch_size: int = 8) -> List[List[float]]:
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_embeddings = self.embed_query(batch_texts, batch=True)
embeddings.extend(batch_embeddings)
return embeddings
def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
tokens = self.tokenizer(texts, max_length=512, padding=True,
truncation=True, return_tensors='np',
return_attention_mask=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
outputs = self.model(**tokens)
embeddings = self._average_pool(
outputs['last_hidden_state'], tokens['attention_mask'])
embeddings = embeddings / \
mx.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)
if batch:
return embeddings.tolist() # -> List[List[float]]
return embeddings[0].tolist() # -> List[float]
class ChatEmbeddings(Embeddings):
model: nn.Module = None
tokenizer: PreTrainedTokenizer = None
def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
self.model = model
self.tokenizer = tokenizer
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> List[float]:
h = self.model.embed_tokens(mx.array(
self.tokenizer.encode(text, add_special_tokens=False)))
# normalized to have unit length
h = mx.mean(h, axis=0)
h = h / mx.linalg.norm(h)
return h.tolist()
================================================
FILE: server/retriever/loader.py
================================================
import os
import glob
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor
from .document import Document
def directory_loader(directory: Optional[str] = None) -> Optional[List[Document]]:
if directory is not None and os.path.exists(directory):
allowed_extensions = ['.txt', '.md', '.csv', '.json', '.xml', '.ts']
def read_file(file_path):
_, file_extension = os.path.splitext(file_path)
if file_extension.lower() in allowed_extensions:
with open(file_path, 'r', encoding='utf-8') as file:
return Document(page_content=file.read(), metadata={'source': file_path})
files = glob.glob(os.path.join(directory, '**', '*.*'), recursive=True)
with ThreadPoolExecutor() as executor:
return list(filter(None, executor.map(read_file, files)))
else:
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
================================================
FILE: server/retriever/splitter.py
================================================
import re
import copy
from abc import ABC, abstractmethod
from typing import (
Any,
List,
Optional,
Callable,
Iterable
)
from .document import Document
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1]
for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class TextSplitter(ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
) -> None:
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether to keep the separator in the chunks
add_start_index: If `True`, includes chunk's start index in metadata
strip_whitespace: If `True`, strips whitespace from the start and end of
every document
"""
if chunk_overlap > chunk_size:
raise ValueError(f"Got a larger chunk overlap ({
chunk_overlap}) than chunk size ({chunk_size}), should be smaller.")
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
offset = index + previous_chunk_len - self._chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
"""Split documents."""
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
if self._strip_whitespace:
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size:
print(f"Created a chunk of size {total}, " +
f"which is longer than the specified {self._chunk_size}")
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len +
(separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
class RecursiveCharacterTextSplitter(TextSplitter):
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(
separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
================================================
FILE: server/retriever/vectorstore.py
================================================
import uuid
import functools
import mlx.core as mx
import chromadb
import chromadb.config
from chromadb.utils.batch_utils import create_batches
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
from typing import (
Any,
List,
Dict,
TypeVar,
Callable,
Iterable,
Optional,
Tuple,
Type,
)
from .document import Document
from .embeddings import Embeddings
Chroma = TypeVar('Chroma', bound='Chroma')
DEFAULT_K = 4 # Number of Documents to return.
def _results_to_docs(results: Any) -> List[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
return [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Validate exactly one arg in each group is not None."""
counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg_group in arg_groups
]
invalid_groups = [
i for i, count in enumerate(counts) if count != 1]
if invalid_groups:
invalid_group_names = [
", ".join(arg_groups[i]) for i in invalid_groups]
raise ValueError(
"Exactly one argument in each of the following"
" groups must be defined:"
f" {', '.join(invalid_group_names)}"
)
return func(*args, **kwargs)
return wrapper
return decorator
def cosine_similarity(
X: mx.array, T: mx.array, axis: int = 1
) -> mx.array:
"""Row-wise cosine similarity between two equal-width matrices."""
X, T = mx.array(X), mx.array(T)
X_norm = mx.linalg.norm(X, axis=axis)
T_norm = mx.linalg.norm(T, axis=axis)
similarity = X @ T.T / mx.outer(X_norm, T_norm)
return similarity
def maximal_marginal_relevance(
query_embedding: mx.array,
embedding_list: mx.array,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance."""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = mx.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(mx.argmax(similarity_to_query).tolist())
idxs = [most_similar]
selected = mx.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -mx.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = mx.concatenate([
selected, embedding_list[idx_to_add:idx_to_add+1]], axis=0)
return idxs
class Chroma():
"""
similarity_search
max_marginal_relevance_search
"""
_DEFAULT_COLLECTION_NAME = "mlx-chat-app"
def __init__(
self,
collection_name: str = _DEFAULT_COLLECTION_NAME,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.Client] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
) -> None:
if client is not None:
self._client_settings = client_settings
self._client = client
self._persist_directory = persist_directory
else:
if client_settings:
# If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = (
persist_directory or client_settings.persist_directory
)
if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
client_settings.chroma_db_impl = "duckdb+parquet"
_client_settings = client_settings
elif persist_directory:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
_client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
)
else:
_client_settings = chromadb.config.Settings(
is_persistent=True)
_client_settings.persist_directory = persist_directory
else:
_client_settings = chromadb.config.Settings()
self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings)
self._persist_directory = (
_client_settings.persist_directory or persist_directory
)
self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=collection_metadata,
)
self.override_relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
@xor_args(("query_texts", "query_embeddings"))
def __query_collection(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Query the chroma collection."""
return self._collection.query(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
**kwargs,
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts (Iterable[str]): Texts to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
Returns:
List[str]: List of IDs of the added texts.
"""
# TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
embeddings = None
texts = list(texts)
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(texts)
if metadatas:
# fill metadatas with empty dicts if somebody
# did not specify metadata for all texts
length_diff = len(texts) - len(metadatas)
if length_diff:
metadatas = metadatas + [{}] * length_diff
empty_ids = []
non_empty_ids = []
for idx, m in enumerate(metadatas):
if m:
non_empty_ids.append(idx)
else:
empty_ids.append(idx)
if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids]
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = (
[embeddings[idx]
for idx in non_empty_ids] if embeddings else None
)
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings_with_metadatas,
documents=texts_with_metadatas,
ids=ids_with_metadata,
)
except ValueError as e:
if "Expected metadata value to be" in str(e):
msg = (
"Try filtering complex metadata from the document using "
"langchain_community.vectorstores.utils.filter_complex_metadata."
)
raise ValueError(e.args[0] + "\n\n" + msg)
else:
raise e
if empty_ids:
texts_without_metadatas = [texts[j] for j in empty_ids]
embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert(
embeddings=embeddings_without_metadatas,
documents=texts_without_metadatas,
ids=ids_without_metadatas,
)
else:
self._collection.upsert(
embeddings=embeddings,
documents=texts,
ids=ids,
)
return ids
def similarity_search(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Run similarity search with Chroma.
Args:
query (str): Query text to search for.
k (int): Number of results to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List[Document]: List of documents most similar to the query text.
"""
docs_and_scores = self.similarity_search_with_score(
query, k, filter=filter, **kwargs
)
return [doc for doc, _ in docs_and_scores]
def similarity_search_with_score(
self,
query: str,
k: int = DEFAULT_K,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Run similarity search with Chroma with distance.
Args:
query (str): Query text to search for.
k (int): Number of results to return. Defaults to 4.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List[Tuple[Document, float]]: List of documents most similar to
the query text and cosine distance in float for each.
Lower score represents more similarity.
"""
if self._embedding_function is None:
results = self.__query_collection(
query_texts=[query],
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
else:
query_embedding = self._embedding_function.embed_query(query)
results = self.__query_collection(
query_embeddings=[query_embedding],
n_results=k,
where=filter,
where_document=where_document,
**kwargs,
)
return _results_to_docs_and_scores(results)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
results = self.__query_collection(
query_embeddings=embedding,
n_results=fetch_k,
where=filter,
where_document=where_document,
include=["metadatas", "documents", "distances", "embeddings"],
**kwargs,
)
mmr_selected = maximal_marginal_relevance(
mx.array(embedding, dtype=mx.float32),
mx.array(results["embeddings"][0]),
k=k,
lambda_mult=lambda_mult,
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(
candidates) if i in mmr_selected]
return selected_results
def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
where_document=where_document,
)
return docs
def delete_collection(self) -> None:
"""Delete the collection."""
self._client.delete_collection(self._collection.name)
def get(
self,
ids: Optional[OneOrMany[ID]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Gets the collection.
Args:
ids: The ids of the embeddings to get. Optional.
where: A Where type dict used to filter results by.
E.g. `{"color" : "red", "price": 4.20}`. Optional.
limit: The number of documents to return. Optional.
offset: The offset to start returning results from.
Useful for paging results with limit. Optional.
where_document: A WhereDocument type dict used to filter by the documents.
E.g. `{$contains: "hello"}`. Optional.
include: A list of what to include in the results.
Can contain `"embeddings"`, `"metadatas"`, `"documents"`.
Ids are always included.
Defaults to `["metadatas", "documents"]`. Optional.
"""
kwargs = {
"ids": ids,
"where": where,
"limit": limit,
"offset": offset,
"where_document": where_document,
}
if include is not None:
kwargs["include"] = include
return self._collection.get(**kwargs)
def update_document(self, document_id: str, document: Document) -> None:
"""Update a document in the collection.
Args:
document_id (str): ID of the document to update.
document (Document): Document to update.
"""
return self.update_documents([document_id], [document])
def update_documents(self, ids: List[str], documents: List[Document]) -> None:
"""Update a document in the collection.
Args:
ids (List[str]): List of ids of the document to update.
documents (List[Document]): List of documents to update.
"""
text = [document.page_content for document in documents]
metadata = [document.metadata for document in documents]
if self._embedding_function is None:
raise ValueError(
"For update, you must specify an embedding function on creation."
)
embeddings = self._embedding_function.embed_documents(text)
if hasattr(
self._collection._client, "max_batch_size"
):
for batch in create_batches(
api=self._collection._client,
ids=ids,
metadatas=metadata,
documents=text,
embeddings=embeddings,
):
self._collection.update(
ids=batch[0],
embeddings=batch[1],
documents=batch[3],
metadatas=batch[2],
)
else:
self._collection.update(
ids=ids,
embeddings=embeddings,
documents=text,
metadatas=metadata,
)
@classmethod
def from_texts(
cls: Type[Chroma],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = _DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.Client] = None,
collection_metadata: Optional[Dict] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a raw documents.
If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.
Args:
texts (List[str]): List of texts to add to the collection.
collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
ids (Optional[List[str]]): List of document IDs. Defaults to None.
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
collection_metadata (Optional[Dict]): Collection configurations.
Defaults to None.
Returns:
Chroma: Chroma vectorstore.
"""
chroma_collection = cls(
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
**kwargs,
)
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
if hasattr(
chroma_collection._client, "max_batch_size"
):
for batch in create_batches(
api=chroma_collection._client,
ids=ids,
metadatas=metadatas,
documents=texts,
):
chroma_collection.add_texts(
texts=batch[3] if batch[3] else [],
metadatas=batch[2] if batch[2] else None,
ids=batch[0],
)
else:
chroma_collection.add_texts(
texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection
@classmethod
def from_documents(
cls: Type[Chroma],
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
collection_name: str = _DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.Client] = None, # Add this line
collection_metadata: Optional[Dict] = None,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a list of documents.
If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.
Args:
collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
ids (Optional[List[str]]): List of document IDs. Defaults to None.
documents (List[Document]): List of documents to add to the vectorstore.
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
collection_metadata (Optional[Dict]): Collection configurations.
Defaults to None.
Returns:
Chroma: Chroma vectorstore.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
**kwargs,
)
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
self._collection.delete(ids=ids)
================================================
FILE: server/server.py
================================================
import os
import sys
import json
import time
import uuid
import mlx.core as mx
import mlx.nn as nn
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List, Dict, Optional
from transformers import PreTrainedTokenizer
from .utils import load, generate_step, get_mlx_path, convert
from .retriever.loader import directory_loader
from .retriever.splitter import RecursiveCharacterTextSplitter
from .retriever.vectorstore import Chroma
from .retriever.embeddings import ChatEmbeddings, E5Embeddings
_model: Optional[nn.Module] = None
_tokenizer: Optional[PreTrainedTokenizer] = None
_database: Optional[Chroma] = None
def load_model(model_path: str, adapter_file: Optional[str] = None):
global _model
global _tokenizer
models_to_quantize = ['mistral', 'llama', 'gemma']
quantize = any(variable in model_path for variable in models_to_quantize)
mlx_path = get_mlx_path(model_path, quantize=quantize)
if not os.path.isdir(mlx_path):
convert(model_path, mlx_path, quantize=quantize)
_model, _tokenizer = load(mlx_path, adapter_file=adapter_file)
def index_directory(directory: str, use_embedding: bool = True):
global _database
start_t = time.time()
raw_docs = directory_loader(directory)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=512, chunk_overlap=32, add_start_index=True
)
embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings(
model=_model.model, tokenizer=_tokenizer)
splits = text_splitter.split_documents(raw_docs)
_database = Chroma.from_documents(
documents=splits,
embedding=embedding
)
print(f'>> indexed {len(splits)} documents in',
f'{time.time() - start_t:.2f}s', flush=True)
def create_response(chat_id, prompt, tokens, text):
response = {
'id': chat_id,
'object': 'chat.completion',
'created': int(time.time()),
'model': _model.model_type,
'system_fingerprint': f'fp_{uuid.uuid4()}',
'choices': [
{
'index': 0,
'message': {
'role': 'assistant',
'content': text,
},
'logprobs': None,
'finish_reason': None,
}
],
'usage': {
'prompt_tokens': len(prompt),
'completion_tokens': len(tokens),
'total_tokens': len(prompt) + len(tokens),
},
}
return response
def format_messages(messages: List[Dict], indexed_files: Optional[str], instructions: Optional[Dict]):
personalization = instructions.get(
'personalization', '').strip().replace('\n', '; ')
response = instructions.get('response', '').strip().replace('\n', '; ')
context = f"with background knowledge of {
indexed_files.strip().replace('\n', '; ')}" if indexed_files else ''
audience = personalization if personalization else 'general'
style = response if response else 'technical, accurate, and professional'
messages[-1]['content'] = f"""
<Context>
you are my personalized AI chatbot {context}
</Context>
<Objective>
respond to the following: {messages[-1]['content']}
</Objective>
<Style>
{style}
</Style>
<Tone>
friendly, helpful, and confident
</Tone>
<Audience>
{audience}
</Audience>
<Response>
brief, concise, and to the point. Please don't start with "Sure, ..."
</Response>
""".strip()
class APIHandler(BaseHTTPRequestHandler):
def _set_headers(self, status_code=200):
self.send_response(status_code)
self.send_header('Content-type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.end_headers()
def do_OPTIONS(self):
self._set_headers(204)
def do_POST(self):
"""
Endpoint: /api/index
Desc: indexes the directory
Body:
{
directory: str
}
Endpoint: /api/init
Desc: initializes the model
Body:
{
model: str
}
Endpoint: /api/query
Desc: handles messages requests (with directory index)
Body:
{
messages: [ { role: str, content: str } ],
max_tokens: int,
repetition_penalty: float,
repetition_context_size: int,
temperature: float,
top_p: float,
instructions: {
personalization: str,
response: str
},
directory: str
}
"""
try:
post_data = self.rfile.read(int(self.headers['Content-Length']))
body = json.loads(post_data.decode('utf-8'))
method = {
'/api/index': self.index,
'/api/query': self.query,
'/api/init': self.init,
}
handle = method.get(self.path, None)
if handle is None:
self._set_headers(404)
self.wfile.write(b'Not Found')
return
response = handle(body)
self._set_headers(200)
self.wfile.write(json.dumps(response).encode('utf-8'))
except Exception as e:
print(f"Error: {e}", flush=True)
self._set_headers(500)
self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8'))
def index(self, body):
directory = body.get('directory', None)
index_directory(directory)
return {'directory': directory}
def init(self, body):
model = body.get('model', None)
load_model(model)
return {'model': model}
def query(self, body):
chat_id = f'chatcmpl-{uuid.uuid4()}'
directory = body.get('directory', None)
messages = body.get('messages', [])
instructions = body.get('instructions', None)
indexed_files = ''
if directory:
# emperically better than `similarity_search`
docs = _database.max_marginal_relevance_search(
messages[-1]['content'],
k=6 # number of documents to return
)
indexed_files = '\n'.join([doc.page_content for doc in docs])
print(body, flush=True)
print(('\n'+'--'*10+'\n').join([
f'{doc.metadata}\n{doc.page_content}' for doc in docs]), flush=True)
format_messages(messages, indexed_files, instructions)
print(messages, flush=True)
prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
), add_special_tokens=True))
max_tokens = body.get('max_tokens', 100)
repetition_penalty = body.get('repetition_penalty', None)
repetition_context_size = body.get('repetition_context_size', 20)
temperature = body.get('temperature', 1.0)
top_p = body.get('top_p', 1.0)
tokens = []
REPLACEMENT_CHAR = '\ufffd'
for (token, prob), _ in zip(
generate_step(
prompt,
_model,
temperature,
repetition_penalty,
repetition_context_size,
top_p,
),
range(max_tokens),
):
if token == _tokenizer.eos_token_id:
break
tokens.append(token.item())
text = _tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')
# TODO: GEMMA IS OBSESSED WITH "Sure, ..."
if text.startswith('Sure, '):
text = text.split('\n')
text[0] = text[0].replace('Sure, ', '').capitalize()
text = '\n'.join([l for l in text])
return create_response(chat_id, prompt, tokens, text)
def run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):
server_address = (host, port)
httpd = server_class(server_address, handler_class)
print(f'Starting httpd at {host} on port {port}...', flush=True)
httpd.serve_forever()
def main():
if len(sys.argv) < 2:
print(
"Usage: python script.py [--host <host_address>] [--port <port_number>]")
sys.exit(1)
args = {
'--host': '127.0.0.1',
'--port': 8080
}
i = 1
while i < len(sys.argv):
if sys.argv[i] in args:
args[sys.argv[i]] = sys.argv[i + 1]
i += 2
else:
print(f"Unknown argument: {sys.argv[i]}")
sys.exit(1)
# Now you can access the parsed arguments using args dictionary
host = args['--host']
port = int(args['--port'])
print(f'>> starting server on {host}:{port}', flush=True)
run(host, port)
if __name__ == '__main__':
main()
================================================
FILE: server/utils.py
================================================
import os
import copy
import glob
import shutil
import importlib
import json
import logging
import time
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
# Constants
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
}
MAX_FILE_SIZE_GB = 5
linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"server.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
],
)
)
return model_path
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
return logits
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing
one token and probability per call.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
softmax_logits = mx.softmax(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
if (
logits.dtype == mx.bfloat16
): # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfl
gitextract_9znslswk/
├── .github/
│ └── workflows/
│ └── lint.yml
├── .gitignore
├── .vscode/
│ └── settings.json
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app/
│ ├── .eslintrc.cjs
│ ├── assets/
│ │ └── icon.icns
│ ├── components.json
│ ├── dprint.json
│ ├── mac/
│ │ └── entitlements.mac.inherit.plist
│ ├── main/
│ │ ├── main.ts
│ │ ├── preload.ts
│ │ ├── renderer.d.ts
│ │ ├── splash/
│ │ │ ├── index.css
│ │ │ └── index.html
│ │ └── tsconfig.json
│ ├── next.config.js
│ ├── notarize.js
│ ├── package.json
│ ├── postcss.config.js
│ ├── src/
│ │ ├── AppProvider.tsx
│ │ ├── app/
│ │ │ ├── globals.css
│ │ │ ├── layout.tsx
│ │ │ ├── page.tsx
│ │ │ └── settings/
│ │ │ └── page.tsx
│ │ ├── components/
│ │ │ ├── chat/
│ │ │ │ ├── Chat.tsx
│ │ │ │ ├── ChatInput.tsx
│ │ │ │ ├── ChatMessage.tsx
│ │ │ │ ├── ChatMessages.tsx
│ │ │ │ └── SystemMessage.tsx
│ │ │ ├── options/
│ │ │ │ ├── SelectDirectory.tsx
│ │ │ │ └── SelectModel.tsx
│ │ │ └── ui/
│ │ │ ├── button.tsx
│ │ │ ├── input.tsx
│ │ │ ├── resizable.tsx
│ │ │ ├── select.tsx
│ │ │ ├── textarea.tsx
│ │ │ └── tooltip.tsx
│ │ ├── constants/
│ │ │ └── chat.tsx
│ │ └── lib/
│ │ ├── hooks.ts
│ │ ├── store.ts
│ │ └── utils.ts
│ ├── tailwind.config.main.js
│ ├── tailwind.config.ts
│ └── tsconfig.json
├── runner.py
├── runner.sh
└── server/
├── __init__.py
├── convert.py
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── bert.py
│ ├── gemma.py
│ ├── layers.py
│ └── llama.py
├── py.typed
├── requirements.txt
├── retriever/
│ ├── document.py
│ ├── embeddings.py
│ ├── loader.py
│ ├── splitter.py
│ └── vectorstore.py
├── server.py
└── utils.py
SYMBOL INDEX (197 symbols across 26 files)
FILE: app/main/main.ts
function handleSetTitle (line 25) | function handleSetTitle(event: any, title: string) {
class ServerManager (line 34) | class ServerManager {
method findOpenPort (line 38) | private findOpenPort(startingPort: number): Promise<number> {
method runPythonServer (line 54) | private runPythonServer(port: number): any {
method start (line 74) | start(model: string): Promise<void> {
method stop (line 119) | stop(): void {
method click (line 343) | click() {
FILE: app/main/renderer.d.ts
type Window (line 4) | interface Window {
FILE: app/src/AppProvider.tsx
function StoreProvider (line 16) | function StoreProvider({
FILE: app/src/app/layout.tsx
function RootLayout (line 16) | function RootLayout({
FILE: app/src/app/page.tsx
function Home (line 35) | function Home() {
FILE: app/src/app/settings/page.tsx
type SETTINGS (line 28) | enum SETTINGS {
function SettingsOption (line 33) | function SettingsOption({
function GeneralSettings (line 71) | function GeneralSettings() {
function PromptSettings (line 128) | function PromptSettings() {
function Settings (line 176) | function Settings() {
FILE: app/src/components/ui/button.tsx
type ButtonProps (line 40) | interface ButtonProps
FILE: app/src/components/ui/input.tsx
type InputProps (line 6) | interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}
FILE: app/src/components/ui/textarea.tsx
type TextareaProps (line 6) | interface TextareaProps extends React.TextareaHTMLAttributes<HTMLTextAre...
FILE: app/src/constants/chat.tsx
type ChatMessage (line 1) | type ChatMessage = {
FILE: app/src/lib/hooks.ts
function usePrevious (line 25) | function usePrevious<T>(value: T): T | undefined {
function convertToNiceShortcut (line 33) | function convertToNiceShortcut(shortcut: string) {
function useKeyboardShortcut (line 40) | function useKeyboardShortcut() {
FILE: app/src/lib/store.ts
type AppStore (line 45) | type AppStore = ReturnType<typeof makeStore>;
type RootState (line 47) | type RootState = ReturnType<AppStore['getState']>;
type AppDispatch (line 48) | type AppDispatch = AppStore['dispatch'];
FILE: app/src/lib/utils.ts
function cn (line 9) | function cn(...inputs: ClassValue[]) {
FILE: server/convert.py
function configure_parser (line 6) | def configure_parser() -> argparse.ArgumentParser:
FILE: server/models/base.py
class BaseModelArgs (line 6) | class BaseModelArgs:
method from_dict (line 8) | def from_dict(cls, params):
FILE: server/models/bert.py
class ModelArgs (line 13) | class ModelArgs(BaseModelArgs):
function apply_chunking_to_forward (line 40) | def apply_chunking_to_forward(
class BertEmbeddings (line 122) | class BertEmbeddings(nn.Module):
method __init__ (line 125) | def __init__(self, config):
method __call__ (line 144) | def __call__(
class BertSelfAttention (line 191) | class BertSelfAttention(nn.Module):
method __init__ (line 192) | def __init__(self, config, position_embedding_type=None):
method transpose_for_scores (line 220) | def transpose_for_scores(self, x: mx.array) -> mx.array:
method __call__ (line 226) | def __call__(
class BertSelfOutput (line 342) | class BertSelfOutput(nn.Module):
method __init__ (line 343) | def __init__(self, config):
method __call__ (line 350) | def __call__(self, hidden_states: mx.array, input_tensor: mx.array) ->...
class BertAttention (line 357) | class BertAttention(nn.Module):
method __init__ (line 358) | def __init__(self, config, position_embedding_type=None):
method __call__ (line 365) | def __call__(
class BertIntermediate (line 390) | class BertIntermediate(nn.Module):
method __init__ (line 391) | def __init__(self, config):
method __call__ (line 396) | def __call__(self, hidden_states: mx.array) -> mx.array:
class BertOutput (line 402) | class BertOutput(nn.Module):
method __init__ (line 403) | def __init__(self, config):
method __call__ (line 410) | def __call__(self, hidden_states: mx.array, input_tensor: mx.array) ->...
class BertLayer (line 417) | class BertLayer(nn.Module):
method __init__ (line 418) | def __init__(self, config):
method __call__ (line 434) | def __call__(
method feed_forward_chunk (line 504) | def feed_forward_chunk(self, attention_output):
class BertEncoder (line 510) | class BertEncoder(nn.Module):
method __init__ (line 511) | def __init__(self, config):
method __call__ (line 519) | def __call__(
class BertPooler (line 603) | class BertPooler(nn.Module):
method __init__ (line 604) | def __init__(self, config):
method __call__ (line 609) | def __call__(self, hidden_states: mx.array) -> mx.array:
class BertModel (line 618) | class BertModel(nn.Module):
method __init__ (line 631) | def __init__(self, config: ModelArgs, add_pooling_layer=True):
method get_input_embeddings (line 639) | def get_input_embeddings(self):
method set_input_embeddings (line 642) | def set_input_embeddings(self, value):
method get_extended_attention_mask (line 645) | def get_extended_attention_mask(
method get_head_mask (line 690) | def get_head_mask(
method __call__ (line 718) | def __call__(
class Model (line 847) | class Model(nn.Module):
method __init__ (line 848) | def __init__(self, args: ModelArgs):
method __call__ (line 853) | def __call__(
method sanitize (line 861) | def sanitize(weights):
method layers (line 868) | def layers(self):
FILE: server/models/gemma.py
class ModelArgs (line 12) | class ModelArgs(BaseModelArgs):
function rms_norm (line 27) | def rms_norm(x, weight, eps):
class RMSNorm (line 33) | class RMSNorm(nn.Module):
method __init__ (line 34) | def __init__(self, dims: int, eps: float = 1e-5):
method __call__ (line 39) | def __call__(self, x):
class Attention (line 43) | class Attention(nn.Module):
method __init__ (line 44) | def __init__(self, args: ModelArgs):
method __call__ (line 67) | def __call__(
class MLP (line 104) | class MLP(nn.Module):
method __init__ (line 105) | def __init__(self, dim, hidden_dim):
method __call__ (line 111) | def __call__(self, x) -> mx.array:
class TransformerBlock (line 115) | class TransformerBlock(nn.Module):
method __init__ (line 116) | def __init__(self, args: ModelArgs):
method __call__ (line 126) | def __call__(
class GemmaModel (line 139) | class GemmaModel(nn.Module):
method __init__ (line 140) | def __init__(self, args: ModelArgs):
method __call__ (line 152) | def __call__(
class Model (line 174) | class Model(nn.Module):
method __init__ (line 175) | def __init__(self, args: ModelArgs):
method __call__ (line 180) | def __call__(
method layers (line 190) | def layers(self):
FILE: server/models/layers.py
function rms_norm (line 8) | def rms_norm(x, weight, eps):
class RMSNorm (line 14) | class RMSNorm(nn.Module):
method __init__ (line 15) | def __init__(self, dims: int, eps: float = 1e-5):
method __call__ (line 20) | def __call__(self, x):
function ln_norm (line 25) | def ln_norm(x, eps, weight=None, bias=None):
class LayerNorm (line 35) | class LayerNorm(nn.Module):
method __init__ (line 36) | def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
method _extra_repr (line 44) | def _extra_repr(self):
method __call__ (line 47) | def __call__(self, x: mx.array) -> mx.array:
FILE: server/models/llama.py
class ModelArgs (line 12) | class ModelArgs(BaseModelArgs):
method __post_init__ (line 25) | def __post_init__(self):
class Attention (line 38) | class Attention(nn.Module):
method __init__ (line 39) | def __init__(self, args: ModelArgs):
method __call__ (line 68) | def __call__(
class MLP (line 105) | class MLP(nn.Module):
method __init__ (line 106) | def __init__(self, dim, hidden_dim):
method __call__ (line 112) | def __call__(self, x) -> mx.array:
class TransformerBlock (line 116) | class TransformerBlock(nn.Module):
method __init__ (line 117) | def __init__(self, args: ModelArgs):
method __call__ (line 127) | def __call__(
class LlamaModel (line 140) | class LlamaModel(nn.Module):
method __init__ (line 141) | def __init__(self, args: ModelArgs):
method __call__ (line 153) | def __call__(
class Model (line 174) | class Model(nn.Module):
method __init__ (line 175) | def __init__(self, args: ModelArgs):
method __call__ (line 181) | def __call__(
method sanitize (line 190) | def sanitize(weights):
method layers (line 197) | def layers(self):
FILE: server/retriever/document.py
class Document (line 4) | class Document():
method __init__ (line 15) | def __init__(self, page_content: str, metadata: Optional[dict] = None,...
FILE: server/retriever/embeddings.py
class Embeddings (line 12) | class Embeddings(ABC):
method embed_documents (line 16) | def embed_documents(self, texts: List[str]) -> List[List[float]]:
method embed_query (line 20) | def embed_query(self, text: str) -> List[float]:
class E5Embeddings (line 24) | class E5Embeddings(Embeddings):
method __init__ (line 29) | def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', qu...
method _average_pool (line 35) | def _average_pool(self, last_hidden_states: mx.array,
method embed_documents (line 41) | def embed_documents(self, texts: List[str], batch_size: int = 8) -> Li...
method embed_query (line 49) | def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:
class ChatEmbeddings (line 66) | class ChatEmbeddings(Embeddings):
method __init__ (line 71) | def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):
method embed_documents (line 75) | def embed_documents(self, texts: List[str]) -> List[List[float]]:
method embed_query (line 78) | def embed_query(self, text: str) -> List[float]:
FILE: server/retriever/loader.py
function directory_loader (line 9) | def directory_loader(directory: Optional[str] = None) -> Optional[List[D...
FILE: server/retriever/splitter.py
function _split_text_with_regex (line 16) | def _split_text_with_regex(
class TextSplitter (line 36) | class TextSplitter(ABC):
method __init__ (line 39) | def __init__(
method split_text (line 70) | def split_text(self, text: str) -> List[str]:
method create_documents (line 73) | def create_documents(
method split_documents (line 93) | def split_documents(self, documents: Iterable[Document]) -> List[Docum...
method _join_docs (line 101) | def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
method _merge_splits (line 110) | def _merge_splits(self, splits: Iterable[str], separator: str) -> List...
class RecursiveCharacterTextSplitter (line 152) | class RecursiveCharacterTextSplitter(TextSplitter):
method __init__ (line 159) | def __init__(
method _split_text (line 171) | def _split_text(self, text: str, separators: List[str]) -> List[str]:
method split_text (line 212) | def split_text(self, text: str) -> List[str]:
FILE: server/retriever/vectorstore.py
function _results_to_docs (line 30) | def _results_to_docs(results: Any) -> List[Document]:
function _results_to_docs_and_scores (line 34) | def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, fl...
function xor_args (line 47) | def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
function cosine_similarity (line 75) | def cosine_similarity(
function maximal_marginal_relevance (line 86) | def maximal_marginal_relevance(
class Chroma (line 121) | class Chroma():
method __init__ (line 128) | def __init__(
method embeddings (line 185) | def embeddings(self) -> Optional[Embeddings]:
method __query_collection (line 189) | def __query_collection(
method add_texts (line 208) | def add_texts(
method similarity_search (line 288) | def similarity_search(
method similarity_search_with_score (line 310) | def similarity_search_with_score(
method max_marginal_relevance_search_by_vector (line 350) | def max_marginal_relevance_search_by_vector(
method max_marginal_relevance_search (line 399) | def max_marginal_relevance_search(
method delete_collection (line 442) | def delete_collection(self) -> None:
method get (line 446) | def get(
method update_document (line 484) | def update_document(self, document_id: str, document: Document) -> None:
method update_documents (line 493) | def update_documents(self, ids: List[str], documents: List[Document]) ...
method from_texts (line 533) | def from_texts(
method from_documents (line 596) | def from_documents(
method delete (line 641) | def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
FILE: server/server.py
function load_model (line 26) | def load_model(model_path: str, adapter_file: Optional[str] = None):
function index_directory (line 40) | def index_directory(directory: str, use_embedding: bool = True):
function create_response (line 58) | def create_response(chat_id, prompt, tokens, text):
function format_messages (line 85) | def format_messages(messages: List[Dict], indexed_files: Optional[str], ...
class APIHandler (line 117) | class APIHandler(BaseHTTPRequestHandler):
method _set_headers (line 118) | def _set_headers(self, status_code=200):
method do_OPTIONS (line 126) | def do_OPTIONS(self):
method do_POST (line 129) | def do_POST(self):
method index (line 185) | def index(self, body):
method init (line 190) | def init(self, body):
method query (line 195) | def query(self, body):
function run (line 256) | def run(host: str, port: int, server_class=HTTPServer, handler_class=API...
function main (line 263) | def main():
FILE: server/utils.py
function _get_classes (line 34) | def _get_classes(config: dict):
function get_model_path (line 56) | def get_model_path(path_or_hf_repo: str) -> Path:
function apply_repetition_penalty (line 84) | def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, pe...
function generate_step (line 108) | def generate_step(
function generate (line 196) | def generate(
function load_model (line 282) | def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
function load (line 353) | def load(
function fetch_from_hub (line 390) | def fetch_from_hub(
function make_shards (line 401) | def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB)...
function upload_to_hub (line 425) | def upload_to_hub(path: str, upload_repo: str, hf_path: str):
function save_weights (line 471) | def save_weights(
function quantize_model (line 523) | def quantize_model(
function get_mlx_path (line 550) | def get_mlx_path(hf_path: str, quantize: bool = False) -> str:
function convert (line 556) | def convert(
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (218K chars).
[
{
"path": ".github/workflows/lint.yml",
"chars": 463,
"preview": "name: Lint\n\non: [pull_request]\n\njobs:\n lint:\n name: Lint\n runs-on: ubuntu-latest\n\n steps:\n - uses: action"
},
{
"path": ".gitignore",
"chars": 2202,
"preview": "# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.\n\n# dependencies\nnode_modules/\n.pnp"
},
{
"path": ".vscode/settings.json",
"chars": 1276,
"preview": "{\n \"[python]\": {\n \"editor.tabSize\": 4,\n \"editor.defaultFormatter\": \"ms-python.autopep8\",\n },\n // \"It "
},
{
"path": "CODEOWNERS",
"chars": 30,
"preview": "* @parkersm1th @stockeh\n"
},
{
"path": "CONTRIBUTING.md",
"chars": 3286,
"preview": "# Welcome to our contribution guide\n\nThank you for wanting to contribute to our project! We apprecaite any contributions"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2024 MLX Chat\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 1681,
"preview": "\n\n\n\n\n**Chat with MLX**"
},
{
"path": "app/.eslintrc.cjs",
"chars": 8829,
"preview": "const namingConventions = [\n 'error',\n {\n format: ['camelCase'],\n selector: 'default',\n },\n {\n format: ['ca"
},
{
"path": "app/components.json",
"chars": 348,
"preview": "{\n \"$schema\": \"https://ui.shadcn.com/schema.json\",\n \"style\": \"new-york\",\n \"rsc\": true,\n \"tsx\": true,\n \"tailwind\": {"
},
{
"path": "app/dprint.json",
"chars": 1091,
"preview": "{\n \"lineWidth\": 100,\n \"typescript\": {\n \"indentWidth\": 2,\n \"quoteStyle\": \"alwaysSingle\",\n \"semiColon"
},
{
"path": "app/mac/entitlements.mac.inherit.plist",
"chars": 409,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/P"
},
{
"path": "app/main/main.ts",
"chars": 11149,
"preview": "// Main File for Electron\n\nimport {\n exec,\n execFile,\n} from 'child_process';\nimport {\n app,\n BrowserWindow,\n dialo"
},
{
"path": "app/main/preload.ts",
"chars": 842,
"preview": "// eslint-disable-next-line import/no-extraneous-dependencies\nimport {\n contextBridge,\n ipcRenderer,\n} from 'electron'"
},
{
"path": "app/main/renderer.d.ts",
"chars": 135,
"preview": "import { electronAPI } from \"./preload\";\n\ndeclare global {\n interface Window {\n electronAPI: typeof electronAPI;\n }"
},
{
"path": "app/main/splash/index.css",
"chars": 581,
"preview": "@tailwind base;\n\n@tailwind components;\n\n@tailwind utilities;\n\ndiv {\n -webkit-user-select: none;\n -webkit-app-region: d"
},
{
"path": "app/main/splash/index.html",
"chars": 455,
"preview": "<!DOCTYPE html>\n<html>\n <head>\n <meta charset=\"UTF-8\" />\n <title>FLOATING LOADING SCREEN</title>\n <link rel=\"s"
},
{
"path": "app/main/tsconfig.json",
"chars": 676,
"preview": "{\n \"compilerOptions\": {\n \"allowJs\": true,\n \"alwaysStrict\": true,\n \"esModuleInterop\": true,\n \"forceConsisten"
},
{
"path": "app/next.config.js",
"chars": 133,
"preview": "/** @type {import('next').NextConfig} */\nconst nextConfig = {\n output: \"export\",\n distDir: \"out\",\n};\n\nmodule.exports ="
},
{
"path": "app/notarize.js",
"chars": 515,
"preview": "require('dotenv').config();\nconst { notarize } = require('electron-notarize');\n\nexports.default = async function notariz"
},
{
"path": "app/package.json",
"chars": 4222,
"preview": "{\n \"name\": \"electron-app\",\n \"productName\": \"Electron App\",\n \"version\": \"0.1.0\",\n \"private\": true,\n \"main\": \"main/ou"
},
{
"path": "app/postcss.config.js",
"chars": 82,
"preview": "module.exports = {\n plugins: {\n tailwindcss: {},\n autoprefixer: {},\n },\n}\n"
},
{
"path": "app/src/AppProvider.tsx",
"chars": 514,
"preview": "'use client';\n\nimport {\n useRef,\n} from 'react';\nimport {\n Provider,\n} from 'react-redux';\nimport type {\n AppStore,\n}"
},
{
"path": "app/src/app/globals.css",
"chars": 3008,
"preview": "@tailwind base;\n@tailwind components;\n@tailwind utilities;\n\n@layer base {\n :root {\n --background: 0 0% 100%;\n --f"
},
{
"path": "app/src/app/layout.tsx",
"chars": 781,
"preview": "'use client';\n\nimport StoreProvider from '../AppProvider';\nimport './globals.css';\nimport '@fortawesome/fontawesome-svg-"
},
{
"path": "app/src/app/page.tsx",
"chars": 4140,
"preview": "'use client';\n\nimport {\n faBan,\n faCheckCircle,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n FontAwesomeIcon"
},
{
"path": "app/src/app/settings/page.tsx",
"chars": 6432,
"preview": "'use client';\n\nimport type {\n IconProp,\n} from '@fortawesome/fontawesome-svg-core';\nimport {\n faCog,\n faMessage,\n} fr"
},
{
"path": "app/src/components/chat/Chat.tsx",
"chars": 2889,
"preview": "import React from 'react';\nimport type {\n ChatMessage,\n} from '../../constants/chat';\nimport {\n useAppDispatch,\n} from"
},
{
"path": "app/src/components/chat/ChatInput.tsx",
"chars": 1648,
"preview": "import React, {\n useEffect,\n useRef,\n useState,\n} from 'react';\nimport {\n useAppSelector,\n} from '../../lib/hooks';\n"
},
{
"path": "app/src/components/chat/ChatMessage.tsx",
"chars": 641,
"preview": "import Markdown from 'markdown-to-jsx';\nimport React from 'react';\nimport type {\n ChatMessage,\n} from '../../constants/"
},
{
"path": "app/src/components/chat/ChatMessages.tsx",
"chars": 2487,
"preview": "/* eslint-disable function-paren-newline */\nimport {\n faCircleNotch,\n} from '@fortawesome/free-solid-svg-icons';\nimport"
},
{
"path": "app/src/components/chat/SystemMessage.tsx",
"chars": 1444,
"preview": "import React from 'react';\nimport type {\n ChatMessage,\n} from '../../constants/chat';\n\nconst Message = ({\n message,\n}:"
},
{
"path": "app/src/components/options/SelectDirectory.tsx",
"chars": 2526,
"preview": "import {\n faCheckCircle,\n faCircleNotch,\n faXmark,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n FontAwesome"
},
{
"path": "app/src/components/options/SelectModel.tsx",
"chars": 1141,
"preview": "import React from 'react';\nimport {\n Select,\n SelectContent,\n SelectGroup,\n SelectItem,\n SelectLabel,\n SelectTrigg"
},
{
"path": "app/src/components/ui/button.tsx",
"chars": 1925,
"preview": "import {\n Slot,\n} from '@radix-ui/react-slot';\nimport {\n cva,\n type VariantProps,\n} from 'class-variance-authority';\n"
},
{
"path": "app/src/components/ui/input.tsx",
"chars": 777,
"preview": "import * as React from 'react';\nimport {\n cn,\n} from '../../lib/utils';\n\nexport interface InputProps extends React.Inpu"
},
{
"path": "app/src/components/ui/resizable.tsx",
"chars": 1773,
"preview": "'use client';\n\nimport {\n DragHandleDots2Icon,\n} from '@radix-ui/react-icons';\nimport * as ResizablePrimitive from 'reac"
},
{
"path": "app/src/components/ui/select.tsx",
"chars": 5719,
"preview": "'use client';\n\nimport {\n CaretSortIcon,\n CheckIcon,\n ChevronDownIcon,\n ChevronUpIcon,\n} from '@radix-ui/react-icons'"
},
{
"path": "app/src/components/ui/textarea.tsx",
"chars": 711,
"preview": "import * as React from 'react';\nimport {\n cn,\n} from '../../lib/utils';\n\nexport interface TextareaProps extends React.T"
},
{
"path": "app/src/components/ui/tooltip.tsx",
"chars": 1168,
"preview": "'use client';\n\nimport * as TooltipPrimitive from '@radix-ui/react-tooltip';\nimport * as React from 'react';\nimport {\n c"
},
{
"path": "app/src/constants/chat.tsx",
"chars": 98,
"preview": "export type ChatMessage = {\n role: 'user' | 'assistant' | 'system';\n content: string | null;\n};\n"
},
{
"path": "app/src/lib/hooks.ts",
"chars": 2301,
"preview": "import {\n useEffect,\n useRef,\n useState,\n} from 'react';\nimport {\n useDispatch,\n useSelector,\n useStore,\n} from 'r"
},
{
"path": "app/src/lib/store.ts",
"chars": 1295,
"preview": "import {\n configureStore,\n createSlice,\n} from '@reduxjs/toolkit';\n\nconst globalSlice = createSlice({\n name: 'global'"
},
{
"path": "app/src/lib/utils.ts",
"chars": 177,
"preview": "import {\n type ClassValue,\n clsx,\n} from 'clsx';\nimport {\n twMerge,\n} from 'tailwind-merge';\n\nexport function cn(...i"
},
{
"path": "app/tailwind.config.main.js",
"chars": 261,
"preview": "/** @type {import('tailwindcss').Config} */\nmodule.exports = {\n content: [\"./main/**/*.{js,ts,jsx,tsx,mdx,html}\"],\n //"
},
{
"path": "app/tailwind.config.ts",
"chars": 2288,
"preview": "/* eslint-disable @typescript-eslint/naming-convention */\nimport type {\n Config,\n} from 'tailwindcss';\n\nconst config = "
},
{
"path": "app/tsconfig.json",
"chars": 859,
"preview": "{\n \"compilerOptions\": {\n \"target\": \"es5\",\n \"lib\": [\n \"dom\",\n \"dom.iterable\",\n \"esnext\",\n ],\n "
},
{
"path": "runner.py",
"chars": 341,
"preview": "# Parent script to package (PyInstaller) server\n#\n# Example Usage:\n#\n# pyinstaller --onefile --collect-all mlx --copy-me"
},
{
"path": "runner.sh",
"chars": 662,
"preview": "#!/bin/bash\n\ncollect_modules=(\n \"mlx\"\n \"chromadb\"\n)\n\nhidden_imports=(\n \"server.models\"\n \"server.models.gemma\"\n \"ser"
},
{
"path": "server/__init__.py",
"chars": 66,
"preview": "from .utils import generate, load, convert\n\n__version__ = \"0.1.0\"\n"
},
{
"path": "server/convert.py",
"chars": 1442,
"preview": "import argparse\n\nfrom .utils import convert\n\n\ndef configure_parser() -> argparse.ArgumentParser:\n \"\"\"\n Configures "
},
{
"path": "server/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "server/models/base.py",
"chars": 314,
"preview": "import inspect\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass BaseModelArgs:\n @classmethod\n def from_dict(cl"
},
{
"path": "server/models/bert.py",
"chars": 37310,
"preview": "import math\nimport inspect\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom dataclasses import dataclass\nfrom typing impo"
},
{
"path": "server/models/gemma.py",
"chars": 6069,
"preview": "from dataclasses import dataclass\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport mlx.core as m"
},
{
"path": "server/models/layers.py",
"chars": 1448,
"preview": "from functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\n@partial(mx.compile, shapeless=True)\ndef rms_"
},
{
"path": "server/models/llama.py",
"chars": 6603,
"preview": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn a"
},
{
"path": "server/py.typed",
"chars": 1,
"preview": "\n"
},
{
"path": "server/requirements.txt",
"chars": 108,
"preview": "chromadb==0.4.23\nhuggingface_hub==0.20.3\nmlx==0.4.0\nmlx_data==0.0.2\ntransformers==4.38.1\npyinstaller==6.4.0\n"
},
{
"path": "server/retriever/document.py",
"chars": 697,
"preview": "from typing import Any, Literal, Optional\n\n\nclass Document():\n \"\"\"Class for storing a piece of text and associated me"
},
{
"path": "server/retriever/embeddings.py",
"chars": 3031,
"preview": "import os\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom transformers import PreTrainedTokenizer\nfrom abc import ABC, a"
},
{
"path": "server/retriever/loader.py",
"chars": 964,
"preview": "import os\nimport glob\nfrom typing import List, Optional\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom .documen"
},
{
"path": "server/retriever/splitter.py",
"chars": 8120,
"preview": "import re\nimport copy\n\nfrom abc import ABC, abstractmethod\nfrom typing import (\n Any,\n List,\n Optional,\n Cal"
},
{
"path": "server/retriever/vectorstore.py",
"chars": 24443,
"preview": "import uuid\nimport functools\nimport mlx.core as mx\n\nimport chromadb\nimport chromadb.config\n\nfrom chromadb.utils.batch_ut"
},
{
"path": "server/server.py",
"chars": 9136,
"preview": "import os\nimport sys\nimport json\nimport time\nimport uuid\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom http.server im"
},
{
"path": "server/utils.py",
"chars": 18938,
"preview": "import os\nimport copy\nimport glob\nimport shutil\nimport importlib\nimport json\nimport logging\nimport time\nfrom pathlib imp"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the mlx-chat/mlx-chat-app GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (201.3 KB), approximately 50.2k tokens, and a symbol index with 197 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.