[
  {
    "path": ".github/workflows/lint.yml",
    "content": "name: Lint\n\non: [pull_request]\n\njobs:\n  lint:\n    name: Lint\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v1\n        with:\n          fetch-depth: 1\n\n      - name: Use Node.js 16\n        uses: actions/setup-node@v1\n        with:\n          node-version: 16\n\n      - name: Install App Deps\n        run: npm i --ignore-scripts\n        working-directory: ./app\n      - name: Lint App\n        working-directory: ./app\n        run: npm run lint\n"
  },
  {
    "path": ".gitignore",
    "content": "# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.\n\n# dependencies\nnode_modules/\n.pnp/\n.pnp.js\n\n# testing\ncoverage/\n\n# next.js\n.next/\nout/\n\n# production\nbuild/\napp/main/tailwind.css\ndist/\n\n# misc\n.DS_Store\n*.pem\n\n# debug\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\n\n# local env files\n.env*.local\n\n# vercel\n.vercel\n\n# typescript\n*.tsbuildinfo\nnext-env.d.ts\n\n# Byte-compiled / optimized / DLL files\n __pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nserver/lib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".vscode/settings.json",
    "content": "{\n    \"[python]\": {\n      \"editor.tabSize\": 4,\n      \"editor.defaultFormatter\": \"ms-python.autopep8\",\n    },\n    // \"It is recommended that either ESLint or cspell checks a file, but not both.\"\n    // https://www.npmjs.com/package/@cspell/eslint-plugin\n    \"cSpell.enableFiletypes\": [\n      \"!javascript\",\n      \"!typescript\",\n    ],\n    \"css.format.spaceAroundSelectorSeparator\": true,\n    \"editor.codeActionsOnSave\": [\n      \"source.fixAll.eslint\",\n    ],\n    \"eslint.codeActionsOnSave.mode\": \"problems\",\n    \"eslint.options\": {\n      \"reportUnusedDisableDirectives\": \"error\",\n    },\n    \"eslint.rules.customizations\": [\n      { \"rule\": \"*\", \"severity\": \"warn\" },\n    ],\n    \"editor.defaultFormatter\": \"dprint.dprint\",\n    \"editor.formatOnSave\": true,\n    \"editor.tabSize\": 2,\n    \"editor.wordWrapColumn\": 100,\n    \"eslint.workingDirectories\": [\n      \"./app\",\n    ],\n    \"files.insertFinalNewline\": true,\n    \"files.trimFinalNewlines\": true,\n    \"git.allowForcePush\": true,\n    \"git.inputValidationSubjectLength\": 100,\n    \"git.inputValidationLength\": 100,\n    \"javascript.preferences.quoteStyle\": \"single\",\n    \"typescript.preferences.quoteStyle\": \"single\",\n    \"scss.format.spaceAroundSelectorSeparator\": true,\n    \"typescript.tsdk\": \"node_modules/typescript/lib\",\n  }\n  "
  },
  {
    "path": "CODEOWNERS",
    "content": "*       @parkersm1th @stockeh\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Welcome to our contribution guide\n\nThank you for wanting to contribute to our project! We apprecaite any contributions that you make.\n\nChat with MLX is an open source project and we love to receive contributions from our community — you! There are many ways to contribute, from writing tutorials or blog posts, improving the documentation, submitting bug reports and feature requests or writing code which can be incorporated into the application itself.\n\n## New Contributor Guide\n\nTo get an overview of the project, read the [README](https://github.com/mlx-chat/mlx-chat-app/blob/main/README.md). Here are some resources to help you get started with open source contributions:\n\n- [Ways to contribute on GitHub](https://docs.github.com/en/get-started/exploring-projects-on-github/finding-ways-to-contribute-to-open-source-on-github)\n- [Setup Git](https://docs.github.com/en/get-started/quickstart/set-up-git)\n- [GitHub workflow](https://docs.github.com/en/get-started/quickstart/github-flow)\n- [Collaborating with pull requests](https://docs.github.com/en/github/collaborating-with-pull-requests)\n\n## Getting Started\n\n### Issues\n\n**Create**: If you spot a problem, [search if an issue already exists](https://docs.github.com/en/github/searching-for-information-on-github/searching-on-github/searching-issues-and-pull-requests#search-by-the-title-body-or-comments). If a related issue doesn't exist, you can open a new issue!\n\n**Solve**: Scan through our [existing issues](https://github.com/mlx-chat/mlx-chat-app) to find one that interests you. If you find an issue to work on, you are welcome to assign it to yourself and open a PR with a fix.\n\n### Make Changes\n\n1. Create your own fork of the code\n2. Create a working branch and start with your changes\n3. Commit and send a pull request \n\n### Pull Request\n\nWhen you're finished with the changes, create a pull request.\n- Check to see your pull request passes our continuous integration (CI). If you cannot get a certain integration test to pass, let us know. We can assist you in fixing these issues or approve a merge manually.\n- Make sure your additions are properly documented!\n- Don't forget to [link PR to issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) if you are solving one.\n- We may ask for changes to be made before a PR can be merged, either using [suggested changes](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/incorporating-feedback-in-your-pull-request) or pull request comments. You can apply suggested changes directly through the UI. You can make any other changes in your fork, then commit them to your branch.\n- As you update your PR and apply changes, mark each conversation as [resolved](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/commenting-on-a-pull-request#resolving-conversations).\n- If you run into any merge issues, checkout this [git tutorial](https://github.com/skills/resolve-merge-conflicts) to help you resolve merge conflicts and other issues.\n\n### Your PR is Merged!\n\nCongratulations :tada::tada: we thank you! :sparkles:\n\nOnce your PR is merged, your contributions will be publicly visible in the [Chat with MLX Repository](https://github.com/mlx-chat/mlx-chat-app).\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 MLX Chat\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "\n![](docs/design-logo-light.png#gh-light-mode-only)\n![](docs/design-logo-dark.png#gh-dark-mode-only)\n\n\n**Chat with MLX** is a high-performance macOS application that connects your local documents to a personalized large language model (LLM). By leveraging retrieval-augmented generation (RAG), open source LLMs, and MLX for accelerated machine learning on Apple silicon, you can efficently search, query, and interact with your documents *without information ever leaving your device.*\n\nOur high-level features include:\n- **Query**: load and search with document-specific prompts\n- **Converse**: switch model interaction modes (converse vs. assist) in real time\n- **Instruct**: provide personalization and response tuning\n\n## Installation and Setup\n\n:warning: **Preliminary Steps**: we are working to release with correct packaging ([pyinstaller](https://github.com/pyinstaller/pyinstaller/) & [electron-builder](https://github.com/electron-userland/electron-builder)) and authentication ([Apple codesign](https://developer.apple.com/support/code-signing/)). In the interium, please clone and run in development by first setting up authentication and requirements. \n\nFirst, setup huggingface [access tokens](https://huggingface.co/settings/tokens) to download models (request access to [google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)), then\n```bash\nhuggingface-cli login\n```\nThen download the npm/python requirements\n```bash\ncd app && npm install\npip install -r server/requirements.txt\n```\nFinally, start the application\n```bash\ncd app && npm run dev\n```\n\n## Contributions\nAll contributions are welcome. Please take a look at [contributing](CONTRIBUTING.md) guide.\n"
  },
  {
    "path": "app/.eslintrc.cjs",
    "content": "const namingConventions = [\n  'error',\n  {\n    format: ['camelCase'],\n    selector: 'default',\n  },\n  {\n    format: ['camelCase', 'UPPER_CASE'],\n    selector: 'variable',\n  },\n  {\n    format: ['camelCase', 'UPPER_CASE', 'PascalCase'],\n    modifiers: ['const', 'exported', 'global'],\n    selector: 'variable',\n  },\n  {\n    format: ['camelCase'],\n    leadingUnderscore: 'allow',\n    selector: 'parameter',\n  },\n  {\n    format: ['camelCase'],\n    leadingUnderscore: 'allow',\n    modifiers: ['private'],\n    selector: 'memberLike',\n  },\n  {\n    format: ['PascalCase', 'UPPER_CASE'],\n    selector: ['enum', 'enumMember'],\n  },\n  {\n    format: ['PascalCase'],\n    selector: 'typeLike',\n  },\n  {\n    format: null,\n    modifiers: ['destructured'],\n    selector: 'variable',\n  },\n  {\n    format: null,\n    modifiers: ['requiresQuotes'],\n    selector: [\n      'classProperty',\n      'objectLiteralProperty',\n      'typeProperty',\n      'classMethod',\n      'objectLiteralMethod',\n      'typeMethod',\n      'accessor',\n      'enumMember',\n    ],\n  },\n  {\n    format: ['camelCase', 'PascalCase', 'UPPER_CASE'],\n    leadingUnderscore: 'allow',\n    selector: 'import',\n  },\n];\n\nconst tsxNamingConventions = [\n  {\n    format: ['camelCase', 'PascalCase', 'UPPER_CASE'],\n    leadingUnderscore: 'forbid',\n    modifiers: ['global'],\n    selector: ['variable', 'function'],\n  },\n];\n\nmodule.exports = {\n  env: {\n    es2020: true,\n  },\n  extends: [\n    'airbnb-base',\n    'plugin:jsdoc/recommended',\n    'plugin:@typescript-eslint/recommended',\n    'plugin:import/typescript',\n    'plugin:no-unsanitized/DOM',\n  ],\n  ignorePatterns: [\n    'node_modules',\n    'main',\n    '.eslintrc.*',\n    'out',\n  ],\n  overrides: [\n    // Config files\n    {\n      files: [\n        'common/**/*.ts*',\n        '**/app*config.ts',\n        '**/app*Config.ts',\n      ],\n      rules: {\n        '@typescript-eslint/member-ordering': ['error', { default: { order: 'alphabetically' } }],\n        'sort-keys': ['error', 'asc', { minKeys: 2, natural: true }],\n      },\n    },\n    {\n      files: [\n        '*.ts*',\n      ],\n      rules: {\n        '@typescript-eslint/no-shadow': 'error',\n        '@typescript-eslint/no-unused-vars': 'off', // Using unused-imports plugin instead\n        '@typescript-eslint/space-before-function-paren': [\n          'error',\n          {\n            anonymous: 'never',\n            asyncArrow: 'always',\n            named: 'never',\n          },\n        ],\n        'no-redeclare': 'off', // @typescript-eslint/no-redeclare is enabled and is more correct\n        'no-shadow': 'off', // @typescript-eslint/no-shadow is enabled and is more correct\n        'no-undef-init': 'off',\n        'no-unused-vars': 'off', // Using unused-imports plugin instead\n        'space-before-function-paren': 'off', // Using @typescript-eslint/space-before-function-paren instead\n        'unused-imports/no-unused-imports': 'error',\n        'unused-imports/no-unused-vars': ['error', {\n          args: 'after-used',\n          argsIgnorePattern: '^_',\n          destructuredArrayIgnorePattern: '^_',\n          ignoreRestSiblings: true,\n        }],\n      },\n    },\n    {\n      files: [\n        'src/**/*.ts*',\n      ],\n      rules: {\n        '@typescript-eslint/await-thenable': 'error',\n        '@typescript-eslint/dot-notation': ['error', { allowIndexSignaturePropertyAccess: true }],\n        '@typescript-eslint/no-base-to-string': ['error', {\n          ignoredTypeNames: ['Error', 'RegExp'],\n        }],\n        '@typescript-eslint/no-floating-promises': 'error',\n        '@typescript-eslint/no-for-in-array': 'error',\n        '@typescript-eslint/no-misused-promises': ['error', { checksVoidReturn: false }],\n        '@typescript-eslint/no-throw-literal': 'error',\n        '@typescript-eslint/no-unnecessary-condition': 'error',\n        '@typescript-eslint/no-unnecessary-type-assertion': 'error',\n        '@typescript-eslint/non-nullable-type-assertion-style': 'error',\n        '@typescript-eslint/prefer-includes': 'error',\n        '@typescript-eslint/prefer-optional-chain': 'error',\n        '@typescript-eslint/prefer-string-starts-ends-with': 'error',\n        '@typescript-eslint/require-await': 'error',\n        '@typescript-eslint/space-infix-ops': 'error',\n        'dot-notation': 'off',\n        'no-throw-literal': 'off',\n        'require-await': 'off',\n        'space-infix-ops': 'off',\n      },\n    },\n    {\n      files: [\n        '*.tsx',\n      ],\n      rules: {\n        '@typescript-eslint/naming-convention': [\n          ...namingConventions,\n          ...tsxNamingConventions,\n        ],\n        '@typescript-eslint/require-await': 'error',\n        'require-await': 'off',\n      },\n    },\n  ],\n  parser: '@typescript-eslint/parser',\n  parserOptions: {\n    project: 'tsconfig.json',\n  },\n  plugins: [\n    'react',\n    '@typescript-eslint',\n    'jest-formatting',\n    'modules-newlines',\n    'unused-imports',\n  ],\n  root: true,\n  rules: {\n    '@typescript-eslint/ban-types': [\n      'error',\n      {\n        extendDefaults: true,\n        types: {\n          object: {\n            message: [\n              'The `object` type is currently hard to use ([see this issue](https://github.com/microsoft/TypeScript/issues/21732)).',\n              'Consider using `Record<string, unknown>` instead, as it allows you to more easily inspect and use the keys.',\n            ].join('\\n'),\n          },\n        },\n      },\n    ],\n    'implicit-arrow-linebreak': 'off',\n    '@typescript-eslint/consistent-type-assertions': ['error', { assertionStyle: 'never' }],\n    '@typescript-eslint/consistent-type-imports': 'error',\n    '@typescript-eslint/init-declarations': 'error',\n    '@typescript-eslint/member-ordering': 'error',\n    '@typescript-eslint/naming-convention': namingConventions,\n    '@typescript-eslint/no-explicit-any': 'error',\n    '@typescript-eslint/no-non-null-asserted-nullish-coalescing': 'error',\n    '@typescript-eslint/no-use-before-define': ['error', {\n      functions: false,\n    }],\n    '@typescript-eslint/prefer-for-of': 'error',\n    '@typescript-eslint/type-annotation-spacing': 'error',\n    'array-element-newline': ['error', 'consistent'],\n    'block-spacing': 'off',\n    camelcase: 'off', // Using @typescript-eslint/naming-convention instead.\n    'comma-dangle': 'off',\n    'default-param-last': 'off',\n    'import/extensions': 'off',\n    'import/no-relative-packages': 'off',\n    'import/order': 'off',\n    'import/prefer-default-export': 'off',\n    'jsdoc/check-indentation': ['error', { excludeTags: ['description', 'example'] }],\n    'jsdoc/check-line-alignment': 'error',\n    'jsdoc/check-tag-names': ['error', {\n      definedTags: ['jest-environment', 'jest-environment-options'],\n    }],\n    'jsdoc/no-bad-blocks': 'error',\n    'jsdoc/no-multi-asterisks': 'off',\n    'jsdoc/no-undefined-types': 'off',\n    'jsdoc/require-jsdoc': 'off',\n    'jsdoc/require-param': 'off',\n    'jsdoc/require-param-description': 'off',\n    'jsdoc/require-param-name': 'off',\n    'jsdoc/require-param-type': 'off',\n    'jsdoc/require-property': 'off',\n    'jsdoc/require-property-description': 'off',\n    'jsdoc/require-property-name': 'off',\n    'jsdoc/require-property-type': 'off',\n    'jsdoc/require-returns': 'off',\n    'jsdoc/require-returns-description': 'off',\n    'jsdoc/require-returns-type': 'off',\n    'jsdoc/require-yields': 'off',\n    'jsdoc/require-yields-check': 'off',\n    'jsdoc/tag-lines': ['error', 'any', { startLines: 1 }],\n    'max-classes-per-file': 'off',\n    'max-len': ['error', {\n      code: 100,\n      ignorePattern: '(/* eslint |eslint-disable-next-line |@ts-expect-error )',\n      ignoreRegExpLiterals: true,\n      ignoreStrings: true,\n      ignoreTemplateLiterals: true,\n      ignoreUrls: true,\n    }],\n    'max-params': ['error', 3],\n    'new-cap': [\n      'error',\n      {\n        capIsNew: true,\n        capIsNewExceptions: [\n          'express.Router',\n          'Immutable.Map',\n          'Immutable.Set',\n          'Immutable.List',\n          'RightRailView',\n          'URLWithSearchParams',\n        ],\n        newIsCap: true,\n        newIsCapExceptions: [],\n        properties: true,\n      },\n    ],\n    'no-console': 'error',\n    'no-continue': 'off',\n    'no-empty-function': 'off',\n    'no-promise-executor-return': 'off',\n    'no-redeclare': 'error',\n    'no-restricted-properties': [\n      'error',\n    ],\n    'no-restricted-syntax': [\n      'error',\n    ],\n    'no-use-before-define': 'off',\n    'no-void': ['error', { allowAsStatement: true }],\n    'padding-line-between-statements': [\n      'error',\n      { blankLine: 'never', next: 'import', prev: 'import' },\n    ],\n    'prefer-arrow-callback': ['error', { allowNamedFunctions: true }],\n    'prefer-exponentiation-operator': 'off',\n    'prefer-regex-literals': 'off',\n  },\n  settings: {\n    'import/typescript': {\n      typescript: {},\n    },\n  },\n};\n"
  },
  {
    "path": "app/components.json",
    "content": "{\n  \"$schema\": \"https://ui.shadcn.com/schema.json\",\n  \"style\": \"new-york\",\n  \"rsc\": true,\n  \"tsx\": true,\n  \"tailwind\": {\n    \"config\": \"tailwind.config.ts\",\n    \"css\": \"main/splash/index.css\",\n    \"baseColor\": \"slate\",\n    \"cssVariables\": true,\n    \"prefix\": \"\"\n  },\n  \"aliases\": {\n    \"components\": \"@/components\",\n    \"utils\": \"@/lib/utils\"\n  }\n}"
  },
  {
    "path": "app/dprint.json",
    "content": "{\n    \"lineWidth\": 100,\n    \"typescript\": {\n      \"indentWidth\": 2,\n      \"quoteStyle\": \"alwaysSingle\",\n      \"semiColons\": \"always\",\n      \"quoteProps\": \"asNeeded\",\n      \"useBraces\": \"always\",\n      \"trailingCommas\": \"onlyMultiLine\",\n      \"module.sortImportDeclarations\": \"caseInsensitive\",\n      \"exportDeclaration.forceMultiLine\": true,\n      \"importDeclaration.forceMultiLine\": true\n    },\n    \"json\": {\n      \"jsonTrailingCommaFiles\": [\n        \".vscode/launch.json\",\n        \".vscode/extensions.json\",\n        \".vscode/settings.json\",\n        \".vscode/tasks.json\",\n        \"tsconfig.json\"\n      ]\n    },\n    \"excludes\": [\n      \"**/node_modules\",\n      \"**/*-lock.json\",\n      \"**/Dockerfile\",\n      \"**/src/ui-tests/fixtures/**/*\",\n      \"**/storybook-static/**/*\",\n      \"**/build/**/*\",\n      \"**/dist/**/*\",\n      \"**/artifacts/**/*\",\n      \"extension/src/assets/**/*.json\"\n    ],\n    \"plugins\": [\n      \"https://plugins.dprint.dev/typescript-0.88.3.wasm\",\n      \"https://plugins.dprint.dev/json-0.19.0.wasm\",\n      \"https://plugins.dprint.dev/dockerfile-0.3.0.wasm\"\n    ]\n  }\n  "
  },
  {
    "path": "app/mac/entitlements.mac.inherit.plist",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n  <dict>\n    <key>com.apple.security.cs.allow-jit</key>\n    <true/>\n    <key>com.apple.security.cs.allow-unsigned-executable-memory</key>\n    <true/>\n    <key>com.apple.security.cs.disable-library-validation</key>\n    <true/>\n  </dict>\n</plist>\n"
  },
  {
    "path": "app/main/main.ts",
    "content": "// Main File for Electron\n\nimport {\n  exec,\n  execFile,\n} from 'child_process';\nimport {\n  app,\n  BrowserWindow,\n  dialog,\n  globalShortcut,\n  ipcMain,\n  Menu,\n  nativeImage,\n  Tray,\n} from 'electron';\nimport * as contextMenu from 'electron-context-menu';\nimport Store from 'electron-store';\nimport * as net from 'net';\n\nconst path = require('path');\nconst serve = require('electron-serve');\nconst { spawn } = require('child_process');\n\nfunction handleSetTitle(event: any, title: string) {\n  const webContents = event.sender;\n  const win = BrowserWindow.fromWebContents(webContents);\n  if (win !== null) {\n    win.setTitle(title);\n  }\n}\n\n// Python Server\nclass ServerManager {\n  private serverProcess: any | null = null;\n  public port: number | null = null;\n\n  private findOpenPort(startingPort: number): Promise<number> {\n    return new Promise<number>((resolve) => {\n      const server = net.createServer();\n\n      server.listen(startingPort, () => {\n        const port = (server.address() as net.AddressInfo).port;\n        server.close(() => resolve(port));\n      });\n\n      server.on(\n        'error',\n        (err: any) => err.code === 'EADDRINUSE' && resolve(this.findOpenPort(startingPort + 1)),\n      );\n    });\n  }\n\n  private runPythonServer(port: number): any {\n    const args = ['--host 127.0.0.1', `--port ${port}`];\n    const modifiedArgs = args.flatMap(arg => arg.split(/\\s+/));\n    const pythonProcess = isProd\n      ? execFile(path.join(process.resourcesPath, 'server', 'runner'), modifiedArgs)\n      : spawn('python', ['-m', 'server.server', ...modifiedArgs], {\n        cwd: '../',\n      });\n    pythonProcess.stdout.on(\n      'data',\n      (data: Buffer) => console.log('Server output:', data.toString('utf8')),\n    );\n    pythonProcess.stderr.on(\n      'data',\n      (data: Buffer) => console.log(`Server error: ${data.toString('utf8')}`),\n    );\n\n    return pythonProcess;\n  }\n\n  start(model: string): Promise<void> {\n    return new Promise<void>((resolve, reject) => {\n      this.stop();\n\n      this.findOpenPort(8080).then((port) => {\n        this.port = port;\n        console.log(`APP: Starting server for model: ${model} on port: ${port}`);\n        this.serverProcess = this.runPythonServer(port);\n\n        this.serverProcess.stdout.on('data', async (data: Buffer) => {\n          const output = data.toString('utf8');\n\n          await new Promise((resolve) => setTimeout(resolve, 1000));\n\n          // Check if the server is ready\n          if (output.includes('Starting httpd')) {\n            fetch(`http://127.0.0.1:${port}/api/init`, {\n              method: 'POST',\n              headers: {\n                'Content-Type': 'application/json',\n              },\n              body: JSON.stringify({ model }),\n            }).then(() => {\n              resolve(); // Resolve the promise when the server is ready\n            }).catch((err) => {\n              console.error('Error initializing the server:', err);\n              reject(err);\n            });\n          }\n        });\n\n        this.serverProcess.on('close', (code: number | null) => {\n          console.log(`Server process exited with code ${code}`);\n          this.serverProcess = null;\n        });\n\n        this.serverProcess.on('error', (err: any) => {\n          console.error(`Error in server process: ${err}`);\n          this.serverProcess = null;\n          reject(err);\n        });\n      });\n    });\n  }\n\n  stop(): void {\n    if (this.serverProcess) {\n      console.log('Stopping the server...');\n      this.serverProcess.kill();\n      this.serverProcess = null;\n    }\n  }\n}\n\n// Loading Screen\nlet splash: BrowserWindow | null;\nconst createSplashScreen = () => {\n  /// create a browser window\n  splash = new BrowserWindow(\n    {\n      width: 200,\n      height: 100,\n      focusable: false,\n      /// remove the window frame, so it will become a frameless window\n      frame: false,\n      skipTaskbar: true,\n      autoHideMenuBar: true,\n    },\n  );\n  splash.setResizable(false);\n  splash.loadURL(`file://${__dirname}/../splash/index.html`);\n  splash.on('closed', () => (splash = null));\n  splash.webContents.on('did-finish-load', () => {\n    if (splash) {\n      splash.show();\n    }\n  });\n};\n\n// run renderer\nconst isProd = process.env.NODE_ENV !== 'development';\nif (isProd) {\n  serve({ directory: 'out' });\n} else {\n  app.setPath('userData', `${app.getPath('userData')} (development)`);\n}\n\ncontextMenu.default({\n  showInspectElement: !isProd,\n});\n\nlet openModal: 'settings' | 'directory' | null = null;\n\nlet globalWindow: BrowserWindow | null = null;\n\nconst triggerShortcut = () => {\n  if (openModal || !globalWindow) {\n    return;\n  }\n  if (globalWindow.isFocused()) {\n    globalWindow.blur();\n    return;\n  }\n  globalWindow.show();\n};\n\nconst store = new Store({\n  schema: {\n    keybind: {\n      type: 'string',\n      default: 'Cmd+O',\n    },\n    model: {\n      type: 'string',\n      default: 'mistralai/Mistral-7B-Instruct-v0.2',\n    },\n    personalization: {\n      type: 'string',\n      default: '',\n    },\n    customResponse: {\n      type: 'string',\n      default: '',\n    },\n  },\n});\n\nconst serverManager = new ServerManager();\n\nconst createWindow = () => {\n  const icon = nativeImage.createFromPath(\n    !isProd\n      ? '../assets/IconTemplate.png'\n      : path.join(process.resourcesPath, 'IconTemplate.png'),\n  );\n  // if you want to resize it, be careful, it creates a copy\n  const trayIcon = icon.resize({ width: 16 });\n  // here is the important part (has to be set on the resized version)\n  trayIcon.setTemplateImage(true);\n  let tray = new Tray(trayIcon);\n  tray.setTitle(isProd ? '' : 'M');\n\n  const win = new BrowserWindow({\n    webPreferences: {\n      preload: path.join(__dirname, 'preload.js'),\n      devTools: !isProd,\n    },\n    show: false,\n    width: 600,\n    height: 99,\n    resizable: false,\n    type: 'panel',\n    frame: false,\n    skipTaskbar: true,\n    autoHideMenuBar: true,\n    vibrancy: 'under-window', // on MacOS\n    backgroundMaterial: 'acrylic',\n    icon: __dirname + '../../assets/public/icon.icns',\n  });\n  globalWindow = win;\n  win.setWindowButtonVisibility(false);\n  win.setAlwaysOnTop(true, 'floating');\n  win.setVisibleOnAllWorkspaces(true);\n\n  // Expose URL\n  if (isProd) {\n    win.loadURL('app://./home.html');\n  } else {\n    // const port = process.argv[2];\n    win.loadURL('http://localhost:3000/');\n  }\n\n  tray.addListener('click', () => {\n    if (win.isFocused()) {\n      win.blur();\n      return;\n    }\n    win.show();\n  });\n\n  win.webContents.on('did-finish-load', async () => {\n    await serverManager.start(store.get('model') as string);\n    /// then close the loading screen window and show the main window\n    if (splash) {\n      splash.close();\n    }\n    app.dock.hide();\n    win.show();\n    globalShortcut.register(store.get('keybind') as string, triggerShortcut.bind(null));\n  });\n\n  // @ts-expect-error -- We don't have types for electron\n  win.on('blur', (event) => {\n    if (openModal) {\n      win.setAlwaysOnTop(false);\n    }\n    if (openModal === 'directory') {\n      return;\n    }\n    if (win.webContents.isDevToolsOpened()) {\n      return;\n    }\n    globalShortcut.unregister('Escape');\n    globalShortcut.unregister('Cmd+Q');\n    win.hide();\n    if (openModal) {\n      return;\n    }\n\n    Menu.sendActionToFirstResponder('hide:');\n  });\n\n  win.on('focus', () => {\n    globalShortcut.register('Cmd+Q', () => {\n      if (!win.isFocused()) {\n        return;\n      }\n      app.quit();\n    });\n    globalShortcut.register('Escape', () => {\n      if (!win.isFocused()) {\n        return;\n      }\n      win.blur();\n    });\n  });\n\n  let settingsModal: BrowserWindow | null = null;\n\n  const createSettings = () => {\n    settingsModal = new BrowserWindow({\n      webPreferences: {\n        preload: path.join(__dirname, 'preload.js'),\n      },\n      width: 500,\n      height: 500,\n      resizable: false,\n      minimizable: false,\n      titleBarStyle: 'hidden',\n      show: false,\n      backgroundColor: '#000',\n    });\n\n    if (isProd) {\n      settingsModal.loadURL('app://./settings.html');\n    } else {\n      // const port = process.argv[2];\n      settingsModal.loadURL('http://localhost:3000/settings');\n    }\n\n    settingsModal.on('closed', () => {\n      openModal = null;\n      settingsModal?.destroy();\n      settingsModal = null;\n    });\n\n    settingsModal.on('ready-to-show', () => {\n      settingsModal?.show();\n    });\n\n    return settingsModal;\n  };\n\n  const nativeMenus: (Electron.MenuItemConstructorOptions | Electron.MenuItem)[] = [\n    {\n      label: 'MLX Chat',\n      submenu: [\n        {\n          label: 'Settings',\n          click() {\n            openModal = 'settings';\n            if (settingsModal !== null) {\n              settingsModal.close();\n            }\n            createSettings();\n          },\n          accelerator: 'Cmd+,',\n        },\n      ],\n    },\n    {\n      label: 'Edit',\n      submenu: [\n        { role: 'undo' },\n        { role: 'redo' },\n        { type: 'separator' },\n        { role: 'cut' },\n        { role: 'copy' },\n        { role: 'paste' },\n        { role: 'pasteAndMatchStyle' },\n        { role: 'delete' },\n        { role: 'selectAll' },\n        { type: 'separator' },\n        {\n          label: 'Speech',\n          submenu: [\n            { role: 'startSpeaking' },\n            { role: 'stopSpeaking' },\n          ],\n        },\n      ],\n    },\n  ];\n\n  const menu = Menu.buildFromTemplate(nativeMenus);\n  Menu.setApplicationMenu(menu);\n};\n\napp.whenReady().then(() => {\n  ipcMain.on('set-title', handleSetTitle);\n  ipcMain.on('select-directory', (event: any) => {\n    openModal = 'directory';\n    dialog.showOpenDialog({ properties: ['openDirectory'] }).then((result: any) => {\n      const win = BrowserWindow.fromWebContents(event.sender);\n      // Weird hack to bring the window to the front after allowing windows in front of it\n      win?.setAlwaysOnTop(true, 'floating');\n\n      openModal = null;\n      event.sender.send('selected-directory', result.filePaths);\n    });\n  });\n\n  ipcMain.on('resize-window', (event, arg) => {\n    const win = BrowserWindow.fromWebContents(event.sender);\n    if (!win) {\n      return;\n    }\n    win.setBounds({\n      height: arg.height,\n    });\n    win.center();\n  });\n\n  ipcMain.on('fetch-setting', (event, arg) => {\n    event.returnValue = store.get(arg);\n  });\n\n  ipcMain.on('update-setting', (_event, arg) => {\n    if (arg.key === 'keybind') {\n      globalShortcut.unregister(store.get('keybind') as string);\n      globalShortcut.register(arg.value, triggerShortcut.bind(null));\n    }\n    store.set(arg.key, arg.value);\n  });\n\n  createSplashScreen();\n\n  setTimeout(() => {\n    createWindow();\n  }, 500);\n\n  app.on('activate', () => {\n    if (BrowserWindow.getAllWindows().length === 0) { createWindow(); }\n  });\n});\n\napp.on('will-quit', () => {\n  exec(\n    `lsof -i :${serverManager.port} -P | awk 'NR>1 {print $2}' | xargs kill`,\n    (err, stdout, stderr) => {\n      if (err) {\n        console.log(err);\n        return;\n      }\n      console.log(`stdout: ${stdout}`);\n      console.log(`stderr: ${stderr}`);\n    },\n  );\n  BrowserWindow.getAllWindows().forEach((win) => {\n    win.close();\n    win.destroy();\n  });\n});\n"
  },
  {
    "path": "app/main/preload.ts",
    "content": "// eslint-disable-next-line import/no-extraneous-dependencies\nimport {\n  contextBridge,\n  ipcRenderer,\n} from 'electron';\n\nexport const electronAPI = {\n  setTitle: (title: string) => ipcRenderer.send('set-title', title),\n  selectDirectory: () => ipcRenderer.send('select-directory'),\n  onSelectDirectory: (cb: (customData: string[]) => void) => {\n    ipcRenderer.on('selected-directory', (event, customData) => {\n      // eslint-disable-next-line no-console\n      console.log(event);\n      cb(customData);\n    });\n  },\n  resizeWindow: (height: number) => ipcRenderer.send('resize-window', { height }),\n  fetchSetting: (key: string) => ipcRenderer.sendSync('fetch-setting', key),\n  updateSetting: (key: string, value: any) => ipcRenderer.send('update-setting', { key, value }),\n};\n\ncontextBridge.exposeInMainWorld('electronAPI', electronAPI);\n"
  },
  {
    "path": "app/main/renderer.d.ts",
    "content": "import { electronAPI } from \"./preload\";\n\ndeclare global {\n  interface Window {\n    electronAPI: typeof electronAPI;\n  }\n}\n\nexport {};\n"
  },
  {
    "path": "app/main/splash/index.css",
    "content": "@tailwind base;\n\n@tailwind components;\n\n@tailwind utilities;\n\ndiv {\n  -webkit-user-select: none;\n  -webkit-app-region: drag;\n}\n\n.loading-bar {\n  display: block;\n  height: 0.2em;\n  background-color: rgba(255, 255, 255, 0.2);\n  position: relative;\n  overflow: hidden;\n  border-radius: 1rem;\n}\n\n.loading-bar:before {\n  content: \"\";\n  display: block;\n  position: absolute;\n  left: -100%;\n  width: 100%;\n  height: 100%;\n  background-color: white;\n  animation: loading-bar 1.5s ease-in-out infinite;\n}\n\n@keyframes loading-bar {\n  from {\n    left: -100%;\n  }\n  to {\n    left: 100%;\n  }\n}\n"
  },
  {
    "path": "app/main/splash/index.html",
    "content": "<!DOCTYPE html>\n<html>\n  <head>\n    <meta charset=\"UTF-8\" />\n    <title>FLOATING LOADING SCREEN</title>\n    <link rel=\"stylesheet\" href=\"../tailwind.css\" />\n    <link rel=\"stylesheet\" href=\"./index.css\" />\n  </head>\n  <body>\n    <div\n      class=\"h-screen w-screen fixed z-50 flex flex-col items-center justify-center text-white text-4xl gap-8 bg-slate-600\"\n    >\n      <div class=\"loading-bar w-3/4 md:w-1/2 lg:w-1/3\"></div>\n    </div>\n  </body>\n</html>\n"
  },
  {
    "path": "app/main/tsconfig.json",
    "content": "{\n  \"compilerOptions\": {\n    \"allowJs\": true,\n    \"alwaysStrict\": true,\n    \"esModuleInterop\": true,\n    \"forceConsistentCasingInFileNames\": true,\n    \"isolatedModules\": true,\n    \"jsx\": \"preserve\",\n    \"lib\": [\"dom\", \"es2017\"],\n    \"module\": \"commonjs\",\n    \"moduleResolution\": \"node\",\n    \"noEmit\": false,\n    \"noFallthroughCasesInSwitch\": true,\n    \"noUnusedLocals\": true,\n    \"noUnusedParameters\": true,\n    \"resolveJsonModule\": true,\n    \"skipLibCheck\": true,\n    \"strict\": true,\n    \"target\": \"esnext\",\n    \"outDir\": \"./out\",\n  },\n  \"compileOnSave\": true,\n  \"exclude\": [\"node_modules\", \"./out/**/*\"],\n  \"include\": [\"**/*.ts\", \"**/*.tsx\", \"**/*.js\", \"public/**.icns\"],\n}\n"
  },
  {
    "path": "app/next.config.js",
    "content": "/** @type {import('next').NextConfig} */\nconst nextConfig = {\n  output: \"export\",\n  distDir: \"out\",\n};\n\nmodule.exports = nextConfig;\n"
  },
  {
    "path": "app/notarize.js",
    "content": "require('dotenv').config();\nconst { notarize } = require('electron-notarize');\n\nexports.default = async function notarizing(context) {\n  const { electronPlatformName, appOutDir } = context;\n  if (electronPlatformName !== 'darwin') {\n    return;\n  }\n\n  const appName = context.packager.appInfo.productFilename;\n\n  return await notarize({\n    appBundleId: 'com.parkersmith.mlx-chat',\n    appPath: `${appOutDir}/${appName}.app`,\n    appleId: process.env.APPLEID,\n    appleIdPassword: process.env.APPLEIDPASS,\n  });\n};\n"
  },
  {
    "path": "app/package.json",
    "content": "{\n  \"name\": \"electron-app\",\n  \"productName\": \"Electron App\",\n  \"version\": \"0.1.0\",\n  \"private\": true,\n  \"main\": \"main/out/main.js\",\n  \"homepage\": \"./\",\n  \"description\": \"My Next.js project\",\n  \"author\": \"test\",\n  \"scripts\": {\n    \"dev\": \"cross-env NODE_ENV=development concurrently -k \\\"cross-env BROWSER=none npm run next:dev\\\" \\\"npm run electron:dev\\\"\",\n    \"build\": \" npm run build:main && next build\",\n    \"start\": \"cross-env npm run electron\",\n    \"build:tailwindMain\": \"npx tailwindcss build --config tailwind.config.main.js -o ./main/tailwind.css\",\n    \"build:main\": \"tsc -p main && npm run build:tailwindMain\",\n    \"next:dev\": \"next dev\",\n    \"next:start\": \"next start\",\n    \"next:lint\": \"next lint\",\n    \"electron:dev\": \"npm run build:main && wait-on tcp:3000 && electron .\",\n    \"electron\": \"electron .\",\n    \"pack\": \"npm run build && electron-builder --dir\",\n    \"dist\": \"npm run build && electron-builder\",\n    \"lint\": \"npx eslint --max-warnings 0 --ext=.ts .\"\n  },\n  \"dependencies\": {\n    \"@electron/osx-sign\": \"^1.0.5\",\n    \"@fortawesome/fontawesome-free\": \"^6.5.1\",\n    \"@fortawesome/fontawesome-svg-core\": \"^6.5.1\",\n    \"@fortawesome/free-regular-svg-icons\": \"^6.5.1\",\n    \"@fortawesome/free-solid-svg-icons\": \"^6.5.1\",\n    \"@fortawesome/react-fontawesome\": \"^0.2.0\",\n    \"@matejmazur/react-katex\": \"^3.1.3\",\n    \"@radix-ui/react-icons\": \"^1.3.0\",\n    \"@radix-ui/react-select\": \"^2.0.0\",\n    \"@radix-ui/react-slot\": \"^1.0.2\",\n    \"@radix-ui/react-tooltip\": \"^1.0.7\",\n    \"@reduxjs/toolkit\": \"^2.2.1\",\n    \"@types/electron\": \"^1.6.10\",\n    \"@types/node\": \"^20.6.0\",\n    \"@types/react\": \"^18.2.21\",\n    \"@types/react-dom\": \"^18.2.7\",\n    \"autoprefixer\": \"^10.4.15\",\n    \"class-variance-authority\": \"^0.7.0\",\n    \"clsx\": \"^2.1.0\",\n    \"concurrently\": \"^8.2.1\",\n    \"cross-env\": \"^7.0.3\",\n    \"dprint\": \"^0.45.0\",\n    \"electron-context-menu\": \"^3.6.1\",\n    \"electron-serve\": \"^1.1.0\",\n    \"electron-squirrel-startup\": \"^1.0.0\",\n    \"electron-store\": \"^8.1.0\",\n    \"eslint\": \"8.41.0\",\n    \"eslint-config-next\": \"13.4.3\",\n    \"markdown-to-jsx\": \"^7.4.1\",\n    \"next\": \"13.4.3\",\n    \"postcss\": \"^8.4.29\",\n    \"react\": \"18.2.0\",\n    \"react-dom\": \"18.2.0\",\n    \"react-redux\": \"^9.1.0\",\n    \"react-resizable-panels\": \"^2.0.11\",\n    \"rxjs\": \"^7.8.1\",\n    \"tailwind-merge\": \"^2.2.1\",\n    \"tailwindcss\": \"^3.3.3\",\n    \"tailwindcss-animate\": \"^1.0.7\",\n    \"wait-on\": \"^7.0.1\"\n  },\n  \"devDependencies\": {\n    \"@typescript-eslint/eslint-plugin\": \"^6.21.0\",\n    \"@typescript-eslint/parser\": \"^6.21.0\",\n    \"dotenv\": \"^16.4.5\",\n    \"dprint\": \"^0.45.0\",\n    \"electron\": \"^26.2.0\",\n    \"electron-builder\": \"^24.6.4\",\n    \"electron-notarize\": \"^1.2.2\",\n    \"eslint\": \"^8.56.0\",\n    \"eslint-config-airbnb-base\": \"^15.0.0\",\n    \"eslint-plugin-compat\": \"^4.2.0\",\n    \"eslint-plugin-jest\": \"^27.6.3\",\n    \"eslint-plugin-jest-formatting\": \"^3.1.0\",\n    \"eslint-plugin-jsdoc\": \"^48.0.6\",\n    \"eslint-plugin-jsx-a11y\": \"^6.8.0\",\n    \"eslint-plugin-justinanastos\": \"^1.3.1\",\n    \"eslint-plugin-modules-newlines\": \"^0.0.7\",\n    \"eslint-plugin-no-unsanitized\": \"^4.0.2\",\n    \"eslint-plugin-react\": \"^7.33.2\",\n    \"eslint-plugin-react-hooks\": \"^4.6.0\",\n    \"eslint-plugin-unused-imports\": \"^3.0.0\",\n    \"typescript\": \"^5.2.2\"\n  },\n  \"build\": {\n    \"appId\": \"mlx-chat\",\n    \"productName\": \"MLX Chat\",\n    \"afterSign\": \"notarize.js\",\n    \"win\": {\n      \"target\": [\n        \"nsis\"\n      ]\n    },\n    \"nsis\": {\n      \"oneClick\": false,\n      \"perMachine\": true,\n      \"allowToChangeInstallationDirectory\": true,\n      \"uninstallDisplayName\": \"MLX Chat\"\n    },\n    \"mac\": {\n      \"category\": \"your.app.category.type\",\n      \"target\": [\n        \"dmg\"\n      ],\n      \"gatekeeperAssess\": false,\n      \"hardenedRuntime\": true,\n      \"icon\": \"assets/icon.icns\",\n      \"entitlements\": \"./mac/entitlements.mac.inherit.plist\",\n      \"entitlementsInherit\": \"./mac/entitlements.mac.inherit.plist\"\n    },\n    \"dmg\": {\n      \"title\": \"MLX Chat Installer\",\n      \"sign\": false\n    },\n    \"extraFiles\": [\n      {\n        \"from\": \"assets\",\n        \"to\": \"resources\",\n        \"filter\": [\n          \"**/*\"\n        ]\n      },\n      {\n        \"from\": \"../dist\",\n        \"to\": \"resources/server\",\n        \"filter\": [\n          \"**/*\"\n        ]\n      }\n    ]\n  }\n}\n"
  },
  {
    "path": "app/postcss.config.js",
    "content": "module.exports = {\n  plugins: {\n    tailwindcss: {},\n    autoprefixer: {},\n  },\n}\n"
  },
  {
    "path": "app/src/AppProvider.tsx",
    "content": "'use client';\n\nimport {\n  useRef,\n} from 'react';\nimport {\n  Provider,\n} from 'react-redux';\nimport type {\n  AppStore,\n} from './lib/store';\nimport {\n  makeStore,\n} from './lib/store';\n\nexport default function StoreProvider({\n  children,\n}: {\n  children: React.ReactNode;\n}) {\n  const storeRef = useRef<AppStore>();\n  if (!storeRef.current) {\n    // Create the store instance the first time this renders\n    storeRef.current = makeStore();\n  }\n\n  return <Provider store={storeRef.current}>{children}</Provider>;\n}\n"
  },
  {
    "path": "app/src/app/globals.css",
    "content": "@tailwind base;\n@tailwind components;\n@tailwind utilities;\n\n@layer base {\n  :root {\n    --background: 0 0% 100%;\n    --foreground: 240 10% 3.9%;\n    --card: 0 0% 100%;\n    --card-foreground: 240 10% 3.9%;\n    --popover: 0 0% 100%;\n    --popover-foreground: 240 10% 3.9%;\n    --primary: 240 5.9% 10%;\n    --primary-foreground: 0 0% 98%;\n    --secondary: 240 4.8% 95.9%;\n    --secondary-foreground: 240 5.9% 10%;\n    --muted: 240 4.8% 95.9%;\n    --muted-foreground: 240 3.8% 46.1%;\n    --accent: 240 4.8% 95.9%;\n    --accent-foreground: 240 5.9% 10%;\n    --destructive: 0 72.22% 50.59%;\n    --destructive-foreground: 0 0% 98%;\n    --border: 240 5.9% 90%;\n    --input: 240 5.9% 90%;\n    --ring: 240 5% 64.9%;\n    --radius: 0.5rem;\n  }\n\n  @media screen and (prefers-color-scheme: dark) {\n      :root {\n      --background: 240 10% 3.9%;\n      --foreground: 0 0% 98%;\n      --card: 240 10% 3.9%;\n      --card-foreground: 0 0% 98%;\n      --popover: 240 10% 3.9%;\n      --popover-foreground: 0 0% 98%;\n      --primary: 0 0% 98%;\n      --primary-foreground: 240 5.9% 10%;\n      --secondary: 240 3.7% 15.9%;\n      --secondary-foreground: 0 0% 98%;\n      --muted: 240 3.7% 15.9%;\n      --muted-foreground: 240 5% 64.9%;\n      --accent: 240 3.7% 15.9%;\n      --accent-foreground: 0 0% 98%;\n      --destructive: 0 62.8% 30.6%;\n      --destructive-foreground: 0 85.7% 97.3%;\n      --border: 240 3.7% 15.9%;\n      --input: 240 3.7% 15.9%;\n      --ring: 240 4.9% 83.9%;\n    }\n  }\n}\n\n@layer base {\n  * {\n    @apply border-border;\n  }\n  body {\n    @apply text-foreground;\n    /* font-feature-settings: \"rlig\" 1, \"calt\" 1; */\n    font-synthesis-weight: none;\n    text-rendering: optimizeLegibility;\n  }\n}\n\n@layer utilities {\n  .step {\n    counter-increment: step;\n  }\n\n  .step:before {\n    @apply absolute w-9 h-9 bg-muted rounded-full font-mono font-medium text-center text-base inline-flex items-center justify-center -indent-px border-4 border-background;\n    @apply ml-[-50px] mt-[-4px];\n    content: counter(step);\n  }\n}\n\n@media (max-width: 640px) {\n  .container {\n    @apply px-4;\n  }\n}\n\n/* Update scrollbar when in dark mode */\n@media screen and (prefers-color-scheme: dark) {\n  ::-webkit-scrollbar-thumb {\n    background-color: hsl(var(--muted));\n    border-radius: 5px;\n    transition: all;\n  }\n  ::-webkit-scrollbar-thumb:hover {\n    background-color: hsl(255, 4%, 20%);\n  }\n}\n\n/* Update scrollbar when in light mode */\n@media screen and (prefers-color-scheme: light) {\n  ::-webkit-scrollbar-thumb {\n    background-color: rgb(38 38 38);\n    border-radius: 5px;\n    transition: all;\n  }\n  ::-webkit-scrollbar-thumb:hover {\n    background-color: hsl(255, 4%, 20%);\n  }\n}\n\n::-webkit-scrollbar {\n  width: 7px;\n}\n::-webkit-scrollbar-track {\n  background-color: transparent;\n}\n\nhtml {\n  font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont;\n}\n\nol,\nul,\nmenu {\n  list-style: outside;\n  margin: 0;\n  padding-left: 20px;\n}\n\n.drag {\n  -webkit-app-region: drag;\n}\n\n.no-drag {\n  -webkit-app-region: no-drag;\n}\n"
  },
  {
    "path": "app/src/app/layout.tsx",
    "content": "'use client';\n\nimport StoreProvider from '../AppProvider';\nimport './globals.css';\nimport '@fortawesome/fontawesome-svg-core/styles.css';\n// Prevent fontawesome from adding its CSS since we did it manually above:\nimport {\n  config,\n} from '@fortawesome/fontawesome-svg-core';\nimport {\n  TooltipProvider,\n} from '../components/ui/tooltip';\n\nconfig.autoAddCss = false;\n\nexport default function RootLayout({\n  children,\n}: {\n  children: React.ReactNode;\n}) {\n  return (\n    <html lang='en'>\n      <body\n        className={'min-h-screen overflow-y-hidden'}\n        style={{\n          userSelect: 'none',\n        }}\n      >\n        <TooltipProvider>\n          <StoreProvider>\n            {children}\n          </StoreProvider>\n        </TooltipProvider>\n      </body>\n    </html>\n  );\n}\n"
  },
  {
    "path": "app/src/app/page.tsx",
    "content": "'use client';\n\nimport {\n  faBan,\n  faCheckCircle,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesomeIcon,\n} from '@fortawesome/react-fontawesome';\nimport React, {\n  useEffect,\n  useState,\n} from 'react';\nimport Chat from '../components/chat/Chat';\nimport SelectDirectory from '../components/options/SelectDirectory';\nimport {\n  Button,\n} from '../components/ui/button';\nimport {\n  Tooltip,\n  TooltipContent,\n  TooltipTrigger,\n} from '../components/ui/tooltip';\nimport type {\n  ChatMessage,\n} from '../constants/chat';\nimport {\n  useAppDispatch,\n} from '../lib/hooks';\nimport {\n  startDirectoryIndexing,\n  stopDirectoryIndexing,\n} from '../lib/store';\n\nexport default function Home() {\n  const [selectedDirectory, setSelectedDirectory] = useState<string | null>(null);\n  const [chatHistory, setChatHistory] = useState<ChatMessage[]>([]);\n\n  const dispatch = useAppDispatch();\n\n  function handleOpen() {\n    if (typeof window !== 'undefined') {\n      window.electronAPI.selectDirectory();\n    }\n  }\n\n  useEffect(() => {\n    window.electronAPI.onSelectDirectory(async (customData: string[]) => {\n      setSelectedDirectory(customData[0]);\n      try {\n        dispatch(startDirectoryIndexing());\n        await fetch('http://localhost:8080/api/index', {\n          method: 'POST',\n          headers: {\n            'Content-Type': 'application/json',\n          },\n          body: JSON.stringify({\n            directory: customData[0],\n          }),\n        });\n        dispatch(stopDirectoryIndexing());\n        // TODO: spinner while indexing\n      } catch (error) {\n        // eslint-disable-next-line no-console\n        console.error('Error sending message: ', error);\n        dispatch(stopDirectoryIndexing());\n      }\n    });\n  }, []);\n\n  useEffect(() => {\n    window.electronAPI.onSelectDirectory(() => {\n      if (chatHistory.length) {\n        setChatHistory([\n          ...chatHistory,\n          { role: 'system', content: 'Assist' },\n        ]);\n      }\n    });\n  }, [chatHistory]);\n\n  const handleClearHistory = () => {\n    setChatHistory([]);\n    if (typeof window !== 'undefined') {\n      window.electronAPI.resizeWindow(99);\n    }\n  };\n\n  const clearDirectory = () => {\n    setSelectedDirectory(null);\n    if (chatHistory.length) {\n      setChatHistory([\n        ...chatHistory,\n        { role: 'system', content: 'Converse' },\n      ]);\n    }\n  };\n\n  return (\n    <main className='flex flex-col'>\n      <Chat\n        chatHistory={chatHistory}\n        setChatHistory={setChatHistory}\n        selectedDirectory={selectedDirectory}\n      />\n      <div className='border-t border-t-neut\n      ral-400 dark:border-t-neutral-700 pt-[5px] px-2'>\n        <div className='flex justify-between drag'>\n          {chatHistory.length\n            ? (\n              <Tooltip delayDuration={0}>\n                <TooltipTrigger>\n                  <Button\n                    className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'\n                    onClick={handleClearHistory}\n                  >\n                    <FontAwesomeIcon icon={faCheckCircle} className='text-green-500' />\n                  </Button>\n                </TooltipTrigger>\n                <TooltipContent>Clear History</TooltipContent>\n              </Tooltip>\n            )\n            : (\n              <Button\n                className='bg-transparent no-drag text-neutral-800 gap-1 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-3 flex items-center cursor-pointer'\n                disabled={true}\n              >\n                <FontAwesomeIcon icon={faBan} className='text-red-400' />\n              </Button>\n            )}\n          <SelectDirectory\n            clearDirectory={clearDirectory}\n            handleOpen={handleOpen}\n            selectedDirectory={selectedDirectory}\n          />\n        </div>\n      </div>\n    </main>\n  );\n}\n"
  },
  {
    "path": "app/src/app/settings/page.tsx",
    "content": "'use client';\n\nimport type {\n  IconProp,\n} from '@fortawesome/fontawesome-svg-core';\nimport {\n  faCog,\n  faMessage,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesomeIcon,\n} from '@fortawesome/react-fontawesome';\nimport React, {\n  useEffect,\n} from 'react';\nimport SelectModel from '../../components/options/SelectModel';\nimport {\n  Textarea,\n} from '../../components/ui/textarea';\nimport {\n  convertToNiceShortcut,\n  useKeyboardShortcut,\n} from '../../lib/hooks';\nimport {\n  cn,\n} from '../../lib/utils';\n\nenum SETTINGS {\n  GENERAL,\n  PROMPTS,\n}\n\nfunction SettingsOption({\n  title,\n  icon,\n  onClick,\n  selected,\n}: {\n  title: string;\n  icon: IconProp;\n  onClick: () => void;\n  selected?: boolean;\n}) {\n  return (\n    <div\n      onClick={onClick}\n      className={cn(\n        'flex flex-col items-center h-[44px] gap-[2px] no-drag w-fit p-1 px-2 justify-center hover:bg-[#E2E3E2] dark:hover:bg-[#454544] active:bg-[#D5D6D5] dark:active:bg-[#525251] cursor-default rounded-md text-[#6C6C6C] dark:text-[#9C9C9B] active:text-[#202020] dark:active:text-[#EEEEEE]',\n        {\n          'bg-[#E2E3E2] dark:bg-[#454544]': selected,\n        },\n      )}\n    >\n      <FontAwesomeIcon\n        className={cn('text-[20px] pt-1', {\n          'dark:text-[#0D87FF] text-[#0066EB]': selected,\n        })}\n        icon={icon}\n      />\n      <h1\n        className={cn('text-[11px]', {\n          'dark:text-[#0D87FF] text-[#0066EB]': selected,\n        })}\n      >\n        {title}\n      </h1>\n    </div>\n  );\n}\n\nfunction GeneralSettings() {\n  const {\n    startListening,\n    stopListening,\n    shortcut,\n  } = useKeyboardShortcut();\n\n  const [keybind, setKeybind] = React.useState<string>(\n    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('keybind') : '⌘O',\n  );\n  const [model, setModel] = React.useState<string>(\n    typeof window !== 'undefined'\n      ? window.electronAPI.fetchSetting('model')\n      : 'mistralai/Mistral-7B-Instruct-v0.2',\n  );\n\n  useEffect(() => {\n    if (!shortcut) {\n      return;\n    }\n    setKeybind(shortcut);\n  }, [shortcut]);\n\n  return (\n    <div className='flex flex-col justify-center w-full items-center'>\n      <div className='flex items-center mt-2'>\n        <p className='text-sm'>Launch keybind:</p>\n        <input\n          className='rounded-sm text-[12px] text-center drop-shadow-sm bg-[#fffff] dark:bg-[#343432] border-0 h-[18px] w-28 ml-3 active:border-0 focus:border-0 outline-offset-2 focus:outline-blue-400'\n          type='text'\n          readOnly\n          value={convertToNiceShortcut(keybind)}\n          onFocus={startListening}\n          onBlur={() => {\n            stopListening();\n            if (typeof window !== 'undefined') {\n              window.electronAPI.updateSetting('keybind', shortcut);\n            }\n          }}\n        />\n      </div>\n      <div className='flex items-center mt-2'>\n        <p className='text-sm mr-2'>Default model:</p>\n        <SelectModel\n          selectedModel={model}\n          handleModelChange={(selectedModel) => {\n            setModel(selectedModel);\n            if (typeof window !== 'undefined' && selectedModel) {\n              window.electronAPI.updateSetting('model', selectedModel);\n            }\n          }}\n        />\n      </div>\n    </div>\n  );\n}\n\nfunction PromptSettings() {\n  const [personalization, setPersonalization] = React.useState<string>(\n    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('personalization') : '',\n  );\n  const [response, setResponse] = React.useState<string>(\n    typeof window !== 'undefined' ? window.electronAPI.fetchSetting('customResponse') : '',\n  );\n\n  return (\n    <div className='flex flex-col justify-center w-full items-center gap-4'>\n      <div className='flex flex-col items-center mt-2 gap-2'>\n        <p className='text-sm flex-shrink-0 font-bold'>Personalization</p>\n        <Textarea\n          className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'\n          value={personalization}\n          onChange={(e) => {\n            setPersonalization(e.target.value);\n            if (typeof window !== 'undefined') {\n              window.electronAPI.updateSetting('personalization', e.target.value);\n            }\n          }}\n          rows={5}\n          placeholder={`Things to know about you... e.g.,\n  - I enjoy thought provoking conversation\n  - I am a fan of the arts and culture`}\n        />\n      </div>\n      <div className='flex flex-col items-center mt-2 gap-2'>\n        <p className='text-sm flex-shrink-0 font-bold'>Custom Response</p>\n        <Textarea\n          className='bg-[#C9C9C9] dark:bg-[#252523] border-[#B5B5B5] dark:border-[#3B3B39] border resize-none w-[300px]'\n          value={response}\n          onChange={(e) => {\n            setResponse(e.target.value);\n            if (typeof window !== 'undefined') {\n              window.electronAPI.updateSetting('customResponse', e.target.value);\n            }\n          }}\n          rows={5}\n          placeholder={`How to format responses... e.g.,\n  - Respond in a concise manner\n  - Do not use slang`}\n        />\n      </div>\n    </div>\n  );\n}\n\nexport default function Settings() {\n  const [selectedSetting, setSelectedSetting] = React.useState<SETTINGS>(SETTINGS.GENERAL);\n\n  return (\n    <main className='flex flex-col bg-[#F1F1F1] dark:bg-[#383736] h-screen'>\n      <div className='h-[81px] border-0 border-b border-b-neutral-300 dark:border-b-zinc-950 drag flex flex-col'>\n        <h1 className='text-[12px] font-bold text-[#6C6C6C] dark:text-[#9C9C9B] text-center pt-1'>\n          {selectedSetting === SETTINGS.GENERAL\n            ? 'General'\n            : 'Prompt'}\n        </h1>\n        <div className='flex justify-center items-center gap-[1px] mt-1'>\n          <SettingsOption\n            title='General'\n            icon={faCog}\n            onClick={() => setSelectedSetting(SETTINGS.GENERAL)}\n            selected={selectedSetting === SETTINGS.GENERAL}\n          />\n          <SettingsOption\n            title='Prompts'\n            icon={faMessage}\n            onClick={() => setSelectedSetting(SETTINGS.PROMPTS)}\n            selected={selectedSetting === SETTINGS.PROMPTS}\n          />\n        </div>\n      </div>\n      <div className='flex-grow dark:bg-[#292929] bg-[#EEEDEC]'>\n        {selectedSetting === SETTINGS.GENERAL\n          ? <GeneralSettings />\n          : <PromptSettings />}\n      </div>\n    </main>\n  );\n}\n"
  },
  {
    "path": "app/src/components/chat/Chat.tsx",
    "content": "import React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\nimport {\n  useAppDispatch,\n} from '../../lib/hooks';\nimport {\n  startWaitingForResponse,\n  stopWaitingForResponse,\n} from '../../lib/store';\nimport {\n  cn,\n} from '../../lib/utils';\nimport ChatInput from './ChatInput';\nimport ChatMessages from './ChatMessages';\n\nconst Chat = ({\n  selectedDirectory,\n  chatHistory,\n  setChatHistory,\n}: {\n  selectedDirectory: string | null;\n  chatHistory: ChatMessage[];\n  setChatHistory: (chats: ChatMessage[]) => void;\n}) => {\n  const dispatch = useAppDispatch();\n  const sendMessage = async (message: string) => {\n    try {\n      if (chatHistory.length === 0) {\n        window.electronAPI.resizeWindow(500);\n      }\n      const newHistory = [\n        ...chatHistory,\n        { role: 'user' as const, content: message },\n      ];\n      setChatHistory(newHistory);\n      dispatch(startWaitingForResponse());\n      const response = await fetch('http://localhost:8080/api/query', {\n        method: 'POST',\n        headers: {\n          'Content-Type': 'application/json',\n        },\n        body: JSON.stringify({\n          messages: selectedDirectory\n            ? [{ role: 'user', content: message }]\n            : newHistory.filter((chat) => chat.role !== 'system'),\n          temperature: 0.7,\n          // eslint-disable-next-line @typescript-eslint/naming-convention\n          top_p: 1.0,\n          // eslint-disable-next-line @typescript-eslint/naming-convention\n          max_tokens: 200,\n          directory: selectedDirectory,\n          instructions: {\n            personalization: typeof window !== 'undefined'\n              ? window.electronAPI.fetchSetting('personalization')\n              : '',\n            response: typeof window !== 'undefined'\n              ? window.electronAPI.fetchSetting('customResponse')\n              : '',\n          },\n        }),\n      });\n      dispatch(stopWaitingForResponse());\n\n      const responseData = await response.json();\n      const assistantResponse = responseData.choices[0].message.content;\n\n      setChatHistory([\n        ...newHistory,\n        { role: 'assistant', content: assistantResponse },\n      ]);\n    } catch (error) {\n      dispatch(stopWaitingForResponse());\n      // eslint-disable-next-line no-console\n      console.error('Error sending message: ', error);\n    }\n  };\n\n  return (\n    <>\n      <div\n        className={cn('flex justify-center border-b-neutral-400 dark:border-b-neutral-700', {\n          'border-b': chatHistory.length > 0,\n        })}\n      >\n        <ChatInput sendMessage={sendMessage} />\n      </div>\n      <div\n        className={cn(\n          'flex-grow min-w-full border-0 flex h-0',\n          {\n            'h-[400px]': chatHistory.length > 0,\n          },\n        )}\n      >\n        <ChatMessages chatHistory={chatHistory} />\n      </div>\n    </>\n  );\n};\n\nexport default Chat;\n"
  },
  {
    "path": "app/src/components/chat/ChatInput.tsx",
    "content": "import React, {\n  useEffect,\n  useRef,\n  useState,\n} from 'react';\nimport {\n  useAppSelector,\n} from '../../lib/hooks';\nimport {\n  Input,\n} from '../ui/input';\n\nconst ChatInput = ({\n  sendMessage,\n}: {\n  sendMessage: (text: string) => void;\n}) => {\n  const [message, setMessage] = useState<string>('');\n  const inputRef = useRef<HTMLInputElement>(null);\n\n  const handleSend = (e: React.KeyboardEvent) => {\n    if (e.key !== 'Enter' || message.length === 0) {\n      return;\n    }\n    e.preventDefault();\n    setMessage('');\n    sendMessage(message);\n  };\n\n  // detect website focus and focus the input\n  const handleFocus = () => {\n    if (document.activeElement !== inputRef.current) {\n      inputRef.current?.focus();\n    }\n  };\n\n  useEffect(() => {\n    window.addEventListener('focus', handleFocus);\n    return () => {\n      window.removeEventListener('focus', handleFocus);\n    };\n  }, []);\n\n  const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);\n\n  return (\n    <div\n      className='w-full py-2 drag'\n      onClick={() => {\n        if (document.activeElement !== inputRef.current) {\n          inputRef.current?.focus();\n        }\n      }}\n    >\n      <Input\n        value={message}\n        onChange={(e) => setMessage(e.target.value)}\n        placeholder={isDirectoryIndexing ? 'Indexing your files..' : 'Enter prompt here'}\n        onKeyDown={handleSend}\n        ref={inputRef}\n        disabled={isDirectoryIndexing}\n        className={'text-xl no-drag border-0 focus-visible:outline-transparent focus-visible:ring-0 focus-visible:shadow-0 w-full shadow-0'}\n      />\n    </div>\n  );\n};\n\nexport default ChatInput;\n"
  },
  {
    "path": "app/src/components/chat/ChatMessage.tsx",
    "content": "import Markdown from 'markdown-to-jsx';\nimport React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\n\nconst Message = ({\n  message,\n}: {\n  message: ChatMessage;\n}) => (\n  <div\n    className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'}`}\n  >\n    <div\n      className={`p-2 rounded-sm ${\n        message.role === 'user'\n          ? 'bg-blue-500 text-white'\n          : 'bg-[#E9E9EB] dark:bg-zinc-500'\n      }`}\n    >\n      <div className='text-md select-text'>\n        <Markdown\n          children={message.content ?? ''}\n        />\n      </div>\n    </div>\n  </div>\n);\n\nexport default Message;\n"
  },
  {
    "path": "app/src/components/chat/ChatMessages.tsx",
    "content": "/* eslint-disable function-paren-newline */\nimport {\n  faCircleNotch,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesomeIcon,\n} from '@fortawesome/react-fontawesome';\nimport React, {\n  useEffect,\n} from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\nimport {\n  useAppSelector,\n} from '../../lib/hooks';\nimport Message from './ChatMessage';\nimport SystemMessage from './SystemMessage';\n\nconst ChatMessages = ({\n  chatHistory,\n}: {\n  chatHistory: ChatMessage[];\n}) => {\n  const messagesRef = React.useRef<HTMLDivElement>(null);\n\n  const scrollToBottom = () => {\n    const scrollHeight = messagesRef.current?.scrollHeight;\n    const height = messagesRef.current?.clientHeight ?? 0;\n    const maxScrollTop = scrollHeight ? scrollHeight - height : 0;\n    if (messagesRef.current) {\n      messagesRef.current.scrollTop = maxScrollTop > 0 ? maxScrollTop : 0;\n    }\n  };\n\n  const isWaitingForResponse = useAppSelector((state) => state.isWaitingForResponse);\n\n  useEffect(() => {\n    // check if the user is not at the bottom of the chat\n    const currentScroll = messagesRef.current?.scrollTop ?? 0;\n    const scrollHeight = messagesRef.current?.scrollHeight;\n    const height = messagesRef.current?.clientHeight ?? 0;\n    const maxScrollTop = scrollHeight ? scrollHeight - height : 0;\n    const scrollInHistory = (maxScrollTop - currentScroll) > 200;\n\n    if (scrollInHistory && chatHistory[chatHistory.length - 1]?.role !== 'user') {\n      return;\n    }\n\n    scrollToBottom();\n  }, [chatHistory]);\n  return chatHistory.length\n    ? (\n      <div ref={messagesRef} className='flex flex-col flex-grow p-4 gap-4 overflow-y-scroll'>\n        {chatHistory.map((message, index) => (message.role !== 'system'\n          ? (\n            <Message\n              key={index}\n              message={message}\n            />\n          )\n          : (\n            <SystemMessage\n              key={index}\n              message={message}\n            />\n          ))\n        )}\n        {isWaitingForResponse\n          ? (\n            <div\n              className={'flex justify-start'}\n            >\n              <div\n                className={'p-2 rounded-sm'}\n              >\n                <div className='text-md select-text'>\n                  <FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />\n                </div>\n              </div>\n            </div>\n          )\n          : null}\n      </div>\n    )\n    : null;\n};\n\nexport default ChatMessages;\n"
  },
  {
    "path": "app/src/components/chat/SystemMessage.tsx",
    "content": "import React from 'react';\nimport type {\n  ChatMessage,\n} from '../../constants/chat';\n\nconst Message = ({\n  message,\n}: {\n  message: ChatMessage;\n}) => (\n  <div\n    className={'flex w-full'}\n  >\n    <div\n      className={'rounded-sm w-full relative flex items-center'}\n    >\n      <div className='w-full h-[1px] bg-red-500 rounded-md' />\n      <div className='text-[12px] font-semibold select-text px-2 text-red-500 bg-transparent flex-grow rounded-md text-center py-1 whitespace-nowrap'>\n        {message.content}\n      </div>\n      <div className='w-full h-[1px] bg-red-500 rounded-md' />\n      <div className='arrow-tag text-[10px] p-0 absolute right-0 font-bold select-text uppercase pr-1 pl-1 text-white bg-red-500 flex-grow rounded-sm text-center whitespace-nowrap'>\n        <svg\n          className='absolute -left-[5px] top-[1px] z-[-1]'\n          aria-hidden='true'\n          role='img'\n          width='8'\n          height='13'\n          viewBox='0 0 8 13'\n        >\n          <path\n            className='fill-red-500 text-red-500'\n            stroke='currentColor'\n            fill='transparent'\n            d='M8.16639 0.5H9C10.933 0.5 12.5 2.067 12.5 4V9C12.5 10.933 10.933 12.5 9 12.5H8.16639C7.23921 12.5 6.34992 12.1321 5.69373 11.4771L0.707739 6.5L5.69373 1.52292C6.34992 0.86789 7.23921 0.5 8.16639 0.5Z'\n          >\n          </path>\n        </svg>\n        Mode\n      </div>\n    </div>\n  </div>\n);\n\nexport default Message;\n"
  },
  {
    "path": "app/src/components/options/SelectDirectory.tsx",
    "content": "import {\n  faCheckCircle,\n  faCircleNotch,\n  faXmark,\n} from '@fortawesome/free-solid-svg-icons';\nimport {\n  FontAwesomeIcon,\n} from '@fortawesome/react-fontawesome';\nimport React, {\n  useEffect,\n} from 'react';\nimport {\n  useAppSelector,\n  usePrevious,\n} from '../../lib/hooks';\nimport {\n  cn,\n} from '../../lib/utils';\nimport {\n  Button,\n} from '../ui/button';\n\nconst SelectDirectory = ({\n  handleOpen,\n  selectedDirectory,\n  clearDirectory,\n}: {\n  handleOpen: () => void;\n  selectedDirectory: string | null;\n  clearDirectory: () => void;\n}) => {\n  const shortenedDirectory = selectedDirectory\n    ? `/${selectedDirectory.split('/')[1]}/../${selectedDirectory.split('/').pop()}`\n    : 'Select Directory';\n  const [isCheckShowing, setIsCheckShowing] = React.useState(false);\n\n  const isDirectoryIndexing = useAppSelector((state) => state.isDirectoryIndexing);\n\n  const oldLoadingState = usePrevious(isDirectoryIndexing);\n\n  useEffect(() => {\n    if (oldLoadingState && !isDirectoryIndexing) {\n      setIsCheckShowing(true);\n      setTimeout(() => {\n        setIsCheckShowing(false);\n      }, 3000);\n    }\n  }, [isDirectoryIndexing]);\n\n  return (\n    <div className='flex no-drag items-center group'>\n      <Button\n        className={cn(\n          'bg-transparent text-neutral-800 z-0 dark:text-white text-sm font-normal shadow-none hover:bg-neutral-200 dark:hover:bg-neutral-700 transition-all border-0 border-zinc-600 w-fit rounded-md py-1 px-2 flex items-center cursor-pointer',\n          {\n            'hover:bg-transparent dark:hover:bg-transparent cursor-default': isDirectoryIndexing,\n          },\n        )}\n        onClick={isDirectoryIndexing ? undefined : handleOpen}\n      >\n        <div className='pr-1'>\n          {selectedDirectory && !isDirectoryIndexing && !isCheckShowing && (\n            <div\n              className='group-hover:opacity-100 opacity-0 px-1 z-10 hover:bg-neutral-300 dark:hover:bg-neutral-800 rounded-sm transition-all cursor-pointer'\n              onClick={(e) => {\n                e.stopPropagation();\n                clearDirectory();\n              }}\n            >\n              <FontAwesomeIcon\n                icon={faXmark}\n              />\n            </div>\n          )}\n          {isDirectoryIndexing && <FontAwesomeIcon className='animate-spin' icon={faCircleNotch} />}\n          {isCheckShowing && <FontAwesomeIcon className='text-green-500' icon={faCheckCircle} />}\n        </div>\n        {shortenedDirectory}\n      </Button>\n    </div>\n  );\n};\n\nexport default SelectDirectory;\n"
  },
  {
    "path": "app/src/components/options/SelectModel.tsx",
    "content": "import React from 'react';\nimport {\n  Select,\n  SelectContent,\n  SelectGroup,\n  SelectItem,\n  SelectLabel,\n  SelectTrigger,\n  SelectValue,\n} from '../ui/select';\n\nconst SelectModel = ({\n  selectedModel,\n  handleModelChange,\n}: {\n  selectedModel: string | null;\n  handleModelChange: (model: string) => void;\n}) => (\n  <div className='no-drag'>\n    <Select\n      value={selectedModel ?? ''}\n      onValueChange={(value) => handleModelChange(value)}\n    >\n      <SelectTrigger className='a-icon w-[140px] h-5 border-none shadow-transparent bg-white dark:bg-[#606160] transition-all border border-zinc-600 focus:ring-0 focus-within:ring-0 focus-visible:ring-0 peer-focus-within:ring-0 text-neutral-800 dark:text-white'>\n        <SelectValue placeholder='Select a model' />\n      </SelectTrigger>\n      <SelectContent>\n        <SelectGroup>\n          <SelectLabel>AI Model</SelectLabel>\n          <SelectItem value='mistralai/Mistral-7B-Instruct-v0.2'>Mistral7B</SelectItem>\n          <SelectItem value='google/gemma-7b-it'>Gemma7B</SelectItem>\n        </SelectGroup>\n      </SelectContent>\n    </Select>\n  </div>\n);\n\nexport default SelectModel;\n"
  },
  {
    "path": "app/src/components/ui/button.tsx",
    "content": "import {\n  Slot,\n} from '@radix-ui/react-slot';\nimport {\n  cva,\n  type VariantProps,\n} from 'class-variance-authority';\nimport * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nconst buttonVariants = cva(\n  'inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50',\n  {\n    variants: {\n      variant: {\n        default: 'bg-primary text-primary-foreground shadow hover:bg-primary/90',\n        destructive: 'bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90',\n        outline:\n          'border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground',\n        secondary: 'bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80',\n        ghost: 'hover:bg-accent hover:text-accent-foreground',\n        link: 'text-primary underline-offset-4 hover:underline',\n      },\n      size: {\n        default: 'h-9 px-4 py-2',\n        sm: 'h-8 rounded-md px-3 text-xs',\n        lg: 'h-10 rounded-md px-8',\n        icon: 'h-9 w-9',\n      },\n    },\n    defaultVariants: {\n      variant: 'default',\n      size: 'default',\n    },\n  },\n);\n\nexport interface ButtonProps\n  extends React.ButtonHTMLAttributes<HTMLButtonElement>, VariantProps<typeof buttonVariants>\n{\n  asChild?: boolean;\n}\n\nconst Button = React.forwardRef<HTMLButtonElement, ButtonProps>(\n  ({\n    className,\n    variant,\n    size,\n    asChild = false,\n    ...props\n  }, ref) => {\n    // eslint-disable-next-line @typescript-eslint/naming-convention\n    const Comp = asChild ? Slot : 'button';\n    return (\n      <Comp\n        className={cn(buttonVariants({ variant, size, className }))}\n        ref={ref}\n        {...props}\n      />\n    );\n  },\n);\nButton.displayName = 'Button';\n\nexport {\n  Button,\n  buttonVariants,\n};\n"
  },
  {
    "path": "app/src/components/ui/input.tsx",
    "content": "import * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nexport interface InputProps extends React.InputHTMLAttributes<HTMLInputElement> {}\n\nconst Input = React.forwardRef<HTMLInputElement, InputProps>(\n  ({ className, type, ...props }, ref) => (\n    <input\n      type={type}\n      className={cn(\n        'flex h-9 w-full rounded-md border border-input bg-transparent px-3 py-1 text-sm shadow-0 transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',\n        className,\n      )}\n      ref={ref}\n      {...props}\n    />\n  ),\n);\nInput.displayName = 'Input';\n\nexport {\n  Input,\n};\n"
  },
  {
    "path": "app/src/components/ui/resizable.tsx",
    "content": "'use client';\n\nimport {\n  DragHandleDots2Icon,\n} from '@radix-ui/react-icons';\nimport * as ResizablePrimitive from 'react-resizable-panels';\nimport {\n  cn,\n} from '../../lib/utils';\n\nconst ResizablePanelGroup = ({\n  className,\n  ...props\n}: React.ComponentProps<typeof ResizablePrimitive.PanelGroup>) => (\n  <ResizablePrimitive.PanelGroup\n    className={cn(\n      'flex h-full w-full data-[panel-group-direction=vertical]:flex-col',\n      className,\n    )}\n    {...props}\n  />\n);\n\nconst ResizablePanel = ResizablePrimitive.Panel;\n\nconst ResizableHandle = ({\n  withHandle,\n  className,\n  ...props\n}: React.ComponentProps<typeof ResizablePrimitive.PanelResizeHandle> & {\n  withHandle?: boolean;\n}) => (\n  <ResizablePrimitive.PanelResizeHandle\n    className={cn(\n      'relative flex w-px items-center justify-center bg-border after:absolute after:inset-y-0 after:left-1/2 after:w-1 after:-translate-x-1/2 focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring focus-visible:ring-offset-1 data-[panel-group-direction=vertical]:h-px data-[panel-group-direction=vertical]:w-full data-[panel-group-direction=vertical]:after:left-0 data-[panel-group-direction=vertical]:after:h-1 data-[panel-group-direction=vertical]:after:w-full data-[panel-group-direction=vertical]:after:-translate-y-1/2 data-[panel-group-direction=vertical]:after:translate-x-0 [&[data-panel-group-direction=vertical]>div]:rotate-90',\n      className,\n    )}\n    {...props}\n  >\n    {withHandle && (\n      <div className='z-10 flex h-4 w-3 items-center justify-center rounded-sm border bg-border'>\n        <DragHandleDots2Icon className='h-2.5 w-2.5' />\n      </div>\n    )}\n  </ResizablePrimitive.PanelResizeHandle>\n);\n\nexport {\n  ResizableHandle,\n  ResizablePanel,\n  ResizablePanelGroup,\n};\n"
  },
  {
    "path": "app/src/components/ui/select.tsx",
    "content": "'use client';\n\nimport {\n  CaretSortIcon,\n  CheckIcon,\n  ChevronDownIcon,\n  ChevronUpIcon,\n} from '@radix-ui/react-icons';\nimport * as SelectPrimitive from '@radix-ui/react-select';\nimport * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nconst Select = SelectPrimitive.Root;\n\nconst SelectGroup = SelectPrimitive.Group;\n\nconst SelectValue = SelectPrimitive.Value;\n\nconst SelectTrigger = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.Trigger>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger>\n>(({ className, children, ...props }, ref) => (\n  <SelectPrimitive.Trigger\n    ref={ref}\n    className={cn(\n      'flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus:outline-none disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1',\n      className,\n    )}\n    {...props}\n  >\n    {children}\n    {className?.includes('a-icon')\n      ? (\n        <SelectPrimitive.Icon asChild>\n          <CaretSortIcon className='h-4 w-4 opacity-50' />\n        </SelectPrimitive.Icon>\n      )\n      : null}\n  </SelectPrimitive.Trigger>\n));\nSelectTrigger.displayName = SelectPrimitive.Trigger.displayName;\n\nconst SelectScrollUpButton = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.ScrollUpButton>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollUpButton>\n>(({ className, ...props }, ref) => (\n  <SelectPrimitive.ScrollUpButton\n    ref={ref}\n    className={cn(\n      'flex cursor-default items-center justify-center py-1',\n      className,\n    )}\n    {...props}\n  >\n    <ChevronUpIcon />\n  </SelectPrimitive.ScrollUpButton>\n));\nSelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName;\n\nconst SelectScrollDownButton = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.ScrollDownButton>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.ScrollDownButton>\n>(({ className, ...props }, ref) => (\n  <SelectPrimitive.ScrollDownButton\n    ref={ref}\n    className={cn(\n      'flex cursor-default items-center justify-center py-1',\n      className,\n    )}\n    {...props}\n  >\n    <ChevronDownIcon />\n  </SelectPrimitive.ScrollDownButton>\n));\nSelectScrollDownButton.displayName = SelectPrimitive.ScrollDownButton.displayName;\n\nconst SelectContent = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.Content>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content>\n>(({\n  className,\n  children,\n  position = 'popper',\n  ...props\n}, ref) => (\n  <SelectPrimitive.Portal>\n    <SelectPrimitive.Content\n      ref={ref}\n      className={cn(\n        'relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border bg-popover text-popover-foreground shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',\n        position === 'popper'\n          && 'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1',\n        className,\n      )}\n      position={position}\n      {...props}\n    >\n      <SelectScrollUpButton />\n      <SelectPrimitive.Viewport\n        className={cn(\n          'p-1',\n          position === 'popper'\n            && 'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]',\n        )}\n      >\n        {children}\n      </SelectPrimitive.Viewport>\n      <SelectScrollDownButton />\n    </SelectPrimitive.Content>\n  </SelectPrimitive.Portal>\n));\nSelectContent.displayName = SelectPrimitive.Content.displayName;\n\nconst SelectLabel = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.Label>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Label>\n>(({ className, ...props }, ref) => (\n  <SelectPrimitive.Label\n    ref={ref}\n    className={cn('px-2 py-1.5 text-sm font-semibold', className)}\n    {...props}\n  />\n));\nSelectLabel.displayName = SelectPrimitive.Label.displayName;\n\nconst SelectItem = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.Item>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item>\n>(({ className, children, ...props }, ref) => (\n  <SelectPrimitive.Item\n    ref={ref}\n    className={cn(\n      'relative flex w-full cursor-default select-none items-center rounded-sm py-1.5 pl-2 pr-8 text-sm outline-none focus:bg-accent focus:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50',\n      className,\n    )}\n    {...props}\n  >\n    <span className='absolute right-2 flex h-3.5 w-3.5 items-center justify-center'>\n      <SelectPrimitive.ItemIndicator>\n        <CheckIcon className='h-4 w-4' />\n      </SelectPrimitive.ItemIndicator>\n    </span>\n    <SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>\n  </SelectPrimitive.Item>\n));\nSelectItem.displayName = SelectPrimitive.Item.displayName;\n\nconst SelectSeparator = React.forwardRef<\n  React.ElementRef<typeof SelectPrimitive.Separator>,\n  React.ComponentPropsWithoutRef<typeof SelectPrimitive.Separator>\n>(({ className, ...props }, ref) => (\n  <SelectPrimitive.Separator\n    ref={ref}\n    className={cn('-mx-1 my-1 h-px bg-muted', className)}\n    {...props}\n  />\n));\nSelectSeparator.displayName = SelectPrimitive.Separator.displayName;\n\nexport {\n  Select,\n  SelectContent,\n  SelectGroup,\n  SelectItem,\n  SelectLabel,\n  SelectScrollDownButton,\n  SelectScrollUpButton,\n  SelectSeparator,\n  SelectTrigger,\n  SelectValue,\n};\n"
  },
  {
    "path": "app/src/components/ui/textarea.tsx",
    "content": "import * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nexport interface TextareaProps extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {}\n\nconst Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(\n  ({ className, ...props }, ref) => (\n    <textarea\n      className={cn(\n        'flex min-h-[60px] w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50',\n        className,\n      )}\n      ref={ref}\n      {...props}\n    />\n  ),\n);\nTextarea.displayName = 'Textarea';\n\nexport {\n  Textarea,\n};\n"
  },
  {
    "path": "app/src/components/ui/tooltip.tsx",
    "content": "'use client';\n\nimport * as TooltipPrimitive from '@radix-ui/react-tooltip';\nimport * as React from 'react';\nimport {\n  cn,\n} from '../../lib/utils';\n\nconst TooltipProvider = TooltipPrimitive.Provider;\n\nconst Tooltip = TooltipPrimitive.Root;\n\nconst TooltipTrigger = TooltipPrimitive.Trigger;\n\nconst TooltipContent = React.forwardRef<\n  React.ElementRef<typeof TooltipPrimitive.Content>,\n  React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>\n>(({ className, sideOffset = 4, ...props }, ref) => (\n  <TooltipPrimitive.Content\n    ref={ref}\n    sideOffset={sideOffset}\n    className={cn(\n      'z-50 overflow-hidden rounded-md bg-primary px-3 py-1.5 text-xs text-primary-foreground animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',\n      className,\n    )}\n    {...props}\n  />\n));\nTooltipContent.displayName = TooltipPrimitive.Content.displayName;\n\nexport {\n  Tooltip,\n  TooltipContent,\n  TooltipProvider,\n  TooltipTrigger,\n};\n"
  },
  {
    "path": "app/src/constants/chat.tsx",
    "content": "export type ChatMessage = {\n  role: 'user' | 'assistant' | 'system';\n  content: string | null;\n};\n"
  },
  {
    "path": "app/src/lib/hooks.ts",
    "content": "import {\n  useEffect,\n  useRef,\n  useState,\n} from 'react';\nimport {\n  useDispatch,\n  useSelector,\n  useStore,\n} from 'react-redux';\nimport type {\n  TypedUseSelectorHook,\n} from 'react-redux';\nimport type {\n  AppDispatch,\n  AppStore,\n  RootState,\n} from './store';\n\n// Use throughout your app instead of plain `useDispatch` and `useSelector`\nexport const useAppDispatch: () => AppDispatch = useDispatch;\nexport const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;\nexport const useAppStore: () => AppStore = useStore;\n\nexport function usePrevious<T>(value: T): T | undefined {\n  const ref = useRef<T>();\n  useEffect(() => {\n    ref.current = value;\n  });\n  return ref.current;\n}\n\nexport function convertToNiceShortcut(shortcut: string) {\n  return shortcut.replace('Cmd', '⌘').replace('Option', '⌥').replace('Shift', '⇧').replaceAll(\n    '+',\n    '',\n  );\n}\n\nexport function useKeyboardShortcut() {\n  const [isListening, setIsListening] = useState(false);\n  const [shortcut, setShortcut] = useState('');\n\n  useEffect(() => {\n    const handleKeyDown = (event: KeyboardEvent) => {\n      if (!isListening) { return; }\n\n      // Prevent default action to avoid interfering with normal browser shortcuts\n      event.preventDefault();\n\n      // command key for mac icon\n      //\n\n      const keys = [];\n      if (event.ctrlKey) { keys.push('Ctrl'); }\n      if (event.shiftKey) { keys.push('Shift'); }\n      if (event.altKey) { keys.push('Option'); }\n      if (event.metaKey) {\n        keys.push('Cmd'); // Command key for Mac\n      }\n      // Avoid adding modifier keys alone, check if another key is also pressed\n      if (\n        event.key.length === 1\n        || (event.key !== 'Control' && event.key !== 'Shift' && event.key !== 'Alt'\n          && event.key !== 'Meta')\n      ) {\n        keys.push(event.key.toUpperCase());\n      }\n\n      const combination = keys.join('+');\n      setShortcut(combination);\n    };\n\n    if (isListening) {\n      window.addEventListener('keydown', handleKeyDown);\n    }\n\n    return () => {\n      window.removeEventListener('keydown', handleKeyDown);\n    };\n  }, [isListening]);\n\n  const startListening = () => setIsListening(true);\n  const stopListening = () => setIsListening(false);\n\n  return {\n    startListening,\n    stopListening,\n    shortcut,\n  };\n}\n"
  },
  {
    "path": "app/src/lib/store.ts",
    "content": "import {\n  configureStore,\n  createSlice,\n} from '@reduxjs/toolkit';\n\nconst globalSlice = createSlice({\n  name: 'global',\n  initialState: {\n    isDirectoryIndexing: false,\n    isWaitingForResponse: false,\n  },\n  reducers: {\n    startDirectoryIndexing: (state) => {\n      // eslint-disable-next-line no-param-reassign\n      state.isDirectoryIndexing = true;\n    },\n    stopDirectoryIndexing: (state) => {\n      // eslint-disable-next-line no-param-reassign\n      state.isDirectoryIndexing = false;\n    },\n    startWaitingForResponse: (state) => {\n      // eslint-disable-next-line no-param-reassign\n      state.isWaitingForResponse = true;\n    },\n    stopWaitingForResponse: (state) => {\n      // eslint-disable-next-line no-param-reassign\n      state.isWaitingForResponse = false;\n    },\n  },\n});\n\nexport const {\n  startDirectoryIndexing,\n  stopDirectoryIndexing,\n  startWaitingForResponse,\n  stopWaitingForResponse,\n} = globalSlice.actions;\n\nexport const makeStore = () =>\n  configureStore({\n    reducer: globalSlice.reducer,\n  });\n\n// Infer the type of makeStore\nexport type AppStore = ReturnType<typeof makeStore>;\n// Infer the `RootState` and `AppDispatch` types from the store itself\nexport type RootState = ReturnType<AppStore['getState']>;\nexport type AppDispatch = AppStore['dispatch'];\n"
  },
  {
    "path": "app/src/lib/utils.ts",
    "content": "import {\n  type ClassValue,\n  clsx,\n} from 'clsx';\nimport {\n  twMerge,\n} from 'tailwind-merge';\n\nexport function cn(...inputs: ClassValue[]) {\n  return twMerge(clsx(inputs));\n}\n"
  },
  {
    "path": "app/tailwind.config.main.js",
    "content": "/** @type {import('tailwindcss').Config} */\nmodule.exports = {\n  content: [\"./main/**/*.{js,ts,jsx,tsx,mdx,html}\"],\n  //   purge: ['./subdir/index.html',     './src/components/**/*.{js,ts,jsx,tsx,mdx}',\n  // ],\n  theme: {\n    extend: {},\n  },\n  plugins: [],\n};\n"
  },
  {
    "path": "app/tailwind.config.ts",
    "content": "/* eslint-disable @typescript-eslint/naming-convention */\nimport type {\n  Config,\n} from 'tailwindcss';\n\nconst config = {\n  darkMode: 'media',\n  content: [\n    './pages/**/*.{ts,tsx}',\n    './components/**/*.{ts,tsx}',\n    './app/**/*.{ts,tsx}',\n    './src/**/*.{ts,tsx}',\n  ],\n  prefix: '',\n  theme: {\n    container: {\n      center: true,\n      padding: '2rem',\n      screens: {\n        '2xl': '1400px',\n      },\n    },\n    extend: {\n      colors: {\n        border: 'hsl(var(--border))',\n        input: 'hsl(var(--input))',\n        ring: 'hsl(var(--ring))',\n        background: 'hsl(var(--background))',\n        foreground: 'hsl(var(--foreground))',\n        primary: {\n          DEFAULT: 'hsl(var(--primary))',\n          foreground: 'hsl(var(--primary-foreground))',\n        },\n        secondary: {\n          DEFAULT: 'hsl(var(--secondary))',\n          foreground: 'hsl(var(--secondary-foreground))',\n        },\n        destructive: {\n          DEFAULT: 'hsl(var(--destructive))',\n          foreground: 'hsl(var(--destructive-foreground))',\n        },\n        muted: {\n          DEFAULT: 'hsl(var(--muted))',\n          foreground: 'hsl(var(--muted-foreground))',\n        },\n        accent: {\n          DEFAULT: 'hsl(var(--accent))',\n          foreground: 'hsl(var(--accent-foreground))',\n        },\n        popover: {\n          DEFAULT: 'hsl(var(--popover))',\n          foreground: 'hsl(var(--popover-foreground))',\n        },\n        card: {\n          DEFAULT: 'hsl(var(--card))',\n          foreground: 'hsl(var(--card-foreground))',\n        },\n      },\n      borderRadius: {\n        lg: 'var(--radius)',\n        md: 'calc(var(--radius) - 2px)',\n        sm: 'calc(var(--radius) - 4px)',\n      },\n      keyframes: {\n        'accordion-down': {\n          from: { height: '0' },\n          to: { height: 'var(--radix-accordion-content-height)' },\n        },\n        'accordion-up': {\n          from: { height: 'var(--radix-accordion-content-height)' },\n          to: { height: '0' },\n        },\n      },\n      animation: {\n        'accordion-down': 'accordion-down 0.2s ease-out',\n        'accordion-up': 'accordion-up 0.2s ease-out',\n      },\n    },\n  },\n  // eslint-disable-next-line global-require\n  plugins: [require('tailwindcss-animate')],\n} satisfies Config;\n\nexport default config;\n"
  },
  {
    "path": "app/tsconfig.json",
    "content": "{\n  \"compilerOptions\": {\n    \"target\": \"es5\",\n    \"lib\": [\n      \"dom\",\n      \"dom.iterable\",\n      \"esnext\",\n    ],\n    \"allowJs\": true,\n    \"skipLibCheck\": true,\n    \"strict\": true,\n    \"forceConsistentCasingInFileNames\": true,\n    \"noEmit\": true,\n    \"esModuleInterop\": true,\n    \"module\": \"esnext\",\n    \"moduleResolution\": \"node\",\n    \"resolveJsonModule\": true,\n    \"isolatedModules\": true,\n    \"jsx\": \"preserve\",\n    \"incremental\": true,\n    \"plugins\": [\n      {\n        \"name\": \"next\",\n      },\n    ],\n    \"paths\": {\n      \"@/*\": [\n        \"./src/*\",\n      ],\n    },\n  },\n  \"include\": [\n    \"next-env.d.ts\",\n    \"**/*.ts\",\n    \"**/*.tsx\",\n    \".next/types/**/*.ts\",\n    \"build/types/**/*.ts\",\n    \"main/preload.ts\",\n    \"main/main.ts\",\n    \"out/types/**/*.ts\",\n    \".eslintrc.cjs\",\n  ],\n  \"exclude\": [\n    \"node_modules\",\n    \"out\",\n    \"build\",\n  ],\n}\n"
  },
  {
    "path": "runner.py",
    "content": "# Parent script to package (PyInstaller) server\n#\n# Example Usage:\n#\n# pyinstaller --onefile --collect-all mlx --copy-metadata opentelemetry-sdk \\\n# --hidden-import server.models --hidden-import server.models.gemma --hidden-import server.models.bert --hidden-import server.models.llama \\\n# runner.py\n\nfrom server import server\nserver.main()\n"
  },
  {
    "path": "runner.sh",
    "content": "#!/bin/bash\n\ncollect_modules=(\n  \"mlx\"\n  \"chromadb\"\n)\n\nhidden_imports=(\n  \"server.models\"\n  \"server.models.gemma\"\n  \"server.models.bert\"\n  \"server.models.llama\"\n)\n\nexclude_modules=(\n  \"matplotlib\"\n  \"pandas\"\n  \"PIL\"\n  \"IPython\"\n)\n\nmisc_params=(\n  \"--copy-metadata opentelemetry-sdk\"\n)\n\ncommand=\"pyinstaller --onefile runner.py\"\n\nfor module in \"${collect_modules[@]}\"; do\n  command+=\" --collect-all $module\"\ndone\nfor module in \"${hidden_imports[@]}\"; do\n  command+=\" --hidden-import $module\"\ndone\nfor module in \"${exclude_modules[@]}\"; do\n  command+=\" --exclude-module $module\"\ndone\nfor param in \"${misc_params[@]}\"; do\n  command+=\" $param\"\ndone\n\neval \"$command\"\n"
  },
  {
    "path": "server/__init__.py",
    "content": "from .utils import generate, load, convert\n\n__version__ = \"0.1.0\"\n"
  },
  {
    "path": "server/convert.py",
    "content": "import argparse\n\nfrom .utils import convert\n\n\ndef configure_parser() -> argparse.ArgumentParser:\n    \"\"\"\n    Configures and returns the argument parser for the script.\n\n    Returns:\n        argparse.ArgumentParser: Configured argument parser.\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Convert Hugging Face model to MLX format\"\n    )\n\n    parser.add_argument(\"--hf-path\", type=str,\n                        help=\"Path to the Hugging Face model.\")\n    parser.add_argument(\n        \"--mlx-path\", type=str, default=\"mlx_model\", help=\"Path to save the MLX model.\"\n    )\n    parser.add_argument(\n        \"-q\", \"--quantize\", help=\"Generate a quantized model.\", action=\"store_true\"\n    )\n    parser.add_argument(\n        \"--q-group-size\", help=\"Group size for quantization.\", type=int, default=64\n    )\n    parser.add_argument(\n        \"--q-bits\", help=\"Bits per weight for quantization.\", type=int, default=4\n    )\n    parser.add_argument(\n        \"--dtype\",\n        help=\"Type to save the parameters, ignored if -q is given.\",\n        type=str,\n        choices=[\"float16\", \"bfloat16\", \"float32\"],\n        default=\"float16\",\n    )\n    parser.add_argument(\n        \"--upload-repo\",\n        help=\"The Hugging Face repo to upload the model to.\",\n        type=str,\n        default=None,\n    )\n    return parser\n\n\nif __name__ == \"__main__\":\n    parser = configure_parser()\n    args = parser.parse_args()\n    convert(**vars(args))\n"
  },
  {
    "path": "server/models/__init__.py",
    "content": ""
  },
  {
    "path": "server/models/base.py",
    "content": "import inspect\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass BaseModelArgs:\n    @classmethod\n    def from_dict(cls, params):\n        return cls(\n            **{\n                k: v\n                for k, v in params.items()\n                if k in inspect.signature(cls).parameters\n            }\n        )\n"
  },
  {
    "path": "server/models/bert.py",
    "content": "import math\nimport inspect\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom dataclasses import dataclass\nfrom typing import List, Dict, Optional, Tuple, Union, Callable\n\nfrom .base import BaseModelArgs\n\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    classifier_dropout: float\n    hidden_act: str\n    hidden_dropout_prob: float\n    hidden_size: int\n    initializer_range: float\n    intermediate_size: int\n    layer_norm_eps: float\n    max_position_embeddings: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    pad_token_id: int\n    position_embedding_type: str\n    torch_dtype: str\n    type_vocab_size: int\n    use_cache: bool\n    vocab_size: int\n    chunk_size_feed_forward: int = None\n    attention_probs_dropout_prob: float = 0.0\n    is_decoder: bool = False\n    add_cross_attention: bool = False\n    output_attentions: bool = False\n    output_hidden_states: bool = False\n    use_return_dict: bool = True\n\n\ndef apply_chunking_to_forward(\n    forward_fn: Callable[..., mx.array], chunk_size: int, chunk_dim: int, *input_tensors\n) -> mx.array:\n    \"\"\"\n    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension\n    `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.\n\n    If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly\n    applying `forward_fn` to `input_tensors`.\n\n    Args:\n        forward_fn (`Callable[..., mx.array]`):\n            The forward function of the model.\n        chunk_size (`int`):\n            The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.\n        chunk_dim (`int`):\n            The dimension over which the `input_tensors` should be chunked.\n        input_tensors (`Tuple[mx.array]`):\n            The input tensors of `forward_fn` which will be chunked\n\n    Returns:\n        `mx.array`: A tensor with the same shape as the `forward_fn` would have given if applied`.\n\n\n    Examples:\n\n    ```python\n    # rename the usual forward() fn to forward_chunk()\n    def __call___chunk(self, hidden_states):\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\n    # implement a chunked forward function\n    def __call__(self, hidden_states):\n        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)\n    ```\"\"\"\n\n    assert len(input_tensors) > 0, f\"{\n        input_tensors} has to be a tuple/list of tensors\"\n\n    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility\n    num_args_in_forward_chunk_fn = len(\n        inspect.signature(forward_fn).parameters)\n    if num_args_in_forward_chunk_fn != len(input_tensors):\n        raise ValueError(\n            f\"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {\n                len(input_tensors)} input \"\n            \"tensors are given\"\n        )\n\n    if chunk_size > 0:\n        tensor_shape = input_tensors[0].shape[chunk_dim]\n        for input_tensor in input_tensors:\n            if input_tensor.shape[chunk_dim] != tensor_shape:\n                raise ValueError(\n                    f\"All input tenors have to be of the same shape: {\n                        tensor_shape}, \"\n                    f\"found shape {input_tensor.shape[chunk_dim]}\"\n                )\n\n        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:\n            raise ValueError(\n                f\"The dimension to be chunked {\n                    input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk \"\n                f\"size {chunk_size}\"\n            )\n\n        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size\n\n        # chunk input tensor into tuples\n        input_tensors_chunks = tuple(input_tensor.chunk(\n            num_chunks, dim=chunk_dim) for input_tensor in input_tensors)\n        # apply forward fn to every tuple\n        output_chunks = tuple(forward_fn(*input_tensors_chunk)\n                              for input_tensors_chunk in zip(*input_tensors_chunks))\n        # concatenate output at same dimension\n        return mx.concatenate(output_chunks, dim=chunk_dim)\n\n    return forward_fn(*input_tensors)\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            config.vocab_size, config.hidden_size)\n        self.position_embeddings = nn.Embedding(\n            config.max_position_embeddings, config.hidden_size)\n        self.token_type_embeddings = nn.Embedding(\n            config.type_vocab_size, config.hidden_size)\n\n        self._position_ids = mx.expand_dims(\n            mx.arange(0, config.max_position_embeddings), axis=0)\n        self._token_type_ids = mx.zeros((self._position_ids.shape))\n\n        self.LayerNorm = nn.LayerNorm(\n            config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.position_embedding_type = getattr(\n            config, \"position_embedding_type\", \"absolute\")\n\n    def __call__(\n        self,\n        input_ids: Optional[float] = None,\n        token_type_ids: Optional[float] = None,\n        position_ids: Optional[float] = None,\n        inputs_embeds: Optional[float] = None,\n        past_key_values_length: int = 0,\n    ) -> mx.array:\n        if input_ids is not None:\n            input_shape = input_ids.shape\n        else:\n            input_shape = inputs_embeds.shape[:-1]\n\n        seq_length = input_shape[1]\n\n        if position_ids is None:\n            position_ids = self._position_ids[:, past_key_values_length:\n                                              seq_length + past_key_values_length]\n\n        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs\n        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves\n        # issue #5664\n        if token_type_ids is None:\n            if hasattr(self, \"_token_type_ids\"):\n                buffered_token_type_ids = self._token_type_ids[:, :seq_length]\n                buffered_token_type_ids_expanded = mx.repeat(\n                    buffered_token_type_ids, input_shape[0], axis=0)\n                token_type_ids = buffered_token_type_ids_expanded.astype(\n                    mx.int8)\n            else:\n                token_type_ids = mx.zeros(\n                    input_shape, dtype=mx.float16)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        token_type_embeddings = self.token_type_embeddings(token_type_ids)\n\n        embeddings = inputs_embeds + token_type_embeddings\n        if self.position_embedding_type == \"absolute\":\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings += position_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, position_embedding_type=None):\n        super().__init__()\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n            raise ValueError(\n                f\"Hidden size ({config.hidden_size}) is not a multiple of \"\n                f\"the number of attention heads ({config.num_attention_heads})\"\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(\n            config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = position_embedding_type or getattr(\n            config, \"position_embedding_type\", \"absolute\"\n        )\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(\n                2 * config.max_position_embeddings - 1, self.attention_head_size)\n\n        self.is_decoder = config.is_decoder\n\n    def transpose_for_scores(self, x: mx.array) -> mx.array:\n        new_x_shape = x.shape[\n            :-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.reshape(new_x_shape)\n        return x.transpose([0, 2, 1, 3])\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: Optional[float] = None,\n        head_mask: Optional[float] = None,\n        encoder_hidden_states: Optional[float] = None,\n        encoder_attention_mask: Optional[float] = None,\n        past_key_value: Optional[Tuple[Tuple[float]]] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[mx.array]:\n        mixed_query_layer = self.query(hidden_states)\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n        if is_cross_attention and past_key_value is not None:\n            # reuse k,v, cross_attentions\n            key_layer = past_key_value[0]\n            value_layer = past_key_value[1]\n            attention_mask = encoder_attention_mask\n        elif is_cross_attention:\n            key_layer = self.transpose_for_scores(\n                self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(\n                self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = mx.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = mx.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        use_cache = past_key_value is not None\n        if self.is_decoder:\n            # if cross_attention save Tuple(mx.array, mx.array) of all cross attention key/value_states.\n            # Further calls to cross_attention layer can then reuse all cross-attention\n            # key/value_states (first \"if\" case)\n            # if uni-directional self-attention (decoder) save Tuple(mx.array, mx.array) of\n            # all previous decoder key/value_states. Further calls to uni-directional self-attention\n            # can concat previous decoder key/value_states to current projected key/value_states (third \"elif\" case)\n            # if encoder bi-directional self-attention `past_key_value` is always `None`\n            past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = mx.matmul(\n            query_layer, key_layer.transpose([0, 1, -1, -2]))\n\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            query_length, key_length = query_layer.shape[2], key_layer.shape[2]\n            if use_cache:\n                position_ids_l = mx.array(key_length - 1, dtype=mx.float16).reshape(\n                    -1, 1\n                )\n            else:\n                position_ids_l = mx.arange(\n                    query_length, dtype=mx.float16).reshape(-1, 1)\n            position_ids_r = mx.arange(\n                key_length, dtype=mx.float16).reshape(1, -1)\n            distance = position_ids_l - position_ids_r\n\n            positional_embedding = self.distance_embedding(\n                distance + self.max_position_embeddings - 1)\n            positional_embedding = positional_embedding.to(\n                dtype=query_layer.dtype)  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = mx.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding)\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = mx.einsum(\n                    \"bhld,lrd->bhlr\", query_layer, positional_embedding)\n                relative_position_scores_key = mx.einsum(\n                    \"bhrd,lrd->bhlr\", key_layer, positional_embedding)\n                attention_scores = attention_scores + \\\n                    relative_position_scores_query + relative_position_scores_key\n\n        attention_scores = attention_scores / \\\n            math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = mx.softmax(attention_scores, axis=-1)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs = attention_probs * head_mask\n\n        context_layer = mx.matmul(attention_probs, value_layer)\n\n        context_layer = context_layer.transpose([0, 2, 1, 3])\n        new_context_layer_shape = context_layer.shape[\n            :-2] + (self.all_head_size,)\n        context_layer = context_layer.reshape(new_context_layer_shape)\n\n        outputs = (context_layer, attention_probs) if output_attentions else (\n            context_layer,)\n\n        if self.is_decoder:\n            outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(\n            config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, position_embedding_type=None):\n        super().__init__()\n        self.self = BertSelfAttention(\n            config, position_embedding_type=position_embedding_type)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: Optional[float] = None,\n        head_mask: Optional[float] = None,\n        encoder_hidden_states: Optional[float] = None,\n        encoder_attention_mask: Optional[float] = None,\n        past_key_value: Optional[Tuple[Tuple[float]]] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[mx.array]:\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n        # add attentions if we output them\n        outputs = (attention_output,) + self_outputs[1:]\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        self.intermediate_act_fn = nn.GELU()\n\n    def __call__(self, hidden_states: mx.array) -> mx.array:\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(\n            config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n        self.is_decoder = config.is_decoder\n        self.add_cross_attention = config.add_cross_attention\n        if self.add_cross_attention:\n            if not self.is_decoder:\n                raise ValueError(\n                    f\"{self} should be used as a decoder model if cross attention is added\")\n            self.crossattention = BertAttention(\n                config, position_embedding_type=\"absolute\")\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: Optional[float] = None,\n        head_mask: Optional[float] = None,\n        encoder_hidden_states: Optional[float] = None,\n        encoder_attention_mask: Optional[float] = None,\n        past_key_value: Optional[Tuple[Tuple[float]]] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[mx.array]:\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = past_key_value[:\n                                                  2] if past_key_value is not None else None\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n\n        # if decoder, the last output is tuple of self-attn cache\n        if self.is_decoder:\n            outputs = self_attention_outputs[1:-1]\n            present_key_value = self_attention_outputs[-1]\n        else:\n            # add self attentions if we output attention weights\n            outputs = self_attention_outputs[1:]\n\n        cross_attn_present_key_value = None\n        if self.is_decoder and encoder_hidden_states is not None:\n            if not hasattr(self, \"crossattention\"):\n                raise ValueError(\n                    f\"If `encoder_hidden_states` are passed, {\n                        self} has to be instantiated with cross-attention layers\"\n                    \" by setting `config.add_cross_attention=True`\"\n                )\n\n            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple\n            cross_attn_past_key_value = past_key_value[-2:\n                                                       ] if past_key_value is not None else None\n            cross_attention_outputs = self.crossattention(\n                attention_output,\n                attention_mask,\n                head_mask,\n                encoder_hidden_states,\n                encoder_attention_mask,\n                cross_attn_past_key_value,\n                output_attentions,\n            )\n            attention_output = cross_attention_outputs[0]\n            # add cross attentions if we output attention weights\n            outputs = outputs + cross_attention_outputs[1:-1]\n\n            # add cross-attn cache to positions 3,4 of present_key_value tuple\n            cross_attn_present_key_value = cross_attention_outputs[-1]\n            present_key_value = present_key_value + cross_attn_present_key_value\n\n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output\n        )\n        outputs = (layer_output,) + outputs\n\n        # if decoder, return the attn key/values as the last output\n        if self.is_decoder:\n            outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = [\n            BertLayer(config) for _ in range(config.num_hidden_layers)\n        ]\n        self.gradient_checkpointing = False\n\n    def __call__(\n        self,\n        hidden_states: mx.array,\n        attention_mask: Optional[float] = None,\n        head_mask: Optional[float] = None,\n        encoder_hidden_states: Optional[float] = None,\n        encoder_attention_mask: Optional[float] = None,\n        past_key_values: Optional[Tuple[Tuple[float]]] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = False,\n        output_hidden_states: Optional[bool] = False,\n        return_dict: Optional[bool] = True,\n    ) -> Union[Tuple[mx.array], Dict]:\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                use_cache = False\n\n        next_decoder_cache = () if use_cache else None\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    layer_module.__call__,\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n                if self.config.add_cross_attention:\n                    all_cross_attentions = all_cross_attentions + \\\n                        (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return dict(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def __call__(self, hidden_states: mx.array) -> mx.array:\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertModel(nn.Module):\n    \"\"\"\n\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in [Attention is\n    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n\n    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set\n    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and\n    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config: ModelArgs, add_pooling_layer=True):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n        self.encoder = BertEncoder(config)\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def get_extended_attention_mask(\n        self, attention_mask: mx.array, input_shape: Tuple[int], dtype: float = mx.float16\n    ) -> mx.array:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (`mx.array`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (`Tuple[int]`):\n                The shape of the input to the model.\n\n        Returns:\n            `mx.array` The extended attention mask, with a the same dtype as `attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if len(attention_mask.shape) == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif len(attention_mask.shape) == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if self.config.is_decoder:\n                pass\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                f\"Wrong shape for input_ids (shape {input_shape}) \"\n                f\"or attention_mask (shape {attention_mask.shape})\"\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and the dtype's smallest value for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.astype(\n            dtype=dtype)  # fp16 compatibility\n        extended_attention_mask = (\n            # torch.finfo(torch.float16).min\n            1.0 - extended_attention_mask) * -65504\n        return extended_attention_mask\n\n    def get_head_mask(\n        self, head_mask: Optional[mx.array], num_hidden_layers: int, is_attention_chunked: bool = False\n    ) -> mx.array:\n        \"\"\"\n        Prepare the head mask if needed.\n\n        Args:\n            head_mask (`mx.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):\n                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).\n            num_hidden_layers (`int`):\n                The number of hidden layers in the model.\n            is_attention_chunked (`bool`, *optional*, defaults to `False`):\n                Whether or not the attentions scores are computed by chunks or not.\n\n        Returns:\n            `mx.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with\n            `[None]` for each layer.\n        \"\"\"\n        if head_mask is not None:\n            head_mask = self._convert_head_mask_to_5d(\n                head_mask, num_hidden_layers)\n            if is_attention_chunked is True:\n                head_mask = head_mask.unsqueeze(-1)\n        else:\n            head_mask = [None] * num_hidden_layers\n\n        return head_mask\n\n    def __call__(\n        self,\n        input_ids: Optional[mx.array] = None,\n        attention_mask: Optional[mx.array] = None,\n        token_type_ids: Optional[mx.array] = None,\n        position_ids: Optional[mx.array] = None,\n        head_mask: Optional[mx.array] = None,\n        inputs_embeds: Optional[mx.array] = None,\n        encoder_hidden_states: Optional[mx.array] = None,\n        encoder_attention_mask: Optional[mx.array] = None,\n        past_key_values: Optional[List[float]] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[mx.array], dict]:\n        r\"\"\"\n        encoder_hidden_states  (`mx.float16` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (`mx.float16` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (`tuple(tuple(mx.float16))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if self.config.is_decoder:\n            use_cache = use_cache if use_cache is not None else self.config.use_cache\n        else:\n            use_cache = False\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.shape\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.shape[:-1]\n        else:\n            raise ValueError(\n                \"You have to specify either input_ids or inputs_embeds\")\n\n        batch_size, seq_length = input_shape\n\n        # past_key_values_length\n        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n\n        if attention_mask is None:\n            attention_mask = mx.ones(\n                ((batch_size, seq_length + past_key_values_length)))\n\n        if token_type_ids is None:\n            if hasattr(self.embeddings, \"_token_type_ids\"):\n                buffered_token_type_ids = self.embeddings._token_type_ids[:, :seq_length]\n                buffered_token_type_ids_expanded = mx.repeat(\n                    buffered_token_type_ids, batch_size, axis=0)\n                token_type_ids = buffered_token_type_ids_expanded.astype(\n                    mx.int8)\n            else:\n                token_type_ids = mx.zeros(\n                    input_shape, dtype=mx.float16)\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask: mx.array = self.get_extended_attention_mask(\n            attention_mask, input_shape)\n        encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(\n            head_mask, self.config.num_hidden_layers)\n\n        embedding_output = self.embeddings(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            token_type_ids=token_type_ids,\n            inputs_embeds=inputs_embeds,\n            past_key_values_length=past_key_values_length,\n        )\n\n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = encoder_outputs['last_hidden_state'] if return_dict else encoder_outputs\n        pooled_output = self.pooler(\n            sequence_output) if self.pooler is not None else None\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return dict(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs['past_key_values'],\n            hidden_states=encoder_outputs['hidden_states'],\n            attentions=encoder_outputs['attentions'],\n            cross_attentions=encoder_outputs['cross_attentions'],\n        )\n\n\nclass Model(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.model_type = args.model_type\n        self.model = BertModel(args)\n\n    def __call__(\n        self,\n        input_ids: mx.array,\n        attention_mask: mx.array = None,\n    ):\n        return self.model(input_ids, attention_mask)\n\n    @staticmethod\n    def sanitize(weights):\n        # remove position_ids and add model.\n        return {\n            f'model.{k}' if not 'model' in k else k: v for k, v in weights.items() if 'embeddings.position_ids' not in k\n        }\n\n    @property\n    def layers(self):\n        return self.model.layers\n"
  },
  {
    "path": "server/models/gemma.py",
    "content": "from dataclasses import dataclass\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .base import BaseModelArgs\n\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    hidden_size: int\n    num_hidden_layers: int\n    intermediate_size: int\n    num_attention_heads: int\n    head_dim: int\n    rms_norm_eps: float\n    vocab_size: int\n    num_key_value_heads: int = None\n    rope_theta: float = 10000\n    rope_traditional: bool = False\n\n\n@partial(mx.compile, shapeless=True)\ndef rms_norm(x, weight, eps):\n    x = x.astype(mx.float32)\n    x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)\n    return (1.0 + weight) * x.astype(weight.dtype)\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dims: int, eps: float = 1e-5):\n        super().__init__()\n        self.weight = mx.ones((dims,))\n        self.eps = eps\n\n    def __call__(self, x):\n        return rms_norm(x, self.weight, self.eps)\n\n\nclass Attention(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n\n        dim = args.hidden_size\n        self.n_heads = n_heads = args.num_attention_heads\n        self.n_kv_heads = n_kv_heads = args.num_key_value_heads\n        self.head_dim = head_dim = args.head_dim\n\n        self.repeats = n_heads // n_kv_heads\n\n        self.scale = head_dim**-0.5\n\n        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)\n        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)\n        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)\n        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)\n\n        self.rope = nn.RoPE(\n            head_dim,\n            traditional=args.rope_traditional,\n            base=args.rope_theta,\n        )\n\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Tuple[mx.array, mx.array]] = None,\n    ) -> mx.array:\n        B, L, D = x.shape\n\n        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)\n\n        # Prepare the queries, keys and values for the attention computation\n        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)\n        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)\n        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)\n\n        if self.repeats > 1:\n            keys = mx.repeat(keys, self.repeats, axis=1)\n            values = mx.repeat(values, self.repeats, axis=1)\n\n        if cache is not None:\n            key_cache, value_cache = cache\n            queries = self.rope(queries, offset=key_cache.shape[2])\n            keys = self.rope(keys, offset=key_cache.shape[2])\n            keys = mx.concatenate([key_cache, keys], axis=2)\n            values = mx.concatenate([value_cache, values], axis=2)\n        else:\n            queries = self.rope(queries)\n            keys = self.rope(keys)\n\n        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)\n        if mask is not None:\n            scores += mask\n        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)\n        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)\n        return self.o_proj(output), (keys, values)\n\n\nclass MLP(nn.Module):\n    def __init__(self, dim, hidden_dim):\n        super().__init__()\n        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)\n        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)\n        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)\n\n    def __call__(self, x) -> mx.array:\n        return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.num_attention_heads = args.num_attention_heads\n        self.hidden_size = args.hidden_size\n        self.self_attn = Attention(args)\n        self.mlp = MLP(args.hidden_size, args.intermediate_size)\n        self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n        self.args = args\n\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Tuple[mx.array, mx.array]] = None,\n    ) -> mx.array:\n        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)\n        h = x + r\n        r = self.mlp(self.post_attention_layernorm(h))\n        out = h + r\n        return out, cache\n\n\nclass GemmaModel(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.args = args\n        self.vocab_size = args.vocab_size\n        self.num_hidden_layers = args.num_hidden_layers\n        assert self.vocab_size > 0\n        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)\n        self.layers = [\n            TransformerBlock(args=args) for _ in range(args.num_hidden_layers)\n        ]\n        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache=None,\n    ):\n        h = self.embed_tokens(inputs)\n        h = h * (self.args.hidden_size**0.5)\n\n        mask = None\n        if h.shape[1] > 1:\n            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])\n            mask = mask.astype(h.dtype)\n\n        if cache is None:\n            cache = [None] * len(self.layers)\n\n        for e, layer in enumerate(self.layers):\n            h, cache[e] = layer(h, mask, cache[e])\n\n        return self.norm(h), cache\n\n\nclass Model(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.model_type = args.model_type\n        self.model = GemmaModel(args)\n\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache=None,\n    ):\n        out, cache = self.model(inputs, cache)\n        out = out @ self.model.embed_tokens.weight.T\n        return out, cache\n\n    @property\n    def layers(self):\n        return self.model.layers\n"
  },
  {
    "path": "server/models/layers.py",
    "content": "from functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\n@partial(mx.compile, shapeless=True)\ndef rms_norm(x, weight, eps):\n    x = x.astype(mx.float32)\n    x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)\n    return weight * x.astype(weight.dtype)\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dims: int, eps: float = 1e-5):\n        super().__init__()\n        self.weight = mx.ones((dims,))\n        self.eps = eps\n\n    def __call__(self, x):\n        return rms_norm(x, self.weight, self.eps)\n\n\n@partial(mx.compile, shapeless=True)\ndef ln_norm(x, eps, weight=None, bias=None):\n    t = x.dtype\n    x = x.astype(mx.float32)\n    means = mx.mean(x, axis=-1, keepdims=True)\n    var = mx.var(x, axis=-1, keepdims=True)\n    x = (x - means) * mx.rsqrt(var + eps)\n    x = x.astype(t)\n    return weight * x + bias if weight is not None else x\n\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):\n        super().__init__()\n        if affine:\n            self.bias = mx.zeros((dims,))\n            self.weight = mx.ones((dims,))\n        self.eps = eps\n        self.dims = dims\n\n    def _extra_repr(self):\n        return f\"{self.dims}, eps={self.eps}, affine={'weight' in self}\"\n\n    def __call__(self, x: mx.array) -> mx.array:\n        if \"weight\" in self:\n            return ln_norm(x, self.eps, self.weight, self.bias)\n        else:\n            return ln_norm(x, self.eps)\n"
  },
  {
    "path": "server/models/llama.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .base import BaseModelArgs\nfrom .layers import RMSNorm\n\n\n@dataclass\nclass ModelArgs(BaseModelArgs):\n    model_type: str\n    hidden_size: int\n    num_hidden_layers: int\n    intermediate_size: int\n    num_attention_heads: int\n    rms_norm_eps: float\n    vocab_size: int\n    num_key_value_heads: int = None\n    rope_theta: float = 10000\n    rope_traditional: bool = False\n    rope_scaling: Optional[Dict[str, Union[float, str]]] = None\n\n    def __post_init__(self):\n        if self.num_key_value_heads is None:\n            self.num_key_value_heads = self.num_attention_heads\n\n        if self.rope_scaling:\n            required_keys = {\"factor\", \"type\"}\n            if not all(key in self.rope_scaling for key in required_keys):\n                raise ValueError(f\"rope_scaling must contain keys {required_keys}\")\n\n            if self.rope_scaling[\"type\"] != \"linear\":\n                raise ValueError(\"rope_scaling 'type' currently only supports 'linear'\")\n\n\nclass Attention(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n\n        dim = args.hidden_size\n        self.n_heads = n_heads = args.num_attention_heads\n        self.n_kv_heads = n_kv_heads = args.num_key_value_heads\n\n        self.repeats = n_heads // n_kv_heads\n\n        head_dim = args.hidden_size // n_heads\n        self.scale = head_dim**-0.5\n\n        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)\n        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)\n        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)\n        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)\n\n        rope_scale = (\n            1 / args.rope_scaling[\"factor\"]\n            if args.rope_scaling is not None and args.rope_scaling[\"type\"] == \"linear\"\n            else 1\n        )\n        self.rope = nn.RoPE(\n            head_dim,\n            traditional=args.rope_traditional,\n            base=args.rope_theta,\n            scale=rope_scale,\n        )\n\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Tuple[mx.array, mx.array]] = None,\n    ) -> mx.array:\n        B, L, D = x.shape\n\n        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)\n\n        # Prepare the queries, keys and values for the attention computation\n        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)\n        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)\n        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)\n\n        if self.repeats > 1:\n            keys = mx.repeat(keys, self.repeats, axis=1)\n            values = mx.repeat(values, self.repeats, axis=1)\n\n        if cache is not None:\n            key_cache, value_cache = cache\n            queries = self.rope(queries, offset=key_cache.shape[2])\n            keys = self.rope(keys, offset=key_cache.shape[2])\n            keys = mx.concatenate([key_cache, keys], axis=2)\n            values = mx.concatenate([value_cache, values], axis=2)\n        else:\n            queries = self.rope(queries)\n            keys = self.rope(keys)\n\n        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)\n        if mask is not None:\n            scores += mask\n        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)\n        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)\n        return self.o_proj(output), (keys, values)\n\n\nclass MLP(nn.Module):\n    def __init__(self, dim, hidden_dim):\n        super().__init__()\n        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)\n        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)\n        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)\n\n    def __call__(self, x) -> mx.array:\n        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.num_attention_heads = args.num_attention_heads\n        self.hidden_size = args.hidden_size\n        self.self_attn = Attention(args)\n        self.mlp = MLP(args.hidden_size, args.intermediate_size)\n        self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n        self.args = args\n\n    def __call__(\n        self,\n        x: mx.array,\n        mask: Optional[mx.array] = None,\n        cache: Optional[Tuple[mx.array, mx.array]] = None,\n    ) -> mx.array:\n        r, cache = self.self_attn(self.input_layernorm(x), mask, cache)\n        h = x + r\n        r = self.mlp(self.post_attention_layernorm(h))\n        out = h + r\n        return out, cache\n\n\nclass LlamaModel(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.args = args\n        self.vocab_size = args.vocab_size\n        self.num_hidden_layers = args.num_hidden_layers\n        assert self.vocab_size > 0\n        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)\n        self.layers = [\n            TransformerBlock(args=args) for _ in range(args.num_hidden_layers)\n        ]\n        self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)\n\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache=None,\n    ):\n        h = self.embed_tokens(inputs)\n\n        mask = None\n        if h.shape[1] > 1:\n            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])\n            mask = mask.astype(h.dtype)\n\n        if cache is None:\n            cache = [None] * len(self.layers)\n\n        for e, layer in enumerate(self.layers):\n            h, cache[e] = layer(h, mask, cache[e])\n\n        return self.norm(h), cache\n\n\nclass Model(nn.Module):\n    def __init__(self, args: ModelArgs):\n        super().__init__()\n        self.model_type = args.model_type\n        self.model = LlamaModel(args)\n        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)\n\n    def __call__(\n        self,\n        inputs: mx.array,\n        cache=None,\n    ):\n        out, cache = self.model(inputs, cache)\n        return self.lm_head(out), cache\n\n    @staticmethod\n    def sanitize(weights):\n        # Remove unused precomputed rotary freqs\n        return {\n            k: v for k, v in weights.items() if \"self_attn.rotary_emb.inv_freq\" not in k\n        }\n\n    @property\n    def layers(self):\n        return self.model.layers\n"
  },
  {
    "path": "server/py.typed",
    "content": "\n"
  },
  {
    "path": "server/requirements.txt",
    "content": "chromadb==0.4.23\nhuggingface_hub==0.20.3\nmlx==0.4.0\nmlx_data==0.0.2\ntransformers==4.38.1\npyinstaller==6.4.0\n"
  },
  {
    "path": "server/retriever/document.py",
    "content": "from typing import Any, Literal, Optional\n\n\nclass Document():\n    \"\"\"Class for storing a piece of text and associated metadata.\"\"\"\n\n    page_content: str\n    \"\"\"String text.\"\"\"\n    metadata: dict = dict()\n    \"\"\"Arbitrary metadata about the page content (e.g., source, relationships to other\n        documents, etc.).\n    \"\"\"\n    type: Literal[\"Document\"] = \"Document\"\n\n    def __init__(self, page_content: str, metadata: Optional[dict] = None, **kwargs: Any) -> None:\n        \"\"\"Pass page_content in as positional or named arg.\"\"\"\n        self.page_content = page_content\n        self.metadata = metadata or dict()\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n"
  },
  {
    "path": "server/retriever/embeddings.py",
    "content": "import os\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom transformers import PreTrainedTokenizer\nfrom abc import ABC, abstractmethod\nfrom typing import Any, List\n\nfrom ..utils import load, get_mlx_path, convert\n\n\nclass Embeddings(ABC):\n    \"\"\"Interface for embedding models.\"\"\"\n\n    @abstractmethod\n    def embed_documents(self, texts: List[str]) -> List[List[float]]:\n        \"\"\"Embed search docs.\"\"\"\n\n    @abstractmethod\n    def embed_query(self, text: str) -> List[float]:\n        \"\"\"Embed query text.\"\"\"\n\n\nclass E5Embeddings(Embeddings):\n\n    model: Any = None\n    tokenizer: PreTrainedTokenizer = None\n\n    def __init__(self, hf_path: str = 'intfloat/multilingual-e5-small', quantize: bool = False):\n        mlx_path = get_mlx_path(hf_path, quantize=quantize)\n        if not os.path.isdir(mlx_path):\n            convert(hf_path, mlx_path, quantize=quantize)\n        self.model, self.tokenizer = load(mlx_path)\n\n    def _average_pool(self, last_hidden_states: mx.array,\n                      attention_mask: mx.array) -> mx.array:\n        last_hidden = mx.where(~attention_mask[..., None].astype(dtype=mx.bool_),\n                               0.0, last_hidden_states)\n        return mx.sum(last_hidden, axis=1) / mx.sum(attention_mask, axis=1, keepdims=True)\n\n    def embed_documents(self, texts: List[str], batch_size: int = 8) -> List[List[float]]:\n        embeddings = []\n        for i in range(0, len(texts), batch_size):\n            batch_texts = texts[i:i+batch_size]\n            batch_embeddings = self.embed_query(batch_texts, batch=True)\n            embeddings.extend(batch_embeddings)\n        return embeddings\n\n    def embed_query(self, texts: Any, batch: bool = False) -> List[Any]:\n        tokens = self.tokenizer(texts, max_length=512, padding=True,\n                                truncation=True, return_tensors='np',\n                                return_attention_mask=True)\n        tokens = {key: mx.array(v) for key, v in tokens.items()}\n        outputs = self.model(**tokens)\n        embeddings = self._average_pool(\n            outputs['last_hidden_state'], tokens['attention_mask'])\n        embeddings = embeddings / \\\n            mx.linalg.norm(embeddings, ord=2, axis=1, keepdims=True)\n\n        if batch:\n            return embeddings.tolist()  # -> List[List[float]]\n\n        return embeddings[0].tolist()  # -> List[float]\n\n\nclass ChatEmbeddings(Embeddings):\n\n    model: nn.Module = None\n    tokenizer: PreTrainedTokenizer = None\n\n    def __init__(self, model: nn.Module, tokenizer: PreTrainedTokenizer):\n        self.model = model\n        self.tokenizer = tokenizer\n\n    def embed_documents(self, texts: List[str]) -> List[List[float]]:\n        return [self.embed_query(text) for text in texts]\n\n    def embed_query(self,  text: str) -> List[float]:\n        h = self.model.embed_tokens(mx.array(\n            self.tokenizer.encode(text, add_special_tokens=False)))\n        # normalized to have unit length\n        h = mx.mean(h, axis=0)\n        h = h / mx.linalg.norm(h)\n        return h.tolist()\n"
  },
  {
    "path": "server/retriever/loader.py",
    "content": "import os\nimport glob\nfrom typing import List, Optional\nfrom concurrent.futures import ThreadPoolExecutor\n\nfrom .document import Document\n\n\ndef directory_loader(directory: Optional[str] = None) -> Optional[List[Document]]:\n    if directory is not None and os.path.exists(directory):\n        allowed_extensions = ['.txt', '.md', '.csv', '.json', '.xml', '.ts']\n\n        def read_file(file_path):\n            _, file_extension = os.path.splitext(file_path)\n            if file_extension.lower() in allowed_extensions:\n                with open(file_path, 'r', encoding='utf-8') as file:\n                    return Document(page_content=file.read(), metadata={'source': file_path})\n\n        files = glob.glob(os.path.join(directory, '**', '*.*'), recursive=True)\n\n        with ThreadPoolExecutor() as executor:\n            return list(filter(None, executor.map(read_file, files)))\n    else:\n        raise FileNotFoundError(f\"Directory '{directory}' does not exist.\")\n"
  },
  {
    "path": "server/retriever/splitter.py",
    "content": "import re\nimport copy\n\nfrom abc import ABC, abstractmethod\nfrom typing import (\n    Any,\n    List,\n    Optional,\n    Callable,\n    Iterable\n)\n\nfrom .document import Document\n\n\ndef _split_text_with_regex(\n    text: str, separator: str, keep_separator: bool\n) -> List[str]:\n    # Now that we have the separator, split the text\n    if separator:\n        if keep_separator:\n            # The parentheses in the pattern keep the delimiters in the result.\n            _splits = re.split(f\"({separator})\", text)\n            splits = [_splits[i] + _splits[i + 1]\n                      for i in range(1, len(_splits), 2)]\n            if len(_splits) % 2 == 0:\n                splits += _splits[-1:]\n            splits = [_splits[0]] + splits\n        else:\n            splits = re.split(separator, text)\n    else:\n        splits = list(text)\n    return [s for s in splits if s != \"\"]\n\n\nclass TextSplitter(ABC):\n    \"\"\"Interface for splitting text into chunks.\"\"\"\n\n    def __init__(\n        self,\n        chunk_size: int = 4000,\n        chunk_overlap: int = 200,\n        length_function: Callable[[str], int] = len,\n        keep_separator: bool = False,\n        add_start_index: bool = False,\n        strip_whitespace: bool = True,\n    ) -> None:\n        \"\"\"Create a new TextSplitter.\n\n        Args:\n            chunk_size: Maximum size of chunks to return\n            chunk_overlap: Overlap in characters between chunks\n            length_function: Function that measures the length of given chunks\n            keep_separator: Whether to keep the separator in the chunks\n            add_start_index: If `True`, includes chunk's start index in metadata\n            strip_whitespace: If `True`, strips whitespace from the start and end of\n                              every document\n        \"\"\"\n        if chunk_overlap > chunk_size:\n            raise ValueError(f\"Got a larger chunk overlap ({\n                             chunk_overlap}) than chunk size ({chunk_size}), should be smaller.\")\n        self._chunk_size = chunk_size\n        self._chunk_overlap = chunk_overlap\n        self._length_function = length_function\n        self._keep_separator = keep_separator\n        self._add_start_index = add_start_index\n        self._strip_whitespace = strip_whitespace\n\n    @abstractmethod\n    def split_text(self, text: str) -> List[str]:\n        \"\"\"Split text into multiple components.\"\"\"\n\n    def create_documents(\n        self, texts: List[str], metadatas: Optional[List[dict]] = None\n    ) -> List[Document]:\n        \"\"\"Create documents from a list of texts.\"\"\"\n        _metadatas = metadatas or [{}] * len(texts)\n        documents = []\n        for i, text in enumerate(texts):\n            index = 0\n            previous_chunk_len = 0\n            for chunk in self.split_text(text):\n                metadata = copy.deepcopy(_metadatas[i])\n                if self._add_start_index:\n                    offset = index + previous_chunk_len - self._chunk_overlap\n                    index = text.find(chunk, max(0, offset))\n                    metadata[\"start_index\"] = index\n                    previous_chunk_len = len(chunk)\n                new_doc = Document(page_content=chunk, metadata=metadata)\n                documents.append(new_doc)\n        return documents\n\n    def split_documents(self, documents: Iterable[Document]) -> List[Document]:\n        \"\"\"Split documents.\"\"\"\n        texts, metadatas = [], []\n        for doc in documents:\n            texts.append(doc.page_content)\n            metadatas.append(doc.metadata)\n        return self.create_documents(texts, metadatas=metadatas)\n\n    def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:\n        text = separator.join(docs)\n        if self._strip_whitespace:\n            text = text.strip()\n        if text == \"\":\n            return None\n        else:\n            return text\n\n    def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:\n        # We now want to combine these smaller pieces into medium size\n        # chunks to send to the LLM.\n        separator_len = self._length_function(separator)\n\n        docs = []\n        current_doc: List[str] = []\n        total = 0\n        for d in splits:\n            _len = self._length_function(d)\n            if (\n                total + _len + (separator_len if len(current_doc) > 0 else 0)\n                > self._chunk_size\n            ):\n                if total > self._chunk_size:\n                    print(f\"Created a chunk of size {total}, \" +\n                          f\"which is longer than the specified {self._chunk_size}\")\n                if len(current_doc) > 0:\n                    doc = self._join_docs(current_doc, separator)\n                    if doc is not None:\n                        docs.append(doc)\n                    # Keep on popping if:\n                    # - we have a larger chunk than in the chunk overlap\n                    # - or if we still have any chunks and the length is long\n                    while total > self._chunk_overlap or (\n                        total + _len +\n                            (separator_len if len(current_doc) > 0 else 0)\n                        > self._chunk_size\n                        and total > 0\n                    ):\n                        total -= self._length_function(current_doc[0]) + (\n                            separator_len if len(current_doc) > 1 else 0\n                        )\n                        current_doc = current_doc[1:]\n            current_doc.append(d)\n            total += _len + (separator_len if len(current_doc) > 1 else 0)\n        doc = self._join_docs(current_doc, separator)\n        if doc is not None:\n            docs.append(doc)\n        return docs\n\n\nclass RecursiveCharacterTextSplitter(TextSplitter):\n    \"\"\"Splitting text by recursively look at characters.\n\n    Recursively tries to split by different characters to find one\n    that works.\n    \"\"\"\n\n    def __init__(\n        self,\n        separators: Optional[List[str]] = None,\n        keep_separator: bool = True,\n        is_separator_regex: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"Create a new TextSplitter.\"\"\"\n        super().__init__(keep_separator=keep_separator, **kwargs)\n        self._separators = separators or [\"\\n\\n\", \"\\n\", \" \", \"\"]\n        self._is_separator_regex = is_separator_regex\n\n    def _split_text(self, text: str, separators: List[str]) -> List[str]:\n        \"\"\"Split incoming text and return chunks.\"\"\"\n        final_chunks = []\n        # Get appropriate separator to use\n        separator = separators[-1]\n        new_separators = []\n        for i, _s in enumerate(separators):\n            _separator = _s if self._is_separator_regex else re.escape(_s)\n            if _s == \"\":\n                separator = _s\n                break\n            if re.search(_separator, text):\n                separator = _s\n                new_separators = separators[i + 1:]\n                break\n\n        _separator = separator if self._is_separator_regex else re.escape(\n            separator)\n        splits = _split_text_with_regex(text, _separator, self._keep_separator)\n\n        # Now go merging things, recursively splitting longer texts.\n        _good_splits = []\n        _separator = \"\" if self._keep_separator else separator\n        for s in splits:\n            if self._length_function(s) < self._chunk_size:\n                _good_splits.append(s)\n            else:\n                if _good_splits:\n                    merged_text = self._merge_splits(_good_splits, _separator)\n                    final_chunks.extend(merged_text)\n                    _good_splits = []\n                if not new_separators:\n                    final_chunks.append(s)\n                else:\n                    other_info = self._split_text(s, new_separators)\n                    final_chunks.extend(other_info)\n        if _good_splits:\n            merged_text = self._merge_splits(_good_splits, _separator)\n            final_chunks.extend(merged_text)\n        return final_chunks\n\n    def split_text(self, text: str) -> List[str]:\n        return self._split_text(text, self._separators)\n"
  },
  {
    "path": "server/retriever/vectorstore.py",
    "content": "import uuid\nimport functools\nimport mlx.core as mx\n\nimport chromadb\nimport chromadb.config\n\nfrom chromadb.utils.batch_utils import create_batches\nfrom chromadb.api.types import ID, OneOrMany, Where, WhereDocument\nfrom typing import (\n    Any,\n    List,\n    Dict,\n    TypeVar,\n    Callable,\n    Iterable,\n    Optional,\n    Tuple,\n    Type,\n)\nfrom .document import Document\nfrom .embeddings import Embeddings\n\nChroma = TypeVar('Chroma', bound='Chroma')\n\n\nDEFAULT_K = 4  # Number of Documents to return.\n\n\ndef _results_to_docs(results: Any) -> List[Document]:\n    return [doc for doc, _ in _results_to_docs_and_scores(results)]\n\n\ndef _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:\n    return [\n        # TODO: Chroma can do batch querying,\n        # we shouldn't hard code to the 1st result\n        (Document(page_content=result[0], metadata=result[1] or {}), result[2])\n        for result in zip(\n            results[\"documents\"][0],\n            results[\"metadatas\"][0],\n            results[\"distances\"][0],\n        )\n    ]\n\n\ndef xor_args(*arg_groups: Tuple[str, ...]) -> Callable:\n    \"\"\"Validate specified keyword args are mutually exclusive.\"\"\"\n\n    def decorator(func: Callable) -> Callable:\n        @functools.wraps(func)\n        def wrapper(*args: Any, **kwargs: Any) -> Any:\n            \"\"\"Validate exactly one arg in each group is not None.\"\"\"\n            counts = [\n                sum(1 for arg in arg_group if kwargs.get(arg) is not None)\n                for arg_group in arg_groups\n            ]\n            invalid_groups = [\n                i for i, count in enumerate(counts) if count != 1]\n            if invalid_groups:\n                invalid_group_names = [\n                    \", \".join(arg_groups[i]) for i in invalid_groups]\n                raise ValueError(\n                    \"Exactly one argument in each of the following\"\n                    \" groups must be defined:\"\n                    f\" {', '.join(invalid_group_names)}\"\n                )\n            return func(*args, **kwargs)\n\n        return wrapper\n\n    return decorator\n\n\ndef cosine_similarity(\n    X: mx.array, T: mx.array, axis: int = 1\n) -> mx.array:\n    \"\"\"Row-wise cosine similarity between two equal-width matrices.\"\"\"\n    X, T = mx.array(X), mx.array(T)\n    X_norm = mx.linalg.norm(X, axis=axis)\n    T_norm = mx.linalg.norm(T, axis=axis)\n    similarity = X @ T.T / mx.outer(X_norm, T_norm)\n    return similarity\n\n\ndef maximal_marginal_relevance(\n    query_embedding: mx.array,\n    embedding_list: mx.array,\n    lambda_mult: float = 0.5,\n    k: int = 4,\n) -> List[int]:\n    \"\"\"Calculate maximal marginal relevance.\"\"\"\n    if min(k, len(embedding_list)) <= 0:\n        return []\n    if query_embedding.ndim == 1:\n        query_embedding = mx.expand_dims(query_embedding, axis=0)\n    similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]\n    most_similar = int(mx.argmax(similarity_to_query).tolist())\n    idxs = [most_similar]\n    selected = mx.array([embedding_list[most_similar]])\n    while len(idxs) < min(k, len(embedding_list)):\n        best_score = -mx.inf\n        idx_to_add = -1\n        similarity_to_selected = cosine_similarity(embedding_list, selected)\n        for i, query_score in enumerate(similarity_to_query):\n            if i in idxs:\n                continue\n            redundant_score = max(similarity_to_selected[i])\n            equation_score = (\n                lambda_mult * query_score - (1 - lambda_mult) * redundant_score\n            )\n            if equation_score > best_score:\n                best_score = equation_score\n                idx_to_add = i\n        idxs.append(idx_to_add)\n        selected = mx.concatenate([\n            selected, embedding_list[idx_to_add:idx_to_add+1]], axis=0)\n    return idxs\n\n\nclass Chroma():\n    \"\"\"\n    similarity_search\n    max_marginal_relevance_search\n    \"\"\"\n    _DEFAULT_COLLECTION_NAME = \"mlx-chat-app\"\n\n    def __init__(\n        self,\n        collection_name: str = _DEFAULT_COLLECTION_NAME,\n        embedding_function: Optional[Embeddings] = None,\n        persist_directory: Optional[str] = None,\n        client_settings: Optional[chromadb.config.Settings] = None,\n        collection_metadata: Optional[Dict] = None,\n        client: Optional[chromadb.Client] = None,\n        relevance_score_fn: Optional[Callable[[float], float]] = None,\n    ) -> None:\n\n        if client is not None:\n            self._client_settings = client_settings\n            self._client = client\n            self._persist_directory = persist_directory\n        else:\n            if client_settings:\n                # If client_settings is provided with persist_directory specified,\n                # then it is \"in-memory and persisting to disk\" mode.\n                client_settings.persist_directory = (\n                    persist_directory or client_settings.persist_directory\n                )\n                if client_settings.persist_directory is not None:\n                    # Maintain backwards compatibility with chromadb < 0.4.0\n                    major, minor, _ = chromadb.__version__.split(\".\")\n                    if int(major) == 0 and int(minor) < 4:\n                        client_settings.chroma_db_impl = \"duckdb+parquet\"\n\n                _client_settings = client_settings\n            elif persist_directory:\n                # Maintain backwards compatibility with chromadb < 0.4.0\n                major, minor, _ = chromadb.__version__.split(\".\")\n                if int(major) == 0 and int(minor) < 4:\n                    _client_settings = chromadb.config.Settings(\n                        chroma_db_impl=\"duckdb+parquet\",\n                    )\n                else:\n                    _client_settings = chromadb.config.Settings(\n                        is_persistent=True)\n                _client_settings.persist_directory = persist_directory\n            else:\n                _client_settings = chromadb.config.Settings()\n            self._client_settings = _client_settings\n            self._client = chromadb.Client(_client_settings)\n            self._persist_directory = (\n                _client_settings.persist_directory or persist_directory\n            )\n\n        self._embedding_function = embedding_function\n        self._collection = self._client.get_or_create_collection(\n            name=collection_name,\n            embedding_function=None,\n            metadata=collection_metadata,\n        )\n        self.override_relevance_score_fn = relevance_score_fn\n\n    @property\n    def embeddings(self) -> Optional[Embeddings]:\n        return self._embedding_function\n\n    @xor_args((\"query_texts\", \"query_embeddings\"))\n    def __query_collection(\n        self,\n        query_texts: Optional[List[str]] = None,\n        query_embeddings: Optional[List[List[float]]] = None,\n        n_results: int = 4,\n        where: Optional[Dict[str, str]] = None,\n        where_document: Optional[Dict[str, str]] = None,\n        **kwargs: Any,\n    ) -> List[Document]:\n        \"\"\"Query the chroma collection.\"\"\"\n        return self._collection.query(\n            query_texts=query_texts,\n            query_embeddings=query_embeddings,\n            n_results=n_results,\n            where=where,\n            where_document=where_document,\n            **kwargs,\n        )\n\n    def add_texts(\n        self,\n        texts: Iterable[str],\n        metadatas: Optional[List[dict]] = None,\n        ids: Optional[List[str]] = None,\n        **kwargs: Any,\n    ) -> List[str]:\n        \"\"\"Run more texts through the embeddings and add to the vectorstore.\n\n        Args:\n            texts (Iterable[str]): Texts to add to the vectorstore.\n            metadatas (Optional[List[dict]], optional): Optional list of metadatas.\n            ids (Optional[List[str]], optional): Optional list of IDs.\n\n        Returns:\n            List[str]: List of IDs of the added texts.\n        \"\"\"\n        # TODO: Handle the case where the user doesn't provide ids on the Collection\n        if ids is None:\n            ids = [str(uuid.uuid1()) for _ in texts]\n        embeddings = None\n        texts = list(texts)\n        if self._embedding_function is not None:\n            embeddings = self._embedding_function.embed_documents(texts)\n        if metadatas:\n            # fill metadatas with empty dicts if somebody\n            # did not specify metadata for all texts\n            length_diff = len(texts) - len(metadatas)\n            if length_diff:\n                metadatas = metadatas + [{}] * length_diff\n            empty_ids = []\n            non_empty_ids = []\n            for idx, m in enumerate(metadatas):\n                if m:\n                    non_empty_ids.append(idx)\n                else:\n                    empty_ids.append(idx)\n            if non_empty_ids:\n                metadatas = [metadatas[idx] for idx in non_empty_ids]\n                texts_with_metadatas = [texts[idx] for idx in non_empty_ids]\n                embeddings_with_metadatas = (\n                    [embeddings[idx]\n                        for idx in non_empty_ids] if embeddings else None\n                )\n                ids_with_metadata = [ids[idx] for idx in non_empty_ids]\n                try:\n                    self._collection.upsert(\n                        metadatas=metadatas,\n                        embeddings=embeddings_with_metadatas,\n                        documents=texts_with_metadatas,\n                        ids=ids_with_metadata,\n                    )\n                except ValueError as e:\n                    if \"Expected metadata value to be\" in str(e):\n                        msg = (\n                            \"Try filtering complex metadata from the document using \"\n                            \"langchain_community.vectorstores.utils.filter_complex_metadata.\"\n                        )\n                        raise ValueError(e.args[0] + \"\\n\\n\" + msg)\n                    else:\n                        raise e\n            if empty_ids:\n                texts_without_metadatas = [texts[j] for j in empty_ids]\n                embeddings_without_metadatas = (\n                    [embeddings[j] for j in empty_ids] if embeddings else None\n                )\n                ids_without_metadatas = [ids[j] for j in empty_ids]\n                self._collection.upsert(\n                    embeddings=embeddings_without_metadatas,\n                    documents=texts_without_metadatas,\n                    ids=ids_without_metadatas,\n                )\n        else:\n            self._collection.upsert(\n                embeddings=embeddings,\n                documents=texts,\n                ids=ids,\n            )\n        return ids\n\n    def similarity_search(\n        self,\n        query: str,\n        k: int = DEFAULT_K,\n        filter: Optional[Dict[str, str]] = None,\n        **kwargs: Any,\n    ) -> List[Document]:\n        \"\"\"Run similarity search with Chroma.\n\n        Args:\n            query (str): Query text to search for.\n            k (int): Number of results to return. Defaults to 4.\n            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.\n\n        Returns:\n            List[Document]: List of documents most similar to the query text.\n        \"\"\"\n        docs_and_scores = self.similarity_search_with_score(\n            query, k, filter=filter, **kwargs\n        )\n        return [doc for doc, _ in docs_and_scores]\n\n    def similarity_search_with_score(\n        self,\n        query: str,\n        k: int = DEFAULT_K,\n        filter: Optional[Dict[str, str]] = None,\n        where_document: Optional[Dict[str, str]] = None,\n        **kwargs: Any,\n    ) -> List[Tuple[Document, float]]:\n        \"\"\"Run similarity search with Chroma with distance.\n\n        Args:\n            query (str): Query text to search for.\n            k (int): Number of results to return. Defaults to 4.\n            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.\n\n        Returns:\n            List[Tuple[Document, float]]: List of documents most similar to\n            the query text and cosine distance in float for each.\n            Lower score represents more similarity.\n        \"\"\"\n        if self._embedding_function is None:\n            results = self.__query_collection(\n                query_texts=[query],\n                n_results=k,\n                where=filter,\n                where_document=where_document,\n                **kwargs,\n            )\n        else:\n            query_embedding = self._embedding_function.embed_query(query)\n            results = self.__query_collection(\n                query_embeddings=[query_embedding],\n                n_results=k,\n                where=filter,\n                where_document=where_document,\n                **kwargs,\n            )\n\n        return _results_to_docs_and_scores(results)\n\n    def max_marginal_relevance_search_by_vector(\n        self,\n        embedding: List[float],\n        k: int = DEFAULT_K,\n        fetch_k: int = 20,\n        lambda_mult: float = 0.5,\n        filter: Optional[Dict[str, str]] = None,\n        where_document: Optional[Dict[str, str]] = None,\n        **kwargs: Any,\n    ) -> List[Document]:\n        \"\"\"Return docs selected using the maximal marginal relevance.\n        Maximal marginal relevance optimizes for similarity to query AND diversity\n        among selected documents.\n\n        Args:\n            embedding: Embedding to look up documents similar to.\n            k: Number of Documents to return. Defaults to 4.\n            fetch_k: Number of Documents to fetch to pass to MMR algorithm.\n            lambda_mult: Number between 0 and 1 that determines the degree\n                        of diversity among the results with 0 corresponding\n                        to maximum diversity and 1 to minimum diversity.\n                        Defaults to 0.5.\n            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.\n\n        Returns:\n            List of Documents selected by maximal marginal relevance.\n        \"\"\"\n\n        results = self.__query_collection(\n            query_embeddings=embedding,\n            n_results=fetch_k,\n            where=filter,\n            where_document=where_document,\n            include=[\"metadatas\", \"documents\", \"distances\", \"embeddings\"],\n            **kwargs,\n        )\n        mmr_selected = maximal_marginal_relevance(\n            mx.array(embedding, dtype=mx.float32),\n            mx.array(results[\"embeddings\"][0]),\n            k=k,\n            lambda_mult=lambda_mult,\n        )\n\n        candidates = _results_to_docs(results)\n\n        selected_results = [r for i, r in enumerate(\n            candidates) if i in mmr_selected]\n        return selected_results\n\n    def max_marginal_relevance_search(\n        self,\n        query: str,\n        k: int = DEFAULT_K,\n        fetch_k: int = 20,\n        lambda_mult: float = 0.5,\n        filter: Optional[Dict[str, str]] = None,\n        where_document: Optional[Dict[str, str]] = None,\n        **kwargs: Any,\n    ) -> List[Document]:\n        \"\"\"Return docs selected using the maximal marginal relevance.\n        Maximal marginal relevance optimizes for similarity to query AND diversity\n        among selected documents.\n\n        Args:\n            query: Text to look up documents similar to.\n            k: Number of Documents to return. Defaults to 4.\n            fetch_k: Number of Documents to fetch to pass to MMR algorithm.\n            lambda_mult: Number between 0 and 1 that determines the degree\n                        of diversity among the results with 0 corresponding\n                        to maximum diversity and 1 to minimum diversity.\n                        Defaults to 0.5.\n            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.\n\n        Returns:\n            List of Documents selected by maximal marginal relevance.\n        \"\"\"\n        if self._embedding_function is None:\n            raise ValueError(\n                \"For MMR search, you must specify an embedding function on\" \"creation.\"\n            )\n\n        embedding = self._embedding_function.embed_query(query)\n        docs = self.max_marginal_relevance_search_by_vector(\n            embedding,\n            k,\n            fetch_k,\n            lambda_mult=lambda_mult,\n            filter=filter,\n            where_document=where_document,\n        )\n        return docs\n\n    def delete_collection(self) -> None:\n        \"\"\"Delete the collection.\"\"\"\n        self._client.delete_collection(self._collection.name)\n\n    def get(\n        self,\n        ids: Optional[OneOrMany[ID]] = None,\n        where: Optional[Where] = None,\n        limit: Optional[int] = None,\n        offset: Optional[int] = None,\n        where_document: Optional[WhereDocument] = None,\n        include: Optional[List[str]] = None,\n    ) -> Dict[str, Any]:\n        \"\"\"Gets the collection.\n\n        Args:\n            ids: The ids of the embeddings to get. Optional.\n            where: A Where type dict used to filter results by.\n                   E.g. `{\"color\" : \"red\", \"price\": 4.20}`. Optional.\n            limit: The number of documents to return. Optional.\n            offset: The offset to start returning results from.\n                    Useful for paging results with limit. Optional.\n            where_document: A WhereDocument type dict used to filter by the documents.\n                            E.g. `{$contains: \"hello\"}`. Optional.\n            include: A list of what to include in the results.\n                     Can contain `\"embeddings\"`, `\"metadatas\"`, `\"documents\"`.\n                     Ids are always included.\n                     Defaults to `[\"metadatas\", \"documents\"]`. Optional.\n        \"\"\"\n        kwargs = {\n            \"ids\": ids,\n            \"where\": where,\n            \"limit\": limit,\n            \"offset\": offset,\n            \"where_document\": where_document,\n        }\n\n        if include is not None:\n            kwargs[\"include\"] = include\n\n        return self._collection.get(**kwargs)\n\n    def update_document(self, document_id: str, document: Document) -> None:\n        \"\"\"Update a document in the collection.\n\n        Args:\n            document_id (str): ID of the document to update.\n            document (Document): Document to update.\n        \"\"\"\n        return self.update_documents([document_id], [document])\n\n    def update_documents(self, ids: List[str], documents: List[Document]) -> None:\n        \"\"\"Update a document in the collection.\n\n        Args:\n            ids (List[str]): List of ids of the document to update.\n            documents (List[Document]): List of documents to update.\n        \"\"\"\n        text = [document.page_content for document in documents]\n        metadata = [document.metadata for document in documents]\n        if self._embedding_function is None:\n            raise ValueError(\n                \"For update, you must specify an embedding function on creation.\"\n            )\n        embeddings = self._embedding_function.embed_documents(text)\n\n        if hasattr(\n            self._collection._client, \"max_batch_size\"\n        ):\n            for batch in create_batches(\n                api=self._collection._client,\n                ids=ids,\n                metadatas=metadata,\n                documents=text,\n                embeddings=embeddings,\n            ):\n                self._collection.update(\n                    ids=batch[0],\n                    embeddings=batch[1],\n                    documents=batch[3],\n                    metadatas=batch[2],\n                )\n        else:\n            self._collection.update(\n                ids=ids,\n                embeddings=embeddings,\n                documents=text,\n                metadatas=metadata,\n            )\n\n    @classmethod\n    def from_texts(\n        cls: Type[Chroma],\n        texts: List[str],\n        embedding: Optional[Embeddings] = None,\n        metadatas: Optional[List[dict]] = None,\n        ids: Optional[List[str]] = None,\n        collection_name: str = _DEFAULT_COLLECTION_NAME,\n        persist_directory: Optional[str] = None,\n        client_settings: Optional[chromadb.config.Settings] = None,\n        client: Optional[chromadb.Client] = None,\n        collection_metadata: Optional[Dict] = None,\n        **kwargs: Any,\n    ) -> Chroma:\n        \"\"\"Create a Chroma vectorstore from a raw documents.\n\n        If a persist_directory is specified, the collection will be persisted there.\n        Otherwise, the data will be ephemeral in-memory.\n\n        Args:\n            texts (List[str]): List of texts to add to the collection.\n            collection_name (str): Name of the collection to create.\n            persist_directory (Optional[str]): Directory to persist the collection.\n            embedding (Optional[Embeddings]): Embedding function. Defaults to None.\n            metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.\n            ids (Optional[List[str]]): List of document IDs. Defaults to None.\n            client_settings (Optional[chromadb.config.Settings]): Chroma client settings\n            collection_metadata (Optional[Dict]): Collection configurations.\n                                                  Defaults to None.\n\n        Returns:\n            Chroma: Chroma vectorstore.\n        \"\"\"\n        chroma_collection = cls(\n            collection_name=collection_name,\n            embedding_function=embedding,\n            persist_directory=persist_directory,\n            client_settings=client_settings,\n            client=client,\n            collection_metadata=collection_metadata,\n            **kwargs,\n        )\n        if ids is None:\n            ids = [str(uuid.uuid1()) for _ in texts]\n        if hasattr(\n            chroma_collection._client, \"max_batch_size\"\n        ):\n            for batch in create_batches(\n                api=chroma_collection._client,\n                ids=ids,\n                metadatas=metadatas,\n                documents=texts,\n            ):\n                chroma_collection.add_texts(\n                    texts=batch[3] if batch[3] else [],\n                    metadatas=batch[2] if batch[2] else None,\n                    ids=batch[0],\n                )\n        else:\n            chroma_collection.add_texts(\n                texts=texts, metadatas=metadatas, ids=ids)\n        return chroma_collection\n\n    @classmethod\n    def from_documents(\n        cls: Type[Chroma],\n        documents: List[Document],\n        embedding: Optional[Embeddings] = None,\n        ids: Optional[List[str]] = None,\n        collection_name: str = _DEFAULT_COLLECTION_NAME,\n        persist_directory: Optional[str] = None,\n        client_settings: Optional[chromadb.config.Settings] = None,\n        client: Optional[chromadb.Client] = None,  # Add this line\n        collection_metadata: Optional[Dict] = None,\n        **kwargs: Any,\n    ) -> Chroma:\n        \"\"\"Create a Chroma vectorstore from a list of documents.\n\n        If a persist_directory is specified, the collection will be persisted there.\n        Otherwise, the data will be ephemeral in-memory.\n\n        Args:\n            collection_name (str): Name of the collection to create.\n            persist_directory (Optional[str]): Directory to persist the collection.\n            ids (Optional[List[str]]): List of document IDs. Defaults to None.\n            documents (List[Document]): List of documents to add to the vectorstore.\n            embedding (Optional[Embeddings]): Embedding function. Defaults to None.\n            client_settings (Optional[chromadb.config.Settings]): Chroma client settings\n            collection_metadata (Optional[Dict]): Collection configurations.\n                                                  Defaults to None.\n\n        Returns:\n            Chroma: Chroma vectorstore.\n        \"\"\"\n        texts = [doc.page_content for doc in documents]\n        metadatas = [doc.metadata for doc in documents]\n        return cls.from_texts(\n            texts=texts,\n            embedding=embedding,\n            metadatas=metadatas,\n            ids=ids,\n            collection_name=collection_name,\n            persist_directory=persist_directory,\n            client_settings=client_settings,\n            client=client,\n            collection_metadata=collection_metadata,\n            **kwargs,\n        )\n\n    def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:\n        \"\"\"Delete by vector IDs.\n\n        Args:\n            ids: List of ids to delete.\n        \"\"\"\n        self._collection.delete(ids=ids)\n"
  },
  {
    "path": "server/server.py",
    "content": "import os\nimport sys\nimport json\nimport time\nimport uuid\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nfrom typing import List, Dict, Optional\nfrom transformers import PreTrainedTokenizer\n\nfrom .utils import load, generate_step, get_mlx_path, convert\n\nfrom .retriever.loader import directory_loader\nfrom .retriever.splitter import RecursiveCharacterTextSplitter\nfrom .retriever.vectorstore import Chroma\nfrom .retriever.embeddings import ChatEmbeddings, E5Embeddings\n\n_model: Optional[nn.Module] = None\n_tokenizer: Optional[PreTrainedTokenizer] = None\n_database: Optional[Chroma] = None\n\n\ndef load_model(model_path: str, adapter_file: Optional[str] = None):\n    global _model\n    global _tokenizer\n\n    models_to_quantize = ['mistral', 'llama', 'gemma']\n    quantize = any(variable in model_path for variable in models_to_quantize)\n\n    mlx_path = get_mlx_path(model_path, quantize=quantize)\n    if not os.path.isdir(mlx_path):\n        convert(model_path, mlx_path, quantize=quantize)\n\n    _model, _tokenizer = load(mlx_path, adapter_file=adapter_file)\n\n\ndef index_directory(directory: str, use_embedding: bool = True):\n    global _database\n    start_t = time.time()\n    raw_docs = directory_loader(directory)\n    text_splitter = RecursiveCharacterTextSplitter(\n        chunk_size=512, chunk_overlap=32, add_start_index=True\n    )\n    embedding = E5Embeddings(quantize=True) if use_embedding else ChatEmbeddings(\n        model=_model.model, tokenizer=_tokenizer)\n    splits = text_splitter.split_documents(raw_docs)\n    _database = Chroma.from_documents(\n        documents=splits,\n        embedding=embedding\n    )\n    print(f'>> indexed {len(splits)} documents in',\n          f'{time.time() - start_t:.2f}s', flush=True)\n\n\ndef create_response(chat_id, prompt, tokens, text):\n    response = {\n        'id': chat_id,\n        'object': 'chat.completion',\n        'created': int(time.time()),\n        'model':  _model.model_type,\n        'system_fingerprint': f'fp_{uuid.uuid4()}',\n        'choices': [\n            {\n                'index': 0,\n                'message': {\n                    'role': 'assistant',\n                    'content': text,\n                },\n                'logprobs': None,\n                'finish_reason': None,\n            }\n        ],\n        'usage': {\n            'prompt_tokens': len(prompt),\n            'completion_tokens': len(tokens),\n            'total_tokens': len(prompt) + len(tokens),\n        },\n    }\n    return response\n\n\ndef format_messages(messages: List[Dict], indexed_files: Optional[str], instructions: Optional[Dict]):\n    personalization = instructions.get(\n        'personalization', '').strip().replace('\\n', '; ')\n    response = instructions.get('response', '').strip().replace('\\n', '; ')\n\n    context = f\"with background knowledge of {\n        indexed_files.strip().replace('\\n', '; ')}\" if indexed_files else ''\n    audience = personalization if personalization else 'general'\n    style = response if response else 'technical, accurate, and professional'\n\n    messages[-1]['content'] = f\"\"\"\n<Context>\n  you are my personalized AI chatbot {context}\n</Context>\n<Objective>\n  respond to the following: {messages[-1]['content']}\n</Objective>\n<Style>\n  {style}\n</Style>\n<Tone>\n  friendly, helpful, and confident\n</Tone>\n<Audience>\n  {audience}\n</Audience>\n<Response>\n  brief, concise, and to the point. Please don't start with \"Sure, ...\"\n</Response>\n\"\"\".strip()\n\n\nclass APIHandler(BaseHTTPRequestHandler):\n    def _set_headers(self, status_code=200):\n        self.send_response(status_code)\n        self.send_header('Content-type', 'application/json')\n        self.send_header('Access-Control-Allow-Origin', '*')\n        self.send_header('Access-Control-Allow-Methods', '*')\n        self.send_header('Access-Control-Allow-Headers', '*')\n        self.end_headers()\n\n    def do_OPTIONS(self):\n        self._set_headers(204)\n\n    def do_POST(self):\n        \"\"\"\n        Endpoint: /api/index\n            Desc: indexes the directory\n            Body:\n                {\n                    directory: str\n                }\n\n        Endpoint: /api/init\n            Desc: initializes the model\n            Body:\n                {\n                    model: str\n                }\n\n        Endpoint: /api/query\n            Desc: handles messages requests (with directory index)\n            Body:\n                {\n                    messages: [ { role: str, content: str } ],\n                    max_tokens: int,\n                    repetition_penalty: float,\n                    repetition_context_size: int,\n                    temperature: float,\n                    top_p: float,\n                    instructions: {\n                        personalization: str,\n                        response: str\n                    },\n                    directory: str\n                }\n        \"\"\"\n        try:\n            post_data = self.rfile.read(int(self.headers['Content-Length']))\n            body = json.loads(post_data.decode('utf-8'))\n            method = {\n                '/api/index': self.index,\n                '/api/query': self.query,\n                '/api/init': self.init,\n            }\n            handle = method.get(self.path, None)\n            if handle is None:\n                self._set_headers(404)\n                self.wfile.write(b'Not Found')\n                return\n\n            response = handle(body)\n            self._set_headers(200)\n            self.wfile.write(json.dumps(response).encode('utf-8'))\n\n        except Exception as e:\n            print(f\"Error: {e}\", flush=True)\n            self._set_headers(500)\n            self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8'))\n\n    def index(self, body):\n        directory = body.get('directory', None)\n        index_directory(directory)\n        return {'directory': directory}\n\n    def init(self, body):\n        model = body.get('model', None)\n        load_model(model)\n        return {'model': model}\n\n    def query(self, body):\n        chat_id = f'chatcmpl-{uuid.uuid4()}'\n\n        directory = body.get('directory', None)\n        messages = body.get('messages', [])\n        instructions = body.get('instructions', None)\n\n        indexed_files = ''\n        if directory:\n            # emperically better than `similarity_search`\n            docs = _database.max_marginal_relevance_search(\n                messages[-1]['content'],\n                k=6  # number of documents to return\n            )\n            indexed_files = '\\n'.join([doc.page_content for doc in docs])\n\n            print(body, flush=True)\n            print(('\\n'+'--'*10+'\\n').join([\n                f'{doc.metadata}\\n{doc.page_content}' for doc in docs]), flush=True)\n\n        format_messages(messages, indexed_files, instructions)\n        print(messages, flush=True)\n\n        prompt = mx.array(_tokenizer.encode(_tokenizer.apply_chat_template(\n            messages,\n            tokenize=False,\n            add_generation_prompt=True,\n        ), add_special_tokens=True))\n\n        max_tokens = body.get('max_tokens', 100)\n        repetition_penalty = body.get('repetition_penalty', None)\n        repetition_context_size = body.get('repetition_context_size', 20)\n        temperature = body.get('temperature', 1.0)\n        top_p = body.get('top_p', 1.0)\n\n        tokens = []\n        REPLACEMENT_CHAR = '\\ufffd'\n        for (token, prob), _ in zip(\n            generate_step(\n                prompt,\n                _model,\n                temperature,\n                repetition_penalty,\n                repetition_context_size,\n                top_p,\n            ),\n            range(max_tokens),\n        ):\n            if token == _tokenizer.eos_token_id:\n                break\n            tokens.append(token.item())\n\n        text = _tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, '')\n        # TODO: GEMMA IS OBSESSED WITH \"Sure, ...\"\n        if text.startswith('Sure, '):\n            text = text.split('\\n')\n            text[0] = text[0].replace('Sure, ', '').capitalize()\n            text = '\\n'.join([l for l in text])\n        return create_response(chat_id, prompt, tokens, text)\n\n\ndef run(host: str, port: int, server_class=HTTPServer, handler_class=APIHandler):\n    server_address = (host, port)\n    httpd = server_class(server_address, handler_class)\n    print(f'Starting httpd at {host} on port {port}...', flush=True)\n    httpd.serve_forever()\n\n\ndef main():\n    if len(sys.argv) < 2:\n        print(\n            \"Usage: python script.py [--host <host_address>] [--port <port_number>]\")\n        sys.exit(1)\n\n    args = {\n        '--host': '127.0.0.1',\n        '--port': 8080\n    }\n\n    i = 1\n    while i < len(sys.argv):\n        if sys.argv[i] in args:\n            args[sys.argv[i]] = sys.argv[i + 1]\n            i += 2\n        else:\n            print(f\"Unknown argument: {sys.argv[i]}\")\n            sys.exit(1)\n\n    # Now you can access the parsed arguments using args dictionary\n    host = args['--host']\n    port = int(args['--port'])\n\n    print(f'>> starting server on {host}:{port}', flush=True)\n    run(host, port)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "server/utils.py",
    "content": "import os\nimport copy\nimport glob\nimport shutil\nimport importlib\nimport json\nimport logging\nimport time\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, Generator, Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom mlx.utils import tree_flatten\n\nfrom huggingface_hub import snapshot_download\nfrom transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer\n\n# Constants\nMODEL_REMAPPING = {\n    \"mistral\": \"llama\",  # mistral is compatible with llama\n    \"phi-msft\": \"phixtral\",\n}\n\nMAX_FILE_SIZE_GB = 5\n\nlinear_class_predicate = (\n    lambda m: isinstance(m, nn.Linear)\n    and m.weight.shape[0]\n    != 8  # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models\n)\n\n\ndef _get_classes(config: dict):\n    \"\"\"\n    Retrieve the model and model args classes based on the configuration.\n\n    Args:\n        config (dict): The model configuration.\n\n    Returns:\n        A tuple containing the Model class and the ModelArgs class.\n    \"\"\"\n    model_type = config[\"model_type\"]\n    model_type = MODEL_REMAPPING.get(model_type, model_type)\n    try:\n        arch = importlib.import_module(f\"server.models.{model_type}\")\n    except ImportError:\n        msg = f\"Model type {model_type} not supported.\"\n        logging.error(msg)\n        raise ValueError(msg)\n\n    return arch.Model, arch.ModelArgs\n\n\ndef get_model_path(path_or_hf_repo: str) -> Path:\n    \"\"\"\n    Ensures the model is available locally. If the path does not exist locally,\n    it is downloaded from the Hugging Face Hub.\n\n    Args:\n        path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.\n\n    Returns:\n        Path: The path to the model.\n    \"\"\"\n    model_path = Path(path_or_hf_repo)\n    if not model_path.exists():\n        model_path = Path(\n            snapshot_download(\n                repo_id=path_or_hf_repo,\n                allow_patterns=[\n                    \"*.json\",\n                    \"*.safetensors\",\n                    \"*.py\",\n                    \"tokenizer.model\",\n                    \"*.tiktoken\",\n                ],\n            )\n        )\n    return model_path\n\n\ndef apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):\n    \"\"\"\n    Apply repetition penalty to specific logits based on the given context.\n\n    Paper: https://arxiv.org/abs/1909.05858\n\n    Args:\n        logits (mx.array): The logits produced by the language model.\n        generated_tokens (any): A list of N previous tokens.\n        penalty (float): The repetition penalty factor to be applied.\n\n    Returns:\n        logits (mx.array): Logits with repetition penalty applied to generated tokens.\n    \"\"\"\n    if len(generated_tokens) > 0:\n        indices = mx.array([token for token in generated_tokens])\n        selected_logits = logits[:, indices]\n        selected_logits = mx.where(\n            selected_logits < 0, selected_logits * penalty, selected_logits / penalty\n        )\n        logits[:, indices] = selected_logits\n    return logits\n\n\ndef generate_step(\n    prompt: mx.array,\n    model: nn.Module,\n    temp: 0.0,\n    repetition_penalty: Optional[float] = None,\n    repetition_context_size: Optional[int] = 20,\n    top_p: float = 1.0,\n) -> Generator[Tuple[mx.array, mx.array], None, None]:\n    \"\"\"\n    A generator producing text based on the given prompt from the model.\n\n    Args:\n        prompt (mx.array): The input prompt.\n        model (nn.Module): The model to use for generation.\n        temp (float): The temperature for sampling, if 0 the argmax is used.\n        repetition_penalty (float, optional): The penalty factor for repeating tokens.\n        repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).\n\n    Yields:\n        Generator[Tuple[mx.array, mx.array]]: A generator producing\n        one token and probability per call.\n    \"\"\"\n\n    def sample(logits: mx.array) -> Tuple[mx.array, float]:\n        softmax_logits = mx.softmax(logits)\n\n        if temp == 0:\n            token = mx.argmax(logits, axis=-1)\n        else:\n            if top_p > 0 and top_p < 1.0:\n                if (\n                    logits.dtype == mx.bfloat16\n                ):  # workdaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16\n                    logits = logits.astype(mx.float32)\n                probs = mx.softmax(logits / temp, axis=-1)\n\n                sorted_probs = mx.sort(probs)[::-1]\n                sorted_indices = mx.argsort(probs)[::-1]\n                cumulative_probs = mx.cumsum(sorted_probs, axis=-1)\n\n                top_probs = mx.where(\n                    cumulative_probs > 1 - top_p,\n                    sorted_probs,\n                    mx.zeros_like(sorted_probs),\n                )\n                sorted_token = mx.random.categorical(mx.log(top_probs))\n                token = sorted_indices.squeeze(0)[sorted_token]\n            else:\n                token = mx.random.categorical(logits * (1 / temp))\n\n        prob = softmax_logits[0, token]\n        return token, prob\n\n    if repetition_penalty and (\n        repetition_penalty < 0 or not isinstance(repetition_penalty, float)\n    ):\n        raise ValueError(\n            f\"repetition_penalty must be a non-negative float, got {\n                repetition_penalty}\"\n        )\n\n    y = prompt\n    cache = None\n\n    repetition_context = prompt.tolist()\n\n    if repetition_context_size:\n        repetition_context = repetition_context[-repetition_context_size:]\n\n    while True:\n        logits, cache = model(y[None], cache=cache)\n        logits = logits[:, -1, :]\n\n        if repetition_penalty:\n            logits = apply_repetition_penalty(\n                logits, repetition_context, repetition_penalty\n            )\n            y, prob = sample(logits)\n            repetition_context.append(y.item())\n        else:\n            y, prob = sample(logits)\n\n        if repetition_context_size:\n            if len(repetition_context) > repetition_context_size:\n                repetition_context = repetition_context[-repetition_context_size:]\n        yield y, prob\n\n\ndef generate(\n    model: nn.Module,\n    tokenizer: PreTrainedTokenizer,\n    prompt: str,\n    temp: float = 0.0,\n    max_tokens: int = 100,\n    verbose: bool = False,\n    formatter: Callable = None,\n    repetition_penalty: Optional[float] = None,\n    repetition_context_size: Optional[int] = None,\n    top_p: float = 1.0,\n) -> str:\n    \"\"\"\n    Generate text from the model.\n\n    Args:\n       model (nn.Module): The language model.\n       tokenizer (PreTrainedTokenizer): The tokenizer.\n       prompt (str): The string prompt.\n       temp (float): The temperature for sampling (default 0).\n       max_tokens (int): The maximum number of tokens (default 100).\n       verbose (bool): If ``True``, print tokens and timing information\n           (default ``False``).\n       formatter (Optional[Callable]): A function which takes a token and a\n           probability and displays it.\n       repetition_penalty (float, optional): The penalty factor for repeating tokens.\n       repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.\n    \"\"\"\n\n    if verbose:\n        print(\"=\" * 10)\n        print(\"Prompt:\", prompt)\n\n    prompt_tokens = mx.array(tokenizer.encode(prompt))\n    print(prompt_tokens, flush=True)\n\n    tic = time.perf_counter()\n    tokens = []\n    skip = 0\n    REPLACEMENT_CHAR = \"\\ufffd\"\n\n    for (token, prob), n in zip(\n        generate_step(\n            prompt_tokens,\n            model,\n            temp,\n            repetition_penalty,\n            repetition_context_size,\n            top_p,\n        ),\n        range(max_tokens),\n    ):\n        if token == tokenizer.eos_token_id:\n            break\n        if n == 0:\n            prompt_time = time.perf_counter() - tic\n            tic = time.perf_counter()\n        tokens.append(token.item())\n\n        if verbose:\n            s = tokenizer.decode(tokens)\n            if formatter:\n                formatter(s[skip:], prob.item())\n                skip = len(s)\n            elif REPLACEMENT_CHAR not in s:\n                print(s[skip:], end=\"\", flush=True)\n                skip = len(s)\n\n    token_count = len(tokens)\n    token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, \"\")\n\n    if verbose:\n        print(token_string[skip:], flush=True)\n        gen_time = time.perf_counter() - tic\n        print(\"=\" * 10)\n        if token_count == 0:\n            print(\"No tokens generated for this prompt\")\n            return\n        prompt_tps = prompt_tokens.size / prompt_time\n        gen_tps = (token_count - 1) / gen_time\n        print(f\"Prompt: {prompt_tps:.3f} tokens-per-sec\")\n        print(f\"Generation: {gen_tps:.3f} tokens-per-sec\")\n\n    return token_string\n\n\ndef load_model(model_path: Path, lazy: bool = False) -> nn.Module:\n    \"\"\"\n    Load and initialize the model from a given path.\n\n    Args:\n        model_path (Path): The path to load the model from.\n        lazy (bool): If False eval the model parameters to make sure they are\n            loaded in memory before returning, otherwise they will be loaded\n            when needed. Default: ``False``\n\n    Returns:\n        nn.Module: The loaded and initialized model.\n\n    Raises:\n        FileNotFoundError: If the weight files (.safetensors) are not found.\n        ValueError: If the model class or args class are not found or cannot be instantiated.\n    \"\"\"\n    try:\n        with open(model_path / \"config.json\", \"r\") as f:\n            config = json.load(f)\n            quantization = config.get(\"quantization\", None)\n    except FileNotFoundError:\n        logging.error(f\"Config file not found in {model_path}\")\n        raise\n\n    weight_files = glob.glob(str(model_path / \"*.safetensors\"))\n    if not weight_files:\n        logging.error(f\"No safetensors found in {model_path}\")\n        raise FileNotFoundError(f\"No safetensors found in {model_path}\")\n\n    weights = {}\n    for wf in weight_files:\n        weights.update(mx.load(wf))\n\n    model_class, model_args_class = _get_classes(config=config)\n    if hasattr(model_class, \"sanitize\"):\n        weights = model_class.sanitize(weights)\n\n    model_args = model_args_class.from_dict(config)\n    model = model_class(model_args)\n\n    if quantization is not None:\n        # for legacy models that don't have lm_head quant due to non-32 dims\n        if \"lm_head.scales\" not in weights.keys():\n            vocab_size = config[\"vocab_size\"]\n            extended_linear_class_predicate = (\n                lambda layer: linear_class_predicate(layer)\n                and layer.weight.shape[0] != vocab_size\n            )\n            nn.QuantizedLinear.quantize_module(\n                model,\n                **quantization,\n                linear_class_predicate=extended_linear_class_predicate,\n            )\n        # for models that have lm_head quant\n        else:\n            nn.QuantizedLinear.quantize_module(\n                model,\n                **quantization,\n                linear_class_predicate=linear_class_predicate,\n            )\n\n    model.load_weights(list(weights.items()))\n\n    if not lazy:\n        mx.eval(model.parameters())\n\n    model.eval()\n    return model\n\n\ndef load(\n    path_or_hf_repo: str,\n    tokenizer_config={},\n    adapter_file: str = None,\n    lazy: bool = False,\n) -> Tuple[nn.Module, PreTrainedTokenizer]:\n    \"\"\"\n    Load the model and tokenizer from a given path or a huggingface repository.\n\n    Args:\n        model_path (Path): The path or the huggingface repository to load the model from.\n        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.\n            Defaults to an empty dictionary.\n        adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.\n            Defaults to None.\n        lazy (bool): If False eval the model parameters to make sure they are\n            loaded in memory before returning, otherwise they will be loaded\n            when needed. Default: ``False``\n    Returns:\n        Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer.\n\n    Raises:\n        FileNotFoundError: If config file or safetensors are not found.\n        ValueError: If model class or args class are not found.\n    \"\"\"\n    model_path = get_model_path(path_or_hf_repo)\n\n    model = load_model(model_path, lazy)\n    if adapter_file is not None:\n        # TODO: Apply LoRA layers\n        # model = apply_lora_layers(model, adapter_file)\n        model.eval()\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)\n    return model, tokenizer\n\n\ndef fetch_from_hub(\n    model_path: Path, lazy: bool = False\n) -> Tuple[Dict, dict, PreTrainedTokenizer]:\n    model = load_model(model_path, lazy)\n\n    config = AutoConfig.from_pretrained(model_path)\n    tokenizer = AutoTokenizer.from_pretrained(model_path)\n\n    return model, config.to_dict(), tokenizer\n\n\ndef make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:\n    \"\"\"\n    Splits the weights into smaller shards.\n\n    Args:\n        weights (dict): Model weights.\n        max_file_size_gb (int): Maximum size of each shard in gigabytes.\n\n    Returns:\n        list: List of weight shards.\n    \"\"\"\n    max_file_size_bytes = max_file_size_gb << 30\n    shards = []\n    shard, shard_size = {}, 0\n    for k, v in weights.items():\n        if shard_size + v.nbytes > max_file_size_bytes:\n            shards.append(shard)\n            shard, shard_size = {}, 0\n        shard[k] = v\n        shard_size += v.nbytes\n    shards.append(shard)\n    return shards\n\n\ndef upload_to_hub(path: str, upload_repo: str, hf_path: str):\n    \"\"\"\n    Uploads the model to Hugging Face hub.\n\n    Args:\n        path (str): Local path to the model.\n        upload_repo (str): Name of the HF repo to upload to.\n        hf_path (str): Path to the original Hugging Face model.\n    \"\"\"\n    import os\n\n    from huggingface_hub import HfApi, ModelCard, logging\n\n    card = ModelCard.load(hf_path)\n    card.data.tags = [\n        \"mlx\"] if card.data.tags is None else card.data.tags + [\"mlx\"]\n    card.text = f\"\"\"\n# {upload_repo}\nThis model was converted to MLX format from [`{hf_path}`]().\nRefer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.\n## Use with mlx\n\n```bash\npip install mlx-lm\n```\n\n```python\nfrom mlx_lm import load, generate\n\nmodel, tokenizer = load(\"{upload_repo}\")\nresponse = generate(model, tokenizer, prompt=\"hello\", verbose=True)\n```\n\"\"\"\n    card.save(os.path.join(path, \"README.md\"))\n\n    logging.set_verbosity_info()\n\n    api = HfApi()\n    api.create_repo(repo_id=upload_repo, exist_ok=True)\n    api.upload_folder(\n        folder_path=path,\n        repo_id=upload_repo,\n        repo_type=\"model\",\n    )\n\n\ndef save_weights(\n    save_path: Union[str, Path],\n    weights: Dict[str, Any],\n    *,\n    donate_weights: bool = False,\n) -> None:\n    \"\"\"Save model weights into specified directory.\"\"\"\n    if isinstance(save_path, str):\n        save_path = Path(save_path)\n    save_path.mkdir(parents=True, exist_ok=True)\n\n    shards = make_shards(weights)\n    shards_count = len(shards)\n    shard_file_format = (\n        \"model-{:05d}-of-{:05d}.safetensors\"\n        if shards_count > 1\n        else \"model.safetensors\"\n    )\n\n    total_size = sum(v.nbytes for v in weights.values())\n    index_data = {\"metadata\": {\"total_size\": total_size}, \"weight_map\": {}}\n\n    # Write the weights and make sure no references are kept other than the\n    # necessary ones\n    if donate_weights:\n        weights.clear()\n        del weights\n\n    for i in range(len(shards)):\n        shard = shards[i]\n        shards[i] = None\n        shard_name = shard_file_format.format(i + 1, shards_count)\n        shard_path = save_path / shard_name\n\n        mx.save_safetensors(str(shard_path), shard)\n\n        for weight_name in shard.keys():\n            index_data[\"weight_map\"][weight_name] = shard_name\n        del shard\n\n    index_data[\"weight_map\"] = {\n        k: index_data[\"weight_map\"][k] for k in sorted(index_data[\"weight_map\"])\n    }\n\n    with open(save_path / \"model.safetensors.index.json\", \"w\") as f:\n        json.dump(\n            index_data,\n            f,\n            indent=4,\n        )\n\n\ndef quantize_model(\n    model: nn.Module, config: dict, q_group_size: int, q_bits: int\n) -> Tuple:\n    \"\"\"\n    Applies quantization to the model weights.\n\n    Args:\n        model (nn.Module): The model to be quantized.\n        config (dict): Model configuration.\n        q_group_size (int): Group size for quantization.\n        q_bits (int): Bits per weight for quantization.\n\n    Returns:\n        Tuple: Tuple containing quantized weights and config.\n    \"\"\"\n    quantized_config = copy.deepcopy(config)\n\n    nn.QuantizedLinear.quantize_module(\n        model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate\n    )\n    quantized_config[\"quantization\"] = {\n        \"group_size\": q_group_size, \"bits\": q_bits}\n    quantized_weights = dict(tree_flatten(model.parameters()))\n\n    return quantized_weights, quantized_config\n\n\ndef get_mlx_path(hf_path: str, quantize: bool = False) -> str:\n    default_home = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n    return os.path.join(\n        default_home, 'huggingface', 'hub', f'models--{hf_path.replace(\"/\", \"--\")}-mlx{\"-q\" if quantize else \"\"}')\n\n\ndef convert(\n    hf_path: str,\n    mlx_path: str = None,\n    quantize: bool = False,\n    q_group_size: int = 64,\n    q_bits: int = 4,\n    dtype: str = \"float16\",\n    upload_repo: str = None,\n    delete_old: bool = True,\n):\n    print(\"[INFO] Loading\", flush=True)\n    model_path = get_model_path(hf_path)\n    print(model_path, flush=True)\n    model, config, tokenizer = fetch_from_hub(model_path, lazy=True)\n\n    weights = dict(tree_flatten(model.parameters()))\n    dtype = mx.float16 if quantize else getattr(mx, dtype)\n    weights = {k: v.astype(dtype) for k, v in weights.items()}\n\n    if quantize:\n        print(\"[INFO] Quantizing\", flush=True)\n        model.load_weights(list(weights.items()))\n        weights, config = quantize_model(model, config, q_group_size, q_bits)\n\n    if mlx_path is None:\n        mlx_path = get_mlx_path(hf_path, quantize)\n\n    if isinstance(mlx_path, str):\n        mlx_path = Path(mlx_path)\n\n    print(f\"[INFO] Saving to {mlx_path}\", flush=True)\n\n    del model\n    save_weights(mlx_path, weights, donate_weights=True)\n\n    py_files = glob.glob(str(model_path / \"*.py\"))\n    for file in py_files:\n        shutil.copy(file, mlx_path)\n\n    tokenizer.save_pretrained(mlx_path)\n\n    with open(mlx_path / \"config.json\", \"w\") as fid:\n        json.dump(config, fid, indent=4)\n\n    if upload_repo is not None:\n        upload_to_hub(mlx_path, upload_repo, hf_path)\n\n    if delete_old:\n        path_components = str(model_path).split(os.path.sep)\n        shutil.rmtree(os.path.sep.join(path_components[:-2]))\n"
  }
]